1use std::iter;
2
3use arrayvec::ArrayVec;
4
5use crate::{
6 arena::{Arena, Handle, HandleVec, UniqueArena},
7 ArraySize, BinaryOperator, Constant, Expression, Literal, Override, ScalarKind, Span, Type,
8 TypeInner, UnaryOperator,
9};
10
11macro_rules! with_dollar_sign {
17 ($($body:tt)*) => {
18 macro_rules! __with_dollar_sign { $($body)* }
19 __with_dollar_sign!($);
20 }
21}
22
23macro_rules! gen_component_wise_extractor {
24 (
25 $ident:ident -> $target:ident,
26 literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
27 scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
28 ) => {
29 #[derive(Debug)]
31 #[cfg_attr(test, derive(PartialEq))]
32 enum $target<const N: usize> {
33 $(
34 #[doc = concat!(
35 "Maps to [`Literal::",
36 stringify!($literal),
37 "`]",
38 )]
39 $mapping([$ty; N]),
40 )+
41 }
42
43 impl From<$target<1>> for Expression {
44 fn from(value: $target<1>) -> Self {
45 match value {
46 $(
47 $target::$mapping([value]) => {
48 Expression::Literal(Literal::$literal(value))
49 }
50 )+
51 }
52 }
53 }
54
55 #[doc = concat!(
56 "Attempts to evaluate multiple `exprs` as a combined [`",
57 stringify!($target),
58 "`] to pass to `handler`. ",
59 )]
60 fn $ident<const N: usize, const M: usize, F>(
67 eval: &mut ConstantEvaluator<'_>,
68 span: Span,
69 exprs: [Handle<Expression>; N],
70 mut handler: F,
71 ) -> Result<Handle<Expression>, ConstantEvaluatorError>
72 where
73 $target<M>: Into<Expression>,
74 F: FnMut($target<N>) -> Result<$target<M>, ConstantEvaluatorError> + Clone,
75 {
76 assert!(N > 0);
77 let err = ConstantEvaluatorError::InvalidMathArg;
78 let mut exprs = exprs.into_iter();
79
80 macro_rules! sanitize {
81 ($expr:expr) => {
82 eval.eval_zero_value_and_splat($expr, span)
83 .map(|expr| &eval.expressions[expr])
84 };
85 }
86
87 let new_expr = match sanitize!(exprs.next().unwrap())? {
88 $(
89 &Expression::Literal(Literal::$literal(x)) => iter::once(Ok(x))
90 .chain(exprs.map(|expr| {
91 sanitize!(expr).and_then(|expr| match expr {
92 &Expression::Literal(Literal::$literal(x)) => Ok(x),
93 _ => Err(err.clone()),
94 })
95 }))
96 .collect::<Result<ArrayVec<_, N>, _>>()
97 .map(|a| a.into_inner().unwrap())
98 .map($target::$mapping)
99 .and_then(|comps| Ok(handler(comps)?.into())),
100 )+
101 &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
102 &TypeInner::Vector { size, scalar } => match scalar.kind {
103 $(ScalarKind::$scalar_kind)|* => {
104 let first_ty = ty;
105 let mut component_groups =
106 ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
107 component_groups.push(crate::proc::flatten_compose(
108 first_ty,
109 components,
110 eval.expressions,
111 eval.types,
112 ).collect());
113 component_groups.extend(
114 exprs
115 .map(|expr| {
116 sanitize!(expr).and_then(|expr| match expr {
117 &Expression::Compose { ty, ref components }
118 if &eval.types[ty].inner
119 == &eval.types[first_ty].inner =>
120 {
121 Ok(crate::proc::flatten_compose(
122 ty,
123 components,
124 eval.expressions,
125 eval.types,
126 ).collect())
127 }
128 _ => Err(err.clone()),
129 })
130 })
131 .collect::<Result<ArrayVec<_, { crate::VectorSize::MAX }>, _>>(
132 )?,
133 );
134 let component_groups = component_groups.into_inner().unwrap();
135 let mut new_components =
136 ArrayVec::<_, { crate::VectorSize::MAX }>::new();
137 for idx in 0..(size as u8).into() {
138 let group = component_groups
139 .iter()
140 .map(|cs| cs.get(idx).cloned().ok_or(err.clone()))
141 .collect::<Result<ArrayVec<_, N>, _>>()?
142 .into_inner()
143 .unwrap();
144 new_components.push($ident(
145 eval,
146 span,
147 group,
148 handler.clone(),
149 )?);
150 }
151 Ok(Expression::Compose {
152 ty: first_ty,
153 components: new_components.into_iter().collect(),
154 })
155 }
156 _ => return Err(err),
157 },
158 _ => return Err(err),
159 },
160 _ => return Err(err),
161 }?;
162 eval.register_evaluated_expr(new_expr, span)
163 }
164
165 with_dollar_sign! {
166 ($d:tt) => {
167 #[allow(unused)]
168 #[doc = concat!(
169 "A convenience macro for using the same RHS for each [`",
170 stringify!($target),
171 "`] variant in a call to [`",
172 stringify!($ident),
173 "`].",
174 )]
175 macro_rules! $ident {
176 (
177 $eval:expr,
178 $span:expr,
179 [$d ($d expr:expr),+ $d (,)?],
180 |$d ($d arg:ident),+| $d tt:tt
181 ) => {
182 $ident($eval, $span, [$d ($d expr),+], |args| match args {
183 $(
184 $target::$mapping([$d ($d arg),+]) => {
185 let res = $d tt;
186 Result::map(res, $target::$mapping)
187 },
188 )+
189 })
190 };
191 }
192 };
193 }
194 };
195}
196
197gen_component_wise_extractor! {
198 component_wise_scalar -> Scalar,
199 literals: [
200 AbstractFloat => AbstractFloat: f64,
201 F32 => F32: f32,
202 AbstractInt => AbstractInt: i64,
203 U32 => U32: u32,
204 I32 => I32: i32,
205 U64 => U64: u64,
206 I64 => I64: i64,
207 ],
208 scalar_kinds: [
209 Float,
210 AbstractFloat,
211 Sint,
212 Uint,
213 AbstractInt,
214 ],
215}
216
217gen_component_wise_extractor! {
218 component_wise_float -> Float,
219 literals: [
220 AbstractFloat => Abstract: f64,
221 F32 => F32: f32,
222 ],
223 scalar_kinds: [
224 Float,
225 AbstractFloat,
226 ],
227}
228
229gen_component_wise_extractor! {
230 component_wise_concrete_int -> ConcreteInt,
231 literals: [
232 U32 => U32: u32,
233 I32 => I32: i32,
234 ],
235 scalar_kinds: [
236 Sint,
237 Uint,
238 ],
239}
240
241gen_component_wise_extractor! {
242 component_wise_signed -> Signed,
243 literals: [
244 AbstractFloat => AbstractFloat: f64,
245 AbstractInt => AbstractInt: i64,
246 F32 => F32: f32,
247 I32 => I32: i32,
248 ],
249 scalar_kinds: [
250 Sint,
251 AbstractInt,
252 Float,
253 AbstractFloat,
254 ],
255}
256
257#[derive(Debug)]
258enum Behavior<'a> {
259 Wgsl(WgslRestrictions<'a>),
260 Glsl(GlslRestrictions<'a>),
261}
262
263impl Behavior<'_> {
264 const fn has_runtime_restrictions(&self) -> bool {
266 matches!(
267 self,
268 &Behavior::Wgsl(WgslRestrictions::Runtime(_))
269 | &Behavior::Glsl(GlslRestrictions::Runtime(_))
270 )
271 }
272}
273
274#[derive(Debug)]
292pub struct ConstantEvaluator<'a> {
293 behavior: Behavior<'a>,
295
296 types: &'a mut UniqueArena<Type>,
303
304 constants: &'a Arena<Constant>,
306
307 overrides: &'a Arena<Override>,
309
310 expressions: &'a mut Arena<Expression>,
312
313 expression_kind_tracker: &'a mut ExpressionKindTracker,
315}
316
317#[derive(Debug)]
318enum WgslRestrictions<'a> {
319 Const(Option<FunctionLocalData<'a>>),
321 Override,
324 Runtime(FunctionLocalData<'a>),
328}
329
330#[derive(Debug)]
331enum GlslRestrictions<'a> {
332 Const,
334 Runtime(FunctionLocalData<'a>),
338}
339
340#[derive(Debug)]
341struct FunctionLocalData<'a> {
342 global_expressions: &'a Arena<Expression>,
344 emitter: &'a mut super::Emitter,
345 block: &'a mut crate::Block,
346}
347
348#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
349pub enum ExpressionKind {
350 ImplConst,
352 Const,
353 Override,
354 Runtime,
355}
356
357#[derive(Debug)]
358pub struct ExpressionKindTracker {
359 inner: HandleVec<Expression, ExpressionKind>,
360}
361
362impl ExpressionKindTracker {
363 pub const fn new() -> Self {
364 Self {
365 inner: HandleVec::new(),
366 }
367 }
368
369 pub fn force_non_const(&mut self, value: Handle<Expression>) {
371 self.inner[value] = ExpressionKind::Runtime;
372 }
373
374 pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
375 self.inner.insert(value, expr_type);
376 }
377
378 pub fn is_const(&self, h: Handle<Expression>) -> bool {
379 matches!(
380 self.type_of(h),
381 ExpressionKind::Const | ExpressionKind::ImplConst
382 )
383 }
384
385 pub fn is_impl_const(&self, h: Handle<Expression>) -> bool {
387 matches!(self.type_of(h), ExpressionKind::ImplConst)
388 }
389
390 pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
391 matches!(
392 self.type_of(h),
393 ExpressionKind::Const | ExpressionKind::Override | ExpressionKind::ImplConst
394 )
395 }
396
397 fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
398 self.inner[value]
399 }
400
401 pub fn from_arena(arena: &Arena<Expression>) -> Self {
402 let mut tracker = Self {
403 inner: HandleVec::with_capacity(arena.len()),
404 };
405 for (handle, expr) in arena.iter() {
406 tracker
407 .inner
408 .insert(handle, tracker.type_of_with_expr(expr));
409 }
410 tracker
411 }
412
413 fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
414 use crate::MathFunction as Mf;
415 match *expr {
416 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
417 ExpressionKind::ImplConst
418 }
419 Expression::Override(_) => ExpressionKind::Override,
420 Expression::Compose { ref components, .. } => {
421 let mut expr_type = ExpressionKind::ImplConst;
422 for component in components {
423 expr_type = expr_type.max(self.type_of(*component))
424 }
425 expr_type
426 }
427 Expression::Splat { value, .. } => self.type_of(value),
428 Expression::AccessIndex { base, .. } => self.type_of(base),
429 Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
430 Expression::Swizzle { vector, .. } => self.type_of(vector),
431 Expression::Unary { expr, .. } => self.type_of(expr),
432 Expression::Binary { left, right, .. } => self
433 .type_of(left)
434 .max(self.type_of(right))
435 .max(ExpressionKind::Const),
436 Expression::Math {
437 fun,
438 arg,
439 arg1,
440 arg2,
441 arg3,
442 } => self
443 .type_of(arg)
444 .max(
445 arg1.map(|arg| self.type_of(arg))
446 .unwrap_or(ExpressionKind::Const),
447 )
448 .max(
449 arg2.map(|arg| self.type_of(arg))
450 .unwrap_or(ExpressionKind::Const),
451 )
452 .max(
453 arg3.map(|arg| self.type_of(arg))
454 .unwrap_or(ExpressionKind::Const),
455 )
456 .max(
457 if matches!(
458 fun,
459 Mf::Dot
460 | Mf::Outer
461 | Mf::Cross
462 | Mf::Distance
463 | Mf::Length
464 | Mf::Normalize
465 | Mf::FaceForward
466 | Mf::Reflect
467 | Mf::Refract
468 | Mf::Ldexp
469 | Mf::Modf
470 | Mf::Mix
471 | Mf::Frexp
472 ) {
473 ExpressionKind::Const
474 } else {
475 ExpressionKind::ImplConst
476 },
477 ),
478 Expression::As { convert, expr, .. } => self.type_of(expr).max(if convert.is_some() {
479 ExpressionKind::ImplConst
480 } else {
481 ExpressionKind::Const
482 }),
483 Expression::Select {
484 condition,
485 accept,
486 reject,
487 } => self
488 .type_of(condition)
489 .max(self.type_of(accept))
490 .max(self.type_of(reject))
491 .max(ExpressionKind::Const),
492 Expression::Relational { argument, .. } => self.type_of(argument),
493 Expression::ArrayLength(expr) => self.type_of(expr),
494 _ => ExpressionKind::Runtime,
495 }
496 }
497}
498
499#[derive(Clone, Debug, thiserror::Error)]
500#[cfg_attr(test, derive(PartialEq))]
501pub enum ConstantEvaluatorError {
502 #[error("Constants cannot access function arguments")]
503 FunctionArg,
504 #[error("Constants cannot access global variables")]
505 GlobalVariable,
506 #[error("Constants cannot access local variables")]
507 LocalVariable,
508 #[error("Cannot get the array length of a non array type")]
509 InvalidArrayLengthArg,
510 #[error("Constants cannot get the array length of a dynamically sized array")]
511 ArrayLengthDynamic,
512 #[error("Cannot call arrayLength on array sized by override-expression")]
513 ArrayLengthOverridden,
514 #[error("Constants cannot call functions")]
515 Call,
516 #[error("Constants don't support workGroupUniformLoad")]
517 WorkGroupUniformLoadResult,
518 #[error("Constants don't support atomic functions")]
519 Atomic,
520 #[error("Constants don't support derivative functions")]
521 Derivative,
522 #[error("Constants don't support load expressions")]
523 Load,
524 #[error("Constants don't support image expressions")]
525 ImageExpression,
526 #[error("Constants don't support ray query expressions")]
527 RayQueryExpression,
528 #[error("Constants don't support subgroup expressions")]
529 SubgroupExpression,
530 #[error("Cannot access the type")]
531 InvalidAccessBase,
532 #[error("Cannot access at the index")]
533 InvalidAccessIndex,
534 #[error("Cannot access with index of type")]
535 InvalidAccessIndexTy,
536 #[error("Constants don't support array length expressions")]
537 ArrayLength,
538 #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
539 InvalidCastArg { from: String, to: String },
540 #[error("Cannot apply the unary op to the argument")]
541 InvalidUnaryOpArg,
542 #[error("Cannot apply the binary op to the arguments")]
543 InvalidBinaryOpArgs,
544 #[error("Cannot apply math function to type")]
545 InvalidMathArg,
546 #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
547 InvalidMathArgCount(crate::MathFunction, usize, usize),
548 #[error("value of `low` is greater than `high` for clamp built-in function")]
549 InvalidClamp,
550 #[error("Splat is defined only on scalar values")]
551 SplatScalarOnly,
552 #[error("Can only swizzle vector constants")]
553 SwizzleVectorOnly,
554 #[error("swizzle component not present in source expression")]
555 SwizzleOutOfBounds,
556 #[error("Type is not constructible")]
557 TypeNotConstructible,
558 #[error("Subexpression(s) are not constant")]
559 SubexpressionsAreNotConstant,
560 #[error("Not implemented as constant expression: {0}")]
561 NotImplemented(String),
562 #[error("{0} operation overflowed")]
563 Overflow(String),
564 #[error(
565 "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
566 )]
567 AutomaticConversionLossy {
568 value: String,
569 to_type: &'static str,
570 },
571 #[error("abstract floating-point values cannot be automatically converted to integers")]
572 AutomaticConversionFloatToInt { to_type: &'static str },
573 #[error("Division by zero")]
574 DivisionByZero,
575 #[error("Remainder by zero")]
576 RemainderByZero,
577 #[error("RHS of shift operation is greater than or equal to 32")]
578 ShiftedMoreThan32Bits,
579 #[error(transparent)]
580 Literal(#[from] crate::valid::LiteralError),
581 #[error("Can't use pipeline-overridable constants in const-expressions")]
582 Override,
583 #[error("Unexpected runtime-expression")]
584 RuntimeExpr,
585 #[error("Unexpected override-expression")]
586 OverrideExpr,
587}
588
589impl<'a> ConstantEvaluator<'a> {
590 pub fn for_wgsl_module(
595 module: &'a mut crate::Module,
596 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
597 in_override_ctx: bool,
598 ) -> Self {
599 Self::for_module(
600 Behavior::Wgsl(if in_override_ctx {
601 WgslRestrictions::Override
602 } else {
603 WgslRestrictions::Const(None)
604 }),
605 module,
606 global_expression_kind_tracker,
607 )
608 }
609
610 pub fn for_glsl_module(
615 module: &'a mut crate::Module,
616 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
617 ) -> Self {
618 Self::for_module(
619 Behavior::Glsl(GlslRestrictions::Const),
620 module,
621 global_expression_kind_tracker,
622 )
623 }
624
625 fn for_module(
626 behavior: Behavior<'a>,
627 module: &'a mut crate::Module,
628 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
629 ) -> Self {
630 Self {
631 behavior,
632 types: &mut module.types,
633 constants: &module.constants,
634 overrides: &module.overrides,
635 expressions: &mut module.global_expressions,
636 expression_kind_tracker: global_expression_kind_tracker,
637 }
638 }
639
640 pub fn for_wgsl_function(
645 module: &'a mut crate::Module,
646 expressions: &'a mut Arena<Expression>,
647 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
648 emitter: &'a mut super::Emitter,
649 block: &'a mut crate::Block,
650 is_const: bool,
651 ) -> Self {
652 let local_data = FunctionLocalData {
653 global_expressions: &module.global_expressions,
654 emitter,
655 block,
656 };
657 Self {
658 behavior: Behavior::Wgsl(if is_const {
659 WgslRestrictions::Const(Some(local_data))
660 } else {
661 WgslRestrictions::Runtime(local_data)
662 }),
663 types: &mut module.types,
664 constants: &module.constants,
665 overrides: &module.overrides,
666 expressions,
667 expression_kind_tracker: local_expression_kind_tracker,
668 }
669 }
670
671 pub fn for_glsl_function(
676 module: &'a mut crate::Module,
677 expressions: &'a mut Arena<Expression>,
678 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
679 emitter: &'a mut super::Emitter,
680 block: &'a mut crate::Block,
681 ) -> Self {
682 Self {
683 behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
684 global_expressions: &module.global_expressions,
685 emitter,
686 block,
687 })),
688 types: &mut module.types,
689 constants: &module.constants,
690 overrides: &module.overrides,
691 expressions,
692 expression_kind_tracker: local_expression_kind_tracker,
693 }
694 }
695
696 pub fn to_ctx(&self) -> crate::proc::GlobalCtx {
697 crate::proc::GlobalCtx {
698 types: self.types,
699 constants: self.constants,
700 overrides: self.overrides,
701 global_expressions: match self.function_local_data() {
702 Some(data) => data.global_expressions,
703 None => self.expressions,
704 },
705 }
706 }
707
708 fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
709 if !self.expression_kind_tracker.is_const(expr) {
710 log::debug!("check: SubexpressionsAreNotConstant");
711 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
712 }
713 Ok(())
714 }
715
716 fn check_and_get(
717 &mut self,
718 expr: Handle<Expression>,
719 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
720 match self.expressions[expr] {
721 Expression::Constant(c) => {
722 if let Some(function_local_data) = self.function_local_data() {
725 self.copy_from(
727 self.constants[c].init,
728 function_local_data.global_expressions,
729 )
730 } else {
731 Ok(self.constants[c].init)
733 }
734 }
735 _ => {
736 self.check(expr)?;
737 Ok(expr)
738 }
739 }
740 }
741
742 pub fn try_eval_and_append(
766 &mut self,
767 expr: Expression,
768 span: Span,
769 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
770 match self.expression_kind_tracker.type_of_with_expr(&expr) {
771 ExpressionKind::ImplConst => self.try_eval_and_append_impl(&expr, span),
772 ExpressionKind::Const => {
773 let eval_result = self.try_eval_and_append_impl(&expr, span);
774 if self.behavior.has_runtime_restrictions()
779 && matches!(
780 eval_result,
781 Err(ConstantEvaluatorError::NotImplemented(_)
782 | ConstantEvaluatorError::InvalidBinaryOpArgs,)
783 )
784 {
785 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
786 } else {
787 eval_result
788 }
789 }
790 ExpressionKind::Override => match self.behavior {
791 Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
792 Ok(self.append_expr(expr, span, ExpressionKind::Override))
793 }
794 Behavior::Wgsl(WgslRestrictions::Const(_)) => {
795 Err(ConstantEvaluatorError::OverrideExpr)
796 }
797 Behavior::Glsl(_) => {
798 unreachable!()
799 }
800 },
801 ExpressionKind::Runtime => {
802 if self.behavior.has_runtime_restrictions() {
803 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
804 } else {
805 Err(ConstantEvaluatorError::RuntimeExpr)
806 }
807 }
808 }
809 }
810
811 const fn is_global_arena(&self) -> bool {
813 matches!(
814 self.behavior,
815 Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
816 | Behavior::Glsl(GlslRestrictions::Const)
817 )
818 }
819
820 const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
821 match self.behavior {
822 Behavior::Wgsl(
823 WgslRestrictions::Runtime(ref function_local_data)
824 | WgslRestrictions::Const(Some(ref function_local_data)),
825 )
826 | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
827 Some(function_local_data)
828 }
829 _ => None,
830 }
831 }
832
833 fn try_eval_and_append_impl(
834 &mut self,
835 expr: &Expression,
836 span: Span,
837 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
838 log::trace!("try_eval_and_append: {:?}", expr);
839 match *expr {
840 Expression::Constant(c) if self.is_global_arena() => {
841 Ok(self.constants[c].init)
844 }
845 Expression::Override(_) => Err(ConstantEvaluatorError::Override),
846 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
847 self.register_evaluated_expr(expr.clone(), span)
848 }
849 Expression::Compose { ty, ref components } => {
850 let components = components
851 .iter()
852 .map(|component| self.check_and_get(*component))
853 .collect::<Result<Vec<_>, _>>()?;
854 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
855 }
856 Expression::Splat { size, value } => {
857 let value = self.check_and_get(value)?;
858 self.register_evaluated_expr(Expression::Splat { size, value }, span)
859 }
860 Expression::AccessIndex { base, index } => {
861 let base = self.check_and_get(base)?;
862
863 self.access(base, index as usize, span)
864 }
865 Expression::Access { base, index } => {
866 let base = self.check_and_get(base)?;
867 let index = self.check_and_get(index)?;
868
869 self.access(base, self.constant_index(index)?, span)
870 }
871 Expression::Swizzle {
872 size,
873 vector,
874 pattern,
875 } => {
876 let vector = self.check_and_get(vector)?;
877
878 self.swizzle(size, span, vector, pattern)
879 }
880 Expression::Unary { expr, op } => {
881 let expr = self.check_and_get(expr)?;
882
883 self.unary_op(op, expr, span)
884 }
885 Expression::Binary { left, right, op } => {
886 let left = self.check_and_get(left)?;
887 let right = self.check_and_get(right)?;
888
889 self.binary_op(op, left, right, span)
890 }
891 Expression::Math {
892 fun,
893 arg,
894 arg1,
895 arg2,
896 arg3,
897 } => {
898 let arg = self.check_and_get(arg)?;
899 let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
900 let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
901 let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
902
903 self.math(arg, arg1, arg2, arg3, fun, span)
904 }
905 Expression::As {
906 convert,
907 expr,
908 kind,
909 } => {
910 let expr = self.check_and_get(expr)?;
911
912 match convert {
913 Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
914 None => Err(ConstantEvaluatorError::NotImplemented(
915 "bitcast built-in function".into(),
916 )),
917 }
918 }
919 Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented(
920 "select built-in function".into(),
921 )),
922 Expression::Relational { fun, .. } => Err(ConstantEvaluatorError::NotImplemented(
923 format!("{fun:?} built-in function"),
924 )),
925 Expression::ArrayLength(expr) => match self.behavior {
926 Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
927 Behavior::Glsl(_) => {
928 let expr = self.check_and_get(expr)?;
929 self.array_length(expr, span)
930 }
931 },
932 Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
933 Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
934 Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
935 Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
936 Expression::WorkGroupUniformLoadResult { .. } => {
937 Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
938 }
939 Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
940 Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
941 Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
942 Expression::ImageSample { .. }
943 | Expression::ImageLoad { .. }
944 | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
945 Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => {
946 Err(ConstantEvaluatorError::RayQueryExpression)
947 }
948 Expression::SubgroupBallotResult { .. } => {
949 Err(ConstantEvaluatorError::SubgroupExpression)
950 }
951 Expression::SubgroupOperationResult { .. } => {
952 Err(ConstantEvaluatorError::SubgroupExpression)
953 }
954 }
955 }
956
957 fn splat(
970 &mut self,
971 value: Handle<Expression>,
972 size: crate::VectorSize,
973 span: Span,
974 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
975 match self.expressions[value] {
976 Expression::Literal(literal) => {
977 let scalar = literal.scalar();
978 let ty = self.types.insert(
979 Type {
980 name: None,
981 inner: TypeInner::Vector { size, scalar },
982 },
983 span,
984 );
985 let expr = Expression::Compose {
986 ty,
987 components: vec![value; size as usize],
988 };
989 self.register_evaluated_expr(expr, span)
990 }
991 Expression::ZeroValue(ty) => {
992 let inner = match self.types[ty].inner {
993 TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
994 _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
995 };
996 let res_ty = self.types.insert(Type { name: None, inner }, span);
997 let expr = Expression::ZeroValue(res_ty);
998 self.register_evaluated_expr(expr, span)
999 }
1000 _ => Err(ConstantEvaluatorError::SplatScalarOnly),
1001 }
1002 }
1003
1004 fn swizzle(
1005 &mut self,
1006 size: crate::VectorSize,
1007 span: Span,
1008 src_constant: Handle<Expression>,
1009 pattern: [crate::SwizzleComponent; 4],
1010 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1011 let mut get_dst_ty = |ty| match self.types[ty].inner {
1012 TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1013 Type {
1014 name: None,
1015 inner: TypeInner::Vector { size, scalar },
1016 },
1017 span,
1018 )),
1019 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1020 };
1021
1022 match self.expressions[src_constant] {
1023 Expression::ZeroValue(ty) => {
1024 let dst_ty = get_dst_ty(ty)?;
1025 let expr = Expression::ZeroValue(dst_ty);
1026 self.register_evaluated_expr(expr, span)
1027 }
1028 Expression::Splat { value, .. } => {
1029 let expr = Expression::Splat { size, value };
1030 self.register_evaluated_expr(expr, span)
1031 }
1032 Expression::Compose { ty, ref components } => {
1033 let dst_ty = get_dst_ty(ty)?;
1034
1035 let mut flattened = [src_constant; 4]; let len =
1037 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1038 .zip(flattened.iter_mut())
1039 .map(|(component, elt)| *elt = component)
1040 .count();
1041 let flattened = &flattened[..len];
1042
1043 let swizzled_components = pattern[..size as usize]
1044 .iter()
1045 .map(|&sc| {
1046 let sc = sc as usize;
1047 if let Some(elt) = flattened.get(sc) {
1048 Ok(*elt)
1049 } else {
1050 Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1051 }
1052 })
1053 .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1054 let expr = Expression::Compose {
1055 ty: dst_ty,
1056 components: swizzled_components,
1057 };
1058 self.register_evaluated_expr(expr, span)
1059 }
1060 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1061 }
1062 }
1063
1064 fn math(
1065 &mut self,
1066 arg: Handle<Expression>,
1067 arg1: Option<Handle<Expression>>,
1068 arg2: Option<Handle<Expression>>,
1069 arg3: Option<Handle<Expression>>,
1070 fun: crate::MathFunction,
1071 span: Span,
1072 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1073 let expected = fun.argument_count();
1074 let given = Some(arg)
1075 .into_iter()
1076 .chain(arg1)
1077 .chain(arg2)
1078 .chain(arg3)
1079 .count();
1080 if expected != given {
1081 return Err(ConstantEvaluatorError::InvalidMathArgCount(
1082 fun, expected, given,
1083 ));
1084 }
1085
1086 match fun {
1088 crate::MathFunction::Abs => {
1090 component_wise_scalar(self, span, [arg], |args| match args {
1091 Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1092 Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1093 Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.abs()])),
1094 Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1095 Scalar::U32([e]) => Ok(Scalar::U32([e])), Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1097 Scalar::U64([e]) => Ok(Scalar::U64([e])),
1098 })
1099 }
1100 crate::MathFunction::Min => {
1101 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1102 Ok([e1.min(e2)])
1103 })
1104 }
1105 crate::MathFunction::Max => {
1106 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1107 Ok([e1.max(e2)])
1108 })
1109 }
1110 crate::MathFunction::Clamp => {
1111 component_wise_scalar!(
1112 self,
1113 span,
1114 [arg, arg1.unwrap(), arg2.unwrap()],
1115 |e, low, high| {
1116 if low > high {
1117 Err(ConstantEvaluatorError::InvalidClamp)
1118 } else {
1119 Ok([e.clamp(low, high)])
1120 }
1121 }
1122 )
1123 }
1124 crate::MathFunction::Saturate => {
1125 component_wise_float!(self, span, [arg], |e| { Ok([e.clamp(0., 1.)]) })
1126 }
1127
1128 crate::MathFunction::Cos => {
1130 component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1131 }
1132 crate::MathFunction::Cosh => {
1133 component_wise_float!(self, span, [arg], |e| { Ok([e.cosh()]) })
1134 }
1135 crate::MathFunction::Sin => {
1136 component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1137 }
1138 crate::MathFunction::Sinh => {
1139 component_wise_float!(self, span, [arg], |e| { Ok([e.sinh()]) })
1140 }
1141 crate::MathFunction::Tan => {
1142 component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1143 }
1144 crate::MathFunction::Tanh => {
1145 component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1146 }
1147 crate::MathFunction::Acos => {
1148 component_wise_float!(self, span, [arg], |e| { Ok([e.acos()]) })
1149 }
1150 crate::MathFunction::Asin => {
1151 component_wise_float!(self, span, [arg], |e| { Ok([e.asin()]) })
1152 }
1153 crate::MathFunction::Atan => {
1154 component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1155 }
1156 crate::MathFunction::Asinh => {
1157 component_wise_float!(self, span, [arg], |e| { Ok([e.asinh()]) })
1158 }
1159 crate::MathFunction::Acosh => {
1160 component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) })
1161 }
1162 crate::MathFunction::Atanh => {
1163 component_wise_float!(self, span, [arg], |e| { Ok([e.atanh()]) })
1164 }
1165 crate::MathFunction::Radians => {
1166 component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1167 }
1168 crate::MathFunction::Degrees => {
1169 component_wise_float!(self, span, [arg], |e| { Ok([e.to_degrees()]) })
1170 }
1171
1172 crate::MathFunction::Ceil => {
1174 component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1175 }
1176 crate::MathFunction::Floor => {
1177 component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1178 }
1179 crate::MathFunction::Round => {
1180 fn round_ties_even(x: f64) -> f64 {
1187 let i = x as i64;
1188 let f = (x - i as f64).abs();
1189 if f == 0.5 {
1190 if i & 1 == 1 {
1191 (x.abs() + 0.5).copysign(x)
1193 } else {
1194 (x.abs() - 0.5).copysign(x)
1195 }
1196 } else {
1197 x.round()
1198 }
1199 }
1200 component_wise_float(self, span, [arg], |e| match e {
1201 Float::Abstract([e]) => Ok(Float::Abstract([round_ties_even(e)])),
1202 Float::F32([e]) => Ok(Float::F32([(round_ties_even(e as f64) as f32)])),
1203 })
1204 }
1205 crate::MathFunction::Fract => {
1206 component_wise_float!(self, span, [arg], |e| {
1207 Ok([e - e.floor()])
1210 })
1211 }
1212 crate::MathFunction::Trunc => {
1213 component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1214 }
1215
1216 crate::MathFunction::Exp => {
1218 component_wise_float!(self, span, [arg], |e| { Ok([e.exp()]) })
1219 }
1220 crate::MathFunction::Exp2 => {
1221 component_wise_float!(self, span, [arg], |e| { Ok([e.exp2()]) })
1222 }
1223 crate::MathFunction::Log => {
1224 component_wise_float!(self, span, [arg], |e| { Ok([e.ln()]) })
1225 }
1226 crate::MathFunction::Log2 => {
1227 component_wise_float!(self, span, [arg], |e| { Ok([e.log2()]) })
1228 }
1229 crate::MathFunction::Pow => {
1230 component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1231 Ok([e1.powf(e2)])
1232 })
1233 }
1234
1235 crate::MathFunction::Sign => {
1237 component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) })
1238 }
1239 crate::MathFunction::Fma => {
1240 component_wise_float!(
1241 self,
1242 span,
1243 [arg, arg1.unwrap(), arg2.unwrap()],
1244 |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1245 )
1246 }
1247 crate::MathFunction::Step => {
1248 component_wise_float!(self, span, [arg, arg1.unwrap()], |edge, x| {
1249 Ok([if edge <= x { 1.0 } else { 0.0 }])
1250 })
1251 }
1252 crate::MathFunction::Sqrt => {
1253 component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) })
1254 }
1255 crate::MathFunction::InverseSqrt => {
1256 component_wise_float!(self, span, [arg], |e| { Ok([1. / e.sqrt()]) })
1257 }
1258
1259 crate::MathFunction::CountTrailingZeros => {
1261 component_wise_concrete_int!(self, span, [arg], |e| {
1262 #[allow(clippy::useless_conversion)]
1263 Ok([e
1264 .trailing_zeros()
1265 .try_into()
1266 .expect("bit count overflowed 32 bits, somehow!?")])
1267 })
1268 }
1269 crate::MathFunction::CountLeadingZeros => {
1270 component_wise_concrete_int!(self, span, [arg], |e| {
1271 #[allow(clippy::useless_conversion)]
1272 Ok([e
1273 .leading_zeros()
1274 .try_into()
1275 .expect("bit count overflowed 32 bits, somehow!?")])
1276 })
1277 }
1278 crate::MathFunction::CountOneBits => {
1279 component_wise_concrete_int!(self, span, [arg], |e| {
1280 #[allow(clippy::useless_conversion)]
1281 Ok([e
1282 .count_ones()
1283 .try_into()
1284 .expect("bit count overflowed 32 bits, somehow!?")])
1285 })
1286 }
1287 crate::MathFunction::ReverseBits => {
1288 component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1289 }
1290 crate::MathFunction::FirstTrailingBit => {
1291 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1292 }
1293 crate::MathFunction::FirstLeadingBit => {
1294 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1295 }
1296
1297 fun => Err(ConstantEvaluatorError::NotImplemented(format!(
1298 "{fun:?} built-in function"
1299 ))),
1300 }
1301 }
1302
1303 fn array_length(
1304 &mut self,
1305 array: Handle<Expression>,
1306 span: Span,
1307 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1308 match self.expressions[array] {
1309 Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
1310 match self.types[ty].inner {
1311 TypeInner::Array { size, .. } => match size {
1312 ArraySize::Constant(len) => {
1313 let expr = Expression::Literal(Literal::U32(len.get()));
1314 self.register_evaluated_expr(expr, span)
1315 }
1316 ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden),
1317 ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
1318 },
1319 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
1320 }
1321 }
1322 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
1323 }
1324 }
1325
1326 fn access(
1327 &mut self,
1328 base: Handle<Expression>,
1329 index: usize,
1330 span: Span,
1331 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1332 match self.expressions[base] {
1333 Expression::ZeroValue(ty) => {
1334 let ty_inner = &self.types[ty].inner;
1335 let components = ty_inner
1336 .components()
1337 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
1338
1339 if index >= components as usize {
1340 Err(ConstantEvaluatorError::InvalidAccessBase)
1341 } else {
1342 let ty_res = ty_inner
1343 .component_type(index)
1344 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
1345 let ty = match ty_res {
1346 crate::proc::TypeResolution::Handle(ty) => ty,
1347 crate::proc::TypeResolution::Value(inner) => {
1348 self.types.insert(Type { name: None, inner }, span)
1349 }
1350 };
1351 self.register_evaluated_expr(Expression::ZeroValue(ty), span)
1352 }
1353 }
1354 Expression::Splat { size, value } => {
1355 if index >= size as usize {
1356 Err(ConstantEvaluatorError::InvalidAccessBase)
1357 } else {
1358 Ok(value)
1359 }
1360 }
1361 Expression::Compose { ty, ref components } => {
1362 let _ = self.types[ty]
1363 .inner
1364 .components()
1365 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
1366
1367 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1368 .nth(index)
1369 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
1370 }
1371 _ => Err(ConstantEvaluatorError::InvalidAccessBase),
1372 }
1373 }
1374
1375 fn constant_index(&self, expr: Handle<Expression>) -> Result<usize, ConstantEvaluatorError> {
1376 match self.expressions[expr] {
1377 Expression::ZeroValue(ty)
1378 if matches!(
1379 self.types[ty].inner,
1380 TypeInner::Scalar(crate::Scalar {
1381 kind: ScalarKind::Uint,
1382 ..
1383 })
1384 ) =>
1385 {
1386 Ok(0)
1387 }
1388 Expression::Literal(Literal::U32(index)) => Ok(index as usize),
1389 _ => Err(ConstantEvaluatorError::InvalidAccessIndexTy),
1390 }
1391 }
1392
1393 fn eval_zero_value_and_splat(
1400 &mut self,
1401 expr: Handle<Expression>,
1402 span: Span,
1403 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1404 match self.expressions[expr] {
1405 Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
1406 Expression::Splat { size, value } => self.splat(value, size, span),
1407 _ => Ok(expr),
1408 }
1409 }
1410
1411 fn eval_zero_value(
1417 &mut self,
1418 expr: Handle<Expression>,
1419 span: Span,
1420 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1421 match self.expressions[expr] {
1422 Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
1423 _ => Ok(expr),
1424 }
1425 }
1426
1427 fn eval_zero_value_impl(
1433 &mut self,
1434 ty: Handle<Type>,
1435 span: Span,
1436 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1437 match self.types[ty].inner {
1438 TypeInner::Scalar(scalar) => {
1439 let expr = Expression::Literal(
1440 Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
1441 );
1442 self.register_evaluated_expr(expr, span)
1443 }
1444 TypeInner::Vector { size, scalar } => {
1445 let scalar_ty = self.types.insert(
1446 Type {
1447 name: None,
1448 inner: TypeInner::Scalar(scalar),
1449 },
1450 span,
1451 );
1452 let el = self.eval_zero_value_impl(scalar_ty, span)?;
1453 let expr = Expression::Compose {
1454 ty,
1455 components: vec![el; size as usize],
1456 };
1457 self.register_evaluated_expr(expr, span)
1458 }
1459 TypeInner::Matrix {
1460 columns,
1461 rows,
1462 scalar,
1463 } => {
1464 let vec_ty = self.types.insert(
1465 Type {
1466 name: None,
1467 inner: TypeInner::Vector { size: rows, scalar },
1468 },
1469 span,
1470 );
1471 let el = self.eval_zero_value_impl(vec_ty, span)?;
1472 let expr = Expression::Compose {
1473 ty,
1474 components: vec![el; columns as usize],
1475 };
1476 self.register_evaluated_expr(expr, span)
1477 }
1478 TypeInner::Array {
1479 base,
1480 size: ArraySize::Constant(size),
1481 ..
1482 } => {
1483 let el = self.eval_zero_value_impl(base, span)?;
1484 let expr = Expression::Compose {
1485 ty,
1486 components: vec![el; size.get() as usize],
1487 };
1488 self.register_evaluated_expr(expr, span)
1489 }
1490 TypeInner::Struct { ref members, .. } => {
1491 let types: Vec<_> = members.iter().map(|m| m.ty).collect();
1492 let mut components = Vec::with_capacity(members.len());
1493 for ty in types {
1494 components.push(self.eval_zero_value_impl(ty, span)?);
1495 }
1496 let expr = Expression::Compose { ty, components };
1497 self.register_evaluated_expr(expr, span)
1498 }
1499 _ => Err(ConstantEvaluatorError::TypeNotConstructible),
1500 }
1501 }
1502
1503 pub fn cast(
1507 &mut self,
1508 expr: Handle<Expression>,
1509 target: crate::Scalar,
1510 span: Span,
1511 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1512 use crate::Scalar as Sc;
1513
1514 let expr = self.eval_zero_value(expr, span)?;
1515
1516 let make_error = || -> Result<_, ConstantEvaluatorError> {
1517 let from = format!("{:?} {:?}", expr, self.expressions[expr]);
1518
1519 #[cfg(feature = "wgsl-in")]
1520 let to = target.to_wgsl();
1521
1522 #[cfg(not(feature = "wgsl-in"))]
1523 let to = format!("{target:?}");
1524
1525 Err(ConstantEvaluatorError::InvalidCastArg { from, to })
1526 };
1527
1528 let expr = match self.expressions[expr] {
1529 Expression::Literal(literal) => {
1530 let literal = match target {
1531 Sc::I32 => Literal::I32(match literal {
1532 Literal::I32(v) => v,
1533 Literal::U32(v) => v as i32,
1534 Literal::F32(v) => v as i32,
1535 Literal::Bool(v) => v as i32,
1536 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1537 return make_error();
1538 }
1539 Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
1540 Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
1541 }),
1542 Sc::U32 => Literal::U32(match literal {
1543 Literal::I32(v) => v as u32,
1544 Literal::U32(v) => v,
1545 Literal::F32(v) => v as u32,
1546 Literal::Bool(v) => v as u32,
1547 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1548 return make_error();
1549 }
1550 Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
1551 Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
1552 }),
1553 Sc::I64 => Literal::I64(match literal {
1554 Literal::I32(v) => v as i64,
1555 Literal::U32(v) => v as i64,
1556 Literal::F32(v) => v as i64,
1557 Literal::Bool(v) => v as i64,
1558 Literal::F64(v) => v as i64,
1559 Literal::I64(v) => v,
1560 Literal::U64(v) => v as i64,
1561 Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
1562 Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
1563 }),
1564 Sc::U64 => Literal::U64(match literal {
1565 Literal::I32(v) => v as u64,
1566 Literal::U32(v) => v as u64,
1567 Literal::F32(v) => v as u64,
1568 Literal::Bool(v) => v as u64,
1569 Literal::F64(v) => v as u64,
1570 Literal::I64(v) => v as u64,
1571 Literal::U64(v) => v,
1572 Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
1573 Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
1574 }),
1575 Sc::F32 => Literal::F32(match literal {
1576 Literal::I32(v) => v as f32,
1577 Literal::U32(v) => v as f32,
1578 Literal::F32(v) => v,
1579 Literal::Bool(v) => v as u32 as f32,
1580 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1581 return make_error();
1582 }
1583 Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
1584 Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
1585 }),
1586 Sc::F64 => Literal::F64(match literal {
1587 Literal::I32(v) => v as f64,
1588 Literal::U32(v) => v as f64,
1589 Literal::F32(v) => v as f64,
1590 Literal::F64(v) => v,
1591 Literal::Bool(v) => v as u32 as f64,
1592 Literal::I64(_) | Literal::U64(_) => return make_error(),
1593 Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
1594 Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
1595 }),
1596 Sc::BOOL => Literal::Bool(match literal {
1597 Literal::I32(v) => v != 0,
1598 Literal::U32(v) => v != 0,
1599 Literal::F32(v) => v != 0.0,
1600 Literal::Bool(v) => v,
1601 Literal::F64(_)
1602 | Literal::I64(_)
1603 | Literal::U64(_)
1604 | Literal::AbstractInt(_)
1605 | Literal::AbstractFloat(_) => {
1606 return make_error();
1607 }
1608 }),
1609 Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
1610 Literal::AbstractInt(v) => {
1611 v as f64
1616 }
1617 Literal::AbstractFloat(v) => v,
1618 _ => return make_error(),
1619 }),
1620 _ => {
1621 log::debug!("Constant evaluator refused to convert value to {target:?}");
1622 return make_error();
1623 }
1624 };
1625 Expression::Literal(literal)
1626 }
1627 Expression::Compose {
1628 ty,
1629 components: ref src_components,
1630 } => {
1631 let ty_inner = match self.types[ty].inner {
1632 TypeInner::Vector { size, .. } => TypeInner::Vector {
1633 size,
1634 scalar: target,
1635 },
1636 TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
1637 columns,
1638 rows,
1639 scalar: target,
1640 },
1641 _ => return make_error(),
1642 };
1643
1644 let mut components = src_components.clone();
1645 for component in &mut components {
1646 *component = self.cast(*component, target, span)?;
1647 }
1648
1649 let ty = self.types.insert(
1650 Type {
1651 name: None,
1652 inner: ty_inner,
1653 },
1654 span,
1655 );
1656
1657 Expression::Compose { ty, components }
1658 }
1659 Expression::Splat { size, value } => {
1660 let value_span = self.expressions.get_span(value);
1661 let cast_value = self.cast(value, target, value_span)?;
1662 Expression::Splat {
1663 size,
1664 value: cast_value,
1665 }
1666 }
1667 _ => return make_error(),
1668 };
1669
1670 self.register_evaluated_expr(expr, span)
1671 }
1672
1673 pub fn cast_array(
1686 &mut self,
1687 expr: Handle<Expression>,
1688 target: crate::Scalar,
1689 span: Span,
1690 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1691 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
1692 return self.cast(expr, target, span);
1693 };
1694
1695 let TypeInner::Array {
1696 base: _,
1697 size,
1698 stride: _,
1699 } = self.types[ty].inner
1700 else {
1701 return self.cast(expr, target, span);
1702 };
1703
1704 let mut components = components.clone();
1705 for component in &mut components {
1706 *component = self.cast_array(*component, target, span)?;
1707 }
1708
1709 let first = components.first().unwrap();
1710 let new_base = match self.resolve_type(*first)? {
1711 crate::proc::TypeResolution::Handle(ty) => ty,
1712 crate::proc::TypeResolution::Value(inner) => {
1713 self.types.insert(Type { name: None, inner }, span)
1714 }
1715 };
1716 let new_base_stride = self.types[new_base].inner.size(self.to_ctx());
1717 let new_array_ty = self.types.insert(
1718 Type {
1719 name: None,
1720 inner: TypeInner::Array {
1721 base: new_base,
1722 size,
1723 stride: new_base_stride,
1724 },
1725 },
1726 span,
1727 );
1728
1729 let compose = Expression::Compose {
1730 ty: new_array_ty,
1731 components,
1732 };
1733 self.register_evaluated_expr(compose, span)
1734 }
1735
1736 fn unary_op(
1737 &mut self,
1738 op: UnaryOperator,
1739 expr: Handle<Expression>,
1740 span: Span,
1741 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1742 let expr = self.eval_zero_value_and_splat(expr, span)?;
1743
1744 let expr = match self.expressions[expr] {
1745 Expression::Literal(value) => Expression::Literal(match op {
1746 UnaryOperator::Negate => match value {
1747 Literal::I32(v) => Literal::I32(v.wrapping_neg()),
1748 Literal::I64(v) => Literal::I64(v.wrapping_neg()),
1749 Literal::F32(v) => Literal::F32(-v),
1750 Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
1751 Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
1752 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1753 },
1754 UnaryOperator::LogicalNot => match value {
1755 Literal::Bool(v) => Literal::Bool(!v),
1756 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1757 },
1758 UnaryOperator::BitwiseNot => match value {
1759 Literal::I32(v) => Literal::I32(!v),
1760 Literal::I64(v) => Literal::I64(!v),
1761 Literal::U32(v) => Literal::U32(!v),
1762 Literal::U64(v) => Literal::U64(!v),
1763 Literal::AbstractInt(v) => Literal::AbstractInt(!v),
1764 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1765 },
1766 }),
1767 Expression::Compose {
1768 ty,
1769 components: ref src_components,
1770 } => {
1771 match self.types[ty].inner {
1772 TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
1773 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1774 }
1775
1776 let mut components = src_components.clone();
1777 for component in &mut components {
1778 *component = self.unary_op(op, *component, span)?;
1779 }
1780
1781 Expression::Compose { ty, components }
1782 }
1783 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1784 };
1785
1786 self.register_evaluated_expr(expr, span)
1787 }
1788
1789 fn binary_op(
1790 &mut self,
1791 op: BinaryOperator,
1792 left: Handle<Expression>,
1793 right: Handle<Expression>,
1794 span: Span,
1795 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1796 let left = self.eval_zero_value_and_splat(left, span)?;
1797 let right = self.eval_zero_value_and_splat(right, span)?;
1798
1799 let expr = match (&self.expressions[left], &self.expressions[right]) {
1800 (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
1801 let literal = match op {
1802 BinaryOperator::Equal => Literal::Bool(left_value == right_value),
1803 BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
1804 BinaryOperator::Less => Literal::Bool(left_value < right_value),
1805 BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
1806 BinaryOperator::Greater => Literal::Bool(left_value > right_value),
1807 BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
1808
1809 _ => match (left_value, right_value) {
1810 (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
1811 BinaryOperator::Add => a.wrapping_add(b),
1812 BinaryOperator::Subtract => a.wrapping_sub(b),
1813 BinaryOperator::Multiply => a.wrapping_mul(b),
1814 BinaryOperator::Divide => {
1815 if b == 0 {
1816 return Err(ConstantEvaluatorError::DivisionByZero);
1817 } else {
1818 a.wrapping_div(b)
1819 }
1820 }
1821 BinaryOperator::Modulo => {
1822 if b == 0 {
1823 return Err(ConstantEvaluatorError::RemainderByZero);
1824 } else {
1825 a.wrapping_rem(b)
1826 }
1827 }
1828 BinaryOperator::And => a & b,
1829 BinaryOperator::ExclusiveOr => a ^ b,
1830 BinaryOperator::InclusiveOr => a | b,
1831 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1832 }),
1833 (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
1834 BinaryOperator::ShiftLeft => {
1835 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
1836 return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
1837 }
1838 a.checked_shl(b)
1839 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
1840 }
1841 BinaryOperator::ShiftRight => a
1842 .checked_shr(b)
1843 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
1844 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1845 }),
1846 (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
1847 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
1848 ConstantEvaluatorError::Overflow("addition".into())
1849 })?,
1850 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
1851 ConstantEvaluatorError::Overflow("subtraction".into())
1852 })?,
1853 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
1854 ConstantEvaluatorError::Overflow("multiplication".into())
1855 })?,
1856 BinaryOperator::Divide => a
1857 .checked_div(b)
1858 .ok_or(ConstantEvaluatorError::DivisionByZero)?,
1859 BinaryOperator::Modulo => a
1860 .checked_rem(b)
1861 .ok_or(ConstantEvaluatorError::RemainderByZero)?,
1862 BinaryOperator::And => a & b,
1863 BinaryOperator::ExclusiveOr => a ^ b,
1864 BinaryOperator::InclusiveOr => a | b,
1865 BinaryOperator::ShiftLeft => a
1866 .checked_mul(
1867 1u32.checked_shl(b)
1868 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
1869 )
1870 .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
1871 BinaryOperator::ShiftRight => a
1872 .checked_shr(b)
1873 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
1874 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1875 }),
1876 (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
1877 BinaryOperator::Add => a + b,
1878 BinaryOperator::Subtract => a - b,
1879 BinaryOperator::Multiply => a * b,
1880 BinaryOperator::Divide => a / b,
1881 BinaryOperator::Modulo => a % b,
1882 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1883 }),
1884 (Literal::AbstractInt(a), Literal::U32(b)) => {
1885 Literal::AbstractInt(match op {
1886 BinaryOperator::ShiftLeft => {
1887 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
1888 return Err(ConstantEvaluatorError::Overflow(
1889 "<<".to_string(),
1890 ));
1891 }
1892 a.checked_shl(b).unwrap_or(0)
1893 }
1894 BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0),
1895 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1896 })
1897 }
1898 (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
1899 Literal::AbstractInt(match op {
1900 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
1901 ConstantEvaluatorError::Overflow("addition".into())
1902 })?,
1903 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
1904 ConstantEvaluatorError::Overflow("subtraction".into())
1905 })?,
1906 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
1907 ConstantEvaluatorError::Overflow("multiplication".into())
1908 })?,
1909 BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
1910 if b == 0 {
1911 ConstantEvaluatorError::DivisionByZero
1912 } else {
1913 ConstantEvaluatorError::Overflow("division".into())
1914 }
1915 })?,
1916 BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
1917 if b == 0 {
1918 ConstantEvaluatorError::RemainderByZero
1919 } else {
1920 ConstantEvaluatorError::Overflow("remainder".into())
1921 }
1922 })?,
1923 BinaryOperator::And => a & b,
1924 BinaryOperator::ExclusiveOr => a ^ b,
1925 BinaryOperator::InclusiveOr => a | b,
1926 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1927 })
1928 }
1929 (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
1930 Literal::AbstractFloat(match op {
1931 BinaryOperator::Add => a + b,
1932 BinaryOperator::Subtract => a - b,
1933 BinaryOperator::Multiply => a * b,
1934 BinaryOperator::Divide => a / b,
1935 BinaryOperator::Modulo => a % b,
1936 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1937 })
1938 }
1939 (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
1940 BinaryOperator::LogicalAnd => a && b,
1941 BinaryOperator::LogicalOr => a || b,
1942 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1943 }),
1944 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1945 },
1946 };
1947 Expression::Literal(literal)
1948 }
1949 (
1950 &Expression::Compose {
1951 components: ref src_components,
1952 ty,
1953 },
1954 &Expression::Literal(_),
1955 ) => {
1956 let mut components = src_components.clone();
1957 for component in &mut components {
1958 *component = self.binary_op(op, *component, right, span)?;
1959 }
1960 Expression::Compose { ty, components }
1961 }
1962 (
1963 &Expression::Literal(_),
1964 &Expression::Compose {
1965 components: ref src_components,
1966 ty,
1967 },
1968 ) => {
1969 let mut components = src_components.clone();
1970 for component in &mut components {
1971 *component = self.binary_op(op, left, *component, span)?;
1972 }
1973 Expression::Compose { ty, components }
1974 }
1975 (
1976 &Expression::Compose {
1977 components: ref left_components,
1978 ty: left_ty,
1979 },
1980 &Expression::Compose {
1981 components: ref right_components,
1982 ty: right_ty,
1983 },
1984 ) => {
1985 let left_flattened = crate::proc::flatten_compose(
1989 left_ty,
1990 left_components,
1991 self.expressions,
1992 self.types,
1993 );
1994 let right_flattened = crate::proc::flatten_compose(
1995 right_ty,
1996 right_components,
1997 self.expressions,
1998 self.types,
1999 );
2000
2001 let mut flattened = Vec::with_capacity(left_components.len());
2004 flattened.extend(left_flattened.zip(right_flattened));
2005
2006 match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
2007 (
2008 &TypeInner::Vector {
2009 size: left_size, ..
2010 },
2011 &TypeInner::Vector {
2012 size: right_size, ..
2013 },
2014 ) if left_size == right_size => {
2015 self.binary_op_vector(op, left_size, &flattened, left_ty, span)?
2016 }
2017 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2018 }
2019 }
2020 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2021 };
2022
2023 self.register_evaluated_expr(expr, span)
2024 }
2025
2026 fn binary_op_vector(
2027 &mut self,
2028 op: BinaryOperator,
2029 size: crate::VectorSize,
2030 components: &[(Handle<Expression>, Handle<Expression>)],
2031 left_ty: Handle<Type>,
2032 span: Span,
2033 ) -> Result<Expression, ConstantEvaluatorError> {
2034 let ty = match op {
2035 BinaryOperator::Equal
2037 | BinaryOperator::NotEqual
2038 | BinaryOperator::Less
2039 | BinaryOperator::LessEqual
2040 | BinaryOperator::Greater
2041 | BinaryOperator::GreaterEqual => self.types.insert(
2042 Type {
2043 name: None,
2044 inner: TypeInner::Vector {
2045 size,
2046 scalar: crate::Scalar::BOOL,
2047 },
2048 },
2049 span,
2050 ),
2051
2052 BinaryOperator::Add
2055 | BinaryOperator::Subtract
2056 | BinaryOperator::Multiply
2057 | BinaryOperator::Divide
2058 | BinaryOperator::Modulo
2059 | BinaryOperator::And
2060 | BinaryOperator::ExclusiveOr
2061 | BinaryOperator::InclusiveOr
2062 | BinaryOperator::LogicalAnd
2063 | BinaryOperator::LogicalOr
2064 | BinaryOperator::ShiftLeft
2065 | BinaryOperator::ShiftRight => left_ty,
2066 };
2067
2068 let components = components
2069 .iter()
2070 .map(|&(left, right)| self.binary_op(op, left, right, span))
2071 .collect::<Result<Vec<_>, _>>()?;
2072
2073 Ok(Expression::Compose { ty, components })
2074 }
2075
2076 fn copy_from(
2084 &mut self,
2085 expr: Handle<Expression>,
2086 expressions: &Arena<Expression>,
2087 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2088 let span = expressions.get_span(expr);
2089 match expressions[expr] {
2090 ref expr @ (Expression::Literal(_)
2091 | Expression::Constant(_)
2092 | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
2093 Expression::Compose { ty, ref components } => {
2094 let mut components = components.clone();
2095 for component in &mut components {
2096 *component = self.copy_from(*component, expressions)?;
2097 }
2098 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
2099 }
2100 Expression::Splat { size, value } => {
2101 let value = self.copy_from(value, expressions)?;
2102 self.register_evaluated_expr(Expression::Splat { size, value }, span)
2103 }
2104 _ => {
2105 log::debug!("copy_from: SubexpressionsAreNotConstant");
2106 Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
2107 }
2108 }
2109 }
2110
2111 fn register_evaluated_expr(
2112 &mut self,
2113 expr: Expression,
2114 span: Span,
2115 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2116 if let Expression::Literal(literal) = expr {
2120 crate::valid::check_literal_value(literal)?;
2121 }
2122
2123 Ok(self.append_expr(expr, span, ExpressionKind::Const))
2124 }
2125
2126 fn append_expr(
2127 &mut self,
2128 expr: Expression,
2129 span: Span,
2130 expr_type: ExpressionKind,
2131 ) -> Handle<Expression> {
2132 let h = match self.behavior {
2133 Behavior::Wgsl(
2134 WgslRestrictions::Runtime(ref mut function_local_data)
2135 | WgslRestrictions::Const(Some(ref mut function_local_data)),
2136 )
2137 | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
2138 let is_running = function_local_data.emitter.is_running();
2139 let needs_pre_emit = expr.needs_pre_emit();
2140 if is_running && needs_pre_emit {
2141 function_local_data
2142 .block
2143 .extend(function_local_data.emitter.finish(self.expressions));
2144 let h = self.expressions.append(expr, span);
2145 function_local_data.emitter.start(self.expressions);
2146 h
2147 } else {
2148 self.expressions.append(expr, span)
2149 }
2150 }
2151 _ => self.expressions.append(expr, span),
2152 };
2153 self.expression_kind_tracker.insert(h, expr_type);
2154 h
2155 }
2156
2157 fn resolve_type(
2158 &self,
2159 expr: Handle<Expression>,
2160 ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
2161 use crate::proc::TypeResolution as Tr;
2162 use crate::Expression as Ex;
2163 let resolution = match self.expressions[expr] {
2164 Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
2165 Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
2166 Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
2167 Ex::Splat { size, value } => {
2168 let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
2169 return Err(ConstantEvaluatorError::SplatScalarOnly);
2170 };
2171 Tr::Value(TypeInner::Vector { scalar, size })
2172 }
2173 _ => {
2174 log::debug!("resolve_type: SubexpressionsAreNotConstant");
2175 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
2176 }
2177 };
2178
2179 Ok(resolution)
2180 }
2181}
2182
2183fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
2184 let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
2188 match e {
2189 idx @ 0..=31 => idx,
2190 32 => u32::MAX,
2191 _ => unreachable!(),
2192 }
2193 };
2194 match concrete_int {
2195 ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
2196 ConcreteInt::I32([e]) => {
2197 ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
2198 }
2199 }
2200}
2201
2202#[test]
2203fn first_trailing_bit_smoke() {
2204 assert_eq!(
2205 first_trailing_bit(ConcreteInt::I32([0])),
2206 ConcreteInt::I32([-1])
2207 );
2208 assert_eq!(
2209 first_trailing_bit(ConcreteInt::I32([1])),
2210 ConcreteInt::I32([0])
2211 );
2212 assert_eq!(
2213 first_trailing_bit(ConcreteInt::I32([2])),
2214 ConcreteInt::I32([1])
2215 );
2216 assert_eq!(
2217 first_trailing_bit(ConcreteInt::I32([-1])),
2218 ConcreteInt::I32([0]),
2219 );
2220 assert_eq!(
2221 first_trailing_bit(ConcreteInt::I32([i32::MIN])),
2222 ConcreteInt::I32([31]),
2223 );
2224 assert_eq!(
2225 first_trailing_bit(ConcreteInt::I32([i32::MAX])),
2226 ConcreteInt::I32([0]),
2227 );
2228 for idx in 0..32 {
2229 assert_eq!(
2230 first_trailing_bit(ConcreteInt::I32([1 << idx])),
2231 ConcreteInt::I32([idx])
2232 )
2233 }
2234
2235 assert_eq!(
2236 first_trailing_bit(ConcreteInt::U32([0])),
2237 ConcreteInt::U32([u32::MAX])
2238 );
2239 assert_eq!(
2240 first_trailing_bit(ConcreteInt::U32([1])),
2241 ConcreteInt::U32([0])
2242 );
2243 assert_eq!(
2244 first_trailing_bit(ConcreteInt::U32([2])),
2245 ConcreteInt::U32([1])
2246 );
2247 assert_eq!(
2248 first_trailing_bit(ConcreteInt::U32([1 << 31])),
2249 ConcreteInt::U32([31]),
2250 );
2251 assert_eq!(
2252 first_trailing_bit(ConcreteInt::U32([u32::MAX])),
2253 ConcreteInt::U32([0]),
2254 );
2255 for idx in 0..32 {
2256 assert_eq!(
2257 first_trailing_bit(ConcreteInt::U32([1 << idx])),
2258 ConcreteInt::U32([idx])
2259 )
2260 }
2261}
2262
2263fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
2264 let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
2268 match e {
2269 idx @ 0..=31 => 31 - idx,
2270 32 => u32::MAX,
2271 _ => unreachable!(),
2272 }
2273 };
2274 match concrete_int {
2275 ConcreteInt::I32([e]) => ConcreteInt::I32([{
2276 let rtl_bit_index = if e.is_negative() {
2277 e.leading_ones()
2278 } else {
2279 e.leading_zeros()
2280 };
2281 rtl_to_ltr_bit_idx(rtl_bit_index) as i32
2282 }]),
2283 ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
2284 }
2285}
2286
2287#[test]
2288fn first_leading_bit_smoke() {
2289 assert_eq!(
2290 first_leading_bit(ConcreteInt::I32([-1])),
2291 ConcreteInt::I32([-1])
2292 );
2293 assert_eq!(
2294 first_leading_bit(ConcreteInt::I32([0])),
2295 ConcreteInt::I32([-1])
2296 );
2297 assert_eq!(
2298 first_leading_bit(ConcreteInt::I32([1])),
2299 ConcreteInt::I32([0])
2300 );
2301 assert_eq!(
2302 first_leading_bit(ConcreteInt::I32([-2])),
2303 ConcreteInt::I32([0])
2304 );
2305 assert_eq!(
2306 first_leading_bit(ConcreteInt::I32([1234 + 4567])),
2307 ConcreteInt::I32([12])
2308 );
2309 assert_eq!(
2310 first_leading_bit(ConcreteInt::I32([i32::MAX])),
2311 ConcreteInt::I32([30])
2312 );
2313 assert_eq!(
2314 first_leading_bit(ConcreteInt::I32([i32::MIN])),
2315 ConcreteInt::I32([30])
2316 );
2317 for idx in 0..(32 - 1) {
2319 assert_eq!(
2320 first_leading_bit(ConcreteInt::I32([1 << idx])),
2321 ConcreteInt::I32([idx])
2322 );
2323 }
2324 for idx in 1..(32 - 1) {
2325 assert_eq!(
2326 first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
2327 ConcreteInt::I32([idx - 1])
2328 );
2329 }
2330
2331 assert_eq!(
2332 first_leading_bit(ConcreteInt::U32([0])),
2333 ConcreteInt::U32([u32::MAX])
2334 );
2335 assert_eq!(
2336 first_leading_bit(ConcreteInt::U32([1])),
2337 ConcreteInt::U32([0])
2338 );
2339 assert_eq!(
2340 first_leading_bit(ConcreteInt::U32([u32::MAX])),
2341 ConcreteInt::U32([31])
2342 );
2343 for idx in 0..32 {
2344 assert_eq!(
2345 first_leading_bit(ConcreteInt::U32([1 << idx])),
2346 ConcreteInt::U32([idx])
2347 )
2348 }
2349}
2350
2351trait TryFromAbstract<T>: Sized {
2353 fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
2369}
2370
2371impl TryFromAbstract<i64> for i32 {
2372 fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
2373 i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2374 value: format!("{value:?}"),
2375 to_type: "i32",
2376 })
2377 }
2378}
2379
2380impl TryFromAbstract<i64> for u32 {
2381 fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
2382 u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2383 value: format!("{value:?}"),
2384 to_type: "u32",
2385 })
2386 }
2387}
2388
2389impl TryFromAbstract<i64> for u64 {
2390 fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
2391 u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2392 value: format!("{value:?}"),
2393 to_type: "u64",
2394 })
2395 }
2396}
2397
2398impl TryFromAbstract<i64> for i64 {
2399 fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
2400 Ok(value)
2401 }
2402}
2403
2404impl TryFromAbstract<i64> for f32 {
2405 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
2406 let f = value as f32;
2407 Ok(f)
2411 }
2412}
2413
2414impl TryFromAbstract<f64> for f32 {
2415 fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
2416 let f = value as f32;
2417 if f.is_infinite() {
2418 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
2419 value: format!("{value:?}"),
2420 to_type: "f32",
2421 });
2422 }
2423 Ok(f)
2424 }
2425}
2426
2427impl TryFromAbstract<i64> for f64 {
2428 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
2429 let f = value as f64;
2430 Ok(f)
2434 }
2435}
2436
2437impl TryFromAbstract<f64> for f64 {
2438 fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
2439 Ok(value)
2440 }
2441}
2442
2443impl TryFromAbstract<f64> for i32 {
2444 fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
2445 Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "i32" })
2446 }
2447}
2448
2449impl TryFromAbstract<f64> for u32 {
2450 fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
2451 Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "u32" })
2452 }
2453}
2454
2455impl TryFromAbstract<f64> for i64 {
2456 fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
2457 Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "i64" })
2458 }
2459}
2460
2461impl TryFromAbstract<f64> for u64 {
2462 fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
2463 Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "u64" })
2464 }
2465}
2466
2467#[cfg(test)]
2468mod tests {
2469 use std::vec;
2470
2471 use crate::{
2472 Arena, Constant, Expression, Literal, ScalarKind, Type, TypeInner, UnaryOperator,
2473 UniqueArena, VectorSize,
2474 };
2475
2476 use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
2477
2478 #[test]
2479 fn unary_op() {
2480 let mut types = UniqueArena::new();
2481 let mut constants = Arena::new();
2482 let overrides = Arena::new();
2483 let mut global_expressions = Arena::new();
2484
2485 let scalar_ty = types.insert(
2486 Type {
2487 name: None,
2488 inner: TypeInner::Scalar(crate::Scalar::I32),
2489 },
2490 Default::default(),
2491 );
2492
2493 let vec_ty = types.insert(
2494 Type {
2495 name: None,
2496 inner: TypeInner::Vector {
2497 size: VectorSize::Bi,
2498 scalar: crate::Scalar::I32,
2499 },
2500 },
2501 Default::default(),
2502 );
2503
2504 let h = constants.append(
2505 Constant {
2506 name: None,
2507 ty: scalar_ty,
2508 init: global_expressions
2509 .append(Expression::Literal(Literal::I32(4)), Default::default()),
2510 },
2511 Default::default(),
2512 );
2513
2514 let h1 = constants.append(
2515 Constant {
2516 name: None,
2517 ty: scalar_ty,
2518 init: global_expressions
2519 .append(Expression::Literal(Literal::I32(8)), Default::default()),
2520 },
2521 Default::default(),
2522 );
2523
2524 let vec_h = constants.append(
2525 Constant {
2526 name: None,
2527 ty: vec_ty,
2528 init: global_expressions.append(
2529 Expression::Compose {
2530 ty: vec_ty,
2531 components: vec![constants[h].init, constants[h1].init],
2532 },
2533 Default::default(),
2534 ),
2535 },
2536 Default::default(),
2537 );
2538
2539 let expr = global_expressions.append(Expression::Constant(h), Default::default());
2540 let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
2541
2542 let expr2 = Expression::Unary {
2543 op: UnaryOperator::Negate,
2544 expr,
2545 };
2546
2547 let expr3 = Expression::Unary {
2548 op: UnaryOperator::BitwiseNot,
2549 expr,
2550 };
2551
2552 let expr4 = Expression::Unary {
2553 op: UnaryOperator::BitwiseNot,
2554 expr: expr1,
2555 };
2556
2557 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2558 let mut solver = ConstantEvaluator {
2559 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2560 types: &mut types,
2561 constants: &constants,
2562 overrides: &overrides,
2563 expressions: &mut global_expressions,
2564 expression_kind_tracker,
2565 };
2566
2567 let res1 = solver
2568 .try_eval_and_append(expr2, Default::default())
2569 .unwrap();
2570 let res2 = solver
2571 .try_eval_and_append(expr3, Default::default())
2572 .unwrap();
2573 let res3 = solver
2574 .try_eval_and_append(expr4, Default::default())
2575 .unwrap();
2576
2577 assert_eq!(
2578 global_expressions[res1],
2579 Expression::Literal(Literal::I32(-4))
2580 );
2581
2582 assert_eq!(
2583 global_expressions[res2],
2584 Expression::Literal(Literal::I32(!4))
2585 );
2586
2587 let res3_inner = &global_expressions[res3];
2588
2589 match *res3_inner {
2590 Expression::Compose {
2591 ref ty,
2592 ref components,
2593 } => {
2594 assert_eq!(*ty, vec_ty);
2595 let mut components_iter = components.iter().copied();
2596 assert_eq!(
2597 global_expressions[components_iter.next().unwrap()],
2598 Expression::Literal(Literal::I32(!4))
2599 );
2600 assert_eq!(
2601 global_expressions[components_iter.next().unwrap()],
2602 Expression::Literal(Literal::I32(!8))
2603 );
2604 assert!(components_iter.next().is_none());
2605 }
2606 _ => panic!("Expected vector"),
2607 }
2608 }
2609
2610 #[test]
2611 fn cast() {
2612 let mut types = UniqueArena::new();
2613 let mut constants = Arena::new();
2614 let overrides = Arena::new();
2615 let mut global_expressions = Arena::new();
2616
2617 let scalar_ty = types.insert(
2618 Type {
2619 name: None,
2620 inner: TypeInner::Scalar(crate::Scalar::I32),
2621 },
2622 Default::default(),
2623 );
2624
2625 let h = constants.append(
2626 Constant {
2627 name: None,
2628 ty: scalar_ty,
2629 init: global_expressions
2630 .append(Expression::Literal(Literal::I32(4)), Default::default()),
2631 },
2632 Default::default(),
2633 );
2634
2635 let expr = global_expressions.append(Expression::Constant(h), Default::default());
2636
2637 let root = Expression::As {
2638 expr,
2639 kind: ScalarKind::Bool,
2640 convert: Some(crate::BOOL_WIDTH),
2641 };
2642
2643 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2644 let mut solver = ConstantEvaluator {
2645 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2646 types: &mut types,
2647 constants: &constants,
2648 overrides: &overrides,
2649 expressions: &mut global_expressions,
2650 expression_kind_tracker,
2651 };
2652
2653 let res = solver
2654 .try_eval_and_append(root, Default::default())
2655 .unwrap();
2656
2657 assert_eq!(
2658 global_expressions[res],
2659 Expression::Literal(Literal::Bool(true))
2660 );
2661 }
2662
2663 #[test]
2664 fn access() {
2665 let mut types = UniqueArena::new();
2666 let mut constants = Arena::new();
2667 let overrides = Arena::new();
2668 let mut global_expressions = Arena::new();
2669
2670 let matrix_ty = types.insert(
2671 Type {
2672 name: None,
2673 inner: TypeInner::Matrix {
2674 columns: VectorSize::Bi,
2675 rows: VectorSize::Tri,
2676 scalar: crate::Scalar::F32,
2677 },
2678 },
2679 Default::default(),
2680 );
2681
2682 let vec_ty = types.insert(
2683 Type {
2684 name: None,
2685 inner: TypeInner::Vector {
2686 size: VectorSize::Tri,
2687 scalar: crate::Scalar::F32,
2688 },
2689 },
2690 Default::default(),
2691 );
2692
2693 let mut vec1_components = Vec::with_capacity(3);
2694 let mut vec2_components = Vec::with_capacity(3);
2695
2696 for i in 0..3 {
2697 let h = global_expressions.append(
2698 Expression::Literal(Literal::F32(i as f32)),
2699 Default::default(),
2700 );
2701
2702 vec1_components.push(h)
2703 }
2704
2705 for i in 3..6 {
2706 let h = global_expressions.append(
2707 Expression::Literal(Literal::F32(i as f32)),
2708 Default::default(),
2709 );
2710
2711 vec2_components.push(h)
2712 }
2713
2714 let vec1 = constants.append(
2715 Constant {
2716 name: None,
2717 ty: vec_ty,
2718 init: global_expressions.append(
2719 Expression::Compose {
2720 ty: vec_ty,
2721 components: vec1_components,
2722 },
2723 Default::default(),
2724 ),
2725 },
2726 Default::default(),
2727 );
2728
2729 let vec2 = constants.append(
2730 Constant {
2731 name: None,
2732 ty: vec_ty,
2733 init: global_expressions.append(
2734 Expression::Compose {
2735 ty: vec_ty,
2736 components: vec2_components,
2737 },
2738 Default::default(),
2739 ),
2740 },
2741 Default::default(),
2742 );
2743
2744 let h = constants.append(
2745 Constant {
2746 name: None,
2747 ty: matrix_ty,
2748 init: global_expressions.append(
2749 Expression::Compose {
2750 ty: matrix_ty,
2751 components: vec![constants[vec1].init, constants[vec2].init],
2752 },
2753 Default::default(),
2754 ),
2755 },
2756 Default::default(),
2757 );
2758
2759 let base = global_expressions.append(Expression::Constant(h), Default::default());
2760
2761 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2762 let mut solver = ConstantEvaluator {
2763 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2764 types: &mut types,
2765 constants: &constants,
2766 overrides: &overrides,
2767 expressions: &mut global_expressions,
2768 expression_kind_tracker,
2769 };
2770
2771 let root1 = Expression::AccessIndex { base, index: 1 };
2772
2773 let res1 = solver
2774 .try_eval_and_append(root1, Default::default())
2775 .unwrap();
2776
2777 let root2 = Expression::AccessIndex {
2778 base: res1,
2779 index: 2,
2780 };
2781
2782 let res2 = solver
2783 .try_eval_and_append(root2, Default::default())
2784 .unwrap();
2785
2786 match global_expressions[res1] {
2787 Expression::Compose {
2788 ref ty,
2789 ref components,
2790 } => {
2791 assert_eq!(*ty, vec_ty);
2792 let mut components_iter = components.iter().copied();
2793 assert_eq!(
2794 global_expressions[components_iter.next().unwrap()],
2795 Expression::Literal(Literal::F32(3.))
2796 );
2797 assert_eq!(
2798 global_expressions[components_iter.next().unwrap()],
2799 Expression::Literal(Literal::F32(4.))
2800 );
2801 assert_eq!(
2802 global_expressions[components_iter.next().unwrap()],
2803 Expression::Literal(Literal::F32(5.))
2804 );
2805 assert!(components_iter.next().is_none());
2806 }
2807 _ => panic!("Expected vector"),
2808 }
2809
2810 assert_eq!(
2811 global_expressions[res2],
2812 Expression::Literal(Literal::F32(5.))
2813 );
2814 }
2815
2816 #[test]
2817 fn compose_of_constants() {
2818 let mut types = UniqueArena::new();
2819 let mut constants = Arena::new();
2820 let overrides = Arena::new();
2821 let mut global_expressions = Arena::new();
2822
2823 let i32_ty = types.insert(
2824 Type {
2825 name: None,
2826 inner: TypeInner::Scalar(crate::Scalar::I32),
2827 },
2828 Default::default(),
2829 );
2830
2831 let vec2_i32_ty = types.insert(
2832 Type {
2833 name: None,
2834 inner: TypeInner::Vector {
2835 size: VectorSize::Bi,
2836 scalar: crate::Scalar::I32,
2837 },
2838 },
2839 Default::default(),
2840 );
2841
2842 let h = constants.append(
2843 Constant {
2844 name: None,
2845 ty: i32_ty,
2846 init: global_expressions
2847 .append(Expression::Literal(Literal::I32(4)), Default::default()),
2848 },
2849 Default::default(),
2850 );
2851
2852 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
2853
2854 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2855 let mut solver = ConstantEvaluator {
2856 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2857 types: &mut types,
2858 constants: &constants,
2859 overrides: &overrides,
2860 expressions: &mut global_expressions,
2861 expression_kind_tracker,
2862 };
2863
2864 let solved_compose = solver
2865 .try_eval_and_append(
2866 Expression::Compose {
2867 ty: vec2_i32_ty,
2868 components: vec![h_expr, h_expr],
2869 },
2870 Default::default(),
2871 )
2872 .unwrap();
2873 let solved_negate = solver
2874 .try_eval_and_append(
2875 Expression::Unary {
2876 op: UnaryOperator::Negate,
2877 expr: solved_compose,
2878 },
2879 Default::default(),
2880 )
2881 .unwrap();
2882
2883 let pass = match global_expressions[solved_negate] {
2884 Expression::Compose { ty, ref components } => {
2885 ty == vec2_i32_ty
2886 && components.iter().all(|&component| {
2887 let component = &global_expressions[component];
2888 matches!(*component, Expression::Literal(Literal::I32(-4)))
2889 })
2890 }
2891 _ => false,
2892 };
2893 if !pass {
2894 panic!("unexpected evaluation result")
2895 }
2896 }
2897
2898 #[test]
2899 fn splat_of_constant() {
2900 let mut types = UniqueArena::new();
2901 let mut constants = Arena::new();
2902 let overrides = Arena::new();
2903 let mut global_expressions = Arena::new();
2904
2905 let i32_ty = types.insert(
2906 Type {
2907 name: None,
2908 inner: TypeInner::Scalar(crate::Scalar::I32),
2909 },
2910 Default::default(),
2911 );
2912
2913 let vec2_i32_ty = types.insert(
2914 Type {
2915 name: None,
2916 inner: TypeInner::Vector {
2917 size: VectorSize::Bi,
2918 scalar: crate::Scalar::I32,
2919 },
2920 },
2921 Default::default(),
2922 );
2923
2924 let h = constants.append(
2925 Constant {
2926 name: None,
2927 ty: i32_ty,
2928 init: global_expressions
2929 .append(Expression::Literal(Literal::I32(4)), Default::default()),
2930 },
2931 Default::default(),
2932 );
2933
2934 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
2935
2936 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2937 let mut solver = ConstantEvaluator {
2938 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2939 types: &mut types,
2940 constants: &constants,
2941 overrides: &overrides,
2942 expressions: &mut global_expressions,
2943 expression_kind_tracker,
2944 };
2945
2946 let solved_compose = solver
2947 .try_eval_and_append(
2948 Expression::Splat {
2949 size: VectorSize::Bi,
2950 value: h_expr,
2951 },
2952 Default::default(),
2953 )
2954 .unwrap();
2955 let solved_negate = solver
2956 .try_eval_and_append(
2957 Expression::Unary {
2958 op: UnaryOperator::Negate,
2959 expr: solved_compose,
2960 },
2961 Default::default(),
2962 )
2963 .unwrap();
2964
2965 let pass = match global_expressions[solved_negate] {
2966 Expression::Compose { ty, ref components } => {
2967 ty == vec2_i32_ty
2968 && components.iter().all(|&component| {
2969 let component = &global_expressions[component];
2970 matches!(*component, Expression::Literal(Literal::I32(-4)))
2971 })
2972 }
2973 _ => false,
2974 };
2975 if !pass {
2976 panic!("unexpected evaluation result")
2977 }
2978 }
2979}