naga/proc/
mod.rs

1/*!
2[`Module`](super::Module) processing functionality.
3*/
4
5mod constant_evaluator;
6mod emitter;
7pub mod index;
8mod layouter;
9mod namer;
10mod terminator;
11mod typifier;
12
13pub use constant_evaluator::{
14    ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker,
15};
16pub use emitter::Emitter;
17pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
18pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
19pub use namer::{EntryPointIndex, NameKey, Namer};
20pub use terminator::ensure_block_returns;
21pub use typifier::{ResolveContext, ResolveError, TypeResolution};
22
23impl From<super::StorageFormat> for super::Scalar {
24    fn from(format: super::StorageFormat) -> Self {
25        use super::{ScalarKind as Sk, StorageFormat as Sf};
26        let kind = match format {
27            Sf::R8Unorm => Sk::Float,
28            Sf::R8Snorm => Sk::Float,
29            Sf::R8Uint => Sk::Uint,
30            Sf::R8Sint => Sk::Sint,
31            Sf::R16Uint => Sk::Uint,
32            Sf::R16Sint => Sk::Sint,
33            Sf::R16Float => Sk::Float,
34            Sf::Rg8Unorm => Sk::Float,
35            Sf::Rg8Snorm => Sk::Float,
36            Sf::Rg8Uint => Sk::Uint,
37            Sf::Rg8Sint => Sk::Sint,
38            Sf::R32Uint => Sk::Uint,
39            Sf::R32Sint => Sk::Sint,
40            Sf::R32Float => Sk::Float,
41            Sf::Rg16Uint => Sk::Uint,
42            Sf::Rg16Sint => Sk::Sint,
43            Sf::Rg16Float => Sk::Float,
44            Sf::Rgba8Unorm => Sk::Float,
45            Sf::Rgba8Snorm => Sk::Float,
46            Sf::Rgba8Uint => Sk::Uint,
47            Sf::Rgba8Sint => Sk::Sint,
48            Sf::Bgra8Unorm => Sk::Float,
49            Sf::Rgb10a2Uint => Sk::Uint,
50            Sf::Rgb10a2Unorm => Sk::Float,
51            Sf::Rg11b10Ufloat => Sk::Float,
52            Sf::Rg32Uint => Sk::Uint,
53            Sf::Rg32Sint => Sk::Sint,
54            Sf::Rg32Float => Sk::Float,
55            Sf::Rgba16Uint => Sk::Uint,
56            Sf::Rgba16Sint => Sk::Sint,
57            Sf::Rgba16Float => Sk::Float,
58            Sf::Rgba32Uint => Sk::Uint,
59            Sf::Rgba32Sint => Sk::Sint,
60            Sf::Rgba32Float => Sk::Float,
61            Sf::R16Unorm => Sk::Float,
62            Sf::R16Snorm => Sk::Float,
63            Sf::Rg16Unorm => Sk::Float,
64            Sf::Rg16Snorm => Sk::Float,
65            Sf::Rgba16Unorm => Sk::Float,
66            Sf::Rgba16Snorm => Sk::Float,
67        };
68        super::Scalar { kind, width: 4 }
69    }
70}
71
72impl super::ScalarKind {
73    pub const fn is_numeric(self) -> bool {
74        match self {
75            crate::ScalarKind::Sint
76            | crate::ScalarKind::Uint
77            | crate::ScalarKind::Float
78            | crate::ScalarKind::AbstractInt
79            | crate::ScalarKind::AbstractFloat => true,
80            crate::ScalarKind::Bool => false,
81        }
82    }
83}
84
85impl super::Scalar {
86    pub const I32: Self = Self {
87        kind: crate::ScalarKind::Sint,
88        width: 4,
89    };
90    pub const U32: Self = Self {
91        kind: crate::ScalarKind::Uint,
92        width: 4,
93    };
94    pub const F32: Self = Self {
95        kind: crate::ScalarKind::Float,
96        width: 4,
97    };
98    pub const F64: Self = Self {
99        kind: crate::ScalarKind::Float,
100        width: 8,
101    };
102    pub const I64: Self = Self {
103        kind: crate::ScalarKind::Sint,
104        width: 8,
105    };
106    pub const U64: Self = Self {
107        kind: crate::ScalarKind::Uint,
108        width: 8,
109    };
110    pub const BOOL: Self = Self {
111        kind: crate::ScalarKind::Bool,
112        width: crate::BOOL_WIDTH,
113    };
114    pub const ABSTRACT_INT: Self = Self {
115        kind: crate::ScalarKind::AbstractInt,
116        width: crate::ABSTRACT_WIDTH,
117    };
118    pub const ABSTRACT_FLOAT: Self = Self {
119        kind: crate::ScalarKind::AbstractFloat,
120        width: crate::ABSTRACT_WIDTH,
121    };
122
123    pub const fn is_abstract(self) -> bool {
124        match self.kind {
125            crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => true,
126            crate::ScalarKind::Sint
127            | crate::ScalarKind::Uint
128            | crate::ScalarKind::Float
129            | crate::ScalarKind::Bool => false,
130        }
131    }
132
133    /// Construct a float `Scalar` with the given width.
134    ///
135    /// This is especially common when dealing with
136    /// `TypeInner::Matrix`, where the scalar kind is implicit.
137    pub const fn float(width: crate::Bytes) -> Self {
138        Self {
139            kind: crate::ScalarKind::Float,
140            width,
141        }
142    }
143
144    pub const fn to_inner_scalar(self) -> crate::TypeInner {
145        crate::TypeInner::Scalar(self)
146    }
147
148    pub const fn to_inner_vector(self, size: crate::VectorSize) -> crate::TypeInner {
149        crate::TypeInner::Vector { size, scalar: self }
150    }
151
152    pub const fn to_inner_atomic(self) -> crate::TypeInner {
153        crate::TypeInner::Atomic(self)
154    }
155}
156
157#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
158pub enum HashableLiteral {
159    F64(u64),
160    F32(u32),
161    U32(u32),
162    I32(i32),
163    U64(u64),
164    I64(i64),
165    Bool(bool),
166    AbstractInt(i64),
167    AbstractFloat(u64),
168}
169
170impl From<crate::Literal> for HashableLiteral {
171    fn from(l: crate::Literal) -> Self {
172        match l {
173            crate::Literal::F64(v) => Self::F64(v.to_bits()),
174            crate::Literal::F32(v) => Self::F32(v.to_bits()),
175            crate::Literal::U32(v) => Self::U32(v),
176            crate::Literal::I32(v) => Self::I32(v),
177            crate::Literal::U64(v) => Self::U64(v),
178            crate::Literal::I64(v) => Self::I64(v),
179            crate::Literal::Bool(v) => Self::Bool(v),
180            crate::Literal::AbstractInt(v) => Self::AbstractInt(v),
181            crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()),
182        }
183    }
184}
185
186impl crate::Literal {
187    pub const fn new(value: u8, scalar: crate::Scalar) -> Option<Self> {
188        match (value, scalar.kind, scalar.width) {
189            (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)),
190            (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
191            (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
192            (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
193            (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)),
194            (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
195            (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)),
196            (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)),
197            _ => None,
198        }
199    }
200
201    pub const fn zero(scalar: crate::Scalar) -> Option<Self> {
202        Self::new(0, scalar)
203    }
204
205    pub const fn one(scalar: crate::Scalar) -> Option<Self> {
206        Self::new(1, scalar)
207    }
208
209    pub const fn width(&self) -> crate::Bytes {
210        match *self {
211            Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
212            Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
213            Self::Bool(_) => crate::BOOL_WIDTH,
214            Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
215        }
216    }
217    pub const fn scalar(&self) -> crate::Scalar {
218        match *self {
219            Self::F64(_) => crate::Scalar::F64,
220            Self::F32(_) => crate::Scalar::F32,
221            Self::U32(_) => crate::Scalar::U32,
222            Self::I32(_) => crate::Scalar::I32,
223            Self::U64(_) => crate::Scalar::U64,
224            Self::I64(_) => crate::Scalar::I64,
225            Self::Bool(_) => crate::Scalar::BOOL,
226            Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT,
227            Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT,
228        }
229    }
230    pub const fn scalar_kind(&self) -> crate::ScalarKind {
231        self.scalar().kind
232    }
233    pub const fn ty_inner(&self) -> crate::TypeInner {
234        crate::TypeInner::Scalar(self.scalar())
235    }
236}
237
238pub const POINTER_SPAN: u32 = 4;
239
240impl super::TypeInner {
241    /// Return the scalar type of `self`.
242    ///
243    /// If `inner` is a scalar, vector, or matrix type, return
244    /// its scalar type. Otherwise, return `None`.
245    pub const fn scalar(&self) -> Option<super::Scalar> {
246        use crate::TypeInner as Ti;
247        match *self {
248            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => Some(scalar),
249            Ti::Matrix { scalar, .. } => Some(scalar),
250            _ => None,
251        }
252    }
253
254    pub fn scalar_kind(&self) -> Option<super::ScalarKind> {
255        self.scalar().map(|scalar| scalar.kind)
256    }
257
258    /// Returns the scalar width in bytes
259    pub fn scalar_width(&self) -> Option<u8> {
260        self.scalar().map(|scalar| scalar.width)
261    }
262
263    pub const fn pointer_space(&self) -> Option<crate::AddressSpace> {
264        match *self {
265            Self::Pointer { space, .. } => Some(space),
266            Self::ValuePointer { space, .. } => Some(space),
267            _ => None,
268        }
269    }
270
271    pub fn is_atomic_pointer(&self, types: &crate::UniqueArena<crate::Type>) -> bool {
272        match *self {
273            crate::TypeInner::Pointer { base, .. } => match types[base].inner {
274                crate::TypeInner::Atomic { .. } => true,
275                _ => false,
276            },
277            _ => false,
278        }
279    }
280
281    /// Get the size of this type.
282    pub fn size(&self, _gctx: GlobalCtx) -> u32 {
283        match *self {
284            Self::Scalar(scalar) | Self::Atomic(scalar) => scalar.width as u32,
285            Self::Vector { size, scalar } => size as u32 * scalar.width as u32,
286            // matrices are treated as arrays of aligned columns
287            Self::Matrix {
288                columns,
289                rows,
290                scalar,
291            } => Alignment::from(rows) * scalar.width as u32 * columns as u32,
292            Self::Pointer { .. } | Self::ValuePointer { .. } => POINTER_SPAN,
293            Self::Array {
294                base: _,
295                size,
296                stride,
297            } => {
298                let count = match size {
299                    super::ArraySize::Constant(count) => count.get(),
300                    // A dynamically-sized array has to have at least one element
301                    super::ArraySize::Dynamic => 1,
302                };
303                count * stride
304            }
305            Self::Struct { span, .. } => span,
306            Self::Image { .. }
307            | Self::Sampler { .. }
308            | Self::AccelerationStructure
309            | Self::RayQuery
310            | Self::BindingArray { .. } => 0,
311        }
312    }
313
314    /// Return the canonical form of `self`, or `None` if it's already in
315    /// canonical form.
316    ///
317    /// Certain types have multiple representations in `TypeInner`. This
318    /// function converts all forms of equivalent types to a single
319    /// representative of their class, so that simply applying `Eq` to the
320    /// result indicates whether the types are equivalent, as far as Naga IR is
321    /// concerned.
322    pub fn canonical_form(
323        &self,
324        types: &crate::UniqueArena<crate::Type>,
325    ) -> Option<crate::TypeInner> {
326        use crate::TypeInner as Ti;
327        match *self {
328            Ti::Pointer { base, space } => match types[base].inner {
329                Ti::Scalar(scalar) => Some(Ti::ValuePointer {
330                    size: None,
331                    scalar,
332                    space,
333                }),
334                Ti::Vector { size, scalar } => Some(Ti::ValuePointer {
335                    size: Some(size),
336                    scalar,
337                    space,
338                }),
339                _ => None,
340            },
341            _ => None,
342        }
343    }
344
345    /// Compare `self` and `rhs` as types.
346    ///
347    /// This is mostly the same as `<TypeInner as Eq>::eq`, but it treats
348    /// `ValuePointer` and `Pointer` types as equivalent.
349    ///
350    /// When you know that one side of the comparison is never a pointer, it's
351    /// fine to not bother with canonicalization, and just compare `TypeInner`
352    /// values with `==`.
353    pub fn equivalent(
354        &self,
355        rhs: &crate::TypeInner,
356        types: &crate::UniqueArena<crate::Type>,
357    ) -> bool {
358        let left = self.canonical_form(types);
359        let right = rhs.canonical_form(types);
360        left.as_ref().unwrap_or(self) == right.as_ref().unwrap_or(rhs)
361    }
362
363    pub fn is_dynamically_sized(&self, types: &crate::UniqueArena<crate::Type>) -> bool {
364        use crate::TypeInner as Ti;
365        match *self {
366            Ti::Array { size, .. } => size == crate::ArraySize::Dynamic,
367            Ti::Struct { ref members, .. } => members
368                .last()
369                .map(|last| types[last.ty].inner.is_dynamically_sized(types))
370                .unwrap_or(false),
371            _ => false,
372        }
373    }
374
375    pub fn components(&self) -> Option<u32> {
376        Some(match *self {
377            Self::Vector { size, .. } => size as u32,
378            Self::Matrix { columns, .. } => columns as u32,
379            Self::Array {
380                size: crate::ArraySize::Constant(len),
381                ..
382            } => len.get(),
383            Self::Struct { ref members, .. } => members.len() as u32,
384            _ => return None,
385        })
386    }
387
388    pub fn component_type(&self, index: usize) -> Option<TypeResolution> {
389        Some(match *self {
390            Self::Vector { scalar, .. } => TypeResolution::Value(crate::TypeInner::Scalar(scalar)),
391            Self::Matrix { rows, scalar, .. } => {
392                TypeResolution::Value(crate::TypeInner::Vector { size: rows, scalar })
393            }
394            Self::Array {
395                base,
396                size: crate::ArraySize::Constant(_),
397                ..
398            } => TypeResolution::Handle(base),
399            Self::Struct { ref members, .. } => TypeResolution::Handle(members[index].ty),
400            _ => return None,
401        })
402    }
403}
404
405impl super::AddressSpace {
406    pub fn access(self) -> crate::StorageAccess {
407        use crate::StorageAccess as Sa;
408        match self {
409            crate::AddressSpace::Function
410            | crate::AddressSpace::Private
411            | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
412            crate::AddressSpace::Uniform => Sa::LOAD,
413            crate::AddressSpace::Storage { access } => access,
414            crate::AddressSpace::Handle => Sa::LOAD,
415            crate::AddressSpace::PushConstant => Sa::LOAD,
416        }
417    }
418}
419
420impl super::MathFunction {
421    pub const fn argument_count(&self) -> usize {
422        match *self {
423            // comparison
424            Self::Abs => 1,
425            Self::Min => 2,
426            Self::Max => 2,
427            Self::Clamp => 3,
428            Self::Saturate => 1,
429            // trigonometry
430            Self::Cos => 1,
431            Self::Cosh => 1,
432            Self::Sin => 1,
433            Self::Sinh => 1,
434            Self::Tan => 1,
435            Self::Tanh => 1,
436            Self::Acos => 1,
437            Self::Asin => 1,
438            Self::Atan => 1,
439            Self::Atan2 => 2,
440            Self::Asinh => 1,
441            Self::Acosh => 1,
442            Self::Atanh => 1,
443            Self::Radians => 1,
444            Self::Degrees => 1,
445            // decomposition
446            Self::Ceil => 1,
447            Self::Floor => 1,
448            Self::Round => 1,
449            Self::Fract => 1,
450            Self::Trunc => 1,
451            Self::Modf => 1,
452            Self::Frexp => 1,
453            Self::Ldexp => 2,
454            // exponent
455            Self::Exp => 1,
456            Self::Exp2 => 1,
457            Self::Log => 1,
458            Self::Log2 => 1,
459            Self::Pow => 2,
460            // geometry
461            Self::Dot => 2,
462            Self::Outer => 2,
463            Self::Cross => 2,
464            Self::Distance => 2,
465            Self::Length => 1,
466            Self::Normalize => 1,
467            Self::FaceForward => 3,
468            Self::Reflect => 2,
469            Self::Refract => 3,
470            // computational
471            Self::Sign => 1,
472            Self::Fma => 3,
473            Self::Mix => 3,
474            Self::Step => 2,
475            Self::SmoothStep => 3,
476            Self::Sqrt => 1,
477            Self::InverseSqrt => 1,
478            Self::Inverse => 1,
479            Self::Transpose => 1,
480            Self::Determinant => 1,
481            // bits
482            Self::CountTrailingZeros => 1,
483            Self::CountLeadingZeros => 1,
484            Self::CountOneBits => 1,
485            Self::ReverseBits => 1,
486            Self::ExtractBits => 3,
487            Self::InsertBits => 4,
488            Self::FirstTrailingBit => 1,
489            Self::FirstLeadingBit => 1,
490            // data packing
491            Self::Pack4x8snorm => 1,
492            Self::Pack4x8unorm => 1,
493            Self::Pack2x16snorm => 1,
494            Self::Pack2x16unorm => 1,
495            Self::Pack2x16float => 1,
496            Self::Pack4xI8 => 1,
497            Self::Pack4xU8 => 1,
498            // data unpacking
499            Self::Unpack4x8snorm => 1,
500            Self::Unpack4x8unorm => 1,
501            Self::Unpack2x16snorm => 1,
502            Self::Unpack2x16unorm => 1,
503            Self::Unpack2x16float => 1,
504            Self::Unpack4xI8 => 1,
505            Self::Unpack4xU8 => 1,
506        }
507    }
508}
509
510impl crate::Expression {
511    /// Returns true if the expression is considered emitted at the start of a function.
512    pub const fn needs_pre_emit(&self) -> bool {
513        match *self {
514            Self::Literal(_)
515            | Self::Constant(_)
516            | Self::Override(_)
517            | Self::ZeroValue(_)
518            | Self::FunctionArgument(_)
519            | Self::GlobalVariable(_)
520            | Self::LocalVariable(_) => true,
521            _ => false,
522        }
523    }
524
525    /// Return true if this expression is a dynamic array/vector/matrix index,
526    /// for [`Access`].
527    ///
528    /// This method returns true if this expression is a dynamically computed
529    /// index, and as such can only be used to index matrices when they appear
530    /// behind a pointer. See the documentation for [`Access`] for details.
531    ///
532    /// Note, this does not check the _type_ of the given expression. It's up to
533    /// the caller to establish that the `Access` expression is well-typed
534    /// through other means, like [`ResolveContext`].
535    ///
536    /// [`Access`]: crate::Expression::Access
537    /// [`ResolveContext`]: crate::proc::ResolveContext
538    pub const fn is_dynamic_index(&self) -> bool {
539        match *self {
540            Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false,
541            _ => true,
542        }
543    }
544}
545
546impl crate::Function {
547    /// Return the global variable being accessed by the expression `pointer`.
548    ///
549    /// Assuming that `pointer` is a series of `Access` and `AccessIndex`
550    /// expressions that ultimately access some part of a `GlobalVariable`,
551    /// return a handle for that global.
552    ///
553    /// If the expression does not ultimately access a global variable, return
554    /// `None`.
555    pub fn originating_global(
556        &self,
557        mut pointer: crate::Handle<crate::Expression>,
558    ) -> Option<crate::Handle<crate::GlobalVariable>> {
559        loop {
560            pointer = match self.expressions[pointer] {
561                crate::Expression::Access { base, .. } => base,
562                crate::Expression::AccessIndex { base, .. } => base,
563                crate::Expression::GlobalVariable(handle) => return Some(handle),
564                crate::Expression::LocalVariable(_) => return None,
565                crate::Expression::FunctionArgument(_) => return None,
566                // There are no other expressions that produce pointer values.
567                _ => unreachable!(),
568            }
569        }
570    }
571}
572
573impl crate::SampleLevel {
574    pub const fn implicit_derivatives(&self) -> bool {
575        match *self {
576            Self::Auto | Self::Bias(_) => true,
577            Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
578        }
579    }
580}
581
582impl crate::Binding {
583    pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
584        match *self {
585            crate::Binding::BuiltIn(built_in) => Some(built_in),
586            Self::Location { .. } => None,
587        }
588    }
589}
590
591impl super::SwizzleComponent {
592    pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
593
594    pub const fn index(&self) -> u32 {
595        match *self {
596            Self::X => 0,
597            Self::Y => 1,
598            Self::Z => 2,
599            Self::W => 3,
600        }
601    }
602    pub const fn from_index(idx: u32) -> Self {
603        match idx {
604            0 => Self::X,
605            1 => Self::Y,
606            2 => Self::Z,
607            _ => Self::W,
608        }
609    }
610}
611
612impl super::ImageClass {
613    pub const fn is_multisampled(self) -> bool {
614        match self {
615            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
616            crate::ImageClass::Storage { .. } => false,
617        }
618    }
619
620    pub const fn is_mipmapped(self) -> bool {
621        match self {
622            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
623            crate::ImageClass::Storage { .. } => false,
624        }
625    }
626
627    pub const fn is_depth(self) -> bool {
628        matches!(self, crate::ImageClass::Depth { .. })
629    }
630}
631
632impl crate::Module {
633    pub const fn to_ctx(&self) -> GlobalCtx<'_> {
634        GlobalCtx {
635            types: &self.types,
636            constants: &self.constants,
637            overrides: &self.overrides,
638            global_expressions: &self.global_expressions,
639        }
640    }
641}
642
643#[derive(Debug)]
644pub(super) enum U32EvalError {
645    NonConst,
646    Negative,
647}
648
649#[derive(Clone, Copy)]
650pub struct GlobalCtx<'a> {
651    pub types: &'a crate::UniqueArena<crate::Type>,
652    pub constants: &'a crate::Arena<crate::Constant>,
653    pub overrides: &'a crate::Arena<crate::Override>,
654    pub global_expressions: &'a crate::Arena<crate::Expression>,
655}
656
657impl GlobalCtx<'_> {
658    /// Try to evaluate the expression in `self.global_expressions` using its `handle` and return it as a `u32`.
659    #[allow(dead_code)]
660    pub(super) fn eval_expr_to_u32(
661        &self,
662        handle: crate::Handle<crate::Expression>,
663    ) -> Result<u32, U32EvalError> {
664        self.eval_expr_to_u32_from(handle, self.global_expressions)
665    }
666
667    /// Try to evaluate the expression in the `arena` using its `handle` and return it as a `u32`.
668    pub(super) fn eval_expr_to_u32_from(
669        &self,
670        handle: crate::Handle<crate::Expression>,
671        arena: &crate::Arena<crate::Expression>,
672    ) -> Result<u32, U32EvalError> {
673        match self.eval_expr_to_literal_from(handle, arena) {
674            Some(crate::Literal::U32(value)) => Ok(value),
675            Some(crate::Literal::I32(value)) => {
676                value.try_into().map_err(|_| U32EvalError::Negative)
677            }
678            _ => Err(U32EvalError::NonConst),
679        }
680    }
681
682    /// Try to evaluate the expression in the `arena` using its `handle` and return it as a `bool`.
683    #[allow(dead_code)]
684    pub(super) fn eval_expr_to_bool_from(
685        &self,
686        handle: crate::Handle<crate::Expression>,
687        arena: &crate::Arena<crate::Expression>,
688    ) -> Option<bool> {
689        match self.eval_expr_to_literal_from(handle, arena) {
690            Some(crate::Literal::Bool(value)) => Some(value),
691            _ => None,
692        }
693    }
694
695    #[allow(dead_code)]
696    pub(crate) fn eval_expr_to_literal(
697        &self,
698        handle: crate::Handle<crate::Expression>,
699    ) -> Option<crate::Literal> {
700        self.eval_expr_to_literal_from(handle, self.global_expressions)
701    }
702
703    fn eval_expr_to_literal_from(
704        &self,
705        handle: crate::Handle<crate::Expression>,
706        arena: &crate::Arena<crate::Expression>,
707    ) -> Option<crate::Literal> {
708        fn get(
709            gctx: GlobalCtx,
710            handle: crate::Handle<crate::Expression>,
711            arena: &crate::Arena<crate::Expression>,
712        ) -> Option<crate::Literal> {
713            match arena[handle] {
714                crate::Expression::Literal(literal) => Some(literal),
715                crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner {
716                    crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar),
717                    _ => None,
718                },
719                _ => None,
720            }
721        }
722        match arena[handle] {
723            crate::Expression::Constant(c) => {
724                get(*self, self.constants[c].init, self.global_expressions)
725            }
726            _ => get(*self, handle, arena),
727        }
728    }
729}
730
731/// Return an iterator over the individual components assembled by a
732/// `Compose` expression.
733///
734/// Given `ty` and `components` from an `Expression::Compose`, return an
735/// iterator over the components of the resulting value.
736///
737/// Normally, this would just be an iterator over `components`. However,
738/// `Compose` expressions can concatenate vectors, in which case the i'th
739/// value being composed is not generally the i'th element of `components`.
740/// This function consults `ty` to decide if this concatenation is occurring,
741/// and returns an iterator that produces the components of the result of
742/// the `Compose` expression in either case.
743pub fn flatten_compose<'arenas>(
744    ty: crate::Handle<crate::Type>,
745    components: &'arenas [crate::Handle<crate::Expression>],
746    expressions: &'arenas crate::Arena<crate::Expression>,
747    types: &'arenas crate::UniqueArena<crate::Type>,
748) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
749    // Returning `impl Iterator` is a bit tricky. We may or may not
750    // want to flatten the components, but we have to settle on a
751    // single concrete type to return. This function returns a single
752    // iterator chain that handles both the flattening and
753    // non-flattening cases.
754    let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
755        (size as usize, true)
756    } else {
757        (components.len(), false)
758    };
759
760    /// Flatten `Compose` expressions if `is_vector` is true.
761    fn flatten_compose<'c>(
762        component: &'c crate::Handle<crate::Expression>,
763        is_vector: bool,
764        expressions: &'c crate::Arena<crate::Expression>,
765    ) -> &'c [crate::Handle<crate::Expression>] {
766        if is_vector {
767            if let crate::Expression::Compose {
768                ty: _,
769                components: ref subcomponents,
770            } = expressions[*component]
771            {
772                return subcomponents;
773            }
774        }
775        std::slice::from_ref(component)
776    }
777
778    /// Flatten `Splat` expressions if `is_vector` is true.
779    fn flatten_splat<'c>(
780        component: &'c crate::Handle<crate::Expression>,
781        is_vector: bool,
782        expressions: &'c crate::Arena<crate::Expression>,
783    ) -> impl Iterator<Item = crate::Handle<crate::Expression>> {
784        let mut expr = *component;
785        let mut count = 1;
786        if is_vector {
787            if let crate::Expression::Splat { size, value } = expressions[expr] {
788                expr = value;
789                count = size as usize;
790            }
791        }
792        std::iter::repeat(expr).take(count)
793    }
794
795    // Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to
796    // flatten up to two levels of `Compose` expressions.
797    //
798    // Expressions like `vec4(vec3(1.0), 1.0)` require us to flatten
799    // `Splat` expressions. Fortunately, the operand of a `Splat` must
800    // be a scalar, so we can stop there.
801    components
802        .iter()
803        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
804        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
805        .flat_map(move |component| flatten_splat(component, is_vector, expressions))
806        .take(size)
807}
808
809#[test]
810fn test_matrix_size() {
811    let module = crate::Module::default();
812    assert_eq!(
813        crate::TypeInner::Matrix {
814            columns: crate::VectorSize::Tri,
815            rows: crate::VectorSize::Tri,
816            scalar: crate::Scalar::F32,
817        }
818        .size(module.to_ctx()),
819        48,
820    );
821}