naga/proc/
typifier.rs

1use crate::arena::{Arena, Handle, UniqueArena};
2
3use thiserror::Error;
4
5/// The result of computing an expression's type.
6///
7/// This is the (Rust) type returned by [`ResolveContext::resolve`] to represent
8/// the (Naga) type it ascribes to some expression.
9///
10/// You might expect such a function to simply return a `Handle<Type>`. However,
11/// we want type resolution to be a read-only process, and that would limit the
12/// possible results to types already present in the expression's associated
13/// `UniqueArena<Type>`. Naga IR does have certain expressions whose types are
14/// not certain to be present.
15///
16/// So instead, type resolution returns a `TypeResolution` enum: either a
17/// [`Handle`], referencing some type in the arena, or a [`Value`], holding a
18/// free-floating [`TypeInner`]. This extends the range to cover anything that
19/// can be represented with a `TypeInner` referring to the existing arena.
20///
21/// What sorts of expressions can have types not available in the arena?
22///
23/// -   An [`Access`] or [`AccessIndex`] expression applied to a [`Vector`] or
24///     [`Matrix`] must have a [`Scalar`] or [`Vector`] type. But since `Vector`
25///     and `Matrix` represent their element and column types implicitly, not
26///     via a handle, there may not be a suitable type in the expression's
27///     associated arena. Instead, resolving such an expression returns a
28///     `TypeResolution::Value(TypeInner::X { ... })`, where `X` is `Scalar` or
29///     `Vector`.
30///
31/// -   Similarly, the type of an [`Access`] or [`AccessIndex`] expression
32///     applied to a *pointer to* a vector or matrix must produce a *pointer to*
33///     a scalar or vector type. These cannot be represented with a
34///     [`TypeInner::Pointer`], since the `Pointer`'s `base` must point into the
35///     arena, and as before, we cannot assume that a suitable scalar or vector
36///     type is there. So we take things one step further and provide
37///     [`TypeInner::ValuePointer`], specifically for the case of pointers to
38///     scalars or vectors. This type fits in a `TypeInner` and is exactly
39///     equivalent to a `Pointer` to a `Vector` or `Scalar`.
40///
41/// So, for example, the type of an `Access` expression applied to a value of type:
42///
43/// ```ignore
44/// TypeInner::Matrix { columns, rows, width }
45/// ```
46///
47/// might be:
48///
49/// ```ignore
50/// TypeResolution::Value(TypeInner::Vector {
51///     size: rows,
52///     kind: ScalarKind::Float,
53///     width,
54/// })
55/// ```
56///
57/// and the type of an access to a pointer of address space `space` to such a
58/// matrix might be:
59///
60/// ```ignore
61/// TypeResolution::Value(TypeInner::ValuePointer {
62///     size: Some(rows),
63///     kind: ScalarKind::Float,
64///     width,
65///     space,
66/// })
67/// ```
68///
69/// [`Handle`]: TypeResolution::Handle
70/// [`Value`]: TypeResolution::Value
71///
72/// [`Access`]: crate::Expression::Access
73/// [`AccessIndex`]: crate::Expression::AccessIndex
74///
75/// [`TypeInner`]: crate::TypeInner
76/// [`Matrix`]: crate::TypeInner::Matrix
77/// [`Pointer`]: crate::TypeInner::Pointer
78/// [`Scalar`]: crate::TypeInner::Scalar
79/// [`ValuePointer`]: crate::TypeInner::ValuePointer
80/// [`Vector`]: crate::TypeInner::Vector
81///
82/// [`TypeInner::Pointer`]: crate::TypeInner::Pointer
83/// [`TypeInner::ValuePointer`]: crate::TypeInner::ValuePointer
84#[derive(Debug, PartialEq)]
85#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
86#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
87pub enum TypeResolution {
88    /// A type stored in the associated arena.
89    Handle(Handle<crate::Type>),
90
91    /// A free-floating [`TypeInner`], representing a type that may not be
92    /// available in the associated arena. However, the `TypeInner` itself may
93    /// contain `Handle<Type>` values referring to types from the arena.
94    ///
95    /// The inner type must only be one of the following variants:
96    /// - TypeInner::Pointer
97    /// - TypeInner::ValuePointer
98    /// - TypeInner::Matrix (generated by matrix multiplication)
99    /// - TypeInner::Vector
100    /// - TypeInner::Scalar
101    ///
102    /// [`TypeInner`]: crate::TypeInner
103    Value(crate::TypeInner),
104}
105
106impl TypeResolution {
107    pub const fn handle(&self) -> Option<Handle<crate::Type>> {
108        match *self {
109            Self::Handle(handle) => Some(handle),
110            Self::Value(_) => None,
111        }
112    }
113
114    pub fn inner_with<'a>(&'a self, arena: &'a UniqueArena<crate::Type>) -> &'a crate::TypeInner {
115        match *self {
116            Self::Handle(handle) => &arena[handle].inner,
117            Self::Value(ref inner) => inner,
118        }
119    }
120}
121
122// Clone is only implemented for numeric variants of `TypeInner`.
123impl Clone for TypeResolution {
124    fn clone(&self) -> Self {
125        use crate::TypeInner as Ti;
126        match *self {
127            Self::Handle(handle) => Self::Handle(handle),
128            Self::Value(ref v) => Self::Value(match *v {
129                Ti::Scalar(scalar) => Ti::Scalar(scalar),
130                Ti::Vector { size, scalar } => Ti::Vector { size, scalar },
131                Ti::Matrix {
132                    rows,
133                    columns,
134                    scalar,
135                } => Ti::Matrix {
136                    rows,
137                    columns,
138                    scalar,
139                },
140                Ti::Pointer { base, space } => Ti::Pointer { base, space },
141                Ti::ValuePointer {
142                    size,
143                    scalar,
144                    space,
145                } => Ti::ValuePointer {
146                    size,
147                    scalar,
148                    space,
149                },
150                _ => unreachable!("Unexpected clone type: {:?}", v),
151            }),
152        }
153    }
154}
155
156#[derive(Clone, Debug, Error, PartialEq)]
157pub enum ResolveError {
158    #[error("Index {index} is out of bounds for expression {expr:?}")]
159    OutOfBoundsIndex {
160        expr: Handle<crate::Expression>,
161        index: u32,
162    },
163    #[error("Invalid access into expression {expr:?}, indexed: {indexed}")]
164    InvalidAccess {
165        expr: Handle<crate::Expression>,
166        indexed: bool,
167    },
168    #[error("Invalid sub-access into type {ty:?}, indexed: {indexed}")]
169    InvalidSubAccess {
170        ty: Handle<crate::Type>,
171        indexed: bool,
172    },
173    #[error("Invalid scalar {0:?}")]
174    InvalidScalar(Handle<crate::Expression>),
175    #[error("Invalid vector {0:?}")]
176    InvalidVector(Handle<crate::Expression>),
177    #[error("Invalid pointer {0:?}")]
178    InvalidPointer(Handle<crate::Expression>),
179    #[error("Invalid image {0:?}")]
180    InvalidImage(Handle<crate::Expression>),
181    #[error("Function {name} not defined")]
182    FunctionNotDefined { name: String },
183    #[error("Function without return type")]
184    FunctionReturnsVoid,
185    #[error("Incompatible operands: {0}")]
186    IncompatibleOperands(String),
187    #[error("Function argument {0} doesn't exist")]
188    FunctionArgumentNotFound(u32),
189    #[error("Special type is not registered within the module")]
190    MissingSpecialType,
191}
192
193pub struct ResolveContext<'a> {
194    pub constants: &'a Arena<crate::Constant>,
195    pub overrides: &'a Arena<crate::Override>,
196    pub types: &'a UniqueArena<crate::Type>,
197    pub special_types: &'a crate::SpecialTypes,
198    pub global_vars: &'a Arena<crate::GlobalVariable>,
199    pub local_vars: &'a Arena<crate::LocalVariable>,
200    pub functions: &'a Arena<crate::Function>,
201    pub arguments: &'a [crate::FunctionArgument],
202}
203
204impl<'a> ResolveContext<'a> {
205    /// Initialize a resolve context from the module.
206    pub const fn with_locals(
207        module: &'a crate::Module,
208        local_vars: &'a Arena<crate::LocalVariable>,
209        arguments: &'a [crate::FunctionArgument],
210    ) -> Self {
211        Self {
212            constants: &module.constants,
213            overrides: &module.overrides,
214            types: &module.types,
215            special_types: &module.special_types,
216            global_vars: &module.global_variables,
217            local_vars,
218            functions: &module.functions,
219            arguments,
220        }
221    }
222
223    /// Determine the type of `expr`.
224    ///
225    /// The `past` argument must be a closure that can resolve the types of any
226    /// expressions that `expr` refers to. These can be gathered by caching the
227    /// results of prior calls to `resolve`, perhaps as done by the
228    /// [`front::Typifier`] utility type.
229    ///
230    /// Type resolution is a read-only process: this method takes `self` by
231    /// shared reference. However, this means that we cannot add anything to
232    /// `self.types` that we might need to describe `expr`. To work around this,
233    /// this method returns a [`TypeResolution`], rather than simply returning a
234    /// `Handle<Type>`; see the documentation for [`TypeResolution`] for
235    /// details.
236    ///
237    /// [`front::Typifier`]: crate::front::Typifier
238    pub fn resolve(
239        &self,
240        expr: &crate::Expression,
241        past: impl Fn(Handle<crate::Expression>) -> Result<&'a TypeResolution, ResolveError>,
242    ) -> Result<TypeResolution, ResolveError> {
243        use crate::TypeInner as Ti;
244        let types = self.types;
245        Ok(match *expr {
246            crate::Expression::Access { base, .. } => match *past(base)?.inner_with(types) {
247                // Arrays and matrices can only be indexed dynamically behind a
248                // pointer, but that's a validation error, not a type error, so
249                // go ahead provide a type here.
250                Ti::Array { base, .. } => TypeResolution::Handle(base),
251                Ti::Matrix { rows, scalar, .. } => {
252                    TypeResolution::Value(Ti::Vector { size: rows, scalar })
253                }
254                Ti::Vector { size: _, scalar } => TypeResolution::Value(Ti::Scalar(scalar)),
255                Ti::ValuePointer {
256                    size: Some(_),
257                    scalar,
258                    space,
259                } => TypeResolution::Value(Ti::ValuePointer {
260                    size: None,
261                    scalar,
262                    space,
263                }),
264                Ti::Pointer { base, space } => {
265                    TypeResolution::Value(match types[base].inner {
266                        Ti::Array { base, .. } => Ti::Pointer { base, space },
267                        Ti::Vector { size: _, scalar } => Ti::ValuePointer {
268                            size: None,
269                            scalar,
270                            space,
271                        },
272                        // Matrices are only dynamically indexed behind a pointer
273                        Ti::Matrix {
274                            columns: _,
275                            rows,
276                            scalar,
277                        } => Ti::ValuePointer {
278                            size: Some(rows),
279                            scalar,
280                            space,
281                        },
282                        Ti::BindingArray { base, .. } => Ti::Pointer { base, space },
283                        ref other => {
284                            log::error!("Access sub-type {:?}", other);
285                            return Err(ResolveError::InvalidSubAccess {
286                                ty: base,
287                                indexed: false,
288                            });
289                        }
290                    })
291                }
292                Ti::BindingArray { base, .. } => TypeResolution::Handle(base),
293                ref other => {
294                    log::error!("Access type {:?}", other);
295                    return Err(ResolveError::InvalidAccess {
296                        expr: base,
297                        indexed: false,
298                    });
299                }
300            },
301            crate::Expression::AccessIndex { base, index } => {
302                match *past(base)?.inner_with(types) {
303                    Ti::Vector { size, scalar } => {
304                        if index >= size as u32 {
305                            return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
306                        }
307                        TypeResolution::Value(Ti::Scalar(scalar))
308                    }
309                    Ti::Matrix {
310                        columns,
311                        rows,
312                        scalar,
313                    } => {
314                        if index >= columns as u32 {
315                            return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
316                        }
317                        TypeResolution::Value(crate::TypeInner::Vector { size: rows, scalar })
318                    }
319                    Ti::Array { base, .. } => TypeResolution::Handle(base),
320                    Ti::Struct { ref members, .. } => {
321                        let member = members
322                            .get(index as usize)
323                            .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
324                        TypeResolution::Handle(member.ty)
325                    }
326                    Ti::ValuePointer {
327                        size: Some(size),
328                        scalar,
329                        space,
330                    } => {
331                        if index >= size as u32 {
332                            return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
333                        }
334                        TypeResolution::Value(Ti::ValuePointer {
335                            size: None,
336                            scalar,
337                            space,
338                        })
339                    }
340                    Ti::Pointer {
341                        base: ty_base,
342                        space,
343                    } => TypeResolution::Value(match types[ty_base].inner {
344                        Ti::Array { base, .. } => Ti::Pointer { base, space },
345                        Ti::Vector { size, scalar } => {
346                            if index >= size as u32 {
347                                return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
348                            }
349                            Ti::ValuePointer {
350                                size: None,
351                                scalar,
352                                space,
353                            }
354                        }
355                        Ti::Matrix {
356                            rows,
357                            columns,
358                            scalar,
359                        } => {
360                            if index >= columns as u32 {
361                                return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
362                            }
363                            Ti::ValuePointer {
364                                size: Some(rows),
365                                scalar,
366                                space,
367                            }
368                        }
369                        Ti::Struct { ref members, .. } => {
370                            let member = members
371                                .get(index as usize)
372                                .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
373                            Ti::Pointer {
374                                base: member.ty,
375                                space,
376                            }
377                        }
378                        Ti::BindingArray { base, .. } => Ti::Pointer { base, space },
379                        ref other => {
380                            log::error!("Access index sub-type {:?}", other);
381                            return Err(ResolveError::InvalidSubAccess {
382                                ty: ty_base,
383                                indexed: true,
384                            });
385                        }
386                    }),
387                    Ti::BindingArray { base, .. } => TypeResolution::Handle(base),
388                    ref other => {
389                        log::error!("Access index type {:?}", other);
390                        return Err(ResolveError::InvalidAccess {
391                            expr: base,
392                            indexed: true,
393                        });
394                    }
395                }
396            }
397            crate::Expression::Splat { size, value } => match *past(value)?.inner_with(types) {
398                Ti::Scalar(scalar) => TypeResolution::Value(Ti::Vector { size, scalar }),
399                ref other => {
400                    log::error!("Scalar type {:?}", other);
401                    return Err(ResolveError::InvalidScalar(value));
402                }
403            },
404            crate::Expression::Swizzle {
405                size,
406                vector,
407                pattern: _,
408            } => match *past(vector)?.inner_with(types) {
409                Ti::Vector { size: _, scalar } => {
410                    TypeResolution::Value(Ti::Vector { size, scalar })
411                }
412                ref other => {
413                    log::error!("Vector type {:?}", other);
414                    return Err(ResolveError::InvalidVector(vector));
415                }
416            },
417            crate::Expression::Literal(lit) => TypeResolution::Value(lit.ty_inner()),
418            crate::Expression::Constant(h) => TypeResolution::Handle(self.constants[h].ty),
419            crate::Expression::Override(h) => TypeResolution::Handle(self.overrides[h].ty),
420            crate::Expression::ZeroValue(ty) => TypeResolution::Handle(ty),
421            crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty),
422            crate::Expression::FunctionArgument(index) => {
423                let arg = self
424                    .arguments
425                    .get(index as usize)
426                    .ok_or(ResolveError::FunctionArgumentNotFound(index))?;
427                TypeResolution::Handle(arg.ty)
428            }
429            crate::Expression::GlobalVariable(h) => {
430                let var = &self.global_vars[h];
431                if var.space == crate::AddressSpace::Handle {
432                    TypeResolution::Handle(var.ty)
433                } else {
434                    TypeResolution::Value(Ti::Pointer {
435                        base: var.ty,
436                        space: var.space,
437                    })
438                }
439            }
440            crate::Expression::LocalVariable(h) => {
441                let var = &self.local_vars[h];
442                TypeResolution::Value(Ti::Pointer {
443                    base: var.ty,
444                    space: crate::AddressSpace::Function,
445                })
446            }
447            crate::Expression::Load { pointer } => match *past(pointer)?.inner_with(types) {
448                Ti::Pointer { base, space: _ } => {
449                    if let Ti::Atomic(scalar) = types[base].inner {
450                        TypeResolution::Value(Ti::Scalar(scalar))
451                    } else {
452                        TypeResolution::Handle(base)
453                    }
454                }
455                Ti::ValuePointer {
456                    size,
457                    scalar,
458                    space: _,
459                } => TypeResolution::Value(match size {
460                    Some(size) => Ti::Vector { size, scalar },
461                    None => Ti::Scalar(scalar),
462                }),
463                ref other => {
464                    log::error!("Pointer type {:?}", other);
465                    return Err(ResolveError::InvalidPointer(pointer));
466                }
467            },
468            crate::Expression::ImageSample {
469                image,
470                gather: Some(_),
471                ..
472            } => match *past(image)?.inner_with(types) {
473                Ti::Image { class, .. } => TypeResolution::Value(Ti::Vector {
474                    scalar: crate::Scalar {
475                        kind: match class {
476                            crate::ImageClass::Sampled { kind, multi: _ } => kind,
477                            _ => crate::ScalarKind::Float,
478                        },
479                        width: 4,
480                    },
481                    size: crate::VectorSize::Quad,
482                }),
483                ref other => {
484                    log::error!("Image type {:?}", other);
485                    return Err(ResolveError::InvalidImage(image));
486                }
487            },
488            crate::Expression::ImageSample { image, .. }
489            | crate::Expression::ImageLoad { image, .. } => match *past(image)?.inner_with(types) {
490                Ti::Image { class, .. } => TypeResolution::Value(match class {
491                    crate::ImageClass::Depth { multi: _ } => Ti::Scalar(crate::Scalar::F32),
492                    crate::ImageClass::Sampled { kind, multi: _ } => Ti::Vector {
493                        scalar: crate::Scalar { kind, width: 4 },
494                        size: crate::VectorSize::Quad,
495                    },
496                    crate::ImageClass::Storage { format, .. } => Ti::Vector {
497                        scalar: format.into(),
498                        size: crate::VectorSize::Quad,
499                    },
500                }),
501                ref other => {
502                    log::error!("Image type {:?}", other);
503                    return Err(ResolveError::InvalidImage(image));
504                }
505            },
506            crate::Expression::ImageQuery { image, query } => TypeResolution::Value(match query {
507                crate::ImageQuery::Size { level: _ } => match *past(image)?.inner_with(types) {
508                    Ti::Image { dim, .. } => match dim {
509                        crate::ImageDimension::D1 => Ti::Scalar(crate::Scalar::U32),
510                        crate::ImageDimension::D2 | crate::ImageDimension::Cube => Ti::Vector {
511                            size: crate::VectorSize::Bi,
512                            scalar: crate::Scalar::U32,
513                        },
514                        crate::ImageDimension::D3 => Ti::Vector {
515                            size: crate::VectorSize::Tri,
516                            scalar: crate::Scalar::U32,
517                        },
518                    },
519                    ref other => {
520                        log::error!("Image type {:?}", other);
521                        return Err(ResolveError::InvalidImage(image));
522                    }
523                },
524                crate::ImageQuery::NumLevels
525                | crate::ImageQuery::NumLayers
526                | crate::ImageQuery::NumSamples => Ti::Scalar(crate::Scalar::U32),
527            }),
528            crate::Expression::Unary { expr, .. } => past(expr)?.clone(),
529            crate::Expression::Binary { op, left, right } => match op {
530                crate::BinaryOperator::Add
531                | crate::BinaryOperator::Subtract
532                | crate::BinaryOperator::Divide
533                | crate::BinaryOperator::Modulo => past(left)?.clone(),
534                crate::BinaryOperator::Multiply => {
535                    let (res_left, res_right) = (past(left)?, past(right)?);
536                    match (res_left.inner_with(types), res_right.inner_with(types)) {
537                        (
538                            &Ti::Matrix {
539                                columns: _,
540                                rows,
541                                scalar,
542                            },
543                            &Ti::Matrix { columns, .. },
544                        ) => TypeResolution::Value(Ti::Matrix {
545                            columns,
546                            rows,
547                            scalar,
548                        }),
549                        (
550                            &Ti::Matrix {
551                                columns: _,
552                                rows,
553                                scalar,
554                            },
555                            &Ti::Vector { .. },
556                        ) => TypeResolution::Value(Ti::Vector { size: rows, scalar }),
557                        (
558                            &Ti::Vector { .. },
559                            &Ti::Matrix {
560                                columns,
561                                rows: _,
562                                scalar,
563                            },
564                        ) => TypeResolution::Value(Ti::Vector {
565                            size: columns,
566                            scalar,
567                        }),
568                        (&Ti::Scalar { .. }, _) => res_right.clone(),
569                        (_, &Ti::Scalar { .. }) => res_left.clone(),
570                        (&Ti::Vector { .. }, &Ti::Vector { .. }) => res_left.clone(),
571                        (tl, tr) => {
572                            return Err(ResolveError::IncompatibleOperands(format!(
573                                "{tl:?} * {tr:?}"
574                            )))
575                        }
576                    }
577                }
578                crate::BinaryOperator::Equal
579                | crate::BinaryOperator::NotEqual
580                | crate::BinaryOperator::Less
581                | crate::BinaryOperator::LessEqual
582                | crate::BinaryOperator::Greater
583                | crate::BinaryOperator::GreaterEqual
584                | crate::BinaryOperator::LogicalAnd
585                | crate::BinaryOperator::LogicalOr => {
586                    let scalar = crate::Scalar::BOOL;
587                    let inner = match *past(left)?.inner_with(types) {
588                        Ti::Scalar { .. } => Ti::Scalar(scalar),
589                        Ti::Vector { size, .. } => Ti::Vector { size, scalar },
590                        ref other => {
591                            return Err(ResolveError::IncompatibleOperands(format!(
592                                "{op:?}({other:?}, _)"
593                            )))
594                        }
595                    };
596                    TypeResolution::Value(inner)
597                }
598                crate::BinaryOperator::And
599                | crate::BinaryOperator::ExclusiveOr
600                | crate::BinaryOperator::InclusiveOr
601                | crate::BinaryOperator::ShiftLeft
602                | crate::BinaryOperator::ShiftRight => past(left)?.clone(),
603            },
604            crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty),
605            crate::Expression::SubgroupOperationResult { ty } => TypeResolution::Handle(ty),
606            crate::Expression::WorkGroupUniformLoadResult { ty } => TypeResolution::Handle(ty),
607            crate::Expression::Select { accept, .. } => past(accept)?.clone(),
608            crate::Expression::Derivative { expr, .. } => past(expr)?.clone(),
609            crate::Expression::Relational { fun, argument } => match fun {
610                crate::RelationalFunction::All | crate::RelationalFunction::Any => {
611                    TypeResolution::Value(Ti::Scalar(crate::Scalar::BOOL))
612                }
613                crate::RelationalFunction::IsNan | crate::RelationalFunction::IsInf => {
614                    match *past(argument)?.inner_with(types) {
615                        Ti::Scalar { .. } => TypeResolution::Value(Ti::Scalar(crate::Scalar::BOOL)),
616                        Ti::Vector { size, .. } => TypeResolution::Value(Ti::Vector {
617                            scalar: crate::Scalar::BOOL,
618                            size,
619                        }),
620                        ref other => {
621                            return Err(ResolveError::IncompatibleOperands(format!(
622                                "{fun:?}({other:?})"
623                            )))
624                        }
625                    }
626                }
627            },
628            crate::Expression::Math {
629                fun,
630                arg,
631                arg1,
632                arg2: _,
633                arg3: _,
634            } => {
635                use crate::MathFunction as Mf;
636                let res_arg = past(arg)?;
637                match fun {
638                    Mf::Abs
639                    | Mf::Min
640                    | Mf::Max
641                    | Mf::Clamp
642                    | Mf::Saturate
643                    | Mf::Cos
644                    | Mf::Cosh
645                    | Mf::Sin
646                    | Mf::Sinh
647                    | Mf::Tan
648                    | Mf::Tanh
649                    | Mf::Acos
650                    | Mf::Asin
651                    | Mf::Atan
652                    | Mf::Atan2
653                    | Mf::Asinh
654                    | Mf::Acosh
655                    | Mf::Atanh
656                    | Mf::Radians
657                    | Mf::Degrees
658                    | Mf::Ceil
659                    | Mf::Floor
660                    | Mf::Round
661                    | Mf::Fract
662                    | Mf::Trunc
663                    | Mf::Ldexp
664                    | Mf::Exp
665                    | Mf::Exp2
666                    | Mf::Log
667                    | Mf::Log2
668                    | Mf::Pow => res_arg.clone(),
669                    Mf::Modf | Mf::Frexp => {
670                        let (size, width) = match res_arg.inner_with(types) {
671                            &Ti::Scalar(crate::Scalar {
672                                kind: crate::ScalarKind::Float,
673                                width,
674                            }) => (None, width),
675                            &Ti::Vector {
676                                scalar:
677                                    crate::Scalar {
678                                        kind: crate::ScalarKind::Float,
679                                        width,
680                                    },
681                                size,
682                            } => (Some(size), width),
683                            ref other => {
684                                return Err(ResolveError::IncompatibleOperands(format!(
685                                    "{fun:?}({other:?}, _)"
686                                )))
687                            }
688                        };
689                        let result = self
690                            .special_types
691                            .predeclared_types
692                            .get(&if fun == Mf::Modf {
693                                crate::PredeclaredType::ModfResult { size, width }
694                            } else {
695                                crate::PredeclaredType::FrexpResult { size, width }
696                            })
697                            .ok_or(ResolveError::MissingSpecialType)?;
698                        TypeResolution::Handle(*result)
699                    }
700                    Mf::Dot => match *res_arg.inner_with(types) {
701                        Ti::Vector { size: _, scalar } => TypeResolution::Value(Ti::Scalar(scalar)),
702                        ref other => {
703                            return Err(ResolveError::IncompatibleOperands(format!(
704                                "{fun:?}({other:?}, _)"
705                            )))
706                        }
707                    },
708                    Mf::Outer => {
709                        let arg1 = arg1.ok_or_else(|| {
710                            ResolveError::IncompatibleOperands(format!("{fun:?}(_, None)"))
711                        })?;
712                        match (res_arg.inner_with(types), past(arg1)?.inner_with(types)) {
713                            (
714                                &Ti::Vector {
715                                    size: columns,
716                                    scalar,
717                                },
718                                &Ti::Vector { size: rows, .. },
719                            ) => TypeResolution::Value(Ti::Matrix {
720                                columns,
721                                rows,
722                                scalar,
723                            }),
724                            (left, right) => {
725                                return Err(ResolveError::IncompatibleOperands(format!(
726                                    "{fun:?}({left:?}, {right:?})"
727                                )))
728                            }
729                        }
730                    }
731                    Mf::Cross => res_arg.clone(),
732                    Mf::Distance | Mf::Length => match *res_arg.inner_with(types) {
733                        Ti::Scalar(scalar) | Ti::Vector { scalar, size: _ } => {
734                            TypeResolution::Value(Ti::Scalar(scalar))
735                        }
736                        ref other => {
737                            return Err(ResolveError::IncompatibleOperands(format!(
738                                "{fun:?}({other:?})"
739                            )))
740                        }
741                    },
742                    Mf::Normalize | Mf::FaceForward | Mf::Reflect | Mf::Refract => res_arg.clone(),
743                    // computational
744                    Mf::Sign
745                    | Mf::Fma
746                    | Mf::Mix
747                    | Mf::Step
748                    | Mf::SmoothStep
749                    | Mf::Sqrt
750                    | Mf::InverseSqrt => res_arg.clone(),
751                    Mf::Transpose => match *res_arg.inner_with(types) {
752                        Ti::Matrix {
753                            columns,
754                            rows,
755                            scalar,
756                        } => TypeResolution::Value(Ti::Matrix {
757                            columns: rows,
758                            rows: columns,
759                            scalar,
760                        }),
761                        ref other => {
762                            return Err(ResolveError::IncompatibleOperands(format!(
763                                "{fun:?}({other:?})"
764                            )))
765                        }
766                    },
767                    Mf::Inverse => match *res_arg.inner_with(types) {
768                        Ti::Matrix {
769                            columns,
770                            rows,
771                            scalar,
772                        } if columns == rows => TypeResolution::Value(Ti::Matrix {
773                            columns,
774                            rows,
775                            scalar,
776                        }),
777                        ref other => {
778                            return Err(ResolveError::IncompatibleOperands(format!(
779                                "{fun:?}({other:?})"
780                            )))
781                        }
782                    },
783                    Mf::Determinant => match *res_arg.inner_with(types) {
784                        Ti::Matrix { scalar, .. } => TypeResolution::Value(Ti::Scalar(scalar)),
785                        ref other => {
786                            return Err(ResolveError::IncompatibleOperands(format!(
787                                "{fun:?}({other:?})"
788                            )))
789                        }
790                    },
791                    // bits
792                    Mf::CountTrailingZeros
793                    | Mf::CountLeadingZeros
794                    | Mf::CountOneBits
795                    | Mf::ReverseBits
796                    | Mf::ExtractBits
797                    | Mf::InsertBits
798                    | Mf::FirstTrailingBit
799                    | Mf::FirstLeadingBit => match *res_arg.inner_with(types) {
800                        Ti::Scalar(
801                            scalar @ crate::Scalar {
802                                kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
803                                ..
804                            },
805                        ) => TypeResolution::Value(Ti::Scalar(scalar)),
806                        Ti::Vector {
807                            size,
808                            scalar:
809                                scalar @ crate::Scalar {
810                                    kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
811                                    ..
812                                },
813                        } => TypeResolution::Value(Ti::Vector { size, scalar }),
814                        ref other => {
815                            return Err(ResolveError::IncompatibleOperands(format!(
816                                "{fun:?}({other:?})"
817                            )))
818                        }
819                    },
820                    // data packing
821                    Mf::Pack4x8snorm
822                    | Mf::Pack4x8unorm
823                    | Mf::Pack2x16snorm
824                    | Mf::Pack2x16unorm
825                    | Mf::Pack2x16float
826                    | Mf::Pack4xI8
827                    | Mf::Pack4xU8 => TypeResolution::Value(Ti::Scalar(crate::Scalar::U32)),
828                    // data unpacking
829                    Mf::Unpack4x8snorm | Mf::Unpack4x8unorm => TypeResolution::Value(Ti::Vector {
830                        size: crate::VectorSize::Quad,
831                        scalar: crate::Scalar::F32,
832                    }),
833                    Mf::Unpack2x16snorm | Mf::Unpack2x16unorm | Mf::Unpack2x16float => {
834                        TypeResolution::Value(Ti::Vector {
835                            size: crate::VectorSize::Bi,
836                            scalar: crate::Scalar::F32,
837                        })
838                    }
839                    Mf::Unpack4xI8 => TypeResolution::Value(Ti::Vector {
840                        size: crate::VectorSize::Quad,
841                        scalar: crate::Scalar::I32,
842                    }),
843                    Mf::Unpack4xU8 => TypeResolution::Value(Ti::Vector {
844                        size: crate::VectorSize::Quad,
845                        scalar: crate::Scalar::U32,
846                    }),
847                }
848            }
849            crate::Expression::As {
850                expr,
851                kind,
852                convert,
853            } => match *past(expr)?.inner_with(types) {
854                Ti::Scalar(crate::Scalar { width, .. }) => {
855                    TypeResolution::Value(Ti::Scalar(crate::Scalar {
856                        kind,
857                        width: convert.unwrap_or(width),
858                    }))
859                }
860                Ti::Vector {
861                    size,
862                    scalar: crate::Scalar { kind: _, width },
863                } => TypeResolution::Value(Ti::Vector {
864                    size,
865                    scalar: crate::Scalar {
866                        kind,
867                        width: convert.unwrap_or(width),
868                    },
869                }),
870                Ti::Matrix {
871                    columns,
872                    rows,
873                    mut scalar,
874                } => {
875                    if let Some(width) = convert {
876                        scalar.width = width;
877                    }
878                    TypeResolution::Value(Ti::Matrix {
879                        columns,
880                        rows,
881                        scalar,
882                    })
883                }
884                ref other => {
885                    return Err(ResolveError::IncompatibleOperands(format!(
886                        "{other:?} as {kind:?}"
887                    )))
888                }
889            },
890            crate::Expression::CallResult(function) => {
891                let result = self.functions[function]
892                    .result
893                    .as_ref()
894                    .ok_or(ResolveError::FunctionReturnsVoid)?;
895                TypeResolution::Handle(result.ty)
896            }
897            crate::Expression::ArrayLength(_) => {
898                TypeResolution::Value(Ti::Scalar(crate::Scalar::U32))
899            }
900            crate::Expression::RayQueryProceedResult => {
901                TypeResolution::Value(Ti::Scalar(crate::Scalar::BOOL))
902            }
903            crate::Expression::RayQueryGetIntersection { .. } => {
904                let result = self
905                    .special_types
906                    .ray_intersection
907                    .ok_or(ResolveError::MissingSpecialType)?;
908                TypeResolution::Handle(result)
909            }
910            crate::Expression::SubgroupBallotResult => TypeResolution::Value(Ti::Vector {
911                scalar: crate::Scalar::U32,
912                size: crate::VectorSize::Quad,
913            }),
914        })
915    }
916}
917
918#[test]
919fn test_error_size() {
920    use std::mem::size_of;
921    assert_eq!(size_of::<ResolveError>(), 32);
922}