bevy_render/render_resource/
pipeline_cache.rs

1use crate::renderer::WgpuWrapper;
2use crate::{
3    render_resource::*,
4    renderer::{RenderAdapter, RenderDevice},
5    Extract,
6};
7use alloc::{borrow::Cow, sync::Arc};
8use bevy_asset::{AssetEvent, AssetId, Assets};
9use bevy_ecs::{
10    event::EventReader,
11    resource::Resource,
12    system::{Res, ResMut},
13};
14use bevy_platform::collections::{hash_map::EntryRef, HashMap, HashSet};
15use bevy_tasks::Task;
16use bevy_utils::default;
17use core::{future::Future, hash::Hash, mem, ops::Deref};
18use naga::valid::Capabilities;
19use std::sync::{Mutex, PoisonError};
20use thiserror::Error;
21use tracing::{debug, error};
22#[cfg(feature = "shader_format_spirv")]
23use wgpu::util::make_spirv;
24use wgpu::{
25    DownlevelFlags, Features, PipelineCompilationOptions,
26    VertexBufferLayout as RawVertexBufferLayout,
27};
28
29/// A descriptor for a [`Pipeline`].
30///
31/// Used to store a heterogenous collection of render and compute pipeline descriptors together.
32#[derive(Debug)]
33pub enum PipelineDescriptor {
34    RenderPipelineDescriptor(Box<RenderPipelineDescriptor>),
35    ComputePipelineDescriptor(Box<ComputePipelineDescriptor>),
36}
37
38/// A pipeline defining the data layout and shader logic for a specific GPU task.
39///
40/// Used to store a heterogenous collection of render and compute pipelines together.
41#[derive(Debug)]
42pub enum Pipeline {
43    RenderPipeline(RenderPipeline),
44    ComputePipeline(ComputePipeline),
45}
46
47type CachedPipelineId = usize;
48
49/// Index of a cached render pipeline in a [`PipelineCache`].
50#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, PartialOrd, Ord)]
51pub struct CachedRenderPipelineId(CachedPipelineId);
52
53impl CachedRenderPipelineId {
54    /// An invalid cached render pipeline index, often used to initialize a variable.
55    pub const INVALID: Self = CachedRenderPipelineId(usize::MAX);
56
57    #[inline]
58    pub fn id(&self) -> usize {
59        self.0
60    }
61}
62
63/// Index of a cached compute pipeline in a [`PipelineCache`].
64#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
65pub struct CachedComputePipelineId(CachedPipelineId);
66
67impl CachedComputePipelineId {
68    /// An invalid cached compute pipeline index, often used to initialize a variable.
69    pub const INVALID: Self = CachedComputePipelineId(usize::MAX);
70
71    #[inline]
72    pub fn id(&self) -> usize {
73        self.0
74    }
75}
76
77pub struct CachedPipeline {
78    pub descriptor: PipelineDescriptor,
79    pub state: CachedPipelineState,
80}
81
82/// State of a cached pipeline inserted into a [`PipelineCache`].
83#[derive(Debug)]
84pub enum CachedPipelineState {
85    /// The pipeline GPU object is queued for creation.
86    Queued,
87    /// The pipeline GPU object is being created.
88    Creating(Task<Result<Pipeline, PipelineCacheError>>),
89    /// The pipeline GPU object was created successfully and is available (allocated on the GPU).
90    Ok(Pipeline),
91    /// An error occurred while trying to create the pipeline GPU object.
92    Err(PipelineCacheError),
93}
94
95impl CachedPipelineState {
96    /// Convenience method to "unwrap" a pipeline state into its underlying GPU object.
97    ///
98    /// # Returns
99    ///
100    /// The method returns the allocated pipeline GPU object.
101    ///
102    /// # Panics
103    ///
104    /// This method panics if the pipeline GPU object is not available, either because it is
105    /// pending creation or because an error occurred while attempting to create GPU object.
106    pub fn unwrap(&self) -> &Pipeline {
107        match self {
108            CachedPipelineState::Ok(pipeline) => pipeline,
109            CachedPipelineState::Queued => {
110                panic!("Pipeline has not been compiled yet. It is still in the 'Queued' state.")
111            }
112            CachedPipelineState::Creating(..) => {
113                panic!("Pipeline has not been compiled yet. It is still in the 'Creating' state.")
114            }
115            CachedPipelineState::Err(err) => panic!("{}", err),
116        }
117    }
118}
119
120#[derive(Default)]
121struct ShaderData {
122    pipelines: HashSet<CachedPipelineId>,
123    processed_shaders: HashMap<Box<[ShaderDefVal]>, Arc<WgpuWrapper<ShaderModule>>>,
124    resolved_imports: HashMap<ShaderImport, AssetId<Shader>>,
125    dependents: HashSet<AssetId<Shader>>,
126}
127
128struct ShaderCache {
129    data: HashMap<AssetId<Shader>, ShaderData>,
130    #[cfg(feature = "shader_format_wesl")]
131    asset_paths: HashMap<wesl::syntax::ModulePath, AssetId<Shader>>,
132    shaders: HashMap<AssetId<Shader>, Shader>,
133    import_path_shaders: HashMap<ShaderImport, AssetId<Shader>>,
134    waiting_on_import: HashMap<ShaderImport, Vec<AssetId<Shader>>>,
135    composer: naga_oil::compose::Composer,
136}
137
138#[derive(Clone, PartialEq, Eq, Debug, Hash)]
139pub enum ShaderDefVal {
140    Bool(String, bool),
141    Int(String, i32),
142    UInt(String, u32),
143}
144
145impl From<&str> for ShaderDefVal {
146    fn from(key: &str) -> Self {
147        ShaderDefVal::Bool(key.to_string(), true)
148    }
149}
150
151impl From<String> for ShaderDefVal {
152    fn from(key: String) -> Self {
153        ShaderDefVal::Bool(key, true)
154    }
155}
156
157impl ShaderDefVal {
158    pub fn value_as_string(&self) -> String {
159        match self {
160            ShaderDefVal::Bool(_, def) => def.to_string(),
161            ShaderDefVal::Int(_, def) => def.to_string(),
162            ShaderDefVal::UInt(_, def) => def.to_string(),
163        }
164    }
165}
166
167impl ShaderCache {
168    fn new(render_device: &RenderDevice, render_adapter: &RenderAdapter) -> Self {
169        let capabilities = get_capabilities(
170            render_device.features(),
171            render_adapter.get_downlevel_capabilities().flags,
172        );
173
174        #[cfg(debug_assertions)]
175        let composer = naga_oil::compose::Composer::default();
176        #[cfg(not(debug_assertions))]
177        let composer = naga_oil::compose::Composer::non_validating();
178
179        let composer = composer.with_capabilities(capabilities);
180
181        Self {
182            composer,
183            data: Default::default(),
184            #[cfg(feature = "shader_format_wesl")]
185            asset_paths: Default::default(),
186            shaders: Default::default(),
187            import_path_shaders: Default::default(),
188            waiting_on_import: Default::default(),
189        }
190    }
191
192    fn add_import_to_composer(
193        composer: &mut naga_oil::compose::Composer,
194        import_path_shaders: &HashMap<ShaderImport, AssetId<Shader>>,
195        shaders: &HashMap<AssetId<Shader>, Shader>,
196        import: &ShaderImport,
197    ) -> Result<(), PipelineCacheError> {
198        if !composer.contains_module(&import.module_name()) {
199            if let Some(shader_handle) = import_path_shaders.get(import) {
200                if let Some(shader) = shaders.get(shader_handle) {
201                    for import in &shader.imports {
202                        Self::add_import_to_composer(
203                            composer,
204                            import_path_shaders,
205                            shaders,
206                            import,
207                        )?;
208                    }
209
210                    composer.add_composable_module(shader.into())?;
211                }
212            }
213            // if we fail to add a module the composer will tell us what is missing
214        }
215
216        Ok(())
217    }
218
219    fn get(
220        &mut self,
221        render_device: &RenderDevice,
222        pipeline: CachedPipelineId,
223        id: AssetId<Shader>,
224        shader_defs: &[ShaderDefVal],
225    ) -> Result<Arc<WgpuWrapper<ShaderModule>>, PipelineCacheError> {
226        let shader = self
227            .shaders
228            .get(&id)
229            .ok_or(PipelineCacheError::ShaderNotLoaded(id))?;
230
231        let data = self.data.entry(id).or_default();
232        let n_asset_imports = shader
233            .imports()
234            .filter(|import| matches!(import, ShaderImport::AssetPath(_)))
235            .count();
236        let n_resolved_asset_imports = data
237            .resolved_imports
238            .keys()
239            .filter(|import| matches!(import, ShaderImport::AssetPath(_)))
240            .count();
241        if n_asset_imports != n_resolved_asset_imports {
242            return Err(PipelineCacheError::ShaderImportNotYetAvailable);
243        }
244
245        data.pipelines.insert(pipeline);
246
247        // PERF: this shader_defs clone isn't great. use raw_entry_mut when it stabilizes
248        let module = match data.processed_shaders.entry_ref(shader_defs) {
249            EntryRef::Occupied(entry) => entry.into_mut(),
250            EntryRef::Vacant(entry) => {
251                let mut shader_defs = shader_defs.to_vec();
252                #[cfg(all(feature = "webgl", target_arch = "wasm32", not(feature = "webgpu")))]
253                {
254                    shader_defs.push("NO_ARRAY_TEXTURES_SUPPORT".into());
255                    shader_defs.push("NO_CUBE_ARRAY_TEXTURES_SUPPORT".into());
256                    shader_defs.push("SIXTEEN_BYTE_ALIGNMENT".into());
257                }
258
259                if cfg!(target_abi = "sim") {
260                    shader_defs.push("NO_CUBE_ARRAY_TEXTURES_SUPPORT".into());
261                }
262
263                shader_defs.push(ShaderDefVal::UInt(
264                    String::from("AVAILABLE_STORAGE_BUFFER_BINDINGS"),
265                    render_device.limits().max_storage_buffers_per_shader_stage,
266                ));
267
268                debug!(
269                    "processing shader {}, with shader defs {:?}",
270                    id, shader_defs
271                );
272                let shader_source = match &shader.source {
273                    #[cfg(feature = "shader_format_spirv")]
274                    Source::SpirV(data) => make_spirv(data),
275                    #[cfg(feature = "shader_format_wesl")]
276                    Source::Wesl(_) => {
277                        if let ShaderImport::AssetPath(path) = shader.import_path() {
278                            let shader_resolver =
279                                ShaderResolver::new(&self.asset_paths, &self.shaders);
280                            let module_path = wesl::syntax::ModulePath::from_path(path);
281                            let mut compiler_options = wesl::CompileOptions {
282                                imports: true,
283                                condcomp: true,
284                                lower: true,
285                                ..default()
286                            };
287
288                            for shader_def in shader_defs {
289                                match shader_def {
290                                    ShaderDefVal::Bool(key, value) => {
291                                        compiler_options.features.insert(key.clone(), value);
292                                    }
293                                    _ => debug!(
294                                        "ShaderDefVal::Int and ShaderDefVal::UInt are not supported in wesl",
295                                    ),
296                                }
297                            }
298
299                            let compiled = wesl::compile(
300                                &module_path,
301                                &shader_resolver,
302                                &wesl::EscapeMangler,
303                                &compiler_options,
304                            )
305                            .unwrap();
306
307                            let naga = naga::front::wgsl::parse_str(&compiled.to_string()).unwrap();
308                            ShaderSource::Naga(Cow::Owned(naga))
309                        } else {
310                            panic!("Wesl shaders must be imported from a file");
311                        }
312                    }
313                    #[cfg(not(feature = "shader_format_spirv"))]
314                    Source::SpirV(_) => {
315                        unimplemented!(
316                            "Enable feature \"shader_format_spirv\" to use SPIR-V shaders"
317                        )
318                    }
319                    _ => {
320                        for import in shader.imports() {
321                            Self::add_import_to_composer(
322                                &mut self.composer,
323                                &self.import_path_shaders,
324                                &self.shaders,
325                                import,
326                            )?;
327                        }
328
329                        let shader_defs = shader_defs
330                            .into_iter()
331                            .chain(shader.shader_defs.iter().cloned())
332                            .map(|def| match def {
333                                ShaderDefVal::Bool(k, v) => {
334                                    (k, naga_oil::compose::ShaderDefValue::Bool(v))
335                                }
336                                ShaderDefVal::Int(k, v) => {
337                                    (k, naga_oil::compose::ShaderDefValue::Int(v))
338                                }
339                                ShaderDefVal::UInt(k, v) => {
340                                    (k, naga_oil::compose::ShaderDefValue::UInt(v))
341                                }
342                            })
343                            .collect::<std::collections::HashMap<_, _>>();
344
345                        let naga = self.composer.make_naga_module(
346                            naga_oil::compose::NagaModuleDescriptor {
347                                shader_defs,
348                                ..shader.into()
349                            },
350                        )?;
351
352                        #[cfg(not(feature = "decoupled_naga"))]
353                        {
354                            ShaderSource::Naga(Cow::Owned(naga))
355                        }
356
357                        #[cfg(feature = "decoupled_naga")]
358                        {
359                            let mut validator = naga::valid::Validator::new(
360                                naga::valid::ValidationFlags::all(),
361                                self.composer.capabilities,
362                            );
363                            let module_info = validator.validate(&naga).unwrap();
364                            let wgsl = Cow::Owned(
365                                naga::back::wgsl::write_string(
366                                    &naga,
367                                    &module_info,
368                                    naga::back::wgsl::WriterFlags::empty(),
369                                )
370                                .unwrap(),
371                            );
372                            ShaderSource::Wgsl(wgsl)
373                        }
374                    }
375                };
376
377                let module_descriptor = ShaderModuleDescriptor {
378                    label: None,
379                    source: shader_source,
380                };
381
382                render_device
383                    .wgpu_device()
384                    .push_error_scope(wgpu::ErrorFilter::Validation);
385
386                let shader_module = match shader.validate_shader {
387                    ValidateShader::Enabled => {
388                        render_device.create_and_validate_shader_module(module_descriptor)
389                    }
390                    // SAFETY: we are interfacing with shader code, which may contain undefined behavior,
391                    // such as indexing out of bounds.
392                    // The checks required are prohibitively expensive and a poor default for game engines.
393                    ValidateShader::Disabled => unsafe {
394                        render_device.create_shader_module(module_descriptor)
395                    },
396                };
397
398                let error = render_device.wgpu_device().pop_error_scope();
399
400                // `now_or_never` will return Some if the future is ready and None otherwise.
401                // On native platforms, wgpu will yield the error immediately while on wasm it may take longer since the browser APIs are asynchronous.
402                // So to keep the complexity of the ShaderCache low, we will only catch this error early on native platforms,
403                // and on wasm the error will be handled by wgpu and crash the application.
404                if let Some(Some(wgpu::Error::Validation { description, .. })) =
405                    bevy_tasks::futures::now_or_never(error)
406                {
407                    return Err(PipelineCacheError::CreateShaderModule(description));
408                }
409
410                entry.insert(Arc::new(WgpuWrapper::new(shader_module)))
411            }
412        };
413
414        Ok(module.clone())
415    }
416
417    fn clear(&mut self, id: AssetId<Shader>) -> Vec<CachedPipelineId> {
418        let mut shaders_to_clear = vec![id];
419        let mut pipelines_to_queue = Vec::new();
420        while let Some(handle) = shaders_to_clear.pop() {
421            if let Some(data) = self.data.get_mut(&handle) {
422                data.processed_shaders.clear();
423                pipelines_to_queue.extend(data.pipelines.iter().copied());
424                shaders_to_clear.extend(data.dependents.iter().copied());
425
426                if let Some(Shader { import_path, .. }) = self.shaders.get(&handle) {
427                    self.composer
428                        .remove_composable_module(&import_path.module_name());
429                }
430            }
431        }
432
433        pipelines_to_queue
434    }
435
436    fn set_shader(&mut self, id: AssetId<Shader>, shader: Shader) -> Vec<CachedPipelineId> {
437        let pipelines_to_queue = self.clear(id);
438        let path = shader.import_path();
439        self.import_path_shaders.insert(path.clone(), id);
440        if let Some(waiting_shaders) = self.waiting_on_import.get_mut(path) {
441            for waiting_shader in waiting_shaders.drain(..) {
442                // resolve waiting shader import
443                let data = self.data.entry(waiting_shader).or_default();
444                data.resolved_imports.insert(path.clone(), id);
445                // add waiting shader as dependent of this shader
446                let data = self.data.entry(id).or_default();
447                data.dependents.insert(waiting_shader);
448            }
449        }
450
451        for import in shader.imports() {
452            if let Some(import_id) = self.import_path_shaders.get(import).copied() {
453                // resolve import because it is currently available
454                let data = self.data.entry(id).or_default();
455                data.resolved_imports.insert(import.clone(), import_id);
456                // add this shader as a dependent of the import
457                let data = self.data.entry(import_id).or_default();
458                data.dependents.insert(id);
459            } else {
460                let waiting = self.waiting_on_import.entry(import.clone()).or_default();
461                waiting.push(id);
462            }
463        }
464
465        #[cfg(feature = "shader_format_wesl")]
466        if let Source::Wesl(_) = shader.source {
467            if let ShaderImport::AssetPath(path) = shader.import_path() {
468                self.asset_paths
469                    .insert(wesl::syntax::ModulePath::from_path(path), id);
470            }
471        }
472        self.shaders.insert(id, shader);
473        pipelines_to_queue
474    }
475
476    fn remove(&mut self, id: AssetId<Shader>) -> Vec<CachedPipelineId> {
477        let pipelines_to_queue = self.clear(id);
478        if let Some(shader) = self.shaders.remove(&id) {
479            self.import_path_shaders.remove(shader.import_path());
480        }
481
482        pipelines_to_queue
483    }
484}
485
486#[cfg(feature = "shader_format_wesl")]
487pub struct ShaderResolver<'a> {
488    asset_paths: &'a HashMap<wesl::syntax::ModulePath, AssetId<Shader>>,
489    shaders: &'a HashMap<AssetId<Shader>, Shader>,
490}
491
492#[cfg(feature = "shader_format_wesl")]
493impl<'a> ShaderResolver<'a> {
494    pub fn new(
495        asset_paths: &'a HashMap<wesl::syntax::ModulePath, AssetId<Shader>>,
496        shaders: &'a HashMap<AssetId<Shader>, Shader>,
497    ) -> Self {
498        Self {
499            asset_paths,
500            shaders,
501        }
502    }
503}
504
505#[cfg(feature = "shader_format_wesl")]
506impl<'a> wesl::Resolver for ShaderResolver<'a> {
507    fn resolve_source(
508        &self,
509        module_path: &wesl::syntax::ModulePath,
510    ) -> Result<Cow<str>, wesl::ResolveError> {
511        let asset_id = self.asset_paths.get(module_path).ok_or_else(|| {
512            wesl::ResolveError::ModuleNotFound(module_path.clone(), "Invalid asset id".to_string())
513        })?;
514
515        let shader = self.shaders.get(asset_id).unwrap();
516        Ok(Cow::Borrowed(shader.source.as_str()))
517    }
518}
519
520type LayoutCacheKey = (Vec<BindGroupLayoutId>, Vec<PushConstantRange>);
521#[derive(Default)]
522struct LayoutCache {
523    layouts: HashMap<LayoutCacheKey, Arc<WgpuWrapper<PipelineLayout>>>,
524}
525
526impl LayoutCache {
527    fn get(
528        &mut self,
529        render_device: &RenderDevice,
530        bind_group_layouts: &[BindGroupLayout],
531        push_constant_ranges: Vec<PushConstantRange>,
532    ) -> Arc<WgpuWrapper<PipelineLayout>> {
533        let bind_group_ids = bind_group_layouts.iter().map(BindGroupLayout::id).collect();
534        self.layouts
535            .entry((bind_group_ids, push_constant_ranges))
536            .or_insert_with_key(|(_, push_constant_ranges)| {
537                let bind_group_layouts = bind_group_layouts
538                    .iter()
539                    .map(BindGroupLayout::value)
540                    .collect::<Vec<_>>();
541                Arc::new(WgpuWrapper::new(render_device.create_pipeline_layout(
542                    &PipelineLayoutDescriptor {
543                        bind_group_layouts: &bind_group_layouts,
544                        push_constant_ranges,
545                        ..default()
546                    },
547                )))
548            })
549            .clone()
550    }
551}
552
553/// Cache for render and compute pipelines.
554///
555/// The cache stores existing render and compute pipelines allocated on the GPU, as well as
556/// pending creation. Pipelines inserted into the cache are identified by a unique ID, which
557/// can be used to retrieve the actual GPU object once it's ready. The creation of the GPU
558/// pipeline object is deferred to the [`RenderSet::Render`] step, just before the render
559/// graph starts being processed, as this requires access to the GPU.
560///
561/// Note that the cache does not perform automatic deduplication of identical pipelines. It is
562/// up to the user not to insert the same pipeline twice to avoid wasting GPU resources.
563///
564/// [`RenderSet::Render`]: crate::RenderSet::Render
565#[derive(Resource)]
566pub struct PipelineCache {
567    layout_cache: Arc<Mutex<LayoutCache>>,
568    shader_cache: Arc<Mutex<ShaderCache>>,
569    device: RenderDevice,
570    pipelines: Vec<CachedPipeline>,
571    waiting_pipelines: HashSet<CachedPipelineId>,
572    new_pipelines: Mutex<Vec<CachedPipeline>>,
573    /// If `true`, disables asynchronous pipeline compilation.
574    /// This has no effect on macOS, wasm, or without the `multi_threaded` feature.
575    synchronous_pipeline_compilation: bool,
576}
577
578impl PipelineCache {
579    /// Returns an iterator over the pipelines in the pipeline cache.
580    pub fn pipelines(&self) -> impl Iterator<Item = &CachedPipeline> {
581        self.pipelines.iter()
582    }
583
584    /// Returns a iterator of the IDs of all currently waiting pipelines.
585    pub fn waiting_pipelines(&self) -> impl Iterator<Item = CachedPipelineId> + '_ {
586        self.waiting_pipelines.iter().copied()
587    }
588
589    /// Create a new pipeline cache associated with the given render device.
590    pub fn new(
591        device: RenderDevice,
592        render_adapter: RenderAdapter,
593        synchronous_pipeline_compilation: bool,
594    ) -> Self {
595        Self {
596            shader_cache: Arc::new(Mutex::new(ShaderCache::new(&device, &render_adapter))),
597            device,
598            layout_cache: default(),
599            waiting_pipelines: default(),
600            new_pipelines: default(),
601            pipelines: default(),
602            synchronous_pipeline_compilation,
603        }
604    }
605
606    /// Get the state of a cached render pipeline.
607    ///
608    /// See [`PipelineCache::queue_render_pipeline()`].
609    #[inline]
610    pub fn get_render_pipeline_state(&self, id: CachedRenderPipelineId) -> &CachedPipelineState {
611        // If the pipeline id isn't in `pipelines`, it's queued in `new_pipelines`
612        self.pipelines
613            .get(id.0)
614            .map_or(&CachedPipelineState::Queued, |pipeline| &pipeline.state)
615    }
616
617    /// Get the state of a cached compute pipeline.
618    ///
619    /// See [`PipelineCache::queue_compute_pipeline()`].
620    #[inline]
621    pub fn get_compute_pipeline_state(&self, id: CachedComputePipelineId) -> &CachedPipelineState {
622        // If the pipeline id isn't in `pipelines`, it's queued in `new_pipelines`
623        self.pipelines
624            .get(id.0)
625            .map_or(&CachedPipelineState::Queued, |pipeline| &pipeline.state)
626    }
627
628    /// Get the render pipeline descriptor a cached render pipeline was inserted from.
629    ///
630    /// See [`PipelineCache::queue_render_pipeline()`].
631    ///
632    /// **Note**: Be careful calling this method. It will panic if called with a pipeline that
633    /// has been queued but has not yet been processed by [`PipelineCache::process_queue()`].
634    #[inline]
635    pub fn get_render_pipeline_descriptor(
636        &self,
637        id: CachedRenderPipelineId,
638    ) -> &RenderPipelineDescriptor {
639        match &self.pipelines[id.0].descriptor {
640            PipelineDescriptor::RenderPipelineDescriptor(descriptor) => descriptor,
641            PipelineDescriptor::ComputePipelineDescriptor(_) => unreachable!(),
642        }
643    }
644
645    /// Get the compute pipeline descriptor a cached render pipeline was inserted from.
646    ///
647    /// See [`PipelineCache::queue_compute_pipeline()`].
648    ///
649    /// **Note**: Be careful calling this method. It will panic if called with a pipeline that
650    /// has been queued but has not yet been processed by [`PipelineCache::process_queue()`].
651    #[inline]
652    pub fn get_compute_pipeline_descriptor(
653        &self,
654        id: CachedComputePipelineId,
655    ) -> &ComputePipelineDescriptor {
656        match &self.pipelines[id.0].descriptor {
657            PipelineDescriptor::RenderPipelineDescriptor(_) => unreachable!(),
658            PipelineDescriptor::ComputePipelineDescriptor(descriptor) => descriptor,
659        }
660    }
661
662    /// Try to retrieve a render pipeline GPU object from a cached ID.
663    ///
664    /// # Returns
665    ///
666    /// This method returns a successfully created render pipeline if any, or `None` if the pipeline
667    /// was not created yet or if there was an error during creation. You can check the actual creation
668    /// state with [`PipelineCache::get_render_pipeline_state()`].
669    #[inline]
670    pub fn get_render_pipeline(&self, id: CachedRenderPipelineId) -> Option<&RenderPipeline> {
671        if let CachedPipelineState::Ok(Pipeline::RenderPipeline(pipeline)) =
672            &self.pipelines.get(id.0)?.state
673        {
674            Some(pipeline)
675        } else {
676            None
677        }
678    }
679
680    /// Wait for a render pipeline to finish compiling.
681    #[inline]
682    pub fn block_on_render_pipeline(&mut self, id: CachedRenderPipelineId) {
683        if self.pipelines.len() <= id.0 {
684            self.process_queue();
685        }
686
687        let state = &mut self.pipelines[id.0].state;
688        if let CachedPipelineState::Creating(task) = state {
689            *state = match bevy_tasks::block_on(task) {
690                Ok(p) => CachedPipelineState::Ok(p),
691                Err(e) => CachedPipelineState::Err(e),
692            };
693        }
694    }
695
696    /// Try to retrieve a compute pipeline GPU object from a cached ID.
697    ///
698    /// # Returns
699    ///
700    /// This method returns a successfully created compute pipeline if any, or `None` if the pipeline
701    /// was not created yet or if there was an error during creation. You can check the actual creation
702    /// state with [`PipelineCache::get_compute_pipeline_state()`].
703    #[inline]
704    pub fn get_compute_pipeline(&self, id: CachedComputePipelineId) -> Option<&ComputePipeline> {
705        if let CachedPipelineState::Ok(Pipeline::ComputePipeline(pipeline)) =
706            &self.pipelines.get(id.0)?.state
707        {
708            Some(pipeline)
709        } else {
710            None
711        }
712    }
713
714    /// Insert a render pipeline into the cache, and queue its creation.
715    ///
716    /// The pipeline is always inserted and queued for creation. There is no attempt to deduplicate it with
717    /// an already cached pipeline.
718    ///
719    /// # Returns
720    ///
721    /// This method returns the unique render shader ID of the cached pipeline, which can be used to query
722    /// the caching state with [`get_render_pipeline_state()`] and to retrieve the created GPU pipeline once
723    /// it's ready with [`get_render_pipeline()`].
724    ///
725    /// [`get_render_pipeline_state()`]: PipelineCache::get_render_pipeline_state
726    /// [`get_render_pipeline()`]: PipelineCache::get_render_pipeline
727    pub fn queue_render_pipeline(
728        &self,
729        descriptor: RenderPipelineDescriptor,
730    ) -> CachedRenderPipelineId {
731        let mut new_pipelines = self
732            .new_pipelines
733            .lock()
734            .unwrap_or_else(PoisonError::into_inner);
735        let id = CachedRenderPipelineId(self.pipelines.len() + new_pipelines.len());
736        new_pipelines.push(CachedPipeline {
737            descriptor: PipelineDescriptor::RenderPipelineDescriptor(Box::new(descriptor)),
738            state: CachedPipelineState::Queued,
739        });
740        id
741    }
742
743    /// Insert a compute pipeline into the cache, and queue its creation.
744    ///
745    /// The pipeline is always inserted and queued for creation. There is no attempt to deduplicate it with
746    /// an already cached pipeline.
747    ///
748    /// # Returns
749    ///
750    /// This method returns the unique compute shader ID of the cached pipeline, which can be used to query
751    /// the caching state with [`get_compute_pipeline_state()`] and to retrieve the created GPU pipeline once
752    /// it's ready with [`get_compute_pipeline()`].
753    ///
754    /// [`get_compute_pipeline_state()`]: PipelineCache::get_compute_pipeline_state
755    /// [`get_compute_pipeline()`]: PipelineCache::get_compute_pipeline
756    pub fn queue_compute_pipeline(
757        &self,
758        descriptor: ComputePipelineDescriptor,
759    ) -> CachedComputePipelineId {
760        let mut new_pipelines = self
761            .new_pipelines
762            .lock()
763            .unwrap_or_else(PoisonError::into_inner);
764        let id = CachedComputePipelineId(self.pipelines.len() + new_pipelines.len());
765        new_pipelines.push(CachedPipeline {
766            descriptor: PipelineDescriptor::ComputePipelineDescriptor(Box::new(descriptor)),
767            state: CachedPipelineState::Queued,
768        });
769        id
770    }
771
772    fn set_shader(&mut self, id: AssetId<Shader>, shader: &Shader) {
773        let mut shader_cache = self.shader_cache.lock().unwrap();
774        let pipelines_to_queue = shader_cache.set_shader(id, shader.clone());
775        for cached_pipeline in pipelines_to_queue {
776            self.pipelines[cached_pipeline].state = CachedPipelineState::Queued;
777            self.waiting_pipelines.insert(cached_pipeline);
778        }
779    }
780
781    fn remove_shader(&mut self, shader: AssetId<Shader>) {
782        let mut shader_cache = self.shader_cache.lock().unwrap();
783        let pipelines_to_queue = shader_cache.remove(shader);
784        for cached_pipeline in pipelines_to_queue {
785            self.pipelines[cached_pipeline].state = CachedPipelineState::Queued;
786            self.waiting_pipelines.insert(cached_pipeline);
787        }
788    }
789
790    fn start_create_render_pipeline(
791        &mut self,
792        id: CachedPipelineId,
793        descriptor: RenderPipelineDescriptor,
794    ) -> CachedPipelineState {
795        let device = self.device.clone();
796        let shader_cache = self.shader_cache.clone();
797        let layout_cache = self.layout_cache.clone();
798
799        create_pipeline_task(
800            async move {
801                let mut shader_cache = shader_cache.lock().unwrap();
802                let mut layout_cache = layout_cache.lock().unwrap();
803
804                let vertex_module = match shader_cache.get(
805                    &device,
806                    id,
807                    descriptor.vertex.shader.id(),
808                    &descriptor.vertex.shader_defs,
809                ) {
810                    Ok(module) => module,
811                    Err(err) => return Err(err),
812                };
813
814                let fragment_module = match &descriptor.fragment {
815                    Some(fragment) => {
816                        match shader_cache.get(
817                            &device,
818                            id,
819                            fragment.shader.id(),
820                            &fragment.shader_defs,
821                        ) {
822                            Ok(module) => Some(module),
823                            Err(err) => return Err(err),
824                        }
825                    }
826                    None => None,
827                };
828
829                let layout =
830                    if descriptor.layout.is_empty() && descriptor.push_constant_ranges.is_empty() {
831                        None
832                    } else {
833                        Some(layout_cache.get(
834                            &device,
835                            &descriptor.layout,
836                            descriptor.push_constant_ranges.to_vec(),
837                        ))
838                    };
839
840                drop((shader_cache, layout_cache));
841
842                let vertex_buffer_layouts = descriptor
843                    .vertex
844                    .buffers
845                    .iter()
846                    .map(|layout| RawVertexBufferLayout {
847                        array_stride: layout.array_stride,
848                        attributes: &layout.attributes,
849                        step_mode: layout.step_mode,
850                    })
851                    .collect::<Vec<_>>();
852
853                let fragment_data = descriptor.fragment.as_ref().map(|fragment| {
854                    (
855                        fragment_module.unwrap(),
856                        fragment.entry_point.deref(),
857                        fragment.targets.as_slice(),
858                    )
859                });
860
861                // TODO: Expose the rest of this somehow
862                let compilation_options = PipelineCompilationOptions {
863                    constants: &default(),
864                    zero_initialize_workgroup_memory: descriptor.zero_initialize_workgroup_memory,
865                };
866
867                let descriptor = RawRenderPipelineDescriptor {
868                    multiview: None,
869                    depth_stencil: descriptor.depth_stencil.clone(),
870                    label: descriptor.label.as_deref(),
871                    layout: layout.as_ref().map(|layout| -> &PipelineLayout { layout }),
872                    multisample: descriptor.multisample,
873                    primitive: descriptor.primitive,
874                    vertex: RawVertexState {
875                        buffers: &vertex_buffer_layouts,
876                        entry_point: Some(descriptor.vertex.entry_point.deref()),
877                        module: &vertex_module,
878                        // TODO: Should this be the same as the fragment compilation options?
879                        compilation_options: compilation_options.clone(),
880                    },
881                    fragment: fragment_data
882                        .as_ref()
883                        .map(|(module, entry_point, targets)| RawFragmentState {
884                            entry_point: Some(entry_point),
885                            module,
886                            targets,
887                            // TODO: Should this be the same as the vertex compilation options?
888                            compilation_options,
889                        }),
890                    cache: None,
891                };
892
893                Ok(Pipeline::RenderPipeline(
894                    device.create_render_pipeline(&descriptor),
895                ))
896            },
897            self.synchronous_pipeline_compilation,
898        )
899    }
900
901    fn start_create_compute_pipeline(
902        &mut self,
903        id: CachedPipelineId,
904        descriptor: ComputePipelineDescriptor,
905    ) -> CachedPipelineState {
906        let device = self.device.clone();
907        let shader_cache = self.shader_cache.clone();
908        let layout_cache = self.layout_cache.clone();
909
910        create_pipeline_task(
911            async move {
912                let mut shader_cache = shader_cache.lock().unwrap();
913                let mut layout_cache = layout_cache.lock().unwrap();
914
915                let compute_module = match shader_cache.get(
916                    &device,
917                    id,
918                    descriptor.shader.id(),
919                    &descriptor.shader_defs,
920                ) {
921                    Ok(module) => module,
922                    Err(err) => return Err(err),
923                };
924
925                let layout =
926                    if descriptor.layout.is_empty() && descriptor.push_constant_ranges.is_empty() {
927                        None
928                    } else {
929                        Some(layout_cache.get(
930                            &device,
931                            &descriptor.layout,
932                            descriptor.push_constant_ranges.to_vec(),
933                        ))
934                    };
935
936                drop((shader_cache, layout_cache));
937
938                let descriptor = RawComputePipelineDescriptor {
939                    label: descriptor.label.as_deref(),
940                    layout: layout.as_ref().map(|layout| -> &PipelineLayout { layout }),
941                    module: &compute_module,
942                    entry_point: Some(&descriptor.entry_point),
943                    // TODO: Expose the rest of this somehow
944                    compilation_options: PipelineCompilationOptions {
945                        constants: &default(),
946                        zero_initialize_workgroup_memory: descriptor
947                            .zero_initialize_workgroup_memory,
948                    },
949                    cache: None,
950                };
951
952                Ok(Pipeline::ComputePipeline(
953                    device.create_compute_pipeline(&descriptor),
954                ))
955            },
956            self.synchronous_pipeline_compilation,
957        )
958    }
959
960    /// Process the pipeline queue and create all pending pipelines if possible.
961    ///
962    /// This is generally called automatically during the [`RenderSet::Render`] step, but can
963    /// be called manually to force creation at a different time.
964    ///
965    /// [`RenderSet::Render`]: crate::RenderSet::Render
966    pub fn process_queue(&mut self) {
967        let mut waiting_pipelines = mem::take(&mut self.waiting_pipelines);
968        let mut pipelines = mem::take(&mut self.pipelines);
969
970        {
971            let mut new_pipelines = self
972                .new_pipelines
973                .lock()
974                .unwrap_or_else(PoisonError::into_inner);
975            for new_pipeline in new_pipelines.drain(..) {
976                let id = pipelines.len();
977                pipelines.push(new_pipeline);
978                waiting_pipelines.insert(id);
979            }
980        }
981
982        for id in waiting_pipelines {
983            self.process_pipeline(&mut pipelines[id], id);
984        }
985
986        self.pipelines = pipelines;
987    }
988
989    fn process_pipeline(&mut self, cached_pipeline: &mut CachedPipeline, id: usize) {
990        match &mut cached_pipeline.state {
991            CachedPipelineState::Queued => {
992                cached_pipeline.state = match &cached_pipeline.descriptor {
993                    PipelineDescriptor::RenderPipelineDescriptor(descriptor) => {
994                        self.start_create_render_pipeline(id, *descriptor.clone())
995                    }
996                    PipelineDescriptor::ComputePipelineDescriptor(descriptor) => {
997                        self.start_create_compute_pipeline(id, *descriptor.clone())
998                    }
999                };
1000            }
1001
1002            CachedPipelineState::Creating(task) => match bevy_tasks::futures::check_ready(task) {
1003                Some(Ok(pipeline)) => {
1004                    cached_pipeline.state = CachedPipelineState::Ok(pipeline);
1005                    return;
1006                }
1007                Some(Err(err)) => cached_pipeline.state = CachedPipelineState::Err(err),
1008                _ => (),
1009            },
1010
1011            CachedPipelineState::Err(err) => match err {
1012                // Retry
1013                PipelineCacheError::ShaderNotLoaded(_)
1014                | PipelineCacheError::ShaderImportNotYetAvailable => {
1015                    cached_pipeline.state = CachedPipelineState::Queued;
1016                }
1017
1018                // Shader could not be processed ... retrying won't help
1019                PipelineCacheError::ProcessShaderError(err) => {
1020                    let error_detail =
1021                        err.emit_to_string(&self.shader_cache.lock().unwrap().composer);
1022                    error!("failed to process shader:\n{}", error_detail);
1023                    return;
1024                }
1025                PipelineCacheError::CreateShaderModule(description) => {
1026                    error!("failed to create shader module: {}", description);
1027                    return;
1028                }
1029            },
1030
1031            CachedPipelineState::Ok(_) => return,
1032        }
1033
1034        // Retry
1035        self.waiting_pipelines.insert(id);
1036    }
1037
1038    pub(crate) fn process_pipeline_queue_system(mut cache: ResMut<Self>) {
1039        cache.process_queue();
1040    }
1041
1042    pub(crate) fn extract_shaders(
1043        mut cache: ResMut<Self>,
1044        shaders: Extract<Res<Assets<Shader>>>,
1045        mut events: Extract<EventReader<AssetEvent<Shader>>>,
1046    ) {
1047        for event in events.read() {
1048            #[expect(
1049                clippy::match_same_arms,
1050                reason = "LoadedWithDependencies is marked as a TODO, so it's likely this will no longer lint soon."
1051            )]
1052            match event {
1053                // PERF: Instead of blocking waiting for the shader cache lock, try again next frame if the lock is currently held
1054                AssetEvent::Added { id } | AssetEvent::Modified { id } => {
1055                    if let Some(shader) = shaders.get(*id) {
1056                        cache.set_shader(*id, shader);
1057                    }
1058                }
1059                AssetEvent::Removed { id } => cache.remove_shader(*id),
1060                AssetEvent::Unused { .. } => {}
1061                AssetEvent::LoadedWithDependencies { .. } => {
1062                    // TODO: handle this
1063                }
1064            }
1065        }
1066    }
1067}
1068
1069#[cfg(all(
1070    not(target_arch = "wasm32"),
1071    not(target_os = "macos"),
1072    feature = "multi_threaded"
1073))]
1074fn create_pipeline_task(
1075    task: impl Future<Output = Result<Pipeline, PipelineCacheError>> + Send + 'static,
1076    sync: bool,
1077) -> CachedPipelineState {
1078    if !sync {
1079        return CachedPipelineState::Creating(bevy_tasks::AsyncComputeTaskPool::get().spawn(task));
1080    }
1081
1082    match futures_lite::future::block_on(task) {
1083        Ok(pipeline) => CachedPipelineState::Ok(pipeline),
1084        Err(err) => CachedPipelineState::Err(err),
1085    }
1086}
1087
1088#[cfg(any(
1089    target_arch = "wasm32",
1090    target_os = "macos",
1091    not(feature = "multi_threaded")
1092))]
1093fn create_pipeline_task(
1094    task: impl Future<Output = Result<Pipeline, PipelineCacheError>> + Send + 'static,
1095    _sync: bool,
1096) -> CachedPipelineState {
1097    match futures_lite::future::block_on(task) {
1098        Ok(pipeline) => CachedPipelineState::Ok(pipeline),
1099        Err(err) => CachedPipelineState::Err(err),
1100    }
1101}
1102
1103/// Type of error returned by a [`PipelineCache`] when the creation of a GPU pipeline object failed.
1104#[derive(Error, Debug)]
1105pub enum PipelineCacheError {
1106    #[error(
1107        "Pipeline could not be compiled because the following shader could not be loaded: {0:?}"
1108    )]
1109    ShaderNotLoaded(AssetId<Shader>),
1110    #[error(transparent)]
1111    ProcessShaderError(#[from] naga_oil::compose::ComposerError),
1112    #[error("Shader import not yet available.")]
1113    ShaderImportNotYetAvailable,
1114    #[error("Could not create shader module: {0}")]
1115    CreateShaderModule(String),
1116}
1117
1118// TODO: This needs to be kept up to date with the capabilities in the `create_validator` function in wgpu-core
1119// https://github.com/gfx-rs/wgpu/blob/trunk/wgpu-core/src/device/mod.rs#L449
1120// We can't use the `wgpu-core` function to detect the device's capabilities because `wgpu-core` isn't included in WebGPU builds.
1121/// Get the device's capabilities for use in `naga_oil`.
1122fn get_capabilities(features: Features, downlevel: DownlevelFlags) -> Capabilities {
1123    let mut capabilities = Capabilities::empty();
1124    capabilities.set(
1125        Capabilities::PUSH_CONSTANT,
1126        features.contains(Features::PUSH_CONSTANTS),
1127    );
1128    capabilities.set(
1129        Capabilities::FLOAT64,
1130        features.contains(Features::SHADER_F64),
1131    );
1132    capabilities.set(
1133        Capabilities::PRIMITIVE_INDEX,
1134        features.contains(Features::SHADER_PRIMITIVE_INDEX),
1135    );
1136    capabilities.set(
1137        Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
1138        features.contains(Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
1139    );
1140    capabilities.set(
1141        Capabilities::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING,
1142        features.contains(Features::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING),
1143    );
1144    // TODO: This needs a proper wgpu feature
1145    capabilities.set(
1146        Capabilities::SAMPLER_NON_UNIFORM_INDEXING,
1147        features.contains(Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
1148    );
1149    capabilities.set(
1150        Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
1151        features.contains(Features::TEXTURE_FORMAT_16BIT_NORM),
1152    );
1153    capabilities.set(
1154        Capabilities::MULTIVIEW,
1155        features.contains(Features::MULTIVIEW),
1156    );
1157    capabilities.set(
1158        Capabilities::EARLY_DEPTH_TEST,
1159        features.contains(Features::SHADER_EARLY_DEPTH_TEST),
1160    );
1161    capabilities.set(
1162        Capabilities::SHADER_INT64,
1163        features.contains(Features::SHADER_INT64),
1164    );
1165    capabilities.set(
1166        Capabilities::SHADER_INT64_ATOMIC_MIN_MAX,
1167        features.intersects(
1168            Features::SHADER_INT64_ATOMIC_MIN_MAX | Features::SHADER_INT64_ATOMIC_ALL_OPS,
1169        ),
1170    );
1171    capabilities.set(
1172        Capabilities::SHADER_INT64_ATOMIC_ALL_OPS,
1173        features.contains(Features::SHADER_INT64_ATOMIC_ALL_OPS),
1174    );
1175    capabilities.set(
1176        Capabilities::MULTISAMPLED_SHADING,
1177        downlevel.contains(DownlevelFlags::MULTISAMPLED_SHADING),
1178    );
1179    capabilities.set(
1180        Capabilities::DUAL_SOURCE_BLENDING,
1181        features.contains(Features::DUAL_SOURCE_BLENDING),
1182    );
1183    capabilities.set(
1184        Capabilities::CUBE_ARRAY_TEXTURES,
1185        downlevel.contains(DownlevelFlags::CUBE_ARRAY_TEXTURES),
1186    );
1187    capabilities.set(
1188        Capabilities::SUBGROUP,
1189        features.intersects(Features::SUBGROUP | Features::SUBGROUP_VERTEX),
1190    );
1191    capabilities.set(
1192        Capabilities::SUBGROUP_BARRIER,
1193        features.intersects(Features::SUBGROUP_BARRIER),
1194    );
1195    capabilities.set(
1196        Capabilities::SUBGROUP_VERTEX_STAGE,
1197        features.contains(Features::SUBGROUP_VERTEX),
1198    );
1199    capabilities.set(
1200        Capabilities::SHADER_FLOAT32_ATOMIC,
1201        features.contains(Features::SHADER_FLOAT32_ATOMIC),
1202    );
1203    capabilities.set(
1204        Capabilities::TEXTURE_ATOMIC,
1205        features.contains(Features::TEXTURE_ATOMIC),
1206    );
1207    capabilities.set(
1208        Capabilities::TEXTURE_INT64_ATOMIC,
1209        features.contains(Features::TEXTURE_INT64_ATOMIC),
1210    );
1211
1212    capabilities
1213}