naga/valid/
expression.rs

1use super::{compose::validate_compose, FunctionInfo, ModuleInfo, ShaderStages, TypeFlags};
2use crate::arena::UniqueArena;
3
4use crate::{
5    arena::Handle,
6    proc::{IndexableLengthError, ResolveError},
7};
8
9#[derive(Clone, Debug, thiserror::Error)]
10#[cfg_attr(test, derive(PartialEq))]
11pub enum ExpressionError {
12    #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")]
13    NotInScope,
14    #[error("Base type {0:?} is not compatible with this expression")]
15    InvalidBaseType(Handle<crate::Expression>),
16    #[error("Accessing with index {0:?} can't be done")]
17    InvalidIndexType(Handle<crate::Expression>),
18    #[error("Accessing {0:?} via a negative index is invalid")]
19    NegativeIndex(Handle<crate::Expression>),
20    #[error("Accessing index {1} is out of {0:?} bounds")]
21    IndexOutOfBounds(Handle<crate::Expression>, u32),
22    #[error("Function argument {0:?} doesn't exist")]
23    FunctionArgumentDoesntExist(u32),
24    #[error("Loading of {0:?} can't be done")]
25    InvalidPointerType(Handle<crate::Expression>),
26    #[error("Array length of {0:?} can't be done")]
27    InvalidArrayType(Handle<crate::Expression>),
28    #[error("Get intersection of {0:?} can't be done")]
29    InvalidRayQueryType(Handle<crate::Expression>),
30    #[error("Splatting {0:?} can't be done")]
31    InvalidSplatType(Handle<crate::Expression>),
32    #[error("Swizzling {0:?} can't be done")]
33    InvalidVectorType(Handle<crate::Expression>),
34    #[error("Swizzle component {0:?} is outside of vector size {1:?}")]
35    InvalidSwizzleComponent(crate::SwizzleComponent, crate::VectorSize),
36    #[error(transparent)]
37    Compose(#[from] super::ComposeError),
38    #[error(transparent)]
39    IndexableLength(#[from] IndexableLengthError),
40    #[error("Operation {0:?} can't work with {1:?}")]
41    InvalidUnaryOperandType(crate::UnaryOperator, Handle<crate::Expression>),
42    #[error(
43        "Operation {:?} can't work with {:?} (of type {:?}) and {:?} (of type {:?})",
44        op,
45        lhs_expr,
46        lhs_type,
47        rhs_expr,
48        rhs_type
49    )]
50    InvalidBinaryOperandTypes {
51        op: crate::BinaryOperator,
52        lhs_expr: Handle<crate::Expression>,
53        lhs_type: crate::TypeInner,
54        rhs_expr: Handle<crate::Expression>,
55        rhs_type: crate::TypeInner,
56    },
57    #[error("Expected selection argument types to match, but reject value of type {reject:?} does not match accept value of value {accept:?}")]
58    SelectValuesTypeMismatch {
59        accept: crate::TypeInner,
60        reject: crate::TypeInner,
61    },
62    #[error("Expected selection condition to be a boolean value, got {actual:?}")]
63    SelectConditionNotABool { actual: crate::TypeInner },
64    #[error("Relational argument {0:?} is not a boolean vector")]
65    InvalidBooleanVector(Handle<crate::Expression>),
66    #[error("Relational argument {0:?} is not a float")]
67    InvalidFloatArgument(Handle<crate::Expression>),
68    #[error("Type resolution failed")]
69    Type(#[from] ResolveError),
70    #[error("Not a global variable")]
71    ExpectedGlobalVariable,
72    #[error("Not a global variable or a function argument")]
73    ExpectedGlobalOrArgument,
74    #[error("Needs to be an binding array instead of {0:?}")]
75    ExpectedBindingArrayType(Handle<crate::Type>),
76    #[error("Needs to be an image instead of {0:?}")]
77    ExpectedImageType(Handle<crate::Type>),
78    #[error("Needs to be an image instead of {0:?}")]
79    ExpectedSamplerType(Handle<crate::Type>),
80    #[error("Unable to operate on image class {0:?}")]
81    InvalidImageClass(crate::ImageClass),
82    #[error("Image atomics are not supported for storage format {0:?}")]
83    InvalidImageFormat(crate::StorageFormat),
84    #[error("Image atomics require atomic storage access, {0:?} is insufficient")]
85    InvalidImageStorageAccess(crate::StorageAccess),
86    #[error("Derivatives can only be taken from scalar and vector floats")]
87    InvalidDerivative,
88    #[error("Image array index parameter is misplaced")]
89    InvalidImageArrayIndex,
90    #[error("Inappropriate sample or level-of-detail index for texel access")]
91    InvalidImageOtherIndex,
92    #[error("Image array index type of {0:?} is not an integer scalar")]
93    InvalidImageArrayIndexType(Handle<crate::Expression>),
94    #[error("Image sample or level-of-detail index's type of {0:?} is not an integer scalar")]
95    InvalidImageOtherIndexType(Handle<crate::Expression>),
96    #[error("Image coordinate type of {1:?} does not match dimension {0:?}")]
97    InvalidImageCoordinateType(crate::ImageDimension, Handle<crate::Expression>),
98    #[error("Comparison sampling mismatch: image has class {image:?}, but the sampler is comparison={sampler}, and the reference was provided={has_ref}")]
99    ComparisonSamplingMismatch {
100        image: crate::ImageClass,
101        sampler: bool,
102        has_ref: bool,
103    },
104    #[error("Sample offset must be a const-expression")]
105    InvalidSampleOffsetExprType,
106    #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")]
107    InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>),
108    #[error("Depth reference {0:?} is not a scalar float")]
109    InvalidDepthReference(Handle<crate::Expression>),
110    #[error("Depth sample level can only be Auto or Zero")]
111    InvalidDepthSampleLevel,
112    #[error("Gather level can only be Zero")]
113    InvalidGatherLevel,
114    #[error("Gather component {0:?} doesn't exist in the image")]
115    InvalidGatherComponent(crate::SwizzleComponent),
116    #[error("Gather can't be done for image dimension {0:?}")]
117    InvalidGatherDimension(crate::ImageDimension),
118    #[error("Sample level (exact) type {0:?} has an invalid type")]
119    InvalidSampleLevelExactType(Handle<crate::Expression>),
120    #[error("Sample level (bias) type {0:?} is not a scalar float")]
121    InvalidSampleLevelBiasType(Handle<crate::Expression>),
122    #[error("Bias can't be done for image dimension {0:?}")]
123    InvalidSampleLevelBiasDimension(crate::ImageDimension),
124    #[error("Sample level (gradient) of {1:?} doesn't match the image dimension {0:?}")]
125    InvalidSampleLevelGradientType(crate::ImageDimension, Handle<crate::Expression>),
126    #[error("Unable to cast")]
127    InvalidCastArgument,
128    #[error("Invalid argument count for {0:?}")]
129    WrongArgumentCount(crate::MathFunction),
130    #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
131    InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
132    #[error(
133        "workgroupUniformLoad result type can't be {0:?}. It can only be a constructible type."
134    )]
135    InvalidWorkGroupUniformLoadResultType(Handle<crate::Type>),
136    #[error("Shader requires capability {0:?}")]
137    MissingCapabilities(super::Capabilities),
138    #[error(transparent)]
139    Literal(#[from] LiteralError),
140    #[error("{0:?} is not supported for Width {2} {1:?} arguments yet, see https://github.com/gfx-rs/wgpu/issues/5276")]
141    UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes),
142}
143
144#[derive(Clone, Debug, thiserror::Error)]
145#[cfg_attr(test, derive(PartialEq))]
146pub enum ConstExpressionError {
147    #[error("The expression is not a constant or override expression")]
148    NonConstOrOverride,
149    #[error("The expression is not a fully evaluated constant expression")]
150    NonFullyEvaluatedConst,
151    #[error(transparent)]
152    Compose(#[from] super::ComposeError),
153    #[error("Splatting {0:?} can't be done")]
154    InvalidSplatType(Handle<crate::Expression>),
155    #[error("Type resolution failed")]
156    Type(#[from] ResolveError),
157    #[error(transparent)]
158    Literal(#[from] LiteralError),
159    #[error(transparent)]
160    Width(#[from] super::r#type::WidthError),
161}
162
163#[derive(Clone, Debug, thiserror::Error)]
164#[cfg_attr(test, derive(PartialEq))]
165pub enum LiteralError {
166    #[error("Float literal is NaN")]
167    NaN,
168    #[error("Float literal is infinite")]
169    Infinity,
170    #[error(transparent)]
171    Width(#[from] super::r#type::WidthError),
172}
173
174struct ExpressionTypeResolver<'a> {
175    root: Handle<crate::Expression>,
176    types: &'a UniqueArena<crate::Type>,
177    info: &'a FunctionInfo,
178}
179
180impl std::ops::Index<Handle<crate::Expression>> for ExpressionTypeResolver<'_> {
181    type Output = crate::TypeInner;
182
183    #[allow(clippy::panic)]
184    fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
185        if handle < self.root {
186            self.info[handle].ty.inner_with(self.types)
187        } else {
188            // `Validator::validate_module_handles` should have caught this.
189            panic!(
190                "Depends on {:?}, which has not been processed yet",
191                self.root
192            )
193        }
194    }
195}
196
197impl super::Validator {
198    pub(super) fn validate_const_expression(
199        &self,
200        handle: Handle<crate::Expression>,
201        gctx: crate::proc::GlobalCtx,
202        mod_info: &ModuleInfo,
203        global_expr_kind: &crate::proc::ExpressionKindTracker,
204    ) -> Result<(), ConstExpressionError> {
205        use crate::Expression as E;
206
207        if !global_expr_kind.is_const_or_override(handle) {
208            return Err(ConstExpressionError::NonConstOrOverride);
209        }
210
211        match gctx.global_expressions[handle] {
212            E::Literal(literal) => {
213                self.validate_literal(literal)?;
214            }
215            E::Constant(_) | E::ZeroValue(_) => {}
216            E::Compose { ref components, ty } => {
217                validate_compose(
218                    ty,
219                    gctx,
220                    components.iter().map(|&handle| mod_info[handle].clone()),
221                )?;
222            }
223            E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) {
224                crate::TypeInner::Scalar { .. } => {}
225                _ => return Err(ConstExpressionError::InvalidSplatType(value)),
226            },
227            _ if global_expr_kind.is_const(handle) || !self.allow_overrides => {
228                return Err(ConstExpressionError::NonFullyEvaluatedConst)
229            }
230            // the constant evaluator will report errors about override-expressions
231            _ => {}
232        }
233
234        Ok(())
235    }
236
237    #[allow(clippy::too_many_arguments)]
238    pub(super) fn validate_expression(
239        &self,
240        root: Handle<crate::Expression>,
241        expression: &crate::Expression,
242        function: &crate::Function,
243        module: &crate::Module,
244        info: &FunctionInfo,
245        mod_info: &ModuleInfo,
246        global_expr_kind: &crate::proc::ExpressionKindTracker,
247    ) -> Result<ShaderStages, ExpressionError> {
248        use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti};
249
250        let resolver = ExpressionTypeResolver {
251            root,
252            types: &module.types,
253            info,
254        };
255
256        let stages = match *expression {
257            E::Access { base, index } => {
258                let base_type = &resolver[base];
259                match *base_type {
260                    Ti::Matrix { .. }
261                    | Ti::Vector { .. }
262                    | Ti::Array { .. }
263                    | Ti::Pointer { .. }
264                    | Ti::ValuePointer { size: Some(_), .. }
265                    | Ti::BindingArray { .. } => {}
266                    ref other => {
267                        log::error!("Indexing of {:?}", other);
268                        return Err(ExpressionError::InvalidBaseType(base));
269                    }
270                };
271                match resolver[index] {
272                    //TODO: only allow one of these
273                    Ti::Scalar(Sc {
274                        kind: Sk::Sint | Sk::Uint,
275                        ..
276                    }) => {}
277                    ref other => {
278                        log::error!("Indexing by {:?}", other);
279                        return Err(ExpressionError::InvalidIndexType(index));
280                    }
281                }
282
283                // If we know both the length and the index, we can do the
284                // bounds check now.
285                if let crate::proc::IndexableLength::Known(known_length) =
286                    base_type.indexable_length(module)?
287                {
288                    match module
289                        .to_ctx()
290                        .eval_expr_to_u32_from(index, &function.expressions)
291                    {
292                        Ok(value) => {
293                            if value >= known_length {
294                                return Err(ExpressionError::IndexOutOfBounds(base, value));
295                            }
296                        }
297                        Err(crate::proc::U32EvalError::Negative) => {
298                            return Err(ExpressionError::NegativeIndex(base))
299                        }
300                        Err(crate::proc::U32EvalError::NonConst) => {}
301                    }
302                }
303
304                ShaderStages::all()
305            }
306            E::AccessIndex { base, index } => {
307                fn resolve_index_limit(
308                    module: &crate::Module,
309                    top: Handle<crate::Expression>,
310                    ty: &crate::TypeInner,
311                    top_level: bool,
312                ) -> Result<u32, ExpressionError> {
313                    let limit = match *ty {
314                        Ti::Vector { size, .. }
315                        | Ti::ValuePointer {
316                            size: Some(size), ..
317                        } => size as u32,
318                        Ti::Matrix { columns, .. } => columns as u32,
319                        Ti::Array {
320                            size: crate::ArraySize::Constant(len),
321                            ..
322                        } => len.get(),
323                        Ti::Array { .. } | Ti::BindingArray { .. } => u32::MAX, // can't statically know, but need run-time checks
324                        Ti::Pointer { base, .. } if top_level => {
325                            resolve_index_limit(module, top, &module.types[base].inner, false)?
326                        }
327                        Ti::Struct { ref members, .. } => members.len() as u32,
328                        ref other => {
329                            log::error!("Indexing of {:?}", other);
330                            return Err(ExpressionError::InvalidBaseType(top));
331                        }
332                    };
333                    Ok(limit)
334                }
335
336                let limit = resolve_index_limit(module, base, &resolver[base], true)?;
337                if index >= limit {
338                    return Err(ExpressionError::IndexOutOfBounds(base, limit));
339                }
340                ShaderStages::all()
341            }
342            E::Splat { size: _, value } => match resolver[value] {
343                Ti::Scalar { .. } => ShaderStages::all(),
344                ref other => {
345                    log::error!("Splat scalar type {:?}", other);
346                    return Err(ExpressionError::InvalidSplatType(value));
347                }
348            },
349            E::Swizzle {
350                size,
351                vector,
352                pattern,
353            } => {
354                let vec_size = match resolver[vector] {
355                    Ti::Vector { size: vec_size, .. } => vec_size,
356                    ref other => {
357                        log::error!("Swizzle vector type {:?}", other);
358                        return Err(ExpressionError::InvalidVectorType(vector));
359                    }
360                };
361                for &sc in pattern[..size as usize].iter() {
362                    if sc as u8 >= vec_size as u8 {
363                        return Err(ExpressionError::InvalidSwizzleComponent(sc, vec_size));
364                    }
365                }
366                ShaderStages::all()
367            }
368            E::Literal(literal) => {
369                self.validate_literal(literal)?;
370                ShaderStages::all()
371            }
372            E::Constant(_) | E::Override(_) | E::ZeroValue(_) => ShaderStages::all(),
373            E::Compose { ref components, ty } => {
374                validate_compose(
375                    ty,
376                    module.to_ctx(),
377                    components.iter().map(|&handle| info[handle].ty.clone()),
378                )?;
379                ShaderStages::all()
380            }
381            E::FunctionArgument(index) => {
382                if index >= function.arguments.len() as u32 {
383                    return Err(ExpressionError::FunctionArgumentDoesntExist(index));
384                }
385                ShaderStages::all()
386            }
387            E::GlobalVariable(_handle) => ShaderStages::all(),
388            E::LocalVariable(_handle) => ShaderStages::all(),
389            E::Load { pointer } => {
390                match resolver[pointer] {
391                    Ti::Pointer { base, .. }
392                        if self.types[base.index()]
393                            .flags
394                            .contains(TypeFlags::SIZED | TypeFlags::DATA) => {}
395                    Ti::ValuePointer { .. } => {}
396                    ref other => {
397                        log::error!("Loading {:?}", other);
398                        return Err(ExpressionError::InvalidPointerType(pointer));
399                    }
400                }
401                ShaderStages::all()
402            }
403            E::ImageSample {
404                image,
405                sampler,
406                gather,
407                coordinate,
408                array_index,
409                offset,
410                level,
411                depth_ref,
412            } => {
413                // check the validity of expressions
414                let image_ty = Self::global_var_ty(module, function, image)?;
415                let sampler_ty = Self::global_var_ty(module, function, sampler)?;
416
417                let comparison = match module.types[sampler_ty].inner {
418                    Ti::Sampler { comparison } => comparison,
419                    _ => return Err(ExpressionError::ExpectedSamplerType(sampler_ty)),
420                };
421
422                let (class, dim) = match module.types[image_ty].inner {
423                    Ti::Image {
424                        class,
425                        arrayed,
426                        dim,
427                    } => {
428                        // check the array property
429                        if arrayed != array_index.is_some() {
430                            return Err(ExpressionError::InvalidImageArrayIndex);
431                        }
432                        if let Some(expr) = array_index {
433                            match resolver[expr] {
434                                Ti::Scalar(Sc {
435                                    kind: Sk::Sint | Sk::Uint,
436                                    ..
437                                }) => {}
438                                _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
439                            }
440                        }
441                        (class, dim)
442                    }
443                    _ => return Err(ExpressionError::ExpectedImageType(image_ty)),
444                };
445
446                // check sampling and comparison properties
447                let image_depth = match class {
448                    crate::ImageClass::Sampled {
449                        kind: crate::ScalarKind::Float,
450                        multi: false,
451                    } => false,
452                    crate::ImageClass::Sampled {
453                        kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint,
454                        multi: false,
455                    } if gather.is_some() => false,
456                    crate::ImageClass::Depth { multi: false } => true,
457                    _ => return Err(ExpressionError::InvalidImageClass(class)),
458                };
459                if comparison != depth_ref.is_some() || (comparison && !image_depth) {
460                    return Err(ExpressionError::ComparisonSamplingMismatch {
461                        image: class,
462                        sampler: comparison,
463                        has_ref: depth_ref.is_some(),
464                    });
465                }
466
467                // check texture coordinates type
468                let num_components = match dim {
469                    crate::ImageDimension::D1 => 1,
470                    crate::ImageDimension::D2 => 2,
471                    crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3,
472                };
473                match resolver[coordinate] {
474                    Ti::Scalar(Sc {
475                        kind: Sk::Float, ..
476                    }) if num_components == 1 => {}
477                    Ti::Vector {
478                        size,
479                        scalar:
480                            Sc {
481                                kind: Sk::Float, ..
482                            },
483                    } if size as u32 == num_components => {}
484                    _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
485                }
486
487                // check constant offset
488                if let Some(const_expr) = offset {
489                    if !global_expr_kind.is_const(const_expr) {
490                        return Err(ExpressionError::InvalidSampleOffsetExprType);
491                    }
492
493                    match *mod_info[const_expr].inner_with(&module.types) {
494                        Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {}
495                        Ti::Vector {
496                            size,
497                            scalar: Sc { kind: Sk::Sint, .. },
498                        } if size as u32 == num_components => {}
499                        _ => {
500                            return Err(ExpressionError::InvalidSampleOffset(dim, const_expr));
501                        }
502                    }
503                }
504
505                // check depth reference type
506                if let Some(expr) = depth_ref {
507                    match resolver[expr] {
508                        Ti::Scalar(Sc {
509                            kind: Sk::Float, ..
510                        }) => {}
511                        _ => return Err(ExpressionError::InvalidDepthReference(expr)),
512                    }
513                    match level {
514                        crate::SampleLevel::Auto | crate::SampleLevel::Zero => {}
515                        _ => return Err(ExpressionError::InvalidDepthSampleLevel),
516                    }
517                }
518
519                if let Some(component) = gather {
520                    match dim {
521                        crate::ImageDimension::D2 | crate::ImageDimension::Cube => {}
522                        crate::ImageDimension::D1 | crate::ImageDimension::D3 => {
523                            return Err(ExpressionError::InvalidGatherDimension(dim))
524                        }
525                    };
526                    let max_component = match class {
527                        crate::ImageClass::Depth { .. } => crate::SwizzleComponent::X,
528                        _ => crate::SwizzleComponent::W,
529                    };
530                    if component > max_component {
531                        return Err(ExpressionError::InvalidGatherComponent(component));
532                    }
533                    match level {
534                        crate::SampleLevel::Zero => {}
535                        _ => return Err(ExpressionError::InvalidGatherLevel),
536                    }
537                }
538
539                // check level properties
540                match level {
541                    crate::SampleLevel::Auto => ShaderStages::FRAGMENT,
542                    crate::SampleLevel::Zero => ShaderStages::all(),
543                    crate::SampleLevel::Exact(expr) => {
544                        match class {
545                            crate::ImageClass::Depth { .. } => match resolver[expr] {
546                                Ti::Scalar(Sc {
547                                    kind: Sk::Sint | Sk::Uint,
548                                    ..
549                                }) => {}
550                                _ => {
551                                    return Err(ExpressionError::InvalidSampleLevelExactType(expr))
552                                }
553                            },
554                            _ => match resolver[expr] {
555                                Ti::Scalar(Sc {
556                                    kind: Sk::Float, ..
557                                }) => {}
558                                _ => {
559                                    return Err(ExpressionError::InvalidSampleLevelExactType(expr))
560                                }
561                            },
562                        }
563                        ShaderStages::all()
564                    }
565                    crate::SampleLevel::Bias(expr) => {
566                        match resolver[expr] {
567                            Ti::Scalar(Sc {
568                                kind: Sk::Float, ..
569                            }) => {}
570                            _ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)),
571                        }
572                        match class {
573                            crate::ImageClass::Sampled {
574                                kind: Sk::Float,
575                                multi: false,
576                            } => {
577                                if dim == crate::ImageDimension::D1 {
578                                    return Err(ExpressionError::InvalidSampleLevelBiasDimension(
579                                        dim,
580                                    ));
581                                }
582                            }
583                            _ => return Err(ExpressionError::InvalidImageClass(class)),
584                        }
585                        ShaderStages::FRAGMENT
586                    }
587                    crate::SampleLevel::Gradient { x, y } => {
588                        match resolver[x] {
589                            Ti::Scalar(Sc {
590                                kind: Sk::Float, ..
591                            }) if num_components == 1 => {}
592                            Ti::Vector {
593                                size,
594                                scalar:
595                                    Sc {
596                                        kind: Sk::Float, ..
597                                    },
598                            } if size as u32 == num_components => {}
599                            _ => {
600                                return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x))
601                            }
602                        }
603                        match resolver[y] {
604                            Ti::Scalar(Sc {
605                                kind: Sk::Float, ..
606                            }) if num_components == 1 => {}
607                            Ti::Vector {
608                                size,
609                                scalar:
610                                    Sc {
611                                        kind: Sk::Float, ..
612                                    },
613                            } if size as u32 == num_components => {}
614                            _ => {
615                                return Err(ExpressionError::InvalidSampleLevelGradientType(dim, y))
616                            }
617                        }
618                        ShaderStages::all()
619                    }
620                }
621            }
622            E::ImageLoad {
623                image,
624                coordinate,
625                array_index,
626                sample,
627                level,
628            } => {
629                let ty = Self::global_var_ty(module, function, image)?;
630                match module.types[ty].inner {
631                    Ti::Image {
632                        class,
633                        arrayed,
634                        dim,
635                    } => {
636                        match resolver[coordinate].image_storage_coordinates() {
637                            Some(coord_dim) if coord_dim == dim => {}
638                            _ => {
639                                return Err(ExpressionError::InvalidImageCoordinateType(
640                                    dim, coordinate,
641                                ))
642                            }
643                        };
644                        if arrayed != array_index.is_some() {
645                            return Err(ExpressionError::InvalidImageArrayIndex);
646                        }
647                        if let Some(expr) = array_index {
648                            match resolver[expr] {
649                                Ti::Scalar(Sc {
650                                    kind: Sk::Sint | Sk::Uint,
651                                    width: _,
652                                }) => {}
653                                _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
654                            }
655                        }
656
657                        match (sample, class.is_multisampled()) {
658                            (None, false) => {}
659                            (Some(sample), true) => {
660                                if resolver[sample].scalar_kind() != Some(Sk::Sint) {
661                                    return Err(ExpressionError::InvalidImageOtherIndexType(
662                                        sample,
663                                    ));
664                                }
665                            }
666                            _ => {
667                                return Err(ExpressionError::InvalidImageOtherIndex);
668                            }
669                        }
670
671                        match (level, class.is_mipmapped()) {
672                            (None, false) => {}
673                            (Some(level), true) => {
674                                if resolver[level].scalar_kind() != Some(Sk::Sint) {
675                                    return Err(ExpressionError::InvalidImageOtherIndexType(level));
676                                }
677                            }
678                            _ => {
679                                return Err(ExpressionError::InvalidImageOtherIndex);
680                            }
681                        }
682                    }
683                    _ => return Err(ExpressionError::ExpectedImageType(ty)),
684                }
685                ShaderStages::all()
686            }
687            E::ImageQuery { image, query } => {
688                let ty = Self::global_var_ty(module, function, image)?;
689                match module.types[ty].inner {
690                    Ti::Image { class, arrayed, .. } => {
691                        let good = match query {
692                            crate::ImageQuery::NumLayers => arrayed,
693                            crate::ImageQuery::Size { level: None } => true,
694                            crate::ImageQuery::Size { level: Some(_) }
695                            | crate::ImageQuery::NumLevels => class.is_mipmapped(),
696                            crate::ImageQuery::NumSamples => class.is_multisampled(),
697                        };
698                        if !good {
699                            return Err(ExpressionError::InvalidImageClass(class));
700                        }
701                    }
702                    _ => return Err(ExpressionError::ExpectedImageType(ty)),
703                }
704                ShaderStages::all()
705            }
706            E::Unary { op, expr } => {
707                use crate::UnaryOperator as Uo;
708                let inner = &resolver[expr];
709                match (op, inner.scalar_kind()) {
710                    (Uo::Negate, Some(Sk::Float | Sk::Sint))
711                    | (Uo::LogicalNot, Some(Sk::Bool))
712                    | (Uo::BitwiseNot, Some(Sk::Sint | Sk::Uint)) => {}
713                    other => {
714                        log::error!("Op {:?} kind {:?}", op, other);
715                        return Err(ExpressionError::InvalidUnaryOperandType(op, expr));
716                    }
717                }
718                ShaderStages::all()
719            }
720            E::Binary { op, left, right } => {
721                use crate::BinaryOperator as Bo;
722                let left_inner = &resolver[left];
723                let right_inner = &resolver[right];
724                let good = match op {
725                    Bo::Add | Bo::Subtract => match *left_inner {
726                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
727                            Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
728                            Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
729                        },
730                        Ti::Matrix { .. } => left_inner == right_inner,
731                        _ => false,
732                    },
733                    Bo::Divide | Bo::Modulo => match *left_inner {
734                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
735                            Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
736                            Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
737                        },
738                        _ => false,
739                    },
740                    Bo::Multiply => {
741                        let kind_allowed = match left_inner.scalar_kind() {
742                            Some(Sk::Uint | Sk::Sint | Sk::Float) => true,
743                            Some(Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat) | None => false,
744                        };
745                        let types_match = match (left_inner, right_inner) {
746                            // Straight scalar and mixed scalar/vector.
747                            (&Ti::Scalar(scalar1), &Ti::Scalar(scalar2))
748                            | (
749                                &Ti::Vector {
750                                    scalar: scalar1, ..
751                                },
752                                &Ti::Scalar(scalar2),
753                            )
754                            | (
755                                &Ti::Scalar(scalar1),
756                                &Ti::Vector {
757                                    scalar: scalar2, ..
758                                },
759                            ) => scalar1 == scalar2,
760                            // Scalar/matrix.
761                            (
762                                &Ti::Scalar(Sc {
763                                    kind: Sk::Float, ..
764                                }),
765                                &Ti::Matrix { .. },
766                            )
767                            | (
768                                &Ti::Matrix { .. },
769                                &Ti::Scalar(Sc {
770                                    kind: Sk::Float, ..
771                                }),
772                            ) => true,
773                            // Vector/vector.
774                            (
775                                &Ti::Vector {
776                                    size: size1,
777                                    scalar: scalar1,
778                                },
779                                &Ti::Vector {
780                                    size: size2,
781                                    scalar: scalar2,
782                                },
783                            ) => scalar1 == scalar2 && size1 == size2,
784                            // Matrix * vector.
785                            (
786                                &Ti::Matrix { columns, .. },
787                                &Ti::Vector {
788                                    size,
789                                    scalar:
790                                        Sc {
791                                            kind: Sk::Float, ..
792                                        },
793                                },
794                            ) => columns == size,
795                            // Vector * matrix.
796                            (
797                                &Ti::Vector {
798                                    size,
799                                    scalar:
800                                        Sc {
801                                            kind: Sk::Float, ..
802                                        },
803                                },
804                                &Ti::Matrix { rows, .. },
805                            ) => size == rows,
806                            (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => {
807                                columns == rows
808                            }
809                            _ => false,
810                        };
811                        let left_width = left_inner.scalar_width().unwrap_or(0);
812                        let right_width = right_inner.scalar_width().unwrap_or(0);
813                        kind_allowed && types_match && left_width == right_width
814                    }
815                    Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner,
816                    Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => {
817                        match *left_inner {
818                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
819                                Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
820                                Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
821                            },
822                            ref other => {
823                                log::error!("Op {:?} left type {:?}", op, other);
824                                false
825                            }
826                        }
827                    }
828                    Bo::LogicalAnd | Bo::LogicalOr => match *left_inner {
829                        Ti::Scalar(Sc { kind: Sk::Bool, .. })
830                        | Ti::Vector {
831                            scalar: Sc { kind: Sk::Bool, .. },
832                            ..
833                        } => left_inner == right_inner,
834                        ref other => {
835                            log::error!("Op {:?} left type {:?}", op, other);
836                            false
837                        }
838                    },
839                    Bo::And | Bo::InclusiveOr => match *left_inner {
840                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
841                            Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner,
842                            Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
843                        },
844                        ref other => {
845                            log::error!("Op {:?} left type {:?}", op, other);
846                            false
847                        }
848                    },
849                    Bo::ExclusiveOr => match *left_inner {
850                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
851                            Sk::Sint | Sk::Uint => left_inner == right_inner,
852                            Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
853                        },
854                        ref other => {
855                            log::error!("Op {:?} left type {:?}", op, other);
856                            false
857                        }
858                    },
859                    Bo::ShiftLeft | Bo::ShiftRight => {
860                        let (base_size, base_scalar) = match *left_inner {
861                            Ti::Scalar(scalar) => (Ok(None), scalar),
862                            Ti::Vector { size, scalar } => (Ok(Some(size)), scalar),
863                            ref other => {
864                                log::error!("Op {:?} base type {:?}", op, other);
865                                (Err(()), Sc::BOOL)
866                            }
867                        };
868                        let shift_size = match *right_inner {
869                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => Ok(None),
870                            Ti::Vector {
871                                size,
872                                scalar: Sc { kind: Sk::Uint, .. },
873                            } => Ok(Some(size)),
874                            ref other => {
875                                log::error!("Op {:?} shift type {:?}", op, other);
876                                Err(())
877                            }
878                        };
879                        match base_scalar.kind {
880                            Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size,
881                            Sk::Float | Sk::AbstractInt | Sk::AbstractFloat | Sk::Bool => false,
882                        }
883                    }
884                };
885                if !good {
886                    log::error!(
887                        "Left: {:?} of type {:?}",
888                        function.expressions[left],
889                        left_inner
890                    );
891                    log::error!(
892                        "Right: {:?} of type {:?}",
893                        function.expressions[right],
894                        right_inner
895                    );
896                    return Err(ExpressionError::InvalidBinaryOperandTypes {
897                        op,
898                        lhs_expr: left,
899                        lhs_type: left_inner.clone(),
900                        rhs_expr: right,
901                        rhs_type: right_inner.clone(),
902                    });
903                }
904                ShaderStages::all()
905            }
906            E::Select {
907                condition,
908                accept,
909                reject,
910            } => {
911                let accept_inner = &resolver[accept];
912                let reject_inner = &resolver[reject];
913                let condition_ty = &resolver[condition];
914                let condition_good = match *condition_ty {
915                    Ti::Scalar(Sc {
916                        kind: Sk::Bool,
917                        width: _,
918                    }) => {
919                        // When `condition` is a single boolean, `accept` and
920                        // `reject` can be vectors or scalars.
921                        match *accept_inner {
922                            Ti::Scalar { .. } | Ti::Vector { .. } => true,
923                            _ => false,
924                        }
925                    }
926                    Ti::Vector {
927                        size,
928                        scalar:
929                            Sc {
930                                kind: Sk::Bool,
931                                width: _,
932                            },
933                    } => match *accept_inner {
934                        Ti::Vector {
935                            size: other_size, ..
936                        } => size == other_size,
937                        _ => false,
938                    },
939                    _ => false,
940                };
941                if accept_inner != reject_inner {
942                    return Err(ExpressionError::SelectValuesTypeMismatch {
943                        accept: accept_inner.clone(),
944                        reject: reject_inner.clone(),
945                    });
946                }
947                if !condition_good {
948                    return Err(ExpressionError::SelectConditionNotABool {
949                        actual: condition_ty.clone(),
950                    });
951                }
952                ShaderStages::all()
953            }
954            E::Derivative { expr, .. } => {
955                match resolver[expr] {
956                    Ti::Scalar(Sc {
957                        kind: Sk::Float, ..
958                    })
959                    | Ti::Vector {
960                        scalar:
961                            Sc {
962                                kind: Sk::Float, ..
963                            },
964                        ..
965                    } => {}
966                    _ => return Err(ExpressionError::InvalidDerivative),
967                }
968                ShaderStages::FRAGMENT
969            }
970            E::Relational { fun, argument } => {
971                use crate::RelationalFunction as Rf;
972                let argument_inner = &resolver[argument];
973                match fun {
974                    Rf::All | Rf::Any => match *argument_inner {
975                        Ti::Vector {
976                            scalar: Sc { kind: Sk::Bool, .. },
977                            ..
978                        } => {}
979                        ref other => {
980                            log::error!("All/Any of type {:?}", other);
981                            return Err(ExpressionError::InvalidBooleanVector(argument));
982                        }
983                    },
984                    Rf::IsNan | Rf::IsInf => match *argument_inner {
985                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
986                            if scalar.kind == Sk::Float => {}
987                        ref other => {
988                            log::error!("Float test of type {:?}", other);
989                            return Err(ExpressionError::InvalidFloatArgument(argument));
990                        }
991                    },
992                }
993                ShaderStages::all()
994            }
995            E::Math {
996                fun,
997                arg,
998                arg1,
999                arg2,
1000                arg3,
1001            } => {
1002                use crate::MathFunction as Mf;
1003
1004                let resolve = |arg| &resolver[arg];
1005                let arg_ty = resolve(arg);
1006                let arg1_ty = arg1.map(resolve);
1007                let arg2_ty = arg2.map(resolve);
1008                let arg3_ty = arg3.map(resolve);
1009                match fun {
1010                    Mf::Abs => {
1011                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1012                            return Err(ExpressionError::WrongArgumentCount(fun));
1013                        }
1014                        let good = match *arg_ty {
1015                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
1016                                scalar.kind != Sk::Bool
1017                            }
1018                            _ => false,
1019                        };
1020                        if !good {
1021                            return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
1022                        }
1023                    }
1024                    Mf::Min | Mf::Max => {
1025                        let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1026                            (Some(ty1), None, None) => ty1,
1027                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1028                        };
1029                        let good = match *arg_ty {
1030                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
1031                                scalar.kind != Sk::Bool
1032                            }
1033                            _ => false,
1034                        };
1035                        if !good {
1036                            return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
1037                        }
1038                        if arg1_ty != arg_ty {
1039                            return Err(ExpressionError::InvalidArgumentType(
1040                                fun,
1041                                1,
1042                                arg1.unwrap(),
1043                            ));
1044                        }
1045                    }
1046                    Mf::Clamp => {
1047                        let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1048                            (Some(ty1), Some(ty2), None) => (ty1, ty2),
1049                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1050                        };
1051                        let good = match *arg_ty {
1052                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
1053                                scalar.kind != Sk::Bool
1054                            }
1055                            _ => false,
1056                        };
1057                        if !good {
1058                            return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
1059                        }
1060                        if arg1_ty != arg_ty {
1061                            return Err(ExpressionError::InvalidArgumentType(
1062                                fun,
1063                                1,
1064                                arg1.unwrap(),
1065                            ));
1066                        }
1067                        if arg2_ty != arg_ty {
1068                            return Err(ExpressionError::InvalidArgumentType(
1069                                fun,
1070                                2,
1071                                arg2.unwrap(),
1072                            ));
1073                        }
1074                    }
1075                    Mf::Saturate
1076                    | Mf::Cos
1077                    | Mf::Cosh
1078                    | Mf::Sin
1079                    | Mf::Sinh
1080                    | Mf::Tan
1081                    | Mf::Tanh
1082                    | Mf::Acos
1083                    | Mf::Asin
1084                    | Mf::Atan
1085                    | Mf::Asinh
1086                    | Mf::Acosh
1087                    | Mf::Atanh
1088                    | Mf::Radians
1089                    | Mf::Degrees
1090                    | Mf::Ceil
1091                    | Mf::Floor
1092                    | Mf::Round
1093                    | Mf::Fract
1094                    | Mf::Trunc
1095                    | Mf::Exp
1096                    | Mf::Exp2
1097                    | Mf::Log
1098                    | Mf::Log2
1099                    | Mf::Length
1100                    | Mf::Sqrt
1101                    | Mf::InverseSqrt => {
1102                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1103                            return Err(ExpressionError::WrongArgumentCount(fun));
1104                        }
1105                        match *arg_ty {
1106                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1107                                if scalar.kind == Sk::Float => {}
1108                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1109                        }
1110                    }
1111                    Mf::Sign => {
1112                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1113                            return Err(ExpressionError::WrongArgumentCount(fun));
1114                        }
1115                        match *arg_ty {
1116                            Ti::Scalar(Sc {
1117                                kind: Sk::Float | Sk::Sint,
1118                                ..
1119                            })
1120                            | Ti::Vector {
1121                                scalar:
1122                                    Sc {
1123                                        kind: Sk::Float | Sk::Sint,
1124                                        ..
1125                                    },
1126                                ..
1127                            } => {}
1128                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1129                        }
1130                    }
1131                    Mf::Atan2 | Mf::Pow | Mf::Distance | Mf::Step => {
1132                        let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1133                            (Some(ty1), None, None) => ty1,
1134                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1135                        };
1136                        match *arg_ty {
1137                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1138                                if scalar.kind == Sk::Float => {}
1139                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1140                        }
1141                        if arg1_ty != arg_ty {
1142                            return Err(ExpressionError::InvalidArgumentType(
1143                                fun,
1144                                1,
1145                                arg1.unwrap(),
1146                            ));
1147                        }
1148                    }
1149                    Mf::Modf | Mf::Frexp => {
1150                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1151                            return Err(ExpressionError::WrongArgumentCount(fun));
1152                        }
1153                        if !matches!(*arg_ty,
1154                                     Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1155                                     if scalar.kind == Sk::Float)
1156                        {
1157                            return Err(ExpressionError::InvalidArgumentType(fun, 1, arg));
1158                        }
1159                    }
1160                    Mf::Ldexp => {
1161                        let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1162                            (Some(ty1), None, None) => ty1,
1163                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1164                        };
1165                        let size0 = match *arg_ty {
1166                            Ti::Scalar(Sc {
1167                                kind: Sk::Float, ..
1168                            }) => None,
1169                            Ti::Vector {
1170                                scalar:
1171                                    Sc {
1172                                        kind: Sk::Float, ..
1173                                    },
1174                                size,
1175                            } => Some(size),
1176                            _ => {
1177                                return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
1178                            }
1179                        };
1180                        let good = match *arg1_ty {
1181                            Ti::Scalar(Sc { kind: Sk::Sint, .. }) if size0.is_none() => true,
1182                            Ti::Vector {
1183                                size,
1184                                scalar: Sc { kind: Sk::Sint, .. },
1185                            } if Some(size) == size0 => true,
1186                            _ => false,
1187                        };
1188                        if !good {
1189                            return Err(ExpressionError::InvalidArgumentType(
1190                                fun,
1191                                1,
1192                                arg1.unwrap(),
1193                            ));
1194                        }
1195                    }
1196                    Mf::Dot => {
1197                        let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1198                            (Some(ty1), None, None) => ty1,
1199                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1200                        };
1201                        match *arg_ty {
1202                            Ti::Vector {
1203                                scalar:
1204                                    Sc {
1205                                        kind: Sk::Float | Sk::Sint | Sk::Uint,
1206                                        ..
1207                                    },
1208                                ..
1209                            } => {}
1210                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1211                        }
1212                        if arg1_ty != arg_ty {
1213                            return Err(ExpressionError::InvalidArgumentType(
1214                                fun,
1215                                1,
1216                                arg1.unwrap(),
1217                            ));
1218                        }
1219                    }
1220                    Mf::Outer | Mf::Reflect => {
1221                        let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1222                            (Some(ty1), None, None) => ty1,
1223                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1224                        };
1225                        match *arg_ty {
1226                            Ti::Vector {
1227                                scalar:
1228                                    Sc {
1229                                        kind: Sk::Float, ..
1230                                    },
1231                                ..
1232                            } => {}
1233                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1234                        }
1235                        if arg1_ty != arg_ty {
1236                            return Err(ExpressionError::InvalidArgumentType(
1237                                fun,
1238                                1,
1239                                arg1.unwrap(),
1240                            ));
1241                        }
1242                    }
1243                    Mf::Cross => {
1244                        let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1245                            (Some(ty1), None, None) => ty1,
1246                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1247                        };
1248                        match *arg_ty {
1249                            Ti::Vector {
1250                                scalar:
1251                                    Sc {
1252                                        kind: Sk::Float, ..
1253                                    },
1254                                size: crate::VectorSize::Tri,
1255                            } => {}
1256                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1257                        }
1258                        if arg1_ty != arg_ty {
1259                            return Err(ExpressionError::InvalidArgumentType(
1260                                fun,
1261                                1,
1262                                arg1.unwrap(),
1263                            ));
1264                        }
1265                    }
1266                    Mf::Refract => {
1267                        let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1268                            (Some(ty1), Some(ty2), None) => (ty1, ty2),
1269                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1270                        };
1271
1272                        match *arg_ty {
1273                            Ti::Vector {
1274                                scalar:
1275                                    Sc {
1276                                        kind: Sk::Float, ..
1277                                    },
1278                                ..
1279                            } => {}
1280                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1281                        }
1282
1283                        if arg1_ty != arg_ty {
1284                            return Err(ExpressionError::InvalidArgumentType(
1285                                fun,
1286                                1,
1287                                arg1.unwrap(),
1288                            ));
1289                        }
1290
1291                        match (arg_ty, arg2_ty) {
1292                            (
1293                                &Ti::Vector {
1294                                    scalar:
1295                                        Sc {
1296                                            width: vector_width,
1297                                            ..
1298                                        },
1299                                    ..
1300                                },
1301                                &Ti::Scalar(Sc {
1302                                    width: scalar_width,
1303                                    kind: Sk::Float,
1304                                }),
1305                            ) if vector_width == scalar_width => {}
1306                            _ => {
1307                                return Err(ExpressionError::InvalidArgumentType(
1308                                    fun,
1309                                    2,
1310                                    arg2.unwrap(),
1311                                ))
1312                            }
1313                        }
1314                    }
1315                    Mf::Normalize => {
1316                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1317                            return Err(ExpressionError::WrongArgumentCount(fun));
1318                        }
1319                        match *arg_ty {
1320                            Ti::Vector {
1321                                scalar:
1322                                    Sc {
1323                                        kind: Sk::Float, ..
1324                                    },
1325                                ..
1326                            } => {}
1327                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1328                        }
1329                    }
1330                    Mf::FaceForward | Mf::Fma | Mf::SmoothStep => {
1331                        let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1332                            (Some(ty1), Some(ty2), None) => (ty1, ty2),
1333                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1334                        };
1335                        match *arg_ty {
1336                            Ti::Scalar(Sc {
1337                                kind: Sk::Float, ..
1338                            })
1339                            | Ti::Vector {
1340                                scalar:
1341                                    Sc {
1342                                        kind: Sk::Float, ..
1343                                    },
1344                                ..
1345                            } => {}
1346                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1347                        }
1348                        if arg1_ty != arg_ty {
1349                            return Err(ExpressionError::InvalidArgumentType(
1350                                fun,
1351                                1,
1352                                arg1.unwrap(),
1353                            ));
1354                        }
1355                        if arg2_ty != arg_ty {
1356                            return Err(ExpressionError::InvalidArgumentType(
1357                                fun,
1358                                2,
1359                                arg2.unwrap(),
1360                            ));
1361                        }
1362                    }
1363                    Mf::Mix => {
1364                        let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1365                            (Some(ty1), Some(ty2), None) => (ty1, ty2),
1366                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1367                        };
1368                        let arg_width = match *arg_ty {
1369                            Ti::Scalar(Sc {
1370                                kind: Sk::Float,
1371                                width,
1372                            })
1373                            | Ti::Vector {
1374                                scalar:
1375                                    Sc {
1376                                        kind: Sk::Float,
1377                                        width,
1378                                    },
1379                                ..
1380                            } => width,
1381                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1382                        };
1383                        if arg1_ty != arg_ty {
1384                            return Err(ExpressionError::InvalidArgumentType(
1385                                fun,
1386                                1,
1387                                arg1.unwrap(),
1388                            ));
1389                        }
1390                        // the last argument can always be a scalar
1391                        match *arg2_ty {
1392                            Ti::Scalar(Sc {
1393                                kind: Sk::Float,
1394                                width,
1395                            }) if width == arg_width => {}
1396                            _ if arg2_ty == arg_ty => {}
1397                            _ => {
1398                                return Err(ExpressionError::InvalidArgumentType(
1399                                    fun,
1400                                    2,
1401                                    arg2.unwrap(),
1402                                ));
1403                            }
1404                        }
1405                    }
1406                    Mf::Inverse | Mf::Determinant => {
1407                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1408                            return Err(ExpressionError::WrongArgumentCount(fun));
1409                        }
1410                        let good = match *arg_ty {
1411                            Ti::Matrix { columns, rows, .. } => columns == rows,
1412                            _ => false,
1413                        };
1414                        if !good {
1415                            return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
1416                        }
1417                    }
1418                    Mf::Transpose => {
1419                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1420                            return Err(ExpressionError::WrongArgumentCount(fun));
1421                        }
1422                        match *arg_ty {
1423                            Ti::Matrix { .. } => {}
1424                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1425                        }
1426                    }
1427                    Mf::QuantizeToF16 => {
1428                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1429                            return Err(ExpressionError::WrongArgumentCount(fun));
1430                        }
1431                        match *arg_ty {
1432                            Ti::Scalar(Sc {
1433                                kind: Sk::Float,
1434                                width: 4,
1435                            })
1436                            | Ti::Vector {
1437                                scalar:
1438                                    Sc {
1439                                        kind: Sk::Float,
1440                                        width: 4,
1441                                    },
1442                                ..
1443                            } => {}
1444                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1445                        }
1446                    }
1447                    // Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276
1448                    Mf::CountLeadingZeros
1449                    | Mf::CountTrailingZeros
1450                    | Mf::CountOneBits
1451                    | Mf::ReverseBits
1452                    | Mf::FirstLeadingBit
1453                    | Mf::FirstTrailingBit => {
1454                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1455                            return Err(ExpressionError::WrongArgumentCount(fun));
1456                        }
1457                        match *arg_ty {
1458                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
1459                                Sk::Sint | Sk::Uint => {
1460                                    if scalar.width != 4 {
1461                                        return Err(ExpressionError::UnsupportedWidth(
1462                                            fun,
1463                                            scalar.kind,
1464                                            scalar.width,
1465                                        ));
1466                                    }
1467                                }
1468                                _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1469                            },
1470                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1471                        }
1472                    }
1473                    Mf::InsertBits => {
1474                        let (arg1_ty, arg2_ty, arg3_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1475                            (Some(ty1), Some(ty2), Some(ty3)) => (ty1, ty2, ty3),
1476                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1477                        };
1478                        match *arg_ty {
1479                            Ti::Scalar(Sc {
1480                                kind: Sk::Sint | Sk::Uint,
1481                                ..
1482                            })
1483                            | Ti::Vector {
1484                                scalar:
1485                                    Sc {
1486                                        kind: Sk::Sint | Sk::Uint,
1487                                        ..
1488                                    },
1489                                ..
1490                            } => {}
1491                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1492                        }
1493                        if arg1_ty != arg_ty {
1494                            return Err(ExpressionError::InvalidArgumentType(
1495                                fun,
1496                                1,
1497                                arg1.unwrap(),
1498                            ));
1499                        }
1500                        match *arg2_ty {
1501                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1502                            _ => {
1503                                return Err(ExpressionError::InvalidArgumentType(
1504                                    fun,
1505                                    2,
1506                                    arg2.unwrap(),
1507                                ))
1508                            }
1509                        }
1510                        match *arg3_ty {
1511                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1512                            _ => {
1513                                return Err(ExpressionError::InvalidArgumentType(
1514                                    fun,
1515                                    2,
1516                                    arg3.unwrap(),
1517                                ))
1518                            }
1519                        }
1520                        // Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276
1521                        for &arg in [arg_ty, arg1_ty, arg2_ty, arg3_ty].iter() {
1522                            match *arg {
1523                                Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
1524                                    if scalar.width != 4 {
1525                                        return Err(ExpressionError::UnsupportedWidth(
1526                                            fun,
1527                                            scalar.kind,
1528                                            scalar.width,
1529                                        ));
1530                                    }
1531                                }
1532                                _ => {}
1533                            }
1534                        }
1535                    }
1536                    Mf::ExtractBits => {
1537                        let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1538                            (Some(ty1), Some(ty2), None) => (ty1, ty2),
1539                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1540                        };
1541                        match *arg_ty {
1542                            Ti::Scalar(Sc {
1543                                kind: Sk::Sint | Sk::Uint,
1544                                ..
1545                            })
1546                            | Ti::Vector {
1547                                scalar:
1548                                    Sc {
1549                                        kind: Sk::Sint | Sk::Uint,
1550                                        ..
1551                                    },
1552                                ..
1553                            } => {}
1554                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1555                        }
1556                        match *arg1_ty {
1557                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1558                            _ => {
1559                                return Err(ExpressionError::InvalidArgumentType(
1560                                    fun,
1561                                    2,
1562                                    arg1.unwrap(),
1563                                ))
1564                            }
1565                        }
1566                        match *arg2_ty {
1567                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1568                            _ => {
1569                                return Err(ExpressionError::InvalidArgumentType(
1570                                    fun,
1571                                    2,
1572                                    arg2.unwrap(),
1573                                ))
1574                            }
1575                        }
1576                        // Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276
1577                        for &arg in [arg_ty, arg1_ty, arg2_ty].iter() {
1578                            match *arg {
1579                                Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
1580                                    if scalar.width != 4 {
1581                                        return Err(ExpressionError::UnsupportedWidth(
1582                                            fun,
1583                                            scalar.kind,
1584                                            scalar.width,
1585                                        ));
1586                                    }
1587                                }
1588                                _ => {}
1589                            }
1590                        }
1591                    }
1592                    Mf::Pack2x16unorm | Mf::Pack2x16snorm | Mf::Pack2x16float => {
1593                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1594                            return Err(ExpressionError::WrongArgumentCount(fun));
1595                        }
1596                        match *arg_ty {
1597                            Ti::Vector {
1598                                size: crate::VectorSize::Bi,
1599                                scalar:
1600                                    Sc {
1601                                        kind: Sk::Float, ..
1602                                    },
1603                            } => {}
1604                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1605                        }
1606                    }
1607                    Mf::Pack4x8snorm | Mf::Pack4x8unorm => {
1608                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1609                            return Err(ExpressionError::WrongArgumentCount(fun));
1610                        }
1611                        match *arg_ty {
1612                            Ti::Vector {
1613                                size: crate::VectorSize::Quad,
1614                                scalar:
1615                                    Sc {
1616                                        kind: Sk::Float, ..
1617                                    },
1618                            } => {}
1619                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1620                        }
1621                    }
1622                    mf @ (Mf::Pack4xI8 | Mf::Pack4xU8) => {
1623                        let scalar_kind = match mf {
1624                            Mf::Pack4xI8 => Sk::Sint,
1625                            Mf::Pack4xU8 => Sk::Uint,
1626                            _ => unreachable!(),
1627                        };
1628                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1629                            return Err(ExpressionError::WrongArgumentCount(fun));
1630                        }
1631                        match *arg_ty {
1632                            Ti::Vector {
1633                                size: crate::VectorSize::Quad,
1634                                scalar: Sc { kind, .. },
1635                            } if kind == scalar_kind => {}
1636                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1637                        }
1638                    }
1639                    Mf::Unpack2x16float
1640                    | Mf::Unpack2x16snorm
1641                    | Mf::Unpack2x16unorm
1642                    | Mf::Unpack4x8snorm
1643                    | Mf::Unpack4x8unorm
1644                    | Mf::Unpack4xI8
1645                    | Mf::Unpack4xU8 => {
1646                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1647                            return Err(ExpressionError::WrongArgumentCount(fun));
1648                        }
1649                        match *arg_ty {
1650                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1651                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1652                        }
1653                    }
1654                }
1655                ShaderStages::all()
1656            }
1657            E::As {
1658                expr,
1659                kind,
1660                convert,
1661            } => {
1662                let mut base_scalar = match resolver[expr] {
1663                    crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } => {
1664                        scalar
1665                    }
1666                    crate::TypeInner::Matrix { scalar, .. } => scalar,
1667                    _ => return Err(ExpressionError::InvalidCastArgument),
1668                };
1669                base_scalar.kind = kind;
1670                if let Some(width) = convert {
1671                    base_scalar.width = width;
1672                }
1673                if self.check_width(base_scalar).is_err() {
1674                    return Err(ExpressionError::InvalidCastArgument);
1675                }
1676                ShaderStages::all()
1677            }
1678            E::CallResult(function) => mod_info.functions[function.index()].available_stages,
1679            E::AtomicResult { .. } => {
1680                // These expressions are validated when we check the `Atomic` statement
1681                // that refers to them, because we have all the information we need at
1682                // that point. The checks driven by `Validator::needs_visit` ensure
1683                // that this expression is indeed visited by one `Atomic` statement.
1684                ShaderStages::all()
1685            }
1686            E::WorkGroupUniformLoadResult { ty } => {
1687                if self.types[ty.index()]
1688                    .flags
1689                    // Sized | Constructible is exactly the types currently supported by
1690                    // WorkGroupUniformLoad
1691                    .contains(TypeFlags::SIZED | TypeFlags::CONSTRUCTIBLE)
1692                {
1693                    ShaderStages::COMPUTE
1694                } else {
1695                    return Err(ExpressionError::InvalidWorkGroupUniformLoadResultType(ty));
1696                }
1697            }
1698            E::ArrayLength(expr) => match resolver[expr] {
1699                Ti::Pointer { base, .. } => {
1700                    let base_ty = &resolver.types[base];
1701                    if let Ti::Array {
1702                        size: crate::ArraySize::Dynamic,
1703                        ..
1704                    } = base_ty.inner
1705                    {
1706                        ShaderStages::all()
1707                    } else {
1708                        return Err(ExpressionError::InvalidArrayType(expr));
1709                    }
1710                }
1711                ref other => {
1712                    log::error!("Array length of {:?}", other);
1713                    return Err(ExpressionError::InvalidArrayType(expr));
1714                }
1715            },
1716            E::RayQueryProceedResult => ShaderStages::all(),
1717            E::RayQueryGetIntersection {
1718                query,
1719                committed: _,
1720            } => match resolver[query] {
1721                Ti::Pointer {
1722                    base,
1723                    space: crate::AddressSpace::Function,
1724                } => match resolver.types[base].inner {
1725                    Ti::RayQuery => ShaderStages::all(),
1726                    ref other => {
1727                        log::error!("Intersection result of a pointer to {:?}", other);
1728                        return Err(ExpressionError::InvalidRayQueryType(query));
1729                    }
1730                },
1731                ref other => {
1732                    log::error!("Intersection result of {:?}", other);
1733                    return Err(ExpressionError::InvalidRayQueryType(query));
1734                }
1735            },
1736            E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages,
1737        };
1738        Ok(stages)
1739    }
1740
1741    fn global_var_ty(
1742        module: &crate::Module,
1743        function: &crate::Function,
1744        expr: Handle<crate::Expression>,
1745    ) -> Result<Handle<crate::Type>, ExpressionError> {
1746        use crate::Expression as Ex;
1747
1748        match function.expressions[expr] {
1749            Ex::GlobalVariable(var_handle) => Ok(module.global_variables[var_handle].ty),
1750            Ex::FunctionArgument(i) => Ok(function.arguments[i as usize].ty),
1751            Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
1752                match function.expressions[base] {
1753                    Ex::GlobalVariable(var_handle) => {
1754                        let array_ty = module.global_variables[var_handle].ty;
1755
1756                        match module.types[array_ty].inner {
1757                            crate::TypeInner::BindingArray { base, .. } => Ok(base),
1758                            _ => Err(ExpressionError::ExpectedBindingArrayType(array_ty)),
1759                        }
1760                    }
1761                    _ => Err(ExpressionError::ExpectedGlobalVariable),
1762                }
1763            }
1764            _ => Err(ExpressionError::ExpectedGlobalVariable),
1765        }
1766    }
1767
1768    pub fn validate_literal(&self, literal: crate::Literal) -> Result<(), LiteralError> {
1769        self.check_width(literal.scalar())?;
1770        check_literal_value(literal)?;
1771
1772        Ok(())
1773    }
1774}
1775
1776pub fn check_literal_value(literal: crate::Literal) -> Result<(), LiteralError> {
1777    let is_nan = match literal {
1778        crate::Literal::F64(v) => v.is_nan(),
1779        crate::Literal::F32(v) => v.is_nan(),
1780        _ => false,
1781    };
1782    if is_nan {
1783        return Err(LiteralError::NaN);
1784    }
1785
1786    let is_infinite = match literal {
1787        crate::Literal::F64(v) => v.is_infinite(),
1788        crate::Literal::F32(v) => v.is_infinite(),
1789        _ => false,
1790    };
1791    if is_infinite {
1792        return Err(LiteralError::Infinity);
1793    }
1794
1795    Ok(())
1796}
1797
1798#[cfg(test)]
1799/// Validate a module containing the given expression, expecting an error.
1800fn validate_with_expression(
1801    expr: crate::Expression,
1802    caps: super::Capabilities,
1803) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1804    use crate::span::Span;
1805
1806    let mut function = crate::Function::default();
1807    function.expressions.append(expr, Span::default());
1808    function.body.push(
1809        crate::Statement::Emit(function.expressions.range_from(0)),
1810        Span::default(),
1811    );
1812
1813    let mut module = crate::Module::default();
1814    module.functions.append(function, Span::default());
1815
1816    let mut validator = super::Validator::new(super::ValidationFlags::EXPRESSIONS, caps);
1817
1818    validator.validate(&module)
1819}
1820
1821#[cfg(test)]
1822/// Validate a module containing the given constant expression, expecting an error.
1823fn validate_with_const_expression(
1824    expr: crate::Expression,
1825    caps: super::Capabilities,
1826) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1827    use crate::span::Span;
1828
1829    let mut module = crate::Module::default();
1830    module.global_expressions.append(expr, Span::default());
1831
1832    let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps);
1833
1834    validator.validate(&module)
1835}
1836
1837/// Using F64 in a function's expression arena is forbidden.
1838#[test]
1839fn f64_runtime_literals() {
1840    let result = validate_with_expression(
1841        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1842        super::Capabilities::default(),
1843    );
1844    let error = result.unwrap_err().into_inner();
1845    assert!(matches!(
1846        error,
1847        crate::valid::ValidationError::Function {
1848            source: super::FunctionError::Expression {
1849                source: ExpressionError::Literal(LiteralError::Width(
1850                    super::r#type::WidthError::MissingCapability {
1851                        name: "f64",
1852                        flag: "FLOAT64",
1853                    }
1854                ),),
1855                ..
1856            },
1857            ..
1858        }
1859    ));
1860
1861    let result = validate_with_expression(
1862        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1863        super::Capabilities::default() | super::Capabilities::FLOAT64,
1864    );
1865    assert!(result.is_ok());
1866}
1867
1868/// Using F64 in a module's constant expression arena is forbidden.
1869#[test]
1870fn f64_const_literals() {
1871    let result = validate_with_const_expression(
1872        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1873        super::Capabilities::default(),
1874    );
1875    let error = result.unwrap_err().into_inner();
1876    assert!(matches!(
1877        error,
1878        crate::valid::ValidationError::ConstExpression {
1879            source: ConstExpressionError::Literal(LiteralError::Width(
1880                super::r#type::WidthError::MissingCapability {
1881                    name: "f64",
1882                    flag: "FLOAT64",
1883                }
1884            )),
1885            ..
1886        }
1887    ));
1888
1889    let result = validate_with_const_expression(
1890        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1891        super::Capabilities::default() | super::Capabilities::FLOAT64,
1892    );
1893    assert!(result.is_ok());
1894}