1mod analyzer;
6mod compose;
7mod expression;
8mod function;
9mod handles;
10mod interface;
11mod r#type;
12
13use crate::{
14 arena::{Handle, HandleSet},
15 proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution},
16 FastHashSet,
17};
18use bit_set::BitSet;
19use std::ops;
20
21use crate::span::{AddSpan as _, WithSpan};
25pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
26pub use compose::ComposeError;
27pub use expression::{check_literal_value, LiteralError};
28pub use expression::{ConstExpressionError, ExpressionError};
29pub use function::{CallError, FunctionError, LocalVariableError};
30pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
31pub use r#type::{Disalignment, TypeError, TypeFlags, WidthError};
32
33use self::handles::InvalidHandleError;
34
35bitflags::bitflags! {
36 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
50 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
51 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
52 pub struct ValidationFlags: u8 {
53 const EXPRESSIONS = 0x1;
55 const BLOCKS = 0x2;
57 const CONTROL_FLOW_UNIFORMITY = 0x4;
59 const STRUCT_LAYOUTS = 0x8;
61 const CONSTANTS = 0x10;
63 const BINDINGS = 0x20;
65 }
66}
67
68impl Default for ValidationFlags {
69 fn default() -> Self {
70 Self::all()
71 }
72}
73
74bitflags::bitflags! {
75 #[must_use]
77 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
78 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
79 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
80 pub struct Capabilities: u32 {
81 const PUSH_CONSTANT = 0x1;
85 const FLOAT64 = 0x2;
87 const PRIMITIVE_INDEX = 0x4;
91 const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 0x8;
93 const UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 0x10;
95 const SAMPLER_NON_UNIFORM_INDEXING = 0x20;
97 const CLIP_DISTANCE = 0x40;
101 const CULL_DISTANCE = 0x80;
105 const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 0x100;
107 const MULTIVIEW = 0x200;
111 const EARLY_DEPTH_TEST = 0x400;
113 const MULTISAMPLED_SHADING = 0x800;
118 const RAY_QUERY = 0x1000;
120 const DUAL_SOURCE_BLENDING = 0x2000;
122 const CUBE_ARRAY_TEXTURES = 0x4000;
124 const SHADER_INT64 = 0x8000;
126 const SUBGROUP = 0x10000;
130 const SUBGROUP_BARRIER = 0x20000;
132 const SUBGROUP_VERTEX_STAGE = 0x40000;
134 const SHADER_INT64_ATOMIC_MIN_MAX = 0x80000;
144 const SHADER_INT64_ATOMIC_ALL_OPS = 0x100000;
146 }
147}
148
149impl Default for Capabilities {
150 fn default() -> Self {
151 Self::MULTISAMPLED_SHADING | Self::CUBE_ARRAY_TEXTURES
152 }
153}
154
155bitflags::bitflags! {
156 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
158 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
159 #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
160 pub struct SubgroupOperationSet: u8 {
161 const BASIC = 1 << 0;
163 const VOTE = 1 << 1;
165 const ARITHMETIC = 1 << 2;
167 const BALLOT = 1 << 3;
169 const SHUFFLE = 1 << 4;
171 const SHUFFLE_RELATIVE = 1 << 5;
173 }
181}
182
183impl super::SubgroupOperation {
184 const fn required_operations(&self) -> SubgroupOperationSet {
185 use SubgroupOperationSet as S;
186 match *self {
187 Self::All | Self::Any => S::VOTE,
188 Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
189 S::ARITHMETIC
190 }
191 }
192 }
193}
194
195impl super::GatherMode {
196 const fn required_operations(&self) -> SubgroupOperationSet {
197 use SubgroupOperationSet as S;
198 match *self {
199 Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
200 Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
201 Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
202 }
203 }
204}
205
206bitflags::bitflags! {
207 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
209 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
210 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
211 pub struct ShaderStages: u8 {
212 const VERTEX = 0x1;
213 const FRAGMENT = 0x2;
214 const COMPUTE = 0x4;
215 }
216}
217
218#[derive(Debug, Clone)]
219#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
220#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
221pub struct ModuleInfo {
222 type_flags: Vec<TypeFlags>,
223 functions: Vec<FunctionInfo>,
224 entry_points: Vec<FunctionInfo>,
225 const_expression_types: Box<[TypeResolution]>,
226}
227
228impl ops::Index<Handle<crate::Type>> for ModuleInfo {
229 type Output = TypeFlags;
230 fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
231 &self.type_flags[handle.index()]
232 }
233}
234
235impl ops::Index<Handle<crate::Function>> for ModuleInfo {
236 type Output = FunctionInfo;
237 fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
238 &self.functions[handle.index()]
239 }
240}
241
242impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
243 type Output = TypeResolution;
244 fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
245 &self.const_expression_types[handle.index()]
246 }
247}
248
249#[derive(Debug)]
250pub struct Validator {
251 flags: ValidationFlags,
252 capabilities: Capabilities,
253 subgroup_stages: ShaderStages,
254 subgroup_operations: SubgroupOperationSet,
255 types: Vec<r#type::TypeInfo>,
256 layouter: Layouter,
257 location_mask: BitSet,
258 ep_resource_bindings: FastHashSet<crate::ResourceBinding>,
259 #[allow(dead_code)]
260 switch_values: FastHashSet<crate::SwitchValue>,
261 valid_expression_list: Vec<Handle<crate::Expression>>,
262 valid_expression_set: HandleSet<crate::Expression>,
263 override_ids: FastHashSet<u16>,
264 allow_overrides: bool,
265
266 needs_visit: HandleSet<crate::Expression>,
285}
286
287#[derive(Clone, Debug, thiserror::Error)]
288#[cfg_attr(test, derive(PartialEq))]
289pub enum ConstantError {
290 #[error("Initializer must be a const-expression")]
291 InitializerExprType,
292 #[error("The type doesn't match the constant")]
293 InvalidType,
294 #[error("The type is not constructible")]
295 NonConstructibleType,
296}
297
298#[derive(Clone, Debug, thiserror::Error)]
299#[cfg_attr(test, derive(PartialEq))]
300pub enum OverrideError {
301 #[error("Override name and ID are missing")]
302 MissingNameAndID,
303 #[error("Override ID must be unique")]
304 DuplicateID,
305 #[error("Initializer must be a const-expression or override-expression")]
306 InitializerExprType,
307 #[error("The type doesn't match the override")]
308 InvalidType,
309 #[error("The type is not constructible")]
310 NonConstructibleType,
311 #[error("The type is not a scalar")]
312 TypeNotScalar,
313 #[error("Override declarations are not allowed")]
314 NotAllowed,
315}
316
317#[derive(Clone, Debug, thiserror::Error)]
318#[cfg_attr(test, derive(PartialEq))]
319pub enum ValidationError {
320 #[error(transparent)]
321 InvalidHandle(#[from] InvalidHandleError),
322 #[error(transparent)]
323 Layouter(#[from] LayoutError),
324 #[error("Type {handle:?} '{name}' is invalid")]
325 Type {
326 handle: Handle<crate::Type>,
327 name: String,
328 source: TypeError,
329 },
330 #[error("Constant expression {handle:?} is invalid")]
331 ConstExpression {
332 handle: Handle<crate::Expression>,
333 source: ConstExpressionError,
334 },
335 #[error("Constant {handle:?} '{name}' is invalid")]
336 Constant {
337 handle: Handle<crate::Constant>,
338 name: String,
339 source: ConstantError,
340 },
341 #[error("Override {handle:?} '{name}' is invalid")]
342 Override {
343 handle: Handle<crate::Override>,
344 name: String,
345 source: OverrideError,
346 },
347 #[error("Global variable {handle:?} '{name}' is invalid")]
348 GlobalVariable {
349 handle: Handle<crate::GlobalVariable>,
350 name: String,
351 source: GlobalVariableError,
352 },
353 #[error("Function {handle:?} '{name}' is invalid")]
354 Function {
355 handle: Handle<crate::Function>,
356 name: String,
357 source: FunctionError,
358 },
359 #[error("Entry point {name} at {stage:?} is invalid")]
360 EntryPoint {
361 stage: crate::ShaderStage,
362 name: String,
363 source: EntryPointError,
364 },
365 #[error("Module is corrupted")]
366 Corrupted,
367}
368
369impl crate::TypeInner {
370 const fn is_sized(&self) -> bool {
371 match *self {
372 Self::Scalar { .. }
373 | Self::Vector { .. }
374 | Self::Matrix { .. }
375 | Self::Array {
376 size: crate::ArraySize::Constant(_),
377 ..
378 }
379 | Self::Atomic { .. }
380 | Self::Pointer { .. }
381 | Self::ValuePointer { .. }
382 | Self::Struct { .. } => true,
383 Self::Array { .. }
384 | Self::Image { .. }
385 | Self::Sampler { .. }
386 | Self::AccelerationStructure
387 | Self::RayQuery
388 | Self::BindingArray { .. } => false,
389 }
390 }
391
392 const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
394 match *self {
395 Self::Scalar(crate::Scalar {
396 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
397 ..
398 }) => Some(crate::ImageDimension::D1),
399 Self::Vector {
400 size: crate::VectorSize::Bi,
401 scalar:
402 crate::Scalar {
403 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
404 ..
405 },
406 } => Some(crate::ImageDimension::D2),
407 Self::Vector {
408 size: crate::VectorSize::Tri,
409 scalar:
410 crate::Scalar {
411 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
412 ..
413 },
414 } => Some(crate::ImageDimension::D3),
415 _ => None,
416 }
417 }
418}
419
420impl Validator {
421 pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
423 let subgroup_operations = if capabilities.contains(Capabilities::SUBGROUP) {
424 use SubgroupOperationSet as S;
425 S::BASIC | S::VOTE | S::ARITHMETIC | S::BALLOT | S::SHUFFLE | S::SHUFFLE_RELATIVE
426 } else {
427 SubgroupOperationSet::empty()
428 };
429 let subgroup_stages = {
430 let mut stages = ShaderStages::empty();
431 if capabilities.contains(Capabilities::SUBGROUP_VERTEX_STAGE) {
432 stages |= ShaderStages::VERTEX;
433 }
434 if capabilities.contains(Capabilities::SUBGROUP) {
435 stages |= ShaderStages::FRAGMENT | ShaderStages::COMPUTE;
436 }
437 stages
438 };
439
440 Validator {
441 flags,
442 capabilities,
443 subgroup_stages,
444 subgroup_operations,
445 types: Vec::new(),
446 layouter: Layouter::default(),
447 location_mask: BitSet::new(),
448 ep_resource_bindings: FastHashSet::default(),
449 switch_values: FastHashSet::default(),
450 valid_expression_list: Vec::new(),
451 valid_expression_set: HandleSet::new(),
452 override_ids: FastHashSet::default(),
453 allow_overrides: true,
454 needs_visit: HandleSet::new(),
455 }
456 }
457
458 pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
459 self.subgroup_stages = stages;
460 self
461 }
462
463 pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
464 self.subgroup_operations = operations;
465 self
466 }
467
468 pub fn reset(&mut self) {
470 self.types.clear();
471 self.layouter.clear();
472 self.location_mask.clear();
473 self.ep_resource_bindings.clear();
474 self.switch_values.clear();
475 self.valid_expression_list.clear();
476 self.valid_expression_set.clear();
477 self.override_ids.clear();
478 }
479
480 fn validate_constant(
481 &self,
482 handle: Handle<crate::Constant>,
483 gctx: crate::proc::GlobalCtx,
484 mod_info: &ModuleInfo,
485 global_expr_kind: &ExpressionKindTracker,
486 ) -> Result<(), ConstantError> {
487 let con = &gctx.constants[handle];
488
489 let type_info = &self.types[con.ty.index()];
490 if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
491 return Err(ConstantError::NonConstructibleType);
492 }
493
494 if !global_expr_kind.is_const(con.init) {
495 return Err(ConstantError::InitializerExprType);
496 }
497
498 let decl_ty = &gctx.types[con.ty].inner;
499 let init_ty = mod_info[con.init].inner_with(gctx.types);
500 if !decl_ty.equivalent(init_ty, gctx.types) {
501 return Err(ConstantError::InvalidType);
502 }
503
504 Ok(())
505 }
506
507 fn validate_override(
508 &mut self,
509 handle: Handle<crate::Override>,
510 gctx: crate::proc::GlobalCtx,
511 mod_info: &ModuleInfo,
512 ) -> Result<(), OverrideError> {
513 if !self.allow_overrides {
514 return Err(OverrideError::NotAllowed);
515 }
516
517 let o = &gctx.overrides[handle];
518
519 if o.name.is_none() && o.id.is_none() {
520 return Err(OverrideError::MissingNameAndID);
521 }
522
523 if let Some(id) = o.id {
524 if !self.override_ids.insert(id) {
525 return Err(OverrideError::DuplicateID);
526 }
527 }
528
529 let type_info = &self.types[o.ty.index()];
530 if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
531 return Err(OverrideError::NonConstructibleType);
532 }
533
534 let decl_ty = &gctx.types[o.ty].inner;
535 match decl_ty {
536 &crate::TypeInner::Scalar(
537 crate::Scalar::BOOL
538 | crate::Scalar::I32
539 | crate::Scalar::U32
540 | crate::Scalar::F32
541 | crate::Scalar::F64,
542 ) => {}
543 _ => return Err(OverrideError::TypeNotScalar),
544 }
545
546 if let Some(init) = o.init {
547 let init_ty = mod_info[init].inner_with(gctx.types);
548 if !decl_ty.equivalent(init_ty, gctx.types) {
549 return Err(OverrideError::InvalidType);
550 }
551 }
552
553 Ok(())
554 }
555
556 pub fn validate(
558 &mut self,
559 module: &crate::Module,
560 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
561 self.allow_overrides = true;
562 self.validate_impl(module)
563 }
564
565 pub fn validate_no_overrides(
569 &mut self,
570 module: &crate::Module,
571 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
572 self.allow_overrides = false;
573 self.validate_impl(module)
574 }
575
576 fn validate_impl(
577 &mut self,
578 module: &crate::Module,
579 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
580 self.reset();
581 self.reset_types(module.types.len());
582
583 Self::validate_module_handles(module).map_err(|e| e.with_span())?;
584
585 self.layouter.update(module.to_ctx()).map_err(|e| {
586 let handle = e.ty;
587 ValidationError::from(e).with_span_handle(handle, &module.types)
588 })?;
589
590 let placeholder = TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
592 kind: crate::ScalarKind::Bool,
593 width: 0,
594 }));
595
596 let mut mod_info = ModuleInfo {
597 type_flags: Vec::with_capacity(module.types.len()),
598 functions: Vec::with_capacity(module.functions.len()),
599 entry_points: Vec::with_capacity(module.entry_points.len()),
600 const_expression_types: vec![placeholder; module.global_expressions.len()]
601 .into_boxed_slice(),
602 };
603
604 for (handle, ty) in module.types.iter() {
605 let ty_info = self
606 .validate_type(handle, module.to_ctx())
607 .map_err(|source| {
608 ValidationError::Type {
609 handle,
610 name: ty.name.clone().unwrap_or_default(),
611 source,
612 }
613 .with_span_handle(handle, &module.types)
614 })?;
615 mod_info.type_flags.push(ty_info.flags);
616 self.types[handle.index()] = ty_info;
617 }
618
619 {
620 let t = crate::Arena::new();
621 let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
622 for (handle, _) in module.global_expressions.iter() {
623 mod_info
624 .process_const_expression(handle, &resolve_context, module.to_ctx())
625 .map_err(|source| {
626 ValidationError::ConstExpression { handle, source }
627 .with_span_handle(handle, &module.global_expressions)
628 })?
629 }
630 }
631
632 let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions);
633
634 if self.flags.contains(ValidationFlags::CONSTANTS) {
635 for (handle, _) in module.global_expressions.iter() {
636 self.validate_const_expression(
637 handle,
638 module.to_ctx(),
639 &mod_info,
640 &global_expr_kind,
641 )
642 .map_err(|source| {
643 ValidationError::ConstExpression { handle, source }
644 .with_span_handle(handle, &module.global_expressions)
645 })?
646 }
647
648 for (handle, constant) in module.constants.iter() {
649 self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
650 .map_err(|source| {
651 ValidationError::Constant {
652 handle,
653 name: constant.name.clone().unwrap_or_default(),
654 source,
655 }
656 .with_span_handle(handle, &module.constants)
657 })?
658 }
659
660 for (handle, override_) in module.overrides.iter() {
661 self.validate_override(handle, module.to_ctx(), &mod_info)
662 .map_err(|source| {
663 ValidationError::Override {
664 handle,
665 name: override_.name.clone().unwrap_or_default(),
666 source,
667 }
668 .with_span_handle(handle, &module.overrides)
669 })?
670 }
671 }
672
673 for (var_handle, var) in module.global_variables.iter() {
674 self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
675 .map_err(|source| {
676 ValidationError::GlobalVariable {
677 handle: var_handle,
678 name: var.name.clone().unwrap_or_default(),
679 source,
680 }
681 .with_span_handle(var_handle, &module.global_variables)
682 })?;
683 }
684
685 for (handle, fun) in module.functions.iter() {
686 match self.validate_function(fun, module, &mod_info, false, &global_expr_kind) {
687 Ok(info) => mod_info.functions.push(info),
688 Err(error) => {
689 return Err(error.and_then(|source| {
690 ValidationError::Function {
691 handle,
692 name: fun.name.clone().unwrap_or_default(),
693 source,
694 }
695 .with_span_handle(handle, &module.functions)
696 }))
697 }
698 }
699 }
700
701 let mut ep_map = FastHashSet::default();
702 for ep in module.entry_points.iter() {
703 if !ep_map.insert((ep.stage, &ep.name)) {
704 return Err(ValidationError::EntryPoint {
705 stage: ep.stage,
706 name: ep.name.clone(),
707 source: EntryPointError::Conflict,
708 }
709 .with_span()); }
711
712 match self.validate_entry_point(ep, module, &mod_info, &global_expr_kind) {
713 Ok(info) => mod_info.entry_points.push(info),
714 Err(error) => {
715 return Err(error.and_then(|source| {
716 ValidationError::EntryPoint {
717 stage: ep.stage,
718 name: ep.name.clone(),
719 source,
720 }
721 .with_span()
722 }));
723 }
724 }
725 }
726
727 Ok(mod_info)
728 }
729}
730
731fn validate_atomic_compare_exchange_struct(
732 types: &crate::UniqueArena<crate::Type>,
733 members: &[crate::StructMember],
734 scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
735) -> bool {
736 members.len() == 2
737 && members[0].name.as_deref() == Some("old_value")
738 && scalar_predicate(&types[members[0].ty].inner)
739 && members[1].name.as_deref() == Some("exchanged")
740 && types[members[1].ty].inner == crate::TypeInner::Scalar(crate::Scalar::BOOL)
741}