1mod conv;
111mod help;
112mod keywords;
113mod ray;
114mod storage;
115mod writer;
116
117use alloc::{string::String, vec::Vec};
118use core::fmt::Error as FmtError;
119
120use thiserror::Error;
121
122use crate::{back, ir, proc};
123
124#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
125#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
126#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
127pub struct BindTarget {
128 pub space: u8,
129 pub register: u32,
133 pub binding_array_size: Option<u32>,
135 pub dynamic_storage_buffer_offsets_index: Option<u32>,
137 #[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
141 pub restrict_indexing: bool,
142}
143
144#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
145#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
146#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
147pub struct OffsetsBindTarget {
149 pub space: u8,
150 pub register: u32,
151 pub size: u32,
152}
153
154#[cfg(any(feature = "serialize", feature = "deserialize"))]
155#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
156#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
157struct BindingMapSerialization {
158 resource_binding: crate::ResourceBinding,
159 bind_target: BindTarget,
160}
161
162#[cfg(feature = "deserialize")]
163fn deserialize_binding_map<'de, D>(deserializer: D) -> Result<BindingMap, D::Error>
164where
165 D: serde::Deserializer<'de>,
166{
167 use serde::Deserialize;
168
169 let vec = Vec::<BindingMapSerialization>::deserialize(deserializer)?;
170 let mut map = BindingMap::default();
171 for item in vec {
172 map.insert(item.resource_binding, item.bind_target);
173 }
174 Ok(map)
175}
176
177pub type BindingMap = alloc::collections::BTreeMap<crate::ResourceBinding, BindTarget>;
179
180#[allow(non_snake_case, non_camel_case_types)]
182#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, PartialOrd)]
183#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
184#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
185pub enum ShaderModel {
186 V5_0,
187 V5_1,
188 V6_0,
189 V6_1,
190 V6_2,
191 V6_3,
192 V6_4,
193 V6_5,
194 V6_6,
195 V6_7,
196}
197
198impl ShaderModel {
199 pub const fn to_str(self) -> &'static str {
200 match self {
201 Self::V5_0 => "5_0",
202 Self::V5_1 => "5_1",
203 Self::V6_0 => "6_0",
204 Self::V6_1 => "6_1",
205 Self::V6_2 => "6_2",
206 Self::V6_3 => "6_3",
207 Self::V6_4 => "6_4",
208 Self::V6_5 => "6_5",
209 Self::V6_6 => "6_6",
210 Self::V6_7 => "6_7",
211 }
212 }
213}
214
215impl crate::ShaderStage {
216 pub const fn to_hlsl_str(self) -> &'static str {
217 match self {
218 Self::Vertex => "vs",
219 Self::Fragment => "ps",
220 Self::Compute => "cs",
221 Self::Task | Self::Mesh => unreachable!(),
222 }
223 }
224}
225
226impl crate::ImageDimension {
227 const fn to_hlsl_str(self) -> &'static str {
228 match self {
229 Self::D1 => "1D",
230 Self::D2 => "2D",
231 Self::D3 => "3D",
232 Self::Cube => "Cube",
233 }
234 }
235}
236
237#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
238#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
239#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
240pub struct SamplerIndexBufferKey {
241 pub group: u32,
242}
243
244#[derive(Clone, Debug, Hash, PartialEq, Eq)]
245#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
246#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
247#[cfg_attr(feature = "deserialize", serde(default))]
248pub struct SamplerHeapBindTargets {
249 pub standard_samplers: BindTarget,
250 pub comparison_samplers: BindTarget,
251}
252
253impl Default for SamplerHeapBindTargets {
254 fn default() -> Self {
255 Self {
256 standard_samplers: BindTarget {
257 space: 0,
258 register: 0,
259 binding_array_size: None,
260 dynamic_storage_buffer_offsets_index: None,
261 restrict_indexing: false,
262 },
263 comparison_samplers: BindTarget {
264 space: 1,
265 register: 0,
266 binding_array_size: None,
267 dynamic_storage_buffer_offsets_index: None,
268 restrict_indexing: false,
269 },
270 }
271 }
272}
273
274#[cfg(any(feature = "serialize", feature = "deserialize"))]
275#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
276#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
277struct SamplerIndexBufferBindingSerialization {
278 group: u32,
279 bind_target: BindTarget,
280}
281
282#[cfg(feature = "deserialize")]
283fn deserialize_sampler_index_buffer_bindings<'de, D>(
284 deserializer: D,
285) -> Result<SamplerIndexBufferBindingMap, D::Error>
286where
287 D: serde::Deserializer<'de>,
288{
289 use serde::Deserialize;
290
291 let vec = Vec::<SamplerIndexBufferBindingSerialization>::deserialize(deserializer)?;
292 let mut map = SamplerIndexBufferBindingMap::default();
293 for item in vec {
294 map.insert(
295 SamplerIndexBufferKey { group: item.group },
296 item.bind_target,
297 );
298 }
299 Ok(map)
300}
301
302pub type SamplerIndexBufferBindingMap =
304 alloc::collections::BTreeMap<SamplerIndexBufferKey, BindTarget>;
305
306#[cfg(any(feature = "serialize", feature = "deserialize"))]
307#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
308#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
309struct DynamicStorageBufferOffsetTargetSerialization {
310 index: u32,
311 bind_target: OffsetsBindTarget,
312}
313
314#[cfg(feature = "deserialize")]
315fn deserialize_storage_buffer_offsets<'de, D>(
316 deserializer: D,
317) -> Result<DynamicStorageBufferOffsetsTargets, D::Error>
318where
319 D: serde::Deserializer<'de>,
320{
321 use serde::Deserialize;
322
323 let vec = Vec::<DynamicStorageBufferOffsetTargetSerialization>::deserialize(deserializer)?;
324 let mut map = DynamicStorageBufferOffsetsTargets::default();
325 for item in vec {
326 map.insert(item.index, item.bind_target);
327 }
328 Ok(map)
329}
330
331pub type DynamicStorageBufferOffsetsTargets = alloc::collections::BTreeMap<u32, OffsetsBindTarget>;
332
333type BackendResult = Result<(), Error>;
335
336#[derive(Clone, Debug, PartialEq, thiserror::Error)]
337#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
338#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
339pub enum EntryPointError {
340 #[error("mapping of {0:?} is missing")]
341 MissingBinding(crate::ResourceBinding),
342}
343
344#[derive(Clone, Debug, Hash, PartialEq, Eq)]
346#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
347#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
348#[cfg_attr(feature = "deserialize", serde(default))]
349pub struct Options {
350 pub shader_model: ShaderModel,
352 #[cfg_attr(
354 feature = "deserialize",
355 serde(deserialize_with = "deserialize_binding_map")
356 )]
357 pub binding_map: BindingMap,
358 pub fake_missing_bindings: bool,
360 pub special_constants_binding: Option<BindTarget>,
363 pub push_constants_target: Option<BindTarget>,
365 pub sampler_heap_target: SamplerHeapBindTargets,
367 #[cfg_attr(
369 feature = "deserialize",
370 serde(deserialize_with = "deserialize_sampler_index_buffer_bindings")
371 )]
372 pub sampler_buffer_binding_map: SamplerIndexBufferBindingMap,
373 #[cfg_attr(
375 feature = "deserialize",
376 serde(deserialize_with = "deserialize_storage_buffer_offsets")
377 )]
378 pub dynamic_storage_buffer_offsets_targets: DynamicStorageBufferOffsetsTargets,
379 pub zero_initialize_workgroup_memory: bool,
381 pub restrict_indexing: bool,
383 pub force_loop_bounding: bool,
386}
387
388impl Default for Options {
389 fn default() -> Self {
390 Options {
391 shader_model: ShaderModel::V5_1,
392 binding_map: BindingMap::default(),
393 fake_missing_bindings: true,
394 special_constants_binding: None,
395 sampler_heap_target: SamplerHeapBindTargets::default(),
396 sampler_buffer_binding_map: alloc::collections::BTreeMap::default(),
397 push_constants_target: None,
398 dynamic_storage_buffer_offsets_targets: alloc::collections::BTreeMap::new(),
399 zero_initialize_workgroup_memory: true,
400 restrict_indexing: true,
401 force_loop_bounding: true,
402 }
403 }
404}
405
406impl Options {
407 fn resolve_resource_binding(
408 &self,
409 res_binding: &crate::ResourceBinding,
410 ) -> Result<BindTarget, EntryPointError> {
411 match self.binding_map.get(res_binding) {
412 Some(target) => Ok(*target),
413 None if self.fake_missing_bindings => Ok(BindTarget {
414 space: res_binding.group as u8,
415 register: res_binding.binding,
416 binding_array_size: None,
417 dynamic_storage_buffer_offsets_index: None,
418 restrict_indexing: false,
419 }),
420 None => Err(EntryPointError::MissingBinding(*res_binding)),
421 }
422 }
423}
424
425#[derive(Default)]
427pub struct ReflectionInfo {
428 pub entry_point_names: Vec<Result<String, EntryPointError>>,
435}
436
437#[derive(Debug, Default, Clone)]
439#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
440#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
441#[cfg_attr(feature = "deserialize", serde(default))]
442pub struct PipelineOptions {
443 pub entry_point: Option<(ir::ShaderStage, String)>,
451}
452
453#[derive(Error, Debug)]
454pub enum Error {
455 #[error(transparent)]
456 IoError(#[from] FmtError),
457 #[error("A scalar with an unsupported width was requested: {0:?}")]
458 UnsupportedScalar(crate::Scalar),
459 #[error("{0}")]
460 Unimplemented(String), #[error("{0}")]
462 Custom(String),
463 #[error("overrides should not be present at this stage")]
464 Override,
465 #[error(transparent)]
466 ResolveArraySizeError(#[from] proc::ResolveArraySizeError),
467 #[error("entry point with stage {0:?} and name '{1}' not found")]
468 EntryPointNotFound(ir::ShaderStage, String),
469}
470
471#[derive(PartialEq, Eq, Hash)]
472enum WrappedType {
473 ZeroValue(help::WrappedZeroValue),
474 ArrayLength(help::WrappedArrayLength),
475 ImageSample(help::WrappedImageSample),
476 ImageQuery(help::WrappedImageQuery),
477 ImageLoadScalar(crate::Scalar),
478 Constructor(help::WrappedConstructor),
479 StructMatrixAccess(help::WrappedStructMatrixAccess),
480 MatCx2(help::WrappedMatCx2),
481 Math(help::WrappedMath),
482 UnaryOp(help::WrappedUnaryOp),
483 BinaryOp(help::WrappedBinaryOp),
484 Cast(help::WrappedCast),
485}
486
487#[derive(Default)]
488struct Wrapped {
489 types: crate::FastHashSet<WrappedType>,
490 sampler_heaps: bool,
492 sampler_index_buffers: crate::FastHashMap<SamplerIndexBufferKey, String>,
494}
495
496impl Wrapped {
497 fn insert(&mut self, r#type: WrappedType) -> bool {
498 self.types.insert(r#type)
499 }
500
501 fn clear(&mut self) {
502 self.types.clear();
503 }
504}
505
506pub struct FragmentEntryPoint<'a> {
515 module: &'a crate::Module,
516 func: &'a crate::Function,
517}
518
519impl<'a> FragmentEntryPoint<'a> {
520 pub fn new(module: &'a crate::Module, ep_name: &'a str) -> Option<Self> {
523 module
524 .entry_points
525 .iter()
526 .find(|ep| ep.name == ep_name)
527 .filter(|ep| ep.stage == crate::ShaderStage::Fragment)
528 .map(|ep| Self {
529 module,
530 func: &ep.function,
531 })
532 }
533}
534
535pub struct Writer<'a, W> {
536 out: W,
537 names: crate::FastHashMap<proc::NameKey, String>,
538 namer: proc::Namer,
539 options: &'a Options,
541 pipeline_options: &'a PipelineOptions,
543 entry_point_io: crate::FastHashMap<usize, writer::EntryPointInterface>,
545 named_expressions: crate::NamedExpressions,
547 wrapped: Wrapped,
548 written_committed_intersection: bool,
549 written_candidate_intersection: bool,
550 continue_ctx: back::continue_forward::ContinueCtx,
551
552 temp_access_chain: Vec<storage::SubAccess>,
570 need_bake_expressions: back::NeedBakeExpressions,
571}