bevy_pbr/
extended_material.rs

1use alloc::borrow::Cow;
2
3use bevy_asset::Asset;
4use bevy_ecs::system::SystemParamItem;
5use bevy_mesh::MeshVertexBufferLayoutRef;
6use bevy_platform::{collections::HashSet, hash::FixedHasher};
7use bevy_reflect::{impl_type_path, Reflect};
8use bevy_render::{
9    alpha::AlphaMode,
10    render_resource::{
11        AsBindGroup, AsBindGroupError, BindGroupLayout, BindGroupLayoutEntry, BindlessDescriptor,
12        BindlessResourceType, BindlessSlabResourceLimit, RenderPipelineDescriptor,
13        SpecializedMeshPipelineError, UnpreparedBindGroup,
14    },
15    renderer::RenderDevice,
16};
17use bevy_shader::ShaderRef;
18
19use crate::{Material, MaterialPipeline, MaterialPipelineKey, MeshPipeline, MeshPipelineKey};
20
21pub struct MaterialExtensionPipeline {
22    pub mesh_pipeline: MeshPipeline,
23}
24
25pub struct MaterialExtensionKey<E: MaterialExtension> {
26    pub mesh_key: MeshPipelineKey,
27    pub bind_group_data: E::Data,
28}
29
30/// A subset of the `Material` trait for defining extensions to a base `Material`, such as the builtin `StandardMaterial`.
31///
32/// A user type implementing the trait should be used as the `E` generic param in an `ExtendedMaterial` struct.
33pub trait MaterialExtension: Asset + AsBindGroup + Clone + Sized {
34    /// Returns this material's vertex shader. If [`ShaderRef::Default`] is returned, the base material mesh vertex shader
35    /// will be used.
36    fn vertex_shader() -> ShaderRef {
37        ShaderRef::Default
38    }
39
40    /// Returns this material's fragment shader. If [`ShaderRef::Default`] is returned, the base material mesh fragment shader
41    /// will be used.
42    fn fragment_shader() -> ShaderRef {
43        ShaderRef::Default
44    }
45
46    // Returns this material’s AlphaMode. If None is returned, the base material alpha mode will be used.
47    fn alpha_mode() -> Option<AlphaMode> {
48        None
49    }
50
51    /// Controls if the prepass is enabled for the Material.
52    /// For more information about what a prepass is, see the [`bevy_core_pipeline::prepass`] docs.
53    #[inline]
54    fn enable_prepass() -> bool {
55        true
56    }
57
58    /// Controls if shadows are enabled for the Material.
59    #[inline]
60    fn enable_shadows() -> bool {
61        true
62    }
63
64    /// Returns this material's prepass vertex shader. If [`ShaderRef::Default`] is returned, the base material prepass vertex shader
65    /// will be used.
66    fn prepass_vertex_shader() -> ShaderRef {
67        ShaderRef::Default
68    }
69
70    /// Returns this material's prepass fragment shader. If [`ShaderRef::Default`] is returned, the base material prepass fragment shader
71    /// will be used.
72    fn prepass_fragment_shader() -> ShaderRef {
73        ShaderRef::Default
74    }
75
76    /// Returns this material's deferred vertex shader. If [`ShaderRef::Default`] is returned, the base material deferred vertex shader
77    /// will be used.
78    fn deferred_vertex_shader() -> ShaderRef {
79        ShaderRef::Default
80    }
81
82    /// Returns this material's prepass fragment shader. If [`ShaderRef::Default`] is returned, the base material deferred fragment shader
83    /// will be used.
84    fn deferred_fragment_shader() -> ShaderRef {
85        ShaderRef::Default
86    }
87
88    /// Returns this material's [`crate::meshlet::MeshletMesh`] fragment shader. If [`ShaderRef::Default`] is returned,
89    /// the default meshlet mesh fragment shader will be used.
90    #[cfg(feature = "meshlet")]
91    fn meshlet_mesh_fragment_shader() -> ShaderRef {
92        ShaderRef::Default
93    }
94
95    /// Returns this material's [`crate::meshlet::MeshletMesh`] prepass fragment shader. If [`ShaderRef::Default`] is returned,
96    /// the default meshlet mesh prepass fragment shader will be used.
97    #[cfg(feature = "meshlet")]
98    fn meshlet_mesh_prepass_fragment_shader() -> ShaderRef {
99        ShaderRef::Default
100    }
101
102    /// Returns this material's [`crate::meshlet::MeshletMesh`] deferred fragment shader. If [`ShaderRef::Default`] is returned,
103    /// the default meshlet mesh deferred fragment shader will be used.
104    #[cfg(feature = "meshlet")]
105    fn meshlet_mesh_deferred_fragment_shader() -> ShaderRef {
106        ShaderRef::Default
107    }
108
109    /// Customizes the default [`RenderPipelineDescriptor`] for a specific entity using the entity's
110    /// [`MaterialPipelineKey`] and [`MeshVertexBufferLayoutRef`] as input.
111    /// Specialization for the base material is applied before this function is called.
112    #[expect(
113        unused_variables,
114        reason = "The parameters here are intentionally unused by the default implementation; however, putting underscores here will result in the underscores being copied by rust-analyzer's tab completion."
115    )]
116    #[inline]
117    fn specialize(
118        pipeline: &MaterialExtensionPipeline,
119        descriptor: &mut RenderPipelineDescriptor,
120        layout: &MeshVertexBufferLayoutRef,
121        key: MaterialExtensionKey<Self>,
122    ) -> Result<(), SpecializedMeshPipelineError> {
123        Ok(())
124    }
125}
126
127/// A material that extends a base [`Material`] with additional shaders and data.
128///
129/// The data from both materials will be combined and made available to the shader
130/// so that shader functions built for the base material (and referencing the base material
131/// bindings) will work as expected, and custom alterations based on custom data can also be used.
132///
133/// If the extension `E` returns a non-default result from `vertex_shader()` it will be used in place of the base
134/// material's vertex shader.
135///
136/// If the extension `E` returns a non-default result from `fragment_shader()` it will be used in place of the base
137/// fragment shader.
138///
139/// When used with `StandardMaterial` as the base, all the standard material fields are
140/// present, so the `pbr_fragment` shader functions can be called from the extension shader (see
141/// the `extended_material` example).
142#[derive(Asset, Clone, Debug, Reflect)]
143#[reflect(type_path = false)]
144#[reflect(Clone)]
145pub struct ExtendedMaterial<B: Material, E: MaterialExtension> {
146    pub base: B,
147    pub extension: E,
148}
149
150impl<B, E> Default for ExtendedMaterial<B, E>
151where
152    B: Material + Default,
153    E: MaterialExtension + Default,
154{
155    fn default() -> Self {
156        Self {
157            base: B::default(),
158            extension: E::default(),
159        }
160    }
161}
162
163#[derive(Copy, Clone, PartialEq, Eq, Hash)]
164#[repr(C, packed)]
165pub struct MaterialExtensionBindGroupData<B, E> {
166    pub base: B,
167    pub extension: E,
168}
169
170// We don't use the `TypePath` derive here due to a bug where `#[reflect(type_path = false)]`
171// causes the `TypePath` derive to not generate an implementation.
172impl_type_path!((in bevy_pbr::extended_material) ExtendedMaterial<B: Material, E: MaterialExtension>);
173
174impl<B: Material, E: MaterialExtension> AsBindGroup for ExtendedMaterial<B, E> {
175    type Data = MaterialExtensionBindGroupData<B::Data, E::Data>;
176    type Param = (<B as AsBindGroup>::Param, <E as AsBindGroup>::Param);
177
178    fn bindless_slot_count() -> Option<BindlessSlabResourceLimit> {
179        // We only enable bindless if both the base material and its extension
180        // are bindless. If we do enable bindless, we choose the smaller of the
181        // two slab size limits.
182        match (B::bindless_slot_count()?, E::bindless_slot_count()?) {
183            (BindlessSlabResourceLimit::Auto, BindlessSlabResourceLimit::Auto) => {
184                Some(BindlessSlabResourceLimit::Auto)
185            }
186            (BindlessSlabResourceLimit::Auto, BindlessSlabResourceLimit::Custom(limit))
187            | (BindlessSlabResourceLimit::Custom(limit), BindlessSlabResourceLimit::Auto) => {
188                Some(BindlessSlabResourceLimit::Custom(limit))
189            }
190            (
191                BindlessSlabResourceLimit::Custom(base_limit),
192                BindlessSlabResourceLimit::Custom(extended_limit),
193            ) => Some(BindlessSlabResourceLimit::Custom(
194                base_limit.min(extended_limit),
195            )),
196        }
197    }
198
199    fn bindless_supported(render_device: &RenderDevice) -> bool {
200        B::bindless_supported(render_device) && E::bindless_supported(render_device)
201    }
202
203    fn label() -> &'static str {
204        E::label()
205    }
206
207    fn bind_group_data(&self) -> Self::Data {
208        MaterialExtensionBindGroupData {
209            base: self.base.bind_group_data(),
210            extension: self.extension.bind_group_data(),
211        }
212    }
213
214    fn unprepared_bind_group(
215        &self,
216        layout: &BindGroupLayout,
217        render_device: &RenderDevice,
218        (base_param, extended_param): &mut SystemParamItem<'_, '_, Self::Param>,
219        mut force_non_bindless: bool,
220    ) -> Result<UnpreparedBindGroup, AsBindGroupError> {
221        force_non_bindless = force_non_bindless || Self::bindless_slot_count().is_none();
222
223        // add together the bindings of the base material and the extension
224        let UnpreparedBindGroup { mut bindings } = B::unprepared_bind_group(
225            &self.base,
226            layout,
227            render_device,
228            base_param,
229            force_non_bindless,
230        )?;
231        let UnpreparedBindGroup {
232            bindings: extension_bindings,
233        } = E::unprepared_bind_group(
234            &self.extension,
235            layout,
236            render_device,
237            extended_param,
238            force_non_bindless,
239        )?;
240
241        bindings.extend(extension_bindings.0);
242
243        Ok(UnpreparedBindGroup { bindings })
244    }
245
246    fn bind_group_layout_entries(
247        render_device: &RenderDevice,
248        mut force_non_bindless: bool,
249    ) -> Vec<BindGroupLayoutEntry>
250    where
251        Self: Sized,
252    {
253        force_non_bindless = force_non_bindless || Self::bindless_slot_count().is_none();
254
255        // Add together the bindings of the standard material and the user
256        // material, skipping duplicate bindings. Duplicate bindings will occur
257        // when bindless mode is on, because of the common bindless resource
258        // arrays, and we need to eliminate the duplicates or `wgpu` will
259        // complain.
260        let base_entries = B::bind_group_layout_entries(render_device, force_non_bindless);
261        let extension_entries = E::bind_group_layout_entries(render_device, force_non_bindless);
262
263        let mut seen_bindings = HashSet::<u32>::with_hasher(FixedHasher);
264
265        base_entries
266            .into_iter()
267            .chain(extension_entries)
268            .filter(|entry| seen_bindings.insert(entry.binding))
269            .collect()
270    }
271
272    fn bindless_descriptor() -> Option<BindlessDescriptor> {
273        // We're going to combine the two bindless descriptors.
274        let base_bindless_descriptor = B::bindless_descriptor()?;
275        let extended_bindless_descriptor = E::bindless_descriptor()?;
276
277        // Combining the buffers and index tables is straightforward.
278
279        let mut buffers = base_bindless_descriptor.buffers.to_vec();
280        let mut index_tables = base_bindless_descriptor.index_tables.to_vec();
281
282        buffers.extend(extended_bindless_descriptor.buffers.iter().cloned());
283        index_tables.extend(extended_bindless_descriptor.index_tables.iter().cloned());
284
285        // Combining the resources is a little trickier because the resource
286        // array is indexed by bindless index, so we have to merge the two
287        // arrays, not just concatenate them.
288        let max_bindless_index = base_bindless_descriptor
289            .resources
290            .len()
291            .max(extended_bindless_descriptor.resources.len());
292        let mut resources = Vec::with_capacity(max_bindless_index);
293        for bindless_index in 0..max_bindless_index {
294            // In the event of a conflicting bindless index, we choose the
295            // base's binding.
296            match base_bindless_descriptor.resources.get(bindless_index) {
297                None | Some(&BindlessResourceType::None) => resources.push(
298                    extended_bindless_descriptor
299                        .resources
300                        .get(bindless_index)
301                        .copied()
302                        .unwrap_or(BindlessResourceType::None),
303                ),
304                Some(&resource_type) => resources.push(resource_type),
305            }
306        }
307
308        Some(BindlessDescriptor {
309            resources: Cow::Owned(resources),
310            buffers: Cow::Owned(buffers),
311            index_tables: Cow::Owned(index_tables),
312        })
313    }
314}
315
316impl<B: Material, E: MaterialExtension> Material for ExtendedMaterial<B, E> {
317    fn vertex_shader() -> ShaderRef {
318        match E::vertex_shader() {
319            ShaderRef::Default => B::vertex_shader(),
320            specified => specified,
321        }
322    }
323
324    fn fragment_shader() -> ShaderRef {
325        match E::fragment_shader() {
326            ShaderRef::Default => B::fragment_shader(),
327            specified => specified,
328        }
329    }
330
331    fn alpha_mode(&self) -> AlphaMode {
332        match E::alpha_mode() {
333            Some(specified) => specified,
334            None => B::alpha_mode(&self.base),
335        }
336    }
337
338    fn opaque_render_method(&self) -> crate::OpaqueRendererMethod {
339        B::opaque_render_method(&self.base)
340    }
341
342    fn depth_bias(&self) -> f32 {
343        B::depth_bias(&self.base)
344    }
345
346    fn reads_view_transmission_texture(&self) -> bool {
347        B::reads_view_transmission_texture(&self.base)
348    }
349
350    fn enable_prepass() -> bool {
351        E::enable_prepass()
352    }
353
354    fn enable_shadows() -> bool {
355        E::enable_shadows()
356    }
357
358    fn prepass_vertex_shader() -> ShaderRef {
359        match E::prepass_vertex_shader() {
360            ShaderRef::Default => B::prepass_vertex_shader(),
361            specified => specified,
362        }
363    }
364
365    fn prepass_fragment_shader() -> ShaderRef {
366        match E::prepass_fragment_shader() {
367            ShaderRef::Default => B::prepass_fragment_shader(),
368            specified => specified,
369        }
370    }
371
372    fn deferred_vertex_shader() -> ShaderRef {
373        match E::deferred_vertex_shader() {
374            ShaderRef::Default => B::deferred_vertex_shader(),
375            specified => specified,
376        }
377    }
378
379    fn deferred_fragment_shader() -> ShaderRef {
380        match E::deferred_fragment_shader() {
381            ShaderRef::Default => B::deferred_fragment_shader(),
382            specified => specified,
383        }
384    }
385
386    #[cfg(feature = "meshlet")]
387    fn meshlet_mesh_fragment_shader() -> ShaderRef {
388        match E::meshlet_mesh_fragment_shader() {
389            ShaderRef::Default => B::meshlet_mesh_fragment_shader(),
390            specified => specified,
391        }
392    }
393
394    #[cfg(feature = "meshlet")]
395    fn meshlet_mesh_prepass_fragment_shader() -> ShaderRef {
396        match E::meshlet_mesh_prepass_fragment_shader() {
397            ShaderRef::Default => B::meshlet_mesh_prepass_fragment_shader(),
398            specified => specified,
399        }
400    }
401
402    #[cfg(feature = "meshlet")]
403    fn meshlet_mesh_deferred_fragment_shader() -> ShaderRef {
404        match E::meshlet_mesh_deferred_fragment_shader() {
405            ShaderRef::Default => B::meshlet_mesh_deferred_fragment_shader(),
406            specified => specified,
407        }
408    }
409
410    fn specialize(
411        pipeline: &MaterialPipeline,
412        descriptor: &mut RenderPipelineDescriptor,
413        layout: &MeshVertexBufferLayoutRef,
414        key: MaterialPipelineKey<Self>,
415    ) -> Result<(), SpecializedMeshPipelineError> {
416        // Call the base material's specialize function
417        let base_key = MaterialPipelineKey::<B> {
418            mesh_key: key.mesh_key,
419            bind_group_data: key.bind_group_data.base,
420        };
421        B::specialize(pipeline, descriptor, layout, base_key)?;
422
423        // Call the extended material's specialize function afterwards
424        E::specialize(
425            &MaterialExtensionPipeline {
426                mesh_pipeline: pipeline.mesh_pipeline.clone(),
427            },
428            descriptor,
429            layout,
430            MaterialExtensionKey {
431                mesh_key: key.mesh_key,
432                bind_group_data: key.bind_group_data.extension,
433            },
434        )
435    }
436}