bevy_render/render_resource/
shader.rs

1use super::ShaderDefVal;
2use crate::define_atomic_id;
3use alloc::borrow::Cow;
4use bevy_asset::{io::Reader, Asset, AssetLoader, AssetPath, Handle, LoadContext};
5use bevy_reflect::TypePath;
6use core::marker::Copy;
7use thiserror::Error;
8
9define_atomic_id!(ShaderId);
10
11#[derive(Error, Debug)]
12pub enum ShaderReflectError {
13    #[error(transparent)]
14    WgslParse(#[from] naga::front::wgsl::ParseError),
15    #[cfg(feature = "shader_format_glsl")]
16    #[error("GLSL Parse Error: {0:?}")]
17    GlslParse(Vec<naga::front::glsl::Error>),
18    #[cfg(feature = "shader_format_spirv")]
19    #[error(transparent)]
20    SpirVParse(#[from] naga::front::spv::Error),
21    #[error(transparent)]
22    Validation(#[from] naga::WithSpan<naga::valid::ValidationError>),
23}
24
25/// Describes whether or not to perform runtime checks on shaders.
26/// Runtime checks can be enabled for safety at the cost of speed.
27/// By default no runtime checks will be performed.
28///
29/// # Panics
30/// Because no runtime checks are performed for spirv,
31/// enabling `ValidateShader` for spirv will cause a panic
32#[derive(Clone, Debug, Default)]
33pub enum ValidateShader {
34    #[default]
35    /// No runtime checks for soundness (e.g. bound checking) are performed.
36    ///
37    /// This is suitable for trusted shaders, written by your program or dependencies you trust.
38    Disabled,
39    /// Enable's runtime checks for soundness (e.g. bound checking).
40    ///
41    /// While this can have a meaningful impact on performance,
42    /// this setting should *always* be enabled when loading untrusted shaders.
43    /// This might occur if you are creating a shader playground, running user-generated shaders
44    /// (as in `VRChat`), or writing a web browser in Bevy.
45    Enabled,
46}
47
48/// A shader, as defined by its [`ShaderSource`](wgpu::ShaderSource) and [`ShaderStage`](naga::ShaderStage)
49/// This is an "unprocessed" shader. It can contain preprocessor directives.
50#[derive(Asset, TypePath, Debug, Clone)]
51pub struct Shader {
52    pub path: String,
53    pub source: Source,
54    pub import_path: ShaderImport,
55    pub imports: Vec<ShaderImport>,
56    // extra imports not specified in the source string
57    pub additional_imports: Vec<naga_oil::compose::ImportDefinition>,
58    // any shader defs that will be included when this module is used
59    pub shader_defs: Vec<ShaderDefVal>,
60    // we must store strong handles to our dependencies to stop them
61    // from being immediately dropped if we are the only user.
62    pub file_dependencies: Vec<Handle<Shader>>,
63    /// Enable or disable runtime shader validation, trading safety against speed.
64    ///
65    /// Please read the [`ValidateShader`] docs for a discussion of the tradeoffs involved.
66    pub validate_shader: ValidateShader,
67}
68
69impl Shader {
70    fn preprocess(source: &str, path: &str) -> (ShaderImport, Vec<ShaderImport>) {
71        let (import_path, imports, _) = naga_oil::compose::get_preprocessor_data(source);
72
73        let import_path = import_path
74            .map(ShaderImport::Custom)
75            .unwrap_or_else(|| ShaderImport::AssetPath(path.to_owned()));
76
77        let imports = imports
78            .into_iter()
79            .map(|import| {
80                if import.import.starts_with('\"') {
81                    let import = import
82                        .import
83                        .chars()
84                        .skip(1)
85                        .take_while(|c| *c != '\"')
86                        .collect();
87                    ShaderImport::AssetPath(import)
88                } else {
89                    ShaderImport::Custom(import.import)
90                }
91            })
92            .collect();
93
94        (import_path, imports)
95    }
96
97    pub fn from_wgsl(source: impl Into<Cow<'static, str>>, path: impl Into<String>) -> Shader {
98        let source = source.into();
99        let path = path.into();
100        let (import_path, imports) = Shader::preprocess(&source, &path);
101        Shader {
102            path,
103            imports,
104            import_path,
105            source: Source::Wgsl(source),
106            additional_imports: Default::default(),
107            shader_defs: Default::default(),
108            file_dependencies: Default::default(),
109            validate_shader: ValidateShader::Disabled,
110        }
111    }
112
113    pub fn from_wgsl_with_defs(
114        source: impl Into<Cow<'static, str>>,
115        path: impl Into<String>,
116        shader_defs: Vec<ShaderDefVal>,
117    ) -> Shader {
118        Self {
119            shader_defs,
120            ..Self::from_wgsl(source, path)
121        }
122    }
123
124    pub fn from_glsl(
125        source: impl Into<Cow<'static, str>>,
126        stage: naga::ShaderStage,
127        path: impl Into<String>,
128    ) -> Shader {
129        let source = source.into();
130        let path = path.into();
131        let (import_path, imports) = Shader::preprocess(&source, &path);
132        Shader {
133            path,
134            imports,
135            import_path,
136            source: Source::Glsl(source, stage),
137            additional_imports: Default::default(),
138            shader_defs: Default::default(),
139            file_dependencies: Default::default(),
140            validate_shader: ValidateShader::Disabled,
141        }
142    }
143
144    pub fn from_spirv(source: impl Into<Cow<'static, [u8]>>, path: impl Into<String>) -> Shader {
145        let path = path.into();
146        Shader {
147            path: path.clone(),
148            imports: Vec::new(),
149            import_path: ShaderImport::AssetPath(path),
150            source: Source::SpirV(source.into()),
151            additional_imports: Default::default(),
152            shader_defs: Default::default(),
153            file_dependencies: Default::default(),
154            validate_shader: ValidateShader::Disabled,
155        }
156    }
157
158    #[cfg(feature = "shader_format_wesl")]
159    pub fn from_wesl(source: impl Into<Cow<'static, str>>, path: impl Into<String>) -> Shader {
160        let source = source.into();
161        let path = path.into();
162        let (import_path, imports) = Shader::preprocess(&source, &path);
163
164        match import_path {
165            ShaderImport::AssetPath(asset_path) => {
166                // Create the shader import path - always starting with "/"
167                let shader_path = std::path::Path::new("/").join(&asset_path);
168
169                // Convert to a string with forward slashes and without extension
170                let import_path_str = shader_path
171                    .with_extension("")
172                    .to_string_lossy()
173                    .replace('\\', "/");
174
175                let import_path = ShaderImport::AssetPath(import_path_str.to_string());
176
177                Shader {
178                    path,
179                    imports,
180                    import_path,
181                    source: Source::Wesl(source),
182                    additional_imports: Default::default(),
183                    shader_defs: Default::default(),
184                    file_dependencies: Default::default(),
185                    validate_shader: ValidateShader::Disabled,
186                }
187            }
188            ShaderImport::Custom(_) => {
189                panic!("Wesl shaders must be imported from an asset path");
190            }
191        }
192    }
193
194    pub fn set_import_path<P: Into<String>>(&mut self, import_path: P) {
195        self.import_path = ShaderImport::Custom(import_path.into());
196    }
197
198    #[must_use]
199    pub fn with_import_path<P: Into<String>>(mut self, import_path: P) -> Self {
200        self.set_import_path(import_path);
201        self
202    }
203
204    #[inline]
205    pub fn import_path(&self) -> &ShaderImport {
206        &self.import_path
207    }
208
209    pub fn imports(&self) -> impl ExactSizeIterator<Item = &ShaderImport> {
210        self.imports.iter()
211    }
212}
213
214impl<'a> From<&'a Shader> for naga_oil::compose::ComposableModuleDescriptor<'a> {
215    fn from(shader: &'a Shader) -> Self {
216        let shader_defs = shader
217            .shader_defs
218            .iter()
219            .map(|def| match def {
220                ShaderDefVal::Bool(name, b) => {
221                    (name.clone(), naga_oil::compose::ShaderDefValue::Bool(*b))
222                }
223                ShaderDefVal::Int(name, i) => {
224                    (name.clone(), naga_oil::compose::ShaderDefValue::Int(*i))
225                }
226                ShaderDefVal::UInt(name, i) => {
227                    (name.clone(), naga_oil::compose::ShaderDefValue::UInt(*i))
228                }
229            })
230            .collect();
231
232        let as_name = match &shader.import_path {
233            ShaderImport::AssetPath(asset_path) => Some(format!("\"{asset_path}\"")),
234            ShaderImport::Custom(_) => None,
235        };
236
237        naga_oil::compose::ComposableModuleDescriptor {
238            source: shader.source.as_str(),
239            file_path: &shader.path,
240            language: (&shader.source).into(),
241            additional_imports: &shader.additional_imports,
242            shader_defs,
243            as_name,
244        }
245    }
246}
247
248impl<'a> From<&'a Shader> for naga_oil::compose::NagaModuleDescriptor<'a> {
249    fn from(shader: &'a Shader) -> Self {
250        naga_oil::compose::NagaModuleDescriptor {
251            source: shader.source.as_str(),
252            file_path: &shader.path,
253            shader_type: (&shader.source).into(),
254            ..Default::default()
255        }
256    }
257}
258
259#[derive(Debug, Clone)]
260pub enum Source {
261    Wgsl(Cow<'static, str>),
262    Wesl(Cow<'static, str>),
263    Glsl(Cow<'static, str>, naga::ShaderStage),
264    SpirV(Cow<'static, [u8]>),
265    // TODO: consider the following
266    // PrecompiledSpirVMacros(HashMap<HashSet<String>, Vec<u32>>)
267    // NagaModule(Module) ... Module impls Serialize/Deserialize
268}
269
270impl Source {
271    pub fn as_str(&self) -> &str {
272        match self {
273            Source::Wgsl(s) | Source::Wesl(s) | Source::Glsl(s, _) => s,
274            Source::SpirV(_) => panic!("spirv not yet implemented"),
275        }
276    }
277}
278
279impl From<&Source> for naga_oil::compose::ShaderLanguage {
280    fn from(value: &Source) -> Self {
281        match value {
282            Source::Wgsl(_) => naga_oil::compose::ShaderLanguage::Wgsl,
283            #[cfg(any(feature = "shader_format_glsl", target_arch = "wasm32"))]
284            Source::Glsl(_, _) => naga_oil::compose::ShaderLanguage::Glsl,
285            #[cfg(all(not(feature = "shader_format_glsl"), not(target_arch = "wasm32")))]
286            Source::Glsl(_, _) => panic!(
287                "GLSL is not supported in this configuration; use the feature `shader_format_glsl`"
288            ),
289            Source::SpirV(_) => panic!("spirv not yet implemented"),
290            Source::Wesl(_) => panic!("wesl not yet implemented"),
291        }
292    }
293}
294
295impl From<&Source> for naga_oil::compose::ShaderType {
296    fn from(value: &Source) -> Self {
297        match value {
298            Source::Wgsl(_) => naga_oil::compose::ShaderType::Wgsl,
299            #[cfg(any(feature = "shader_format_glsl", target_arch = "wasm32"))]
300            Source::Glsl(_, shader_stage) => match shader_stage {
301                naga::ShaderStage::Vertex => naga_oil::compose::ShaderType::GlslVertex,
302                naga::ShaderStage::Fragment => naga_oil::compose::ShaderType::GlslFragment,
303                naga::ShaderStage::Compute => panic!("glsl compute not yet implemented"),
304            },
305            #[cfg(all(not(feature = "shader_format_glsl"), not(target_arch = "wasm32")))]
306            Source::Glsl(_, _) => panic!(
307                "GLSL is not supported in this configuration; use the feature `shader_format_glsl`"
308            ),
309            Source::SpirV(_) => panic!("spirv not yet implemented"),
310            Source::Wesl(_) => panic!("wesl not yet implemented"),
311        }
312    }
313}
314
315#[derive(Default)]
316pub struct ShaderLoader;
317
318#[non_exhaustive]
319#[derive(Debug, Error)]
320pub enum ShaderLoaderError {
321    #[error("Could not load shader: {0}")]
322    Io(#[from] std::io::Error),
323    #[error("Could not parse shader: {0}")]
324    Parse(#[from] alloc::string::FromUtf8Error),
325}
326
327impl AssetLoader for ShaderLoader {
328    type Asset = Shader;
329    type Settings = ();
330    type Error = ShaderLoaderError;
331    async fn load(
332        &self,
333        reader: &mut dyn Reader,
334        _settings: &Self::Settings,
335        load_context: &mut LoadContext<'_>,
336    ) -> Result<Shader, Self::Error> {
337        let ext = load_context.path().extension().unwrap().to_str().unwrap();
338        let path = load_context.asset_path().to_string();
339        // On windows, the path will inconsistently use \ or /.
340        // TODO: remove this once AssetPath forces cross-platform "slash" consistency. See #10511
341        let path = path.replace(std::path::MAIN_SEPARATOR, "/");
342        let mut bytes = Vec::new();
343        reader.read_to_end(&mut bytes).await?;
344        let mut shader = match ext {
345            "spv" => Shader::from_spirv(bytes, load_context.path().to_string_lossy()),
346            "wgsl" => Shader::from_wgsl(String::from_utf8(bytes)?, path),
347            "vert" => Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Vertex, path),
348            "frag" => {
349                Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Fragment, path)
350            }
351            "comp" => {
352                Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Compute, path)
353            }
354            #[cfg(feature = "shader_format_wesl")]
355            "wesl" => Shader::from_wesl(String::from_utf8(bytes)?, path),
356            _ => panic!("unhandled extension: {ext}"),
357        };
358
359        // collect and store file dependencies
360        for import in &shader.imports {
361            if let ShaderImport::AssetPath(asset_path) = import {
362                shader.file_dependencies.push(load_context.load(asset_path));
363            }
364        }
365        Ok(shader)
366    }
367
368    fn extensions(&self) -> &[&str] {
369        &["spv", "wgsl", "vert", "frag", "comp", "wesl"]
370    }
371}
372
373#[derive(Debug, PartialEq, Eq, Clone, Hash)]
374pub enum ShaderImport {
375    AssetPath(String),
376    Custom(String),
377}
378
379impl ShaderImport {
380    pub fn module_name(&self) -> Cow<'_, String> {
381        match self {
382            ShaderImport::AssetPath(s) => Cow::Owned(format!("\"{s}\"")),
383            ShaderImport::Custom(s) => Cow::Borrowed(s),
384        }
385    }
386}
387
388/// A reference to a shader asset.
389pub enum ShaderRef {
390    /// Use the "default" shader for the current context.
391    Default,
392    /// A handle to a shader stored in the [`Assets<Shader>`](bevy_asset::Assets) resource
393    Handle(Handle<Shader>),
394    /// An asset path leading to a shader
395    Path(AssetPath<'static>),
396}
397
398impl From<Handle<Shader>> for ShaderRef {
399    fn from(handle: Handle<Shader>) -> Self {
400        Self::Handle(handle)
401    }
402}
403
404impl From<AssetPath<'static>> for ShaderRef {
405    fn from(path: AssetPath<'static>) -> Self {
406        Self::Path(path)
407    }
408}
409
410impl From<&'static str> for ShaderRef {
411    fn from(path: &'static str) -> Self {
412        Self::Path(AssetPath::from(path))
413    }
414}