bevy_pbr/
extended_material.rs

1use alloc::borrow::Cow;
2
3use bevy_asset::{Asset, Handle};
4use bevy_ecs::system::SystemParamItem;
5use bevy_platform::{collections::HashSet, hash::FixedHasher};
6use bevy_reflect::{impl_type_path, Reflect};
7use bevy_render::{
8    alpha::AlphaMode,
9    mesh::MeshVertexBufferLayoutRef,
10    render_resource::{
11        AsBindGroup, AsBindGroupError, BindGroupLayout, BindGroupLayoutEntry, BindlessDescriptor,
12        BindlessResourceType, BindlessSlabResourceLimit, RenderPipelineDescriptor, Shader,
13        ShaderRef, SpecializedMeshPipelineError, UnpreparedBindGroup,
14    },
15    renderer::RenderDevice,
16};
17
18use crate::{Material, MaterialPipeline, MaterialPipelineKey, MeshPipeline, MeshPipelineKey};
19
20pub struct MaterialExtensionPipeline {
21    pub mesh_pipeline: MeshPipeline,
22    pub material_layout: BindGroupLayout,
23    pub vertex_shader: Option<Handle<Shader>>,
24    pub fragment_shader: Option<Handle<Shader>>,
25    pub bindless: bool,
26}
27
28pub struct MaterialExtensionKey<E: MaterialExtension> {
29    pub mesh_key: MeshPipelineKey,
30    pub bind_group_data: E::Data,
31}
32
33/// A subset of the `Material` trait for defining extensions to a base `Material`, such as the builtin `StandardMaterial`.
34///
35/// A user type implementing the trait should be used as the `E` generic param in an `ExtendedMaterial` struct.
36pub trait MaterialExtension: Asset + AsBindGroup + Clone + Sized {
37    /// Returns this material's vertex shader. If [`ShaderRef::Default`] is returned, the base material mesh vertex shader
38    /// will be used.
39    fn vertex_shader() -> ShaderRef {
40        ShaderRef::Default
41    }
42
43    /// Returns this material's fragment shader. If [`ShaderRef::Default`] is returned, the base material mesh fragment shader
44    /// will be used.
45    fn fragment_shader() -> ShaderRef {
46        ShaderRef::Default
47    }
48
49    // Returns this material’s AlphaMode. If None is returned, the base material alpha mode will be used.
50    fn alpha_mode() -> Option<AlphaMode> {
51        None
52    }
53
54    /// Returns this material's prepass vertex shader. If [`ShaderRef::Default`] is returned, the base material prepass vertex shader
55    /// will be used.
56    fn prepass_vertex_shader() -> ShaderRef {
57        ShaderRef::Default
58    }
59
60    /// Returns this material's prepass fragment shader. If [`ShaderRef::Default`] is returned, the base material prepass fragment shader
61    /// will be used.
62    fn prepass_fragment_shader() -> ShaderRef {
63        ShaderRef::Default
64    }
65
66    /// Returns this material's deferred vertex shader. If [`ShaderRef::Default`] is returned, the base material deferred vertex shader
67    /// will be used.
68    fn deferred_vertex_shader() -> ShaderRef {
69        ShaderRef::Default
70    }
71
72    /// Returns this material's prepass fragment shader. If [`ShaderRef::Default`] is returned, the base material deferred fragment shader
73    /// will be used.
74    fn deferred_fragment_shader() -> ShaderRef {
75        ShaderRef::Default
76    }
77
78    /// Returns this material's [`crate::meshlet::MeshletMesh`] fragment shader. If [`ShaderRef::Default`] is returned,
79    /// the default meshlet mesh fragment shader will be used.
80    #[cfg(feature = "meshlet")]
81    fn meshlet_mesh_fragment_shader() -> ShaderRef {
82        ShaderRef::Default
83    }
84
85    /// Returns this material's [`crate::meshlet::MeshletMesh`] prepass fragment shader. If [`ShaderRef::Default`] is returned,
86    /// the default meshlet mesh prepass fragment shader will be used.
87    #[cfg(feature = "meshlet")]
88    fn meshlet_mesh_prepass_fragment_shader() -> ShaderRef {
89        ShaderRef::Default
90    }
91
92    /// Returns this material's [`crate::meshlet::MeshletMesh`] deferred fragment shader. If [`ShaderRef::Default`] is returned,
93    /// the default meshlet mesh deferred fragment shader will be used.
94    #[cfg(feature = "meshlet")]
95    fn meshlet_mesh_deferred_fragment_shader() -> ShaderRef {
96        ShaderRef::Default
97    }
98
99    /// Customizes the default [`RenderPipelineDescriptor`] for a specific entity using the entity's
100    /// [`MaterialPipelineKey`] and [`MeshVertexBufferLayoutRef`] as input.
101    /// Specialization for the base material is applied before this function is called.
102    #[expect(
103        unused_variables,
104        reason = "The parameters here are intentionally unused by the default implementation; however, putting underscores here will result in the underscores being copied by rust-analyzer's tab completion."
105    )]
106    #[inline]
107    fn specialize(
108        pipeline: &MaterialExtensionPipeline,
109        descriptor: &mut RenderPipelineDescriptor,
110        layout: &MeshVertexBufferLayoutRef,
111        key: MaterialExtensionKey<Self>,
112    ) -> Result<(), SpecializedMeshPipelineError> {
113        Ok(())
114    }
115}
116
117/// A material that extends a base [`Material`] with additional shaders and data.
118///
119/// The data from both materials will be combined and made available to the shader
120/// so that shader functions built for the base material (and referencing the base material
121/// bindings) will work as expected, and custom alterations based on custom data can also be used.
122///
123/// If the extension `E` returns a non-default result from `vertex_shader()` it will be used in place of the base
124/// material's vertex shader.
125///
126/// If the extension `E` returns a non-default result from `fragment_shader()` it will be used in place of the base
127/// fragment shader.
128///
129/// When used with `StandardMaterial` as the base, all the standard material fields are
130/// present, so the `pbr_fragment` shader functions can be called from the extension shader (see
131/// the `extended_material` example).
132#[derive(Asset, Clone, Debug, Reflect)]
133#[reflect(type_path = false)]
134#[reflect(Clone)]
135pub struct ExtendedMaterial<B: Material, E: MaterialExtension> {
136    pub base: B,
137    pub extension: E,
138}
139
140impl<B, E> Default for ExtendedMaterial<B, E>
141where
142    B: Material + Default,
143    E: MaterialExtension + Default,
144{
145    fn default() -> Self {
146        Self {
147            base: B::default(),
148            extension: E::default(),
149        }
150    }
151}
152
153// We don't use the `TypePath` derive here due to a bug where `#[reflect(type_path = false)]`
154// causes the `TypePath` derive to not generate an implementation.
155impl_type_path!((in bevy_pbr::extended_material) ExtendedMaterial<B: Material, E: MaterialExtension>);
156
157impl<B: Material, E: MaterialExtension> AsBindGroup for ExtendedMaterial<B, E> {
158    type Data = (<B as AsBindGroup>::Data, <E as AsBindGroup>::Data);
159    type Param = (<B as AsBindGroup>::Param, <E as AsBindGroup>::Param);
160
161    fn bindless_slot_count() -> Option<BindlessSlabResourceLimit> {
162        // We only enable bindless if both the base material and its extension
163        // are bindless. If we do enable bindless, we choose the smaller of the
164        // two slab size limits.
165        match (B::bindless_slot_count()?, E::bindless_slot_count()?) {
166            (BindlessSlabResourceLimit::Auto, BindlessSlabResourceLimit::Auto) => {
167                Some(BindlessSlabResourceLimit::Auto)
168            }
169            (BindlessSlabResourceLimit::Auto, BindlessSlabResourceLimit::Custom(limit))
170            | (BindlessSlabResourceLimit::Custom(limit), BindlessSlabResourceLimit::Auto) => {
171                Some(BindlessSlabResourceLimit::Custom(limit))
172            }
173            (
174                BindlessSlabResourceLimit::Custom(base_limit),
175                BindlessSlabResourceLimit::Custom(extended_limit),
176            ) => Some(BindlessSlabResourceLimit::Custom(
177                base_limit.min(extended_limit),
178            )),
179        }
180    }
181
182    fn unprepared_bind_group(
183        &self,
184        layout: &BindGroupLayout,
185        render_device: &RenderDevice,
186        (base_param, extended_param): &mut SystemParamItem<'_, '_, Self::Param>,
187        mut force_non_bindless: bool,
188    ) -> Result<UnpreparedBindGroup<Self::Data>, AsBindGroupError> {
189        force_non_bindless = force_non_bindless || Self::bindless_slot_count().is_none();
190
191        // add together the bindings of the base material and the user material
192        let UnpreparedBindGroup {
193            mut bindings,
194            data: base_data,
195        } = B::unprepared_bind_group(
196            &self.base,
197            layout,
198            render_device,
199            base_param,
200            force_non_bindless,
201        )?;
202        let extended_bindgroup = E::unprepared_bind_group(
203            &self.extension,
204            layout,
205            render_device,
206            extended_param,
207            force_non_bindless,
208        )?;
209
210        bindings.extend(extended_bindgroup.bindings.0);
211
212        Ok(UnpreparedBindGroup {
213            bindings,
214            data: (base_data, extended_bindgroup.data),
215        })
216    }
217
218    fn bind_group_layout_entries(
219        render_device: &RenderDevice,
220        mut force_non_bindless: bool,
221    ) -> Vec<BindGroupLayoutEntry>
222    where
223        Self: Sized,
224    {
225        force_non_bindless = force_non_bindless || Self::bindless_slot_count().is_none();
226
227        // Add together the bindings of the standard material and the user
228        // material, skipping duplicate bindings. Duplicate bindings will occur
229        // when bindless mode is on, because of the common bindless resource
230        // arrays, and we need to eliminate the duplicates or `wgpu` will
231        // complain.
232        let mut entries = vec![];
233        let mut seen_bindings = HashSet::<_>::with_hasher(FixedHasher);
234        for entry in B::bind_group_layout_entries(render_device, force_non_bindless)
235            .into_iter()
236            .chain(E::bind_group_layout_entries(render_device, force_non_bindless).into_iter())
237        {
238            if seen_bindings.insert(entry.binding) {
239                entries.push(entry);
240            }
241        }
242        entries
243    }
244
245    fn bindless_descriptor() -> Option<BindlessDescriptor> {
246        // We're going to combine the two bindless descriptors.
247        let base_bindless_descriptor = B::bindless_descriptor()?;
248        let extended_bindless_descriptor = E::bindless_descriptor()?;
249
250        // Combining the buffers and index tables is straightforward.
251
252        let mut buffers = base_bindless_descriptor.buffers.to_vec();
253        let mut index_tables = base_bindless_descriptor.index_tables.to_vec();
254
255        buffers.extend(extended_bindless_descriptor.buffers.iter().cloned());
256        index_tables.extend(extended_bindless_descriptor.index_tables.iter().cloned());
257
258        // Combining the resources is a little trickier because the resource
259        // array is indexed by bindless index, so we have to merge the two
260        // arrays, not just concatenate them.
261        let max_bindless_index = base_bindless_descriptor
262            .resources
263            .len()
264            .max(extended_bindless_descriptor.resources.len());
265        let mut resources = Vec::with_capacity(max_bindless_index);
266        for bindless_index in 0..max_bindless_index {
267            // In the event of a conflicting bindless index, we choose the
268            // base's binding.
269            match base_bindless_descriptor.resources.get(bindless_index) {
270                None | Some(&BindlessResourceType::None) => resources.push(
271                    extended_bindless_descriptor
272                        .resources
273                        .get(bindless_index)
274                        .copied()
275                        .unwrap_or(BindlessResourceType::None),
276                ),
277                Some(&resource_type) => resources.push(resource_type),
278            }
279        }
280
281        Some(BindlessDescriptor {
282            resources: Cow::Owned(resources),
283            buffers: Cow::Owned(buffers),
284            index_tables: Cow::Owned(index_tables),
285        })
286    }
287}
288
289impl<B: Material, E: MaterialExtension> Material for ExtendedMaterial<B, E> {
290    fn vertex_shader() -> ShaderRef {
291        match E::vertex_shader() {
292            ShaderRef::Default => B::vertex_shader(),
293            specified => specified,
294        }
295    }
296
297    fn fragment_shader() -> ShaderRef {
298        match E::fragment_shader() {
299            ShaderRef::Default => B::fragment_shader(),
300            specified => specified,
301        }
302    }
303
304    fn alpha_mode(&self) -> AlphaMode {
305        match E::alpha_mode() {
306            Some(specified) => specified,
307            None => B::alpha_mode(&self.base),
308        }
309    }
310
311    fn opaque_render_method(&self) -> crate::OpaqueRendererMethod {
312        B::opaque_render_method(&self.base)
313    }
314
315    fn depth_bias(&self) -> f32 {
316        B::depth_bias(&self.base)
317    }
318
319    fn reads_view_transmission_texture(&self) -> bool {
320        B::reads_view_transmission_texture(&self.base)
321    }
322
323    fn prepass_vertex_shader() -> ShaderRef {
324        match E::prepass_vertex_shader() {
325            ShaderRef::Default => B::prepass_vertex_shader(),
326            specified => specified,
327        }
328    }
329
330    fn prepass_fragment_shader() -> ShaderRef {
331        match E::prepass_fragment_shader() {
332            ShaderRef::Default => B::prepass_fragment_shader(),
333            specified => specified,
334        }
335    }
336
337    fn deferred_vertex_shader() -> ShaderRef {
338        match E::deferred_vertex_shader() {
339            ShaderRef::Default => B::deferred_vertex_shader(),
340            specified => specified,
341        }
342    }
343
344    fn deferred_fragment_shader() -> ShaderRef {
345        match E::deferred_fragment_shader() {
346            ShaderRef::Default => B::deferred_fragment_shader(),
347            specified => specified,
348        }
349    }
350
351    #[cfg(feature = "meshlet")]
352    fn meshlet_mesh_fragment_shader() -> ShaderRef {
353        match E::meshlet_mesh_fragment_shader() {
354            ShaderRef::Default => B::meshlet_mesh_fragment_shader(),
355            specified => specified,
356        }
357    }
358
359    #[cfg(feature = "meshlet")]
360    fn meshlet_mesh_prepass_fragment_shader() -> ShaderRef {
361        match E::meshlet_mesh_prepass_fragment_shader() {
362            ShaderRef::Default => B::meshlet_mesh_prepass_fragment_shader(),
363            specified => specified,
364        }
365    }
366
367    #[cfg(feature = "meshlet")]
368    fn meshlet_mesh_deferred_fragment_shader() -> ShaderRef {
369        match E::meshlet_mesh_deferred_fragment_shader() {
370            ShaderRef::Default => B::meshlet_mesh_deferred_fragment_shader(),
371            specified => specified,
372        }
373    }
374
375    fn specialize(
376        pipeline: &MaterialPipeline<Self>,
377        descriptor: &mut RenderPipelineDescriptor,
378        layout: &MeshVertexBufferLayoutRef,
379        key: MaterialPipelineKey<Self>,
380    ) -> Result<(), SpecializedMeshPipelineError> {
381        // Call the base material's specialize function
382        let MaterialPipeline::<Self> {
383            mesh_pipeline,
384            material_layout,
385            vertex_shader,
386            fragment_shader,
387            bindless,
388            ..
389        } = pipeline.clone();
390        let base_pipeline = MaterialPipeline::<B> {
391            mesh_pipeline,
392            material_layout,
393            vertex_shader,
394            fragment_shader,
395            bindless,
396            marker: Default::default(),
397        };
398        let base_key = MaterialPipelineKey::<B> {
399            mesh_key: key.mesh_key,
400            bind_group_data: key.bind_group_data.0,
401        };
402        B::specialize(&base_pipeline, descriptor, layout, base_key)?;
403
404        // Call the extended material's specialize function afterwards
405        let MaterialPipeline::<Self> {
406            mesh_pipeline,
407            material_layout,
408            vertex_shader,
409            fragment_shader,
410            bindless,
411            ..
412        } = pipeline.clone();
413
414        E::specialize(
415            &MaterialExtensionPipeline {
416                mesh_pipeline,
417                material_layout,
418                vertex_shader,
419                fragment_shader,
420                bindless,
421            },
422            descriptor,
423            layout,
424            MaterialExtensionKey {
425                mesh_key: key.mesh_key,
426                bind_group_data: key.bind_group_data.1,
427            },
428        )
429    }
430}