naga/proc/
constant_evaluator.rs

1use std::iter;
2
3use arrayvec::ArrayVec;
4
5use crate::{
6    arena::{Arena, Handle, HandleVec, UniqueArena},
7    ArraySize, BinaryOperator, Constant, Expression, Literal, Override, ScalarKind, Span, Type,
8    TypeInner, UnaryOperator,
9};
10
11/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating
12/// `macro_rules!` items that, in turn, emit their own `macro_rules!` items.
13///
14/// Technique stolen directly from
15/// <https://github.com/rust-lang/rust/issues/35853#issuecomment-415993963>.
16macro_rules! with_dollar_sign {
17    ($($body:tt)*) => {
18        macro_rules! __with_dollar_sign { $($body)* }
19        __with_dollar_sign!($);
20    }
21}
22
23macro_rules! gen_component_wise_extractor {
24    (
25        $ident:ident -> $target:ident,
26        literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
27        scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
28    ) => {
29        /// A subset of [`Literal`]s intended to be used for implementing numeric built-ins.
30        #[derive(Debug)]
31        #[cfg_attr(test, derive(PartialEq))]
32        enum $target<const N: usize> {
33            $(
34                #[doc = concat!(
35                    "Maps to [`Literal::",
36                    stringify!($literal),
37                    "`]",
38                )]
39                $mapping([$ty; N]),
40            )+
41        }
42
43        impl From<$target<1>> for Expression {
44            fn from(value: $target<1>) -> Self {
45                match value {
46                    $(
47                        $target::$mapping([value]) => {
48                            Expression::Literal(Literal::$literal(value))
49                        }
50                    )+
51                }
52            }
53        }
54
55        #[doc = concat!(
56            "Attempts to evaluate multiple `exprs` as a combined [`",
57            stringify!($target),
58            "`] to pass to `handler`. ",
59        )]
60        /// If `exprs` are vectors of the same length, `handler` is called for each corresponding
61        /// component of each vector.
62        ///
63        /// `handler`'s output is registered as a new expression. If `exprs` are vectors of the
64        /// same length, a new vector expression is registered, composed of each component emitted
65        /// by `handler`.
66        fn $ident<const N: usize, const M: usize, F>(
67            eval: &mut ConstantEvaluator<'_>,
68            span: Span,
69            exprs: [Handle<Expression>; N],
70            mut handler: F,
71        ) -> Result<Handle<Expression>, ConstantEvaluatorError>
72        where
73            $target<M>: Into<Expression>,
74            F: FnMut($target<N>) -> Result<$target<M>, ConstantEvaluatorError> + Clone,
75        {
76            assert!(N > 0);
77            let err = ConstantEvaluatorError::InvalidMathArg;
78            let mut exprs = exprs.into_iter();
79
80            macro_rules! sanitize {
81                ($expr:expr) => {
82                    eval.eval_zero_value_and_splat($expr, span)
83                        .map(|expr| &eval.expressions[expr])
84                };
85            }
86
87            let new_expr = match sanitize!(exprs.next().unwrap())? {
88                $(
89                    &Expression::Literal(Literal::$literal(x)) => iter::once(Ok(x))
90                        .chain(exprs.map(|expr| {
91                            sanitize!(expr).and_then(|expr| match expr {
92                                &Expression::Literal(Literal::$literal(x)) => Ok(x),
93                                _ => Err(err.clone()),
94                            })
95                        }))
96                        .collect::<Result<ArrayVec<_, N>, _>>()
97                        .map(|a| a.into_inner().unwrap())
98                        .map($target::$mapping)
99                        .and_then(|comps| Ok(handler(comps)?.into())),
100                )+
101                &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
102                    &TypeInner::Vector { size, scalar } => match scalar.kind {
103                        $(ScalarKind::$scalar_kind)|* => {
104                            let first_ty = ty;
105                            let mut component_groups =
106                                ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
107                            component_groups.push(crate::proc::flatten_compose(
108                                first_ty,
109                                components,
110                                eval.expressions,
111                                eval.types,
112                            ).collect());
113                            component_groups.extend(
114                                exprs
115                                    .map(|expr| {
116                                        sanitize!(expr).and_then(|expr| match expr {
117                                            &Expression::Compose { ty, ref components }
118                                                if &eval.types[ty].inner
119                                                    == &eval.types[first_ty].inner =>
120                                            {
121                                                Ok(crate::proc::flatten_compose(
122                                                    ty,
123                                                    components,
124                                                    eval.expressions,
125                                                    eval.types,
126                                                ).collect())
127                                            }
128                                            _ => Err(err.clone()),
129                                        })
130                                    })
131                                    .collect::<Result<ArrayVec<_, { crate::VectorSize::MAX }>, _>>(
132                                    )?,
133                            );
134                            let component_groups = component_groups.into_inner().unwrap();
135                            let mut new_components =
136                                ArrayVec::<_, { crate::VectorSize::MAX }>::new();
137                            for idx in 0..(size as u8).into() {
138                                let group = component_groups
139                                    .iter()
140                                    .map(|cs| cs.get(idx).cloned().ok_or(err.clone()))
141                                    .collect::<Result<ArrayVec<_, N>, _>>()?
142                                    .into_inner()
143                                    .unwrap();
144                                new_components.push($ident(
145                                    eval,
146                                    span,
147                                    group,
148                                    handler.clone(),
149                                )?);
150                            }
151                            Ok(Expression::Compose {
152                                ty: first_ty,
153                                components: new_components.into_iter().collect(),
154                            })
155                        }
156                        _ => return Err(err),
157                    },
158                    _ => return Err(err),
159                },
160                _ => return Err(err),
161            }?;
162            eval.register_evaluated_expr(new_expr, span)
163        }
164
165        with_dollar_sign! {
166            ($d:tt) => {
167                #[allow(unused)]
168                #[doc = concat!(
169                    "A convenience macro for using the same RHS for each [`",
170                    stringify!($target),
171                    "`] variant in a call to [`",
172                    stringify!($ident),
173                    "`].",
174                )]
175                macro_rules! $ident {
176                    (
177                        $eval:expr,
178                        $span:expr,
179                        [$d ($d expr:expr),+ $d (,)?],
180                        |$d ($d arg:ident),+| $d tt:tt
181                    ) => {
182                        $ident($eval, $span, [$d ($d expr),+], |args| match args {
183                            $(
184                                $target::$mapping([$d ($d arg),+]) => {
185                                    let res = $d tt;
186                                    Result::map(res, $target::$mapping)
187                                },
188                            )+
189                        })
190                    };
191                }
192            };
193        }
194    };
195}
196
197gen_component_wise_extractor! {
198    component_wise_scalar -> Scalar,
199    literals: [
200        AbstractFloat => AbstractFloat: f64,
201        F32 => F32: f32,
202        AbstractInt => AbstractInt: i64,
203        U32 => U32: u32,
204        I32 => I32: i32,
205        U64 => U64: u64,
206        I64 => I64: i64,
207    ],
208    scalar_kinds: [
209        Float,
210        AbstractFloat,
211        Sint,
212        Uint,
213        AbstractInt,
214    ],
215}
216
217gen_component_wise_extractor! {
218    component_wise_float -> Float,
219    literals: [
220        AbstractFloat => Abstract: f64,
221        F32 => F32: f32,
222    ],
223    scalar_kinds: [
224        Float,
225        AbstractFloat,
226    ],
227}
228
229gen_component_wise_extractor! {
230    component_wise_concrete_int -> ConcreteInt,
231    literals: [
232        U32 => U32: u32,
233        I32 => I32: i32,
234    ],
235    scalar_kinds: [
236        Sint,
237        Uint,
238    ],
239}
240
241gen_component_wise_extractor! {
242    component_wise_signed -> Signed,
243    literals: [
244        AbstractFloat => AbstractFloat: f64,
245        AbstractInt => AbstractInt: i64,
246        F32 => F32: f32,
247        I32 => I32: i32,
248    ],
249    scalar_kinds: [
250        Sint,
251        AbstractInt,
252        Float,
253        AbstractFloat,
254    ],
255}
256
257#[derive(Debug)]
258enum Behavior<'a> {
259    Wgsl(WgslRestrictions<'a>),
260    Glsl(GlslRestrictions<'a>),
261}
262
263impl Behavior<'_> {
264    /// Returns `true` if the inner WGSL/GLSL restrictions are runtime restrictions.
265    const fn has_runtime_restrictions(&self) -> bool {
266        matches!(
267            self,
268            &Behavior::Wgsl(WgslRestrictions::Runtime(_))
269                | &Behavior::Glsl(GlslRestrictions::Runtime(_))
270        )
271    }
272}
273
274/// A context for evaluating constant expressions.
275///
276/// A `ConstantEvaluator` points at an expression arena to which it can append
277/// newly evaluated expressions: you pass [`try_eval_and_append`] whatever kind
278/// of Naga [`Expression`] you like, and if its value can be computed at compile
279/// time, `try_eval_and_append` appends an expression representing the computed
280/// value - a tree of [`Literal`], [`Compose`], [`ZeroValue`], and [`Swizzle`]
281/// expressions - to the arena. See the [`try_eval_and_append`] method for details.
282///
283/// A `ConstantEvaluator` also holds whatever information we need to carry out
284/// that evaluation: types, other constants, and so on.
285///
286/// [`try_eval_and_append`]: ConstantEvaluator::try_eval_and_append
287/// [`Compose`]: Expression::Compose
288/// [`ZeroValue`]: Expression::ZeroValue
289/// [`Literal`]: Expression::Literal
290/// [`Swizzle`]: Expression::Swizzle
291#[derive(Debug)]
292pub struct ConstantEvaluator<'a> {
293    /// Which language's evaluation rules we should follow.
294    behavior: Behavior<'a>,
295
296    /// The module's type arena.
297    ///
298    /// Because expressions like [`Splat`] contain type handles, we need to be
299    /// able to add new types to produce those expressions.
300    ///
301    /// [`Splat`]: Expression::Splat
302    types: &'a mut UniqueArena<Type>,
303
304    /// The module's constant arena.
305    constants: &'a Arena<Constant>,
306
307    /// The module's override arena.
308    overrides: &'a Arena<Override>,
309
310    /// The arena to which we are contributing expressions.
311    expressions: &'a mut Arena<Expression>,
312
313    /// Tracks the constness of expressions residing in [`Self::expressions`]
314    expression_kind_tracker: &'a mut ExpressionKindTracker,
315}
316
317#[derive(Debug)]
318enum WgslRestrictions<'a> {
319    /// - const-expressions will be evaluated and inserted in the arena
320    Const(Option<FunctionLocalData<'a>>),
321    /// - const-expressions will be evaluated and inserted in the arena
322    /// - override-expressions will be inserted in the arena
323    Override,
324    /// - const-expressions will be evaluated and inserted in the arena
325    /// - override-expressions will be inserted in the arena
326    /// - runtime-expressions will be inserted in the arena
327    Runtime(FunctionLocalData<'a>),
328}
329
330#[derive(Debug)]
331enum GlslRestrictions<'a> {
332    /// - const-expressions will be evaluated and inserted in the arena
333    Const,
334    /// - const-expressions will be evaluated and inserted in the arena
335    /// - override-expressions will be inserted in the arena
336    /// - runtime-expressions will be inserted in the arena
337    Runtime(FunctionLocalData<'a>),
338}
339
340#[derive(Debug)]
341struct FunctionLocalData<'a> {
342    /// Global constant expressions
343    global_expressions: &'a Arena<Expression>,
344    emitter: &'a mut super::Emitter,
345    block: &'a mut crate::Block,
346}
347
348#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
349pub enum ExpressionKind {
350    /// If const is also implemented as const
351    ImplConst,
352    Const,
353    Override,
354    Runtime,
355}
356
357#[derive(Debug)]
358pub struct ExpressionKindTracker {
359    inner: HandleVec<Expression, ExpressionKind>,
360}
361
362impl ExpressionKindTracker {
363    pub const fn new() -> Self {
364        Self {
365            inner: HandleVec::new(),
366        }
367    }
368
369    /// Forces the the expression to not be const
370    pub fn force_non_const(&mut self, value: Handle<Expression>) {
371        self.inner[value] = ExpressionKind::Runtime;
372    }
373
374    pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
375        self.inner.insert(value, expr_type);
376    }
377
378    pub fn is_const(&self, h: Handle<Expression>) -> bool {
379        matches!(
380            self.type_of(h),
381            ExpressionKind::Const | ExpressionKind::ImplConst
382        )
383    }
384
385    /// Returns `true` if naga can also evaluate expression as const
386    pub fn is_impl_const(&self, h: Handle<Expression>) -> bool {
387        matches!(self.type_of(h), ExpressionKind::ImplConst)
388    }
389
390    pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
391        matches!(
392            self.type_of(h),
393            ExpressionKind::Const | ExpressionKind::Override | ExpressionKind::ImplConst
394        )
395    }
396
397    fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
398        self.inner[value]
399    }
400
401    pub fn from_arena(arena: &Arena<Expression>) -> Self {
402        let mut tracker = Self {
403            inner: HandleVec::with_capacity(arena.len()),
404        };
405        for (handle, expr) in arena.iter() {
406            tracker
407                .inner
408                .insert(handle, tracker.type_of_with_expr(expr));
409        }
410        tracker
411    }
412
413    fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
414        use crate::MathFunction as Mf;
415        match *expr {
416            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
417                ExpressionKind::ImplConst
418            }
419            Expression::Override(_) => ExpressionKind::Override,
420            Expression::Compose { ref components, .. } => {
421                let mut expr_type = ExpressionKind::ImplConst;
422                for component in components {
423                    expr_type = expr_type.max(self.type_of(*component))
424                }
425                expr_type
426            }
427            Expression::Splat { value, .. } => self.type_of(value),
428            Expression::AccessIndex { base, .. } => self.type_of(base),
429            Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
430            Expression::Swizzle { vector, .. } => self.type_of(vector),
431            Expression::Unary { expr, .. } => self.type_of(expr),
432            Expression::Binary { left, right, .. } => self
433                .type_of(left)
434                .max(self.type_of(right))
435                .max(ExpressionKind::Const),
436            Expression::Math {
437                fun,
438                arg,
439                arg1,
440                arg2,
441                arg3,
442            } => self
443                .type_of(arg)
444                .max(
445                    arg1.map(|arg| self.type_of(arg))
446                        .unwrap_or(ExpressionKind::Const),
447                )
448                .max(
449                    arg2.map(|arg| self.type_of(arg))
450                        .unwrap_or(ExpressionKind::Const),
451                )
452                .max(
453                    arg3.map(|arg| self.type_of(arg))
454                        .unwrap_or(ExpressionKind::Const),
455                )
456                .max(
457                    if matches!(
458                        fun,
459                        Mf::Dot
460                            | Mf::Outer
461                            | Mf::Cross
462                            | Mf::Distance
463                            | Mf::Length
464                            | Mf::Normalize
465                            | Mf::FaceForward
466                            | Mf::Reflect
467                            | Mf::Refract
468                            | Mf::Ldexp
469                            | Mf::Modf
470                            | Mf::Mix
471                            | Mf::Frexp
472                    ) {
473                        ExpressionKind::Const
474                    } else {
475                        ExpressionKind::ImplConst
476                    },
477                ),
478            Expression::As { convert, expr, .. } => self.type_of(expr).max(if convert.is_some() {
479                ExpressionKind::ImplConst
480            } else {
481                ExpressionKind::Const
482            }),
483            Expression::Select {
484                condition,
485                accept,
486                reject,
487            } => self
488                .type_of(condition)
489                .max(self.type_of(accept))
490                .max(self.type_of(reject))
491                .max(ExpressionKind::Const),
492            Expression::Relational { argument, .. } => self.type_of(argument),
493            Expression::ArrayLength(expr) => self.type_of(expr),
494            _ => ExpressionKind::Runtime,
495        }
496    }
497}
498
499#[derive(Clone, Debug, thiserror::Error)]
500#[cfg_attr(test, derive(PartialEq))]
501pub enum ConstantEvaluatorError {
502    #[error("Constants cannot access function arguments")]
503    FunctionArg,
504    #[error("Constants cannot access global variables")]
505    GlobalVariable,
506    #[error("Constants cannot access local variables")]
507    LocalVariable,
508    #[error("Cannot get the array length of a non array type")]
509    InvalidArrayLengthArg,
510    #[error("Constants cannot get the array length of a dynamically sized array")]
511    ArrayLengthDynamic,
512    #[error("Cannot call arrayLength on array sized by override-expression")]
513    ArrayLengthOverridden,
514    #[error("Constants cannot call functions")]
515    Call,
516    #[error("Constants don't support workGroupUniformLoad")]
517    WorkGroupUniformLoadResult,
518    #[error("Constants don't support atomic functions")]
519    Atomic,
520    #[error("Constants don't support derivative functions")]
521    Derivative,
522    #[error("Constants don't support load expressions")]
523    Load,
524    #[error("Constants don't support image expressions")]
525    ImageExpression,
526    #[error("Constants don't support ray query expressions")]
527    RayQueryExpression,
528    #[error("Constants don't support subgroup expressions")]
529    SubgroupExpression,
530    #[error("Cannot access the type")]
531    InvalidAccessBase,
532    #[error("Cannot access at the index")]
533    InvalidAccessIndex,
534    #[error("Cannot access with index of type")]
535    InvalidAccessIndexTy,
536    #[error("Constants don't support array length expressions")]
537    ArrayLength,
538    #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
539    InvalidCastArg { from: String, to: String },
540    #[error("Cannot apply the unary op to the argument")]
541    InvalidUnaryOpArg,
542    #[error("Cannot apply the binary op to the arguments")]
543    InvalidBinaryOpArgs,
544    #[error("Cannot apply math function to type")]
545    InvalidMathArg,
546    #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
547    InvalidMathArgCount(crate::MathFunction, usize, usize),
548    #[error("value of `low` is greater than `high` for clamp built-in function")]
549    InvalidClamp,
550    #[error("Splat is defined only on scalar values")]
551    SplatScalarOnly,
552    #[error("Can only swizzle vector constants")]
553    SwizzleVectorOnly,
554    #[error("swizzle component not present in source expression")]
555    SwizzleOutOfBounds,
556    #[error("Type is not constructible")]
557    TypeNotConstructible,
558    #[error("Subexpression(s) are not constant")]
559    SubexpressionsAreNotConstant,
560    #[error("Not implemented as constant expression: {0}")]
561    NotImplemented(String),
562    #[error("{0} operation overflowed")]
563    Overflow(String),
564    #[error(
565        "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
566    )]
567    AutomaticConversionLossy {
568        value: String,
569        to_type: &'static str,
570    },
571    #[error("abstract floating-point values cannot be automatically converted to integers")]
572    AutomaticConversionFloatToInt { to_type: &'static str },
573    #[error("Division by zero")]
574    DivisionByZero,
575    #[error("Remainder by zero")]
576    RemainderByZero,
577    #[error("RHS of shift operation is greater than or equal to 32")]
578    ShiftedMoreThan32Bits,
579    #[error(transparent)]
580    Literal(#[from] crate::valid::LiteralError),
581    #[error("Can't use pipeline-overridable constants in const-expressions")]
582    Override,
583    #[error("Unexpected runtime-expression")]
584    RuntimeExpr,
585    #[error("Unexpected override-expression")]
586    OverrideExpr,
587}
588
589impl<'a> ConstantEvaluator<'a> {
590    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
591    /// constant expression arena.
592    ///
593    /// Report errors according to WGSL's rules for constant evaluation.
594    pub fn for_wgsl_module(
595        module: &'a mut crate::Module,
596        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
597        in_override_ctx: bool,
598    ) -> Self {
599        Self::for_module(
600            Behavior::Wgsl(if in_override_ctx {
601                WgslRestrictions::Override
602            } else {
603                WgslRestrictions::Const(None)
604            }),
605            module,
606            global_expression_kind_tracker,
607        )
608    }
609
610    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
611    /// constant expression arena.
612    ///
613    /// Report errors according to GLSL's rules for constant evaluation.
614    pub fn for_glsl_module(
615        module: &'a mut crate::Module,
616        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
617    ) -> Self {
618        Self::for_module(
619            Behavior::Glsl(GlslRestrictions::Const),
620            module,
621            global_expression_kind_tracker,
622        )
623    }
624
625    fn for_module(
626        behavior: Behavior<'a>,
627        module: &'a mut crate::Module,
628        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
629    ) -> Self {
630        Self {
631            behavior,
632            types: &mut module.types,
633            constants: &module.constants,
634            overrides: &module.overrides,
635            expressions: &mut module.global_expressions,
636            expression_kind_tracker: global_expression_kind_tracker,
637        }
638    }
639
640    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
641    /// expression arena.
642    ///
643    /// Report errors according to WGSL's rules for constant evaluation.
644    pub fn for_wgsl_function(
645        module: &'a mut crate::Module,
646        expressions: &'a mut Arena<Expression>,
647        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
648        emitter: &'a mut super::Emitter,
649        block: &'a mut crate::Block,
650        is_const: bool,
651    ) -> Self {
652        let local_data = FunctionLocalData {
653            global_expressions: &module.global_expressions,
654            emitter,
655            block,
656        };
657        Self {
658            behavior: Behavior::Wgsl(if is_const {
659                WgslRestrictions::Const(Some(local_data))
660            } else {
661                WgslRestrictions::Runtime(local_data)
662            }),
663            types: &mut module.types,
664            constants: &module.constants,
665            overrides: &module.overrides,
666            expressions,
667            expression_kind_tracker: local_expression_kind_tracker,
668        }
669    }
670
671    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
672    /// expression arena.
673    ///
674    /// Report errors according to GLSL's rules for constant evaluation.
675    pub fn for_glsl_function(
676        module: &'a mut crate::Module,
677        expressions: &'a mut Arena<Expression>,
678        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
679        emitter: &'a mut super::Emitter,
680        block: &'a mut crate::Block,
681    ) -> Self {
682        Self {
683            behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
684                global_expressions: &module.global_expressions,
685                emitter,
686                block,
687            })),
688            types: &mut module.types,
689            constants: &module.constants,
690            overrides: &module.overrides,
691            expressions,
692            expression_kind_tracker: local_expression_kind_tracker,
693        }
694    }
695
696    pub fn to_ctx(&self) -> crate::proc::GlobalCtx {
697        crate::proc::GlobalCtx {
698            types: self.types,
699            constants: self.constants,
700            overrides: self.overrides,
701            global_expressions: match self.function_local_data() {
702                Some(data) => data.global_expressions,
703                None => self.expressions,
704            },
705        }
706    }
707
708    fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
709        if !self.expression_kind_tracker.is_const(expr) {
710            log::debug!("check: SubexpressionsAreNotConstant");
711            return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
712        }
713        Ok(())
714    }
715
716    fn check_and_get(
717        &mut self,
718        expr: Handle<Expression>,
719    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
720        match self.expressions[expr] {
721            Expression::Constant(c) => {
722                // Are we working in a function's expression arena, or the
723                // module's constant expression arena?
724                if let Some(function_local_data) = self.function_local_data() {
725                    // Deep-copy the constant's value into our arena.
726                    self.copy_from(
727                        self.constants[c].init,
728                        function_local_data.global_expressions,
729                    )
730                } else {
731                    // "See through" the constant and use its initializer.
732                    Ok(self.constants[c].init)
733                }
734            }
735            _ => {
736                self.check(expr)?;
737                Ok(expr)
738            }
739        }
740    }
741
742    /// Try to evaluate `expr` at compile time.
743    ///
744    /// The `expr` argument can be any sort of Naga [`Expression`] you like. If
745    /// we can determine its value at compile time, we append an expression
746    /// representing its value - a tree of [`Literal`], [`Compose`],
747    /// [`ZeroValue`], and [`Swizzle`] expressions - to the expression arena
748    /// `self` contributes to.
749    ///
750    /// If `expr`'s value cannot be determined at compile time, and `self` is
751    /// contributing to some function's expression arena, then append `expr` to
752    /// that arena unchanged (and thus unevaluated). Otherwise, `self` must be
753    /// contributing to the module's constant expression arena; since `expr`'s
754    /// value is not a constant, return an error.
755    ///
756    /// We only consider `expr` itself, without recursing into its operands. Its
757    /// operands must all have been produced by prior calls to
758    /// `try_eval_and_append`, to ensure that they have already been reduced to
759    /// an evaluated form if possible.
760    ///
761    /// [`Literal`]: Expression::Literal
762    /// [`Compose`]: Expression::Compose
763    /// [`ZeroValue`]: Expression::ZeroValue
764    /// [`Swizzle`]: Expression::Swizzle
765    pub fn try_eval_and_append(
766        &mut self,
767        expr: Expression,
768        span: Span,
769    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
770        match self.expression_kind_tracker.type_of_with_expr(&expr) {
771            ExpressionKind::ImplConst => self.try_eval_and_append_impl(&expr, span),
772            ExpressionKind::Const => {
773                let eval_result = self.try_eval_and_append_impl(&expr, span);
774                // We should be able to evaluate `Const` expressions at this
775                // point. If we failed to, then that probably means we just
776                // haven't implemented that part of constant evaluation. Work
777                // around this by simply emitting it as a run-time expression.
778                if self.behavior.has_runtime_restrictions()
779                    && matches!(
780                        eval_result,
781                        Err(ConstantEvaluatorError::NotImplemented(_)
782                            | ConstantEvaluatorError::InvalidBinaryOpArgs,)
783                    )
784                {
785                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
786                } else {
787                    eval_result
788                }
789            }
790            ExpressionKind::Override => match self.behavior {
791                Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
792                    Ok(self.append_expr(expr, span, ExpressionKind::Override))
793                }
794                Behavior::Wgsl(WgslRestrictions::Const(_)) => {
795                    Err(ConstantEvaluatorError::OverrideExpr)
796                }
797                Behavior::Glsl(_) => {
798                    unreachable!()
799                }
800            },
801            ExpressionKind::Runtime => {
802                if self.behavior.has_runtime_restrictions() {
803                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
804                } else {
805                    Err(ConstantEvaluatorError::RuntimeExpr)
806                }
807            }
808        }
809    }
810
811    /// Is the [`Self::expressions`] arena the global module expression arena?
812    const fn is_global_arena(&self) -> bool {
813        matches!(
814            self.behavior,
815            Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
816                | Behavior::Glsl(GlslRestrictions::Const)
817        )
818    }
819
820    const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
821        match self.behavior {
822            Behavior::Wgsl(
823                WgslRestrictions::Runtime(ref function_local_data)
824                | WgslRestrictions::Const(Some(ref function_local_data)),
825            )
826            | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
827                Some(function_local_data)
828            }
829            _ => None,
830        }
831    }
832
833    fn try_eval_and_append_impl(
834        &mut self,
835        expr: &Expression,
836        span: Span,
837    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
838        log::trace!("try_eval_and_append: {:?}", expr);
839        match *expr {
840            Expression::Constant(c) if self.is_global_arena() => {
841                // "See through" the constant and use its initializer.
842                // This is mainly done to avoid having constants pointing to other constants.
843                Ok(self.constants[c].init)
844            }
845            Expression::Override(_) => Err(ConstantEvaluatorError::Override),
846            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
847                self.register_evaluated_expr(expr.clone(), span)
848            }
849            Expression::Compose { ty, ref components } => {
850                let components = components
851                    .iter()
852                    .map(|component| self.check_and_get(*component))
853                    .collect::<Result<Vec<_>, _>>()?;
854                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
855            }
856            Expression::Splat { size, value } => {
857                let value = self.check_and_get(value)?;
858                self.register_evaluated_expr(Expression::Splat { size, value }, span)
859            }
860            Expression::AccessIndex { base, index } => {
861                let base = self.check_and_get(base)?;
862
863                self.access(base, index as usize, span)
864            }
865            Expression::Access { base, index } => {
866                let base = self.check_and_get(base)?;
867                let index = self.check_and_get(index)?;
868
869                self.access(base, self.constant_index(index)?, span)
870            }
871            Expression::Swizzle {
872                size,
873                vector,
874                pattern,
875            } => {
876                let vector = self.check_and_get(vector)?;
877
878                self.swizzle(size, span, vector, pattern)
879            }
880            Expression::Unary { expr, op } => {
881                let expr = self.check_and_get(expr)?;
882
883                self.unary_op(op, expr, span)
884            }
885            Expression::Binary { left, right, op } => {
886                let left = self.check_and_get(left)?;
887                let right = self.check_and_get(right)?;
888
889                self.binary_op(op, left, right, span)
890            }
891            Expression::Math {
892                fun,
893                arg,
894                arg1,
895                arg2,
896                arg3,
897            } => {
898                let arg = self.check_and_get(arg)?;
899                let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
900                let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
901                let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
902
903                self.math(arg, arg1, arg2, arg3, fun, span)
904            }
905            Expression::As {
906                convert,
907                expr,
908                kind,
909            } => {
910                let expr = self.check_and_get(expr)?;
911
912                match convert {
913                    Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
914                    None => Err(ConstantEvaluatorError::NotImplemented(
915                        "bitcast built-in function".into(),
916                    )),
917                }
918            }
919            Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented(
920                "select built-in function".into(),
921            )),
922            Expression::Relational { fun, .. } => Err(ConstantEvaluatorError::NotImplemented(
923                format!("{fun:?} built-in function"),
924            )),
925            Expression::ArrayLength(expr) => match self.behavior {
926                Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
927                Behavior::Glsl(_) => {
928                    let expr = self.check_and_get(expr)?;
929                    self.array_length(expr, span)
930                }
931            },
932            Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
933            Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
934            Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
935            Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
936            Expression::WorkGroupUniformLoadResult { .. } => {
937                Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
938            }
939            Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
940            Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
941            Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
942            Expression::ImageSample { .. }
943            | Expression::ImageLoad { .. }
944            | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
945            Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => {
946                Err(ConstantEvaluatorError::RayQueryExpression)
947            }
948            Expression::SubgroupBallotResult { .. } => {
949                Err(ConstantEvaluatorError::SubgroupExpression)
950            }
951            Expression::SubgroupOperationResult { .. } => {
952                Err(ConstantEvaluatorError::SubgroupExpression)
953            }
954        }
955    }
956
957    /// Splat `value` to `size`, without using [`Splat`] expressions.
958    ///
959    /// This constructs [`Compose`] or [`ZeroValue`] expressions to
960    /// build a vector with the given `size` whose components are all
961    /// `value`.
962    ///
963    /// Use `span` as the span of the inserted expressions and
964    /// resulting types.
965    ///
966    /// [`Splat`]: Expression::Splat
967    /// [`Compose`]: Expression::Compose
968    /// [`ZeroValue`]: Expression::ZeroValue
969    fn splat(
970        &mut self,
971        value: Handle<Expression>,
972        size: crate::VectorSize,
973        span: Span,
974    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
975        match self.expressions[value] {
976            Expression::Literal(literal) => {
977                let scalar = literal.scalar();
978                let ty = self.types.insert(
979                    Type {
980                        name: None,
981                        inner: TypeInner::Vector { size, scalar },
982                    },
983                    span,
984                );
985                let expr = Expression::Compose {
986                    ty,
987                    components: vec![value; size as usize],
988                };
989                self.register_evaluated_expr(expr, span)
990            }
991            Expression::ZeroValue(ty) => {
992                let inner = match self.types[ty].inner {
993                    TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
994                    _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
995                };
996                let res_ty = self.types.insert(Type { name: None, inner }, span);
997                let expr = Expression::ZeroValue(res_ty);
998                self.register_evaluated_expr(expr, span)
999            }
1000            _ => Err(ConstantEvaluatorError::SplatScalarOnly),
1001        }
1002    }
1003
1004    fn swizzle(
1005        &mut self,
1006        size: crate::VectorSize,
1007        span: Span,
1008        src_constant: Handle<Expression>,
1009        pattern: [crate::SwizzleComponent; 4],
1010    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1011        let mut get_dst_ty = |ty| match self.types[ty].inner {
1012            TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1013                Type {
1014                    name: None,
1015                    inner: TypeInner::Vector { size, scalar },
1016                },
1017                span,
1018            )),
1019            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1020        };
1021
1022        match self.expressions[src_constant] {
1023            Expression::ZeroValue(ty) => {
1024                let dst_ty = get_dst_ty(ty)?;
1025                let expr = Expression::ZeroValue(dst_ty);
1026                self.register_evaluated_expr(expr, span)
1027            }
1028            Expression::Splat { value, .. } => {
1029                let expr = Expression::Splat { size, value };
1030                self.register_evaluated_expr(expr, span)
1031            }
1032            Expression::Compose { ty, ref components } => {
1033                let dst_ty = get_dst_ty(ty)?;
1034
1035                let mut flattened = [src_constant; 4]; // dummy value
1036                let len =
1037                    crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1038                        .zip(flattened.iter_mut())
1039                        .map(|(component, elt)| *elt = component)
1040                        .count();
1041                let flattened = &flattened[..len];
1042
1043                let swizzled_components = pattern[..size as usize]
1044                    .iter()
1045                    .map(|&sc| {
1046                        let sc = sc as usize;
1047                        if let Some(elt) = flattened.get(sc) {
1048                            Ok(*elt)
1049                        } else {
1050                            Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1051                        }
1052                    })
1053                    .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1054                let expr = Expression::Compose {
1055                    ty: dst_ty,
1056                    components: swizzled_components,
1057                };
1058                self.register_evaluated_expr(expr, span)
1059            }
1060            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1061        }
1062    }
1063
1064    fn math(
1065        &mut self,
1066        arg: Handle<Expression>,
1067        arg1: Option<Handle<Expression>>,
1068        arg2: Option<Handle<Expression>>,
1069        arg3: Option<Handle<Expression>>,
1070        fun: crate::MathFunction,
1071        span: Span,
1072    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1073        let expected = fun.argument_count();
1074        let given = Some(arg)
1075            .into_iter()
1076            .chain(arg1)
1077            .chain(arg2)
1078            .chain(arg3)
1079            .count();
1080        if expected != given {
1081            return Err(ConstantEvaluatorError::InvalidMathArgCount(
1082                fun, expected, given,
1083            ));
1084        }
1085
1086        // NOTE: We try to match the declaration order of `MathFunction` here.
1087        match fun {
1088            // comparison
1089            crate::MathFunction::Abs => {
1090                component_wise_scalar(self, span, [arg], |args| match args {
1091                    Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1092                    Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1093                    Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.abs()])),
1094                    Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1095                    Scalar::U32([e]) => Ok(Scalar::U32([e])), // TODO: just re-use the expression, ezpz
1096                    Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1097                    Scalar::U64([e]) => Ok(Scalar::U64([e])),
1098                })
1099            }
1100            crate::MathFunction::Min => {
1101                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1102                    Ok([e1.min(e2)])
1103                })
1104            }
1105            crate::MathFunction::Max => {
1106                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1107                    Ok([e1.max(e2)])
1108                })
1109            }
1110            crate::MathFunction::Clamp => {
1111                component_wise_scalar!(
1112                    self,
1113                    span,
1114                    [arg, arg1.unwrap(), arg2.unwrap()],
1115                    |e, low, high| {
1116                        if low > high {
1117                            Err(ConstantEvaluatorError::InvalidClamp)
1118                        } else {
1119                            Ok([e.clamp(low, high)])
1120                        }
1121                    }
1122                )
1123            }
1124            crate::MathFunction::Saturate => {
1125                component_wise_float!(self, span, [arg], |e| { Ok([e.clamp(0., 1.)]) })
1126            }
1127
1128            // trigonometry
1129            crate::MathFunction::Cos => {
1130                component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1131            }
1132            crate::MathFunction::Cosh => {
1133                component_wise_float!(self, span, [arg], |e| { Ok([e.cosh()]) })
1134            }
1135            crate::MathFunction::Sin => {
1136                component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1137            }
1138            crate::MathFunction::Sinh => {
1139                component_wise_float!(self, span, [arg], |e| { Ok([e.sinh()]) })
1140            }
1141            crate::MathFunction::Tan => {
1142                component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1143            }
1144            crate::MathFunction::Tanh => {
1145                component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1146            }
1147            crate::MathFunction::Acos => {
1148                component_wise_float!(self, span, [arg], |e| { Ok([e.acos()]) })
1149            }
1150            crate::MathFunction::Asin => {
1151                component_wise_float!(self, span, [arg], |e| { Ok([e.asin()]) })
1152            }
1153            crate::MathFunction::Atan => {
1154                component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1155            }
1156            crate::MathFunction::Asinh => {
1157                component_wise_float!(self, span, [arg], |e| { Ok([e.asinh()]) })
1158            }
1159            crate::MathFunction::Acosh => {
1160                component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) })
1161            }
1162            crate::MathFunction::Atanh => {
1163                component_wise_float!(self, span, [arg], |e| { Ok([e.atanh()]) })
1164            }
1165            crate::MathFunction::Radians => {
1166                component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1167            }
1168            crate::MathFunction::Degrees => {
1169                component_wise_float!(self, span, [arg], |e| { Ok([e.to_degrees()]) })
1170            }
1171
1172            // decomposition
1173            crate::MathFunction::Ceil => {
1174                component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1175            }
1176            crate::MathFunction::Floor => {
1177                component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1178            }
1179            crate::MathFunction::Round => {
1180                // TODO: this hit stable on 1.77, but MSRV hasn't caught up yet
1181                // This polyfill is shamelessly [~~stolen from~~ inspired by `ndarray-image`][polyfill source],
1182                // which has licensing compatible with ours. See also
1183                // <https://github.com/rust-lang/rust/issues/96710>.
1184                //
1185                // [polyfill source]: https://github.com/imeka/ndarray-ndimage/blob/8b14b4d6ecfbc96a8a052f802e342a7049c68d8f/src/lib.rs#L98
1186                fn round_ties_even(x: f64) -> f64 {
1187                    let i = x as i64;
1188                    let f = (x - i as f64).abs();
1189                    if f == 0.5 {
1190                        if i & 1 == 1 {
1191                            // -1.5, 1.5, 3.5, ...
1192                            (x.abs() + 0.5).copysign(x)
1193                        } else {
1194                            (x.abs() - 0.5).copysign(x)
1195                        }
1196                    } else {
1197                        x.round()
1198                    }
1199                }
1200                component_wise_float(self, span, [arg], |e| match e {
1201                    Float::Abstract([e]) => Ok(Float::Abstract([round_ties_even(e)])),
1202                    Float::F32([e]) => Ok(Float::F32([(round_ties_even(e as f64) as f32)])),
1203                })
1204            }
1205            crate::MathFunction::Fract => {
1206                component_wise_float!(self, span, [arg], |e| {
1207                    // N.B., Rust's definition of `fract` is `e - e.trunc()`, so we can't use that
1208                    // here.
1209                    Ok([e - e.floor()])
1210                })
1211            }
1212            crate::MathFunction::Trunc => {
1213                component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1214            }
1215
1216            // exponent
1217            crate::MathFunction::Exp => {
1218                component_wise_float!(self, span, [arg], |e| { Ok([e.exp()]) })
1219            }
1220            crate::MathFunction::Exp2 => {
1221                component_wise_float!(self, span, [arg], |e| { Ok([e.exp2()]) })
1222            }
1223            crate::MathFunction::Log => {
1224                component_wise_float!(self, span, [arg], |e| { Ok([e.ln()]) })
1225            }
1226            crate::MathFunction::Log2 => {
1227                component_wise_float!(self, span, [arg], |e| { Ok([e.log2()]) })
1228            }
1229            crate::MathFunction::Pow => {
1230                component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1231                    Ok([e1.powf(e2)])
1232                })
1233            }
1234
1235            // computational
1236            crate::MathFunction::Sign => {
1237                component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) })
1238            }
1239            crate::MathFunction::Fma => {
1240                component_wise_float!(
1241                    self,
1242                    span,
1243                    [arg, arg1.unwrap(), arg2.unwrap()],
1244                    |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1245                )
1246            }
1247            crate::MathFunction::Step => {
1248                component_wise_float!(self, span, [arg, arg1.unwrap()], |edge, x| {
1249                    Ok([if edge <= x { 1.0 } else { 0.0 }])
1250                })
1251            }
1252            crate::MathFunction::Sqrt => {
1253                component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) })
1254            }
1255            crate::MathFunction::InverseSqrt => {
1256                component_wise_float!(self, span, [arg], |e| { Ok([1. / e.sqrt()]) })
1257            }
1258
1259            // bits
1260            crate::MathFunction::CountTrailingZeros => {
1261                component_wise_concrete_int!(self, span, [arg], |e| {
1262                    #[allow(clippy::useless_conversion)]
1263                    Ok([e
1264                        .trailing_zeros()
1265                        .try_into()
1266                        .expect("bit count overflowed 32 bits, somehow!?")])
1267                })
1268            }
1269            crate::MathFunction::CountLeadingZeros => {
1270                component_wise_concrete_int!(self, span, [arg], |e| {
1271                    #[allow(clippy::useless_conversion)]
1272                    Ok([e
1273                        .leading_zeros()
1274                        .try_into()
1275                        .expect("bit count overflowed 32 bits, somehow!?")])
1276                })
1277            }
1278            crate::MathFunction::CountOneBits => {
1279                component_wise_concrete_int!(self, span, [arg], |e| {
1280                    #[allow(clippy::useless_conversion)]
1281                    Ok([e
1282                        .count_ones()
1283                        .try_into()
1284                        .expect("bit count overflowed 32 bits, somehow!?")])
1285                })
1286            }
1287            crate::MathFunction::ReverseBits => {
1288                component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1289            }
1290            crate::MathFunction::FirstTrailingBit => {
1291                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1292            }
1293            crate::MathFunction::FirstLeadingBit => {
1294                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1295            }
1296
1297            fun => Err(ConstantEvaluatorError::NotImplemented(format!(
1298                "{fun:?} built-in function"
1299            ))),
1300        }
1301    }
1302
1303    fn array_length(
1304        &mut self,
1305        array: Handle<Expression>,
1306        span: Span,
1307    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1308        match self.expressions[array] {
1309            Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
1310                match self.types[ty].inner {
1311                    TypeInner::Array { size, .. } => match size {
1312                        ArraySize::Constant(len) => {
1313                            let expr = Expression::Literal(Literal::U32(len.get()));
1314                            self.register_evaluated_expr(expr, span)
1315                        }
1316                        ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden),
1317                        ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
1318                    },
1319                    _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
1320                }
1321            }
1322            _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
1323        }
1324    }
1325
1326    fn access(
1327        &mut self,
1328        base: Handle<Expression>,
1329        index: usize,
1330        span: Span,
1331    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1332        match self.expressions[base] {
1333            Expression::ZeroValue(ty) => {
1334                let ty_inner = &self.types[ty].inner;
1335                let components = ty_inner
1336                    .components()
1337                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
1338
1339                if index >= components as usize {
1340                    Err(ConstantEvaluatorError::InvalidAccessBase)
1341                } else {
1342                    let ty_res = ty_inner
1343                        .component_type(index)
1344                        .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
1345                    let ty = match ty_res {
1346                        crate::proc::TypeResolution::Handle(ty) => ty,
1347                        crate::proc::TypeResolution::Value(inner) => {
1348                            self.types.insert(Type { name: None, inner }, span)
1349                        }
1350                    };
1351                    self.register_evaluated_expr(Expression::ZeroValue(ty), span)
1352                }
1353            }
1354            Expression::Splat { size, value } => {
1355                if index >= size as usize {
1356                    Err(ConstantEvaluatorError::InvalidAccessBase)
1357                } else {
1358                    Ok(value)
1359                }
1360            }
1361            Expression::Compose { ty, ref components } => {
1362                let _ = self.types[ty]
1363                    .inner
1364                    .components()
1365                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
1366
1367                crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1368                    .nth(index)
1369                    .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
1370            }
1371            _ => Err(ConstantEvaluatorError::InvalidAccessBase),
1372        }
1373    }
1374
1375    fn constant_index(&self, expr: Handle<Expression>) -> Result<usize, ConstantEvaluatorError> {
1376        match self.expressions[expr] {
1377            Expression::ZeroValue(ty)
1378                if matches!(
1379                    self.types[ty].inner,
1380                    TypeInner::Scalar(crate::Scalar {
1381                        kind: ScalarKind::Uint,
1382                        ..
1383                    })
1384                ) =>
1385            {
1386                Ok(0)
1387            }
1388            Expression::Literal(Literal::U32(index)) => Ok(index as usize),
1389            _ => Err(ConstantEvaluatorError::InvalidAccessIndexTy),
1390        }
1391    }
1392
1393    /// Lower [`ZeroValue`] and [`Splat`] expressions to [`Literal`] and [`Compose`] expressions.
1394    ///
1395    /// [`ZeroValue`]: Expression::ZeroValue
1396    /// [`Splat`]: Expression::Splat
1397    /// [`Literal`]: Expression::Literal
1398    /// [`Compose`]: Expression::Compose
1399    fn eval_zero_value_and_splat(
1400        &mut self,
1401        expr: Handle<Expression>,
1402        span: Span,
1403    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1404        match self.expressions[expr] {
1405            Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
1406            Expression::Splat { size, value } => self.splat(value, size, span),
1407            _ => Ok(expr),
1408        }
1409    }
1410
1411    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
1412    ///
1413    /// [`ZeroValue`]: Expression::ZeroValue
1414    /// [`Literal`]: Expression::Literal
1415    /// [`Compose`]: Expression::Compose
1416    fn eval_zero_value(
1417        &mut self,
1418        expr: Handle<Expression>,
1419        span: Span,
1420    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1421        match self.expressions[expr] {
1422            Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
1423            _ => Ok(expr),
1424        }
1425    }
1426
1427    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
1428    ///
1429    /// [`ZeroValue`]: Expression::ZeroValue
1430    /// [`Literal`]: Expression::Literal
1431    /// [`Compose`]: Expression::Compose
1432    fn eval_zero_value_impl(
1433        &mut self,
1434        ty: Handle<Type>,
1435        span: Span,
1436    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1437        match self.types[ty].inner {
1438            TypeInner::Scalar(scalar) => {
1439                let expr = Expression::Literal(
1440                    Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
1441                );
1442                self.register_evaluated_expr(expr, span)
1443            }
1444            TypeInner::Vector { size, scalar } => {
1445                let scalar_ty = self.types.insert(
1446                    Type {
1447                        name: None,
1448                        inner: TypeInner::Scalar(scalar),
1449                    },
1450                    span,
1451                );
1452                let el = self.eval_zero_value_impl(scalar_ty, span)?;
1453                let expr = Expression::Compose {
1454                    ty,
1455                    components: vec![el; size as usize],
1456                };
1457                self.register_evaluated_expr(expr, span)
1458            }
1459            TypeInner::Matrix {
1460                columns,
1461                rows,
1462                scalar,
1463            } => {
1464                let vec_ty = self.types.insert(
1465                    Type {
1466                        name: None,
1467                        inner: TypeInner::Vector { size: rows, scalar },
1468                    },
1469                    span,
1470                );
1471                let el = self.eval_zero_value_impl(vec_ty, span)?;
1472                let expr = Expression::Compose {
1473                    ty,
1474                    components: vec![el; columns as usize],
1475                };
1476                self.register_evaluated_expr(expr, span)
1477            }
1478            TypeInner::Array {
1479                base,
1480                size: ArraySize::Constant(size),
1481                ..
1482            } => {
1483                let el = self.eval_zero_value_impl(base, span)?;
1484                let expr = Expression::Compose {
1485                    ty,
1486                    components: vec![el; size.get() as usize],
1487                };
1488                self.register_evaluated_expr(expr, span)
1489            }
1490            TypeInner::Struct { ref members, .. } => {
1491                let types: Vec<_> = members.iter().map(|m| m.ty).collect();
1492                let mut components = Vec::with_capacity(members.len());
1493                for ty in types {
1494                    components.push(self.eval_zero_value_impl(ty, span)?);
1495                }
1496                let expr = Expression::Compose { ty, components };
1497                self.register_evaluated_expr(expr, span)
1498            }
1499            _ => Err(ConstantEvaluatorError::TypeNotConstructible),
1500        }
1501    }
1502
1503    /// Convert the scalar components of `expr` to `target`.
1504    ///
1505    /// Treat `span` as the location of the resulting expression.
1506    pub fn cast(
1507        &mut self,
1508        expr: Handle<Expression>,
1509        target: crate::Scalar,
1510        span: Span,
1511    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1512        use crate::Scalar as Sc;
1513
1514        let expr = self.eval_zero_value(expr, span)?;
1515
1516        let make_error = || -> Result<_, ConstantEvaluatorError> {
1517            let from = format!("{:?} {:?}", expr, self.expressions[expr]);
1518
1519            #[cfg(feature = "wgsl-in")]
1520            let to = target.to_wgsl();
1521
1522            #[cfg(not(feature = "wgsl-in"))]
1523            let to = format!("{target:?}");
1524
1525            Err(ConstantEvaluatorError::InvalidCastArg { from, to })
1526        };
1527
1528        let expr = match self.expressions[expr] {
1529            Expression::Literal(literal) => {
1530                let literal = match target {
1531                    Sc::I32 => Literal::I32(match literal {
1532                        Literal::I32(v) => v,
1533                        Literal::U32(v) => v as i32,
1534                        Literal::F32(v) => v as i32,
1535                        Literal::Bool(v) => v as i32,
1536                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1537                            return make_error();
1538                        }
1539                        Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
1540                        Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
1541                    }),
1542                    Sc::U32 => Literal::U32(match literal {
1543                        Literal::I32(v) => v as u32,
1544                        Literal::U32(v) => v,
1545                        Literal::F32(v) => v as u32,
1546                        Literal::Bool(v) => v as u32,
1547                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1548                            return make_error();
1549                        }
1550                        Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
1551                        Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
1552                    }),
1553                    Sc::I64 => Literal::I64(match literal {
1554                        Literal::I32(v) => v as i64,
1555                        Literal::U32(v) => v as i64,
1556                        Literal::F32(v) => v as i64,
1557                        Literal::Bool(v) => v as i64,
1558                        Literal::F64(v) => v as i64,
1559                        Literal::I64(v) => v,
1560                        Literal::U64(v) => v as i64,
1561                        Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
1562                        Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
1563                    }),
1564                    Sc::U64 => Literal::U64(match literal {
1565                        Literal::I32(v) => v as u64,
1566                        Literal::U32(v) => v as u64,
1567                        Literal::F32(v) => v as u64,
1568                        Literal::Bool(v) => v as u64,
1569                        Literal::F64(v) => v as u64,
1570                        Literal::I64(v) => v as u64,
1571                        Literal::U64(v) => v,
1572                        Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
1573                        Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
1574                    }),
1575                    Sc::F32 => Literal::F32(match literal {
1576                        Literal::I32(v) => v as f32,
1577                        Literal::U32(v) => v as f32,
1578                        Literal::F32(v) => v,
1579                        Literal::Bool(v) => v as u32 as f32,
1580                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1581                            return make_error();
1582                        }
1583                        Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
1584                        Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
1585                    }),
1586                    Sc::F64 => Literal::F64(match literal {
1587                        Literal::I32(v) => v as f64,
1588                        Literal::U32(v) => v as f64,
1589                        Literal::F32(v) => v as f64,
1590                        Literal::F64(v) => v,
1591                        Literal::Bool(v) => v as u32 as f64,
1592                        Literal::I64(_) | Literal::U64(_) => return make_error(),
1593                        Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
1594                        Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
1595                    }),
1596                    Sc::BOOL => Literal::Bool(match literal {
1597                        Literal::I32(v) => v != 0,
1598                        Literal::U32(v) => v != 0,
1599                        Literal::F32(v) => v != 0.0,
1600                        Literal::Bool(v) => v,
1601                        Literal::F64(_)
1602                        | Literal::I64(_)
1603                        | Literal::U64(_)
1604                        | Literal::AbstractInt(_)
1605                        | Literal::AbstractFloat(_) => {
1606                            return make_error();
1607                        }
1608                    }),
1609                    Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
1610                        Literal::AbstractInt(v) => {
1611                            // Overflow is forbidden, but inexact conversions
1612                            // are fine. The range of f64 is far larger than
1613                            // that of i64, so we don't have to check anything
1614                            // here.
1615                            v as f64
1616                        }
1617                        Literal::AbstractFloat(v) => v,
1618                        _ => return make_error(),
1619                    }),
1620                    _ => {
1621                        log::debug!("Constant evaluator refused to convert value to {target:?}");
1622                        return make_error();
1623                    }
1624                };
1625                Expression::Literal(literal)
1626            }
1627            Expression::Compose {
1628                ty,
1629                components: ref src_components,
1630            } => {
1631                let ty_inner = match self.types[ty].inner {
1632                    TypeInner::Vector { size, .. } => TypeInner::Vector {
1633                        size,
1634                        scalar: target,
1635                    },
1636                    TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
1637                        columns,
1638                        rows,
1639                        scalar: target,
1640                    },
1641                    _ => return make_error(),
1642                };
1643
1644                let mut components = src_components.clone();
1645                for component in &mut components {
1646                    *component = self.cast(*component, target, span)?;
1647                }
1648
1649                let ty = self.types.insert(
1650                    Type {
1651                        name: None,
1652                        inner: ty_inner,
1653                    },
1654                    span,
1655                );
1656
1657                Expression::Compose { ty, components }
1658            }
1659            Expression::Splat { size, value } => {
1660                let value_span = self.expressions.get_span(value);
1661                let cast_value = self.cast(value, target, value_span)?;
1662                Expression::Splat {
1663                    size,
1664                    value: cast_value,
1665                }
1666            }
1667            _ => return make_error(),
1668        };
1669
1670        self.register_evaluated_expr(expr, span)
1671    }
1672
1673    /// Convert the scalar leaves of  `expr` to `target`, handling arrays.
1674    ///
1675    /// `expr` must be a `Compose` expression whose type is a scalar, vector,
1676    /// matrix, or nested arrays of such.
1677    ///
1678    /// This is basically the same as the [`cast`] method, except that that
1679    /// should only handle Naga [`As`] expressions, which cannot convert arrays.
1680    ///
1681    /// Treat `span` as the location of the resulting expression.
1682    ///
1683    /// [`cast`]: ConstantEvaluator::cast
1684    /// [`As`]: crate::Expression::As
1685    pub fn cast_array(
1686        &mut self,
1687        expr: Handle<Expression>,
1688        target: crate::Scalar,
1689        span: Span,
1690    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1691        let Expression::Compose { ty, ref components } = self.expressions[expr] else {
1692            return self.cast(expr, target, span);
1693        };
1694
1695        let TypeInner::Array {
1696            base: _,
1697            size,
1698            stride: _,
1699        } = self.types[ty].inner
1700        else {
1701            return self.cast(expr, target, span);
1702        };
1703
1704        let mut components = components.clone();
1705        for component in &mut components {
1706            *component = self.cast_array(*component, target, span)?;
1707        }
1708
1709        let first = components.first().unwrap();
1710        let new_base = match self.resolve_type(*first)? {
1711            crate::proc::TypeResolution::Handle(ty) => ty,
1712            crate::proc::TypeResolution::Value(inner) => {
1713                self.types.insert(Type { name: None, inner }, span)
1714            }
1715        };
1716        let new_base_stride = self.types[new_base].inner.size(self.to_ctx());
1717        let new_array_ty = self.types.insert(
1718            Type {
1719                name: None,
1720                inner: TypeInner::Array {
1721                    base: new_base,
1722                    size,
1723                    stride: new_base_stride,
1724                },
1725            },
1726            span,
1727        );
1728
1729        let compose = Expression::Compose {
1730            ty: new_array_ty,
1731            components,
1732        };
1733        self.register_evaluated_expr(compose, span)
1734    }
1735
1736    fn unary_op(
1737        &mut self,
1738        op: UnaryOperator,
1739        expr: Handle<Expression>,
1740        span: Span,
1741    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1742        let expr = self.eval_zero_value_and_splat(expr, span)?;
1743
1744        let expr = match self.expressions[expr] {
1745            Expression::Literal(value) => Expression::Literal(match op {
1746                UnaryOperator::Negate => match value {
1747                    Literal::I32(v) => Literal::I32(v.wrapping_neg()),
1748                    Literal::I64(v) => Literal::I64(v.wrapping_neg()),
1749                    Literal::F32(v) => Literal::F32(-v),
1750                    Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
1751                    Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
1752                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1753                },
1754                UnaryOperator::LogicalNot => match value {
1755                    Literal::Bool(v) => Literal::Bool(!v),
1756                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1757                },
1758                UnaryOperator::BitwiseNot => match value {
1759                    Literal::I32(v) => Literal::I32(!v),
1760                    Literal::I64(v) => Literal::I64(!v),
1761                    Literal::U32(v) => Literal::U32(!v),
1762                    Literal::U64(v) => Literal::U64(!v),
1763                    Literal::AbstractInt(v) => Literal::AbstractInt(!v),
1764                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1765                },
1766            }),
1767            Expression::Compose {
1768                ty,
1769                components: ref src_components,
1770            } => {
1771                match self.types[ty].inner {
1772                    TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
1773                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1774                }
1775
1776                let mut components = src_components.clone();
1777                for component in &mut components {
1778                    *component = self.unary_op(op, *component, span)?;
1779                }
1780
1781                Expression::Compose { ty, components }
1782            }
1783            _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1784        };
1785
1786        self.register_evaluated_expr(expr, span)
1787    }
1788
1789    fn binary_op(
1790        &mut self,
1791        op: BinaryOperator,
1792        left: Handle<Expression>,
1793        right: Handle<Expression>,
1794        span: Span,
1795    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1796        let left = self.eval_zero_value_and_splat(left, span)?;
1797        let right = self.eval_zero_value_and_splat(right, span)?;
1798
1799        let expr = match (&self.expressions[left], &self.expressions[right]) {
1800            (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
1801                let literal = match op {
1802                    BinaryOperator::Equal => Literal::Bool(left_value == right_value),
1803                    BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
1804                    BinaryOperator::Less => Literal::Bool(left_value < right_value),
1805                    BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
1806                    BinaryOperator::Greater => Literal::Bool(left_value > right_value),
1807                    BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
1808
1809                    _ => match (left_value, right_value) {
1810                        (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
1811                            BinaryOperator::Add => a.wrapping_add(b),
1812                            BinaryOperator::Subtract => a.wrapping_sub(b),
1813                            BinaryOperator::Multiply => a.wrapping_mul(b),
1814                            BinaryOperator::Divide => {
1815                                if b == 0 {
1816                                    return Err(ConstantEvaluatorError::DivisionByZero);
1817                                } else {
1818                                    a.wrapping_div(b)
1819                                }
1820                            }
1821                            BinaryOperator::Modulo => {
1822                                if b == 0 {
1823                                    return Err(ConstantEvaluatorError::RemainderByZero);
1824                                } else {
1825                                    a.wrapping_rem(b)
1826                                }
1827                            }
1828                            BinaryOperator::And => a & b,
1829                            BinaryOperator::ExclusiveOr => a ^ b,
1830                            BinaryOperator::InclusiveOr => a | b,
1831                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1832                        }),
1833                        (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
1834                            BinaryOperator::ShiftLeft => {
1835                                if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
1836                                    return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
1837                                }
1838                                a.checked_shl(b)
1839                                    .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
1840                            }
1841                            BinaryOperator::ShiftRight => a
1842                                .checked_shr(b)
1843                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
1844                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1845                        }),
1846                        (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
1847                            BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
1848                                ConstantEvaluatorError::Overflow("addition".into())
1849                            })?,
1850                            BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
1851                                ConstantEvaluatorError::Overflow("subtraction".into())
1852                            })?,
1853                            BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
1854                                ConstantEvaluatorError::Overflow("multiplication".into())
1855                            })?,
1856                            BinaryOperator::Divide => a
1857                                .checked_div(b)
1858                                .ok_or(ConstantEvaluatorError::DivisionByZero)?,
1859                            BinaryOperator::Modulo => a
1860                                .checked_rem(b)
1861                                .ok_or(ConstantEvaluatorError::RemainderByZero)?,
1862                            BinaryOperator::And => a & b,
1863                            BinaryOperator::ExclusiveOr => a ^ b,
1864                            BinaryOperator::InclusiveOr => a | b,
1865                            BinaryOperator::ShiftLeft => a
1866                                .checked_mul(
1867                                    1u32.checked_shl(b)
1868                                        .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
1869                                )
1870                                .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
1871                            BinaryOperator::ShiftRight => a
1872                                .checked_shr(b)
1873                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
1874                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1875                        }),
1876                        (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
1877                            BinaryOperator::Add => a + b,
1878                            BinaryOperator::Subtract => a - b,
1879                            BinaryOperator::Multiply => a * b,
1880                            BinaryOperator::Divide => a / b,
1881                            BinaryOperator::Modulo => a % b,
1882                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1883                        }),
1884                        (Literal::AbstractInt(a), Literal::U32(b)) => {
1885                            Literal::AbstractInt(match op {
1886                                BinaryOperator::ShiftLeft => {
1887                                    if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
1888                                        return Err(ConstantEvaluatorError::Overflow(
1889                                            "<<".to_string(),
1890                                        ));
1891                                    }
1892                                    a.checked_shl(b).unwrap_or(0)
1893                                }
1894                                BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0),
1895                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1896                            })
1897                        }
1898                        (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
1899                            Literal::AbstractInt(match op {
1900                                BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
1901                                    ConstantEvaluatorError::Overflow("addition".into())
1902                                })?,
1903                                BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
1904                                    ConstantEvaluatorError::Overflow("subtraction".into())
1905                                })?,
1906                                BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
1907                                    ConstantEvaluatorError::Overflow("multiplication".into())
1908                                })?,
1909                                BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
1910                                    if b == 0 {
1911                                        ConstantEvaluatorError::DivisionByZero
1912                                    } else {
1913                                        ConstantEvaluatorError::Overflow("division".into())
1914                                    }
1915                                })?,
1916                                BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
1917                                    if b == 0 {
1918                                        ConstantEvaluatorError::RemainderByZero
1919                                    } else {
1920                                        ConstantEvaluatorError::Overflow("remainder".into())
1921                                    }
1922                                })?,
1923                                BinaryOperator::And => a & b,
1924                                BinaryOperator::ExclusiveOr => a ^ b,
1925                                BinaryOperator::InclusiveOr => a | b,
1926                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1927                            })
1928                        }
1929                        (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
1930                            Literal::AbstractFloat(match op {
1931                                BinaryOperator::Add => a + b,
1932                                BinaryOperator::Subtract => a - b,
1933                                BinaryOperator::Multiply => a * b,
1934                                BinaryOperator::Divide => a / b,
1935                                BinaryOperator::Modulo => a % b,
1936                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1937                            })
1938                        }
1939                        (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
1940                            BinaryOperator::LogicalAnd => a && b,
1941                            BinaryOperator::LogicalOr => a || b,
1942                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1943                        }),
1944                        _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1945                    },
1946                };
1947                Expression::Literal(literal)
1948            }
1949            (
1950                &Expression::Compose {
1951                    components: ref src_components,
1952                    ty,
1953                },
1954                &Expression::Literal(_),
1955            ) => {
1956                let mut components = src_components.clone();
1957                for component in &mut components {
1958                    *component = self.binary_op(op, *component, right, span)?;
1959                }
1960                Expression::Compose { ty, components }
1961            }
1962            (
1963                &Expression::Literal(_),
1964                &Expression::Compose {
1965                    components: ref src_components,
1966                    ty,
1967                },
1968            ) => {
1969                let mut components = src_components.clone();
1970                for component in &mut components {
1971                    *component = self.binary_op(op, left, *component, span)?;
1972                }
1973                Expression::Compose { ty, components }
1974            }
1975            (
1976                &Expression::Compose {
1977                    components: ref left_components,
1978                    ty: left_ty,
1979                },
1980                &Expression::Compose {
1981                    components: ref right_components,
1982                    ty: right_ty,
1983                },
1984            ) => {
1985                // We have to make a copy of the component lists, because the
1986                // call to `binary_op_vector` needs `&mut self`, but `self` owns
1987                // the component lists.
1988                let left_flattened = crate::proc::flatten_compose(
1989                    left_ty,
1990                    left_components,
1991                    self.expressions,
1992                    self.types,
1993                );
1994                let right_flattened = crate::proc::flatten_compose(
1995                    right_ty,
1996                    right_components,
1997                    self.expressions,
1998                    self.types,
1999                );
2000
2001                // `flatten_compose` doesn't return an `ExactSizeIterator`, so
2002                // make a reasonable guess of the capacity we'll need.
2003                let mut flattened = Vec::with_capacity(left_components.len());
2004                flattened.extend(left_flattened.zip(right_flattened));
2005
2006                match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
2007                    (
2008                        &TypeInner::Vector {
2009                            size: left_size, ..
2010                        },
2011                        &TypeInner::Vector {
2012                            size: right_size, ..
2013                        },
2014                    ) if left_size == right_size => {
2015                        self.binary_op_vector(op, left_size, &flattened, left_ty, span)?
2016                    }
2017                    _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2018                }
2019            }
2020            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2021        };
2022
2023        self.register_evaluated_expr(expr, span)
2024    }
2025
2026    fn binary_op_vector(
2027        &mut self,
2028        op: BinaryOperator,
2029        size: crate::VectorSize,
2030        components: &[(Handle<Expression>, Handle<Expression>)],
2031        left_ty: Handle<Type>,
2032        span: Span,
2033    ) -> Result<Expression, ConstantEvaluatorError> {
2034        let ty = match op {
2035            // Relational operators produce vectors of booleans.
2036            BinaryOperator::Equal
2037            | BinaryOperator::NotEqual
2038            | BinaryOperator::Less
2039            | BinaryOperator::LessEqual
2040            | BinaryOperator::Greater
2041            | BinaryOperator::GreaterEqual => self.types.insert(
2042                Type {
2043                    name: None,
2044                    inner: TypeInner::Vector {
2045                        size,
2046                        scalar: crate::Scalar::BOOL,
2047                    },
2048                },
2049                span,
2050            ),
2051
2052            // Other operators produce the same type as their left
2053            // operand.
2054            BinaryOperator::Add
2055            | BinaryOperator::Subtract
2056            | BinaryOperator::Multiply
2057            | BinaryOperator::Divide
2058            | BinaryOperator::Modulo
2059            | BinaryOperator::And
2060            | BinaryOperator::ExclusiveOr
2061            | BinaryOperator::InclusiveOr
2062            | BinaryOperator::LogicalAnd
2063            | BinaryOperator::LogicalOr
2064            | BinaryOperator::ShiftLeft
2065            | BinaryOperator::ShiftRight => left_ty,
2066        };
2067
2068        let components = components
2069            .iter()
2070            .map(|&(left, right)| self.binary_op(op, left, right, span))
2071            .collect::<Result<Vec<_>, _>>()?;
2072
2073        Ok(Expression::Compose { ty, components })
2074    }
2075
2076    /// Deep copy `expr` from `expressions` into `self.expressions`.
2077    ///
2078    /// Return the root of the new copy.
2079    ///
2080    /// This is used when we're evaluating expressions in a function's
2081    /// expression arena that refer to a constant: we need to copy the
2082    /// constant's value into the function's arena so we can operate on it.
2083    fn copy_from(
2084        &mut self,
2085        expr: Handle<Expression>,
2086        expressions: &Arena<Expression>,
2087    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2088        let span = expressions.get_span(expr);
2089        match expressions[expr] {
2090            ref expr @ (Expression::Literal(_)
2091            | Expression::Constant(_)
2092            | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
2093            Expression::Compose { ty, ref components } => {
2094                let mut components = components.clone();
2095                for component in &mut components {
2096                    *component = self.copy_from(*component, expressions)?;
2097                }
2098                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
2099            }
2100            Expression::Splat { size, value } => {
2101                let value = self.copy_from(value, expressions)?;
2102                self.register_evaluated_expr(Expression::Splat { size, value }, span)
2103            }
2104            _ => {
2105                log::debug!("copy_from: SubexpressionsAreNotConstant");
2106                Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
2107            }
2108        }
2109    }
2110
2111    fn register_evaluated_expr(
2112        &mut self,
2113        expr: Expression,
2114        span: Span,
2115    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2116        // It suffices to only check literals, since we only register one
2117        // expression at a time, `Compose` expressions can only refer to other
2118        // expressions, and `ZeroValue` expressions are always okay.
2119        if let Expression::Literal(literal) = expr {
2120            crate::valid::check_literal_value(literal)?;
2121        }
2122
2123        Ok(self.append_expr(expr, span, ExpressionKind::Const))
2124    }
2125
2126    fn append_expr(
2127        &mut self,
2128        expr: Expression,
2129        span: Span,
2130        expr_type: ExpressionKind,
2131    ) -> Handle<Expression> {
2132        let h = match self.behavior {
2133            Behavior::Wgsl(
2134                WgslRestrictions::Runtime(ref mut function_local_data)
2135                | WgslRestrictions::Const(Some(ref mut function_local_data)),
2136            )
2137            | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
2138                let is_running = function_local_data.emitter.is_running();
2139                let needs_pre_emit = expr.needs_pre_emit();
2140                if is_running && needs_pre_emit {
2141                    function_local_data
2142                        .block
2143                        .extend(function_local_data.emitter.finish(self.expressions));
2144                    let h = self.expressions.append(expr, span);
2145                    function_local_data.emitter.start(self.expressions);
2146                    h
2147                } else {
2148                    self.expressions.append(expr, span)
2149                }
2150            }
2151            _ => self.expressions.append(expr, span),
2152        };
2153        self.expression_kind_tracker.insert(h, expr_type);
2154        h
2155    }
2156
2157    fn resolve_type(
2158        &self,
2159        expr: Handle<Expression>,
2160    ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
2161        use crate::proc::TypeResolution as Tr;
2162        use crate::Expression as Ex;
2163        let resolution = match self.expressions[expr] {
2164            Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
2165            Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
2166            Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
2167            Ex::Splat { size, value } => {
2168                let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
2169                    return Err(ConstantEvaluatorError::SplatScalarOnly);
2170                };
2171                Tr::Value(TypeInner::Vector { scalar, size })
2172            }
2173            _ => {
2174                log::debug!("resolve_type: SubexpressionsAreNotConstant");
2175                return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
2176            }
2177        };
2178
2179        Ok(resolution)
2180    }
2181}
2182
2183fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
2184    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, a value
2185    // of 1 means the least significant bit is set. Therefore, an input of `0x[80 00…]` would
2186    // return a right-to-left bit index of 0.
2187    let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
2188        match e {
2189            idx @ 0..=31 => idx,
2190            32 => u32::MAX,
2191            _ => unreachable!(),
2192        }
2193    };
2194    match concrete_int {
2195        ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
2196        ConcreteInt::I32([e]) => {
2197            ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
2198        }
2199    }
2200}
2201
2202#[test]
2203fn first_trailing_bit_smoke() {
2204    assert_eq!(
2205        first_trailing_bit(ConcreteInt::I32([0])),
2206        ConcreteInt::I32([-1])
2207    );
2208    assert_eq!(
2209        first_trailing_bit(ConcreteInt::I32([1])),
2210        ConcreteInt::I32([0])
2211    );
2212    assert_eq!(
2213        first_trailing_bit(ConcreteInt::I32([2])),
2214        ConcreteInt::I32([1])
2215    );
2216    assert_eq!(
2217        first_trailing_bit(ConcreteInt::I32([-1])),
2218        ConcreteInt::I32([0]),
2219    );
2220    assert_eq!(
2221        first_trailing_bit(ConcreteInt::I32([i32::MIN])),
2222        ConcreteInt::I32([31]),
2223    );
2224    assert_eq!(
2225        first_trailing_bit(ConcreteInt::I32([i32::MAX])),
2226        ConcreteInt::I32([0]),
2227    );
2228    for idx in 0..32 {
2229        assert_eq!(
2230            first_trailing_bit(ConcreteInt::I32([1 << idx])),
2231            ConcreteInt::I32([idx])
2232        )
2233    }
2234
2235    assert_eq!(
2236        first_trailing_bit(ConcreteInt::U32([0])),
2237        ConcreteInt::U32([u32::MAX])
2238    );
2239    assert_eq!(
2240        first_trailing_bit(ConcreteInt::U32([1])),
2241        ConcreteInt::U32([0])
2242    );
2243    assert_eq!(
2244        first_trailing_bit(ConcreteInt::U32([2])),
2245        ConcreteInt::U32([1])
2246    );
2247    assert_eq!(
2248        first_trailing_bit(ConcreteInt::U32([1 << 31])),
2249        ConcreteInt::U32([31]),
2250    );
2251    assert_eq!(
2252        first_trailing_bit(ConcreteInt::U32([u32::MAX])),
2253        ConcreteInt::U32([0]),
2254    );
2255    for idx in 0..32 {
2256        assert_eq!(
2257            first_trailing_bit(ConcreteInt::U32([1 << idx])),
2258            ConcreteInt::U32([idx])
2259        )
2260    }
2261}
2262
2263fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
2264    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, 1 means
2265    // the least significant bit is set. Therefore, an input of 1 would return a right-to-left bit
2266    // index of 0.
2267    let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
2268        match e {
2269            idx @ 0..=31 => 31 - idx,
2270            32 => u32::MAX,
2271            _ => unreachable!(),
2272        }
2273    };
2274    match concrete_int {
2275        ConcreteInt::I32([e]) => ConcreteInt::I32([{
2276            let rtl_bit_index = if e.is_negative() {
2277                e.leading_ones()
2278            } else {
2279                e.leading_zeros()
2280            };
2281            rtl_to_ltr_bit_idx(rtl_bit_index) as i32
2282        }]),
2283        ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
2284    }
2285}
2286
2287#[test]
2288fn first_leading_bit_smoke() {
2289    assert_eq!(
2290        first_leading_bit(ConcreteInt::I32([-1])),
2291        ConcreteInt::I32([-1])
2292    );
2293    assert_eq!(
2294        first_leading_bit(ConcreteInt::I32([0])),
2295        ConcreteInt::I32([-1])
2296    );
2297    assert_eq!(
2298        first_leading_bit(ConcreteInt::I32([1])),
2299        ConcreteInt::I32([0])
2300    );
2301    assert_eq!(
2302        first_leading_bit(ConcreteInt::I32([-2])),
2303        ConcreteInt::I32([0])
2304    );
2305    assert_eq!(
2306        first_leading_bit(ConcreteInt::I32([1234 + 4567])),
2307        ConcreteInt::I32([12])
2308    );
2309    assert_eq!(
2310        first_leading_bit(ConcreteInt::I32([i32::MAX])),
2311        ConcreteInt::I32([30])
2312    );
2313    assert_eq!(
2314        first_leading_bit(ConcreteInt::I32([i32::MIN])),
2315        ConcreteInt::I32([30])
2316    );
2317    // NOTE: Ignore the sign bit, which is a separate (above) case.
2318    for idx in 0..(32 - 1) {
2319        assert_eq!(
2320            first_leading_bit(ConcreteInt::I32([1 << idx])),
2321            ConcreteInt::I32([idx])
2322        );
2323    }
2324    for idx in 1..(32 - 1) {
2325        assert_eq!(
2326            first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
2327            ConcreteInt::I32([idx - 1])
2328        );
2329    }
2330
2331    assert_eq!(
2332        first_leading_bit(ConcreteInt::U32([0])),
2333        ConcreteInt::U32([u32::MAX])
2334    );
2335    assert_eq!(
2336        first_leading_bit(ConcreteInt::U32([1])),
2337        ConcreteInt::U32([0])
2338    );
2339    assert_eq!(
2340        first_leading_bit(ConcreteInt::U32([u32::MAX])),
2341        ConcreteInt::U32([31])
2342    );
2343    for idx in 0..32 {
2344        assert_eq!(
2345            first_leading_bit(ConcreteInt::U32([1 << idx])),
2346            ConcreteInt::U32([idx])
2347        )
2348    }
2349}
2350
2351/// Trait for conversions of abstract values to concrete types.
2352trait TryFromAbstract<T>: Sized {
2353    /// Convert an abstract literal `value` to `Self`.
2354    ///
2355    /// Since Naga's `AbstractInt` and `AbstractFloat` exist to support
2356    /// WGSL, we follow WGSL's conversion rules here:
2357    ///
2358    /// - WGSL §6.1.2. Conversion Rank says that automatic conversions
2359    ///   to integers are either lossless or an error.
2360    ///
2361    /// - WGSL §14.6.4 Floating Point Conversion says that conversions
2362    ///   to floating point in constant expressions and override
2363    ///   expressions are errors if the value is out of range for the
2364    ///   destination type, but rounding is okay.
2365    ///
2366    /// [`AbstractInt`]: crate::Literal::AbstractInt
2367    /// [`Float`]: crate::Literal::Float
2368    fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
2369}
2370
2371impl TryFromAbstract<i64> for i32 {
2372    fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
2373        i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2374            value: format!("{value:?}"),
2375            to_type: "i32",
2376        })
2377    }
2378}
2379
2380impl TryFromAbstract<i64> for u32 {
2381    fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
2382        u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2383            value: format!("{value:?}"),
2384            to_type: "u32",
2385        })
2386    }
2387}
2388
2389impl TryFromAbstract<i64> for u64 {
2390    fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
2391        u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2392            value: format!("{value:?}"),
2393            to_type: "u64",
2394        })
2395    }
2396}
2397
2398impl TryFromAbstract<i64> for i64 {
2399    fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
2400        Ok(value)
2401    }
2402}
2403
2404impl TryFromAbstract<i64> for f32 {
2405    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
2406        let f = value as f32;
2407        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
2408        // `f32` is roughly ±3.4 × 10³⁸, so there's no opportunity for
2409        // overflow here.
2410        Ok(f)
2411    }
2412}
2413
2414impl TryFromAbstract<f64> for f32 {
2415    fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
2416        let f = value as f32;
2417        if f.is_infinite() {
2418            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
2419                value: format!("{value:?}"),
2420                to_type: "f32",
2421            });
2422        }
2423        Ok(f)
2424    }
2425}
2426
2427impl TryFromAbstract<i64> for f64 {
2428    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
2429        let f = value as f64;
2430        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
2431        // `f64` is roughly ±1.8 × 10³⁰⁸, so there's no opportunity for
2432        // overflow here.
2433        Ok(f)
2434    }
2435}
2436
2437impl TryFromAbstract<f64> for f64 {
2438    fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
2439        Ok(value)
2440    }
2441}
2442
2443impl TryFromAbstract<f64> for i32 {
2444    fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
2445        Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "i32" })
2446    }
2447}
2448
2449impl TryFromAbstract<f64> for u32 {
2450    fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
2451        Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "u32" })
2452    }
2453}
2454
2455impl TryFromAbstract<f64> for i64 {
2456    fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
2457        Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "i64" })
2458    }
2459}
2460
2461impl TryFromAbstract<f64> for u64 {
2462    fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
2463        Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "u64" })
2464    }
2465}
2466
2467#[cfg(test)]
2468mod tests {
2469    use std::vec;
2470
2471    use crate::{
2472        Arena, Constant, Expression, Literal, ScalarKind, Type, TypeInner, UnaryOperator,
2473        UniqueArena, VectorSize,
2474    };
2475
2476    use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
2477
2478    #[test]
2479    fn unary_op() {
2480        let mut types = UniqueArena::new();
2481        let mut constants = Arena::new();
2482        let overrides = Arena::new();
2483        let mut global_expressions = Arena::new();
2484
2485        let scalar_ty = types.insert(
2486            Type {
2487                name: None,
2488                inner: TypeInner::Scalar(crate::Scalar::I32),
2489            },
2490            Default::default(),
2491        );
2492
2493        let vec_ty = types.insert(
2494            Type {
2495                name: None,
2496                inner: TypeInner::Vector {
2497                    size: VectorSize::Bi,
2498                    scalar: crate::Scalar::I32,
2499                },
2500            },
2501            Default::default(),
2502        );
2503
2504        let h = constants.append(
2505            Constant {
2506                name: None,
2507                ty: scalar_ty,
2508                init: global_expressions
2509                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
2510            },
2511            Default::default(),
2512        );
2513
2514        let h1 = constants.append(
2515            Constant {
2516                name: None,
2517                ty: scalar_ty,
2518                init: global_expressions
2519                    .append(Expression::Literal(Literal::I32(8)), Default::default()),
2520            },
2521            Default::default(),
2522        );
2523
2524        let vec_h = constants.append(
2525            Constant {
2526                name: None,
2527                ty: vec_ty,
2528                init: global_expressions.append(
2529                    Expression::Compose {
2530                        ty: vec_ty,
2531                        components: vec![constants[h].init, constants[h1].init],
2532                    },
2533                    Default::default(),
2534                ),
2535            },
2536            Default::default(),
2537        );
2538
2539        let expr = global_expressions.append(Expression::Constant(h), Default::default());
2540        let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
2541
2542        let expr2 = Expression::Unary {
2543            op: UnaryOperator::Negate,
2544            expr,
2545        };
2546
2547        let expr3 = Expression::Unary {
2548            op: UnaryOperator::BitwiseNot,
2549            expr,
2550        };
2551
2552        let expr4 = Expression::Unary {
2553            op: UnaryOperator::BitwiseNot,
2554            expr: expr1,
2555        };
2556
2557        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2558        let mut solver = ConstantEvaluator {
2559            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2560            types: &mut types,
2561            constants: &constants,
2562            overrides: &overrides,
2563            expressions: &mut global_expressions,
2564            expression_kind_tracker,
2565        };
2566
2567        let res1 = solver
2568            .try_eval_and_append(expr2, Default::default())
2569            .unwrap();
2570        let res2 = solver
2571            .try_eval_and_append(expr3, Default::default())
2572            .unwrap();
2573        let res3 = solver
2574            .try_eval_and_append(expr4, Default::default())
2575            .unwrap();
2576
2577        assert_eq!(
2578            global_expressions[res1],
2579            Expression::Literal(Literal::I32(-4))
2580        );
2581
2582        assert_eq!(
2583            global_expressions[res2],
2584            Expression::Literal(Literal::I32(!4))
2585        );
2586
2587        let res3_inner = &global_expressions[res3];
2588
2589        match *res3_inner {
2590            Expression::Compose {
2591                ref ty,
2592                ref components,
2593            } => {
2594                assert_eq!(*ty, vec_ty);
2595                let mut components_iter = components.iter().copied();
2596                assert_eq!(
2597                    global_expressions[components_iter.next().unwrap()],
2598                    Expression::Literal(Literal::I32(!4))
2599                );
2600                assert_eq!(
2601                    global_expressions[components_iter.next().unwrap()],
2602                    Expression::Literal(Literal::I32(!8))
2603                );
2604                assert!(components_iter.next().is_none());
2605            }
2606            _ => panic!("Expected vector"),
2607        }
2608    }
2609
2610    #[test]
2611    fn cast() {
2612        let mut types = UniqueArena::new();
2613        let mut constants = Arena::new();
2614        let overrides = Arena::new();
2615        let mut global_expressions = Arena::new();
2616
2617        let scalar_ty = types.insert(
2618            Type {
2619                name: None,
2620                inner: TypeInner::Scalar(crate::Scalar::I32),
2621            },
2622            Default::default(),
2623        );
2624
2625        let h = constants.append(
2626            Constant {
2627                name: None,
2628                ty: scalar_ty,
2629                init: global_expressions
2630                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
2631            },
2632            Default::default(),
2633        );
2634
2635        let expr = global_expressions.append(Expression::Constant(h), Default::default());
2636
2637        let root = Expression::As {
2638            expr,
2639            kind: ScalarKind::Bool,
2640            convert: Some(crate::BOOL_WIDTH),
2641        };
2642
2643        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2644        let mut solver = ConstantEvaluator {
2645            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2646            types: &mut types,
2647            constants: &constants,
2648            overrides: &overrides,
2649            expressions: &mut global_expressions,
2650            expression_kind_tracker,
2651        };
2652
2653        let res = solver
2654            .try_eval_and_append(root, Default::default())
2655            .unwrap();
2656
2657        assert_eq!(
2658            global_expressions[res],
2659            Expression::Literal(Literal::Bool(true))
2660        );
2661    }
2662
2663    #[test]
2664    fn access() {
2665        let mut types = UniqueArena::new();
2666        let mut constants = Arena::new();
2667        let overrides = Arena::new();
2668        let mut global_expressions = Arena::new();
2669
2670        let matrix_ty = types.insert(
2671            Type {
2672                name: None,
2673                inner: TypeInner::Matrix {
2674                    columns: VectorSize::Bi,
2675                    rows: VectorSize::Tri,
2676                    scalar: crate::Scalar::F32,
2677                },
2678            },
2679            Default::default(),
2680        );
2681
2682        let vec_ty = types.insert(
2683            Type {
2684                name: None,
2685                inner: TypeInner::Vector {
2686                    size: VectorSize::Tri,
2687                    scalar: crate::Scalar::F32,
2688                },
2689            },
2690            Default::default(),
2691        );
2692
2693        let mut vec1_components = Vec::with_capacity(3);
2694        let mut vec2_components = Vec::with_capacity(3);
2695
2696        for i in 0..3 {
2697            let h = global_expressions.append(
2698                Expression::Literal(Literal::F32(i as f32)),
2699                Default::default(),
2700            );
2701
2702            vec1_components.push(h)
2703        }
2704
2705        for i in 3..6 {
2706            let h = global_expressions.append(
2707                Expression::Literal(Literal::F32(i as f32)),
2708                Default::default(),
2709            );
2710
2711            vec2_components.push(h)
2712        }
2713
2714        let vec1 = constants.append(
2715            Constant {
2716                name: None,
2717                ty: vec_ty,
2718                init: global_expressions.append(
2719                    Expression::Compose {
2720                        ty: vec_ty,
2721                        components: vec1_components,
2722                    },
2723                    Default::default(),
2724                ),
2725            },
2726            Default::default(),
2727        );
2728
2729        let vec2 = constants.append(
2730            Constant {
2731                name: None,
2732                ty: vec_ty,
2733                init: global_expressions.append(
2734                    Expression::Compose {
2735                        ty: vec_ty,
2736                        components: vec2_components,
2737                    },
2738                    Default::default(),
2739                ),
2740            },
2741            Default::default(),
2742        );
2743
2744        let h = constants.append(
2745            Constant {
2746                name: None,
2747                ty: matrix_ty,
2748                init: global_expressions.append(
2749                    Expression::Compose {
2750                        ty: matrix_ty,
2751                        components: vec![constants[vec1].init, constants[vec2].init],
2752                    },
2753                    Default::default(),
2754                ),
2755            },
2756            Default::default(),
2757        );
2758
2759        let base = global_expressions.append(Expression::Constant(h), Default::default());
2760
2761        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2762        let mut solver = ConstantEvaluator {
2763            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2764            types: &mut types,
2765            constants: &constants,
2766            overrides: &overrides,
2767            expressions: &mut global_expressions,
2768            expression_kind_tracker,
2769        };
2770
2771        let root1 = Expression::AccessIndex { base, index: 1 };
2772
2773        let res1 = solver
2774            .try_eval_and_append(root1, Default::default())
2775            .unwrap();
2776
2777        let root2 = Expression::AccessIndex {
2778            base: res1,
2779            index: 2,
2780        };
2781
2782        let res2 = solver
2783            .try_eval_and_append(root2, Default::default())
2784            .unwrap();
2785
2786        match global_expressions[res1] {
2787            Expression::Compose {
2788                ref ty,
2789                ref components,
2790            } => {
2791                assert_eq!(*ty, vec_ty);
2792                let mut components_iter = components.iter().copied();
2793                assert_eq!(
2794                    global_expressions[components_iter.next().unwrap()],
2795                    Expression::Literal(Literal::F32(3.))
2796                );
2797                assert_eq!(
2798                    global_expressions[components_iter.next().unwrap()],
2799                    Expression::Literal(Literal::F32(4.))
2800                );
2801                assert_eq!(
2802                    global_expressions[components_iter.next().unwrap()],
2803                    Expression::Literal(Literal::F32(5.))
2804                );
2805                assert!(components_iter.next().is_none());
2806            }
2807            _ => panic!("Expected vector"),
2808        }
2809
2810        assert_eq!(
2811            global_expressions[res2],
2812            Expression::Literal(Literal::F32(5.))
2813        );
2814    }
2815
2816    #[test]
2817    fn compose_of_constants() {
2818        let mut types = UniqueArena::new();
2819        let mut constants = Arena::new();
2820        let overrides = Arena::new();
2821        let mut global_expressions = Arena::new();
2822
2823        let i32_ty = types.insert(
2824            Type {
2825                name: None,
2826                inner: TypeInner::Scalar(crate::Scalar::I32),
2827            },
2828            Default::default(),
2829        );
2830
2831        let vec2_i32_ty = types.insert(
2832            Type {
2833                name: None,
2834                inner: TypeInner::Vector {
2835                    size: VectorSize::Bi,
2836                    scalar: crate::Scalar::I32,
2837                },
2838            },
2839            Default::default(),
2840        );
2841
2842        let h = constants.append(
2843            Constant {
2844                name: None,
2845                ty: i32_ty,
2846                init: global_expressions
2847                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
2848            },
2849            Default::default(),
2850        );
2851
2852        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
2853
2854        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2855        let mut solver = ConstantEvaluator {
2856            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2857            types: &mut types,
2858            constants: &constants,
2859            overrides: &overrides,
2860            expressions: &mut global_expressions,
2861            expression_kind_tracker,
2862        };
2863
2864        let solved_compose = solver
2865            .try_eval_and_append(
2866                Expression::Compose {
2867                    ty: vec2_i32_ty,
2868                    components: vec![h_expr, h_expr],
2869                },
2870                Default::default(),
2871            )
2872            .unwrap();
2873        let solved_negate = solver
2874            .try_eval_and_append(
2875                Expression::Unary {
2876                    op: UnaryOperator::Negate,
2877                    expr: solved_compose,
2878                },
2879                Default::default(),
2880            )
2881            .unwrap();
2882
2883        let pass = match global_expressions[solved_negate] {
2884            Expression::Compose { ty, ref components } => {
2885                ty == vec2_i32_ty
2886                    && components.iter().all(|&component| {
2887                        let component = &global_expressions[component];
2888                        matches!(*component, Expression::Literal(Literal::I32(-4)))
2889                    })
2890            }
2891            _ => false,
2892        };
2893        if !pass {
2894            panic!("unexpected evaluation result")
2895        }
2896    }
2897
2898    #[test]
2899    fn splat_of_constant() {
2900        let mut types = UniqueArena::new();
2901        let mut constants = Arena::new();
2902        let overrides = Arena::new();
2903        let mut global_expressions = Arena::new();
2904
2905        let i32_ty = types.insert(
2906            Type {
2907                name: None,
2908                inner: TypeInner::Scalar(crate::Scalar::I32),
2909            },
2910            Default::default(),
2911        );
2912
2913        let vec2_i32_ty = types.insert(
2914            Type {
2915                name: None,
2916                inner: TypeInner::Vector {
2917                    size: VectorSize::Bi,
2918                    scalar: crate::Scalar::I32,
2919                },
2920            },
2921            Default::default(),
2922        );
2923
2924        let h = constants.append(
2925            Constant {
2926                name: None,
2927                ty: i32_ty,
2928                init: global_expressions
2929                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
2930            },
2931            Default::default(),
2932        );
2933
2934        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
2935
2936        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2937        let mut solver = ConstantEvaluator {
2938            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2939            types: &mut types,
2940            constants: &constants,
2941            overrides: &overrides,
2942            expressions: &mut global_expressions,
2943            expression_kind_tracker,
2944        };
2945
2946        let solved_compose = solver
2947            .try_eval_and_append(
2948                Expression::Splat {
2949                    size: VectorSize::Bi,
2950                    value: h_expr,
2951                },
2952                Default::default(),
2953            )
2954            .unwrap();
2955        let solved_negate = solver
2956            .try_eval_and_append(
2957                Expression::Unary {
2958                    op: UnaryOperator::Negate,
2959                    expr: solved_compose,
2960                },
2961                Default::default(),
2962            )
2963            .unwrap();
2964
2965        let pass = match global_expressions[solved_negate] {
2966            Expression::Compose { ty, ref components } => {
2967                ty == vec2_i32_ty
2968                    && components.iter().all(|&component| {
2969                        let component = &global_expressions[component];
2970                        matches!(*component, Expression::Literal(Literal::I32(-4)))
2971                    })
2972            }
2973            _ => false,
2974        };
2975        if !pass {
2976            panic!("unexpected evaluation result")
2977        }
2978    }
2979}