1use alloc::borrow::Cow;
2
3use bevy_asset::{Asset, Handle};
4use bevy_ecs::system::SystemParamItem;
5use bevy_platform::{collections::HashSet, hash::FixedHasher};
6use bevy_reflect::{impl_type_path, Reflect};
7use bevy_render::{
8 alpha::AlphaMode,
9 mesh::MeshVertexBufferLayoutRef,
10 render_resource::{
11 AsBindGroup, AsBindGroupError, BindGroupLayout, BindGroupLayoutEntry, BindlessDescriptor,
12 BindlessResourceType, BindlessSlabResourceLimit, RenderPipelineDescriptor, Shader,
13 ShaderRef, SpecializedMeshPipelineError, UnpreparedBindGroup,
14 },
15 renderer::RenderDevice,
16};
17
18use crate::{Material, MaterialPipeline, MaterialPipelineKey, MeshPipeline, MeshPipelineKey};
19
20pub struct MaterialExtensionPipeline {
21 pub mesh_pipeline: MeshPipeline,
22 pub material_layout: BindGroupLayout,
23 pub vertex_shader: Option<Handle<Shader>>,
24 pub fragment_shader: Option<Handle<Shader>>,
25 pub bindless: bool,
26}
27
28pub struct MaterialExtensionKey<E: MaterialExtension> {
29 pub mesh_key: MeshPipelineKey,
30 pub bind_group_data: E::Data,
31}
32
33pub trait MaterialExtension: Asset + AsBindGroup + Clone + Sized {
37 fn vertex_shader() -> ShaderRef {
40 ShaderRef::Default
41 }
42
43 fn fragment_shader() -> ShaderRef {
46 ShaderRef::Default
47 }
48
49 fn alpha_mode() -> Option<AlphaMode> {
51 None
52 }
53
54 fn prepass_vertex_shader() -> ShaderRef {
57 ShaderRef::Default
58 }
59
60 fn prepass_fragment_shader() -> ShaderRef {
63 ShaderRef::Default
64 }
65
66 fn deferred_vertex_shader() -> ShaderRef {
69 ShaderRef::Default
70 }
71
72 fn deferred_fragment_shader() -> ShaderRef {
75 ShaderRef::Default
76 }
77
78 #[cfg(feature = "meshlet")]
81 fn meshlet_mesh_fragment_shader() -> ShaderRef {
82 ShaderRef::Default
83 }
84
85 #[cfg(feature = "meshlet")]
88 fn meshlet_mesh_prepass_fragment_shader() -> ShaderRef {
89 ShaderRef::Default
90 }
91
92 #[cfg(feature = "meshlet")]
95 fn meshlet_mesh_deferred_fragment_shader() -> ShaderRef {
96 ShaderRef::Default
97 }
98
99 #[expect(
103 unused_variables,
104 reason = "The parameters here are intentionally unused by the default implementation; however, putting underscores here will result in the underscores being copied by rust-analyzer's tab completion."
105 )]
106 #[inline]
107 fn specialize(
108 pipeline: &MaterialExtensionPipeline,
109 descriptor: &mut RenderPipelineDescriptor,
110 layout: &MeshVertexBufferLayoutRef,
111 key: MaterialExtensionKey<Self>,
112 ) -> Result<(), SpecializedMeshPipelineError> {
113 Ok(())
114 }
115}
116
117#[derive(Asset, Clone, Debug, Reflect)]
133#[reflect(type_path = false)]
134#[reflect(Clone)]
135pub struct ExtendedMaterial<B: Material, E: MaterialExtension> {
136 pub base: B,
137 pub extension: E,
138}
139
140impl<B, E> Default for ExtendedMaterial<B, E>
141where
142 B: Material + Default,
143 E: MaterialExtension + Default,
144{
145 fn default() -> Self {
146 Self {
147 base: B::default(),
148 extension: E::default(),
149 }
150 }
151}
152
153impl_type_path!((in bevy_pbr::extended_material) ExtendedMaterial<B: Material, E: MaterialExtension>);
156
157impl<B: Material, E: MaterialExtension> AsBindGroup for ExtendedMaterial<B, E> {
158 type Data = (<B as AsBindGroup>::Data, <E as AsBindGroup>::Data);
159 type Param = (<B as AsBindGroup>::Param, <E as AsBindGroup>::Param);
160
161 fn bindless_slot_count() -> Option<BindlessSlabResourceLimit> {
162 match (B::bindless_slot_count()?, E::bindless_slot_count()?) {
166 (BindlessSlabResourceLimit::Auto, BindlessSlabResourceLimit::Auto) => {
167 Some(BindlessSlabResourceLimit::Auto)
168 }
169 (BindlessSlabResourceLimit::Auto, BindlessSlabResourceLimit::Custom(limit))
170 | (BindlessSlabResourceLimit::Custom(limit), BindlessSlabResourceLimit::Auto) => {
171 Some(BindlessSlabResourceLimit::Custom(limit))
172 }
173 (
174 BindlessSlabResourceLimit::Custom(base_limit),
175 BindlessSlabResourceLimit::Custom(extended_limit),
176 ) => Some(BindlessSlabResourceLimit::Custom(
177 base_limit.min(extended_limit),
178 )),
179 }
180 }
181
182 fn unprepared_bind_group(
183 &self,
184 layout: &BindGroupLayout,
185 render_device: &RenderDevice,
186 (base_param, extended_param): &mut SystemParamItem<'_, '_, Self::Param>,
187 mut force_non_bindless: bool,
188 ) -> Result<UnpreparedBindGroup<Self::Data>, AsBindGroupError> {
189 force_non_bindless = force_non_bindless || Self::bindless_slot_count().is_none();
190
191 let UnpreparedBindGroup {
193 mut bindings,
194 data: base_data,
195 } = B::unprepared_bind_group(
196 &self.base,
197 layout,
198 render_device,
199 base_param,
200 force_non_bindless,
201 )?;
202 let extended_bindgroup = E::unprepared_bind_group(
203 &self.extension,
204 layout,
205 render_device,
206 extended_param,
207 force_non_bindless,
208 )?;
209
210 bindings.extend(extended_bindgroup.bindings.0);
211
212 Ok(UnpreparedBindGroup {
213 bindings,
214 data: (base_data, extended_bindgroup.data),
215 })
216 }
217
218 fn bind_group_layout_entries(
219 render_device: &RenderDevice,
220 mut force_non_bindless: bool,
221 ) -> Vec<BindGroupLayoutEntry>
222 where
223 Self: Sized,
224 {
225 force_non_bindless = force_non_bindless || Self::bindless_slot_count().is_none();
226
227 let mut entries = vec![];
233 let mut seen_bindings = HashSet::<_>::with_hasher(FixedHasher);
234 for entry in B::bind_group_layout_entries(render_device, force_non_bindless)
235 .into_iter()
236 .chain(E::bind_group_layout_entries(render_device, force_non_bindless).into_iter())
237 {
238 if seen_bindings.insert(entry.binding) {
239 entries.push(entry);
240 }
241 }
242 entries
243 }
244
245 fn bindless_descriptor() -> Option<BindlessDescriptor> {
246 let base_bindless_descriptor = B::bindless_descriptor()?;
248 let extended_bindless_descriptor = E::bindless_descriptor()?;
249
250 let mut buffers = base_bindless_descriptor.buffers.to_vec();
253 let mut index_tables = base_bindless_descriptor.index_tables.to_vec();
254
255 buffers.extend(extended_bindless_descriptor.buffers.iter().cloned());
256 index_tables.extend(extended_bindless_descriptor.index_tables.iter().cloned());
257
258 let max_bindless_index = base_bindless_descriptor
262 .resources
263 .len()
264 .max(extended_bindless_descriptor.resources.len());
265 let mut resources = Vec::with_capacity(max_bindless_index);
266 for bindless_index in 0..max_bindless_index {
267 match base_bindless_descriptor.resources.get(bindless_index) {
270 None | Some(&BindlessResourceType::None) => resources.push(
271 extended_bindless_descriptor
272 .resources
273 .get(bindless_index)
274 .copied()
275 .unwrap_or(BindlessResourceType::None),
276 ),
277 Some(&resource_type) => resources.push(resource_type),
278 }
279 }
280
281 Some(BindlessDescriptor {
282 resources: Cow::Owned(resources),
283 buffers: Cow::Owned(buffers),
284 index_tables: Cow::Owned(index_tables),
285 })
286 }
287}
288
289impl<B: Material, E: MaterialExtension> Material for ExtendedMaterial<B, E> {
290 fn vertex_shader() -> ShaderRef {
291 match E::vertex_shader() {
292 ShaderRef::Default => B::vertex_shader(),
293 specified => specified,
294 }
295 }
296
297 fn fragment_shader() -> ShaderRef {
298 match E::fragment_shader() {
299 ShaderRef::Default => B::fragment_shader(),
300 specified => specified,
301 }
302 }
303
304 fn alpha_mode(&self) -> AlphaMode {
305 match E::alpha_mode() {
306 Some(specified) => specified,
307 None => B::alpha_mode(&self.base),
308 }
309 }
310
311 fn opaque_render_method(&self) -> crate::OpaqueRendererMethod {
312 B::opaque_render_method(&self.base)
313 }
314
315 fn depth_bias(&self) -> f32 {
316 B::depth_bias(&self.base)
317 }
318
319 fn reads_view_transmission_texture(&self) -> bool {
320 B::reads_view_transmission_texture(&self.base)
321 }
322
323 fn prepass_vertex_shader() -> ShaderRef {
324 match E::prepass_vertex_shader() {
325 ShaderRef::Default => B::prepass_vertex_shader(),
326 specified => specified,
327 }
328 }
329
330 fn prepass_fragment_shader() -> ShaderRef {
331 match E::prepass_fragment_shader() {
332 ShaderRef::Default => B::prepass_fragment_shader(),
333 specified => specified,
334 }
335 }
336
337 fn deferred_vertex_shader() -> ShaderRef {
338 match E::deferred_vertex_shader() {
339 ShaderRef::Default => B::deferred_vertex_shader(),
340 specified => specified,
341 }
342 }
343
344 fn deferred_fragment_shader() -> ShaderRef {
345 match E::deferred_fragment_shader() {
346 ShaderRef::Default => B::deferred_fragment_shader(),
347 specified => specified,
348 }
349 }
350
351 #[cfg(feature = "meshlet")]
352 fn meshlet_mesh_fragment_shader() -> ShaderRef {
353 match E::meshlet_mesh_fragment_shader() {
354 ShaderRef::Default => B::meshlet_mesh_fragment_shader(),
355 specified => specified,
356 }
357 }
358
359 #[cfg(feature = "meshlet")]
360 fn meshlet_mesh_prepass_fragment_shader() -> ShaderRef {
361 match E::meshlet_mesh_prepass_fragment_shader() {
362 ShaderRef::Default => B::meshlet_mesh_prepass_fragment_shader(),
363 specified => specified,
364 }
365 }
366
367 #[cfg(feature = "meshlet")]
368 fn meshlet_mesh_deferred_fragment_shader() -> ShaderRef {
369 match E::meshlet_mesh_deferred_fragment_shader() {
370 ShaderRef::Default => B::meshlet_mesh_deferred_fragment_shader(),
371 specified => specified,
372 }
373 }
374
375 fn specialize(
376 pipeline: &MaterialPipeline<Self>,
377 descriptor: &mut RenderPipelineDescriptor,
378 layout: &MeshVertexBufferLayoutRef,
379 key: MaterialPipelineKey<Self>,
380 ) -> Result<(), SpecializedMeshPipelineError> {
381 let MaterialPipeline::<Self> {
383 mesh_pipeline,
384 material_layout,
385 vertex_shader,
386 fragment_shader,
387 bindless,
388 ..
389 } = pipeline.clone();
390 let base_pipeline = MaterialPipeline::<B> {
391 mesh_pipeline,
392 material_layout,
393 vertex_shader,
394 fragment_shader,
395 bindless,
396 marker: Default::default(),
397 };
398 let base_key = MaterialPipelineKey::<B> {
399 mesh_key: key.mesh_key,
400 bind_group_data: key.bind_group_data.0,
401 };
402 B::specialize(&base_pipeline, descriptor, layout, base_key)?;
403
404 let MaterialPipeline::<Self> {
406 mesh_pipeline,
407 material_layout,
408 vertex_shader,
409 fragment_shader,
410 bindless,
411 ..
412 } = pipeline.clone();
413
414 E::specialize(
415 &MaterialExtensionPipeline {
416 mesh_pipeline,
417 material_layout,
418 vertex_shader,
419 fragment_shader,
420 bindless,
421 },
422 descriptor,
423 layout,
424 MaterialExtensionKey {
425 mesh_key: key.mesh_key,
426 bind_group_data: key.bind_group_data.1,
427 },
428 )
429 }
430}