1use crate::shader::*;
2use alloc::sync::Arc;
3use bevy_asset::AssetId;
4use bevy_platform::collections::{hash_map::EntryRef, HashMap, HashSet};
5use core::hash::Hash;
6use naga::valid::Capabilities;
7use thiserror::Error;
8use tracing::{debug, error};
9use wgpu_types::{DownlevelFlags, Features};
10
11#[cfg_attr(
23 not(feature = "decoupled_naga"),
24 expect(
25 clippy::large_enum_variant,
26 reason = "naga modules are the most common use, and are large"
27 )
28)]
29#[derive(Clone, Debug)]
30pub enum ShaderCacheSource<'a> {
31 SpirV(&'a [u8]),
33 Wgsl(String),
35 #[cfg(not(feature = "decoupled_naga"))]
37 Naga(naga::Module),
38}
39
40pub type CachedPipelineId = usize;
41
42struct ShaderData<ShaderModule> {
43 pipelines: HashSet<CachedPipelineId>,
44 processed_shaders: HashMap<Box<[ShaderDefVal]>, Arc<ShaderModule>>,
45 resolved_imports: HashMap<ShaderImport, AssetId<Shader>>,
46 dependents: HashSet<AssetId<Shader>>,
47}
48
49impl<T> Default for ShaderData<T> {
50 fn default() -> Self {
51 Self {
52 pipelines: Default::default(),
53 processed_shaders: Default::default(),
54 resolved_imports: Default::default(),
55 dependents: Default::default(),
56 }
57 }
58}
59
60pub struct ShaderCache<ShaderModule, RenderDevice> {
61 data: HashMap<AssetId<Shader>, ShaderData<ShaderModule>>,
62 load_module: fn(
63 &RenderDevice,
64 ShaderCacheSource,
65 &ValidateShader,
66 ) -> Result<ShaderModule, PipelineCacheError>,
67 #[cfg(feature = "shader_format_wesl")]
68 asset_paths: HashMap<wesl::syntax::ModulePath, AssetId<Shader>>,
69 shaders: HashMap<AssetId<Shader>, Shader>,
70 import_path_shaders: HashMap<ShaderImport, AssetId<Shader>>,
71 waiting_on_import: HashMap<ShaderImport, Vec<AssetId<Shader>>>,
72 pub composer: naga_oil::compose::Composer,
73}
74
75#[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, Eq, Debug, Hash)]
76pub enum ShaderDefVal {
77 Bool(String, bool),
78 Int(String, i32),
79 UInt(String, u32),
80}
81
82impl From<&str> for ShaderDefVal {
83 fn from(key: &str) -> Self {
84 ShaderDefVal::Bool(key.to_string(), true)
85 }
86}
87
88impl From<String> for ShaderDefVal {
89 fn from(key: String) -> Self {
90 ShaderDefVal::Bool(key, true)
91 }
92}
93
94impl ShaderDefVal {
95 pub fn value_as_string(&self) -> String {
96 match self {
97 ShaderDefVal::Bool(_, def) => def.to_string(),
98 ShaderDefVal::Int(_, def) => def.to_string(),
99 ShaderDefVal::UInt(_, def) => def.to_string(),
100 }
101 }
102}
103
104impl<ShaderModule, RenderDevice> ShaderCache<ShaderModule, RenderDevice> {
105 pub fn new(
106 features: Features,
107 downlevel: DownlevelFlags,
108 load_module: fn(
109 &RenderDevice,
110 ShaderCacheSource,
111 &ValidateShader,
112 ) -> Result<ShaderModule, PipelineCacheError>,
113 ) -> Self {
114 let capabilities = get_capabilities(features, downlevel);
115 #[cfg(debug_assertions)]
116 let composer = naga_oil::compose::Composer::default();
117 #[cfg(not(debug_assertions))]
118 let composer = naga_oil::compose::Composer::non_validating();
119
120 let composer = composer.with_capabilities(capabilities);
121
122 Self {
123 composer,
124 load_module,
125 data: Default::default(),
126 #[cfg(feature = "shader_format_wesl")]
127 asset_paths: Default::default(),
128 shaders: Default::default(),
129 import_path_shaders: Default::default(),
130 waiting_on_import: Default::default(),
131 }
132 }
133
134 #[expect(
135 clippy::result_large_err,
136 reason = "See https://github.com/bevyengine/bevy/issues/19220"
137 )]
138 fn add_import_to_composer(
139 composer: &mut naga_oil::compose::Composer,
140 import_path_shaders: &HashMap<ShaderImport, AssetId<Shader>>,
141 shaders: &HashMap<AssetId<Shader>, Shader>,
142 import: &ShaderImport,
143 ) -> Result<(), PipelineCacheError> {
144 if composer.contains_module(&import.module_name()) {
146 return Ok(());
147 }
148
149 let shader = import_path_shaders
151 .get(import)
152 .and_then(|handle| shaders.get(handle))
153 .ok_or(PipelineCacheError::ShaderImportNotYetAvailable)?;
154
155 for import in &shader.imports {
157 Self::add_import_to_composer(composer, import_path_shaders, shaders, import)?;
158 }
159
160 composer.add_composable_module(shader.into())?;
161 Ok(())
164 }
165
166 #[expect(
167 clippy::result_large_err,
168 reason = "See https://github.com/bevyengine/bevy/issues/19220"
169 )]
170 pub fn get(
171 &mut self,
172 render_device: &RenderDevice,
173 pipeline: CachedPipelineId,
174 id: AssetId<Shader>,
175 shader_defs: &[ShaderDefVal],
176 ) -> Result<Arc<ShaderModule>, PipelineCacheError> {
177 let shader = self
178 .shaders
179 .get(&id)
180 .ok_or(PipelineCacheError::ShaderNotLoaded(id))?;
181
182 let data = self.data.entry(id).or_default();
183 let n_asset_imports = shader
184 .imports()
185 .filter(|import| matches!(import, ShaderImport::AssetPath(_)))
186 .count();
187 let n_resolved_asset_imports = data
188 .resolved_imports
189 .keys()
190 .filter(|import| matches!(import, ShaderImport::AssetPath(_)))
191 .count();
192 if n_asset_imports != n_resolved_asset_imports {
193 return Err(PipelineCacheError::ShaderImportNotYetAvailable);
194 }
195
196 data.pipelines.insert(pipeline);
197
198 let module = match data.processed_shaders.entry_ref(shader_defs) {
200 EntryRef::Occupied(entry) => entry.into_mut(),
201 EntryRef::Vacant(entry) => {
202 debug!(
203 "processing shader {}, with shader defs {:?}",
204 id, shader_defs
205 );
206 let shader_source = match &shader.source {
207 Source::SpirV(data) => ShaderCacheSource::SpirV(data.as_ref()),
208 #[cfg(feature = "shader_format_wesl")]
209 Source::Wesl(_) => {
210 if let ShaderImport::AssetPath(path) = shader.import_path() {
211 let shader_resolver =
212 ShaderResolver::new(&self.asset_paths, &self.shaders);
213 let module_path = wesl::syntax::ModulePath::from_path(path);
214 let mut compiler_options = wesl::CompileOptions {
215 imports: true,
216 condcomp: true,
217 lower: true,
218 ..Default::default()
219 };
220
221 for shader_def in shader_defs {
222 match shader_def {
223 ShaderDefVal::Bool(key, value) => {
224 compiler_options.features.insert(key.clone(), *value);
225 }
226 _ => debug!(
227 "ShaderDefVal::Int and ShaderDefVal::UInt are not supported in wesl",
228 ),
229 }
230 }
231
232 let compiled = wesl::compile(
233 &module_path,
234 &shader_resolver,
235 &wesl::EscapeMangler,
236 &compiler_options,
237 )
238 .unwrap();
239
240 ShaderCacheSource::Wgsl(compiled.to_string())
241 } else {
242 panic!("Wesl shaders must be imported from a file");
243 }
244 }
245 _ => {
246 for import in shader.imports() {
247 Self::add_import_to_composer(
248 &mut self.composer,
249 &self.import_path_shaders,
250 &self.shaders,
251 import,
252 )?;
253 }
254
255 let shader_defs = shader_defs
256 .iter()
257 .chain(shader.shader_defs.iter())
258 .map(|def| match def.clone() {
259 ShaderDefVal::Bool(k, v) => {
260 (k, naga_oil::compose::ShaderDefValue::Bool(v))
261 }
262 ShaderDefVal::Int(k, v) => {
263 (k, naga_oil::compose::ShaderDefValue::Int(v))
264 }
265 ShaderDefVal::UInt(k, v) => {
266 (k, naga_oil::compose::ShaderDefValue::UInt(v))
267 }
268 })
269 .collect::<std::collections::HashMap<_, _>>();
270
271 let naga = self.composer.make_naga_module(
272 naga_oil::compose::NagaModuleDescriptor {
273 shader_defs,
274 ..shader.into()
275 },
276 )?;
277
278 #[cfg(not(feature = "decoupled_naga"))]
279 {
280 ShaderCacheSource::Naga(naga)
281 }
282
283 #[cfg(feature = "decoupled_naga")]
284 {
285 let mut validator = naga::valid::Validator::new(
286 naga::valid::ValidationFlags::all(),
287 self.composer.capabilities,
288 );
289 let module_info = validator.validate(&naga).unwrap();
290 let wgsl = naga::back::wgsl::write_string(
291 &naga,
292 &module_info,
293 naga::back::wgsl::WriterFlags::empty(),
294 )
295 .unwrap();
296 ShaderCacheSource::Wgsl(wgsl)
297 }
298 }
299 };
300
301 let shader_module =
302 (self.load_module)(render_device, shader_source, &shader.validate_shader)?;
303
304 entry.insert(Arc::new(shader_module))
305 }
306 };
307
308 Ok(module.clone())
309 }
310
311 fn clear(&mut self, id: AssetId<Shader>) -> Vec<CachedPipelineId> {
312 let mut shaders_to_clear = vec![id];
313 let mut pipelines_to_queue = Vec::new();
314 while let Some(handle) = shaders_to_clear.pop() {
315 if let Some(data) = self.data.get_mut(&handle) {
316 data.processed_shaders.clear();
317 pipelines_to_queue.extend(data.pipelines.iter().copied());
318 shaders_to_clear.extend(data.dependents.iter().copied());
319
320 if let Some(Shader { import_path, .. }) = self.shaders.get(&handle) {
321 self.composer
322 .remove_composable_module(&import_path.module_name());
323 }
324 }
325 }
326
327 pipelines_to_queue
328 }
329
330 pub fn set_shader(&mut self, id: AssetId<Shader>, shader: Shader) -> Vec<CachedPipelineId> {
331 let pipelines_to_queue = self.clear(id);
332 let path = shader.import_path();
333 self.import_path_shaders.insert(path.clone(), id);
334 if let Some(waiting_shaders) = self.waiting_on_import.get_mut(path) {
335 for waiting_shader in waiting_shaders.drain(..) {
336 let data = self.data.entry(waiting_shader).or_default();
338 data.resolved_imports.insert(path.clone(), id);
339 let data = self.data.entry(id).or_default();
341 data.dependents.insert(waiting_shader);
342 }
343 }
344
345 for import in shader.imports() {
346 if let Some(import_id) = self.import_path_shaders.get(import).copied() {
347 let data = self.data.entry(id).or_default();
349 data.resolved_imports.insert(import.clone(), import_id);
350 let data = self.data.entry(import_id).or_default();
352 data.dependents.insert(id);
353 } else {
354 let waiting = self.waiting_on_import.entry(import.clone()).or_default();
355 waiting.push(id);
356 }
357 }
358
359 #[cfg(feature = "shader_format_wesl")]
360 if let Source::Wesl(_) = shader.source
361 && let ShaderImport::AssetPath(path) = shader.import_path()
362 {
363 self.asset_paths
364 .insert(wesl::syntax::ModulePath::from_path(path), id);
365 }
366 self.shaders.insert(id, shader);
367 pipelines_to_queue
368 }
369
370 pub fn remove(&mut self, id: AssetId<Shader>) -> Vec<CachedPipelineId> {
371 let pipelines_to_queue = self.clear(id);
372 if let Some(shader) = self.shaders.remove(&id) {
373 self.import_path_shaders.remove(shader.import_path());
374 }
375
376 pipelines_to_queue
377 }
378}
379
380#[cfg(feature = "shader_format_wesl")]
381pub struct ShaderResolver<'a> {
382 asset_paths: &'a HashMap<wesl::syntax::ModulePath, AssetId<Shader>>,
383 shaders: &'a HashMap<AssetId<Shader>, Shader>,
384}
385
386#[cfg(feature = "shader_format_wesl")]
387impl<'a> ShaderResolver<'a> {
388 pub fn new(
389 asset_paths: &'a HashMap<wesl::syntax::ModulePath, AssetId<Shader>>,
390 shaders: &'a HashMap<AssetId<Shader>, Shader>,
391 ) -> Self {
392 Self {
393 asset_paths,
394 shaders,
395 }
396 }
397}
398
399#[cfg(feature = "shader_format_wesl")]
400impl<'a> wesl::Resolver for ShaderResolver<'a> {
401 fn resolve_source(
402 &self,
403 module_path: &wesl::syntax::ModulePath,
404 ) -> Result<alloc::borrow::Cow<'_, str>, wesl::ResolveError> {
405 let asset_id = self.asset_paths.get(module_path).ok_or_else(|| {
406 wesl::ResolveError::ModuleNotFound(module_path.clone(), "Invalid asset id".to_string())
407 })?;
408
409 let shader = self.shaders.get(asset_id).unwrap();
410 Ok(alloc::borrow::Cow::Borrowed(shader.source.as_str()))
411 }
412}
413
414#[cfg_attr(
416 not(target_arch = "wasm32"),
417 expect(
418 clippy::large_enum_variant,
419 reason = "See https://github.com/bevyengine/bevy/issues/19220"
420 )
421)]
422#[derive(Error, Debug)]
423pub enum PipelineCacheError {
424 #[error(
425 "Pipeline could not be compiled because the following shader could not be loaded: {0:?}"
426 )]
427 ShaderNotLoaded(AssetId<Shader>),
428 #[error(transparent)]
429 ProcessShaderError(#[from] naga_oil::compose::ComposerError),
430 #[error("Shader import not yet available.")]
431 ShaderImportNotYetAvailable,
432 #[error("Could not create shader module: {0}")]
433 CreateShaderModule(String),
434}
435
436fn get_capabilities(features: Features, downlevel: DownlevelFlags) -> Capabilities {
441 let mut capabilities = Capabilities::empty();
442 capabilities.set(
443 Capabilities::PUSH_CONSTANT,
444 features.contains(Features::PUSH_CONSTANTS),
445 );
446 capabilities.set(
447 Capabilities::FLOAT64,
448 features.contains(Features::SHADER_F64),
449 );
450 capabilities.set(
451 Capabilities::PRIMITIVE_INDEX,
452 features.contains(Features::SHADER_PRIMITIVE_INDEX),
453 );
454 capabilities.set(
455 Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
456 features.contains(Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
457 );
458 capabilities.set(
459 Capabilities::STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING,
460 features.contains(Features::STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING),
461 );
462 capabilities.set(
463 Capabilities::UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
464 features.contains(Features::UNIFORM_BUFFER_BINDING_ARRAYS),
465 );
466 capabilities.set(
468 Capabilities::SAMPLER_NON_UNIFORM_INDEXING,
469 features.contains(Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
470 );
471 capabilities.set(
472 Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
473 features.contains(Features::TEXTURE_FORMAT_16BIT_NORM),
474 );
475 capabilities.set(
476 Capabilities::MULTIVIEW,
477 features.contains(Features::MULTIVIEW),
478 );
479 capabilities.set(
480 Capabilities::EARLY_DEPTH_TEST,
481 features.contains(Features::SHADER_EARLY_DEPTH_TEST),
482 );
483 capabilities.set(
484 Capabilities::SHADER_INT64,
485 features.contains(Features::SHADER_INT64),
486 );
487 capabilities.set(
488 Capabilities::SHADER_INT64_ATOMIC_MIN_MAX,
489 features.intersects(
490 Features::SHADER_INT64_ATOMIC_MIN_MAX | Features::SHADER_INT64_ATOMIC_ALL_OPS,
491 ),
492 );
493 capabilities.set(
494 Capabilities::SHADER_INT64_ATOMIC_ALL_OPS,
495 features.contains(Features::SHADER_INT64_ATOMIC_ALL_OPS),
496 );
497 capabilities.set(
498 Capabilities::MULTISAMPLED_SHADING,
499 downlevel.contains(DownlevelFlags::MULTISAMPLED_SHADING),
500 );
501 capabilities.set(
502 Capabilities::RAY_QUERY,
503 features.contains(Features::EXPERIMENTAL_RAY_QUERY),
504 );
505 capabilities.set(
506 Capabilities::DUAL_SOURCE_BLENDING,
507 features.contains(Features::DUAL_SOURCE_BLENDING),
508 );
509 capabilities.set(
510 Capabilities::CLIP_DISTANCE,
511 features.contains(Features::CLIP_DISTANCES),
512 );
513 capabilities.set(
514 Capabilities::CUBE_ARRAY_TEXTURES,
515 downlevel.contains(DownlevelFlags::CUBE_ARRAY_TEXTURES),
516 );
517 capabilities.set(
518 Capabilities::SUBGROUP,
519 features.intersects(Features::SUBGROUP | Features::SUBGROUP_VERTEX),
520 );
521 capabilities.set(
522 Capabilities::SUBGROUP_BARRIER,
523 features.intersects(Features::SUBGROUP_BARRIER),
524 );
525 capabilities.set(
526 Capabilities::SUBGROUP_VERTEX_STAGE,
527 features.contains(Features::SUBGROUP_VERTEX),
528 );
529 capabilities.set(
530 Capabilities::SHADER_FLOAT32_ATOMIC,
531 features.contains(Features::SHADER_FLOAT32_ATOMIC),
532 );
533 capabilities.set(
534 Capabilities::TEXTURE_ATOMIC,
535 features.contains(Features::TEXTURE_ATOMIC),
536 );
537 capabilities.set(
538 Capabilities::TEXTURE_INT64_ATOMIC,
539 features.contains(Features::TEXTURE_INT64_ATOMIC),
540 );
541 capabilities.set(
542 Capabilities::SHADER_FLOAT16,
543 features.contains(Features::SHADER_F16),
544 );
545 capabilities.set(
546 Capabilities::RAY_HIT_VERTEX_POSITION,
547 features.intersects(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN),
548 );
549
550 capabilities
551}