bevy_shader/
shader.rs

1use super::ShaderDefVal;
2use alloc::borrow::Cow;
3use bevy_asset::{io::Reader, Asset, AssetLoader, AssetPath, Handle, LoadContext};
4use bevy_reflect::TypePath;
5use core::{marker::Copy, num::NonZero};
6use thiserror::Error;
7
8#[derive(Copy, Clone, Hash, Eq, PartialEq, PartialOrd, Ord, Debug)]
9pub struct ShaderId(NonZero<u32>);
10
11impl ShaderId {
12    #[expect(
13        clippy::new_without_default,
14        reason = "Implementing the `Default` trait on atomic IDs would imply that two `<AtomicIdType>::default()` equal each other. By only implementing `new()`, we indicate that each atomic ID created will be unique."
15    )]
16    pub fn new() -> Self {
17        use core::sync::atomic::{AtomicU32, Ordering};
18        static COUNTER: AtomicU32 = AtomicU32::new(1);
19        let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
20        Self(NonZero::<u32>::new(counter).unwrap_or_else(|| {
21            panic!("The system ran out of unique `{}`s.", stringify!(ShaderId));
22        }))
23    }
24}
25impl From<ShaderId> for NonZero<u32> {
26    fn from(value: ShaderId) -> Self {
27        value.0
28    }
29}
30impl From<NonZero<u32>> for ShaderId {
31    fn from(value: NonZero<u32>) -> Self {
32        Self(value)
33    }
34}
35
36#[derive(Error, Debug)]
37pub enum ShaderReflectError {
38    #[error(transparent)]
39    WgslParse(#[from] naga::front::wgsl::ParseError),
40    #[cfg(feature = "shader_format_glsl")]
41    #[error("GLSL Parse Error: {0:?}")]
42    GlslParse(Vec<naga::front::glsl::Error>),
43    #[cfg(feature = "shader_format_spirv")]
44    #[error(transparent)]
45    SpirVParse(#[from] naga::front::spv::Error),
46    #[error(transparent)]
47    Validation(#[from] naga::WithSpan<naga::valid::ValidationError>),
48}
49
50/// Describes whether or not to perform runtime checks on shaders.
51/// Runtime checks can be enabled for safety at the cost of speed.
52/// By default no runtime checks will be performed.
53///
54/// # Panics
55/// Because no runtime checks are performed for spirv,
56/// enabling `ValidateShader` for spirv will cause a panic
57#[derive(Clone, Debug, Default)]
58pub enum ValidateShader {
59    #[default]
60    /// No runtime checks for soundness (e.g. bound checking) are performed.
61    ///
62    /// This is suitable for trusted shaders, written by your program or dependencies you trust.
63    Disabled,
64    /// Enable's runtime checks for soundness (e.g. bound checking).
65    ///
66    /// While this can have a meaningful impact on performance,
67    /// this setting should *always* be enabled when loading untrusted shaders.
68    /// This might occur if you are creating a shader playground, running user-generated shaders
69    /// (as in `VRChat`), or writing a web browser in Bevy.
70    Enabled,
71}
72
73/// An "unprocessed" shader. It can contain preprocessor directives.
74#[derive(Asset, TypePath, Debug, Clone)]
75pub struct Shader {
76    pub path: String,
77    pub source: Source,
78    pub import_path: ShaderImport,
79    pub imports: Vec<ShaderImport>,
80    // extra imports not specified in the source string
81    pub additional_imports: Vec<naga_oil::compose::ImportDefinition>,
82    // any shader defs that will be included when this module is used
83    pub shader_defs: Vec<ShaderDefVal>,
84    // we must store strong handles to our dependencies to stop them
85    // from being immediately dropped if we are the only user.
86    pub file_dependencies: Vec<Handle<Shader>>,
87    /// Enable or disable runtime shader validation, trading safety against speed.
88    ///
89    /// Please read the [`ValidateShader`] docs for a discussion of the tradeoffs involved.
90    pub validate_shader: ValidateShader,
91}
92
93impl Shader {
94    fn preprocess(source: &str, path: &str) -> (ShaderImport, Vec<ShaderImport>) {
95        let (import_path, imports, _) = naga_oil::compose::get_preprocessor_data(source);
96
97        let import_path = import_path
98            .map(ShaderImport::Custom)
99            .unwrap_or_else(|| ShaderImport::AssetPath(path.to_owned()));
100
101        let imports = imports
102            .into_iter()
103            .map(|import| {
104                if import.import.starts_with('\"') {
105                    let import = import
106                        .import
107                        .chars()
108                        .skip(1)
109                        .take_while(|c| *c != '\"')
110                        .collect();
111                    ShaderImport::AssetPath(import)
112                } else {
113                    ShaderImport::Custom(import.import)
114                }
115            })
116            .collect();
117
118        (import_path, imports)
119    }
120
121    pub fn from_wgsl(source: impl Into<Cow<'static, str>>, path: impl Into<String>) -> Shader {
122        let source = source.into();
123        let path = path.into();
124        let (import_path, imports) = Shader::preprocess(&source, &path);
125        Shader {
126            path,
127            imports,
128            import_path,
129            source: Source::Wgsl(source),
130            additional_imports: Default::default(),
131            shader_defs: Default::default(),
132            file_dependencies: Default::default(),
133            validate_shader: ValidateShader::Disabled,
134        }
135    }
136
137    pub fn from_wgsl_with_defs(
138        source: impl Into<Cow<'static, str>>,
139        path: impl Into<String>,
140        shader_defs: Vec<ShaderDefVal>,
141    ) -> Shader {
142        Self {
143            shader_defs,
144            ..Self::from_wgsl(source, path)
145        }
146    }
147
148    pub fn from_glsl(
149        source: impl Into<Cow<'static, str>>,
150        stage: naga::ShaderStage,
151        path: impl Into<String>,
152    ) -> Shader {
153        let source = source.into();
154        let path = path.into();
155        let (import_path, imports) = Shader::preprocess(&source, &path);
156        Shader {
157            path,
158            imports,
159            import_path,
160            source: Source::Glsl(source, stage),
161            additional_imports: Default::default(),
162            shader_defs: Default::default(),
163            file_dependencies: Default::default(),
164            validate_shader: ValidateShader::Disabled,
165        }
166    }
167
168    pub fn from_spirv(source: impl Into<Cow<'static, [u8]>>, path: impl Into<String>) -> Shader {
169        let path = path.into();
170        Shader {
171            path: path.clone(),
172            imports: Vec::new(),
173            import_path: ShaderImport::AssetPath(path),
174            source: Source::SpirV(source.into()),
175            additional_imports: Default::default(),
176            shader_defs: Default::default(),
177            file_dependencies: Default::default(),
178            validate_shader: ValidateShader::Disabled,
179        }
180    }
181
182    #[cfg(feature = "shader_format_wesl")]
183    pub fn from_wesl(source: impl Into<Cow<'static, str>>, path: impl Into<String>) -> Shader {
184        let source = source.into();
185        let path = path.into();
186        let (import_path, imports) = Shader::preprocess(&source, &path);
187
188        match import_path {
189            ShaderImport::AssetPath(asset_path) => {
190                // Create the shader import path - always starting with "/"
191                let shader_path = std::path::Path::new("/").join(&asset_path);
192
193                // Convert to a string with forward slashes and without extension
194                let import_path_str = shader_path
195                    .with_extension("")
196                    .to_string_lossy()
197                    .replace('\\', "/");
198
199                let import_path = ShaderImport::AssetPath(import_path_str.to_string());
200
201                Shader {
202                    path,
203                    imports,
204                    import_path,
205                    source: Source::Wesl(source),
206                    additional_imports: Default::default(),
207                    shader_defs: Default::default(),
208                    file_dependencies: Default::default(),
209                    validate_shader: ValidateShader::Disabled,
210                }
211            }
212            ShaderImport::Custom(_) => {
213                panic!("Wesl shaders must be imported from an asset path");
214            }
215        }
216    }
217
218    pub fn set_import_path<P: Into<String>>(&mut self, import_path: P) {
219        self.import_path = ShaderImport::Custom(import_path.into());
220    }
221
222    #[must_use]
223    pub fn with_import_path<P: Into<String>>(mut self, import_path: P) -> Self {
224        self.set_import_path(import_path);
225        self
226    }
227
228    #[inline]
229    pub fn import_path(&self) -> &ShaderImport {
230        &self.import_path
231    }
232
233    pub fn imports(&self) -> impl ExactSizeIterator<Item = &ShaderImport> {
234        self.imports.iter()
235    }
236}
237
238impl<'a> From<&'a Shader> for naga_oil::compose::ComposableModuleDescriptor<'a> {
239    fn from(shader: &'a Shader) -> Self {
240        let shader_defs = shader
241            .shader_defs
242            .iter()
243            .map(|def| match def {
244                ShaderDefVal::Bool(name, b) => {
245                    (name.clone(), naga_oil::compose::ShaderDefValue::Bool(*b))
246                }
247                ShaderDefVal::Int(name, i) => {
248                    (name.clone(), naga_oil::compose::ShaderDefValue::Int(*i))
249                }
250                ShaderDefVal::UInt(name, i) => {
251                    (name.clone(), naga_oil::compose::ShaderDefValue::UInt(*i))
252                }
253            })
254            .collect();
255
256        let as_name = match &shader.import_path {
257            ShaderImport::AssetPath(asset_path) => Some(format!("\"{asset_path}\"")),
258            ShaderImport::Custom(_) => None,
259        };
260
261        naga_oil::compose::ComposableModuleDescriptor {
262            source: shader.source.as_str(),
263            file_path: &shader.path,
264            language: (&shader.source).into(),
265            additional_imports: &shader.additional_imports,
266            shader_defs,
267            as_name,
268        }
269    }
270}
271
272impl<'a> From<&'a Shader> for naga_oil::compose::NagaModuleDescriptor<'a> {
273    fn from(shader: &'a Shader) -> Self {
274        naga_oil::compose::NagaModuleDescriptor {
275            source: shader.source.as_str(),
276            file_path: &shader.path,
277            shader_type: (&shader.source).into(),
278            ..Default::default()
279        }
280    }
281}
282
283#[derive(Debug, Clone)]
284pub enum Source {
285    Wgsl(Cow<'static, str>),
286    Wesl(Cow<'static, str>),
287    Glsl(Cow<'static, str>, naga::ShaderStage),
288    SpirV(Cow<'static, [u8]>),
289    // TODO: consider the following
290    // PrecompiledSpirVMacros(HashMap<HashSet<String>, Vec<u32>>)
291    // NagaModule(Module) ... Module impls Serialize/Deserialize
292}
293
294impl Source {
295    pub fn as_str(&self) -> &str {
296        match self {
297            Source::Wgsl(s) | Source::Wesl(s) | Source::Glsl(s, _) => s,
298            Source::SpirV(_) => panic!("spirv not yet implemented"),
299        }
300    }
301}
302
303impl From<&Source> for naga_oil::compose::ShaderLanguage {
304    fn from(value: &Source) -> Self {
305        match value {
306            Source::Wgsl(_) => naga_oil::compose::ShaderLanguage::Wgsl,
307            #[cfg(any(feature = "shader_format_glsl", target_arch = "wasm32"))]
308            Source::Glsl(_, _) => naga_oil::compose::ShaderLanguage::Glsl,
309            #[cfg(all(not(feature = "shader_format_glsl"), not(target_arch = "wasm32")))]
310            Source::Glsl(_, _) => panic!(
311                "GLSL is not supported in this configuration; use the feature `shader_format_glsl`"
312            ),
313            Source::SpirV(_) => panic!("spirv not yet implemented"),
314            Source::Wesl(_) => panic!("wesl not yet implemented"),
315        }
316    }
317}
318
319impl From<&Source> for naga_oil::compose::ShaderType {
320    fn from(value: &Source) -> Self {
321        match value {
322            Source::Wgsl(_) => naga_oil::compose::ShaderType::Wgsl,
323            #[cfg(any(feature = "shader_format_glsl", target_arch = "wasm32"))]
324            Source::Glsl(_, shader_stage) => match shader_stage {
325                naga::ShaderStage::Vertex => naga_oil::compose::ShaderType::GlslVertex,
326                naga::ShaderStage::Fragment => naga_oil::compose::ShaderType::GlslFragment,
327                naga::ShaderStage::Compute => panic!("glsl compute not yet implemented"),
328                naga::ShaderStage::Task => panic!("task shaders not yet implemented"),
329                naga::ShaderStage::Mesh => panic!("mesh shaders not yet implemented"),
330            },
331            #[cfg(all(not(feature = "shader_format_glsl"), not(target_arch = "wasm32")))]
332            Source::Glsl(_, _) => panic!(
333                "GLSL is not supported in this configuration; use the feature `shader_format_glsl`"
334            ),
335            Source::SpirV(_) => panic!("spirv not yet implemented"),
336            Source::Wesl(_) => panic!("wesl not yet implemented"),
337        }
338    }
339}
340
341#[derive(Default)]
342pub struct ShaderLoader;
343
344#[non_exhaustive]
345#[derive(Debug, Error)]
346pub enum ShaderLoaderError {
347    #[error("Could not load shader: {0}")]
348    Io(#[from] std::io::Error),
349    #[error("Could not parse shader: {0}")]
350    Parse(#[from] alloc::string::FromUtf8Error),
351}
352
353/// Settings for loading shaders.
354#[derive(serde::Serialize, serde::Deserialize, Debug, Default)]
355pub struct ShaderSettings {
356    /// The `#define` specified for this shader.
357    pub shader_defs: Vec<ShaderDefVal>,
358}
359
360impl AssetLoader for ShaderLoader {
361    type Asset = Shader;
362    type Settings = ShaderSettings;
363    type Error = ShaderLoaderError;
364    async fn load(
365        &self,
366        reader: &mut dyn Reader,
367        settings: &Self::Settings,
368        load_context: &mut LoadContext<'_>,
369    ) -> Result<Shader, Self::Error> {
370        let ext = load_context.path().extension().unwrap().to_str().unwrap();
371        let path = load_context.asset_path().to_string();
372        // On windows, the path will inconsistently use \ or /.
373        // TODO: remove this once AssetPath forces cross-platform "slash" consistency. See #10511
374        let path = path.replace(std::path::MAIN_SEPARATOR, "/");
375        let mut bytes = Vec::new();
376        reader.read_to_end(&mut bytes).await?;
377        if ext != "wgsl" && !settings.shader_defs.is_empty() {
378            tracing::warn!(
379                "Tried to load a non-wgsl shader with shader defs, this isn't supported: \
380                    The shader defs will be ignored."
381            );
382        }
383        let mut shader = match ext {
384            "spv" => Shader::from_spirv(bytes, load_context.path().to_string_lossy()),
385            "wgsl" => Shader::from_wgsl_with_defs(
386                String::from_utf8(bytes)?,
387                path,
388                settings.shader_defs.clone(),
389            ),
390            "vert" => Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Vertex, path),
391            "frag" => {
392                Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Fragment, path)
393            }
394            "comp" => {
395                Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Compute, path)
396            }
397            #[cfg(feature = "shader_format_wesl")]
398            "wesl" => Shader::from_wesl(String::from_utf8(bytes)?, path),
399            _ => panic!("unhandled extension: {ext}"),
400        };
401
402        // collect and store file dependencies
403        for import in &shader.imports {
404            if let ShaderImport::AssetPath(asset_path) = import {
405                shader.file_dependencies.push(load_context.load(asset_path));
406            }
407        }
408        Ok(shader)
409    }
410
411    fn extensions(&self) -> &[&str] {
412        &["spv", "wgsl", "vert", "frag", "comp", "wesl"]
413    }
414}
415
416#[derive(Debug, PartialEq, Eq, Clone, Hash)]
417pub enum ShaderImport {
418    AssetPath(String),
419    Custom(String),
420}
421
422impl ShaderImport {
423    pub fn module_name(&self) -> Cow<'_, String> {
424        match self {
425            ShaderImport::AssetPath(s) => Cow::Owned(format!("\"{s}\"")),
426            ShaderImport::Custom(s) => Cow::Borrowed(s),
427        }
428    }
429}
430
431/// A reference to a shader asset.
432#[derive(Default)]
433pub enum ShaderRef {
434    /// Use the "default" shader for the current context.
435    #[default]
436    Default,
437    /// A handle to a shader stored in the [`Assets<Shader>`](bevy_asset::Assets) resource
438    Handle(Handle<Shader>),
439    /// An asset path leading to a shader
440    Path(AssetPath<'static>),
441}
442
443impl From<Handle<Shader>> for ShaderRef {
444    fn from(handle: Handle<Shader>) -> Self {
445        Self::Handle(handle)
446    }
447}
448
449impl From<AssetPath<'static>> for ShaderRef {
450    fn from(path: AssetPath<'static>) -> Self {
451        Self::Path(path)
452    }
453}
454
455impl From<&'static str> for ShaderRef {
456    fn from(path: &'static str) -> Self {
457        Self::Path(AssetPath::from(path))
458    }
459}