naga/back/spv/
writer.rs

1use alloc::{string::String, vec, vec::Vec};
2
3use hashbrown::hash_map::Entry;
4use spirv::Word;
5
6use super::{
7    block::DebugInfoInner,
8    helpers::{contains_builtin, global_needs_wrapper, map_storage_class},
9    Block, BlockContext, CachedConstant, CachedExpressions, DebugInfo, EntryPointContext, Error,
10    Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction, LocalImageType,
11    LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType, NumericType, Options,
12    PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags, BITS_PER_BYTE,
13};
14use crate::{
15    arena::{Handle, HandleVec, UniqueArena},
16    back::spv::{BindingInfo, WrappedFunction},
17    path_like::PathLike,
18    proc::{Alignment, TypeResolution},
19    valid::{FunctionInfo, ModuleInfo},
20};
21
22struct FunctionInterface<'a> {
23    varying_ids: &'a mut Vec<Word>,
24    stage: crate::ShaderStage,
25}
26
27impl Function {
28    pub(super) fn to_words(&self, sink: &mut impl Extend<Word>) {
29        self.signature.as_ref().unwrap().to_words(sink);
30        for argument in self.parameters.iter() {
31            argument.instruction.to_words(sink);
32        }
33        for (index, block) in self.blocks.iter().enumerate() {
34            Instruction::label(block.label_id).to_words(sink);
35            if index == 0 {
36                for local_var in self.variables.values() {
37                    local_var.instruction.to_words(sink);
38                }
39                for local_var in self.force_loop_bounding_vars.iter() {
40                    local_var.instruction.to_words(sink);
41                }
42                for internal_var in self.spilled_composites.values() {
43                    internal_var.instruction.to_words(sink);
44                }
45            }
46            for instruction in block.body.iter() {
47                instruction.to_words(sink);
48            }
49        }
50        Instruction::function_end().to_words(sink);
51    }
52}
53
54impl Writer {
55    pub fn new(options: &Options) -> Result<Self, Error> {
56        let (major, minor) = options.lang_version;
57        if major != 1 {
58            return Err(Error::UnsupportedVersion(major, minor));
59        }
60
61        let mut capabilities_used = crate::FastIndexSet::default();
62        capabilities_used.insert(spirv::Capability::Shader);
63
64        let mut id_gen = IdGenerator::default();
65        let gl450_ext_inst_id = id_gen.next();
66        let void_type = id_gen.next();
67
68        Ok(Writer {
69            physical_layout: PhysicalLayout::new(major, minor),
70            logical_layout: LogicalLayout::default(),
71            id_gen,
72            capabilities_available: options.capabilities.clone(),
73            capabilities_used,
74            extensions_used: crate::FastIndexSet::default(),
75            debugs: vec![],
76            annotations: vec![],
77            flags: options.flags,
78            bounds_check_policies: options.bounds_check_policies,
79            zero_initialize_workgroup_memory: options.zero_initialize_workgroup_memory,
80            force_loop_bounding: options.force_loop_bounding,
81            void_type,
82            lookup_type: crate::FastHashMap::default(),
83            lookup_function: crate::FastHashMap::default(),
84            lookup_function_type: crate::FastHashMap::default(),
85            wrapped_functions: crate::FastHashMap::default(),
86            constant_ids: HandleVec::new(),
87            cached_constants: crate::FastHashMap::default(),
88            global_variables: HandleVec::new(),
89            binding_map: options.binding_map.clone(),
90            saved_cached: CachedExpressions::default(),
91            gl450_ext_inst_id,
92            temp_list: Vec::new(),
93            ray_get_committed_intersection_function: None,
94            ray_get_candidate_intersection_function: None,
95        })
96    }
97
98    /// Returns `(major, minor)` of the SPIR-V language version.
99    pub const fn lang_version(&self) -> (u8, u8) {
100        self.physical_layout.lang_version()
101    }
102
103    /// Reset `Writer` to its initial state, retaining any allocations.
104    ///
105    /// Why not just implement `Recyclable` for `Writer`? By design,
106    /// `Recyclable::recycle` requires ownership of the value, not just
107    /// `&mut`; see the trait documentation. But we need to use this method
108    /// from functions like `Writer::write`, which only have `&mut Writer`.
109    /// Workarounds include unsafe code (`core::ptr::read`, then `write`, ugh)
110    /// or something like a `Default` impl that returns an oddly-initialized
111    /// `Writer`, which is worse.
112    fn reset(&mut self) {
113        use super::recyclable::Recyclable;
114        use core::mem::take;
115
116        let mut id_gen = IdGenerator::default();
117        let gl450_ext_inst_id = id_gen.next();
118        let void_type = id_gen.next();
119
120        // Every field of the old writer that is not determined by the `Options`
121        // passed to `Writer::new` should be reset somehow.
122        let fresh = Writer {
123            // Copied from the old Writer:
124            flags: self.flags,
125            bounds_check_policies: self.bounds_check_policies,
126            zero_initialize_workgroup_memory: self.zero_initialize_workgroup_memory,
127            force_loop_bounding: self.force_loop_bounding,
128            capabilities_available: take(&mut self.capabilities_available),
129            binding_map: take(&mut self.binding_map),
130
131            // Initialized afresh:
132            id_gen,
133            void_type,
134            gl450_ext_inst_id,
135
136            // Recycled:
137            capabilities_used: take(&mut self.capabilities_used).recycle(),
138            extensions_used: take(&mut self.extensions_used).recycle(),
139            physical_layout: self.physical_layout.clone().recycle(),
140            logical_layout: take(&mut self.logical_layout).recycle(),
141            debugs: take(&mut self.debugs).recycle(),
142            annotations: take(&mut self.annotations).recycle(),
143            lookup_type: take(&mut self.lookup_type).recycle(),
144            lookup_function: take(&mut self.lookup_function).recycle(),
145            lookup_function_type: take(&mut self.lookup_function_type).recycle(),
146            wrapped_functions: take(&mut self.wrapped_functions).recycle(),
147            constant_ids: take(&mut self.constant_ids).recycle(),
148            cached_constants: take(&mut self.cached_constants).recycle(),
149            global_variables: take(&mut self.global_variables).recycle(),
150            saved_cached: take(&mut self.saved_cached).recycle(),
151            temp_list: take(&mut self.temp_list).recycle(),
152            ray_get_candidate_intersection_function: None,
153            ray_get_committed_intersection_function: None,
154        };
155
156        *self = fresh;
157
158        self.capabilities_used.insert(spirv::Capability::Shader);
159    }
160
161    /// Indicate that the code requires any one of the listed capabilities.
162    ///
163    /// If nothing in `capabilities` appears in the available capabilities
164    /// specified in the [`Options`] from which this `Writer` was created,
165    /// return an error. The `what` string is used in the error message to
166    /// explain what provoked the requirement. (If no available capabilities were
167    /// given, assume everything is available.)
168    ///
169    /// The first acceptable capability will be added to this `Writer`'s
170    /// [`capabilities_used`] table, and an `OpCapability` emitted for it in the
171    /// result. For this reason, more specific capabilities should be listed
172    /// before more general.
173    ///
174    /// [`capabilities_used`]: Writer::capabilities_used
175    pub(super) fn require_any(
176        &mut self,
177        what: &'static str,
178        capabilities: &[spirv::Capability],
179    ) -> Result<(), Error> {
180        match *capabilities {
181            [] => Ok(()),
182            [first, ..] => {
183                // Find the first acceptable capability, or return an error if
184                // there is none.
185                let selected = match self.capabilities_available {
186                    None => first,
187                    Some(ref available) => {
188                        match capabilities
189                            .iter()
190                            // need explicit type for hashbrown::HashSet::contains fn call to keep rustc happy
191                            .find(|cap| available.contains::<spirv::Capability>(cap))
192                        {
193                            Some(&cap) => cap,
194                            None => {
195                                return Err(Error::MissingCapabilities(what, capabilities.to_vec()))
196                            }
197                        }
198                    }
199                };
200                self.capabilities_used.insert(selected);
201                Ok(())
202            }
203        }
204    }
205
206    /// Indicate that the code requires all of the listed capabilities.
207    ///
208    /// If all entries of `capabilities` appear in the available capabilities
209    /// specified in the [`Options`] from which this `Writer` was created
210    /// (including the case where [`Options::capabilities`] is `None`), add
211    /// them all to this `Writer`'s [`capabilities_used`] table, and return
212    /// `Ok(())`. If at least one of the listed capabilities is not available,
213    /// do not add anything to the `capabilities_used` table, and return the
214    /// first unavailable requested capability, wrapped in `Err()`.
215    ///
216    /// This method is does not return an [`enum@Error`] in case of failure
217    /// because it may be used in cases where the caller can recover (e.g.,
218    /// with a polyfill) if the requested capabilities are not available. In
219    /// this case, it would be unnecessary work to find *all* the unavailable
220    /// requested capabilities, and to allocate a `Vec` for them, just so we
221    /// could return an [`Error::MissingCapabilities`]).
222    ///
223    /// [`capabilities_used`]: Writer::capabilities_used
224    pub(super) fn require_all(
225        &mut self,
226        capabilities: &[spirv::Capability],
227    ) -> Result<(), spirv::Capability> {
228        if let Some(ref available) = self.capabilities_available {
229            for requested in capabilities {
230                if !available.contains(requested) {
231                    return Err(*requested);
232                }
233            }
234        }
235
236        for requested in capabilities {
237            self.capabilities_used.insert(*requested);
238        }
239
240        Ok(())
241    }
242
243    /// Indicate that the code uses the given extension.
244    pub(super) fn use_extension(&mut self, extension: &'static str) {
245        self.extensions_used.insert(extension);
246    }
247
248    pub(super) fn get_type_id(&mut self, lookup_ty: LookupType) -> Word {
249        match self.lookup_type.entry(lookup_ty) {
250            Entry::Occupied(e) => *e.get(),
251            Entry::Vacant(e) => {
252                let local = match lookup_ty {
253                    LookupType::Handle(_handle) => unreachable!("Handles are populated at start"),
254                    LookupType::Local(local) => local,
255                };
256
257                let id = self.id_gen.next();
258                e.insert(id);
259                self.write_type_declaration_local(id, local);
260                id
261            }
262        }
263    }
264
265    pub(super) fn get_handle_type_id(&mut self, handle: Handle<crate::Type>) -> Word {
266        self.get_type_id(LookupType::Handle(handle))
267    }
268
269    pub(super) fn get_expression_lookup_type(&mut self, tr: &TypeResolution) -> LookupType {
270        match *tr {
271            TypeResolution::Handle(ty_handle) => LookupType::Handle(ty_handle),
272            TypeResolution::Value(ref inner) => {
273                let inner_local_type = self.localtype_from_inner(inner).unwrap();
274                LookupType::Local(inner_local_type)
275            }
276        }
277    }
278
279    pub(super) fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word {
280        let lookup_ty = self.get_expression_lookup_type(tr);
281        self.get_type_id(lookup_ty)
282    }
283
284    pub(super) fn get_localtype_id(&mut self, local: LocalType) -> Word {
285        self.get_type_id(LookupType::Local(local))
286    }
287
288    pub(super) fn get_pointer_type_id(&mut self, base: Word, class: spirv::StorageClass) -> Word {
289        self.get_type_id(LookupType::Local(LocalType::Pointer { base, class }))
290    }
291
292    pub(super) fn get_handle_pointer_type_id(
293        &mut self,
294        base: Handle<crate::Type>,
295        class: spirv::StorageClass,
296    ) -> Word {
297        let base_id = self.get_handle_type_id(base);
298        self.get_pointer_type_id(base_id, class)
299    }
300
301    pub(super) fn get_ray_query_pointer_id(&mut self) -> Word {
302        let rq_id = self.get_type_id(LookupType::Local(LocalType::RayQuery));
303        self.get_pointer_type_id(rq_id, spirv::StorageClass::Function)
304    }
305
306    /// Return a SPIR-V type for a pointer to `resolution`.
307    ///
308    /// The given `resolution` must be one that we can represent
309    /// either as a `LocalType::Pointer` or `LocalType::LocalPointer`.
310    pub(super) fn get_resolution_pointer_id(
311        &mut self,
312        resolution: &TypeResolution,
313        class: spirv::StorageClass,
314    ) -> Word {
315        let resolution_type_id = self.get_expression_type_id(resolution);
316        self.get_pointer_type_id(resolution_type_id, class)
317    }
318
319    pub(super) fn get_numeric_type_id(&mut self, numeric: NumericType) -> Word {
320        self.get_type_id(LocalType::Numeric(numeric).into())
321    }
322
323    pub(super) fn get_u32_type_id(&mut self) -> Word {
324        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32))
325    }
326
327    pub(super) fn get_f32_type_id(&mut self) -> Word {
328        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::F32))
329    }
330
331    pub(super) fn get_vec2u_type_id(&mut self) -> Word {
332        self.get_numeric_type_id(NumericType::Vector {
333            size: crate::VectorSize::Bi,
334            scalar: crate::Scalar::U32,
335        })
336    }
337
338    pub(super) fn get_vec2f_type_id(&mut self) -> Word {
339        self.get_numeric_type_id(NumericType::Vector {
340            size: crate::VectorSize::Bi,
341            scalar: crate::Scalar::F32,
342        })
343    }
344
345    pub(super) fn get_vec3u_type_id(&mut self) -> Word {
346        self.get_numeric_type_id(NumericType::Vector {
347            size: crate::VectorSize::Tri,
348            scalar: crate::Scalar::U32,
349        })
350    }
351
352    pub(super) fn get_f32_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
353        let f32_id = self.get_f32_type_id();
354        self.get_pointer_type_id(f32_id, class)
355    }
356
357    pub(super) fn get_vec2u_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
358        let vec2u_id = self.get_numeric_type_id(NumericType::Vector {
359            size: crate::VectorSize::Bi,
360            scalar: crate::Scalar::U32,
361        });
362        self.get_pointer_type_id(vec2u_id, class)
363    }
364
365    pub(super) fn get_vec3u_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
366        let vec3u_id = self.get_numeric_type_id(NumericType::Vector {
367            size: crate::VectorSize::Tri,
368            scalar: crate::Scalar::U32,
369        });
370        self.get_pointer_type_id(vec3u_id, class)
371    }
372
373    pub(super) fn get_bool_type_id(&mut self) -> Word {
374        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::BOOL))
375    }
376
377    pub(super) fn get_vec2_bool_type_id(&mut self) -> Word {
378        self.get_numeric_type_id(NumericType::Vector {
379            size: crate::VectorSize::Bi,
380            scalar: crate::Scalar::BOOL,
381        })
382    }
383
384    pub(super) fn get_vec3_bool_type_id(&mut self) -> Word {
385        self.get_numeric_type_id(NumericType::Vector {
386            size: crate::VectorSize::Tri,
387            scalar: crate::Scalar::BOOL,
388        })
389    }
390
391    pub(super) fn decorate(&mut self, id: Word, decoration: spirv::Decoration, operands: &[Word]) {
392        self.annotations
393            .push(Instruction::decorate(id, decoration, operands));
394    }
395
396    /// Return `inner` as a `LocalType`, if that's possible.
397    ///
398    /// If `inner` can be represented as a `LocalType`, return
399    /// `Some(local_type)`.
400    ///
401    /// Otherwise, return `None`. In this case, the type must always be looked
402    /// up using a `LookupType::Handle`.
403    fn localtype_from_inner(&mut self, inner: &crate::TypeInner) -> Option<LocalType> {
404        Some(match *inner {
405            crate::TypeInner::Scalar(_)
406            | crate::TypeInner::Atomic(_)
407            | crate::TypeInner::Vector { .. }
408            | crate::TypeInner::Matrix { .. } => {
409                // We expect `NumericType::from_inner` to handle all
410                // these cases, so unwrap.
411                LocalType::Numeric(NumericType::from_inner(inner).unwrap())
412            }
413            crate::TypeInner::Pointer { base, space } => {
414                let base_type_id = self.get_handle_type_id(base);
415                LocalType::Pointer {
416                    base: base_type_id,
417                    class: map_storage_class(space),
418                }
419            }
420            crate::TypeInner::ValuePointer {
421                size,
422                scalar,
423                space,
424            } => {
425                let base_numeric_type = match size {
426                    Some(size) => NumericType::Vector { size, scalar },
427                    None => NumericType::Scalar(scalar),
428                };
429                LocalType::Pointer {
430                    base: self.get_numeric_type_id(base_numeric_type),
431                    class: map_storage_class(space),
432                }
433            }
434            crate::TypeInner::Image {
435                dim,
436                arrayed,
437                class,
438            } => LocalType::Image(LocalImageType::from_inner(dim, arrayed, class)),
439            crate::TypeInner::Sampler { comparison: _ } => LocalType::Sampler,
440            crate::TypeInner::AccelerationStructure { .. } => LocalType::AccelerationStructure,
441            crate::TypeInner::RayQuery { .. } => LocalType::RayQuery,
442            crate::TypeInner::Array { .. }
443            | crate::TypeInner::Struct { .. }
444            | crate::TypeInner::BindingArray { .. } => return None,
445        })
446    }
447
448    /// Emits code for any wrapper functions required by the expressions in ir_function.
449    /// The IDs of any emitted functions will be stored in [`Self::wrapped_functions`].
450    fn write_wrapped_functions(
451        &mut self,
452        ir_function: &crate::Function,
453        info: &FunctionInfo,
454        ir_module: &crate::Module,
455    ) -> Result<(), Error> {
456        log::trace!("Generating wrapped functions for {:?}", ir_function.name);
457
458        for (expr_handle, expr) in ir_function.expressions.iter() {
459            match *expr {
460                crate::Expression::Binary { op, left, right } => {
461                    let expr_ty_inner = info[expr_handle].ty.inner_with(&ir_module.types);
462                    if let Some(expr_ty) = NumericType::from_inner(expr_ty_inner) {
463                        match (op, expr_ty.scalar().kind) {
464                            // Division and modulo are undefined behaviour when the
465                            // dividend is the minimum representable value and the divisor
466                            // is negative one, or when the divisor is zero. These wrapped
467                            // functions override the divisor to one in these cases,
468                            // matching the WGSL spec.
469                            (
470                                crate::BinaryOperator::Divide | crate::BinaryOperator::Modulo,
471                                crate::ScalarKind::Sint | crate::ScalarKind::Uint,
472                            ) => {
473                                self.write_wrapped_binary_op(
474                                    op,
475                                    expr_ty,
476                                    &info[left].ty,
477                                    &info[right].ty,
478                                )?;
479                            }
480                            _ => {}
481                        }
482                    }
483                }
484                _ => {}
485            }
486        }
487
488        Ok(())
489    }
490
491    /// Write a SPIR-V function that performs the operator `op` with Naga IR semantics.
492    ///
493    /// Define a function that performs an integer division or modulo operation,
494    /// except that using a divisor of zero or causing signed overflow with a
495    /// divisor of -1 returns the numerator unchanged, rather than exhibiting
496    /// undefined behavior.
497    ///
498    /// Store the generated function's id in the [`wrapped_functions`] table.
499    ///
500    /// The operator `op` must be either [`Divide`] or [`Modulo`].
501    ///
502    /// # Panics
503    ///
504    /// The `return_type`, `left_type` or `right_type` arguments must all be
505    /// integer scalars or vectors. If not, this function panics.
506    ///
507    /// [`wrapped_functions`]: Writer::wrapped_functions
508    /// [`Divide`]: crate::BinaryOperator::Divide
509    /// [`Modulo`]: crate::BinaryOperator::Modulo
510    fn write_wrapped_binary_op(
511        &mut self,
512        op: crate::BinaryOperator,
513        return_type: NumericType,
514        left_type: &TypeResolution,
515        right_type: &TypeResolution,
516    ) -> Result<(), Error> {
517        let return_type_id = self.get_localtype_id(LocalType::Numeric(return_type));
518        let left_type_id = self.get_expression_type_id(left_type);
519        let right_type_id = self.get_expression_type_id(right_type);
520
521        // Check if we've already emitted this function.
522        let wrapped = WrappedFunction::BinaryOp {
523            op,
524            left_type_id,
525            right_type_id,
526        };
527        let function_id = match self.wrapped_functions.entry(wrapped) {
528            Entry::Occupied(_) => return Ok(()),
529            Entry::Vacant(e) => *e.insert(self.id_gen.next()),
530        };
531
532        let scalar = return_type.scalar();
533
534        if self.flags.contains(WriterFlags::DEBUG) {
535            let function_name = match op {
536                crate::BinaryOperator::Divide => "naga_div",
537                crate::BinaryOperator::Modulo => "naga_mod",
538                _ => unreachable!(),
539            };
540            self.debugs
541                .push(Instruction::name(function_id, function_name));
542        }
543        let mut function = Function::default();
544
545        let function_type_id = self.get_function_type(LookupFunctionType {
546            parameter_type_ids: vec![left_type_id, right_type_id],
547            return_type_id,
548        });
549        function.signature = Some(Instruction::function(
550            return_type_id,
551            function_id,
552            spirv::FunctionControl::empty(),
553            function_type_id,
554        ));
555
556        let lhs_id = self.id_gen.next();
557        let rhs_id = self.id_gen.next();
558        if self.flags.contains(WriterFlags::DEBUG) {
559            self.debugs.push(Instruction::name(lhs_id, "lhs"));
560            self.debugs.push(Instruction::name(rhs_id, "rhs"));
561        }
562        let left_par = Instruction::function_parameter(left_type_id, lhs_id);
563        let right_par = Instruction::function_parameter(right_type_id, rhs_id);
564        for instruction in [left_par, right_par] {
565            function.parameters.push(FunctionArgument {
566                instruction,
567                handle_id: 0,
568            });
569        }
570
571        let label_id = self.id_gen.next();
572        let mut block = Block::new(label_id);
573
574        let bool_type = return_type.with_scalar(crate::Scalar::BOOL);
575        let bool_type_id = self.get_numeric_type_id(bool_type);
576
577        let maybe_splat_const = |writer: &mut Self, const_id| match return_type {
578            NumericType::Scalar(_) => const_id,
579            NumericType::Vector { size, .. } => {
580                let constituent_ids = [const_id; crate::VectorSize::MAX];
581                writer.get_constant_composite(
582                    LookupType::Local(LocalType::Numeric(return_type)),
583                    &constituent_ids[..size as usize],
584                )
585            }
586            NumericType::Matrix { .. } => unreachable!(),
587        };
588
589        let const_zero_id = self.get_constant_scalar_with(0, scalar)?;
590        let composite_zero_id = maybe_splat_const(self, const_zero_id);
591        let rhs_eq_zero_id = self.id_gen.next();
592        block.body.push(Instruction::binary(
593            spirv::Op::IEqual,
594            bool_type_id,
595            rhs_eq_zero_id,
596            rhs_id,
597            composite_zero_id,
598        ));
599        let divisor_selector_id = match scalar.kind {
600            crate::ScalarKind::Sint => {
601                let (const_min_id, const_neg_one_id) = match scalar.width {
602                    4 => Ok((
603                        self.get_constant_scalar(crate::Literal::I32(i32::MIN)),
604                        self.get_constant_scalar(crate::Literal::I32(-1i32)),
605                    )),
606                    8 => Ok((
607                        self.get_constant_scalar(crate::Literal::I64(i64::MIN)),
608                        self.get_constant_scalar(crate::Literal::I64(-1i64)),
609                    )),
610                    _ => Err(Error::Validation("Unexpected scalar width")),
611                }?;
612                let composite_min_id = maybe_splat_const(self, const_min_id);
613                let composite_neg_one_id = maybe_splat_const(self, const_neg_one_id);
614
615                let lhs_eq_int_min_id = self.id_gen.next();
616                block.body.push(Instruction::binary(
617                    spirv::Op::IEqual,
618                    bool_type_id,
619                    lhs_eq_int_min_id,
620                    lhs_id,
621                    composite_min_id,
622                ));
623                let rhs_eq_neg_one_id = self.id_gen.next();
624                block.body.push(Instruction::binary(
625                    spirv::Op::IEqual,
626                    bool_type_id,
627                    rhs_eq_neg_one_id,
628                    rhs_id,
629                    composite_neg_one_id,
630                ));
631                let lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next();
632                block.body.push(Instruction::binary(
633                    spirv::Op::LogicalAnd,
634                    bool_type_id,
635                    lhs_eq_int_min_and_rhs_eq_neg_one_id,
636                    lhs_eq_int_min_id,
637                    rhs_eq_neg_one_id,
638                ));
639                let rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next();
640                block.body.push(Instruction::binary(
641                    spirv::Op::LogicalOr,
642                    bool_type_id,
643                    rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id,
644                    rhs_eq_zero_id,
645                    lhs_eq_int_min_and_rhs_eq_neg_one_id,
646                ));
647                rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id
648            }
649            crate::ScalarKind::Uint => rhs_eq_zero_id,
650            _ => unreachable!(),
651        };
652
653        let const_one_id = self.get_constant_scalar_with(1, scalar)?;
654        let composite_one_id = maybe_splat_const(self, const_one_id);
655        let divisor_id = self.id_gen.next();
656        block.body.push(Instruction::select(
657            right_type_id,
658            divisor_id,
659            divisor_selector_id,
660            composite_one_id,
661            rhs_id,
662        ));
663        let op = match (op, scalar.kind) {
664            (crate::BinaryOperator::Divide, crate::ScalarKind::Sint) => spirv::Op::SDiv,
665            (crate::BinaryOperator::Divide, crate::ScalarKind::Uint) => spirv::Op::UDiv,
666            (crate::BinaryOperator::Modulo, crate::ScalarKind::Sint) => spirv::Op::SRem,
667            (crate::BinaryOperator::Modulo, crate::ScalarKind::Uint) => spirv::Op::UMod,
668            _ => unreachable!(),
669        };
670        let return_id = self.id_gen.next();
671        block.body.push(Instruction::binary(
672            op,
673            return_type_id,
674            return_id,
675            lhs_id,
676            divisor_id,
677        ));
678
679        function.consume(block, Instruction::return_value(return_id));
680        function.to_words(&mut self.logical_layout.function_definitions);
681        Ok(())
682    }
683
684    fn write_function(
685        &mut self,
686        ir_function: &crate::Function,
687        info: &FunctionInfo,
688        ir_module: &crate::Module,
689        mut interface: Option<FunctionInterface>,
690        debug_info: &Option<DebugInfoInner>,
691    ) -> Result<Word, Error> {
692        self.write_wrapped_functions(ir_function, info, ir_module)?;
693
694        log::trace!("Generating code for {:?}", ir_function.name);
695        let mut function = Function::default();
696
697        let prelude_id = self.id_gen.next();
698        let mut prelude = Block::new(prelude_id);
699        let mut ep_context = EntryPointContext {
700            argument_ids: Vec::new(),
701            results: Vec::new(),
702        };
703
704        let mut local_invocation_id = None;
705
706        let mut parameter_type_ids = Vec::with_capacity(ir_function.arguments.len());
707        for argument in ir_function.arguments.iter() {
708            let class = spirv::StorageClass::Input;
709            let handle_ty = ir_module.types[argument.ty].inner.is_handle();
710            let argument_type_id = if handle_ty {
711                self.get_handle_pointer_type_id(argument.ty, spirv::StorageClass::UniformConstant)
712            } else {
713                self.get_handle_type_id(argument.ty)
714            };
715
716            if let Some(ref mut iface) = interface {
717                let id = if let Some(ref binding) = argument.binding {
718                    let name = argument.name.as_deref();
719
720                    let varying_id = self.write_varying(
721                        ir_module,
722                        iface.stage,
723                        class,
724                        name,
725                        argument.ty,
726                        binding,
727                    )?;
728                    iface.varying_ids.push(varying_id);
729                    let id = self.id_gen.next();
730                    prelude
731                        .body
732                        .push(Instruction::load(argument_type_id, id, varying_id, None));
733
734                    if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
735                        local_invocation_id = Some(id);
736                    }
737
738                    id
739                } else if let crate::TypeInner::Struct { ref members, .. } =
740                    ir_module.types[argument.ty].inner
741                {
742                    let struct_id = self.id_gen.next();
743                    let mut constituent_ids = Vec::with_capacity(members.len());
744                    for member in members {
745                        let type_id = self.get_handle_type_id(member.ty);
746                        let name = member.name.as_deref();
747                        let binding = member.binding.as_ref().unwrap();
748                        let varying_id = self.write_varying(
749                            ir_module,
750                            iface.stage,
751                            class,
752                            name,
753                            member.ty,
754                            binding,
755                        )?;
756                        iface.varying_ids.push(varying_id);
757                        let id = self.id_gen.next();
758                        prelude
759                            .body
760                            .push(Instruction::load(type_id, id, varying_id, None));
761                        constituent_ids.push(id);
762
763                        if binding == &crate::Binding::BuiltIn(crate::BuiltIn::GlobalInvocationId) {
764                            local_invocation_id = Some(id);
765                        }
766                    }
767                    prelude.body.push(Instruction::composite_construct(
768                        argument_type_id,
769                        struct_id,
770                        &constituent_ids,
771                    ));
772                    struct_id
773                } else {
774                    unreachable!("Missing argument binding on an entry point");
775                };
776                ep_context.argument_ids.push(id);
777            } else {
778                let argument_id = self.id_gen.next();
779                let instruction = Instruction::function_parameter(argument_type_id, argument_id);
780                if self.flags.contains(WriterFlags::DEBUG) {
781                    if let Some(ref name) = argument.name {
782                        self.debugs.push(Instruction::name(argument_id, name));
783                    }
784                }
785                function.parameters.push(FunctionArgument {
786                    instruction,
787                    handle_id: if handle_ty {
788                        let id = self.id_gen.next();
789                        prelude.body.push(Instruction::load(
790                            self.get_handle_type_id(argument.ty),
791                            id,
792                            argument_id,
793                            None,
794                        ));
795                        id
796                    } else {
797                        0
798                    },
799                });
800                parameter_type_ids.push(argument_type_id);
801            };
802        }
803
804        let return_type_id = match ir_function.result {
805            Some(ref result) => {
806                if let Some(ref mut iface) = interface {
807                    let mut has_point_size = false;
808                    let class = spirv::StorageClass::Output;
809                    if let Some(ref binding) = result.binding {
810                        has_point_size |=
811                            *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
812                        let type_id = self.get_handle_type_id(result.ty);
813                        let varying_id = self.write_varying(
814                            ir_module,
815                            iface.stage,
816                            class,
817                            None,
818                            result.ty,
819                            binding,
820                        )?;
821                        iface.varying_ids.push(varying_id);
822                        ep_context.results.push(ResultMember {
823                            id: varying_id,
824                            type_id,
825                            built_in: binding.to_built_in(),
826                        });
827                    } else if let crate::TypeInner::Struct { ref members, .. } =
828                        ir_module.types[result.ty].inner
829                    {
830                        for member in members {
831                            let type_id = self.get_handle_type_id(member.ty);
832                            let name = member.name.as_deref();
833                            let binding = member.binding.as_ref().unwrap();
834                            has_point_size |=
835                                *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
836                            let varying_id = self.write_varying(
837                                ir_module,
838                                iface.stage,
839                                class,
840                                name,
841                                member.ty,
842                                binding,
843                            )?;
844                            iface.varying_ids.push(varying_id);
845                            ep_context.results.push(ResultMember {
846                                id: varying_id,
847                                type_id,
848                                built_in: binding.to_built_in(),
849                            });
850                        }
851                    } else {
852                        unreachable!("Missing result binding on an entry point");
853                    }
854
855                    if self.flags.contains(WriterFlags::FORCE_POINT_SIZE)
856                        && iface.stage == crate::ShaderStage::Vertex
857                        && !has_point_size
858                    {
859                        // add point size artificially
860                        let varying_id = self.id_gen.next();
861                        let pointer_type_id = self.get_f32_pointer_type_id(class);
862                        Instruction::variable(pointer_type_id, varying_id, class, None)
863                            .to_words(&mut self.logical_layout.declarations);
864                        self.decorate(
865                            varying_id,
866                            spirv::Decoration::BuiltIn,
867                            &[spirv::BuiltIn::PointSize as u32],
868                        );
869                        iface.varying_ids.push(varying_id);
870
871                        let default_value_id = self.get_constant_scalar(crate::Literal::F32(1.0));
872                        prelude
873                            .body
874                            .push(Instruction::store(varying_id, default_value_id, None));
875                    }
876                    self.void_type
877                } else {
878                    self.get_handle_type_id(result.ty)
879                }
880            }
881            None => self.void_type,
882        };
883
884        let lookup_function_type = LookupFunctionType {
885            parameter_type_ids,
886            return_type_id,
887        };
888
889        let function_id = self.id_gen.next();
890        if self.flags.contains(WriterFlags::DEBUG) {
891            if let Some(ref name) = ir_function.name {
892                self.debugs.push(Instruction::name(function_id, name));
893            }
894        }
895
896        let function_type = self.get_function_type(lookup_function_type);
897        function.signature = Some(Instruction::function(
898            return_type_id,
899            function_id,
900            spirv::FunctionControl::empty(),
901            function_type,
902        ));
903
904        if interface.is_some() {
905            function.entry_point_context = Some(ep_context);
906        }
907
908        // fill up the `GlobalVariable::access_id`
909        for gv in self.global_variables.iter_mut() {
910            gv.reset_for_function();
911        }
912        for (handle, var) in ir_module.global_variables.iter() {
913            if info[handle].is_empty() {
914                continue;
915            }
916
917            let mut gv = self.global_variables[handle].clone();
918            if let Some(ref mut iface) = interface {
919                // Have to include global variables in the interface
920                if self.physical_layout.version >= 0x10400 {
921                    iface.varying_ids.push(gv.var_id);
922                }
923            }
924
925            // Handle globals are pre-emitted and should be loaded automatically.
926            //
927            // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
928            match ir_module.types[var.ty].inner {
929                crate::TypeInner::BindingArray { .. } => {
930                    gv.access_id = gv.var_id;
931                }
932                _ => {
933                    if var.space == crate::AddressSpace::Handle {
934                        let var_type_id = self.get_handle_type_id(var.ty);
935                        let id = self.id_gen.next();
936                        prelude
937                            .body
938                            .push(Instruction::load(var_type_id, id, gv.var_id, None));
939                        gv.access_id = gv.var_id;
940                        gv.handle_id = id;
941                    } else if global_needs_wrapper(ir_module, var) {
942                        let class = map_storage_class(var.space);
943                        let pointer_type_id = self.get_handle_pointer_type_id(var.ty, class);
944                        let index_id = self.get_index_constant(0);
945                        let id = self.id_gen.next();
946                        prelude.body.push(Instruction::access_chain(
947                            pointer_type_id,
948                            id,
949                            gv.var_id,
950                            &[index_id],
951                        ));
952                        gv.access_id = id;
953                    } else {
954                        // by default, the variable ID is accessed as is
955                        gv.access_id = gv.var_id;
956                    };
957                }
958            }
959
960            // work around borrow checking in the presence of `self.xxx()` calls
961            self.global_variables[handle] = gv;
962        }
963
964        // Create a `BlockContext` for generating SPIR-V for the function's
965        // body.
966        let mut context = BlockContext {
967            ir_module,
968            ir_function,
969            fun_info: info,
970            function: &mut function,
971            // Re-use the cached expression table from prior functions.
972            cached: core::mem::take(&mut self.saved_cached),
973
974            // Steal the Writer's temp list for a bit.
975            temp_list: core::mem::take(&mut self.temp_list),
976            force_loop_bounding: self.force_loop_bounding,
977            writer: self,
978            expression_constness: super::ExpressionConstnessTracker::from_arena(
979                &ir_function.expressions,
980            ),
981        };
982
983        // fill up the pre-emitted and const expressions
984        context.cached.reset(ir_function.expressions.len());
985        for (handle, expr) in ir_function.expressions.iter() {
986            if (expr.needs_pre_emit() && !matches!(*expr, crate::Expression::LocalVariable(_)))
987                || context.expression_constness.is_const(handle)
988            {
989                context.cache_expression_value(handle, &mut prelude)?;
990            }
991        }
992
993        for (handle, variable) in ir_function.local_variables.iter() {
994            let id = context.gen_id();
995
996            if context.writer.flags.contains(WriterFlags::DEBUG) {
997                if let Some(ref name) = variable.name {
998                    context.writer.debugs.push(Instruction::name(id, name));
999                }
1000            }
1001
1002            let init_word = variable.init.map(|constant| context.cached[constant]);
1003            let pointer_type_id = context
1004                .writer
1005                .get_handle_pointer_type_id(variable.ty, spirv::StorageClass::Function);
1006            let instruction = Instruction::variable(
1007                pointer_type_id,
1008                id,
1009                spirv::StorageClass::Function,
1010                init_word.or_else(|| match ir_module.types[variable.ty].inner {
1011                    crate::TypeInner::RayQuery { .. } => None,
1012                    _ => {
1013                        let type_id = context.get_handle_type_id(variable.ty);
1014                        Some(context.writer.write_constant_null(type_id))
1015                    }
1016                }),
1017            );
1018            context
1019                .function
1020                .variables
1021                .insert(handle, LocalVariable { id, instruction });
1022        }
1023
1024        for (handle, expr) in ir_function.expressions.iter() {
1025            match *expr {
1026                crate::Expression::LocalVariable(_) => {
1027                    // Cache the `OpVariable` instruction we generated above as
1028                    // the value of this expression.
1029                    context.cache_expression_value(handle, &mut prelude)?;
1030                }
1031                crate::Expression::Access { base, .. }
1032                | crate::Expression::AccessIndex { base, .. } => {
1033                    // Count references to `base` by `Access` and `AccessIndex`
1034                    // instructions. See `access_uses` for details.
1035                    *context.function.access_uses.entry(base).or_insert(0) += 1;
1036                }
1037                _ => {}
1038            }
1039        }
1040
1041        let next_id = context.gen_id();
1042
1043        context
1044            .function
1045            .consume(prelude, Instruction::branch(next_id));
1046
1047        let workgroup_vars_init_exit_block_id =
1048            match (context.writer.zero_initialize_workgroup_memory, interface) {
1049                (
1050                    super::ZeroInitializeWorkgroupMemoryMode::Polyfill,
1051                    Some(
1052                        ref mut interface @ FunctionInterface {
1053                            stage: crate::ShaderStage::Compute,
1054                            ..
1055                        },
1056                    ),
1057                ) => context.writer.generate_workgroup_vars_init_block(
1058                    next_id,
1059                    ir_module,
1060                    info,
1061                    local_invocation_id,
1062                    interface,
1063                    context.function,
1064                ),
1065                _ => None,
1066            };
1067
1068        let main_id = if let Some(exit_id) = workgroup_vars_init_exit_block_id {
1069            exit_id
1070        } else {
1071            next_id
1072        };
1073
1074        context.write_function_body(main_id, debug_info.as_ref())?;
1075
1076        // Consume the `BlockContext`, ending its borrows and letting the
1077        // `Writer` steal back its cached expression table and temp_list.
1078        let BlockContext {
1079            cached, temp_list, ..
1080        } = context;
1081        self.saved_cached = cached;
1082        self.temp_list = temp_list;
1083
1084        function.to_words(&mut self.logical_layout.function_definitions);
1085
1086        Ok(function_id)
1087    }
1088
1089    fn write_execution_mode(
1090        &mut self,
1091        function_id: Word,
1092        mode: spirv::ExecutionMode,
1093    ) -> Result<(), Error> {
1094        //self.check(mode.required_capabilities())?;
1095        Instruction::execution_mode(function_id, mode, &[])
1096            .to_words(&mut self.logical_layout.execution_modes);
1097        Ok(())
1098    }
1099
1100    // TODO Move to instructions module
1101    fn write_entry_point(
1102        &mut self,
1103        entry_point: &crate::EntryPoint,
1104        info: &FunctionInfo,
1105        ir_module: &crate::Module,
1106        debug_info: &Option<DebugInfoInner>,
1107    ) -> Result<Instruction, Error> {
1108        let mut interface_ids = Vec::new();
1109        let function_id = self.write_function(
1110            &entry_point.function,
1111            info,
1112            ir_module,
1113            Some(FunctionInterface {
1114                varying_ids: &mut interface_ids,
1115                stage: entry_point.stage,
1116            }),
1117            debug_info,
1118        )?;
1119
1120        let exec_model = match entry_point.stage {
1121            crate::ShaderStage::Vertex => spirv::ExecutionModel::Vertex,
1122            crate::ShaderStage::Fragment => {
1123                self.write_execution_mode(function_id, spirv::ExecutionMode::OriginUpperLeft)?;
1124                match entry_point.early_depth_test {
1125                    Some(crate::EarlyDepthTest::Force) => {
1126                        self.write_execution_mode(
1127                            function_id,
1128                            spirv::ExecutionMode::EarlyFragmentTests,
1129                        )?;
1130                    }
1131                    Some(crate::EarlyDepthTest::Allow { conservative }) => {
1132                        // TODO: Consider emitting EarlyAndLateFragmentTestsAMD here, if available.
1133                        // https://github.khronos.org/SPIRV-Registry/extensions/AMD/SPV_AMD_shader_early_and_late_fragment_tests.html
1134                        // This permits early depth tests even if the shader writes to a storage
1135                        // binding
1136                        match conservative {
1137                            crate::ConservativeDepth::GreaterEqual => self.write_execution_mode(
1138                                function_id,
1139                                spirv::ExecutionMode::DepthGreater,
1140                            )?,
1141                            crate::ConservativeDepth::LessEqual => self.write_execution_mode(
1142                                function_id,
1143                                spirv::ExecutionMode::DepthLess,
1144                            )?,
1145                            crate::ConservativeDepth::Unchanged => self.write_execution_mode(
1146                                function_id,
1147                                spirv::ExecutionMode::DepthUnchanged,
1148                            )?,
1149                        }
1150                    }
1151                    None => {}
1152                }
1153                if let Some(ref result) = entry_point.function.result {
1154                    if contains_builtin(
1155                        result.binding.as_ref(),
1156                        result.ty,
1157                        &ir_module.types,
1158                        crate::BuiltIn::FragDepth,
1159                    ) {
1160                        self.write_execution_mode(
1161                            function_id,
1162                            spirv::ExecutionMode::DepthReplacing,
1163                        )?;
1164                    }
1165                }
1166                spirv::ExecutionModel::Fragment
1167            }
1168            crate::ShaderStage::Compute => {
1169                let execution_mode = spirv::ExecutionMode::LocalSize;
1170                //self.check(execution_mode.required_capabilities())?;
1171                Instruction::execution_mode(
1172                    function_id,
1173                    execution_mode,
1174                    &entry_point.workgroup_size,
1175                )
1176                .to_words(&mut self.logical_layout.execution_modes);
1177                spirv::ExecutionModel::GLCompute
1178            }
1179            crate::ShaderStage::Task | crate::ShaderStage::Mesh => unreachable!(),
1180        };
1181        //self.check(exec_model.required_capabilities())?;
1182
1183        Ok(Instruction::entry_point(
1184            exec_model,
1185            function_id,
1186            &entry_point.name,
1187            interface_ids.as_slice(),
1188        ))
1189    }
1190
1191    fn make_scalar(&mut self, id: Word, scalar: crate::Scalar) -> Instruction {
1192        use crate::ScalarKind as Sk;
1193
1194        let bits = (scalar.width * BITS_PER_BYTE) as u32;
1195        match scalar.kind {
1196            Sk::Sint | Sk::Uint => {
1197                let signedness = if scalar.kind == Sk::Sint {
1198                    super::instructions::Signedness::Signed
1199                } else {
1200                    super::instructions::Signedness::Unsigned
1201                };
1202                let cap = match bits {
1203                    8 => Some(spirv::Capability::Int8),
1204                    16 => Some(spirv::Capability::Int16),
1205                    64 => Some(spirv::Capability::Int64),
1206                    _ => None,
1207                };
1208                if let Some(cap) = cap {
1209                    self.capabilities_used.insert(cap);
1210                }
1211                Instruction::type_int(id, bits, signedness)
1212            }
1213            Sk::Float => {
1214                if bits == 64 {
1215                    self.capabilities_used.insert(spirv::Capability::Float64);
1216                }
1217                if bits == 16 {
1218                    self.capabilities_used.insert(spirv::Capability::Float16);
1219                    self.capabilities_used
1220                        .insert(spirv::Capability::StorageBuffer16BitAccess);
1221                    self.capabilities_used
1222                        .insert(spirv::Capability::UniformAndStorageBuffer16BitAccess);
1223                    self.capabilities_used
1224                        .insert(spirv::Capability::StorageInputOutput16);
1225                }
1226                Instruction::type_float(id, bits)
1227            }
1228            Sk::Bool => Instruction::type_bool(id),
1229            Sk::AbstractInt | Sk::AbstractFloat => {
1230                unreachable!("abstract types should never reach the backend");
1231            }
1232        }
1233    }
1234
1235    fn request_type_capabilities(&mut self, inner: &crate::TypeInner) -> Result<(), Error> {
1236        match *inner {
1237            crate::TypeInner::Image {
1238                dim,
1239                arrayed,
1240                class,
1241            } => {
1242                let sampled = match class {
1243                    crate::ImageClass::Sampled { .. } => true,
1244                    crate::ImageClass::Depth { .. } => true,
1245                    crate::ImageClass::Storage { format, .. } => {
1246                        self.request_image_format_capabilities(format.into())?;
1247                        false
1248                    }
1249                };
1250
1251                match dim {
1252                    crate::ImageDimension::D1 => {
1253                        if sampled {
1254                            self.require_any("sampled 1D images", &[spirv::Capability::Sampled1D])?;
1255                        } else {
1256                            self.require_any("1D storage images", &[spirv::Capability::Image1D])?;
1257                        }
1258                    }
1259                    crate::ImageDimension::Cube if arrayed => {
1260                        if sampled {
1261                            self.require_any(
1262                                "sampled cube array images",
1263                                &[spirv::Capability::SampledCubeArray],
1264                            )?;
1265                        } else {
1266                            self.require_any(
1267                                "cube array storage images",
1268                                &[spirv::Capability::ImageCubeArray],
1269                            )?;
1270                        }
1271                    }
1272                    _ => {}
1273                }
1274            }
1275            crate::TypeInner::AccelerationStructure { .. } => {
1276                self.require_any("Acceleration Structure", &[spirv::Capability::RayQueryKHR])?;
1277            }
1278            crate::TypeInner::RayQuery { .. } => {
1279                self.require_any("Ray Query", &[spirv::Capability::RayQueryKHR])?;
1280            }
1281            crate::TypeInner::Atomic(crate::Scalar { width: 8, kind: _ }) => {
1282                self.require_any("64 bit integer atomics", &[spirv::Capability::Int64Atomics])?;
1283            }
1284            crate::TypeInner::Atomic(crate::Scalar {
1285                width: 4,
1286                kind: crate::ScalarKind::Float,
1287            }) => {
1288                self.require_any(
1289                    "32 bit floating-point atomics",
1290                    &[spirv::Capability::AtomicFloat32AddEXT],
1291                )?;
1292                self.use_extension("SPV_EXT_shader_atomic_float_add");
1293            }
1294            // 16 bit floating-point support requires Float16 capability
1295            crate::TypeInner::Matrix {
1296                scalar: crate::Scalar::F16,
1297                ..
1298            }
1299            | crate::TypeInner::Vector {
1300                scalar: crate::Scalar::F16,
1301                ..
1302            }
1303            | crate::TypeInner::Scalar(crate::Scalar::F16) => {
1304                self.require_any("16 bit floating-point", &[spirv::Capability::Float16])?;
1305                self.use_extension("SPV_KHR_16bit_storage");
1306            }
1307            _ => {}
1308        }
1309        Ok(())
1310    }
1311
1312    fn write_numeric_type_declaration_local(&mut self, id: Word, numeric: NumericType) {
1313        let instruction = match numeric {
1314            NumericType::Scalar(scalar) => self.make_scalar(id, scalar),
1315            NumericType::Vector { size, scalar } => {
1316                let scalar_id = self.get_numeric_type_id(NumericType::Scalar(scalar));
1317                Instruction::type_vector(id, scalar_id, size)
1318            }
1319            NumericType::Matrix {
1320                columns,
1321                rows,
1322                scalar,
1323            } => {
1324                let column_id =
1325                    self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
1326                Instruction::type_matrix(id, column_id, columns)
1327            }
1328        };
1329
1330        instruction.to_words(&mut self.logical_layout.declarations);
1331    }
1332
1333    fn write_type_declaration_local(&mut self, id: Word, local_ty: LocalType) {
1334        let instruction = match local_ty {
1335            LocalType::Numeric(numeric) => {
1336                self.write_numeric_type_declaration_local(id, numeric);
1337                return;
1338            }
1339            LocalType::Pointer { base, class } => Instruction::type_pointer(id, class, base),
1340            LocalType::Image(image) => {
1341                let local_type = LocalType::Numeric(NumericType::Scalar(image.sampled_type));
1342                let type_id = self.get_localtype_id(local_type);
1343                Instruction::type_image(id, type_id, image.dim, image.flags, image.image_format)
1344            }
1345            LocalType::Sampler => Instruction::type_sampler(id),
1346            LocalType::SampledImage { image_type_id } => {
1347                Instruction::type_sampled_image(id, image_type_id)
1348            }
1349            LocalType::BindingArray { base, size } => {
1350                let inner_ty = self.get_handle_type_id(base);
1351                let scalar_id = self.get_constant_scalar(crate::Literal::U32(size));
1352                Instruction::type_array(id, inner_ty, scalar_id)
1353            }
1354            LocalType::AccelerationStructure => Instruction::type_acceleration_structure(id),
1355            LocalType::RayQuery => Instruction::type_ray_query(id),
1356        };
1357
1358        instruction.to_words(&mut self.logical_layout.declarations);
1359    }
1360
1361    fn write_type_declaration_arena(
1362        &mut self,
1363        module: &crate::Module,
1364        handle: Handle<crate::Type>,
1365    ) -> Result<Word, Error> {
1366        let ty = &module.types[handle];
1367        // If it's a type that needs SPIR-V capabilities, request them now.
1368        // This needs to happen regardless of the LocalType lookup succeeding,
1369        // because some types which map to the same LocalType have different
1370        // capability requirements. See https://github.com/gfx-rs/wgpu/issues/5569
1371        self.request_type_capabilities(&ty.inner)?;
1372        let id = if let Some(local) = self.localtype_from_inner(&ty.inner) {
1373            // This type can be represented as a `LocalType`, so check if we've
1374            // already written an instruction for it. If not, do so now, with
1375            // `write_type_declaration_local`.
1376            match self.lookup_type.entry(LookupType::Local(local)) {
1377                // We already have an id for this `LocalType`.
1378                Entry::Occupied(e) => *e.get(),
1379
1380                // It's a type we haven't seen before.
1381                Entry::Vacant(e) => {
1382                    let id = self.id_gen.next();
1383                    e.insert(id);
1384
1385                    self.write_type_declaration_local(id, local);
1386
1387                    id
1388                }
1389            }
1390        } else {
1391            use spirv::Decoration;
1392
1393            let id = self.id_gen.next();
1394            let instruction = match ty.inner {
1395                crate::TypeInner::Array { base, size, stride } => {
1396                    self.decorate(id, Decoration::ArrayStride, &[stride]);
1397
1398                    let type_id = self.get_handle_type_id(base);
1399                    match size.resolve(module.to_ctx())? {
1400                        crate::proc::IndexableLength::Known(length) => {
1401                            let length_id = self.get_index_constant(length);
1402                            Instruction::type_array(id, type_id, length_id)
1403                        }
1404                        crate::proc::IndexableLength::Dynamic => {
1405                            Instruction::type_runtime_array(id, type_id)
1406                        }
1407                    }
1408                }
1409                crate::TypeInner::BindingArray { base, size } => {
1410                    let type_id = self.get_handle_type_id(base);
1411                    match size.resolve(module.to_ctx())? {
1412                        crate::proc::IndexableLength::Known(length) => {
1413                            let length_id = self.get_index_constant(length);
1414                            Instruction::type_array(id, type_id, length_id)
1415                        }
1416                        crate::proc::IndexableLength::Dynamic => {
1417                            Instruction::type_runtime_array(id, type_id)
1418                        }
1419                    }
1420                }
1421                crate::TypeInner::Struct {
1422                    ref members,
1423                    span: _,
1424                } => {
1425                    let mut has_runtime_array = false;
1426                    let mut member_ids = Vec::with_capacity(members.len());
1427                    for (index, member) in members.iter().enumerate() {
1428                        let member_ty = &module.types[member.ty];
1429                        match member_ty.inner {
1430                            crate::TypeInner::Array {
1431                                base: _,
1432                                size: crate::ArraySize::Dynamic,
1433                                stride: _,
1434                            } => {
1435                                has_runtime_array = true;
1436                            }
1437                            _ => (),
1438                        }
1439                        self.decorate_struct_member(id, index, member, &module.types)?;
1440                        let member_id = self.get_handle_type_id(member.ty);
1441                        member_ids.push(member_id);
1442                    }
1443                    if has_runtime_array {
1444                        self.decorate(id, Decoration::Block, &[]);
1445                    }
1446                    Instruction::type_struct(id, member_ids.as_slice())
1447                }
1448
1449                // These all have TypeLocal representations, so they should have been
1450                // handled by `write_type_declaration_local` above.
1451                crate::TypeInner::Scalar(_)
1452                | crate::TypeInner::Atomic(_)
1453                | crate::TypeInner::Vector { .. }
1454                | crate::TypeInner::Matrix { .. }
1455                | crate::TypeInner::Pointer { .. }
1456                | crate::TypeInner::ValuePointer { .. }
1457                | crate::TypeInner::Image { .. }
1458                | crate::TypeInner::Sampler { .. }
1459                | crate::TypeInner::AccelerationStructure { .. }
1460                | crate::TypeInner::RayQuery { .. } => unreachable!(),
1461            };
1462
1463            instruction.to_words(&mut self.logical_layout.declarations);
1464            id
1465        };
1466
1467        // Add this handle as a new alias for that type.
1468        self.lookup_type.insert(LookupType::Handle(handle), id);
1469
1470        if self.flags.contains(WriterFlags::DEBUG) {
1471            if let Some(ref name) = ty.name {
1472                self.debugs.push(Instruction::name(id, name));
1473            }
1474        }
1475
1476        Ok(id)
1477    }
1478
1479    fn request_image_format_capabilities(
1480        &mut self,
1481        format: spirv::ImageFormat,
1482    ) -> Result<(), Error> {
1483        use spirv::ImageFormat as If;
1484        match format {
1485            If::Rg32f
1486            | If::Rg16f
1487            | If::R11fG11fB10f
1488            | If::R16f
1489            | If::Rgba16
1490            | If::Rgb10A2
1491            | If::Rg16
1492            | If::Rg8
1493            | If::R16
1494            | If::R8
1495            | If::Rgba16Snorm
1496            | If::Rg16Snorm
1497            | If::Rg8Snorm
1498            | If::R16Snorm
1499            | If::R8Snorm
1500            | If::Rg32i
1501            | If::Rg16i
1502            | If::Rg8i
1503            | If::R16i
1504            | If::R8i
1505            | If::Rgb10a2ui
1506            | If::Rg32ui
1507            | If::Rg16ui
1508            | If::Rg8ui
1509            | If::R16ui
1510            | If::R8ui => self.require_any(
1511                "storage image format",
1512                &[spirv::Capability::StorageImageExtendedFormats],
1513            ),
1514            If::R64ui | If::R64i => {
1515                self.use_extension("SPV_EXT_shader_image_int64");
1516                self.require_any(
1517                    "64-bit integer storage image format",
1518                    &[spirv::Capability::Int64ImageEXT],
1519                )
1520            }
1521            If::Unknown
1522            | If::Rgba32f
1523            | If::Rgba16f
1524            | If::R32f
1525            | If::Rgba8
1526            | If::Rgba8Snorm
1527            | If::Rgba32i
1528            | If::Rgba16i
1529            | If::Rgba8i
1530            | If::R32i
1531            | If::Rgba32ui
1532            | If::Rgba16ui
1533            | If::Rgba8ui
1534            | If::R32ui => Ok(()),
1535        }
1536    }
1537
1538    pub(super) fn get_index_constant(&mut self, index: Word) -> Word {
1539        self.get_constant_scalar(crate::Literal::U32(index))
1540    }
1541
1542    pub(super) fn get_constant_scalar_with(
1543        &mut self,
1544        value: u8,
1545        scalar: crate::Scalar,
1546    ) -> Result<Word, Error> {
1547        Ok(
1548            self.get_constant_scalar(crate::Literal::new(value, scalar).ok_or(
1549                Error::Validation("Unexpected kind and/or width for Literal"),
1550            )?),
1551        )
1552    }
1553
1554    pub(super) fn get_constant_scalar(&mut self, value: crate::Literal) -> Word {
1555        let scalar = CachedConstant::Literal(value.into());
1556        if let Some(&id) = self.cached_constants.get(&scalar) {
1557            return id;
1558        }
1559        let id = self.id_gen.next();
1560        self.write_constant_scalar(id, &value, None);
1561        self.cached_constants.insert(scalar, id);
1562        id
1563    }
1564
1565    fn write_constant_scalar(
1566        &mut self,
1567        id: Word,
1568        value: &crate::Literal,
1569        debug_name: Option<&String>,
1570    ) {
1571        if self.flags.contains(WriterFlags::DEBUG) {
1572            if let Some(name) = debug_name {
1573                self.debugs.push(Instruction::name(id, name));
1574            }
1575        }
1576        let type_id = self.get_numeric_type_id(NumericType::Scalar(value.scalar()));
1577        let instruction = match *value {
1578            crate::Literal::F64(value) => {
1579                let bits = value.to_bits();
1580                Instruction::constant_64bit(type_id, id, bits as u32, (bits >> 32) as u32)
1581            }
1582            crate::Literal::F32(value) => Instruction::constant_32bit(type_id, id, value.to_bits()),
1583            crate::Literal::F16(value) => {
1584                let low = value.to_bits();
1585                Instruction::constant_16bit(type_id, id, low as u32)
1586            }
1587            crate::Literal::U32(value) => Instruction::constant_32bit(type_id, id, value),
1588            crate::Literal::I32(value) => Instruction::constant_32bit(type_id, id, value as u32),
1589            crate::Literal::U64(value) => {
1590                Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32)
1591            }
1592            crate::Literal::I64(value) => {
1593                Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32)
1594            }
1595            crate::Literal::Bool(true) => Instruction::constant_true(type_id, id),
1596            crate::Literal::Bool(false) => Instruction::constant_false(type_id, id),
1597            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1598                unreachable!("Abstract types should not appear in IR presented to backends");
1599            }
1600        };
1601
1602        instruction.to_words(&mut self.logical_layout.declarations);
1603    }
1604
1605    pub(super) fn get_constant_composite(
1606        &mut self,
1607        ty: LookupType,
1608        constituent_ids: &[Word],
1609    ) -> Word {
1610        let composite = CachedConstant::Composite {
1611            ty,
1612            constituent_ids: constituent_ids.to_vec(),
1613        };
1614        if let Some(&id) = self.cached_constants.get(&composite) {
1615            return id;
1616        }
1617        let id = self.id_gen.next();
1618        self.write_constant_composite(id, ty, constituent_ids, None);
1619        self.cached_constants.insert(composite, id);
1620        id
1621    }
1622
1623    fn write_constant_composite(
1624        &mut self,
1625        id: Word,
1626        ty: LookupType,
1627        constituent_ids: &[Word],
1628        debug_name: Option<&String>,
1629    ) {
1630        if self.flags.contains(WriterFlags::DEBUG) {
1631            if let Some(name) = debug_name {
1632                self.debugs.push(Instruction::name(id, name));
1633            }
1634        }
1635        let type_id = self.get_type_id(ty);
1636        Instruction::constant_composite(type_id, id, constituent_ids)
1637            .to_words(&mut self.logical_layout.declarations);
1638    }
1639
1640    pub(super) fn get_constant_null(&mut self, type_id: Word) -> Word {
1641        let null = CachedConstant::ZeroValue(type_id);
1642        if let Some(&id) = self.cached_constants.get(&null) {
1643            return id;
1644        }
1645        let id = self.write_constant_null(type_id);
1646        self.cached_constants.insert(null, id);
1647        id
1648    }
1649
1650    pub(super) fn write_constant_null(&mut self, type_id: Word) -> Word {
1651        let null_id = self.id_gen.next();
1652        Instruction::constant_null(type_id, null_id)
1653            .to_words(&mut self.logical_layout.declarations);
1654        null_id
1655    }
1656
1657    fn write_constant_expr(
1658        &mut self,
1659        handle: Handle<crate::Expression>,
1660        ir_module: &crate::Module,
1661        mod_info: &ModuleInfo,
1662    ) -> Result<Word, Error> {
1663        let id = match ir_module.global_expressions[handle] {
1664            crate::Expression::Literal(literal) => self.get_constant_scalar(literal),
1665            crate::Expression::Constant(constant) => {
1666                let constant = &ir_module.constants[constant];
1667                self.constant_ids[constant.init]
1668            }
1669            crate::Expression::ZeroValue(ty) => {
1670                let type_id = self.get_handle_type_id(ty);
1671                self.get_constant_null(type_id)
1672            }
1673            crate::Expression::Compose { ty, ref components } => {
1674                let component_ids: Vec<_> = crate::proc::flatten_compose(
1675                    ty,
1676                    components,
1677                    &ir_module.global_expressions,
1678                    &ir_module.types,
1679                )
1680                .map(|component| self.constant_ids[component])
1681                .collect();
1682                self.get_constant_composite(LookupType::Handle(ty), component_ids.as_slice())
1683            }
1684            crate::Expression::Splat { size, value } => {
1685                let value_id = self.constant_ids[value];
1686                let component_ids = &[value_id; 4][..size as usize];
1687
1688                let ty = self.get_expression_lookup_type(&mod_info[handle]);
1689
1690                self.get_constant_composite(ty, component_ids)
1691            }
1692            _ => {
1693                return Err(Error::Override);
1694            }
1695        };
1696
1697        self.constant_ids[handle] = id;
1698
1699        Ok(id)
1700    }
1701
1702    pub(super) fn write_control_barrier(&mut self, flags: crate::Barrier, block: &mut Block) {
1703        let memory_scope = if flags.contains(crate::Barrier::STORAGE) {
1704            spirv::Scope::Device
1705        } else if flags.contains(crate::Barrier::SUB_GROUP) {
1706            spirv::Scope::Subgroup
1707        } else {
1708            spirv::Scope::Workgroup
1709        };
1710        let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE;
1711        semantics.set(
1712            spirv::MemorySemantics::UNIFORM_MEMORY,
1713            flags.contains(crate::Barrier::STORAGE),
1714        );
1715        semantics.set(
1716            spirv::MemorySemantics::WORKGROUP_MEMORY,
1717            flags.contains(crate::Barrier::WORK_GROUP),
1718        );
1719        semantics.set(
1720            spirv::MemorySemantics::SUBGROUP_MEMORY,
1721            flags.contains(crate::Barrier::SUB_GROUP),
1722        );
1723        semantics.set(
1724            spirv::MemorySemantics::IMAGE_MEMORY,
1725            flags.contains(crate::Barrier::TEXTURE),
1726        );
1727        let exec_scope_id = if flags.contains(crate::Barrier::SUB_GROUP) {
1728            self.get_index_constant(spirv::Scope::Subgroup as u32)
1729        } else {
1730            self.get_index_constant(spirv::Scope::Workgroup as u32)
1731        };
1732        let mem_scope_id = self.get_index_constant(memory_scope as u32);
1733        let semantics_id = self.get_index_constant(semantics.bits());
1734        block.body.push(Instruction::control_barrier(
1735            exec_scope_id,
1736            mem_scope_id,
1737            semantics_id,
1738        ));
1739    }
1740
1741    pub(super) fn write_memory_barrier(&mut self, flags: crate::Barrier, block: &mut Block) {
1742        let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE;
1743        semantics.set(
1744            spirv::MemorySemantics::UNIFORM_MEMORY,
1745            flags.contains(crate::Barrier::STORAGE),
1746        );
1747        semantics.set(
1748            spirv::MemorySemantics::WORKGROUP_MEMORY,
1749            flags.contains(crate::Barrier::WORK_GROUP),
1750        );
1751        semantics.set(
1752            spirv::MemorySemantics::SUBGROUP_MEMORY,
1753            flags.contains(crate::Barrier::SUB_GROUP),
1754        );
1755        semantics.set(
1756            spirv::MemorySemantics::IMAGE_MEMORY,
1757            flags.contains(crate::Barrier::TEXTURE),
1758        );
1759        let mem_scope_id = if flags.contains(crate::Barrier::STORAGE) {
1760            self.get_index_constant(spirv::Scope::Device as u32)
1761        } else if flags.contains(crate::Barrier::SUB_GROUP) {
1762            self.get_index_constant(spirv::Scope::Subgroup as u32)
1763        } else {
1764            self.get_index_constant(spirv::Scope::Workgroup as u32)
1765        };
1766        let semantics_id = self.get_index_constant(semantics.bits());
1767        block
1768            .body
1769            .push(Instruction::memory_barrier(mem_scope_id, semantics_id));
1770    }
1771
1772    fn generate_workgroup_vars_init_block(
1773        &mut self,
1774        entry_id: Word,
1775        ir_module: &crate::Module,
1776        info: &FunctionInfo,
1777        local_invocation_id: Option<Word>,
1778        interface: &mut FunctionInterface,
1779        function: &mut Function,
1780    ) -> Option<Word> {
1781        let body = ir_module
1782            .global_variables
1783            .iter()
1784            .filter(|&(handle, var)| {
1785                !info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1786            })
1787            .map(|(handle, var)| {
1788                // It's safe to use `var_id` here, not `access_id`, because only
1789                // variables in the `Uniform` and `StorageBuffer` address spaces
1790                // get wrapped, and we're initializing `WorkGroup` variables.
1791                let var_id = self.global_variables[handle].var_id;
1792                let var_type_id = self.get_handle_type_id(var.ty);
1793                let init_word = self.get_constant_null(var_type_id);
1794                Instruction::store(var_id, init_word, None)
1795            })
1796            .collect::<Vec<_>>();
1797
1798        if body.is_empty() {
1799            return None;
1800        }
1801
1802        let uint3_type_id = self.get_vec3u_type_id();
1803
1804        let mut pre_if_block = Block::new(entry_id);
1805
1806        let local_invocation_id = if let Some(local_invocation_id) = local_invocation_id {
1807            local_invocation_id
1808        } else {
1809            let varying_id = self.id_gen.next();
1810            let class = spirv::StorageClass::Input;
1811            let pointer_type_id = self.get_vec3u_pointer_type_id(class);
1812
1813            Instruction::variable(pointer_type_id, varying_id, class, None)
1814                .to_words(&mut self.logical_layout.declarations);
1815
1816            self.decorate(
1817                varying_id,
1818                spirv::Decoration::BuiltIn,
1819                &[spirv::BuiltIn::LocalInvocationId as u32],
1820            );
1821
1822            interface.varying_ids.push(varying_id);
1823            let id = self.id_gen.next();
1824            pre_if_block
1825                .body
1826                .push(Instruction::load(uint3_type_id, id, varying_id, None));
1827
1828            id
1829        };
1830
1831        let zero_id = self.get_constant_null(uint3_type_id);
1832        let bool3_type_id = self.get_vec3_bool_type_id();
1833
1834        let eq_id = self.id_gen.next();
1835        pre_if_block.body.push(Instruction::binary(
1836            spirv::Op::IEqual,
1837            bool3_type_id,
1838            eq_id,
1839            local_invocation_id,
1840            zero_id,
1841        ));
1842
1843        let condition_id = self.id_gen.next();
1844        let bool_type_id = self.get_bool_type_id();
1845        pre_if_block.body.push(Instruction::relational(
1846            spirv::Op::All,
1847            bool_type_id,
1848            condition_id,
1849            eq_id,
1850        ));
1851
1852        let merge_id = self.id_gen.next();
1853        pre_if_block.body.push(Instruction::selection_merge(
1854            merge_id,
1855            spirv::SelectionControl::NONE,
1856        ));
1857
1858        let accept_id = self.id_gen.next();
1859        function.consume(
1860            pre_if_block,
1861            Instruction::branch_conditional(condition_id, accept_id, merge_id),
1862        );
1863
1864        let accept_block = Block {
1865            label_id: accept_id,
1866            body,
1867        };
1868        function.consume(accept_block, Instruction::branch(merge_id));
1869
1870        let mut post_if_block = Block::new(merge_id);
1871
1872        self.write_control_barrier(crate::Barrier::WORK_GROUP, &mut post_if_block);
1873
1874        let next_id = self.id_gen.next();
1875        function.consume(post_if_block, Instruction::branch(next_id));
1876        Some(next_id)
1877    }
1878
1879    /// Generate an `OpVariable` for one value in an [`EntryPoint`]'s IO interface.
1880    ///
1881    /// The [`Binding`]s of the arguments and result of an [`EntryPoint`]'s
1882    /// [`Function`] describe a SPIR-V shader interface. In SPIR-V, the
1883    /// interface is represented by global variables in the `Input` and `Output`
1884    /// storage classes, with decorations indicating which builtin or location
1885    /// each variable corresponds to.
1886    ///
1887    /// This function emits a single global `OpVariable` for a single value from
1888    /// the interface, and adds appropriate decorations to indicate which
1889    /// builtin or location it represents, how it should be interpolated, and so
1890    /// on. The `class` argument gives the variable's SPIR-V storage class,
1891    /// which should be either [`Input`] or [`Output`].
1892    ///
1893    /// [`Binding`]: crate::Binding
1894    /// [`Function`]: crate::Function
1895    /// [`EntryPoint`]: crate::EntryPoint
1896    /// [`Input`]: spirv::StorageClass::Input
1897    /// [`Output`]: spirv::StorageClass::Output
1898    fn write_varying(
1899        &mut self,
1900        ir_module: &crate::Module,
1901        stage: crate::ShaderStage,
1902        class: spirv::StorageClass,
1903        debug_name: Option<&str>,
1904        ty: Handle<crate::Type>,
1905        binding: &crate::Binding,
1906    ) -> Result<Word, Error> {
1907        let id = self.id_gen.next();
1908        let pointer_type_id = self.get_handle_pointer_type_id(ty, class);
1909        Instruction::variable(pointer_type_id, id, class, None)
1910            .to_words(&mut self.logical_layout.declarations);
1911
1912        if self
1913            .flags
1914            .contains(WriterFlags::DEBUG | WriterFlags::LABEL_VARYINGS)
1915        {
1916            if let Some(name) = debug_name {
1917                self.debugs.push(Instruction::name(id, name));
1918            }
1919        }
1920
1921        use spirv::{BuiltIn, Decoration};
1922
1923        match *binding {
1924            crate::Binding::Location {
1925                location,
1926                interpolation,
1927                sampling,
1928                blend_src,
1929            } => {
1930                self.decorate(id, Decoration::Location, &[location]);
1931
1932                let no_decorations =
1933                    // VUID-StandaloneSpirv-Flat-06202
1934                    // > The Flat, NoPerspective, Sample, and Centroid decorations
1935                    // > must not be used on variables with the Input storage class in a vertex shader
1936                    (class == spirv::StorageClass::Input && stage == crate::ShaderStage::Vertex) ||
1937                    // VUID-StandaloneSpirv-Flat-06201
1938                    // > The Flat, NoPerspective, Sample, and Centroid decorations
1939                    // > must not be used on variables with the Output storage class in a fragment shader
1940                    (class == spirv::StorageClass::Output && stage == crate::ShaderStage::Fragment);
1941
1942                if !no_decorations {
1943                    match interpolation {
1944                        // Perspective-correct interpolation is the default in SPIR-V.
1945                        None | Some(crate::Interpolation::Perspective) => (),
1946                        Some(crate::Interpolation::Flat) => {
1947                            self.decorate(id, Decoration::Flat, &[]);
1948                        }
1949                        Some(crate::Interpolation::Linear) => {
1950                            self.decorate(id, Decoration::NoPerspective, &[]);
1951                        }
1952                    }
1953                    match sampling {
1954                        // Center sampling is the default in SPIR-V.
1955                        None
1956                        | Some(
1957                            crate::Sampling::Center
1958                            | crate::Sampling::First
1959                            | crate::Sampling::Either,
1960                        ) => (),
1961                        Some(crate::Sampling::Centroid) => {
1962                            self.decorate(id, Decoration::Centroid, &[]);
1963                        }
1964                        Some(crate::Sampling::Sample) => {
1965                            self.require_any(
1966                                "per-sample interpolation",
1967                                &[spirv::Capability::SampleRateShading],
1968                            )?;
1969                            self.decorate(id, Decoration::Sample, &[]);
1970                        }
1971                    }
1972                }
1973                if let Some(blend_src) = blend_src {
1974                    self.decorate(id, Decoration::Index, &[blend_src]);
1975                }
1976            }
1977            crate::Binding::BuiltIn(built_in) => {
1978                use crate::BuiltIn as Bi;
1979                let built_in = match built_in {
1980                    Bi::Position { invariant } => {
1981                        if invariant {
1982                            self.decorate(id, Decoration::Invariant, &[]);
1983                        }
1984
1985                        if class == spirv::StorageClass::Output {
1986                            BuiltIn::Position
1987                        } else {
1988                            BuiltIn::FragCoord
1989                        }
1990                    }
1991                    Bi::ViewIndex => {
1992                        self.require_any("`view_index` built-in", &[spirv::Capability::MultiView])?;
1993                        BuiltIn::ViewIndex
1994                    }
1995                    // vertex
1996                    Bi::BaseInstance => BuiltIn::BaseInstance,
1997                    Bi::BaseVertex => BuiltIn::BaseVertex,
1998                    Bi::ClipDistance => {
1999                        self.require_any(
2000                            "`clip_distance` built-in",
2001                            &[spirv::Capability::ClipDistance],
2002                        )?;
2003                        BuiltIn::ClipDistance
2004                    }
2005                    Bi::CullDistance => {
2006                        self.require_any(
2007                            "`cull_distance` built-in",
2008                            &[spirv::Capability::CullDistance],
2009                        )?;
2010                        BuiltIn::CullDistance
2011                    }
2012                    Bi::InstanceIndex => BuiltIn::InstanceIndex,
2013                    Bi::PointSize => BuiltIn::PointSize,
2014                    Bi::VertexIndex => BuiltIn::VertexIndex,
2015                    Bi::DrawID => BuiltIn::DrawIndex,
2016                    // fragment
2017                    Bi::FragDepth => BuiltIn::FragDepth,
2018                    Bi::PointCoord => BuiltIn::PointCoord,
2019                    Bi::FrontFacing => BuiltIn::FrontFacing,
2020                    Bi::PrimitiveIndex => {
2021                        self.require_any(
2022                            "`primitive_index` built-in",
2023                            &[spirv::Capability::Geometry],
2024                        )?;
2025                        BuiltIn::PrimitiveId
2026                    }
2027                    Bi::SampleIndex => {
2028                        self.require_any(
2029                            "`sample_index` built-in",
2030                            &[spirv::Capability::SampleRateShading],
2031                        )?;
2032
2033                        BuiltIn::SampleId
2034                    }
2035                    Bi::SampleMask => BuiltIn::SampleMask,
2036                    // compute
2037                    Bi::GlobalInvocationId => BuiltIn::GlobalInvocationId,
2038                    Bi::LocalInvocationId => BuiltIn::LocalInvocationId,
2039                    Bi::LocalInvocationIndex => BuiltIn::LocalInvocationIndex,
2040                    Bi::WorkGroupId => BuiltIn::WorkgroupId,
2041                    Bi::WorkGroupSize => BuiltIn::WorkgroupSize,
2042                    Bi::NumWorkGroups => BuiltIn::NumWorkgroups,
2043                    // Subgroup
2044                    Bi::NumSubgroups => {
2045                        self.require_any(
2046                            "`num_subgroups` built-in",
2047                            &[spirv::Capability::GroupNonUniform],
2048                        )?;
2049                        BuiltIn::NumSubgroups
2050                    }
2051                    Bi::SubgroupId => {
2052                        self.require_any(
2053                            "`subgroup_id` built-in",
2054                            &[spirv::Capability::GroupNonUniform],
2055                        )?;
2056                        BuiltIn::SubgroupId
2057                    }
2058                    Bi::SubgroupSize => {
2059                        self.require_any(
2060                            "`subgroup_size` built-in",
2061                            &[
2062                                spirv::Capability::GroupNonUniform,
2063                                spirv::Capability::SubgroupBallotKHR,
2064                            ],
2065                        )?;
2066                        BuiltIn::SubgroupSize
2067                    }
2068                    Bi::SubgroupInvocationId => {
2069                        self.require_any(
2070                            "`subgroup_invocation_id` built-in",
2071                            &[
2072                                spirv::Capability::GroupNonUniform,
2073                                spirv::Capability::SubgroupBallotKHR,
2074                            ],
2075                        )?;
2076                        BuiltIn::SubgroupLocalInvocationId
2077                    }
2078                };
2079
2080                self.decorate(id, Decoration::BuiltIn, &[built_in as u32]);
2081
2082                use crate::ScalarKind as Sk;
2083
2084                // Per the Vulkan spec, `VUID-StandaloneSpirv-Flat-04744`:
2085                //
2086                // > Any variable with integer or double-precision floating-
2087                // > point type and with Input storage class in a fragment
2088                // > shader, must be decorated Flat
2089                if class == spirv::StorageClass::Input && stage == crate::ShaderStage::Fragment {
2090                    let is_flat = match ir_module.types[ty].inner {
2091                        crate::TypeInner::Scalar(scalar)
2092                        | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
2093                            Sk::Uint | Sk::Sint | Sk::Bool => true,
2094                            Sk::Float => false,
2095                            Sk::AbstractInt | Sk::AbstractFloat => {
2096                                return Err(Error::Validation(
2097                                    "Abstract types should not appear in IR presented to backends",
2098                                ))
2099                            }
2100                        },
2101                        _ => false,
2102                    };
2103
2104                    if is_flat {
2105                        self.decorate(id, Decoration::Flat, &[]);
2106                    }
2107                }
2108            }
2109        }
2110
2111        Ok(id)
2112    }
2113
2114    fn write_global_variable(
2115        &mut self,
2116        ir_module: &crate::Module,
2117        global_variable: &crate::GlobalVariable,
2118    ) -> Result<Word, Error> {
2119        use spirv::Decoration;
2120
2121        let id = self.id_gen.next();
2122        let class = map_storage_class(global_variable.space);
2123
2124        //self.check(class.required_capabilities())?;
2125
2126        if self.flags.contains(WriterFlags::DEBUG) {
2127            if let Some(ref name) = global_variable.name {
2128                self.debugs.push(Instruction::name(id, name));
2129            }
2130        }
2131
2132        let storage_access = match global_variable.space {
2133            crate::AddressSpace::Storage { access } => Some(access),
2134            _ => match ir_module.types[global_variable.ty].inner {
2135                crate::TypeInner::Image {
2136                    class: crate::ImageClass::Storage { access, .. },
2137                    ..
2138                } => Some(access),
2139                _ => None,
2140            },
2141        };
2142        if let Some(storage_access) = storage_access {
2143            if !storage_access.contains(crate::StorageAccess::LOAD) {
2144                self.decorate(id, Decoration::NonReadable, &[]);
2145            }
2146            if !storage_access.contains(crate::StorageAccess::STORE) {
2147                self.decorate(id, Decoration::NonWritable, &[]);
2148            }
2149        }
2150
2151        // Note: we should be able to substitute `binding_array<Foo, 0>`,
2152        // but there is still code that tries to register the pre-substituted type,
2153        // and it is failing on 0.
2154        let mut substitute_inner_type_lookup = None;
2155        if let Some(ref res_binding) = global_variable.binding {
2156            self.decorate(id, Decoration::DescriptorSet, &[res_binding.group]);
2157            self.decorate(id, Decoration::Binding, &[res_binding.binding]);
2158
2159            if let Some(&BindingInfo {
2160                binding_array_size: Some(remapped_binding_array_size),
2161            }) = self.binding_map.get(res_binding)
2162            {
2163                if let crate::TypeInner::BindingArray { base, .. } =
2164                    ir_module.types[global_variable.ty].inner
2165                {
2166                    let binding_array_type_id =
2167                        self.get_type_id(LookupType::Local(LocalType::BindingArray {
2168                            base,
2169                            size: remapped_binding_array_size,
2170                        }));
2171                    substitute_inner_type_lookup = Some(LookupType::Local(LocalType::Pointer {
2172                        base: binding_array_type_id,
2173                        class,
2174                    }));
2175                }
2176            }
2177        };
2178
2179        let init_word = global_variable
2180            .init
2181            .map(|constant| self.constant_ids[constant]);
2182        let inner_type_id = self.get_type_id(
2183            substitute_inner_type_lookup.unwrap_or(LookupType::Handle(global_variable.ty)),
2184        );
2185
2186        // generate the wrapping structure if needed
2187        let pointer_type_id = if global_needs_wrapper(ir_module, global_variable) {
2188            let wrapper_type_id = self.id_gen.next();
2189
2190            self.decorate(wrapper_type_id, Decoration::Block, &[]);
2191            let member = crate::StructMember {
2192                name: None,
2193                ty: global_variable.ty,
2194                binding: None,
2195                offset: 0,
2196            };
2197            self.decorate_struct_member(wrapper_type_id, 0, &member, &ir_module.types)?;
2198
2199            Instruction::type_struct(wrapper_type_id, &[inner_type_id])
2200                .to_words(&mut self.logical_layout.declarations);
2201
2202            let pointer_type_id = self.id_gen.next();
2203            Instruction::type_pointer(pointer_type_id, class, wrapper_type_id)
2204                .to_words(&mut self.logical_layout.declarations);
2205
2206            pointer_type_id
2207        } else {
2208            // This is a global variable in the Storage address space. The only
2209            // way it could have `global_needs_wrapper() == false` is if it has
2210            // a runtime-sized or binding array.
2211            // Runtime-sized arrays were decorated when iterating through struct content.
2212            // Now binding arrays require Block decorating.
2213            if let crate::AddressSpace::Storage { .. } = global_variable.space {
2214                match ir_module.types[global_variable.ty].inner {
2215                    crate::TypeInner::BindingArray { base, .. } => {
2216                        let ty = &ir_module.types[base];
2217                        let mut should_decorate = true;
2218                        // Check if the type has a runtime array.
2219                        // A normal runtime array gets validated out,
2220                        // so only structs can be with runtime arrays
2221                        if let crate::TypeInner::Struct { ref members, .. } = ty.inner {
2222                            // only the last member in a struct can be dynamically sized
2223                            if let Some(last_member) = members.last() {
2224                                if let &crate::TypeInner::Array {
2225                                    size: crate::ArraySize::Dynamic,
2226                                    ..
2227                                } = &ir_module.types[last_member.ty].inner
2228                                {
2229                                    should_decorate = false;
2230                                }
2231                            }
2232                        }
2233                        if should_decorate {
2234                            let decorated_id = self.get_handle_type_id(base);
2235                            self.decorate(decorated_id, Decoration::Block, &[]);
2236                        }
2237                    }
2238                    _ => (),
2239                };
2240            }
2241            if substitute_inner_type_lookup.is_some() {
2242                inner_type_id
2243            } else {
2244                self.get_handle_pointer_type_id(global_variable.ty, class)
2245            }
2246        };
2247
2248        let init_word = match (global_variable.space, self.zero_initialize_workgroup_memory) {
2249            (crate::AddressSpace::Private, _)
2250            | (crate::AddressSpace::WorkGroup, super::ZeroInitializeWorkgroupMemoryMode::Native) => {
2251                init_word.or_else(|| Some(self.get_constant_null(inner_type_id)))
2252            }
2253            _ => init_word,
2254        };
2255
2256        Instruction::variable(pointer_type_id, id, class, init_word)
2257            .to_words(&mut self.logical_layout.declarations);
2258        Ok(id)
2259    }
2260
2261    /// Write the necessary decorations for a struct member.
2262    ///
2263    /// Emit decorations for the `index`'th member of the struct type
2264    /// designated by `struct_id`, described by `member`.
2265    fn decorate_struct_member(
2266        &mut self,
2267        struct_id: Word,
2268        index: usize,
2269        member: &crate::StructMember,
2270        arena: &UniqueArena<crate::Type>,
2271    ) -> Result<(), Error> {
2272        use spirv::Decoration;
2273
2274        self.annotations.push(Instruction::member_decorate(
2275            struct_id,
2276            index as u32,
2277            Decoration::Offset,
2278            &[member.offset],
2279        ));
2280
2281        if self.flags.contains(WriterFlags::DEBUG) {
2282            if let Some(ref name) = member.name {
2283                self.debugs
2284                    .push(Instruction::member_name(struct_id, index as u32, name));
2285            }
2286        }
2287
2288        // Matrices and (potentially nested) arrays of matrices both require decorations,
2289        // so "see through" any arrays to determine if they're needed.
2290        let mut member_array_subty_inner = &arena[member.ty].inner;
2291        while let crate::TypeInner::Array { base, .. } = *member_array_subty_inner {
2292            member_array_subty_inner = &arena[base].inner;
2293        }
2294
2295        if let crate::TypeInner::Matrix {
2296            columns: _,
2297            rows,
2298            scalar,
2299        } = *member_array_subty_inner
2300        {
2301            let byte_stride = Alignment::from(rows) * scalar.width as u32;
2302            self.annotations.push(Instruction::member_decorate(
2303                struct_id,
2304                index as u32,
2305                Decoration::ColMajor,
2306                &[],
2307            ));
2308            self.annotations.push(Instruction::member_decorate(
2309                struct_id,
2310                index as u32,
2311                Decoration::MatrixStride,
2312                &[byte_stride],
2313            ));
2314        }
2315
2316        Ok(())
2317    }
2318
2319    pub(super) fn get_function_type(&mut self, lookup_function_type: LookupFunctionType) -> Word {
2320        match self
2321            .lookup_function_type
2322            .entry(lookup_function_type.clone())
2323        {
2324            Entry::Occupied(e) => *e.get(),
2325            Entry::Vacant(_) => {
2326                let id = self.id_gen.next();
2327                let instruction = Instruction::type_function(
2328                    id,
2329                    lookup_function_type.return_type_id,
2330                    &lookup_function_type.parameter_type_ids,
2331                );
2332                instruction.to_words(&mut self.logical_layout.declarations);
2333                self.lookup_function_type.insert(lookup_function_type, id);
2334                id
2335            }
2336        }
2337    }
2338
2339    fn write_physical_layout(&mut self) {
2340        self.physical_layout.bound = self.id_gen.0 + 1;
2341    }
2342
2343    fn write_logical_layout(
2344        &mut self,
2345        ir_module: &crate::Module,
2346        mod_info: &ModuleInfo,
2347        ep_index: Option<usize>,
2348        debug_info: &Option<DebugInfo>,
2349    ) -> Result<(), Error> {
2350        fn has_view_index_check(
2351            ir_module: &crate::Module,
2352            binding: Option<&crate::Binding>,
2353            ty: Handle<crate::Type>,
2354        ) -> bool {
2355            match ir_module.types[ty].inner {
2356                crate::TypeInner::Struct { ref members, .. } => members.iter().any(|member| {
2357                    has_view_index_check(ir_module, member.binding.as_ref(), member.ty)
2358                }),
2359                _ => binding == Some(&crate::Binding::BuiltIn(crate::BuiltIn::ViewIndex)),
2360            }
2361        }
2362
2363        let has_storage_buffers =
2364            ir_module
2365                .global_variables
2366                .iter()
2367                .any(|(_, var)| match var.space {
2368                    crate::AddressSpace::Storage { .. } => true,
2369                    _ => false,
2370                });
2371        let has_view_index = ir_module
2372            .entry_points
2373            .iter()
2374            .flat_map(|entry| entry.function.arguments.iter())
2375            .any(|arg| has_view_index_check(ir_module, arg.binding.as_ref(), arg.ty));
2376        let mut has_ray_query = ir_module.special_types.ray_desc.is_some()
2377            | ir_module.special_types.ray_intersection.is_some();
2378        let has_vertex_return = ir_module.special_types.ray_vertex_return.is_some();
2379
2380        for (_, &crate::Type { ref inner, .. }) in ir_module.types.iter() {
2381            // spirv does not know whether these have vertex return - that is done by us
2382            if let &crate::TypeInner::AccelerationStructure { .. }
2383            | &crate::TypeInner::RayQuery { .. } = inner
2384            {
2385                has_ray_query = true
2386            }
2387        }
2388
2389        if self.physical_layout.version < 0x10300 && has_storage_buffers {
2390            // enable the storage buffer class on < SPV-1.3
2391            Instruction::extension("SPV_KHR_storage_buffer_storage_class")
2392                .to_words(&mut self.logical_layout.extensions);
2393        }
2394        if has_view_index {
2395            Instruction::extension("SPV_KHR_multiview")
2396                .to_words(&mut self.logical_layout.extensions)
2397        }
2398        if has_ray_query {
2399            Instruction::extension("SPV_KHR_ray_query")
2400                .to_words(&mut self.logical_layout.extensions)
2401        }
2402        if has_vertex_return {
2403            Instruction::extension("SPV_KHR_ray_tracing_position_fetch")
2404                .to_words(&mut self.logical_layout.extensions);
2405        }
2406        Instruction::type_void(self.void_type).to_words(&mut self.logical_layout.declarations);
2407        Instruction::ext_inst_import(self.gl450_ext_inst_id, "GLSL.std.450")
2408            .to_words(&mut self.logical_layout.ext_inst_imports);
2409
2410        let mut debug_info_inner = None;
2411        if self.flags.contains(WriterFlags::DEBUG) {
2412            if let Some(debug_info) = debug_info.as_ref() {
2413                let source_file_id = self.id_gen.next();
2414                self.debugs.push(Instruction::string(
2415                    &debug_info.file_name.to_string_lossy(),
2416                    source_file_id,
2417                ));
2418
2419                debug_info_inner = Some(DebugInfoInner {
2420                    source_code: debug_info.source_code,
2421                    source_file_id,
2422                });
2423                self.debugs.append(&mut Instruction::source_auto_continued(
2424                    debug_info.language,
2425                    0,
2426                    &debug_info_inner,
2427                ));
2428            }
2429        }
2430
2431        // write all types
2432        for (handle, _) in ir_module.types.iter() {
2433            self.write_type_declaration_arena(ir_module, handle)?;
2434        }
2435
2436        // write all const-expressions as constants
2437        self.constant_ids
2438            .resize(ir_module.global_expressions.len(), 0);
2439        for (handle, _) in ir_module.global_expressions.iter() {
2440            self.write_constant_expr(handle, ir_module, mod_info)?;
2441        }
2442        debug_assert!(self.constant_ids.iter().all(|&id| id != 0));
2443
2444        // write the name of constants on their respective const-expression initializer
2445        if self.flags.contains(WriterFlags::DEBUG) {
2446            for (_, constant) in ir_module.constants.iter() {
2447                if let Some(ref name) = constant.name {
2448                    let id = self.constant_ids[constant.init];
2449                    self.debugs.push(Instruction::name(id, name));
2450                }
2451            }
2452        }
2453
2454        // write all global variables
2455        for (handle, var) in ir_module.global_variables.iter() {
2456            // If a single entry point was specified, only write `OpVariable` instructions
2457            // for the globals it actually uses. Emit dummies for the others,
2458            // to preserve the indices in `global_variables`.
2459            let gvar = match ep_index {
2460                Some(index) if mod_info.get_entry_point(index)[handle].is_empty() => {
2461                    GlobalVariable::dummy()
2462                }
2463                _ => {
2464                    let id = self.write_global_variable(ir_module, var)?;
2465                    GlobalVariable::new(id)
2466                }
2467            };
2468            self.global_variables.insert(handle, gvar);
2469        }
2470
2471        // write all functions
2472        for (handle, ir_function) in ir_module.functions.iter() {
2473            let info = &mod_info[handle];
2474            if let Some(index) = ep_index {
2475                let ep_info = mod_info.get_entry_point(index);
2476                // If this function uses globals that we omitted from the SPIR-V
2477                // because the entry point and its callees didn't use them,
2478                // then we must skip it.
2479                if !ep_info.dominates_global_use(info) {
2480                    log::info!("Skip function {:?}", ir_function.name);
2481                    continue;
2482                }
2483
2484                // Skip functions that that are not compatible with this entry point's stage.
2485                //
2486                // When validation is enabled, it rejects modules whose entry points try to call
2487                // incompatible functions, so if we got this far, then any functions incompatible
2488                // with our selected entry point must not be used.
2489                //
2490                // When validation is disabled, `fun_info.available_stages` is always just
2491                // `ShaderStages::all()`, so this will write all functions in the module, and
2492                // the downstream GLSL compiler will catch any problems.
2493                if !info.available_stages.contains(ep_info.available_stages) {
2494                    continue;
2495                }
2496            }
2497            let id = self.write_function(ir_function, info, ir_module, None, &debug_info_inner)?;
2498            self.lookup_function.insert(handle, id);
2499        }
2500
2501        // write all or one entry points
2502        for (index, ir_ep) in ir_module.entry_points.iter().enumerate() {
2503            if ep_index.is_some() && ep_index != Some(index) {
2504                continue;
2505            }
2506            let info = mod_info.get_entry_point(index);
2507            let ep_instruction =
2508                self.write_entry_point(ir_ep, info, ir_module, &debug_info_inner)?;
2509            ep_instruction.to_words(&mut self.logical_layout.entry_points);
2510        }
2511
2512        for capability in self.capabilities_used.iter() {
2513            Instruction::capability(*capability).to_words(&mut self.logical_layout.capabilities);
2514        }
2515        for extension in self.extensions_used.iter() {
2516            Instruction::extension(extension).to_words(&mut self.logical_layout.extensions);
2517        }
2518        if ir_module.entry_points.is_empty() {
2519            // SPIR-V doesn't like modules without entry points
2520            Instruction::capability(spirv::Capability::Linkage)
2521                .to_words(&mut self.logical_layout.capabilities);
2522        }
2523
2524        let addressing_model = spirv::AddressingModel::Logical;
2525        let memory_model = spirv::MemoryModel::GLSL450;
2526        //self.check(addressing_model.required_capabilities())?;
2527        //self.check(memory_model.required_capabilities())?;
2528
2529        Instruction::memory_model(addressing_model, memory_model)
2530            .to_words(&mut self.logical_layout.memory_model);
2531
2532        if self.flags.contains(WriterFlags::DEBUG) {
2533            for debug in self.debugs.iter() {
2534                debug.to_words(&mut self.logical_layout.debugs);
2535            }
2536        }
2537
2538        for annotation in self.annotations.iter() {
2539            annotation.to_words(&mut self.logical_layout.annotations);
2540        }
2541
2542        Ok(())
2543    }
2544
2545    pub fn write(
2546        &mut self,
2547        ir_module: &crate::Module,
2548        info: &ModuleInfo,
2549        pipeline_options: Option<&PipelineOptions>,
2550        debug_info: &Option<DebugInfo>,
2551        words: &mut Vec<Word>,
2552    ) -> Result<(), Error> {
2553        self.reset();
2554
2555        // Try to find the entry point and corresponding index
2556        let ep_index = match pipeline_options {
2557            Some(po) => {
2558                let index = ir_module
2559                    .entry_points
2560                    .iter()
2561                    .position(|ep| po.shader_stage == ep.stage && po.entry_point == ep.name)
2562                    .ok_or(Error::EntryPointNotFound)?;
2563                Some(index)
2564            }
2565            None => None,
2566        };
2567
2568        self.write_logical_layout(ir_module, info, ep_index, debug_info)?;
2569        self.write_physical_layout();
2570
2571        self.physical_layout.in_words(words);
2572        self.logical_layout.in_words(words);
2573        Ok(())
2574    }
2575
2576    /// Return the set of capabilities the last module written used.
2577    pub const fn get_capabilities_used(&self) -> &crate::FastIndexSet<spirv::Capability> {
2578        &self.capabilities_used
2579    }
2580
2581    pub fn decorate_non_uniform_binding_array_access(&mut self, id: Word) -> Result<(), Error> {
2582        self.require_any("NonUniformEXT", &[spirv::Capability::ShaderNonUniform])?;
2583        self.use_extension("SPV_EXT_descriptor_indexing");
2584        self.decorate(id, spirv::Decoration::NonUniform, &[]);
2585        Ok(())
2586    }
2587}
2588
2589#[test]
2590fn test_write_physical_layout() {
2591    let mut writer = Writer::new(&Options::default()).unwrap();
2592    assert_eq!(writer.physical_layout.bound, 0);
2593    writer.write_physical_layout();
2594    assert_eq!(writer.physical_layout.bound, 3);
2595}