bevy_render/render_resource/
shader.rs

1use super::ShaderDefVal;
2use crate::define_atomic_id;
3use alloc::borrow::Cow;
4use bevy_asset::{io::Reader, Asset, AssetLoader, AssetPath, Handle, LoadContext};
5use bevy_reflect::TypePath;
6use core::marker::Copy;
7use derive_more::derive::{Display, Error, From};
8
9define_atomic_id!(ShaderId);
10
11#[derive(Error, Display, Debug, From)]
12pub enum ShaderReflectError {
13    WgslParse(naga::front::wgsl::ParseError),
14    #[cfg(feature = "shader_format_glsl")]
15    #[display("GLSL Parse Error: {_0:?}")]
16    #[error(ignore)]
17    GlslParse(Vec<naga::front::glsl::Error>),
18    #[cfg(feature = "shader_format_spirv")]
19    SpirVParse(naga::front::spv::Error),
20    Validation(naga::WithSpan<naga::valid::ValidationError>),
21}
22/// A shader, as defined by its [`ShaderSource`](wgpu::ShaderSource) and [`ShaderStage`](naga::ShaderStage)
23/// This is an "unprocessed" shader. It can contain preprocessor directives.
24#[derive(Asset, TypePath, Debug, Clone)]
25pub struct Shader {
26    pub path: String,
27    pub source: Source,
28    pub import_path: ShaderImport,
29    pub imports: Vec<ShaderImport>,
30    // extra imports not specified in the source string
31    pub additional_imports: Vec<naga_oil::compose::ImportDefinition>,
32    // any shader defs that will be included when this module is used
33    pub shader_defs: Vec<ShaderDefVal>,
34    // we must store strong handles to our dependencies to stop them
35    // from being immediately dropped if we are the only user.
36    pub file_dependencies: Vec<Handle<Shader>>,
37}
38
39impl Shader {
40    fn preprocess(source: &str, path: &str) -> (ShaderImport, Vec<ShaderImport>) {
41        let (import_path, imports, _) = naga_oil::compose::get_preprocessor_data(source);
42
43        let import_path = import_path
44            .map(ShaderImport::Custom)
45            .unwrap_or_else(|| ShaderImport::AssetPath(path.to_owned()));
46
47        let imports = imports
48            .into_iter()
49            .map(|import| {
50                if import.import.starts_with('\"') {
51                    let import = import
52                        .import
53                        .chars()
54                        .skip(1)
55                        .take_while(|c| *c != '\"')
56                        .collect();
57                    ShaderImport::AssetPath(import)
58                } else {
59                    ShaderImport::Custom(import.import)
60                }
61            })
62            .collect();
63
64        (import_path, imports)
65    }
66
67    pub fn from_wgsl(source: impl Into<Cow<'static, str>>, path: impl Into<String>) -> Shader {
68        let source = source.into();
69        let path = path.into();
70        let (import_path, imports) = Shader::preprocess(&source, &path);
71        Shader {
72            path,
73            imports,
74            import_path,
75            source: Source::Wgsl(source),
76            additional_imports: Default::default(),
77            shader_defs: Default::default(),
78            file_dependencies: Default::default(),
79        }
80    }
81
82    pub fn from_wgsl_with_defs(
83        source: impl Into<Cow<'static, str>>,
84        path: impl Into<String>,
85        shader_defs: Vec<ShaderDefVal>,
86    ) -> Shader {
87        Self {
88            shader_defs,
89            ..Self::from_wgsl(source, path)
90        }
91    }
92
93    pub fn from_glsl(
94        source: impl Into<Cow<'static, str>>,
95        stage: naga::ShaderStage,
96        path: impl Into<String>,
97    ) -> Shader {
98        let source = source.into();
99        let path = path.into();
100        let (import_path, imports) = Shader::preprocess(&source, &path);
101        Shader {
102            path,
103            imports,
104            import_path,
105            source: Source::Glsl(source, stage),
106            additional_imports: Default::default(),
107            shader_defs: Default::default(),
108            file_dependencies: Default::default(),
109        }
110    }
111
112    pub fn from_spirv(source: impl Into<Cow<'static, [u8]>>, path: impl Into<String>) -> Shader {
113        let path = path.into();
114        Shader {
115            path: path.clone(),
116            imports: Vec::new(),
117            import_path: ShaderImport::AssetPath(path),
118            source: Source::SpirV(source.into()),
119            additional_imports: Default::default(),
120            shader_defs: Default::default(),
121            file_dependencies: Default::default(),
122        }
123    }
124
125    pub fn set_import_path<P: Into<String>>(&mut self, import_path: P) {
126        self.import_path = ShaderImport::Custom(import_path.into());
127    }
128
129    #[must_use]
130    pub fn with_import_path<P: Into<String>>(mut self, import_path: P) -> Self {
131        self.set_import_path(import_path);
132        self
133    }
134
135    #[inline]
136    pub fn import_path(&self) -> &ShaderImport {
137        &self.import_path
138    }
139
140    pub fn imports(&self) -> impl ExactSizeIterator<Item = &ShaderImport> {
141        self.imports.iter()
142    }
143}
144
145impl<'a> From<&'a Shader> for naga_oil::compose::ComposableModuleDescriptor<'a> {
146    fn from(shader: &'a Shader) -> Self {
147        let shader_defs = shader
148            .shader_defs
149            .iter()
150            .map(|def| match def {
151                ShaderDefVal::Bool(name, b) => {
152                    (name.clone(), naga_oil::compose::ShaderDefValue::Bool(*b))
153                }
154                ShaderDefVal::Int(name, i) => {
155                    (name.clone(), naga_oil::compose::ShaderDefValue::Int(*i))
156                }
157                ShaderDefVal::UInt(name, i) => {
158                    (name.clone(), naga_oil::compose::ShaderDefValue::UInt(*i))
159                }
160            })
161            .collect();
162
163        let as_name = match &shader.import_path {
164            ShaderImport::AssetPath(asset_path) => Some(format!("\"{asset_path}\"")),
165            ShaderImport::Custom(_) => None,
166        };
167
168        naga_oil::compose::ComposableModuleDescriptor {
169            source: shader.source.as_str(),
170            file_path: &shader.path,
171            language: (&shader.source).into(),
172            additional_imports: &shader.additional_imports,
173            shader_defs,
174            as_name,
175        }
176    }
177}
178
179impl<'a> From<&'a Shader> for naga_oil::compose::NagaModuleDescriptor<'a> {
180    fn from(shader: &'a Shader) -> Self {
181        naga_oil::compose::NagaModuleDescriptor {
182            source: shader.source.as_str(),
183            file_path: &shader.path,
184            shader_type: (&shader.source).into(),
185            ..Default::default()
186        }
187    }
188}
189
190#[derive(Debug, Clone)]
191pub enum Source {
192    Wgsl(Cow<'static, str>),
193    Glsl(Cow<'static, str>, naga::ShaderStage),
194    SpirV(Cow<'static, [u8]>),
195    // TODO: consider the following
196    // PrecompiledSpirVMacros(HashMap<HashSet<String>, Vec<u32>>)
197    // NagaModule(Module) ... Module impls Serialize/Deserialize
198}
199
200impl Source {
201    pub fn as_str(&self) -> &str {
202        match self {
203            Source::Wgsl(s) | Source::Glsl(s, _) => s,
204            Source::SpirV(_) => panic!("spirv not yet implemented"),
205        }
206    }
207}
208
209impl From<&Source> for naga_oil::compose::ShaderLanguage {
210    fn from(value: &Source) -> Self {
211        match value {
212            Source::Wgsl(_) => naga_oil::compose::ShaderLanguage::Wgsl,
213            #[cfg(any(feature = "shader_format_glsl", target_arch = "wasm32"))]
214            Source::Glsl(_, _) => naga_oil::compose::ShaderLanguage::Glsl,
215            #[cfg(all(not(feature = "shader_format_glsl"), not(target_arch = "wasm32")))]
216            Source::Glsl(_, _) => panic!(
217                "GLSL is not supported in this configuration; use the feature `shader_format_glsl`"
218            ),
219            Source::SpirV(_) => panic!("spirv not yet implemented"),
220        }
221    }
222}
223
224impl From<&Source> for naga_oil::compose::ShaderType {
225    fn from(value: &Source) -> Self {
226        match value {
227            Source::Wgsl(_) => naga_oil::compose::ShaderType::Wgsl,
228            #[cfg(any(feature = "shader_format_glsl", target_arch = "wasm32"))]
229            Source::Glsl(_, shader_stage) => match shader_stage {
230                naga::ShaderStage::Vertex => naga_oil::compose::ShaderType::GlslVertex,
231                naga::ShaderStage::Fragment => naga_oil::compose::ShaderType::GlslFragment,
232                naga::ShaderStage::Compute => panic!("glsl compute not yet implemented"),
233            },
234            #[cfg(all(not(feature = "shader_format_glsl"), not(target_arch = "wasm32")))]
235            Source::Glsl(_, _) => panic!(
236                "GLSL is not supported in this configuration; use the feature `shader_format_glsl`"
237            ),
238            Source::SpirV(_) => panic!("spirv not yet implemented"),
239        }
240    }
241}
242
243#[derive(Default)]
244pub struct ShaderLoader;
245
246#[non_exhaustive]
247#[derive(Debug, Error, Display, From)]
248pub enum ShaderLoaderError {
249    #[display("Could not load shader: {_0}")]
250    Io(std::io::Error),
251    #[display("Could not parse shader: {_0}")]
252    Parse(alloc::string::FromUtf8Error),
253}
254
255impl AssetLoader for ShaderLoader {
256    type Asset = Shader;
257    type Settings = ();
258    type Error = ShaderLoaderError;
259    async fn load(
260        &self,
261        reader: &mut dyn Reader,
262        _settings: &Self::Settings,
263        load_context: &mut LoadContext<'_>,
264    ) -> Result<Shader, Self::Error> {
265        let ext = load_context.path().extension().unwrap().to_str().unwrap();
266        let path = load_context.asset_path().to_string();
267        // On windows, the path will inconsistently use \ or /.
268        // TODO: remove this once AssetPath forces cross-platform "slash" consistency. See #10511
269        let path = path.replace(std::path::MAIN_SEPARATOR, "/");
270        let mut bytes = Vec::new();
271        reader.read_to_end(&mut bytes).await?;
272        let mut shader = match ext {
273            "spv" => Shader::from_spirv(bytes, load_context.path().to_string_lossy()),
274            "wgsl" => Shader::from_wgsl(String::from_utf8(bytes)?, path),
275            "vert" => Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Vertex, path),
276            "frag" => {
277                Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Fragment, path)
278            }
279            "comp" => {
280                Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Compute, path)
281            }
282            _ => panic!("unhandled extension: {ext}"),
283        };
284
285        // collect and store file dependencies
286        for import in &shader.imports {
287            if let ShaderImport::AssetPath(asset_path) = import {
288                shader.file_dependencies.push(load_context.load(asset_path));
289            }
290        }
291        Ok(shader)
292    }
293
294    fn extensions(&self) -> &[&str] {
295        &["spv", "wgsl", "vert", "frag", "comp"]
296    }
297}
298
299#[derive(Debug, PartialEq, Eq, Clone, Hash)]
300pub enum ShaderImport {
301    AssetPath(String),
302    Custom(String),
303}
304
305impl ShaderImport {
306    pub fn module_name(&self) -> Cow<'_, String> {
307        match self {
308            ShaderImport::AssetPath(s) => Cow::Owned(format!("\"{s}\"")),
309            ShaderImport::Custom(s) => Cow::Borrowed(s),
310        }
311    }
312}
313
314/// A reference to a shader asset.
315pub enum ShaderRef {
316    /// Use the "default" shader for the current context.
317    Default,
318    /// A handle to a shader stored in the [`Assets<Shader>`](bevy_asset::Assets) resource
319    Handle(Handle<Shader>),
320    /// An asset path leading to a shader
321    Path(AssetPath<'static>),
322}
323
324impl From<Handle<Shader>> for ShaderRef {
325    fn from(handle: Handle<Shader>) -> Self {
326        Self::Handle(handle)
327    }
328}
329
330impl From<AssetPath<'static>> for ShaderRef {
331    fn from(path: AssetPath<'static>) -> Self {
332        Self::Path(path)
333    }
334}
335
336impl From<&'static str> for ShaderRef {
337    fn from(path: &'static str) -> Self {
338        Self::Path(AssetPath::from(path))
339    }
340}