bevy_shader/
shader_cache.rs

1use crate::shader::*;
2use alloc::sync::Arc;
3use bevy_asset::AssetId;
4use bevy_platform::collections::{hash_map::EntryRef, HashMap, HashSet};
5use core::hash::Hash;
6use naga::valid::Capabilities;
7use thiserror::Error;
8use tracing::{debug, error};
9use wgpu_types::{DownlevelFlags, Features};
10
11/// Source of a shader module.
12///
13/// The source will be parsed and validated.
14///
15/// Any necessary shader translation (e.g. from WGSL to SPIR-V or vice versa)
16/// will be done internally by wgpu.
17///
18/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
19/// only WGSL source code strings are accepted.
20///
21/// This is roughly equivalent to `wgpu::ShaderSource`
22#[cfg_attr(
23    not(feature = "decoupled_naga"),
24    expect(
25        clippy::large_enum_variant,
26        reason = "naga modules are the most common use, and are large"
27    )
28)]
29#[derive(Clone, Debug)]
30pub enum ShaderCacheSource<'a> {
31    /// SPIR-V module represented as a slice of words.
32    SpirV(&'a [u8]),
33    /// WGSL module as a string slice.
34    Wgsl(String),
35    /// Naga module.
36    #[cfg(not(feature = "decoupled_naga"))]
37    Naga(naga::Module),
38}
39
40pub type CachedPipelineId = usize;
41
42struct ShaderData<ShaderModule> {
43    pipelines: HashSet<CachedPipelineId>,
44    processed_shaders: HashMap<Box<[ShaderDefVal]>, Arc<ShaderModule>>,
45    resolved_imports: HashMap<ShaderImport, AssetId<Shader>>,
46    dependents: HashSet<AssetId<Shader>>,
47}
48
49impl<T> Default for ShaderData<T> {
50    fn default() -> Self {
51        Self {
52            pipelines: Default::default(),
53            processed_shaders: Default::default(),
54            resolved_imports: Default::default(),
55            dependents: Default::default(),
56        }
57    }
58}
59
60pub struct ShaderCache<ShaderModule, RenderDevice> {
61    data: HashMap<AssetId<Shader>, ShaderData<ShaderModule>>,
62    load_module: fn(
63        &RenderDevice,
64        ShaderCacheSource,
65        &ValidateShader,
66    ) -> Result<ShaderModule, PipelineCacheError>,
67    #[cfg(feature = "shader_format_wesl")]
68    asset_paths: HashMap<wesl::syntax::ModulePath, AssetId<Shader>>,
69    shaders: HashMap<AssetId<Shader>, Shader>,
70    import_path_shaders: HashMap<ShaderImport, AssetId<Shader>>,
71    waiting_on_import: HashMap<ShaderImport, Vec<AssetId<Shader>>>,
72    pub composer: naga_oil::compose::Composer,
73}
74
75#[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, Eq, Debug, Hash)]
76pub enum ShaderDefVal {
77    Bool(String, bool),
78    Int(String, i32),
79    UInt(String, u32),
80}
81
82impl From<&str> for ShaderDefVal {
83    fn from(key: &str) -> Self {
84        ShaderDefVal::Bool(key.to_string(), true)
85    }
86}
87
88impl From<String> for ShaderDefVal {
89    fn from(key: String) -> Self {
90        ShaderDefVal::Bool(key, true)
91    }
92}
93
94impl ShaderDefVal {
95    pub fn value_as_string(&self) -> String {
96        match self {
97            ShaderDefVal::Bool(_, def) => def.to_string(),
98            ShaderDefVal::Int(_, def) => def.to_string(),
99            ShaderDefVal::UInt(_, def) => def.to_string(),
100        }
101    }
102}
103
104impl<ShaderModule, RenderDevice> ShaderCache<ShaderModule, RenderDevice> {
105    pub fn new(
106        features: Features,
107        downlevel: DownlevelFlags,
108        load_module: fn(
109            &RenderDevice,
110            ShaderCacheSource,
111            &ValidateShader,
112        ) -> Result<ShaderModule, PipelineCacheError>,
113    ) -> Self {
114        let capabilities = get_capabilities(features, downlevel);
115        #[cfg(debug_assertions)]
116        let composer = naga_oil::compose::Composer::default();
117        #[cfg(not(debug_assertions))]
118        let composer = naga_oil::compose::Composer::non_validating();
119
120        let composer = composer.with_capabilities(capabilities);
121
122        Self {
123            composer,
124            load_module,
125            data: Default::default(),
126            #[cfg(feature = "shader_format_wesl")]
127            asset_paths: Default::default(),
128            shaders: Default::default(),
129            import_path_shaders: Default::default(),
130            waiting_on_import: Default::default(),
131        }
132    }
133
134    #[expect(
135        clippy::result_large_err,
136        reason = "See https://github.com/bevyengine/bevy/issues/19220"
137    )]
138    fn add_import_to_composer(
139        composer: &mut naga_oil::compose::Composer,
140        import_path_shaders: &HashMap<ShaderImport, AssetId<Shader>>,
141        shaders: &HashMap<AssetId<Shader>, Shader>,
142        import: &ShaderImport,
143    ) -> Result<(), PipelineCacheError> {
144        // Early out if we've already imported this module
145        if composer.contains_module(&import.module_name()) {
146            return Ok(());
147        }
148
149        // Check if the import is available (this handles the recursive import case)
150        let shader = import_path_shaders
151            .get(import)
152            .and_then(|handle| shaders.get(handle))
153            .ok_or(PipelineCacheError::ShaderImportNotYetAvailable)?;
154
155        // Recurse down to ensure all import dependencies are met
156        for import in &shader.imports {
157            Self::add_import_to_composer(composer, import_path_shaders, shaders, import)?;
158        }
159
160        composer.add_composable_module(shader.into())?;
161        // if we fail to add a module the composer will tell us what is missing
162
163        Ok(())
164    }
165
166    #[expect(
167        clippy::result_large_err,
168        reason = "See https://github.com/bevyengine/bevy/issues/19220"
169    )]
170    pub fn get(
171        &mut self,
172        render_device: &RenderDevice,
173        pipeline: CachedPipelineId,
174        id: AssetId<Shader>,
175        shader_defs: &[ShaderDefVal],
176    ) -> Result<Arc<ShaderModule>, PipelineCacheError> {
177        let shader = self
178            .shaders
179            .get(&id)
180            .ok_or(PipelineCacheError::ShaderNotLoaded(id))?;
181
182        let data = self.data.entry(id).or_default();
183        let n_asset_imports = shader
184            .imports()
185            .filter(|import| matches!(import, ShaderImport::AssetPath(_)))
186            .count();
187        let n_resolved_asset_imports = data
188            .resolved_imports
189            .keys()
190            .filter(|import| matches!(import, ShaderImport::AssetPath(_)))
191            .count();
192        if n_asset_imports != n_resolved_asset_imports {
193            return Err(PipelineCacheError::ShaderImportNotYetAvailable);
194        }
195
196        data.pipelines.insert(pipeline);
197
198        // PERF: this shader_defs clone isn't great. use raw_entry_mut when it stabilizes
199        let module = match data.processed_shaders.entry_ref(shader_defs) {
200            EntryRef::Occupied(entry) => entry.into_mut(),
201            EntryRef::Vacant(entry) => {
202                debug!(
203                    "processing shader {}, with shader defs {:?}",
204                    id, shader_defs
205                );
206                let shader_source = match &shader.source {
207                    Source::SpirV(data) => ShaderCacheSource::SpirV(data.as_ref()),
208                    #[cfg(feature = "shader_format_wesl")]
209                    Source::Wesl(_) => {
210                        if let ShaderImport::AssetPath(path) = shader.import_path() {
211                            let shader_resolver =
212                                ShaderResolver::new(&self.asset_paths, &self.shaders);
213                            let module_path = wesl::syntax::ModulePath::from_path(path);
214                            let mut compiler_options = wesl::CompileOptions {
215                                imports: true,
216                                condcomp: true,
217                                lower: true,
218                                ..Default::default()
219                            };
220
221                            for shader_def in shader_defs {
222                                match shader_def {
223                                    ShaderDefVal::Bool(key, value) => {
224                                        compiler_options.features.insert(key.clone(), *value);
225                                    }
226                                    _ => debug!(
227                                        "ShaderDefVal::Int and ShaderDefVal::UInt are not supported in wesl",
228                                    ),
229                                }
230                            }
231
232                            let compiled = wesl::compile(
233                                &module_path,
234                                &shader_resolver,
235                                &wesl::EscapeMangler,
236                                &compiler_options,
237                            )
238                            .unwrap();
239
240                            ShaderCacheSource::Wgsl(compiled.to_string())
241                        } else {
242                            panic!("Wesl shaders must be imported from a file");
243                        }
244                    }
245                    _ => {
246                        for import in shader.imports() {
247                            Self::add_import_to_composer(
248                                &mut self.composer,
249                                &self.import_path_shaders,
250                                &self.shaders,
251                                import,
252                            )?;
253                        }
254
255                        let shader_defs = shader_defs
256                            .iter()
257                            .chain(shader.shader_defs.iter())
258                            .map(|def| match def.clone() {
259                                ShaderDefVal::Bool(k, v) => {
260                                    (k, naga_oil::compose::ShaderDefValue::Bool(v))
261                                }
262                                ShaderDefVal::Int(k, v) => {
263                                    (k, naga_oil::compose::ShaderDefValue::Int(v))
264                                }
265                                ShaderDefVal::UInt(k, v) => {
266                                    (k, naga_oil::compose::ShaderDefValue::UInt(v))
267                                }
268                            })
269                            .collect::<std::collections::HashMap<_, _>>();
270
271                        let naga = self.composer.make_naga_module(
272                            naga_oil::compose::NagaModuleDescriptor {
273                                shader_defs,
274                                ..shader.into()
275                            },
276                        )?;
277
278                        #[cfg(not(feature = "decoupled_naga"))]
279                        {
280                            ShaderCacheSource::Naga(naga)
281                        }
282
283                        #[cfg(feature = "decoupled_naga")]
284                        {
285                            let mut validator = naga::valid::Validator::new(
286                                naga::valid::ValidationFlags::all(),
287                                self.composer.capabilities,
288                            );
289                            let module_info = validator.validate(&naga).unwrap();
290                            let wgsl = naga::back::wgsl::write_string(
291                                &naga,
292                                &module_info,
293                                naga::back::wgsl::WriterFlags::empty(),
294                            )
295                            .unwrap();
296                            ShaderCacheSource::Wgsl(wgsl)
297                        }
298                    }
299                };
300
301                let shader_module =
302                    (self.load_module)(render_device, shader_source, &shader.validate_shader)?;
303
304                entry.insert(Arc::new(shader_module))
305            }
306        };
307
308        Ok(module.clone())
309    }
310
311    fn clear(&mut self, id: AssetId<Shader>) -> Vec<CachedPipelineId> {
312        let mut shaders_to_clear = vec![id];
313        let mut pipelines_to_queue = Vec::new();
314        while let Some(handle) = shaders_to_clear.pop() {
315            if let Some(data) = self.data.get_mut(&handle) {
316                data.processed_shaders.clear();
317                pipelines_to_queue.extend(data.pipelines.iter().copied());
318                shaders_to_clear.extend(data.dependents.iter().copied());
319
320                if let Some(Shader { import_path, .. }) = self.shaders.get(&handle) {
321                    self.composer
322                        .remove_composable_module(&import_path.module_name());
323                }
324            }
325        }
326
327        pipelines_to_queue
328    }
329
330    pub fn set_shader(&mut self, id: AssetId<Shader>, shader: Shader) -> Vec<CachedPipelineId> {
331        let pipelines_to_queue = self.clear(id);
332        let path = shader.import_path();
333        self.import_path_shaders.insert(path.clone(), id);
334        if let Some(waiting_shaders) = self.waiting_on_import.get_mut(path) {
335            for waiting_shader in waiting_shaders.drain(..) {
336                // resolve waiting shader import
337                let data = self.data.entry(waiting_shader).or_default();
338                data.resolved_imports.insert(path.clone(), id);
339                // add waiting shader as dependent of this shader
340                let data = self.data.entry(id).or_default();
341                data.dependents.insert(waiting_shader);
342            }
343        }
344
345        for import in shader.imports() {
346            if let Some(import_id) = self.import_path_shaders.get(import).copied() {
347                // resolve import because it is currently available
348                let data = self.data.entry(id).or_default();
349                data.resolved_imports.insert(import.clone(), import_id);
350                // add this shader as a dependent of the import
351                let data = self.data.entry(import_id).or_default();
352                data.dependents.insert(id);
353            } else {
354                let waiting = self.waiting_on_import.entry(import.clone()).or_default();
355                waiting.push(id);
356            }
357        }
358
359        #[cfg(feature = "shader_format_wesl")]
360        if let Source::Wesl(_) = shader.source
361            && let ShaderImport::AssetPath(path) = shader.import_path()
362        {
363            self.asset_paths
364                .insert(wesl::syntax::ModulePath::from_path(path), id);
365        }
366        self.shaders.insert(id, shader);
367        pipelines_to_queue
368    }
369
370    pub fn remove(&mut self, id: AssetId<Shader>) -> Vec<CachedPipelineId> {
371        let pipelines_to_queue = self.clear(id);
372        if let Some(shader) = self.shaders.remove(&id) {
373            self.import_path_shaders.remove(shader.import_path());
374        }
375
376        pipelines_to_queue
377    }
378}
379
380#[cfg(feature = "shader_format_wesl")]
381pub struct ShaderResolver<'a> {
382    asset_paths: &'a HashMap<wesl::syntax::ModulePath, AssetId<Shader>>,
383    shaders: &'a HashMap<AssetId<Shader>, Shader>,
384}
385
386#[cfg(feature = "shader_format_wesl")]
387impl<'a> ShaderResolver<'a> {
388    pub fn new(
389        asset_paths: &'a HashMap<wesl::syntax::ModulePath, AssetId<Shader>>,
390        shaders: &'a HashMap<AssetId<Shader>, Shader>,
391    ) -> Self {
392        Self {
393            asset_paths,
394            shaders,
395        }
396    }
397}
398
399#[cfg(feature = "shader_format_wesl")]
400impl<'a> wesl::Resolver for ShaderResolver<'a> {
401    fn resolve_source(
402        &self,
403        module_path: &wesl::syntax::ModulePath,
404    ) -> Result<alloc::borrow::Cow<'_, str>, wesl::ResolveError> {
405        let asset_id = self.asset_paths.get(module_path).ok_or_else(|| {
406            wesl::ResolveError::ModuleNotFound(module_path.clone(), "Invalid asset id".to_string())
407        })?;
408
409        let shader = self.shaders.get(asset_id).unwrap();
410        Ok(alloc::borrow::Cow::Borrowed(shader.source.as_str()))
411    }
412}
413
414/// Type of error returned by a `PipelineCache` when the creation of a GPU pipeline object failed.
415#[cfg_attr(
416    not(target_arch = "wasm32"),
417    expect(
418        clippy::large_enum_variant,
419        reason = "See https://github.com/bevyengine/bevy/issues/19220"
420    )
421)]
422#[derive(Error, Debug)]
423pub enum PipelineCacheError {
424    #[error(
425        "Pipeline could not be compiled because the following shader could not be loaded: {0:?}"
426    )]
427    ShaderNotLoaded(AssetId<Shader>),
428    #[error(transparent)]
429    ProcessShaderError(#[from] naga_oil::compose::ComposerError),
430    #[error("Shader import not yet available.")]
431    ShaderImportNotYetAvailable,
432    #[error("Could not create shader module: {0}")]
433    CreateShaderModule(String),
434}
435
436// TODO: This needs to be kept up to date with the capabilities in the `create_validator` function in wgpu-core
437// https://github.com/gfx-rs/wgpu/blob/trunk/wgpu-core/src/device/mod.rs#L449
438// We can't use the `wgpu-core` function to detect the device's capabilities because `wgpu-core` isn't included in WebGPU builds.
439/// Get the device's capabilities for use in `naga_oil`.
440fn get_capabilities(features: Features, downlevel: DownlevelFlags) -> Capabilities {
441    let mut capabilities = Capabilities::empty();
442    capabilities.set(
443        Capabilities::PUSH_CONSTANT,
444        features.contains(Features::PUSH_CONSTANTS),
445    );
446    capabilities.set(
447        Capabilities::FLOAT64,
448        features.contains(Features::SHADER_F64),
449    );
450    capabilities.set(
451        Capabilities::PRIMITIVE_INDEX,
452        features.contains(Features::SHADER_PRIMITIVE_INDEX),
453    );
454    capabilities.set(
455        Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
456        features.contains(Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
457    );
458    capabilities.set(
459        Capabilities::STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING,
460        features.contains(Features::STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING),
461    );
462    capabilities.set(
463        Capabilities::UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
464        features.contains(Features::UNIFORM_BUFFER_BINDING_ARRAYS),
465    );
466    // TODO: This needs a proper wgpu feature
467    capabilities.set(
468        Capabilities::SAMPLER_NON_UNIFORM_INDEXING,
469        features.contains(Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
470    );
471    capabilities.set(
472        Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
473        features.contains(Features::TEXTURE_FORMAT_16BIT_NORM),
474    );
475    capabilities.set(
476        Capabilities::MULTIVIEW,
477        features.contains(Features::MULTIVIEW),
478    );
479    capabilities.set(
480        Capabilities::EARLY_DEPTH_TEST,
481        features.contains(Features::SHADER_EARLY_DEPTH_TEST),
482    );
483    capabilities.set(
484        Capabilities::SHADER_INT64,
485        features.contains(Features::SHADER_INT64),
486    );
487    capabilities.set(
488        Capabilities::SHADER_INT64_ATOMIC_MIN_MAX,
489        features.intersects(
490            Features::SHADER_INT64_ATOMIC_MIN_MAX | Features::SHADER_INT64_ATOMIC_ALL_OPS,
491        ),
492    );
493    capabilities.set(
494        Capabilities::SHADER_INT64_ATOMIC_ALL_OPS,
495        features.contains(Features::SHADER_INT64_ATOMIC_ALL_OPS),
496    );
497    capabilities.set(
498        Capabilities::MULTISAMPLED_SHADING,
499        downlevel.contains(DownlevelFlags::MULTISAMPLED_SHADING),
500    );
501    capabilities.set(
502        Capabilities::RAY_QUERY,
503        features.contains(Features::EXPERIMENTAL_RAY_QUERY),
504    );
505    capabilities.set(
506        Capabilities::DUAL_SOURCE_BLENDING,
507        features.contains(Features::DUAL_SOURCE_BLENDING),
508    );
509    capabilities.set(
510        Capabilities::CLIP_DISTANCE,
511        features.contains(Features::CLIP_DISTANCES),
512    );
513    capabilities.set(
514        Capabilities::CUBE_ARRAY_TEXTURES,
515        downlevel.contains(DownlevelFlags::CUBE_ARRAY_TEXTURES),
516    );
517    capabilities.set(
518        Capabilities::SUBGROUP,
519        features.intersects(Features::SUBGROUP | Features::SUBGROUP_VERTEX),
520    );
521    capabilities.set(
522        Capabilities::SUBGROUP_BARRIER,
523        features.intersects(Features::SUBGROUP_BARRIER),
524    );
525    capabilities.set(
526        Capabilities::SUBGROUP_VERTEX_STAGE,
527        features.contains(Features::SUBGROUP_VERTEX),
528    );
529    capabilities.set(
530        Capabilities::SHADER_FLOAT32_ATOMIC,
531        features.contains(Features::SHADER_FLOAT32_ATOMIC),
532    );
533    capabilities.set(
534        Capabilities::TEXTURE_ATOMIC,
535        features.contains(Features::TEXTURE_ATOMIC),
536    );
537    capabilities.set(
538        Capabilities::TEXTURE_INT64_ATOMIC,
539        features.contains(Features::TEXTURE_INT64_ATOMIC),
540    );
541    capabilities.set(
542        Capabilities::SHADER_FLOAT16,
543        features.contains(Features::SHADER_F16),
544    );
545    capabilities.set(
546        Capabilities::RAY_HIT_VERTEX_POSITION,
547        features.intersects(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN),
548    );
549
550    capabilities
551}