1use std::iter;
2
3use arrayvec::ArrayVec;
4
5use crate::{
6 arena::{Arena, Handle, HandleVec, UniqueArena},
7 ArraySize, BinaryOperator, Constant, Expression, Literal, Override, ScalarKind, Span, Type,
8 TypeInner, UnaryOperator,
9};
10
11macro_rules! with_dollar_sign {
17 ($($body:tt)*) => {
18 macro_rules! __with_dollar_sign { $($body)* }
19 __with_dollar_sign!($);
20 }
21}
22
23macro_rules! gen_component_wise_extractor {
24 (
25 $ident:ident -> $target:ident,
26 literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
27 scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
28 ) => {
29 #[derive(Debug)]
31 #[cfg_attr(test, derive(PartialEq))]
32 enum $target<const N: usize> {
33 $(
34 #[doc = concat!(
35 "Maps to [`Literal::",
36 stringify!($literal),
37 "`]",
38 )]
39 $mapping([$ty; N]),
40 )+
41 }
42
43 impl From<$target<1>> for Expression {
44 fn from(value: $target<1>) -> Self {
45 match value {
46 $(
47 $target::$mapping([value]) => {
48 Expression::Literal(Literal::$literal(value))
49 }
50 )+
51 }
52 }
53 }
54
55 #[doc = concat!(
56 "Attempts to evaluate multiple `exprs` as a combined [`",
57 stringify!($target),
58 "`] to pass to `handler`. ",
59 )]
60 fn $ident<const N: usize, const M: usize, F>(
67 eval: &mut ConstantEvaluator<'_>,
68 span: Span,
69 exprs: [Handle<Expression>; N],
70 mut handler: F,
71 ) -> Result<Handle<Expression>, ConstantEvaluatorError>
72 where
73 $target<M>: Into<Expression>,
74 F: FnMut($target<N>) -> Result<$target<M>, ConstantEvaluatorError> + Clone,
75 {
76 assert!(N > 0);
77 let err = ConstantEvaluatorError::InvalidMathArg;
78 let mut exprs = exprs.into_iter();
79
80 macro_rules! sanitize {
81 ($expr:expr) => {
82 eval.eval_zero_value_and_splat($expr, span)
83 .map(|expr| &eval.expressions[expr])
84 };
85 }
86
87 let new_expr = match sanitize!(exprs.next().unwrap())? {
88 $(
89 &Expression::Literal(Literal::$literal(x)) => iter::once(Ok(x))
90 .chain(exprs.map(|expr| {
91 sanitize!(expr).and_then(|expr| match expr {
92 &Expression::Literal(Literal::$literal(x)) => Ok(x),
93 _ => Err(err.clone()),
94 })
95 }))
96 .collect::<Result<ArrayVec<_, N>, _>>()
97 .map(|a| a.into_inner().unwrap())
98 .map($target::$mapping)
99 .and_then(|comps| Ok(handler(comps)?.into())),
100 )+
101 &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
102 &TypeInner::Vector { size, scalar } => match scalar.kind {
103 $(ScalarKind::$scalar_kind)|* => {
104 let first_ty = ty;
105 let mut component_groups =
106 ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
107 component_groups.push(crate::proc::flatten_compose(
108 first_ty,
109 components,
110 eval.expressions,
111 eval.types,
112 ).collect());
113 component_groups.extend(
114 exprs
115 .map(|expr| {
116 sanitize!(expr).and_then(|expr| match expr {
117 &Expression::Compose { ty, ref components }
118 if &eval.types[ty].inner
119 == &eval.types[first_ty].inner =>
120 {
121 Ok(crate::proc::flatten_compose(
122 ty,
123 components,
124 eval.expressions,
125 eval.types,
126 ).collect())
127 }
128 _ => Err(err.clone()),
129 })
130 })
131 .collect::<Result<ArrayVec<_, { crate::VectorSize::MAX }>, _>>(
132 )?,
133 );
134 let component_groups = component_groups.into_inner().unwrap();
135 let mut new_components =
136 ArrayVec::<_, { crate::VectorSize::MAX }>::new();
137 for idx in 0..(size as u8).into() {
138 let group = component_groups
139 .iter()
140 .map(|cs| cs[idx])
141 .collect::<ArrayVec<_, N>>()
142 .into_inner()
143 .unwrap();
144 new_components.push($ident(
145 eval,
146 span,
147 group,
148 handler.clone(),
149 )?);
150 }
151 Ok(Expression::Compose {
152 ty: first_ty,
153 components: new_components.into_iter().collect(),
154 })
155 }
156 _ => return Err(err),
157 },
158 _ => return Err(err),
159 },
160 _ => return Err(err),
161 }?;
162 eval.register_evaluated_expr(new_expr, span)
163 }
164
165 with_dollar_sign! {
166 ($d:tt) => {
167 #[allow(unused)]
168 #[doc = concat!(
169 "A convenience macro for using the same RHS for each [`",
170 stringify!($target),
171 "`] variant in a call to [`",
172 stringify!($ident),
173 "`].",
174 )]
175 macro_rules! $ident {
176 (
177 $eval:expr,
178 $span:expr,
179 [$d ($d expr:expr),+ $d (,)?],
180 |$d ($d arg:ident),+| $d tt:tt
181 ) => {
182 $ident($eval, $span, [$d ($d expr),+], |args| match args {
183 $(
184 $target::$mapping([$d ($d arg),+]) => {
185 let res = $d tt;
186 Result::map(res, $target::$mapping)
187 },
188 )+
189 })
190 };
191 }
192 };
193 }
194 };
195}
196
197gen_component_wise_extractor! {
198 component_wise_scalar -> Scalar,
199 literals: [
200 AbstractFloat => AbstractFloat: f64,
201 F32 => F32: f32,
202 AbstractInt => AbstractInt: i64,
203 U32 => U32: u32,
204 I32 => I32: i32,
205 U64 => U64: u64,
206 I64 => I64: i64,
207 ],
208 scalar_kinds: [
209 Float,
210 AbstractFloat,
211 Sint,
212 Uint,
213 AbstractInt,
214 ],
215}
216
217gen_component_wise_extractor! {
218 component_wise_float -> Float,
219 literals: [
220 AbstractFloat => Abstract: f64,
221 F32 => F32: f32,
222 ],
223 scalar_kinds: [
224 Float,
225 AbstractFloat,
226 ],
227}
228
229gen_component_wise_extractor! {
230 component_wise_concrete_int -> ConcreteInt,
231 literals: [
232 U32 => U32: u32,
233 I32 => I32: i32,
234 ],
235 scalar_kinds: [
236 Sint,
237 Uint,
238 ],
239}
240
241gen_component_wise_extractor! {
242 component_wise_signed -> Signed,
243 literals: [
244 AbstractFloat => AbstractFloat: f64,
245 AbstractInt => AbstractInt: i64,
246 F32 => F32: f32,
247 I32 => I32: i32,
248 ],
249 scalar_kinds: [
250 Sint,
251 AbstractInt,
252 Float,
253 AbstractFloat,
254 ],
255}
256
257#[derive(Debug)]
258enum Behavior<'a> {
259 Wgsl(WgslRestrictions<'a>),
260 Glsl(GlslRestrictions<'a>),
261}
262
263impl Behavior<'_> {
264 const fn has_runtime_restrictions(&self) -> bool {
266 matches!(
267 self,
268 &Behavior::Wgsl(WgslRestrictions::Runtime(_))
269 | &Behavior::Glsl(GlslRestrictions::Runtime(_))
270 )
271 }
272}
273
274#[derive(Debug)]
292pub struct ConstantEvaluator<'a> {
293 behavior: Behavior<'a>,
295
296 types: &'a mut UniqueArena<Type>,
303
304 constants: &'a Arena<Constant>,
306
307 overrides: &'a Arena<Override>,
309
310 expressions: &'a mut Arena<Expression>,
312
313 expression_kind_tracker: &'a mut ExpressionKindTracker,
315}
316
317#[derive(Debug)]
318enum WgslRestrictions<'a> {
319 Const(Option<FunctionLocalData<'a>>),
321 Override,
324 Runtime(FunctionLocalData<'a>),
328}
329
330#[derive(Debug)]
331enum GlslRestrictions<'a> {
332 Const,
334 Runtime(FunctionLocalData<'a>),
338}
339
340#[derive(Debug)]
341struct FunctionLocalData<'a> {
342 global_expressions: &'a Arena<Expression>,
344 emitter: &'a mut super::Emitter,
345 block: &'a mut crate::Block,
346}
347
348#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
349pub enum ExpressionKind {
350 ImplConst,
352 Const,
353 Override,
354 Runtime,
355}
356
357#[derive(Debug)]
358pub struct ExpressionKindTracker {
359 inner: HandleVec<Expression, ExpressionKind>,
360}
361
362impl ExpressionKindTracker {
363 pub const fn new() -> Self {
364 Self {
365 inner: HandleVec::new(),
366 }
367 }
368
369 pub fn force_non_const(&mut self, value: Handle<Expression>) {
371 self.inner[value] = ExpressionKind::Runtime;
372 }
373
374 pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
375 self.inner.insert(value, expr_type);
376 }
377
378 pub fn is_const(&self, h: Handle<Expression>) -> bool {
379 matches!(
380 self.type_of(h),
381 ExpressionKind::Const | ExpressionKind::ImplConst
382 )
383 }
384
385 pub fn is_impl_const(&self, h: Handle<Expression>) -> bool {
387 matches!(self.type_of(h), ExpressionKind::ImplConst)
388 }
389
390 pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
391 matches!(
392 self.type_of(h),
393 ExpressionKind::Const | ExpressionKind::Override | ExpressionKind::ImplConst
394 )
395 }
396
397 fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
398 self.inner[value]
399 }
400
401 pub fn from_arena(arena: &Arena<Expression>) -> Self {
402 let mut tracker = Self {
403 inner: HandleVec::with_capacity(arena.len()),
404 };
405 for (handle, expr) in arena.iter() {
406 tracker
407 .inner
408 .insert(handle, tracker.type_of_with_expr(expr));
409 }
410 tracker
411 }
412
413 fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
414 use crate::MathFunction as Mf;
415 match *expr {
416 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
417 ExpressionKind::ImplConst
418 }
419 Expression::Override(_) => ExpressionKind::Override,
420 Expression::Compose { ref components, .. } => {
421 let mut expr_type = ExpressionKind::ImplConst;
422 for component in components {
423 expr_type = expr_type.max(self.type_of(*component))
424 }
425 expr_type
426 }
427 Expression::Splat { value, .. } => self.type_of(value),
428 Expression::AccessIndex { base, .. } => self.type_of(base),
429 Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
430 Expression::Swizzle { vector, .. } => self.type_of(vector),
431 Expression::Unary { expr, .. } => self.type_of(expr),
432 Expression::Binary { left, right, .. } => self
433 .type_of(left)
434 .max(self.type_of(right))
435 .max(ExpressionKind::Const),
436 Expression::Math {
437 fun,
438 arg,
439 arg1,
440 arg2,
441 arg3,
442 } => self
443 .type_of(arg)
444 .max(
445 arg1.map(|arg| self.type_of(arg))
446 .unwrap_or(ExpressionKind::Const),
447 )
448 .max(
449 arg2.map(|arg| self.type_of(arg))
450 .unwrap_or(ExpressionKind::Const),
451 )
452 .max(
453 arg3.map(|arg| self.type_of(arg))
454 .unwrap_or(ExpressionKind::Const),
455 )
456 .max(
457 if matches!(
458 fun,
459 Mf::Dot
460 | Mf::Outer
461 | Mf::Cross
462 | Mf::Distance
463 | Mf::Length
464 | Mf::Normalize
465 | Mf::FaceForward
466 | Mf::Reflect
467 | Mf::Refract
468 | Mf::Ldexp
469 | Mf::Modf
470 | Mf::Mix
471 | Mf::Frexp
472 ) {
473 ExpressionKind::Const
474 } else {
475 ExpressionKind::ImplConst
476 },
477 ),
478 Expression::As { convert, expr, .. } => self.type_of(expr).max(if convert.is_some() {
479 ExpressionKind::ImplConst
480 } else {
481 ExpressionKind::Const
482 }),
483 Expression::Select {
484 condition,
485 accept,
486 reject,
487 } => self
488 .type_of(condition)
489 .max(self.type_of(accept))
490 .max(self.type_of(reject))
491 .max(ExpressionKind::Const),
492 Expression::Relational { argument, .. } => self.type_of(argument),
493 Expression::ArrayLength(expr) => self.type_of(expr),
494 _ => ExpressionKind::Runtime,
495 }
496 }
497}
498
499#[derive(Clone, Debug, thiserror::Error)]
500#[cfg_attr(test, derive(PartialEq))]
501pub enum ConstantEvaluatorError {
502 #[error("Constants cannot access function arguments")]
503 FunctionArg,
504 #[error("Constants cannot access global variables")]
505 GlobalVariable,
506 #[error("Constants cannot access local variables")]
507 LocalVariable,
508 #[error("Cannot get the array length of a non array type")]
509 InvalidArrayLengthArg,
510 #[error("Constants cannot get the array length of a dynamically sized array")]
511 ArrayLengthDynamic,
512 #[error("Constants cannot call functions")]
513 Call,
514 #[error("Constants don't support workGroupUniformLoad")]
515 WorkGroupUniformLoadResult,
516 #[error("Constants don't support atomic functions")]
517 Atomic,
518 #[error("Constants don't support derivative functions")]
519 Derivative,
520 #[error("Constants don't support load expressions")]
521 Load,
522 #[error("Constants don't support image expressions")]
523 ImageExpression,
524 #[error("Constants don't support ray query expressions")]
525 RayQueryExpression,
526 #[error("Constants don't support subgroup expressions")]
527 SubgroupExpression,
528 #[error("Cannot access the type")]
529 InvalidAccessBase,
530 #[error("Cannot access at the index")]
531 InvalidAccessIndex,
532 #[error("Cannot access with index of type")]
533 InvalidAccessIndexTy,
534 #[error("Constants don't support array length expressions")]
535 ArrayLength,
536 #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
537 InvalidCastArg { from: String, to: String },
538 #[error("Cannot apply the unary op to the argument")]
539 InvalidUnaryOpArg,
540 #[error("Cannot apply the binary op to the arguments")]
541 InvalidBinaryOpArgs,
542 #[error("Cannot apply math function to type")]
543 InvalidMathArg,
544 #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
545 InvalidMathArgCount(crate::MathFunction, usize, usize),
546 #[error("value of `low` is greater than `high` for clamp built-in function")]
547 InvalidClamp,
548 #[error("Splat is defined only on scalar values")]
549 SplatScalarOnly,
550 #[error("Can only swizzle vector constants")]
551 SwizzleVectorOnly,
552 #[error("swizzle component not present in source expression")]
553 SwizzleOutOfBounds,
554 #[error("Type is not constructible")]
555 TypeNotConstructible,
556 #[error("Subexpression(s) are not constant")]
557 SubexpressionsAreNotConstant,
558 #[error("Not implemented as constant expression: {0}")]
559 NotImplemented(String),
560 #[error("{0} operation overflowed")]
561 Overflow(String),
562 #[error(
563 "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
564 )]
565 AutomaticConversionLossy {
566 value: String,
567 to_type: &'static str,
568 },
569 #[error("abstract floating-point values cannot be automatically converted to integers")]
570 AutomaticConversionFloatToInt { to_type: &'static str },
571 #[error("Division by zero")]
572 DivisionByZero,
573 #[error("Remainder by zero")]
574 RemainderByZero,
575 #[error("RHS of shift operation is greater than or equal to 32")]
576 ShiftedMoreThan32Bits,
577 #[error(transparent)]
578 Literal(#[from] crate::valid::LiteralError),
579 #[error("Can't use pipeline-overridable constants in const-expressions")]
580 Override,
581 #[error("Unexpected runtime-expression")]
582 RuntimeExpr,
583 #[error("Unexpected override-expression")]
584 OverrideExpr,
585}
586
587impl<'a> ConstantEvaluator<'a> {
588 pub fn for_wgsl_module(
593 module: &'a mut crate::Module,
594 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
595 in_override_ctx: bool,
596 ) -> Self {
597 Self::for_module(
598 Behavior::Wgsl(if in_override_ctx {
599 WgslRestrictions::Override
600 } else {
601 WgslRestrictions::Const(None)
602 }),
603 module,
604 global_expression_kind_tracker,
605 )
606 }
607
608 pub fn for_glsl_module(
613 module: &'a mut crate::Module,
614 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
615 ) -> Self {
616 Self::for_module(
617 Behavior::Glsl(GlslRestrictions::Const),
618 module,
619 global_expression_kind_tracker,
620 )
621 }
622
623 fn for_module(
624 behavior: Behavior<'a>,
625 module: &'a mut crate::Module,
626 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
627 ) -> Self {
628 Self {
629 behavior,
630 types: &mut module.types,
631 constants: &module.constants,
632 overrides: &module.overrides,
633 expressions: &mut module.global_expressions,
634 expression_kind_tracker: global_expression_kind_tracker,
635 }
636 }
637
638 pub fn for_wgsl_function(
643 module: &'a mut crate::Module,
644 expressions: &'a mut Arena<Expression>,
645 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
646 emitter: &'a mut super::Emitter,
647 block: &'a mut crate::Block,
648 is_const: bool,
649 ) -> Self {
650 let local_data = FunctionLocalData {
651 global_expressions: &module.global_expressions,
652 emitter,
653 block,
654 };
655 Self {
656 behavior: Behavior::Wgsl(if is_const {
657 WgslRestrictions::Const(Some(local_data))
658 } else {
659 WgslRestrictions::Runtime(local_data)
660 }),
661 types: &mut module.types,
662 constants: &module.constants,
663 overrides: &module.overrides,
664 expressions,
665 expression_kind_tracker: local_expression_kind_tracker,
666 }
667 }
668
669 pub fn for_glsl_function(
674 module: &'a mut crate::Module,
675 expressions: &'a mut Arena<Expression>,
676 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
677 emitter: &'a mut super::Emitter,
678 block: &'a mut crate::Block,
679 ) -> Self {
680 Self {
681 behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
682 global_expressions: &module.global_expressions,
683 emitter,
684 block,
685 })),
686 types: &mut module.types,
687 constants: &module.constants,
688 overrides: &module.overrides,
689 expressions,
690 expression_kind_tracker: local_expression_kind_tracker,
691 }
692 }
693
694 pub fn to_ctx(&self) -> crate::proc::GlobalCtx {
695 crate::proc::GlobalCtx {
696 types: self.types,
697 constants: self.constants,
698 overrides: self.overrides,
699 global_expressions: match self.function_local_data() {
700 Some(data) => data.global_expressions,
701 None => self.expressions,
702 },
703 }
704 }
705
706 fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
707 if !self.expression_kind_tracker.is_const(expr) {
708 log::debug!("check: SubexpressionsAreNotConstant");
709 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
710 }
711 Ok(())
712 }
713
714 fn check_and_get(
715 &mut self,
716 expr: Handle<Expression>,
717 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
718 match self.expressions[expr] {
719 Expression::Constant(c) => {
720 if let Some(function_local_data) = self.function_local_data() {
723 self.copy_from(
725 self.constants[c].init,
726 function_local_data.global_expressions,
727 )
728 } else {
729 Ok(self.constants[c].init)
731 }
732 }
733 _ => {
734 self.check(expr)?;
735 Ok(expr)
736 }
737 }
738 }
739
740 pub fn try_eval_and_append(
764 &mut self,
765 expr: Expression,
766 span: Span,
767 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
768 match self.expression_kind_tracker.type_of_with_expr(&expr) {
769 ExpressionKind::ImplConst => self.try_eval_and_append_impl(&expr, span),
770 ExpressionKind::Const => {
771 let eval_result = self.try_eval_and_append_impl(&expr, span);
772 if self.behavior.has_runtime_restrictions()
777 && matches!(
778 eval_result,
779 Err(ConstantEvaluatorError::NotImplemented(_)
780 | ConstantEvaluatorError::InvalidBinaryOpArgs,)
781 )
782 {
783 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
784 } else {
785 eval_result
786 }
787 }
788 ExpressionKind::Override => match self.behavior {
789 Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
790 Ok(self.append_expr(expr, span, ExpressionKind::Override))
791 }
792 Behavior::Wgsl(WgslRestrictions::Const(_)) => {
793 Err(ConstantEvaluatorError::OverrideExpr)
794 }
795 Behavior::Glsl(_) => {
796 unreachable!()
797 }
798 },
799 ExpressionKind::Runtime => {
800 if self.behavior.has_runtime_restrictions() {
801 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
802 } else {
803 Err(ConstantEvaluatorError::RuntimeExpr)
804 }
805 }
806 }
807 }
808
809 const fn is_global_arena(&self) -> bool {
811 matches!(
812 self.behavior,
813 Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
814 | Behavior::Glsl(GlslRestrictions::Const)
815 )
816 }
817
818 const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
819 match self.behavior {
820 Behavior::Wgsl(
821 WgslRestrictions::Runtime(ref function_local_data)
822 | WgslRestrictions::Const(Some(ref function_local_data)),
823 )
824 | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
825 Some(function_local_data)
826 }
827 _ => None,
828 }
829 }
830
831 fn try_eval_and_append_impl(
832 &mut self,
833 expr: &Expression,
834 span: Span,
835 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
836 log::trace!("try_eval_and_append: {:?}", expr);
837 match *expr {
838 Expression::Constant(c) if self.is_global_arena() => {
839 Ok(self.constants[c].init)
842 }
843 Expression::Override(_) => Err(ConstantEvaluatorError::Override),
844 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
845 self.register_evaluated_expr(expr.clone(), span)
846 }
847 Expression::Compose { ty, ref components } => {
848 let components = components
849 .iter()
850 .map(|component| self.check_and_get(*component))
851 .collect::<Result<Vec<_>, _>>()?;
852 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
853 }
854 Expression::Splat { size, value } => {
855 let value = self.check_and_get(value)?;
856 self.register_evaluated_expr(Expression::Splat { size, value }, span)
857 }
858 Expression::AccessIndex { base, index } => {
859 let base = self.check_and_get(base)?;
860
861 self.access(base, index as usize, span)
862 }
863 Expression::Access { base, index } => {
864 let base = self.check_and_get(base)?;
865 let index = self.check_and_get(index)?;
866
867 self.access(base, self.constant_index(index)?, span)
868 }
869 Expression::Swizzle {
870 size,
871 vector,
872 pattern,
873 } => {
874 let vector = self.check_and_get(vector)?;
875
876 self.swizzle(size, span, vector, pattern)
877 }
878 Expression::Unary { expr, op } => {
879 let expr = self.check_and_get(expr)?;
880
881 self.unary_op(op, expr, span)
882 }
883 Expression::Binary { left, right, op } => {
884 let left = self.check_and_get(left)?;
885 let right = self.check_and_get(right)?;
886
887 self.binary_op(op, left, right, span)
888 }
889 Expression::Math {
890 fun,
891 arg,
892 arg1,
893 arg2,
894 arg3,
895 } => {
896 let arg = self.check_and_get(arg)?;
897 let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
898 let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
899 let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
900
901 self.math(arg, arg1, arg2, arg3, fun, span)
902 }
903 Expression::As {
904 convert,
905 expr,
906 kind,
907 } => {
908 let expr = self.check_and_get(expr)?;
909
910 match convert {
911 Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
912 None => Err(ConstantEvaluatorError::NotImplemented(
913 "bitcast built-in function".into(),
914 )),
915 }
916 }
917 Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented(
918 "select built-in function".into(),
919 )),
920 Expression::Relational { fun, .. } => Err(ConstantEvaluatorError::NotImplemented(
921 format!("{fun:?} built-in function"),
922 )),
923 Expression::ArrayLength(expr) => match self.behavior {
924 Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
925 Behavior::Glsl(_) => {
926 let expr = self.check_and_get(expr)?;
927 self.array_length(expr, span)
928 }
929 },
930 Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
931 Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
932 Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
933 Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
934 Expression::WorkGroupUniformLoadResult { .. } => {
935 Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
936 }
937 Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
938 Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
939 Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
940 Expression::ImageSample { .. }
941 | Expression::ImageLoad { .. }
942 | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
943 Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => {
944 Err(ConstantEvaluatorError::RayQueryExpression)
945 }
946 Expression::SubgroupBallotResult { .. } => {
947 Err(ConstantEvaluatorError::SubgroupExpression)
948 }
949 Expression::SubgroupOperationResult { .. } => {
950 Err(ConstantEvaluatorError::SubgroupExpression)
951 }
952 }
953 }
954
955 fn splat(
968 &mut self,
969 value: Handle<Expression>,
970 size: crate::VectorSize,
971 span: Span,
972 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
973 match self.expressions[value] {
974 Expression::Literal(literal) => {
975 let scalar = literal.scalar();
976 let ty = self.types.insert(
977 Type {
978 name: None,
979 inner: TypeInner::Vector { size, scalar },
980 },
981 span,
982 );
983 let expr = Expression::Compose {
984 ty,
985 components: vec![value; size as usize],
986 };
987 self.register_evaluated_expr(expr, span)
988 }
989 Expression::ZeroValue(ty) => {
990 let inner = match self.types[ty].inner {
991 TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
992 _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
993 };
994 let res_ty = self.types.insert(Type { name: None, inner }, span);
995 let expr = Expression::ZeroValue(res_ty);
996 self.register_evaluated_expr(expr, span)
997 }
998 _ => Err(ConstantEvaluatorError::SplatScalarOnly),
999 }
1000 }
1001
1002 fn swizzle(
1003 &mut self,
1004 size: crate::VectorSize,
1005 span: Span,
1006 src_constant: Handle<Expression>,
1007 pattern: [crate::SwizzleComponent; 4],
1008 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1009 let mut get_dst_ty = |ty| match self.types[ty].inner {
1010 TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1011 Type {
1012 name: None,
1013 inner: TypeInner::Vector { size, scalar },
1014 },
1015 span,
1016 )),
1017 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1018 };
1019
1020 match self.expressions[src_constant] {
1021 Expression::ZeroValue(ty) => {
1022 let dst_ty = get_dst_ty(ty)?;
1023 let expr = Expression::ZeroValue(dst_ty);
1024 self.register_evaluated_expr(expr, span)
1025 }
1026 Expression::Splat { value, .. } => {
1027 let expr = Expression::Splat { size, value };
1028 self.register_evaluated_expr(expr, span)
1029 }
1030 Expression::Compose { ty, ref components } => {
1031 let dst_ty = get_dst_ty(ty)?;
1032
1033 let mut flattened = [src_constant; 4]; let len =
1035 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1036 .zip(flattened.iter_mut())
1037 .map(|(component, elt)| *elt = component)
1038 .count();
1039 let flattened = &flattened[..len];
1040
1041 let swizzled_components = pattern[..size as usize]
1042 .iter()
1043 .map(|&sc| {
1044 let sc = sc as usize;
1045 if let Some(elt) = flattened.get(sc) {
1046 Ok(*elt)
1047 } else {
1048 Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1049 }
1050 })
1051 .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1052 let expr = Expression::Compose {
1053 ty: dst_ty,
1054 components: swizzled_components,
1055 };
1056 self.register_evaluated_expr(expr, span)
1057 }
1058 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1059 }
1060 }
1061
1062 fn math(
1063 &mut self,
1064 arg: Handle<Expression>,
1065 arg1: Option<Handle<Expression>>,
1066 arg2: Option<Handle<Expression>>,
1067 arg3: Option<Handle<Expression>>,
1068 fun: crate::MathFunction,
1069 span: Span,
1070 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1071 let expected = fun.argument_count();
1072 let given = Some(arg)
1073 .into_iter()
1074 .chain(arg1)
1075 .chain(arg2)
1076 .chain(arg3)
1077 .count();
1078 if expected != given {
1079 return Err(ConstantEvaluatorError::InvalidMathArgCount(
1080 fun, expected, given,
1081 ));
1082 }
1083
1084 match fun {
1086 crate::MathFunction::Abs => {
1088 component_wise_scalar(self, span, [arg], |args| match args {
1089 Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1090 Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1091 Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.abs()])),
1092 Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1093 Scalar::U32([e]) => Ok(Scalar::U32([e])), Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1095 Scalar::U64([e]) => Ok(Scalar::U64([e])),
1096 })
1097 }
1098 crate::MathFunction::Min => {
1099 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1100 Ok([e1.min(e2)])
1101 })
1102 }
1103 crate::MathFunction::Max => {
1104 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1105 Ok([e1.max(e2)])
1106 })
1107 }
1108 crate::MathFunction::Clamp => {
1109 component_wise_scalar!(
1110 self,
1111 span,
1112 [arg, arg1.unwrap(), arg2.unwrap()],
1113 |e, low, high| {
1114 if low > high {
1115 Err(ConstantEvaluatorError::InvalidClamp)
1116 } else {
1117 Ok([e.clamp(low, high)])
1118 }
1119 }
1120 )
1121 }
1122 crate::MathFunction::Saturate => {
1123 component_wise_float!(self, span, [arg], |e| { Ok([e.clamp(0., 1.)]) })
1124 }
1125
1126 crate::MathFunction::Cos => {
1128 component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1129 }
1130 crate::MathFunction::Cosh => {
1131 component_wise_float!(self, span, [arg], |e| { Ok([e.cosh()]) })
1132 }
1133 crate::MathFunction::Sin => {
1134 component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1135 }
1136 crate::MathFunction::Sinh => {
1137 component_wise_float!(self, span, [arg], |e| { Ok([e.sinh()]) })
1138 }
1139 crate::MathFunction::Tan => {
1140 component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1141 }
1142 crate::MathFunction::Tanh => {
1143 component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1144 }
1145 crate::MathFunction::Acos => {
1146 component_wise_float!(self, span, [arg], |e| { Ok([e.acos()]) })
1147 }
1148 crate::MathFunction::Asin => {
1149 component_wise_float!(self, span, [arg], |e| { Ok([e.asin()]) })
1150 }
1151 crate::MathFunction::Atan => {
1152 component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1153 }
1154 crate::MathFunction::Asinh => {
1155 component_wise_float!(self, span, [arg], |e| { Ok([e.asinh()]) })
1156 }
1157 crate::MathFunction::Acosh => {
1158 component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) })
1159 }
1160 crate::MathFunction::Atanh => {
1161 component_wise_float!(self, span, [arg], |e| { Ok([e.atanh()]) })
1162 }
1163 crate::MathFunction::Radians => {
1164 component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1165 }
1166 crate::MathFunction::Degrees => {
1167 component_wise_float!(self, span, [arg], |e| { Ok([e.to_degrees()]) })
1168 }
1169
1170 crate::MathFunction::Ceil => {
1172 component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1173 }
1174 crate::MathFunction::Floor => {
1175 component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1176 }
1177 crate::MathFunction::Round => {
1178 fn round_ties_even(x: f64) -> f64 {
1185 let i = x as i64;
1186 let f = (x - i as f64).abs();
1187 if f == 0.5 {
1188 if i & 1 == 1 {
1189 (x.abs() + 0.5).copysign(x)
1191 } else {
1192 (x.abs() - 0.5).copysign(x)
1193 }
1194 } else {
1195 x.round()
1196 }
1197 }
1198 component_wise_float(self, span, [arg], |e| match e {
1199 Float::Abstract([e]) => Ok(Float::Abstract([round_ties_even(e)])),
1200 Float::F32([e]) => Ok(Float::F32([(round_ties_even(e as f64) as f32)])),
1201 })
1202 }
1203 crate::MathFunction::Fract => {
1204 component_wise_float!(self, span, [arg], |e| {
1205 Ok([e - e.floor()])
1208 })
1209 }
1210 crate::MathFunction::Trunc => {
1211 component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1212 }
1213
1214 crate::MathFunction::Exp => {
1216 component_wise_float!(self, span, [arg], |e| { Ok([e.exp()]) })
1217 }
1218 crate::MathFunction::Exp2 => {
1219 component_wise_float!(self, span, [arg], |e| { Ok([e.exp2()]) })
1220 }
1221 crate::MathFunction::Log => {
1222 component_wise_float!(self, span, [arg], |e| { Ok([e.ln()]) })
1223 }
1224 crate::MathFunction::Log2 => {
1225 component_wise_float!(self, span, [arg], |e| { Ok([e.log2()]) })
1226 }
1227 crate::MathFunction::Pow => {
1228 component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1229 Ok([e1.powf(e2)])
1230 })
1231 }
1232
1233 crate::MathFunction::Sign => {
1235 component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) })
1236 }
1237 crate::MathFunction::Fma => {
1238 component_wise_float!(
1239 self,
1240 span,
1241 [arg, arg1.unwrap(), arg2.unwrap()],
1242 |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1243 )
1244 }
1245 crate::MathFunction::Step => {
1246 component_wise_float!(self, span, [arg, arg1.unwrap()], |edge, x| {
1247 Ok([if edge <= x { 1.0 } else { 0.0 }])
1248 })
1249 }
1250 crate::MathFunction::Sqrt => {
1251 component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) })
1252 }
1253 crate::MathFunction::InverseSqrt => {
1254 component_wise_float!(self, span, [arg], |e| { Ok([1. / e.sqrt()]) })
1255 }
1256
1257 crate::MathFunction::CountTrailingZeros => {
1259 component_wise_concrete_int!(self, span, [arg], |e| {
1260 #[allow(clippy::useless_conversion)]
1261 Ok([e
1262 .trailing_zeros()
1263 .try_into()
1264 .expect("bit count overflowed 32 bits, somehow!?")])
1265 })
1266 }
1267 crate::MathFunction::CountLeadingZeros => {
1268 component_wise_concrete_int!(self, span, [arg], |e| {
1269 #[allow(clippy::useless_conversion)]
1270 Ok([e
1271 .leading_zeros()
1272 .try_into()
1273 .expect("bit count overflowed 32 bits, somehow!?")])
1274 })
1275 }
1276 crate::MathFunction::CountOneBits => {
1277 component_wise_concrete_int!(self, span, [arg], |e| {
1278 #[allow(clippy::useless_conversion)]
1279 Ok([e
1280 .count_ones()
1281 .try_into()
1282 .expect("bit count overflowed 32 bits, somehow!?")])
1283 })
1284 }
1285 crate::MathFunction::ReverseBits => {
1286 component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1287 }
1288 crate::MathFunction::FirstTrailingBit => {
1289 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1290 }
1291 crate::MathFunction::FirstLeadingBit => {
1292 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1293 }
1294
1295 fun => Err(ConstantEvaluatorError::NotImplemented(format!(
1296 "{fun:?} built-in function"
1297 ))),
1298 }
1299 }
1300
1301 fn array_length(
1302 &mut self,
1303 array: Handle<Expression>,
1304 span: Span,
1305 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1306 match self.expressions[array] {
1307 Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
1308 match self.types[ty].inner {
1309 TypeInner::Array { size, .. } => match size {
1310 ArraySize::Constant(len) => {
1311 let expr = Expression::Literal(Literal::U32(len.get()));
1312 self.register_evaluated_expr(expr, span)
1313 }
1314 ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
1315 },
1316 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
1317 }
1318 }
1319 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
1320 }
1321 }
1322
1323 fn access(
1324 &mut self,
1325 base: Handle<Expression>,
1326 index: usize,
1327 span: Span,
1328 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1329 match self.expressions[base] {
1330 Expression::ZeroValue(ty) => {
1331 let ty_inner = &self.types[ty].inner;
1332 let components = ty_inner
1333 .components()
1334 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
1335
1336 if index >= components as usize {
1337 Err(ConstantEvaluatorError::InvalidAccessBase)
1338 } else {
1339 let ty_res = ty_inner
1340 .component_type(index)
1341 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
1342 let ty = match ty_res {
1343 crate::proc::TypeResolution::Handle(ty) => ty,
1344 crate::proc::TypeResolution::Value(inner) => {
1345 self.types.insert(Type { name: None, inner }, span)
1346 }
1347 };
1348 self.register_evaluated_expr(Expression::ZeroValue(ty), span)
1349 }
1350 }
1351 Expression::Splat { size, value } => {
1352 if index >= size as usize {
1353 Err(ConstantEvaluatorError::InvalidAccessBase)
1354 } else {
1355 Ok(value)
1356 }
1357 }
1358 Expression::Compose { ty, ref components } => {
1359 let _ = self.types[ty]
1360 .inner
1361 .components()
1362 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
1363
1364 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1365 .nth(index)
1366 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
1367 }
1368 _ => Err(ConstantEvaluatorError::InvalidAccessBase),
1369 }
1370 }
1371
1372 fn constant_index(&self, expr: Handle<Expression>) -> Result<usize, ConstantEvaluatorError> {
1373 match self.expressions[expr] {
1374 Expression::ZeroValue(ty)
1375 if matches!(
1376 self.types[ty].inner,
1377 TypeInner::Scalar(crate::Scalar {
1378 kind: ScalarKind::Uint,
1379 ..
1380 })
1381 ) =>
1382 {
1383 Ok(0)
1384 }
1385 Expression::Literal(Literal::U32(index)) => Ok(index as usize),
1386 _ => Err(ConstantEvaluatorError::InvalidAccessIndexTy),
1387 }
1388 }
1389
1390 fn eval_zero_value_and_splat(
1397 &mut self,
1398 expr: Handle<Expression>,
1399 span: Span,
1400 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1401 match self.expressions[expr] {
1402 Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
1403 Expression::Splat { size, value } => self.splat(value, size, span),
1404 _ => Ok(expr),
1405 }
1406 }
1407
1408 fn eval_zero_value(
1414 &mut self,
1415 expr: Handle<Expression>,
1416 span: Span,
1417 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1418 match self.expressions[expr] {
1419 Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
1420 _ => Ok(expr),
1421 }
1422 }
1423
1424 fn eval_zero_value_impl(
1430 &mut self,
1431 ty: Handle<Type>,
1432 span: Span,
1433 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1434 match self.types[ty].inner {
1435 TypeInner::Scalar(scalar) => {
1436 let expr = Expression::Literal(
1437 Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
1438 );
1439 self.register_evaluated_expr(expr, span)
1440 }
1441 TypeInner::Vector { size, scalar } => {
1442 let scalar_ty = self.types.insert(
1443 Type {
1444 name: None,
1445 inner: TypeInner::Scalar(scalar),
1446 },
1447 span,
1448 );
1449 let el = self.eval_zero_value_impl(scalar_ty, span)?;
1450 let expr = Expression::Compose {
1451 ty,
1452 components: vec![el; size as usize],
1453 };
1454 self.register_evaluated_expr(expr, span)
1455 }
1456 TypeInner::Matrix {
1457 columns,
1458 rows,
1459 scalar,
1460 } => {
1461 let vec_ty = self.types.insert(
1462 Type {
1463 name: None,
1464 inner: TypeInner::Vector { size: rows, scalar },
1465 },
1466 span,
1467 );
1468 let el = self.eval_zero_value_impl(vec_ty, span)?;
1469 let expr = Expression::Compose {
1470 ty,
1471 components: vec![el; columns as usize],
1472 };
1473 self.register_evaluated_expr(expr, span)
1474 }
1475 TypeInner::Array {
1476 base,
1477 size: ArraySize::Constant(size),
1478 ..
1479 } => {
1480 let el = self.eval_zero_value_impl(base, span)?;
1481 let expr = Expression::Compose {
1482 ty,
1483 components: vec![el; size.get() as usize],
1484 };
1485 self.register_evaluated_expr(expr, span)
1486 }
1487 TypeInner::Struct { ref members, .. } => {
1488 let types: Vec<_> = members.iter().map(|m| m.ty).collect();
1489 let mut components = Vec::with_capacity(members.len());
1490 for ty in types {
1491 components.push(self.eval_zero_value_impl(ty, span)?);
1492 }
1493 let expr = Expression::Compose { ty, components };
1494 self.register_evaluated_expr(expr, span)
1495 }
1496 _ => Err(ConstantEvaluatorError::TypeNotConstructible),
1497 }
1498 }
1499
1500 pub fn cast(
1504 &mut self,
1505 expr: Handle<Expression>,
1506 target: crate::Scalar,
1507 span: Span,
1508 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1509 use crate::Scalar as Sc;
1510
1511 let expr = self.eval_zero_value(expr, span)?;
1512
1513 let make_error = || -> Result<_, ConstantEvaluatorError> {
1514 let from = format!("{:?} {:?}", expr, self.expressions[expr]);
1515
1516 #[cfg(feature = "wgsl-in")]
1517 let to = target.to_wgsl();
1518
1519 #[cfg(not(feature = "wgsl-in"))]
1520 let to = format!("{target:?}");
1521
1522 Err(ConstantEvaluatorError::InvalidCastArg { from, to })
1523 };
1524
1525 let expr = match self.expressions[expr] {
1526 Expression::Literal(literal) => {
1527 let literal = match target {
1528 Sc::I32 => Literal::I32(match literal {
1529 Literal::I32(v) => v,
1530 Literal::U32(v) => v as i32,
1531 Literal::F32(v) => v as i32,
1532 Literal::Bool(v) => v as i32,
1533 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1534 return make_error();
1535 }
1536 Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
1537 Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
1538 }),
1539 Sc::U32 => Literal::U32(match literal {
1540 Literal::I32(v) => v as u32,
1541 Literal::U32(v) => v,
1542 Literal::F32(v) => v as u32,
1543 Literal::Bool(v) => v as u32,
1544 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1545 return make_error();
1546 }
1547 Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
1548 Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
1549 }),
1550 Sc::I64 => Literal::I64(match literal {
1551 Literal::I32(v) => v as i64,
1552 Literal::U32(v) => v as i64,
1553 Literal::F32(v) => v as i64,
1554 Literal::Bool(v) => v as i64,
1555 Literal::F64(v) => v as i64,
1556 Literal::I64(v) => v,
1557 Literal::U64(v) => v as i64,
1558 Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
1559 Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
1560 }),
1561 Sc::U64 => Literal::U64(match literal {
1562 Literal::I32(v) => v as u64,
1563 Literal::U32(v) => v as u64,
1564 Literal::F32(v) => v as u64,
1565 Literal::Bool(v) => v as u64,
1566 Literal::F64(v) => v as u64,
1567 Literal::I64(v) => v as u64,
1568 Literal::U64(v) => v,
1569 Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
1570 Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
1571 }),
1572 Sc::F32 => Literal::F32(match literal {
1573 Literal::I32(v) => v as f32,
1574 Literal::U32(v) => v as f32,
1575 Literal::F32(v) => v,
1576 Literal::Bool(v) => v as u32 as f32,
1577 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1578 return make_error();
1579 }
1580 Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
1581 Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
1582 }),
1583 Sc::F64 => Literal::F64(match literal {
1584 Literal::I32(v) => v as f64,
1585 Literal::U32(v) => v as f64,
1586 Literal::F32(v) => v as f64,
1587 Literal::F64(v) => v,
1588 Literal::Bool(v) => v as u32 as f64,
1589 Literal::I64(_) | Literal::U64(_) => return make_error(),
1590 Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
1591 Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
1592 }),
1593 Sc::BOOL => Literal::Bool(match literal {
1594 Literal::I32(v) => v != 0,
1595 Literal::U32(v) => v != 0,
1596 Literal::F32(v) => v != 0.0,
1597 Literal::Bool(v) => v,
1598 Literal::F64(_)
1599 | Literal::I64(_)
1600 | Literal::U64(_)
1601 | Literal::AbstractInt(_)
1602 | Literal::AbstractFloat(_) => {
1603 return make_error();
1604 }
1605 }),
1606 Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
1607 Literal::AbstractInt(v) => {
1608 v as f64
1613 }
1614 Literal::AbstractFloat(v) => v,
1615 _ => return make_error(),
1616 }),
1617 _ => {
1618 log::debug!("Constant evaluator refused to convert value to {target:?}");
1619 return make_error();
1620 }
1621 };
1622 Expression::Literal(literal)
1623 }
1624 Expression::Compose {
1625 ty,
1626 components: ref src_components,
1627 } => {
1628 let ty_inner = match self.types[ty].inner {
1629 TypeInner::Vector { size, .. } => TypeInner::Vector {
1630 size,
1631 scalar: target,
1632 },
1633 TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
1634 columns,
1635 rows,
1636 scalar: target,
1637 },
1638 _ => return make_error(),
1639 };
1640
1641 let mut components = src_components.clone();
1642 for component in &mut components {
1643 *component = self.cast(*component, target, span)?;
1644 }
1645
1646 let ty = self.types.insert(
1647 Type {
1648 name: None,
1649 inner: ty_inner,
1650 },
1651 span,
1652 );
1653
1654 Expression::Compose { ty, components }
1655 }
1656 Expression::Splat { size, value } => {
1657 let value_span = self.expressions.get_span(value);
1658 let cast_value = self.cast(value, target, value_span)?;
1659 Expression::Splat {
1660 size,
1661 value: cast_value,
1662 }
1663 }
1664 _ => return make_error(),
1665 };
1666
1667 self.register_evaluated_expr(expr, span)
1668 }
1669
1670 pub fn cast_array(
1683 &mut self,
1684 expr: Handle<Expression>,
1685 target: crate::Scalar,
1686 span: Span,
1687 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1688 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
1689 return self.cast(expr, target, span);
1690 };
1691
1692 let TypeInner::Array {
1693 base: _,
1694 size,
1695 stride: _,
1696 } = self.types[ty].inner
1697 else {
1698 return self.cast(expr, target, span);
1699 };
1700
1701 let mut components = components.clone();
1702 for component in &mut components {
1703 *component = self.cast_array(*component, target, span)?;
1704 }
1705
1706 let first = components.first().unwrap();
1707 let new_base = match self.resolve_type(*first)? {
1708 crate::proc::TypeResolution::Handle(ty) => ty,
1709 crate::proc::TypeResolution::Value(inner) => {
1710 self.types.insert(Type { name: None, inner }, span)
1711 }
1712 };
1713 let new_base_stride = self.types[new_base].inner.size(self.to_ctx());
1714 let new_array_ty = self.types.insert(
1715 Type {
1716 name: None,
1717 inner: TypeInner::Array {
1718 base: new_base,
1719 size,
1720 stride: new_base_stride,
1721 },
1722 },
1723 span,
1724 );
1725
1726 let compose = Expression::Compose {
1727 ty: new_array_ty,
1728 components,
1729 };
1730 self.register_evaluated_expr(compose, span)
1731 }
1732
1733 fn unary_op(
1734 &mut self,
1735 op: UnaryOperator,
1736 expr: Handle<Expression>,
1737 span: Span,
1738 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1739 let expr = self.eval_zero_value_and_splat(expr, span)?;
1740
1741 let expr = match self.expressions[expr] {
1742 Expression::Literal(value) => Expression::Literal(match op {
1743 UnaryOperator::Negate => match value {
1744 Literal::I32(v) => Literal::I32(v.wrapping_neg()),
1745 Literal::F32(v) => Literal::F32(-v),
1746 Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
1747 Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
1748 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1749 },
1750 UnaryOperator::LogicalNot => match value {
1751 Literal::Bool(v) => Literal::Bool(!v),
1752 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1753 },
1754 UnaryOperator::BitwiseNot => match value {
1755 Literal::I32(v) => Literal::I32(!v),
1756 Literal::U32(v) => Literal::U32(!v),
1757 Literal::AbstractInt(v) => Literal::AbstractInt(!v),
1758 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1759 },
1760 }),
1761 Expression::Compose {
1762 ty,
1763 components: ref src_components,
1764 } => {
1765 match self.types[ty].inner {
1766 TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
1767 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1768 }
1769
1770 let mut components = src_components.clone();
1771 for component in &mut components {
1772 *component = self.unary_op(op, *component, span)?;
1773 }
1774
1775 Expression::Compose { ty, components }
1776 }
1777 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1778 };
1779
1780 self.register_evaluated_expr(expr, span)
1781 }
1782
1783 fn binary_op(
1784 &mut self,
1785 op: BinaryOperator,
1786 left: Handle<Expression>,
1787 right: Handle<Expression>,
1788 span: Span,
1789 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1790 let left = self.eval_zero_value_and_splat(left, span)?;
1791 let right = self.eval_zero_value_and_splat(right, span)?;
1792
1793 let expr = match (&self.expressions[left], &self.expressions[right]) {
1794 (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
1795 let literal = match op {
1796 BinaryOperator::Equal => Literal::Bool(left_value == right_value),
1797 BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
1798 BinaryOperator::Less => Literal::Bool(left_value < right_value),
1799 BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
1800 BinaryOperator::Greater => Literal::Bool(left_value > right_value),
1801 BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
1802
1803 _ => match (left_value, right_value) {
1804 (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
1805 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
1806 ConstantEvaluatorError::Overflow("addition".into())
1807 })?,
1808 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
1809 ConstantEvaluatorError::Overflow("subtraction".into())
1810 })?,
1811 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
1812 ConstantEvaluatorError::Overflow("multiplication".into())
1813 })?,
1814 BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
1815 if b == 0 {
1816 ConstantEvaluatorError::DivisionByZero
1817 } else {
1818 ConstantEvaluatorError::Overflow("division".into())
1819 }
1820 })?,
1821 BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
1822 if b == 0 {
1823 ConstantEvaluatorError::RemainderByZero
1824 } else {
1825 ConstantEvaluatorError::Overflow("remainder".into())
1826 }
1827 })?,
1828 BinaryOperator::And => a & b,
1829 BinaryOperator::ExclusiveOr => a ^ b,
1830 BinaryOperator::InclusiveOr => a | b,
1831 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1832 }),
1833 (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
1834 BinaryOperator::ShiftLeft => {
1835 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
1836 return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
1837 }
1838 a.checked_shl(b)
1839 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
1840 }
1841 BinaryOperator::ShiftRight => a
1842 .checked_shr(b)
1843 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
1844 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1845 }),
1846 (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
1847 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
1848 ConstantEvaluatorError::Overflow("addition".into())
1849 })?,
1850 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
1851 ConstantEvaluatorError::Overflow("subtraction".into())
1852 })?,
1853 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
1854 ConstantEvaluatorError::Overflow("multiplication".into())
1855 })?,
1856 BinaryOperator::Divide => a
1857 .checked_div(b)
1858 .ok_or(ConstantEvaluatorError::DivisionByZero)?,
1859 BinaryOperator::Modulo => a
1860 .checked_rem(b)
1861 .ok_or(ConstantEvaluatorError::RemainderByZero)?,
1862 BinaryOperator::And => a & b,
1863 BinaryOperator::ExclusiveOr => a ^ b,
1864 BinaryOperator::InclusiveOr => a | b,
1865 BinaryOperator::ShiftLeft => a
1866 .checked_mul(
1867 1u32.checked_shl(b)
1868 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
1869 )
1870 .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
1871 BinaryOperator::ShiftRight => a
1872 .checked_shr(b)
1873 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
1874 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1875 }),
1876 (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
1877 BinaryOperator::Add => a + b,
1878 BinaryOperator::Subtract => a - b,
1879 BinaryOperator::Multiply => a * b,
1880 BinaryOperator::Divide => a / b,
1881 BinaryOperator::Modulo => a % b,
1882 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1883 }),
1884 (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
1885 Literal::AbstractInt(match op {
1886 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
1887 ConstantEvaluatorError::Overflow("addition".into())
1888 })?,
1889 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
1890 ConstantEvaluatorError::Overflow("subtraction".into())
1891 })?,
1892 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
1893 ConstantEvaluatorError::Overflow("multiplication".into())
1894 })?,
1895 BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
1896 if b == 0 {
1897 ConstantEvaluatorError::DivisionByZero
1898 } else {
1899 ConstantEvaluatorError::Overflow("division".into())
1900 }
1901 })?,
1902 BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
1903 if b == 0 {
1904 ConstantEvaluatorError::RemainderByZero
1905 } else {
1906 ConstantEvaluatorError::Overflow("remainder".into())
1907 }
1908 })?,
1909 BinaryOperator::And => a & b,
1910 BinaryOperator::ExclusiveOr => a ^ b,
1911 BinaryOperator::InclusiveOr => a | b,
1912 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1913 })
1914 }
1915 (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
1916 Literal::AbstractFloat(match op {
1917 BinaryOperator::Add => a + b,
1918 BinaryOperator::Subtract => a - b,
1919 BinaryOperator::Multiply => a * b,
1920 BinaryOperator::Divide => a / b,
1921 BinaryOperator::Modulo => a % b,
1922 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1923 })
1924 }
1925 (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
1926 BinaryOperator::LogicalAnd => a && b,
1927 BinaryOperator::LogicalOr => a || b,
1928 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1929 }),
1930 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1931 },
1932 };
1933 Expression::Literal(literal)
1934 }
1935 (
1936 &Expression::Compose {
1937 components: ref src_components,
1938 ty,
1939 },
1940 &Expression::Literal(_),
1941 ) => {
1942 let mut components = src_components.clone();
1943 for component in &mut components {
1944 *component = self.binary_op(op, *component, right, span)?;
1945 }
1946 Expression::Compose { ty, components }
1947 }
1948 (
1949 &Expression::Literal(_),
1950 &Expression::Compose {
1951 components: ref src_components,
1952 ty,
1953 },
1954 ) => {
1955 let mut components = src_components.clone();
1956 for component in &mut components {
1957 *component = self.binary_op(op, left, *component, span)?;
1958 }
1959 Expression::Compose { ty, components }
1960 }
1961 (
1962 &Expression::Compose {
1963 components: ref left_components,
1964 ty: left_ty,
1965 },
1966 &Expression::Compose {
1967 components: ref right_components,
1968 ty: right_ty,
1969 },
1970 ) => {
1971 let left_flattened = crate::proc::flatten_compose(
1975 left_ty,
1976 left_components,
1977 self.expressions,
1978 self.types,
1979 );
1980 let right_flattened = crate::proc::flatten_compose(
1981 right_ty,
1982 right_components,
1983 self.expressions,
1984 self.types,
1985 );
1986
1987 let mut flattened = Vec::with_capacity(left_components.len());
1990 flattened.extend(left_flattened.zip(right_flattened));
1991
1992 match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
1993 (
1994 &TypeInner::Vector {
1995 size: left_size, ..
1996 },
1997 &TypeInner::Vector {
1998 size: right_size, ..
1999 },
2000 ) if left_size == right_size => {
2001 self.binary_op_vector(op, left_size, &flattened, left_ty, span)?
2002 }
2003 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2004 }
2005 }
2006 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2007 };
2008
2009 self.register_evaluated_expr(expr, span)
2010 }
2011
2012 fn binary_op_vector(
2013 &mut self,
2014 op: BinaryOperator,
2015 size: crate::VectorSize,
2016 components: &[(Handle<Expression>, Handle<Expression>)],
2017 left_ty: Handle<Type>,
2018 span: Span,
2019 ) -> Result<Expression, ConstantEvaluatorError> {
2020 let ty = match op {
2021 BinaryOperator::Equal
2023 | BinaryOperator::NotEqual
2024 | BinaryOperator::Less
2025 | BinaryOperator::LessEqual
2026 | BinaryOperator::Greater
2027 | BinaryOperator::GreaterEqual => self.types.insert(
2028 Type {
2029 name: None,
2030 inner: TypeInner::Vector {
2031 size,
2032 scalar: crate::Scalar::BOOL,
2033 },
2034 },
2035 span,
2036 ),
2037
2038 BinaryOperator::Add
2041 | BinaryOperator::Subtract
2042 | BinaryOperator::Multiply
2043 | BinaryOperator::Divide
2044 | BinaryOperator::Modulo
2045 | BinaryOperator::And
2046 | BinaryOperator::ExclusiveOr
2047 | BinaryOperator::InclusiveOr
2048 | BinaryOperator::LogicalAnd
2049 | BinaryOperator::LogicalOr
2050 | BinaryOperator::ShiftLeft
2051 | BinaryOperator::ShiftRight => left_ty,
2052 };
2053
2054 let components = components
2055 .iter()
2056 .map(|&(left, right)| self.binary_op(op, left, right, span))
2057 .collect::<Result<Vec<_>, _>>()?;
2058
2059 Ok(Expression::Compose { ty, components })
2060 }
2061
2062 fn copy_from(
2070 &mut self,
2071 expr: Handle<Expression>,
2072 expressions: &Arena<Expression>,
2073 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2074 let span = expressions.get_span(expr);
2075 match expressions[expr] {
2076 ref expr @ (Expression::Literal(_)
2077 | Expression::Constant(_)
2078 | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
2079 Expression::Compose { ty, ref components } => {
2080 let mut components = components.clone();
2081 for component in &mut components {
2082 *component = self.copy_from(*component, expressions)?;
2083 }
2084 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
2085 }
2086 Expression::Splat { size, value } => {
2087 let value = self.copy_from(value, expressions)?;
2088 self.register_evaluated_expr(Expression::Splat { size, value }, span)
2089 }
2090 _ => {
2091 log::debug!("copy_from: SubexpressionsAreNotConstant");
2092 Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
2093 }
2094 }
2095 }
2096
2097 fn register_evaluated_expr(
2098 &mut self,
2099 expr: Expression,
2100 span: Span,
2101 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2102 if let Expression::Literal(literal) = expr {
2106 crate::valid::check_literal_value(literal)?;
2107 }
2108
2109 Ok(self.append_expr(expr, span, ExpressionKind::Const))
2110 }
2111
2112 fn append_expr(
2113 &mut self,
2114 expr: Expression,
2115 span: Span,
2116 expr_type: ExpressionKind,
2117 ) -> Handle<Expression> {
2118 let h = match self.behavior {
2119 Behavior::Wgsl(
2120 WgslRestrictions::Runtime(ref mut function_local_data)
2121 | WgslRestrictions::Const(Some(ref mut function_local_data)),
2122 )
2123 | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
2124 let is_running = function_local_data.emitter.is_running();
2125 let needs_pre_emit = expr.needs_pre_emit();
2126 if is_running && needs_pre_emit {
2127 function_local_data
2128 .block
2129 .extend(function_local_data.emitter.finish(self.expressions));
2130 let h = self.expressions.append(expr, span);
2131 function_local_data.emitter.start(self.expressions);
2132 h
2133 } else {
2134 self.expressions.append(expr, span)
2135 }
2136 }
2137 _ => self.expressions.append(expr, span),
2138 };
2139 self.expression_kind_tracker.insert(h, expr_type);
2140 h
2141 }
2142
2143 fn resolve_type(
2144 &self,
2145 expr: Handle<Expression>,
2146 ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
2147 use crate::proc::TypeResolution as Tr;
2148 use crate::Expression as Ex;
2149 let resolution = match self.expressions[expr] {
2150 Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
2151 Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
2152 Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
2153 Ex::Splat { size, value } => {
2154 let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
2155 return Err(ConstantEvaluatorError::SplatScalarOnly);
2156 };
2157 Tr::Value(TypeInner::Vector { scalar, size })
2158 }
2159 _ => {
2160 log::debug!("resolve_type: SubexpressionsAreNotConstant");
2161 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
2162 }
2163 };
2164
2165 Ok(resolution)
2166 }
2167}
2168
2169fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
2170 let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
2174 match e {
2175 idx @ 0..=31 => idx,
2176 32 => u32::MAX,
2177 _ => unreachable!(),
2178 }
2179 };
2180 match concrete_int {
2181 ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
2182 ConcreteInt::I32([e]) => {
2183 ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
2184 }
2185 }
2186}
2187
2188#[test]
2189fn first_trailing_bit_smoke() {
2190 assert_eq!(
2191 first_trailing_bit(ConcreteInt::I32([0])),
2192 ConcreteInt::I32([-1])
2193 );
2194 assert_eq!(
2195 first_trailing_bit(ConcreteInt::I32([1])),
2196 ConcreteInt::I32([0])
2197 );
2198 assert_eq!(
2199 first_trailing_bit(ConcreteInt::I32([2])),
2200 ConcreteInt::I32([1])
2201 );
2202 assert_eq!(
2203 first_trailing_bit(ConcreteInt::I32([-1])),
2204 ConcreteInt::I32([0]),
2205 );
2206 assert_eq!(
2207 first_trailing_bit(ConcreteInt::I32([i32::MIN])),
2208 ConcreteInt::I32([31]),
2209 );
2210 assert_eq!(
2211 first_trailing_bit(ConcreteInt::I32([i32::MAX])),
2212 ConcreteInt::I32([0]),
2213 );
2214 for idx in 0..32 {
2215 assert_eq!(
2216 first_trailing_bit(ConcreteInt::I32([1 << idx])),
2217 ConcreteInt::I32([idx])
2218 )
2219 }
2220
2221 assert_eq!(
2222 first_trailing_bit(ConcreteInt::U32([0])),
2223 ConcreteInt::U32([u32::MAX])
2224 );
2225 assert_eq!(
2226 first_trailing_bit(ConcreteInt::U32([1])),
2227 ConcreteInt::U32([0])
2228 );
2229 assert_eq!(
2230 first_trailing_bit(ConcreteInt::U32([2])),
2231 ConcreteInt::U32([1])
2232 );
2233 assert_eq!(
2234 first_trailing_bit(ConcreteInt::U32([1 << 31])),
2235 ConcreteInt::U32([31]),
2236 );
2237 assert_eq!(
2238 first_trailing_bit(ConcreteInt::U32([u32::MAX])),
2239 ConcreteInt::U32([0]),
2240 );
2241 for idx in 0..32 {
2242 assert_eq!(
2243 first_trailing_bit(ConcreteInt::U32([1 << idx])),
2244 ConcreteInt::U32([idx])
2245 )
2246 }
2247}
2248
2249fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
2250 let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
2254 match e {
2255 idx @ 0..=31 => 31 - idx,
2256 32 => u32::MAX,
2257 _ => unreachable!(),
2258 }
2259 };
2260 match concrete_int {
2261 ConcreteInt::I32([e]) => ConcreteInt::I32([{
2262 let rtl_bit_index = if e.is_negative() {
2263 e.leading_ones()
2264 } else {
2265 e.leading_zeros()
2266 };
2267 rtl_to_ltr_bit_idx(rtl_bit_index) as i32
2268 }]),
2269 ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
2270 }
2271}
2272
2273#[test]
2274fn first_leading_bit_smoke() {
2275 assert_eq!(
2276 first_leading_bit(ConcreteInt::I32([-1])),
2277 ConcreteInt::I32([-1])
2278 );
2279 assert_eq!(
2280 first_leading_bit(ConcreteInt::I32([0])),
2281 ConcreteInt::I32([-1])
2282 );
2283 assert_eq!(
2284 first_leading_bit(ConcreteInt::I32([1])),
2285 ConcreteInt::I32([0])
2286 );
2287 assert_eq!(
2288 first_leading_bit(ConcreteInt::I32([-2])),
2289 ConcreteInt::I32([0])
2290 );
2291 assert_eq!(
2292 first_leading_bit(ConcreteInt::I32([1234 + 4567])),
2293 ConcreteInt::I32([12])
2294 );
2295 assert_eq!(
2296 first_leading_bit(ConcreteInt::I32([i32::MAX])),
2297 ConcreteInt::I32([30])
2298 );
2299 assert_eq!(
2300 first_leading_bit(ConcreteInt::I32([i32::MIN])),
2301 ConcreteInt::I32([30])
2302 );
2303 for idx in 0..(32 - 1) {
2305 assert_eq!(
2306 first_leading_bit(ConcreteInt::I32([1 << idx])),
2307 ConcreteInt::I32([idx])
2308 );
2309 }
2310 for idx in 1..(32 - 1) {
2311 assert_eq!(
2312 first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
2313 ConcreteInt::I32([idx - 1])
2314 );
2315 }
2316
2317 assert_eq!(
2318 first_leading_bit(ConcreteInt::U32([0])),
2319 ConcreteInt::U32([u32::MAX])
2320 );
2321 assert_eq!(
2322 first_leading_bit(ConcreteInt::U32([1])),
2323 ConcreteInt::U32([0])
2324 );
2325 assert_eq!(
2326 first_leading_bit(ConcreteInt::U32([u32::MAX])),
2327 ConcreteInt::U32([31])
2328 );
2329 for idx in 0..32 {
2330 assert_eq!(
2331 first_leading_bit(ConcreteInt::U32([1 << idx])),
2332 ConcreteInt::U32([idx])
2333 )
2334 }
2335}
2336
2337trait TryFromAbstract<T>: Sized {
2339 fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
2355}
2356
2357impl TryFromAbstract<i64> for i32 {
2358 fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
2359 i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2360 value: format!("{value:?}"),
2361 to_type: "i32",
2362 })
2363 }
2364}
2365
2366impl TryFromAbstract<i64> for u32 {
2367 fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
2368 u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2369 value: format!("{value:?}"),
2370 to_type: "u32",
2371 })
2372 }
2373}
2374
2375impl TryFromAbstract<i64> for u64 {
2376 fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
2377 u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2378 value: format!("{value:?}"),
2379 to_type: "u64",
2380 })
2381 }
2382}
2383
2384impl TryFromAbstract<i64> for i64 {
2385 fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
2386 Ok(value)
2387 }
2388}
2389
2390impl TryFromAbstract<i64> for f32 {
2391 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
2392 let f = value as f32;
2393 Ok(f)
2397 }
2398}
2399
2400impl TryFromAbstract<f64> for f32 {
2401 fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
2402 let f = value as f32;
2403 if f.is_infinite() {
2404 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
2405 value: format!("{value:?}"),
2406 to_type: "f32",
2407 });
2408 }
2409 Ok(f)
2410 }
2411}
2412
2413impl TryFromAbstract<i64> for f64 {
2414 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
2415 let f = value as f64;
2416 Ok(f)
2420 }
2421}
2422
2423impl TryFromAbstract<f64> for f64 {
2424 fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
2425 Ok(value)
2426 }
2427}
2428
2429impl TryFromAbstract<f64> for i32 {
2430 fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
2431 Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "i32" })
2432 }
2433}
2434
2435impl TryFromAbstract<f64> for u32 {
2436 fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
2437 Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "u32" })
2438 }
2439}
2440
2441impl TryFromAbstract<f64> for i64 {
2442 fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
2443 Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "i64" })
2444 }
2445}
2446
2447impl TryFromAbstract<f64> for u64 {
2448 fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
2449 Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "u64" })
2450 }
2451}
2452
2453#[cfg(test)]
2454mod tests {
2455 use std::vec;
2456
2457 use crate::{
2458 Arena, Constant, Expression, Literal, ScalarKind, Type, TypeInner, UnaryOperator,
2459 UniqueArena, VectorSize,
2460 };
2461
2462 use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
2463
2464 #[test]
2465 fn unary_op() {
2466 let mut types = UniqueArena::new();
2467 let mut constants = Arena::new();
2468 let overrides = Arena::new();
2469 let mut global_expressions = Arena::new();
2470
2471 let scalar_ty = types.insert(
2472 Type {
2473 name: None,
2474 inner: TypeInner::Scalar(crate::Scalar::I32),
2475 },
2476 Default::default(),
2477 );
2478
2479 let vec_ty = types.insert(
2480 Type {
2481 name: None,
2482 inner: TypeInner::Vector {
2483 size: VectorSize::Bi,
2484 scalar: crate::Scalar::I32,
2485 },
2486 },
2487 Default::default(),
2488 );
2489
2490 let h = constants.append(
2491 Constant {
2492 name: None,
2493 ty: scalar_ty,
2494 init: global_expressions
2495 .append(Expression::Literal(Literal::I32(4)), Default::default()),
2496 },
2497 Default::default(),
2498 );
2499
2500 let h1 = constants.append(
2501 Constant {
2502 name: None,
2503 ty: scalar_ty,
2504 init: global_expressions
2505 .append(Expression::Literal(Literal::I32(8)), Default::default()),
2506 },
2507 Default::default(),
2508 );
2509
2510 let vec_h = constants.append(
2511 Constant {
2512 name: None,
2513 ty: vec_ty,
2514 init: global_expressions.append(
2515 Expression::Compose {
2516 ty: vec_ty,
2517 components: vec![constants[h].init, constants[h1].init],
2518 },
2519 Default::default(),
2520 ),
2521 },
2522 Default::default(),
2523 );
2524
2525 let expr = global_expressions.append(Expression::Constant(h), Default::default());
2526 let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
2527
2528 let expr2 = Expression::Unary {
2529 op: UnaryOperator::Negate,
2530 expr,
2531 };
2532
2533 let expr3 = Expression::Unary {
2534 op: UnaryOperator::BitwiseNot,
2535 expr,
2536 };
2537
2538 let expr4 = Expression::Unary {
2539 op: UnaryOperator::BitwiseNot,
2540 expr: expr1,
2541 };
2542
2543 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2544 let mut solver = ConstantEvaluator {
2545 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2546 types: &mut types,
2547 constants: &constants,
2548 overrides: &overrides,
2549 expressions: &mut global_expressions,
2550 expression_kind_tracker,
2551 };
2552
2553 let res1 = solver
2554 .try_eval_and_append(expr2, Default::default())
2555 .unwrap();
2556 let res2 = solver
2557 .try_eval_and_append(expr3, Default::default())
2558 .unwrap();
2559 let res3 = solver
2560 .try_eval_and_append(expr4, Default::default())
2561 .unwrap();
2562
2563 assert_eq!(
2564 global_expressions[res1],
2565 Expression::Literal(Literal::I32(-4))
2566 );
2567
2568 assert_eq!(
2569 global_expressions[res2],
2570 Expression::Literal(Literal::I32(!4))
2571 );
2572
2573 let res3_inner = &global_expressions[res3];
2574
2575 match *res3_inner {
2576 Expression::Compose {
2577 ref ty,
2578 ref components,
2579 } => {
2580 assert_eq!(*ty, vec_ty);
2581 let mut components_iter = components.iter().copied();
2582 assert_eq!(
2583 global_expressions[components_iter.next().unwrap()],
2584 Expression::Literal(Literal::I32(!4))
2585 );
2586 assert_eq!(
2587 global_expressions[components_iter.next().unwrap()],
2588 Expression::Literal(Literal::I32(!8))
2589 );
2590 assert!(components_iter.next().is_none());
2591 }
2592 _ => panic!("Expected vector"),
2593 }
2594 }
2595
2596 #[test]
2597 fn cast() {
2598 let mut types = UniqueArena::new();
2599 let mut constants = Arena::new();
2600 let overrides = Arena::new();
2601 let mut global_expressions = Arena::new();
2602
2603 let scalar_ty = types.insert(
2604 Type {
2605 name: None,
2606 inner: TypeInner::Scalar(crate::Scalar::I32),
2607 },
2608 Default::default(),
2609 );
2610
2611 let h = constants.append(
2612 Constant {
2613 name: None,
2614 ty: scalar_ty,
2615 init: global_expressions
2616 .append(Expression::Literal(Literal::I32(4)), Default::default()),
2617 },
2618 Default::default(),
2619 );
2620
2621 let expr = global_expressions.append(Expression::Constant(h), Default::default());
2622
2623 let root = Expression::As {
2624 expr,
2625 kind: ScalarKind::Bool,
2626 convert: Some(crate::BOOL_WIDTH),
2627 };
2628
2629 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2630 let mut solver = ConstantEvaluator {
2631 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2632 types: &mut types,
2633 constants: &constants,
2634 overrides: &overrides,
2635 expressions: &mut global_expressions,
2636 expression_kind_tracker,
2637 };
2638
2639 let res = solver
2640 .try_eval_and_append(root, Default::default())
2641 .unwrap();
2642
2643 assert_eq!(
2644 global_expressions[res],
2645 Expression::Literal(Literal::Bool(true))
2646 );
2647 }
2648
2649 #[test]
2650 fn access() {
2651 let mut types = UniqueArena::new();
2652 let mut constants = Arena::new();
2653 let overrides = Arena::new();
2654 let mut global_expressions = Arena::new();
2655
2656 let matrix_ty = types.insert(
2657 Type {
2658 name: None,
2659 inner: TypeInner::Matrix {
2660 columns: VectorSize::Bi,
2661 rows: VectorSize::Tri,
2662 scalar: crate::Scalar::F32,
2663 },
2664 },
2665 Default::default(),
2666 );
2667
2668 let vec_ty = types.insert(
2669 Type {
2670 name: None,
2671 inner: TypeInner::Vector {
2672 size: VectorSize::Tri,
2673 scalar: crate::Scalar::F32,
2674 },
2675 },
2676 Default::default(),
2677 );
2678
2679 let mut vec1_components = Vec::with_capacity(3);
2680 let mut vec2_components = Vec::with_capacity(3);
2681
2682 for i in 0..3 {
2683 let h = global_expressions.append(
2684 Expression::Literal(Literal::F32(i as f32)),
2685 Default::default(),
2686 );
2687
2688 vec1_components.push(h)
2689 }
2690
2691 for i in 3..6 {
2692 let h = global_expressions.append(
2693 Expression::Literal(Literal::F32(i as f32)),
2694 Default::default(),
2695 );
2696
2697 vec2_components.push(h)
2698 }
2699
2700 let vec1 = constants.append(
2701 Constant {
2702 name: None,
2703 ty: vec_ty,
2704 init: global_expressions.append(
2705 Expression::Compose {
2706 ty: vec_ty,
2707 components: vec1_components,
2708 },
2709 Default::default(),
2710 ),
2711 },
2712 Default::default(),
2713 );
2714
2715 let vec2 = constants.append(
2716 Constant {
2717 name: None,
2718 ty: vec_ty,
2719 init: global_expressions.append(
2720 Expression::Compose {
2721 ty: vec_ty,
2722 components: vec2_components,
2723 },
2724 Default::default(),
2725 ),
2726 },
2727 Default::default(),
2728 );
2729
2730 let h = constants.append(
2731 Constant {
2732 name: None,
2733 ty: matrix_ty,
2734 init: global_expressions.append(
2735 Expression::Compose {
2736 ty: matrix_ty,
2737 components: vec![constants[vec1].init, constants[vec2].init],
2738 },
2739 Default::default(),
2740 ),
2741 },
2742 Default::default(),
2743 );
2744
2745 let base = global_expressions.append(Expression::Constant(h), Default::default());
2746
2747 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2748 let mut solver = ConstantEvaluator {
2749 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2750 types: &mut types,
2751 constants: &constants,
2752 overrides: &overrides,
2753 expressions: &mut global_expressions,
2754 expression_kind_tracker,
2755 };
2756
2757 let root1 = Expression::AccessIndex { base, index: 1 };
2758
2759 let res1 = solver
2760 .try_eval_and_append(root1, Default::default())
2761 .unwrap();
2762
2763 let root2 = Expression::AccessIndex {
2764 base: res1,
2765 index: 2,
2766 };
2767
2768 let res2 = solver
2769 .try_eval_and_append(root2, Default::default())
2770 .unwrap();
2771
2772 match global_expressions[res1] {
2773 Expression::Compose {
2774 ref ty,
2775 ref components,
2776 } => {
2777 assert_eq!(*ty, vec_ty);
2778 let mut components_iter = components.iter().copied();
2779 assert_eq!(
2780 global_expressions[components_iter.next().unwrap()],
2781 Expression::Literal(Literal::F32(3.))
2782 );
2783 assert_eq!(
2784 global_expressions[components_iter.next().unwrap()],
2785 Expression::Literal(Literal::F32(4.))
2786 );
2787 assert_eq!(
2788 global_expressions[components_iter.next().unwrap()],
2789 Expression::Literal(Literal::F32(5.))
2790 );
2791 assert!(components_iter.next().is_none());
2792 }
2793 _ => panic!("Expected vector"),
2794 }
2795
2796 assert_eq!(
2797 global_expressions[res2],
2798 Expression::Literal(Literal::F32(5.))
2799 );
2800 }
2801
2802 #[test]
2803 fn compose_of_constants() {
2804 let mut types = UniqueArena::new();
2805 let mut constants = Arena::new();
2806 let overrides = Arena::new();
2807 let mut global_expressions = Arena::new();
2808
2809 let i32_ty = types.insert(
2810 Type {
2811 name: None,
2812 inner: TypeInner::Scalar(crate::Scalar::I32),
2813 },
2814 Default::default(),
2815 );
2816
2817 let vec2_i32_ty = types.insert(
2818 Type {
2819 name: None,
2820 inner: TypeInner::Vector {
2821 size: VectorSize::Bi,
2822 scalar: crate::Scalar::I32,
2823 },
2824 },
2825 Default::default(),
2826 );
2827
2828 let h = constants.append(
2829 Constant {
2830 name: None,
2831 ty: i32_ty,
2832 init: global_expressions
2833 .append(Expression::Literal(Literal::I32(4)), Default::default()),
2834 },
2835 Default::default(),
2836 );
2837
2838 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
2839
2840 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2841 let mut solver = ConstantEvaluator {
2842 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2843 types: &mut types,
2844 constants: &constants,
2845 overrides: &overrides,
2846 expressions: &mut global_expressions,
2847 expression_kind_tracker,
2848 };
2849
2850 let solved_compose = solver
2851 .try_eval_and_append(
2852 Expression::Compose {
2853 ty: vec2_i32_ty,
2854 components: vec![h_expr, h_expr],
2855 },
2856 Default::default(),
2857 )
2858 .unwrap();
2859 let solved_negate = solver
2860 .try_eval_and_append(
2861 Expression::Unary {
2862 op: UnaryOperator::Negate,
2863 expr: solved_compose,
2864 },
2865 Default::default(),
2866 )
2867 .unwrap();
2868
2869 let pass = match global_expressions[solved_negate] {
2870 Expression::Compose { ty, ref components } => {
2871 ty == vec2_i32_ty
2872 && components.iter().all(|&component| {
2873 let component = &global_expressions[component];
2874 matches!(*component, Expression::Literal(Literal::I32(-4)))
2875 })
2876 }
2877 _ => false,
2878 };
2879 if !pass {
2880 panic!("unexpected evaluation result")
2881 }
2882 }
2883
2884 #[test]
2885 fn splat_of_constant() {
2886 let mut types = UniqueArena::new();
2887 let mut constants = Arena::new();
2888 let overrides = Arena::new();
2889 let mut global_expressions = Arena::new();
2890
2891 let i32_ty = types.insert(
2892 Type {
2893 name: None,
2894 inner: TypeInner::Scalar(crate::Scalar::I32),
2895 },
2896 Default::default(),
2897 );
2898
2899 let vec2_i32_ty = types.insert(
2900 Type {
2901 name: None,
2902 inner: TypeInner::Vector {
2903 size: VectorSize::Bi,
2904 scalar: crate::Scalar::I32,
2905 },
2906 },
2907 Default::default(),
2908 );
2909
2910 let h = constants.append(
2911 Constant {
2912 name: None,
2913 ty: i32_ty,
2914 init: global_expressions
2915 .append(Expression::Literal(Literal::I32(4)), Default::default()),
2916 },
2917 Default::default(),
2918 );
2919
2920 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
2921
2922 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2923 let mut solver = ConstantEvaluator {
2924 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2925 types: &mut types,
2926 constants: &constants,
2927 overrides: &overrides,
2928 expressions: &mut global_expressions,
2929 expression_kind_tracker,
2930 };
2931
2932 let solved_compose = solver
2933 .try_eval_and_append(
2934 Expression::Splat {
2935 size: VectorSize::Bi,
2936 value: h_expr,
2937 },
2938 Default::default(),
2939 )
2940 .unwrap();
2941 let solved_negate = solver
2942 .try_eval_and_append(
2943 Expression::Unary {
2944 op: UnaryOperator::Negate,
2945 expr: solved_compose,
2946 },
2947 Default::default(),
2948 )
2949 .unwrap();
2950
2951 let pass = match global_expressions[solved_negate] {
2952 Expression::Compose { ty, ref components } => {
2953 ty == vec2_i32_ty
2954 && components.iter().all(|&component| {
2955 let component = &global_expressions[component];
2956 matches!(*component, Expression::Literal(Literal::I32(-4)))
2957 })
2958 }
2959 _ => false,
2960 };
2961 if !pass {
2962 panic!("unexpected evaluation result")
2963 }
2964 }
2965}