1use super::{compose::validate_compose, FunctionInfo, ModuleInfo, ShaderStages, TypeFlags};
2use crate::arena::UniqueArena;
3
4use crate::{
5 arena::Handle,
6 proc::{IndexableLengthError, ResolveError},
7};
8
9#[derive(Clone, Debug, thiserror::Error)]
10#[cfg_attr(test, derive(PartialEq))]
11pub enum ExpressionError {
12 #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")]
13 NotInScope,
14 #[error("Base type {0:?} is not compatible with this expression")]
15 InvalidBaseType(Handle<crate::Expression>),
16 #[error("Accessing with index {0:?} can't be done")]
17 InvalidIndexType(Handle<crate::Expression>),
18 #[error("Accessing {0:?} via a negative index is invalid")]
19 NegativeIndex(Handle<crate::Expression>),
20 #[error("Accessing index {1} is out of {0:?} bounds")]
21 IndexOutOfBounds(Handle<crate::Expression>, u32),
22 #[error("Function argument {0:?} doesn't exist")]
23 FunctionArgumentDoesntExist(u32),
24 #[error("Loading of {0:?} can't be done")]
25 InvalidPointerType(Handle<crate::Expression>),
26 #[error("Array length of {0:?} can't be done")]
27 InvalidArrayType(Handle<crate::Expression>),
28 #[error("Get intersection of {0:?} can't be done")]
29 InvalidRayQueryType(Handle<crate::Expression>),
30 #[error("Splatting {0:?} can't be done")]
31 InvalidSplatType(Handle<crate::Expression>),
32 #[error("Swizzling {0:?} can't be done")]
33 InvalidVectorType(Handle<crate::Expression>),
34 #[error("Swizzle component {0:?} is outside of vector size {1:?}")]
35 InvalidSwizzleComponent(crate::SwizzleComponent, crate::VectorSize),
36 #[error(transparent)]
37 Compose(#[from] super::ComposeError),
38 #[error(transparent)]
39 IndexableLength(#[from] IndexableLengthError),
40 #[error("Operation {0:?} can't work with {1:?}")]
41 InvalidUnaryOperandType(crate::UnaryOperator, Handle<crate::Expression>),
42 #[error("Operation {0:?} can't work with {1:?} and {2:?}")]
43 InvalidBinaryOperandTypes(
44 crate::BinaryOperator,
45 Handle<crate::Expression>,
46 Handle<crate::Expression>,
47 ),
48 #[error("Selecting is not possible")]
49 InvalidSelectTypes,
50 #[error("Relational argument {0:?} is not a boolean vector")]
51 InvalidBooleanVector(Handle<crate::Expression>),
52 #[error("Relational argument {0:?} is not a float")]
53 InvalidFloatArgument(Handle<crate::Expression>),
54 #[error("Type resolution failed")]
55 Type(#[from] ResolveError),
56 #[error("Not a global variable")]
57 ExpectedGlobalVariable,
58 #[error("Not a global variable or a function argument")]
59 ExpectedGlobalOrArgument,
60 #[error("Needs to be an binding array instead of {0:?}")]
61 ExpectedBindingArrayType(Handle<crate::Type>),
62 #[error("Needs to be an image instead of {0:?}")]
63 ExpectedImageType(Handle<crate::Type>),
64 #[error("Needs to be an image instead of {0:?}")]
65 ExpectedSamplerType(Handle<crate::Type>),
66 #[error("Unable to operate on image class {0:?}")]
67 InvalidImageClass(crate::ImageClass),
68 #[error("Derivatives can only be taken from scalar and vector floats")]
69 InvalidDerivative,
70 #[error("Image array index parameter is misplaced")]
71 InvalidImageArrayIndex,
72 #[error("Inappropriate sample or level-of-detail index for texel access")]
73 InvalidImageOtherIndex,
74 #[error("Image array index type of {0:?} is not an integer scalar")]
75 InvalidImageArrayIndexType(Handle<crate::Expression>),
76 #[error("Image sample or level-of-detail index's type of {0:?} is not an integer scalar")]
77 InvalidImageOtherIndexType(Handle<crate::Expression>),
78 #[error("Image coordinate type of {1:?} does not match dimension {0:?}")]
79 InvalidImageCoordinateType(crate::ImageDimension, Handle<crate::Expression>),
80 #[error("Comparison sampling mismatch: image has class {image:?}, but the sampler is comparison={sampler}, and the reference was provided={has_ref}")]
81 ComparisonSamplingMismatch {
82 image: crate::ImageClass,
83 sampler: bool,
84 has_ref: bool,
85 },
86 #[error("Sample offset must be a const-expression")]
87 InvalidSampleOffsetExprType,
88 #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")]
89 InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>),
90 #[error("Depth reference {0:?} is not a scalar float")]
91 InvalidDepthReference(Handle<crate::Expression>),
92 #[error("Depth sample level can only be Auto or Zero")]
93 InvalidDepthSampleLevel,
94 #[error("Gather level can only be Zero")]
95 InvalidGatherLevel,
96 #[error("Gather component {0:?} doesn't exist in the image")]
97 InvalidGatherComponent(crate::SwizzleComponent),
98 #[error("Gather can't be done for image dimension {0:?}")]
99 InvalidGatherDimension(crate::ImageDimension),
100 #[error("Sample level (exact) type {0:?} is not a scalar float")]
101 InvalidSampleLevelExactType(Handle<crate::Expression>),
102 #[error("Sample level (bias) type {0:?} is not a scalar float")]
103 InvalidSampleLevelBiasType(Handle<crate::Expression>),
104 #[error("Sample level (gradient) of {1:?} doesn't match the image dimension {0:?}")]
105 InvalidSampleLevelGradientType(crate::ImageDimension, Handle<crate::Expression>),
106 #[error("Unable to cast")]
107 InvalidCastArgument,
108 #[error("Invalid argument count for {0:?}")]
109 WrongArgumentCount(crate::MathFunction),
110 #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
111 InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
112 #[error(
113 "workgroupUniformLoad result type can't be {0:?}. It can only be a constructible type."
114 )]
115 InvalidWorkGroupUniformLoadResultType(Handle<crate::Type>),
116 #[error("Shader requires capability {0:?}")]
117 MissingCapabilities(super::Capabilities),
118 #[error(transparent)]
119 Literal(#[from] LiteralError),
120 #[error("{0:?} is not supported for Width {2} {1:?} arguments yet, see https://github.com/gfx-rs/wgpu/issues/5276")]
121 UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes),
122}
123
124#[derive(Clone, Debug, thiserror::Error)]
125#[cfg_attr(test, derive(PartialEq))]
126pub enum ConstExpressionError {
127 #[error("The expression is not a constant or override expression")]
128 NonConstOrOverride,
129 #[error("The expression is not a fully evaluated constant expression")]
130 NonFullyEvaluatedConst,
131 #[error(transparent)]
132 Compose(#[from] super::ComposeError),
133 #[error("Splatting {0:?} can't be done")]
134 InvalidSplatType(Handle<crate::Expression>),
135 #[error("Type resolution failed")]
136 Type(#[from] ResolveError),
137 #[error(transparent)]
138 Literal(#[from] LiteralError),
139 #[error(transparent)]
140 Width(#[from] super::r#type::WidthError),
141}
142
143#[derive(Clone, Debug, thiserror::Error)]
144#[cfg_attr(test, derive(PartialEq))]
145pub enum LiteralError {
146 #[error("Float literal is NaN")]
147 NaN,
148 #[error("Float literal is infinite")]
149 Infinity,
150 #[error(transparent)]
151 Width(#[from] super::r#type::WidthError),
152}
153
154struct ExpressionTypeResolver<'a> {
155 root: Handle<crate::Expression>,
156 types: &'a UniqueArena<crate::Type>,
157 info: &'a FunctionInfo,
158}
159
160impl<'a> std::ops::Index<Handle<crate::Expression>> for ExpressionTypeResolver<'a> {
161 type Output = crate::TypeInner;
162
163 #[allow(clippy::panic)]
164 fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
165 if handle < self.root {
166 self.info[handle].ty.inner_with(self.types)
167 } else {
168 panic!(
170 "Depends on {:?}, which has not been processed yet",
171 self.root
172 )
173 }
174 }
175}
176
177impl super::Validator {
178 pub(super) fn validate_const_expression(
179 &self,
180 handle: Handle<crate::Expression>,
181 gctx: crate::proc::GlobalCtx,
182 mod_info: &ModuleInfo,
183 global_expr_kind: &crate::proc::ExpressionKindTracker,
184 ) -> Result<(), ConstExpressionError> {
185 use crate::Expression as E;
186
187 if !global_expr_kind.is_const_or_override(handle) {
188 return Err(ConstExpressionError::NonConstOrOverride);
189 }
190
191 match gctx.global_expressions[handle] {
192 E::Literal(literal) => {
193 self.validate_literal(literal)?;
194 }
195 E::Constant(_) | E::ZeroValue(_) => {}
196 E::Compose { ref components, ty } => {
197 validate_compose(
198 ty,
199 gctx,
200 components.iter().map(|&handle| mod_info[handle].clone()),
201 )?;
202 }
203 E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) {
204 crate::TypeInner::Scalar { .. } => {}
205 _ => return Err(ConstExpressionError::InvalidSplatType(value)),
206 },
207 _ if global_expr_kind.is_const(handle) || !self.allow_overrides => {
208 return Err(ConstExpressionError::NonFullyEvaluatedConst)
209 }
210 _ => {}
212 }
213
214 Ok(())
215 }
216
217 #[allow(clippy::too_many_arguments)]
218 pub(super) fn validate_expression(
219 &self,
220 root: Handle<crate::Expression>,
221 expression: &crate::Expression,
222 function: &crate::Function,
223 module: &crate::Module,
224 info: &FunctionInfo,
225 mod_info: &ModuleInfo,
226 global_expr_kind: &crate::proc::ExpressionKindTracker,
227 ) -> Result<ShaderStages, ExpressionError> {
228 use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti};
229
230 let resolver = ExpressionTypeResolver {
231 root,
232 types: &module.types,
233 info,
234 };
235
236 let stages = match *expression {
237 E::Access { base, index } => {
238 let base_type = &resolver[base];
239 match *base_type {
240 Ti::Matrix { .. }
241 | Ti::Vector { .. }
242 | Ti::Array { .. }
243 | Ti::Pointer { .. }
244 | Ti::ValuePointer { size: Some(_), .. }
245 | Ti::BindingArray { .. } => false,
246 ref other => {
247 log::error!("Indexing of {:?}", other);
248 return Err(ExpressionError::InvalidBaseType(base));
249 }
250 };
251 match resolver[index] {
252 Ti::Scalar(Sc {
254 kind: Sk::Sint | Sk::Uint,
255 ..
256 }) => {}
257 ref other => {
258 log::error!("Indexing by {:?}", other);
259 return Err(ExpressionError::InvalidIndexType(index));
260 }
261 }
262
263 if let crate::proc::IndexableLength::Known(known_length) =
266 base_type.indexable_length(module)?
267 {
268 match module
269 .to_ctx()
270 .eval_expr_to_u32_from(index, &function.expressions)
271 {
272 Ok(value) => {
273 if value >= known_length {
274 return Err(ExpressionError::IndexOutOfBounds(base, value));
275 }
276 }
277 Err(crate::proc::U32EvalError::Negative) => {
278 return Err(ExpressionError::NegativeIndex(base))
279 }
280 Err(crate::proc::U32EvalError::NonConst) => {}
281 }
282 }
283
284 ShaderStages::all()
285 }
286 E::AccessIndex { base, index } => {
287 fn resolve_index_limit(
288 module: &crate::Module,
289 top: Handle<crate::Expression>,
290 ty: &crate::TypeInner,
291 top_level: bool,
292 ) -> Result<u32, ExpressionError> {
293 let limit = match *ty {
294 Ti::Vector { size, .. }
295 | Ti::ValuePointer {
296 size: Some(size), ..
297 } => size as u32,
298 Ti::Matrix { columns, .. } => columns as u32,
299 Ti::Array {
300 size: crate::ArraySize::Constant(len),
301 ..
302 } => len.get(),
303 Ti::Array { .. } | Ti::BindingArray { .. } => u32::MAX, Ti::Pointer { base, .. } if top_level => {
305 resolve_index_limit(module, top, &module.types[base].inner, false)?
306 }
307 Ti::Struct { ref members, .. } => members.len() as u32,
308 ref other => {
309 log::error!("Indexing of {:?}", other);
310 return Err(ExpressionError::InvalidBaseType(top));
311 }
312 };
313 Ok(limit)
314 }
315
316 let limit = resolve_index_limit(module, base, &resolver[base], true)?;
317 if index >= limit {
318 return Err(ExpressionError::IndexOutOfBounds(base, limit));
319 }
320 ShaderStages::all()
321 }
322 E::Splat { size: _, value } => match resolver[value] {
323 Ti::Scalar { .. } => ShaderStages::all(),
324 ref other => {
325 log::error!("Splat scalar type {:?}", other);
326 return Err(ExpressionError::InvalidSplatType(value));
327 }
328 },
329 E::Swizzle {
330 size,
331 vector,
332 pattern,
333 } => {
334 let vec_size = match resolver[vector] {
335 Ti::Vector { size: vec_size, .. } => vec_size,
336 ref other => {
337 log::error!("Swizzle vector type {:?}", other);
338 return Err(ExpressionError::InvalidVectorType(vector));
339 }
340 };
341 for &sc in pattern[..size as usize].iter() {
342 if sc as u8 >= vec_size as u8 {
343 return Err(ExpressionError::InvalidSwizzleComponent(sc, vec_size));
344 }
345 }
346 ShaderStages::all()
347 }
348 E::Literal(literal) => {
349 self.validate_literal(literal)?;
350 ShaderStages::all()
351 }
352 E::Constant(_) | E::Override(_) | E::ZeroValue(_) => ShaderStages::all(),
353 E::Compose { ref components, ty } => {
354 validate_compose(
355 ty,
356 module.to_ctx(),
357 components.iter().map(|&handle| info[handle].ty.clone()),
358 )?;
359 ShaderStages::all()
360 }
361 E::FunctionArgument(index) => {
362 if index >= function.arguments.len() as u32 {
363 return Err(ExpressionError::FunctionArgumentDoesntExist(index));
364 }
365 ShaderStages::all()
366 }
367 E::GlobalVariable(_handle) => ShaderStages::all(),
368 E::LocalVariable(_handle) => ShaderStages::all(),
369 E::Load { pointer } => {
370 match resolver[pointer] {
371 Ti::Pointer { base, .. }
372 if self.types[base.index()]
373 .flags
374 .contains(TypeFlags::SIZED | TypeFlags::DATA) => {}
375 Ti::ValuePointer { .. } => {}
376 ref other => {
377 log::error!("Loading {:?}", other);
378 return Err(ExpressionError::InvalidPointerType(pointer));
379 }
380 }
381 ShaderStages::all()
382 }
383 E::ImageSample {
384 image,
385 sampler,
386 gather,
387 coordinate,
388 array_index,
389 offset,
390 level,
391 depth_ref,
392 } => {
393 let image_ty = Self::global_var_ty(module, function, image)?;
395 let sampler_ty = Self::global_var_ty(module, function, sampler)?;
396
397 let comparison = match module.types[sampler_ty].inner {
398 Ti::Sampler { comparison } => comparison,
399 _ => return Err(ExpressionError::ExpectedSamplerType(sampler_ty)),
400 };
401
402 let (class, dim) = match module.types[image_ty].inner {
403 Ti::Image {
404 class,
405 arrayed,
406 dim,
407 } => {
408 if arrayed != array_index.is_some() {
410 return Err(ExpressionError::InvalidImageArrayIndex);
411 }
412 if let Some(expr) = array_index {
413 match resolver[expr] {
414 Ti::Scalar(Sc {
415 kind: Sk::Sint | Sk::Uint,
416 ..
417 }) => {}
418 _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
419 }
420 }
421 (class, dim)
422 }
423 _ => return Err(ExpressionError::ExpectedImageType(image_ty)),
424 };
425
426 let image_depth = match class {
428 crate::ImageClass::Sampled {
429 kind: crate::ScalarKind::Float,
430 multi: false,
431 } => false,
432 crate::ImageClass::Sampled {
433 kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint,
434 multi: false,
435 } if gather.is_some() => false,
436 crate::ImageClass::Depth { multi: false } => true,
437 _ => return Err(ExpressionError::InvalidImageClass(class)),
438 };
439 if comparison != depth_ref.is_some() || (comparison && !image_depth) {
440 return Err(ExpressionError::ComparisonSamplingMismatch {
441 image: class,
442 sampler: comparison,
443 has_ref: depth_ref.is_some(),
444 });
445 }
446
447 let num_components = match dim {
449 crate::ImageDimension::D1 => 1,
450 crate::ImageDimension::D2 => 2,
451 crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3,
452 };
453 match resolver[coordinate] {
454 Ti::Scalar(Sc {
455 kind: Sk::Float, ..
456 }) if num_components == 1 => {}
457 Ti::Vector {
458 size,
459 scalar:
460 Sc {
461 kind: Sk::Float, ..
462 },
463 } if size as u32 == num_components => {}
464 _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
465 }
466
467 if let Some(const_expr) = offset {
469 if !global_expr_kind.is_const(const_expr) {
470 return Err(ExpressionError::InvalidSampleOffsetExprType);
471 }
472
473 match *mod_info[const_expr].inner_with(&module.types) {
474 Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {}
475 Ti::Vector {
476 size,
477 scalar: Sc { kind: Sk::Sint, .. },
478 } if size as u32 == num_components => {}
479 _ => {
480 return Err(ExpressionError::InvalidSampleOffset(dim, const_expr));
481 }
482 }
483 }
484
485 if let Some(expr) = depth_ref {
487 match resolver[expr] {
488 Ti::Scalar(Sc {
489 kind: Sk::Float, ..
490 }) => {}
491 _ => return Err(ExpressionError::InvalidDepthReference(expr)),
492 }
493 match level {
494 crate::SampleLevel::Auto | crate::SampleLevel::Zero => {}
495 _ => return Err(ExpressionError::InvalidDepthSampleLevel),
496 }
497 }
498
499 if let Some(component) = gather {
500 match dim {
501 crate::ImageDimension::D2 | crate::ImageDimension::Cube => {}
502 crate::ImageDimension::D1 | crate::ImageDimension::D3 => {
503 return Err(ExpressionError::InvalidGatherDimension(dim))
504 }
505 };
506 let max_component = match class {
507 crate::ImageClass::Depth { .. } => crate::SwizzleComponent::X,
508 _ => crate::SwizzleComponent::W,
509 };
510 if component > max_component {
511 return Err(ExpressionError::InvalidGatherComponent(component));
512 }
513 match level {
514 crate::SampleLevel::Zero => {}
515 _ => return Err(ExpressionError::InvalidGatherLevel),
516 }
517 }
518
519 match level {
521 crate::SampleLevel::Auto => ShaderStages::FRAGMENT,
522 crate::SampleLevel::Zero => ShaderStages::all(),
523 crate::SampleLevel::Exact(expr) => {
524 match resolver[expr] {
525 Ti::Scalar(Sc {
526 kind: Sk::Float, ..
527 }) => {}
528 _ => return Err(ExpressionError::InvalidSampleLevelExactType(expr)),
529 }
530 ShaderStages::all()
531 }
532 crate::SampleLevel::Bias(expr) => {
533 match resolver[expr] {
534 Ti::Scalar(Sc {
535 kind: Sk::Float, ..
536 }) => {}
537 _ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)),
538 }
539 ShaderStages::FRAGMENT
540 }
541 crate::SampleLevel::Gradient { x, y } => {
542 match resolver[x] {
543 Ti::Scalar(Sc {
544 kind: Sk::Float, ..
545 }) if num_components == 1 => {}
546 Ti::Vector {
547 size,
548 scalar:
549 Sc {
550 kind: Sk::Float, ..
551 },
552 } if size as u32 == num_components => {}
553 _ => {
554 return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x))
555 }
556 }
557 match resolver[y] {
558 Ti::Scalar(Sc {
559 kind: Sk::Float, ..
560 }) if num_components == 1 => {}
561 Ti::Vector {
562 size,
563 scalar:
564 Sc {
565 kind: Sk::Float, ..
566 },
567 } if size as u32 == num_components => {}
568 _ => {
569 return Err(ExpressionError::InvalidSampleLevelGradientType(dim, y))
570 }
571 }
572 ShaderStages::all()
573 }
574 }
575 }
576 E::ImageLoad {
577 image,
578 coordinate,
579 array_index,
580 sample,
581 level,
582 } => {
583 let ty = Self::global_var_ty(module, function, image)?;
584 match module.types[ty].inner {
585 Ti::Image {
586 class,
587 arrayed,
588 dim,
589 } => {
590 match resolver[coordinate].image_storage_coordinates() {
591 Some(coord_dim) if coord_dim == dim => {}
592 _ => {
593 return Err(ExpressionError::InvalidImageCoordinateType(
594 dim, coordinate,
595 ))
596 }
597 };
598 if arrayed != array_index.is_some() {
599 return Err(ExpressionError::InvalidImageArrayIndex);
600 }
601 if let Some(expr) = array_index {
602 match resolver[expr] {
603 Ti::Scalar(Sc {
604 kind: Sk::Sint | Sk::Uint,
605 width: _,
606 }) => {}
607 _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
608 }
609 }
610
611 match (sample, class.is_multisampled()) {
612 (None, false) => {}
613 (Some(sample), true) => {
614 if resolver[sample].scalar_kind() != Some(Sk::Sint) {
615 return Err(ExpressionError::InvalidImageOtherIndexType(
616 sample,
617 ));
618 }
619 }
620 _ => {
621 return Err(ExpressionError::InvalidImageOtherIndex);
622 }
623 }
624
625 match (level, class.is_mipmapped()) {
626 (None, false) => {}
627 (Some(level), true) => {
628 if resolver[level].scalar_kind() != Some(Sk::Sint) {
629 return Err(ExpressionError::InvalidImageOtherIndexType(level));
630 }
631 }
632 _ => {
633 return Err(ExpressionError::InvalidImageOtherIndex);
634 }
635 }
636 }
637 _ => return Err(ExpressionError::ExpectedImageType(ty)),
638 }
639 ShaderStages::all()
640 }
641 E::ImageQuery { image, query } => {
642 let ty = Self::global_var_ty(module, function, image)?;
643 match module.types[ty].inner {
644 Ti::Image { class, arrayed, .. } => {
645 let good = match query {
646 crate::ImageQuery::NumLayers => arrayed,
647 crate::ImageQuery::Size { level: None } => true,
648 crate::ImageQuery::Size { level: Some(_) }
649 | crate::ImageQuery::NumLevels => class.is_mipmapped(),
650 crate::ImageQuery::NumSamples => class.is_multisampled(),
651 };
652 if !good {
653 return Err(ExpressionError::InvalidImageClass(class));
654 }
655 }
656 _ => return Err(ExpressionError::ExpectedImageType(ty)),
657 }
658 ShaderStages::all()
659 }
660 E::Unary { op, expr } => {
661 use crate::UnaryOperator as Uo;
662 let inner = &resolver[expr];
663 match (op, inner.scalar_kind()) {
664 (Uo::Negate, Some(Sk::Float | Sk::Sint))
665 | (Uo::LogicalNot, Some(Sk::Bool))
666 | (Uo::BitwiseNot, Some(Sk::Sint | Sk::Uint)) => {}
667 other => {
668 log::error!("Op {:?} kind {:?}", op, other);
669 return Err(ExpressionError::InvalidUnaryOperandType(op, expr));
670 }
671 }
672 ShaderStages::all()
673 }
674 E::Binary { op, left, right } => {
675 use crate::BinaryOperator as Bo;
676 let left_inner = &resolver[left];
677 let right_inner = &resolver[right];
678 let good = match op {
679 Bo::Add | Bo::Subtract => match *left_inner {
680 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
681 Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
682 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
683 },
684 Ti::Matrix { .. } => left_inner == right_inner,
685 _ => false,
686 },
687 Bo::Divide | Bo::Modulo => match *left_inner {
688 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
689 Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
690 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
691 },
692 _ => false,
693 },
694 Bo::Multiply => {
695 let kind_allowed = match left_inner.scalar_kind() {
696 Some(Sk::Uint | Sk::Sint | Sk::Float) => true,
697 Some(Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat) | None => false,
698 };
699 let types_match = match (left_inner, right_inner) {
700 (&Ti::Scalar(scalar1), &Ti::Scalar(scalar2))
702 | (
703 &Ti::Vector {
704 scalar: scalar1, ..
705 },
706 &Ti::Scalar(scalar2),
707 )
708 | (
709 &Ti::Scalar(scalar1),
710 &Ti::Vector {
711 scalar: scalar2, ..
712 },
713 ) => scalar1 == scalar2,
714 (
716 &Ti::Scalar(Sc {
717 kind: Sk::Float, ..
718 }),
719 &Ti::Matrix { .. },
720 )
721 | (
722 &Ti::Matrix { .. },
723 &Ti::Scalar(Sc {
724 kind: Sk::Float, ..
725 }),
726 ) => true,
727 (
729 &Ti::Vector {
730 size: size1,
731 scalar: scalar1,
732 },
733 &Ti::Vector {
734 size: size2,
735 scalar: scalar2,
736 },
737 ) => scalar1 == scalar2 && size1 == size2,
738 (
740 &Ti::Matrix { columns, .. },
741 &Ti::Vector {
742 size,
743 scalar:
744 Sc {
745 kind: Sk::Float, ..
746 },
747 },
748 ) => columns == size,
749 (
751 &Ti::Vector {
752 size,
753 scalar:
754 Sc {
755 kind: Sk::Float, ..
756 },
757 },
758 &Ti::Matrix { rows, .. },
759 ) => size == rows,
760 (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => {
761 columns == rows
762 }
763 _ => false,
764 };
765 let left_width = left_inner.scalar_width().unwrap_or(0);
766 let right_width = right_inner.scalar_width().unwrap_or(0);
767 kind_allowed && types_match && left_width == right_width
768 }
769 Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner,
770 Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => {
771 match *left_inner {
772 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
773 Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
774 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
775 },
776 ref other => {
777 log::error!("Op {:?} left type {:?}", op, other);
778 false
779 }
780 }
781 }
782 Bo::LogicalAnd | Bo::LogicalOr => match *left_inner {
783 Ti::Scalar(Sc { kind: Sk::Bool, .. })
784 | Ti::Vector {
785 scalar: Sc { kind: Sk::Bool, .. },
786 ..
787 } => left_inner == right_inner,
788 ref other => {
789 log::error!("Op {:?} left type {:?}", op, other);
790 false
791 }
792 },
793 Bo::And | Bo::InclusiveOr => match *left_inner {
794 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
795 Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner,
796 Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
797 },
798 ref other => {
799 log::error!("Op {:?} left type {:?}", op, other);
800 false
801 }
802 },
803 Bo::ExclusiveOr => match *left_inner {
804 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
805 Sk::Sint | Sk::Uint => left_inner == right_inner,
806 Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
807 },
808 ref other => {
809 log::error!("Op {:?} left type {:?}", op, other);
810 false
811 }
812 },
813 Bo::ShiftLeft | Bo::ShiftRight => {
814 let (base_size, base_scalar) = match *left_inner {
815 Ti::Scalar(scalar) => (Ok(None), scalar),
816 Ti::Vector { size, scalar } => (Ok(Some(size)), scalar),
817 ref other => {
818 log::error!("Op {:?} base type {:?}", op, other);
819 (Err(()), Sc::BOOL)
820 }
821 };
822 let shift_size = match *right_inner {
823 Ti::Scalar(Sc { kind: Sk::Uint, .. }) => Ok(None),
824 Ti::Vector {
825 size,
826 scalar: Sc { kind: Sk::Uint, .. },
827 } => Ok(Some(size)),
828 ref other => {
829 log::error!("Op {:?} shift type {:?}", op, other);
830 Err(())
831 }
832 };
833 match base_scalar.kind {
834 Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size,
835 Sk::Float | Sk::AbstractInt | Sk::AbstractFloat | Sk::Bool => false,
836 }
837 }
838 };
839 if !good {
840 log::error!(
841 "Left: {:?} of type {:?}",
842 function.expressions[left],
843 left_inner
844 );
845 log::error!(
846 "Right: {:?} of type {:?}",
847 function.expressions[right],
848 right_inner
849 );
850 return Err(ExpressionError::InvalidBinaryOperandTypes(op, left, right));
851 }
852 ShaderStages::all()
853 }
854 E::Select {
855 condition,
856 accept,
857 reject,
858 } => {
859 let accept_inner = &resolver[accept];
860 let reject_inner = &resolver[reject];
861 let condition_good = match resolver[condition] {
862 Ti::Scalar(Sc {
863 kind: Sk::Bool,
864 width: _,
865 }) => {
866 match *accept_inner {
869 Ti::Scalar { .. } | Ti::Vector { .. } => true,
870 _ => false,
871 }
872 }
873 Ti::Vector {
874 size,
875 scalar:
876 Sc {
877 kind: Sk::Bool,
878 width: _,
879 },
880 } => match *accept_inner {
881 Ti::Vector {
882 size: other_size, ..
883 } => size == other_size,
884 _ => false,
885 },
886 _ => false,
887 };
888 if !condition_good || accept_inner != reject_inner {
889 return Err(ExpressionError::InvalidSelectTypes);
890 }
891 ShaderStages::all()
892 }
893 E::Derivative { expr, .. } => {
894 match resolver[expr] {
895 Ti::Scalar(Sc {
896 kind: Sk::Float, ..
897 })
898 | Ti::Vector {
899 scalar:
900 Sc {
901 kind: Sk::Float, ..
902 },
903 ..
904 } => {}
905 _ => return Err(ExpressionError::InvalidDerivative),
906 }
907 ShaderStages::FRAGMENT
908 }
909 E::Relational { fun, argument } => {
910 use crate::RelationalFunction as Rf;
911 let argument_inner = &resolver[argument];
912 match fun {
913 Rf::All | Rf::Any => match *argument_inner {
914 Ti::Vector {
915 scalar: Sc { kind: Sk::Bool, .. },
916 ..
917 } => {}
918 ref other => {
919 log::error!("All/Any of type {:?}", other);
920 return Err(ExpressionError::InvalidBooleanVector(argument));
921 }
922 },
923 Rf::IsNan | Rf::IsInf => match *argument_inner {
924 Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
925 if scalar.kind == Sk::Float => {}
926 ref other => {
927 log::error!("Float test of type {:?}", other);
928 return Err(ExpressionError::InvalidFloatArgument(argument));
929 }
930 },
931 }
932 ShaderStages::all()
933 }
934 E::Math {
935 fun,
936 arg,
937 arg1,
938 arg2,
939 arg3,
940 } => {
941 use crate::MathFunction as Mf;
942
943 let resolve = |arg| &resolver[arg];
944 let arg_ty = resolve(arg);
945 let arg1_ty = arg1.map(resolve);
946 let arg2_ty = arg2.map(resolve);
947 let arg3_ty = arg3.map(resolve);
948 match fun {
949 Mf::Abs => {
950 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
951 return Err(ExpressionError::WrongArgumentCount(fun));
952 }
953 let good = match *arg_ty {
954 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
955 scalar.kind != Sk::Bool
956 }
957 _ => false,
958 };
959 if !good {
960 return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
961 }
962 }
963 Mf::Min | Mf::Max => {
964 let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
965 (Some(ty1), None, None) => ty1,
966 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
967 };
968 let good = match *arg_ty {
969 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
970 scalar.kind != Sk::Bool
971 }
972 _ => false,
973 };
974 if !good {
975 return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
976 }
977 if arg1_ty != arg_ty {
978 return Err(ExpressionError::InvalidArgumentType(
979 fun,
980 1,
981 arg1.unwrap(),
982 ));
983 }
984 }
985 Mf::Clamp => {
986 let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
987 (Some(ty1), Some(ty2), None) => (ty1, ty2),
988 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
989 };
990 let good = match *arg_ty {
991 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
992 scalar.kind != Sk::Bool
993 }
994 _ => false,
995 };
996 if !good {
997 return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
998 }
999 if arg1_ty != arg_ty {
1000 return Err(ExpressionError::InvalidArgumentType(
1001 fun,
1002 1,
1003 arg1.unwrap(),
1004 ));
1005 }
1006 if arg2_ty != arg_ty {
1007 return Err(ExpressionError::InvalidArgumentType(
1008 fun,
1009 2,
1010 arg2.unwrap(),
1011 ));
1012 }
1013 }
1014 Mf::Saturate
1015 | Mf::Cos
1016 | Mf::Cosh
1017 | Mf::Sin
1018 | Mf::Sinh
1019 | Mf::Tan
1020 | Mf::Tanh
1021 | Mf::Acos
1022 | Mf::Asin
1023 | Mf::Atan
1024 | Mf::Asinh
1025 | Mf::Acosh
1026 | Mf::Atanh
1027 | Mf::Radians
1028 | Mf::Degrees
1029 | Mf::Ceil
1030 | Mf::Floor
1031 | Mf::Round
1032 | Mf::Fract
1033 | Mf::Trunc
1034 | Mf::Exp
1035 | Mf::Exp2
1036 | Mf::Log
1037 | Mf::Log2
1038 | Mf::Length
1039 | Mf::Sqrt
1040 | Mf::InverseSqrt => {
1041 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1042 return Err(ExpressionError::WrongArgumentCount(fun));
1043 }
1044 match *arg_ty {
1045 Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1046 if scalar.kind == Sk::Float => {}
1047 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1048 }
1049 }
1050 Mf::Sign => {
1051 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1052 return Err(ExpressionError::WrongArgumentCount(fun));
1053 }
1054 match *arg_ty {
1055 Ti::Scalar(Sc {
1056 kind: Sk::Float | Sk::Sint,
1057 ..
1058 })
1059 | Ti::Vector {
1060 scalar:
1061 Sc {
1062 kind: Sk::Float | Sk::Sint,
1063 ..
1064 },
1065 ..
1066 } => {}
1067 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1068 }
1069 }
1070 Mf::Atan2 | Mf::Pow | Mf::Distance | Mf::Step => {
1071 let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1072 (Some(ty1), None, None) => ty1,
1073 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1074 };
1075 match *arg_ty {
1076 Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1077 if scalar.kind == Sk::Float => {}
1078 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1079 }
1080 if arg1_ty != arg_ty {
1081 return Err(ExpressionError::InvalidArgumentType(
1082 fun,
1083 1,
1084 arg1.unwrap(),
1085 ));
1086 }
1087 }
1088 Mf::Modf | Mf::Frexp => {
1089 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1090 return Err(ExpressionError::WrongArgumentCount(fun));
1091 }
1092 if !matches!(*arg_ty,
1093 Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1094 if scalar.kind == Sk::Float)
1095 {
1096 return Err(ExpressionError::InvalidArgumentType(fun, 1, arg));
1097 }
1098 }
1099 Mf::Ldexp => {
1100 let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1101 (Some(ty1), None, None) => ty1,
1102 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1103 };
1104 let size0 = match *arg_ty {
1105 Ti::Scalar(Sc {
1106 kind: Sk::Float, ..
1107 }) => None,
1108 Ti::Vector {
1109 scalar:
1110 Sc {
1111 kind: Sk::Float, ..
1112 },
1113 size,
1114 } => Some(size),
1115 _ => {
1116 return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
1117 }
1118 };
1119 let good = match *arg1_ty {
1120 Ti::Scalar(Sc { kind: Sk::Sint, .. }) if size0.is_none() => true,
1121 Ti::Vector {
1122 size,
1123 scalar: Sc { kind: Sk::Sint, .. },
1124 } if Some(size) == size0 => true,
1125 _ => false,
1126 };
1127 if !good {
1128 return Err(ExpressionError::InvalidArgumentType(
1129 fun,
1130 1,
1131 arg1.unwrap(),
1132 ));
1133 }
1134 }
1135 Mf::Dot => {
1136 let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1137 (Some(ty1), None, None) => ty1,
1138 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1139 };
1140 match *arg_ty {
1141 Ti::Vector {
1142 scalar:
1143 Sc {
1144 kind: Sk::Float | Sk::Sint | Sk::Uint,
1145 ..
1146 },
1147 ..
1148 } => {}
1149 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1150 }
1151 if arg1_ty != arg_ty {
1152 return Err(ExpressionError::InvalidArgumentType(
1153 fun,
1154 1,
1155 arg1.unwrap(),
1156 ));
1157 }
1158 }
1159 Mf::Outer | Mf::Reflect => {
1160 let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1161 (Some(ty1), None, None) => ty1,
1162 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1163 };
1164 match *arg_ty {
1165 Ti::Vector {
1166 scalar:
1167 Sc {
1168 kind: Sk::Float, ..
1169 },
1170 ..
1171 } => {}
1172 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1173 }
1174 if arg1_ty != arg_ty {
1175 return Err(ExpressionError::InvalidArgumentType(
1176 fun,
1177 1,
1178 arg1.unwrap(),
1179 ));
1180 }
1181 }
1182 Mf::Cross => {
1183 let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1184 (Some(ty1), None, None) => ty1,
1185 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1186 };
1187 match *arg_ty {
1188 Ti::Vector {
1189 scalar:
1190 Sc {
1191 kind: Sk::Float, ..
1192 },
1193 size: crate::VectorSize::Tri,
1194 } => {}
1195 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1196 }
1197 if arg1_ty != arg_ty {
1198 return Err(ExpressionError::InvalidArgumentType(
1199 fun,
1200 1,
1201 arg1.unwrap(),
1202 ));
1203 }
1204 }
1205 Mf::Refract => {
1206 let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1207 (Some(ty1), Some(ty2), None) => (ty1, ty2),
1208 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1209 };
1210
1211 match *arg_ty {
1212 Ti::Vector {
1213 scalar:
1214 Sc {
1215 kind: Sk::Float, ..
1216 },
1217 ..
1218 } => {}
1219 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1220 }
1221
1222 if arg1_ty != arg_ty {
1223 return Err(ExpressionError::InvalidArgumentType(
1224 fun,
1225 1,
1226 arg1.unwrap(),
1227 ));
1228 }
1229
1230 match (arg_ty, arg2_ty) {
1231 (
1232 &Ti::Vector {
1233 scalar:
1234 Sc {
1235 width: vector_width,
1236 ..
1237 },
1238 ..
1239 },
1240 &Ti::Scalar(Sc {
1241 width: scalar_width,
1242 kind: Sk::Float,
1243 }),
1244 ) if vector_width == scalar_width => {}
1245 _ => {
1246 return Err(ExpressionError::InvalidArgumentType(
1247 fun,
1248 2,
1249 arg2.unwrap(),
1250 ))
1251 }
1252 }
1253 }
1254 Mf::Normalize => {
1255 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1256 return Err(ExpressionError::WrongArgumentCount(fun));
1257 }
1258 match *arg_ty {
1259 Ti::Vector {
1260 scalar:
1261 Sc {
1262 kind: Sk::Float, ..
1263 },
1264 ..
1265 } => {}
1266 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1267 }
1268 }
1269 Mf::FaceForward | Mf::Fma | Mf::SmoothStep => {
1270 let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1271 (Some(ty1), Some(ty2), None) => (ty1, ty2),
1272 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1273 };
1274 match *arg_ty {
1275 Ti::Scalar(Sc {
1276 kind: Sk::Float, ..
1277 })
1278 | Ti::Vector {
1279 scalar:
1280 Sc {
1281 kind: Sk::Float, ..
1282 },
1283 ..
1284 } => {}
1285 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1286 }
1287 if arg1_ty != arg_ty {
1288 return Err(ExpressionError::InvalidArgumentType(
1289 fun,
1290 1,
1291 arg1.unwrap(),
1292 ));
1293 }
1294 if arg2_ty != arg_ty {
1295 return Err(ExpressionError::InvalidArgumentType(
1296 fun,
1297 2,
1298 arg2.unwrap(),
1299 ));
1300 }
1301 }
1302 Mf::Mix => {
1303 let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1304 (Some(ty1), Some(ty2), None) => (ty1, ty2),
1305 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1306 };
1307 let arg_width = match *arg_ty {
1308 Ti::Scalar(Sc {
1309 kind: Sk::Float,
1310 width,
1311 })
1312 | Ti::Vector {
1313 scalar:
1314 Sc {
1315 kind: Sk::Float,
1316 width,
1317 },
1318 ..
1319 } => width,
1320 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1321 };
1322 if arg1_ty != arg_ty {
1323 return Err(ExpressionError::InvalidArgumentType(
1324 fun,
1325 1,
1326 arg1.unwrap(),
1327 ));
1328 }
1329 match *arg2_ty {
1331 Ti::Scalar(Sc {
1332 kind: Sk::Float,
1333 width,
1334 }) if width == arg_width => {}
1335 _ if arg2_ty == arg_ty => {}
1336 _ => {
1337 return Err(ExpressionError::InvalidArgumentType(
1338 fun,
1339 2,
1340 arg2.unwrap(),
1341 ));
1342 }
1343 }
1344 }
1345 Mf::Inverse | Mf::Determinant => {
1346 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1347 return Err(ExpressionError::WrongArgumentCount(fun));
1348 }
1349 let good = match *arg_ty {
1350 Ti::Matrix { columns, rows, .. } => columns == rows,
1351 _ => false,
1352 };
1353 if !good {
1354 return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
1355 }
1356 }
1357 Mf::Transpose => {
1358 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1359 return Err(ExpressionError::WrongArgumentCount(fun));
1360 }
1361 match *arg_ty {
1362 Ti::Matrix { .. } => {}
1363 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1364 }
1365 }
1366 Mf::CountLeadingZeros
1368 | Mf::CountTrailingZeros
1369 | Mf::CountOneBits
1370 | Mf::ReverseBits
1371 | Mf::FirstLeadingBit
1372 | Mf::FirstTrailingBit => {
1373 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1374 return Err(ExpressionError::WrongArgumentCount(fun));
1375 }
1376 match *arg_ty {
1377 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
1378 Sk::Sint | Sk::Uint => {
1379 if scalar.width != 4 {
1380 return Err(ExpressionError::UnsupportedWidth(
1381 fun,
1382 scalar.kind,
1383 scalar.width,
1384 ));
1385 }
1386 }
1387 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1388 },
1389 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1390 }
1391 }
1392 Mf::InsertBits => {
1393 let (arg1_ty, arg2_ty, arg3_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1394 (Some(ty1), Some(ty2), Some(ty3)) => (ty1, ty2, ty3),
1395 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1396 };
1397 match *arg_ty {
1398 Ti::Scalar(Sc {
1399 kind: Sk::Sint | Sk::Uint,
1400 ..
1401 })
1402 | Ti::Vector {
1403 scalar:
1404 Sc {
1405 kind: Sk::Sint | Sk::Uint,
1406 ..
1407 },
1408 ..
1409 } => {}
1410 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1411 }
1412 if arg1_ty != arg_ty {
1413 return Err(ExpressionError::InvalidArgumentType(
1414 fun,
1415 1,
1416 arg1.unwrap(),
1417 ));
1418 }
1419 match *arg2_ty {
1420 Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1421 _ => {
1422 return Err(ExpressionError::InvalidArgumentType(
1423 fun,
1424 2,
1425 arg2.unwrap(),
1426 ))
1427 }
1428 }
1429 match *arg3_ty {
1430 Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1431 _ => {
1432 return Err(ExpressionError::InvalidArgumentType(
1433 fun,
1434 2,
1435 arg3.unwrap(),
1436 ))
1437 }
1438 }
1439 for &arg in [arg_ty, arg1_ty, arg2_ty, arg3_ty].iter() {
1441 match *arg {
1442 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
1443 if scalar.width != 4 {
1444 return Err(ExpressionError::UnsupportedWidth(
1445 fun,
1446 scalar.kind,
1447 scalar.width,
1448 ));
1449 }
1450 }
1451 _ => {}
1452 }
1453 }
1454 }
1455 Mf::ExtractBits => {
1456 let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1457 (Some(ty1), Some(ty2), None) => (ty1, ty2),
1458 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1459 };
1460 match *arg_ty {
1461 Ti::Scalar(Sc {
1462 kind: Sk::Sint | Sk::Uint,
1463 ..
1464 })
1465 | Ti::Vector {
1466 scalar:
1467 Sc {
1468 kind: Sk::Sint | Sk::Uint,
1469 ..
1470 },
1471 ..
1472 } => {}
1473 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1474 }
1475 match *arg1_ty {
1476 Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1477 _ => {
1478 return Err(ExpressionError::InvalidArgumentType(
1479 fun,
1480 2,
1481 arg1.unwrap(),
1482 ))
1483 }
1484 }
1485 match *arg2_ty {
1486 Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1487 _ => {
1488 return Err(ExpressionError::InvalidArgumentType(
1489 fun,
1490 2,
1491 arg2.unwrap(),
1492 ))
1493 }
1494 }
1495 for &arg in [arg_ty, arg1_ty, arg2_ty].iter() {
1497 match *arg {
1498 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
1499 if scalar.width != 4 {
1500 return Err(ExpressionError::UnsupportedWidth(
1501 fun,
1502 scalar.kind,
1503 scalar.width,
1504 ));
1505 }
1506 }
1507 _ => {}
1508 }
1509 }
1510 }
1511 Mf::Pack2x16unorm | Mf::Pack2x16snorm | Mf::Pack2x16float => {
1512 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1513 return Err(ExpressionError::WrongArgumentCount(fun));
1514 }
1515 match *arg_ty {
1516 Ti::Vector {
1517 size: crate::VectorSize::Bi,
1518 scalar:
1519 Sc {
1520 kind: Sk::Float, ..
1521 },
1522 } => {}
1523 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1524 }
1525 }
1526 Mf::Pack4x8snorm | Mf::Pack4x8unorm => {
1527 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1528 return Err(ExpressionError::WrongArgumentCount(fun));
1529 }
1530 match *arg_ty {
1531 Ti::Vector {
1532 size: crate::VectorSize::Quad,
1533 scalar:
1534 Sc {
1535 kind: Sk::Float, ..
1536 },
1537 } => {}
1538 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1539 }
1540 }
1541 mf @ (Mf::Pack4xI8 | Mf::Pack4xU8) => {
1542 let scalar_kind = match mf {
1543 Mf::Pack4xI8 => Sk::Sint,
1544 Mf::Pack4xU8 => Sk::Uint,
1545 _ => unreachable!(),
1546 };
1547 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1548 return Err(ExpressionError::WrongArgumentCount(fun));
1549 }
1550 match *arg_ty {
1551 Ti::Vector {
1552 size: crate::VectorSize::Quad,
1553 scalar: Sc { kind, .. },
1554 } if kind == scalar_kind => {}
1555 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1556 }
1557 }
1558 Mf::Unpack2x16float
1559 | Mf::Unpack2x16snorm
1560 | Mf::Unpack2x16unorm
1561 | Mf::Unpack4x8snorm
1562 | Mf::Unpack4x8unorm
1563 | Mf::Unpack4xI8
1564 | Mf::Unpack4xU8 => {
1565 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1566 return Err(ExpressionError::WrongArgumentCount(fun));
1567 }
1568 match *arg_ty {
1569 Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1570 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1571 }
1572 }
1573 }
1574 ShaderStages::all()
1575 }
1576 E::As {
1577 expr,
1578 kind,
1579 convert,
1580 } => {
1581 let mut base_scalar = match resolver[expr] {
1582 crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } => {
1583 scalar
1584 }
1585 crate::TypeInner::Matrix { scalar, .. } => scalar,
1586 _ => return Err(ExpressionError::InvalidCastArgument),
1587 };
1588 base_scalar.kind = kind;
1589 if let Some(width) = convert {
1590 base_scalar.width = width;
1591 }
1592 if self.check_width(base_scalar).is_err() {
1593 return Err(ExpressionError::InvalidCastArgument);
1594 }
1595 ShaderStages::all()
1596 }
1597 E::CallResult(function) => mod_info.functions[function.index()].available_stages,
1598 E::AtomicResult { .. } => {
1599 ShaderStages::all()
1604 }
1605 E::WorkGroupUniformLoadResult { ty } => {
1606 if self.types[ty.index()]
1607 .flags
1608 .contains(TypeFlags::SIZED | TypeFlags::CONSTRUCTIBLE)
1611 {
1612 ShaderStages::COMPUTE
1613 } else {
1614 return Err(ExpressionError::InvalidWorkGroupUniformLoadResultType(ty));
1615 }
1616 }
1617 E::ArrayLength(expr) => match resolver[expr] {
1618 Ti::Pointer { base, .. } => {
1619 let base_ty = &resolver.types[base];
1620 if let Ti::Array {
1621 size: crate::ArraySize::Dynamic,
1622 ..
1623 } = base_ty.inner
1624 {
1625 ShaderStages::all()
1626 } else {
1627 return Err(ExpressionError::InvalidArrayType(expr));
1628 }
1629 }
1630 ref other => {
1631 log::error!("Array length of {:?}", other);
1632 return Err(ExpressionError::InvalidArrayType(expr));
1633 }
1634 },
1635 E::RayQueryProceedResult => ShaderStages::all(),
1636 E::RayQueryGetIntersection {
1637 query,
1638 committed: _,
1639 } => match resolver[query] {
1640 Ti::Pointer {
1641 base,
1642 space: crate::AddressSpace::Function,
1643 } => match resolver.types[base].inner {
1644 Ti::RayQuery => ShaderStages::all(),
1645 ref other => {
1646 log::error!("Intersection result of a pointer to {:?}", other);
1647 return Err(ExpressionError::InvalidRayQueryType(query));
1648 }
1649 },
1650 ref other => {
1651 log::error!("Intersection result of {:?}", other);
1652 return Err(ExpressionError::InvalidRayQueryType(query));
1653 }
1654 },
1655 E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages,
1656 };
1657 Ok(stages)
1658 }
1659
1660 fn global_var_ty(
1661 module: &crate::Module,
1662 function: &crate::Function,
1663 expr: Handle<crate::Expression>,
1664 ) -> Result<Handle<crate::Type>, ExpressionError> {
1665 use crate::Expression as Ex;
1666
1667 match function.expressions[expr] {
1668 Ex::GlobalVariable(var_handle) => Ok(module.global_variables[var_handle].ty),
1669 Ex::FunctionArgument(i) => Ok(function.arguments[i as usize].ty),
1670 Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
1671 match function.expressions[base] {
1672 Ex::GlobalVariable(var_handle) => {
1673 let array_ty = module.global_variables[var_handle].ty;
1674
1675 match module.types[array_ty].inner {
1676 crate::TypeInner::BindingArray { base, .. } => Ok(base),
1677 _ => Err(ExpressionError::ExpectedBindingArrayType(array_ty)),
1678 }
1679 }
1680 _ => Err(ExpressionError::ExpectedGlobalVariable),
1681 }
1682 }
1683 _ => Err(ExpressionError::ExpectedGlobalVariable),
1684 }
1685 }
1686
1687 pub fn validate_literal(&self, literal: crate::Literal) -> Result<(), LiteralError> {
1688 self.check_width(literal.scalar())?;
1689 check_literal_value(literal)?;
1690
1691 Ok(())
1692 }
1693}
1694
1695pub fn check_literal_value(literal: crate::Literal) -> Result<(), LiteralError> {
1696 let is_nan = match literal {
1697 crate::Literal::F64(v) => v.is_nan(),
1698 crate::Literal::F32(v) => v.is_nan(),
1699 _ => false,
1700 };
1701 if is_nan {
1702 return Err(LiteralError::NaN);
1703 }
1704
1705 let is_infinite = match literal {
1706 crate::Literal::F64(v) => v.is_infinite(),
1707 crate::Literal::F32(v) => v.is_infinite(),
1708 _ => false,
1709 };
1710 if is_infinite {
1711 return Err(LiteralError::Infinity);
1712 }
1713
1714 Ok(())
1715}
1716
1717#[cfg(test)]
1718fn validate_with_expression(
1720 expr: crate::Expression,
1721 caps: super::Capabilities,
1722) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1723 use crate::span::Span;
1724
1725 let mut function = crate::Function::default();
1726 function.expressions.append(expr, Span::default());
1727 function.body.push(
1728 crate::Statement::Emit(function.expressions.range_from(0)),
1729 Span::default(),
1730 );
1731
1732 let mut module = crate::Module::default();
1733 module.functions.append(function, Span::default());
1734
1735 let mut validator = super::Validator::new(super::ValidationFlags::EXPRESSIONS, caps);
1736
1737 validator.validate(&module)
1738}
1739
1740#[cfg(test)]
1741fn validate_with_const_expression(
1743 expr: crate::Expression,
1744 caps: super::Capabilities,
1745) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1746 use crate::span::Span;
1747
1748 let mut module = crate::Module::default();
1749 module.global_expressions.append(expr, Span::default());
1750
1751 let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps);
1752
1753 validator.validate(&module)
1754}
1755
1756#[test]
1758fn f64_runtime_literals() {
1759 let result = validate_with_expression(
1760 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1761 super::Capabilities::default(),
1762 );
1763 let error = result.unwrap_err().into_inner();
1764 assert!(matches!(
1765 error,
1766 crate::valid::ValidationError::Function {
1767 source: super::FunctionError::Expression {
1768 source: ExpressionError::Literal(LiteralError::Width(
1769 super::r#type::WidthError::MissingCapability {
1770 name: "f64",
1771 flag: "FLOAT64",
1772 }
1773 ),),
1774 ..
1775 },
1776 ..
1777 }
1778 ));
1779
1780 let result = validate_with_expression(
1781 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1782 super::Capabilities::default() | super::Capabilities::FLOAT64,
1783 );
1784 assert!(result.is_ok());
1785}
1786
1787#[test]
1789fn f64_const_literals() {
1790 let result = validate_with_const_expression(
1791 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1792 super::Capabilities::default(),
1793 );
1794 let error = result.unwrap_err().into_inner();
1795 assert!(matches!(
1796 error,
1797 crate::valid::ValidationError::ConstExpression {
1798 source: ConstExpressionError::Literal(LiteralError::Width(
1799 super::r#type::WidthError::MissingCapability {
1800 name: "f64",
1801 flag: "FLOAT64",
1802 }
1803 )),
1804 ..
1805 }
1806 ));
1807
1808 let result = validate_with_const_expression(
1809 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1810 super::Capabilities::default() | super::Capabilities::FLOAT64,
1811 );
1812 assert!(result.is_ok());
1813}