naga/proc/
constant_evaluator.rs

1use std::iter;
2
3use arrayvec::ArrayVec;
4
5use crate::{
6    arena::{Arena, Handle, HandleVec, UniqueArena},
7    ArraySize, BinaryOperator, Constant, Expression, Literal, Override, ScalarKind, Span, Type,
8    TypeInner, UnaryOperator,
9};
10
11/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating
12/// `macro_rules!` items that, in turn, emit their own `macro_rules!` items.
13///
14/// Technique stolen directly from
15/// <https://github.com/rust-lang/rust/issues/35853#issuecomment-415993963>.
16macro_rules! with_dollar_sign {
17    ($($body:tt)*) => {
18        macro_rules! __with_dollar_sign { $($body)* }
19        __with_dollar_sign!($);
20    }
21}
22
23macro_rules! gen_component_wise_extractor {
24    (
25        $ident:ident -> $target:ident,
26        literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
27        scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
28    ) => {
29        /// A subset of [`Literal`]s intended to be used for implementing numeric built-ins.
30        #[derive(Debug)]
31        #[cfg_attr(test, derive(PartialEq))]
32        enum $target<const N: usize> {
33            $(
34                #[doc = concat!(
35                    "Maps to [`Literal::",
36                    stringify!($literal),
37                    "`]",
38                )]
39                $mapping([$ty; N]),
40            )+
41        }
42
43        impl From<$target<1>> for Expression {
44            fn from(value: $target<1>) -> Self {
45                match value {
46                    $(
47                        $target::$mapping([value]) => {
48                            Expression::Literal(Literal::$literal(value))
49                        }
50                    )+
51                }
52            }
53        }
54
55        #[doc = concat!(
56            "Attempts to evaluate multiple `exprs` as a combined [`",
57            stringify!($target),
58            "`] to pass to `handler`. ",
59        )]
60        /// If `exprs` are vectors of the same length, `handler` is called for each corresponding
61        /// component of each vector.
62        ///
63        /// `handler`'s output is registered as a new expression. If `exprs` are vectors of the
64        /// same length, a new vector expression is registered, composed of each component emitted
65        /// by `handler`.
66        fn $ident<const N: usize, const M: usize, F>(
67            eval: &mut ConstantEvaluator<'_>,
68            span: Span,
69            exprs: [Handle<Expression>; N],
70            mut handler: F,
71        ) -> Result<Handle<Expression>, ConstantEvaluatorError>
72        where
73            $target<M>: Into<Expression>,
74            F: FnMut($target<N>) -> Result<$target<M>, ConstantEvaluatorError> + Clone,
75        {
76            assert!(N > 0);
77            let err = ConstantEvaluatorError::InvalidMathArg;
78            let mut exprs = exprs.into_iter();
79
80            macro_rules! sanitize {
81                ($expr:expr) => {
82                    eval.eval_zero_value_and_splat($expr, span)
83                        .map(|expr| &eval.expressions[expr])
84                };
85            }
86
87            let new_expr = match sanitize!(exprs.next().unwrap())? {
88                $(
89                    &Expression::Literal(Literal::$literal(x)) => iter::once(Ok(x))
90                        .chain(exprs.map(|expr| {
91                            sanitize!(expr).and_then(|expr| match expr {
92                                &Expression::Literal(Literal::$literal(x)) => Ok(x),
93                                _ => Err(err.clone()),
94                            })
95                        }))
96                        .collect::<Result<ArrayVec<_, N>, _>>()
97                        .map(|a| a.into_inner().unwrap())
98                        .map($target::$mapping)
99                        .and_then(|comps| Ok(handler(comps)?.into())),
100                )+
101                &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
102                    &TypeInner::Vector { size, scalar } => match scalar.kind {
103                        $(ScalarKind::$scalar_kind)|* => {
104                            let first_ty = ty;
105                            let mut component_groups =
106                                ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
107                            component_groups.push(crate::proc::flatten_compose(
108                                first_ty,
109                                components,
110                                eval.expressions,
111                                eval.types,
112                            ).collect());
113                            component_groups.extend(
114                                exprs
115                                    .map(|expr| {
116                                        sanitize!(expr).and_then(|expr| match expr {
117                                            &Expression::Compose { ty, ref components }
118                                                if &eval.types[ty].inner
119                                                    == &eval.types[first_ty].inner =>
120                                            {
121                                                Ok(crate::proc::flatten_compose(
122                                                    ty,
123                                                    components,
124                                                    eval.expressions,
125                                                    eval.types,
126                                                ).collect())
127                                            }
128                                            _ => Err(err.clone()),
129                                        })
130                                    })
131                                    .collect::<Result<ArrayVec<_, { crate::VectorSize::MAX }>, _>>(
132                                    )?,
133                            );
134                            let component_groups = component_groups.into_inner().unwrap();
135                            let mut new_components =
136                                ArrayVec::<_, { crate::VectorSize::MAX }>::new();
137                            for idx in 0..(size as u8).into() {
138                                let group = component_groups
139                                    .iter()
140                                    .map(|cs| cs[idx])
141                                    .collect::<ArrayVec<_, N>>()
142                                    .into_inner()
143                                    .unwrap();
144                                new_components.push($ident(
145                                    eval,
146                                    span,
147                                    group,
148                                    handler.clone(),
149                                )?);
150                            }
151                            Ok(Expression::Compose {
152                                ty: first_ty,
153                                components: new_components.into_iter().collect(),
154                            })
155                        }
156                        _ => return Err(err),
157                    },
158                    _ => return Err(err),
159                },
160                _ => return Err(err),
161            }?;
162            eval.register_evaluated_expr(new_expr, span)
163        }
164
165        with_dollar_sign! {
166            ($d:tt) => {
167                #[allow(unused)]
168                #[doc = concat!(
169                    "A convenience macro for using the same RHS for each [`",
170                    stringify!($target),
171                    "`] variant in a call to [`",
172                    stringify!($ident),
173                    "`].",
174                )]
175                macro_rules! $ident {
176                    (
177                        $eval:expr,
178                        $span:expr,
179                        [$d ($d expr:expr),+ $d (,)?],
180                        |$d ($d arg:ident),+| $d tt:tt
181                    ) => {
182                        $ident($eval, $span, [$d ($d expr),+], |args| match args {
183                            $(
184                                $target::$mapping([$d ($d arg),+]) => {
185                                    let res = $d tt;
186                                    Result::map(res, $target::$mapping)
187                                },
188                            )+
189                        })
190                    };
191                }
192            };
193        }
194    };
195}
196
197gen_component_wise_extractor! {
198    component_wise_scalar -> Scalar,
199    literals: [
200        AbstractFloat => AbstractFloat: f64,
201        F32 => F32: f32,
202        AbstractInt => AbstractInt: i64,
203        U32 => U32: u32,
204        I32 => I32: i32,
205        U64 => U64: u64,
206        I64 => I64: i64,
207    ],
208    scalar_kinds: [
209        Float,
210        AbstractFloat,
211        Sint,
212        Uint,
213        AbstractInt,
214    ],
215}
216
217gen_component_wise_extractor! {
218    component_wise_float -> Float,
219    literals: [
220        AbstractFloat => Abstract: f64,
221        F32 => F32: f32,
222    ],
223    scalar_kinds: [
224        Float,
225        AbstractFloat,
226    ],
227}
228
229gen_component_wise_extractor! {
230    component_wise_concrete_int -> ConcreteInt,
231    literals: [
232        U32 => U32: u32,
233        I32 => I32: i32,
234    ],
235    scalar_kinds: [
236        Sint,
237        Uint,
238    ],
239}
240
241gen_component_wise_extractor! {
242    component_wise_signed -> Signed,
243    literals: [
244        AbstractFloat => AbstractFloat: f64,
245        AbstractInt => AbstractInt: i64,
246        F32 => F32: f32,
247        I32 => I32: i32,
248    ],
249    scalar_kinds: [
250        Sint,
251        AbstractInt,
252        Float,
253        AbstractFloat,
254    ],
255}
256
257#[derive(Debug)]
258enum Behavior<'a> {
259    Wgsl(WgslRestrictions<'a>),
260    Glsl(GlslRestrictions<'a>),
261}
262
263impl Behavior<'_> {
264    /// Returns `true` if the inner WGSL/GLSL restrictions are runtime restrictions.
265    const fn has_runtime_restrictions(&self) -> bool {
266        matches!(
267            self,
268            &Behavior::Wgsl(WgslRestrictions::Runtime(_))
269                | &Behavior::Glsl(GlslRestrictions::Runtime(_))
270        )
271    }
272}
273
274/// A context for evaluating constant expressions.
275///
276/// A `ConstantEvaluator` points at an expression arena to which it can append
277/// newly evaluated expressions: you pass [`try_eval_and_append`] whatever kind
278/// of Naga [`Expression`] you like, and if its value can be computed at compile
279/// time, `try_eval_and_append` appends an expression representing the computed
280/// value - a tree of [`Literal`], [`Compose`], [`ZeroValue`], and [`Swizzle`]
281/// expressions - to the arena. See the [`try_eval_and_append`] method for details.
282///
283/// A `ConstantEvaluator` also holds whatever information we need to carry out
284/// that evaluation: types, other constants, and so on.
285///
286/// [`try_eval_and_append`]: ConstantEvaluator::try_eval_and_append
287/// [`Compose`]: Expression::Compose
288/// [`ZeroValue`]: Expression::ZeroValue
289/// [`Literal`]: Expression::Literal
290/// [`Swizzle`]: Expression::Swizzle
291#[derive(Debug)]
292pub struct ConstantEvaluator<'a> {
293    /// Which language's evaluation rules we should follow.
294    behavior: Behavior<'a>,
295
296    /// The module's type arena.
297    ///
298    /// Because expressions like [`Splat`] contain type handles, we need to be
299    /// able to add new types to produce those expressions.
300    ///
301    /// [`Splat`]: Expression::Splat
302    types: &'a mut UniqueArena<Type>,
303
304    /// The module's constant arena.
305    constants: &'a Arena<Constant>,
306
307    /// The module's override arena.
308    overrides: &'a Arena<Override>,
309
310    /// The arena to which we are contributing expressions.
311    expressions: &'a mut Arena<Expression>,
312
313    /// Tracks the constness of expressions residing in [`Self::expressions`]
314    expression_kind_tracker: &'a mut ExpressionKindTracker,
315}
316
317#[derive(Debug)]
318enum WgslRestrictions<'a> {
319    /// - const-expressions will be evaluated and inserted in the arena
320    Const(Option<FunctionLocalData<'a>>),
321    /// - const-expressions will be evaluated and inserted in the arena
322    /// - override-expressions will be inserted in the arena
323    Override,
324    /// - const-expressions will be evaluated and inserted in the arena
325    /// - override-expressions will be inserted in the arena
326    /// - runtime-expressions will be inserted in the arena
327    Runtime(FunctionLocalData<'a>),
328}
329
330#[derive(Debug)]
331enum GlslRestrictions<'a> {
332    /// - const-expressions will be evaluated and inserted in the arena
333    Const,
334    /// - const-expressions will be evaluated and inserted in the arena
335    /// - override-expressions will be inserted in the arena
336    /// - runtime-expressions will be inserted in the arena
337    Runtime(FunctionLocalData<'a>),
338}
339
340#[derive(Debug)]
341struct FunctionLocalData<'a> {
342    /// Global constant expressions
343    global_expressions: &'a Arena<Expression>,
344    emitter: &'a mut super::Emitter,
345    block: &'a mut crate::Block,
346}
347
348#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
349pub enum ExpressionKind {
350    /// If const is also implemented as const
351    ImplConst,
352    Const,
353    Override,
354    Runtime,
355}
356
357#[derive(Debug)]
358pub struct ExpressionKindTracker {
359    inner: HandleVec<Expression, ExpressionKind>,
360}
361
362impl ExpressionKindTracker {
363    pub const fn new() -> Self {
364        Self {
365            inner: HandleVec::new(),
366        }
367    }
368
369    /// Forces the the expression to not be const
370    pub fn force_non_const(&mut self, value: Handle<Expression>) {
371        self.inner[value] = ExpressionKind::Runtime;
372    }
373
374    pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
375        self.inner.insert(value, expr_type);
376    }
377
378    pub fn is_const(&self, h: Handle<Expression>) -> bool {
379        matches!(
380            self.type_of(h),
381            ExpressionKind::Const | ExpressionKind::ImplConst
382        )
383    }
384
385    /// Returns `true` if naga can also evaluate expression as const
386    pub fn is_impl_const(&self, h: Handle<Expression>) -> bool {
387        matches!(self.type_of(h), ExpressionKind::ImplConst)
388    }
389
390    pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
391        matches!(
392            self.type_of(h),
393            ExpressionKind::Const | ExpressionKind::Override | ExpressionKind::ImplConst
394        )
395    }
396
397    fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
398        self.inner[value]
399    }
400
401    pub fn from_arena(arena: &Arena<Expression>) -> Self {
402        let mut tracker = Self {
403            inner: HandleVec::with_capacity(arena.len()),
404        };
405        for (handle, expr) in arena.iter() {
406            tracker
407                .inner
408                .insert(handle, tracker.type_of_with_expr(expr));
409        }
410        tracker
411    }
412
413    fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
414        use crate::MathFunction as Mf;
415        match *expr {
416            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
417                ExpressionKind::ImplConst
418            }
419            Expression::Override(_) => ExpressionKind::Override,
420            Expression::Compose { ref components, .. } => {
421                let mut expr_type = ExpressionKind::ImplConst;
422                for component in components {
423                    expr_type = expr_type.max(self.type_of(*component))
424                }
425                expr_type
426            }
427            Expression::Splat { value, .. } => self.type_of(value),
428            Expression::AccessIndex { base, .. } => self.type_of(base),
429            Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
430            Expression::Swizzle { vector, .. } => self.type_of(vector),
431            Expression::Unary { expr, .. } => self.type_of(expr),
432            Expression::Binary { left, right, .. } => self
433                .type_of(left)
434                .max(self.type_of(right))
435                .max(ExpressionKind::Const),
436            Expression::Math {
437                fun,
438                arg,
439                arg1,
440                arg2,
441                arg3,
442            } => self
443                .type_of(arg)
444                .max(
445                    arg1.map(|arg| self.type_of(arg))
446                        .unwrap_or(ExpressionKind::Const),
447                )
448                .max(
449                    arg2.map(|arg| self.type_of(arg))
450                        .unwrap_or(ExpressionKind::Const),
451                )
452                .max(
453                    arg3.map(|arg| self.type_of(arg))
454                        .unwrap_or(ExpressionKind::Const),
455                )
456                .max(
457                    if matches!(
458                        fun,
459                        Mf::Dot
460                            | Mf::Outer
461                            | Mf::Cross
462                            | Mf::Distance
463                            | Mf::Length
464                            | Mf::Normalize
465                            | Mf::FaceForward
466                            | Mf::Reflect
467                            | Mf::Refract
468                            | Mf::Ldexp
469                            | Mf::Modf
470                            | Mf::Mix
471                            | Mf::Frexp
472                    ) {
473                        ExpressionKind::Const
474                    } else {
475                        ExpressionKind::ImplConst
476                    },
477                ),
478            Expression::As { convert, expr, .. } => self.type_of(expr).max(if convert.is_some() {
479                ExpressionKind::ImplConst
480            } else {
481                ExpressionKind::Const
482            }),
483            Expression::Select {
484                condition,
485                accept,
486                reject,
487            } => self
488                .type_of(condition)
489                .max(self.type_of(accept))
490                .max(self.type_of(reject))
491                .max(ExpressionKind::Const),
492            Expression::Relational { argument, .. } => self.type_of(argument),
493            Expression::ArrayLength(expr) => self.type_of(expr),
494            _ => ExpressionKind::Runtime,
495        }
496    }
497}
498
499#[derive(Clone, Debug, thiserror::Error)]
500#[cfg_attr(test, derive(PartialEq))]
501pub enum ConstantEvaluatorError {
502    #[error("Constants cannot access function arguments")]
503    FunctionArg,
504    #[error("Constants cannot access global variables")]
505    GlobalVariable,
506    #[error("Constants cannot access local variables")]
507    LocalVariable,
508    #[error("Cannot get the array length of a non array type")]
509    InvalidArrayLengthArg,
510    #[error("Constants cannot get the array length of a dynamically sized array")]
511    ArrayLengthDynamic,
512    #[error("Constants cannot call functions")]
513    Call,
514    #[error("Constants don't support workGroupUniformLoad")]
515    WorkGroupUniformLoadResult,
516    #[error("Constants don't support atomic functions")]
517    Atomic,
518    #[error("Constants don't support derivative functions")]
519    Derivative,
520    #[error("Constants don't support load expressions")]
521    Load,
522    #[error("Constants don't support image expressions")]
523    ImageExpression,
524    #[error("Constants don't support ray query expressions")]
525    RayQueryExpression,
526    #[error("Constants don't support subgroup expressions")]
527    SubgroupExpression,
528    #[error("Cannot access the type")]
529    InvalidAccessBase,
530    #[error("Cannot access at the index")]
531    InvalidAccessIndex,
532    #[error("Cannot access with index of type")]
533    InvalidAccessIndexTy,
534    #[error("Constants don't support array length expressions")]
535    ArrayLength,
536    #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
537    InvalidCastArg { from: String, to: String },
538    #[error("Cannot apply the unary op to the argument")]
539    InvalidUnaryOpArg,
540    #[error("Cannot apply the binary op to the arguments")]
541    InvalidBinaryOpArgs,
542    #[error("Cannot apply math function to type")]
543    InvalidMathArg,
544    #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
545    InvalidMathArgCount(crate::MathFunction, usize, usize),
546    #[error("value of `low` is greater than `high` for clamp built-in function")]
547    InvalidClamp,
548    #[error("Splat is defined only on scalar values")]
549    SplatScalarOnly,
550    #[error("Can only swizzle vector constants")]
551    SwizzleVectorOnly,
552    #[error("swizzle component not present in source expression")]
553    SwizzleOutOfBounds,
554    #[error("Type is not constructible")]
555    TypeNotConstructible,
556    #[error("Subexpression(s) are not constant")]
557    SubexpressionsAreNotConstant,
558    #[error("Not implemented as constant expression: {0}")]
559    NotImplemented(String),
560    #[error("{0} operation overflowed")]
561    Overflow(String),
562    #[error(
563        "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
564    )]
565    AutomaticConversionLossy {
566        value: String,
567        to_type: &'static str,
568    },
569    #[error("abstract floating-point values cannot be automatically converted to integers")]
570    AutomaticConversionFloatToInt { to_type: &'static str },
571    #[error("Division by zero")]
572    DivisionByZero,
573    #[error("Remainder by zero")]
574    RemainderByZero,
575    #[error("RHS of shift operation is greater than or equal to 32")]
576    ShiftedMoreThan32Bits,
577    #[error(transparent)]
578    Literal(#[from] crate::valid::LiteralError),
579    #[error("Can't use pipeline-overridable constants in const-expressions")]
580    Override,
581    #[error("Unexpected runtime-expression")]
582    RuntimeExpr,
583    #[error("Unexpected override-expression")]
584    OverrideExpr,
585}
586
587impl<'a> ConstantEvaluator<'a> {
588    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
589    /// constant expression arena.
590    ///
591    /// Report errors according to WGSL's rules for constant evaluation.
592    pub fn for_wgsl_module(
593        module: &'a mut crate::Module,
594        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
595        in_override_ctx: bool,
596    ) -> Self {
597        Self::for_module(
598            Behavior::Wgsl(if in_override_ctx {
599                WgslRestrictions::Override
600            } else {
601                WgslRestrictions::Const(None)
602            }),
603            module,
604            global_expression_kind_tracker,
605        )
606    }
607
608    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
609    /// constant expression arena.
610    ///
611    /// Report errors according to GLSL's rules for constant evaluation.
612    pub fn for_glsl_module(
613        module: &'a mut crate::Module,
614        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
615    ) -> Self {
616        Self::for_module(
617            Behavior::Glsl(GlslRestrictions::Const),
618            module,
619            global_expression_kind_tracker,
620        )
621    }
622
623    fn for_module(
624        behavior: Behavior<'a>,
625        module: &'a mut crate::Module,
626        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
627    ) -> Self {
628        Self {
629            behavior,
630            types: &mut module.types,
631            constants: &module.constants,
632            overrides: &module.overrides,
633            expressions: &mut module.global_expressions,
634            expression_kind_tracker: global_expression_kind_tracker,
635        }
636    }
637
638    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
639    /// expression arena.
640    ///
641    /// Report errors according to WGSL's rules for constant evaluation.
642    pub fn for_wgsl_function(
643        module: &'a mut crate::Module,
644        expressions: &'a mut Arena<Expression>,
645        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
646        emitter: &'a mut super::Emitter,
647        block: &'a mut crate::Block,
648        is_const: bool,
649    ) -> Self {
650        let local_data = FunctionLocalData {
651            global_expressions: &module.global_expressions,
652            emitter,
653            block,
654        };
655        Self {
656            behavior: Behavior::Wgsl(if is_const {
657                WgslRestrictions::Const(Some(local_data))
658            } else {
659                WgslRestrictions::Runtime(local_data)
660            }),
661            types: &mut module.types,
662            constants: &module.constants,
663            overrides: &module.overrides,
664            expressions,
665            expression_kind_tracker: local_expression_kind_tracker,
666        }
667    }
668
669    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
670    /// expression arena.
671    ///
672    /// Report errors according to GLSL's rules for constant evaluation.
673    pub fn for_glsl_function(
674        module: &'a mut crate::Module,
675        expressions: &'a mut Arena<Expression>,
676        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
677        emitter: &'a mut super::Emitter,
678        block: &'a mut crate::Block,
679    ) -> Self {
680        Self {
681            behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
682                global_expressions: &module.global_expressions,
683                emitter,
684                block,
685            })),
686            types: &mut module.types,
687            constants: &module.constants,
688            overrides: &module.overrides,
689            expressions,
690            expression_kind_tracker: local_expression_kind_tracker,
691        }
692    }
693
694    pub fn to_ctx(&self) -> crate::proc::GlobalCtx {
695        crate::proc::GlobalCtx {
696            types: self.types,
697            constants: self.constants,
698            overrides: self.overrides,
699            global_expressions: match self.function_local_data() {
700                Some(data) => data.global_expressions,
701                None => self.expressions,
702            },
703        }
704    }
705
706    fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
707        if !self.expression_kind_tracker.is_const(expr) {
708            log::debug!("check: SubexpressionsAreNotConstant");
709            return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
710        }
711        Ok(())
712    }
713
714    fn check_and_get(
715        &mut self,
716        expr: Handle<Expression>,
717    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
718        match self.expressions[expr] {
719            Expression::Constant(c) => {
720                // Are we working in a function's expression arena, or the
721                // module's constant expression arena?
722                if let Some(function_local_data) = self.function_local_data() {
723                    // Deep-copy the constant's value into our arena.
724                    self.copy_from(
725                        self.constants[c].init,
726                        function_local_data.global_expressions,
727                    )
728                } else {
729                    // "See through" the constant and use its initializer.
730                    Ok(self.constants[c].init)
731                }
732            }
733            _ => {
734                self.check(expr)?;
735                Ok(expr)
736            }
737        }
738    }
739
740    /// Try to evaluate `expr` at compile time.
741    ///
742    /// The `expr` argument can be any sort of Naga [`Expression`] you like. If
743    /// we can determine its value at compile time, we append an expression
744    /// representing its value - a tree of [`Literal`], [`Compose`],
745    /// [`ZeroValue`], and [`Swizzle`] expressions - to the expression arena
746    /// `self` contributes to.
747    ///
748    /// If `expr`'s value cannot be determined at compile time, and `self` is
749    /// contributing to some function's expression arena, then append `expr` to
750    /// that arena unchanged (and thus unevaluated). Otherwise, `self` must be
751    /// contributing to the module's constant expression arena; since `expr`'s
752    /// value is not a constant, return an error.
753    ///
754    /// We only consider `expr` itself, without recursing into its operands. Its
755    /// operands must all have been produced by prior calls to
756    /// `try_eval_and_append`, to ensure that they have already been reduced to
757    /// an evaluated form if possible.
758    ///
759    /// [`Literal`]: Expression::Literal
760    /// [`Compose`]: Expression::Compose
761    /// [`ZeroValue`]: Expression::ZeroValue
762    /// [`Swizzle`]: Expression::Swizzle
763    pub fn try_eval_and_append(
764        &mut self,
765        expr: Expression,
766        span: Span,
767    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
768        match self.expression_kind_tracker.type_of_with_expr(&expr) {
769            ExpressionKind::ImplConst => self.try_eval_and_append_impl(&expr, span),
770            ExpressionKind::Const => {
771                let eval_result = self.try_eval_and_append_impl(&expr, span);
772                // We should be able to evaluate `Const` expressions at this
773                // point. If we failed to, then that probably means we just
774                // haven't implemented that part of constant evaluation. Work
775                // around this by simply emitting it as a run-time expression.
776                if self.behavior.has_runtime_restrictions()
777                    && matches!(
778                        eval_result,
779                        Err(ConstantEvaluatorError::NotImplemented(_)
780                            | ConstantEvaluatorError::InvalidBinaryOpArgs,)
781                    )
782                {
783                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
784                } else {
785                    eval_result
786                }
787            }
788            ExpressionKind::Override => match self.behavior {
789                Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
790                    Ok(self.append_expr(expr, span, ExpressionKind::Override))
791                }
792                Behavior::Wgsl(WgslRestrictions::Const(_)) => {
793                    Err(ConstantEvaluatorError::OverrideExpr)
794                }
795                Behavior::Glsl(_) => {
796                    unreachable!()
797                }
798            },
799            ExpressionKind::Runtime => {
800                if self.behavior.has_runtime_restrictions() {
801                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
802                } else {
803                    Err(ConstantEvaluatorError::RuntimeExpr)
804                }
805            }
806        }
807    }
808
809    /// Is the [`Self::expressions`] arena the global module expression arena?
810    const fn is_global_arena(&self) -> bool {
811        matches!(
812            self.behavior,
813            Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
814                | Behavior::Glsl(GlslRestrictions::Const)
815        )
816    }
817
818    const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
819        match self.behavior {
820            Behavior::Wgsl(
821                WgslRestrictions::Runtime(ref function_local_data)
822                | WgslRestrictions::Const(Some(ref function_local_data)),
823            )
824            | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
825                Some(function_local_data)
826            }
827            _ => None,
828        }
829    }
830
831    fn try_eval_and_append_impl(
832        &mut self,
833        expr: &Expression,
834        span: Span,
835    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
836        log::trace!("try_eval_and_append: {:?}", expr);
837        match *expr {
838            Expression::Constant(c) if self.is_global_arena() => {
839                // "See through" the constant and use its initializer.
840                // This is mainly done to avoid having constants pointing to other constants.
841                Ok(self.constants[c].init)
842            }
843            Expression::Override(_) => Err(ConstantEvaluatorError::Override),
844            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
845                self.register_evaluated_expr(expr.clone(), span)
846            }
847            Expression::Compose { ty, ref components } => {
848                let components = components
849                    .iter()
850                    .map(|component| self.check_and_get(*component))
851                    .collect::<Result<Vec<_>, _>>()?;
852                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
853            }
854            Expression::Splat { size, value } => {
855                let value = self.check_and_get(value)?;
856                self.register_evaluated_expr(Expression::Splat { size, value }, span)
857            }
858            Expression::AccessIndex { base, index } => {
859                let base = self.check_and_get(base)?;
860
861                self.access(base, index as usize, span)
862            }
863            Expression::Access { base, index } => {
864                let base = self.check_and_get(base)?;
865                let index = self.check_and_get(index)?;
866
867                self.access(base, self.constant_index(index)?, span)
868            }
869            Expression::Swizzle {
870                size,
871                vector,
872                pattern,
873            } => {
874                let vector = self.check_and_get(vector)?;
875
876                self.swizzle(size, span, vector, pattern)
877            }
878            Expression::Unary { expr, op } => {
879                let expr = self.check_and_get(expr)?;
880
881                self.unary_op(op, expr, span)
882            }
883            Expression::Binary { left, right, op } => {
884                let left = self.check_and_get(left)?;
885                let right = self.check_and_get(right)?;
886
887                self.binary_op(op, left, right, span)
888            }
889            Expression::Math {
890                fun,
891                arg,
892                arg1,
893                arg2,
894                arg3,
895            } => {
896                let arg = self.check_and_get(arg)?;
897                let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
898                let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
899                let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
900
901                self.math(arg, arg1, arg2, arg3, fun, span)
902            }
903            Expression::As {
904                convert,
905                expr,
906                kind,
907            } => {
908                let expr = self.check_and_get(expr)?;
909
910                match convert {
911                    Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
912                    None => Err(ConstantEvaluatorError::NotImplemented(
913                        "bitcast built-in function".into(),
914                    )),
915                }
916            }
917            Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented(
918                "select built-in function".into(),
919            )),
920            Expression::Relational { fun, .. } => Err(ConstantEvaluatorError::NotImplemented(
921                format!("{fun:?} built-in function"),
922            )),
923            Expression::ArrayLength(expr) => match self.behavior {
924                Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
925                Behavior::Glsl(_) => {
926                    let expr = self.check_and_get(expr)?;
927                    self.array_length(expr, span)
928                }
929            },
930            Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
931            Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
932            Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
933            Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
934            Expression::WorkGroupUniformLoadResult { .. } => {
935                Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
936            }
937            Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
938            Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
939            Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
940            Expression::ImageSample { .. }
941            | Expression::ImageLoad { .. }
942            | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
943            Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => {
944                Err(ConstantEvaluatorError::RayQueryExpression)
945            }
946            Expression::SubgroupBallotResult { .. } => {
947                Err(ConstantEvaluatorError::SubgroupExpression)
948            }
949            Expression::SubgroupOperationResult { .. } => {
950                Err(ConstantEvaluatorError::SubgroupExpression)
951            }
952        }
953    }
954
955    /// Splat `value` to `size`, without using [`Splat`] expressions.
956    ///
957    /// This constructs [`Compose`] or [`ZeroValue`] expressions to
958    /// build a vector with the given `size` whose components are all
959    /// `value`.
960    ///
961    /// Use `span` as the span of the inserted expressions and
962    /// resulting types.
963    ///
964    /// [`Splat`]: Expression::Splat
965    /// [`Compose`]: Expression::Compose
966    /// [`ZeroValue`]: Expression::ZeroValue
967    fn splat(
968        &mut self,
969        value: Handle<Expression>,
970        size: crate::VectorSize,
971        span: Span,
972    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
973        match self.expressions[value] {
974            Expression::Literal(literal) => {
975                let scalar = literal.scalar();
976                let ty = self.types.insert(
977                    Type {
978                        name: None,
979                        inner: TypeInner::Vector { size, scalar },
980                    },
981                    span,
982                );
983                let expr = Expression::Compose {
984                    ty,
985                    components: vec![value; size as usize],
986                };
987                self.register_evaluated_expr(expr, span)
988            }
989            Expression::ZeroValue(ty) => {
990                let inner = match self.types[ty].inner {
991                    TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
992                    _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
993                };
994                let res_ty = self.types.insert(Type { name: None, inner }, span);
995                let expr = Expression::ZeroValue(res_ty);
996                self.register_evaluated_expr(expr, span)
997            }
998            _ => Err(ConstantEvaluatorError::SplatScalarOnly),
999        }
1000    }
1001
1002    fn swizzle(
1003        &mut self,
1004        size: crate::VectorSize,
1005        span: Span,
1006        src_constant: Handle<Expression>,
1007        pattern: [crate::SwizzleComponent; 4],
1008    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1009        let mut get_dst_ty = |ty| match self.types[ty].inner {
1010            TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1011                Type {
1012                    name: None,
1013                    inner: TypeInner::Vector { size, scalar },
1014                },
1015                span,
1016            )),
1017            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1018        };
1019
1020        match self.expressions[src_constant] {
1021            Expression::ZeroValue(ty) => {
1022                let dst_ty = get_dst_ty(ty)?;
1023                let expr = Expression::ZeroValue(dst_ty);
1024                self.register_evaluated_expr(expr, span)
1025            }
1026            Expression::Splat { value, .. } => {
1027                let expr = Expression::Splat { size, value };
1028                self.register_evaluated_expr(expr, span)
1029            }
1030            Expression::Compose { ty, ref components } => {
1031                let dst_ty = get_dst_ty(ty)?;
1032
1033                let mut flattened = [src_constant; 4]; // dummy value
1034                let len =
1035                    crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1036                        .zip(flattened.iter_mut())
1037                        .map(|(component, elt)| *elt = component)
1038                        .count();
1039                let flattened = &flattened[..len];
1040
1041                let swizzled_components = pattern[..size as usize]
1042                    .iter()
1043                    .map(|&sc| {
1044                        let sc = sc as usize;
1045                        if let Some(elt) = flattened.get(sc) {
1046                            Ok(*elt)
1047                        } else {
1048                            Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1049                        }
1050                    })
1051                    .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1052                let expr = Expression::Compose {
1053                    ty: dst_ty,
1054                    components: swizzled_components,
1055                };
1056                self.register_evaluated_expr(expr, span)
1057            }
1058            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1059        }
1060    }
1061
1062    fn math(
1063        &mut self,
1064        arg: Handle<Expression>,
1065        arg1: Option<Handle<Expression>>,
1066        arg2: Option<Handle<Expression>>,
1067        arg3: Option<Handle<Expression>>,
1068        fun: crate::MathFunction,
1069        span: Span,
1070    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1071        let expected = fun.argument_count();
1072        let given = Some(arg)
1073            .into_iter()
1074            .chain(arg1)
1075            .chain(arg2)
1076            .chain(arg3)
1077            .count();
1078        if expected != given {
1079            return Err(ConstantEvaluatorError::InvalidMathArgCount(
1080                fun, expected, given,
1081            ));
1082        }
1083
1084        // NOTE: We try to match the declaration order of `MathFunction` here.
1085        match fun {
1086            // comparison
1087            crate::MathFunction::Abs => {
1088                component_wise_scalar(self, span, [arg], |args| match args {
1089                    Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1090                    Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1091                    Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.abs()])),
1092                    Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1093                    Scalar::U32([e]) => Ok(Scalar::U32([e])), // TODO: just re-use the expression, ezpz
1094                    Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1095                    Scalar::U64([e]) => Ok(Scalar::U64([e])),
1096                })
1097            }
1098            crate::MathFunction::Min => {
1099                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1100                    Ok([e1.min(e2)])
1101                })
1102            }
1103            crate::MathFunction::Max => {
1104                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1105                    Ok([e1.max(e2)])
1106                })
1107            }
1108            crate::MathFunction::Clamp => {
1109                component_wise_scalar!(
1110                    self,
1111                    span,
1112                    [arg, arg1.unwrap(), arg2.unwrap()],
1113                    |e, low, high| {
1114                        if low > high {
1115                            Err(ConstantEvaluatorError::InvalidClamp)
1116                        } else {
1117                            Ok([e.clamp(low, high)])
1118                        }
1119                    }
1120                )
1121            }
1122            crate::MathFunction::Saturate => {
1123                component_wise_float!(self, span, [arg], |e| { Ok([e.clamp(0., 1.)]) })
1124            }
1125
1126            // trigonometry
1127            crate::MathFunction::Cos => {
1128                component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1129            }
1130            crate::MathFunction::Cosh => {
1131                component_wise_float!(self, span, [arg], |e| { Ok([e.cosh()]) })
1132            }
1133            crate::MathFunction::Sin => {
1134                component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1135            }
1136            crate::MathFunction::Sinh => {
1137                component_wise_float!(self, span, [arg], |e| { Ok([e.sinh()]) })
1138            }
1139            crate::MathFunction::Tan => {
1140                component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1141            }
1142            crate::MathFunction::Tanh => {
1143                component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1144            }
1145            crate::MathFunction::Acos => {
1146                component_wise_float!(self, span, [arg], |e| { Ok([e.acos()]) })
1147            }
1148            crate::MathFunction::Asin => {
1149                component_wise_float!(self, span, [arg], |e| { Ok([e.asin()]) })
1150            }
1151            crate::MathFunction::Atan => {
1152                component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1153            }
1154            crate::MathFunction::Asinh => {
1155                component_wise_float!(self, span, [arg], |e| { Ok([e.asinh()]) })
1156            }
1157            crate::MathFunction::Acosh => {
1158                component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) })
1159            }
1160            crate::MathFunction::Atanh => {
1161                component_wise_float!(self, span, [arg], |e| { Ok([e.atanh()]) })
1162            }
1163            crate::MathFunction::Radians => {
1164                component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1165            }
1166            crate::MathFunction::Degrees => {
1167                component_wise_float!(self, span, [arg], |e| { Ok([e.to_degrees()]) })
1168            }
1169
1170            // decomposition
1171            crate::MathFunction::Ceil => {
1172                component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1173            }
1174            crate::MathFunction::Floor => {
1175                component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1176            }
1177            crate::MathFunction::Round => {
1178                // TODO: this hit stable on 1.77, but MSRV hasn't caught up yet
1179                // This polyfill is shamelessly [~~stolen from~~ inspired by `ndarray-image`][polyfill source],
1180                // which has licensing compatible with ours. See also
1181                // <https://github.com/rust-lang/rust/issues/96710>.
1182                //
1183                // [polyfill source]: https://github.com/imeka/ndarray-ndimage/blob/8b14b4d6ecfbc96a8a052f802e342a7049c68d8f/src/lib.rs#L98
1184                fn round_ties_even(x: f64) -> f64 {
1185                    let i = x as i64;
1186                    let f = (x - i as f64).abs();
1187                    if f == 0.5 {
1188                        if i & 1 == 1 {
1189                            // -1.5, 1.5, 3.5, ...
1190                            (x.abs() + 0.5).copysign(x)
1191                        } else {
1192                            (x.abs() - 0.5).copysign(x)
1193                        }
1194                    } else {
1195                        x.round()
1196                    }
1197                }
1198                component_wise_float(self, span, [arg], |e| match e {
1199                    Float::Abstract([e]) => Ok(Float::Abstract([round_ties_even(e)])),
1200                    Float::F32([e]) => Ok(Float::F32([(round_ties_even(e as f64) as f32)])),
1201                })
1202            }
1203            crate::MathFunction::Fract => {
1204                component_wise_float!(self, span, [arg], |e| {
1205                    // N.B., Rust's definition of `fract` is `e - e.trunc()`, so we can't use that
1206                    // here.
1207                    Ok([e - e.floor()])
1208                })
1209            }
1210            crate::MathFunction::Trunc => {
1211                component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1212            }
1213
1214            // exponent
1215            crate::MathFunction::Exp => {
1216                component_wise_float!(self, span, [arg], |e| { Ok([e.exp()]) })
1217            }
1218            crate::MathFunction::Exp2 => {
1219                component_wise_float!(self, span, [arg], |e| { Ok([e.exp2()]) })
1220            }
1221            crate::MathFunction::Log => {
1222                component_wise_float!(self, span, [arg], |e| { Ok([e.ln()]) })
1223            }
1224            crate::MathFunction::Log2 => {
1225                component_wise_float!(self, span, [arg], |e| { Ok([e.log2()]) })
1226            }
1227            crate::MathFunction::Pow => {
1228                component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1229                    Ok([e1.powf(e2)])
1230                })
1231            }
1232
1233            // computational
1234            crate::MathFunction::Sign => {
1235                component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) })
1236            }
1237            crate::MathFunction::Fma => {
1238                component_wise_float!(
1239                    self,
1240                    span,
1241                    [arg, arg1.unwrap(), arg2.unwrap()],
1242                    |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1243                )
1244            }
1245            crate::MathFunction::Step => {
1246                component_wise_float!(self, span, [arg, arg1.unwrap()], |edge, x| {
1247                    Ok([if edge <= x { 1.0 } else { 0.0 }])
1248                })
1249            }
1250            crate::MathFunction::Sqrt => {
1251                component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) })
1252            }
1253            crate::MathFunction::InverseSqrt => {
1254                component_wise_float!(self, span, [arg], |e| { Ok([1. / e.sqrt()]) })
1255            }
1256
1257            // bits
1258            crate::MathFunction::CountTrailingZeros => {
1259                component_wise_concrete_int!(self, span, [arg], |e| {
1260                    #[allow(clippy::useless_conversion)]
1261                    Ok([e
1262                        .trailing_zeros()
1263                        .try_into()
1264                        .expect("bit count overflowed 32 bits, somehow!?")])
1265                })
1266            }
1267            crate::MathFunction::CountLeadingZeros => {
1268                component_wise_concrete_int!(self, span, [arg], |e| {
1269                    #[allow(clippy::useless_conversion)]
1270                    Ok([e
1271                        .leading_zeros()
1272                        .try_into()
1273                        .expect("bit count overflowed 32 bits, somehow!?")])
1274                })
1275            }
1276            crate::MathFunction::CountOneBits => {
1277                component_wise_concrete_int!(self, span, [arg], |e| {
1278                    #[allow(clippy::useless_conversion)]
1279                    Ok([e
1280                        .count_ones()
1281                        .try_into()
1282                        .expect("bit count overflowed 32 bits, somehow!?")])
1283                })
1284            }
1285            crate::MathFunction::ReverseBits => {
1286                component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1287            }
1288            crate::MathFunction::FirstTrailingBit => {
1289                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1290            }
1291            crate::MathFunction::FirstLeadingBit => {
1292                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1293            }
1294
1295            fun => Err(ConstantEvaluatorError::NotImplemented(format!(
1296                "{fun:?} built-in function"
1297            ))),
1298        }
1299    }
1300
1301    fn array_length(
1302        &mut self,
1303        array: Handle<Expression>,
1304        span: Span,
1305    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1306        match self.expressions[array] {
1307            Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
1308                match self.types[ty].inner {
1309                    TypeInner::Array { size, .. } => match size {
1310                        ArraySize::Constant(len) => {
1311                            let expr = Expression::Literal(Literal::U32(len.get()));
1312                            self.register_evaluated_expr(expr, span)
1313                        }
1314                        ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
1315                    },
1316                    _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
1317                }
1318            }
1319            _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
1320        }
1321    }
1322
1323    fn access(
1324        &mut self,
1325        base: Handle<Expression>,
1326        index: usize,
1327        span: Span,
1328    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1329        match self.expressions[base] {
1330            Expression::ZeroValue(ty) => {
1331                let ty_inner = &self.types[ty].inner;
1332                let components = ty_inner
1333                    .components()
1334                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
1335
1336                if index >= components as usize {
1337                    Err(ConstantEvaluatorError::InvalidAccessBase)
1338                } else {
1339                    let ty_res = ty_inner
1340                        .component_type(index)
1341                        .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
1342                    let ty = match ty_res {
1343                        crate::proc::TypeResolution::Handle(ty) => ty,
1344                        crate::proc::TypeResolution::Value(inner) => {
1345                            self.types.insert(Type { name: None, inner }, span)
1346                        }
1347                    };
1348                    self.register_evaluated_expr(Expression::ZeroValue(ty), span)
1349                }
1350            }
1351            Expression::Splat { size, value } => {
1352                if index >= size as usize {
1353                    Err(ConstantEvaluatorError::InvalidAccessBase)
1354                } else {
1355                    Ok(value)
1356                }
1357            }
1358            Expression::Compose { ty, ref components } => {
1359                let _ = self.types[ty]
1360                    .inner
1361                    .components()
1362                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
1363
1364                crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1365                    .nth(index)
1366                    .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
1367            }
1368            _ => Err(ConstantEvaluatorError::InvalidAccessBase),
1369        }
1370    }
1371
1372    fn constant_index(&self, expr: Handle<Expression>) -> Result<usize, ConstantEvaluatorError> {
1373        match self.expressions[expr] {
1374            Expression::ZeroValue(ty)
1375                if matches!(
1376                    self.types[ty].inner,
1377                    TypeInner::Scalar(crate::Scalar {
1378                        kind: ScalarKind::Uint,
1379                        ..
1380                    })
1381                ) =>
1382            {
1383                Ok(0)
1384            }
1385            Expression::Literal(Literal::U32(index)) => Ok(index as usize),
1386            _ => Err(ConstantEvaluatorError::InvalidAccessIndexTy),
1387        }
1388    }
1389
1390    /// Lower [`ZeroValue`] and [`Splat`] expressions to [`Literal`] and [`Compose`] expressions.
1391    ///
1392    /// [`ZeroValue`]: Expression::ZeroValue
1393    /// [`Splat`]: Expression::Splat
1394    /// [`Literal`]: Expression::Literal
1395    /// [`Compose`]: Expression::Compose
1396    fn eval_zero_value_and_splat(
1397        &mut self,
1398        expr: Handle<Expression>,
1399        span: Span,
1400    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1401        match self.expressions[expr] {
1402            Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
1403            Expression::Splat { size, value } => self.splat(value, size, span),
1404            _ => Ok(expr),
1405        }
1406    }
1407
1408    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
1409    ///
1410    /// [`ZeroValue`]: Expression::ZeroValue
1411    /// [`Literal`]: Expression::Literal
1412    /// [`Compose`]: Expression::Compose
1413    fn eval_zero_value(
1414        &mut self,
1415        expr: Handle<Expression>,
1416        span: Span,
1417    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1418        match self.expressions[expr] {
1419            Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
1420            _ => Ok(expr),
1421        }
1422    }
1423
1424    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
1425    ///
1426    /// [`ZeroValue`]: Expression::ZeroValue
1427    /// [`Literal`]: Expression::Literal
1428    /// [`Compose`]: Expression::Compose
1429    fn eval_zero_value_impl(
1430        &mut self,
1431        ty: Handle<Type>,
1432        span: Span,
1433    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1434        match self.types[ty].inner {
1435            TypeInner::Scalar(scalar) => {
1436                let expr = Expression::Literal(
1437                    Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
1438                );
1439                self.register_evaluated_expr(expr, span)
1440            }
1441            TypeInner::Vector { size, scalar } => {
1442                let scalar_ty = self.types.insert(
1443                    Type {
1444                        name: None,
1445                        inner: TypeInner::Scalar(scalar),
1446                    },
1447                    span,
1448                );
1449                let el = self.eval_zero_value_impl(scalar_ty, span)?;
1450                let expr = Expression::Compose {
1451                    ty,
1452                    components: vec![el; size as usize],
1453                };
1454                self.register_evaluated_expr(expr, span)
1455            }
1456            TypeInner::Matrix {
1457                columns,
1458                rows,
1459                scalar,
1460            } => {
1461                let vec_ty = self.types.insert(
1462                    Type {
1463                        name: None,
1464                        inner: TypeInner::Vector { size: rows, scalar },
1465                    },
1466                    span,
1467                );
1468                let el = self.eval_zero_value_impl(vec_ty, span)?;
1469                let expr = Expression::Compose {
1470                    ty,
1471                    components: vec![el; columns as usize],
1472                };
1473                self.register_evaluated_expr(expr, span)
1474            }
1475            TypeInner::Array {
1476                base,
1477                size: ArraySize::Constant(size),
1478                ..
1479            } => {
1480                let el = self.eval_zero_value_impl(base, span)?;
1481                let expr = Expression::Compose {
1482                    ty,
1483                    components: vec![el; size.get() as usize],
1484                };
1485                self.register_evaluated_expr(expr, span)
1486            }
1487            TypeInner::Struct { ref members, .. } => {
1488                let types: Vec<_> = members.iter().map(|m| m.ty).collect();
1489                let mut components = Vec::with_capacity(members.len());
1490                for ty in types {
1491                    components.push(self.eval_zero_value_impl(ty, span)?);
1492                }
1493                let expr = Expression::Compose { ty, components };
1494                self.register_evaluated_expr(expr, span)
1495            }
1496            _ => Err(ConstantEvaluatorError::TypeNotConstructible),
1497        }
1498    }
1499
1500    /// Convert the scalar components of `expr` to `target`.
1501    ///
1502    /// Treat `span` as the location of the resulting expression.
1503    pub fn cast(
1504        &mut self,
1505        expr: Handle<Expression>,
1506        target: crate::Scalar,
1507        span: Span,
1508    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1509        use crate::Scalar as Sc;
1510
1511        let expr = self.eval_zero_value(expr, span)?;
1512
1513        let make_error = || -> Result<_, ConstantEvaluatorError> {
1514            let from = format!("{:?} {:?}", expr, self.expressions[expr]);
1515
1516            #[cfg(feature = "wgsl-in")]
1517            let to = target.to_wgsl();
1518
1519            #[cfg(not(feature = "wgsl-in"))]
1520            let to = format!("{target:?}");
1521
1522            Err(ConstantEvaluatorError::InvalidCastArg { from, to })
1523        };
1524
1525        let expr = match self.expressions[expr] {
1526            Expression::Literal(literal) => {
1527                let literal = match target {
1528                    Sc::I32 => Literal::I32(match literal {
1529                        Literal::I32(v) => v,
1530                        Literal::U32(v) => v as i32,
1531                        Literal::F32(v) => v as i32,
1532                        Literal::Bool(v) => v as i32,
1533                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1534                            return make_error();
1535                        }
1536                        Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
1537                        Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
1538                    }),
1539                    Sc::U32 => Literal::U32(match literal {
1540                        Literal::I32(v) => v as u32,
1541                        Literal::U32(v) => v,
1542                        Literal::F32(v) => v as u32,
1543                        Literal::Bool(v) => v as u32,
1544                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1545                            return make_error();
1546                        }
1547                        Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
1548                        Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
1549                    }),
1550                    Sc::I64 => Literal::I64(match literal {
1551                        Literal::I32(v) => v as i64,
1552                        Literal::U32(v) => v as i64,
1553                        Literal::F32(v) => v as i64,
1554                        Literal::Bool(v) => v as i64,
1555                        Literal::F64(v) => v as i64,
1556                        Literal::I64(v) => v,
1557                        Literal::U64(v) => v as i64,
1558                        Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
1559                        Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
1560                    }),
1561                    Sc::U64 => Literal::U64(match literal {
1562                        Literal::I32(v) => v as u64,
1563                        Literal::U32(v) => v as u64,
1564                        Literal::F32(v) => v as u64,
1565                        Literal::Bool(v) => v as u64,
1566                        Literal::F64(v) => v as u64,
1567                        Literal::I64(v) => v as u64,
1568                        Literal::U64(v) => v,
1569                        Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
1570                        Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
1571                    }),
1572                    Sc::F32 => Literal::F32(match literal {
1573                        Literal::I32(v) => v as f32,
1574                        Literal::U32(v) => v as f32,
1575                        Literal::F32(v) => v,
1576                        Literal::Bool(v) => v as u32 as f32,
1577                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1578                            return make_error();
1579                        }
1580                        Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
1581                        Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
1582                    }),
1583                    Sc::F64 => Literal::F64(match literal {
1584                        Literal::I32(v) => v as f64,
1585                        Literal::U32(v) => v as f64,
1586                        Literal::F32(v) => v as f64,
1587                        Literal::F64(v) => v,
1588                        Literal::Bool(v) => v as u32 as f64,
1589                        Literal::I64(_) | Literal::U64(_) => return make_error(),
1590                        Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
1591                        Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
1592                    }),
1593                    Sc::BOOL => Literal::Bool(match literal {
1594                        Literal::I32(v) => v != 0,
1595                        Literal::U32(v) => v != 0,
1596                        Literal::F32(v) => v != 0.0,
1597                        Literal::Bool(v) => v,
1598                        Literal::F64(_)
1599                        | Literal::I64(_)
1600                        | Literal::U64(_)
1601                        | Literal::AbstractInt(_)
1602                        | Literal::AbstractFloat(_) => {
1603                            return make_error();
1604                        }
1605                    }),
1606                    Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
1607                        Literal::AbstractInt(v) => {
1608                            // Overflow is forbidden, but inexact conversions
1609                            // are fine. The range of f64 is far larger than
1610                            // that of i64, so we don't have to check anything
1611                            // here.
1612                            v as f64
1613                        }
1614                        Literal::AbstractFloat(v) => v,
1615                        _ => return make_error(),
1616                    }),
1617                    _ => {
1618                        log::debug!("Constant evaluator refused to convert value to {target:?}");
1619                        return make_error();
1620                    }
1621                };
1622                Expression::Literal(literal)
1623            }
1624            Expression::Compose {
1625                ty,
1626                components: ref src_components,
1627            } => {
1628                let ty_inner = match self.types[ty].inner {
1629                    TypeInner::Vector { size, .. } => TypeInner::Vector {
1630                        size,
1631                        scalar: target,
1632                    },
1633                    TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
1634                        columns,
1635                        rows,
1636                        scalar: target,
1637                    },
1638                    _ => return make_error(),
1639                };
1640
1641                let mut components = src_components.clone();
1642                for component in &mut components {
1643                    *component = self.cast(*component, target, span)?;
1644                }
1645
1646                let ty = self.types.insert(
1647                    Type {
1648                        name: None,
1649                        inner: ty_inner,
1650                    },
1651                    span,
1652                );
1653
1654                Expression::Compose { ty, components }
1655            }
1656            Expression::Splat { size, value } => {
1657                let value_span = self.expressions.get_span(value);
1658                let cast_value = self.cast(value, target, value_span)?;
1659                Expression::Splat {
1660                    size,
1661                    value: cast_value,
1662                }
1663            }
1664            _ => return make_error(),
1665        };
1666
1667        self.register_evaluated_expr(expr, span)
1668    }
1669
1670    /// Convert the scalar leaves of  `expr` to `target`, handling arrays.
1671    ///
1672    /// `expr` must be a `Compose` expression whose type is a scalar, vector,
1673    /// matrix, or nested arrays of such.
1674    ///
1675    /// This is basically the same as the [`cast`] method, except that that
1676    /// should only handle Naga [`As`] expressions, which cannot convert arrays.
1677    ///
1678    /// Treat `span` as the location of the resulting expression.
1679    ///
1680    /// [`cast`]: ConstantEvaluator::cast
1681    /// [`As`]: crate::Expression::As
1682    pub fn cast_array(
1683        &mut self,
1684        expr: Handle<Expression>,
1685        target: crate::Scalar,
1686        span: Span,
1687    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1688        let Expression::Compose { ty, ref components } = self.expressions[expr] else {
1689            return self.cast(expr, target, span);
1690        };
1691
1692        let TypeInner::Array {
1693            base: _,
1694            size,
1695            stride: _,
1696        } = self.types[ty].inner
1697        else {
1698            return self.cast(expr, target, span);
1699        };
1700
1701        let mut components = components.clone();
1702        for component in &mut components {
1703            *component = self.cast_array(*component, target, span)?;
1704        }
1705
1706        let first = components.first().unwrap();
1707        let new_base = match self.resolve_type(*first)? {
1708            crate::proc::TypeResolution::Handle(ty) => ty,
1709            crate::proc::TypeResolution::Value(inner) => {
1710                self.types.insert(Type { name: None, inner }, span)
1711            }
1712        };
1713        let new_base_stride = self.types[new_base].inner.size(self.to_ctx());
1714        let new_array_ty = self.types.insert(
1715            Type {
1716                name: None,
1717                inner: TypeInner::Array {
1718                    base: new_base,
1719                    size,
1720                    stride: new_base_stride,
1721                },
1722            },
1723            span,
1724        );
1725
1726        let compose = Expression::Compose {
1727            ty: new_array_ty,
1728            components,
1729        };
1730        self.register_evaluated_expr(compose, span)
1731    }
1732
1733    fn unary_op(
1734        &mut self,
1735        op: UnaryOperator,
1736        expr: Handle<Expression>,
1737        span: Span,
1738    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1739        let expr = self.eval_zero_value_and_splat(expr, span)?;
1740
1741        let expr = match self.expressions[expr] {
1742            Expression::Literal(value) => Expression::Literal(match op {
1743                UnaryOperator::Negate => match value {
1744                    Literal::I32(v) => Literal::I32(v.wrapping_neg()),
1745                    Literal::F32(v) => Literal::F32(-v),
1746                    Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
1747                    Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
1748                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1749                },
1750                UnaryOperator::LogicalNot => match value {
1751                    Literal::Bool(v) => Literal::Bool(!v),
1752                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1753                },
1754                UnaryOperator::BitwiseNot => match value {
1755                    Literal::I32(v) => Literal::I32(!v),
1756                    Literal::U32(v) => Literal::U32(!v),
1757                    Literal::AbstractInt(v) => Literal::AbstractInt(!v),
1758                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1759                },
1760            }),
1761            Expression::Compose {
1762                ty,
1763                components: ref src_components,
1764            } => {
1765                match self.types[ty].inner {
1766                    TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
1767                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1768                }
1769
1770                let mut components = src_components.clone();
1771                for component in &mut components {
1772                    *component = self.unary_op(op, *component, span)?;
1773                }
1774
1775                Expression::Compose { ty, components }
1776            }
1777            _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1778        };
1779
1780        self.register_evaluated_expr(expr, span)
1781    }
1782
1783    fn binary_op(
1784        &mut self,
1785        op: BinaryOperator,
1786        left: Handle<Expression>,
1787        right: Handle<Expression>,
1788        span: Span,
1789    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1790        let left = self.eval_zero_value_and_splat(left, span)?;
1791        let right = self.eval_zero_value_and_splat(right, span)?;
1792
1793        let expr = match (&self.expressions[left], &self.expressions[right]) {
1794            (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
1795                let literal = match op {
1796                    BinaryOperator::Equal => Literal::Bool(left_value == right_value),
1797                    BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
1798                    BinaryOperator::Less => Literal::Bool(left_value < right_value),
1799                    BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
1800                    BinaryOperator::Greater => Literal::Bool(left_value > right_value),
1801                    BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
1802
1803                    _ => match (left_value, right_value) {
1804                        (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
1805                            BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
1806                                ConstantEvaluatorError::Overflow("addition".into())
1807                            })?,
1808                            BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
1809                                ConstantEvaluatorError::Overflow("subtraction".into())
1810                            })?,
1811                            BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
1812                                ConstantEvaluatorError::Overflow("multiplication".into())
1813                            })?,
1814                            BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
1815                                if b == 0 {
1816                                    ConstantEvaluatorError::DivisionByZero
1817                                } else {
1818                                    ConstantEvaluatorError::Overflow("division".into())
1819                                }
1820                            })?,
1821                            BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
1822                                if b == 0 {
1823                                    ConstantEvaluatorError::RemainderByZero
1824                                } else {
1825                                    ConstantEvaluatorError::Overflow("remainder".into())
1826                                }
1827                            })?,
1828                            BinaryOperator::And => a & b,
1829                            BinaryOperator::ExclusiveOr => a ^ b,
1830                            BinaryOperator::InclusiveOr => a | b,
1831                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1832                        }),
1833                        (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
1834                            BinaryOperator::ShiftLeft => {
1835                                if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
1836                                    return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
1837                                }
1838                                a.checked_shl(b)
1839                                    .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
1840                            }
1841                            BinaryOperator::ShiftRight => a
1842                                .checked_shr(b)
1843                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
1844                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1845                        }),
1846                        (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
1847                            BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
1848                                ConstantEvaluatorError::Overflow("addition".into())
1849                            })?,
1850                            BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
1851                                ConstantEvaluatorError::Overflow("subtraction".into())
1852                            })?,
1853                            BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
1854                                ConstantEvaluatorError::Overflow("multiplication".into())
1855                            })?,
1856                            BinaryOperator::Divide => a
1857                                .checked_div(b)
1858                                .ok_or(ConstantEvaluatorError::DivisionByZero)?,
1859                            BinaryOperator::Modulo => a
1860                                .checked_rem(b)
1861                                .ok_or(ConstantEvaluatorError::RemainderByZero)?,
1862                            BinaryOperator::And => a & b,
1863                            BinaryOperator::ExclusiveOr => a ^ b,
1864                            BinaryOperator::InclusiveOr => a | b,
1865                            BinaryOperator::ShiftLeft => a
1866                                .checked_mul(
1867                                    1u32.checked_shl(b)
1868                                        .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
1869                                )
1870                                .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
1871                            BinaryOperator::ShiftRight => a
1872                                .checked_shr(b)
1873                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
1874                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1875                        }),
1876                        (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
1877                            BinaryOperator::Add => a + b,
1878                            BinaryOperator::Subtract => a - b,
1879                            BinaryOperator::Multiply => a * b,
1880                            BinaryOperator::Divide => a / b,
1881                            BinaryOperator::Modulo => a % b,
1882                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1883                        }),
1884                        (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
1885                            Literal::AbstractInt(match op {
1886                                BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
1887                                    ConstantEvaluatorError::Overflow("addition".into())
1888                                })?,
1889                                BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
1890                                    ConstantEvaluatorError::Overflow("subtraction".into())
1891                                })?,
1892                                BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
1893                                    ConstantEvaluatorError::Overflow("multiplication".into())
1894                                })?,
1895                                BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
1896                                    if b == 0 {
1897                                        ConstantEvaluatorError::DivisionByZero
1898                                    } else {
1899                                        ConstantEvaluatorError::Overflow("division".into())
1900                                    }
1901                                })?,
1902                                BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
1903                                    if b == 0 {
1904                                        ConstantEvaluatorError::RemainderByZero
1905                                    } else {
1906                                        ConstantEvaluatorError::Overflow("remainder".into())
1907                                    }
1908                                })?,
1909                                BinaryOperator::And => a & b,
1910                                BinaryOperator::ExclusiveOr => a ^ b,
1911                                BinaryOperator::InclusiveOr => a | b,
1912                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1913                            })
1914                        }
1915                        (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
1916                            Literal::AbstractFloat(match op {
1917                                BinaryOperator::Add => a + b,
1918                                BinaryOperator::Subtract => a - b,
1919                                BinaryOperator::Multiply => a * b,
1920                                BinaryOperator::Divide => a / b,
1921                                BinaryOperator::Modulo => a % b,
1922                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1923                            })
1924                        }
1925                        (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
1926                            BinaryOperator::LogicalAnd => a && b,
1927                            BinaryOperator::LogicalOr => a || b,
1928                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1929                        }),
1930                        _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1931                    },
1932                };
1933                Expression::Literal(literal)
1934            }
1935            (
1936                &Expression::Compose {
1937                    components: ref src_components,
1938                    ty,
1939                },
1940                &Expression::Literal(_),
1941            ) => {
1942                let mut components = src_components.clone();
1943                for component in &mut components {
1944                    *component = self.binary_op(op, *component, right, span)?;
1945                }
1946                Expression::Compose { ty, components }
1947            }
1948            (
1949                &Expression::Literal(_),
1950                &Expression::Compose {
1951                    components: ref src_components,
1952                    ty,
1953                },
1954            ) => {
1955                let mut components = src_components.clone();
1956                for component in &mut components {
1957                    *component = self.binary_op(op, left, *component, span)?;
1958                }
1959                Expression::Compose { ty, components }
1960            }
1961            (
1962                &Expression::Compose {
1963                    components: ref left_components,
1964                    ty: left_ty,
1965                },
1966                &Expression::Compose {
1967                    components: ref right_components,
1968                    ty: right_ty,
1969                },
1970            ) => {
1971                // We have to make a copy of the component lists, because the
1972                // call to `binary_op_vector` needs `&mut self`, but `self` owns
1973                // the component lists.
1974                let left_flattened = crate::proc::flatten_compose(
1975                    left_ty,
1976                    left_components,
1977                    self.expressions,
1978                    self.types,
1979                );
1980                let right_flattened = crate::proc::flatten_compose(
1981                    right_ty,
1982                    right_components,
1983                    self.expressions,
1984                    self.types,
1985                );
1986
1987                // `flatten_compose` doesn't return an `ExactSizeIterator`, so
1988                // make a reasonable guess of the capacity we'll need.
1989                let mut flattened = Vec::with_capacity(left_components.len());
1990                flattened.extend(left_flattened.zip(right_flattened));
1991
1992                match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
1993                    (
1994                        &TypeInner::Vector {
1995                            size: left_size, ..
1996                        },
1997                        &TypeInner::Vector {
1998                            size: right_size, ..
1999                        },
2000                    ) if left_size == right_size => {
2001                        self.binary_op_vector(op, left_size, &flattened, left_ty, span)?
2002                    }
2003                    _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2004                }
2005            }
2006            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2007        };
2008
2009        self.register_evaluated_expr(expr, span)
2010    }
2011
2012    fn binary_op_vector(
2013        &mut self,
2014        op: BinaryOperator,
2015        size: crate::VectorSize,
2016        components: &[(Handle<Expression>, Handle<Expression>)],
2017        left_ty: Handle<Type>,
2018        span: Span,
2019    ) -> Result<Expression, ConstantEvaluatorError> {
2020        let ty = match op {
2021            // Relational operators produce vectors of booleans.
2022            BinaryOperator::Equal
2023            | BinaryOperator::NotEqual
2024            | BinaryOperator::Less
2025            | BinaryOperator::LessEqual
2026            | BinaryOperator::Greater
2027            | BinaryOperator::GreaterEqual => self.types.insert(
2028                Type {
2029                    name: None,
2030                    inner: TypeInner::Vector {
2031                        size,
2032                        scalar: crate::Scalar::BOOL,
2033                    },
2034                },
2035                span,
2036            ),
2037
2038            // Other operators produce the same type as their left
2039            // operand.
2040            BinaryOperator::Add
2041            | BinaryOperator::Subtract
2042            | BinaryOperator::Multiply
2043            | BinaryOperator::Divide
2044            | BinaryOperator::Modulo
2045            | BinaryOperator::And
2046            | BinaryOperator::ExclusiveOr
2047            | BinaryOperator::InclusiveOr
2048            | BinaryOperator::LogicalAnd
2049            | BinaryOperator::LogicalOr
2050            | BinaryOperator::ShiftLeft
2051            | BinaryOperator::ShiftRight => left_ty,
2052        };
2053
2054        let components = components
2055            .iter()
2056            .map(|&(left, right)| self.binary_op(op, left, right, span))
2057            .collect::<Result<Vec<_>, _>>()?;
2058
2059        Ok(Expression::Compose { ty, components })
2060    }
2061
2062    /// Deep copy `expr` from `expressions` into `self.expressions`.
2063    ///
2064    /// Return the root of the new copy.
2065    ///
2066    /// This is used when we're evaluating expressions in a function's
2067    /// expression arena that refer to a constant: we need to copy the
2068    /// constant's value into the function's arena so we can operate on it.
2069    fn copy_from(
2070        &mut self,
2071        expr: Handle<Expression>,
2072        expressions: &Arena<Expression>,
2073    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2074        let span = expressions.get_span(expr);
2075        match expressions[expr] {
2076            ref expr @ (Expression::Literal(_)
2077            | Expression::Constant(_)
2078            | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
2079            Expression::Compose { ty, ref components } => {
2080                let mut components = components.clone();
2081                for component in &mut components {
2082                    *component = self.copy_from(*component, expressions)?;
2083                }
2084                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
2085            }
2086            Expression::Splat { size, value } => {
2087                let value = self.copy_from(value, expressions)?;
2088                self.register_evaluated_expr(Expression::Splat { size, value }, span)
2089            }
2090            _ => {
2091                log::debug!("copy_from: SubexpressionsAreNotConstant");
2092                Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
2093            }
2094        }
2095    }
2096
2097    fn register_evaluated_expr(
2098        &mut self,
2099        expr: Expression,
2100        span: Span,
2101    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2102        // It suffices to only check literals, since we only register one
2103        // expression at a time, `Compose` expressions can only refer to other
2104        // expressions, and `ZeroValue` expressions are always okay.
2105        if let Expression::Literal(literal) = expr {
2106            crate::valid::check_literal_value(literal)?;
2107        }
2108
2109        Ok(self.append_expr(expr, span, ExpressionKind::Const))
2110    }
2111
2112    fn append_expr(
2113        &mut self,
2114        expr: Expression,
2115        span: Span,
2116        expr_type: ExpressionKind,
2117    ) -> Handle<Expression> {
2118        let h = match self.behavior {
2119            Behavior::Wgsl(
2120                WgslRestrictions::Runtime(ref mut function_local_data)
2121                | WgslRestrictions::Const(Some(ref mut function_local_data)),
2122            )
2123            | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
2124                let is_running = function_local_data.emitter.is_running();
2125                let needs_pre_emit = expr.needs_pre_emit();
2126                if is_running && needs_pre_emit {
2127                    function_local_data
2128                        .block
2129                        .extend(function_local_data.emitter.finish(self.expressions));
2130                    let h = self.expressions.append(expr, span);
2131                    function_local_data.emitter.start(self.expressions);
2132                    h
2133                } else {
2134                    self.expressions.append(expr, span)
2135                }
2136            }
2137            _ => self.expressions.append(expr, span),
2138        };
2139        self.expression_kind_tracker.insert(h, expr_type);
2140        h
2141    }
2142
2143    fn resolve_type(
2144        &self,
2145        expr: Handle<Expression>,
2146    ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
2147        use crate::proc::TypeResolution as Tr;
2148        use crate::Expression as Ex;
2149        let resolution = match self.expressions[expr] {
2150            Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
2151            Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
2152            Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
2153            Ex::Splat { size, value } => {
2154                let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
2155                    return Err(ConstantEvaluatorError::SplatScalarOnly);
2156                };
2157                Tr::Value(TypeInner::Vector { scalar, size })
2158            }
2159            _ => {
2160                log::debug!("resolve_type: SubexpressionsAreNotConstant");
2161                return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
2162            }
2163        };
2164
2165        Ok(resolution)
2166    }
2167}
2168
2169fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
2170    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, a value
2171    // of 1 means the least significant bit is set. Therefore, an input of `0x[80 00…]` would
2172    // return a right-to-left bit index of 0.
2173    let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
2174        match e {
2175            idx @ 0..=31 => idx,
2176            32 => u32::MAX,
2177            _ => unreachable!(),
2178        }
2179    };
2180    match concrete_int {
2181        ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
2182        ConcreteInt::I32([e]) => {
2183            ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
2184        }
2185    }
2186}
2187
2188#[test]
2189fn first_trailing_bit_smoke() {
2190    assert_eq!(
2191        first_trailing_bit(ConcreteInt::I32([0])),
2192        ConcreteInt::I32([-1])
2193    );
2194    assert_eq!(
2195        first_trailing_bit(ConcreteInt::I32([1])),
2196        ConcreteInt::I32([0])
2197    );
2198    assert_eq!(
2199        first_trailing_bit(ConcreteInt::I32([2])),
2200        ConcreteInt::I32([1])
2201    );
2202    assert_eq!(
2203        first_trailing_bit(ConcreteInt::I32([-1])),
2204        ConcreteInt::I32([0]),
2205    );
2206    assert_eq!(
2207        first_trailing_bit(ConcreteInt::I32([i32::MIN])),
2208        ConcreteInt::I32([31]),
2209    );
2210    assert_eq!(
2211        first_trailing_bit(ConcreteInt::I32([i32::MAX])),
2212        ConcreteInt::I32([0]),
2213    );
2214    for idx in 0..32 {
2215        assert_eq!(
2216            first_trailing_bit(ConcreteInt::I32([1 << idx])),
2217            ConcreteInt::I32([idx])
2218        )
2219    }
2220
2221    assert_eq!(
2222        first_trailing_bit(ConcreteInt::U32([0])),
2223        ConcreteInt::U32([u32::MAX])
2224    );
2225    assert_eq!(
2226        first_trailing_bit(ConcreteInt::U32([1])),
2227        ConcreteInt::U32([0])
2228    );
2229    assert_eq!(
2230        first_trailing_bit(ConcreteInt::U32([2])),
2231        ConcreteInt::U32([1])
2232    );
2233    assert_eq!(
2234        first_trailing_bit(ConcreteInt::U32([1 << 31])),
2235        ConcreteInt::U32([31]),
2236    );
2237    assert_eq!(
2238        first_trailing_bit(ConcreteInt::U32([u32::MAX])),
2239        ConcreteInt::U32([0]),
2240    );
2241    for idx in 0..32 {
2242        assert_eq!(
2243            first_trailing_bit(ConcreteInt::U32([1 << idx])),
2244            ConcreteInt::U32([idx])
2245        )
2246    }
2247}
2248
2249fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
2250    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, 1 means
2251    // the least significant bit is set. Therefore, an input of 1 would return a right-to-left bit
2252    // index of 0.
2253    let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
2254        match e {
2255            idx @ 0..=31 => 31 - idx,
2256            32 => u32::MAX,
2257            _ => unreachable!(),
2258        }
2259    };
2260    match concrete_int {
2261        ConcreteInt::I32([e]) => ConcreteInt::I32([{
2262            let rtl_bit_index = if e.is_negative() {
2263                e.leading_ones()
2264            } else {
2265                e.leading_zeros()
2266            };
2267            rtl_to_ltr_bit_idx(rtl_bit_index) as i32
2268        }]),
2269        ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
2270    }
2271}
2272
2273#[test]
2274fn first_leading_bit_smoke() {
2275    assert_eq!(
2276        first_leading_bit(ConcreteInt::I32([-1])),
2277        ConcreteInt::I32([-1])
2278    );
2279    assert_eq!(
2280        first_leading_bit(ConcreteInt::I32([0])),
2281        ConcreteInt::I32([-1])
2282    );
2283    assert_eq!(
2284        first_leading_bit(ConcreteInt::I32([1])),
2285        ConcreteInt::I32([0])
2286    );
2287    assert_eq!(
2288        first_leading_bit(ConcreteInt::I32([-2])),
2289        ConcreteInt::I32([0])
2290    );
2291    assert_eq!(
2292        first_leading_bit(ConcreteInt::I32([1234 + 4567])),
2293        ConcreteInt::I32([12])
2294    );
2295    assert_eq!(
2296        first_leading_bit(ConcreteInt::I32([i32::MAX])),
2297        ConcreteInt::I32([30])
2298    );
2299    assert_eq!(
2300        first_leading_bit(ConcreteInt::I32([i32::MIN])),
2301        ConcreteInt::I32([30])
2302    );
2303    // NOTE: Ignore the sign bit, which is a separate (above) case.
2304    for idx in 0..(32 - 1) {
2305        assert_eq!(
2306            first_leading_bit(ConcreteInt::I32([1 << idx])),
2307            ConcreteInt::I32([idx])
2308        );
2309    }
2310    for idx in 1..(32 - 1) {
2311        assert_eq!(
2312            first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
2313            ConcreteInt::I32([idx - 1])
2314        );
2315    }
2316
2317    assert_eq!(
2318        first_leading_bit(ConcreteInt::U32([0])),
2319        ConcreteInt::U32([u32::MAX])
2320    );
2321    assert_eq!(
2322        first_leading_bit(ConcreteInt::U32([1])),
2323        ConcreteInt::U32([0])
2324    );
2325    assert_eq!(
2326        first_leading_bit(ConcreteInt::U32([u32::MAX])),
2327        ConcreteInt::U32([31])
2328    );
2329    for idx in 0..32 {
2330        assert_eq!(
2331            first_leading_bit(ConcreteInt::U32([1 << idx])),
2332            ConcreteInt::U32([idx])
2333        )
2334    }
2335}
2336
2337/// Trait for conversions of abstract values to concrete types.
2338trait TryFromAbstract<T>: Sized {
2339    /// Convert an abstract literal `value` to `Self`.
2340    ///
2341    /// Since Naga's `AbstractInt` and `AbstractFloat` exist to support
2342    /// WGSL, we follow WGSL's conversion rules here:
2343    ///
2344    /// - WGSL §6.1.2. Conversion Rank says that automatic conversions
2345    ///   to integers are either lossless or an error.
2346    ///
2347    /// - WGSL §14.6.4 Floating Point Conversion says that conversions
2348    ///   to floating point in constant expressions and override
2349    ///   expressions are errors if the value is out of range for the
2350    ///   destination type, but rounding is okay.
2351    ///
2352    /// [`AbstractInt`]: crate::Literal::AbstractInt
2353    /// [`Float`]: crate::Literal::Float
2354    fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
2355}
2356
2357impl TryFromAbstract<i64> for i32 {
2358    fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
2359        i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2360            value: format!("{value:?}"),
2361            to_type: "i32",
2362        })
2363    }
2364}
2365
2366impl TryFromAbstract<i64> for u32 {
2367    fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
2368        u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2369            value: format!("{value:?}"),
2370            to_type: "u32",
2371        })
2372    }
2373}
2374
2375impl TryFromAbstract<i64> for u64 {
2376    fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
2377        u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2378            value: format!("{value:?}"),
2379            to_type: "u64",
2380        })
2381    }
2382}
2383
2384impl TryFromAbstract<i64> for i64 {
2385    fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
2386        Ok(value)
2387    }
2388}
2389
2390impl TryFromAbstract<i64> for f32 {
2391    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
2392        let f = value as f32;
2393        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
2394        // `f32` is roughly ±3.4 × 10³⁸, so there's no opportunity for
2395        // overflow here.
2396        Ok(f)
2397    }
2398}
2399
2400impl TryFromAbstract<f64> for f32 {
2401    fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
2402        let f = value as f32;
2403        if f.is_infinite() {
2404            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
2405                value: format!("{value:?}"),
2406                to_type: "f32",
2407            });
2408        }
2409        Ok(f)
2410    }
2411}
2412
2413impl TryFromAbstract<i64> for f64 {
2414    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
2415        let f = value as f64;
2416        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
2417        // `f64` is roughly ±1.8 × 10³⁰⁸, so there's no opportunity for
2418        // overflow here.
2419        Ok(f)
2420    }
2421}
2422
2423impl TryFromAbstract<f64> for f64 {
2424    fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
2425        Ok(value)
2426    }
2427}
2428
2429impl TryFromAbstract<f64> for i32 {
2430    fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
2431        Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "i32" })
2432    }
2433}
2434
2435impl TryFromAbstract<f64> for u32 {
2436    fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
2437        Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "u32" })
2438    }
2439}
2440
2441impl TryFromAbstract<f64> for i64 {
2442    fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
2443        Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "i64" })
2444    }
2445}
2446
2447impl TryFromAbstract<f64> for u64 {
2448    fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
2449        Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "u64" })
2450    }
2451}
2452
2453#[cfg(test)]
2454mod tests {
2455    use std::vec;
2456
2457    use crate::{
2458        Arena, Constant, Expression, Literal, ScalarKind, Type, TypeInner, UnaryOperator,
2459        UniqueArena, VectorSize,
2460    };
2461
2462    use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
2463
2464    #[test]
2465    fn unary_op() {
2466        let mut types = UniqueArena::new();
2467        let mut constants = Arena::new();
2468        let overrides = Arena::new();
2469        let mut global_expressions = Arena::new();
2470
2471        let scalar_ty = types.insert(
2472            Type {
2473                name: None,
2474                inner: TypeInner::Scalar(crate::Scalar::I32),
2475            },
2476            Default::default(),
2477        );
2478
2479        let vec_ty = types.insert(
2480            Type {
2481                name: None,
2482                inner: TypeInner::Vector {
2483                    size: VectorSize::Bi,
2484                    scalar: crate::Scalar::I32,
2485                },
2486            },
2487            Default::default(),
2488        );
2489
2490        let h = constants.append(
2491            Constant {
2492                name: None,
2493                ty: scalar_ty,
2494                init: global_expressions
2495                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
2496            },
2497            Default::default(),
2498        );
2499
2500        let h1 = constants.append(
2501            Constant {
2502                name: None,
2503                ty: scalar_ty,
2504                init: global_expressions
2505                    .append(Expression::Literal(Literal::I32(8)), Default::default()),
2506            },
2507            Default::default(),
2508        );
2509
2510        let vec_h = constants.append(
2511            Constant {
2512                name: None,
2513                ty: vec_ty,
2514                init: global_expressions.append(
2515                    Expression::Compose {
2516                        ty: vec_ty,
2517                        components: vec![constants[h].init, constants[h1].init],
2518                    },
2519                    Default::default(),
2520                ),
2521            },
2522            Default::default(),
2523        );
2524
2525        let expr = global_expressions.append(Expression::Constant(h), Default::default());
2526        let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
2527
2528        let expr2 = Expression::Unary {
2529            op: UnaryOperator::Negate,
2530            expr,
2531        };
2532
2533        let expr3 = Expression::Unary {
2534            op: UnaryOperator::BitwiseNot,
2535            expr,
2536        };
2537
2538        let expr4 = Expression::Unary {
2539            op: UnaryOperator::BitwiseNot,
2540            expr: expr1,
2541        };
2542
2543        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2544        let mut solver = ConstantEvaluator {
2545            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2546            types: &mut types,
2547            constants: &constants,
2548            overrides: &overrides,
2549            expressions: &mut global_expressions,
2550            expression_kind_tracker,
2551        };
2552
2553        let res1 = solver
2554            .try_eval_and_append(expr2, Default::default())
2555            .unwrap();
2556        let res2 = solver
2557            .try_eval_and_append(expr3, Default::default())
2558            .unwrap();
2559        let res3 = solver
2560            .try_eval_and_append(expr4, Default::default())
2561            .unwrap();
2562
2563        assert_eq!(
2564            global_expressions[res1],
2565            Expression::Literal(Literal::I32(-4))
2566        );
2567
2568        assert_eq!(
2569            global_expressions[res2],
2570            Expression::Literal(Literal::I32(!4))
2571        );
2572
2573        let res3_inner = &global_expressions[res3];
2574
2575        match *res3_inner {
2576            Expression::Compose {
2577                ref ty,
2578                ref components,
2579            } => {
2580                assert_eq!(*ty, vec_ty);
2581                let mut components_iter = components.iter().copied();
2582                assert_eq!(
2583                    global_expressions[components_iter.next().unwrap()],
2584                    Expression::Literal(Literal::I32(!4))
2585                );
2586                assert_eq!(
2587                    global_expressions[components_iter.next().unwrap()],
2588                    Expression::Literal(Literal::I32(!8))
2589                );
2590                assert!(components_iter.next().is_none());
2591            }
2592            _ => panic!("Expected vector"),
2593        }
2594    }
2595
2596    #[test]
2597    fn cast() {
2598        let mut types = UniqueArena::new();
2599        let mut constants = Arena::new();
2600        let overrides = Arena::new();
2601        let mut global_expressions = Arena::new();
2602
2603        let scalar_ty = types.insert(
2604            Type {
2605                name: None,
2606                inner: TypeInner::Scalar(crate::Scalar::I32),
2607            },
2608            Default::default(),
2609        );
2610
2611        let h = constants.append(
2612            Constant {
2613                name: None,
2614                ty: scalar_ty,
2615                init: global_expressions
2616                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
2617            },
2618            Default::default(),
2619        );
2620
2621        let expr = global_expressions.append(Expression::Constant(h), Default::default());
2622
2623        let root = Expression::As {
2624            expr,
2625            kind: ScalarKind::Bool,
2626            convert: Some(crate::BOOL_WIDTH),
2627        };
2628
2629        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2630        let mut solver = ConstantEvaluator {
2631            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2632            types: &mut types,
2633            constants: &constants,
2634            overrides: &overrides,
2635            expressions: &mut global_expressions,
2636            expression_kind_tracker,
2637        };
2638
2639        let res = solver
2640            .try_eval_and_append(root, Default::default())
2641            .unwrap();
2642
2643        assert_eq!(
2644            global_expressions[res],
2645            Expression::Literal(Literal::Bool(true))
2646        );
2647    }
2648
2649    #[test]
2650    fn access() {
2651        let mut types = UniqueArena::new();
2652        let mut constants = Arena::new();
2653        let overrides = Arena::new();
2654        let mut global_expressions = Arena::new();
2655
2656        let matrix_ty = types.insert(
2657            Type {
2658                name: None,
2659                inner: TypeInner::Matrix {
2660                    columns: VectorSize::Bi,
2661                    rows: VectorSize::Tri,
2662                    scalar: crate::Scalar::F32,
2663                },
2664            },
2665            Default::default(),
2666        );
2667
2668        let vec_ty = types.insert(
2669            Type {
2670                name: None,
2671                inner: TypeInner::Vector {
2672                    size: VectorSize::Tri,
2673                    scalar: crate::Scalar::F32,
2674                },
2675            },
2676            Default::default(),
2677        );
2678
2679        let mut vec1_components = Vec::with_capacity(3);
2680        let mut vec2_components = Vec::with_capacity(3);
2681
2682        for i in 0..3 {
2683            let h = global_expressions.append(
2684                Expression::Literal(Literal::F32(i as f32)),
2685                Default::default(),
2686            );
2687
2688            vec1_components.push(h)
2689        }
2690
2691        for i in 3..6 {
2692            let h = global_expressions.append(
2693                Expression::Literal(Literal::F32(i as f32)),
2694                Default::default(),
2695            );
2696
2697            vec2_components.push(h)
2698        }
2699
2700        let vec1 = constants.append(
2701            Constant {
2702                name: None,
2703                ty: vec_ty,
2704                init: global_expressions.append(
2705                    Expression::Compose {
2706                        ty: vec_ty,
2707                        components: vec1_components,
2708                    },
2709                    Default::default(),
2710                ),
2711            },
2712            Default::default(),
2713        );
2714
2715        let vec2 = constants.append(
2716            Constant {
2717                name: None,
2718                ty: vec_ty,
2719                init: global_expressions.append(
2720                    Expression::Compose {
2721                        ty: vec_ty,
2722                        components: vec2_components,
2723                    },
2724                    Default::default(),
2725                ),
2726            },
2727            Default::default(),
2728        );
2729
2730        let h = constants.append(
2731            Constant {
2732                name: None,
2733                ty: matrix_ty,
2734                init: global_expressions.append(
2735                    Expression::Compose {
2736                        ty: matrix_ty,
2737                        components: vec![constants[vec1].init, constants[vec2].init],
2738                    },
2739                    Default::default(),
2740                ),
2741            },
2742            Default::default(),
2743        );
2744
2745        let base = global_expressions.append(Expression::Constant(h), Default::default());
2746
2747        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2748        let mut solver = ConstantEvaluator {
2749            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2750            types: &mut types,
2751            constants: &constants,
2752            overrides: &overrides,
2753            expressions: &mut global_expressions,
2754            expression_kind_tracker,
2755        };
2756
2757        let root1 = Expression::AccessIndex { base, index: 1 };
2758
2759        let res1 = solver
2760            .try_eval_and_append(root1, Default::default())
2761            .unwrap();
2762
2763        let root2 = Expression::AccessIndex {
2764            base: res1,
2765            index: 2,
2766        };
2767
2768        let res2 = solver
2769            .try_eval_and_append(root2, Default::default())
2770            .unwrap();
2771
2772        match global_expressions[res1] {
2773            Expression::Compose {
2774                ref ty,
2775                ref components,
2776            } => {
2777                assert_eq!(*ty, vec_ty);
2778                let mut components_iter = components.iter().copied();
2779                assert_eq!(
2780                    global_expressions[components_iter.next().unwrap()],
2781                    Expression::Literal(Literal::F32(3.))
2782                );
2783                assert_eq!(
2784                    global_expressions[components_iter.next().unwrap()],
2785                    Expression::Literal(Literal::F32(4.))
2786                );
2787                assert_eq!(
2788                    global_expressions[components_iter.next().unwrap()],
2789                    Expression::Literal(Literal::F32(5.))
2790                );
2791                assert!(components_iter.next().is_none());
2792            }
2793            _ => panic!("Expected vector"),
2794        }
2795
2796        assert_eq!(
2797            global_expressions[res2],
2798            Expression::Literal(Literal::F32(5.))
2799        );
2800    }
2801
2802    #[test]
2803    fn compose_of_constants() {
2804        let mut types = UniqueArena::new();
2805        let mut constants = Arena::new();
2806        let overrides = Arena::new();
2807        let mut global_expressions = Arena::new();
2808
2809        let i32_ty = types.insert(
2810            Type {
2811                name: None,
2812                inner: TypeInner::Scalar(crate::Scalar::I32),
2813            },
2814            Default::default(),
2815        );
2816
2817        let vec2_i32_ty = types.insert(
2818            Type {
2819                name: None,
2820                inner: TypeInner::Vector {
2821                    size: VectorSize::Bi,
2822                    scalar: crate::Scalar::I32,
2823                },
2824            },
2825            Default::default(),
2826        );
2827
2828        let h = constants.append(
2829            Constant {
2830                name: None,
2831                ty: i32_ty,
2832                init: global_expressions
2833                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
2834            },
2835            Default::default(),
2836        );
2837
2838        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
2839
2840        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2841        let mut solver = ConstantEvaluator {
2842            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2843            types: &mut types,
2844            constants: &constants,
2845            overrides: &overrides,
2846            expressions: &mut global_expressions,
2847            expression_kind_tracker,
2848        };
2849
2850        let solved_compose = solver
2851            .try_eval_and_append(
2852                Expression::Compose {
2853                    ty: vec2_i32_ty,
2854                    components: vec![h_expr, h_expr],
2855                },
2856                Default::default(),
2857            )
2858            .unwrap();
2859        let solved_negate = solver
2860            .try_eval_and_append(
2861                Expression::Unary {
2862                    op: UnaryOperator::Negate,
2863                    expr: solved_compose,
2864                },
2865                Default::default(),
2866            )
2867            .unwrap();
2868
2869        let pass = match global_expressions[solved_negate] {
2870            Expression::Compose { ty, ref components } => {
2871                ty == vec2_i32_ty
2872                    && components.iter().all(|&component| {
2873                        let component = &global_expressions[component];
2874                        matches!(*component, Expression::Literal(Literal::I32(-4)))
2875                    })
2876            }
2877            _ => false,
2878        };
2879        if !pass {
2880            panic!("unexpected evaluation result")
2881        }
2882    }
2883
2884    #[test]
2885    fn splat_of_constant() {
2886        let mut types = UniqueArena::new();
2887        let mut constants = Arena::new();
2888        let overrides = Arena::new();
2889        let mut global_expressions = Arena::new();
2890
2891        let i32_ty = types.insert(
2892            Type {
2893                name: None,
2894                inner: TypeInner::Scalar(crate::Scalar::I32),
2895            },
2896            Default::default(),
2897        );
2898
2899        let vec2_i32_ty = types.insert(
2900            Type {
2901                name: None,
2902                inner: TypeInner::Vector {
2903                    size: VectorSize::Bi,
2904                    scalar: crate::Scalar::I32,
2905                },
2906            },
2907            Default::default(),
2908        );
2909
2910        let h = constants.append(
2911            Constant {
2912                name: None,
2913                ty: i32_ty,
2914                init: global_expressions
2915                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
2916            },
2917            Default::default(),
2918        );
2919
2920        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
2921
2922        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2923        let mut solver = ConstantEvaluator {
2924            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2925            types: &mut types,
2926            constants: &constants,
2927            overrides: &overrides,
2928            expressions: &mut global_expressions,
2929            expression_kind_tracker,
2930        };
2931
2932        let solved_compose = solver
2933            .try_eval_and_append(
2934                Expression::Splat {
2935                    size: VectorSize::Bi,
2936                    value: h_expr,
2937                },
2938                Default::default(),
2939            )
2940            .unwrap();
2941        let solved_negate = solver
2942            .try_eval_and_append(
2943                Expression::Unary {
2944                    op: UnaryOperator::Negate,
2945                    expr: solved_compose,
2946                },
2947                Default::default(),
2948            )
2949            .unwrap();
2950
2951        let pass = match global_expressions[solved_negate] {
2952            Expression::Compose { ty, ref components } => {
2953                ty == vec2_i32_ty
2954                    && components.iter().all(|&component| {
2955                        let component = &global_expressions[component];
2956                        matches!(*component, Expression::Literal(Literal::I32(-4)))
2957                    })
2958            }
2959            _ => false,
2960        };
2961        if !pass {
2962            panic!("unexpected evaluation result")
2963        }
2964    }
2965}