1use super::{compose::validate_compose, FunctionInfo, ModuleInfo, ShaderStages, TypeFlags};
2use crate::arena::UniqueArena;
3
4use crate::{
5 arena::Handle,
6 proc::{IndexableLengthError, ResolveError},
7};
8
9#[derive(Clone, Debug, thiserror::Error)]
10#[cfg_attr(test, derive(PartialEq))]
11pub enum ExpressionError {
12 #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")]
13 NotInScope,
14 #[error("Base type {0:?} is not compatible with this expression")]
15 InvalidBaseType(Handle<crate::Expression>),
16 #[error("Accessing with index {0:?} can't be done")]
17 InvalidIndexType(Handle<crate::Expression>),
18 #[error("Accessing {0:?} via a negative index is invalid")]
19 NegativeIndex(Handle<crate::Expression>),
20 #[error("Accessing index {1} is out of {0:?} bounds")]
21 IndexOutOfBounds(Handle<crate::Expression>, u32),
22 #[error("Function argument {0:?} doesn't exist")]
23 FunctionArgumentDoesntExist(u32),
24 #[error("Loading of {0:?} can't be done")]
25 InvalidPointerType(Handle<crate::Expression>),
26 #[error("Array length of {0:?} can't be done")]
27 InvalidArrayType(Handle<crate::Expression>),
28 #[error("Get intersection of {0:?} can't be done")]
29 InvalidRayQueryType(Handle<crate::Expression>),
30 #[error("Splatting {0:?} can't be done")]
31 InvalidSplatType(Handle<crate::Expression>),
32 #[error("Swizzling {0:?} can't be done")]
33 InvalidVectorType(Handle<crate::Expression>),
34 #[error("Swizzle component {0:?} is outside of vector size {1:?}")]
35 InvalidSwizzleComponent(crate::SwizzleComponent, crate::VectorSize),
36 #[error(transparent)]
37 Compose(#[from] super::ComposeError),
38 #[error(transparent)]
39 IndexableLength(#[from] IndexableLengthError),
40 #[error("Operation {0:?} can't work with {1:?}")]
41 InvalidUnaryOperandType(crate::UnaryOperator, Handle<crate::Expression>),
42 #[error(
43 "Operation {:?} can't work with {:?} (of type {:?}) and {:?} (of type {:?})",
44 op,
45 lhs_expr,
46 lhs_type,
47 rhs_expr,
48 rhs_type
49 )]
50 InvalidBinaryOperandTypes {
51 op: crate::BinaryOperator,
52 lhs_expr: Handle<crate::Expression>,
53 lhs_type: crate::TypeInner,
54 rhs_expr: Handle<crate::Expression>,
55 rhs_type: crate::TypeInner,
56 },
57 #[error("Expected selection argument types to match, but reject value of type {reject:?} does not match accept value of value {accept:?}")]
58 SelectValuesTypeMismatch {
59 accept: crate::TypeInner,
60 reject: crate::TypeInner,
61 },
62 #[error("Expected selection condition to be a boolean value, got {actual:?}")]
63 SelectConditionNotABool { actual: crate::TypeInner },
64 #[error("Relational argument {0:?} is not a boolean vector")]
65 InvalidBooleanVector(Handle<crate::Expression>),
66 #[error("Relational argument {0:?} is not a float")]
67 InvalidFloatArgument(Handle<crate::Expression>),
68 #[error("Type resolution failed")]
69 Type(#[from] ResolveError),
70 #[error("Not a global variable")]
71 ExpectedGlobalVariable,
72 #[error("Not a global variable or a function argument")]
73 ExpectedGlobalOrArgument,
74 #[error("Needs to be an binding array instead of {0:?}")]
75 ExpectedBindingArrayType(Handle<crate::Type>),
76 #[error("Needs to be an image instead of {0:?}")]
77 ExpectedImageType(Handle<crate::Type>),
78 #[error("Needs to be an image instead of {0:?}")]
79 ExpectedSamplerType(Handle<crate::Type>),
80 #[error("Unable to operate on image class {0:?}")]
81 InvalidImageClass(crate::ImageClass),
82 #[error("Image atomics are not supported for storage format {0:?}")]
83 InvalidImageFormat(crate::StorageFormat),
84 #[error("Image atomics require atomic storage access, {0:?} is insufficient")]
85 InvalidImageStorageAccess(crate::StorageAccess),
86 #[error("Derivatives can only be taken from scalar and vector floats")]
87 InvalidDerivative,
88 #[error("Image array index parameter is misplaced")]
89 InvalidImageArrayIndex,
90 #[error("Inappropriate sample or level-of-detail index for texel access")]
91 InvalidImageOtherIndex,
92 #[error("Image array index type of {0:?} is not an integer scalar")]
93 InvalidImageArrayIndexType(Handle<crate::Expression>),
94 #[error("Image sample or level-of-detail index's type of {0:?} is not an integer scalar")]
95 InvalidImageOtherIndexType(Handle<crate::Expression>),
96 #[error("Image coordinate type of {1:?} does not match dimension {0:?}")]
97 InvalidImageCoordinateType(crate::ImageDimension, Handle<crate::Expression>),
98 #[error("Comparison sampling mismatch: image has class {image:?}, but the sampler is comparison={sampler}, and the reference was provided={has_ref}")]
99 ComparisonSamplingMismatch {
100 image: crate::ImageClass,
101 sampler: bool,
102 has_ref: bool,
103 },
104 #[error("Sample offset must be a const-expression")]
105 InvalidSampleOffsetExprType,
106 #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")]
107 InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>),
108 #[error("Depth reference {0:?} is not a scalar float")]
109 InvalidDepthReference(Handle<crate::Expression>),
110 #[error("Depth sample level can only be Auto or Zero")]
111 InvalidDepthSampleLevel,
112 #[error("Gather level can only be Zero")]
113 InvalidGatherLevel,
114 #[error("Gather component {0:?} doesn't exist in the image")]
115 InvalidGatherComponent(crate::SwizzleComponent),
116 #[error("Gather can't be done for image dimension {0:?}")]
117 InvalidGatherDimension(crate::ImageDimension),
118 #[error("Sample level (exact) type {0:?} has an invalid type")]
119 InvalidSampleLevelExactType(Handle<crate::Expression>),
120 #[error("Sample level (bias) type {0:?} is not a scalar float")]
121 InvalidSampleLevelBiasType(Handle<crate::Expression>),
122 #[error("Bias can't be done for image dimension {0:?}")]
123 InvalidSampleLevelBiasDimension(crate::ImageDimension),
124 #[error("Sample level (gradient) of {1:?} doesn't match the image dimension {0:?}")]
125 InvalidSampleLevelGradientType(crate::ImageDimension, Handle<crate::Expression>),
126 #[error("Unable to cast")]
127 InvalidCastArgument,
128 #[error("Invalid argument count for {0:?}")]
129 WrongArgumentCount(crate::MathFunction),
130 #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
131 InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
132 #[error(
133 "workgroupUniformLoad result type can't be {0:?}. It can only be a constructible type."
134 )]
135 InvalidWorkGroupUniformLoadResultType(Handle<crate::Type>),
136 #[error("Shader requires capability {0:?}")]
137 MissingCapabilities(super::Capabilities),
138 #[error(transparent)]
139 Literal(#[from] LiteralError),
140 #[error("{0:?} is not supported for Width {2} {1:?} arguments yet, see https://github.com/gfx-rs/wgpu/issues/5276")]
141 UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes),
142}
143
144#[derive(Clone, Debug, thiserror::Error)]
145#[cfg_attr(test, derive(PartialEq))]
146pub enum ConstExpressionError {
147 #[error("The expression is not a constant or override expression")]
148 NonConstOrOverride,
149 #[error("The expression is not a fully evaluated constant expression")]
150 NonFullyEvaluatedConst,
151 #[error(transparent)]
152 Compose(#[from] super::ComposeError),
153 #[error("Splatting {0:?} can't be done")]
154 InvalidSplatType(Handle<crate::Expression>),
155 #[error("Type resolution failed")]
156 Type(#[from] ResolveError),
157 #[error(transparent)]
158 Literal(#[from] LiteralError),
159 #[error(transparent)]
160 Width(#[from] super::r#type::WidthError),
161}
162
163#[derive(Clone, Debug, thiserror::Error)]
164#[cfg_attr(test, derive(PartialEq))]
165pub enum LiteralError {
166 #[error("Float literal is NaN")]
167 NaN,
168 #[error("Float literal is infinite")]
169 Infinity,
170 #[error(transparent)]
171 Width(#[from] super::r#type::WidthError),
172}
173
174struct ExpressionTypeResolver<'a> {
175 root: Handle<crate::Expression>,
176 types: &'a UniqueArena<crate::Type>,
177 info: &'a FunctionInfo,
178}
179
180impl std::ops::Index<Handle<crate::Expression>> for ExpressionTypeResolver<'_> {
181 type Output = crate::TypeInner;
182
183 #[allow(clippy::panic)]
184 fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
185 if handle < self.root {
186 self.info[handle].ty.inner_with(self.types)
187 } else {
188 panic!(
190 "Depends on {:?}, which has not been processed yet",
191 self.root
192 )
193 }
194 }
195}
196
197impl super::Validator {
198 pub(super) fn validate_const_expression(
199 &self,
200 handle: Handle<crate::Expression>,
201 gctx: crate::proc::GlobalCtx,
202 mod_info: &ModuleInfo,
203 global_expr_kind: &crate::proc::ExpressionKindTracker,
204 ) -> Result<(), ConstExpressionError> {
205 use crate::Expression as E;
206
207 if !global_expr_kind.is_const_or_override(handle) {
208 return Err(ConstExpressionError::NonConstOrOverride);
209 }
210
211 match gctx.global_expressions[handle] {
212 E::Literal(literal) => {
213 self.validate_literal(literal)?;
214 }
215 E::Constant(_) | E::ZeroValue(_) => {}
216 E::Compose { ref components, ty } => {
217 validate_compose(
218 ty,
219 gctx,
220 components.iter().map(|&handle| mod_info[handle].clone()),
221 )?;
222 }
223 E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) {
224 crate::TypeInner::Scalar { .. } => {}
225 _ => return Err(ConstExpressionError::InvalidSplatType(value)),
226 },
227 _ if global_expr_kind.is_const(handle) || !self.allow_overrides => {
228 return Err(ConstExpressionError::NonFullyEvaluatedConst)
229 }
230 _ => {}
232 }
233
234 Ok(())
235 }
236
237 #[allow(clippy::too_many_arguments)]
238 pub(super) fn validate_expression(
239 &self,
240 root: Handle<crate::Expression>,
241 expression: &crate::Expression,
242 function: &crate::Function,
243 module: &crate::Module,
244 info: &FunctionInfo,
245 mod_info: &ModuleInfo,
246 global_expr_kind: &crate::proc::ExpressionKindTracker,
247 ) -> Result<ShaderStages, ExpressionError> {
248 use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti};
249
250 let resolver = ExpressionTypeResolver {
251 root,
252 types: &module.types,
253 info,
254 };
255
256 let stages = match *expression {
257 E::Access { base, index } => {
258 let base_type = &resolver[base];
259 match *base_type {
260 Ti::Matrix { .. }
261 | Ti::Vector { .. }
262 | Ti::Array { .. }
263 | Ti::Pointer { .. }
264 | Ti::ValuePointer { size: Some(_), .. }
265 | Ti::BindingArray { .. } => {}
266 ref other => {
267 log::error!("Indexing of {:?}", other);
268 return Err(ExpressionError::InvalidBaseType(base));
269 }
270 };
271 match resolver[index] {
272 Ti::Scalar(Sc {
274 kind: Sk::Sint | Sk::Uint,
275 ..
276 }) => {}
277 ref other => {
278 log::error!("Indexing by {:?}", other);
279 return Err(ExpressionError::InvalidIndexType(index));
280 }
281 }
282
283 if let crate::proc::IndexableLength::Known(known_length) =
286 base_type.indexable_length(module)?
287 {
288 match module
289 .to_ctx()
290 .eval_expr_to_u32_from(index, &function.expressions)
291 {
292 Ok(value) => {
293 if value >= known_length {
294 return Err(ExpressionError::IndexOutOfBounds(base, value));
295 }
296 }
297 Err(crate::proc::U32EvalError::Negative) => {
298 return Err(ExpressionError::NegativeIndex(base))
299 }
300 Err(crate::proc::U32EvalError::NonConst) => {}
301 }
302 }
303
304 ShaderStages::all()
305 }
306 E::AccessIndex { base, index } => {
307 fn resolve_index_limit(
308 module: &crate::Module,
309 top: Handle<crate::Expression>,
310 ty: &crate::TypeInner,
311 top_level: bool,
312 ) -> Result<u32, ExpressionError> {
313 let limit = match *ty {
314 Ti::Vector { size, .. }
315 | Ti::ValuePointer {
316 size: Some(size), ..
317 } => size as u32,
318 Ti::Matrix { columns, .. } => columns as u32,
319 Ti::Array {
320 size: crate::ArraySize::Constant(len),
321 ..
322 } => len.get(),
323 Ti::Array { .. } | Ti::BindingArray { .. } => u32::MAX, Ti::Pointer { base, .. } if top_level => {
325 resolve_index_limit(module, top, &module.types[base].inner, false)?
326 }
327 Ti::Struct { ref members, .. } => members.len() as u32,
328 ref other => {
329 log::error!("Indexing of {:?}", other);
330 return Err(ExpressionError::InvalidBaseType(top));
331 }
332 };
333 Ok(limit)
334 }
335
336 let limit = resolve_index_limit(module, base, &resolver[base], true)?;
337 if index >= limit {
338 return Err(ExpressionError::IndexOutOfBounds(base, limit));
339 }
340 ShaderStages::all()
341 }
342 E::Splat { size: _, value } => match resolver[value] {
343 Ti::Scalar { .. } => ShaderStages::all(),
344 ref other => {
345 log::error!("Splat scalar type {:?}", other);
346 return Err(ExpressionError::InvalidSplatType(value));
347 }
348 },
349 E::Swizzle {
350 size,
351 vector,
352 pattern,
353 } => {
354 let vec_size = match resolver[vector] {
355 Ti::Vector { size: vec_size, .. } => vec_size,
356 ref other => {
357 log::error!("Swizzle vector type {:?}", other);
358 return Err(ExpressionError::InvalidVectorType(vector));
359 }
360 };
361 for &sc in pattern[..size as usize].iter() {
362 if sc as u8 >= vec_size as u8 {
363 return Err(ExpressionError::InvalidSwizzleComponent(sc, vec_size));
364 }
365 }
366 ShaderStages::all()
367 }
368 E::Literal(literal) => {
369 self.validate_literal(literal)?;
370 ShaderStages::all()
371 }
372 E::Constant(_) | E::Override(_) | E::ZeroValue(_) => ShaderStages::all(),
373 E::Compose { ref components, ty } => {
374 validate_compose(
375 ty,
376 module.to_ctx(),
377 components.iter().map(|&handle| info[handle].ty.clone()),
378 )?;
379 ShaderStages::all()
380 }
381 E::FunctionArgument(index) => {
382 if index >= function.arguments.len() as u32 {
383 return Err(ExpressionError::FunctionArgumentDoesntExist(index));
384 }
385 ShaderStages::all()
386 }
387 E::GlobalVariable(_handle) => ShaderStages::all(),
388 E::LocalVariable(_handle) => ShaderStages::all(),
389 E::Load { pointer } => {
390 match resolver[pointer] {
391 Ti::Pointer { base, .. }
392 if self.types[base.index()]
393 .flags
394 .contains(TypeFlags::SIZED | TypeFlags::DATA) => {}
395 Ti::ValuePointer { .. } => {}
396 ref other => {
397 log::error!("Loading {:?}", other);
398 return Err(ExpressionError::InvalidPointerType(pointer));
399 }
400 }
401 ShaderStages::all()
402 }
403 E::ImageSample {
404 image,
405 sampler,
406 gather,
407 coordinate,
408 array_index,
409 offset,
410 level,
411 depth_ref,
412 } => {
413 let image_ty = Self::global_var_ty(module, function, image)?;
415 let sampler_ty = Self::global_var_ty(module, function, sampler)?;
416
417 let comparison = match module.types[sampler_ty].inner {
418 Ti::Sampler { comparison } => comparison,
419 _ => return Err(ExpressionError::ExpectedSamplerType(sampler_ty)),
420 };
421
422 let (class, dim) = match module.types[image_ty].inner {
423 Ti::Image {
424 class,
425 arrayed,
426 dim,
427 } => {
428 if arrayed != array_index.is_some() {
430 return Err(ExpressionError::InvalidImageArrayIndex);
431 }
432 if let Some(expr) = array_index {
433 match resolver[expr] {
434 Ti::Scalar(Sc {
435 kind: Sk::Sint | Sk::Uint,
436 ..
437 }) => {}
438 _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
439 }
440 }
441 (class, dim)
442 }
443 _ => return Err(ExpressionError::ExpectedImageType(image_ty)),
444 };
445
446 let image_depth = match class {
448 crate::ImageClass::Sampled {
449 kind: crate::ScalarKind::Float,
450 multi: false,
451 } => false,
452 crate::ImageClass::Sampled {
453 kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint,
454 multi: false,
455 } if gather.is_some() => false,
456 crate::ImageClass::Depth { multi: false } => true,
457 _ => return Err(ExpressionError::InvalidImageClass(class)),
458 };
459 if comparison != depth_ref.is_some() || (comparison && !image_depth) {
460 return Err(ExpressionError::ComparisonSamplingMismatch {
461 image: class,
462 sampler: comparison,
463 has_ref: depth_ref.is_some(),
464 });
465 }
466
467 let num_components = match dim {
469 crate::ImageDimension::D1 => 1,
470 crate::ImageDimension::D2 => 2,
471 crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3,
472 };
473 match resolver[coordinate] {
474 Ti::Scalar(Sc {
475 kind: Sk::Float, ..
476 }) if num_components == 1 => {}
477 Ti::Vector {
478 size,
479 scalar:
480 Sc {
481 kind: Sk::Float, ..
482 },
483 } if size as u32 == num_components => {}
484 _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
485 }
486
487 if let Some(const_expr) = offset {
489 if !global_expr_kind.is_const(const_expr) {
490 return Err(ExpressionError::InvalidSampleOffsetExprType);
491 }
492
493 match *mod_info[const_expr].inner_with(&module.types) {
494 Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {}
495 Ti::Vector {
496 size,
497 scalar: Sc { kind: Sk::Sint, .. },
498 } if size as u32 == num_components => {}
499 _ => {
500 return Err(ExpressionError::InvalidSampleOffset(dim, const_expr));
501 }
502 }
503 }
504
505 if let Some(expr) = depth_ref {
507 match resolver[expr] {
508 Ti::Scalar(Sc {
509 kind: Sk::Float, ..
510 }) => {}
511 _ => return Err(ExpressionError::InvalidDepthReference(expr)),
512 }
513 match level {
514 crate::SampleLevel::Auto | crate::SampleLevel::Zero => {}
515 _ => return Err(ExpressionError::InvalidDepthSampleLevel),
516 }
517 }
518
519 if let Some(component) = gather {
520 match dim {
521 crate::ImageDimension::D2 | crate::ImageDimension::Cube => {}
522 crate::ImageDimension::D1 | crate::ImageDimension::D3 => {
523 return Err(ExpressionError::InvalidGatherDimension(dim))
524 }
525 };
526 let max_component = match class {
527 crate::ImageClass::Depth { .. } => crate::SwizzleComponent::X,
528 _ => crate::SwizzleComponent::W,
529 };
530 if component > max_component {
531 return Err(ExpressionError::InvalidGatherComponent(component));
532 }
533 match level {
534 crate::SampleLevel::Zero => {}
535 _ => return Err(ExpressionError::InvalidGatherLevel),
536 }
537 }
538
539 match level {
541 crate::SampleLevel::Auto => ShaderStages::FRAGMENT,
542 crate::SampleLevel::Zero => ShaderStages::all(),
543 crate::SampleLevel::Exact(expr) => {
544 match class {
545 crate::ImageClass::Depth { .. } => match resolver[expr] {
546 Ti::Scalar(Sc {
547 kind: Sk::Sint | Sk::Uint,
548 ..
549 }) => {}
550 _ => {
551 return Err(ExpressionError::InvalidSampleLevelExactType(expr))
552 }
553 },
554 _ => match resolver[expr] {
555 Ti::Scalar(Sc {
556 kind: Sk::Float, ..
557 }) => {}
558 _ => {
559 return Err(ExpressionError::InvalidSampleLevelExactType(expr))
560 }
561 },
562 }
563 ShaderStages::all()
564 }
565 crate::SampleLevel::Bias(expr) => {
566 match resolver[expr] {
567 Ti::Scalar(Sc {
568 kind: Sk::Float, ..
569 }) => {}
570 _ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)),
571 }
572 match class {
573 crate::ImageClass::Sampled {
574 kind: Sk::Float,
575 multi: false,
576 } => {
577 if dim == crate::ImageDimension::D1 {
578 return Err(ExpressionError::InvalidSampleLevelBiasDimension(
579 dim,
580 ));
581 }
582 }
583 _ => return Err(ExpressionError::InvalidImageClass(class)),
584 }
585 ShaderStages::FRAGMENT
586 }
587 crate::SampleLevel::Gradient { x, y } => {
588 match resolver[x] {
589 Ti::Scalar(Sc {
590 kind: Sk::Float, ..
591 }) if num_components == 1 => {}
592 Ti::Vector {
593 size,
594 scalar:
595 Sc {
596 kind: Sk::Float, ..
597 },
598 } if size as u32 == num_components => {}
599 _ => {
600 return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x))
601 }
602 }
603 match resolver[y] {
604 Ti::Scalar(Sc {
605 kind: Sk::Float, ..
606 }) if num_components == 1 => {}
607 Ti::Vector {
608 size,
609 scalar:
610 Sc {
611 kind: Sk::Float, ..
612 },
613 } if size as u32 == num_components => {}
614 _ => {
615 return Err(ExpressionError::InvalidSampleLevelGradientType(dim, y))
616 }
617 }
618 ShaderStages::all()
619 }
620 }
621 }
622 E::ImageLoad {
623 image,
624 coordinate,
625 array_index,
626 sample,
627 level,
628 } => {
629 let ty = Self::global_var_ty(module, function, image)?;
630 match module.types[ty].inner {
631 Ti::Image {
632 class,
633 arrayed,
634 dim,
635 } => {
636 match resolver[coordinate].image_storage_coordinates() {
637 Some(coord_dim) if coord_dim == dim => {}
638 _ => {
639 return Err(ExpressionError::InvalidImageCoordinateType(
640 dim, coordinate,
641 ))
642 }
643 };
644 if arrayed != array_index.is_some() {
645 return Err(ExpressionError::InvalidImageArrayIndex);
646 }
647 if let Some(expr) = array_index {
648 match resolver[expr] {
649 Ti::Scalar(Sc {
650 kind: Sk::Sint | Sk::Uint,
651 width: _,
652 }) => {}
653 _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
654 }
655 }
656
657 match (sample, class.is_multisampled()) {
658 (None, false) => {}
659 (Some(sample), true) => {
660 if resolver[sample].scalar_kind() != Some(Sk::Sint) {
661 return Err(ExpressionError::InvalidImageOtherIndexType(
662 sample,
663 ));
664 }
665 }
666 _ => {
667 return Err(ExpressionError::InvalidImageOtherIndex);
668 }
669 }
670
671 match (level, class.is_mipmapped()) {
672 (None, false) => {}
673 (Some(level), true) => {
674 if resolver[level].scalar_kind() != Some(Sk::Sint) {
675 return Err(ExpressionError::InvalidImageOtherIndexType(level));
676 }
677 }
678 _ => {
679 return Err(ExpressionError::InvalidImageOtherIndex);
680 }
681 }
682 }
683 _ => return Err(ExpressionError::ExpectedImageType(ty)),
684 }
685 ShaderStages::all()
686 }
687 E::ImageQuery { image, query } => {
688 let ty = Self::global_var_ty(module, function, image)?;
689 match module.types[ty].inner {
690 Ti::Image { class, arrayed, .. } => {
691 let good = match query {
692 crate::ImageQuery::NumLayers => arrayed,
693 crate::ImageQuery::Size { level: None } => true,
694 crate::ImageQuery::Size { level: Some(_) }
695 | crate::ImageQuery::NumLevels => class.is_mipmapped(),
696 crate::ImageQuery::NumSamples => class.is_multisampled(),
697 };
698 if !good {
699 return Err(ExpressionError::InvalidImageClass(class));
700 }
701 }
702 _ => return Err(ExpressionError::ExpectedImageType(ty)),
703 }
704 ShaderStages::all()
705 }
706 E::Unary { op, expr } => {
707 use crate::UnaryOperator as Uo;
708 let inner = &resolver[expr];
709 match (op, inner.scalar_kind()) {
710 (Uo::Negate, Some(Sk::Float | Sk::Sint))
711 | (Uo::LogicalNot, Some(Sk::Bool))
712 | (Uo::BitwiseNot, Some(Sk::Sint | Sk::Uint)) => {}
713 other => {
714 log::error!("Op {:?} kind {:?}", op, other);
715 return Err(ExpressionError::InvalidUnaryOperandType(op, expr));
716 }
717 }
718 ShaderStages::all()
719 }
720 E::Binary { op, left, right } => {
721 use crate::BinaryOperator as Bo;
722 let left_inner = &resolver[left];
723 let right_inner = &resolver[right];
724 let good = match op {
725 Bo::Add | Bo::Subtract => match *left_inner {
726 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
727 Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
728 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
729 },
730 Ti::Matrix { .. } => left_inner == right_inner,
731 _ => false,
732 },
733 Bo::Divide | Bo::Modulo => match *left_inner {
734 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
735 Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
736 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
737 },
738 _ => false,
739 },
740 Bo::Multiply => {
741 let kind_allowed = match left_inner.scalar_kind() {
742 Some(Sk::Uint | Sk::Sint | Sk::Float) => true,
743 Some(Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat) | None => false,
744 };
745 let types_match = match (left_inner, right_inner) {
746 (&Ti::Scalar(scalar1), &Ti::Scalar(scalar2))
748 | (
749 &Ti::Vector {
750 scalar: scalar1, ..
751 },
752 &Ti::Scalar(scalar2),
753 )
754 | (
755 &Ti::Scalar(scalar1),
756 &Ti::Vector {
757 scalar: scalar2, ..
758 },
759 ) => scalar1 == scalar2,
760 (
762 &Ti::Scalar(Sc {
763 kind: Sk::Float, ..
764 }),
765 &Ti::Matrix { .. },
766 )
767 | (
768 &Ti::Matrix { .. },
769 &Ti::Scalar(Sc {
770 kind: Sk::Float, ..
771 }),
772 ) => true,
773 (
775 &Ti::Vector {
776 size: size1,
777 scalar: scalar1,
778 },
779 &Ti::Vector {
780 size: size2,
781 scalar: scalar2,
782 },
783 ) => scalar1 == scalar2 && size1 == size2,
784 (
786 &Ti::Matrix { columns, .. },
787 &Ti::Vector {
788 size,
789 scalar:
790 Sc {
791 kind: Sk::Float, ..
792 },
793 },
794 ) => columns == size,
795 (
797 &Ti::Vector {
798 size,
799 scalar:
800 Sc {
801 kind: Sk::Float, ..
802 },
803 },
804 &Ti::Matrix { rows, .. },
805 ) => size == rows,
806 (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => {
807 columns == rows
808 }
809 _ => false,
810 };
811 let left_width = left_inner.scalar_width().unwrap_or(0);
812 let right_width = right_inner.scalar_width().unwrap_or(0);
813 kind_allowed && types_match && left_width == right_width
814 }
815 Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner,
816 Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => {
817 match *left_inner {
818 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
819 Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
820 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
821 },
822 ref other => {
823 log::error!("Op {:?} left type {:?}", op, other);
824 false
825 }
826 }
827 }
828 Bo::LogicalAnd | Bo::LogicalOr => match *left_inner {
829 Ti::Scalar(Sc { kind: Sk::Bool, .. })
830 | Ti::Vector {
831 scalar: Sc { kind: Sk::Bool, .. },
832 ..
833 } => left_inner == right_inner,
834 ref other => {
835 log::error!("Op {:?} left type {:?}", op, other);
836 false
837 }
838 },
839 Bo::And | Bo::InclusiveOr => match *left_inner {
840 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
841 Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner,
842 Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
843 },
844 ref other => {
845 log::error!("Op {:?} left type {:?}", op, other);
846 false
847 }
848 },
849 Bo::ExclusiveOr => match *left_inner {
850 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
851 Sk::Sint | Sk::Uint => left_inner == right_inner,
852 Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
853 },
854 ref other => {
855 log::error!("Op {:?} left type {:?}", op, other);
856 false
857 }
858 },
859 Bo::ShiftLeft | Bo::ShiftRight => {
860 let (base_size, base_scalar) = match *left_inner {
861 Ti::Scalar(scalar) => (Ok(None), scalar),
862 Ti::Vector { size, scalar } => (Ok(Some(size)), scalar),
863 ref other => {
864 log::error!("Op {:?} base type {:?}", op, other);
865 (Err(()), Sc::BOOL)
866 }
867 };
868 let shift_size = match *right_inner {
869 Ti::Scalar(Sc { kind: Sk::Uint, .. }) => Ok(None),
870 Ti::Vector {
871 size,
872 scalar: Sc { kind: Sk::Uint, .. },
873 } => Ok(Some(size)),
874 ref other => {
875 log::error!("Op {:?} shift type {:?}", op, other);
876 Err(())
877 }
878 };
879 match base_scalar.kind {
880 Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size,
881 Sk::Float | Sk::AbstractInt | Sk::AbstractFloat | Sk::Bool => false,
882 }
883 }
884 };
885 if !good {
886 log::error!(
887 "Left: {:?} of type {:?}",
888 function.expressions[left],
889 left_inner
890 );
891 log::error!(
892 "Right: {:?} of type {:?}",
893 function.expressions[right],
894 right_inner
895 );
896 return Err(ExpressionError::InvalidBinaryOperandTypes {
897 op,
898 lhs_expr: left,
899 lhs_type: left_inner.clone(),
900 rhs_expr: right,
901 rhs_type: right_inner.clone(),
902 });
903 }
904 ShaderStages::all()
905 }
906 E::Select {
907 condition,
908 accept,
909 reject,
910 } => {
911 let accept_inner = &resolver[accept];
912 let reject_inner = &resolver[reject];
913 let condition_ty = &resolver[condition];
914 let condition_good = match *condition_ty {
915 Ti::Scalar(Sc {
916 kind: Sk::Bool,
917 width: _,
918 }) => {
919 match *accept_inner {
922 Ti::Scalar { .. } | Ti::Vector { .. } => true,
923 _ => false,
924 }
925 }
926 Ti::Vector {
927 size,
928 scalar:
929 Sc {
930 kind: Sk::Bool,
931 width: _,
932 },
933 } => match *accept_inner {
934 Ti::Vector {
935 size: other_size, ..
936 } => size == other_size,
937 _ => false,
938 },
939 _ => false,
940 };
941 if accept_inner != reject_inner {
942 return Err(ExpressionError::SelectValuesTypeMismatch {
943 accept: accept_inner.clone(),
944 reject: reject_inner.clone(),
945 });
946 }
947 if !condition_good {
948 return Err(ExpressionError::SelectConditionNotABool {
949 actual: condition_ty.clone(),
950 });
951 }
952 ShaderStages::all()
953 }
954 E::Derivative { expr, .. } => {
955 match resolver[expr] {
956 Ti::Scalar(Sc {
957 kind: Sk::Float, ..
958 })
959 | Ti::Vector {
960 scalar:
961 Sc {
962 kind: Sk::Float, ..
963 },
964 ..
965 } => {}
966 _ => return Err(ExpressionError::InvalidDerivative),
967 }
968 ShaderStages::FRAGMENT
969 }
970 E::Relational { fun, argument } => {
971 use crate::RelationalFunction as Rf;
972 let argument_inner = &resolver[argument];
973 match fun {
974 Rf::All | Rf::Any => match *argument_inner {
975 Ti::Vector {
976 scalar: Sc { kind: Sk::Bool, .. },
977 ..
978 } => {}
979 ref other => {
980 log::error!("All/Any of type {:?}", other);
981 return Err(ExpressionError::InvalidBooleanVector(argument));
982 }
983 },
984 Rf::IsNan | Rf::IsInf => match *argument_inner {
985 Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
986 if scalar.kind == Sk::Float => {}
987 ref other => {
988 log::error!("Float test of type {:?}", other);
989 return Err(ExpressionError::InvalidFloatArgument(argument));
990 }
991 },
992 }
993 ShaderStages::all()
994 }
995 E::Math {
996 fun,
997 arg,
998 arg1,
999 arg2,
1000 arg3,
1001 } => {
1002 use crate::MathFunction as Mf;
1003
1004 let resolve = |arg| &resolver[arg];
1005 let arg_ty = resolve(arg);
1006 let arg1_ty = arg1.map(resolve);
1007 let arg2_ty = arg2.map(resolve);
1008 let arg3_ty = arg3.map(resolve);
1009 match fun {
1010 Mf::Abs => {
1011 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1012 return Err(ExpressionError::WrongArgumentCount(fun));
1013 }
1014 let good = match *arg_ty {
1015 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
1016 scalar.kind != Sk::Bool
1017 }
1018 _ => false,
1019 };
1020 if !good {
1021 return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
1022 }
1023 }
1024 Mf::Min | Mf::Max => {
1025 let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1026 (Some(ty1), None, None) => ty1,
1027 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1028 };
1029 let good = match *arg_ty {
1030 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
1031 scalar.kind != Sk::Bool
1032 }
1033 _ => false,
1034 };
1035 if !good {
1036 return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
1037 }
1038 if arg1_ty != arg_ty {
1039 return Err(ExpressionError::InvalidArgumentType(
1040 fun,
1041 1,
1042 arg1.unwrap(),
1043 ));
1044 }
1045 }
1046 Mf::Clamp => {
1047 let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1048 (Some(ty1), Some(ty2), None) => (ty1, ty2),
1049 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1050 };
1051 let good = match *arg_ty {
1052 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
1053 scalar.kind != Sk::Bool
1054 }
1055 _ => false,
1056 };
1057 if !good {
1058 return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
1059 }
1060 if arg1_ty != arg_ty {
1061 return Err(ExpressionError::InvalidArgumentType(
1062 fun,
1063 1,
1064 arg1.unwrap(),
1065 ));
1066 }
1067 if arg2_ty != arg_ty {
1068 return Err(ExpressionError::InvalidArgumentType(
1069 fun,
1070 2,
1071 arg2.unwrap(),
1072 ));
1073 }
1074 }
1075 Mf::Saturate
1076 | Mf::Cos
1077 | Mf::Cosh
1078 | Mf::Sin
1079 | Mf::Sinh
1080 | Mf::Tan
1081 | Mf::Tanh
1082 | Mf::Acos
1083 | Mf::Asin
1084 | Mf::Atan
1085 | Mf::Asinh
1086 | Mf::Acosh
1087 | Mf::Atanh
1088 | Mf::Radians
1089 | Mf::Degrees
1090 | Mf::Ceil
1091 | Mf::Floor
1092 | Mf::Round
1093 | Mf::Fract
1094 | Mf::Trunc
1095 | Mf::Exp
1096 | Mf::Exp2
1097 | Mf::Log
1098 | Mf::Log2
1099 | Mf::Length
1100 | Mf::Sqrt
1101 | Mf::InverseSqrt => {
1102 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1103 return Err(ExpressionError::WrongArgumentCount(fun));
1104 }
1105 match *arg_ty {
1106 Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1107 if scalar.kind == Sk::Float => {}
1108 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1109 }
1110 }
1111 Mf::Sign => {
1112 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1113 return Err(ExpressionError::WrongArgumentCount(fun));
1114 }
1115 match *arg_ty {
1116 Ti::Scalar(Sc {
1117 kind: Sk::Float | Sk::Sint,
1118 ..
1119 })
1120 | Ti::Vector {
1121 scalar:
1122 Sc {
1123 kind: Sk::Float | Sk::Sint,
1124 ..
1125 },
1126 ..
1127 } => {}
1128 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1129 }
1130 }
1131 Mf::Atan2 | Mf::Pow | Mf::Distance | Mf::Step => {
1132 let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1133 (Some(ty1), None, None) => ty1,
1134 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1135 };
1136 match *arg_ty {
1137 Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1138 if scalar.kind == Sk::Float => {}
1139 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1140 }
1141 if arg1_ty != arg_ty {
1142 return Err(ExpressionError::InvalidArgumentType(
1143 fun,
1144 1,
1145 arg1.unwrap(),
1146 ));
1147 }
1148 }
1149 Mf::Modf | Mf::Frexp => {
1150 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1151 return Err(ExpressionError::WrongArgumentCount(fun));
1152 }
1153 if !matches!(*arg_ty,
1154 Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1155 if scalar.kind == Sk::Float)
1156 {
1157 return Err(ExpressionError::InvalidArgumentType(fun, 1, arg));
1158 }
1159 }
1160 Mf::Ldexp => {
1161 let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1162 (Some(ty1), None, None) => ty1,
1163 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1164 };
1165 let size0 = match *arg_ty {
1166 Ti::Scalar(Sc {
1167 kind: Sk::Float, ..
1168 }) => None,
1169 Ti::Vector {
1170 scalar:
1171 Sc {
1172 kind: Sk::Float, ..
1173 },
1174 size,
1175 } => Some(size),
1176 _ => {
1177 return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
1178 }
1179 };
1180 let good = match *arg1_ty {
1181 Ti::Scalar(Sc { kind: Sk::Sint, .. }) if size0.is_none() => true,
1182 Ti::Vector {
1183 size,
1184 scalar: Sc { kind: Sk::Sint, .. },
1185 } if Some(size) == size0 => true,
1186 _ => false,
1187 };
1188 if !good {
1189 return Err(ExpressionError::InvalidArgumentType(
1190 fun,
1191 1,
1192 arg1.unwrap(),
1193 ));
1194 }
1195 }
1196 Mf::Dot => {
1197 let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1198 (Some(ty1), None, None) => ty1,
1199 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1200 };
1201 match *arg_ty {
1202 Ti::Vector {
1203 scalar:
1204 Sc {
1205 kind: Sk::Float | Sk::Sint | Sk::Uint,
1206 ..
1207 },
1208 ..
1209 } => {}
1210 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1211 }
1212 if arg1_ty != arg_ty {
1213 return Err(ExpressionError::InvalidArgumentType(
1214 fun,
1215 1,
1216 arg1.unwrap(),
1217 ));
1218 }
1219 }
1220 Mf::Outer | Mf::Reflect => {
1221 let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1222 (Some(ty1), None, None) => ty1,
1223 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1224 };
1225 match *arg_ty {
1226 Ti::Vector {
1227 scalar:
1228 Sc {
1229 kind: Sk::Float, ..
1230 },
1231 ..
1232 } => {}
1233 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1234 }
1235 if arg1_ty != arg_ty {
1236 return Err(ExpressionError::InvalidArgumentType(
1237 fun,
1238 1,
1239 arg1.unwrap(),
1240 ));
1241 }
1242 }
1243 Mf::Cross => {
1244 let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1245 (Some(ty1), None, None) => ty1,
1246 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1247 };
1248 match *arg_ty {
1249 Ti::Vector {
1250 scalar:
1251 Sc {
1252 kind: Sk::Float, ..
1253 },
1254 size: crate::VectorSize::Tri,
1255 } => {}
1256 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1257 }
1258 if arg1_ty != arg_ty {
1259 return Err(ExpressionError::InvalidArgumentType(
1260 fun,
1261 1,
1262 arg1.unwrap(),
1263 ));
1264 }
1265 }
1266 Mf::Refract => {
1267 let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1268 (Some(ty1), Some(ty2), None) => (ty1, ty2),
1269 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1270 };
1271
1272 match *arg_ty {
1273 Ti::Vector {
1274 scalar:
1275 Sc {
1276 kind: Sk::Float, ..
1277 },
1278 ..
1279 } => {}
1280 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1281 }
1282
1283 if arg1_ty != arg_ty {
1284 return Err(ExpressionError::InvalidArgumentType(
1285 fun,
1286 1,
1287 arg1.unwrap(),
1288 ));
1289 }
1290
1291 match (arg_ty, arg2_ty) {
1292 (
1293 &Ti::Vector {
1294 scalar:
1295 Sc {
1296 width: vector_width,
1297 ..
1298 },
1299 ..
1300 },
1301 &Ti::Scalar(Sc {
1302 width: scalar_width,
1303 kind: Sk::Float,
1304 }),
1305 ) if vector_width == scalar_width => {}
1306 _ => {
1307 return Err(ExpressionError::InvalidArgumentType(
1308 fun,
1309 2,
1310 arg2.unwrap(),
1311 ))
1312 }
1313 }
1314 }
1315 Mf::Normalize => {
1316 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1317 return Err(ExpressionError::WrongArgumentCount(fun));
1318 }
1319 match *arg_ty {
1320 Ti::Vector {
1321 scalar:
1322 Sc {
1323 kind: Sk::Float, ..
1324 },
1325 ..
1326 } => {}
1327 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1328 }
1329 }
1330 Mf::FaceForward | Mf::Fma | Mf::SmoothStep => {
1331 let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1332 (Some(ty1), Some(ty2), None) => (ty1, ty2),
1333 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1334 };
1335 match *arg_ty {
1336 Ti::Scalar(Sc {
1337 kind: Sk::Float, ..
1338 })
1339 | Ti::Vector {
1340 scalar:
1341 Sc {
1342 kind: Sk::Float, ..
1343 },
1344 ..
1345 } => {}
1346 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1347 }
1348 if arg1_ty != arg_ty {
1349 return Err(ExpressionError::InvalidArgumentType(
1350 fun,
1351 1,
1352 arg1.unwrap(),
1353 ));
1354 }
1355 if arg2_ty != arg_ty {
1356 return Err(ExpressionError::InvalidArgumentType(
1357 fun,
1358 2,
1359 arg2.unwrap(),
1360 ));
1361 }
1362 }
1363 Mf::Mix => {
1364 let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1365 (Some(ty1), Some(ty2), None) => (ty1, ty2),
1366 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1367 };
1368 let arg_width = match *arg_ty {
1369 Ti::Scalar(Sc {
1370 kind: Sk::Float,
1371 width,
1372 })
1373 | Ti::Vector {
1374 scalar:
1375 Sc {
1376 kind: Sk::Float,
1377 width,
1378 },
1379 ..
1380 } => width,
1381 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1382 };
1383 if arg1_ty != arg_ty {
1384 return Err(ExpressionError::InvalidArgumentType(
1385 fun,
1386 1,
1387 arg1.unwrap(),
1388 ));
1389 }
1390 match *arg2_ty {
1392 Ti::Scalar(Sc {
1393 kind: Sk::Float,
1394 width,
1395 }) if width == arg_width => {}
1396 _ if arg2_ty == arg_ty => {}
1397 _ => {
1398 return Err(ExpressionError::InvalidArgumentType(
1399 fun,
1400 2,
1401 arg2.unwrap(),
1402 ));
1403 }
1404 }
1405 }
1406 Mf::Inverse | Mf::Determinant => {
1407 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1408 return Err(ExpressionError::WrongArgumentCount(fun));
1409 }
1410 let good = match *arg_ty {
1411 Ti::Matrix { columns, rows, .. } => columns == rows,
1412 _ => false,
1413 };
1414 if !good {
1415 return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
1416 }
1417 }
1418 Mf::Transpose => {
1419 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1420 return Err(ExpressionError::WrongArgumentCount(fun));
1421 }
1422 match *arg_ty {
1423 Ti::Matrix { .. } => {}
1424 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1425 }
1426 }
1427 Mf::QuantizeToF16 => {
1428 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1429 return Err(ExpressionError::WrongArgumentCount(fun));
1430 }
1431 match *arg_ty {
1432 Ti::Scalar(Sc {
1433 kind: Sk::Float,
1434 width: 4,
1435 })
1436 | Ti::Vector {
1437 scalar:
1438 Sc {
1439 kind: Sk::Float,
1440 width: 4,
1441 },
1442 ..
1443 } => {}
1444 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1445 }
1446 }
1447 Mf::CountLeadingZeros
1449 | Mf::CountTrailingZeros
1450 | Mf::CountOneBits
1451 | Mf::ReverseBits
1452 | Mf::FirstLeadingBit
1453 | Mf::FirstTrailingBit => {
1454 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1455 return Err(ExpressionError::WrongArgumentCount(fun));
1456 }
1457 match *arg_ty {
1458 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
1459 Sk::Sint | Sk::Uint => {
1460 if scalar.width != 4 {
1461 return Err(ExpressionError::UnsupportedWidth(
1462 fun,
1463 scalar.kind,
1464 scalar.width,
1465 ));
1466 }
1467 }
1468 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1469 },
1470 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1471 }
1472 }
1473 Mf::InsertBits => {
1474 let (arg1_ty, arg2_ty, arg3_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1475 (Some(ty1), Some(ty2), Some(ty3)) => (ty1, ty2, ty3),
1476 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1477 };
1478 match *arg_ty {
1479 Ti::Scalar(Sc {
1480 kind: Sk::Sint | Sk::Uint,
1481 ..
1482 })
1483 | Ti::Vector {
1484 scalar:
1485 Sc {
1486 kind: Sk::Sint | Sk::Uint,
1487 ..
1488 },
1489 ..
1490 } => {}
1491 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1492 }
1493 if arg1_ty != arg_ty {
1494 return Err(ExpressionError::InvalidArgumentType(
1495 fun,
1496 1,
1497 arg1.unwrap(),
1498 ));
1499 }
1500 match *arg2_ty {
1501 Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1502 _ => {
1503 return Err(ExpressionError::InvalidArgumentType(
1504 fun,
1505 2,
1506 arg2.unwrap(),
1507 ))
1508 }
1509 }
1510 match *arg3_ty {
1511 Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1512 _ => {
1513 return Err(ExpressionError::InvalidArgumentType(
1514 fun,
1515 2,
1516 arg3.unwrap(),
1517 ))
1518 }
1519 }
1520 for &arg in [arg_ty, arg1_ty, arg2_ty, arg3_ty].iter() {
1522 match *arg {
1523 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
1524 if scalar.width != 4 {
1525 return Err(ExpressionError::UnsupportedWidth(
1526 fun,
1527 scalar.kind,
1528 scalar.width,
1529 ));
1530 }
1531 }
1532 _ => {}
1533 }
1534 }
1535 }
1536 Mf::ExtractBits => {
1537 let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1538 (Some(ty1), Some(ty2), None) => (ty1, ty2),
1539 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1540 };
1541 match *arg_ty {
1542 Ti::Scalar(Sc {
1543 kind: Sk::Sint | Sk::Uint,
1544 ..
1545 })
1546 | Ti::Vector {
1547 scalar:
1548 Sc {
1549 kind: Sk::Sint | Sk::Uint,
1550 ..
1551 },
1552 ..
1553 } => {}
1554 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1555 }
1556 match *arg1_ty {
1557 Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1558 _ => {
1559 return Err(ExpressionError::InvalidArgumentType(
1560 fun,
1561 2,
1562 arg1.unwrap(),
1563 ))
1564 }
1565 }
1566 match *arg2_ty {
1567 Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1568 _ => {
1569 return Err(ExpressionError::InvalidArgumentType(
1570 fun,
1571 2,
1572 arg2.unwrap(),
1573 ))
1574 }
1575 }
1576 for &arg in [arg_ty, arg1_ty, arg2_ty].iter() {
1578 match *arg {
1579 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
1580 if scalar.width != 4 {
1581 return Err(ExpressionError::UnsupportedWidth(
1582 fun,
1583 scalar.kind,
1584 scalar.width,
1585 ));
1586 }
1587 }
1588 _ => {}
1589 }
1590 }
1591 }
1592 Mf::Pack2x16unorm | Mf::Pack2x16snorm | Mf::Pack2x16float => {
1593 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1594 return Err(ExpressionError::WrongArgumentCount(fun));
1595 }
1596 match *arg_ty {
1597 Ti::Vector {
1598 size: crate::VectorSize::Bi,
1599 scalar:
1600 Sc {
1601 kind: Sk::Float, ..
1602 },
1603 } => {}
1604 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1605 }
1606 }
1607 Mf::Pack4x8snorm | Mf::Pack4x8unorm => {
1608 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1609 return Err(ExpressionError::WrongArgumentCount(fun));
1610 }
1611 match *arg_ty {
1612 Ti::Vector {
1613 size: crate::VectorSize::Quad,
1614 scalar:
1615 Sc {
1616 kind: Sk::Float, ..
1617 },
1618 } => {}
1619 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1620 }
1621 }
1622 mf @ (Mf::Pack4xI8 | Mf::Pack4xU8) => {
1623 let scalar_kind = match mf {
1624 Mf::Pack4xI8 => Sk::Sint,
1625 Mf::Pack4xU8 => Sk::Uint,
1626 _ => unreachable!(),
1627 };
1628 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1629 return Err(ExpressionError::WrongArgumentCount(fun));
1630 }
1631 match *arg_ty {
1632 Ti::Vector {
1633 size: crate::VectorSize::Quad,
1634 scalar: Sc { kind, .. },
1635 } if kind == scalar_kind => {}
1636 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1637 }
1638 }
1639 Mf::Unpack2x16float
1640 | Mf::Unpack2x16snorm
1641 | Mf::Unpack2x16unorm
1642 | Mf::Unpack4x8snorm
1643 | Mf::Unpack4x8unorm
1644 | Mf::Unpack4xI8
1645 | Mf::Unpack4xU8 => {
1646 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1647 return Err(ExpressionError::WrongArgumentCount(fun));
1648 }
1649 match *arg_ty {
1650 Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1651 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1652 }
1653 }
1654 }
1655 ShaderStages::all()
1656 }
1657 E::As {
1658 expr,
1659 kind,
1660 convert,
1661 } => {
1662 let mut base_scalar = match resolver[expr] {
1663 crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } => {
1664 scalar
1665 }
1666 crate::TypeInner::Matrix { scalar, .. } => scalar,
1667 _ => return Err(ExpressionError::InvalidCastArgument),
1668 };
1669 base_scalar.kind = kind;
1670 if let Some(width) = convert {
1671 base_scalar.width = width;
1672 }
1673 if self.check_width(base_scalar).is_err() {
1674 return Err(ExpressionError::InvalidCastArgument);
1675 }
1676 ShaderStages::all()
1677 }
1678 E::CallResult(function) => mod_info.functions[function.index()].available_stages,
1679 E::AtomicResult { .. } => {
1680 ShaderStages::all()
1685 }
1686 E::WorkGroupUniformLoadResult { ty } => {
1687 if self.types[ty.index()]
1688 .flags
1689 .contains(TypeFlags::SIZED | TypeFlags::CONSTRUCTIBLE)
1692 {
1693 ShaderStages::COMPUTE
1694 } else {
1695 return Err(ExpressionError::InvalidWorkGroupUniformLoadResultType(ty));
1696 }
1697 }
1698 E::ArrayLength(expr) => match resolver[expr] {
1699 Ti::Pointer { base, .. } => {
1700 let base_ty = &resolver.types[base];
1701 if let Ti::Array {
1702 size: crate::ArraySize::Dynamic,
1703 ..
1704 } = base_ty.inner
1705 {
1706 ShaderStages::all()
1707 } else {
1708 return Err(ExpressionError::InvalidArrayType(expr));
1709 }
1710 }
1711 ref other => {
1712 log::error!("Array length of {:?}", other);
1713 return Err(ExpressionError::InvalidArrayType(expr));
1714 }
1715 },
1716 E::RayQueryProceedResult => ShaderStages::all(),
1717 E::RayQueryGetIntersection {
1718 query,
1719 committed: _,
1720 } => match resolver[query] {
1721 Ti::Pointer {
1722 base,
1723 space: crate::AddressSpace::Function,
1724 } => match resolver.types[base].inner {
1725 Ti::RayQuery => ShaderStages::all(),
1726 ref other => {
1727 log::error!("Intersection result of a pointer to {:?}", other);
1728 return Err(ExpressionError::InvalidRayQueryType(query));
1729 }
1730 },
1731 ref other => {
1732 log::error!("Intersection result of {:?}", other);
1733 return Err(ExpressionError::InvalidRayQueryType(query));
1734 }
1735 },
1736 E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages,
1737 };
1738 Ok(stages)
1739 }
1740
1741 fn global_var_ty(
1742 module: &crate::Module,
1743 function: &crate::Function,
1744 expr: Handle<crate::Expression>,
1745 ) -> Result<Handle<crate::Type>, ExpressionError> {
1746 use crate::Expression as Ex;
1747
1748 match function.expressions[expr] {
1749 Ex::GlobalVariable(var_handle) => Ok(module.global_variables[var_handle].ty),
1750 Ex::FunctionArgument(i) => Ok(function.arguments[i as usize].ty),
1751 Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
1752 match function.expressions[base] {
1753 Ex::GlobalVariable(var_handle) => {
1754 let array_ty = module.global_variables[var_handle].ty;
1755
1756 match module.types[array_ty].inner {
1757 crate::TypeInner::BindingArray { base, .. } => Ok(base),
1758 _ => Err(ExpressionError::ExpectedBindingArrayType(array_ty)),
1759 }
1760 }
1761 _ => Err(ExpressionError::ExpectedGlobalVariable),
1762 }
1763 }
1764 _ => Err(ExpressionError::ExpectedGlobalVariable),
1765 }
1766 }
1767
1768 pub fn validate_literal(&self, literal: crate::Literal) -> Result<(), LiteralError> {
1769 self.check_width(literal.scalar())?;
1770 check_literal_value(literal)?;
1771
1772 Ok(())
1773 }
1774}
1775
1776pub fn check_literal_value(literal: crate::Literal) -> Result<(), LiteralError> {
1777 let is_nan = match literal {
1778 crate::Literal::F64(v) => v.is_nan(),
1779 crate::Literal::F32(v) => v.is_nan(),
1780 _ => false,
1781 };
1782 if is_nan {
1783 return Err(LiteralError::NaN);
1784 }
1785
1786 let is_infinite = match literal {
1787 crate::Literal::F64(v) => v.is_infinite(),
1788 crate::Literal::F32(v) => v.is_infinite(),
1789 _ => false,
1790 };
1791 if is_infinite {
1792 return Err(LiteralError::Infinity);
1793 }
1794
1795 Ok(())
1796}
1797
1798#[cfg(test)]
1799fn validate_with_expression(
1801 expr: crate::Expression,
1802 caps: super::Capabilities,
1803) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1804 use crate::span::Span;
1805
1806 let mut function = crate::Function::default();
1807 function.expressions.append(expr, Span::default());
1808 function.body.push(
1809 crate::Statement::Emit(function.expressions.range_from(0)),
1810 Span::default(),
1811 );
1812
1813 let mut module = crate::Module::default();
1814 module.functions.append(function, Span::default());
1815
1816 let mut validator = super::Validator::new(super::ValidationFlags::EXPRESSIONS, caps);
1817
1818 validator.validate(&module)
1819}
1820
1821#[cfg(test)]
1822fn validate_with_const_expression(
1824 expr: crate::Expression,
1825 caps: super::Capabilities,
1826) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1827 use crate::span::Span;
1828
1829 let mut module = crate::Module::default();
1830 module.global_expressions.append(expr, Span::default());
1831
1832 let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps);
1833
1834 validator.validate(&module)
1835}
1836
1837#[test]
1839fn f64_runtime_literals() {
1840 let result = validate_with_expression(
1841 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1842 super::Capabilities::default(),
1843 );
1844 let error = result.unwrap_err().into_inner();
1845 assert!(matches!(
1846 error,
1847 crate::valid::ValidationError::Function {
1848 source: super::FunctionError::Expression {
1849 source: ExpressionError::Literal(LiteralError::Width(
1850 super::r#type::WidthError::MissingCapability {
1851 name: "f64",
1852 flag: "FLOAT64",
1853 }
1854 ),),
1855 ..
1856 },
1857 ..
1858 }
1859 ));
1860
1861 let result = validate_with_expression(
1862 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1863 super::Capabilities::default() | super::Capabilities::FLOAT64,
1864 );
1865 assert!(result.is_ok());
1866}
1867
1868#[test]
1870fn f64_const_literals() {
1871 let result = validate_with_const_expression(
1872 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1873 super::Capabilities::default(),
1874 );
1875 let error = result.unwrap_err().into_inner();
1876 assert!(matches!(
1877 error,
1878 crate::valid::ValidationError::ConstExpression {
1879 source: ConstExpressionError::Literal(LiteralError::Width(
1880 super::r#type::WidthError::MissingCapability {
1881 name: "f64",
1882 flag: "FLOAT64",
1883 }
1884 )),
1885 ..
1886 }
1887 ));
1888
1889 let result = validate_with_const_expression(
1890 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1891 super::Capabilities::default() | super::Capabilities::FLOAT64,
1892 );
1893 assert!(result.is_ok());
1894}