1use alloc::borrow::Cow;
2
3use bevy_asset::Asset;
4use bevy_ecs::system::SystemParamItem;
5use bevy_mesh::MeshVertexBufferLayoutRef;
6use bevy_platform::{collections::HashSet, hash::FixedHasher};
7use bevy_reflect::{impl_type_path, Reflect};
8use bevy_render::{
9 alpha::AlphaMode,
10 render_resource::{
11 AsBindGroup, AsBindGroupError, BindGroupLayout, BindGroupLayoutEntry, BindlessDescriptor,
12 BindlessResourceType, BindlessSlabResourceLimit, RenderPipelineDescriptor,
13 SpecializedMeshPipelineError, UnpreparedBindGroup,
14 },
15 renderer::RenderDevice,
16};
17use bevy_shader::ShaderRef;
18
19use crate::{Material, MaterialPipeline, MaterialPipelineKey, MeshPipeline, MeshPipelineKey};
20
21pub struct MaterialExtensionPipeline {
22 pub mesh_pipeline: MeshPipeline,
23}
24
25pub struct MaterialExtensionKey<E: MaterialExtension> {
26 pub mesh_key: MeshPipelineKey,
27 pub bind_group_data: E::Data,
28}
29
30pub trait MaterialExtension: Asset + AsBindGroup + Clone + Sized {
34 fn vertex_shader() -> ShaderRef {
37 ShaderRef::Default
38 }
39
40 fn fragment_shader() -> ShaderRef {
43 ShaderRef::Default
44 }
45
46 fn alpha_mode() -> Option<AlphaMode> {
48 None
49 }
50
51 #[inline]
54 fn enable_prepass() -> bool {
55 true
56 }
57
58 #[inline]
60 fn enable_shadows() -> bool {
61 true
62 }
63
64 fn prepass_vertex_shader() -> ShaderRef {
67 ShaderRef::Default
68 }
69
70 fn prepass_fragment_shader() -> ShaderRef {
73 ShaderRef::Default
74 }
75
76 fn deferred_vertex_shader() -> ShaderRef {
79 ShaderRef::Default
80 }
81
82 fn deferred_fragment_shader() -> ShaderRef {
85 ShaderRef::Default
86 }
87
88 #[cfg(feature = "meshlet")]
91 fn meshlet_mesh_fragment_shader() -> ShaderRef {
92 ShaderRef::Default
93 }
94
95 #[cfg(feature = "meshlet")]
98 fn meshlet_mesh_prepass_fragment_shader() -> ShaderRef {
99 ShaderRef::Default
100 }
101
102 #[cfg(feature = "meshlet")]
105 fn meshlet_mesh_deferred_fragment_shader() -> ShaderRef {
106 ShaderRef::Default
107 }
108
109 #[expect(
113 unused_variables,
114 reason = "The parameters here are intentionally unused by the default implementation; however, putting underscores here will result in the underscores being copied by rust-analyzer's tab completion."
115 )]
116 #[inline]
117 fn specialize(
118 pipeline: &MaterialExtensionPipeline,
119 descriptor: &mut RenderPipelineDescriptor,
120 layout: &MeshVertexBufferLayoutRef,
121 key: MaterialExtensionKey<Self>,
122 ) -> Result<(), SpecializedMeshPipelineError> {
123 Ok(())
124 }
125}
126
127#[derive(Asset, Clone, Debug, Reflect)]
143#[reflect(type_path = false)]
144#[reflect(Clone)]
145pub struct ExtendedMaterial<B: Material, E: MaterialExtension> {
146 pub base: B,
147 pub extension: E,
148}
149
150impl<B, E> Default for ExtendedMaterial<B, E>
151where
152 B: Material + Default,
153 E: MaterialExtension + Default,
154{
155 fn default() -> Self {
156 Self {
157 base: B::default(),
158 extension: E::default(),
159 }
160 }
161}
162
163#[derive(Copy, Clone, PartialEq, Eq, Hash)]
164#[repr(C, packed)]
165pub struct MaterialExtensionBindGroupData<B, E> {
166 pub base: B,
167 pub extension: E,
168}
169
170impl_type_path!((in bevy_pbr::extended_material) ExtendedMaterial<B: Material, E: MaterialExtension>);
173
174impl<B: Material, E: MaterialExtension> AsBindGroup for ExtendedMaterial<B, E> {
175 type Data = MaterialExtensionBindGroupData<B::Data, E::Data>;
176 type Param = (<B as AsBindGroup>::Param, <E as AsBindGroup>::Param);
177
178 fn bindless_slot_count() -> Option<BindlessSlabResourceLimit> {
179 match (B::bindless_slot_count()?, E::bindless_slot_count()?) {
183 (BindlessSlabResourceLimit::Auto, BindlessSlabResourceLimit::Auto) => {
184 Some(BindlessSlabResourceLimit::Auto)
185 }
186 (BindlessSlabResourceLimit::Auto, BindlessSlabResourceLimit::Custom(limit))
187 | (BindlessSlabResourceLimit::Custom(limit), BindlessSlabResourceLimit::Auto) => {
188 Some(BindlessSlabResourceLimit::Custom(limit))
189 }
190 (
191 BindlessSlabResourceLimit::Custom(base_limit),
192 BindlessSlabResourceLimit::Custom(extended_limit),
193 ) => Some(BindlessSlabResourceLimit::Custom(
194 base_limit.min(extended_limit),
195 )),
196 }
197 }
198
199 fn bindless_supported(render_device: &RenderDevice) -> bool {
200 B::bindless_supported(render_device) && E::bindless_supported(render_device)
201 }
202
203 fn label() -> &'static str {
204 E::label()
205 }
206
207 fn bind_group_data(&self) -> Self::Data {
208 MaterialExtensionBindGroupData {
209 base: self.base.bind_group_data(),
210 extension: self.extension.bind_group_data(),
211 }
212 }
213
214 fn unprepared_bind_group(
215 &self,
216 layout: &BindGroupLayout,
217 render_device: &RenderDevice,
218 (base_param, extended_param): &mut SystemParamItem<'_, '_, Self::Param>,
219 mut force_non_bindless: bool,
220 ) -> Result<UnpreparedBindGroup, AsBindGroupError> {
221 force_non_bindless = force_non_bindless || Self::bindless_slot_count().is_none();
222
223 let UnpreparedBindGroup { mut bindings } = B::unprepared_bind_group(
225 &self.base,
226 layout,
227 render_device,
228 base_param,
229 force_non_bindless,
230 )?;
231 let UnpreparedBindGroup {
232 bindings: extension_bindings,
233 } = E::unprepared_bind_group(
234 &self.extension,
235 layout,
236 render_device,
237 extended_param,
238 force_non_bindless,
239 )?;
240
241 bindings.extend(extension_bindings.0);
242
243 Ok(UnpreparedBindGroup { bindings })
244 }
245
246 fn bind_group_layout_entries(
247 render_device: &RenderDevice,
248 mut force_non_bindless: bool,
249 ) -> Vec<BindGroupLayoutEntry>
250 where
251 Self: Sized,
252 {
253 force_non_bindless = force_non_bindless || Self::bindless_slot_count().is_none();
254
255 let base_entries = B::bind_group_layout_entries(render_device, force_non_bindless);
261 let extension_entries = E::bind_group_layout_entries(render_device, force_non_bindless);
262
263 let mut seen_bindings = HashSet::<u32>::with_hasher(FixedHasher);
264
265 base_entries
266 .into_iter()
267 .chain(extension_entries)
268 .filter(|entry| seen_bindings.insert(entry.binding))
269 .collect()
270 }
271
272 fn bindless_descriptor() -> Option<BindlessDescriptor> {
273 let base_bindless_descriptor = B::bindless_descriptor()?;
275 let extended_bindless_descriptor = E::bindless_descriptor()?;
276
277 let mut buffers = base_bindless_descriptor.buffers.to_vec();
280 let mut index_tables = base_bindless_descriptor.index_tables.to_vec();
281
282 buffers.extend(extended_bindless_descriptor.buffers.iter().cloned());
283 index_tables.extend(extended_bindless_descriptor.index_tables.iter().cloned());
284
285 let max_bindless_index = base_bindless_descriptor
289 .resources
290 .len()
291 .max(extended_bindless_descriptor.resources.len());
292 let mut resources = Vec::with_capacity(max_bindless_index);
293 for bindless_index in 0..max_bindless_index {
294 match base_bindless_descriptor.resources.get(bindless_index) {
297 None | Some(&BindlessResourceType::None) => resources.push(
298 extended_bindless_descriptor
299 .resources
300 .get(bindless_index)
301 .copied()
302 .unwrap_or(BindlessResourceType::None),
303 ),
304 Some(&resource_type) => resources.push(resource_type),
305 }
306 }
307
308 Some(BindlessDescriptor {
309 resources: Cow::Owned(resources),
310 buffers: Cow::Owned(buffers),
311 index_tables: Cow::Owned(index_tables),
312 })
313 }
314}
315
316impl<B: Material, E: MaterialExtension> Material for ExtendedMaterial<B, E> {
317 fn vertex_shader() -> ShaderRef {
318 match E::vertex_shader() {
319 ShaderRef::Default => B::vertex_shader(),
320 specified => specified,
321 }
322 }
323
324 fn fragment_shader() -> ShaderRef {
325 match E::fragment_shader() {
326 ShaderRef::Default => B::fragment_shader(),
327 specified => specified,
328 }
329 }
330
331 fn alpha_mode(&self) -> AlphaMode {
332 match E::alpha_mode() {
333 Some(specified) => specified,
334 None => B::alpha_mode(&self.base),
335 }
336 }
337
338 fn opaque_render_method(&self) -> crate::OpaqueRendererMethod {
339 B::opaque_render_method(&self.base)
340 }
341
342 fn depth_bias(&self) -> f32 {
343 B::depth_bias(&self.base)
344 }
345
346 fn reads_view_transmission_texture(&self) -> bool {
347 B::reads_view_transmission_texture(&self.base)
348 }
349
350 fn enable_prepass() -> bool {
351 E::enable_prepass()
352 }
353
354 fn enable_shadows() -> bool {
355 E::enable_shadows()
356 }
357
358 fn prepass_vertex_shader() -> ShaderRef {
359 match E::prepass_vertex_shader() {
360 ShaderRef::Default => B::prepass_vertex_shader(),
361 specified => specified,
362 }
363 }
364
365 fn prepass_fragment_shader() -> ShaderRef {
366 match E::prepass_fragment_shader() {
367 ShaderRef::Default => B::prepass_fragment_shader(),
368 specified => specified,
369 }
370 }
371
372 fn deferred_vertex_shader() -> ShaderRef {
373 match E::deferred_vertex_shader() {
374 ShaderRef::Default => B::deferred_vertex_shader(),
375 specified => specified,
376 }
377 }
378
379 fn deferred_fragment_shader() -> ShaderRef {
380 match E::deferred_fragment_shader() {
381 ShaderRef::Default => B::deferred_fragment_shader(),
382 specified => specified,
383 }
384 }
385
386 #[cfg(feature = "meshlet")]
387 fn meshlet_mesh_fragment_shader() -> ShaderRef {
388 match E::meshlet_mesh_fragment_shader() {
389 ShaderRef::Default => B::meshlet_mesh_fragment_shader(),
390 specified => specified,
391 }
392 }
393
394 #[cfg(feature = "meshlet")]
395 fn meshlet_mesh_prepass_fragment_shader() -> ShaderRef {
396 match E::meshlet_mesh_prepass_fragment_shader() {
397 ShaderRef::Default => B::meshlet_mesh_prepass_fragment_shader(),
398 specified => specified,
399 }
400 }
401
402 #[cfg(feature = "meshlet")]
403 fn meshlet_mesh_deferred_fragment_shader() -> ShaderRef {
404 match E::meshlet_mesh_deferred_fragment_shader() {
405 ShaderRef::Default => B::meshlet_mesh_deferred_fragment_shader(),
406 specified => specified,
407 }
408 }
409
410 fn specialize(
411 pipeline: &MaterialPipeline,
412 descriptor: &mut RenderPipelineDescriptor,
413 layout: &MeshVertexBufferLayoutRef,
414 key: MaterialPipelineKey<Self>,
415 ) -> Result<(), SpecializedMeshPipelineError> {
416 let base_key = MaterialPipelineKey::<B> {
418 mesh_key: key.mesh_key,
419 bind_group_data: key.bind_group_data.base,
420 };
421 B::specialize(pipeline, descriptor, layout, base_key)?;
422
423 E::specialize(
425 &MaterialExtensionPipeline {
426 mesh_pipeline: pipeline.mesh_pipeline.clone(),
427 },
428 descriptor,
429 layout,
430 MaterialExtensionKey {
431 mesh_key: key.mesh_key,
432 bind_group_data: key.bind_group_data.extension,
433 },
434 )
435 }
436}