bevy_render/render_resource/
pipeline_cache.rs

1use crate::renderer::WgpuWrapper;
2use crate::{
3    render_resource::*,
4    renderer::{RenderAdapter, RenderDevice},
5    Extract,
6};
7use alloc::{borrow::Cow, sync::Arc};
8use bevy_asset::{AssetEvent, AssetId, Assets};
9use bevy_ecs::{
10    event::EventReader,
11    system::{Res, ResMut, Resource},
12};
13use bevy_tasks::Task;
14use bevy_utils::{
15    default,
16    hashbrown::hash_map::EntryRef,
17    tracing::{debug, error},
18    HashMap, HashSet,
19};
20use core::{future::Future, hash::Hash, mem, ops::Deref};
21use derive_more::derive::{Display, Error, From};
22use naga::valid::Capabilities;
23use std::sync::{Mutex, PoisonError};
24#[cfg(feature = "shader_format_spirv")]
25use wgpu::util::make_spirv;
26use wgpu::{
27    DownlevelFlags, Features, PipelineCompilationOptions,
28    VertexBufferLayout as RawVertexBufferLayout,
29};
30
31/// A descriptor for a [`Pipeline`].
32///
33/// Used to store an heterogenous collection of render and compute pipeline descriptors together.
34#[derive(Debug)]
35pub enum PipelineDescriptor {
36    RenderPipelineDescriptor(Box<RenderPipelineDescriptor>),
37    ComputePipelineDescriptor(Box<ComputePipelineDescriptor>),
38}
39
40/// A pipeline defining the data layout and shader logic for a specific GPU task.
41///
42/// Used to store an heterogenous collection of render and compute pipelines together.
43#[derive(Debug)]
44pub enum Pipeline {
45    RenderPipeline(RenderPipeline),
46    ComputePipeline(ComputePipeline),
47}
48
49type CachedPipelineId = usize;
50
51/// Index of a cached render pipeline in a [`PipelineCache`].
52#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, PartialOrd, Ord)]
53pub struct CachedRenderPipelineId(CachedPipelineId);
54
55impl CachedRenderPipelineId {
56    /// An invalid cached render pipeline index, often used to initialize a variable.
57    pub const INVALID: Self = CachedRenderPipelineId(usize::MAX);
58
59    #[inline]
60    pub fn id(&self) -> usize {
61        self.0
62    }
63}
64
65/// Index of a cached compute pipeline in a [`PipelineCache`].
66#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
67pub struct CachedComputePipelineId(CachedPipelineId);
68
69impl CachedComputePipelineId {
70    /// An invalid cached compute pipeline index, often used to initialize a variable.
71    pub const INVALID: Self = CachedComputePipelineId(usize::MAX);
72
73    #[inline]
74    pub fn id(&self) -> usize {
75        self.0
76    }
77}
78
79pub struct CachedPipeline {
80    pub descriptor: PipelineDescriptor,
81    pub state: CachedPipelineState,
82}
83
84/// State of a cached pipeline inserted into a [`PipelineCache`].
85#[derive(Debug)]
86pub enum CachedPipelineState {
87    /// The pipeline GPU object is queued for creation.
88    Queued,
89    /// The pipeline GPU object is being created.
90    Creating(Task<Result<Pipeline, PipelineCacheError>>),
91    /// The pipeline GPU object was created successfully and is available (allocated on the GPU).
92    Ok(Pipeline),
93    /// An error occurred while trying to create the pipeline GPU object.
94    Err(PipelineCacheError),
95}
96
97impl CachedPipelineState {
98    /// Convenience method to "unwrap" a pipeline state into its underlying GPU object.
99    ///
100    /// # Returns
101    ///
102    /// The method returns the allocated pipeline GPU object.
103    ///
104    /// # Panics
105    ///
106    /// This method panics if the pipeline GPU object is not available, either because it is
107    /// pending creation or because an error occurred while attempting to create GPU object.
108    pub fn unwrap(&self) -> &Pipeline {
109        match self {
110            CachedPipelineState::Ok(pipeline) => pipeline,
111            CachedPipelineState::Queued => {
112                panic!("Pipeline has not been compiled yet. It is still in the 'Queued' state.")
113            }
114            CachedPipelineState::Creating(..) => {
115                panic!("Pipeline has not been compiled yet. It is still in the 'Creating' state.")
116            }
117            CachedPipelineState::Err(err) => panic!("{}", err),
118        }
119    }
120}
121
122#[derive(Default)]
123struct ShaderData {
124    pipelines: HashSet<CachedPipelineId>,
125    processed_shaders: HashMap<Box<[ShaderDefVal]>, Arc<WgpuWrapper<ShaderModule>>>,
126    resolved_imports: HashMap<ShaderImport, AssetId<Shader>>,
127    dependents: HashSet<AssetId<Shader>>,
128}
129
130struct ShaderCache {
131    data: HashMap<AssetId<Shader>, ShaderData>,
132    shaders: HashMap<AssetId<Shader>, Shader>,
133    import_path_shaders: HashMap<ShaderImport, AssetId<Shader>>,
134    waiting_on_import: HashMap<ShaderImport, Vec<AssetId<Shader>>>,
135    composer: naga_oil::compose::Composer,
136}
137
138#[derive(Clone, PartialEq, Eq, Debug, Hash)]
139pub enum ShaderDefVal {
140    Bool(String, bool),
141    Int(String, i32),
142    UInt(String, u32),
143}
144
145impl From<&str> for ShaderDefVal {
146    fn from(key: &str) -> Self {
147        ShaderDefVal::Bool(key.to_string(), true)
148    }
149}
150
151impl From<String> for ShaderDefVal {
152    fn from(key: String) -> Self {
153        ShaderDefVal::Bool(key, true)
154    }
155}
156
157impl ShaderDefVal {
158    pub fn value_as_string(&self) -> String {
159        match self {
160            ShaderDefVal::Bool(_, def) => def.to_string(),
161            ShaderDefVal::Int(_, def) => def.to_string(),
162            ShaderDefVal::UInt(_, def) => def.to_string(),
163        }
164    }
165}
166
167impl ShaderCache {
168    fn new(render_device: &RenderDevice, render_adapter: &RenderAdapter) -> Self {
169        let capabilities = get_capabilities(
170            render_device.features(),
171            render_adapter.get_downlevel_capabilities().flags,
172        );
173
174        #[cfg(debug_assertions)]
175        let composer = naga_oil::compose::Composer::default();
176        #[cfg(not(debug_assertions))]
177        let composer = naga_oil::compose::Composer::non_validating();
178
179        let composer = composer.with_capabilities(capabilities);
180
181        Self {
182            composer,
183            data: Default::default(),
184            shaders: Default::default(),
185            import_path_shaders: Default::default(),
186            waiting_on_import: Default::default(),
187        }
188    }
189
190    fn add_import_to_composer(
191        composer: &mut naga_oil::compose::Composer,
192        import_path_shaders: &HashMap<ShaderImport, AssetId<Shader>>,
193        shaders: &HashMap<AssetId<Shader>, Shader>,
194        import: &ShaderImport,
195    ) -> Result<(), PipelineCacheError> {
196        if !composer.contains_module(&import.module_name()) {
197            if let Some(shader_handle) = import_path_shaders.get(import) {
198                if let Some(shader) = shaders.get(shader_handle) {
199                    for import in &shader.imports {
200                        Self::add_import_to_composer(
201                            composer,
202                            import_path_shaders,
203                            shaders,
204                            import,
205                        )?;
206                    }
207
208                    composer.add_composable_module(shader.into())?;
209                }
210            }
211            // if we fail to add a module the composer will tell us what is missing
212        }
213
214        Ok(())
215    }
216
217    #[allow(clippy::result_large_err)]
218    fn get(
219        &mut self,
220        render_device: &RenderDevice,
221        pipeline: CachedPipelineId,
222        id: AssetId<Shader>,
223        shader_defs: &[ShaderDefVal],
224    ) -> Result<Arc<WgpuWrapper<ShaderModule>>, PipelineCacheError> {
225        let shader = self
226            .shaders
227            .get(&id)
228            .ok_or(PipelineCacheError::ShaderNotLoaded(id))?;
229        let data = self.data.entry(id).or_default();
230        let n_asset_imports = shader
231            .imports()
232            .filter(|import| matches!(import, ShaderImport::AssetPath(_)))
233            .count();
234        let n_resolved_asset_imports = data
235            .resolved_imports
236            .keys()
237            .filter(|import| matches!(import, ShaderImport::AssetPath(_)))
238            .count();
239        if n_asset_imports != n_resolved_asset_imports {
240            return Err(PipelineCacheError::ShaderImportNotYetAvailable);
241        }
242
243        data.pipelines.insert(pipeline);
244
245        // PERF: this shader_defs clone isn't great. use raw_entry_mut when it stabilizes
246        let module = match data.processed_shaders.entry_ref(shader_defs) {
247            EntryRef::Occupied(entry) => entry.into_mut(),
248            EntryRef::Vacant(entry) => {
249                let mut shader_defs = shader_defs.to_vec();
250                #[cfg(all(feature = "webgl", target_arch = "wasm32", not(feature = "webgpu")))]
251                {
252                    shader_defs.push("NO_ARRAY_TEXTURES_SUPPORT".into());
253                    shader_defs.push("NO_CUBE_ARRAY_TEXTURES_SUPPORT".into());
254                    shader_defs.push("SIXTEEN_BYTE_ALIGNMENT".into());
255                }
256
257                if cfg!(feature = "ios_simulator") {
258                    shader_defs.push("NO_CUBE_ARRAY_TEXTURES_SUPPORT".into());
259                }
260
261                shader_defs.push(ShaderDefVal::UInt(
262                    String::from("AVAILABLE_STORAGE_BUFFER_BINDINGS"),
263                    render_device.limits().max_storage_buffers_per_shader_stage,
264                ));
265
266                debug!(
267                    "processing shader {:?}, with shader defs {:?}",
268                    id, shader_defs
269                );
270                let shader_source = match &shader.source {
271                    #[cfg(feature = "shader_format_spirv")]
272                    Source::SpirV(data) => make_spirv(data),
273                    #[cfg(not(feature = "shader_format_spirv"))]
274                    Source::SpirV(_) => {
275                        unimplemented!(
276                            "Enable feature \"shader_format_spirv\" to use SPIR-V shaders"
277                        )
278                    }
279                    _ => {
280                        for import in shader.imports() {
281                            Self::add_import_to_composer(
282                                &mut self.composer,
283                                &self.import_path_shaders,
284                                &self.shaders,
285                                import,
286                            )?;
287                        }
288
289                        let shader_defs = shader_defs
290                            .into_iter()
291                            .chain(shader.shader_defs.iter().cloned())
292                            .map(|def| match def {
293                                ShaderDefVal::Bool(k, v) => {
294                                    (k, naga_oil::compose::ShaderDefValue::Bool(v))
295                                }
296                                ShaderDefVal::Int(k, v) => {
297                                    (k, naga_oil::compose::ShaderDefValue::Int(v))
298                                }
299                                ShaderDefVal::UInt(k, v) => {
300                                    (k, naga_oil::compose::ShaderDefValue::UInt(v))
301                                }
302                            })
303                            .collect::<std::collections::HashMap<_, _>>();
304
305                        let naga = self.composer.make_naga_module(
306                            naga_oil::compose::NagaModuleDescriptor {
307                                shader_defs,
308                                ..shader.into()
309                            },
310                        )?;
311
312                        ShaderSource::Naga(Cow::Owned(naga))
313                    }
314                };
315
316                let module_descriptor = ShaderModuleDescriptor {
317                    label: None,
318                    source: shader_source,
319                };
320
321                render_device
322                    .wgpu_device()
323                    .push_error_scope(wgpu::ErrorFilter::Validation);
324                let shader_module = render_device.create_shader_module(module_descriptor);
325                let error = render_device.wgpu_device().pop_error_scope();
326
327                // `now_or_never` will return Some if the future is ready and None otherwise.
328                // On native platforms, wgpu will yield the error immediately while on wasm it may take longer since the browser APIs are asynchronous.
329                // So to keep the complexity of the ShaderCache low, we will only catch this error early on native platforms,
330                // and on wasm the error will be handled by wgpu and crash the application.
331                if let Some(Some(wgpu::Error::Validation { description, .. })) =
332                    bevy_utils::futures::now_or_never(error)
333                {
334                    return Err(PipelineCacheError::CreateShaderModule(description));
335                }
336
337                entry.insert(Arc::new(WgpuWrapper::new(shader_module)))
338            }
339        };
340
341        Ok(module.clone())
342    }
343
344    fn clear(&mut self, id: AssetId<Shader>) -> Vec<CachedPipelineId> {
345        let mut shaders_to_clear = vec![id];
346        let mut pipelines_to_queue = Vec::new();
347        while let Some(handle) = shaders_to_clear.pop() {
348            if let Some(data) = self.data.get_mut(&handle) {
349                data.processed_shaders.clear();
350                pipelines_to_queue.extend(data.pipelines.iter().copied());
351                shaders_to_clear.extend(data.dependents.iter().copied());
352
353                if let Some(Shader { import_path, .. }) = self.shaders.get(&handle) {
354                    self.composer
355                        .remove_composable_module(&import_path.module_name());
356                }
357            }
358        }
359
360        pipelines_to_queue
361    }
362
363    fn set_shader(&mut self, id: AssetId<Shader>, shader: Shader) -> Vec<CachedPipelineId> {
364        let pipelines_to_queue = self.clear(id);
365        let path = shader.import_path();
366        self.import_path_shaders.insert(path.clone(), id);
367        if let Some(waiting_shaders) = self.waiting_on_import.get_mut(path) {
368            for waiting_shader in waiting_shaders.drain(..) {
369                // resolve waiting shader import
370                let data = self.data.entry(waiting_shader).or_default();
371                data.resolved_imports.insert(path.clone(), id);
372                // add waiting shader as dependent of this shader
373                let data = self.data.entry(id).or_default();
374                data.dependents.insert(waiting_shader);
375            }
376        }
377
378        for import in shader.imports() {
379            if let Some(import_id) = self.import_path_shaders.get(import).copied() {
380                // resolve import because it is currently available
381                let data = self.data.entry(id).or_default();
382                data.resolved_imports.insert(import.clone(), import_id);
383                // add this shader as a dependent of the import
384                let data = self.data.entry(import_id).or_default();
385                data.dependents.insert(id);
386            } else {
387                let waiting = self.waiting_on_import.entry(import.clone()).or_default();
388                waiting.push(id);
389            }
390        }
391
392        self.shaders.insert(id, shader);
393        pipelines_to_queue
394    }
395
396    fn remove(&mut self, id: AssetId<Shader>) -> Vec<CachedPipelineId> {
397        let pipelines_to_queue = self.clear(id);
398        if let Some(shader) = self.shaders.remove(&id) {
399            self.import_path_shaders.remove(shader.import_path());
400        }
401
402        pipelines_to_queue
403    }
404}
405
406type LayoutCacheKey = (Vec<BindGroupLayoutId>, Vec<PushConstantRange>);
407#[derive(Default)]
408struct LayoutCache {
409    layouts: HashMap<LayoutCacheKey, Arc<WgpuWrapper<PipelineLayout>>>,
410}
411
412impl LayoutCache {
413    fn get(
414        &mut self,
415        render_device: &RenderDevice,
416        bind_group_layouts: &[BindGroupLayout],
417        push_constant_ranges: Vec<PushConstantRange>,
418    ) -> Arc<WgpuWrapper<PipelineLayout>> {
419        let bind_group_ids = bind_group_layouts.iter().map(BindGroupLayout::id).collect();
420        self.layouts
421            .entry((bind_group_ids, push_constant_ranges))
422            .or_insert_with_key(|(_, push_constant_ranges)| {
423                let bind_group_layouts = bind_group_layouts
424                    .iter()
425                    .map(BindGroupLayout::value)
426                    .collect::<Vec<_>>();
427                Arc::new(WgpuWrapper::new(render_device.create_pipeline_layout(
428                    &PipelineLayoutDescriptor {
429                        bind_group_layouts: &bind_group_layouts,
430                        push_constant_ranges,
431                        ..default()
432                    },
433                )))
434            })
435            .clone()
436    }
437}
438
439/// Cache for render and compute pipelines.
440///
441/// The cache stores existing render and compute pipelines allocated on the GPU, as well as
442/// pending creation. Pipelines inserted into the cache are identified by a unique ID, which
443/// can be used to retrieve the actual GPU object once it's ready. The creation of the GPU
444/// pipeline object is deferred to the [`RenderSet::Render`] step, just before the render
445/// graph starts being processed, as this requires access to the GPU.
446///
447/// Note that the cache does not perform automatic deduplication of identical pipelines. It is
448/// up to the user not to insert the same pipeline twice to avoid wasting GPU resources.
449///
450/// [`RenderSet::Render`]: crate::RenderSet::Render
451#[derive(Resource)]
452pub struct PipelineCache {
453    layout_cache: Arc<Mutex<LayoutCache>>,
454    shader_cache: Arc<Mutex<ShaderCache>>,
455    device: RenderDevice,
456    pipelines: Vec<CachedPipeline>,
457    waiting_pipelines: HashSet<CachedPipelineId>,
458    new_pipelines: Mutex<Vec<CachedPipeline>>,
459    /// If `true`, disables asynchronous pipeline compilation.
460    /// This has no effect on macOS, wasm, or without the `multi_threaded` feature.
461    synchronous_pipeline_compilation: bool,
462}
463
464impl PipelineCache {
465    /// Returns an iterator over the pipelines in the pipeline cache.
466    pub fn pipelines(&self) -> impl Iterator<Item = &CachedPipeline> {
467        self.pipelines.iter()
468    }
469
470    /// Returns a iterator of the IDs of all currently waiting pipelines.
471    pub fn waiting_pipelines(&self) -> impl Iterator<Item = CachedPipelineId> + '_ {
472        self.waiting_pipelines.iter().copied()
473    }
474
475    /// Create a new pipeline cache associated with the given render device.
476    pub fn new(
477        device: RenderDevice,
478        render_adapter: RenderAdapter,
479        synchronous_pipeline_compilation: bool,
480    ) -> Self {
481        Self {
482            shader_cache: Arc::new(Mutex::new(ShaderCache::new(&device, &render_adapter))),
483            device,
484            layout_cache: default(),
485            waiting_pipelines: default(),
486            new_pipelines: default(),
487            pipelines: default(),
488            synchronous_pipeline_compilation,
489        }
490    }
491
492    /// Get the state of a cached render pipeline.
493    ///
494    /// See [`PipelineCache::queue_render_pipeline()`].
495    #[inline]
496    pub fn get_render_pipeline_state(&self, id: CachedRenderPipelineId) -> &CachedPipelineState {
497        &self.pipelines[id.0].state
498    }
499
500    /// Get the state of a cached compute pipeline.
501    ///
502    /// See [`PipelineCache::queue_compute_pipeline()`].
503    #[inline]
504    pub fn get_compute_pipeline_state(&self, id: CachedComputePipelineId) -> &CachedPipelineState {
505        &self.pipelines[id.0].state
506    }
507
508    /// Get the render pipeline descriptor a cached render pipeline was inserted from.
509    ///
510    /// See [`PipelineCache::queue_render_pipeline()`].
511    #[inline]
512    pub fn get_render_pipeline_descriptor(
513        &self,
514        id: CachedRenderPipelineId,
515    ) -> &RenderPipelineDescriptor {
516        match &self.pipelines[id.0].descriptor {
517            PipelineDescriptor::RenderPipelineDescriptor(descriptor) => descriptor,
518            PipelineDescriptor::ComputePipelineDescriptor(_) => unreachable!(),
519        }
520    }
521
522    /// Get the compute pipeline descriptor a cached render pipeline was inserted from.
523    ///
524    /// See [`PipelineCache::queue_compute_pipeline()`].
525    #[inline]
526    pub fn get_compute_pipeline_descriptor(
527        &self,
528        id: CachedComputePipelineId,
529    ) -> &ComputePipelineDescriptor {
530        match &self.pipelines[id.0].descriptor {
531            PipelineDescriptor::RenderPipelineDescriptor(_) => unreachable!(),
532            PipelineDescriptor::ComputePipelineDescriptor(descriptor) => descriptor,
533        }
534    }
535
536    /// Try to retrieve a render pipeline GPU object from a cached ID.
537    ///
538    /// # Returns
539    ///
540    /// This method returns a successfully created render pipeline if any, or `None` if the pipeline
541    /// was not created yet or if there was an error during creation. You can check the actual creation
542    /// state with [`PipelineCache::get_render_pipeline_state()`].
543    #[inline]
544    pub fn get_render_pipeline(&self, id: CachedRenderPipelineId) -> Option<&RenderPipeline> {
545        if let CachedPipelineState::Ok(Pipeline::RenderPipeline(pipeline)) =
546            &self.pipelines[id.0].state
547        {
548            Some(pipeline)
549        } else {
550            None
551        }
552    }
553
554    /// Wait for a render pipeline to finish compiling.
555    #[inline]
556    pub fn block_on_render_pipeline(&mut self, id: CachedRenderPipelineId) {
557        if self.pipelines.len() <= id.0 {
558            self.process_queue();
559        }
560
561        let state = &mut self.pipelines[id.0].state;
562        if let CachedPipelineState::Creating(task) = state {
563            *state = match bevy_tasks::block_on(task) {
564                Ok(p) => CachedPipelineState::Ok(p),
565                Err(e) => CachedPipelineState::Err(e),
566            };
567        }
568    }
569
570    /// Try to retrieve a compute pipeline GPU object from a cached ID.
571    ///
572    /// # Returns
573    ///
574    /// This method returns a successfully created compute pipeline if any, or `None` if the pipeline
575    /// was not created yet or if there was an error during creation. You can check the actual creation
576    /// state with [`PipelineCache::get_compute_pipeline_state()`].
577    #[inline]
578    pub fn get_compute_pipeline(&self, id: CachedComputePipelineId) -> Option<&ComputePipeline> {
579        if let CachedPipelineState::Ok(Pipeline::ComputePipeline(pipeline)) =
580            &self.pipelines[id.0].state
581        {
582            Some(pipeline)
583        } else {
584            None
585        }
586    }
587
588    /// Insert a render pipeline into the cache, and queue its creation.
589    ///
590    /// The pipeline is always inserted and queued for creation. There is no attempt to deduplicate it with
591    /// an already cached pipeline.
592    ///
593    /// # Returns
594    ///
595    /// This method returns the unique render shader ID of the cached pipeline, which can be used to query
596    /// the caching state with [`get_render_pipeline_state()`] and to retrieve the created GPU pipeline once
597    /// it's ready with [`get_render_pipeline()`].
598    ///
599    /// [`get_render_pipeline_state()`]: PipelineCache::get_render_pipeline_state
600    /// [`get_render_pipeline()`]: PipelineCache::get_render_pipeline
601    pub fn queue_render_pipeline(
602        &self,
603        descriptor: RenderPipelineDescriptor,
604    ) -> CachedRenderPipelineId {
605        let mut new_pipelines = self
606            .new_pipelines
607            .lock()
608            .unwrap_or_else(PoisonError::into_inner);
609        let id = CachedRenderPipelineId(self.pipelines.len() + new_pipelines.len());
610        new_pipelines.push(CachedPipeline {
611            descriptor: PipelineDescriptor::RenderPipelineDescriptor(Box::new(descriptor)),
612            state: CachedPipelineState::Queued,
613        });
614        id
615    }
616
617    /// Insert a compute pipeline into the cache, and queue its creation.
618    ///
619    /// The pipeline is always inserted and queued for creation. There is no attempt to deduplicate it with
620    /// an already cached pipeline.
621    ///
622    /// # Returns
623    ///
624    /// This method returns the unique compute shader ID of the cached pipeline, which can be used to query
625    /// the caching state with [`get_compute_pipeline_state()`] and to retrieve the created GPU pipeline once
626    /// it's ready with [`get_compute_pipeline()`].
627    ///
628    /// [`get_compute_pipeline_state()`]: PipelineCache::get_compute_pipeline_state
629    /// [`get_compute_pipeline()`]: PipelineCache::get_compute_pipeline
630    pub fn queue_compute_pipeline(
631        &self,
632        descriptor: ComputePipelineDescriptor,
633    ) -> CachedComputePipelineId {
634        let mut new_pipelines = self
635            .new_pipelines
636            .lock()
637            .unwrap_or_else(PoisonError::into_inner);
638        let id = CachedComputePipelineId(self.pipelines.len() + new_pipelines.len());
639        new_pipelines.push(CachedPipeline {
640            descriptor: PipelineDescriptor::ComputePipelineDescriptor(Box::new(descriptor)),
641            state: CachedPipelineState::Queued,
642        });
643        id
644    }
645
646    fn set_shader(&mut self, id: AssetId<Shader>, shader: &Shader) {
647        let mut shader_cache = self.shader_cache.lock().unwrap();
648        let pipelines_to_queue = shader_cache.set_shader(id, shader.clone());
649        for cached_pipeline in pipelines_to_queue {
650            self.pipelines[cached_pipeline].state = CachedPipelineState::Queued;
651            self.waiting_pipelines.insert(cached_pipeline);
652        }
653    }
654
655    fn remove_shader(&mut self, shader: AssetId<Shader>) {
656        let mut shader_cache = self.shader_cache.lock().unwrap();
657        let pipelines_to_queue = shader_cache.remove(shader);
658        for cached_pipeline in pipelines_to_queue {
659            self.pipelines[cached_pipeline].state = CachedPipelineState::Queued;
660            self.waiting_pipelines.insert(cached_pipeline);
661        }
662    }
663
664    fn start_create_render_pipeline(
665        &mut self,
666        id: CachedPipelineId,
667        descriptor: RenderPipelineDescriptor,
668    ) -> CachedPipelineState {
669        let device = self.device.clone();
670        let shader_cache = self.shader_cache.clone();
671        let layout_cache = self.layout_cache.clone();
672
673        create_pipeline_task(
674            async move {
675                let mut shader_cache = shader_cache.lock().unwrap();
676                let mut layout_cache = layout_cache.lock().unwrap();
677
678                let vertex_module = match shader_cache.get(
679                    &device,
680                    id,
681                    descriptor.vertex.shader.id(),
682                    &descriptor.vertex.shader_defs,
683                ) {
684                    Ok(module) => module,
685                    Err(err) => return Err(err),
686                };
687
688                let fragment_module = match &descriptor.fragment {
689                    Some(fragment) => {
690                        match shader_cache.get(
691                            &device,
692                            id,
693                            fragment.shader.id(),
694                            &fragment.shader_defs,
695                        ) {
696                            Ok(module) => Some(module),
697                            Err(err) => return Err(err),
698                        }
699                    }
700                    None => None,
701                };
702
703                let layout =
704                    if descriptor.layout.is_empty() && descriptor.push_constant_ranges.is_empty() {
705                        None
706                    } else {
707                        Some(layout_cache.get(
708                            &device,
709                            &descriptor.layout,
710                            descriptor.push_constant_ranges.to_vec(),
711                        ))
712                    };
713
714                drop((shader_cache, layout_cache));
715
716                let vertex_buffer_layouts = descriptor
717                    .vertex
718                    .buffers
719                    .iter()
720                    .map(|layout| RawVertexBufferLayout {
721                        array_stride: layout.array_stride,
722                        attributes: &layout.attributes,
723                        step_mode: layout.step_mode,
724                    })
725                    .collect::<Vec<_>>();
726
727                let fragment_data = descriptor.fragment.as_ref().map(|fragment| {
728                    (
729                        fragment_module.unwrap(),
730                        fragment.entry_point.deref(),
731                        fragment.targets.as_slice(),
732                    )
733                });
734
735                // TODO: Expose the rest of this somehow
736                let compilation_options = PipelineCompilationOptions {
737                    constants: &default(),
738                    zero_initialize_workgroup_memory: descriptor.zero_initialize_workgroup_memory,
739                };
740
741                let descriptor = RawRenderPipelineDescriptor {
742                    multiview: None,
743                    depth_stencil: descriptor.depth_stencil.clone(),
744                    label: descriptor.label.as_deref(),
745                    layout: layout.as_ref().map(|layout| -> &PipelineLayout { layout }),
746                    multisample: descriptor.multisample,
747                    primitive: descriptor.primitive,
748                    vertex: RawVertexState {
749                        buffers: &vertex_buffer_layouts,
750                        entry_point: Some(descriptor.vertex.entry_point.deref()),
751                        module: &vertex_module,
752                        // TODO: Should this be the same as the fragment compilation options?
753                        compilation_options: compilation_options.clone(),
754                    },
755                    fragment: fragment_data
756                        .as_ref()
757                        .map(|(module, entry_point, targets)| RawFragmentState {
758                            entry_point: Some(entry_point),
759                            module,
760                            targets,
761                            // TODO: Should this be the same as the vertex compilation options?
762                            compilation_options,
763                        }),
764                    cache: None,
765                };
766
767                Ok(Pipeline::RenderPipeline(
768                    device.create_render_pipeline(&descriptor),
769                ))
770            },
771            self.synchronous_pipeline_compilation,
772        )
773    }
774
775    fn start_create_compute_pipeline(
776        &mut self,
777        id: CachedPipelineId,
778        descriptor: ComputePipelineDescriptor,
779    ) -> CachedPipelineState {
780        let device = self.device.clone();
781        let shader_cache = self.shader_cache.clone();
782        let layout_cache = self.layout_cache.clone();
783
784        create_pipeline_task(
785            async move {
786                let mut shader_cache = shader_cache.lock().unwrap();
787                let mut layout_cache = layout_cache.lock().unwrap();
788
789                let compute_module = match shader_cache.get(
790                    &device,
791                    id,
792                    descriptor.shader.id(),
793                    &descriptor.shader_defs,
794                ) {
795                    Ok(module) => module,
796                    Err(err) => return Err(err),
797                };
798
799                let layout =
800                    if descriptor.layout.is_empty() && descriptor.push_constant_ranges.is_empty() {
801                        None
802                    } else {
803                        Some(layout_cache.get(
804                            &device,
805                            &descriptor.layout,
806                            descriptor.push_constant_ranges.to_vec(),
807                        ))
808                    };
809
810                drop((shader_cache, layout_cache));
811
812                let descriptor = RawComputePipelineDescriptor {
813                    label: descriptor.label.as_deref(),
814                    layout: layout.as_ref().map(|layout| -> &PipelineLayout { layout }),
815                    module: &compute_module,
816                    entry_point: Some(&descriptor.entry_point),
817                    // TODO: Expose the rest of this somehow
818                    compilation_options: PipelineCompilationOptions {
819                        constants: &default(),
820                        zero_initialize_workgroup_memory: descriptor
821                            .zero_initialize_workgroup_memory,
822                    },
823                    cache: None,
824                };
825
826                Ok(Pipeline::ComputePipeline(
827                    device.create_compute_pipeline(&descriptor),
828                ))
829            },
830            self.synchronous_pipeline_compilation,
831        )
832    }
833
834    /// Process the pipeline queue and create all pending pipelines if possible.
835    ///
836    /// This is generally called automatically during the [`RenderSet::Render`] step, but can
837    /// be called manually to force creation at a different time.
838    ///
839    /// [`RenderSet::Render`]: crate::RenderSet::Render
840    pub fn process_queue(&mut self) {
841        let mut waiting_pipelines = mem::take(&mut self.waiting_pipelines);
842        let mut pipelines = mem::take(&mut self.pipelines);
843
844        {
845            let mut new_pipelines = self
846                .new_pipelines
847                .lock()
848                .unwrap_or_else(PoisonError::into_inner);
849            for new_pipeline in new_pipelines.drain(..) {
850                let id = pipelines.len();
851                pipelines.push(new_pipeline);
852                waiting_pipelines.insert(id);
853            }
854        }
855
856        for id in waiting_pipelines {
857            self.process_pipeline(&mut pipelines[id], id);
858        }
859
860        self.pipelines = pipelines;
861    }
862
863    fn process_pipeline(&mut self, cached_pipeline: &mut CachedPipeline, id: usize) {
864        match &mut cached_pipeline.state {
865            CachedPipelineState::Queued => {
866                cached_pipeline.state = match &cached_pipeline.descriptor {
867                    PipelineDescriptor::RenderPipelineDescriptor(descriptor) => {
868                        self.start_create_render_pipeline(id, *descriptor.clone())
869                    }
870                    PipelineDescriptor::ComputePipelineDescriptor(descriptor) => {
871                        self.start_create_compute_pipeline(id, *descriptor.clone())
872                    }
873                };
874            }
875
876            CachedPipelineState::Creating(ref mut task) => {
877                match bevy_utils::futures::check_ready(task) {
878                    Some(Ok(pipeline)) => {
879                        cached_pipeline.state = CachedPipelineState::Ok(pipeline);
880                        return;
881                    }
882                    Some(Err(err)) => cached_pipeline.state = CachedPipelineState::Err(err),
883                    _ => (),
884                }
885            }
886
887            CachedPipelineState::Err(err) => match err {
888                // Retry
889                PipelineCacheError::ShaderNotLoaded(_)
890                | PipelineCacheError::ShaderImportNotYetAvailable => {
891                    cached_pipeline.state = CachedPipelineState::Queued;
892                }
893
894                // Shader could not be processed ... retrying won't help
895                PipelineCacheError::ProcessShaderError(err) => {
896                    let error_detail =
897                        err.emit_to_string(&self.shader_cache.lock().unwrap().composer);
898                    error!("failed to process shader:\n{}", error_detail);
899                    return;
900                }
901                PipelineCacheError::CreateShaderModule(description) => {
902                    error!("failed to create shader module: {}", description);
903                    return;
904                }
905            },
906
907            CachedPipelineState::Ok(_) => return,
908        }
909
910        // Retry
911        self.waiting_pipelines.insert(id);
912    }
913
914    pub(crate) fn process_pipeline_queue_system(mut cache: ResMut<Self>) {
915        cache.process_queue();
916    }
917
918    pub(crate) fn extract_shaders(
919        mut cache: ResMut<Self>,
920        shaders: Extract<Res<Assets<Shader>>>,
921        mut events: Extract<EventReader<AssetEvent<Shader>>>,
922    ) {
923        for event in events.read() {
924            #[allow(clippy::match_same_arms)]
925            match event {
926                // PERF: Instead of blocking waiting for the shader cache lock, try again next frame if the lock is currently held
927                AssetEvent::Added { id } | AssetEvent::Modified { id } => {
928                    if let Some(shader) = shaders.get(*id) {
929                        cache.set_shader(*id, shader);
930                    }
931                }
932                AssetEvent::Removed { id } => cache.remove_shader(*id),
933                AssetEvent::Unused { .. } => {}
934                AssetEvent::LoadedWithDependencies { .. } => {
935                    // TODO: handle this
936                }
937            }
938        }
939    }
940}
941
942#[cfg(all(
943    not(target_arch = "wasm32"),
944    not(target_os = "macos"),
945    feature = "multi_threaded"
946))]
947fn create_pipeline_task(
948    task: impl Future<Output = Result<Pipeline, PipelineCacheError>> + Send + 'static,
949    sync: bool,
950) -> CachedPipelineState {
951    if !sync {
952        return CachedPipelineState::Creating(bevy_tasks::AsyncComputeTaskPool::get().spawn(task));
953    }
954
955    match futures_lite::future::block_on(task) {
956        Ok(pipeline) => CachedPipelineState::Ok(pipeline),
957        Err(err) => CachedPipelineState::Err(err),
958    }
959}
960
961#[cfg(any(
962    target_arch = "wasm32",
963    target_os = "macos",
964    not(feature = "multi_threaded")
965))]
966fn create_pipeline_task(
967    task: impl Future<Output = Result<Pipeline, PipelineCacheError>> + Send + 'static,
968    _sync: bool,
969) -> CachedPipelineState {
970    match futures_lite::future::block_on(task) {
971        Ok(pipeline) => CachedPipelineState::Ok(pipeline),
972        Err(err) => CachedPipelineState::Err(err),
973    }
974}
975
976/// Type of error returned by a [`PipelineCache`] when the creation of a GPU pipeline object failed.
977#[derive(Error, Display, Debug, From)]
978pub enum PipelineCacheError {
979    #[display(
980        "Pipeline could not be compiled because the following shader could not be loaded: {_0:?}"
981    )]
982    #[error(ignore)]
983    ShaderNotLoaded(AssetId<Shader>),
984
985    ProcessShaderError(naga_oil::compose::ComposerError),
986    #[display("Shader import not yet available.")]
987    ShaderImportNotYetAvailable,
988    #[display("Could not create shader module: {_0}")]
989    #[error(ignore)]
990    CreateShaderModule(String),
991}
992
993// TODO: This needs to be kept up to date with the capabilities in the `create_validator` function in wgpu-core
994// https://github.com/gfx-rs/wgpu/blob/trunk/wgpu-core/src/device/mod.rs#L449
995// We can't use the `wgpu-core` function to detect the device's capabilities because `wgpu-core` isn't included in WebGPU builds.
996/// Get the device's capabilities for use in `naga_oil`.
997fn get_capabilities(features: Features, downlevel: DownlevelFlags) -> Capabilities {
998    let mut capabilities = Capabilities::empty();
999    capabilities.set(
1000        Capabilities::PUSH_CONSTANT,
1001        features.contains(Features::PUSH_CONSTANTS),
1002    );
1003    capabilities.set(
1004        Capabilities::FLOAT64,
1005        features.contains(Features::SHADER_F64),
1006    );
1007    capabilities.set(
1008        Capabilities::PRIMITIVE_INDEX,
1009        features.contains(Features::SHADER_PRIMITIVE_INDEX),
1010    );
1011    capabilities.set(
1012        Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
1013        features.contains(Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
1014    );
1015    capabilities.set(
1016        Capabilities::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING,
1017        features.contains(Features::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING),
1018    );
1019    // TODO: This needs a proper wgpu feature
1020    capabilities.set(
1021        Capabilities::SAMPLER_NON_UNIFORM_INDEXING,
1022        features.contains(Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
1023    );
1024    capabilities.set(
1025        Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
1026        features.contains(Features::TEXTURE_FORMAT_16BIT_NORM),
1027    );
1028    capabilities.set(
1029        Capabilities::MULTIVIEW,
1030        features.contains(Features::MULTIVIEW),
1031    );
1032    capabilities.set(
1033        Capabilities::EARLY_DEPTH_TEST,
1034        features.contains(Features::SHADER_EARLY_DEPTH_TEST),
1035    );
1036    capabilities.set(
1037        Capabilities::SHADER_INT64,
1038        features.contains(Features::SHADER_INT64),
1039    );
1040    capabilities.set(
1041        Capabilities::SHADER_INT64_ATOMIC_MIN_MAX,
1042        features.intersects(
1043            Features::SHADER_INT64_ATOMIC_MIN_MAX | Features::SHADER_INT64_ATOMIC_ALL_OPS,
1044        ),
1045    );
1046    capabilities.set(
1047        Capabilities::SHADER_INT64_ATOMIC_ALL_OPS,
1048        features.contains(Features::SHADER_INT64_ATOMIC_ALL_OPS),
1049    );
1050    capabilities.set(
1051        Capabilities::MULTISAMPLED_SHADING,
1052        downlevel.contains(DownlevelFlags::MULTISAMPLED_SHADING),
1053    );
1054    capabilities.set(
1055        Capabilities::DUAL_SOURCE_BLENDING,
1056        features.contains(Features::DUAL_SOURCE_BLENDING),
1057    );
1058    capabilities.set(
1059        Capabilities::CUBE_ARRAY_TEXTURES,
1060        downlevel.contains(DownlevelFlags::CUBE_ARRAY_TEXTURES),
1061    );
1062    capabilities.set(
1063        Capabilities::SUBGROUP,
1064        features.intersects(Features::SUBGROUP | Features::SUBGROUP_VERTEX),
1065    );
1066    capabilities.set(
1067        Capabilities::SUBGROUP_BARRIER,
1068        features.intersects(Features::SUBGROUP_BARRIER),
1069    );
1070    capabilities.set(
1071        Capabilities::SUBGROUP_VERTEX_STAGE,
1072        features.contains(Features::SUBGROUP_VERTEX),
1073    );
1074
1075    capabilities
1076}