bevy_pbr/render/
gpu_preprocess.rs

1//! GPU mesh preprocessing.
2//!
3//! This is an optional pass that uses a compute shader to reduce the amount of
4//! data that has to be transferred from the CPU to the GPU. When enabled,
5//! instead of transferring [`MeshUniform`]s to the GPU, we transfer the smaller
6//! [`MeshInputUniform`]s instead and use the GPU to calculate the remaining
7//! derived fields in [`MeshUniform`].
8
9use core::num::NonZero;
10
11use bevy_app::{App, Plugin};
12use bevy_asset::{load_internal_asset, Handle};
13use bevy_ecs::{
14    component::Component,
15    entity::Entity,
16    query::{Has, QueryState, Without},
17    schedule::{common_conditions::resource_exists, IntoSystemConfigs as _},
18    system::{lifetimeless::Read, Commands, Res, ResMut, Resource},
19    world::{FromWorld, World},
20};
21use bevy_render::{
22    batching::gpu_preprocessing::{
23        BatchedInstanceBuffers, GpuPreprocessingSupport, IndirectParameters,
24        IndirectParametersBuffer, PreprocessWorkItem,
25    },
26    graph::CameraDriverLabel,
27    render_graph::{Node, NodeRunError, RenderGraph, RenderGraphContext},
28    render_resource::{
29        binding_types::{storage_buffer, storage_buffer_read_only, uniform_buffer},
30        BindGroup, BindGroupEntries, BindGroupLayout, BindingResource, BufferBinding,
31        CachedComputePipelineId, ComputePassDescriptor, ComputePipelineDescriptor,
32        DynamicBindGroupLayoutEntries, PipelineCache, Shader, ShaderStages, ShaderType,
33        SpecializedComputePipeline, SpecializedComputePipelines,
34    },
35    renderer::{RenderContext, RenderDevice, RenderQueue},
36    view::{GpuCulling, ViewUniform, ViewUniformOffset, ViewUniforms},
37    Render, RenderApp, RenderSet,
38};
39use bevy_utils::tracing::warn;
40use bitflags::bitflags;
41use smallvec::{smallvec, SmallVec};
42
43use crate::{
44    graph::NodePbr, MeshCullingData, MeshCullingDataBuffer, MeshInputUniform, MeshUniform,
45};
46
47/// The handle to the `mesh_preprocess.wgsl` compute shader.
48pub const MESH_PREPROCESS_SHADER_HANDLE: Handle<Shader> =
49    Handle::weak_from_u128(16991728318640779533);
50
51/// The GPU workgroup size.
52const WORKGROUP_SIZE: usize = 64;
53
54/// A plugin that builds mesh uniforms on GPU.
55///
56/// This will only be added if the platform supports compute shaders (e.g. not
57/// on WebGL 2).
58pub struct GpuMeshPreprocessPlugin {
59    /// Whether we're building [`MeshUniform`]s on GPU.
60    ///
61    /// This requires compute shader support and so will be forcibly disabled if
62    /// the platform doesn't support those.
63    pub use_gpu_instance_buffer_builder: bool,
64}
65
66/// The render node for the mesh uniform building pass.
67pub struct GpuPreprocessNode {
68    view_query: QueryState<
69        (
70            Entity,
71            Read<PreprocessBindGroup>,
72            Read<ViewUniformOffset>,
73            Has<GpuCulling>,
74        ),
75        Without<SkipGpuPreprocess>,
76    >,
77}
78
79/// The compute shader pipelines for the mesh uniform building pass.
80#[derive(Resource)]
81pub struct PreprocessPipelines {
82    /// The pipeline used for CPU culling. This pipeline doesn't populate
83    /// indirect parameters.
84    pub direct: PreprocessPipeline,
85    /// The pipeline used for GPU culling. This pipeline populates indirect
86    /// parameters.
87    pub gpu_culling: PreprocessPipeline,
88}
89
90/// The pipeline for the GPU mesh preprocessing shader.
91pub struct PreprocessPipeline {
92    /// The bind group layout for the compute shader.
93    pub bind_group_layout: BindGroupLayout,
94    /// The pipeline ID for the compute shader.
95    ///
96    /// This gets filled in `prepare_preprocess_pipelines`.
97    pub pipeline_id: Option<CachedComputePipelineId>,
98}
99
100bitflags! {
101    /// Specifies variants of the mesh preprocessing shader.
102    #[derive(Clone, Copy, PartialEq, Eq, Hash)]
103    pub struct PreprocessPipelineKey: u8 {
104        /// Whether GPU culling is in use.
105        ///
106        /// This `#define`'s `GPU_CULLING` in the shader.
107        const GPU_CULLING = 1;
108    }
109}
110
111/// The compute shader bind group for the mesh uniform building pass.
112///
113/// This goes on the view.
114#[derive(Component, Clone)]
115pub struct PreprocessBindGroup(BindGroup);
116
117/// Stops the `GpuPreprocessNode` attempting to generate the buffer for this view
118/// useful to avoid duplicating effort if the bind group is shared between views
119#[derive(Component)]
120pub struct SkipGpuPreprocess;
121
122impl Plugin for GpuMeshPreprocessPlugin {
123    fn build(&self, app: &mut App) {
124        load_internal_asset!(
125            app,
126            MESH_PREPROCESS_SHADER_HANDLE,
127            "mesh_preprocess.wgsl",
128            Shader::from_wgsl
129        );
130    }
131
132    fn finish(&self, app: &mut App) {
133        let Some(render_app) = app.get_sub_app_mut(RenderApp) else {
134            return;
135        };
136
137        // This plugin does nothing if GPU instance buffer building isn't in
138        // use.
139        let gpu_preprocessing_support = render_app.world().resource::<GpuPreprocessingSupport>();
140        if !self.use_gpu_instance_buffer_builder
141            || *gpu_preprocessing_support == GpuPreprocessingSupport::None
142        {
143            return;
144        }
145
146        // Stitch the node in.
147        let gpu_preprocess_node = GpuPreprocessNode::from_world(render_app.world_mut());
148        let mut render_graph = render_app.world_mut().resource_mut::<RenderGraph>();
149        render_graph.add_node(NodePbr::GpuPreprocess, gpu_preprocess_node);
150        render_graph.add_node_edge(NodePbr::GpuPreprocess, CameraDriverLabel);
151
152        render_app
153            .init_resource::<PreprocessPipelines>()
154            .init_resource::<SpecializedComputePipelines<PreprocessPipeline>>()
155            .add_systems(
156                Render,
157                (
158                    prepare_preprocess_pipelines.in_set(RenderSet::Prepare),
159                    prepare_preprocess_bind_groups
160                        .run_if(
161                            resource_exists::<BatchedInstanceBuffers<MeshUniform, MeshInputUniform>>,
162                        )
163                        .in_set(RenderSet::PrepareBindGroups),
164                    write_mesh_culling_data_buffer.in_set(RenderSet::PrepareResourcesFlush),
165                )
166            );
167    }
168}
169
170impl FromWorld for GpuPreprocessNode {
171    fn from_world(world: &mut World) -> Self {
172        Self {
173            view_query: QueryState::new(world),
174        }
175    }
176}
177
178impl Node for GpuPreprocessNode {
179    fn update(&mut self, world: &mut World) {
180        self.view_query.update_archetypes(world);
181    }
182
183    fn run<'w>(
184        &self,
185        _: &mut RenderGraphContext,
186        render_context: &mut RenderContext<'w>,
187        world: &'w World,
188    ) -> Result<(), NodeRunError> {
189        // Grab the [`BatchedInstanceBuffers`].
190        let BatchedInstanceBuffers {
191            work_item_buffers: ref index_buffers,
192            ..
193        } = world.resource::<BatchedInstanceBuffers<MeshUniform, MeshInputUniform>>();
194
195        let pipeline_cache = world.resource::<PipelineCache>();
196        let preprocess_pipelines = world.resource::<PreprocessPipelines>();
197
198        let mut compute_pass =
199            render_context
200                .command_encoder()
201                .begin_compute_pass(&ComputePassDescriptor {
202                    label: Some("mesh preprocessing"),
203                    timestamp_writes: None,
204                });
205
206        // Run the compute passes.
207        for (view, bind_group, view_uniform_offset, gpu_culling) in
208            self.view_query.iter_manual(world)
209        {
210            // Grab the index buffer for this view.
211            let Some(index_buffer) = index_buffers.get(&view) else {
212                warn!("The preprocessing index buffer wasn't present");
213                continue;
214            };
215
216            // Select the right pipeline, depending on whether GPU culling is in
217            // use.
218            let maybe_pipeline_id = if gpu_culling {
219                preprocess_pipelines.gpu_culling.pipeline_id
220            } else {
221                preprocess_pipelines.direct.pipeline_id
222            };
223
224            // Fetch the pipeline.
225            let Some(preprocess_pipeline_id) = maybe_pipeline_id else {
226                warn!("The build mesh uniforms pipeline wasn't ready");
227                return Ok(());
228            };
229
230            let Some(preprocess_pipeline) =
231                pipeline_cache.get_compute_pipeline(preprocess_pipeline_id)
232            else {
233                // This will happen while the pipeline is being compiled and is fine.
234                return Ok(());
235            };
236
237            compute_pass.set_pipeline(preprocess_pipeline);
238
239            let mut dynamic_offsets: SmallVec<[u32; 1]> = smallvec![];
240            if gpu_culling {
241                dynamic_offsets.push(view_uniform_offset.offset);
242            }
243            compute_pass.set_bind_group(0, &bind_group.0, &dynamic_offsets);
244
245            let workgroup_count = index_buffer.buffer.len().div_ceil(WORKGROUP_SIZE);
246            compute_pass.dispatch_workgroups(workgroup_count as u32, 1, 1);
247        }
248
249        Ok(())
250    }
251}
252
253impl PreprocessPipelines {
254    pub(crate) fn pipelines_are_loaded(&self, pipeline_cache: &PipelineCache) -> bool {
255        self.direct.is_loaded(pipeline_cache) && self.gpu_culling.is_loaded(pipeline_cache)
256    }
257}
258
259impl PreprocessPipeline {
260    fn is_loaded(&self, pipeline_cache: &PipelineCache) -> bool {
261        self.pipeline_id
262            .is_some_and(|pipeline_id| pipeline_cache.get_compute_pipeline(pipeline_id).is_some())
263    }
264}
265
266impl SpecializedComputePipeline for PreprocessPipeline {
267    type Key = PreprocessPipelineKey;
268
269    fn specialize(&self, key: Self::Key) -> ComputePipelineDescriptor {
270        let mut shader_defs = vec![];
271        if key.contains(PreprocessPipelineKey::GPU_CULLING) {
272            shader_defs.push("INDIRECT".into());
273            shader_defs.push("FRUSTUM_CULLING".into());
274        }
275
276        ComputePipelineDescriptor {
277            label: Some(
278                format!(
279                    "mesh preprocessing ({})",
280                    if key.contains(PreprocessPipelineKey::GPU_CULLING) {
281                        "GPU culling"
282                    } else {
283                        "direct"
284                    }
285                )
286                .into(),
287            ),
288            layout: vec![self.bind_group_layout.clone()],
289            push_constant_ranges: vec![],
290            shader: MESH_PREPROCESS_SHADER_HANDLE,
291            shader_defs,
292            entry_point: "main".into(),
293            zero_initialize_workgroup_memory: false,
294        }
295    }
296}
297
298impl FromWorld for PreprocessPipelines {
299    fn from_world(world: &mut World) -> Self {
300        let render_device = world.resource::<RenderDevice>();
301
302        // GPU culling bind group parameters are a superset of those in the CPU
303        // culling (direct) shader.
304        let direct_bind_group_layout_entries = preprocess_direct_bind_group_layout_entries();
305        let gpu_culling_bind_group_layout_entries = preprocess_direct_bind_group_layout_entries()
306            .extend_sequential((
307                // `indirect_parameters`
308                storage_buffer::<IndirectParameters>(/* has_dynamic_offset= */ false),
309                // `mesh_culling_data`
310                storage_buffer_read_only::<MeshCullingData>(/* has_dynamic_offset= */ false),
311                // `view`
312                uniform_buffer::<ViewUniform>(/* has_dynamic_offset= */ true),
313            ));
314
315        let direct_bind_group_layout = render_device.create_bind_group_layout(
316            "build mesh uniforms direct bind group layout",
317            &direct_bind_group_layout_entries,
318        );
319        let gpu_culling_bind_group_layout = render_device.create_bind_group_layout(
320            "build mesh uniforms GPU culling bind group layout",
321            &gpu_culling_bind_group_layout_entries,
322        );
323
324        PreprocessPipelines {
325            direct: PreprocessPipeline {
326                bind_group_layout: direct_bind_group_layout,
327                pipeline_id: None,
328            },
329            gpu_culling: PreprocessPipeline {
330                bind_group_layout: gpu_culling_bind_group_layout,
331                pipeline_id: None,
332            },
333        }
334    }
335}
336
337fn preprocess_direct_bind_group_layout_entries() -> DynamicBindGroupLayoutEntries {
338    DynamicBindGroupLayoutEntries::sequential(
339        ShaderStages::COMPUTE,
340        (
341            // `current_input`
342            storage_buffer_read_only::<MeshInputUniform>(false),
343            // `previous_input`
344            storage_buffer_read_only::<MeshInputUniform>(false),
345            // `indices`
346            storage_buffer_read_only::<PreprocessWorkItem>(false),
347            // `output`
348            storage_buffer::<MeshUniform>(false),
349        ),
350    )
351}
352
353/// A system that specializes the `mesh_preprocess.wgsl` pipelines if necessary.
354pub fn prepare_preprocess_pipelines(
355    pipeline_cache: Res<PipelineCache>,
356    mut pipelines: ResMut<SpecializedComputePipelines<PreprocessPipeline>>,
357    mut preprocess_pipelines: ResMut<PreprocessPipelines>,
358) {
359    preprocess_pipelines.direct.prepare(
360        &pipeline_cache,
361        &mut pipelines,
362        PreprocessPipelineKey::empty(),
363    );
364    preprocess_pipelines.gpu_culling.prepare(
365        &pipeline_cache,
366        &mut pipelines,
367        PreprocessPipelineKey::GPU_CULLING,
368    );
369}
370
371impl PreprocessPipeline {
372    fn prepare(
373        &mut self,
374        pipeline_cache: &PipelineCache,
375        pipelines: &mut SpecializedComputePipelines<PreprocessPipeline>,
376        key: PreprocessPipelineKey,
377    ) {
378        if self.pipeline_id.is_some() {
379            return;
380        }
381
382        let preprocess_pipeline_id = pipelines.specialize(pipeline_cache, self, key);
383        self.pipeline_id = Some(preprocess_pipeline_id);
384    }
385}
386
387/// A system that attaches the mesh uniform buffers to the bind groups for the
388/// variants of the mesh preprocessing compute shader.
389pub fn prepare_preprocess_bind_groups(
390    mut commands: Commands,
391    render_device: Res<RenderDevice>,
392    batched_instance_buffers: Res<BatchedInstanceBuffers<MeshUniform, MeshInputUniform>>,
393    indirect_parameters_buffer: Res<IndirectParametersBuffer>,
394    mesh_culling_data_buffer: Res<MeshCullingDataBuffer>,
395    view_uniforms: Res<ViewUniforms>,
396    pipelines: Res<PreprocessPipelines>,
397) {
398    // Grab the `BatchedInstanceBuffers`.
399    let BatchedInstanceBuffers {
400        data_buffer: ref data_buffer_vec,
401        work_item_buffers: ref index_buffers,
402        current_input_buffer: ref current_input_buffer_vec,
403        previous_input_buffer: ref previous_input_buffer_vec,
404    } = batched_instance_buffers.into_inner();
405
406    let (Some(current_input_buffer), Some(previous_input_buffer), Some(data_buffer)) = (
407        current_input_buffer_vec.buffer(),
408        previous_input_buffer_vec.buffer(),
409        data_buffer_vec.buffer(),
410    ) else {
411        return;
412    };
413
414    for (view, index_buffer_vec) in index_buffers {
415        let Some(index_buffer) = index_buffer_vec.buffer.buffer() else {
416            continue;
417        };
418
419        // Don't use `as_entire_binding()` here; the shader reads the array
420        // length and the underlying buffer may be longer than the actual size
421        // of the vector.
422        let index_buffer_size = NonZero::<u64>::try_from(
423            index_buffer_vec.buffer.len() as u64 * u64::from(PreprocessWorkItem::min_size()),
424        )
425        .ok();
426
427        let bind_group = if index_buffer_vec.gpu_culling {
428            let (
429                Some(indirect_parameters_buffer),
430                Some(mesh_culling_data_buffer),
431                Some(view_uniforms_binding),
432            ) = (
433                indirect_parameters_buffer.buffer(),
434                mesh_culling_data_buffer.buffer(),
435                view_uniforms.uniforms.binding(),
436            )
437            else {
438                continue;
439            };
440
441            PreprocessBindGroup(render_device.create_bind_group(
442                "preprocess_gpu_culling_bind_group",
443                &pipelines.gpu_culling.bind_group_layout,
444                &BindGroupEntries::sequential((
445                    current_input_buffer.as_entire_binding(),
446                    previous_input_buffer.as_entire_binding(),
447                    BindingResource::Buffer(BufferBinding {
448                        buffer: index_buffer,
449                        offset: 0,
450                        size: index_buffer_size,
451                    }),
452                    data_buffer.as_entire_binding(),
453                    indirect_parameters_buffer.as_entire_binding(),
454                    mesh_culling_data_buffer.as_entire_binding(),
455                    view_uniforms_binding,
456                )),
457            ))
458        } else {
459            PreprocessBindGroup(render_device.create_bind_group(
460                "preprocess_direct_bind_group",
461                &pipelines.direct.bind_group_layout,
462                &BindGroupEntries::sequential((
463                    current_input_buffer.as_entire_binding(),
464                    previous_input_buffer.as_entire_binding(),
465                    BindingResource::Buffer(BufferBinding {
466                        buffer: index_buffer,
467                        offset: 0,
468                        size: index_buffer_size,
469                    }),
470                    data_buffer.as_entire_binding(),
471                )),
472            ))
473        };
474
475        commands.entity(*view).insert(bind_group);
476    }
477}
478
479/// Writes the information needed to do GPU mesh culling to the GPU.
480pub fn write_mesh_culling_data_buffer(
481    render_device: Res<RenderDevice>,
482    render_queue: Res<RenderQueue>,
483    mut mesh_culling_data_buffer: ResMut<MeshCullingDataBuffer>,
484) {
485    mesh_culling_data_buffer.write_buffer(&render_device, &render_queue);
486    mesh_culling_data_buffer.clear();
487}