bevy_pbr/
extended_material.rs

1use alloc::borrow::Cow;
2
3use bevy_asset::Asset;
4use bevy_ecs::system::SystemParamItem;
5use bevy_mesh::MeshVertexBufferLayoutRef;
6use bevy_platform::{collections::HashSet, hash::FixedHasher};
7use bevy_reflect::{impl_type_path, Reflect};
8use bevy_render::{
9    alpha::AlphaMode,
10    render_resource::{
11        AsBindGroup, AsBindGroupError, BindGroupLayout, BindGroupLayoutEntry, BindlessDescriptor,
12        BindlessResourceType, BindlessSlabResourceLimit, RenderPipelineDescriptor,
13        SpecializedMeshPipelineError, UnpreparedBindGroup,
14    },
15    renderer::RenderDevice,
16};
17use bevy_shader::ShaderRef;
18
19use crate::{Material, MaterialPipeline, MaterialPipelineKey, MeshPipeline, MeshPipelineKey};
20
21pub struct MaterialExtensionPipeline {
22    pub mesh_pipeline: MeshPipeline,
23}
24
25pub struct MaterialExtensionKey<E: MaterialExtension> {
26    pub mesh_key: MeshPipelineKey,
27    pub bind_group_data: E::Data,
28}
29
30/// A subset of the `Material` trait for defining extensions to a base `Material`, such as the builtin `StandardMaterial`.
31///
32/// A user type implementing the trait should be used as the `E` generic param in an `ExtendedMaterial` struct.
33pub trait MaterialExtension: Asset + AsBindGroup + Clone + Sized {
34    /// Returns this material's vertex shader. If [`ShaderRef::Default`] is returned, the base material mesh vertex shader
35    /// will be used.
36    fn vertex_shader() -> ShaderRef {
37        ShaderRef::Default
38    }
39
40    /// Returns this material's fragment shader. If [`ShaderRef::Default`] is returned, the base material mesh fragment shader
41    /// will be used.
42    fn fragment_shader() -> ShaderRef {
43        ShaderRef::Default
44    }
45
46    // Returns this material’s AlphaMode. If None is returned, the base material alpha mode will be used.
47    fn alpha_mode() -> Option<AlphaMode> {
48        None
49    }
50
51    /// Returns this material's prepass vertex shader. If [`ShaderRef::Default`] is returned, the base material prepass vertex shader
52    /// will be used.
53    fn prepass_vertex_shader() -> ShaderRef {
54        ShaderRef::Default
55    }
56
57    /// Returns this material's prepass fragment shader. If [`ShaderRef::Default`] is returned, the base material prepass fragment shader
58    /// will be used.
59    fn prepass_fragment_shader() -> ShaderRef {
60        ShaderRef::Default
61    }
62
63    /// Returns this material's deferred vertex shader. If [`ShaderRef::Default`] is returned, the base material deferred vertex shader
64    /// will be used.
65    fn deferred_vertex_shader() -> ShaderRef {
66        ShaderRef::Default
67    }
68
69    /// Returns this material's prepass fragment shader. If [`ShaderRef::Default`] is returned, the base material deferred fragment shader
70    /// will be used.
71    fn deferred_fragment_shader() -> ShaderRef {
72        ShaderRef::Default
73    }
74
75    /// Returns this material's [`crate::meshlet::MeshletMesh`] fragment shader. If [`ShaderRef::Default`] is returned,
76    /// the default meshlet mesh fragment shader will be used.
77    #[cfg(feature = "meshlet")]
78    fn meshlet_mesh_fragment_shader() -> ShaderRef {
79        ShaderRef::Default
80    }
81
82    /// Returns this material's [`crate::meshlet::MeshletMesh`] prepass fragment shader. If [`ShaderRef::Default`] is returned,
83    /// the default meshlet mesh prepass fragment shader will be used.
84    #[cfg(feature = "meshlet")]
85    fn meshlet_mesh_prepass_fragment_shader() -> ShaderRef {
86        ShaderRef::Default
87    }
88
89    /// Returns this material's [`crate::meshlet::MeshletMesh`] deferred fragment shader. If [`ShaderRef::Default`] is returned,
90    /// the default meshlet mesh deferred fragment shader will be used.
91    #[cfg(feature = "meshlet")]
92    fn meshlet_mesh_deferred_fragment_shader() -> ShaderRef {
93        ShaderRef::Default
94    }
95
96    /// Customizes the default [`RenderPipelineDescriptor`] for a specific entity using the entity's
97    /// [`MaterialPipelineKey`] and [`MeshVertexBufferLayoutRef`] as input.
98    /// Specialization for the base material is applied before this function is called.
99    #[expect(
100        unused_variables,
101        reason = "The parameters here are intentionally unused by the default implementation; however, putting underscores here will result in the underscores being copied by rust-analyzer's tab completion."
102    )]
103    #[inline]
104    fn specialize(
105        pipeline: &MaterialExtensionPipeline,
106        descriptor: &mut RenderPipelineDescriptor,
107        layout: &MeshVertexBufferLayoutRef,
108        key: MaterialExtensionKey<Self>,
109    ) -> Result<(), SpecializedMeshPipelineError> {
110        Ok(())
111    }
112}
113
114/// A material that extends a base [`Material`] with additional shaders and data.
115///
116/// The data from both materials will be combined and made available to the shader
117/// so that shader functions built for the base material (and referencing the base material
118/// bindings) will work as expected, and custom alterations based on custom data can also be used.
119///
120/// If the extension `E` returns a non-default result from `vertex_shader()` it will be used in place of the base
121/// material's vertex shader.
122///
123/// If the extension `E` returns a non-default result from `fragment_shader()` it will be used in place of the base
124/// fragment shader.
125///
126/// When used with `StandardMaterial` as the base, all the standard material fields are
127/// present, so the `pbr_fragment` shader functions can be called from the extension shader (see
128/// the `extended_material` example).
129#[derive(Asset, Clone, Debug, Reflect)]
130#[reflect(type_path = false)]
131#[reflect(Clone)]
132pub struct ExtendedMaterial<B: Material, E: MaterialExtension> {
133    pub base: B,
134    pub extension: E,
135}
136
137impl<B, E> Default for ExtendedMaterial<B, E>
138where
139    B: Material + Default,
140    E: MaterialExtension + Default,
141{
142    fn default() -> Self {
143        Self {
144            base: B::default(),
145            extension: E::default(),
146        }
147    }
148}
149
150#[derive(Copy, Clone, PartialEq, Eq, Hash)]
151#[repr(C, packed)]
152pub struct MaterialExtensionBindGroupData<B, E> {
153    pub base: B,
154    pub extension: E,
155}
156
157// We don't use the `TypePath` derive here due to a bug where `#[reflect(type_path = false)]`
158// causes the `TypePath` derive to not generate an implementation.
159impl_type_path!((in bevy_pbr::extended_material) ExtendedMaterial<B: Material, E: MaterialExtension>);
160
161impl<B: Material, E: MaterialExtension> AsBindGroup for ExtendedMaterial<B, E> {
162    type Data = MaterialExtensionBindGroupData<B::Data, E::Data>;
163    type Param = (<B as AsBindGroup>::Param, <E as AsBindGroup>::Param);
164
165    fn bindless_slot_count() -> Option<BindlessSlabResourceLimit> {
166        // We only enable bindless if both the base material and its extension
167        // are bindless. If we do enable bindless, we choose the smaller of the
168        // two slab size limits.
169        match (B::bindless_slot_count()?, E::bindless_slot_count()?) {
170            (BindlessSlabResourceLimit::Auto, BindlessSlabResourceLimit::Auto) => {
171                Some(BindlessSlabResourceLimit::Auto)
172            }
173            (BindlessSlabResourceLimit::Auto, BindlessSlabResourceLimit::Custom(limit))
174            | (BindlessSlabResourceLimit::Custom(limit), BindlessSlabResourceLimit::Auto) => {
175                Some(BindlessSlabResourceLimit::Custom(limit))
176            }
177            (
178                BindlessSlabResourceLimit::Custom(base_limit),
179                BindlessSlabResourceLimit::Custom(extended_limit),
180            ) => Some(BindlessSlabResourceLimit::Custom(
181                base_limit.min(extended_limit),
182            )),
183        }
184    }
185
186    fn bind_group_data(&self) -> Self::Data {
187        MaterialExtensionBindGroupData {
188            base: self.base.bind_group_data(),
189            extension: self.extension.bind_group_data(),
190        }
191    }
192
193    fn unprepared_bind_group(
194        &self,
195        layout: &BindGroupLayout,
196        render_device: &RenderDevice,
197        (base_param, extended_param): &mut SystemParamItem<'_, '_, Self::Param>,
198        mut force_non_bindless: bool,
199    ) -> Result<UnpreparedBindGroup, AsBindGroupError> {
200        force_non_bindless = force_non_bindless || Self::bindless_slot_count().is_none();
201
202        // add together the bindings of the base material and the user material
203        let UnpreparedBindGroup { mut bindings } = B::unprepared_bind_group(
204            &self.base,
205            layout,
206            render_device,
207            base_param,
208            force_non_bindless,
209        )?;
210        let extended_bindgroup = E::unprepared_bind_group(
211            &self.extension,
212            layout,
213            render_device,
214            extended_param,
215            force_non_bindless,
216        )?;
217
218        bindings.extend(extended_bindgroup.bindings.0);
219
220        Ok(UnpreparedBindGroup { bindings })
221    }
222
223    fn bind_group_layout_entries(
224        render_device: &RenderDevice,
225        mut force_non_bindless: bool,
226    ) -> Vec<BindGroupLayoutEntry>
227    where
228        Self: Sized,
229    {
230        force_non_bindless = force_non_bindless || Self::bindless_slot_count().is_none();
231
232        // Add together the bindings of the standard material and the user
233        // material, skipping duplicate bindings. Duplicate bindings will occur
234        // when bindless mode is on, because of the common bindless resource
235        // arrays, and we need to eliminate the duplicates or `wgpu` will
236        // complain.
237        let mut entries = vec![];
238        let mut seen_bindings = HashSet::<_>::with_hasher(FixedHasher);
239        for entry in B::bind_group_layout_entries(render_device, force_non_bindless)
240            .into_iter()
241            .chain(E::bind_group_layout_entries(render_device, force_non_bindless).into_iter())
242        {
243            if seen_bindings.insert(entry.binding) {
244                entries.push(entry);
245            }
246        }
247        entries
248    }
249
250    fn bindless_descriptor() -> Option<BindlessDescriptor> {
251        // We're going to combine the two bindless descriptors.
252        let base_bindless_descriptor = B::bindless_descriptor()?;
253        let extended_bindless_descriptor = E::bindless_descriptor()?;
254
255        // Combining the buffers and index tables is straightforward.
256
257        let mut buffers = base_bindless_descriptor.buffers.to_vec();
258        let mut index_tables = base_bindless_descriptor.index_tables.to_vec();
259
260        buffers.extend(extended_bindless_descriptor.buffers.iter().cloned());
261        index_tables.extend(extended_bindless_descriptor.index_tables.iter().cloned());
262
263        // Combining the resources is a little trickier because the resource
264        // array is indexed by bindless index, so we have to merge the two
265        // arrays, not just concatenate them.
266        let max_bindless_index = base_bindless_descriptor
267            .resources
268            .len()
269            .max(extended_bindless_descriptor.resources.len());
270        let mut resources = Vec::with_capacity(max_bindless_index);
271        for bindless_index in 0..max_bindless_index {
272            // In the event of a conflicting bindless index, we choose the
273            // base's binding.
274            match base_bindless_descriptor.resources.get(bindless_index) {
275                None | Some(&BindlessResourceType::None) => resources.push(
276                    extended_bindless_descriptor
277                        .resources
278                        .get(bindless_index)
279                        .copied()
280                        .unwrap_or(BindlessResourceType::None),
281                ),
282                Some(&resource_type) => resources.push(resource_type),
283            }
284        }
285
286        Some(BindlessDescriptor {
287            resources: Cow::Owned(resources),
288            buffers: Cow::Owned(buffers),
289            index_tables: Cow::Owned(index_tables),
290        })
291    }
292}
293
294impl<B: Material, E: MaterialExtension> Material for ExtendedMaterial<B, E> {
295    fn vertex_shader() -> ShaderRef {
296        match E::vertex_shader() {
297            ShaderRef::Default => B::vertex_shader(),
298            specified => specified,
299        }
300    }
301
302    fn fragment_shader() -> ShaderRef {
303        match E::fragment_shader() {
304            ShaderRef::Default => B::fragment_shader(),
305            specified => specified,
306        }
307    }
308
309    fn alpha_mode(&self) -> AlphaMode {
310        match E::alpha_mode() {
311            Some(specified) => specified,
312            None => B::alpha_mode(&self.base),
313        }
314    }
315
316    fn opaque_render_method(&self) -> crate::OpaqueRendererMethod {
317        B::opaque_render_method(&self.base)
318    }
319
320    fn depth_bias(&self) -> f32 {
321        B::depth_bias(&self.base)
322    }
323
324    fn reads_view_transmission_texture(&self) -> bool {
325        B::reads_view_transmission_texture(&self.base)
326    }
327
328    fn prepass_vertex_shader() -> ShaderRef {
329        match E::prepass_vertex_shader() {
330            ShaderRef::Default => B::prepass_vertex_shader(),
331            specified => specified,
332        }
333    }
334
335    fn prepass_fragment_shader() -> ShaderRef {
336        match E::prepass_fragment_shader() {
337            ShaderRef::Default => B::prepass_fragment_shader(),
338            specified => specified,
339        }
340    }
341
342    fn deferred_vertex_shader() -> ShaderRef {
343        match E::deferred_vertex_shader() {
344            ShaderRef::Default => B::deferred_vertex_shader(),
345            specified => specified,
346        }
347    }
348
349    fn deferred_fragment_shader() -> ShaderRef {
350        match E::deferred_fragment_shader() {
351            ShaderRef::Default => B::deferred_fragment_shader(),
352            specified => specified,
353        }
354    }
355
356    #[cfg(feature = "meshlet")]
357    fn meshlet_mesh_fragment_shader() -> ShaderRef {
358        match E::meshlet_mesh_fragment_shader() {
359            ShaderRef::Default => B::meshlet_mesh_fragment_shader(),
360            specified => specified,
361        }
362    }
363
364    #[cfg(feature = "meshlet")]
365    fn meshlet_mesh_prepass_fragment_shader() -> ShaderRef {
366        match E::meshlet_mesh_prepass_fragment_shader() {
367            ShaderRef::Default => B::meshlet_mesh_prepass_fragment_shader(),
368            specified => specified,
369        }
370    }
371
372    #[cfg(feature = "meshlet")]
373    fn meshlet_mesh_deferred_fragment_shader() -> ShaderRef {
374        match E::meshlet_mesh_deferred_fragment_shader() {
375            ShaderRef::Default => B::meshlet_mesh_deferred_fragment_shader(),
376            specified => specified,
377        }
378    }
379
380    fn specialize(
381        pipeline: &MaterialPipeline,
382        descriptor: &mut RenderPipelineDescriptor,
383        layout: &MeshVertexBufferLayoutRef,
384        key: MaterialPipelineKey<Self>,
385    ) -> Result<(), SpecializedMeshPipelineError> {
386        // Call the base material's specialize function
387        let base_key = MaterialPipelineKey::<B> {
388            mesh_key: key.mesh_key,
389            bind_group_data: key.bind_group_data.base,
390        };
391        B::specialize(pipeline, descriptor, layout, base_key)?;
392
393        // Call the extended material's specialize function afterwards
394        E::specialize(
395            &MaterialExtensionPipeline {
396                mesh_pipeline: pipeline.mesh_pipeline.clone(),
397            },
398            descriptor,
399            layout,
400            MaterialExtensionKey {
401                mesh_key: key.mesh_key,
402                bind_group_data: key.bind_group_data.extension,
403            },
404        )
405    }
406}