1use alloc::{
2 format,
3 string::{String, ToString},
4 vec::Vec,
5};
6use core::{fmt, mem};
7
8use super::{
9 help,
10 help::{
11 WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
12 WrappedZeroValue,
13 },
14 storage::StoreValue,
15 BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
16};
17use crate::{
18 back::{self, get_entry_points, Baked},
19 common,
20 proc::{self, index, NameKey},
21 valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
22};
23
24const LOCATION_SEMANTIC: &str = "LOC";
25const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
26const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
27const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
28const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
29const SPECIAL_OTHER: &str = "other";
30
31pub(crate) const MODF_FUNCTION: &str = "naga_modf";
32pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
33pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
34pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
35pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap";
36pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap";
37pub(crate) const ABS_FUNCTION: &str = "naga_abs";
38pub(crate) const DIV_FUNCTION: &str = "naga_div";
39pub(crate) const MOD_FUNCTION: &str = "naga_mod";
40pub(crate) const NEG_FUNCTION: &str = "naga_neg";
41pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
42pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
43pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
44pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
45pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
46 "nagaTextureSampleBaseClampToEdge";
47
48struct EpStructMember {
49 name: String,
50 ty: Handle<crate::Type>,
51 binding: Option<crate::Binding>,
54 index: u32,
55}
56
57struct EntryPointBinding {
60 arg_name: String,
63 ty_name: String,
65 members: Vec<EpStructMember>,
67}
68
69pub(super) struct EntryPointInterface {
70 input: Option<EntryPointBinding>,
75 output: Option<EntryPointBinding>,
79}
80
81#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
82enum InterfaceKey {
83 Location(u32),
84 BuiltIn(crate::BuiltIn),
85 Other,
86}
87
88impl InterfaceKey {
89 const fn new(binding: Option<&crate::Binding>) -> Self {
90 match binding {
91 Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
92 Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
93 None => Self::Other,
94 }
95 }
96}
97
98#[derive(Copy, Clone, PartialEq)]
99enum Io {
100 Input,
101 Output,
102}
103
104const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
105 let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
106 return false;
107 };
108 matches!(
109 builtin,
110 crate::BuiltIn::SubgroupSize
111 | crate::BuiltIn::SubgroupInvocationId
112 | crate::BuiltIn::NumSubgroups
113 | crate::BuiltIn::SubgroupId
114 )
115}
116
117struct BindingArraySamplerInfo {
119 sampler_heap_name: &'static str,
121 sampler_index_buffer_name: String,
123 binding_array_base_index_name: String,
125}
126
127impl<'a, W: fmt::Write> super::Writer<'a, W> {
128 pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
129 Self {
130 out,
131 names: crate::FastHashMap::default(),
132 namer: proc::Namer::default(),
133 options,
134 pipeline_options,
135 entry_point_io: crate::FastHashMap::default(),
136 named_expressions: crate::NamedExpressions::default(),
137 wrapped: super::Wrapped::default(),
138 written_committed_intersection: false,
139 written_candidate_intersection: false,
140 continue_ctx: back::continue_forward::ContinueCtx::default(),
141 temp_access_chain: Vec::new(),
142 need_bake_expressions: Default::default(),
143 }
144 }
145
146 fn reset(&mut self, module: &Module) {
147 self.names.clear();
148 self.namer.reset(
149 module,
150 &super::keywords::RESERVED_SET,
151 super::keywords::RESERVED_CASE_INSENSITIVE,
152 super::keywords::RESERVED_PREFIXES,
153 &mut self.names,
154 );
155 self.entry_point_io.clear();
156 self.named_expressions.clear();
157 self.wrapped.clear();
158 self.written_committed_intersection = false;
159 self.written_candidate_intersection = false;
160 self.continue_ctx.clear();
161 self.need_bake_expressions.clear();
162 }
163
164 fn gen_force_bounded_loop_statements(
172 &mut self,
173 level: back::Level,
174 ) -> Option<(String, String)> {
175 if !self.options.force_loop_bounding {
176 return None;
177 }
178
179 let loop_bound_name = self.namer.call("loop_bound");
180 let max = u32::MAX;
181 let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
184 let level = level.next();
185 let break_and_inc = format!(
186 "{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
187{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
188 );
189
190 Some((decl, break_and_inc))
191 }
192
193 fn update_expressions_to_bake(
198 &mut self,
199 module: &Module,
200 func: &crate::Function,
201 info: &valid::FunctionInfo,
202 ) {
203 use crate::Expression;
204 self.need_bake_expressions.clear();
205 for (exp_handle, expr) in func.expressions.iter() {
206 let expr_info = &info[exp_handle];
207 let min_ref_count = func.expressions[exp_handle].bake_ref_count();
208 if min_ref_count <= expr_info.ref_count {
209 self.need_bake_expressions.insert(exp_handle);
210 }
211
212 if let Expression::Math { fun, arg, arg1, .. } = *expr {
213 match fun {
214 crate::MathFunction::Asinh
215 | crate::MathFunction::Acosh
216 | crate::MathFunction::Atanh
217 | crate::MathFunction::Unpack2x16float
218 | crate::MathFunction::Unpack2x16snorm
219 | crate::MathFunction::Unpack2x16unorm
220 | crate::MathFunction::Unpack4x8snorm
221 | crate::MathFunction::Unpack4x8unorm
222 | crate::MathFunction::Unpack4xI8
223 | crate::MathFunction::Unpack4xU8
224 | crate::MathFunction::Pack2x16float
225 | crate::MathFunction::Pack2x16snorm
226 | crate::MathFunction::Pack2x16unorm
227 | crate::MathFunction::Pack4x8snorm
228 | crate::MathFunction::Pack4x8unorm
229 | crate::MathFunction::Pack4xI8
230 | crate::MathFunction::Pack4xU8
231 | crate::MathFunction::Pack4xI8Clamp
232 | crate::MathFunction::Pack4xU8Clamp => {
233 self.need_bake_expressions.insert(arg);
234 }
235 crate::MathFunction::CountLeadingZeros => {
236 let inner = info[exp_handle].ty.inner_with(&module.types);
237 if let Some(ScalarKind::Sint) = inner.scalar_kind() {
238 self.need_bake_expressions.insert(arg);
239 }
240 }
241 crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
242 self.need_bake_expressions.insert(arg);
243 self.need_bake_expressions.insert(arg1.unwrap());
244 }
245 _ => {}
246 }
247 }
248
249 if let Expression::Derivative { axis, ctrl, expr } = *expr {
250 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
251 if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
252 self.need_bake_expressions.insert(expr);
253 }
254 }
255
256 if let Expression::GlobalVariable(_) = *expr {
257 let inner = info[exp_handle].ty.inner_with(&module.types);
258
259 if let TypeInner::Sampler { .. } = *inner {
260 self.need_bake_expressions.insert(exp_handle);
261 }
262 }
263 }
264 for statement in func.body.iter() {
265 match *statement {
266 crate::Statement::SubgroupCollectiveOperation {
267 op: _,
268 collective_op: crate::CollectiveOperation::InclusiveScan,
269 argument,
270 result: _,
271 } => {
272 self.need_bake_expressions.insert(argument);
273 }
274 crate::Statement::Atomic {
275 fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
276 ..
277 } => {
278 self.need_bake_expressions.insert(cmp);
279 }
280 _ => {}
281 }
282 }
283 }
284
285 pub fn write(
286 &mut self,
287 module: &Module,
288 module_info: &valid::ModuleInfo,
289 fragment_entry_point: Option<&FragmentEntryPoint<'_>>,
290 ) -> Result<super::ReflectionInfo, Error> {
291 self.reset(module);
292
293 if let Some(ref bt) = self.options.special_constants_binding {
295 writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
296 writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
297 writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
298 writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
299 writeln!(self.out, "}};")?;
300 write!(
301 self.out,
302 "ConstantBuffer<{}> {}: register(b{}",
303 SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
304 )?;
305 if bt.space != 0 {
306 write!(self.out, ", space{}", bt.space)?;
307 }
308 writeln!(self.out, ");")?;
309
310 writeln!(self.out)?;
312 }
313
314 for (group, bt) in self.options.dynamic_storage_buffer_offsets_targets.iter() {
315 writeln!(self.out, "struct __dynamic_buffer_offsetsTy{} {{", group)?;
316 for i in 0..bt.size {
317 writeln!(self.out, "{}uint _{};", back::INDENT, i)?;
318 }
319 writeln!(self.out, "}};")?;
320 writeln!(
321 self.out,
322 "ConstantBuffer<__dynamic_buffer_offsetsTy{}> __dynamic_buffer_offsets{}: register(b{}, space{});",
323 group, group, bt.register, bt.space
324 )?;
325
326 writeln!(self.out)?;
328 }
329
330 let ep_results = module
332 .entry_points
333 .iter()
334 .map(|ep| (ep.stage, ep.function.result.clone()))
335 .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
336
337 self.write_all_mat_cx2_typedefs_and_functions(module)?;
338
339 for (handle, ty) in module.types.iter() {
341 if let TypeInner::Struct { ref members, span } = ty.inner {
342 if module.types[members.last().unwrap().ty]
343 .inner
344 .is_dynamically_sized(&module.types)
345 {
346 continue;
349 }
350
351 let ep_result = ep_results.iter().find(|e| {
352 if let Some(ref result) = e.1 {
353 result.ty == handle
354 } else {
355 false
356 }
357 });
358
359 self.write_struct(
360 module,
361 handle,
362 members,
363 span,
364 ep_result.map(|r| (r.0, Io::Output)),
365 )?;
366 writeln!(self.out)?;
367 }
368 }
369
370 self.write_special_functions(module)?;
371
372 self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
373 self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
374
375 let mut constants = module
377 .constants
378 .iter()
379 .filter(|&(_, c)| c.name.is_some())
380 .peekable();
381 while let Some((handle, _)) = constants.next() {
382 self.write_global_constant(module, handle)?;
383 if constants.peek().is_none() {
385 writeln!(self.out)?;
386 }
387 }
388
389 for (ty, _) in module.global_variables.iter() {
391 self.write_global(module, ty)?;
392 }
393
394 if !module.global_variables.is_empty() {
395 writeln!(self.out)?;
397 }
398
399 let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
400 .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
401
402 for index in ep_range.clone() {
404 let ep = &module.entry_points[index];
405 let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
406 let ep_io = self.write_ep_interface(
407 module,
408 &ep.function,
409 ep.stage,
410 &ep_name,
411 fragment_entry_point,
412 )?;
413 self.entry_point_io.insert(index, ep_io);
414 }
415
416 for (handle, function) in module.functions.iter() {
418 let info = &module_info[handle];
419
420 if !self.options.fake_missing_bindings {
422 if let Some((var_handle, _)) =
423 module
424 .global_variables
425 .iter()
426 .find(|&(var_handle, var)| match var.binding {
427 Some(ref binding) if !info[var_handle].is_empty() => {
428 self.options.resolve_resource_binding(binding).is_err()
429 }
430 _ => false,
431 })
432 {
433 log::info!(
434 "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
435 handle,
436 function.name,
437 var_handle
438 );
439 continue;
440 }
441 }
442
443 let ctx = back::FunctionCtx {
444 ty: back::FunctionType::Function(handle),
445 info,
446 expressions: &function.expressions,
447 named_expressions: &function.named_expressions,
448 };
449 let name = self.names[&NameKey::Function(handle)].clone();
450
451 self.write_wrapped_functions(module, &ctx)?;
452
453 self.write_function(module, name.as_str(), function, &ctx, info)?;
454
455 writeln!(self.out)?;
456 }
457
458 let mut translated_ep_names = Vec::with_capacity(ep_range.len());
459
460 for index in ep_range {
462 let ep = &module.entry_points[index];
463 let info = module_info.get_entry_point(index);
464
465 if !self.options.fake_missing_bindings {
466 let mut ep_error = None;
467 for (var_handle, var) in module.global_variables.iter() {
468 match var.binding {
469 Some(ref binding) if !info[var_handle].is_empty() => {
470 if let Err(err) = self.options.resolve_resource_binding(binding) {
471 ep_error = Some(err);
472 break;
473 }
474 }
475 _ => {}
476 }
477 }
478 if let Some(err) = ep_error {
479 translated_ep_names.push(Err(err));
480 continue;
481 }
482 }
483
484 let ctx = back::FunctionCtx {
485 ty: back::FunctionType::EntryPoint(index as u16),
486 info,
487 expressions: &ep.function.expressions,
488 named_expressions: &ep.function.named_expressions,
489 };
490
491 self.write_wrapped_functions(module, &ctx)?;
492
493 if ep.stage == ShaderStage::Compute {
494 let num_threads = ep.workgroup_size;
496 writeln!(
497 self.out,
498 "[numthreads({}, {}, {})]",
499 num_threads[0], num_threads[1], num_threads[2]
500 )?;
501 }
502
503 let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
504 self.write_function(module, &name, &ep.function, &ctx, info)?;
505
506 if index < module.entry_points.len() - 1 {
507 writeln!(self.out)?;
508 }
509
510 translated_ep_names.push(Ok(name));
511 }
512
513 Ok(super::ReflectionInfo {
514 entry_point_names: translated_ep_names,
515 })
516 }
517
518 fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
519 match *binding {
520 crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
521 write!(self.out, "precise ")?;
522 }
523 crate::Binding::Location {
524 interpolation,
525 sampling,
526 ..
527 } => {
528 if let Some(interpolation) = interpolation {
529 if let Some(string) = interpolation.to_hlsl_str() {
530 write!(self.out, "{string} ")?
531 }
532 }
533
534 if let Some(sampling) = sampling {
535 if let Some(string) = sampling.to_hlsl_str() {
536 write!(self.out, "{string} ")?
537 }
538 }
539 }
540 crate::Binding::BuiltIn(_) => {}
541 }
542
543 Ok(())
544 }
545
546 fn write_semantic(
549 &mut self,
550 binding: &Option<crate::Binding>,
551 stage: Option<(ShaderStage, Io)>,
552 ) -> BackendResult {
553 match *binding {
554 Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
555 let builtin_str = builtin.to_hlsl_str()?;
556 write!(self.out, " : {builtin_str}")?;
557 }
558 Some(crate::Binding::Location {
559 blend_src: Some(1), ..
560 }) => {
561 write!(self.out, " : SV_Target1")?;
562 }
563 Some(crate::Binding::Location { location, .. }) => {
564 if stage == Some((ShaderStage::Fragment, Io::Output)) {
565 write!(self.out, " : SV_Target{location}")?;
566 } else {
567 write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
568 }
569 }
570 _ => {}
571 }
572
573 Ok(())
574 }
575
576 fn write_interface_struct(
577 &mut self,
578 module: &Module,
579 shader_stage: (ShaderStage, Io),
580 struct_name: String,
581 mut members: Vec<EpStructMember>,
582 ) -> Result<EntryPointBinding, Error> {
583 members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
587
588 write!(self.out, "struct {struct_name}")?;
589 writeln!(self.out, " {{")?;
590 for m in members.iter() {
591 debug_assert!(m.binding.is_some());
594
595 if is_subgroup_builtin_binding(&m.binding) {
596 continue;
597 }
598 write!(self.out, "{}", back::INDENT)?;
599 if let Some(ref binding) = m.binding {
600 self.write_modifier(binding)?;
601 }
602 self.write_type(module, m.ty)?;
603 write!(self.out, " {}", &m.name)?;
604 self.write_semantic(&m.binding, Some(shader_stage))?;
605 writeln!(self.out, ";")?;
606 }
607 if members.iter().any(|arg| {
608 matches!(
609 arg.binding,
610 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId))
611 )
612 }) {
613 writeln!(
614 self.out,
615 "{}uint __local_invocation_index : SV_GroupIndex;",
616 back::INDENT
617 )?;
618 }
619 writeln!(self.out, "}};")?;
620 writeln!(self.out)?;
621
622 match shader_stage.1 {
624 Io::Input => {
625 members.sort_by_key(|m| m.index);
627 }
628 Io::Output => {
629 }
631 }
632
633 Ok(EntryPointBinding {
634 arg_name: self.namer.call(struct_name.to_lowercase().as_str()),
635 ty_name: struct_name,
636 members,
637 })
638 }
639
640 fn write_ep_input_struct(
644 &mut self,
645 module: &Module,
646 func: &crate::Function,
647 stage: ShaderStage,
648 entry_point_name: &str,
649 ) -> Result<EntryPointBinding, Error> {
650 let struct_name = format!("{stage:?}Input_{entry_point_name}");
651
652 let mut fake_members = Vec::new();
653 for arg in func.arguments.iter() {
654 match module.types[arg.ty].inner {
659 TypeInner::Struct { ref members, .. } => {
660 for member in members.iter() {
661 let name = self.namer.call_or(&member.name, "member");
662 let index = fake_members.len() as u32;
663 fake_members.push(EpStructMember {
664 name,
665 ty: member.ty,
666 binding: member.binding.clone(),
667 index,
668 });
669 }
670 }
671 _ => {
672 let member_name = self.namer.call_or(&arg.name, "member");
673 let index = fake_members.len() as u32;
674 fake_members.push(EpStructMember {
675 name: member_name,
676 ty: arg.ty,
677 binding: arg.binding.clone(),
678 index,
679 });
680 }
681 }
682 }
683
684 self.write_interface_struct(module, (stage, Io::Input), struct_name, fake_members)
685 }
686
687 fn write_ep_output_struct(
691 &mut self,
692 module: &Module,
693 result: &crate::FunctionResult,
694 stage: ShaderStage,
695 entry_point_name: &str,
696 frag_ep: Option<&FragmentEntryPoint<'_>>,
697 ) -> Result<EntryPointBinding, Error> {
698 let struct_name = format!("{stage:?}Output_{entry_point_name}");
699
700 let empty = [];
701 let members = match module.types[result.ty].inner {
702 TypeInner::Struct { ref members, .. } => members,
703 ref other => {
704 log::error!("Unexpected {:?} output type without a binding", other);
705 &empty[..]
706 }
707 };
708
709 let fs_input_locs = if let (Some(frag_ep), ShaderStage::Vertex) = (frag_ep, stage) {
714 let mut fs_input_locs = Vec::new();
715 for arg in frag_ep.func.arguments.iter() {
716 let mut push_if_location = |binding: &Option<crate::Binding>| match *binding {
717 Some(crate::Binding::Location { location, .. }) => fs_input_locs.push(location),
718 Some(crate::Binding::BuiltIn(_)) | None => {}
719 };
720
721 match frag_ep.module.types[arg.ty].inner {
724 TypeInner::Struct { ref members, .. } => {
725 for member in members.iter() {
726 push_if_location(&member.binding);
727 }
728 }
729 _ => push_if_location(&arg.binding),
730 }
731 }
732 fs_input_locs.sort();
733 Some(fs_input_locs)
734 } else {
735 None
736 };
737
738 let mut fake_members = Vec::new();
739 for (index, member) in members.iter().enumerate() {
740 if let Some(ref fs_input_locs) = fs_input_locs {
741 match member.binding {
742 Some(crate::Binding::Location { location, .. }) => {
743 if fs_input_locs.binary_search(&location).is_err() {
744 continue;
745 }
746 }
747 Some(crate::Binding::BuiltIn(_)) | None => {}
748 }
749 }
750
751 let member_name = self.namer.call_or(&member.name, "member");
752 fake_members.push(EpStructMember {
753 name: member_name,
754 ty: member.ty,
755 binding: member.binding.clone(),
756 index: index as u32,
757 });
758 }
759
760 self.write_interface_struct(module, (stage, Io::Output), struct_name, fake_members)
761 }
762
763 fn write_ep_interface(
767 &mut self,
768 module: &Module,
769 func: &crate::Function,
770 stage: ShaderStage,
771 ep_name: &str,
772 frag_ep: Option<&FragmentEntryPoint<'_>>,
773 ) -> Result<EntryPointInterface, Error> {
774 Ok(EntryPointInterface {
775 input: if !func.arguments.is_empty()
776 && (stage == ShaderStage::Fragment
777 || func
778 .arguments
779 .iter()
780 .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
781 {
782 Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
783 } else {
784 None
785 },
786 output: match func.result {
787 Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
788 Some(self.write_ep_output_struct(module, fr, stage, ep_name, frag_ep)?)
789 }
790 _ => None,
791 },
792 })
793 }
794
795 fn write_ep_argument_initialization(
796 &mut self,
797 ep: &crate::EntryPoint,
798 ep_input: &EntryPointBinding,
799 fake_member: &EpStructMember,
800 ) -> BackendResult {
801 match fake_member.binding {
802 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
803 write!(self.out, "WaveGetLaneCount()")?
804 }
805 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
806 write!(self.out, "WaveGetLaneIndex()")?
807 }
808 Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
809 self.out,
810 "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
811 ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
812 )?,
813 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
814 write!(
815 self.out,
816 "{}.__local_invocation_index / WaveGetLaneCount()",
817 ep_input.arg_name
818 )?;
819 }
820 _ => {
821 write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
822 }
823 }
824 Ok(())
825 }
826
827 fn write_ep_arguments_initialization(
829 &mut self,
830 module: &Module,
831 func: &crate::Function,
832 ep_index: u16,
833 ) -> BackendResult {
834 let ep = &module.entry_points[ep_index as usize];
835 let ep_input = match self
836 .entry_point_io
837 .get_mut(&(ep_index as usize))
838 .unwrap()
839 .input
840 .take()
841 {
842 Some(ep_input) => ep_input,
843 None => return Ok(()),
844 };
845 let mut fake_iter = ep_input.members.iter();
846 for (arg_index, arg) in func.arguments.iter().enumerate() {
847 write!(self.out, "{}", back::INDENT)?;
848 self.write_type(module, arg.ty)?;
849 let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
850 write!(self.out, " {arg_name}")?;
851 match module.types[arg.ty].inner {
852 TypeInner::Array { base, size, .. } => {
853 self.write_array_size(module, base, size)?;
854 write!(self.out, " = ")?;
855 self.write_ep_argument_initialization(
856 ep,
857 &ep_input,
858 fake_iter.next().unwrap(),
859 )?;
860 writeln!(self.out, ";")?;
861 }
862 TypeInner::Struct { ref members, .. } => {
863 write!(self.out, " = {{ ")?;
864 for index in 0..members.len() {
865 if index != 0 {
866 write!(self.out, ", ")?;
867 }
868 self.write_ep_argument_initialization(
869 ep,
870 &ep_input,
871 fake_iter.next().unwrap(),
872 )?;
873 }
874 writeln!(self.out, " }};")?;
875 }
876 _ => {
877 write!(self.out, " = ")?;
878 self.write_ep_argument_initialization(
879 ep,
880 &ep_input,
881 fake_iter.next().unwrap(),
882 )?;
883 writeln!(self.out, ";")?;
884 }
885 }
886 }
887 assert!(fake_iter.next().is_none());
888 Ok(())
889 }
890
891 fn write_global(
895 &mut self,
896 module: &Module,
897 handle: Handle<crate::GlobalVariable>,
898 ) -> BackendResult {
899 let global = &module.global_variables[handle];
900 let inner = &module.types[global.ty].inner;
901
902 if let Some(ref binding) = global.binding {
903 if let Err(err) = self.options.resolve_resource_binding(binding) {
904 log::info!(
905 "Skipping global {:?} (name {:?}) for being inaccessible: {}",
906 handle,
907 global.name,
908 err,
909 );
910 return Ok(());
911 }
912 }
913
914 let handle_ty = match *inner {
915 TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
916 _ => inner,
917 };
918
919 let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. });
921
922 if is_sampler {
923 return self.write_global_sampler(module, handle, global);
924 }
925
926 let register_ty = match global.space {
928 crate::AddressSpace::Function => unreachable!("Function address space"),
929 crate::AddressSpace::Private => {
930 write!(self.out, "static ")?;
931 self.write_type(module, global.ty)?;
932 ""
933 }
934 crate::AddressSpace::WorkGroup => {
935 write!(self.out, "groupshared ")?;
936 self.write_type(module, global.ty)?;
937 ""
938 }
939 crate::AddressSpace::Uniform => {
940 write!(self.out, "cbuffer")?;
943 "b"
944 }
945 crate::AddressSpace::Storage { access } => {
946 let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
947 ("RW", "u")
948 } else {
949 ("", "t")
950 };
951 write!(self.out, "{prefix}ByteAddressBuffer")?;
952 register
953 }
954 crate::AddressSpace::Handle => {
955 let register = match *handle_ty {
956 TypeInner::Image {
958 class: crate::ImageClass::Storage { .. },
959 ..
960 } => "u",
961 _ => "t",
962 };
963 self.write_type(module, global.ty)?;
964 register
965 }
966 crate::AddressSpace::PushConstant => {
967 write!(self.out, "ConstantBuffer<")?;
969 "b"
970 }
971 };
972
973 if global.space == crate::AddressSpace::PushConstant {
976 self.write_global_type(module, global.ty)?;
977
978 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
980 self.write_array_size(module, base, size)?;
981 }
982
983 write!(self.out, ">")?;
985 }
986
987 let name = &self.names[&NameKey::GlobalVariable(handle)];
988 write!(self.out, " {name}")?;
989
990 if global.space == crate::AddressSpace::PushConstant {
993 match module.types[global.ty].inner {
994 TypeInner::Struct { .. } => {}
995 _ => {
996 return Err(Error::Unimplemented(format!(
997 "push-constant '{name}' has non-struct type; tracked by: https://github.com/gfx-rs/wgpu/issues/5683"
998 )));
999 }
1000 }
1001
1002 let target = self
1003 .options
1004 .push_constants_target
1005 .as_ref()
1006 .expect("No bind target was defined for the push constants block");
1007 write!(self.out, ": register(b{}", target.register)?;
1008 if target.space != 0 {
1009 write!(self.out, ", space{}", target.space)?;
1010 }
1011 write!(self.out, ")")?;
1012 }
1013
1014 if let Some(ref binding) = global.binding {
1015 let bt = self.options.resolve_resource_binding(binding).unwrap();
1017
1018 if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
1020 if let Some(overridden_size) = bt.binding_array_size {
1021 write!(self.out, "[{overridden_size}]")?;
1022 } else {
1023 self.write_array_size(module, base, size)?;
1024 }
1025 }
1026
1027 write!(self.out, " : register({}{}", register_ty, bt.register)?;
1028 if bt.space != 0 {
1029 write!(self.out, ", space{}", bt.space)?;
1030 }
1031 write!(self.out, ")")?;
1032 } else {
1033 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1035 self.write_array_size(module, base, size)?;
1036 }
1037 if global.space == crate::AddressSpace::Private {
1038 write!(self.out, " = ")?;
1039 if let Some(init) = global.init {
1040 self.write_const_expression(module, init, &module.global_expressions)?;
1041 } else {
1042 self.write_default_init(module, global.ty)?;
1043 }
1044 }
1045 }
1046
1047 if global.space == crate::AddressSpace::Uniform {
1048 write!(self.out, " {{ ")?;
1049
1050 self.write_global_type(module, global.ty)?;
1051
1052 write!(
1053 self.out,
1054 " {}",
1055 &self.names[&NameKey::GlobalVariable(handle)]
1056 )?;
1057
1058 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1060 self.write_array_size(module, base, size)?;
1061 }
1062
1063 writeln!(self.out, "; }}")?;
1064 } else {
1065 writeln!(self.out, ";")?;
1066 }
1067
1068 Ok(())
1069 }
1070
1071 fn write_global_sampler(
1072 &mut self,
1073 module: &Module,
1074 handle: Handle<crate::GlobalVariable>,
1075 global: &crate::GlobalVariable,
1076 ) -> BackendResult {
1077 let binding = *global.binding.as_ref().unwrap();
1078
1079 let key = super::SamplerIndexBufferKey {
1080 group: binding.group,
1081 };
1082 self.write_wrapped_sampler_buffer(key)?;
1083
1084 let bt = self.options.resolve_resource_binding(&binding).unwrap();
1086
1087 match module.types[global.ty].inner {
1088 TypeInner::Sampler { comparison } => {
1089 write!(self.out, "static const ")?;
1096 self.write_type(module, global.ty)?;
1097
1098 let heap_var = if comparison {
1099 COMPARISON_SAMPLER_HEAP_VAR
1100 } else {
1101 SAMPLER_HEAP_VAR
1102 };
1103
1104 let index_buffer_name = &self.wrapped.sampler_index_buffers[&key];
1105 let name = &self.names[&NameKey::GlobalVariable(handle)];
1106 writeln!(
1107 self.out,
1108 " {name} = {heap_var}[{index_buffer_name}[{register}]];",
1109 register = bt.register
1110 )?;
1111 }
1112 TypeInner::BindingArray { .. } => {
1113 let name = &self.names[&NameKey::GlobalVariable(handle)];
1119 writeln!(
1120 self.out,
1121 "static const uint {name} = {register};",
1122 register = bt.register
1123 )?;
1124 }
1125 _ => unreachable!(),
1126 };
1127
1128 Ok(())
1129 }
1130
1131 fn write_global_constant(
1136 &mut self,
1137 module: &Module,
1138 handle: Handle<crate::Constant>,
1139 ) -> BackendResult {
1140 write!(self.out, "static const ")?;
1141 let constant = &module.constants[handle];
1142 self.write_type(module, constant.ty)?;
1143 let name = &self.names[&NameKey::Constant(handle)];
1144 write!(self.out, " {name}")?;
1145 if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
1147 self.write_array_size(module, base, size)?;
1148 }
1149 write!(self.out, " = ")?;
1150 self.write_const_expression(module, constant.init, &module.global_expressions)?;
1151 writeln!(self.out, ";")?;
1152 Ok(())
1153 }
1154
1155 pub(super) fn write_array_size(
1156 &mut self,
1157 module: &Module,
1158 base: Handle<crate::Type>,
1159 size: crate::ArraySize,
1160 ) -> BackendResult {
1161 write!(self.out, "[")?;
1162
1163 match size.resolve(module.to_ctx())? {
1164 proc::IndexableLength::Known(size) => {
1165 write!(self.out, "{size}")?;
1166 }
1167 proc::IndexableLength::Dynamic => unreachable!(),
1168 }
1169
1170 write!(self.out, "]")?;
1171
1172 if let TypeInner::Array {
1173 base: next_base,
1174 size: next_size,
1175 ..
1176 } = module.types[base].inner
1177 {
1178 self.write_array_size(module, next_base, next_size)?;
1179 }
1180
1181 Ok(())
1182 }
1183
1184 fn write_struct(
1189 &mut self,
1190 module: &Module,
1191 handle: Handle<crate::Type>,
1192 members: &[crate::StructMember],
1193 span: u32,
1194 shader_stage: Option<(ShaderStage, Io)>,
1195 ) -> BackendResult {
1196 let struct_name = &self.names[&NameKey::Type(handle)];
1198 writeln!(self.out, "struct {struct_name} {{")?;
1199
1200 let mut last_offset = 0;
1201 for (index, member) in members.iter().enumerate() {
1202 if member.binding.is_none() && member.offset > last_offset {
1203 let padding = (member.offset - last_offset) / 4;
1207 for i in 0..padding {
1208 writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
1209 }
1210 }
1211 let ty_inner = &module.types[member.ty].inner;
1212 last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;
1213
1214 write!(self.out, "{}", back::INDENT)?;
1216
1217 match module.types[member.ty].inner {
1218 TypeInner::Array { base, size, .. } => {
1219 self.write_global_type(module, member.ty)?;
1222
1223 write!(
1225 self.out,
1226 " {}",
1227 &self.names[&NameKey::StructMember(handle, index as u32)]
1228 )?;
1229 self.write_array_size(module, base, size)?;
1231 }
1232 TypeInner::Matrix {
1235 rows,
1236 columns,
1237 scalar,
1238 } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
1239 let vec_ty = TypeInner::Vector { size: rows, scalar };
1240 let field_name_key = NameKey::StructMember(handle, index as u32);
1241
1242 for i in 0..columns as u8 {
1243 if i != 0 {
1244 write!(self.out, "; ")?;
1245 }
1246 self.write_value_type(module, &vec_ty)?;
1247 write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
1248 }
1249 }
1250 _ => {
1251 if let Some(ref binding) = member.binding {
1253 self.write_modifier(binding)?;
1254 }
1255
1256 if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
1260 write!(self.out, "row_major ")?;
1261 }
1262
1263 self.write_type(module, member.ty)?;
1265 write!(
1266 self.out,
1267 " {}",
1268 &self.names[&NameKey::StructMember(handle, index as u32)]
1269 )?;
1270 }
1271 }
1272
1273 self.write_semantic(&member.binding, shader_stage)?;
1274 writeln!(self.out, ";")?;
1275 }
1276
1277 if members.last().unwrap().binding.is_none() && span > last_offset {
1279 let padding = (span - last_offset) / 4;
1280 for i in 0..padding {
1281 writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
1282 }
1283 }
1284
1285 writeln!(self.out, "}};")?;
1286 Ok(())
1287 }
1288
1289 pub(super) fn write_global_type(
1294 &mut self,
1295 module: &Module,
1296 ty: Handle<crate::Type>,
1297 ) -> BackendResult {
1298 let matrix_data = get_inner_matrix_data(module, ty);
1299
1300 if let Some(MatrixType {
1303 columns,
1304 rows: crate::VectorSize::Bi,
1305 width: 4,
1306 }) = matrix_data
1307 {
1308 write!(self.out, "__mat{}x2", columns as u8)?;
1309 } else {
1310 if matrix_data.is_some() {
1314 write!(self.out, "row_major ")?;
1315 }
1316
1317 self.write_type(module, ty)?;
1318 }
1319
1320 Ok(())
1321 }
1322
1323 pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
1328 let inner = &module.types[ty].inner;
1329 match *inner {
1330 TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
1331 TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
1333 self.write_type(module, base)?
1334 }
1335 ref other => self.write_value_type(module, other)?,
1336 }
1337
1338 Ok(())
1339 }
1340
1341 pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
1346 match *inner {
1347 TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
1348 write!(self.out, "{}", scalar.to_hlsl_str()?)?;
1349 }
1350 TypeInner::Vector { size, scalar } => {
1351 write!(
1352 self.out,
1353 "{}{}",
1354 scalar.to_hlsl_str()?,
1355 common::vector_size_str(size)
1356 )?;
1357 }
1358 TypeInner::Matrix {
1359 columns,
1360 rows,
1361 scalar,
1362 } => {
1363 write!(
1368 self.out,
1369 "{}{}x{}",
1370 scalar.to_hlsl_str()?,
1371 common::vector_size_str(columns),
1372 common::vector_size_str(rows),
1373 )?;
1374 }
1375 TypeInner::Image {
1376 dim,
1377 arrayed,
1378 class,
1379 } => {
1380 self.write_image_type(dim, arrayed, class)?;
1381 }
1382 TypeInner::Sampler { comparison } => {
1383 let sampler = if comparison {
1384 "SamplerComparisonState"
1385 } else {
1386 "SamplerState"
1387 };
1388 write!(self.out, "{sampler}")?;
1389 }
1390 TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
1394 self.write_array_size(module, base, size)?;
1395 }
1396 TypeInner::AccelerationStructure { .. } => {
1397 write!(self.out, "RaytracingAccelerationStructure")?;
1398 }
1399 TypeInner::RayQuery { .. } => {
1400 write!(self.out, "RayQuery<RAY_FLAG_NONE>")?;
1402 }
1403 _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
1404 }
1405
1406 Ok(())
1407 }
1408
1409 fn write_function(
1413 &mut self,
1414 module: &Module,
1415 name: &str,
1416 func: &crate::Function,
1417 func_ctx: &back::FunctionCtx<'_>,
1418 info: &valid::FunctionInfo,
1419 ) -> BackendResult {
1420 self.update_expressions_to_bake(module, func, info);
1423
1424 if let Some(ref result) = func.result {
1425 let array_return_type = match module.types[result.ty].inner {
1427 TypeInner::Array { base, size, .. } => {
1428 let array_return_type = self.namer.call(&format!("ret_{name}"));
1429 write!(self.out, "typedef ")?;
1430 self.write_type(module, result.ty)?;
1431 write!(self.out, " {}", array_return_type)?;
1432 self.write_array_size(module, base, size)?;
1433 writeln!(self.out, ";")?;
1434 Some(array_return_type)
1435 }
1436 _ => None,
1437 };
1438
1439 if let Some(
1441 ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }),
1442 ) = result.binding
1443 {
1444 self.write_modifier(binding)?;
1445 }
1446
1447 match func_ctx.ty {
1449 back::FunctionType::Function(_) => {
1450 if let Some(array_return_type) = array_return_type {
1451 write!(self.out, "{array_return_type}")?;
1452 } else {
1453 self.write_type(module, result.ty)?;
1454 }
1455 }
1456 back::FunctionType::EntryPoint(index) => {
1457 if let Some(ref ep_output) =
1458 self.entry_point_io.get(&(index as usize)).unwrap().output
1459 {
1460 write!(self.out, "{}", ep_output.ty_name)?;
1461 } else {
1462 self.write_type(module, result.ty)?;
1463 }
1464 }
1465 }
1466 } else {
1467 write!(self.out, "void")?;
1468 }
1469
1470 write!(self.out, " {name}(")?;
1472
1473 let need_workgroup_variables_initialization =
1474 self.need_workgroup_variables_initialization(func_ctx, module);
1475
1476 match func_ctx.ty {
1478 back::FunctionType::Function(handle) => {
1479 for (index, arg) in func.arguments.iter().enumerate() {
1480 if index != 0 {
1481 write!(self.out, ", ")?;
1482 }
1483 let arg_ty = match module.types[arg.ty].inner {
1485 TypeInner::Pointer { base, .. } => {
1487 write!(self.out, "inout ")?;
1489 base
1490 }
1491 _ => arg.ty,
1492 };
1493 self.write_type(module, arg_ty)?;
1494
1495 let argument_name =
1496 &self.names[&NameKey::FunctionArgument(handle, index as u32)];
1497
1498 write!(self.out, " {argument_name}")?;
1500 if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
1501 self.write_array_size(module, base, size)?;
1502 }
1503 }
1504 }
1505 back::FunctionType::EntryPoint(ep_index) => {
1506 if let Some(ref ep_input) =
1507 self.entry_point_io.get(&(ep_index as usize)).unwrap().input
1508 {
1509 write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
1510 } else {
1511 let stage = module.entry_points[ep_index as usize].stage;
1512 for (index, arg) in func.arguments.iter().enumerate() {
1513 if index != 0 {
1514 write!(self.out, ", ")?;
1515 }
1516 self.write_type(module, arg.ty)?;
1517
1518 let argument_name =
1519 &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
1520
1521 write!(self.out, " {argument_name}")?;
1522 if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
1523 self.write_array_size(module, base, size)?;
1524 }
1525
1526 self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
1527 }
1528 }
1529 if need_workgroup_variables_initialization {
1530 if self
1531 .entry_point_io
1532 .get(&(ep_index as usize))
1533 .unwrap()
1534 .input
1535 .is_some()
1536 || !func.arguments.is_empty()
1537 {
1538 write!(self.out, ", ")?;
1539 }
1540 write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
1541 }
1542 }
1543 }
1544 write!(self.out, ")")?;
1546
1547 if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1549 let stage = module.entry_points[index as usize].stage;
1550 if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
1551 self.write_semantic(binding, Some((stage, Io::Output)))?;
1552 }
1553 }
1554
1555 writeln!(self.out)?;
1557 writeln!(self.out, "{{")?;
1558
1559 if need_workgroup_variables_initialization {
1560 self.write_workgroup_variables_initialization(func_ctx, module)?;
1561 }
1562
1563 if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1564 self.write_ep_arguments_initialization(module, func, index)?;
1565 }
1566
1567 for (handle, local) in func.local_variables.iter() {
1569 write!(self.out, "{}", back::INDENT)?;
1571
1572 self.write_type(module, local.ty)?;
1575 write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
1576 if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
1578 self.write_array_size(module, base, size)?;
1579 }
1580
1581 match module.types[local.ty].inner {
1582 TypeInner::RayQuery { .. } => {}
1584 _ => {
1585 write!(self.out, " = ")?;
1586 if let Some(init) = local.init {
1588 self.write_expr(module, init, func_ctx)?;
1589 } else {
1590 self.write_default_init(module, local.ty)?;
1592 }
1593 }
1594 }
1595 writeln!(self.out, ";")?
1597 }
1598
1599 if !func.local_variables.is_empty() {
1600 writeln!(self.out)?;
1601 }
1602
1603 for sta in func.body.iter() {
1605 self.write_stmt(module, sta, func_ctx, back::Level(1))?;
1607 }
1608
1609 writeln!(self.out, "}}")?;
1610
1611 self.named_expressions.clear();
1612
1613 Ok(())
1614 }
1615
1616 fn need_workgroup_variables_initialization(
1617 &mut self,
1618 func_ctx: &back::FunctionCtx,
1619 module: &Module,
1620 ) -> bool {
1621 self.options.zero_initialize_workgroup_memory
1622 && func_ctx.ty.is_compute_entry_point(module)
1623 && module.global_variables.iter().any(|(handle, var)| {
1624 !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1625 })
1626 }
1627
1628 fn write_workgroup_variables_initialization(
1629 &mut self,
1630 func_ctx: &back::FunctionCtx,
1631 module: &Module,
1632 ) -> BackendResult {
1633 let level = back::Level(1);
1634
1635 writeln!(
1636 self.out,
1637 "{level}if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {{"
1638 )?;
1639
1640 let vars = module.global_variables.iter().filter(|&(handle, var)| {
1641 !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1642 });
1643
1644 for (handle, var) in vars {
1645 let name = &self.names[&NameKey::GlobalVariable(handle)];
1646 write!(self.out, "{}{} = ", level.next(), name)?;
1647 self.write_default_init(module, var.ty)?;
1648 writeln!(self.out, ";")?;
1649 }
1650
1651 writeln!(self.out, "{level}}}")?;
1652 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)
1653 }
1654
1655 fn write_switch(
1657 &mut self,
1658 module: &Module,
1659 func_ctx: &back::FunctionCtx<'_>,
1660 level: back::Level,
1661 selector: Handle<crate::Expression>,
1662 cases: &[crate::SwitchCase],
1663 ) -> BackendResult {
1664 let indent_level_1 = level.next();
1666 let indent_level_2 = indent_level_1.next();
1667
1668 if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) {
1670 writeln!(self.out, "{level}bool {variable} = false;",)?;
1671 };
1672
1673 let one_body = cases
1678 .iter()
1679 .rev()
1680 .skip(1)
1681 .all(|case| case.fall_through && case.body.is_empty());
1682 if one_body {
1683 writeln!(self.out, "{level}do {{")?;
1685 if let Some(case) = cases.last() {
1689 for sta in case.body.iter() {
1690 self.write_stmt(module, sta, func_ctx, indent_level_1)?;
1691 }
1692 }
1693 writeln!(self.out, "{level}}} while(false);")?;
1695 } else {
1696 write!(self.out, "{level}")?;
1698 write!(self.out, "switch(")?;
1699 self.write_expr(module, selector, func_ctx)?;
1700 writeln!(self.out, ") {{")?;
1701
1702 for (i, case) in cases.iter().enumerate() {
1703 match case.value {
1704 crate::SwitchValue::I32(value) => {
1705 write!(self.out, "{indent_level_1}case {value}:")?
1706 }
1707 crate::SwitchValue::U32(value) => {
1708 write!(self.out, "{indent_level_1}case {value}u:")?
1709 }
1710 crate::SwitchValue::Default => write!(self.out, "{indent_level_1}default:")?,
1711 }
1712
1713 let write_block_braces = !(case.fall_through && case.body.is_empty());
1720 if write_block_braces {
1721 writeln!(self.out, " {{")?;
1722 } else {
1723 writeln!(self.out)?;
1724 }
1725
1726 if case.fall_through && !case.body.is_empty() {
1744 let curr_len = i + 1;
1745 let end_case_idx = curr_len
1746 + cases
1747 .iter()
1748 .skip(curr_len)
1749 .position(|case| !case.fall_through)
1750 .unwrap();
1751 let indent_level_3 = indent_level_2.next();
1752 for case in &cases[i..=end_case_idx] {
1753 writeln!(self.out, "{indent_level_2}{{")?;
1754 let prev_len = self.named_expressions.len();
1755 for sta in case.body.iter() {
1756 self.write_stmt(module, sta, func_ctx, indent_level_3)?;
1757 }
1758 self.named_expressions.truncate(prev_len);
1760 writeln!(self.out, "{indent_level_2}}}")?;
1761 }
1762
1763 let last_case = &cases[end_case_idx];
1764 if last_case.body.last().is_none_or(|s| !s.is_terminator()) {
1765 writeln!(self.out, "{indent_level_2}break;")?;
1766 }
1767 } else {
1768 for sta in case.body.iter() {
1769 self.write_stmt(module, sta, func_ctx, indent_level_2)?;
1770 }
1771 if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) {
1772 writeln!(self.out, "{indent_level_2}break;")?;
1773 }
1774 }
1775
1776 if write_block_braces {
1777 writeln!(self.out, "{indent_level_1}}}")?;
1778 }
1779 }
1780
1781 writeln!(self.out, "{level}}}")?;
1782 }
1783
1784 use back::continue_forward::ExitControlFlow;
1786 let op = match self.continue_ctx.exit_switch() {
1787 ExitControlFlow::None => None,
1788 ExitControlFlow::Continue { variable } => Some(("continue", variable)),
1789 ExitControlFlow::Break { variable } => Some(("break", variable)),
1790 };
1791 if let Some((control_flow, variable)) = op {
1792 writeln!(self.out, "{level}if ({variable}) {{")?;
1793 writeln!(self.out, "{indent_level_1}{control_flow};")?;
1794 writeln!(self.out, "{level}}}")?;
1795 }
1796
1797 Ok(())
1798 }
1799
1800 fn write_stmt(
1805 &mut self,
1806 module: &Module,
1807 stmt: &crate::Statement,
1808 func_ctx: &back::FunctionCtx<'_>,
1809 level: back::Level,
1810 ) -> BackendResult {
1811 use crate::Statement;
1812
1813 match *stmt {
1814 Statement::Emit(ref range) => {
1815 for handle in range.clone() {
1816 let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
1817 let expr_name = if ptr_class.is_some() {
1818 None
1822 } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
1823 Some(self.namer.call(name))
1828 } else if self.need_bake_expressions.contains(&handle) {
1829 Some(Baked(handle).to_string())
1830 } else {
1831 None
1832 };
1833
1834 if let Some(name) = expr_name {
1835 write!(self.out, "{level}")?;
1836 self.write_named_expr(module, handle, name, handle, func_ctx)?;
1837 }
1838 }
1839 }
1840 Statement::Block(ref block) => {
1842 write!(self.out, "{level}")?;
1843 writeln!(self.out, "{{")?;
1844 for sta in block.iter() {
1845 self.write_stmt(module, sta, func_ctx, level.next())?
1847 }
1848 writeln!(self.out, "{level}}}")?
1849 }
1850 Statement::If {
1852 condition,
1853 ref accept,
1854 ref reject,
1855 } => {
1856 write!(self.out, "{level}")?;
1857 write!(self.out, "if (")?;
1858 self.write_expr(module, condition, func_ctx)?;
1859 writeln!(self.out, ") {{")?;
1860
1861 let l2 = level.next();
1862 for sta in accept {
1863 self.write_stmt(module, sta, func_ctx, l2)?;
1865 }
1866
1867 if !reject.is_empty() {
1870 writeln!(self.out, "{level}}} else {{")?;
1871
1872 for sta in reject {
1873 self.write_stmt(module, sta, func_ctx, l2)?;
1875 }
1876 }
1877
1878 writeln!(self.out, "{level}}}")?
1879 }
1880 Statement::Kill => writeln!(self.out, "{level}discard;")?,
1882 Statement::Return { value: None } => {
1883 writeln!(self.out, "{level}return;")?;
1884 }
1885 Statement::Return { value: Some(expr) } => {
1886 let base_ty_res = &func_ctx.info[expr].ty;
1887 let mut resolved = base_ty_res.inner_with(&module.types);
1888 if let TypeInner::Pointer { base, space: _ } = *resolved {
1889 resolved = &module.types[base].inner;
1890 }
1891
1892 if let TypeInner::Struct { .. } = *resolved {
1893 let ty = base_ty_res.handle().unwrap();
1895 let struct_name = &self.names[&NameKey::Type(ty)];
1896 let variable_name = self.namer.call(&struct_name.to_lowercase());
1897 write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
1898 self.write_expr(module, expr, func_ctx)?;
1899 writeln!(self.out, ";")?;
1900
1901 let ep_output = match func_ctx.ty {
1903 back::FunctionType::Function(_) => None,
1904 back::FunctionType::EntryPoint(index) => self
1905 .entry_point_io
1906 .get(&(index as usize))
1907 .unwrap()
1908 .output
1909 .as_ref(),
1910 };
1911 let final_name = match ep_output {
1912 Some(ep_output) => {
1913 let final_name = self.namer.call(&variable_name);
1914 write!(
1915 self.out,
1916 "{}const {} {} = {{ ",
1917 level, ep_output.ty_name, final_name,
1918 )?;
1919 for (index, m) in ep_output.members.iter().enumerate() {
1920 if index != 0 {
1921 write!(self.out, ", ")?;
1922 }
1923 let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
1924 write!(self.out, "{variable_name}.{member_name}")?;
1925 }
1926 writeln!(self.out, " }};")?;
1927 final_name
1928 }
1929 None => variable_name,
1930 };
1931 writeln!(self.out, "{level}return {final_name};")?;
1932 } else {
1933 write!(self.out, "{level}return ")?;
1934 self.write_expr(module, expr, func_ctx)?;
1935 writeln!(self.out, ";")?
1936 }
1937 }
1938 Statement::Store { pointer, value } => {
1939 let ty_inner = func_ctx.resolve_type(pointer, &module.types);
1940 if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
1941 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
1942 self.write_storage_store(
1943 module,
1944 var_handle,
1945 StoreValue::Expression(value),
1946 func_ctx,
1947 level,
1948 )?;
1949 } else {
1950 struct MatrixAccess {
1956 base: Handle<crate::Expression>,
1957 index: u32,
1958 }
1959 enum Index {
1960 Expression(Handle<crate::Expression>),
1961 Static(u32),
1962 }
1963
1964 let get_members = |expr: Handle<crate::Expression>| {
1965 let resolved = func_ctx.resolve_type(expr, &module.types);
1966 match *resolved {
1967 TypeInner::Pointer { base, .. } => match module.types[base].inner {
1968 TypeInner::Struct { ref members, .. } => Some(members),
1969 _ => None,
1970 },
1971 _ => None,
1972 }
1973 };
1974
1975 let mut matrix = None;
1976 let mut vector = None;
1977 let mut scalar = None;
1978
1979 let mut current_expr = pointer;
1980 for _ in 0..3 {
1981 let resolved = func_ctx.resolve_type(current_expr, &module.types);
1982
1983 match (resolved, &func_ctx.expressions[current_expr]) {
1984 (
1985 &TypeInner::Pointer { base: ty, .. },
1986 &crate::Expression::AccessIndex { base, index },
1987 ) if matches!(
1988 module.types[ty].inner,
1989 TypeInner::Matrix {
1990 rows: crate::VectorSize::Bi,
1991 ..
1992 }
1993 ) && get_members(base)
1994 .map(|members| members[index as usize].binding.is_none())
1995 == Some(true) =>
1996 {
1997 matrix = Some(MatrixAccess { base, index });
1998 break;
1999 }
2000 (
2001 &TypeInner::ValuePointer {
2002 size: Some(crate::VectorSize::Bi),
2003 ..
2004 },
2005 &crate::Expression::Access { base, index },
2006 ) => {
2007 vector = Some(Index::Expression(index));
2008 current_expr = base;
2009 }
2010 (
2011 &TypeInner::ValuePointer {
2012 size: Some(crate::VectorSize::Bi),
2013 ..
2014 },
2015 &crate::Expression::AccessIndex { base, index },
2016 ) => {
2017 vector = Some(Index::Static(index));
2018 current_expr = base;
2019 }
2020 (
2021 &TypeInner::ValuePointer { size: None, .. },
2022 &crate::Expression::Access { base, index },
2023 ) => {
2024 scalar = Some(Index::Expression(index));
2025 current_expr = base;
2026 }
2027 (
2028 &TypeInner::ValuePointer { size: None, .. },
2029 &crate::Expression::AccessIndex { base, index },
2030 ) => {
2031 scalar = Some(Index::Static(index));
2032 current_expr = base;
2033 }
2034 _ => break,
2035 }
2036 }
2037
2038 write!(self.out, "{level}")?;
2039
2040 if let Some(MatrixAccess { index, base }) = matrix {
2041 let base_ty_res = &func_ctx.info[base].ty;
2042 let resolved = base_ty_res.inner_with(&module.types);
2043 let ty = match *resolved {
2044 TypeInner::Pointer { base, .. } => base,
2045 _ => base_ty_res.handle().unwrap(),
2046 };
2047
2048 if let Some(Index::Static(vec_index)) = vector {
2049 self.write_expr(module, base, func_ctx)?;
2050 write!(
2051 self.out,
2052 ".{}_{}",
2053 &self.names[&NameKey::StructMember(ty, index)],
2054 vec_index
2055 )?;
2056
2057 if let Some(scalar_index) = scalar {
2058 write!(self.out, "[")?;
2059 match scalar_index {
2060 Index::Static(index) => {
2061 write!(self.out, "{index}")?;
2062 }
2063 Index::Expression(index) => {
2064 self.write_expr(module, index, func_ctx)?;
2065 }
2066 }
2067 write!(self.out, "]")?;
2068 }
2069
2070 write!(self.out, " = ")?;
2071 self.write_expr(module, value, func_ctx)?;
2072 writeln!(self.out, ";")?;
2073 } else {
2074 let access = WrappedStructMatrixAccess { ty, index };
2075 match (&vector, &scalar) {
2076 (&Some(_), &Some(_)) => {
2077 self.write_wrapped_struct_matrix_set_scalar_function_name(
2078 access,
2079 )?;
2080 }
2081 (&Some(_), &None) => {
2082 self.write_wrapped_struct_matrix_set_vec_function_name(access)?;
2083 }
2084 (&None, _) => {
2085 self.write_wrapped_struct_matrix_set_function_name(access)?;
2086 }
2087 }
2088
2089 write!(self.out, "(")?;
2090 self.write_expr(module, base, func_ctx)?;
2091 write!(self.out, ", ")?;
2092 self.write_expr(module, value, func_ctx)?;
2093
2094 if let Some(Index::Expression(vec_index)) = vector {
2095 write!(self.out, ", ")?;
2096 self.write_expr(module, vec_index, func_ctx)?;
2097
2098 if let Some(scalar_index) = scalar {
2099 write!(self.out, ", ")?;
2100 match scalar_index {
2101 Index::Static(index) => {
2102 write!(self.out, "{index}")?;
2103 }
2104 Index::Expression(index) => {
2105 self.write_expr(module, index, func_ctx)?;
2106 }
2107 }
2108 }
2109 }
2110 writeln!(self.out, ");")?;
2111 }
2112 } else {
2113 struct MatrixData {
2116 columns: crate::VectorSize,
2117 base: Handle<crate::Expression>,
2118 }
2119
2120 enum Index {
2121 Expression(Handle<crate::Expression>),
2122 Static(u32),
2123 }
2124
2125 let mut matrix = None;
2126 let mut vector = None;
2127 let mut scalar = None;
2128
2129 let mut current_expr = pointer;
2130 for _ in 0..3 {
2131 let resolved = func_ctx.resolve_type(current_expr, &module.types);
2132 match (resolved, &func_ctx.expressions[current_expr]) {
2133 (
2134 &TypeInner::ValuePointer {
2135 size: Some(crate::VectorSize::Bi),
2136 ..
2137 },
2138 &crate::Expression::Access { base, index },
2139 ) => {
2140 vector = Some(index);
2141 current_expr = base;
2142 }
2143 (
2144 &TypeInner::ValuePointer { size: None, .. },
2145 &crate::Expression::Access { base, index },
2146 ) => {
2147 scalar = Some(Index::Expression(index));
2148 current_expr = base;
2149 }
2150 (
2151 &TypeInner::ValuePointer { size: None, .. },
2152 &crate::Expression::AccessIndex { base, index },
2153 ) => {
2154 scalar = Some(Index::Static(index));
2155 current_expr = base;
2156 }
2157 _ => {
2158 if let Some(MatrixType {
2159 columns,
2160 rows: crate::VectorSize::Bi,
2161 width: 4,
2162 }) = get_inner_matrix_of_struct_array_member(
2163 module,
2164 current_expr,
2165 func_ctx,
2166 true,
2167 ) {
2168 matrix = Some(MatrixData {
2169 columns,
2170 base: current_expr,
2171 });
2172 }
2173
2174 break;
2175 }
2176 }
2177 }
2178
2179 if let (Some(MatrixData { columns, base }), Some(vec_index)) =
2180 (matrix, vector)
2181 {
2182 if scalar.is_some() {
2183 write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
2184 } else {
2185 write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
2186 }
2187 write!(self.out, "(")?;
2188 self.write_expr(module, base, func_ctx)?;
2189 write!(self.out, ", ")?;
2190 self.write_expr(module, vec_index, func_ctx)?;
2191
2192 if let Some(scalar_index) = scalar {
2193 write!(self.out, ", ")?;
2194 match scalar_index {
2195 Index::Static(index) => {
2196 write!(self.out, "{index}")?;
2197 }
2198 Index::Expression(index) => {
2199 self.write_expr(module, index, func_ctx)?;
2200 }
2201 }
2202 }
2203
2204 write!(self.out, ", ")?;
2205 self.write_expr(module, value, func_ctx)?;
2206
2207 writeln!(self.out, ");")?;
2208 } else {
2209 self.write_expr(module, pointer, func_ctx)?;
2210 write!(self.out, " = ")?;
2211
2212 if let Some(MatrixType {
2217 columns,
2218 rows: crate::VectorSize::Bi,
2219 width: 4,
2220 }) = get_inner_matrix_of_struct_array_member(
2221 module, pointer, func_ctx, false,
2222 ) {
2223 let mut resolved = func_ctx.resolve_type(pointer, &module.types);
2224 if let TypeInner::Pointer { base, .. } = *resolved {
2225 resolved = &module.types[base].inner;
2226 }
2227
2228 write!(self.out, "(__mat{}x2", columns as u8)?;
2229 if let TypeInner::Array { base, size, .. } = *resolved {
2230 self.write_array_size(module, base, size)?;
2231 }
2232 write!(self.out, ")")?;
2233 }
2234
2235 self.write_expr(module, value, func_ctx)?;
2236 writeln!(self.out, ";")?
2237 }
2238 }
2239 }
2240 }
2241 Statement::Loop {
2242 ref body,
2243 ref continuing,
2244 break_if,
2245 } => {
2246 let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2247 let gate_name = (!continuing.is_empty() || break_if.is_some())
2248 .then(|| self.namer.call("loop_init"));
2249
2250 if let Some((ref decl, _)) = force_loop_bound_statements {
2251 writeln!(self.out, "{decl}")?;
2252 }
2253 if let Some(ref gate_name) = gate_name {
2254 writeln!(self.out, "{level}bool {gate_name} = true;")?;
2255 }
2256
2257 self.continue_ctx.enter_loop();
2258 writeln!(self.out, "{level}while(true) {{")?;
2259 if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2260 writeln!(self.out, "{break_and_inc}")?;
2261 }
2262 let l2 = level.next();
2263 if let Some(gate_name) = gate_name {
2264 writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
2265 let l3 = l2.next();
2266 for sta in continuing.iter() {
2267 self.write_stmt(module, sta, func_ctx, l3)?;
2268 }
2269 if let Some(condition) = break_if {
2270 write!(self.out, "{l3}if (")?;
2271 self.write_expr(module, condition, func_ctx)?;
2272 writeln!(self.out, ") {{")?;
2273 writeln!(self.out, "{}break;", l3.next())?;
2274 writeln!(self.out, "{l3}}}")?;
2275 }
2276 writeln!(self.out, "{l2}}}")?;
2277 writeln!(self.out, "{l2}{gate_name} = false;")?;
2278 }
2279
2280 for sta in body.iter() {
2281 self.write_stmt(module, sta, func_ctx, l2)?;
2282 }
2283
2284 writeln!(self.out, "{level}}}")?;
2285 self.continue_ctx.exit_loop();
2286 }
2287 Statement::Break => writeln!(self.out, "{level}break;")?,
2288 Statement::Continue => {
2289 if let Some(variable) = self.continue_ctx.continue_encountered() {
2290 writeln!(self.out, "{level}{variable} = true;")?;
2291 writeln!(self.out, "{level}break;")?
2292 } else {
2293 writeln!(self.out, "{level}continue;")?
2294 }
2295 }
2296 Statement::ControlBarrier(barrier) => {
2297 self.write_control_barrier(barrier, level)?;
2298 }
2299 Statement::MemoryBarrier(barrier) => {
2300 self.write_memory_barrier(barrier, level)?;
2301 }
2302 Statement::ImageStore {
2303 image,
2304 coordinate,
2305 array_index,
2306 value,
2307 } => {
2308 write!(self.out, "{level}")?;
2309 self.write_expr(module, image, func_ctx)?;
2310
2311 write!(self.out, "[")?;
2312 if let Some(index) = array_index {
2313 write!(self.out, "int3(")?;
2315 self.write_expr(module, coordinate, func_ctx)?;
2316 write!(self.out, ", ")?;
2317 self.write_expr(module, index, func_ctx)?;
2318 write!(self.out, ")")?;
2319 } else {
2320 self.write_expr(module, coordinate, func_ctx)?;
2321 }
2322 write!(self.out, "]")?;
2323
2324 write!(self.out, " = ")?;
2325 self.write_expr(module, value, func_ctx)?;
2326 writeln!(self.out, ";")?;
2327 }
2328 Statement::Call {
2329 function,
2330 ref arguments,
2331 result,
2332 } => {
2333 write!(self.out, "{level}")?;
2334 if let Some(expr) = result {
2335 write!(self.out, "const ")?;
2336 let name = Baked(expr).to_string();
2337 let expr_ty = &func_ctx.info[expr].ty;
2338 let ty_inner = match *expr_ty {
2339 proc::TypeResolution::Handle(handle) => {
2340 self.write_type(module, handle)?;
2341 &module.types[handle].inner
2342 }
2343 proc::TypeResolution::Value(ref value) => {
2344 self.write_value_type(module, value)?;
2345 value
2346 }
2347 };
2348 write!(self.out, " {name}")?;
2349 if let TypeInner::Array { base, size, .. } = *ty_inner {
2350 self.write_array_size(module, base, size)?;
2351 }
2352 write!(self.out, " = ")?;
2353 self.named_expressions.insert(expr, name);
2354 }
2355 let func_name = &self.names[&NameKey::Function(function)];
2356 write!(self.out, "{func_name}(")?;
2357 for (index, argument) in arguments.iter().enumerate() {
2358 if index != 0 {
2359 write!(self.out, ", ")?;
2360 }
2361 self.write_expr(module, *argument, func_ctx)?;
2362 }
2363 writeln!(self.out, ");")?
2364 }
2365 Statement::Atomic {
2366 pointer,
2367 ref fun,
2368 value,
2369 result,
2370 } => {
2371 write!(self.out, "{level}")?;
2372 let res_var_info = if let Some(res_handle) = result {
2373 let name = Baked(res_handle).to_string();
2374 match func_ctx.info[res_handle].ty {
2375 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2376 proc::TypeResolution::Value(ref value) => {
2377 self.write_value_type(module, value)?
2378 }
2379 };
2380 write!(self.out, " {name}; ")?;
2381 self.named_expressions.insert(res_handle, name.clone());
2382 Some((res_handle, name))
2383 } else {
2384 None
2385 };
2386 let pointer_space = func_ctx
2387 .resolve_type(pointer, &module.types)
2388 .pointer_space()
2389 .unwrap();
2390 let fun_str = fun.to_hlsl_suffix();
2391 let compare_expr = match *fun {
2392 crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
2393 _ => None,
2394 };
2395 match pointer_space {
2396 crate::AddressSpace::WorkGroup => {
2397 write!(self.out, "Interlocked{fun_str}(")?;
2398 self.write_expr(module, pointer, func_ctx)?;
2399 self.emit_hlsl_atomic_tail(
2400 module,
2401 func_ctx,
2402 fun,
2403 compare_expr,
2404 value,
2405 &res_var_info,
2406 )?;
2407 }
2408 crate::AddressSpace::Storage { .. } => {
2409 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2410 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2411 let width = match func_ctx.resolve_type(value, &module.types) {
2412 &TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
2413 _ => "",
2414 };
2415 write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
2416 let chain = mem::take(&mut self.temp_access_chain);
2417 self.write_storage_address(module, &chain, func_ctx)?;
2418 self.temp_access_chain = chain;
2419 self.emit_hlsl_atomic_tail(
2420 module,
2421 func_ctx,
2422 fun,
2423 compare_expr,
2424 value,
2425 &res_var_info,
2426 )?;
2427 }
2428 ref other => {
2429 return Err(Error::Custom(format!(
2430 "invalid address space {other:?} for atomic statement"
2431 )))
2432 }
2433 }
2434 if let Some(cmp) = compare_expr {
2435 if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
2436 write!(
2437 self.out,
2438 "{level}{res_name}.exchanged = ({res_name}.old_value == "
2439 )?;
2440 self.write_expr(module, cmp, func_ctx)?;
2441 writeln!(self.out, ");")?;
2442 }
2443 }
2444 }
2445 Statement::ImageAtomic {
2446 image,
2447 coordinate,
2448 array_index,
2449 fun,
2450 value,
2451 } => {
2452 write!(self.out, "{level}")?;
2453
2454 let fun_str = fun.to_hlsl_suffix();
2455 write!(self.out, "Interlocked{fun_str}(")?;
2456 self.write_expr(module, image, func_ctx)?;
2457 write!(self.out, "[")?;
2458 self.write_texture_coordinates(
2459 "int",
2460 coordinate,
2461 array_index,
2462 None,
2463 module,
2464 func_ctx,
2465 )?;
2466 write!(self.out, "],")?;
2467
2468 self.write_expr(module, value, func_ctx)?;
2469 writeln!(self.out, ");")?;
2470 }
2471 Statement::WorkGroupUniformLoad { pointer, result } => {
2472 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2473 write!(self.out, "{level}")?;
2474 let name = Baked(result).to_string();
2475 self.write_named_expr(module, pointer, name, result, func_ctx)?;
2476
2477 self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2478 }
2479 Statement::Switch {
2480 selector,
2481 ref cases,
2482 } => {
2483 self.write_switch(module, func_ctx, level, selector, cases)?;
2484 }
2485 Statement::RayQuery { query, ref fun } => match *fun {
2486 RayQueryFunction::Initialize {
2487 acceleration_structure,
2488 descriptor,
2489 } => {
2490 write!(self.out, "{level}")?;
2491 self.write_expr(module, query, func_ctx)?;
2492 write!(self.out, ".TraceRayInline(")?;
2493 self.write_expr(module, acceleration_structure, func_ctx)?;
2494 write!(self.out, ", ")?;
2495 self.write_expr(module, descriptor, func_ctx)?;
2496 write!(self.out, ".flags, ")?;
2497 self.write_expr(module, descriptor, func_ctx)?;
2498 write!(self.out, ".cull_mask, ")?;
2499 write!(self.out, "RayDescFromRayDesc_(")?;
2500 self.write_expr(module, descriptor, func_ctx)?;
2501 writeln!(self.out, "));")?;
2502 }
2503 RayQueryFunction::Proceed { result } => {
2504 write!(self.out, "{level}")?;
2505 let name = Baked(result).to_string();
2506 write!(self.out, "const bool {name} = ")?;
2507 self.named_expressions.insert(result, name);
2508 self.write_expr(module, query, func_ctx)?;
2509 writeln!(self.out, ".Proceed();")?;
2510 }
2511 RayQueryFunction::GenerateIntersection { hit_t } => {
2512 write!(self.out, "{level}")?;
2513 self.write_expr(module, query, func_ctx)?;
2514 write!(self.out, ".CommitProceduralPrimitiveHit(")?;
2515 self.write_expr(module, hit_t, func_ctx)?;
2516 writeln!(self.out, ");")?;
2517 }
2518 RayQueryFunction::ConfirmIntersection => {
2519 write!(self.out, "{level}")?;
2520 self.write_expr(module, query, func_ctx)?;
2521 writeln!(self.out, ".CommitNonOpaqueTriangleHit();")?;
2522 }
2523 RayQueryFunction::Terminate => {
2524 write!(self.out, "{level}")?;
2525 self.write_expr(module, query, func_ctx)?;
2526 writeln!(self.out, ".Abort();")?;
2527 }
2528 },
2529 Statement::SubgroupBallot { result, predicate } => {
2530 write!(self.out, "{level}")?;
2531 let name = Baked(result).to_string();
2532 write!(self.out, "const uint4 {name} = ")?;
2533 self.named_expressions.insert(result, name);
2534
2535 write!(self.out, "WaveActiveBallot(")?;
2536 match predicate {
2537 Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
2538 None => write!(self.out, "true")?,
2539 }
2540 writeln!(self.out, ");")?;
2541 }
2542 Statement::SubgroupCollectiveOperation {
2543 op,
2544 collective_op,
2545 argument,
2546 result,
2547 } => {
2548 write!(self.out, "{level}")?;
2549 write!(self.out, "const ")?;
2550 let name = Baked(result).to_string();
2551 match func_ctx.info[result].ty {
2552 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2553 proc::TypeResolution::Value(ref value) => {
2554 self.write_value_type(module, value)?
2555 }
2556 };
2557 write!(self.out, " {name} = ")?;
2558 self.named_expressions.insert(result, name);
2559
2560 match (collective_op, op) {
2561 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
2562 write!(self.out, "WaveActiveAllTrue(")?
2563 }
2564 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
2565 write!(self.out, "WaveActiveAnyTrue(")?
2566 }
2567 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
2568 write!(self.out, "WaveActiveSum(")?
2569 }
2570 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
2571 write!(self.out, "WaveActiveProduct(")?
2572 }
2573 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
2574 write!(self.out, "WaveActiveMax(")?
2575 }
2576 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
2577 write!(self.out, "WaveActiveMin(")?
2578 }
2579 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
2580 write!(self.out, "WaveActiveBitAnd(")?
2581 }
2582 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
2583 write!(self.out, "WaveActiveBitOr(")?
2584 }
2585 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
2586 write!(self.out, "WaveActiveBitXor(")?
2587 }
2588 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
2589 write!(self.out, "WavePrefixSum(")?
2590 }
2591 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
2592 write!(self.out, "WavePrefixProduct(")?
2593 }
2594 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
2595 self.write_expr(module, argument, func_ctx)?;
2596 write!(self.out, " + WavePrefixSum(")?;
2597 }
2598 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
2599 self.write_expr(module, argument, func_ctx)?;
2600 write!(self.out, " * WavePrefixProduct(")?;
2601 }
2602 _ => unimplemented!(),
2603 }
2604 self.write_expr(module, argument, func_ctx)?;
2605 writeln!(self.out, ");")?;
2606 }
2607 Statement::SubgroupGather {
2608 mode,
2609 argument,
2610 result,
2611 } => {
2612 write!(self.out, "{level}")?;
2613 write!(self.out, "const ")?;
2614 let name = Baked(result).to_string();
2615 match func_ctx.info[result].ty {
2616 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2617 proc::TypeResolution::Value(ref value) => {
2618 self.write_value_type(module, value)?
2619 }
2620 };
2621 write!(self.out, " {name} = ")?;
2622 self.named_expressions.insert(result, name);
2623 match mode {
2624 crate::GatherMode::BroadcastFirst => {
2625 write!(self.out, "WaveReadLaneFirst(")?;
2626 self.write_expr(module, argument, func_ctx)?;
2627 }
2628 crate::GatherMode::QuadBroadcast(index) => {
2629 write!(self.out, "QuadReadLaneAt(")?;
2630 self.write_expr(module, argument, func_ctx)?;
2631 write!(self.out, ", ")?;
2632 self.write_expr(module, index, func_ctx)?;
2633 }
2634 crate::GatherMode::QuadSwap(direction) => {
2635 match direction {
2636 crate::Direction::X => {
2637 write!(self.out, "QuadReadAcrossX(")?;
2638 }
2639 crate::Direction::Y => {
2640 write!(self.out, "QuadReadAcrossY(")?;
2641 }
2642 crate::Direction::Diagonal => {
2643 write!(self.out, "QuadReadAcrossDiagonal(")?;
2644 }
2645 }
2646 self.write_expr(module, argument, func_ctx)?;
2647 }
2648 _ => {
2649 write!(self.out, "WaveReadLaneAt(")?;
2650 self.write_expr(module, argument, func_ctx)?;
2651 write!(self.out, ", ")?;
2652 match mode {
2653 crate::GatherMode::BroadcastFirst => unreachable!(),
2654 crate::GatherMode::Broadcast(index)
2655 | crate::GatherMode::Shuffle(index) => {
2656 self.write_expr(module, index, func_ctx)?;
2657 }
2658 crate::GatherMode::ShuffleDown(index) => {
2659 write!(self.out, "WaveGetLaneIndex() + ")?;
2660 self.write_expr(module, index, func_ctx)?;
2661 }
2662 crate::GatherMode::ShuffleUp(index) => {
2663 write!(self.out, "WaveGetLaneIndex() - ")?;
2664 self.write_expr(module, index, func_ctx)?;
2665 }
2666 crate::GatherMode::ShuffleXor(index) => {
2667 write!(self.out, "WaveGetLaneIndex() ^ ")?;
2668 self.write_expr(module, index, func_ctx)?;
2669 }
2670 crate::GatherMode::QuadBroadcast(_) => unreachable!(),
2671 crate::GatherMode::QuadSwap(_) => unreachable!(),
2672 }
2673 }
2674 }
2675 writeln!(self.out, ");")?;
2676 }
2677 }
2678
2679 Ok(())
2680 }
2681
2682 fn write_const_expression(
2683 &mut self,
2684 module: &Module,
2685 expr: Handle<crate::Expression>,
2686 arena: &crate::Arena<crate::Expression>,
2687 ) -> BackendResult {
2688 self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
2689 writer.write_const_expression(module, expr, arena)
2690 })
2691 }
2692
2693 pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
2694 match literal {
2695 crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
2696 crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
2697 crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
2698 crate::Literal::U32(value) => write!(self.out, "{value}u")?,
2699 crate::Literal::I32(value) if value == i32::MIN => {
2705 write!(self.out, "int({} - 1)", value + 1)?
2706 }
2707 crate::Literal::I32(value) => write!(self.out, "int({value})")?,
2711 crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
2712 crate::Literal::I64(value) if value == i64::MIN => {
2714 write!(self.out, "({}L - 1L)", value + 1)?;
2715 }
2716 crate::Literal::I64(value) => write!(self.out, "{value}L")?,
2717 crate::Literal::Bool(value) => write!(self.out, "{value}")?,
2718 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
2719 return Err(Error::Custom(
2720 "Abstract types should not appear in IR presented to backends".into(),
2721 ));
2722 }
2723 }
2724 Ok(())
2725 }
2726
2727 fn write_possibly_const_expression<E>(
2728 &mut self,
2729 module: &Module,
2730 expr: Handle<crate::Expression>,
2731 expressions: &crate::Arena<crate::Expression>,
2732 write_expression: E,
2733 ) -> BackendResult
2734 where
2735 E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
2736 {
2737 use crate::Expression;
2738
2739 match expressions[expr] {
2740 Expression::Literal(literal) => {
2741 self.write_literal(literal)?;
2742 }
2743 Expression::Constant(handle) => {
2744 let constant = &module.constants[handle];
2745 if constant.name.is_some() {
2746 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
2747 } else {
2748 self.write_const_expression(module, constant.init, &module.global_expressions)?;
2749 }
2750 }
2751 Expression::ZeroValue(ty) => {
2752 self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
2753 write!(self.out, "()")?;
2754 }
2755 Expression::Compose { ty, ref components } => {
2756 match module.types[ty].inner {
2757 TypeInner::Struct { .. } | TypeInner::Array { .. } => {
2758 self.write_wrapped_constructor_function_name(
2759 module,
2760 WrappedConstructor { ty },
2761 )?;
2762 }
2763 _ => {
2764 self.write_type(module, ty)?;
2765 }
2766 };
2767 write!(self.out, "(")?;
2768 for (index, component) in components.iter().enumerate() {
2769 if index != 0 {
2770 write!(self.out, ", ")?;
2771 }
2772 write_expression(self, *component)?;
2773 }
2774 write!(self.out, ")")?;
2775 }
2776 Expression::Splat { size, value } => {
2777 let number_of_components = match size {
2781 crate::VectorSize::Bi => "xx",
2782 crate::VectorSize::Tri => "xxx",
2783 crate::VectorSize::Quad => "xxxx",
2784 };
2785 write!(self.out, "(")?;
2786 write_expression(self, value)?;
2787 write!(self.out, ").{number_of_components}")?
2788 }
2789 _ => {
2790 return Err(Error::Override);
2791 }
2792 }
2793
2794 Ok(())
2795 }
2796
2797 pub(super) fn write_expr(
2802 &mut self,
2803 module: &Module,
2804 expr: Handle<crate::Expression>,
2805 func_ctx: &back::FunctionCtx<'_>,
2806 ) -> BackendResult {
2807 use crate::Expression;
2808
2809 let ff_input = if self.options.special_constants_binding.is_some() {
2811 func_ctx.is_fixed_function_input(expr, module)
2812 } else {
2813 None
2814 };
2815 let closing_bracket = match ff_input {
2816 Some(crate::BuiltIn::VertexIndex) => {
2817 write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
2818 ")"
2819 }
2820 Some(crate::BuiltIn::InstanceIndex) => {
2821 write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
2822 ")"
2823 }
2824 Some(crate::BuiltIn::NumWorkGroups) => {
2825 write!(
2829 self.out,
2830 "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
2831 )?;
2832 return Ok(());
2833 }
2834 _ => "",
2835 };
2836
2837 if let Some(name) = self.named_expressions.get(&expr) {
2838 write!(self.out, "{name}{closing_bracket}")?;
2839 return Ok(());
2840 }
2841
2842 let expression = &func_ctx.expressions[expr];
2843
2844 match *expression {
2845 Expression::Literal(_)
2846 | Expression::Constant(_)
2847 | Expression::ZeroValue(_)
2848 | Expression::Compose { .. }
2849 | Expression::Splat { .. } => {
2850 self.write_possibly_const_expression(
2851 module,
2852 expr,
2853 func_ctx.expressions,
2854 |writer, expr| writer.write_expr(module, expr, func_ctx),
2855 )?;
2856 }
2857 Expression::Override(_) => return Err(Error::Override),
2858 Expression::Binary {
2865 op:
2866 op @ crate::BinaryOperator::Add
2867 | op @ crate::BinaryOperator::Subtract
2868 | op @ crate::BinaryOperator::Multiply,
2869 left,
2870 right,
2871 } if matches!(
2872 func_ctx.resolve_type(expr, &module.types).scalar(),
2873 Some(Scalar::I32)
2874 ) =>
2875 {
2876 write!(self.out, "asint(asuint(",)?;
2877 self.write_expr(module, left, func_ctx)?;
2878 write!(self.out, ") {} asuint(", back::binary_operation_str(op))?;
2879 self.write_expr(module, right, func_ctx)?;
2880 write!(self.out, "))")?;
2881 }
2882 Expression::Binary {
2885 op: crate::BinaryOperator::Multiply,
2886 left,
2887 right,
2888 } if func_ctx.resolve_type(left, &module.types).is_matrix()
2889 || func_ctx.resolve_type(right, &module.types).is_matrix() =>
2890 {
2891 write!(self.out, "mul(")?;
2893 self.write_expr(module, right, func_ctx)?;
2894 write!(self.out, ", ")?;
2895 self.write_expr(module, left, func_ctx)?;
2896 write!(self.out, ")")?;
2897 }
2898
2899 Expression::Binary {
2911 op: crate::BinaryOperator::Divide,
2912 left,
2913 right,
2914 } if matches!(
2915 func_ctx.resolve_type(expr, &module.types).scalar_kind(),
2916 Some(ScalarKind::Sint | ScalarKind::Uint)
2917 ) =>
2918 {
2919 write!(self.out, "{DIV_FUNCTION}(")?;
2920 self.write_expr(module, left, func_ctx)?;
2921 write!(self.out, ", ")?;
2922 self.write_expr(module, right, func_ctx)?;
2923 write!(self.out, ")")?;
2924 }
2925
2926 Expression::Binary {
2927 op: crate::BinaryOperator::Modulo,
2928 left,
2929 right,
2930 } if matches!(
2931 func_ctx.resolve_type(expr, &module.types).scalar_kind(),
2932 Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float)
2933 ) =>
2934 {
2935 write!(self.out, "{MOD_FUNCTION}(")?;
2936 self.write_expr(module, left, func_ctx)?;
2937 write!(self.out, ", ")?;
2938 self.write_expr(module, right, func_ctx)?;
2939 write!(self.out, ")")?;
2940 }
2941
2942 Expression::Binary { op, left, right } => {
2943 write!(self.out, "(")?;
2944 self.write_expr(module, left, func_ctx)?;
2945 write!(self.out, " {} ", back::binary_operation_str(op))?;
2946 self.write_expr(module, right, func_ctx)?;
2947 write!(self.out, ")")?;
2948 }
2949 Expression::Access { base, index } => {
2950 if let Some(crate::AddressSpace::Storage { .. }) =
2951 func_ctx.resolve_type(expr, &module.types).pointer_space()
2952 {
2953 } else {
2955 if let Some(MatrixType {
2962 columns,
2963 rows: crate::VectorSize::Bi,
2964 width: 4,
2965 }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
2966 {
2967 write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
2968 self.write_expr(module, base, func_ctx)?;
2969 write!(self.out, ", ")?;
2970 self.write_expr(module, index, func_ctx)?;
2971 write!(self.out, ")")?;
2972 return Ok(());
2973 }
2974
2975 let resolved = func_ctx.resolve_type(base, &module.types);
2976
2977 let (indexing_binding_array, non_uniform_qualifier) = match *resolved {
2978 TypeInner::BindingArray { .. } => {
2979 let uniformity = &func_ctx.info[index].uniformity;
2980
2981 (true, uniformity.non_uniform_result.is_some())
2982 }
2983 _ => (false, false),
2984 };
2985
2986 self.write_expr(module, base, func_ctx)?;
2987
2988 let array_sampler_info = self.sampler_binding_array_info_from_expression(
2989 module, func_ctx, base, resolved,
2990 );
2991
2992 if let Some(ref info) = array_sampler_info {
2993 write!(self.out, "{}[", info.sampler_heap_name)?;
2994 } else {
2995 write!(self.out, "[")?;
2996 }
2997
2998 let needs_bound_check = self.options.restrict_indexing
2999 && !indexing_binding_array
3000 && match resolved.pointer_space() {
3001 Some(
3002 crate::AddressSpace::Function
3003 | crate::AddressSpace::Private
3004 | crate::AddressSpace::WorkGroup
3005 | crate::AddressSpace::PushConstant,
3006 )
3007 | None => true,
3008 Some(crate::AddressSpace::Uniform) => {
3009 let var_handle = self.fill_access_chain(module, base, func_ctx)?;
3011 let bind_target = self
3012 .options
3013 .resolve_resource_binding(
3014 module.global_variables[var_handle]
3015 .binding
3016 .as_ref()
3017 .unwrap(),
3018 )
3019 .unwrap();
3020 bind_target.restrict_indexing
3021 }
3022 Some(
3023 crate::AddressSpace::Handle | crate::AddressSpace::Storage { .. },
3024 ) => unreachable!(),
3025 };
3026 let restriction_needed = if needs_bound_check {
3028 index::access_needs_check(
3029 base,
3030 index::GuardedIndex::Expression(index),
3031 module,
3032 func_ctx.expressions,
3033 func_ctx.info,
3034 )
3035 } else {
3036 None
3037 };
3038 if let Some(limit) = restriction_needed {
3039 write!(self.out, "min(uint(")?;
3040 self.write_expr(module, index, func_ctx)?;
3041 write!(self.out, "), ")?;
3042 match limit {
3043 index::IndexableLength::Known(limit) => {
3044 write!(self.out, "{}u", limit - 1)?;
3045 }
3046 index::IndexableLength::Dynamic => unreachable!(),
3047 }
3048 write!(self.out, ")")?;
3049 } else {
3050 if non_uniform_qualifier {
3051 write!(self.out, "NonUniformResourceIndex(")?;
3052 }
3053 if let Some(ref info) = array_sampler_info {
3054 write!(
3055 self.out,
3056 "{}[{} + ",
3057 info.sampler_index_buffer_name, info.binding_array_base_index_name,
3058 )?;
3059 }
3060 self.write_expr(module, index, func_ctx)?;
3061 if array_sampler_info.is_some() {
3062 write!(self.out, "]")?;
3063 }
3064 if non_uniform_qualifier {
3065 write!(self.out, ")")?;
3066 }
3067 }
3068
3069 write!(self.out, "]")?;
3070 }
3071 }
3072 Expression::AccessIndex { base, index } => {
3073 if let Some(crate::AddressSpace::Storage { .. }) =
3074 func_ctx.resolve_type(expr, &module.types).pointer_space()
3075 {
3076 } else {
3078 if let Some(MatrixType {
3081 rows: crate::VectorSize::Bi,
3082 width: 4,
3083 ..
3084 }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3085 {
3086 self.write_expr(module, base, func_ctx)?;
3087 write!(self.out, "._{index}")?;
3088 return Ok(());
3089 }
3090
3091 let base_ty_res = &func_ctx.info[base].ty;
3092 let mut resolved = base_ty_res.inner_with(&module.types);
3093 let base_ty_handle = match *resolved {
3094 TypeInner::Pointer { base, .. } => {
3095 resolved = &module.types[base].inner;
3096 Some(base)
3097 }
3098 _ => base_ty_res.handle(),
3099 };
3100
3101 if let TypeInner::Struct { ref members, .. } = *resolved {
3107 let member = &members[index as usize];
3108
3109 match module.types[member.ty].inner {
3110 TypeInner::Matrix {
3111 rows: crate::VectorSize::Bi,
3112 ..
3113 } if member.binding.is_none() => {
3114 let ty = base_ty_handle.unwrap();
3115 self.write_wrapped_struct_matrix_get_function_name(
3116 WrappedStructMatrixAccess { ty, index },
3117 )?;
3118 write!(self.out, "(")?;
3119 self.write_expr(module, base, func_ctx)?;
3120 write!(self.out, ")")?;
3121 return Ok(());
3122 }
3123 _ => {}
3124 }
3125 }
3126
3127 let array_sampler_info = self.sampler_binding_array_info_from_expression(
3128 module, func_ctx, base, resolved,
3129 );
3130
3131 if let Some(ref info) = array_sampler_info {
3132 write!(
3133 self.out,
3134 "{}[{}",
3135 info.sampler_heap_name, info.sampler_index_buffer_name
3136 )?;
3137 }
3138
3139 self.write_expr(module, base, func_ctx)?;
3140
3141 match *resolved {
3142 TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
3148 write!(self.out, ".{}", back::COMPONENTS[index as usize])?
3150 }
3151 TypeInner::Matrix { .. }
3152 | TypeInner::Array { .. }
3153 | TypeInner::BindingArray { .. } => {
3154 if let Some(ref info) = array_sampler_info {
3155 write!(
3156 self.out,
3157 "[{} + {index}]",
3158 info.binding_array_base_index_name
3159 )?;
3160 } else {
3161 write!(self.out, "[{index}]")?;
3162 }
3163 }
3164 TypeInner::Struct { .. } => {
3165 let ty = base_ty_handle.unwrap();
3168
3169 write!(
3170 self.out,
3171 ".{}",
3172 &self.names[&NameKey::StructMember(ty, index)]
3173 )?
3174 }
3175 ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
3176 }
3177
3178 if array_sampler_info.is_some() {
3179 write!(self.out, "]")?;
3180 }
3181 }
3182 }
3183 Expression::FunctionArgument(pos) => {
3184 let key = func_ctx.argument_key(pos);
3185 let name = &self.names[&key];
3186 write!(self.out, "{name}")?;
3187 }
3188 Expression::ImageSample {
3189 coordinate,
3190 image,
3191 sampler,
3192 clamp_to_edge: true,
3193 gather: None,
3194 array_index: None,
3195 offset: None,
3196 level: crate::SampleLevel::Zero,
3197 depth_ref: None,
3198 } => {
3199 write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
3200 self.write_expr(module, image, func_ctx)?;
3201 write!(self.out, ", ")?;
3202 self.write_expr(module, sampler, func_ctx)?;
3203 write!(self.out, ", ")?;
3204 self.write_expr(module, coordinate, func_ctx)?;
3205 write!(self.out, ")")?;
3206 }
3207 Expression::ImageSample {
3208 image,
3209 sampler,
3210 gather,
3211 coordinate,
3212 array_index,
3213 offset,
3214 level,
3215 depth_ref,
3216 clamp_to_edge,
3217 } => {
3218 if clamp_to_edge {
3219 return Err(Error::Custom(
3220 "ImageSample::clamp_to_edge should have been validated out".to_string(),
3221 ));
3222 }
3223
3224 use crate::SampleLevel as Sl;
3225 const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
3226
3227 let (base_str, component_str) = match gather {
3228 Some(component) => ("Gather", COMPONENTS[component as usize]),
3229 None => ("Sample", ""),
3230 };
3231 let cmp_str = match depth_ref {
3232 Some(_) => "Cmp",
3233 None => "",
3234 };
3235 let level_str = match level {
3236 Sl::Zero if gather.is_none() => "LevelZero",
3237 Sl::Auto | Sl::Zero => "",
3238 Sl::Exact(_) => "Level",
3239 Sl::Bias(_) => "Bias",
3240 Sl::Gradient { .. } => "Grad",
3241 };
3242
3243 self.write_expr(module, image, func_ctx)?;
3244 write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
3245 self.write_expr(module, sampler, func_ctx)?;
3246 write!(self.out, ", ")?;
3247 self.write_texture_coordinates(
3248 "float",
3249 coordinate,
3250 array_index,
3251 None,
3252 module,
3253 func_ctx,
3254 )?;
3255
3256 if let Some(depth_ref) = depth_ref {
3257 write!(self.out, ", ")?;
3258 self.write_expr(module, depth_ref, func_ctx)?;
3259 }
3260
3261 match level {
3262 Sl::Auto | Sl::Zero => {}
3263 Sl::Exact(expr) => {
3264 write!(self.out, ", ")?;
3265 self.write_expr(module, expr, func_ctx)?;
3266 }
3267 Sl::Bias(expr) => {
3268 write!(self.out, ", ")?;
3269 self.write_expr(module, expr, func_ctx)?;
3270 }
3271 Sl::Gradient { x, y } => {
3272 write!(self.out, ", ")?;
3273 self.write_expr(module, x, func_ctx)?;
3274 write!(self.out, ", ")?;
3275 self.write_expr(module, y, func_ctx)?;
3276 }
3277 }
3278
3279 if let Some(offset) = offset {
3280 write!(self.out, ", ")?;
3281 write!(self.out, "int2(")?; self.write_const_expression(module, offset, func_ctx.expressions)?;
3283 write!(self.out, ")")?;
3284 }
3285
3286 write!(self.out, ")")?;
3287 }
3288 Expression::ImageQuery { image, query } => {
3289 if let TypeInner::Image {
3291 dim,
3292 arrayed,
3293 class,
3294 } = *func_ctx.resolve_type(image, &module.types)
3295 {
3296 let wrapped_image_query = WrappedImageQuery {
3297 dim,
3298 arrayed,
3299 class,
3300 query: query.into(),
3301 };
3302
3303 self.write_wrapped_image_query_function_name(wrapped_image_query)?;
3304 write!(self.out, "(")?;
3305 self.write_expr(module, image, func_ctx)?;
3307 if let crate::ImageQuery::Size { level: Some(level) } = query {
3308 write!(self.out, ", ")?;
3309 self.write_expr(module, level, func_ctx)?;
3310 }
3311 write!(self.out, ")")?;
3312 }
3313 }
3314 Expression::ImageLoad {
3315 image,
3316 coordinate,
3317 array_index,
3318 sample,
3319 level,
3320 } => self.write_image_load(
3321 &module,
3322 expr,
3323 func_ctx,
3324 image,
3325 coordinate,
3326 array_index,
3327 sample,
3328 level,
3329 )?,
3330 Expression::GlobalVariable(handle) => {
3331 let global_variable = &module.global_variables[handle];
3332 let ty = &module.types[global_variable.ty].inner;
3333
3334 let is_binding_array_of_samplers = match *ty {
3339 TypeInner::BindingArray { base, .. } => {
3340 let base_ty = &module.types[base].inner;
3341 matches!(*base_ty, TypeInner::Sampler { .. })
3342 }
3343 _ => false,
3344 };
3345
3346 let is_storage_space =
3347 matches!(global_variable.space, crate::AddressSpace::Storage { .. });
3348
3349 if !is_binding_array_of_samplers && !is_storage_space {
3350 let name = &self.names[&NameKey::GlobalVariable(handle)];
3351 write!(self.out, "{name}")?;
3352 }
3353 }
3354 Expression::LocalVariable(handle) => {
3355 write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
3356 }
3357 Expression::Load { pointer } => {
3358 match func_ctx
3359 .resolve_type(pointer, &module.types)
3360 .pointer_space()
3361 {
3362 Some(crate::AddressSpace::Storage { .. }) => {
3363 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
3364 let result_ty = func_ctx.info[expr].ty.clone();
3365 self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
3366 }
3367 _ => {
3368 let mut close_paren = false;
3369
3370 if let Some(MatrixType {
3375 rows: crate::VectorSize::Bi,
3376 width: 4,
3377 ..
3378 }) = get_inner_matrix_of_struct_array_member(
3379 module, pointer, func_ctx, false,
3380 )
3381 .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
3382 {
3383 let mut resolved = func_ctx.resolve_type(pointer, &module.types);
3384 if let TypeInner::Pointer { base, .. } = *resolved {
3385 resolved = &module.types[base].inner;
3386 }
3387
3388 write!(self.out, "((")?;
3389 if let TypeInner::Array { base, size, .. } = *resolved {
3390 self.write_type(module, base)?;
3391 self.write_array_size(module, base, size)?;
3392 } else {
3393 self.write_value_type(module, resolved)?;
3394 }
3395 write!(self.out, ")")?;
3396 close_paren = true;
3397 }
3398
3399 self.write_expr(module, pointer, func_ctx)?;
3400
3401 if close_paren {
3402 write!(self.out, ")")?;
3403 }
3404 }
3405 }
3406 }
3407 Expression::Unary { op, expr } => {
3408 let op_str = match op {
3410 crate::UnaryOperator::Negate => {
3411 match func_ctx.resolve_type(expr, &module.types).scalar() {
3412 Some(Scalar::I32) => NEG_FUNCTION,
3413 _ => "-",
3414 }
3415 }
3416 crate::UnaryOperator::LogicalNot => "!",
3417 crate::UnaryOperator::BitwiseNot => "~",
3418 };
3419 write!(self.out, "{op_str}(")?;
3420 self.write_expr(module, expr, func_ctx)?;
3421 write!(self.out, ")")?;
3422 }
3423 Expression::As {
3424 expr,
3425 kind,
3426 convert,
3427 } => {
3428 let inner = func_ctx.resolve_type(expr, &module.types);
3429 if inner.scalar_kind() == Some(ScalarKind::Float)
3430 && (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
3431 && convert.is_some()
3432 {
3433 let fun_name = match (kind, convert) {
3437 (ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
3438 (ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
3439 (ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
3440 (ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
3441 _ => unreachable!(),
3442 };
3443 write!(self.out, "{fun_name}(")?;
3444 self.write_expr(module, expr, func_ctx)?;
3445 write!(self.out, ")")?;
3446 } else {
3447 let close_paren = match convert {
3448 Some(dst_width) => {
3449 let scalar = Scalar {
3450 kind,
3451 width: dst_width,
3452 };
3453 match *inner {
3454 TypeInner::Vector { size, .. } => {
3455 write!(
3456 self.out,
3457 "{}{}(",
3458 scalar.to_hlsl_str()?,
3459 common::vector_size_str(size)
3460 )?;
3461 }
3462 TypeInner::Scalar(_) => {
3463 write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
3464 }
3465 TypeInner::Matrix { columns, rows, .. } => {
3466 write!(
3467 self.out,
3468 "{}{}x{}(",
3469 scalar.to_hlsl_str()?,
3470 common::vector_size_str(columns),
3471 common::vector_size_str(rows)
3472 )?;
3473 }
3474 _ => {
3475 return Err(Error::Unimplemented(format!(
3476 "write_expr expression::as {inner:?}"
3477 )));
3478 }
3479 };
3480 true
3481 }
3482 None => {
3483 if inner.scalar_width() == Some(8) {
3484 false
3485 } else {
3486 write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
3487 true
3488 }
3489 }
3490 };
3491 self.write_expr(module, expr, func_ctx)?;
3492 if close_paren {
3493 write!(self.out, ")")?;
3494 }
3495 }
3496 }
3497 Expression::Math {
3498 fun,
3499 arg,
3500 arg1,
3501 arg2,
3502 arg3,
3503 } => {
3504 use crate::MathFunction as Mf;
3505
3506 enum Function {
3507 Asincosh { is_sin: bool },
3508 Atanh,
3509 Pack2x16float,
3510 Pack2x16snorm,
3511 Pack2x16unorm,
3512 Pack4x8snorm,
3513 Pack4x8unorm,
3514 Pack4xI8,
3515 Pack4xU8,
3516 Pack4xI8Clamp,
3517 Pack4xU8Clamp,
3518 Unpack2x16float,
3519 Unpack2x16snorm,
3520 Unpack2x16unorm,
3521 Unpack4x8snorm,
3522 Unpack4x8unorm,
3523 Unpack4xI8,
3524 Unpack4xU8,
3525 Dot4I8Packed,
3526 Dot4U8Packed,
3527 QuantizeToF16,
3528 Regular(&'static str),
3529 MissingIntOverload(&'static str),
3530 MissingIntReturnType(&'static str),
3531 CountTrailingZeros,
3532 CountLeadingZeros,
3533 }
3534
3535 let fun = match fun {
3536 Mf::Abs => match func_ctx.resolve_type(arg, &module.types).scalar() {
3538 Some(Scalar::I32) => Function::Regular(ABS_FUNCTION),
3539 _ => Function::Regular("abs"),
3540 },
3541 Mf::Min => Function::Regular("min"),
3542 Mf::Max => Function::Regular("max"),
3543 Mf::Clamp => Function::Regular("clamp"),
3544 Mf::Saturate => Function::Regular("saturate"),
3545 Mf::Cos => Function::Regular("cos"),
3547 Mf::Cosh => Function::Regular("cosh"),
3548 Mf::Sin => Function::Regular("sin"),
3549 Mf::Sinh => Function::Regular("sinh"),
3550 Mf::Tan => Function::Regular("tan"),
3551 Mf::Tanh => Function::Regular("tanh"),
3552 Mf::Acos => Function::Regular("acos"),
3553 Mf::Asin => Function::Regular("asin"),
3554 Mf::Atan => Function::Regular("atan"),
3555 Mf::Atan2 => Function::Regular("atan2"),
3556 Mf::Asinh => Function::Asincosh { is_sin: true },
3557 Mf::Acosh => Function::Asincosh { is_sin: false },
3558 Mf::Atanh => Function::Atanh,
3559 Mf::Radians => Function::Regular("radians"),
3560 Mf::Degrees => Function::Regular("degrees"),
3561 Mf::Ceil => Function::Regular("ceil"),
3563 Mf::Floor => Function::Regular("floor"),
3564 Mf::Round => Function::Regular("round"),
3565 Mf::Fract => Function::Regular("frac"),
3566 Mf::Trunc => Function::Regular("trunc"),
3567 Mf::Modf => Function::Regular(MODF_FUNCTION),
3568 Mf::Frexp => Function::Regular(FREXP_FUNCTION),
3569 Mf::Ldexp => Function::Regular("ldexp"),
3570 Mf::Exp => Function::Regular("exp"),
3572 Mf::Exp2 => Function::Regular("exp2"),
3573 Mf::Log => Function::Regular("log"),
3574 Mf::Log2 => Function::Regular("log2"),
3575 Mf::Pow => Function::Regular("pow"),
3576 Mf::Dot => Function::Regular("dot"),
3578 Mf::Dot4I8Packed => Function::Dot4I8Packed,
3579 Mf::Dot4U8Packed => Function::Dot4U8Packed,
3580 Mf::Cross => Function::Regular("cross"),
3582 Mf::Distance => Function::Regular("distance"),
3583 Mf::Length => Function::Regular("length"),
3584 Mf::Normalize => Function::Regular("normalize"),
3585 Mf::FaceForward => Function::Regular("faceforward"),
3586 Mf::Reflect => Function::Regular("reflect"),
3587 Mf::Refract => Function::Regular("refract"),
3588 Mf::Sign => Function::Regular("sign"),
3590 Mf::Fma => Function::Regular("mad"),
3591 Mf::Mix => Function::Regular("lerp"),
3592 Mf::Step => Function::Regular("step"),
3593 Mf::SmoothStep => Function::Regular("smoothstep"),
3594 Mf::Sqrt => Function::Regular("sqrt"),
3595 Mf::InverseSqrt => Function::Regular("rsqrt"),
3596 Mf::Transpose => Function::Regular("transpose"),
3598 Mf::Determinant => Function::Regular("determinant"),
3599 Mf::QuantizeToF16 => Function::QuantizeToF16,
3600 Mf::CountTrailingZeros => Function::CountTrailingZeros,
3602 Mf::CountLeadingZeros => Function::CountLeadingZeros,
3603 Mf::CountOneBits => Function::MissingIntOverload("countbits"),
3604 Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
3605 Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"),
3606 Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"),
3607 Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
3608 Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
3609 Mf::Pack2x16float => Function::Pack2x16float,
3611 Mf::Pack2x16snorm => Function::Pack2x16snorm,
3612 Mf::Pack2x16unorm => Function::Pack2x16unorm,
3613 Mf::Pack4x8snorm => Function::Pack4x8snorm,
3614 Mf::Pack4x8unorm => Function::Pack4x8unorm,
3615 Mf::Pack4xI8 => Function::Pack4xI8,
3616 Mf::Pack4xU8 => Function::Pack4xU8,
3617 Mf::Pack4xI8Clamp => Function::Pack4xI8Clamp,
3618 Mf::Pack4xU8Clamp => Function::Pack4xU8Clamp,
3619 Mf::Unpack2x16float => Function::Unpack2x16float,
3621 Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
3622 Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
3623 Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
3624 Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
3625 Mf::Unpack4xI8 => Function::Unpack4xI8,
3626 Mf::Unpack4xU8 => Function::Unpack4xU8,
3627 _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
3628 };
3629
3630 match fun {
3631 Function::Asincosh { is_sin } => {
3632 write!(self.out, "log(")?;
3633 self.write_expr(module, arg, func_ctx)?;
3634 write!(self.out, " + sqrt(")?;
3635 self.write_expr(module, arg, func_ctx)?;
3636 write!(self.out, " * ")?;
3637 self.write_expr(module, arg, func_ctx)?;
3638 match is_sin {
3639 true => write!(self.out, " + 1.0))")?,
3640 false => write!(self.out, " - 1.0))")?,
3641 }
3642 }
3643 Function::Atanh => {
3644 write!(self.out, "0.5 * log((1.0 + ")?;
3645 self.write_expr(module, arg, func_ctx)?;
3646 write!(self.out, ") / (1.0 - ")?;
3647 self.write_expr(module, arg, func_ctx)?;
3648 write!(self.out, "))")?;
3649 }
3650 Function::Pack2x16float => {
3651 write!(self.out, "(f32tof16(")?;
3652 self.write_expr(module, arg, func_ctx)?;
3653 write!(self.out, "[0]) | f32tof16(")?;
3654 self.write_expr(module, arg, func_ctx)?;
3655 write!(self.out, "[1]) << 16)")?;
3656 }
3657 Function::Pack2x16snorm => {
3658 let scale = 32767;
3659
3660 write!(self.out, "uint((int(round(clamp(")?;
3661 self.write_expr(module, arg, func_ctx)?;
3662 write!(
3663 self.out,
3664 "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
3665 )?;
3666 self.write_expr(module, arg, func_ctx)?;
3667 write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
3668 }
3669 Function::Pack2x16unorm => {
3670 let scale = 65535;
3671
3672 write!(self.out, "(uint(round(clamp(")?;
3673 self.write_expr(module, arg, func_ctx)?;
3674 write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3675 self.write_expr(module, arg, func_ctx)?;
3676 write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
3677 }
3678 Function::Pack4x8snorm => {
3679 let scale = 127;
3680
3681 write!(self.out, "uint((int(round(clamp(")?;
3682 self.write_expr(module, arg, func_ctx)?;
3683 write!(
3684 self.out,
3685 "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
3686 )?;
3687 self.write_expr(module, arg, func_ctx)?;
3688 write!(
3689 self.out,
3690 "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
3691 )?;
3692 self.write_expr(module, arg, func_ctx)?;
3693 write!(
3694 self.out,
3695 "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
3696 )?;
3697 self.write_expr(module, arg, func_ctx)?;
3698 write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
3699 }
3700 Function::Pack4x8unorm => {
3701 let scale = 255;
3702
3703 write!(self.out, "(uint(round(clamp(")?;
3704 self.write_expr(module, arg, func_ctx)?;
3705 write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3706 self.write_expr(module, arg, func_ctx)?;
3707 write!(
3708 self.out,
3709 "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
3710 )?;
3711 self.write_expr(module, arg, func_ctx)?;
3712 write!(
3713 self.out,
3714 "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
3715 )?;
3716 self.write_expr(module, arg, func_ctx)?;
3717 write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
3718 }
3719 fun @ (Function::Pack4xI8
3720 | Function::Pack4xU8
3721 | Function::Pack4xI8Clamp
3722 | Function::Pack4xU8Clamp) => {
3723 let was_signed =
3724 matches!(fun, Function::Pack4xI8 | Function::Pack4xI8Clamp);
3725 let clamp_bounds = match fun {
3726 Function::Pack4xI8Clamp => Some(("-128", "127")),
3727 Function::Pack4xU8Clamp => Some(("0", "255")),
3728 _ => None,
3729 };
3730 if was_signed {
3731 write!(self.out, "uint(")?;
3732 }
3733 let write_arg = |this: &mut Self| -> BackendResult {
3734 if let Some((min, max)) = clamp_bounds {
3735 write!(this.out, "clamp(")?;
3736 this.write_expr(module, arg, func_ctx)?;
3737 write!(this.out, ", {min}, {max})")?;
3738 } else {
3739 this.write_expr(module, arg, func_ctx)?;
3740 }
3741 Ok(())
3742 };
3743 write!(self.out, "(")?;
3744 write_arg(self)?;
3745 write!(self.out, "[0] & 0xFF) | ((")?;
3746 write_arg(self)?;
3747 write!(self.out, "[1] & 0xFF) << 8) | ((")?;
3748 write_arg(self)?;
3749 write!(self.out, "[2] & 0xFF) << 16) | ((")?;
3750 write_arg(self)?;
3751 write!(self.out, "[3] & 0xFF) << 24)")?;
3752 if was_signed {
3753 write!(self.out, ")")?;
3754 }
3755 }
3756
3757 Function::Unpack2x16float => {
3758 write!(self.out, "float2(f16tof32(")?;
3759 self.write_expr(module, arg, func_ctx)?;
3760 write!(self.out, "), f16tof32((")?;
3761 self.write_expr(module, arg, func_ctx)?;
3762 write!(self.out, ") >> 16))")?;
3763 }
3764 Function::Unpack2x16snorm => {
3765 let scale = 32767;
3766
3767 write!(self.out, "(float2(int2(")?;
3768 self.write_expr(module, arg, func_ctx)?;
3769 write!(self.out, " << 16, ")?;
3770 self.write_expr(module, arg, func_ctx)?;
3771 write!(self.out, ") >> 16) / {scale}.0)")?;
3772 }
3773 Function::Unpack2x16unorm => {
3774 let scale = 65535;
3775
3776 write!(self.out, "(float2(")?;
3777 self.write_expr(module, arg, func_ctx)?;
3778 write!(self.out, " & 0xFFFF, ")?;
3779 self.write_expr(module, arg, func_ctx)?;
3780 write!(self.out, " >> 16) / {scale}.0)")?;
3781 }
3782 Function::Unpack4x8snorm => {
3783 let scale = 127;
3784
3785 write!(self.out, "(float4(int4(")?;
3786 self.write_expr(module, arg, func_ctx)?;
3787 write!(self.out, " << 24, ")?;
3788 self.write_expr(module, arg, func_ctx)?;
3789 write!(self.out, " << 16, ")?;
3790 self.write_expr(module, arg, func_ctx)?;
3791 write!(self.out, " << 8, ")?;
3792 self.write_expr(module, arg, func_ctx)?;
3793 write!(self.out, ") >> 24) / {scale}.0)")?;
3794 }
3795 Function::Unpack4x8unorm => {
3796 let scale = 255;
3797
3798 write!(self.out, "(float4(")?;
3799 self.write_expr(module, arg, func_ctx)?;
3800 write!(self.out, " & 0xFF, ")?;
3801 self.write_expr(module, arg, func_ctx)?;
3802 write!(self.out, " >> 8 & 0xFF, ")?;
3803 self.write_expr(module, arg, func_ctx)?;
3804 write!(self.out, " >> 16 & 0xFF, ")?;
3805 self.write_expr(module, arg, func_ctx)?;
3806 write!(self.out, " >> 24) / {scale}.0)")?;
3807 }
3808 fun @ (Function::Unpack4xI8 | Function::Unpack4xU8) => {
3809 write!(self.out, "(")?;
3810 if matches!(fun, Function::Unpack4xU8) {
3811 write!(self.out, "u")?;
3812 }
3813 write!(self.out, "int4(")?;
3814 self.write_expr(module, arg, func_ctx)?;
3815 write!(self.out, ", ")?;
3816 self.write_expr(module, arg, func_ctx)?;
3817 write!(self.out, " >> 8, ")?;
3818 self.write_expr(module, arg, func_ctx)?;
3819 write!(self.out, " >> 16, ")?;
3820 self.write_expr(module, arg, func_ctx)?;
3821 write!(self.out, " >> 24) << 24 >> 24)")?;
3822 }
3823 fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
3824 let arg1 = arg1.unwrap();
3825
3826 if self.options.shader_model >= ShaderModel::V6_4 {
3827 let function_name = match fun {
3829 Function::Dot4I8Packed => "dot4add_i8packed",
3830 Function::Dot4U8Packed => "dot4add_u8packed",
3831 _ => unreachable!(),
3832 };
3833 write!(self.out, "{function_name}(")?;
3834 self.write_expr(module, arg, func_ctx)?;
3835 write!(self.out, ", ")?;
3836 self.write_expr(module, arg1, func_ctx)?;
3837 write!(self.out, ", 0)")?;
3838 } else {
3839 write!(self.out, "dot(")?;
3841
3842 if matches!(fun, Function::Dot4U8Packed) {
3843 write!(self.out, "u")?;
3844 }
3845 write!(self.out, "int4(")?;
3846 self.write_expr(module, arg, func_ctx)?;
3847 write!(self.out, ", ")?;
3848 self.write_expr(module, arg, func_ctx)?;
3849 write!(self.out, " >> 8, ")?;
3850 self.write_expr(module, arg, func_ctx)?;
3851 write!(self.out, " >> 16, ")?;
3852 self.write_expr(module, arg, func_ctx)?;
3853 write!(self.out, " >> 24) << 24 >> 24, ")?;
3854
3855 if matches!(fun, Function::Dot4U8Packed) {
3856 write!(self.out, "u")?;
3857 }
3858 write!(self.out, "int4(")?;
3859 self.write_expr(module, arg1, func_ctx)?;
3860 write!(self.out, ", ")?;
3861 self.write_expr(module, arg1, func_ctx)?;
3862 write!(self.out, " >> 8, ")?;
3863 self.write_expr(module, arg1, func_ctx)?;
3864 write!(self.out, " >> 16, ")?;
3865 self.write_expr(module, arg1, func_ctx)?;
3866 write!(self.out, " >> 24) << 24 >> 24)")?;
3867 }
3868 }
3869 Function::QuantizeToF16 => {
3870 write!(self.out, "f16tof32(f32tof16(")?;
3871 self.write_expr(module, arg, func_ctx)?;
3872 write!(self.out, "))")?;
3873 }
3874 Function::Regular(fun_name) => {
3875 write!(self.out, "{fun_name}(")?;
3876 self.write_expr(module, arg, func_ctx)?;
3877 if let Some(arg) = arg1 {
3878 write!(self.out, ", ")?;
3879 self.write_expr(module, arg, func_ctx)?;
3880 }
3881 if let Some(arg) = arg2 {
3882 write!(self.out, ", ")?;
3883 self.write_expr(module, arg, func_ctx)?;
3884 }
3885 if let Some(arg) = arg3 {
3886 write!(self.out, ", ")?;
3887 self.write_expr(module, arg, func_ctx)?;
3888 }
3889 write!(self.out, ")")?
3890 }
3891 Function::MissingIntOverload(fun_name) => {
3894 let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
3895 if let Some(Scalar::I32) = scalar_kind {
3896 write!(self.out, "asint({fun_name}(asuint(")?;
3897 self.write_expr(module, arg, func_ctx)?;
3898 write!(self.out, ")))")?;
3899 } else {
3900 write!(self.out, "{fun_name}(")?;
3901 self.write_expr(module, arg, func_ctx)?;
3902 write!(self.out, ")")?;
3903 }
3904 }
3905 Function::MissingIntReturnType(fun_name) => {
3908 let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
3909 if let Some(Scalar::I32) = scalar_kind {
3910 write!(self.out, "asint({fun_name}(")?;
3911 self.write_expr(module, arg, func_ctx)?;
3912 write!(self.out, "))")?;
3913 } else {
3914 write!(self.out, "{fun_name}(")?;
3915 self.write_expr(module, arg, func_ctx)?;
3916 write!(self.out, ")")?;
3917 }
3918 }
3919 Function::CountTrailingZeros => {
3920 match *func_ctx.resolve_type(arg, &module.types) {
3921 TypeInner::Vector { size, scalar } => {
3922 let s = match size {
3923 crate::VectorSize::Bi => ".xx",
3924 crate::VectorSize::Tri => ".xxx",
3925 crate::VectorSize::Quad => ".xxxx",
3926 };
3927
3928 let scalar_width_bits = scalar.width * 8;
3929
3930 if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
3931 write!(
3932 self.out,
3933 "min(({scalar_width_bits}u){s}, firstbitlow("
3934 )?;
3935 self.write_expr(module, arg, func_ctx)?;
3936 write!(self.out, "))")?;
3937 } else {
3938 write!(
3940 self.out,
3941 "asint(min(({scalar_width_bits}u){s}, firstbitlow("
3942 )?;
3943 self.write_expr(module, arg, func_ctx)?;
3944 write!(self.out, ")))")?;
3945 }
3946 }
3947 TypeInner::Scalar(scalar) => {
3948 let scalar_width_bits = scalar.width * 8;
3949
3950 if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
3951 write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
3952 self.write_expr(module, arg, func_ctx)?;
3953 write!(self.out, "))")?;
3954 } else {
3955 write!(
3957 self.out,
3958 "asint(min({scalar_width_bits}u, firstbitlow("
3959 )?;
3960 self.write_expr(module, arg, func_ctx)?;
3961 write!(self.out, ")))")?;
3962 }
3963 }
3964 _ => unreachable!(),
3965 }
3966
3967 return Ok(());
3968 }
3969 Function::CountLeadingZeros => {
3970 match *func_ctx.resolve_type(arg, &module.types) {
3971 TypeInner::Vector { size, scalar } => {
3972 let s = match size {
3973 crate::VectorSize::Bi => ".xx",
3974 crate::VectorSize::Tri => ".xxx",
3975 crate::VectorSize::Quad => ".xxxx",
3976 };
3977
3978 let constant = scalar.width * 8 - 1;
3980
3981 if scalar.kind == ScalarKind::Uint {
3982 write!(self.out, "(({constant}u){s} - firstbithigh(")?;
3983 self.write_expr(module, arg, func_ctx)?;
3984 write!(self.out, "))")?;
3985 } else {
3986 let conversion_func = match scalar.width {
3987 4 => "asint",
3988 _ => "",
3989 };
3990 write!(self.out, "(")?;
3991 self.write_expr(module, arg, func_ctx)?;
3992 write!(
3993 self.out,
3994 " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
3995 )?;
3996 self.write_expr(module, arg, func_ctx)?;
3997 write!(self.out, ")))")?;
3998 }
3999 }
4000 TypeInner::Scalar(scalar) => {
4001 let constant = scalar.width * 8 - 1;
4003
4004 if let ScalarKind::Uint = scalar.kind {
4005 write!(self.out, "({constant}u - firstbithigh(")?;
4006 self.write_expr(module, arg, func_ctx)?;
4007 write!(self.out, "))")?;
4008 } else {
4009 let conversion_func = match scalar.width {
4010 4 => "asint",
4011 _ => "",
4012 };
4013 write!(self.out, "(")?;
4014 self.write_expr(module, arg, func_ctx)?;
4015 write!(
4016 self.out,
4017 " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
4018 )?;
4019 self.write_expr(module, arg, func_ctx)?;
4020 write!(self.out, ")))")?;
4021 }
4022 }
4023 _ => unreachable!(),
4024 }
4025
4026 return Ok(());
4027 }
4028 }
4029 }
4030 Expression::Swizzle {
4031 size,
4032 vector,
4033 pattern,
4034 } => {
4035 self.write_expr(module, vector, func_ctx)?;
4036 write!(self.out, ".")?;
4037 for &sc in pattern[..size as usize].iter() {
4038 self.out.write_char(back::COMPONENTS[sc as usize])?;
4039 }
4040 }
4041 Expression::ArrayLength(expr) => {
4042 let var_handle = match func_ctx.expressions[expr] {
4043 Expression::AccessIndex { base, index: _ } => {
4044 match func_ctx.expressions[base] {
4045 Expression::GlobalVariable(handle) => handle,
4046 _ => unreachable!(),
4047 }
4048 }
4049 Expression::GlobalVariable(handle) => handle,
4050 _ => unreachable!(),
4051 };
4052
4053 let var = &module.global_variables[var_handle];
4054 let (offset, stride) = match module.types[var.ty].inner {
4055 TypeInner::Array { stride, .. } => (0, stride),
4056 TypeInner::Struct { ref members, .. } => {
4057 let last = members.last().unwrap();
4058 let stride = match module.types[last.ty].inner {
4059 TypeInner::Array { stride, .. } => stride,
4060 _ => unreachable!(),
4061 };
4062 (last.offset, stride)
4063 }
4064 _ => unreachable!(),
4065 };
4066
4067 let storage_access = match var.space {
4068 crate::AddressSpace::Storage { access } => access,
4069 _ => crate::StorageAccess::default(),
4070 };
4071 let wrapped_array_length = WrappedArrayLength {
4072 writable: storage_access.contains(crate::StorageAccess::STORE),
4073 };
4074
4075 write!(self.out, "((")?;
4076 self.write_wrapped_array_length_function_name(wrapped_array_length)?;
4077 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4078 write!(self.out, "({var_name}) - {offset}) / {stride})")?
4079 }
4080 Expression::Derivative { axis, ctrl, expr } => {
4081 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
4082 if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
4083 let tail = match ctrl {
4084 Ctrl::Coarse => "coarse",
4085 Ctrl::Fine => "fine",
4086 Ctrl::None => unreachable!(),
4087 };
4088 write!(self.out, "abs(ddx_{tail}(")?;
4089 self.write_expr(module, expr, func_ctx)?;
4090 write!(self.out, ")) + abs(ddy_{tail}(")?;
4091 self.write_expr(module, expr, func_ctx)?;
4092 write!(self.out, "))")?
4093 } else {
4094 let fun_str = match (axis, ctrl) {
4095 (Axis::X, Ctrl::Coarse) => "ddx_coarse",
4096 (Axis::X, Ctrl::Fine) => "ddx_fine",
4097 (Axis::X, Ctrl::None) => "ddx",
4098 (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
4099 (Axis::Y, Ctrl::Fine) => "ddy_fine",
4100 (Axis::Y, Ctrl::None) => "ddy",
4101 (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
4102 (Axis::Width, Ctrl::None) => "fwidth",
4103 };
4104 write!(self.out, "{fun_str}(")?;
4105 self.write_expr(module, expr, func_ctx)?;
4106 write!(self.out, ")")?
4107 }
4108 }
4109 Expression::Relational { fun, argument } => {
4110 use crate::RelationalFunction as Rf;
4111
4112 let fun_str = match fun {
4113 Rf::All => "all",
4114 Rf::Any => "any",
4115 Rf::IsNan => "isnan",
4116 Rf::IsInf => "isinf",
4117 };
4118 write!(self.out, "{fun_str}(")?;
4119 self.write_expr(module, argument, func_ctx)?;
4120 write!(self.out, ")")?
4121 }
4122 Expression::Select {
4123 condition,
4124 accept,
4125 reject,
4126 } => {
4127 write!(self.out, "(")?;
4128 self.write_expr(module, condition, func_ctx)?;
4129 write!(self.out, " ? ")?;
4130 self.write_expr(module, accept, func_ctx)?;
4131 write!(self.out, " : ")?;
4132 self.write_expr(module, reject, func_ctx)?;
4133 write!(self.out, ")")?
4134 }
4135 Expression::RayQueryGetIntersection { query, committed } => {
4136 if committed {
4137 write!(self.out, "GetCommittedIntersection(")?;
4138 self.write_expr(module, query, func_ctx)?;
4139 write!(self.out, ")")?;
4140 } else {
4141 write!(self.out, "GetCandidateIntersection(")?;
4142 self.write_expr(module, query, func_ctx)?;
4143 write!(self.out, ")")?;
4144 }
4145 }
4146 Expression::RayQueryVertexPositions { .. } => unreachable!(),
4148 Expression::CallResult(_)
4150 | Expression::AtomicResult { .. }
4151 | Expression::WorkGroupUniformLoadResult { .. }
4152 | Expression::RayQueryProceedResult
4153 | Expression::SubgroupBallotResult
4154 | Expression::SubgroupOperationResult { .. } => {}
4155 }
4156
4157 if !closing_bracket.is_empty() {
4158 write!(self.out, "{closing_bracket}")?;
4159 }
4160 Ok(())
4161 }
4162
4163 #[allow(clippy::too_many_arguments)]
4164 fn write_image_load(
4165 &mut self,
4166 module: &&Module,
4167 expr: Handle<crate::Expression>,
4168 func_ctx: &back::FunctionCtx,
4169 image: Handle<crate::Expression>,
4170 coordinate: Handle<crate::Expression>,
4171 array_index: Option<Handle<crate::Expression>>,
4172 sample: Option<Handle<crate::Expression>>,
4173 level: Option<Handle<crate::Expression>>,
4174 ) -> Result<(), Error> {
4175 let mut wrapping_type = None;
4176 match *func_ctx.resolve_type(image, &module.types) {
4177 TypeInner::Image {
4178 class: crate::ImageClass::Storage { format, .. },
4179 ..
4180 } => {
4181 if format.single_component() {
4182 wrapping_type = Some(Scalar::from(format));
4183 }
4184 }
4185 _ => {}
4186 }
4187 if let Some(scalar) = wrapping_type {
4188 write!(
4189 self.out,
4190 "{}{}(",
4191 help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
4192 scalar.to_hlsl_str()?
4193 )?;
4194 }
4195 self.write_expr(module, image, func_ctx)?;
4197 write!(self.out, ".Load(")?;
4198
4199 self.write_texture_coordinates("int", coordinate, array_index, level, module, func_ctx)?;
4200
4201 if let Some(sample) = sample {
4202 write!(self.out, ", ")?;
4203 self.write_expr(module, sample, func_ctx)?;
4204 }
4205
4206 write!(self.out, ")")?;
4208
4209 if wrapping_type.is_some() {
4210 write!(self.out, ")")?;
4211 }
4212
4213 if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
4215 write!(self.out, ".x")?;
4216 }
4217 Ok(())
4218 }
4219
4220 fn sampler_binding_array_info_from_expression(
4223 &mut self,
4224 module: &Module,
4225 func_ctx: &back::FunctionCtx<'_>,
4226 base: Handle<crate::Expression>,
4227 resolved: &TypeInner,
4228 ) -> Option<BindingArraySamplerInfo> {
4229 if let TypeInner::BindingArray {
4230 base: base_ty_handle,
4231 ..
4232 } = *resolved
4233 {
4234 let base_ty = &module.types[base_ty_handle].inner;
4235 if let TypeInner::Sampler { comparison, .. } = *base_ty {
4236 let base = &func_ctx.expressions[base];
4237
4238 if let crate::Expression::GlobalVariable(handle) = *base {
4239 let variable = &module.global_variables[handle];
4240
4241 let sampler_heap_name = match comparison {
4242 true => COMPARISON_SAMPLER_HEAP_VAR,
4243 false => SAMPLER_HEAP_VAR,
4244 };
4245
4246 return Some(BindingArraySamplerInfo {
4247 sampler_heap_name,
4248 sampler_index_buffer_name: self
4249 .wrapped
4250 .sampler_index_buffers
4251 .get(&super::SamplerIndexBufferKey {
4252 group: variable.binding.unwrap().group,
4253 })
4254 .unwrap()
4255 .clone(),
4256 binding_array_base_index_name: self.names[&NameKey::GlobalVariable(handle)]
4257 .clone(),
4258 });
4259 }
4260 }
4261 }
4262
4263 None
4264 }
4265
4266 fn write_named_expr(
4267 &mut self,
4268 module: &Module,
4269 handle: Handle<crate::Expression>,
4270 name: String,
4271 named: Handle<crate::Expression>,
4274 ctx: &back::FunctionCtx,
4275 ) -> BackendResult {
4276 match ctx.info[named].ty {
4277 proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
4278 TypeInner::Struct { .. } => {
4279 let ty_name = &self.names[&NameKey::Type(ty_handle)];
4280 write!(self.out, "{ty_name}")?;
4281 }
4282 _ => {
4283 self.write_type(module, ty_handle)?;
4284 }
4285 },
4286 proc::TypeResolution::Value(ref inner) => {
4287 self.write_value_type(module, inner)?;
4288 }
4289 }
4290
4291 let resolved = ctx.resolve_type(named, &module.types);
4292
4293 write!(self.out, " {name}")?;
4294 if let TypeInner::Array { base, size, .. } = *resolved {
4296 self.write_array_size(module, base, size)?;
4297 }
4298 write!(self.out, " = ")?;
4299 self.write_expr(module, handle, ctx)?;
4300 writeln!(self.out, ";")?;
4301 self.named_expressions.insert(named, name);
4302
4303 Ok(())
4304 }
4305
4306 pub(super) fn write_default_init(
4308 &mut self,
4309 module: &Module,
4310 ty: Handle<crate::Type>,
4311 ) -> BackendResult {
4312 write!(self.out, "(")?;
4313 self.write_type(module, ty)?;
4314 if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
4315 self.write_array_size(module, base, size)?;
4316 }
4317 write!(self.out, ")0")?;
4318 Ok(())
4319 }
4320
4321 fn write_control_barrier(
4322 &mut self,
4323 barrier: crate::Barrier,
4324 level: back::Level,
4325 ) -> BackendResult {
4326 if barrier.contains(crate::Barrier::STORAGE) {
4327 writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4328 }
4329 if barrier.contains(crate::Barrier::WORK_GROUP) {
4330 writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
4331 }
4332 if barrier.contains(crate::Barrier::SUB_GROUP) {
4333 }
4335 if barrier.contains(crate::Barrier::TEXTURE) {
4336 writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4337 }
4338 Ok(())
4339 }
4340
4341 fn write_memory_barrier(
4342 &mut self,
4343 barrier: crate::Barrier,
4344 level: back::Level,
4345 ) -> BackendResult {
4346 if barrier.contains(crate::Barrier::STORAGE) {
4347 writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4348 }
4349 if barrier.contains(crate::Barrier::WORK_GROUP) {
4350 writeln!(self.out, "{level}GroupMemoryBarrier();")?;
4351 }
4352 if barrier.contains(crate::Barrier::SUB_GROUP) {
4353 }
4355 if barrier.contains(crate::Barrier::TEXTURE) {
4356 writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4357 }
4358 Ok(())
4359 }
4360
4361 fn emit_hlsl_atomic_tail(
4363 &mut self,
4364 module: &Module,
4365 func_ctx: &back::FunctionCtx<'_>,
4366 fun: &crate::AtomicFunction,
4367 compare_expr: Option<Handle<crate::Expression>>,
4368 value: Handle<crate::Expression>,
4369 res_var_info: &Option<(Handle<crate::Expression>, String)>,
4370 ) -> BackendResult {
4371 if let Some(cmp) = compare_expr {
4372 write!(self.out, ", ")?;
4373 self.write_expr(module, cmp, func_ctx)?;
4374 }
4375 write!(self.out, ", ")?;
4376 if let crate::AtomicFunction::Subtract = *fun {
4377 write!(self.out, "-")?;
4379 }
4380 self.write_expr(module, value, func_ctx)?;
4381 if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
4382 write!(self.out, ", ")?;
4383 if compare_expr.is_some() {
4384 write!(self.out, "{res_name}.old_value")?;
4385 } else {
4386 write!(self.out, "{res_name}")?;
4387 }
4388 }
4389 writeln!(self.out, ");")?;
4390 Ok(())
4391 }
4392}
4393
4394pub(super) struct MatrixType {
4395 pub(super) columns: crate::VectorSize,
4396 pub(super) rows: crate::VectorSize,
4397 pub(super) width: crate::Bytes,
4398}
4399
4400pub(super) fn get_inner_matrix_data(
4401 module: &Module,
4402 handle: Handle<crate::Type>,
4403) -> Option<MatrixType> {
4404 match module.types[handle].inner {
4405 TypeInner::Matrix {
4406 columns,
4407 rows,
4408 scalar,
4409 } => Some(MatrixType {
4410 columns,
4411 rows,
4412 width: scalar.width,
4413 }),
4414 TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
4415 _ => None,
4416 }
4417}
4418
4419pub(super) fn get_inner_matrix_of_struct_array_member(
4424 module: &Module,
4425 base: Handle<crate::Expression>,
4426 func_ctx: &back::FunctionCtx<'_>,
4427 direct: bool,
4428) -> Option<MatrixType> {
4429 let mut mat_data = None;
4430 let mut array_base = None;
4431
4432 let mut current_base = base;
4433 loop {
4434 let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4435 if let TypeInner::Pointer { base, .. } = *resolved {
4436 resolved = &module.types[base].inner;
4437 };
4438
4439 match *resolved {
4440 TypeInner::Matrix {
4441 columns,
4442 rows,
4443 scalar,
4444 } => {
4445 mat_data = Some(MatrixType {
4446 columns,
4447 rows,
4448 width: scalar.width,
4449 })
4450 }
4451 TypeInner::Array { base, .. } => {
4452 array_base = Some(base);
4453 }
4454 TypeInner::Struct { .. } => {
4455 if let Some(array_base) = array_base {
4456 if direct {
4457 return mat_data;
4458 } else {
4459 return get_inner_matrix_data(module, array_base);
4460 }
4461 }
4462
4463 break;
4464 }
4465 _ => break,
4466 }
4467
4468 current_base = match func_ctx.expressions[current_base] {
4469 crate::Expression::Access { base, .. } => base,
4470 crate::Expression::AccessIndex { base, .. } => base,
4471 _ => break,
4472 };
4473 }
4474 None
4475}
4476
4477fn get_inner_matrix_of_global_uniform(
4482 module: &Module,
4483 base: Handle<crate::Expression>,
4484 func_ctx: &back::FunctionCtx<'_>,
4485) -> Option<MatrixType> {
4486 let mut mat_data = None;
4487 let mut array_base = None;
4488
4489 let mut current_base = base;
4490 loop {
4491 let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4492 if let TypeInner::Pointer { base, .. } = *resolved {
4493 resolved = &module.types[base].inner;
4494 };
4495
4496 match *resolved {
4497 TypeInner::Matrix {
4498 columns,
4499 rows,
4500 scalar,
4501 } => {
4502 mat_data = Some(MatrixType {
4503 columns,
4504 rows,
4505 width: scalar.width,
4506 })
4507 }
4508 TypeInner::Array { base, .. } => {
4509 array_base = Some(base);
4510 }
4511 _ => break,
4512 }
4513
4514 current_base = match func_ctx.expressions[current_base] {
4515 crate::Expression::Access { base, .. } => base,
4516 crate::Expression::AccessIndex { base, .. } => base,
4517 crate::Expression::GlobalVariable(handle)
4518 if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
4519 {
4520 return mat_data.or_else(|| {
4521 array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
4522 })
4523 }
4524 _ => break,
4525 };
4526 }
4527 None
4528}