naga/valid/
expression.rs

1use super::{compose::validate_compose, FunctionInfo, ModuleInfo, ShaderStages, TypeFlags};
2use crate::arena::UniqueArena;
3
4use crate::{
5    arena::Handle,
6    proc::{IndexableLengthError, ResolveError},
7};
8
9#[derive(Clone, Debug, thiserror::Error)]
10#[cfg_attr(test, derive(PartialEq))]
11pub enum ExpressionError {
12    #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")]
13    NotInScope,
14    #[error("Base type {0:?} is not compatible with this expression")]
15    InvalidBaseType(Handle<crate::Expression>),
16    #[error("Accessing with index {0:?} can't be done")]
17    InvalidIndexType(Handle<crate::Expression>),
18    #[error("Accessing {0:?} via a negative index is invalid")]
19    NegativeIndex(Handle<crate::Expression>),
20    #[error("Accessing index {1} is out of {0:?} bounds")]
21    IndexOutOfBounds(Handle<crate::Expression>, u32),
22    #[error("Function argument {0:?} doesn't exist")]
23    FunctionArgumentDoesntExist(u32),
24    #[error("Loading of {0:?} can't be done")]
25    InvalidPointerType(Handle<crate::Expression>),
26    #[error("Array length of {0:?} can't be done")]
27    InvalidArrayType(Handle<crate::Expression>),
28    #[error("Get intersection of {0:?} can't be done")]
29    InvalidRayQueryType(Handle<crate::Expression>),
30    #[error("Splatting {0:?} can't be done")]
31    InvalidSplatType(Handle<crate::Expression>),
32    #[error("Swizzling {0:?} can't be done")]
33    InvalidVectorType(Handle<crate::Expression>),
34    #[error("Swizzle component {0:?} is outside of vector size {1:?}")]
35    InvalidSwizzleComponent(crate::SwizzleComponent, crate::VectorSize),
36    #[error(transparent)]
37    Compose(#[from] super::ComposeError),
38    #[error(transparent)]
39    IndexableLength(#[from] IndexableLengthError),
40    #[error("Operation {0:?} can't work with {1:?}")]
41    InvalidUnaryOperandType(crate::UnaryOperator, Handle<crate::Expression>),
42    #[error("Operation {0:?} can't work with {1:?} and {2:?}")]
43    InvalidBinaryOperandTypes(
44        crate::BinaryOperator,
45        Handle<crate::Expression>,
46        Handle<crate::Expression>,
47    ),
48    #[error("Selecting is not possible")]
49    InvalidSelectTypes,
50    #[error("Relational argument {0:?} is not a boolean vector")]
51    InvalidBooleanVector(Handle<crate::Expression>),
52    #[error("Relational argument {0:?} is not a float")]
53    InvalidFloatArgument(Handle<crate::Expression>),
54    #[error("Type resolution failed")]
55    Type(#[from] ResolveError),
56    #[error("Not a global variable")]
57    ExpectedGlobalVariable,
58    #[error("Not a global variable or a function argument")]
59    ExpectedGlobalOrArgument,
60    #[error("Needs to be an binding array instead of {0:?}")]
61    ExpectedBindingArrayType(Handle<crate::Type>),
62    #[error("Needs to be an image instead of {0:?}")]
63    ExpectedImageType(Handle<crate::Type>),
64    #[error("Needs to be an image instead of {0:?}")]
65    ExpectedSamplerType(Handle<crate::Type>),
66    #[error("Unable to operate on image class {0:?}")]
67    InvalidImageClass(crate::ImageClass),
68    #[error("Derivatives can only be taken from scalar and vector floats")]
69    InvalidDerivative,
70    #[error("Image array index parameter is misplaced")]
71    InvalidImageArrayIndex,
72    #[error("Inappropriate sample or level-of-detail index for texel access")]
73    InvalidImageOtherIndex,
74    #[error("Image array index type of {0:?} is not an integer scalar")]
75    InvalidImageArrayIndexType(Handle<crate::Expression>),
76    #[error("Image sample or level-of-detail index's type of {0:?} is not an integer scalar")]
77    InvalidImageOtherIndexType(Handle<crate::Expression>),
78    #[error("Image coordinate type of {1:?} does not match dimension {0:?}")]
79    InvalidImageCoordinateType(crate::ImageDimension, Handle<crate::Expression>),
80    #[error("Comparison sampling mismatch: image has class {image:?}, but the sampler is comparison={sampler}, and the reference was provided={has_ref}")]
81    ComparisonSamplingMismatch {
82        image: crate::ImageClass,
83        sampler: bool,
84        has_ref: bool,
85    },
86    #[error("Sample offset must be a const-expression")]
87    InvalidSampleOffsetExprType,
88    #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")]
89    InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>),
90    #[error("Depth reference {0:?} is not a scalar float")]
91    InvalidDepthReference(Handle<crate::Expression>),
92    #[error("Depth sample level can only be Auto or Zero")]
93    InvalidDepthSampleLevel,
94    #[error("Gather level can only be Zero")]
95    InvalidGatherLevel,
96    #[error("Gather component {0:?} doesn't exist in the image")]
97    InvalidGatherComponent(crate::SwizzleComponent),
98    #[error("Gather can't be done for image dimension {0:?}")]
99    InvalidGatherDimension(crate::ImageDimension),
100    #[error("Sample level (exact) type {0:?} is not a scalar float")]
101    InvalidSampleLevelExactType(Handle<crate::Expression>),
102    #[error("Sample level (bias) type {0:?} is not a scalar float")]
103    InvalidSampleLevelBiasType(Handle<crate::Expression>),
104    #[error("Sample level (gradient) of {1:?} doesn't match the image dimension {0:?}")]
105    InvalidSampleLevelGradientType(crate::ImageDimension, Handle<crate::Expression>),
106    #[error("Unable to cast")]
107    InvalidCastArgument,
108    #[error("Invalid argument count for {0:?}")]
109    WrongArgumentCount(crate::MathFunction),
110    #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
111    InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
112    #[error(
113        "workgroupUniformLoad result type can't be {0:?}. It can only be a constructible type."
114    )]
115    InvalidWorkGroupUniformLoadResultType(Handle<crate::Type>),
116    #[error("Shader requires capability {0:?}")]
117    MissingCapabilities(super::Capabilities),
118    #[error(transparent)]
119    Literal(#[from] LiteralError),
120    #[error("{0:?} is not supported for Width {2} {1:?} arguments yet, see https://github.com/gfx-rs/wgpu/issues/5276")]
121    UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes),
122}
123
124#[derive(Clone, Debug, thiserror::Error)]
125#[cfg_attr(test, derive(PartialEq))]
126pub enum ConstExpressionError {
127    #[error("The expression is not a constant or override expression")]
128    NonConstOrOverride,
129    #[error("The expression is not a fully evaluated constant expression")]
130    NonFullyEvaluatedConst,
131    #[error(transparent)]
132    Compose(#[from] super::ComposeError),
133    #[error("Splatting {0:?} can't be done")]
134    InvalidSplatType(Handle<crate::Expression>),
135    #[error("Type resolution failed")]
136    Type(#[from] ResolveError),
137    #[error(transparent)]
138    Literal(#[from] LiteralError),
139    #[error(transparent)]
140    Width(#[from] super::r#type::WidthError),
141}
142
143#[derive(Clone, Debug, thiserror::Error)]
144#[cfg_attr(test, derive(PartialEq))]
145pub enum LiteralError {
146    #[error("Float literal is NaN")]
147    NaN,
148    #[error("Float literal is infinite")]
149    Infinity,
150    #[error(transparent)]
151    Width(#[from] super::r#type::WidthError),
152}
153
154struct ExpressionTypeResolver<'a> {
155    root: Handle<crate::Expression>,
156    types: &'a UniqueArena<crate::Type>,
157    info: &'a FunctionInfo,
158}
159
160impl<'a> std::ops::Index<Handle<crate::Expression>> for ExpressionTypeResolver<'a> {
161    type Output = crate::TypeInner;
162
163    #[allow(clippy::panic)]
164    fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
165        if handle < self.root {
166            self.info[handle].ty.inner_with(self.types)
167        } else {
168            // `Validator::validate_module_handles` should have caught this.
169            panic!(
170                "Depends on {:?}, which has not been processed yet",
171                self.root
172            )
173        }
174    }
175}
176
177impl super::Validator {
178    pub(super) fn validate_const_expression(
179        &self,
180        handle: Handle<crate::Expression>,
181        gctx: crate::proc::GlobalCtx,
182        mod_info: &ModuleInfo,
183        global_expr_kind: &crate::proc::ExpressionKindTracker,
184    ) -> Result<(), ConstExpressionError> {
185        use crate::Expression as E;
186
187        if !global_expr_kind.is_const_or_override(handle) {
188            return Err(ConstExpressionError::NonConstOrOverride);
189        }
190
191        match gctx.global_expressions[handle] {
192            E::Literal(literal) => {
193                self.validate_literal(literal)?;
194            }
195            E::Constant(_) | E::ZeroValue(_) => {}
196            E::Compose { ref components, ty } => {
197                validate_compose(
198                    ty,
199                    gctx,
200                    components.iter().map(|&handle| mod_info[handle].clone()),
201                )?;
202            }
203            E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) {
204                crate::TypeInner::Scalar { .. } => {}
205                _ => return Err(ConstExpressionError::InvalidSplatType(value)),
206            },
207            _ if global_expr_kind.is_const(handle) || !self.allow_overrides => {
208                return Err(ConstExpressionError::NonFullyEvaluatedConst)
209            }
210            // the constant evaluator will report errors about override-expressions
211            _ => {}
212        }
213
214        Ok(())
215    }
216
217    #[allow(clippy::too_many_arguments)]
218    pub(super) fn validate_expression(
219        &self,
220        root: Handle<crate::Expression>,
221        expression: &crate::Expression,
222        function: &crate::Function,
223        module: &crate::Module,
224        info: &FunctionInfo,
225        mod_info: &ModuleInfo,
226        global_expr_kind: &crate::proc::ExpressionKindTracker,
227    ) -> Result<ShaderStages, ExpressionError> {
228        use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti};
229
230        let resolver = ExpressionTypeResolver {
231            root,
232            types: &module.types,
233            info,
234        };
235
236        let stages = match *expression {
237            E::Access { base, index } => {
238                let base_type = &resolver[base];
239                match *base_type {
240                    Ti::Matrix { .. }
241                    | Ti::Vector { .. }
242                    | Ti::Array { .. }
243                    | Ti::Pointer { .. }
244                    | Ti::ValuePointer { size: Some(_), .. }
245                    | Ti::BindingArray { .. } => false,
246                    ref other => {
247                        log::error!("Indexing of {:?}", other);
248                        return Err(ExpressionError::InvalidBaseType(base));
249                    }
250                };
251                match resolver[index] {
252                    //TODO: only allow one of these
253                    Ti::Scalar(Sc {
254                        kind: Sk::Sint | Sk::Uint,
255                        ..
256                    }) => {}
257                    ref other => {
258                        log::error!("Indexing by {:?}", other);
259                        return Err(ExpressionError::InvalidIndexType(index));
260                    }
261                }
262
263                // If we know both the length and the index, we can do the
264                // bounds check now.
265                if let crate::proc::IndexableLength::Known(known_length) =
266                    base_type.indexable_length(module)?
267                {
268                    match module
269                        .to_ctx()
270                        .eval_expr_to_u32_from(index, &function.expressions)
271                    {
272                        Ok(value) => {
273                            if value >= known_length {
274                                return Err(ExpressionError::IndexOutOfBounds(base, value));
275                            }
276                        }
277                        Err(crate::proc::U32EvalError::Negative) => {
278                            return Err(ExpressionError::NegativeIndex(base))
279                        }
280                        Err(crate::proc::U32EvalError::NonConst) => {}
281                    }
282                }
283
284                ShaderStages::all()
285            }
286            E::AccessIndex { base, index } => {
287                fn resolve_index_limit(
288                    module: &crate::Module,
289                    top: Handle<crate::Expression>,
290                    ty: &crate::TypeInner,
291                    top_level: bool,
292                ) -> Result<u32, ExpressionError> {
293                    let limit = match *ty {
294                        Ti::Vector { size, .. }
295                        | Ti::ValuePointer {
296                            size: Some(size), ..
297                        } => size as u32,
298                        Ti::Matrix { columns, .. } => columns as u32,
299                        Ti::Array {
300                            size: crate::ArraySize::Constant(len),
301                            ..
302                        } => len.get(),
303                        Ti::Array { .. } | Ti::BindingArray { .. } => u32::MAX, // can't statically know, but need run-time checks
304                        Ti::Pointer { base, .. } if top_level => {
305                            resolve_index_limit(module, top, &module.types[base].inner, false)?
306                        }
307                        Ti::Struct { ref members, .. } => members.len() as u32,
308                        ref other => {
309                            log::error!("Indexing of {:?}", other);
310                            return Err(ExpressionError::InvalidBaseType(top));
311                        }
312                    };
313                    Ok(limit)
314                }
315
316                let limit = resolve_index_limit(module, base, &resolver[base], true)?;
317                if index >= limit {
318                    return Err(ExpressionError::IndexOutOfBounds(base, limit));
319                }
320                ShaderStages::all()
321            }
322            E::Splat { size: _, value } => match resolver[value] {
323                Ti::Scalar { .. } => ShaderStages::all(),
324                ref other => {
325                    log::error!("Splat scalar type {:?}", other);
326                    return Err(ExpressionError::InvalidSplatType(value));
327                }
328            },
329            E::Swizzle {
330                size,
331                vector,
332                pattern,
333            } => {
334                let vec_size = match resolver[vector] {
335                    Ti::Vector { size: vec_size, .. } => vec_size,
336                    ref other => {
337                        log::error!("Swizzle vector type {:?}", other);
338                        return Err(ExpressionError::InvalidVectorType(vector));
339                    }
340                };
341                for &sc in pattern[..size as usize].iter() {
342                    if sc as u8 >= vec_size as u8 {
343                        return Err(ExpressionError::InvalidSwizzleComponent(sc, vec_size));
344                    }
345                }
346                ShaderStages::all()
347            }
348            E::Literal(literal) => {
349                self.validate_literal(literal)?;
350                ShaderStages::all()
351            }
352            E::Constant(_) | E::Override(_) | E::ZeroValue(_) => ShaderStages::all(),
353            E::Compose { ref components, ty } => {
354                validate_compose(
355                    ty,
356                    module.to_ctx(),
357                    components.iter().map(|&handle| info[handle].ty.clone()),
358                )?;
359                ShaderStages::all()
360            }
361            E::FunctionArgument(index) => {
362                if index >= function.arguments.len() as u32 {
363                    return Err(ExpressionError::FunctionArgumentDoesntExist(index));
364                }
365                ShaderStages::all()
366            }
367            E::GlobalVariable(_handle) => ShaderStages::all(),
368            E::LocalVariable(_handle) => ShaderStages::all(),
369            E::Load { pointer } => {
370                match resolver[pointer] {
371                    Ti::Pointer { base, .. }
372                        if self.types[base.index()]
373                            .flags
374                            .contains(TypeFlags::SIZED | TypeFlags::DATA) => {}
375                    Ti::ValuePointer { .. } => {}
376                    ref other => {
377                        log::error!("Loading {:?}", other);
378                        return Err(ExpressionError::InvalidPointerType(pointer));
379                    }
380                }
381                ShaderStages::all()
382            }
383            E::ImageSample {
384                image,
385                sampler,
386                gather,
387                coordinate,
388                array_index,
389                offset,
390                level,
391                depth_ref,
392            } => {
393                // check the validity of expressions
394                let image_ty = Self::global_var_ty(module, function, image)?;
395                let sampler_ty = Self::global_var_ty(module, function, sampler)?;
396
397                let comparison = match module.types[sampler_ty].inner {
398                    Ti::Sampler { comparison } => comparison,
399                    _ => return Err(ExpressionError::ExpectedSamplerType(sampler_ty)),
400                };
401
402                let (class, dim) = match module.types[image_ty].inner {
403                    Ti::Image {
404                        class,
405                        arrayed,
406                        dim,
407                    } => {
408                        // check the array property
409                        if arrayed != array_index.is_some() {
410                            return Err(ExpressionError::InvalidImageArrayIndex);
411                        }
412                        if let Some(expr) = array_index {
413                            match resolver[expr] {
414                                Ti::Scalar(Sc {
415                                    kind: Sk::Sint | Sk::Uint,
416                                    ..
417                                }) => {}
418                                _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
419                            }
420                        }
421                        (class, dim)
422                    }
423                    _ => return Err(ExpressionError::ExpectedImageType(image_ty)),
424                };
425
426                // check sampling and comparison properties
427                let image_depth = match class {
428                    crate::ImageClass::Sampled {
429                        kind: crate::ScalarKind::Float,
430                        multi: false,
431                    } => false,
432                    crate::ImageClass::Sampled {
433                        kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint,
434                        multi: false,
435                    } if gather.is_some() => false,
436                    crate::ImageClass::Depth { multi: false } => true,
437                    _ => return Err(ExpressionError::InvalidImageClass(class)),
438                };
439                if comparison != depth_ref.is_some() || (comparison && !image_depth) {
440                    return Err(ExpressionError::ComparisonSamplingMismatch {
441                        image: class,
442                        sampler: comparison,
443                        has_ref: depth_ref.is_some(),
444                    });
445                }
446
447                // check texture coordinates type
448                let num_components = match dim {
449                    crate::ImageDimension::D1 => 1,
450                    crate::ImageDimension::D2 => 2,
451                    crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3,
452                };
453                match resolver[coordinate] {
454                    Ti::Scalar(Sc {
455                        kind: Sk::Float, ..
456                    }) if num_components == 1 => {}
457                    Ti::Vector {
458                        size,
459                        scalar:
460                            Sc {
461                                kind: Sk::Float, ..
462                            },
463                    } if size as u32 == num_components => {}
464                    _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
465                }
466
467                // check constant offset
468                if let Some(const_expr) = offset {
469                    if !global_expr_kind.is_const(const_expr) {
470                        return Err(ExpressionError::InvalidSampleOffsetExprType);
471                    }
472
473                    match *mod_info[const_expr].inner_with(&module.types) {
474                        Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {}
475                        Ti::Vector {
476                            size,
477                            scalar: Sc { kind: Sk::Sint, .. },
478                        } if size as u32 == num_components => {}
479                        _ => {
480                            return Err(ExpressionError::InvalidSampleOffset(dim, const_expr));
481                        }
482                    }
483                }
484
485                // check depth reference type
486                if let Some(expr) = depth_ref {
487                    match resolver[expr] {
488                        Ti::Scalar(Sc {
489                            kind: Sk::Float, ..
490                        }) => {}
491                        _ => return Err(ExpressionError::InvalidDepthReference(expr)),
492                    }
493                    match level {
494                        crate::SampleLevel::Auto | crate::SampleLevel::Zero => {}
495                        _ => return Err(ExpressionError::InvalidDepthSampleLevel),
496                    }
497                }
498
499                if let Some(component) = gather {
500                    match dim {
501                        crate::ImageDimension::D2 | crate::ImageDimension::Cube => {}
502                        crate::ImageDimension::D1 | crate::ImageDimension::D3 => {
503                            return Err(ExpressionError::InvalidGatherDimension(dim))
504                        }
505                    };
506                    let max_component = match class {
507                        crate::ImageClass::Depth { .. } => crate::SwizzleComponent::X,
508                        _ => crate::SwizzleComponent::W,
509                    };
510                    if component > max_component {
511                        return Err(ExpressionError::InvalidGatherComponent(component));
512                    }
513                    match level {
514                        crate::SampleLevel::Zero => {}
515                        _ => return Err(ExpressionError::InvalidGatherLevel),
516                    }
517                }
518
519                // check level properties
520                match level {
521                    crate::SampleLevel::Auto => ShaderStages::FRAGMENT,
522                    crate::SampleLevel::Zero => ShaderStages::all(),
523                    crate::SampleLevel::Exact(expr) => {
524                        match resolver[expr] {
525                            Ti::Scalar(Sc {
526                                kind: Sk::Float, ..
527                            }) => {}
528                            _ => return Err(ExpressionError::InvalidSampleLevelExactType(expr)),
529                        }
530                        ShaderStages::all()
531                    }
532                    crate::SampleLevel::Bias(expr) => {
533                        match resolver[expr] {
534                            Ti::Scalar(Sc {
535                                kind: Sk::Float, ..
536                            }) => {}
537                            _ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)),
538                        }
539                        ShaderStages::FRAGMENT
540                    }
541                    crate::SampleLevel::Gradient { x, y } => {
542                        match resolver[x] {
543                            Ti::Scalar(Sc {
544                                kind: Sk::Float, ..
545                            }) if num_components == 1 => {}
546                            Ti::Vector {
547                                size,
548                                scalar:
549                                    Sc {
550                                        kind: Sk::Float, ..
551                                    },
552                            } if size as u32 == num_components => {}
553                            _ => {
554                                return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x))
555                            }
556                        }
557                        match resolver[y] {
558                            Ti::Scalar(Sc {
559                                kind: Sk::Float, ..
560                            }) if num_components == 1 => {}
561                            Ti::Vector {
562                                size,
563                                scalar:
564                                    Sc {
565                                        kind: Sk::Float, ..
566                                    },
567                            } if size as u32 == num_components => {}
568                            _ => {
569                                return Err(ExpressionError::InvalidSampleLevelGradientType(dim, y))
570                            }
571                        }
572                        ShaderStages::all()
573                    }
574                }
575            }
576            E::ImageLoad {
577                image,
578                coordinate,
579                array_index,
580                sample,
581                level,
582            } => {
583                let ty = Self::global_var_ty(module, function, image)?;
584                match module.types[ty].inner {
585                    Ti::Image {
586                        class,
587                        arrayed,
588                        dim,
589                    } => {
590                        match resolver[coordinate].image_storage_coordinates() {
591                            Some(coord_dim) if coord_dim == dim => {}
592                            _ => {
593                                return Err(ExpressionError::InvalidImageCoordinateType(
594                                    dim, coordinate,
595                                ))
596                            }
597                        };
598                        if arrayed != array_index.is_some() {
599                            return Err(ExpressionError::InvalidImageArrayIndex);
600                        }
601                        if let Some(expr) = array_index {
602                            match resolver[expr] {
603                                Ti::Scalar(Sc {
604                                    kind: Sk::Sint | Sk::Uint,
605                                    width: _,
606                                }) => {}
607                                _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
608                            }
609                        }
610
611                        match (sample, class.is_multisampled()) {
612                            (None, false) => {}
613                            (Some(sample), true) => {
614                                if resolver[sample].scalar_kind() != Some(Sk::Sint) {
615                                    return Err(ExpressionError::InvalidImageOtherIndexType(
616                                        sample,
617                                    ));
618                                }
619                            }
620                            _ => {
621                                return Err(ExpressionError::InvalidImageOtherIndex);
622                            }
623                        }
624
625                        match (level, class.is_mipmapped()) {
626                            (None, false) => {}
627                            (Some(level), true) => {
628                                if resolver[level].scalar_kind() != Some(Sk::Sint) {
629                                    return Err(ExpressionError::InvalidImageOtherIndexType(level));
630                                }
631                            }
632                            _ => {
633                                return Err(ExpressionError::InvalidImageOtherIndex);
634                            }
635                        }
636                    }
637                    _ => return Err(ExpressionError::ExpectedImageType(ty)),
638                }
639                ShaderStages::all()
640            }
641            E::ImageQuery { image, query } => {
642                let ty = Self::global_var_ty(module, function, image)?;
643                match module.types[ty].inner {
644                    Ti::Image { class, arrayed, .. } => {
645                        let good = match query {
646                            crate::ImageQuery::NumLayers => arrayed,
647                            crate::ImageQuery::Size { level: None } => true,
648                            crate::ImageQuery::Size { level: Some(_) }
649                            | crate::ImageQuery::NumLevels => class.is_mipmapped(),
650                            crate::ImageQuery::NumSamples => class.is_multisampled(),
651                        };
652                        if !good {
653                            return Err(ExpressionError::InvalidImageClass(class));
654                        }
655                    }
656                    _ => return Err(ExpressionError::ExpectedImageType(ty)),
657                }
658                ShaderStages::all()
659            }
660            E::Unary { op, expr } => {
661                use crate::UnaryOperator as Uo;
662                let inner = &resolver[expr];
663                match (op, inner.scalar_kind()) {
664                    (Uo::Negate, Some(Sk::Float | Sk::Sint))
665                    | (Uo::LogicalNot, Some(Sk::Bool))
666                    | (Uo::BitwiseNot, Some(Sk::Sint | Sk::Uint)) => {}
667                    other => {
668                        log::error!("Op {:?} kind {:?}", op, other);
669                        return Err(ExpressionError::InvalidUnaryOperandType(op, expr));
670                    }
671                }
672                ShaderStages::all()
673            }
674            E::Binary { op, left, right } => {
675                use crate::BinaryOperator as Bo;
676                let left_inner = &resolver[left];
677                let right_inner = &resolver[right];
678                let good = match op {
679                    Bo::Add | Bo::Subtract => match *left_inner {
680                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
681                            Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
682                            Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
683                        },
684                        Ti::Matrix { .. } => left_inner == right_inner,
685                        _ => false,
686                    },
687                    Bo::Divide | Bo::Modulo => match *left_inner {
688                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
689                            Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
690                            Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
691                        },
692                        _ => false,
693                    },
694                    Bo::Multiply => {
695                        let kind_allowed = match left_inner.scalar_kind() {
696                            Some(Sk::Uint | Sk::Sint | Sk::Float) => true,
697                            Some(Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat) | None => false,
698                        };
699                        let types_match = match (left_inner, right_inner) {
700                            // Straight scalar and mixed scalar/vector.
701                            (&Ti::Scalar(scalar1), &Ti::Scalar(scalar2))
702                            | (
703                                &Ti::Vector {
704                                    scalar: scalar1, ..
705                                },
706                                &Ti::Scalar(scalar2),
707                            )
708                            | (
709                                &Ti::Scalar(scalar1),
710                                &Ti::Vector {
711                                    scalar: scalar2, ..
712                                },
713                            ) => scalar1 == scalar2,
714                            // Scalar/matrix.
715                            (
716                                &Ti::Scalar(Sc {
717                                    kind: Sk::Float, ..
718                                }),
719                                &Ti::Matrix { .. },
720                            )
721                            | (
722                                &Ti::Matrix { .. },
723                                &Ti::Scalar(Sc {
724                                    kind: Sk::Float, ..
725                                }),
726                            ) => true,
727                            // Vector/vector.
728                            (
729                                &Ti::Vector {
730                                    size: size1,
731                                    scalar: scalar1,
732                                },
733                                &Ti::Vector {
734                                    size: size2,
735                                    scalar: scalar2,
736                                },
737                            ) => scalar1 == scalar2 && size1 == size2,
738                            // Matrix * vector.
739                            (
740                                &Ti::Matrix { columns, .. },
741                                &Ti::Vector {
742                                    size,
743                                    scalar:
744                                        Sc {
745                                            kind: Sk::Float, ..
746                                        },
747                                },
748                            ) => columns == size,
749                            // Vector * matrix.
750                            (
751                                &Ti::Vector {
752                                    size,
753                                    scalar:
754                                        Sc {
755                                            kind: Sk::Float, ..
756                                        },
757                                },
758                                &Ti::Matrix { rows, .. },
759                            ) => size == rows,
760                            (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => {
761                                columns == rows
762                            }
763                            _ => false,
764                        };
765                        let left_width = left_inner.scalar_width().unwrap_or(0);
766                        let right_width = right_inner.scalar_width().unwrap_or(0);
767                        kind_allowed && types_match && left_width == right_width
768                    }
769                    Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner,
770                    Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => {
771                        match *left_inner {
772                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
773                                Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
774                                Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
775                            },
776                            ref other => {
777                                log::error!("Op {:?} left type {:?}", op, other);
778                                false
779                            }
780                        }
781                    }
782                    Bo::LogicalAnd | Bo::LogicalOr => match *left_inner {
783                        Ti::Scalar(Sc { kind: Sk::Bool, .. })
784                        | Ti::Vector {
785                            scalar: Sc { kind: Sk::Bool, .. },
786                            ..
787                        } => left_inner == right_inner,
788                        ref other => {
789                            log::error!("Op {:?} left type {:?}", op, other);
790                            false
791                        }
792                    },
793                    Bo::And | Bo::InclusiveOr => match *left_inner {
794                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
795                            Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner,
796                            Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
797                        },
798                        ref other => {
799                            log::error!("Op {:?} left type {:?}", op, other);
800                            false
801                        }
802                    },
803                    Bo::ExclusiveOr => match *left_inner {
804                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
805                            Sk::Sint | Sk::Uint => left_inner == right_inner,
806                            Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
807                        },
808                        ref other => {
809                            log::error!("Op {:?} left type {:?}", op, other);
810                            false
811                        }
812                    },
813                    Bo::ShiftLeft | Bo::ShiftRight => {
814                        let (base_size, base_scalar) = match *left_inner {
815                            Ti::Scalar(scalar) => (Ok(None), scalar),
816                            Ti::Vector { size, scalar } => (Ok(Some(size)), scalar),
817                            ref other => {
818                                log::error!("Op {:?} base type {:?}", op, other);
819                                (Err(()), Sc::BOOL)
820                            }
821                        };
822                        let shift_size = match *right_inner {
823                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => Ok(None),
824                            Ti::Vector {
825                                size,
826                                scalar: Sc { kind: Sk::Uint, .. },
827                            } => Ok(Some(size)),
828                            ref other => {
829                                log::error!("Op {:?} shift type {:?}", op, other);
830                                Err(())
831                            }
832                        };
833                        match base_scalar.kind {
834                            Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size,
835                            Sk::Float | Sk::AbstractInt | Sk::AbstractFloat | Sk::Bool => false,
836                        }
837                    }
838                };
839                if !good {
840                    log::error!(
841                        "Left: {:?} of type {:?}",
842                        function.expressions[left],
843                        left_inner
844                    );
845                    log::error!(
846                        "Right: {:?} of type {:?}",
847                        function.expressions[right],
848                        right_inner
849                    );
850                    return Err(ExpressionError::InvalidBinaryOperandTypes(op, left, right));
851                }
852                ShaderStages::all()
853            }
854            E::Select {
855                condition,
856                accept,
857                reject,
858            } => {
859                let accept_inner = &resolver[accept];
860                let reject_inner = &resolver[reject];
861                let condition_good = match resolver[condition] {
862                    Ti::Scalar(Sc {
863                        kind: Sk::Bool,
864                        width: _,
865                    }) => {
866                        // When `condition` is a single boolean, `accept` and
867                        // `reject` can be vectors or scalars.
868                        match *accept_inner {
869                            Ti::Scalar { .. } | Ti::Vector { .. } => true,
870                            _ => false,
871                        }
872                    }
873                    Ti::Vector {
874                        size,
875                        scalar:
876                            Sc {
877                                kind: Sk::Bool,
878                                width: _,
879                            },
880                    } => match *accept_inner {
881                        Ti::Vector {
882                            size: other_size, ..
883                        } => size == other_size,
884                        _ => false,
885                    },
886                    _ => false,
887                };
888                if !condition_good || accept_inner != reject_inner {
889                    return Err(ExpressionError::InvalidSelectTypes);
890                }
891                ShaderStages::all()
892            }
893            E::Derivative { expr, .. } => {
894                match resolver[expr] {
895                    Ti::Scalar(Sc {
896                        kind: Sk::Float, ..
897                    })
898                    | Ti::Vector {
899                        scalar:
900                            Sc {
901                                kind: Sk::Float, ..
902                            },
903                        ..
904                    } => {}
905                    _ => return Err(ExpressionError::InvalidDerivative),
906                }
907                ShaderStages::FRAGMENT
908            }
909            E::Relational { fun, argument } => {
910                use crate::RelationalFunction as Rf;
911                let argument_inner = &resolver[argument];
912                match fun {
913                    Rf::All | Rf::Any => match *argument_inner {
914                        Ti::Vector {
915                            scalar: Sc { kind: Sk::Bool, .. },
916                            ..
917                        } => {}
918                        ref other => {
919                            log::error!("All/Any of type {:?}", other);
920                            return Err(ExpressionError::InvalidBooleanVector(argument));
921                        }
922                    },
923                    Rf::IsNan | Rf::IsInf => match *argument_inner {
924                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
925                            if scalar.kind == Sk::Float => {}
926                        ref other => {
927                            log::error!("Float test of type {:?}", other);
928                            return Err(ExpressionError::InvalidFloatArgument(argument));
929                        }
930                    },
931                }
932                ShaderStages::all()
933            }
934            E::Math {
935                fun,
936                arg,
937                arg1,
938                arg2,
939                arg3,
940            } => {
941                use crate::MathFunction as Mf;
942
943                let resolve = |arg| &resolver[arg];
944                let arg_ty = resolve(arg);
945                let arg1_ty = arg1.map(resolve);
946                let arg2_ty = arg2.map(resolve);
947                let arg3_ty = arg3.map(resolve);
948                match fun {
949                    Mf::Abs => {
950                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
951                            return Err(ExpressionError::WrongArgumentCount(fun));
952                        }
953                        let good = match *arg_ty {
954                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
955                                scalar.kind != Sk::Bool
956                            }
957                            _ => false,
958                        };
959                        if !good {
960                            return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
961                        }
962                    }
963                    Mf::Min | Mf::Max => {
964                        let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
965                            (Some(ty1), None, None) => ty1,
966                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
967                        };
968                        let good = match *arg_ty {
969                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
970                                scalar.kind != Sk::Bool
971                            }
972                            _ => false,
973                        };
974                        if !good {
975                            return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
976                        }
977                        if arg1_ty != arg_ty {
978                            return Err(ExpressionError::InvalidArgumentType(
979                                fun,
980                                1,
981                                arg1.unwrap(),
982                            ));
983                        }
984                    }
985                    Mf::Clamp => {
986                        let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
987                            (Some(ty1), Some(ty2), None) => (ty1, ty2),
988                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
989                        };
990                        let good = match *arg_ty {
991                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
992                                scalar.kind != Sk::Bool
993                            }
994                            _ => false,
995                        };
996                        if !good {
997                            return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
998                        }
999                        if arg1_ty != arg_ty {
1000                            return Err(ExpressionError::InvalidArgumentType(
1001                                fun,
1002                                1,
1003                                arg1.unwrap(),
1004                            ));
1005                        }
1006                        if arg2_ty != arg_ty {
1007                            return Err(ExpressionError::InvalidArgumentType(
1008                                fun,
1009                                2,
1010                                arg2.unwrap(),
1011                            ));
1012                        }
1013                    }
1014                    Mf::Saturate
1015                    | Mf::Cos
1016                    | Mf::Cosh
1017                    | Mf::Sin
1018                    | Mf::Sinh
1019                    | Mf::Tan
1020                    | Mf::Tanh
1021                    | Mf::Acos
1022                    | Mf::Asin
1023                    | Mf::Atan
1024                    | Mf::Asinh
1025                    | Mf::Acosh
1026                    | Mf::Atanh
1027                    | Mf::Radians
1028                    | Mf::Degrees
1029                    | Mf::Ceil
1030                    | Mf::Floor
1031                    | Mf::Round
1032                    | Mf::Fract
1033                    | Mf::Trunc
1034                    | Mf::Exp
1035                    | Mf::Exp2
1036                    | Mf::Log
1037                    | Mf::Log2
1038                    | Mf::Length
1039                    | Mf::Sqrt
1040                    | Mf::InverseSqrt => {
1041                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1042                            return Err(ExpressionError::WrongArgumentCount(fun));
1043                        }
1044                        match *arg_ty {
1045                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1046                                if scalar.kind == Sk::Float => {}
1047                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1048                        }
1049                    }
1050                    Mf::Sign => {
1051                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1052                            return Err(ExpressionError::WrongArgumentCount(fun));
1053                        }
1054                        match *arg_ty {
1055                            Ti::Scalar(Sc {
1056                                kind: Sk::Float | Sk::Sint,
1057                                ..
1058                            })
1059                            | Ti::Vector {
1060                                scalar:
1061                                    Sc {
1062                                        kind: Sk::Float | Sk::Sint,
1063                                        ..
1064                                    },
1065                                ..
1066                            } => {}
1067                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1068                        }
1069                    }
1070                    Mf::Atan2 | Mf::Pow | Mf::Distance | Mf::Step => {
1071                        let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1072                            (Some(ty1), None, None) => ty1,
1073                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1074                        };
1075                        match *arg_ty {
1076                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1077                                if scalar.kind == Sk::Float => {}
1078                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1079                        }
1080                        if arg1_ty != arg_ty {
1081                            return Err(ExpressionError::InvalidArgumentType(
1082                                fun,
1083                                1,
1084                                arg1.unwrap(),
1085                            ));
1086                        }
1087                    }
1088                    Mf::Modf | Mf::Frexp => {
1089                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1090                            return Err(ExpressionError::WrongArgumentCount(fun));
1091                        }
1092                        if !matches!(*arg_ty,
1093                                     Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1094                                     if scalar.kind == Sk::Float)
1095                        {
1096                            return Err(ExpressionError::InvalidArgumentType(fun, 1, arg));
1097                        }
1098                    }
1099                    Mf::Ldexp => {
1100                        let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1101                            (Some(ty1), None, None) => ty1,
1102                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1103                        };
1104                        let size0 = match *arg_ty {
1105                            Ti::Scalar(Sc {
1106                                kind: Sk::Float, ..
1107                            }) => None,
1108                            Ti::Vector {
1109                                scalar:
1110                                    Sc {
1111                                        kind: Sk::Float, ..
1112                                    },
1113                                size,
1114                            } => Some(size),
1115                            _ => {
1116                                return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
1117                            }
1118                        };
1119                        let good = match *arg1_ty {
1120                            Ti::Scalar(Sc { kind: Sk::Sint, .. }) if size0.is_none() => true,
1121                            Ti::Vector {
1122                                size,
1123                                scalar: Sc { kind: Sk::Sint, .. },
1124                            } if Some(size) == size0 => true,
1125                            _ => false,
1126                        };
1127                        if !good {
1128                            return Err(ExpressionError::InvalidArgumentType(
1129                                fun,
1130                                1,
1131                                arg1.unwrap(),
1132                            ));
1133                        }
1134                    }
1135                    Mf::Dot => {
1136                        let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1137                            (Some(ty1), None, None) => ty1,
1138                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1139                        };
1140                        match *arg_ty {
1141                            Ti::Vector {
1142                                scalar:
1143                                    Sc {
1144                                        kind: Sk::Float | Sk::Sint | Sk::Uint,
1145                                        ..
1146                                    },
1147                                ..
1148                            } => {}
1149                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1150                        }
1151                        if arg1_ty != arg_ty {
1152                            return Err(ExpressionError::InvalidArgumentType(
1153                                fun,
1154                                1,
1155                                arg1.unwrap(),
1156                            ));
1157                        }
1158                    }
1159                    Mf::Outer | Mf::Reflect => {
1160                        let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1161                            (Some(ty1), None, None) => ty1,
1162                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1163                        };
1164                        match *arg_ty {
1165                            Ti::Vector {
1166                                scalar:
1167                                    Sc {
1168                                        kind: Sk::Float, ..
1169                                    },
1170                                ..
1171                            } => {}
1172                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1173                        }
1174                        if arg1_ty != arg_ty {
1175                            return Err(ExpressionError::InvalidArgumentType(
1176                                fun,
1177                                1,
1178                                arg1.unwrap(),
1179                            ));
1180                        }
1181                    }
1182                    Mf::Cross => {
1183                        let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1184                            (Some(ty1), None, None) => ty1,
1185                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1186                        };
1187                        match *arg_ty {
1188                            Ti::Vector {
1189                                scalar:
1190                                    Sc {
1191                                        kind: Sk::Float, ..
1192                                    },
1193                                size: crate::VectorSize::Tri,
1194                            } => {}
1195                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1196                        }
1197                        if arg1_ty != arg_ty {
1198                            return Err(ExpressionError::InvalidArgumentType(
1199                                fun,
1200                                1,
1201                                arg1.unwrap(),
1202                            ));
1203                        }
1204                    }
1205                    Mf::Refract => {
1206                        let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1207                            (Some(ty1), Some(ty2), None) => (ty1, ty2),
1208                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1209                        };
1210
1211                        match *arg_ty {
1212                            Ti::Vector {
1213                                scalar:
1214                                    Sc {
1215                                        kind: Sk::Float, ..
1216                                    },
1217                                ..
1218                            } => {}
1219                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1220                        }
1221
1222                        if arg1_ty != arg_ty {
1223                            return Err(ExpressionError::InvalidArgumentType(
1224                                fun,
1225                                1,
1226                                arg1.unwrap(),
1227                            ));
1228                        }
1229
1230                        match (arg_ty, arg2_ty) {
1231                            (
1232                                &Ti::Vector {
1233                                    scalar:
1234                                        Sc {
1235                                            width: vector_width,
1236                                            ..
1237                                        },
1238                                    ..
1239                                },
1240                                &Ti::Scalar(Sc {
1241                                    width: scalar_width,
1242                                    kind: Sk::Float,
1243                                }),
1244                            ) if vector_width == scalar_width => {}
1245                            _ => {
1246                                return Err(ExpressionError::InvalidArgumentType(
1247                                    fun,
1248                                    2,
1249                                    arg2.unwrap(),
1250                                ))
1251                            }
1252                        }
1253                    }
1254                    Mf::Normalize => {
1255                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1256                            return Err(ExpressionError::WrongArgumentCount(fun));
1257                        }
1258                        match *arg_ty {
1259                            Ti::Vector {
1260                                scalar:
1261                                    Sc {
1262                                        kind: Sk::Float, ..
1263                                    },
1264                                ..
1265                            } => {}
1266                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1267                        }
1268                    }
1269                    Mf::FaceForward | Mf::Fma | Mf::SmoothStep => {
1270                        let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1271                            (Some(ty1), Some(ty2), None) => (ty1, ty2),
1272                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1273                        };
1274                        match *arg_ty {
1275                            Ti::Scalar(Sc {
1276                                kind: Sk::Float, ..
1277                            })
1278                            | Ti::Vector {
1279                                scalar:
1280                                    Sc {
1281                                        kind: Sk::Float, ..
1282                                    },
1283                                ..
1284                            } => {}
1285                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1286                        }
1287                        if arg1_ty != arg_ty {
1288                            return Err(ExpressionError::InvalidArgumentType(
1289                                fun,
1290                                1,
1291                                arg1.unwrap(),
1292                            ));
1293                        }
1294                        if arg2_ty != arg_ty {
1295                            return Err(ExpressionError::InvalidArgumentType(
1296                                fun,
1297                                2,
1298                                arg2.unwrap(),
1299                            ));
1300                        }
1301                    }
1302                    Mf::Mix => {
1303                        let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1304                            (Some(ty1), Some(ty2), None) => (ty1, ty2),
1305                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1306                        };
1307                        let arg_width = match *arg_ty {
1308                            Ti::Scalar(Sc {
1309                                kind: Sk::Float,
1310                                width,
1311                            })
1312                            | Ti::Vector {
1313                                scalar:
1314                                    Sc {
1315                                        kind: Sk::Float,
1316                                        width,
1317                                    },
1318                                ..
1319                            } => width,
1320                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1321                        };
1322                        if arg1_ty != arg_ty {
1323                            return Err(ExpressionError::InvalidArgumentType(
1324                                fun,
1325                                1,
1326                                arg1.unwrap(),
1327                            ));
1328                        }
1329                        // the last argument can always be a scalar
1330                        match *arg2_ty {
1331                            Ti::Scalar(Sc {
1332                                kind: Sk::Float,
1333                                width,
1334                            }) if width == arg_width => {}
1335                            _ if arg2_ty == arg_ty => {}
1336                            _ => {
1337                                return Err(ExpressionError::InvalidArgumentType(
1338                                    fun,
1339                                    2,
1340                                    arg2.unwrap(),
1341                                ));
1342                            }
1343                        }
1344                    }
1345                    Mf::Inverse | Mf::Determinant => {
1346                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1347                            return Err(ExpressionError::WrongArgumentCount(fun));
1348                        }
1349                        let good = match *arg_ty {
1350                            Ti::Matrix { columns, rows, .. } => columns == rows,
1351                            _ => false,
1352                        };
1353                        if !good {
1354                            return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
1355                        }
1356                    }
1357                    Mf::Transpose => {
1358                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1359                            return Err(ExpressionError::WrongArgumentCount(fun));
1360                        }
1361                        match *arg_ty {
1362                            Ti::Matrix { .. } => {}
1363                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1364                        }
1365                    }
1366                    // Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276
1367                    Mf::CountLeadingZeros
1368                    | Mf::CountTrailingZeros
1369                    | Mf::CountOneBits
1370                    | Mf::ReverseBits
1371                    | Mf::FirstLeadingBit
1372                    | Mf::FirstTrailingBit => {
1373                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1374                            return Err(ExpressionError::WrongArgumentCount(fun));
1375                        }
1376                        match *arg_ty {
1377                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
1378                                Sk::Sint | Sk::Uint => {
1379                                    if scalar.width != 4 {
1380                                        return Err(ExpressionError::UnsupportedWidth(
1381                                            fun,
1382                                            scalar.kind,
1383                                            scalar.width,
1384                                        ));
1385                                    }
1386                                }
1387                                _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1388                            },
1389                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1390                        }
1391                    }
1392                    Mf::InsertBits => {
1393                        let (arg1_ty, arg2_ty, arg3_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1394                            (Some(ty1), Some(ty2), Some(ty3)) => (ty1, ty2, ty3),
1395                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1396                        };
1397                        match *arg_ty {
1398                            Ti::Scalar(Sc {
1399                                kind: Sk::Sint | Sk::Uint,
1400                                ..
1401                            })
1402                            | Ti::Vector {
1403                                scalar:
1404                                    Sc {
1405                                        kind: Sk::Sint | Sk::Uint,
1406                                        ..
1407                                    },
1408                                ..
1409                            } => {}
1410                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1411                        }
1412                        if arg1_ty != arg_ty {
1413                            return Err(ExpressionError::InvalidArgumentType(
1414                                fun,
1415                                1,
1416                                arg1.unwrap(),
1417                            ));
1418                        }
1419                        match *arg2_ty {
1420                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1421                            _ => {
1422                                return Err(ExpressionError::InvalidArgumentType(
1423                                    fun,
1424                                    2,
1425                                    arg2.unwrap(),
1426                                ))
1427                            }
1428                        }
1429                        match *arg3_ty {
1430                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1431                            _ => {
1432                                return Err(ExpressionError::InvalidArgumentType(
1433                                    fun,
1434                                    2,
1435                                    arg3.unwrap(),
1436                                ))
1437                            }
1438                        }
1439                        // Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276
1440                        for &arg in [arg_ty, arg1_ty, arg2_ty, arg3_ty].iter() {
1441                            match *arg {
1442                                Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
1443                                    if scalar.width != 4 {
1444                                        return Err(ExpressionError::UnsupportedWidth(
1445                                            fun,
1446                                            scalar.kind,
1447                                            scalar.width,
1448                                        ));
1449                                    }
1450                                }
1451                                _ => {}
1452                            }
1453                        }
1454                    }
1455                    Mf::ExtractBits => {
1456                        let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1457                            (Some(ty1), Some(ty2), None) => (ty1, ty2),
1458                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1459                        };
1460                        match *arg_ty {
1461                            Ti::Scalar(Sc {
1462                                kind: Sk::Sint | Sk::Uint,
1463                                ..
1464                            })
1465                            | Ti::Vector {
1466                                scalar:
1467                                    Sc {
1468                                        kind: Sk::Sint | Sk::Uint,
1469                                        ..
1470                                    },
1471                                ..
1472                            } => {}
1473                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1474                        }
1475                        match *arg1_ty {
1476                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1477                            _ => {
1478                                return Err(ExpressionError::InvalidArgumentType(
1479                                    fun,
1480                                    2,
1481                                    arg1.unwrap(),
1482                                ))
1483                            }
1484                        }
1485                        match *arg2_ty {
1486                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1487                            _ => {
1488                                return Err(ExpressionError::InvalidArgumentType(
1489                                    fun,
1490                                    2,
1491                                    arg2.unwrap(),
1492                                ))
1493                            }
1494                        }
1495                        // Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276
1496                        for &arg in [arg_ty, arg1_ty, arg2_ty].iter() {
1497                            match *arg {
1498                                Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
1499                                    if scalar.width != 4 {
1500                                        return Err(ExpressionError::UnsupportedWidth(
1501                                            fun,
1502                                            scalar.kind,
1503                                            scalar.width,
1504                                        ));
1505                                    }
1506                                }
1507                                _ => {}
1508                            }
1509                        }
1510                    }
1511                    Mf::Pack2x16unorm | Mf::Pack2x16snorm | Mf::Pack2x16float => {
1512                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1513                            return Err(ExpressionError::WrongArgumentCount(fun));
1514                        }
1515                        match *arg_ty {
1516                            Ti::Vector {
1517                                size: crate::VectorSize::Bi,
1518                                scalar:
1519                                    Sc {
1520                                        kind: Sk::Float, ..
1521                                    },
1522                            } => {}
1523                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1524                        }
1525                    }
1526                    Mf::Pack4x8snorm | Mf::Pack4x8unorm => {
1527                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1528                            return Err(ExpressionError::WrongArgumentCount(fun));
1529                        }
1530                        match *arg_ty {
1531                            Ti::Vector {
1532                                size: crate::VectorSize::Quad,
1533                                scalar:
1534                                    Sc {
1535                                        kind: Sk::Float, ..
1536                                    },
1537                            } => {}
1538                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1539                        }
1540                    }
1541                    mf @ (Mf::Pack4xI8 | Mf::Pack4xU8) => {
1542                        let scalar_kind = match mf {
1543                            Mf::Pack4xI8 => Sk::Sint,
1544                            Mf::Pack4xU8 => Sk::Uint,
1545                            _ => unreachable!(),
1546                        };
1547                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1548                            return Err(ExpressionError::WrongArgumentCount(fun));
1549                        }
1550                        match *arg_ty {
1551                            Ti::Vector {
1552                                size: crate::VectorSize::Quad,
1553                                scalar: Sc { kind, .. },
1554                            } if kind == scalar_kind => {}
1555                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1556                        }
1557                    }
1558                    Mf::Unpack2x16float
1559                    | Mf::Unpack2x16snorm
1560                    | Mf::Unpack2x16unorm
1561                    | Mf::Unpack4x8snorm
1562                    | Mf::Unpack4x8unorm
1563                    | Mf::Unpack4xI8
1564                    | Mf::Unpack4xU8 => {
1565                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1566                            return Err(ExpressionError::WrongArgumentCount(fun));
1567                        }
1568                        match *arg_ty {
1569                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1570                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1571                        }
1572                    }
1573                }
1574                ShaderStages::all()
1575            }
1576            E::As {
1577                expr,
1578                kind,
1579                convert,
1580            } => {
1581                let mut base_scalar = match resolver[expr] {
1582                    crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } => {
1583                        scalar
1584                    }
1585                    crate::TypeInner::Matrix { scalar, .. } => scalar,
1586                    _ => return Err(ExpressionError::InvalidCastArgument),
1587                };
1588                base_scalar.kind = kind;
1589                if let Some(width) = convert {
1590                    base_scalar.width = width;
1591                }
1592                if self.check_width(base_scalar).is_err() {
1593                    return Err(ExpressionError::InvalidCastArgument);
1594                }
1595                ShaderStages::all()
1596            }
1597            E::CallResult(function) => mod_info.functions[function.index()].available_stages,
1598            E::AtomicResult { .. } => {
1599                // These expressions are validated when we check the `Atomic` statement
1600                // that refers to them, because we have all the information we need at
1601                // that point. The checks driven by `Validator::needs_visit` ensure
1602                // that this expression is indeed visited by one `Atomic` statement.
1603                ShaderStages::all()
1604            }
1605            E::WorkGroupUniformLoadResult { ty } => {
1606                if self.types[ty.index()]
1607                    .flags
1608                    // Sized | Constructible is exactly the types currently supported by
1609                    // WorkGroupUniformLoad
1610                    .contains(TypeFlags::SIZED | TypeFlags::CONSTRUCTIBLE)
1611                {
1612                    ShaderStages::COMPUTE
1613                } else {
1614                    return Err(ExpressionError::InvalidWorkGroupUniformLoadResultType(ty));
1615                }
1616            }
1617            E::ArrayLength(expr) => match resolver[expr] {
1618                Ti::Pointer { base, .. } => {
1619                    let base_ty = &resolver.types[base];
1620                    if let Ti::Array {
1621                        size: crate::ArraySize::Dynamic,
1622                        ..
1623                    } = base_ty.inner
1624                    {
1625                        ShaderStages::all()
1626                    } else {
1627                        return Err(ExpressionError::InvalidArrayType(expr));
1628                    }
1629                }
1630                ref other => {
1631                    log::error!("Array length of {:?}", other);
1632                    return Err(ExpressionError::InvalidArrayType(expr));
1633                }
1634            },
1635            E::RayQueryProceedResult => ShaderStages::all(),
1636            E::RayQueryGetIntersection {
1637                query,
1638                committed: _,
1639            } => match resolver[query] {
1640                Ti::Pointer {
1641                    base,
1642                    space: crate::AddressSpace::Function,
1643                } => match resolver.types[base].inner {
1644                    Ti::RayQuery => ShaderStages::all(),
1645                    ref other => {
1646                        log::error!("Intersection result of a pointer to {:?}", other);
1647                        return Err(ExpressionError::InvalidRayQueryType(query));
1648                    }
1649                },
1650                ref other => {
1651                    log::error!("Intersection result of {:?}", other);
1652                    return Err(ExpressionError::InvalidRayQueryType(query));
1653                }
1654            },
1655            E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages,
1656        };
1657        Ok(stages)
1658    }
1659
1660    fn global_var_ty(
1661        module: &crate::Module,
1662        function: &crate::Function,
1663        expr: Handle<crate::Expression>,
1664    ) -> Result<Handle<crate::Type>, ExpressionError> {
1665        use crate::Expression as Ex;
1666
1667        match function.expressions[expr] {
1668            Ex::GlobalVariable(var_handle) => Ok(module.global_variables[var_handle].ty),
1669            Ex::FunctionArgument(i) => Ok(function.arguments[i as usize].ty),
1670            Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
1671                match function.expressions[base] {
1672                    Ex::GlobalVariable(var_handle) => {
1673                        let array_ty = module.global_variables[var_handle].ty;
1674
1675                        match module.types[array_ty].inner {
1676                            crate::TypeInner::BindingArray { base, .. } => Ok(base),
1677                            _ => Err(ExpressionError::ExpectedBindingArrayType(array_ty)),
1678                        }
1679                    }
1680                    _ => Err(ExpressionError::ExpectedGlobalVariable),
1681                }
1682            }
1683            _ => Err(ExpressionError::ExpectedGlobalVariable),
1684        }
1685    }
1686
1687    pub fn validate_literal(&self, literal: crate::Literal) -> Result<(), LiteralError> {
1688        self.check_width(literal.scalar())?;
1689        check_literal_value(literal)?;
1690
1691        Ok(())
1692    }
1693}
1694
1695pub fn check_literal_value(literal: crate::Literal) -> Result<(), LiteralError> {
1696    let is_nan = match literal {
1697        crate::Literal::F64(v) => v.is_nan(),
1698        crate::Literal::F32(v) => v.is_nan(),
1699        _ => false,
1700    };
1701    if is_nan {
1702        return Err(LiteralError::NaN);
1703    }
1704
1705    let is_infinite = match literal {
1706        crate::Literal::F64(v) => v.is_infinite(),
1707        crate::Literal::F32(v) => v.is_infinite(),
1708        _ => false,
1709    };
1710    if is_infinite {
1711        return Err(LiteralError::Infinity);
1712    }
1713
1714    Ok(())
1715}
1716
1717#[cfg(test)]
1718/// Validate a module containing the given expression, expecting an error.
1719fn validate_with_expression(
1720    expr: crate::Expression,
1721    caps: super::Capabilities,
1722) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1723    use crate::span::Span;
1724
1725    let mut function = crate::Function::default();
1726    function.expressions.append(expr, Span::default());
1727    function.body.push(
1728        crate::Statement::Emit(function.expressions.range_from(0)),
1729        Span::default(),
1730    );
1731
1732    let mut module = crate::Module::default();
1733    module.functions.append(function, Span::default());
1734
1735    let mut validator = super::Validator::new(super::ValidationFlags::EXPRESSIONS, caps);
1736
1737    validator.validate(&module)
1738}
1739
1740#[cfg(test)]
1741/// Validate a module containing the given constant expression, expecting an error.
1742fn validate_with_const_expression(
1743    expr: crate::Expression,
1744    caps: super::Capabilities,
1745) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1746    use crate::span::Span;
1747
1748    let mut module = crate::Module::default();
1749    module.global_expressions.append(expr, Span::default());
1750
1751    let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps);
1752
1753    validator.validate(&module)
1754}
1755
1756/// Using F64 in a function's expression arena is forbidden.
1757#[test]
1758fn f64_runtime_literals() {
1759    let result = validate_with_expression(
1760        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1761        super::Capabilities::default(),
1762    );
1763    let error = result.unwrap_err().into_inner();
1764    assert!(matches!(
1765        error,
1766        crate::valid::ValidationError::Function {
1767            source: super::FunctionError::Expression {
1768                source: ExpressionError::Literal(LiteralError::Width(
1769                    super::r#type::WidthError::MissingCapability {
1770                        name: "f64",
1771                        flag: "FLOAT64",
1772                    }
1773                ),),
1774                ..
1775            },
1776            ..
1777        }
1778    ));
1779
1780    let result = validate_with_expression(
1781        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1782        super::Capabilities::default() | super::Capabilities::FLOAT64,
1783    );
1784    assert!(result.is_ok());
1785}
1786
1787/// Using F64 in a module's constant expression arena is forbidden.
1788#[test]
1789fn f64_const_literals() {
1790    let result = validate_with_const_expression(
1791        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1792        super::Capabilities::default(),
1793    );
1794    let error = result.unwrap_err().into_inner();
1795    assert!(matches!(
1796        error,
1797        crate::valid::ValidationError::ConstExpression {
1798            source: ConstExpressionError::Literal(LiteralError::Width(
1799                super::r#type::WidthError::MissingCapability {
1800                    name: "f64",
1801                    flag: "FLOAT64",
1802                }
1803            )),
1804            ..
1805        }
1806    ));
1807
1808    let result = validate_with_const_expression(
1809        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1810        super::Capabilities::default() | super::Capabilities::FLOAT64,
1811    );
1812    assert!(result.is_ok());
1813}