1use alloc::borrow::Cow;
2
3use bevy_asset::Asset;
4use bevy_ecs::system::SystemParamItem;
5use bevy_mesh::MeshVertexBufferLayoutRef;
6use bevy_platform::{collections::HashSet, hash::FixedHasher};
7use bevy_reflect::{impl_type_path, Reflect};
8use bevy_render::{
9 alpha::AlphaMode,
10 render_resource::{
11 AsBindGroup, AsBindGroupError, BindGroupLayout, BindGroupLayoutEntry, BindlessDescriptor,
12 BindlessResourceType, BindlessSlabResourceLimit, RenderPipelineDescriptor,
13 SpecializedMeshPipelineError, UnpreparedBindGroup,
14 },
15 renderer::RenderDevice,
16};
17use bevy_shader::ShaderRef;
18
19use crate::{Material, MaterialPipeline, MaterialPipelineKey, MeshPipeline, MeshPipelineKey};
20
21pub struct MaterialExtensionPipeline {
22 pub mesh_pipeline: MeshPipeline,
23}
24
25pub struct MaterialExtensionKey<E: MaterialExtension> {
26 pub mesh_key: MeshPipelineKey,
27 pub bind_group_data: E::Data,
28}
29
30pub trait MaterialExtension: Asset + AsBindGroup + Clone + Sized {
34 fn vertex_shader() -> ShaderRef {
37 ShaderRef::Default
38 }
39
40 fn fragment_shader() -> ShaderRef {
43 ShaderRef::Default
44 }
45
46 fn alpha_mode() -> Option<AlphaMode> {
48 None
49 }
50
51 fn prepass_vertex_shader() -> ShaderRef {
54 ShaderRef::Default
55 }
56
57 fn prepass_fragment_shader() -> ShaderRef {
60 ShaderRef::Default
61 }
62
63 fn deferred_vertex_shader() -> ShaderRef {
66 ShaderRef::Default
67 }
68
69 fn deferred_fragment_shader() -> ShaderRef {
72 ShaderRef::Default
73 }
74
75 #[cfg(feature = "meshlet")]
78 fn meshlet_mesh_fragment_shader() -> ShaderRef {
79 ShaderRef::Default
80 }
81
82 #[cfg(feature = "meshlet")]
85 fn meshlet_mesh_prepass_fragment_shader() -> ShaderRef {
86 ShaderRef::Default
87 }
88
89 #[cfg(feature = "meshlet")]
92 fn meshlet_mesh_deferred_fragment_shader() -> ShaderRef {
93 ShaderRef::Default
94 }
95
96 #[expect(
100 unused_variables,
101 reason = "The parameters here are intentionally unused by the default implementation; however, putting underscores here will result in the underscores being copied by rust-analyzer's tab completion."
102 )]
103 #[inline]
104 fn specialize(
105 pipeline: &MaterialExtensionPipeline,
106 descriptor: &mut RenderPipelineDescriptor,
107 layout: &MeshVertexBufferLayoutRef,
108 key: MaterialExtensionKey<Self>,
109 ) -> Result<(), SpecializedMeshPipelineError> {
110 Ok(())
111 }
112}
113
114#[derive(Asset, Clone, Debug, Reflect)]
130#[reflect(type_path = false)]
131#[reflect(Clone)]
132pub struct ExtendedMaterial<B: Material, E: MaterialExtension> {
133 pub base: B,
134 pub extension: E,
135}
136
137impl<B, E> Default for ExtendedMaterial<B, E>
138where
139 B: Material + Default,
140 E: MaterialExtension + Default,
141{
142 fn default() -> Self {
143 Self {
144 base: B::default(),
145 extension: E::default(),
146 }
147 }
148}
149
150#[derive(Copy, Clone, PartialEq, Eq, Hash)]
151#[repr(C, packed)]
152pub struct MaterialExtensionBindGroupData<B, E> {
153 pub base: B,
154 pub extension: E,
155}
156
157impl_type_path!((in bevy_pbr::extended_material) ExtendedMaterial<B: Material, E: MaterialExtension>);
160
161impl<B: Material, E: MaterialExtension> AsBindGroup for ExtendedMaterial<B, E> {
162 type Data = MaterialExtensionBindGroupData<B::Data, E::Data>;
163 type Param = (<B as AsBindGroup>::Param, <E as AsBindGroup>::Param);
164
165 fn bindless_slot_count() -> Option<BindlessSlabResourceLimit> {
166 match (B::bindless_slot_count()?, E::bindless_slot_count()?) {
170 (BindlessSlabResourceLimit::Auto, BindlessSlabResourceLimit::Auto) => {
171 Some(BindlessSlabResourceLimit::Auto)
172 }
173 (BindlessSlabResourceLimit::Auto, BindlessSlabResourceLimit::Custom(limit))
174 | (BindlessSlabResourceLimit::Custom(limit), BindlessSlabResourceLimit::Auto) => {
175 Some(BindlessSlabResourceLimit::Custom(limit))
176 }
177 (
178 BindlessSlabResourceLimit::Custom(base_limit),
179 BindlessSlabResourceLimit::Custom(extended_limit),
180 ) => Some(BindlessSlabResourceLimit::Custom(
181 base_limit.min(extended_limit),
182 )),
183 }
184 }
185
186 fn bind_group_data(&self) -> Self::Data {
187 MaterialExtensionBindGroupData {
188 base: self.base.bind_group_data(),
189 extension: self.extension.bind_group_data(),
190 }
191 }
192
193 fn unprepared_bind_group(
194 &self,
195 layout: &BindGroupLayout,
196 render_device: &RenderDevice,
197 (base_param, extended_param): &mut SystemParamItem<'_, '_, Self::Param>,
198 mut force_non_bindless: bool,
199 ) -> Result<UnpreparedBindGroup, AsBindGroupError> {
200 force_non_bindless = force_non_bindless || Self::bindless_slot_count().is_none();
201
202 let UnpreparedBindGroup { mut bindings } = B::unprepared_bind_group(
204 &self.base,
205 layout,
206 render_device,
207 base_param,
208 force_non_bindless,
209 )?;
210 let extended_bindgroup = E::unprepared_bind_group(
211 &self.extension,
212 layout,
213 render_device,
214 extended_param,
215 force_non_bindless,
216 )?;
217
218 bindings.extend(extended_bindgroup.bindings.0);
219
220 Ok(UnpreparedBindGroup { bindings })
221 }
222
223 fn bind_group_layout_entries(
224 render_device: &RenderDevice,
225 mut force_non_bindless: bool,
226 ) -> Vec<BindGroupLayoutEntry>
227 where
228 Self: Sized,
229 {
230 force_non_bindless = force_non_bindless || Self::bindless_slot_count().is_none();
231
232 let mut entries = vec![];
238 let mut seen_bindings = HashSet::<_>::with_hasher(FixedHasher);
239 for entry in B::bind_group_layout_entries(render_device, force_non_bindless)
240 .into_iter()
241 .chain(E::bind_group_layout_entries(render_device, force_non_bindless).into_iter())
242 {
243 if seen_bindings.insert(entry.binding) {
244 entries.push(entry);
245 }
246 }
247 entries
248 }
249
250 fn bindless_descriptor() -> Option<BindlessDescriptor> {
251 let base_bindless_descriptor = B::bindless_descriptor()?;
253 let extended_bindless_descriptor = E::bindless_descriptor()?;
254
255 let mut buffers = base_bindless_descriptor.buffers.to_vec();
258 let mut index_tables = base_bindless_descriptor.index_tables.to_vec();
259
260 buffers.extend(extended_bindless_descriptor.buffers.iter().cloned());
261 index_tables.extend(extended_bindless_descriptor.index_tables.iter().cloned());
262
263 let max_bindless_index = base_bindless_descriptor
267 .resources
268 .len()
269 .max(extended_bindless_descriptor.resources.len());
270 let mut resources = Vec::with_capacity(max_bindless_index);
271 for bindless_index in 0..max_bindless_index {
272 match base_bindless_descriptor.resources.get(bindless_index) {
275 None | Some(&BindlessResourceType::None) => resources.push(
276 extended_bindless_descriptor
277 .resources
278 .get(bindless_index)
279 .copied()
280 .unwrap_or(BindlessResourceType::None),
281 ),
282 Some(&resource_type) => resources.push(resource_type),
283 }
284 }
285
286 Some(BindlessDescriptor {
287 resources: Cow::Owned(resources),
288 buffers: Cow::Owned(buffers),
289 index_tables: Cow::Owned(index_tables),
290 })
291 }
292}
293
294impl<B: Material, E: MaterialExtension> Material for ExtendedMaterial<B, E> {
295 fn vertex_shader() -> ShaderRef {
296 match E::vertex_shader() {
297 ShaderRef::Default => B::vertex_shader(),
298 specified => specified,
299 }
300 }
301
302 fn fragment_shader() -> ShaderRef {
303 match E::fragment_shader() {
304 ShaderRef::Default => B::fragment_shader(),
305 specified => specified,
306 }
307 }
308
309 fn alpha_mode(&self) -> AlphaMode {
310 match E::alpha_mode() {
311 Some(specified) => specified,
312 None => B::alpha_mode(&self.base),
313 }
314 }
315
316 fn opaque_render_method(&self) -> crate::OpaqueRendererMethod {
317 B::opaque_render_method(&self.base)
318 }
319
320 fn depth_bias(&self) -> f32 {
321 B::depth_bias(&self.base)
322 }
323
324 fn reads_view_transmission_texture(&self) -> bool {
325 B::reads_view_transmission_texture(&self.base)
326 }
327
328 fn prepass_vertex_shader() -> ShaderRef {
329 match E::prepass_vertex_shader() {
330 ShaderRef::Default => B::prepass_vertex_shader(),
331 specified => specified,
332 }
333 }
334
335 fn prepass_fragment_shader() -> ShaderRef {
336 match E::prepass_fragment_shader() {
337 ShaderRef::Default => B::prepass_fragment_shader(),
338 specified => specified,
339 }
340 }
341
342 fn deferred_vertex_shader() -> ShaderRef {
343 match E::deferred_vertex_shader() {
344 ShaderRef::Default => B::deferred_vertex_shader(),
345 specified => specified,
346 }
347 }
348
349 fn deferred_fragment_shader() -> ShaderRef {
350 match E::deferred_fragment_shader() {
351 ShaderRef::Default => B::deferred_fragment_shader(),
352 specified => specified,
353 }
354 }
355
356 #[cfg(feature = "meshlet")]
357 fn meshlet_mesh_fragment_shader() -> ShaderRef {
358 match E::meshlet_mesh_fragment_shader() {
359 ShaderRef::Default => B::meshlet_mesh_fragment_shader(),
360 specified => specified,
361 }
362 }
363
364 #[cfg(feature = "meshlet")]
365 fn meshlet_mesh_prepass_fragment_shader() -> ShaderRef {
366 match E::meshlet_mesh_prepass_fragment_shader() {
367 ShaderRef::Default => B::meshlet_mesh_prepass_fragment_shader(),
368 specified => specified,
369 }
370 }
371
372 #[cfg(feature = "meshlet")]
373 fn meshlet_mesh_deferred_fragment_shader() -> ShaderRef {
374 match E::meshlet_mesh_deferred_fragment_shader() {
375 ShaderRef::Default => B::meshlet_mesh_deferred_fragment_shader(),
376 specified => specified,
377 }
378 }
379
380 fn specialize(
381 pipeline: &MaterialPipeline,
382 descriptor: &mut RenderPipelineDescriptor,
383 layout: &MeshVertexBufferLayoutRef,
384 key: MaterialPipelineKey<Self>,
385 ) -> Result<(), SpecializedMeshPipelineError> {
386 let base_key = MaterialPipelineKey::<B> {
388 mesh_key: key.mesh_key,
389 bind_group_data: key.bind_group_data.base,
390 };
391 B::specialize(pipeline, descriptor, layout, base_key)?;
392
393 E::specialize(
395 &MaterialExtensionPipeline {
396 mesh_pipeline: pipeline.mesh_pipeline.clone(),
397 },
398 descriptor,
399 layout,
400 MaterialExtensionKey {
401 mesh_key: key.mesh_key,
402 bind_group_data: key.bind_group_data.extension,
403 },
404 )
405 }
406}