naga/back/hlsl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec::Vec,
5};
6use core::{fmt, mem};
7
8use super::{
9    help,
10    help::{
11        WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
12        WrappedZeroValue,
13    },
14    storage::StoreValue,
15    BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
16};
17use crate::{
18    back::{self, get_entry_points, Baked},
19    common,
20    proc::{self, index, NameKey},
21    valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
22};
23
24const LOCATION_SEMANTIC: &str = "LOC";
25const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
26const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
27const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
28const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
29const SPECIAL_OTHER: &str = "other";
30
31pub(crate) const MODF_FUNCTION: &str = "naga_modf";
32pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
33pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
34pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
35pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap";
36pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap";
37pub(crate) const ABS_FUNCTION: &str = "naga_abs";
38pub(crate) const DIV_FUNCTION: &str = "naga_div";
39pub(crate) const MOD_FUNCTION: &str = "naga_mod";
40pub(crate) const NEG_FUNCTION: &str = "naga_neg";
41pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
42pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
43pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
44pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
45pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
46    "nagaTextureSampleBaseClampToEdge";
47
48struct EpStructMember {
49    name: String,
50    ty: Handle<crate::Type>,
51    // technically, this should always be `Some`
52    // (we `debug_assert!` this in `write_interface_struct`)
53    binding: Option<crate::Binding>,
54    index: u32,
55}
56
57/// Structure contains information required for generating
58/// wrapped structure of all entry points arguments
59struct EntryPointBinding {
60    /// Name of the fake EP argument that contains the struct
61    /// with all the flattened input data.
62    arg_name: String,
63    /// Generated structure name
64    ty_name: String,
65    /// Members of generated structure
66    members: Vec<EpStructMember>,
67}
68
69pub(super) struct EntryPointInterface {
70    /// If `Some`, the input of an entry point is gathered in a special
71    /// struct with members sorted by binding.
72    /// The `EntryPointBinding::members` array is sorted by index,
73    /// so that we can walk it in `write_ep_arguments_initialization`.
74    input: Option<EntryPointBinding>,
75    /// If `Some`, the output of an entry point is flattened.
76    /// The `EntryPointBinding::members` array is sorted by binding,
77    /// So that we can walk it in `Statement::Return` handler.
78    output: Option<EntryPointBinding>,
79}
80
81#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
82enum InterfaceKey {
83    Location(u32),
84    BuiltIn(crate::BuiltIn),
85    Other,
86}
87
88impl InterfaceKey {
89    const fn new(binding: Option<&crate::Binding>) -> Self {
90        match binding {
91            Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
92            Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
93            None => Self::Other,
94        }
95    }
96}
97
98#[derive(Copy, Clone, PartialEq)]
99enum Io {
100    Input,
101    Output,
102}
103
104const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
105    let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
106        return false;
107    };
108    matches!(
109        builtin,
110        crate::BuiltIn::SubgroupSize
111            | crate::BuiltIn::SubgroupInvocationId
112            | crate::BuiltIn::NumSubgroups
113            | crate::BuiltIn::SubgroupId
114    )
115}
116
117/// Information for how to generate a `binding_array<sampler>` access.
118struct BindingArraySamplerInfo {
119    /// Variable name of the sampler heap
120    sampler_heap_name: &'static str,
121    /// Variable name of the sampler index buffer
122    sampler_index_buffer_name: String,
123    /// Variable name of the base index _into_ the sampler index buffer
124    binding_array_base_index_name: String,
125}
126
127impl<'a, W: fmt::Write> super::Writer<'a, W> {
128    pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
129        Self {
130            out,
131            names: crate::FastHashMap::default(),
132            namer: proc::Namer::default(),
133            options,
134            pipeline_options,
135            entry_point_io: crate::FastHashMap::default(),
136            named_expressions: crate::NamedExpressions::default(),
137            wrapped: super::Wrapped::default(),
138            written_committed_intersection: false,
139            written_candidate_intersection: false,
140            continue_ctx: back::continue_forward::ContinueCtx::default(),
141            temp_access_chain: Vec::new(),
142            need_bake_expressions: Default::default(),
143        }
144    }
145
146    fn reset(&mut self, module: &Module) {
147        self.names.clear();
148        self.namer.reset(
149            module,
150            &super::keywords::RESERVED_SET,
151            super::keywords::RESERVED_CASE_INSENSITIVE,
152            super::keywords::RESERVED_PREFIXES,
153            &mut self.names,
154        );
155        self.entry_point_io.clear();
156        self.named_expressions.clear();
157        self.wrapped.clear();
158        self.written_committed_intersection = false;
159        self.written_candidate_intersection = false;
160        self.continue_ctx.clear();
161        self.need_bake_expressions.clear();
162    }
163
164    /// Generates statements to be inserted immediately before and at the very
165    /// start of the body of each loop, to defeat infinite loop reasoning.
166    /// The 0th item of the returned tuple should be inserted immediately prior
167    /// to the loop and the 1st item should be inserted at the very start of
168    /// the loop body.
169    ///
170    /// See [`back::msl::Writer::gen_force_bounded_loop_statements`] for details.
171    fn gen_force_bounded_loop_statements(
172        &mut self,
173        level: back::Level,
174    ) -> Option<(String, String)> {
175        if !self.options.force_loop_bounding {
176            return None;
177        }
178
179        let loop_bound_name = self.namer.call("loop_bound");
180        let max = u32::MAX;
181        // Count down from u32::MAX rather than up from 0 to avoid hang on
182        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
183        let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
184        let level = level.next();
185        let break_and_inc = format!(
186            "{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
187{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
188        );
189
190        Some((decl, break_and_inc))
191    }
192
193    /// Helper method used to find which expressions of a given function require baking
194    ///
195    /// # Notes
196    /// Clears `need_bake_expressions` set before adding to it
197    fn update_expressions_to_bake(
198        &mut self,
199        module: &Module,
200        func: &crate::Function,
201        info: &valid::FunctionInfo,
202    ) {
203        use crate::Expression;
204        self.need_bake_expressions.clear();
205        for (exp_handle, expr) in func.expressions.iter() {
206            let expr_info = &info[exp_handle];
207            let min_ref_count = func.expressions[exp_handle].bake_ref_count();
208            if min_ref_count <= expr_info.ref_count {
209                self.need_bake_expressions.insert(exp_handle);
210            }
211
212            if let Expression::Math { fun, arg, arg1, .. } = *expr {
213                match fun {
214                    crate::MathFunction::Asinh
215                    | crate::MathFunction::Acosh
216                    | crate::MathFunction::Atanh
217                    | crate::MathFunction::Unpack2x16float
218                    | crate::MathFunction::Unpack2x16snorm
219                    | crate::MathFunction::Unpack2x16unorm
220                    | crate::MathFunction::Unpack4x8snorm
221                    | crate::MathFunction::Unpack4x8unorm
222                    | crate::MathFunction::Unpack4xI8
223                    | crate::MathFunction::Unpack4xU8
224                    | crate::MathFunction::Pack2x16float
225                    | crate::MathFunction::Pack2x16snorm
226                    | crate::MathFunction::Pack2x16unorm
227                    | crate::MathFunction::Pack4x8snorm
228                    | crate::MathFunction::Pack4x8unorm
229                    | crate::MathFunction::Pack4xI8
230                    | crate::MathFunction::Pack4xU8
231                    | crate::MathFunction::Pack4xI8Clamp
232                    | crate::MathFunction::Pack4xU8Clamp => {
233                        self.need_bake_expressions.insert(arg);
234                    }
235                    crate::MathFunction::CountLeadingZeros => {
236                        let inner = info[exp_handle].ty.inner_with(&module.types);
237                        if let Some(ScalarKind::Sint) = inner.scalar_kind() {
238                            self.need_bake_expressions.insert(arg);
239                        }
240                    }
241                    crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
242                        self.need_bake_expressions.insert(arg);
243                        self.need_bake_expressions.insert(arg1.unwrap());
244                    }
245                    _ => {}
246                }
247            }
248
249            if let Expression::Derivative { axis, ctrl, expr } = *expr {
250                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
251                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
252                    self.need_bake_expressions.insert(expr);
253                }
254            }
255
256            if let Expression::GlobalVariable(_) = *expr {
257                let inner = info[exp_handle].ty.inner_with(&module.types);
258
259                if let TypeInner::Sampler { .. } = *inner {
260                    self.need_bake_expressions.insert(exp_handle);
261                }
262            }
263        }
264        for statement in func.body.iter() {
265            match *statement {
266                crate::Statement::SubgroupCollectiveOperation {
267                    op: _,
268                    collective_op: crate::CollectiveOperation::InclusiveScan,
269                    argument,
270                    result: _,
271                } => {
272                    self.need_bake_expressions.insert(argument);
273                }
274                crate::Statement::Atomic {
275                    fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
276                    ..
277                } => {
278                    self.need_bake_expressions.insert(cmp);
279                }
280                _ => {}
281            }
282        }
283    }
284
285    pub fn write(
286        &mut self,
287        module: &Module,
288        module_info: &valid::ModuleInfo,
289        fragment_entry_point: Option<&FragmentEntryPoint<'_>>,
290    ) -> Result<super::ReflectionInfo, Error> {
291        self.reset(module);
292
293        // Write special constants, if needed
294        if let Some(ref bt) = self.options.special_constants_binding {
295            writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
296            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
297            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
298            writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
299            writeln!(self.out, "}};")?;
300            write!(
301                self.out,
302                "ConstantBuffer<{}> {}: register(b{}",
303                SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
304            )?;
305            if bt.space != 0 {
306                write!(self.out, ", space{}", bt.space)?;
307            }
308            writeln!(self.out, ");")?;
309
310            // Extra newline for readability
311            writeln!(self.out)?;
312        }
313
314        for (group, bt) in self.options.dynamic_storage_buffer_offsets_targets.iter() {
315            writeln!(self.out, "struct __dynamic_buffer_offsetsTy{} {{", group)?;
316            for i in 0..bt.size {
317                writeln!(self.out, "{}uint _{};", back::INDENT, i)?;
318            }
319            writeln!(self.out, "}};")?;
320            writeln!(
321                self.out,
322                "ConstantBuffer<__dynamic_buffer_offsetsTy{}> __dynamic_buffer_offsets{}: register(b{}, space{});",
323                group, group, bt.register, bt.space
324            )?;
325
326            // Extra newline for readability
327            writeln!(self.out)?;
328        }
329
330        // Save all entry point output types
331        let ep_results = module
332            .entry_points
333            .iter()
334            .map(|ep| (ep.stage, ep.function.result.clone()))
335            .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
336
337        self.write_all_mat_cx2_typedefs_and_functions(module)?;
338
339        // Write all structs
340        for (handle, ty) in module.types.iter() {
341            if let TypeInner::Struct { ref members, span } = ty.inner {
342                if module.types[members.last().unwrap().ty]
343                    .inner
344                    .is_dynamically_sized(&module.types)
345                {
346                    // unsized arrays can only be in storage buffers,
347                    // for which we use `ByteAddressBuffer` anyway.
348                    continue;
349                }
350
351                let ep_result = ep_results.iter().find(|e| {
352                    if let Some(ref result) = e.1 {
353                        result.ty == handle
354                    } else {
355                        false
356                    }
357                });
358
359                self.write_struct(
360                    module,
361                    handle,
362                    members,
363                    span,
364                    ep_result.map(|r| (r.0, Io::Output)),
365                )?;
366                writeln!(self.out)?;
367            }
368        }
369
370        self.write_special_functions(module)?;
371
372        self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
373        self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
374
375        // Write all named constants
376        let mut constants = module
377            .constants
378            .iter()
379            .filter(|&(_, c)| c.name.is_some())
380            .peekable();
381        while let Some((handle, _)) = constants.next() {
382            self.write_global_constant(module, handle)?;
383            // Add extra newline for readability on last iteration
384            if constants.peek().is_none() {
385                writeln!(self.out)?;
386            }
387        }
388
389        // Write all globals
390        for (ty, _) in module.global_variables.iter() {
391            self.write_global(module, ty)?;
392        }
393
394        if !module.global_variables.is_empty() {
395            // Add extra newline for readability
396            writeln!(self.out)?;
397        }
398
399        let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
400            .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
401
402        // Write all entry points wrapped structs
403        for index in ep_range.clone() {
404            let ep = &module.entry_points[index];
405            let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
406            let ep_io = self.write_ep_interface(
407                module,
408                &ep.function,
409                ep.stage,
410                &ep_name,
411                fragment_entry_point,
412            )?;
413            self.entry_point_io.insert(index, ep_io);
414        }
415
416        // Write all regular functions
417        for (handle, function) in module.functions.iter() {
418            let info = &module_info[handle];
419
420            // Check if all of the globals are accessible
421            if !self.options.fake_missing_bindings {
422                if let Some((var_handle, _)) =
423                    module
424                        .global_variables
425                        .iter()
426                        .find(|&(var_handle, var)| match var.binding {
427                            Some(ref binding) if !info[var_handle].is_empty() => {
428                                self.options.resolve_resource_binding(binding).is_err()
429                            }
430                            _ => false,
431                        })
432                {
433                    log::info!(
434                        "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
435                        handle,
436                        function.name,
437                        var_handle
438                    );
439                    continue;
440                }
441            }
442
443            let ctx = back::FunctionCtx {
444                ty: back::FunctionType::Function(handle),
445                info,
446                expressions: &function.expressions,
447                named_expressions: &function.named_expressions,
448            };
449            let name = self.names[&NameKey::Function(handle)].clone();
450
451            self.write_wrapped_functions(module, &ctx)?;
452
453            self.write_function(module, name.as_str(), function, &ctx, info)?;
454
455            writeln!(self.out)?;
456        }
457
458        let mut translated_ep_names = Vec::with_capacity(ep_range.len());
459
460        // Write all entry points
461        for index in ep_range {
462            let ep = &module.entry_points[index];
463            let info = module_info.get_entry_point(index);
464
465            if !self.options.fake_missing_bindings {
466                let mut ep_error = None;
467                for (var_handle, var) in module.global_variables.iter() {
468                    match var.binding {
469                        Some(ref binding) if !info[var_handle].is_empty() => {
470                            if let Err(err) = self.options.resolve_resource_binding(binding) {
471                                ep_error = Some(err);
472                                break;
473                            }
474                        }
475                        _ => {}
476                    }
477                }
478                if let Some(err) = ep_error {
479                    translated_ep_names.push(Err(err));
480                    continue;
481                }
482            }
483
484            let ctx = back::FunctionCtx {
485                ty: back::FunctionType::EntryPoint(index as u16),
486                info,
487                expressions: &ep.function.expressions,
488                named_expressions: &ep.function.named_expressions,
489            };
490
491            self.write_wrapped_functions(module, &ctx)?;
492
493            if ep.stage == ShaderStage::Compute {
494                // HLSL is calling workgroup size "num threads"
495                let num_threads = ep.workgroup_size;
496                writeln!(
497                    self.out,
498                    "[numthreads({}, {}, {})]",
499                    num_threads[0], num_threads[1], num_threads[2]
500                )?;
501            }
502
503            let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
504            self.write_function(module, &name, &ep.function, &ctx, info)?;
505
506            if index < module.entry_points.len() - 1 {
507                writeln!(self.out)?;
508            }
509
510            translated_ep_names.push(Ok(name));
511        }
512
513        Ok(super::ReflectionInfo {
514            entry_point_names: translated_ep_names,
515        })
516    }
517
518    fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
519        match *binding {
520            crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
521                write!(self.out, "precise ")?;
522            }
523            crate::Binding::Location {
524                interpolation,
525                sampling,
526                ..
527            } => {
528                if let Some(interpolation) = interpolation {
529                    if let Some(string) = interpolation.to_hlsl_str() {
530                        write!(self.out, "{string} ")?
531                    }
532                }
533
534                if let Some(sampling) = sampling {
535                    if let Some(string) = sampling.to_hlsl_str() {
536                        write!(self.out, "{string} ")?
537                    }
538                }
539            }
540            crate::Binding::BuiltIn(_) => {}
541        }
542
543        Ok(())
544    }
545
546    //TODO: we could force fragment outputs to always go through `entry_point_io.output` path
547    // if they are struct, so that the `stage` argument here could be omitted.
548    fn write_semantic(
549        &mut self,
550        binding: &Option<crate::Binding>,
551        stage: Option<(ShaderStage, Io)>,
552    ) -> BackendResult {
553        match *binding {
554            Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
555                let builtin_str = builtin.to_hlsl_str()?;
556                write!(self.out, " : {builtin_str}")?;
557            }
558            Some(crate::Binding::Location {
559                blend_src: Some(1), ..
560            }) => {
561                write!(self.out, " : SV_Target1")?;
562            }
563            Some(crate::Binding::Location { location, .. }) => {
564                if stage == Some((ShaderStage::Fragment, Io::Output)) {
565                    write!(self.out, " : SV_Target{location}")?;
566                } else {
567                    write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
568                }
569            }
570            _ => {}
571        }
572
573        Ok(())
574    }
575
576    fn write_interface_struct(
577        &mut self,
578        module: &Module,
579        shader_stage: (ShaderStage, Io),
580        struct_name: String,
581        mut members: Vec<EpStructMember>,
582    ) -> Result<EntryPointBinding, Error> {
583        // Sort the members so that first come the user-defined varyings
584        // in ascending locations, and then built-ins. This allows VS and FS
585        // interfaces to match with regards to order.
586        members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
587
588        write!(self.out, "struct {struct_name}")?;
589        writeln!(self.out, " {{")?;
590        for m in members.iter() {
591            // Sanity check that each IO member is a built-in or is assigned a
592            // location. Also see note about nesting in `write_ep_input_struct`.
593            debug_assert!(m.binding.is_some());
594
595            if is_subgroup_builtin_binding(&m.binding) {
596                continue;
597            }
598            write!(self.out, "{}", back::INDENT)?;
599            if let Some(ref binding) = m.binding {
600                self.write_modifier(binding)?;
601            }
602            self.write_type(module, m.ty)?;
603            write!(self.out, " {}", &m.name)?;
604            self.write_semantic(&m.binding, Some(shader_stage))?;
605            writeln!(self.out, ";")?;
606        }
607        if members.iter().any(|arg| {
608            matches!(
609                arg.binding,
610                Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId))
611            )
612        }) {
613            writeln!(
614                self.out,
615                "{}uint __local_invocation_index : SV_GroupIndex;",
616                back::INDENT
617            )?;
618        }
619        writeln!(self.out, "}};")?;
620        writeln!(self.out)?;
621
622        // See ordering notes on EntryPointInterface fields
623        match shader_stage.1 {
624            Io::Input => {
625                // bring back the original order
626                members.sort_by_key(|m| m.index);
627            }
628            Io::Output => {
629                // keep it sorted by binding
630            }
631        }
632
633        Ok(EntryPointBinding {
634            arg_name: self.namer.call(struct_name.to_lowercase().as_str()),
635            ty_name: struct_name,
636            members,
637        })
638    }
639
640    /// Flatten all entry point arguments into a single struct.
641    /// This is needed since we need to re-order them: first placing user locations,
642    /// then built-ins.
643    fn write_ep_input_struct(
644        &mut self,
645        module: &Module,
646        func: &crate::Function,
647        stage: ShaderStage,
648        entry_point_name: &str,
649    ) -> Result<EntryPointBinding, Error> {
650        let struct_name = format!("{stage:?}Input_{entry_point_name}");
651
652        let mut fake_members = Vec::new();
653        for arg in func.arguments.iter() {
654            // NOTE: We don't need to handle nesting structs. All members must
655            // be either built-ins or assigned a location. I.E. `binding` is
656            // `Some`. This is checked in `VaryingContext::validate`. See:
657            // https://gpuweb.github.io/gpuweb/wgsl/#input-output-locations
658            match module.types[arg.ty].inner {
659                TypeInner::Struct { ref members, .. } => {
660                    for member in members.iter() {
661                        let name = self.namer.call_or(&member.name, "member");
662                        let index = fake_members.len() as u32;
663                        fake_members.push(EpStructMember {
664                            name,
665                            ty: member.ty,
666                            binding: member.binding.clone(),
667                            index,
668                        });
669                    }
670                }
671                _ => {
672                    let member_name = self.namer.call_or(&arg.name, "member");
673                    let index = fake_members.len() as u32;
674                    fake_members.push(EpStructMember {
675                        name: member_name,
676                        ty: arg.ty,
677                        binding: arg.binding.clone(),
678                        index,
679                    });
680                }
681            }
682        }
683
684        self.write_interface_struct(module, (stage, Io::Input), struct_name, fake_members)
685    }
686
687    /// Flatten all entry point results into a single struct.
688    /// This is needed since we need to re-order them: first placing user locations,
689    /// then built-ins.
690    fn write_ep_output_struct(
691        &mut self,
692        module: &Module,
693        result: &crate::FunctionResult,
694        stage: ShaderStage,
695        entry_point_name: &str,
696        frag_ep: Option<&FragmentEntryPoint<'_>>,
697    ) -> Result<EntryPointBinding, Error> {
698        let struct_name = format!("{stage:?}Output_{entry_point_name}");
699
700        let empty = [];
701        let members = match module.types[result.ty].inner {
702            TypeInner::Struct { ref members, .. } => members,
703            ref other => {
704                log::error!("Unexpected {:?} output type without a binding", other);
705                &empty[..]
706            }
707        };
708
709        // Gather list of fragment input locations. We use this below to remove user-defined
710        // varyings from VS outputs that aren't in the FS inputs. This makes the VS interface match
711        // as long as the FS inputs are a subset of the VS outputs. This is only applied if the
712        // writer is supplied with information about the fragment entry point.
713        let fs_input_locs = if let (Some(frag_ep), ShaderStage::Vertex) = (frag_ep, stage) {
714            let mut fs_input_locs = Vec::new();
715            for arg in frag_ep.func.arguments.iter() {
716                let mut push_if_location = |binding: &Option<crate::Binding>| match *binding {
717                    Some(crate::Binding::Location { location, .. }) => fs_input_locs.push(location),
718                    Some(crate::Binding::BuiltIn(_)) | None => {}
719                };
720
721                // NOTE: We don't need to handle struct nesting. See note in
722                // `write_ep_input_struct`.
723                match frag_ep.module.types[arg.ty].inner {
724                    TypeInner::Struct { ref members, .. } => {
725                        for member in members.iter() {
726                            push_if_location(&member.binding);
727                        }
728                    }
729                    _ => push_if_location(&arg.binding),
730                }
731            }
732            fs_input_locs.sort();
733            Some(fs_input_locs)
734        } else {
735            None
736        };
737
738        let mut fake_members = Vec::new();
739        for (index, member) in members.iter().enumerate() {
740            if let Some(ref fs_input_locs) = fs_input_locs {
741                match member.binding {
742                    Some(crate::Binding::Location { location, .. }) => {
743                        if fs_input_locs.binary_search(&location).is_err() {
744                            continue;
745                        }
746                    }
747                    Some(crate::Binding::BuiltIn(_)) | None => {}
748                }
749            }
750
751            let member_name = self.namer.call_or(&member.name, "member");
752            fake_members.push(EpStructMember {
753                name: member_name,
754                ty: member.ty,
755                binding: member.binding.clone(),
756                index: index as u32,
757            });
758        }
759
760        self.write_interface_struct(module, (stage, Io::Output), struct_name, fake_members)
761    }
762
763    /// Writes special interface structures for an entry point. The special structures have
764    /// all the fields flattened into them and sorted by binding. They are needed to emulate
765    /// subgroup built-ins and to make the interfaces between VS outputs and FS inputs match.
766    fn write_ep_interface(
767        &mut self,
768        module: &Module,
769        func: &crate::Function,
770        stage: ShaderStage,
771        ep_name: &str,
772        frag_ep: Option<&FragmentEntryPoint<'_>>,
773    ) -> Result<EntryPointInterface, Error> {
774        Ok(EntryPointInterface {
775            input: if !func.arguments.is_empty()
776                && (stage == ShaderStage::Fragment
777                    || func
778                        .arguments
779                        .iter()
780                        .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
781            {
782                Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
783            } else {
784                None
785            },
786            output: match func.result {
787                Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
788                    Some(self.write_ep_output_struct(module, fr, stage, ep_name, frag_ep)?)
789                }
790                _ => None,
791            },
792        })
793    }
794
795    fn write_ep_argument_initialization(
796        &mut self,
797        ep: &crate::EntryPoint,
798        ep_input: &EntryPointBinding,
799        fake_member: &EpStructMember,
800    ) -> BackendResult {
801        match fake_member.binding {
802            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
803                write!(self.out, "WaveGetLaneCount()")?
804            }
805            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
806                write!(self.out, "WaveGetLaneIndex()")?
807            }
808            Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
809                self.out,
810                "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
811                ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
812            )?,
813            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
814                write!(
815                    self.out,
816                    "{}.__local_invocation_index / WaveGetLaneCount()",
817                    ep_input.arg_name
818                )?;
819            }
820            _ => {
821                write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
822            }
823        }
824        Ok(())
825    }
826
827    /// Write an entry point preface that initializes the arguments as specified in IR.
828    fn write_ep_arguments_initialization(
829        &mut self,
830        module: &Module,
831        func: &crate::Function,
832        ep_index: u16,
833    ) -> BackendResult {
834        let ep = &module.entry_points[ep_index as usize];
835        let ep_input = match self
836            .entry_point_io
837            .get_mut(&(ep_index as usize))
838            .unwrap()
839            .input
840            .take()
841        {
842            Some(ep_input) => ep_input,
843            None => return Ok(()),
844        };
845        let mut fake_iter = ep_input.members.iter();
846        for (arg_index, arg) in func.arguments.iter().enumerate() {
847            write!(self.out, "{}", back::INDENT)?;
848            self.write_type(module, arg.ty)?;
849            let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
850            write!(self.out, " {arg_name}")?;
851            match module.types[arg.ty].inner {
852                TypeInner::Array { base, size, .. } => {
853                    self.write_array_size(module, base, size)?;
854                    write!(self.out, " = ")?;
855                    self.write_ep_argument_initialization(
856                        ep,
857                        &ep_input,
858                        fake_iter.next().unwrap(),
859                    )?;
860                    writeln!(self.out, ";")?;
861                }
862                TypeInner::Struct { ref members, .. } => {
863                    write!(self.out, " = {{ ")?;
864                    for index in 0..members.len() {
865                        if index != 0 {
866                            write!(self.out, ", ")?;
867                        }
868                        self.write_ep_argument_initialization(
869                            ep,
870                            &ep_input,
871                            fake_iter.next().unwrap(),
872                        )?;
873                    }
874                    writeln!(self.out, " }};")?;
875                }
876                _ => {
877                    write!(self.out, " = ")?;
878                    self.write_ep_argument_initialization(
879                        ep,
880                        &ep_input,
881                        fake_iter.next().unwrap(),
882                    )?;
883                    writeln!(self.out, ";")?;
884                }
885            }
886        }
887        assert!(fake_iter.next().is_none());
888        Ok(())
889    }
890
891    /// Helper method used to write global variables
892    /// # Notes
893    /// Always adds a newline
894    fn write_global(
895        &mut self,
896        module: &Module,
897        handle: Handle<crate::GlobalVariable>,
898    ) -> BackendResult {
899        let global = &module.global_variables[handle];
900        let inner = &module.types[global.ty].inner;
901
902        if let Some(ref binding) = global.binding {
903            if let Err(err) = self.options.resolve_resource_binding(binding) {
904                log::info!(
905                    "Skipping global {:?} (name {:?}) for being inaccessible: {}",
906                    handle,
907                    global.name,
908                    err,
909                );
910                return Ok(());
911            }
912        }
913
914        let handle_ty = match *inner {
915            TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
916            _ => inner,
917        };
918
919        // Samplers are handled entirely differently, so defer entirely to that method.
920        let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. });
921
922        if is_sampler {
923            return self.write_global_sampler(module, handle, global);
924        }
925
926        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-register
927        let register_ty = match global.space {
928            crate::AddressSpace::Function => unreachable!("Function address space"),
929            crate::AddressSpace::Private => {
930                write!(self.out, "static ")?;
931                self.write_type(module, global.ty)?;
932                ""
933            }
934            crate::AddressSpace::WorkGroup => {
935                write!(self.out, "groupshared ")?;
936                self.write_type(module, global.ty)?;
937                ""
938            }
939            crate::AddressSpace::Uniform => {
940                // constant buffer declarations are expected to be inlined, e.g.
941                // `cbuffer foo: register(b0) { field1: type1; }`
942                write!(self.out, "cbuffer")?;
943                "b"
944            }
945            crate::AddressSpace::Storage { access } => {
946                let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
947                    ("RW", "u")
948                } else {
949                    ("", "t")
950                };
951                write!(self.out, "{prefix}ByteAddressBuffer")?;
952                register
953            }
954            crate::AddressSpace::Handle => {
955                let register = match *handle_ty {
956                    // all storage textures are UAV, unconditionally
957                    TypeInner::Image {
958                        class: crate::ImageClass::Storage { .. },
959                        ..
960                    } => "u",
961                    _ => "t",
962                };
963                self.write_type(module, global.ty)?;
964                register
965            }
966            crate::AddressSpace::PushConstant => {
967                // The type of the push constants will be wrapped in `ConstantBuffer`
968                write!(self.out, "ConstantBuffer<")?;
969                "b"
970            }
971        };
972
973        // If the global is a push constant write the type now because it will be a
974        // generic argument to `ConstantBuffer`
975        if global.space == crate::AddressSpace::PushConstant {
976            self.write_global_type(module, global.ty)?;
977
978            // need to write the array size if the type was emitted with `write_type`
979            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
980                self.write_array_size(module, base, size)?;
981            }
982
983            // Close the angled brackets for the generic argument
984            write!(self.out, ">")?;
985        }
986
987        let name = &self.names[&NameKey::GlobalVariable(handle)];
988        write!(self.out, " {name}")?;
989
990        // Push constants need to be assigned a binding explicitly by the consumer
991        // since naga has no way to know the binding from the shader alone
992        if global.space == crate::AddressSpace::PushConstant {
993            match module.types[global.ty].inner {
994                TypeInner::Struct { .. } => {}
995                _ => {
996                    return Err(Error::Unimplemented(format!(
997                        "push-constant '{name}' has non-struct type; tracked by: https://github.com/gfx-rs/wgpu/issues/5683"
998                    )));
999                }
1000            }
1001
1002            let target = self
1003                .options
1004                .push_constants_target
1005                .as_ref()
1006                .expect("No bind target was defined for the push constants block");
1007            write!(self.out, ": register(b{}", target.register)?;
1008            if target.space != 0 {
1009                write!(self.out, ", space{}", target.space)?;
1010            }
1011            write!(self.out, ")")?;
1012        }
1013
1014        if let Some(ref binding) = global.binding {
1015            // this was already resolved earlier when we started evaluating an entry point.
1016            let bt = self.options.resolve_resource_binding(binding).unwrap();
1017
1018            // need to write the binding array size if the type was emitted with `write_type`
1019            if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
1020                if let Some(overridden_size) = bt.binding_array_size {
1021                    write!(self.out, "[{overridden_size}]")?;
1022                } else {
1023                    self.write_array_size(module, base, size)?;
1024                }
1025            }
1026
1027            write!(self.out, " : register({}{}", register_ty, bt.register)?;
1028            if bt.space != 0 {
1029                write!(self.out, ", space{}", bt.space)?;
1030            }
1031            write!(self.out, ")")?;
1032        } else {
1033            // need to write the array size if the type was emitted with `write_type`
1034            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1035                self.write_array_size(module, base, size)?;
1036            }
1037            if global.space == crate::AddressSpace::Private {
1038                write!(self.out, " = ")?;
1039                if let Some(init) = global.init {
1040                    self.write_const_expression(module, init, &module.global_expressions)?;
1041                } else {
1042                    self.write_default_init(module, global.ty)?;
1043                }
1044            }
1045        }
1046
1047        if global.space == crate::AddressSpace::Uniform {
1048            write!(self.out, " {{ ")?;
1049
1050            self.write_global_type(module, global.ty)?;
1051
1052            write!(
1053                self.out,
1054                " {}",
1055                &self.names[&NameKey::GlobalVariable(handle)]
1056            )?;
1057
1058            // need to write the array size if the type was emitted with `write_type`
1059            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1060                self.write_array_size(module, base, size)?;
1061            }
1062
1063            writeln!(self.out, "; }}")?;
1064        } else {
1065            writeln!(self.out, ";")?;
1066        }
1067
1068        Ok(())
1069    }
1070
1071    fn write_global_sampler(
1072        &mut self,
1073        module: &Module,
1074        handle: Handle<crate::GlobalVariable>,
1075        global: &crate::GlobalVariable,
1076    ) -> BackendResult {
1077        let binding = *global.binding.as_ref().unwrap();
1078
1079        let key = super::SamplerIndexBufferKey {
1080            group: binding.group,
1081        };
1082        self.write_wrapped_sampler_buffer(key)?;
1083
1084        // This was already validated, so we can confidently unwrap it.
1085        let bt = self.options.resolve_resource_binding(&binding).unwrap();
1086
1087        match module.types[global.ty].inner {
1088            TypeInner::Sampler { comparison } => {
1089                // If we are generating a static access, we create a variable for the sampler.
1090                //
1091                // This prevents the DXIL from containing multiple lookups for the sampler, which
1092                // the backend compiler will then have to eliminate. AMD does seem to be able to
1093                // eliminate these, but better safe than sorry.
1094
1095                write!(self.out, "static const ")?;
1096                self.write_type(module, global.ty)?;
1097
1098                let heap_var = if comparison {
1099                    COMPARISON_SAMPLER_HEAP_VAR
1100                } else {
1101                    SAMPLER_HEAP_VAR
1102                };
1103
1104                let index_buffer_name = &self.wrapped.sampler_index_buffers[&key];
1105                let name = &self.names[&NameKey::GlobalVariable(handle)];
1106                writeln!(
1107                    self.out,
1108                    " {name} = {heap_var}[{index_buffer_name}[{register}]];",
1109                    register = bt.register
1110                )?;
1111            }
1112            TypeInner::BindingArray { .. } => {
1113                // If we are generating a binding array, we cannot directly access the sampler as the index
1114                // into the sampler index buffer is unknown at compile time. Instead we generate a constant
1115                // that represents the "base" index into the sampler index buffer. This constant is added
1116                // to the user provided index to get the final index into the sampler index buffer.
1117
1118                let name = &self.names[&NameKey::GlobalVariable(handle)];
1119                writeln!(
1120                    self.out,
1121                    "static const uint {name} = {register};",
1122                    register = bt.register
1123                )?;
1124            }
1125            _ => unreachable!(),
1126        };
1127
1128        Ok(())
1129    }
1130
1131    /// Helper method used to write global constants
1132    ///
1133    /// # Notes
1134    /// Ends in a newline
1135    fn write_global_constant(
1136        &mut self,
1137        module: &Module,
1138        handle: Handle<crate::Constant>,
1139    ) -> BackendResult {
1140        write!(self.out, "static const ")?;
1141        let constant = &module.constants[handle];
1142        self.write_type(module, constant.ty)?;
1143        let name = &self.names[&NameKey::Constant(handle)];
1144        write!(self.out, " {name}")?;
1145        // Write size for array type
1146        if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
1147            self.write_array_size(module, base, size)?;
1148        }
1149        write!(self.out, " = ")?;
1150        self.write_const_expression(module, constant.init, &module.global_expressions)?;
1151        writeln!(self.out, ";")?;
1152        Ok(())
1153    }
1154
1155    pub(super) fn write_array_size(
1156        &mut self,
1157        module: &Module,
1158        base: Handle<crate::Type>,
1159        size: crate::ArraySize,
1160    ) -> BackendResult {
1161        write!(self.out, "[")?;
1162
1163        match size.resolve(module.to_ctx())? {
1164            proc::IndexableLength::Known(size) => {
1165                write!(self.out, "{size}")?;
1166            }
1167            proc::IndexableLength::Dynamic => unreachable!(),
1168        }
1169
1170        write!(self.out, "]")?;
1171
1172        if let TypeInner::Array {
1173            base: next_base,
1174            size: next_size,
1175            ..
1176        } = module.types[base].inner
1177        {
1178            self.write_array_size(module, next_base, next_size)?;
1179        }
1180
1181        Ok(())
1182    }
1183
1184    /// Helper method used to write structs
1185    ///
1186    /// # Notes
1187    /// Ends in a newline
1188    fn write_struct(
1189        &mut self,
1190        module: &Module,
1191        handle: Handle<crate::Type>,
1192        members: &[crate::StructMember],
1193        span: u32,
1194        shader_stage: Option<(ShaderStage, Io)>,
1195    ) -> BackendResult {
1196        // Write struct name
1197        let struct_name = &self.names[&NameKey::Type(handle)];
1198        writeln!(self.out, "struct {struct_name} {{")?;
1199
1200        let mut last_offset = 0;
1201        for (index, member) in members.iter().enumerate() {
1202            if member.binding.is_none() && member.offset > last_offset {
1203                // using int as padding should work as long as the backend
1204                // doesn't support a type that's less than 4 bytes in size
1205                // (Error::UnsupportedScalar catches this)
1206                let padding = (member.offset - last_offset) / 4;
1207                for i in 0..padding {
1208                    writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
1209                }
1210            }
1211            let ty_inner = &module.types[member.ty].inner;
1212            last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;
1213
1214            // The indentation is only for readability
1215            write!(self.out, "{}", back::INDENT)?;
1216
1217            match module.types[member.ty].inner {
1218                TypeInner::Array { base, size, .. } => {
1219                    // HLSL arrays are written as `type name[size]`
1220
1221                    self.write_global_type(module, member.ty)?;
1222
1223                    // Write `name`
1224                    write!(
1225                        self.out,
1226                        " {}",
1227                        &self.names[&NameKey::StructMember(handle, index as u32)]
1228                    )?;
1229                    // Write [size]
1230                    self.write_array_size(module, base, size)?;
1231                }
1232                // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1233                // See the module-level block comment in mod.rs for details.
1234                TypeInner::Matrix {
1235                    rows,
1236                    columns,
1237                    scalar,
1238                } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
1239                    let vec_ty = TypeInner::Vector { size: rows, scalar };
1240                    let field_name_key = NameKey::StructMember(handle, index as u32);
1241
1242                    for i in 0..columns as u8 {
1243                        if i != 0 {
1244                            write!(self.out, "; ")?;
1245                        }
1246                        self.write_value_type(module, &vec_ty)?;
1247                        write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
1248                    }
1249                }
1250                _ => {
1251                    // Write modifier before type
1252                    if let Some(ref binding) = member.binding {
1253                        self.write_modifier(binding)?;
1254                    }
1255
1256                    // Even though Naga IR matrices are column-major, we must describe
1257                    // matrices passed from the CPU as being in row-major order.
1258                    // See the module-level block comment in mod.rs for details.
1259                    if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
1260                        write!(self.out, "row_major ")?;
1261                    }
1262
1263                    // Write the member type and name
1264                    self.write_type(module, member.ty)?;
1265                    write!(
1266                        self.out,
1267                        " {}",
1268                        &self.names[&NameKey::StructMember(handle, index as u32)]
1269                    )?;
1270                }
1271            }
1272
1273            self.write_semantic(&member.binding, shader_stage)?;
1274            writeln!(self.out, ";")?;
1275        }
1276
1277        // add padding at the end since sizes of types don't get rounded up to their alignment in HLSL
1278        if members.last().unwrap().binding.is_none() && span > last_offset {
1279            let padding = (span - last_offset) / 4;
1280            for i in 0..padding {
1281                writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
1282            }
1283        }
1284
1285        writeln!(self.out, "}};")?;
1286        Ok(())
1287    }
1288
1289    /// Helper method used to write global/structs non image/sampler types
1290    ///
1291    /// # Notes
1292    /// Adds no trailing or leading whitespace
1293    pub(super) fn write_global_type(
1294        &mut self,
1295        module: &Module,
1296        ty: Handle<crate::Type>,
1297    ) -> BackendResult {
1298        let matrix_data = get_inner_matrix_data(module, ty);
1299
1300        // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1301        // See the module-level block comment in mod.rs for details.
1302        if let Some(MatrixType {
1303            columns,
1304            rows: crate::VectorSize::Bi,
1305            width: 4,
1306        }) = matrix_data
1307        {
1308            write!(self.out, "__mat{}x2", columns as u8)?;
1309        } else {
1310            // Even though Naga IR matrices are column-major, we must describe
1311            // matrices passed from the CPU as being in row-major order.
1312            // See the module-level block comment in mod.rs for details.
1313            if matrix_data.is_some() {
1314                write!(self.out, "row_major ")?;
1315            }
1316
1317            self.write_type(module, ty)?;
1318        }
1319
1320        Ok(())
1321    }
1322
1323    /// Helper method used to write non image/sampler types
1324    ///
1325    /// # Notes
1326    /// Adds no trailing or leading whitespace
1327    pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
1328        let inner = &module.types[ty].inner;
1329        match *inner {
1330            TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
1331            // hlsl array has the size separated from the base type
1332            TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
1333                self.write_type(module, base)?
1334            }
1335            ref other => self.write_value_type(module, other)?,
1336        }
1337
1338        Ok(())
1339    }
1340
1341    /// Helper method used to write value types
1342    ///
1343    /// # Notes
1344    /// Adds no trailing or leading whitespace
1345    pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
1346        match *inner {
1347            TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
1348                write!(self.out, "{}", scalar.to_hlsl_str()?)?;
1349            }
1350            TypeInner::Vector { size, scalar } => {
1351                write!(
1352                    self.out,
1353                    "{}{}",
1354                    scalar.to_hlsl_str()?,
1355                    common::vector_size_str(size)
1356                )?;
1357            }
1358            TypeInner::Matrix {
1359                columns,
1360                rows,
1361                scalar,
1362            } => {
1363                // The IR supports only float matrix
1364                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix
1365
1366                // Because of the implicit transpose all matrices have in HLSL, we need to transpose the size as well.
1367                write!(
1368                    self.out,
1369                    "{}{}x{}",
1370                    scalar.to_hlsl_str()?,
1371                    common::vector_size_str(columns),
1372                    common::vector_size_str(rows),
1373                )?;
1374            }
1375            TypeInner::Image {
1376                dim,
1377                arrayed,
1378                class,
1379            } => {
1380                self.write_image_type(dim, arrayed, class)?;
1381            }
1382            TypeInner::Sampler { comparison } => {
1383                let sampler = if comparison {
1384                    "SamplerComparisonState"
1385                } else {
1386                    "SamplerState"
1387                };
1388                write!(self.out, "{sampler}")?;
1389            }
1390            // HLSL arrays are written as `type name[size]`
1391            // Current code is written arrays only as `[size]`
1392            // Base `type` and `name` should be written outside
1393            TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
1394                self.write_array_size(module, base, size)?;
1395            }
1396            TypeInner::AccelerationStructure { .. } => {
1397                write!(self.out, "RaytracingAccelerationStructure")?;
1398            }
1399            TypeInner::RayQuery { .. } => {
1400                // these are constant flags, there are dynamic flags also but constant flags are not supported by naga
1401                write!(self.out, "RayQuery<RAY_FLAG_NONE>")?;
1402            }
1403            _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
1404        }
1405
1406        Ok(())
1407    }
1408
1409    /// Helper method used to write functions
1410    /// # Notes
1411    /// Ends in a newline
1412    fn write_function(
1413        &mut self,
1414        module: &Module,
1415        name: &str,
1416        func: &crate::Function,
1417        func_ctx: &back::FunctionCtx<'_>,
1418        info: &valid::FunctionInfo,
1419    ) -> BackendResult {
1420        // Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax
1421
1422        self.update_expressions_to_bake(module, func, info);
1423
1424        if let Some(ref result) = func.result {
1425            // Write typedef if return type is an array
1426            let array_return_type = match module.types[result.ty].inner {
1427                TypeInner::Array { base, size, .. } => {
1428                    let array_return_type = self.namer.call(&format!("ret_{name}"));
1429                    write!(self.out, "typedef ")?;
1430                    self.write_type(module, result.ty)?;
1431                    write!(self.out, " {}", array_return_type)?;
1432                    self.write_array_size(module, base, size)?;
1433                    writeln!(self.out, ";")?;
1434                    Some(array_return_type)
1435                }
1436                _ => None,
1437            };
1438
1439            // Write modifier
1440            if let Some(
1441                ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }),
1442            ) = result.binding
1443            {
1444                self.write_modifier(binding)?;
1445            }
1446
1447            // Write return type
1448            match func_ctx.ty {
1449                back::FunctionType::Function(_) => {
1450                    if let Some(array_return_type) = array_return_type {
1451                        write!(self.out, "{array_return_type}")?;
1452                    } else {
1453                        self.write_type(module, result.ty)?;
1454                    }
1455                }
1456                back::FunctionType::EntryPoint(index) => {
1457                    if let Some(ref ep_output) =
1458                        self.entry_point_io.get(&(index as usize)).unwrap().output
1459                    {
1460                        write!(self.out, "{}", ep_output.ty_name)?;
1461                    } else {
1462                        self.write_type(module, result.ty)?;
1463                    }
1464                }
1465            }
1466        } else {
1467            write!(self.out, "void")?;
1468        }
1469
1470        // Write function name
1471        write!(self.out, " {name}(")?;
1472
1473        let need_workgroup_variables_initialization =
1474            self.need_workgroup_variables_initialization(func_ctx, module);
1475
1476        // Write function arguments for non entry point functions
1477        match func_ctx.ty {
1478            back::FunctionType::Function(handle) => {
1479                for (index, arg) in func.arguments.iter().enumerate() {
1480                    if index != 0 {
1481                        write!(self.out, ", ")?;
1482                    }
1483                    // Write argument type
1484                    let arg_ty = match module.types[arg.ty].inner {
1485                        // pointers in function arguments are expected and resolve to `inout`
1486                        TypeInner::Pointer { base, .. } => {
1487                            //TODO: can we narrow this down to just `in` when possible?
1488                            write!(self.out, "inout ")?;
1489                            base
1490                        }
1491                        _ => arg.ty,
1492                    };
1493                    self.write_type(module, arg_ty)?;
1494
1495                    let argument_name =
1496                        &self.names[&NameKey::FunctionArgument(handle, index as u32)];
1497
1498                    // Write argument name. Space is important.
1499                    write!(self.out, " {argument_name}")?;
1500                    if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
1501                        self.write_array_size(module, base, size)?;
1502                    }
1503                }
1504            }
1505            back::FunctionType::EntryPoint(ep_index) => {
1506                if let Some(ref ep_input) =
1507                    self.entry_point_io.get(&(ep_index as usize)).unwrap().input
1508                {
1509                    write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
1510                } else {
1511                    let stage = module.entry_points[ep_index as usize].stage;
1512                    for (index, arg) in func.arguments.iter().enumerate() {
1513                        if index != 0 {
1514                            write!(self.out, ", ")?;
1515                        }
1516                        self.write_type(module, arg.ty)?;
1517
1518                        let argument_name =
1519                            &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
1520
1521                        write!(self.out, " {argument_name}")?;
1522                        if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
1523                            self.write_array_size(module, base, size)?;
1524                        }
1525
1526                        self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
1527                    }
1528                }
1529                if need_workgroup_variables_initialization {
1530                    if self
1531                        .entry_point_io
1532                        .get(&(ep_index as usize))
1533                        .unwrap()
1534                        .input
1535                        .is_some()
1536                        || !func.arguments.is_empty()
1537                    {
1538                        write!(self.out, ", ")?;
1539                    }
1540                    write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
1541                }
1542            }
1543        }
1544        // Ends of arguments
1545        write!(self.out, ")")?;
1546
1547        // Write semantic if it present
1548        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1549            let stage = module.entry_points[index as usize].stage;
1550            if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
1551                self.write_semantic(binding, Some((stage, Io::Output)))?;
1552            }
1553        }
1554
1555        // Function body start
1556        writeln!(self.out)?;
1557        writeln!(self.out, "{{")?;
1558
1559        if need_workgroup_variables_initialization {
1560            self.write_workgroup_variables_initialization(func_ctx, module)?;
1561        }
1562
1563        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1564            self.write_ep_arguments_initialization(module, func, index)?;
1565        }
1566
1567        // Write function local variables
1568        for (handle, local) in func.local_variables.iter() {
1569            // Write indentation (only for readability)
1570            write!(self.out, "{}", back::INDENT)?;
1571
1572            // Write the local name
1573            // The leading space is important
1574            self.write_type(module, local.ty)?;
1575            write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
1576            // Write size for array type
1577            if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
1578                self.write_array_size(module, base, size)?;
1579            }
1580
1581            match module.types[local.ty].inner {
1582                // from https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#tracerayinline-example-1 it seems that ray queries shouldn't be zeroed
1583                TypeInner::RayQuery { .. } => {}
1584                _ => {
1585                    write!(self.out, " = ")?;
1586                    // Write the local initializer if needed
1587                    if let Some(init) = local.init {
1588                        self.write_expr(module, init, func_ctx)?;
1589                    } else {
1590                        // Zero initialize local variables
1591                        self.write_default_init(module, local.ty)?;
1592                    }
1593                }
1594            }
1595            // Finish the local with `;` and add a newline (only for readability)
1596            writeln!(self.out, ";")?
1597        }
1598
1599        if !func.local_variables.is_empty() {
1600            writeln!(self.out)?;
1601        }
1602
1603        // Write the function body (statement list)
1604        for sta in func.body.iter() {
1605            // The indentation should always be 1 when writing the function body
1606            self.write_stmt(module, sta, func_ctx, back::Level(1))?;
1607        }
1608
1609        writeln!(self.out, "}}")?;
1610
1611        self.named_expressions.clear();
1612
1613        Ok(())
1614    }
1615
1616    fn need_workgroup_variables_initialization(
1617        &mut self,
1618        func_ctx: &back::FunctionCtx,
1619        module: &Module,
1620    ) -> bool {
1621        self.options.zero_initialize_workgroup_memory
1622            && func_ctx.ty.is_compute_entry_point(module)
1623            && module.global_variables.iter().any(|(handle, var)| {
1624                !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1625            })
1626    }
1627
1628    fn write_workgroup_variables_initialization(
1629        &mut self,
1630        func_ctx: &back::FunctionCtx,
1631        module: &Module,
1632    ) -> BackendResult {
1633        let level = back::Level(1);
1634
1635        writeln!(
1636            self.out,
1637            "{level}if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {{"
1638        )?;
1639
1640        let vars = module.global_variables.iter().filter(|&(handle, var)| {
1641            !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1642        });
1643
1644        for (handle, var) in vars {
1645            let name = &self.names[&NameKey::GlobalVariable(handle)];
1646            write!(self.out, "{}{} = ", level.next(), name)?;
1647            self.write_default_init(module, var.ty)?;
1648            writeln!(self.out, ";")?;
1649        }
1650
1651        writeln!(self.out, "{level}}}")?;
1652        self.write_control_barrier(crate::Barrier::WORK_GROUP, level)
1653    }
1654
1655    /// Helper method used to write switches
1656    fn write_switch(
1657        &mut self,
1658        module: &Module,
1659        func_ctx: &back::FunctionCtx<'_>,
1660        level: back::Level,
1661        selector: Handle<crate::Expression>,
1662        cases: &[crate::SwitchCase],
1663    ) -> BackendResult {
1664        // Write all cases
1665        let indent_level_1 = level.next();
1666        let indent_level_2 = indent_level_1.next();
1667
1668        // See docs of `back::continue_forward` module.
1669        if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) {
1670            writeln!(self.out, "{level}bool {variable} = false;",)?;
1671        };
1672
1673        // Check if there is only one body, by seeing if all except the last case are fall through
1674        // with empty bodies. FXC doesn't handle these switches correctly, so
1675        // we generate a `do {} while(false);` loop instead. There must be a default case, so there
1676        // is no need to check if one of the cases would have matched.
1677        let one_body = cases
1678            .iter()
1679            .rev()
1680            .skip(1)
1681            .all(|case| case.fall_through && case.body.is_empty());
1682        if one_body {
1683            // Start the do-while
1684            writeln!(self.out, "{level}do {{")?;
1685            // Note: Expressions have no side-effects so we don't need to emit selector expression.
1686
1687            // Body
1688            if let Some(case) = cases.last() {
1689                for sta in case.body.iter() {
1690                    self.write_stmt(module, sta, func_ctx, indent_level_1)?;
1691                }
1692            }
1693            // End do-while
1694            writeln!(self.out, "{level}}} while(false);")?;
1695        } else {
1696            // Start the switch
1697            write!(self.out, "{level}")?;
1698            write!(self.out, "switch(")?;
1699            self.write_expr(module, selector, func_ctx)?;
1700            writeln!(self.out, ") {{")?;
1701
1702            for (i, case) in cases.iter().enumerate() {
1703                match case.value {
1704                    crate::SwitchValue::I32(value) => {
1705                        write!(self.out, "{indent_level_1}case {value}:")?
1706                    }
1707                    crate::SwitchValue::U32(value) => {
1708                        write!(self.out, "{indent_level_1}case {value}u:")?
1709                    }
1710                    crate::SwitchValue::Default => write!(self.out, "{indent_level_1}default:")?,
1711                }
1712
1713                // The new block is not only stylistic, it plays a role here:
1714                // We might end up having to write the same case body
1715                // multiple times due to FXC not supporting fallthrough.
1716                // Therefore, some `Expression`s written by `Statement::Emit`
1717                // will end up having the same name (`_expr<handle_index>`).
1718                // So we need to put each case in its own scope.
1719                let write_block_braces = !(case.fall_through && case.body.is_empty());
1720                if write_block_braces {
1721                    writeln!(self.out, " {{")?;
1722                } else {
1723                    writeln!(self.out)?;
1724                }
1725
1726                // Although FXC does support a series of case clauses before
1727                // a block[^yes], it does not support fallthrough from a
1728                // non-empty case block to the next[^no]. If this case has a
1729                // non-empty body with a fallthrough, emulate that by
1730                // duplicating the bodies of all the cases it would fall
1731                // into as extensions of this case's own body. This makes
1732                // the HLSL output potentially quadratic in the size of the
1733                // Naga IR.
1734                //
1735                // [^yes]: ```hlsl
1736                // case 1:
1737                // case 2: do_stuff()
1738                // ```
1739                // [^no]: ```hlsl
1740                // case 1: do_this();
1741                // case 2: do_that();
1742                // ```
1743                if case.fall_through && !case.body.is_empty() {
1744                    let curr_len = i + 1;
1745                    let end_case_idx = curr_len
1746                        + cases
1747                            .iter()
1748                            .skip(curr_len)
1749                            .position(|case| !case.fall_through)
1750                            .unwrap();
1751                    let indent_level_3 = indent_level_2.next();
1752                    for case in &cases[i..=end_case_idx] {
1753                        writeln!(self.out, "{indent_level_2}{{")?;
1754                        let prev_len = self.named_expressions.len();
1755                        for sta in case.body.iter() {
1756                            self.write_stmt(module, sta, func_ctx, indent_level_3)?;
1757                        }
1758                        // Clear all named expressions that were previously inserted by the statements in the block
1759                        self.named_expressions.truncate(prev_len);
1760                        writeln!(self.out, "{indent_level_2}}}")?;
1761                    }
1762
1763                    let last_case = &cases[end_case_idx];
1764                    if last_case.body.last().is_none_or(|s| !s.is_terminator()) {
1765                        writeln!(self.out, "{indent_level_2}break;")?;
1766                    }
1767                } else {
1768                    for sta in case.body.iter() {
1769                        self.write_stmt(module, sta, func_ctx, indent_level_2)?;
1770                    }
1771                    if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) {
1772                        writeln!(self.out, "{indent_level_2}break;")?;
1773                    }
1774                }
1775
1776                if write_block_braces {
1777                    writeln!(self.out, "{indent_level_1}}}")?;
1778                }
1779            }
1780
1781            writeln!(self.out, "{level}}}")?;
1782        }
1783
1784        // Handle any forwarded continue statements.
1785        use back::continue_forward::ExitControlFlow;
1786        let op = match self.continue_ctx.exit_switch() {
1787            ExitControlFlow::None => None,
1788            ExitControlFlow::Continue { variable } => Some(("continue", variable)),
1789            ExitControlFlow::Break { variable } => Some(("break", variable)),
1790        };
1791        if let Some((control_flow, variable)) = op {
1792            writeln!(self.out, "{level}if ({variable}) {{")?;
1793            writeln!(self.out, "{indent_level_1}{control_flow};")?;
1794            writeln!(self.out, "{level}}}")?;
1795        }
1796
1797        Ok(())
1798    }
1799
1800    /// Helper method used to write statements
1801    ///
1802    /// # Notes
1803    /// Always adds a newline
1804    fn write_stmt(
1805        &mut self,
1806        module: &Module,
1807        stmt: &crate::Statement,
1808        func_ctx: &back::FunctionCtx<'_>,
1809        level: back::Level,
1810    ) -> BackendResult {
1811        use crate::Statement;
1812
1813        match *stmt {
1814            Statement::Emit(ref range) => {
1815                for handle in range.clone() {
1816                    let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
1817                    let expr_name = if ptr_class.is_some() {
1818                        // HLSL can't save a pointer-valued expression in a variable,
1819                        // but we shouldn't ever need to: they should never be named expressions,
1820                        // and none of the expression types flagged by bake_ref_count can be pointer-valued.
1821                        None
1822                    } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
1823                        // Front end provides names for all variables at the start of writing.
1824                        // But we write them to step by step. We need to recache them
1825                        // Otherwise, we could accidentally write variable name instead of full expression.
1826                        // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
1827                        Some(self.namer.call(name))
1828                    } else if self.need_bake_expressions.contains(&handle) {
1829                        Some(Baked(handle).to_string())
1830                    } else {
1831                        None
1832                    };
1833
1834                    if let Some(name) = expr_name {
1835                        write!(self.out, "{level}")?;
1836                        self.write_named_expr(module, handle, name, handle, func_ctx)?;
1837                    }
1838                }
1839            }
1840            // TODO: copy-paste from glsl-out
1841            Statement::Block(ref block) => {
1842                write!(self.out, "{level}")?;
1843                writeln!(self.out, "{{")?;
1844                for sta in block.iter() {
1845                    // Increase the indentation to help with readability
1846                    self.write_stmt(module, sta, func_ctx, level.next())?
1847                }
1848                writeln!(self.out, "{level}}}")?
1849            }
1850            // TODO: copy-paste from glsl-out
1851            Statement::If {
1852                condition,
1853                ref accept,
1854                ref reject,
1855            } => {
1856                write!(self.out, "{level}")?;
1857                write!(self.out, "if (")?;
1858                self.write_expr(module, condition, func_ctx)?;
1859                writeln!(self.out, ") {{")?;
1860
1861                let l2 = level.next();
1862                for sta in accept {
1863                    // Increase indentation to help with readability
1864                    self.write_stmt(module, sta, func_ctx, l2)?;
1865                }
1866
1867                // If there are no statements in the reject block we skip writing it
1868                // This is only for readability
1869                if !reject.is_empty() {
1870                    writeln!(self.out, "{level}}} else {{")?;
1871
1872                    for sta in reject {
1873                        // Increase indentation to help with readability
1874                        self.write_stmt(module, sta, func_ctx, l2)?;
1875                    }
1876                }
1877
1878                writeln!(self.out, "{level}}}")?
1879            }
1880            // TODO: copy-paste from glsl-out
1881            Statement::Kill => writeln!(self.out, "{level}discard;")?,
1882            Statement::Return { value: None } => {
1883                writeln!(self.out, "{level}return;")?;
1884            }
1885            Statement::Return { value: Some(expr) } => {
1886                let base_ty_res = &func_ctx.info[expr].ty;
1887                let mut resolved = base_ty_res.inner_with(&module.types);
1888                if let TypeInner::Pointer { base, space: _ } = *resolved {
1889                    resolved = &module.types[base].inner;
1890                }
1891
1892                if let TypeInner::Struct { .. } = *resolved {
1893                    // We can safely unwrap here, since we now we working with struct
1894                    let ty = base_ty_res.handle().unwrap();
1895                    let struct_name = &self.names[&NameKey::Type(ty)];
1896                    let variable_name = self.namer.call(&struct_name.to_lowercase());
1897                    write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
1898                    self.write_expr(module, expr, func_ctx)?;
1899                    writeln!(self.out, ";")?;
1900
1901                    // for entry point returns, we may need to reshuffle the outputs into a different struct
1902                    let ep_output = match func_ctx.ty {
1903                        back::FunctionType::Function(_) => None,
1904                        back::FunctionType::EntryPoint(index) => self
1905                            .entry_point_io
1906                            .get(&(index as usize))
1907                            .unwrap()
1908                            .output
1909                            .as_ref(),
1910                    };
1911                    let final_name = match ep_output {
1912                        Some(ep_output) => {
1913                            let final_name = self.namer.call(&variable_name);
1914                            write!(
1915                                self.out,
1916                                "{}const {} {} = {{ ",
1917                                level, ep_output.ty_name, final_name,
1918                            )?;
1919                            for (index, m) in ep_output.members.iter().enumerate() {
1920                                if index != 0 {
1921                                    write!(self.out, ", ")?;
1922                                }
1923                                let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
1924                                write!(self.out, "{variable_name}.{member_name}")?;
1925                            }
1926                            writeln!(self.out, " }};")?;
1927                            final_name
1928                        }
1929                        None => variable_name,
1930                    };
1931                    writeln!(self.out, "{level}return {final_name};")?;
1932                } else {
1933                    write!(self.out, "{level}return ")?;
1934                    self.write_expr(module, expr, func_ctx)?;
1935                    writeln!(self.out, ";")?
1936                }
1937            }
1938            Statement::Store { pointer, value } => {
1939                let ty_inner = func_ctx.resolve_type(pointer, &module.types);
1940                if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
1941                    let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
1942                    self.write_storage_store(
1943                        module,
1944                        var_handle,
1945                        StoreValue::Expression(value),
1946                        func_ctx,
1947                        level,
1948                    )?;
1949                } else {
1950                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1951                    // See the module-level block comment in mod.rs for details.
1952                    //
1953                    // We handle matrix Stores here directly (including sub accesses for Vectors and Scalars).
1954                    // Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads).
1955                    struct MatrixAccess {
1956                        base: Handle<crate::Expression>,
1957                        index: u32,
1958                    }
1959                    enum Index {
1960                        Expression(Handle<crate::Expression>),
1961                        Static(u32),
1962                    }
1963
1964                    let get_members = |expr: Handle<crate::Expression>| {
1965                        let resolved = func_ctx.resolve_type(expr, &module.types);
1966                        match *resolved {
1967                            TypeInner::Pointer { base, .. } => match module.types[base].inner {
1968                                TypeInner::Struct { ref members, .. } => Some(members),
1969                                _ => None,
1970                            },
1971                            _ => None,
1972                        }
1973                    };
1974
1975                    let mut matrix = None;
1976                    let mut vector = None;
1977                    let mut scalar = None;
1978
1979                    let mut current_expr = pointer;
1980                    for _ in 0..3 {
1981                        let resolved = func_ctx.resolve_type(current_expr, &module.types);
1982
1983                        match (resolved, &func_ctx.expressions[current_expr]) {
1984                            (
1985                                &TypeInner::Pointer { base: ty, .. },
1986                                &crate::Expression::AccessIndex { base, index },
1987                            ) if matches!(
1988                                module.types[ty].inner,
1989                                TypeInner::Matrix {
1990                                    rows: crate::VectorSize::Bi,
1991                                    ..
1992                                }
1993                            ) && get_members(base)
1994                                .map(|members| members[index as usize].binding.is_none())
1995                                == Some(true) =>
1996                            {
1997                                matrix = Some(MatrixAccess { base, index });
1998                                break;
1999                            }
2000                            (
2001                                &TypeInner::ValuePointer {
2002                                    size: Some(crate::VectorSize::Bi),
2003                                    ..
2004                                },
2005                                &crate::Expression::Access { base, index },
2006                            ) => {
2007                                vector = Some(Index::Expression(index));
2008                                current_expr = base;
2009                            }
2010                            (
2011                                &TypeInner::ValuePointer {
2012                                    size: Some(crate::VectorSize::Bi),
2013                                    ..
2014                                },
2015                                &crate::Expression::AccessIndex { base, index },
2016                            ) => {
2017                                vector = Some(Index::Static(index));
2018                                current_expr = base;
2019                            }
2020                            (
2021                                &TypeInner::ValuePointer { size: None, .. },
2022                                &crate::Expression::Access { base, index },
2023                            ) => {
2024                                scalar = Some(Index::Expression(index));
2025                                current_expr = base;
2026                            }
2027                            (
2028                                &TypeInner::ValuePointer { size: None, .. },
2029                                &crate::Expression::AccessIndex { base, index },
2030                            ) => {
2031                                scalar = Some(Index::Static(index));
2032                                current_expr = base;
2033                            }
2034                            _ => break,
2035                        }
2036                    }
2037
2038                    write!(self.out, "{level}")?;
2039
2040                    if let Some(MatrixAccess { index, base }) = matrix {
2041                        let base_ty_res = &func_ctx.info[base].ty;
2042                        let resolved = base_ty_res.inner_with(&module.types);
2043                        let ty = match *resolved {
2044                            TypeInner::Pointer { base, .. } => base,
2045                            _ => base_ty_res.handle().unwrap(),
2046                        };
2047
2048                        if let Some(Index::Static(vec_index)) = vector {
2049                            self.write_expr(module, base, func_ctx)?;
2050                            write!(
2051                                self.out,
2052                                ".{}_{}",
2053                                &self.names[&NameKey::StructMember(ty, index)],
2054                                vec_index
2055                            )?;
2056
2057                            if let Some(scalar_index) = scalar {
2058                                write!(self.out, "[")?;
2059                                match scalar_index {
2060                                    Index::Static(index) => {
2061                                        write!(self.out, "{index}")?;
2062                                    }
2063                                    Index::Expression(index) => {
2064                                        self.write_expr(module, index, func_ctx)?;
2065                                    }
2066                                }
2067                                write!(self.out, "]")?;
2068                            }
2069
2070                            write!(self.out, " = ")?;
2071                            self.write_expr(module, value, func_ctx)?;
2072                            writeln!(self.out, ";")?;
2073                        } else {
2074                            let access = WrappedStructMatrixAccess { ty, index };
2075                            match (&vector, &scalar) {
2076                                (&Some(_), &Some(_)) => {
2077                                    self.write_wrapped_struct_matrix_set_scalar_function_name(
2078                                        access,
2079                                    )?;
2080                                }
2081                                (&Some(_), &None) => {
2082                                    self.write_wrapped_struct_matrix_set_vec_function_name(access)?;
2083                                }
2084                                (&None, _) => {
2085                                    self.write_wrapped_struct_matrix_set_function_name(access)?;
2086                                }
2087                            }
2088
2089                            write!(self.out, "(")?;
2090                            self.write_expr(module, base, func_ctx)?;
2091                            write!(self.out, ", ")?;
2092                            self.write_expr(module, value, func_ctx)?;
2093
2094                            if let Some(Index::Expression(vec_index)) = vector {
2095                                write!(self.out, ", ")?;
2096                                self.write_expr(module, vec_index, func_ctx)?;
2097
2098                                if let Some(scalar_index) = scalar {
2099                                    write!(self.out, ", ")?;
2100                                    match scalar_index {
2101                                        Index::Static(index) => {
2102                                            write!(self.out, "{index}")?;
2103                                        }
2104                                        Index::Expression(index) => {
2105                                            self.write_expr(module, index, func_ctx)?;
2106                                        }
2107                                    }
2108                                }
2109                            }
2110                            writeln!(self.out, ");")?;
2111                        }
2112                    } else {
2113                        // We handle `Store`s to __matCx2 column vectors and scalar elements via
2114                        // the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
2115                        struct MatrixData {
2116                            columns: crate::VectorSize,
2117                            base: Handle<crate::Expression>,
2118                        }
2119
2120                        enum Index {
2121                            Expression(Handle<crate::Expression>),
2122                            Static(u32),
2123                        }
2124
2125                        let mut matrix = None;
2126                        let mut vector = None;
2127                        let mut scalar = None;
2128
2129                        let mut current_expr = pointer;
2130                        for _ in 0..3 {
2131                            let resolved = func_ctx.resolve_type(current_expr, &module.types);
2132                            match (resolved, &func_ctx.expressions[current_expr]) {
2133                                (
2134                                    &TypeInner::ValuePointer {
2135                                        size: Some(crate::VectorSize::Bi),
2136                                        ..
2137                                    },
2138                                    &crate::Expression::Access { base, index },
2139                                ) => {
2140                                    vector = Some(index);
2141                                    current_expr = base;
2142                                }
2143                                (
2144                                    &TypeInner::ValuePointer { size: None, .. },
2145                                    &crate::Expression::Access { base, index },
2146                                ) => {
2147                                    scalar = Some(Index::Expression(index));
2148                                    current_expr = base;
2149                                }
2150                                (
2151                                    &TypeInner::ValuePointer { size: None, .. },
2152                                    &crate::Expression::AccessIndex { base, index },
2153                                ) => {
2154                                    scalar = Some(Index::Static(index));
2155                                    current_expr = base;
2156                                }
2157                                _ => {
2158                                    if let Some(MatrixType {
2159                                        columns,
2160                                        rows: crate::VectorSize::Bi,
2161                                        width: 4,
2162                                    }) = get_inner_matrix_of_struct_array_member(
2163                                        module,
2164                                        current_expr,
2165                                        func_ctx,
2166                                        true,
2167                                    ) {
2168                                        matrix = Some(MatrixData {
2169                                            columns,
2170                                            base: current_expr,
2171                                        });
2172                                    }
2173
2174                                    break;
2175                                }
2176                            }
2177                        }
2178
2179                        if let (Some(MatrixData { columns, base }), Some(vec_index)) =
2180                            (matrix, vector)
2181                        {
2182                            if scalar.is_some() {
2183                                write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
2184                            } else {
2185                                write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
2186                            }
2187                            write!(self.out, "(")?;
2188                            self.write_expr(module, base, func_ctx)?;
2189                            write!(self.out, ", ")?;
2190                            self.write_expr(module, vec_index, func_ctx)?;
2191
2192                            if let Some(scalar_index) = scalar {
2193                                write!(self.out, ", ")?;
2194                                match scalar_index {
2195                                    Index::Static(index) => {
2196                                        write!(self.out, "{index}")?;
2197                                    }
2198                                    Index::Expression(index) => {
2199                                        self.write_expr(module, index, func_ctx)?;
2200                                    }
2201                                }
2202                            }
2203
2204                            write!(self.out, ", ")?;
2205                            self.write_expr(module, value, func_ctx)?;
2206
2207                            writeln!(self.out, ");")?;
2208                        } else {
2209                            self.write_expr(module, pointer, func_ctx)?;
2210                            write!(self.out, " = ")?;
2211
2212                            // We cast the RHS of this store in cases where the LHS
2213                            // is a struct member with type:
2214                            //  - matCx2 or
2215                            //  - a (possibly nested) array of matCx2's
2216                            if let Some(MatrixType {
2217                                columns,
2218                                rows: crate::VectorSize::Bi,
2219                                width: 4,
2220                            }) = get_inner_matrix_of_struct_array_member(
2221                                module, pointer, func_ctx, false,
2222                            ) {
2223                                let mut resolved = func_ctx.resolve_type(pointer, &module.types);
2224                                if let TypeInner::Pointer { base, .. } = *resolved {
2225                                    resolved = &module.types[base].inner;
2226                                }
2227
2228                                write!(self.out, "(__mat{}x2", columns as u8)?;
2229                                if let TypeInner::Array { base, size, .. } = *resolved {
2230                                    self.write_array_size(module, base, size)?;
2231                                }
2232                                write!(self.out, ")")?;
2233                            }
2234
2235                            self.write_expr(module, value, func_ctx)?;
2236                            writeln!(self.out, ";")?
2237                        }
2238                    }
2239                }
2240            }
2241            Statement::Loop {
2242                ref body,
2243                ref continuing,
2244                break_if,
2245            } => {
2246                let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2247                let gate_name = (!continuing.is_empty() || break_if.is_some())
2248                    .then(|| self.namer.call("loop_init"));
2249
2250                if let Some((ref decl, _)) = force_loop_bound_statements {
2251                    writeln!(self.out, "{decl}")?;
2252                }
2253                if let Some(ref gate_name) = gate_name {
2254                    writeln!(self.out, "{level}bool {gate_name} = true;")?;
2255                }
2256
2257                self.continue_ctx.enter_loop();
2258                writeln!(self.out, "{level}while(true) {{")?;
2259                if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2260                    writeln!(self.out, "{break_and_inc}")?;
2261                }
2262                let l2 = level.next();
2263                if let Some(gate_name) = gate_name {
2264                    writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
2265                    let l3 = l2.next();
2266                    for sta in continuing.iter() {
2267                        self.write_stmt(module, sta, func_ctx, l3)?;
2268                    }
2269                    if let Some(condition) = break_if {
2270                        write!(self.out, "{l3}if (")?;
2271                        self.write_expr(module, condition, func_ctx)?;
2272                        writeln!(self.out, ") {{")?;
2273                        writeln!(self.out, "{}break;", l3.next())?;
2274                        writeln!(self.out, "{l3}}}")?;
2275                    }
2276                    writeln!(self.out, "{l2}}}")?;
2277                    writeln!(self.out, "{l2}{gate_name} = false;")?;
2278                }
2279
2280                for sta in body.iter() {
2281                    self.write_stmt(module, sta, func_ctx, l2)?;
2282                }
2283
2284                writeln!(self.out, "{level}}}")?;
2285                self.continue_ctx.exit_loop();
2286            }
2287            Statement::Break => writeln!(self.out, "{level}break;")?,
2288            Statement::Continue => {
2289                if let Some(variable) = self.continue_ctx.continue_encountered() {
2290                    writeln!(self.out, "{level}{variable} = true;")?;
2291                    writeln!(self.out, "{level}break;")?
2292                } else {
2293                    writeln!(self.out, "{level}continue;")?
2294                }
2295            }
2296            Statement::ControlBarrier(barrier) => {
2297                self.write_control_barrier(barrier, level)?;
2298            }
2299            Statement::MemoryBarrier(barrier) => {
2300                self.write_memory_barrier(barrier, level)?;
2301            }
2302            Statement::ImageStore {
2303                image,
2304                coordinate,
2305                array_index,
2306                value,
2307            } => {
2308                write!(self.out, "{level}")?;
2309                self.write_expr(module, image, func_ctx)?;
2310
2311                write!(self.out, "[")?;
2312                if let Some(index) = array_index {
2313                    // Array index accepted only for texture_storage_2d_array, so we can safety use int3(coordinate, array_index) here
2314                    write!(self.out, "int3(")?;
2315                    self.write_expr(module, coordinate, func_ctx)?;
2316                    write!(self.out, ", ")?;
2317                    self.write_expr(module, index, func_ctx)?;
2318                    write!(self.out, ")")?;
2319                } else {
2320                    self.write_expr(module, coordinate, func_ctx)?;
2321                }
2322                write!(self.out, "]")?;
2323
2324                write!(self.out, " = ")?;
2325                self.write_expr(module, value, func_ctx)?;
2326                writeln!(self.out, ";")?;
2327            }
2328            Statement::Call {
2329                function,
2330                ref arguments,
2331                result,
2332            } => {
2333                write!(self.out, "{level}")?;
2334                if let Some(expr) = result {
2335                    write!(self.out, "const ")?;
2336                    let name = Baked(expr).to_string();
2337                    let expr_ty = &func_ctx.info[expr].ty;
2338                    let ty_inner = match *expr_ty {
2339                        proc::TypeResolution::Handle(handle) => {
2340                            self.write_type(module, handle)?;
2341                            &module.types[handle].inner
2342                        }
2343                        proc::TypeResolution::Value(ref value) => {
2344                            self.write_value_type(module, value)?;
2345                            value
2346                        }
2347                    };
2348                    write!(self.out, " {name}")?;
2349                    if let TypeInner::Array { base, size, .. } = *ty_inner {
2350                        self.write_array_size(module, base, size)?;
2351                    }
2352                    write!(self.out, " = ")?;
2353                    self.named_expressions.insert(expr, name);
2354                }
2355                let func_name = &self.names[&NameKey::Function(function)];
2356                write!(self.out, "{func_name}(")?;
2357                for (index, argument) in arguments.iter().enumerate() {
2358                    if index != 0 {
2359                        write!(self.out, ", ")?;
2360                    }
2361                    self.write_expr(module, *argument, func_ctx)?;
2362                }
2363                writeln!(self.out, ");")?
2364            }
2365            Statement::Atomic {
2366                pointer,
2367                ref fun,
2368                value,
2369                result,
2370            } => {
2371                write!(self.out, "{level}")?;
2372                let res_var_info = if let Some(res_handle) = result {
2373                    let name = Baked(res_handle).to_string();
2374                    match func_ctx.info[res_handle].ty {
2375                        proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2376                        proc::TypeResolution::Value(ref value) => {
2377                            self.write_value_type(module, value)?
2378                        }
2379                    };
2380                    write!(self.out, " {name}; ")?;
2381                    self.named_expressions.insert(res_handle, name.clone());
2382                    Some((res_handle, name))
2383                } else {
2384                    None
2385                };
2386                let pointer_space = func_ctx
2387                    .resolve_type(pointer, &module.types)
2388                    .pointer_space()
2389                    .unwrap();
2390                let fun_str = fun.to_hlsl_suffix();
2391                let compare_expr = match *fun {
2392                    crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
2393                    _ => None,
2394                };
2395                match pointer_space {
2396                    crate::AddressSpace::WorkGroup => {
2397                        write!(self.out, "Interlocked{fun_str}(")?;
2398                        self.write_expr(module, pointer, func_ctx)?;
2399                        self.emit_hlsl_atomic_tail(
2400                            module,
2401                            func_ctx,
2402                            fun,
2403                            compare_expr,
2404                            value,
2405                            &res_var_info,
2406                        )?;
2407                    }
2408                    crate::AddressSpace::Storage { .. } => {
2409                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2410                        let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2411                        let width = match func_ctx.resolve_type(value, &module.types) {
2412                            &TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
2413                            _ => "",
2414                        };
2415                        write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
2416                        let chain = mem::take(&mut self.temp_access_chain);
2417                        self.write_storage_address(module, &chain, func_ctx)?;
2418                        self.temp_access_chain = chain;
2419                        self.emit_hlsl_atomic_tail(
2420                            module,
2421                            func_ctx,
2422                            fun,
2423                            compare_expr,
2424                            value,
2425                            &res_var_info,
2426                        )?;
2427                    }
2428                    ref other => {
2429                        return Err(Error::Custom(format!(
2430                            "invalid address space {other:?} for atomic statement"
2431                        )))
2432                    }
2433                }
2434                if let Some(cmp) = compare_expr {
2435                    if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
2436                        write!(
2437                            self.out,
2438                            "{level}{res_name}.exchanged = ({res_name}.old_value == "
2439                        )?;
2440                        self.write_expr(module, cmp, func_ctx)?;
2441                        writeln!(self.out, ");")?;
2442                    }
2443                }
2444            }
2445            Statement::ImageAtomic {
2446                image,
2447                coordinate,
2448                array_index,
2449                fun,
2450                value,
2451            } => {
2452                write!(self.out, "{level}")?;
2453
2454                let fun_str = fun.to_hlsl_suffix();
2455                write!(self.out, "Interlocked{fun_str}(")?;
2456                self.write_expr(module, image, func_ctx)?;
2457                write!(self.out, "[")?;
2458                self.write_texture_coordinates(
2459                    "int",
2460                    coordinate,
2461                    array_index,
2462                    None,
2463                    module,
2464                    func_ctx,
2465                )?;
2466                write!(self.out, "],")?;
2467
2468                self.write_expr(module, value, func_ctx)?;
2469                writeln!(self.out, ");")?;
2470            }
2471            Statement::WorkGroupUniformLoad { pointer, result } => {
2472                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2473                write!(self.out, "{level}")?;
2474                let name = Baked(result).to_string();
2475                self.write_named_expr(module, pointer, name, result, func_ctx)?;
2476
2477                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2478            }
2479            Statement::Switch {
2480                selector,
2481                ref cases,
2482            } => {
2483                self.write_switch(module, func_ctx, level, selector, cases)?;
2484            }
2485            Statement::RayQuery { query, ref fun } => match *fun {
2486                RayQueryFunction::Initialize {
2487                    acceleration_structure,
2488                    descriptor,
2489                } => {
2490                    write!(self.out, "{level}")?;
2491                    self.write_expr(module, query, func_ctx)?;
2492                    write!(self.out, ".TraceRayInline(")?;
2493                    self.write_expr(module, acceleration_structure, func_ctx)?;
2494                    write!(self.out, ", ")?;
2495                    self.write_expr(module, descriptor, func_ctx)?;
2496                    write!(self.out, ".flags, ")?;
2497                    self.write_expr(module, descriptor, func_ctx)?;
2498                    write!(self.out, ".cull_mask, ")?;
2499                    write!(self.out, "RayDescFromRayDesc_(")?;
2500                    self.write_expr(module, descriptor, func_ctx)?;
2501                    writeln!(self.out, "));")?;
2502                }
2503                RayQueryFunction::Proceed { result } => {
2504                    write!(self.out, "{level}")?;
2505                    let name = Baked(result).to_string();
2506                    write!(self.out, "const bool {name} = ")?;
2507                    self.named_expressions.insert(result, name);
2508                    self.write_expr(module, query, func_ctx)?;
2509                    writeln!(self.out, ".Proceed();")?;
2510                }
2511                RayQueryFunction::GenerateIntersection { hit_t } => {
2512                    write!(self.out, "{level}")?;
2513                    self.write_expr(module, query, func_ctx)?;
2514                    write!(self.out, ".CommitProceduralPrimitiveHit(")?;
2515                    self.write_expr(module, hit_t, func_ctx)?;
2516                    writeln!(self.out, ");")?;
2517                }
2518                RayQueryFunction::ConfirmIntersection => {
2519                    write!(self.out, "{level}")?;
2520                    self.write_expr(module, query, func_ctx)?;
2521                    writeln!(self.out, ".CommitNonOpaqueTriangleHit();")?;
2522                }
2523                RayQueryFunction::Terminate => {
2524                    write!(self.out, "{level}")?;
2525                    self.write_expr(module, query, func_ctx)?;
2526                    writeln!(self.out, ".Abort();")?;
2527                }
2528            },
2529            Statement::SubgroupBallot { result, predicate } => {
2530                write!(self.out, "{level}")?;
2531                let name = Baked(result).to_string();
2532                write!(self.out, "const uint4 {name} = ")?;
2533                self.named_expressions.insert(result, name);
2534
2535                write!(self.out, "WaveActiveBallot(")?;
2536                match predicate {
2537                    Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
2538                    None => write!(self.out, "true")?,
2539                }
2540                writeln!(self.out, ");")?;
2541            }
2542            Statement::SubgroupCollectiveOperation {
2543                op,
2544                collective_op,
2545                argument,
2546                result,
2547            } => {
2548                write!(self.out, "{level}")?;
2549                write!(self.out, "const ")?;
2550                let name = Baked(result).to_string();
2551                match func_ctx.info[result].ty {
2552                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2553                    proc::TypeResolution::Value(ref value) => {
2554                        self.write_value_type(module, value)?
2555                    }
2556                };
2557                write!(self.out, " {name} = ")?;
2558                self.named_expressions.insert(result, name);
2559
2560                match (collective_op, op) {
2561                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
2562                        write!(self.out, "WaveActiveAllTrue(")?
2563                    }
2564                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
2565                        write!(self.out, "WaveActiveAnyTrue(")?
2566                    }
2567                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
2568                        write!(self.out, "WaveActiveSum(")?
2569                    }
2570                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
2571                        write!(self.out, "WaveActiveProduct(")?
2572                    }
2573                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
2574                        write!(self.out, "WaveActiveMax(")?
2575                    }
2576                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
2577                        write!(self.out, "WaveActiveMin(")?
2578                    }
2579                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
2580                        write!(self.out, "WaveActiveBitAnd(")?
2581                    }
2582                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
2583                        write!(self.out, "WaveActiveBitOr(")?
2584                    }
2585                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
2586                        write!(self.out, "WaveActiveBitXor(")?
2587                    }
2588                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
2589                        write!(self.out, "WavePrefixSum(")?
2590                    }
2591                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
2592                        write!(self.out, "WavePrefixProduct(")?
2593                    }
2594                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
2595                        self.write_expr(module, argument, func_ctx)?;
2596                        write!(self.out, " + WavePrefixSum(")?;
2597                    }
2598                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
2599                        self.write_expr(module, argument, func_ctx)?;
2600                        write!(self.out, " * WavePrefixProduct(")?;
2601                    }
2602                    _ => unimplemented!(),
2603                }
2604                self.write_expr(module, argument, func_ctx)?;
2605                writeln!(self.out, ");")?;
2606            }
2607            Statement::SubgroupGather {
2608                mode,
2609                argument,
2610                result,
2611            } => {
2612                write!(self.out, "{level}")?;
2613                write!(self.out, "const ")?;
2614                let name = Baked(result).to_string();
2615                match func_ctx.info[result].ty {
2616                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2617                    proc::TypeResolution::Value(ref value) => {
2618                        self.write_value_type(module, value)?
2619                    }
2620                };
2621                write!(self.out, " {name} = ")?;
2622                self.named_expressions.insert(result, name);
2623                match mode {
2624                    crate::GatherMode::BroadcastFirst => {
2625                        write!(self.out, "WaveReadLaneFirst(")?;
2626                        self.write_expr(module, argument, func_ctx)?;
2627                    }
2628                    crate::GatherMode::QuadBroadcast(index) => {
2629                        write!(self.out, "QuadReadLaneAt(")?;
2630                        self.write_expr(module, argument, func_ctx)?;
2631                        write!(self.out, ", ")?;
2632                        self.write_expr(module, index, func_ctx)?;
2633                    }
2634                    crate::GatherMode::QuadSwap(direction) => {
2635                        match direction {
2636                            crate::Direction::X => {
2637                                write!(self.out, "QuadReadAcrossX(")?;
2638                            }
2639                            crate::Direction::Y => {
2640                                write!(self.out, "QuadReadAcrossY(")?;
2641                            }
2642                            crate::Direction::Diagonal => {
2643                                write!(self.out, "QuadReadAcrossDiagonal(")?;
2644                            }
2645                        }
2646                        self.write_expr(module, argument, func_ctx)?;
2647                    }
2648                    _ => {
2649                        write!(self.out, "WaveReadLaneAt(")?;
2650                        self.write_expr(module, argument, func_ctx)?;
2651                        write!(self.out, ", ")?;
2652                        match mode {
2653                            crate::GatherMode::BroadcastFirst => unreachable!(),
2654                            crate::GatherMode::Broadcast(index)
2655                            | crate::GatherMode::Shuffle(index) => {
2656                                self.write_expr(module, index, func_ctx)?;
2657                            }
2658                            crate::GatherMode::ShuffleDown(index) => {
2659                                write!(self.out, "WaveGetLaneIndex() + ")?;
2660                                self.write_expr(module, index, func_ctx)?;
2661                            }
2662                            crate::GatherMode::ShuffleUp(index) => {
2663                                write!(self.out, "WaveGetLaneIndex() - ")?;
2664                                self.write_expr(module, index, func_ctx)?;
2665                            }
2666                            crate::GatherMode::ShuffleXor(index) => {
2667                                write!(self.out, "WaveGetLaneIndex() ^ ")?;
2668                                self.write_expr(module, index, func_ctx)?;
2669                            }
2670                            crate::GatherMode::QuadBroadcast(_) => unreachable!(),
2671                            crate::GatherMode::QuadSwap(_) => unreachable!(),
2672                        }
2673                    }
2674                }
2675                writeln!(self.out, ");")?;
2676            }
2677        }
2678
2679        Ok(())
2680    }
2681
2682    fn write_const_expression(
2683        &mut self,
2684        module: &Module,
2685        expr: Handle<crate::Expression>,
2686        arena: &crate::Arena<crate::Expression>,
2687    ) -> BackendResult {
2688        self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
2689            writer.write_const_expression(module, expr, arena)
2690        })
2691    }
2692
2693    pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
2694        match literal {
2695            crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
2696            crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
2697            crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
2698            crate::Literal::U32(value) => write!(self.out, "{value}u")?,
2699            // `-2147483648` is parsed by some compilers as unary negation of
2700            // positive 2147483648, which is too large for an int, causing
2701            // issues for some compilers. Neither DXC nor FXC appear to have
2702            // this problem, but this is not specified and could change. We
2703            // therefore use `-2147483647 - 1` as a precaution.
2704            crate::Literal::I32(value) if value == i32::MIN => {
2705                write!(self.out, "int({} - 1)", value + 1)?
2706            }
2707            // HLSL has no suffix for explicit i32 literals, but not using any suffix
2708            // makes the type ambiguous which prevents overload resolution from
2709            // working. So we explicitly use the int() constructor syntax.
2710            crate::Literal::I32(value) => write!(self.out, "int({value})")?,
2711            crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
2712            // I64 version of the minimum I32 value issue described above.
2713            crate::Literal::I64(value) if value == i64::MIN => {
2714                write!(self.out, "({}L - 1L)", value + 1)?;
2715            }
2716            crate::Literal::I64(value) => write!(self.out, "{value}L")?,
2717            crate::Literal::Bool(value) => write!(self.out, "{value}")?,
2718            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
2719                return Err(Error::Custom(
2720                    "Abstract types should not appear in IR presented to backends".into(),
2721                ));
2722            }
2723        }
2724        Ok(())
2725    }
2726
2727    fn write_possibly_const_expression<E>(
2728        &mut self,
2729        module: &Module,
2730        expr: Handle<crate::Expression>,
2731        expressions: &crate::Arena<crate::Expression>,
2732        write_expression: E,
2733    ) -> BackendResult
2734    where
2735        E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
2736    {
2737        use crate::Expression;
2738
2739        match expressions[expr] {
2740            Expression::Literal(literal) => {
2741                self.write_literal(literal)?;
2742            }
2743            Expression::Constant(handle) => {
2744                let constant = &module.constants[handle];
2745                if constant.name.is_some() {
2746                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
2747                } else {
2748                    self.write_const_expression(module, constant.init, &module.global_expressions)?;
2749                }
2750            }
2751            Expression::ZeroValue(ty) => {
2752                self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
2753                write!(self.out, "()")?;
2754            }
2755            Expression::Compose { ty, ref components } => {
2756                match module.types[ty].inner {
2757                    TypeInner::Struct { .. } | TypeInner::Array { .. } => {
2758                        self.write_wrapped_constructor_function_name(
2759                            module,
2760                            WrappedConstructor { ty },
2761                        )?;
2762                    }
2763                    _ => {
2764                        self.write_type(module, ty)?;
2765                    }
2766                };
2767                write!(self.out, "(")?;
2768                for (index, component) in components.iter().enumerate() {
2769                    if index != 0 {
2770                        write!(self.out, ", ")?;
2771                    }
2772                    write_expression(self, *component)?;
2773                }
2774                write!(self.out, ")")?;
2775            }
2776            Expression::Splat { size, value } => {
2777                // hlsl is not supported one value constructor
2778                // if we write, for example, int4(0), dxc returns error:
2779                // error: too few elements in vector initialization (expected 4 elements, have 1)
2780                let number_of_components = match size {
2781                    crate::VectorSize::Bi => "xx",
2782                    crate::VectorSize::Tri => "xxx",
2783                    crate::VectorSize::Quad => "xxxx",
2784                };
2785                write!(self.out, "(")?;
2786                write_expression(self, value)?;
2787                write!(self.out, ").{number_of_components}")?
2788            }
2789            _ => {
2790                return Err(Error::Override);
2791            }
2792        }
2793
2794        Ok(())
2795    }
2796
2797    /// Helper method to write expressions
2798    ///
2799    /// # Notes
2800    /// Doesn't add any newlines or leading/trailing spaces
2801    pub(super) fn write_expr(
2802        &mut self,
2803        module: &Module,
2804        expr: Handle<crate::Expression>,
2805        func_ctx: &back::FunctionCtx<'_>,
2806    ) -> BackendResult {
2807        use crate::Expression;
2808
2809        // Handle the special semantics of vertex_index/instance_index
2810        let ff_input = if self.options.special_constants_binding.is_some() {
2811            func_ctx.is_fixed_function_input(expr, module)
2812        } else {
2813            None
2814        };
2815        let closing_bracket = match ff_input {
2816            Some(crate::BuiltIn::VertexIndex) => {
2817                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
2818                ")"
2819            }
2820            Some(crate::BuiltIn::InstanceIndex) => {
2821                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
2822                ")"
2823            }
2824            Some(crate::BuiltIn::NumWorkGroups) => {
2825                // Note: despite their names (`FIRST_VERTEX` and `FIRST_INSTANCE`),
2826                // in compute shaders the special constants contain the number
2827                // of workgroups, which we are using here.
2828                write!(
2829                    self.out,
2830                    "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
2831                )?;
2832                return Ok(());
2833            }
2834            _ => "",
2835        };
2836
2837        if let Some(name) = self.named_expressions.get(&expr) {
2838            write!(self.out, "{name}{closing_bracket}")?;
2839            return Ok(());
2840        }
2841
2842        let expression = &func_ctx.expressions[expr];
2843
2844        match *expression {
2845            Expression::Literal(_)
2846            | Expression::Constant(_)
2847            | Expression::ZeroValue(_)
2848            | Expression::Compose { .. }
2849            | Expression::Splat { .. } => {
2850                self.write_possibly_const_expression(
2851                    module,
2852                    expr,
2853                    func_ctx.expressions,
2854                    |writer, expr| writer.write_expr(module, expr, func_ctx),
2855                )?;
2856            }
2857            Expression::Override(_) => return Err(Error::Override),
2858            // Avoid undefined behaviour for addition, subtraction, and
2859            // multiplication of signed integers by casting operands to
2860            // unsigned, performing the operation, then casting the result back
2861            // to signed.
2862            // TODO(#7109): This relies on the asint()/asuint() functions which only work
2863            // for 32-bit types, so we must find another solution for different bit widths.
2864            Expression::Binary {
2865                op:
2866                    op @ crate::BinaryOperator::Add
2867                    | op @ crate::BinaryOperator::Subtract
2868                    | op @ crate::BinaryOperator::Multiply,
2869                left,
2870                right,
2871            } if matches!(
2872                func_ctx.resolve_type(expr, &module.types).scalar(),
2873                Some(Scalar::I32)
2874            ) =>
2875            {
2876                write!(self.out, "asint(asuint(",)?;
2877                self.write_expr(module, left, func_ctx)?;
2878                write!(self.out, ") {} asuint(", back::binary_operation_str(op))?;
2879                self.write_expr(module, right, func_ctx)?;
2880                write!(self.out, "))")?;
2881            }
2882            // All of the multiplication can be expressed as `mul`,
2883            // except vector * vector, which needs to use the "*" operator.
2884            Expression::Binary {
2885                op: crate::BinaryOperator::Multiply,
2886                left,
2887                right,
2888            } if func_ctx.resolve_type(left, &module.types).is_matrix()
2889                || func_ctx.resolve_type(right, &module.types).is_matrix() =>
2890            {
2891                // We intentionally flip the order of multiplication as our matrices are implicitly transposed.
2892                write!(self.out, "mul(")?;
2893                self.write_expr(module, right, func_ctx)?;
2894                write!(self.out, ", ")?;
2895                self.write_expr(module, left, func_ctx)?;
2896                write!(self.out, ")")?;
2897            }
2898
2899            // WGSL says that floating-point division by zero should return
2900            // infinity. Microsoft's Direct3D 11 functional specification
2901            // (https://microsoft.github.io/DirectX-Specs/d3d/archive/D3D11_3_FunctionalSpec.htm)
2902            // says:
2903            //
2904            //     Divide by 0 produces +/- INF, except 0/0 which results in NaN.
2905            //
2906            // which is what we want. The DXIL specification for the FDiv
2907            // instruction corroborates this:
2908            //
2909            // https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/DXIL.rst#fdiv
2910            Expression::Binary {
2911                op: crate::BinaryOperator::Divide,
2912                left,
2913                right,
2914            } if matches!(
2915                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
2916                Some(ScalarKind::Sint | ScalarKind::Uint)
2917            ) =>
2918            {
2919                write!(self.out, "{DIV_FUNCTION}(")?;
2920                self.write_expr(module, left, func_ctx)?;
2921                write!(self.out, ", ")?;
2922                self.write_expr(module, right, func_ctx)?;
2923                write!(self.out, ")")?;
2924            }
2925
2926            Expression::Binary {
2927                op: crate::BinaryOperator::Modulo,
2928                left,
2929                right,
2930            } if matches!(
2931                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
2932                Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float)
2933            ) =>
2934            {
2935                write!(self.out, "{MOD_FUNCTION}(")?;
2936                self.write_expr(module, left, func_ctx)?;
2937                write!(self.out, ", ")?;
2938                self.write_expr(module, right, func_ctx)?;
2939                write!(self.out, ")")?;
2940            }
2941
2942            Expression::Binary { op, left, right } => {
2943                write!(self.out, "(")?;
2944                self.write_expr(module, left, func_ctx)?;
2945                write!(self.out, " {} ", back::binary_operation_str(op))?;
2946                self.write_expr(module, right, func_ctx)?;
2947                write!(self.out, ")")?;
2948            }
2949            Expression::Access { base, index } => {
2950                if let Some(crate::AddressSpace::Storage { .. }) =
2951                    func_ctx.resolve_type(expr, &module.types).pointer_space()
2952                {
2953                    // do nothing, the chain is written on `Load`/`Store`
2954                } else {
2955                    // We use the function __get_col_of_matCx2 here in cases
2956                    // where `base`s type resolves to a matCx2 and is part of a
2957                    // struct member with type of (possibly nested) array of matCx2's.
2958                    //
2959                    // Note that this only works for `Load`s and we handle
2960                    // `Store`s differently in `Statement::Store`.
2961                    if let Some(MatrixType {
2962                        columns,
2963                        rows: crate::VectorSize::Bi,
2964                        width: 4,
2965                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
2966                    {
2967                        write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
2968                        self.write_expr(module, base, func_ctx)?;
2969                        write!(self.out, ", ")?;
2970                        self.write_expr(module, index, func_ctx)?;
2971                        write!(self.out, ")")?;
2972                        return Ok(());
2973                    }
2974
2975                    let resolved = func_ctx.resolve_type(base, &module.types);
2976
2977                    let (indexing_binding_array, non_uniform_qualifier) = match *resolved {
2978                        TypeInner::BindingArray { .. } => {
2979                            let uniformity = &func_ctx.info[index].uniformity;
2980
2981                            (true, uniformity.non_uniform_result.is_some())
2982                        }
2983                        _ => (false, false),
2984                    };
2985
2986                    self.write_expr(module, base, func_ctx)?;
2987
2988                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
2989                        module, func_ctx, base, resolved,
2990                    );
2991
2992                    if let Some(ref info) = array_sampler_info {
2993                        write!(self.out, "{}[", info.sampler_heap_name)?;
2994                    } else {
2995                        write!(self.out, "[")?;
2996                    }
2997
2998                    let needs_bound_check = self.options.restrict_indexing
2999                        && !indexing_binding_array
3000                        && match resolved.pointer_space() {
3001                            Some(
3002                                crate::AddressSpace::Function
3003                                | crate::AddressSpace::Private
3004                                | crate::AddressSpace::WorkGroup
3005                                | crate::AddressSpace::PushConstant,
3006                            )
3007                            | None => true,
3008                            Some(crate::AddressSpace::Uniform) => {
3009                                // check if BindTarget.restrict_indexing is set, this is used for dynamic buffers
3010                                let var_handle = self.fill_access_chain(module, base, func_ctx)?;
3011                                let bind_target = self
3012                                    .options
3013                                    .resolve_resource_binding(
3014                                        module.global_variables[var_handle]
3015                                            .binding
3016                                            .as_ref()
3017                                            .unwrap(),
3018                                    )
3019                                    .unwrap();
3020                                bind_target.restrict_indexing
3021                            }
3022                            Some(
3023                                crate::AddressSpace::Handle | crate::AddressSpace::Storage { .. },
3024                            ) => unreachable!(),
3025                        };
3026                    // Decide whether this index needs to be clamped to fall within range.
3027                    let restriction_needed = if needs_bound_check {
3028                        index::access_needs_check(
3029                            base,
3030                            index::GuardedIndex::Expression(index),
3031                            module,
3032                            func_ctx.expressions,
3033                            func_ctx.info,
3034                        )
3035                    } else {
3036                        None
3037                    };
3038                    if let Some(limit) = restriction_needed {
3039                        write!(self.out, "min(uint(")?;
3040                        self.write_expr(module, index, func_ctx)?;
3041                        write!(self.out, "), ")?;
3042                        match limit {
3043                            index::IndexableLength::Known(limit) => {
3044                                write!(self.out, "{}u", limit - 1)?;
3045                            }
3046                            index::IndexableLength::Dynamic => unreachable!(),
3047                        }
3048                        write!(self.out, ")")?;
3049                    } else {
3050                        if non_uniform_qualifier {
3051                            write!(self.out, "NonUniformResourceIndex(")?;
3052                        }
3053                        if let Some(ref info) = array_sampler_info {
3054                            write!(
3055                                self.out,
3056                                "{}[{} + ",
3057                                info.sampler_index_buffer_name, info.binding_array_base_index_name,
3058                            )?;
3059                        }
3060                        self.write_expr(module, index, func_ctx)?;
3061                        if array_sampler_info.is_some() {
3062                            write!(self.out, "]")?;
3063                        }
3064                        if non_uniform_qualifier {
3065                            write!(self.out, ")")?;
3066                        }
3067                    }
3068
3069                    write!(self.out, "]")?;
3070                }
3071            }
3072            Expression::AccessIndex { base, index } => {
3073                if let Some(crate::AddressSpace::Storage { .. }) =
3074                    func_ctx.resolve_type(expr, &module.types).pointer_space()
3075                {
3076                    // do nothing, the chain is written on `Load`/`Store`
3077                } else {
3078                    // We write the matrix column access in a special way since
3079                    // the type of `base` is our special __matCx2 struct.
3080                    if let Some(MatrixType {
3081                        rows: crate::VectorSize::Bi,
3082                        width: 4,
3083                        ..
3084                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3085                    {
3086                        self.write_expr(module, base, func_ctx)?;
3087                        write!(self.out, "._{index}")?;
3088                        return Ok(());
3089                    }
3090
3091                    let base_ty_res = &func_ctx.info[base].ty;
3092                    let mut resolved = base_ty_res.inner_with(&module.types);
3093                    let base_ty_handle = match *resolved {
3094                        TypeInner::Pointer { base, .. } => {
3095                            resolved = &module.types[base].inner;
3096                            Some(base)
3097                        }
3098                        _ => base_ty_res.handle(),
3099                    };
3100
3101                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
3102                    // See the module-level block comment in mod.rs for details.
3103                    //
3104                    // We handle matrix reconstruction here for Loads.
3105                    // Stores are handled directly by `Statement::Store`.
3106                    if let TypeInner::Struct { ref members, .. } = *resolved {
3107                        let member = &members[index as usize];
3108
3109                        match module.types[member.ty].inner {
3110                            TypeInner::Matrix {
3111                                rows: crate::VectorSize::Bi,
3112                                ..
3113                            } if member.binding.is_none() => {
3114                                let ty = base_ty_handle.unwrap();
3115                                self.write_wrapped_struct_matrix_get_function_name(
3116                                    WrappedStructMatrixAccess { ty, index },
3117                                )?;
3118                                write!(self.out, "(")?;
3119                                self.write_expr(module, base, func_ctx)?;
3120                                write!(self.out, ")")?;
3121                                return Ok(());
3122                            }
3123                            _ => {}
3124                        }
3125                    }
3126
3127                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
3128                        module, func_ctx, base, resolved,
3129                    );
3130
3131                    if let Some(ref info) = array_sampler_info {
3132                        write!(
3133                            self.out,
3134                            "{}[{}",
3135                            info.sampler_heap_name, info.sampler_index_buffer_name
3136                        )?;
3137                    }
3138
3139                    self.write_expr(module, base, func_ctx)?;
3140
3141                    match *resolved {
3142                        // We specifically lift the ValuePointer to this case. While `[0]` is valid
3143                        // HLSL for any vector behind a value pointer, FXC completely miscompiles
3144                        // it and generates completely nonsensical DXBC.
3145                        //
3146                        // See https://github.com/gfx-rs/naga/issues/2095 for more details.
3147                        TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
3148                            // Write vector access as a swizzle
3149                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?
3150                        }
3151                        TypeInner::Matrix { .. }
3152                        | TypeInner::Array { .. }
3153                        | TypeInner::BindingArray { .. } => {
3154                            if let Some(ref info) = array_sampler_info {
3155                                write!(
3156                                    self.out,
3157                                    "[{} + {index}]",
3158                                    info.binding_array_base_index_name
3159                                )?;
3160                            } else {
3161                                write!(self.out, "[{index}]")?;
3162                            }
3163                        }
3164                        TypeInner::Struct { .. } => {
3165                            // This will never panic in case the type is a `Struct`, this is not true
3166                            // for other types so we can only check while inside this match arm
3167                            let ty = base_ty_handle.unwrap();
3168
3169                            write!(
3170                                self.out,
3171                                ".{}",
3172                                &self.names[&NameKey::StructMember(ty, index)]
3173                            )?
3174                        }
3175                        ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
3176                    }
3177
3178                    if array_sampler_info.is_some() {
3179                        write!(self.out, "]")?;
3180                    }
3181                }
3182            }
3183            Expression::FunctionArgument(pos) => {
3184                let key = func_ctx.argument_key(pos);
3185                let name = &self.names[&key];
3186                write!(self.out, "{name}")?;
3187            }
3188            Expression::ImageSample {
3189                coordinate,
3190                image,
3191                sampler,
3192                clamp_to_edge: true,
3193                gather: None,
3194                array_index: None,
3195                offset: None,
3196                level: crate::SampleLevel::Zero,
3197                depth_ref: None,
3198            } => {
3199                write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
3200                self.write_expr(module, image, func_ctx)?;
3201                write!(self.out, ", ")?;
3202                self.write_expr(module, sampler, func_ctx)?;
3203                write!(self.out, ", ")?;
3204                self.write_expr(module, coordinate, func_ctx)?;
3205                write!(self.out, ")")?;
3206            }
3207            Expression::ImageSample {
3208                image,
3209                sampler,
3210                gather,
3211                coordinate,
3212                array_index,
3213                offset,
3214                level,
3215                depth_ref,
3216                clamp_to_edge,
3217            } => {
3218                if clamp_to_edge {
3219                    return Err(Error::Custom(
3220                        "ImageSample::clamp_to_edge should have been validated out".to_string(),
3221                    ));
3222                }
3223
3224                use crate::SampleLevel as Sl;
3225                const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
3226
3227                let (base_str, component_str) = match gather {
3228                    Some(component) => ("Gather", COMPONENTS[component as usize]),
3229                    None => ("Sample", ""),
3230                };
3231                let cmp_str = match depth_ref {
3232                    Some(_) => "Cmp",
3233                    None => "",
3234                };
3235                let level_str = match level {
3236                    Sl::Zero if gather.is_none() => "LevelZero",
3237                    Sl::Auto | Sl::Zero => "",
3238                    Sl::Exact(_) => "Level",
3239                    Sl::Bias(_) => "Bias",
3240                    Sl::Gradient { .. } => "Grad",
3241                };
3242
3243                self.write_expr(module, image, func_ctx)?;
3244                write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
3245                self.write_expr(module, sampler, func_ctx)?;
3246                write!(self.out, ", ")?;
3247                self.write_texture_coordinates(
3248                    "float",
3249                    coordinate,
3250                    array_index,
3251                    None,
3252                    module,
3253                    func_ctx,
3254                )?;
3255
3256                if let Some(depth_ref) = depth_ref {
3257                    write!(self.out, ", ")?;
3258                    self.write_expr(module, depth_ref, func_ctx)?;
3259                }
3260
3261                match level {
3262                    Sl::Auto | Sl::Zero => {}
3263                    Sl::Exact(expr) => {
3264                        write!(self.out, ", ")?;
3265                        self.write_expr(module, expr, func_ctx)?;
3266                    }
3267                    Sl::Bias(expr) => {
3268                        write!(self.out, ", ")?;
3269                        self.write_expr(module, expr, func_ctx)?;
3270                    }
3271                    Sl::Gradient { x, y } => {
3272                        write!(self.out, ", ")?;
3273                        self.write_expr(module, x, func_ctx)?;
3274                        write!(self.out, ", ")?;
3275                        self.write_expr(module, y, func_ctx)?;
3276                    }
3277                }
3278
3279                if let Some(offset) = offset {
3280                    write!(self.out, ", ")?;
3281                    write!(self.out, "int2(")?; // work around https://github.com/microsoft/DirectXShaderCompiler/issues/5082#issuecomment-1540147807
3282                    self.write_const_expression(module, offset, func_ctx.expressions)?;
3283                    write!(self.out, ")")?;
3284                }
3285
3286                write!(self.out, ")")?;
3287            }
3288            Expression::ImageQuery { image, query } => {
3289                // use wrapped image query function
3290                if let TypeInner::Image {
3291                    dim,
3292                    arrayed,
3293                    class,
3294                } = *func_ctx.resolve_type(image, &module.types)
3295                {
3296                    let wrapped_image_query = WrappedImageQuery {
3297                        dim,
3298                        arrayed,
3299                        class,
3300                        query: query.into(),
3301                    };
3302
3303                    self.write_wrapped_image_query_function_name(wrapped_image_query)?;
3304                    write!(self.out, "(")?;
3305                    // Image always first param
3306                    self.write_expr(module, image, func_ctx)?;
3307                    if let crate::ImageQuery::Size { level: Some(level) } = query {
3308                        write!(self.out, ", ")?;
3309                        self.write_expr(module, level, func_ctx)?;
3310                    }
3311                    write!(self.out, ")")?;
3312                }
3313            }
3314            Expression::ImageLoad {
3315                image,
3316                coordinate,
3317                array_index,
3318                sample,
3319                level,
3320            } => self.write_image_load(
3321                &module,
3322                expr,
3323                func_ctx,
3324                image,
3325                coordinate,
3326                array_index,
3327                sample,
3328                level,
3329            )?,
3330            Expression::GlobalVariable(handle) => {
3331                let global_variable = &module.global_variables[handle];
3332                let ty = &module.types[global_variable.ty].inner;
3333
3334                // In the case of binding arrays of samplers, we need to not write anything
3335                // as the we are in the wrong position to fully write the expression.
3336                //
3337                // The entire writing is done by AccessIndex.
3338                let is_binding_array_of_samplers = match *ty {
3339                    TypeInner::BindingArray { base, .. } => {
3340                        let base_ty = &module.types[base].inner;
3341                        matches!(*base_ty, TypeInner::Sampler { .. })
3342                    }
3343                    _ => false,
3344                };
3345
3346                let is_storage_space =
3347                    matches!(global_variable.space, crate::AddressSpace::Storage { .. });
3348
3349                if !is_binding_array_of_samplers && !is_storage_space {
3350                    let name = &self.names[&NameKey::GlobalVariable(handle)];
3351                    write!(self.out, "{name}")?;
3352                }
3353            }
3354            Expression::LocalVariable(handle) => {
3355                write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
3356            }
3357            Expression::Load { pointer } => {
3358                match func_ctx
3359                    .resolve_type(pointer, &module.types)
3360                    .pointer_space()
3361                {
3362                    Some(crate::AddressSpace::Storage { .. }) => {
3363                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
3364                        let result_ty = func_ctx.info[expr].ty.clone();
3365                        self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
3366                    }
3367                    _ => {
3368                        let mut close_paren = false;
3369
3370                        // We cast the value loaded to a native HLSL floatCx2
3371                        // in cases where it is of type:
3372                        //  - __matCx2 or
3373                        //  - a (possibly nested) array of __matCx2's
3374                        if let Some(MatrixType {
3375                            rows: crate::VectorSize::Bi,
3376                            width: 4,
3377                            ..
3378                        }) = get_inner_matrix_of_struct_array_member(
3379                            module, pointer, func_ctx, false,
3380                        )
3381                        .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
3382                        {
3383                            let mut resolved = func_ctx.resolve_type(pointer, &module.types);
3384                            if let TypeInner::Pointer { base, .. } = *resolved {
3385                                resolved = &module.types[base].inner;
3386                            }
3387
3388                            write!(self.out, "((")?;
3389                            if let TypeInner::Array { base, size, .. } = *resolved {
3390                                self.write_type(module, base)?;
3391                                self.write_array_size(module, base, size)?;
3392                            } else {
3393                                self.write_value_type(module, resolved)?;
3394                            }
3395                            write!(self.out, ")")?;
3396                            close_paren = true;
3397                        }
3398
3399                        self.write_expr(module, pointer, func_ctx)?;
3400
3401                        if close_paren {
3402                            write!(self.out, ")")?;
3403                        }
3404                    }
3405                }
3406            }
3407            Expression::Unary { op, expr } => {
3408                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-operators#unary-operators
3409                let op_str = match op {
3410                    crate::UnaryOperator::Negate => {
3411                        match func_ctx.resolve_type(expr, &module.types).scalar() {
3412                            Some(Scalar::I32) => NEG_FUNCTION,
3413                            _ => "-",
3414                        }
3415                    }
3416                    crate::UnaryOperator::LogicalNot => "!",
3417                    crate::UnaryOperator::BitwiseNot => "~",
3418                };
3419                write!(self.out, "{op_str}(")?;
3420                self.write_expr(module, expr, func_ctx)?;
3421                write!(self.out, ")")?;
3422            }
3423            Expression::As {
3424                expr,
3425                kind,
3426                convert,
3427            } => {
3428                let inner = func_ctx.resolve_type(expr, &module.types);
3429                if inner.scalar_kind() == Some(ScalarKind::Float)
3430                    && (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
3431                    && convert.is_some()
3432                {
3433                    // Use helper functions for float to int casts in order to
3434                    // avoid undefined behaviour when value is out of range for
3435                    // the target type.
3436                    let fun_name = match (kind, convert) {
3437                        (ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
3438                        (ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
3439                        (ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
3440                        (ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
3441                        _ => unreachable!(),
3442                    };
3443                    write!(self.out, "{fun_name}(")?;
3444                    self.write_expr(module, expr, func_ctx)?;
3445                    write!(self.out, ")")?;
3446                } else {
3447                    let close_paren = match convert {
3448                        Some(dst_width) => {
3449                            let scalar = Scalar {
3450                                kind,
3451                                width: dst_width,
3452                            };
3453                            match *inner {
3454                                TypeInner::Vector { size, .. } => {
3455                                    write!(
3456                                        self.out,
3457                                        "{}{}(",
3458                                        scalar.to_hlsl_str()?,
3459                                        common::vector_size_str(size)
3460                                    )?;
3461                                }
3462                                TypeInner::Scalar(_) => {
3463                                    write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
3464                                }
3465                                TypeInner::Matrix { columns, rows, .. } => {
3466                                    write!(
3467                                        self.out,
3468                                        "{}{}x{}(",
3469                                        scalar.to_hlsl_str()?,
3470                                        common::vector_size_str(columns),
3471                                        common::vector_size_str(rows)
3472                                    )?;
3473                                }
3474                                _ => {
3475                                    return Err(Error::Unimplemented(format!(
3476                                        "write_expr expression::as {inner:?}"
3477                                    )));
3478                                }
3479                            };
3480                            true
3481                        }
3482                        None => {
3483                            if inner.scalar_width() == Some(8) {
3484                                false
3485                            } else {
3486                                write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
3487                                true
3488                            }
3489                        }
3490                    };
3491                    self.write_expr(module, expr, func_ctx)?;
3492                    if close_paren {
3493                        write!(self.out, ")")?;
3494                    }
3495                }
3496            }
3497            Expression::Math {
3498                fun,
3499                arg,
3500                arg1,
3501                arg2,
3502                arg3,
3503            } => {
3504                use crate::MathFunction as Mf;
3505
3506                enum Function {
3507                    Asincosh { is_sin: bool },
3508                    Atanh,
3509                    Pack2x16float,
3510                    Pack2x16snorm,
3511                    Pack2x16unorm,
3512                    Pack4x8snorm,
3513                    Pack4x8unorm,
3514                    Pack4xI8,
3515                    Pack4xU8,
3516                    Pack4xI8Clamp,
3517                    Pack4xU8Clamp,
3518                    Unpack2x16float,
3519                    Unpack2x16snorm,
3520                    Unpack2x16unorm,
3521                    Unpack4x8snorm,
3522                    Unpack4x8unorm,
3523                    Unpack4xI8,
3524                    Unpack4xU8,
3525                    Dot4I8Packed,
3526                    Dot4U8Packed,
3527                    QuantizeToF16,
3528                    Regular(&'static str),
3529                    MissingIntOverload(&'static str),
3530                    MissingIntReturnType(&'static str),
3531                    CountTrailingZeros,
3532                    CountLeadingZeros,
3533                }
3534
3535                let fun = match fun {
3536                    // comparison
3537                    Mf::Abs => match func_ctx.resolve_type(arg, &module.types).scalar() {
3538                        Some(Scalar::I32) => Function::Regular(ABS_FUNCTION),
3539                        _ => Function::Regular("abs"),
3540                    },
3541                    Mf::Min => Function::Regular("min"),
3542                    Mf::Max => Function::Regular("max"),
3543                    Mf::Clamp => Function::Regular("clamp"),
3544                    Mf::Saturate => Function::Regular("saturate"),
3545                    // trigonometry
3546                    Mf::Cos => Function::Regular("cos"),
3547                    Mf::Cosh => Function::Regular("cosh"),
3548                    Mf::Sin => Function::Regular("sin"),
3549                    Mf::Sinh => Function::Regular("sinh"),
3550                    Mf::Tan => Function::Regular("tan"),
3551                    Mf::Tanh => Function::Regular("tanh"),
3552                    Mf::Acos => Function::Regular("acos"),
3553                    Mf::Asin => Function::Regular("asin"),
3554                    Mf::Atan => Function::Regular("atan"),
3555                    Mf::Atan2 => Function::Regular("atan2"),
3556                    Mf::Asinh => Function::Asincosh { is_sin: true },
3557                    Mf::Acosh => Function::Asincosh { is_sin: false },
3558                    Mf::Atanh => Function::Atanh,
3559                    Mf::Radians => Function::Regular("radians"),
3560                    Mf::Degrees => Function::Regular("degrees"),
3561                    // decomposition
3562                    Mf::Ceil => Function::Regular("ceil"),
3563                    Mf::Floor => Function::Regular("floor"),
3564                    Mf::Round => Function::Regular("round"),
3565                    Mf::Fract => Function::Regular("frac"),
3566                    Mf::Trunc => Function::Regular("trunc"),
3567                    Mf::Modf => Function::Regular(MODF_FUNCTION),
3568                    Mf::Frexp => Function::Regular(FREXP_FUNCTION),
3569                    Mf::Ldexp => Function::Regular("ldexp"),
3570                    // exponent
3571                    Mf::Exp => Function::Regular("exp"),
3572                    Mf::Exp2 => Function::Regular("exp2"),
3573                    Mf::Log => Function::Regular("log"),
3574                    Mf::Log2 => Function::Regular("log2"),
3575                    Mf::Pow => Function::Regular("pow"),
3576                    // geometry
3577                    Mf::Dot => Function::Regular("dot"),
3578                    Mf::Dot4I8Packed => Function::Dot4I8Packed,
3579                    Mf::Dot4U8Packed => Function::Dot4U8Packed,
3580                    //Mf::Outer => ,
3581                    Mf::Cross => Function::Regular("cross"),
3582                    Mf::Distance => Function::Regular("distance"),
3583                    Mf::Length => Function::Regular("length"),
3584                    Mf::Normalize => Function::Regular("normalize"),
3585                    Mf::FaceForward => Function::Regular("faceforward"),
3586                    Mf::Reflect => Function::Regular("reflect"),
3587                    Mf::Refract => Function::Regular("refract"),
3588                    // computational
3589                    Mf::Sign => Function::Regular("sign"),
3590                    Mf::Fma => Function::Regular("mad"),
3591                    Mf::Mix => Function::Regular("lerp"),
3592                    Mf::Step => Function::Regular("step"),
3593                    Mf::SmoothStep => Function::Regular("smoothstep"),
3594                    Mf::Sqrt => Function::Regular("sqrt"),
3595                    Mf::InverseSqrt => Function::Regular("rsqrt"),
3596                    //Mf::Inverse =>,
3597                    Mf::Transpose => Function::Regular("transpose"),
3598                    Mf::Determinant => Function::Regular("determinant"),
3599                    Mf::QuantizeToF16 => Function::QuantizeToF16,
3600                    // bits
3601                    Mf::CountTrailingZeros => Function::CountTrailingZeros,
3602                    Mf::CountLeadingZeros => Function::CountLeadingZeros,
3603                    Mf::CountOneBits => Function::MissingIntOverload("countbits"),
3604                    Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
3605                    Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"),
3606                    Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"),
3607                    Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
3608                    Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
3609                    // Data Packing
3610                    Mf::Pack2x16float => Function::Pack2x16float,
3611                    Mf::Pack2x16snorm => Function::Pack2x16snorm,
3612                    Mf::Pack2x16unorm => Function::Pack2x16unorm,
3613                    Mf::Pack4x8snorm => Function::Pack4x8snorm,
3614                    Mf::Pack4x8unorm => Function::Pack4x8unorm,
3615                    Mf::Pack4xI8 => Function::Pack4xI8,
3616                    Mf::Pack4xU8 => Function::Pack4xU8,
3617                    Mf::Pack4xI8Clamp => Function::Pack4xI8Clamp,
3618                    Mf::Pack4xU8Clamp => Function::Pack4xU8Clamp,
3619                    // Data Unpacking
3620                    Mf::Unpack2x16float => Function::Unpack2x16float,
3621                    Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
3622                    Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
3623                    Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
3624                    Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
3625                    Mf::Unpack4xI8 => Function::Unpack4xI8,
3626                    Mf::Unpack4xU8 => Function::Unpack4xU8,
3627                    _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
3628                };
3629
3630                match fun {
3631                    Function::Asincosh { is_sin } => {
3632                        write!(self.out, "log(")?;
3633                        self.write_expr(module, arg, func_ctx)?;
3634                        write!(self.out, " + sqrt(")?;
3635                        self.write_expr(module, arg, func_ctx)?;
3636                        write!(self.out, " * ")?;
3637                        self.write_expr(module, arg, func_ctx)?;
3638                        match is_sin {
3639                            true => write!(self.out, " + 1.0))")?,
3640                            false => write!(self.out, " - 1.0))")?,
3641                        }
3642                    }
3643                    Function::Atanh => {
3644                        write!(self.out, "0.5 * log((1.0 + ")?;
3645                        self.write_expr(module, arg, func_ctx)?;
3646                        write!(self.out, ") / (1.0 - ")?;
3647                        self.write_expr(module, arg, func_ctx)?;
3648                        write!(self.out, "))")?;
3649                    }
3650                    Function::Pack2x16float => {
3651                        write!(self.out, "(f32tof16(")?;
3652                        self.write_expr(module, arg, func_ctx)?;
3653                        write!(self.out, "[0]) | f32tof16(")?;
3654                        self.write_expr(module, arg, func_ctx)?;
3655                        write!(self.out, "[1]) << 16)")?;
3656                    }
3657                    Function::Pack2x16snorm => {
3658                        let scale = 32767;
3659
3660                        write!(self.out, "uint((int(round(clamp(")?;
3661                        self.write_expr(module, arg, func_ctx)?;
3662                        write!(
3663                            self.out,
3664                            "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
3665                        )?;
3666                        self.write_expr(module, arg, func_ctx)?;
3667                        write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
3668                    }
3669                    Function::Pack2x16unorm => {
3670                        let scale = 65535;
3671
3672                        write!(self.out, "(uint(round(clamp(")?;
3673                        self.write_expr(module, arg, func_ctx)?;
3674                        write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3675                        self.write_expr(module, arg, func_ctx)?;
3676                        write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
3677                    }
3678                    Function::Pack4x8snorm => {
3679                        let scale = 127;
3680
3681                        write!(self.out, "uint((int(round(clamp(")?;
3682                        self.write_expr(module, arg, func_ctx)?;
3683                        write!(
3684                            self.out,
3685                            "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
3686                        )?;
3687                        self.write_expr(module, arg, func_ctx)?;
3688                        write!(
3689                            self.out,
3690                            "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
3691                        )?;
3692                        self.write_expr(module, arg, func_ctx)?;
3693                        write!(
3694                            self.out,
3695                            "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
3696                        )?;
3697                        self.write_expr(module, arg, func_ctx)?;
3698                        write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
3699                    }
3700                    Function::Pack4x8unorm => {
3701                        let scale = 255;
3702
3703                        write!(self.out, "(uint(round(clamp(")?;
3704                        self.write_expr(module, arg, func_ctx)?;
3705                        write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3706                        self.write_expr(module, arg, func_ctx)?;
3707                        write!(
3708                            self.out,
3709                            "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
3710                        )?;
3711                        self.write_expr(module, arg, func_ctx)?;
3712                        write!(
3713                            self.out,
3714                            "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
3715                        )?;
3716                        self.write_expr(module, arg, func_ctx)?;
3717                        write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
3718                    }
3719                    fun @ (Function::Pack4xI8
3720                    | Function::Pack4xU8
3721                    | Function::Pack4xI8Clamp
3722                    | Function::Pack4xU8Clamp) => {
3723                        let was_signed =
3724                            matches!(fun, Function::Pack4xI8 | Function::Pack4xI8Clamp);
3725                        let clamp_bounds = match fun {
3726                            Function::Pack4xI8Clamp => Some(("-128", "127")),
3727                            Function::Pack4xU8Clamp => Some(("0", "255")),
3728                            _ => None,
3729                        };
3730                        if was_signed {
3731                            write!(self.out, "uint(")?;
3732                        }
3733                        let write_arg = |this: &mut Self| -> BackendResult {
3734                            if let Some((min, max)) = clamp_bounds {
3735                                write!(this.out, "clamp(")?;
3736                                this.write_expr(module, arg, func_ctx)?;
3737                                write!(this.out, ", {min}, {max})")?;
3738                            } else {
3739                                this.write_expr(module, arg, func_ctx)?;
3740                            }
3741                            Ok(())
3742                        };
3743                        write!(self.out, "(")?;
3744                        write_arg(self)?;
3745                        write!(self.out, "[0] & 0xFF) | ((")?;
3746                        write_arg(self)?;
3747                        write!(self.out, "[1] & 0xFF) << 8) | ((")?;
3748                        write_arg(self)?;
3749                        write!(self.out, "[2] & 0xFF) << 16) | ((")?;
3750                        write_arg(self)?;
3751                        write!(self.out, "[3] & 0xFF) << 24)")?;
3752                        if was_signed {
3753                            write!(self.out, ")")?;
3754                        }
3755                    }
3756
3757                    Function::Unpack2x16float => {
3758                        write!(self.out, "float2(f16tof32(")?;
3759                        self.write_expr(module, arg, func_ctx)?;
3760                        write!(self.out, "), f16tof32((")?;
3761                        self.write_expr(module, arg, func_ctx)?;
3762                        write!(self.out, ") >> 16))")?;
3763                    }
3764                    Function::Unpack2x16snorm => {
3765                        let scale = 32767;
3766
3767                        write!(self.out, "(float2(int2(")?;
3768                        self.write_expr(module, arg, func_ctx)?;
3769                        write!(self.out, " << 16, ")?;
3770                        self.write_expr(module, arg, func_ctx)?;
3771                        write!(self.out, ") >> 16) / {scale}.0)")?;
3772                    }
3773                    Function::Unpack2x16unorm => {
3774                        let scale = 65535;
3775
3776                        write!(self.out, "(float2(")?;
3777                        self.write_expr(module, arg, func_ctx)?;
3778                        write!(self.out, " & 0xFFFF, ")?;
3779                        self.write_expr(module, arg, func_ctx)?;
3780                        write!(self.out, " >> 16) / {scale}.0)")?;
3781                    }
3782                    Function::Unpack4x8snorm => {
3783                        let scale = 127;
3784
3785                        write!(self.out, "(float4(int4(")?;
3786                        self.write_expr(module, arg, func_ctx)?;
3787                        write!(self.out, " << 24, ")?;
3788                        self.write_expr(module, arg, func_ctx)?;
3789                        write!(self.out, " << 16, ")?;
3790                        self.write_expr(module, arg, func_ctx)?;
3791                        write!(self.out, " << 8, ")?;
3792                        self.write_expr(module, arg, func_ctx)?;
3793                        write!(self.out, ") >> 24) / {scale}.0)")?;
3794                    }
3795                    Function::Unpack4x8unorm => {
3796                        let scale = 255;
3797
3798                        write!(self.out, "(float4(")?;
3799                        self.write_expr(module, arg, func_ctx)?;
3800                        write!(self.out, " & 0xFF, ")?;
3801                        self.write_expr(module, arg, func_ctx)?;
3802                        write!(self.out, " >> 8 & 0xFF, ")?;
3803                        self.write_expr(module, arg, func_ctx)?;
3804                        write!(self.out, " >> 16 & 0xFF, ")?;
3805                        self.write_expr(module, arg, func_ctx)?;
3806                        write!(self.out, " >> 24) / {scale}.0)")?;
3807                    }
3808                    fun @ (Function::Unpack4xI8 | Function::Unpack4xU8) => {
3809                        write!(self.out, "(")?;
3810                        if matches!(fun, Function::Unpack4xU8) {
3811                            write!(self.out, "u")?;
3812                        }
3813                        write!(self.out, "int4(")?;
3814                        self.write_expr(module, arg, func_ctx)?;
3815                        write!(self.out, ", ")?;
3816                        self.write_expr(module, arg, func_ctx)?;
3817                        write!(self.out, " >> 8, ")?;
3818                        self.write_expr(module, arg, func_ctx)?;
3819                        write!(self.out, " >> 16, ")?;
3820                        self.write_expr(module, arg, func_ctx)?;
3821                        write!(self.out, " >> 24) << 24 >> 24)")?;
3822                    }
3823                    fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
3824                        let arg1 = arg1.unwrap();
3825
3826                        if self.options.shader_model >= ShaderModel::V6_4 {
3827                            // Intrinsics `dot4add_{i, u}8packed` are available in SM 6.4 and later.
3828                            let function_name = match fun {
3829                                Function::Dot4I8Packed => "dot4add_i8packed",
3830                                Function::Dot4U8Packed => "dot4add_u8packed",
3831                                _ => unreachable!(),
3832                            };
3833                            write!(self.out, "{function_name}(")?;
3834                            self.write_expr(module, arg, func_ctx)?;
3835                            write!(self.out, ", ")?;
3836                            self.write_expr(module, arg1, func_ctx)?;
3837                            write!(self.out, ", 0)")?;
3838                        } else {
3839                            // Fall back to a polyfill as `dot4add_u8packed` is not available.
3840                            write!(self.out, "dot(")?;
3841
3842                            if matches!(fun, Function::Dot4U8Packed) {
3843                                write!(self.out, "u")?;
3844                            }
3845                            write!(self.out, "int4(")?;
3846                            self.write_expr(module, arg, func_ctx)?;
3847                            write!(self.out, ", ")?;
3848                            self.write_expr(module, arg, func_ctx)?;
3849                            write!(self.out, " >> 8, ")?;
3850                            self.write_expr(module, arg, func_ctx)?;
3851                            write!(self.out, " >> 16, ")?;
3852                            self.write_expr(module, arg, func_ctx)?;
3853                            write!(self.out, " >> 24) << 24 >> 24, ")?;
3854
3855                            if matches!(fun, Function::Dot4U8Packed) {
3856                                write!(self.out, "u")?;
3857                            }
3858                            write!(self.out, "int4(")?;
3859                            self.write_expr(module, arg1, func_ctx)?;
3860                            write!(self.out, ", ")?;
3861                            self.write_expr(module, arg1, func_ctx)?;
3862                            write!(self.out, " >> 8, ")?;
3863                            self.write_expr(module, arg1, func_ctx)?;
3864                            write!(self.out, " >> 16, ")?;
3865                            self.write_expr(module, arg1, func_ctx)?;
3866                            write!(self.out, " >> 24) << 24 >> 24)")?;
3867                        }
3868                    }
3869                    Function::QuantizeToF16 => {
3870                        write!(self.out, "f16tof32(f32tof16(")?;
3871                        self.write_expr(module, arg, func_ctx)?;
3872                        write!(self.out, "))")?;
3873                    }
3874                    Function::Regular(fun_name) => {
3875                        write!(self.out, "{fun_name}(")?;
3876                        self.write_expr(module, arg, func_ctx)?;
3877                        if let Some(arg) = arg1 {
3878                            write!(self.out, ", ")?;
3879                            self.write_expr(module, arg, func_ctx)?;
3880                        }
3881                        if let Some(arg) = arg2 {
3882                            write!(self.out, ", ")?;
3883                            self.write_expr(module, arg, func_ctx)?;
3884                        }
3885                        if let Some(arg) = arg3 {
3886                            write!(self.out, ", ")?;
3887                            self.write_expr(module, arg, func_ctx)?;
3888                        }
3889                        write!(self.out, ")")?
3890                    }
3891                    // These overloads are only missing on FXC, so this is only needed for 32bit types,
3892                    // as non-32bit types are DXC only.
3893                    Function::MissingIntOverload(fun_name) => {
3894                        let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
3895                        if let Some(Scalar::I32) = scalar_kind {
3896                            write!(self.out, "asint({fun_name}(asuint(")?;
3897                            self.write_expr(module, arg, func_ctx)?;
3898                            write!(self.out, ")))")?;
3899                        } else {
3900                            write!(self.out, "{fun_name}(")?;
3901                            self.write_expr(module, arg, func_ctx)?;
3902                            write!(self.out, ")")?;
3903                        }
3904                    }
3905                    // These overloads are only missing on FXC, so this is only needed for 32bit types,
3906                    // as non-32bit types are DXC only.
3907                    Function::MissingIntReturnType(fun_name) => {
3908                        let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
3909                        if let Some(Scalar::I32) = scalar_kind {
3910                            write!(self.out, "asint({fun_name}(")?;
3911                            self.write_expr(module, arg, func_ctx)?;
3912                            write!(self.out, "))")?;
3913                        } else {
3914                            write!(self.out, "{fun_name}(")?;
3915                            self.write_expr(module, arg, func_ctx)?;
3916                            write!(self.out, ")")?;
3917                        }
3918                    }
3919                    Function::CountTrailingZeros => {
3920                        match *func_ctx.resolve_type(arg, &module.types) {
3921                            TypeInner::Vector { size, scalar } => {
3922                                let s = match size {
3923                                    crate::VectorSize::Bi => ".xx",
3924                                    crate::VectorSize::Tri => ".xxx",
3925                                    crate::VectorSize::Quad => ".xxxx",
3926                                };
3927
3928                                let scalar_width_bits = scalar.width * 8;
3929
3930                                if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
3931                                    write!(
3932                                        self.out,
3933                                        "min(({scalar_width_bits}u){s}, firstbitlow("
3934                                    )?;
3935                                    self.write_expr(module, arg, func_ctx)?;
3936                                    write!(self.out, "))")?;
3937                                } else {
3938                                    // This is only needed for the FXC path, on 32bit signed integers.
3939                                    write!(
3940                                        self.out,
3941                                        "asint(min(({scalar_width_bits}u){s}, firstbitlow("
3942                                    )?;
3943                                    self.write_expr(module, arg, func_ctx)?;
3944                                    write!(self.out, ")))")?;
3945                                }
3946                            }
3947                            TypeInner::Scalar(scalar) => {
3948                                let scalar_width_bits = scalar.width * 8;
3949
3950                                if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
3951                                    write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
3952                                    self.write_expr(module, arg, func_ctx)?;
3953                                    write!(self.out, "))")?;
3954                                } else {
3955                                    // This is only needed for the FXC path, on 32bit signed integers.
3956                                    write!(
3957                                        self.out,
3958                                        "asint(min({scalar_width_bits}u, firstbitlow("
3959                                    )?;
3960                                    self.write_expr(module, arg, func_ctx)?;
3961                                    write!(self.out, ")))")?;
3962                                }
3963                            }
3964                            _ => unreachable!(),
3965                        }
3966
3967                        return Ok(());
3968                    }
3969                    Function::CountLeadingZeros => {
3970                        match *func_ctx.resolve_type(arg, &module.types) {
3971                            TypeInner::Vector { size, scalar } => {
3972                                let s = match size {
3973                                    crate::VectorSize::Bi => ".xx",
3974                                    crate::VectorSize::Tri => ".xxx",
3975                                    crate::VectorSize::Quad => ".xxxx",
3976                                };
3977
3978                                // scalar width - 1
3979                                let constant = scalar.width * 8 - 1;
3980
3981                                if scalar.kind == ScalarKind::Uint {
3982                                    write!(self.out, "(({constant}u){s} - firstbithigh(")?;
3983                                    self.write_expr(module, arg, func_ctx)?;
3984                                    write!(self.out, "))")?;
3985                                } else {
3986                                    let conversion_func = match scalar.width {
3987                                        4 => "asint",
3988                                        _ => "",
3989                                    };
3990                                    write!(self.out, "(")?;
3991                                    self.write_expr(module, arg, func_ctx)?;
3992                                    write!(
3993                                        self.out,
3994                                        " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
3995                                    )?;
3996                                    self.write_expr(module, arg, func_ctx)?;
3997                                    write!(self.out, ")))")?;
3998                                }
3999                            }
4000                            TypeInner::Scalar(scalar) => {
4001                                // scalar width - 1
4002                                let constant = scalar.width * 8 - 1;
4003
4004                                if let ScalarKind::Uint = scalar.kind {
4005                                    write!(self.out, "({constant}u - firstbithigh(")?;
4006                                    self.write_expr(module, arg, func_ctx)?;
4007                                    write!(self.out, "))")?;
4008                                } else {
4009                                    let conversion_func = match scalar.width {
4010                                        4 => "asint",
4011                                        _ => "",
4012                                    };
4013                                    write!(self.out, "(")?;
4014                                    self.write_expr(module, arg, func_ctx)?;
4015                                    write!(
4016                                        self.out,
4017                                        " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
4018                                    )?;
4019                                    self.write_expr(module, arg, func_ctx)?;
4020                                    write!(self.out, ")))")?;
4021                                }
4022                            }
4023                            _ => unreachable!(),
4024                        }
4025
4026                        return Ok(());
4027                    }
4028                }
4029            }
4030            Expression::Swizzle {
4031                size,
4032                vector,
4033                pattern,
4034            } => {
4035                self.write_expr(module, vector, func_ctx)?;
4036                write!(self.out, ".")?;
4037                for &sc in pattern[..size as usize].iter() {
4038                    self.out.write_char(back::COMPONENTS[sc as usize])?;
4039                }
4040            }
4041            Expression::ArrayLength(expr) => {
4042                let var_handle = match func_ctx.expressions[expr] {
4043                    Expression::AccessIndex { base, index: _ } => {
4044                        match func_ctx.expressions[base] {
4045                            Expression::GlobalVariable(handle) => handle,
4046                            _ => unreachable!(),
4047                        }
4048                    }
4049                    Expression::GlobalVariable(handle) => handle,
4050                    _ => unreachable!(),
4051                };
4052
4053                let var = &module.global_variables[var_handle];
4054                let (offset, stride) = match module.types[var.ty].inner {
4055                    TypeInner::Array { stride, .. } => (0, stride),
4056                    TypeInner::Struct { ref members, .. } => {
4057                        let last = members.last().unwrap();
4058                        let stride = match module.types[last.ty].inner {
4059                            TypeInner::Array { stride, .. } => stride,
4060                            _ => unreachable!(),
4061                        };
4062                        (last.offset, stride)
4063                    }
4064                    _ => unreachable!(),
4065                };
4066
4067                let storage_access = match var.space {
4068                    crate::AddressSpace::Storage { access } => access,
4069                    _ => crate::StorageAccess::default(),
4070                };
4071                let wrapped_array_length = WrappedArrayLength {
4072                    writable: storage_access.contains(crate::StorageAccess::STORE),
4073                };
4074
4075                write!(self.out, "((")?;
4076                self.write_wrapped_array_length_function_name(wrapped_array_length)?;
4077                let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4078                write!(self.out, "({var_name}) - {offset}) / {stride})")?
4079            }
4080            Expression::Derivative { axis, ctrl, expr } => {
4081                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
4082                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
4083                    let tail = match ctrl {
4084                        Ctrl::Coarse => "coarse",
4085                        Ctrl::Fine => "fine",
4086                        Ctrl::None => unreachable!(),
4087                    };
4088                    write!(self.out, "abs(ddx_{tail}(")?;
4089                    self.write_expr(module, expr, func_ctx)?;
4090                    write!(self.out, ")) + abs(ddy_{tail}(")?;
4091                    self.write_expr(module, expr, func_ctx)?;
4092                    write!(self.out, "))")?
4093                } else {
4094                    let fun_str = match (axis, ctrl) {
4095                        (Axis::X, Ctrl::Coarse) => "ddx_coarse",
4096                        (Axis::X, Ctrl::Fine) => "ddx_fine",
4097                        (Axis::X, Ctrl::None) => "ddx",
4098                        (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
4099                        (Axis::Y, Ctrl::Fine) => "ddy_fine",
4100                        (Axis::Y, Ctrl::None) => "ddy",
4101                        (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
4102                        (Axis::Width, Ctrl::None) => "fwidth",
4103                    };
4104                    write!(self.out, "{fun_str}(")?;
4105                    self.write_expr(module, expr, func_ctx)?;
4106                    write!(self.out, ")")?
4107                }
4108            }
4109            Expression::Relational { fun, argument } => {
4110                use crate::RelationalFunction as Rf;
4111
4112                let fun_str = match fun {
4113                    Rf::All => "all",
4114                    Rf::Any => "any",
4115                    Rf::IsNan => "isnan",
4116                    Rf::IsInf => "isinf",
4117                };
4118                write!(self.out, "{fun_str}(")?;
4119                self.write_expr(module, argument, func_ctx)?;
4120                write!(self.out, ")")?
4121            }
4122            Expression::Select {
4123                condition,
4124                accept,
4125                reject,
4126            } => {
4127                write!(self.out, "(")?;
4128                self.write_expr(module, condition, func_ctx)?;
4129                write!(self.out, " ? ")?;
4130                self.write_expr(module, accept, func_ctx)?;
4131                write!(self.out, " : ")?;
4132                self.write_expr(module, reject, func_ctx)?;
4133                write!(self.out, ")")?
4134            }
4135            Expression::RayQueryGetIntersection { query, committed } => {
4136                if committed {
4137                    write!(self.out, "GetCommittedIntersection(")?;
4138                    self.write_expr(module, query, func_ctx)?;
4139                    write!(self.out, ")")?;
4140                } else {
4141                    write!(self.out, "GetCandidateIntersection(")?;
4142                    self.write_expr(module, query, func_ctx)?;
4143                    write!(self.out, ")")?;
4144                }
4145            }
4146            // Not supported yet
4147            Expression::RayQueryVertexPositions { .. } => unreachable!(),
4148            // Nothing to do here, since call expression already cached
4149            Expression::CallResult(_)
4150            | Expression::AtomicResult { .. }
4151            | Expression::WorkGroupUniformLoadResult { .. }
4152            | Expression::RayQueryProceedResult
4153            | Expression::SubgroupBallotResult
4154            | Expression::SubgroupOperationResult { .. } => {}
4155        }
4156
4157        if !closing_bracket.is_empty() {
4158            write!(self.out, "{closing_bracket}")?;
4159        }
4160        Ok(())
4161    }
4162
4163    #[allow(clippy::too_many_arguments)]
4164    fn write_image_load(
4165        &mut self,
4166        module: &&Module,
4167        expr: Handle<crate::Expression>,
4168        func_ctx: &back::FunctionCtx,
4169        image: Handle<crate::Expression>,
4170        coordinate: Handle<crate::Expression>,
4171        array_index: Option<Handle<crate::Expression>>,
4172        sample: Option<Handle<crate::Expression>>,
4173        level: Option<Handle<crate::Expression>>,
4174    ) -> Result<(), Error> {
4175        let mut wrapping_type = None;
4176        match *func_ctx.resolve_type(image, &module.types) {
4177            TypeInner::Image {
4178                class: crate::ImageClass::Storage { format, .. },
4179                ..
4180            } => {
4181                if format.single_component() {
4182                    wrapping_type = Some(Scalar::from(format));
4183                }
4184            }
4185            _ => {}
4186        }
4187        if let Some(scalar) = wrapping_type {
4188            write!(
4189                self.out,
4190                "{}{}(",
4191                help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
4192                scalar.to_hlsl_str()?
4193            )?;
4194        }
4195        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load
4196        self.write_expr(module, image, func_ctx)?;
4197        write!(self.out, ".Load(")?;
4198
4199        self.write_texture_coordinates("int", coordinate, array_index, level, module, func_ctx)?;
4200
4201        if let Some(sample) = sample {
4202            write!(self.out, ", ")?;
4203            self.write_expr(module, sample, func_ctx)?;
4204        }
4205
4206        // close bracket for Load function
4207        write!(self.out, ")")?;
4208
4209        if wrapping_type.is_some() {
4210            write!(self.out, ")")?;
4211        }
4212
4213        // return x component if return type is scalar
4214        if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
4215            write!(self.out, ".x")?;
4216        }
4217        Ok(())
4218    }
4219
4220    /// Find the [`BindingArraySamplerInfo`] from an expression so that such an access
4221    /// can be generated later.
4222    fn sampler_binding_array_info_from_expression(
4223        &mut self,
4224        module: &Module,
4225        func_ctx: &back::FunctionCtx<'_>,
4226        base: Handle<crate::Expression>,
4227        resolved: &TypeInner,
4228    ) -> Option<BindingArraySamplerInfo> {
4229        if let TypeInner::BindingArray {
4230            base: base_ty_handle,
4231            ..
4232        } = *resolved
4233        {
4234            let base_ty = &module.types[base_ty_handle].inner;
4235            if let TypeInner::Sampler { comparison, .. } = *base_ty {
4236                let base = &func_ctx.expressions[base];
4237
4238                if let crate::Expression::GlobalVariable(handle) = *base {
4239                    let variable = &module.global_variables[handle];
4240
4241                    let sampler_heap_name = match comparison {
4242                        true => COMPARISON_SAMPLER_HEAP_VAR,
4243                        false => SAMPLER_HEAP_VAR,
4244                    };
4245
4246                    return Some(BindingArraySamplerInfo {
4247                        sampler_heap_name,
4248                        sampler_index_buffer_name: self
4249                            .wrapped
4250                            .sampler_index_buffers
4251                            .get(&super::SamplerIndexBufferKey {
4252                                group: variable.binding.unwrap().group,
4253                            })
4254                            .unwrap()
4255                            .clone(),
4256                        binding_array_base_index_name: self.names[&NameKey::GlobalVariable(handle)]
4257                            .clone(),
4258                    });
4259                }
4260            }
4261        }
4262
4263        None
4264    }
4265
4266    fn write_named_expr(
4267        &mut self,
4268        module: &Module,
4269        handle: Handle<crate::Expression>,
4270        name: String,
4271        // The expression which is being named.
4272        // Generally, this is the same as handle, except in WorkGroupUniformLoad
4273        named: Handle<crate::Expression>,
4274        ctx: &back::FunctionCtx,
4275    ) -> BackendResult {
4276        match ctx.info[named].ty {
4277            proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
4278                TypeInner::Struct { .. } => {
4279                    let ty_name = &self.names[&NameKey::Type(ty_handle)];
4280                    write!(self.out, "{ty_name}")?;
4281                }
4282                _ => {
4283                    self.write_type(module, ty_handle)?;
4284                }
4285            },
4286            proc::TypeResolution::Value(ref inner) => {
4287                self.write_value_type(module, inner)?;
4288            }
4289        }
4290
4291        let resolved = ctx.resolve_type(named, &module.types);
4292
4293        write!(self.out, " {name}")?;
4294        // If rhs is a array type, we should write array size
4295        if let TypeInner::Array { base, size, .. } = *resolved {
4296            self.write_array_size(module, base, size)?;
4297        }
4298        write!(self.out, " = ")?;
4299        self.write_expr(module, handle, ctx)?;
4300        writeln!(self.out, ";")?;
4301        self.named_expressions.insert(named, name);
4302
4303        Ok(())
4304    }
4305
4306    /// Helper function that write default zero initialization
4307    pub(super) fn write_default_init(
4308        &mut self,
4309        module: &Module,
4310        ty: Handle<crate::Type>,
4311    ) -> BackendResult {
4312        write!(self.out, "(")?;
4313        self.write_type(module, ty)?;
4314        if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
4315            self.write_array_size(module, base, size)?;
4316        }
4317        write!(self.out, ")0")?;
4318        Ok(())
4319    }
4320
4321    fn write_control_barrier(
4322        &mut self,
4323        barrier: crate::Barrier,
4324        level: back::Level,
4325    ) -> BackendResult {
4326        if barrier.contains(crate::Barrier::STORAGE) {
4327            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4328        }
4329        if barrier.contains(crate::Barrier::WORK_GROUP) {
4330            writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
4331        }
4332        if barrier.contains(crate::Barrier::SUB_GROUP) {
4333            // Does not exist in DirectX
4334        }
4335        if barrier.contains(crate::Barrier::TEXTURE) {
4336            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4337        }
4338        Ok(())
4339    }
4340
4341    fn write_memory_barrier(
4342        &mut self,
4343        barrier: crate::Barrier,
4344        level: back::Level,
4345    ) -> BackendResult {
4346        if barrier.contains(crate::Barrier::STORAGE) {
4347            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4348        }
4349        if barrier.contains(crate::Barrier::WORK_GROUP) {
4350            writeln!(self.out, "{level}GroupMemoryBarrier();")?;
4351        }
4352        if barrier.contains(crate::Barrier::SUB_GROUP) {
4353            // Does not exist in DirectX
4354        }
4355        if barrier.contains(crate::Barrier::TEXTURE) {
4356            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4357        }
4358        Ok(())
4359    }
4360
4361    /// Helper to emit the shared tail of an HLSL atomic call (arguments, value, result)
4362    fn emit_hlsl_atomic_tail(
4363        &mut self,
4364        module: &Module,
4365        func_ctx: &back::FunctionCtx<'_>,
4366        fun: &crate::AtomicFunction,
4367        compare_expr: Option<Handle<crate::Expression>>,
4368        value: Handle<crate::Expression>,
4369        res_var_info: &Option<(Handle<crate::Expression>, String)>,
4370    ) -> BackendResult {
4371        if let Some(cmp) = compare_expr {
4372            write!(self.out, ", ")?;
4373            self.write_expr(module, cmp, func_ctx)?;
4374        }
4375        write!(self.out, ", ")?;
4376        if let crate::AtomicFunction::Subtract = *fun {
4377            // we just wrote `InterlockedAdd`, so negate the argument
4378            write!(self.out, "-")?;
4379        }
4380        self.write_expr(module, value, func_ctx)?;
4381        if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
4382            write!(self.out, ", ")?;
4383            if compare_expr.is_some() {
4384                write!(self.out, "{res_name}.old_value")?;
4385            } else {
4386                write!(self.out, "{res_name}")?;
4387            }
4388        }
4389        writeln!(self.out, ");")?;
4390        Ok(())
4391    }
4392}
4393
4394pub(super) struct MatrixType {
4395    pub(super) columns: crate::VectorSize,
4396    pub(super) rows: crate::VectorSize,
4397    pub(super) width: crate::Bytes,
4398}
4399
4400pub(super) fn get_inner_matrix_data(
4401    module: &Module,
4402    handle: Handle<crate::Type>,
4403) -> Option<MatrixType> {
4404    match module.types[handle].inner {
4405        TypeInner::Matrix {
4406            columns,
4407            rows,
4408            scalar,
4409        } => Some(MatrixType {
4410            columns,
4411            rows,
4412            width: scalar.width,
4413        }),
4414        TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
4415        _ => None,
4416    }
4417}
4418
4419/// Returns the matrix data if the access chain starting at `base`:
4420/// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true`
4421/// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
4422/// - ends at an expression with resolved type of [`TypeInner::Struct`]
4423pub(super) fn get_inner_matrix_of_struct_array_member(
4424    module: &Module,
4425    base: Handle<crate::Expression>,
4426    func_ctx: &back::FunctionCtx<'_>,
4427    direct: bool,
4428) -> Option<MatrixType> {
4429    let mut mat_data = None;
4430    let mut array_base = None;
4431
4432    let mut current_base = base;
4433    loop {
4434        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4435        if let TypeInner::Pointer { base, .. } = *resolved {
4436            resolved = &module.types[base].inner;
4437        };
4438
4439        match *resolved {
4440            TypeInner::Matrix {
4441                columns,
4442                rows,
4443                scalar,
4444            } => {
4445                mat_data = Some(MatrixType {
4446                    columns,
4447                    rows,
4448                    width: scalar.width,
4449                })
4450            }
4451            TypeInner::Array { base, .. } => {
4452                array_base = Some(base);
4453            }
4454            TypeInner::Struct { .. } => {
4455                if let Some(array_base) = array_base {
4456                    if direct {
4457                        return mat_data;
4458                    } else {
4459                        return get_inner_matrix_data(module, array_base);
4460                    }
4461                }
4462
4463                break;
4464            }
4465            _ => break,
4466        }
4467
4468        current_base = match func_ctx.expressions[current_base] {
4469            crate::Expression::Access { base, .. } => base,
4470            crate::Expression::AccessIndex { base, .. } => base,
4471            _ => break,
4472        };
4473    }
4474    None
4475}
4476
4477/// Returns the matrix data if the access chain starting at `base`:
4478/// - starts with an expression with resolved type of [`TypeInner::Matrix`]
4479/// - contains zero or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
4480/// - ends with an [`Expression::GlobalVariable`](crate::Expression::GlobalVariable) in [`AddressSpace::Uniform`](crate::AddressSpace::Uniform)
4481fn get_inner_matrix_of_global_uniform(
4482    module: &Module,
4483    base: Handle<crate::Expression>,
4484    func_ctx: &back::FunctionCtx<'_>,
4485) -> Option<MatrixType> {
4486    let mut mat_data = None;
4487    let mut array_base = None;
4488
4489    let mut current_base = base;
4490    loop {
4491        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4492        if let TypeInner::Pointer { base, .. } = *resolved {
4493            resolved = &module.types[base].inner;
4494        };
4495
4496        match *resolved {
4497            TypeInner::Matrix {
4498                columns,
4499                rows,
4500                scalar,
4501            } => {
4502                mat_data = Some(MatrixType {
4503                    columns,
4504                    rows,
4505                    width: scalar.width,
4506                })
4507            }
4508            TypeInner::Array { base, .. } => {
4509                array_base = Some(base);
4510            }
4511            _ => break,
4512        }
4513
4514        current_base = match func_ctx.expressions[current_base] {
4515            crate::Expression::Access { base, .. } => base,
4516            crate::Expression::AccessIndex { base, .. } => base,
4517            crate::Expression::GlobalVariable(handle)
4518                if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
4519            {
4520                return mat_data.or_else(|| {
4521                    array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
4522                })
4523            }
4524            _ => break,
4525        };
4526    }
4527    None
4528}