bevy_render/render_resource/
pipeline_cache.rs

1use crate::renderer::WgpuWrapper;
2use crate::{
3    render_resource::*,
4    renderer::{RenderAdapter, RenderDevice},
5    Extract,
6};
7use alloc::{borrow::Cow, sync::Arc};
8use bevy_asset::{AssetEvent, AssetId, Assets};
9use bevy_ecs::{
10    event::EventReader,
11    resource::Resource,
12    system::{Res, ResMut},
13};
14use bevy_platform::collections::{hash_map::EntryRef, HashMap, HashSet};
15use bevy_tasks::Task;
16use bevy_utils::default;
17use core::{future::Future, hash::Hash, mem, ops::Deref};
18use naga::valid::Capabilities;
19use std::sync::{Mutex, PoisonError};
20use thiserror::Error;
21use tracing::{debug, error};
22#[cfg(feature = "shader_format_spirv")]
23use wgpu::util::make_spirv;
24use wgpu::{
25    DownlevelFlags, Features, PipelineCompilationOptions,
26    VertexBufferLayout as RawVertexBufferLayout,
27};
28
29/// A descriptor for a [`Pipeline`].
30///
31/// Used to store a heterogenous collection of render and compute pipeline descriptors together.
32#[derive(Debug)]
33pub enum PipelineDescriptor {
34    RenderPipelineDescriptor(Box<RenderPipelineDescriptor>),
35    ComputePipelineDescriptor(Box<ComputePipelineDescriptor>),
36}
37
38/// A pipeline defining the data layout and shader logic for a specific GPU task.
39///
40/// Used to store a heterogenous collection of render and compute pipelines together.
41#[derive(Debug)]
42pub enum Pipeline {
43    RenderPipeline(RenderPipeline),
44    ComputePipeline(ComputePipeline),
45}
46
47type CachedPipelineId = usize;
48
49/// Index of a cached render pipeline in a [`PipelineCache`].
50#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, PartialOrd, Ord)]
51pub struct CachedRenderPipelineId(CachedPipelineId);
52
53impl CachedRenderPipelineId {
54    /// An invalid cached render pipeline index, often used to initialize a variable.
55    pub const INVALID: Self = CachedRenderPipelineId(usize::MAX);
56
57    #[inline]
58    pub fn id(&self) -> usize {
59        self.0
60    }
61}
62
63/// Index of a cached compute pipeline in a [`PipelineCache`].
64#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
65pub struct CachedComputePipelineId(CachedPipelineId);
66
67impl CachedComputePipelineId {
68    /// An invalid cached compute pipeline index, often used to initialize a variable.
69    pub const INVALID: Self = CachedComputePipelineId(usize::MAX);
70
71    #[inline]
72    pub fn id(&self) -> usize {
73        self.0
74    }
75}
76
77pub struct CachedPipeline {
78    pub descriptor: PipelineDescriptor,
79    pub state: CachedPipelineState,
80}
81
82/// State of a cached pipeline inserted into a [`PipelineCache`].
83#[derive(Debug)]
84pub enum CachedPipelineState {
85    /// The pipeline GPU object is queued for creation.
86    Queued,
87    /// The pipeline GPU object is being created.
88    Creating(Task<Result<Pipeline, PipelineCacheError>>),
89    /// The pipeline GPU object was created successfully and is available (allocated on the GPU).
90    Ok(Pipeline),
91    /// An error occurred while trying to create the pipeline GPU object.
92    Err(PipelineCacheError),
93}
94
95impl CachedPipelineState {
96    /// Convenience method to "unwrap" a pipeline state into its underlying GPU object.
97    ///
98    /// # Returns
99    ///
100    /// The method returns the allocated pipeline GPU object.
101    ///
102    /// # Panics
103    ///
104    /// This method panics if the pipeline GPU object is not available, either because it is
105    /// pending creation or because an error occurred while attempting to create GPU object.
106    pub fn unwrap(&self) -> &Pipeline {
107        match self {
108            CachedPipelineState::Ok(pipeline) => pipeline,
109            CachedPipelineState::Queued => {
110                panic!("Pipeline has not been compiled yet. It is still in the 'Queued' state.")
111            }
112            CachedPipelineState::Creating(..) => {
113                panic!("Pipeline has not been compiled yet. It is still in the 'Creating' state.")
114            }
115            CachedPipelineState::Err(err) => panic!("{}", err),
116        }
117    }
118}
119
120#[derive(Default)]
121struct ShaderData {
122    pipelines: HashSet<CachedPipelineId>,
123    processed_shaders: HashMap<Box<[ShaderDefVal]>, Arc<WgpuWrapper<ShaderModule>>>,
124    resolved_imports: HashMap<ShaderImport, AssetId<Shader>>,
125    dependents: HashSet<AssetId<Shader>>,
126}
127
128struct ShaderCache {
129    data: HashMap<AssetId<Shader>, ShaderData>,
130    #[cfg(feature = "shader_format_wesl")]
131    asset_paths: HashMap<wesl::syntax::ModulePath, AssetId<Shader>>,
132    shaders: HashMap<AssetId<Shader>, Shader>,
133    import_path_shaders: HashMap<ShaderImport, AssetId<Shader>>,
134    waiting_on_import: HashMap<ShaderImport, Vec<AssetId<Shader>>>,
135    composer: naga_oil::compose::Composer,
136}
137
138#[derive(Clone, PartialEq, Eq, Debug, Hash)]
139pub enum ShaderDefVal {
140    Bool(String, bool),
141    Int(String, i32),
142    UInt(String, u32),
143}
144
145impl From<&str> for ShaderDefVal {
146    fn from(key: &str) -> Self {
147        ShaderDefVal::Bool(key.to_string(), true)
148    }
149}
150
151impl From<String> for ShaderDefVal {
152    fn from(key: String) -> Self {
153        ShaderDefVal::Bool(key, true)
154    }
155}
156
157impl ShaderDefVal {
158    pub fn value_as_string(&self) -> String {
159        match self {
160            ShaderDefVal::Bool(_, def) => def.to_string(),
161            ShaderDefVal::Int(_, def) => def.to_string(),
162            ShaderDefVal::UInt(_, def) => def.to_string(),
163        }
164    }
165}
166
167impl ShaderCache {
168    fn new(render_device: &RenderDevice, render_adapter: &RenderAdapter) -> Self {
169        let capabilities = get_capabilities(
170            render_device.features(),
171            render_adapter.get_downlevel_capabilities().flags,
172        );
173
174        #[cfg(debug_assertions)]
175        let composer = naga_oil::compose::Composer::default();
176        #[cfg(not(debug_assertions))]
177        let composer = naga_oil::compose::Composer::non_validating();
178
179        let composer = composer.with_capabilities(capabilities);
180
181        Self {
182            composer,
183            data: Default::default(),
184            #[cfg(feature = "shader_format_wesl")]
185            asset_paths: Default::default(),
186            shaders: Default::default(),
187            import_path_shaders: Default::default(),
188            waiting_on_import: Default::default(),
189        }
190    }
191
192    fn add_import_to_composer(
193        composer: &mut naga_oil::compose::Composer,
194        import_path_shaders: &HashMap<ShaderImport, AssetId<Shader>>,
195        shaders: &HashMap<AssetId<Shader>, Shader>,
196        import: &ShaderImport,
197    ) -> Result<(), PipelineCacheError> {
198        if !composer.contains_module(&import.module_name()) {
199            if let Some(shader_handle) = import_path_shaders.get(import) {
200                if let Some(shader) = shaders.get(shader_handle) {
201                    for import in &shader.imports {
202                        Self::add_import_to_composer(
203                            composer,
204                            import_path_shaders,
205                            shaders,
206                            import,
207                        )?;
208                    }
209
210                    composer.add_composable_module(shader.into())?;
211                } else {
212                    Err(PipelineCacheError::ShaderImportNotYetAvailable)?;
213                }
214            } else {
215                Err(PipelineCacheError::ShaderImportNotYetAvailable)?;
216            }
217            // if we fail to add a module the composer will tell us what is missing
218        }
219
220        Ok(())
221    }
222
223    fn get(
224        &mut self,
225        render_device: &RenderDevice,
226        pipeline: CachedPipelineId,
227        id: AssetId<Shader>,
228        shader_defs: &[ShaderDefVal],
229    ) -> Result<Arc<WgpuWrapper<ShaderModule>>, PipelineCacheError> {
230        let shader = self
231            .shaders
232            .get(&id)
233            .ok_or(PipelineCacheError::ShaderNotLoaded(id))?;
234
235        let data = self.data.entry(id).or_default();
236        let n_asset_imports = shader
237            .imports()
238            .filter(|import| matches!(import, ShaderImport::AssetPath(_)))
239            .count();
240        let n_resolved_asset_imports = data
241            .resolved_imports
242            .keys()
243            .filter(|import| matches!(import, ShaderImport::AssetPath(_)))
244            .count();
245        if n_asset_imports != n_resolved_asset_imports {
246            return Err(PipelineCacheError::ShaderImportNotYetAvailable);
247        }
248
249        data.pipelines.insert(pipeline);
250
251        // PERF: this shader_defs clone isn't great. use raw_entry_mut when it stabilizes
252        let module = match data.processed_shaders.entry_ref(shader_defs) {
253            EntryRef::Occupied(entry) => entry.into_mut(),
254            EntryRef::Vacant(entry) => {
255                let mut shader_defs = shader_defs.to_vec();
256                #[cfg(all(feature = "webgl", target_arch = "wasm32", not(feature = "webgpu")))]
257                {
258                    shader_defs.push("NO_ARRAY_TEXTURES_SUPPORT".into());
259                    shader_defs.push("NO_CUBE_ARRAY_TEXTURES_SUPPORT".into());
260                    shader_defs.push("SIXTEEN_BYTE_ALIGNMENT".into());
261                }
262
263                if cfg!(target_abi = "sim") {
264                    shader_defs.push("NO_CUBE_ARRAY_TEXTURES_SUPPORT".into());
265                }
266
267                shader_defs.push(ShaderDefVal::UInt(
268                    String::from("AVAILABLE_STORAGE_BUFFER_BINDINGS"),
269                    render_device.limits().max_storage_buffers_per_shader_stage,
270                ));
271
272                debug!(
273                    "processing shader {}, with shader defs {:?}",
274                    id, shader_defs
275                );
276                let shader_source = match &shader.source {
277                    #[cfg(feature = "shader_format_spirv")]
278                    Source::SpirV(data) => make_spirv(data),
279                    #[cfg(feature = "shader_format_wesl")]
280                    Source::Wesl(_) => {
281                        if let ShaderImport::AssetPath(path) = shader.import_path() {
282                            let shader_resolver =
283                                ShaderResolver::new(&self.asset_paths, &self.shaders);
284                            let module_path = wesl::syntax::ModulePath::from_path(path);
285                            let mut compiler_options = wesl::CompileOptions {
286                                imports: true,
287                                condcomp: true,
288                                lower: true,
289                                ..default()
290                            };
291
292                            for shader_def in shader_defs {
293                                match shader_def {
294                                    ShaderDefVal::Bool(key, value) => {
295                                        compiler_options.features.insert(key.clone(), value);
296                                    }
297                                    _ => debug!(
298                                        "ShaderDefVal::Int and ShaderDefVal::UInt are not supported in wesl",
299                                    ),
300                                }
301                            }
302
303                            let compiled = wesl::compile(
304                                &module_path,
305                                &shader_resolver,
306                                &wesl::EscapeMangler,
307                                &compiler_options,
308                            )
309                            .unwrap();
310
311                            let naga = naga::front::wgsl::parse_str(&compiled.to_string()).unwrap();
312                            ShaderSource::Naga(Cow::Owned(naga))
313                        } else {
314                            panic!("Wesl shaders must be imported from a file");
315                        }
316                    }
317                    #[cfg(not(feature = "shader_format_spirv"))]
318                    Source::SpirV(_) => {
319                        unimplemented!(
320                            "Enable feature \"shader_format_spirv\" to use SPIR-V shaders"
321                        )
322                    }
323                    _ => {
324                        for import in shader.imports() {
325                            Self::add_import_to_composer(
326                                &mut self.composer,
327                                &self.import_path_shaders,
328                                &self.shaders,
329                                import,
330                            )?;
331                        }
332
333                        let shader_defs = shader_defs
334                            .into_iter()
335                            .chain(shader.shader_defs.iter().cloned())
336                            .map(|def| match def {
337                                ShaderDefVal::Bool(k, v) => {
338                                    (k, naga_oil::compose::ShaderDefValue::Bool(v))
339                                }
340                                ShaderDefVal::Int(k, v) => {
341                                    (k, naga_oil::compose::ShaderDefValue::Int(v))
342                                }
343                                ShaderDefVal::UInt(k, v) => {
344                                    (k, naga_oil::compose::ShaderDefValue::UInt(v))
345                                }
346                            })
347                            .collect::<std::collections::HashMap<_, _>>();
348
349                        let naga = self.composer.make_naga_module(
350                            naga_oil::compose::NagaModuleDescriptor {
351                                shader_defs,
352                                ..shader.into()
353                            },
354                        )?;
355
356                        #[cfg(not(feature = "decoupled_naga"))]
357                        {
358                            ShaderSource::Naga(Cow::Owned(naga))
359                        }
360
361                        #[cfg(feature = "decoupled_naga")]
362                        {
363                            let mut validator = naga::valid::Validator::new(
364                                naga::valid::ValidationFlags::all(),
365                                self.composer.capabilities,
366                            );
367                            let module_info = validator.validate(&naga).unwrap();
368                            let wgsl = Cow::Owned(
369                                naga::back::wgsl::write_string(
370                                    &naga,
371                                    &module_info,
372                                    naga::back::wgsl::WriterFlags::empty(),
373                                )
374                                .unwrap(),
375                            );
376                            ShaderSource::Wgsl(wgsl)
377                        }
378                    }
379                };
380
381                let module_descriptor = ShaderModuleDescriptor {
382                    label: None,
383                    source: shader_source,
384                };
385
386                render_device
387                    .wgpu_device()
388                    .push_error_scope(wgpu::ErrorFilter::Validation);
389
390                let shader_module = match shader.validate_shader {
391                    ValidateShader::Enabled => {
392                        render_device.create_and_validate_shader_module(module_descriptor)
393                    }
394                    // SAFETY: we are interfacing with shader code, which may contain undefined behavior,
395                    // such as indexing out of bounds.
396                    // The checks required are prohibitively expensive and a poor default for game engines.
397                    ValidateShader::Disabled => unsafe {
398                        render_device.create_shader_module(module_descriptor)
399                    },
400                };
401
402                let error = render_device.wgpu_device().pop_error_scope();
403
404                // `now_or_never` will return Some if the future is ready and None otherwise.
405                // On native platforms, wgpu will yield the error immediately while on wasm it may take longer since the browser APIs are asynchronous.
406                // So to keep the complexity of the ShaderCache low, we will only catch this error early on native platforms,
407                // and on wasm the error will be handled by wgpu and crash the application.
408                if let Some(Some(wgpu::Error::Validation { description, .. })) =
409                    bevy_tasks::futures::now_or_never(error)
410                {
411                    return Err(PipelineCacheError::CreateShaderModule(description));
412                }
413
414                entry.insert(Arc::new(WgpuWrapper::new(shader_module)))
415            }
416        };
417
418        Ok(module.clone())
419    }
420
421    fn clear(&mut self, id: AssetId<Shader>) -> Vec<CachedPipelineId> {
422        let mut shaders_to_clear = vec![id];
423        let mut pipelines_to_queue = Vec::new();
424        while let Some(handle) = shaders_to_clear.pop() {
425            if let Some(data) = self.data.get_mut(&handle) {
426                data.processed_shaders.clear();
427                pipelines_to_queue.extend(data.pipelines.iter().copied());
428                shaders_to_clear.extend(data.dependents.iter().copied());
429
430                if let Some(Shader { import_path, .. }) = self.shaders.get(&handle) {
431                    self.composer
432                        .remove_composable_module(&import_path.module_name());
433                }
434            }
435        }
436
437        pipelines_to_queue
438    }
439
440    fn set_shader(&mut self, id: AssetId<Shader>, shader: Shader) -> Vec<CachedPipelineId> {
441        let pipelines_to_queue = self.clear(id);
442        let path = shader.import_path();
443        self.import_path_shaders.insert(path.clone(), id);
444        if let Some(waiting_shaders) = self.waiting_on_import.get_mut(path) {
445            for waiting_shader in waiting_shaders.drain(..) {
446                // resolve waiting shader import
447                let data = self.data.entry(waiting_shader).or_default();
448                data.resolved_imports.insert(path.clone(), id);
449                // add waiting shader as dependent of this shader
450                let data = self.data.entry(id).or_default();
451                data.dependents.insert(waiting_shader);
452            }
453        }
454
455        for import in shader.imports() {
456            if let Some(import_id) = self.import_path_shaders.get(import).copied() {
457                // resolve import because it is currently available
458                let data = self.data.entry(id).or_default();
459                data.resolved_imports.insert(import.clone(), import_id);
460                // add this shader as a dependent of the import
461                let data = self.data.entry(import_id).or_default();
462                data.dependents.insert(id);
463            } else {
464                let waiting = self.waiting_on_import.entry(import.clone()).or_default();
465                waiting.push(id);
466            }
467        }
468
469        #[cfg(feature = "shader_format_wesl")]
470        if let Source::Wesl(_) = shader.source {
471            if let ShaderImport::AssetPath(path) = shader.import_path() {
472                self.asset_paths
473                    .insert(wesl::syntax::ModulePath::from_path(path), id);
474            }
475        }
476        self.shaders.insert(id, shader);
477        pipelines_to_queue
478    }
479
480    fn remove(&mut self, id: AssetId<Shader>) -> Vec<CachedPipelineId> {
481        let pipelines_to_queue = self.clear(id);
482        if let Some(shader) = self.shaders.remove(&id) {
483            self.import_path_shaders.remove(shader.import_path());
484        }
485
486        pipelines_to_queue
487    }
488}
489
490#[cfg(feature = "shader_format_wesl")]
491pub struct ShaderResolver<'a> {
492    asset_paths: &'a HashMap<wesl::syntax::ModulePath, AssetId<Shader>>,
493    shaders: &'a HashMap<AssetId<Shader>, Shader>,
494}
495
496#[cfg(feature = "shader_format_wesl")]
497impl<'a> ShaderResolver<'a> {
498    pub fn new(
499        asset_paths: &'a HashMap<wesl::syntax::ModulePath, AssetId<Shader>>,
500        shaders: &'a HashMap<AssetId<Shader>, Shader>,
501    ) -> Self {
502        Self {
503            asset_paths,
504            shaders,
505        }
506    }
507}
508
509#[cfg(feature = "shader_format_wesl")]
510impl<'a> wesl::Resolver for ShaderResolver<'a> {
511    fn resolve_source(
512        &self,
513        module_path: &wesl::syntax::ModulePath,
514    ) -> Result<Cow<str>, wesl::ResolveError> {
515        let asset_id = self.asset_paths.get(module_path).ok_or_else(|| {
516            wesl::ResolveError::ModuleNotFound(module_path.clone(), "Invalid asset id".to_string())
517        })?;
518
519        let shader = self.shaders.get(asset_id).unwrap();
520        Ok(Cow::Borrowed(shader.source.as_str()))
521    }
522}
523
524type LayoutCacheKey = (Vec<BindGroupLayoutId>, Vec<PushConstantRange>);
525#[derive(Default)]
526struct LayoutCache {
527    layouts: HashMap<LayoutCacheKey, Arc<WgpuWrapper<PipelineLayout>>>,
528}
529
530impl LayoutCache {
531    fn get(
532        &mut self,
533        render_device: &RenderDevice,
534        bind_group_layouts: &[BindGroupLayout],
535        push_constant_ranges: Vec<PushConstantRange>,
536    ) -> Arc<WgpuWrapper<PipelineLayout>> {
537        let bind_group_ids = bind_group_layouts.iter().map(BindGroupLayout::id).collect();
538        self.layouts
539            .entry((bind_group_ids, push_constant_ranges))
540            .or_insert_with_key(|(_, push_constant_ranges)| {
541                let bind_group_layouts = bind_group_layouts
542                    .iter()
543                    .map(BindGroupLayout::value)
544                    .collect::<Vec<_>>();
545                Arc::new(WgpuWrapper::new(render_device.create_pipeline_layout(
546                    &PipelineLayoutDescriptor {
547                        bind_group_layouts: &bind_group_layouts,
548                        push_constant_ranges,
549                        ..default()
550                    },
551                )))
552            })
553            .clone()
554    }
555}
556
557/// Cache for render and compute pipelines.
558///
559/// The cache stores existing render and compute pipelines allocated on the GPU, as well as
560/// pending creation. Pipelines inserted into the cache are identified by a unique ID, which
561/// can be used to retrieve the actual GPU object once it's ready. The creation of the GPU
562/// pipeline object is deferred to the [`RenderSet::Render`] step, just before the render
563/// graph starts being processed, as this requires access to the GPU.
564///
565/// Note that the cache does not perform automatic deduplication of identical pipelines. It is
566/// up to the user not to insert the same pipeline twice to avoid wasting GPU resources.
567///
568/// [`RenderSet::Render`]: crate::RenderSet::Render
569#[derive(Resource)]
570pub struct PipelineCache {
571    layout_cache: Arc<Mutex<LayoutCache>>,
572    shader_cache: Arc<Mutex<ShaderCache>>,
573    device: RenderDevice,
574    pipelines: Vec<CachedPipeline>,
575    waiting_pipelines: HashSet<CachedPipelineId>,
576    new_pipelines: Mutex<Vec<CachedPipeline>>,
577    /// If `true`, disables asynchronous pipeline compilation.
578    /// This has no effect on macOS, wasm, or without the `multi_threaded` feature.
579    synchronous_pipeline_compilation: bool,
580}
581
582impl PipelineCache {
583    /// Returns an iterator over the pipelines in the pipeline cache.
584    pub fn pipelines(&self) -> impl Iterator<Item = &CachedPipeline> {
585        self.pipelines.iter()
586    }
587
588    /// Returns a iterator of the IDs of all currently waiting pipelines.
589    pub fn waiting_pipelines(&self) -> impl Iterator<Item = CachedPipelineId> + '_ {
590        self.waiting_pipelines.iter().copied()
591    }
592
593    /// Create a new pipeline cache associated with the given render device.
594    pub fn new(
595        device: RenderDevice,
596        render_adapter: RenderAdapter,
597        synchronous_pipeline_compilation: bool,
598    ) -> Self {
599        Self {
600            shader_cache: Arc::new(Mutex::new(ShaderCache::new(&device, &render_adapter))),
601            device,
602            layout_cache: default(),
603            waiting_pipelines: default(),
604            new_pipelines: default(),
605            pipelines: default(),
606            synchronous_pipeline_compilation,
607        }
608    }
609
610    /// Get the state of a cached render pipeline.
611    ///
612    /// See [`PipelineCache::queue_render_pipeline()`].
613    #[inline]
614    pub fn get_render_pipeline_state(&self, id: CachedRenderPipelineId) -> &CachedPipelineState {
615        // If the pipeline id isn't in `pipelines`, it's queued in `new_pipelines`
616        self.pipelines
617            .get(id.0)
618            .map_or(&CachedPipelineState::Queued, |pipeline| &pipeline.state)
619    }
620
621    /// Get the state of a cached compute pipeline.
622    ///
623    /// See [`PipelineCache::queue_compute_pipeline()`].
624    #[inline]
625    pub fn get_compute_pipeline_state(&self, id: CachedComputePipelineId) -> &CachedPipelineState {
626        // If the pipeline id isn't in `pipelines`, it's queued in `new_pipelines`
627        self.pipelines
628            .get(id.0)
629            .map_or(&CachedPipelineState::Queued, |pipeline| &pipeline.state)
630    }
631
632    /// Get the render pipeline descriptor a cached render pipeline was inserted from.
633    ///
634    /// See [`PipelineCache::queue_render_pipeline()`].
635    ///
636    /// **Note**: Be careful calling this method. It will panic if called with a pipeline that
637    /// has been queued but has not yet been processed by [`PipelineCache::process_queue()`].
638    #[inline]
639    pub fn get_render_pipeline_descriptor(
640        &self,
641        id: CachedRenderPipelineId,
642    ) -> &RenderPipelineDescriptor {
643        match &self.pipelines[id.0].descriptor {
644            PipelineDescriptor::RenderPipelineDescriptor(descriptor) => descriptor,
645            PipelineDescriptor::ComputePipelineDescriptor(_) => unreachable!(),
646        }
647    }
648
649    /// Get the compute pipeline descriptor a cached render pipeline was inserted from.
650    ///
651    /// See [`PipelineCache::queue_compute_pipeline()`].
652    ///
653    /// **Note**: Be careful calling this method. It will panic if called with a pipeline that
654    /// has been queued but has not yet been processed by [`PipelineCache::process_queue()`].
655    #[inline]
656    pub fn get_compute_pipeline_descriptor(
657        &self,
658        id: CachedComputePipelineId,
659    ) -> &ComputePipelineDescriptor {
660        match &self.pipelines[id.0].descriptor {
661            PipelineDescriptor::RenderPipelineDescriptor(_) => unreachable!(),
662            PipelineDescriptor::ComputePipelineDescriptor(descriptor) => descriptor,
663        }
664    }
665
666    /// Try to retrieve a render pipeline GPU object from a cached ID.
667    ///
668    /// # Returns
669    ///
670    /// This method returns a successfully created render pipeline if any, or `None` if the pipeline
671    /// was not created yet or if there was an error during creation. You can check the actual creation
672    /// state with [`PipelineCache::get_render_pipeline_state()`].
673    #[inline]
674    pub fn get_render_pipeline(&self, id: CachedRenderPipelineId) -> Option<&RenderPipeline> {
675        if let CachedPipelineState::Ok(Pipeline::RenderPipeline(pipeline)) =
676            &self.pipelines.get(id.0)?.state
677        {
678            Some(pipeline)
679        } else {
680            None
681        }
682    }
683
684    /// Wait for a render pipeline to finish compiling.
685    #[inline]
686    pub fn block_on_render_pipeline(&mut self, id: CachedRenderPipelineId) {
687        if self.pipelines.len() <= id.0 {
688            self.process_queue();
689        }
690
691        let state = &mut self.pipelines[id.0].state;
692        if let CachedPipelineState::Creating(task) = state {
693            *state = match bevy_tasks::block_on(task) {
694                Ok(p) => CachedPipelineState::Ok(p),
695                Err(e) => CachedPipelineState::Err(e),
696            };
697        }
698    }
699
700    /// Try to retrieve a compute pipeline GPU object from a cached ID.
701    ///
702    /// # Returns
703    ///
704    /// This method returns a successfully created compute pipeline if any, or `None` if the pipeline
705    /// was not created yet or if there was an error during creation. You can check the actual creation
706    /// state with [`PipelineCache::get_compute_pipeline_state()`].
707    #[inline]
708    pub fn get_compute_pipeline(&self, id: CachedComputePipelineId) -> Option<&ComputePipeline> {
709        if let CachedPipelineState::Ok(Pipeline::ComputePipeline(pipeline)) =
710            &self.pipelines.get(id.0)?.state
711        {
712            Some(pipeline)
713        } else {
714            None
715        }
716    }
717
718    /// Insert a render pipeline into the cache, and queue its creation.
719    ///
720    /// The pipeline is always inserted and queued for creation. There is no attempt to deduplicate it with
721    /// an already cached pipeline.
722    ///
723    /// # Returns
724    ///
725    /// This method returns the unique render shader ID of the cached pipeline, which can be used to query
726    /// the caching state with [`get_render_pipeline_state()`] and to retrieve the created GPU pipeline once
727    /// it's ready with [`get_render_pipeline()`].
728    ///
729    /// [`get_render_pipeline_state()`]: PipelineCache::get_render_pipeline_state
730    /// [`get_render_pipeline()`]: PipelineCache::get_render_pipeline
731    pub fn queue_render_pipeline(
732        &self,
733        descriptor: RenderPipelineDescriptor,
734    ) -> CachedRenderPipelineId {
735        let mut new_pipelines = self
736            .new_pipelines
737            .lock()
738            .unwrap_or_else(PoisonError::into_inner);
739        let id = CachedRenderPipelineId(self.pipelines.len() + new_pipelines.len());
740        new_pipelines.push(CachedPipeline {
741            descriptor: PipelineDescriptor::RenderPipelineDescriptor(Box::new(descriptor)),
742            state: CachedPipelineState::Queued,
743        });
744        id
745    }
746
747    /// Insert a compute pipeline into the cache, and queue its creation.
748    ///
749    /// The pipeline is always inserted and queued for creation. There is no attempt to deduplicate it with
750    /// an already cached pipeline.
751    ///
752    /// # Returns
753    ///
754    /// This method returns the unique compute shader ID of the cached pipeline, which can be used to query
755    /// the caching state with [`get_compute_pipeline_state()`] and to retrieve the created GPU pipeline once
756    /// it's ready with [`get_compute_pipeline()`].
757    ///
758    /// [`get_compute_pipeline_state()`]: PipelineCache::get_compute_pipeline_state
759    /// [`get_compute_pipeline()`]: PipelineCache::get_compute_pipeline
760    pub fn queue_compute_pipeline(
761        &self,
762        descriptor: ComputePipelineDescriptor,
763    ) -> CachedComputePipelineId {
764        let mut new_pipelines = self
765            .new_pipelines
766            .lock()
767            .unwrap_or_else(PoisonError::into_inner);
768        let id = CachedComputePipelineId(self.pipelines.len() + new_pipelines.len());
769        new_pipelines.push(CachedPipeline {
770            descriptor: PipelineDescriptor::ComputePipelineDescriptor(Box::new(descriptor)),
771            state: CachedPipelineState::Queued,
772        });
773        id
774    }
775
776    fn set_shader(&mut self, id: AssetId<Shader>, shader: &Shader) {
777        let mut shader_cache = self.shader_cache.lock().unwrap();
778        let pipelines_to_queue = shader_cache.set_shader(id, shader.clone());
779        for cached_pipeline in pipelines_to_queue {
780            self.pipelines[cached_pipeline].state = CachedPipelineState::Queued;
781            self.waiting_pipelines.insert(cached_pipeline);
782        }
783    }
784
785    fn remove_shader(&mut self, shader: AssetId<Shader>) {
786        let mut shader_cache = self.shader_cache.lock().unwrap();
787        let pipelines_to_queue = shader_cache.remove(shader);
788        for cached_pipeline in pipelines_to_queue {
789            self.pipelines[cached_pipeline].state = CachedPipelineState::Queued;
790            self.waiting_pipelines.insert(cached_pipeline);
791        }
792    }
793
794    fn start_create_render_pipeline(
795        &mut self,
796        id: CachedPipelineId,
797        descriptor: RenderPipelineDescriptor,
798    ) -> CachedPipelineState {
799        let device = self.device.clone();
800        let shader_cache = self.shader_cache.clone();
801        let layout_cache = self.layout_cache.clone();
802
803        create_pipeline_task(
804            async move {
805                let mut shader_cache = shader_cache.lock().unwrap();
806                let mut layout_cache = layout_cache.lock().unwrap();
807
808                let vertex_module = match shader_cache.get(
809                    &device,
810                    id,
811                    descriptor.vertex.shader.id(),
812                    &descriptor.vertex.shader_defs,
813                ) {
814                    Ok(module) => module,
815                    Err(err) => return Err(err),
816                };
817
818                let fragment_module = match &descriptor.fragment {
819                    Some(fragment) => {
820                        match shader_cache.get(
821                            &device,
822                            id,
823                            fragment.shader.id(),
824                            &fragment.shader_defs,
825                        ) {
826                            Ok(module) => Some(module),
827                            Err(err) => return Err(err),
828                        }
829                    }
830                    None => None,
831                };
832
833                let layout =
834                    if descriptor.layout.is_empty() && descriptor.push_constant_ranges.is_empty() {
835                        None
836                    } else {
837                        Some(layout_cache.get(
838                            &device,
839                            &descriptor.layout,
840                            descriptor.push_constant_ranges.to_vec(),
841                        ))
842                    };
843
844                drop((shader_cache, layout_cache));
845
846                let vertex_buffer_layouts = descriptor
847                    .vertex
848                    .buffers
849                    .iter()
850                    .map(|layout| RawVertexBufferLayout {
851                        array_stride: layout.array_stride,
852                        attributes: &layout.attributes,
853                        step_mode: layout.step_mode,
854                    })
855                    .collect::<Vec<_>>();
856
857                let fragment_data = descriptor.fragment.as_ref().map(|fragment| {
858                    (
859                        fragment_module.unwrap(),
860                        fragment.entry_point.deref(),
861                        fragment.targets.as_slice(),
862                    )
863                });
864
865                // TODO: Expose the rest of this somehow
866                let compilation_options = PipelineCompilationOptions {
867                    constants: &default(),
868                    zero_initialize_workgroup_memory: descriptor.zero_initialize_workgroup_memory,
869                };
870
871                let descriptor = RawRenderPipelineDescriptor {
872                    multiview: None,
873                    depth_stencil: descriptor.depth_stencil.clone(),
874                    label: descriptor.label.as_deref(),
875                    layout: layout.as_ref().map(|layout| -> &PipelineLayout { layout }),
876                    multisample: descriptor.multisample,
877                    primitive: descriptor.primitive,
878                    vertex: RawVertexState {
879                        buffers: &vertex_buffer_layouts,
880                        entry_point: Some(descriptor.vertex.entry_point.deref()),
881                        module: &vertex_module,
882                        // TODO: Should this be the same as the fragment compilation options?
883                        compilation_options: compilation_options.clone(),
884                    },
885                    fragment: fragment_data
886                        .as_ref()
887                        .map(|(module, entry_point, targets)| RawFragmentState {
888                            entry_point: Some(entry_point),
889                            module,
890                            targets,
891                            // TODO: Should this be the same as the vertex compilation options?
892                            compilation_options,
893                        }),
894                    cache: None,
895                };
896
897                Ok(Pipeline::RenderPipeline(
898                    device.create_render_pipeline(&descriptor),
899                ))
900            },
901            self.synchronous_pipeline_compilation,
902        )
903    }
904
905    fn start_create_compute_pipeline(
906        &mut self,
907        id: CachedPipelineId,
908        descriptor: ComputePipelineDescriptor,
909    ) -> CachedPipelineState {
910        let device = self.device.clone();
911        let shader_cache = self.shader_cache.clone();
912        let layout_cache = self.layout_cache.clone();
913
914        create_pipeline_task(
915            async move {
916                let mut shader_cache = shader_cache.lock().unwrap();
917                let mut layout_cache = layout_cache.lock().unwrap();
918
919                let compute_module = match shader_cache.get(
920                    &device,
921                    id,
922                    descriptor.shader.id(),
923                    &descriptor.shader_defs,
924                ) {
925                    Ok(module) => module,
926                    Err(err) => return Err(err),
927                };
928
929                let layout =
930                    if descriptor.layout.is_empty() && descriptor.push_constant_ranges.is_empty() {
931                        None
932                    } else {
933                        Some(layout_cache.get(
934                            &device,
935                            &descriptor.layout,
936                            descriptor.push_constant_ranges.to_vec(),
937                        ))
938                    };
939
940                drop((shader_cache, layout_cache));
941
942                let descriptor = RawComputePipelineDescriptor {
943                    label: descriptor.label.as_deref(),
944                    layout: layout.as_ref().map(|layout| -> &PipelineLayout { layout }),
945                    module: &compute_module,
946                    entry_point: Some(&descriptor.entry_point),
947                    // TODO: Expose the rest of this somehow
948                    compilation_options: PipelineCompilationOptions {
949                        constants: &default(),
950                        zero_initialize_workgroup_memory: descriptor
951                            .zero_initialize_workgroup_memory,
952                    },
953                    cache: None,
954                };
955
956                Ok(Pipeline::ComputePipeline(
957                    device.create_compute_pipeline(&descriptor),
958                ))
959            },
960            self.synchronous_pipeline_compilation,
961        )
962    }
963
964    /// Process the pipeline queue and create all pending pipelines if possible.
965    ///
966    /// This is generally called automatically during the [`RenderSet::Render`] step, but can
967    /// be called manually to force creation at a different time.
968    ///
969    /// [`RenderSet::Render`]: crate::RenderSet::Render
970    pub fn process_queue(&mut self) {
971        let mut waiting_pipelines = mem::take(&mut self.waiting_pipelines);
972        let mut pipelines = mem::take(&mut self.pipelines);
973
974        {
975            let mut new_pipelines = self
976                .new_pipelines
977                .lock()
978                .unwrap_or_else(PoisonError::into_inner);
979            for new_pipeline in new_pipelines.drain(..) {
980                let id = pipelines.len();
981                pipelines.push(new_pipeline);
982                waiting_pipelines.insert(id);
983            }
984        }
985
986        for id in waiting_pipelines {
987            self.process_pipeline(&mut pipelines[id], id);
988        }
989
990        self.pipelines = pipelines;
991    }
992
993    fn process_pipeline(&mut self, cached_pipeline: &mut CachedPipeline, id: usize) {
994        match &mut cached_pipeline.state {
995            CachedPipelineState::Queued => {
996                cached_pipeline.state = match &cached_pipeline.descriptor {
997                    PipelineDescriptor::RenderPipelineDescriptor(descriptor) => {
998                        self.start_create_render_pipeline(id, *descriptor.clone())
999                    }
1000                    PipelineDescriptor::ComputePipelineDescriptor(descriptor) => {
1001                        self.start_create_compute_pipeline(id, *descriptor.clone())
1002                    }
1003                };
1004            }
1005
1006            CachedPipelineState::Creating(task) => match bevy_tasks::futures::check_ready(task) {
1007                Some(Ok(pipeline)) => {
1008                    cached_pipeline.state = CachedPipelineState::Ok(pipeline);
1009                    return;
1010                }
1011                Some(Err(err)) => cached_pipeline.state = CachedPipelineState::Err(err),
1012                _ => (),
1013            },
1014
1015            CachedPipelineState::Err(err) => match err {
1016                // Retry
1017                PipelineCacheError::ShaderNotLoaded(_)
1018                | PipelineCacheError::ShaderImportNotYetAvailable => {
1019                    cached_pipeline.state = CachedPipelineState::Queued;
1020                }
1021
1022                // Shader could not be processed ... retrying won't help
1023                PipelineCacheError::ProcessShaderError(err) => {
1024                    let error_detail =
1025                        err.emit_to_string(&self.shader_cache.lock().unwrap().composer);
1026                    error!("failed to process shader:\n{}", error_detail);
1027                    return;
1028                }
1029                PipelineCacheError::CreateShaderModule(description) => {
1030                    error!("failed to create shader module: {}", description);
1031                    return;
1032                }
1033            },
1034
1035            CachedPipelineState::Ok(_) => return,
1036        }
1037
1038        // Retry
1039        self.waiting_pipelines.insert(id);
1040    }
1041
1042    pub(crate) fn process_pipeline_queue_system(mut cache: ResMut<Self>) {
1043        cache.process_queue();
1044    }
1045
1046    pub(crate) fn extract_shaders(
1047        mut cache: ResMut<Self>,
1048        shaders: Extract<Res<Assets<Shader>>>,
1049        mut events: Extract<EventReader<AssetEvent<Shader>>>,
1050    ) {
1051        for event in events.read() {
1052            #[expect(
1053                clippy::match_same_arms,
1054                reason = "LoadedWithDependencies is marked as a TODO, so it's likely this will no longer lint soon."
1055            )]
1056            match event {
1057                // PERF: Instead of blocking waiting for the shader cache lock, try again next frame if the lock is currently held
1058                AssetEvent::Added { id } | AssetEvent::Modified { id } => {
1059                    if let Some(shader) = shaders.get(*id) {
1060                        cache.set_shader(*id, shader);
1061                    }
1062                }
1063                AssetEvent::Removed { id } => cache.remove_shader(*id),
1064                AssetEvent::Unused { .. } => {}
1065                AssetEvent::LoadedWithDependencies { .. } => {
1066                    // TODO: handle this
1067                }
1068            }
1069        }
1070    }
1071}
1072
1073#[cfg(all(
1074    not(target_arch = "wasm32"),
1075    not(target_os = "macos"),
1076    feature = "multi_threaded"
1077))]
1078fn create_pipeline_task(
1079    task: impl Future<Output = Result<Pipeline, PipelineCacheError>> + Send + 'static,
1080    sync: bool,
1081) -> CachedPipelineState {
1082    if !sync {
1083        return CachedPipelineState::Creating(bevy_tasks::AsyncComputeTaskPool::get().spawn(task));
1084    }
1085
1086    match futures_lite::future::block_on(task) {
1087        Ok(pipeline) => CachedPipelineState::Ok(pipeline),
1088        Err(err) => CachedPipelineState::Err(err),
1089    }
1090}
1091
1092#[cfg(any(
1093    target_arch = "wasm32",
1094    target_os = "macos",
1095    not(feature = "multi_threaded")
1096))]
1097fn create_pipeline_task(
1098    task: impl Future<Output = Result<Pipeline, PipelineCacheError>> + Send + 'static,
1099    _sync: bool,
1100) -> CachedPipelineState {
1101    match futures_lite::future::block_on(task) {
1102        Ok(pipeline) => CachedPipelineState::Ok(pipeline),
1103        Err(err) => CachedPipelineState::Err(err),
1104    }
1105}
1106
1107/// Type of error returned by a [`PipelineCache`] when the creation of a GPU pipeline object failed.
1108#[derive(Error, Debug)]
1109pub enum PipelineCacheError {
1110    #[error(
1111        "Pipeline could not be compiled because the following shader could not be loaded: {0:?}"
1112    )]
1113    ShaderNotLoaded(AssetId<Shader>),
1114    #[error(transparent)]
1115    ProcessShaderError(#[from] naga_oil::compose::ComposerError),
1116    #[error("Shader import not yet available.")]
1117    ShaderImportNotYetAvailable,
1118    #[error("Could not create shader module: {0}")]
1119    CreateShaderModule(String),
1120}
1121
1122// TODO: This needs to be kept up to date with the capabilities in the `create_validator` function in wgpu-core
1123// https://github.com/gfx-rs/wgpu/blob/trunk/wgpu-core/src/device/mod.rs#L449
1124// We can't use the `wgpu-core` function to detect the device's capabilities because `wgpu-core` isn't included in WebGPU builds.
1125/// Get the device's capabilities for use in `naga_oil`.
1126fn get_capabilities(features: Features, downlevel: DownlevelFlags) -> Capabilities {
1127    let mut capabilities = Capabilities::empty();
1128    capabilities.set(
1129        Capabilities::PUSH_CONSTANT,
1130        features.contains(Features::PUSH_CONSTANTS),
1131    );
1132    capabilities.set(
1133        Capabilities::FLOAT64,
1134        features.contains(Features::SHADER_F64),
1135    );
1136    capabilities.set(
1137        Capabilities::PRIMITIVE_INDEX,
1138        features.contains(Features::SHADER_PRIMITIVE_INDEX),
1139    );
1140    capabilities.set(
1141        Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
1142        features.contains(Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
1143    );
1144    capabilities.set(
1145        Capabilities::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING,
1146        features.contains(Features::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING),
1147    );
1148    // TODO: This needs a proper wgpu feature
1149    capabilities.set(
1150        Capabilities::SAMPLER_NON_UNIFORM_INDEXING,
1151        features.contains(Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
1152    );
1153    capabilities.set(
1154        Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
1155        features.contains(Features::TEXTURE_FORMAT_16BIT_NORM),
1156    );
1157    capabilities.set(
1158        Capabilities::MULTIVIEW,
1159        features.contains(Features::MULTIVIEW),
1160    );
1161    capabilities.set(
1162        Capabilities::EARLY_DEPTH_TEST,
1163        features.contains(Features::SHADER_EARLY_DEPTH_TEST),
1164    );
1165    capabilities.set(
1166        Capabilities::SHADER_INT64,
1167        features.contains(Features::SHADER_INT64),
1168    );
1169    capabilities.set(
1170        Capabilities::SHADER_INT64_ATOMIC_MIN_MAX,
1171        features.intersects(
1172            Features::SHADER_INT64_ATOMIC_MIN_MAX | Features::SHADER_INT64_ATOMIC_ALL_OPS,
1173        ),
1174    );
1175    capabilities.set(
1176        Capabilities::SHADER_INT64_ATOMIC_ALL_OPS,
1177        features.contains(Features::SHADER_INT64_ATOMIC_ALL_OPS),
1178    );
1179    capabilities.set(
1180        Capabilities::MULTISAMPLED_SHADING,
1181        downlevel.contains(DownlevelFlags::MULTISAMPLED_SHADING),
1182    );
1183    capabilities.set(
1184        Capabilities::DUAL_SOURCE_BLENDING,
1185        features.contains(Features::DUAL_SOURCE_BLENDING),
1186    );
1187    capabilities.set(
1188        Capabilities::CUBE_ARRAY_TEXTURES,
1189        downlevel.contains(DownlevelFlags::CUBE_ARRAY_TEXTURES),
1190    );
1191    capabilities.set(
1192        Capabilities::SUBGROUP,
1193        features.intersects(Features::SUBGROUP | Features::SUBGROUP_VERTEX),
1194    );
1195    capabilities.set(
1196        Capabilities::SUBGROUP_BARRIER,
1197        features.intersects(Features::SUBGROUP_BARRIER),
1198    );
1199    capabilities.set(
1200        Capabilities::SUBGROUP_VERTEX_STAGE,
1201        features.contains(Features::SUBGROUP_VERTEX),
1202    );
1203    capabilities.set(
1204        Capabilities::SHADER_FLOAT32_ATOMIC,
1205        features.contains(Features::SHADER_FLOAT32_ATOMIC),
1206    );
1207    capabilities.set(
1208        Capabilities::TEXTURE_ATOMIC,
1209        features.contains(Features::TEXTURE_ATOMIC),
1210    );
1211    capabilities.set(
1212        Capabilities::TEXTURE_INT64_ATOMIC,
1213        features.contains(Features::TEXTURE_INT64_ATOMIC),
1214    );
1215
1216    capabilities
1217}