1use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
9use crate::diagnostic_filter::{DiagnosticFilterNode, StandardFilterableTriggeringRule};
10use crate::span::{AddSpan as _, WithSpan};
11use crate::{
12 arena::{Arena, Handle},
13 proc::{ResolveContext, TypeResolution},
14};
15use std::ops;
16
17pub type NonUniformResult = Option<Handle<crate::Expression>>;
18
19const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true;
20
21bitflags::bitflags! {
22 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
24 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
25 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
26 pub struct UniformityRequirements: u8 {
27 const WORK_GROUP_BARRIER = 0x1;
28 const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 };
29 const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 };
30 }
31}
32
33#[derive(Clone, Debug)]
35#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
36#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
37#[cfg_attr(test, derive(PartialEq))]
38pub struct Uniformity {
39 pub non_uniform_result: NonUniformResult,
51 pub requirements: UniformityRequirements,
53}
54
55impl Uniformity {
56 const fn new() -> Self {
57 Uniformity {
58 non_uniform_result: None,
59 requirements: UniformityRequirements::empty(),
60 }
61 }
62}
63
64bitflags::bitflags! {
65 #[derive(Clone, Copy, Debug, PartialEq)]
66 struct ExitFlags: u8 {
67 const MAY_RETURN = 0x1;
71 const MAY_KILL = 0x2;
76 }
77}
78
79#[cfg_attr(test, derive(Debug, PartialEq))]
81struct FunctionUniformity {
82 result: Uniformity,
83 exit: ExitFlags,
84}
85
86impl ops::BitOr for FunctionUniformity {
87 type Output = Self;
88 fn bitor(self, other: Self) -> Self {
89 FunctionUniformity {
90 result: Uniformity {
91 non_uniform_result: self
92 .result
93 .non_uniform_result
94 .or(other.result.non_uniform_result),
95 requirements: self.result.requirements | other.result.requirements,
96 },
97 exit: self.exit | other.exit,
98 }
99 }
100}
101
102impl FunctionUniformity {
103 const fn new() -> Self {
104 FunctionUniformity {
105 result: Uniformity::new(),
106 exit: ExitFlags::empty(),
107 }
108 }
109
110 const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
112 if self.exit.contains(ExitFlags::MAY_RETURN) {
113 Some(UniformityDisruptor::Return)
114 } else if self.exit.contains(ExitFlags::MAY_KILL) {
115 Some(UniformityDisruptor::Discard)
116 } else {
117 None
118 }
119 }
120}
121
122bitflags::bitflags! {
123 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
125 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
126 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
127 pub struct GlobalUse: u8 {
128 const READ = 0x1;
130 const WRITE = 0x2;
132 const QUERY = 0x4;
134 const ATOMIC = 0x8;
136 }
137}
138
139#[derive(Clone, Debug, Eq, Hash, PartialEq)]
140#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
141#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
142pub struct SamplingKey {
143 pub image: Handle<crate::GlobalVariable>,
144 pub sampler: Handle<crate::GlobalVariable>,
145}
146
147#[derive(Clone, Debug)]
148#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
149#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
150pub struct ExpressionInfo {
152 pub uniformity: Uniformity,
158
159 pub ref_count: usize,
162
163 assignable_global: Option<Handle<crate::GlobalVariable>>,
177
178 pub ty: TypeResolution,
180}
181
182impl ExpressionInfo {
183 const fn new() -> Self {
184 ExpressionInfo {
185 uniformity: Uniformity::new(),
186 ref_count: 0,
187 assignable_global: None,
188 ty: TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
190 kind: crate::ScalarKind::Bool,
191 width: 0,
192 })),
193 }
194 }
195}
196
197#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
198#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
199#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
200enum GlobalOrArgument {
201 Global(Handle<crate::GlobalVariable>),
202 Argument(u32),
203}
204
205impl GlobalOrArgument {
206 fn from_expression(
207 expression_arena: &Arena<crate::Expression>,
208 expression: Handle<crate::Expression>,
209 ) -> Result<GlobalOrArgument, ExpressionError> {
210 Ok(match expression_arena[expression] {
211 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
212 crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
213 crate::Expression::Access { base, .. }
214 | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
215 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
216 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
217 },
218 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
219 })
220 }
221}
222
223#[derive(Debug, Clone, PartialEq, Eq, Hash)]
224#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
225#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
226struct Sampling {
227 image: GlobalOrArgument,
228 sampler: GlobalOrArgument,
229}
230
231#[derive(Debug, Clone)]
232#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
233#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
234pub struct FunctionInfo {
235 #[allow(dead_code)]
237 flags: ValidationFlags,
238 pub available_stages: ShaderStages,
240 pub uniformity: Uniformity,
242 pub may_kill: bool,
244
245 pub sampling_set: crate::FastHashSet<SamplingKey>,
260
261 global_uses: Box<[GlobalUse]>,
268
269 expressions: Box<[ExpressionInfo]>,
276
277 sampling: crate::FastHashSet<Sampling>,
290
291 pub dual_source_blending: bool,
293
294 diagnostic_filter_leaf: Option<Handle<DiagnosticFilterNode>>,
300}
301
302impl FunctionInfo {
303 pub const fn global_variable_count(&self) -> usize {
304 self.global_uses.len()
305 }
306 pub const fn expression_count(&self) -> usize {
307 self.expressions.len()
308 }
309 pub fn dominates_global_use(&self, other: &Self) -> bool {
310 for (self_global_uses, other_global_uses) in
311 self.global_uses.iter().zip(other.global_uses.iter())
312 {
313 if !self_global_uses.contains(*other_global_uses) {
314 return false;
315 }
316 }
317 true
318 }
319}
320
321impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
322 type Output = GlobalUse;
323 fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
324 &self.global_uses[handle.index()]
325 }
326}
327
328impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
329 type Output = ExpressionInfo;
330 fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
331 &self.expressions[handle.index()]
332 }
333}
334
335#[derive(Clone, Copy, Debug, thiserror::Error)]
337#[cfg_attr(test, derive(PartialEq))]
338pub enum UniformityDisruptor {
339 #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
340 Expression(Handle<crate::Expression>),
341 #[error("There is a Return earlier in the control flow of the function")]
342 Return,
343 #[error("There is a Discard earlier in the entry point across all called functions")]
344 Discard,
345}
346
347impl FunctionInfo {
348 #[must_use]
356 fn add_ref_impl(
357 &mut self,
358 expr: Handle<crate::Expression>,
359 global_use: GlobalUse,
360 ) -> NonUniformResult {
361 let info = &mut self.expressions[expr.index()];
362 info.ref_count += 1;
363 if let Some(global) = info.assignable_global {
365 self.global_uses[global.index()] |= global_use;
366 }
367 info.uniformity.non_uniform_result
368 }
369
370 #[must_use]
377 fn add_ref(&mut self, expr: Handle<crate::Expression>) -> NonUniformResult {
378 self.add_ref_impl(expr, GlobalUse::READ)
379 }
380
381 #[must_use]
400 fn add_assignable_ref(
401 &mut self,
402 expr: Handle<crate::Expression>,
403 assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
404 ) -> NonUniformResult {
405 let info = &mut self.expressions[expr.index()];
406 info.ref_count += 1;
407 if let Some(global) = info.assignable_global {
410 if let Some(_old) = assignable_global.replace(global) {
411 unreachable!()
412 }
413 }
414 info.uniformity.non_uniform_result
415 }
416
417 fn process_call(
419 &mut self,
420 callee: &Self,
421 arguments: &[Handle<crate::Expression>],
422 expression_arena: &Arena<crate::Expression>,
423 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
424 self.sampling_set
425 .extend(callee.sampling_set.iter().cloned());
426 for sampling in callee.sampling.iter() {
427 let image_storage = match sampling.image {
430 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
431 GlobalOrArgument::Argument(i) => {
432 let Some(handle) = arguments.get(i as usize).cloned() else {
433 break;
435 };
436 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
437 |source| {
438 FunctionError::Expression { handle, source }
439 .with_span_handle(handle, expression_arena)
440 },
441 )?
442 }
443 };
444
445 let sampler_storage = match sampling.sampler {
446 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
447 GlobalOrArgument::Argument(i) => {
448 let Some(handle) = arguments.get(i as usize).cloned() else {
449 break;
451 };
452 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
453 |source| {
454 FunctionError::Expression { handle, source }
455 .with_span_handle(handle, expression_arena)
456 },
457 )?
458 }
459 };
460
461 match (image_storage, sampler_storage) {
466 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
467 self.sampling_set.insert(SamplingKey { image, sampler });
468 }
469 (image, sampler) => {
470 self.sampling.insert(Sampling { image, sampler });
471 }
472 }
473 }
474
475 for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
477 *mine |= *other;
478 }
479
480 Ok(FunctionUniformity {
481 result: callee.uniformity.clone(),
482 exit: if callee.may_kill {
483 ExitFlags::MAY_KILL
484 } else {
485 ExitFlags::empty()
486 },
487 })
488 }
489
490 #[allow(clippy::or_fun_call)]
510 fn process_expression(
511 &mut self,
512 handle: Handle<crate::Expression>,
513 expression_arena: &Arena<crate::Expression>,
514 other_functions: &[FunctionInfo],
515 resolve_context: &ResolveContext,
516 capabilities: super::Capabilities,
517 ) -> Result<(), ExpressionError> {
518 use crate::{Expression as E, SampleLevel as Sl};
519
520 let expression = &expression_arena[handle];
521 let mut assignable_global = None;
522 let uniformity = match *expression {
523 E::Access { base, index } => {
524 let base_ty = self[base].ty.inner_with(resolve_context.types);
525
526 let mut needed_caps = super::Capabilities::empty();
528 let is_binding_array = match *base_ty {
529 crate::TypeInner::BindingArray {
530 base: array_element_ty_handle,
531 ..
532 } => {
533 let ub_st = super::Capabilities::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING;
535 let st_sb = super::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
536 let sampler = super::Capabilities::SAMPLER_NON_UNIFORM_INDEXING;
537
538 let array_element_ty =
540 &resolve_context.types[array_element_ty_handle].inner;
541
542 needed_caps |= match *array_element_ty {
543 crate::TypeInner::Image { class, .. } => match class {
545 crate::ImageClass::Storage { .. } => ub_st,
546 _ => st_sb,
547 },
548 crate::TypeInner::Sampler { .. } => sampler,
549 _ => {
551 if let E::GlobalVariable(global_handle) = expression_arena[base] {
552 let global = &resolve_context.global_vars[global_handle];
553 match global.space {
554 crate::AddressSpace::Uniform => ub_st,
555 crate::AddressSpace::Storage { .. } => st_sb,
556 _ => unreachable!(),
557 }
558 } else {
559 unreachable!()
560 }
561 }
562 };
563
564 true
565 }
566 _ => false,
567 };
568
569 if self[index].uniformity.non_uniform_result.is_some()
570 && !capabilities.contains(needed_caps)
571 && is_binding_array
572 {
573 return Err(ExpressionError::MissingCapabilities(needed_caps));
574 }
575
576 Uniformity {
577 non_uniform_result: self
578 .add_assignable_ref(base, &mut assignable_global)
579 .or(self.add_ref(index)),
580 requirements: UniformityRequirements::empty(),
581 }
582 }
583 E::AccessIndex { base, .. } => Uniformity {
584 non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
585 requirements: UniformityRequirements::empty(),
586 },
587 E::Splat { size: _, value } => Uniformity {
589 non_uniform_result: self.add_ref(value),
590 requirements: UniformityRequirements::empty(),
591 },
592 E::Swizzle { vector, .. } => Uniformity {
593 non_uniform_result: self.add_ref(vector),
594 requirements: UniformityRequirements::empty(),
595 },
596 E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(),
597 E::Compose { ref components, .. } => {
598 let non_uniform_result = components
599 .iter()
600 .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
601 Uniformity {
602 non_uniform_result,
603 requirements: UniformityRequirements::empty(),
604 }
605 }
606 E::FunctionArgument(index) => {
608 let arg = &resolve_context.arguments[index as usize];
609 let uniform = match arg.binding {
610 Some(crate::Binding::BuiltIn(
611 crate::BuiltIn::WorkGroupId
613 | crate::BuiltIn::WorkGroupSize
614 | crate::BuiltIn::NumWorkGroups,
615 )) => true,
616 _ => false,
617 };
618 Uniformity {
619 non_uniform_result: if uniform { None } else { Some(handle) },
620 requirements: UniformityRequirements::empty(),
621 }
622 }
623 E::GlobalVariable(gh) => {
625 use crate::AddressSpace as As;
626 assignable_global = Some(gh);
627 let var = &resolve_context.global_vars[gh];
628 let uniform = match var.space {
629 As::Function | As::Private => false,
631 As::WorkGroup => true,
633 As::Uniform | As::PushConstant => true,
635 As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
637 As::Handle => false,
638 };
639 Uniformity {
640 non_uniform_result: if uniform { None } else { Some(handle) },
641 requirements: UniformityRequirements::empty(),
642 }
643 }
644 E::LocalVariable(_) => Uniformity {
645 non_uniform_result: Some(handle),
646 requirements: UniformityRequirements::empty(),
647 },
648 E::Load { pointer } => Uniformity {
649 non_uniform_result: self.add_ref(pointer),
650 requirements: UniformityRequirements::empty(),
651 },
652 E::ImageSample {
653 image,
654 sampler,
655 gather: _,
656 coordinate,
657 array_index,
658 offset: _,
659 level,
660 depth_ref,
661 } => {
662 let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
663 let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
664
665 match (image_storage, sampler_storage) {
666 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
667 self.sampling_set.insert(SamplingKey { image, sampler });
668 }
669 _ => {
670 self.sampling.insert(Sampling {
671 image: image_storage,
672 sampler: sampler_storage,
673 });
674 }
675 }
676
677 let array_nur = array_index.and_then(|h| self.add_ref(h));
679 let level_nur = match level {
680 Sl::Auto | Sl::Zero => None,
681 Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
682 Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
683 };
684 let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
685 Uniformity {
686 non_uniform_result: self
687 .add_ref(image)
688 .or(self.add_ref(sampler))
689 .or(self.add_ref(coordinate))
690 .or(array_nur)
691 .or(level_nur)
692 .or(dref_nur),
693 requirements: if level.implicit_derivatives() {
694 UniformityRequirements::IMPLICIT_LEVEL
695 } else {
696 UniformityRequirements::empty()
697 },
698 }
699 }
700 E::ImageLoad {
701 image,
702 coordinate,
703 array_index,
704 sample,
705 level,
706 } => {
707 let array_nur = array_index.and_then(|h| self.add_ref(h));
708 let sample_nur = sample.and_then(|h| self.add_ref(h));
709 let level_nur = level.and_then(|h| self.add_ref(h));
710 Uniformity {
711 non_uniform_result: self
712 .add_ref(image)
713 .or(self.add_ref(coordinate))
714 .or(array_nur)
715 .or(sample_nur)
716 .or(level_nur),
717 requirements: UniformityRequirements::empty(),
718 }
719 }
720 E::ImageQuery { image, query } => {
721 let query_nur = match query {
722 crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
723 _ => None,
724 };
725 Uniformity {
726 non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
727 requirements: UniformityRequirements::empty(),
728 }
729 }
730 E::Unary { expr, .. } => Uniformity {
731 non_uniform_result: self.add_ref(expr),
732 requirements: UniformityRequirements::empty(),
733 },
734 E::Binary { left, right, .. } => Uniformity {
735 non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
736 requirements: UniformityRequirements::empty(),
737 },
738 E::Select {
739 condition,
740 accept,
741 reject,
742 } => Uniformity {
743 non_uniform_result: self
744 .add_ref(condition)
745 .or(self.add_ref(accept))
746 .or(self.add_ref(reject)),
747 requirements: UniformityRequirements::empty(),
748 },
749 E::Derivative { expr, .. } => Uniformity {
751 non_uniform_result: self.add_ref(expr),
753 requirements: UniformityRequirements::DERIVATIVE,
754 },
755 E::Relational { argument, .. } => Uniformity {
756 non_uniform_result: self.add_ref(argument),
757 requirements: UniformityRequirements::empty(),
758 },
759 E::Math {
760 fun: _,
761 arg,
762 arg1,
763 arg2,
764 arg3,
765 } => {
766 let arg1_nur = arg1.and_then(|h| self.add_ref(h));
767 let arg2_nur = arg2.and_then(|h| self.add_ref(h));
768 let arg3_nur = arg3.and_then(|h| self.add_ref(h));
769 Uniformity {
770 non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur),
771 requirements: UniformityRequirements::empty(),
772 }
773 }
774 E::As { expr, .. } => Uniformity {
775 non_uniform_result: self.add_ref(expr),
776 requirements: UniformityRequirements::empty(),
777 },
778 E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
779 E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
780 non_uniform_result: Some(handle),
781 requirements: UniformityRequirements::empty(),
782 },
783 E::WorkGroupUniformLoadResult { .. } => Uniformity {
784 non_uniform_result: None,
786 requirements: UniformityRequirements::empty(),
789 },
790 E::ArrayLength(expr) => Uniformity {
791 non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
792 requirements: UniformityRequirements::empty(),
793 },
794 E::RayQueryGetIntersection {
795 query,
796 committed: _,
797 } => Uniformity {
798 non_uniform_result: self.add_ref(query),
799 requirements: UniformityRequirements::empty(),
800 },
801 E::SubgroupBallotResult => Uniformity {
802 non_uniform_result: Some(handle),
803 requirements: UniformityRequirements::empty(),
804 },
805 E::SubgroupOperationResult { .. } => Uniformity {
806 non_uniform_result: Some(handle),
807 requirements: UniformityRequirements::empty(),
808 },
809 };
810
811 let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
812 self.expressions[handle.index()] = ExpressionInfo {
813 uniformity,
814 ref_count: 0,
815 assignable_global,
816 ty,
817 };
818 Ok(())
819 }
820
821 #[allow(clippy::or_fun_call)]
831 fn process_block(
832 &mut self,
833 statements: &crate::Block,
834 other_functions: &[FunctionInfo],
835 mut disruptor: Option<UniformityDisruptor>,
836 expression_arena: &Arena<crate::Expression>,
837 diagnostic_filter_arena: &Arena<DiagnosticFilterNode>,
838 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
839 use crate::Statement as S;
840
841 let mut combined_uniformity = FunctionUniformity::new();
842 for statement in statements {
843 let uniformity = match *statement {
844 S::Emit(ref range) => {
845 let mut requirements = UniformityRequirements::empty();
846 for expr in range.clone() {
847 let req = self.expressions[expr.index()].uniformity.requirements;
848 if self
849 .flags
850 .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY)
851 && !req.is_empty()
852 {
853 if let Some(cause) = disruptor {
854 let severity = DiagnosticFilterNode::search(
855 self.diagnostic_filter_leaf,
856 diagnostic_filter_arena,
857 StandardFilterableTriggeringRule::DerivativeUniformity,
858 );
859 severity.report_diag(
860 FunctionError::NonUniformControlFlow(req, expr, cause)
861 .with_span_handle(expr, expression_arena),
862 |e, level| log::log!(level, "{e}"),
868 )?;
869 }
870 }
871 requirements |= req;
872 }
873 FunctionUniformity {
874 result: Uniformity {
875 non_uniform_result: None,
876 requirements,
877 },
878 exit: ExitFlags::empty(),
879 }
880 }
881 S::Break | S::Continue => FunctionUniformity::new(),
882 S::Kill => FunctionUniformity {
883 result: Uniformity::new(),
884 exit: if disruptor.is_some() {
885 ExitFlags::MAY_KILL
886 } else {
887 ExitFlags::empty()
888 },
889 },
890 S::Barrier(_) => FunctionUniformity {
891 result: Uniformity {
892 non_uniform_result: None,
893 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
894 },
895 exit: ExitFlags::empty(),
896 },
897 S::WorkGroupUniformLoad { pointer, .. } => {
898 let _condition_nur = self.add_ref(pointer);
899
900 FunctionUniformity {
919 result: Uniformity {
920 non_uniform_result: None,
921 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
922 },
923 exit: ExitFlags::empty(),
924 }
925 }
926 S::Block(ref b) => self.process_block(
927 b,
928 other_functions,
929 disruptor,
930 expression_arena,
931 diagnostic_filter_arena,
932 )?,
933 S::If {
934 condition,
935 ref accept,
936 ref reject,
937 } => {
938 let condition_nur = self.add_ref(condition);
939 let branch_disruptor =
940 disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
941 let accept_uniformity = self.process_block(
942 accept,
943 other_functions,
944 branch_disruptor,
945 expression_arena,
946 diagnostic_filter_arena,
947 )?;
948 let reject_uniformity = self.process_block(
949 reject,
950 other_functions,
951 branch_disruptor,
952 expression_arena,
953 diagnostic_filter_arena,
954 )?;
955 accept_uniformity | reject_uniformity
956 }
957 S::Switch {
958 selector,
959 ref cases,
960 } => {
961 let selector_nur = self.add_ref(selector);
962 let branch_disruptor =
963 disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
964 let mut uniformity = FunctionUniformity::new();
965 let mut case_disruptor = branch_disruptor;
966 for case in cases.iter() {
967 let case_uniformity = self.process_block(
968 &case.body,
969 other_functions,
970 case_disruptor,
971 expression_arena,
972 diagnostic_filter_arena,
973 )?;
974 case_disruptor = if case.fall_through {
975 case_disruptor.or(case_uniformity.exit_disruptor())
976 } else {
977 branch_disruptor
978 };
979 uniformity = uniformity | case_uniformity;
980 }
981 uniformity
982 }
983 S::Loop {
984 ref body,
985 ref continuing,
986 break_if,
987 } => {
988 let body_uniformity = self.process_block(
989 body,
990 other_functions,
991 disruptor,
992 expression_arena,
993 diagnostic_filter_arena,
994 )?;
995 let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
996 let continuing_uniformity = self.process_block(
997 continuing,
998 other_functions,
999 continuing_disruptor,
1000 expression_arena,
1001 diagnostic_filter_arena,
1002 )?;
1003 if let Some(expr) = break_if {
1004 let _ = self.add_ref(expr);
1005 }
1006 body_uniformity | continuing_uniformity
1007 }
1008 S::Return { value } => FunctionUniformity {
1009 result: Uniformity {
1010 non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
1011 requirements: UniformityRequirements::empty(),
1012 },
1013 exit: if disruptor.is_some() {
1014 ExitFlags::MAY_RETURN
1015 } else {
1016 ExitFlags::empty()
1017 },
1018 },
1019 S::Store { pointer, value } => {
1023 let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
1024 let _ = self.add_ref(value);
1025 FunctionUniformity::new()
1026 }
1027 S::ImageStore {
1028 image,
1029 coordinate,
1030 array_index,
1031 value,
1032 } => {
1033 let _ = self.add_ref_impl(image, GlobalUse::WRITE);
1034 if let Some(expr) = array_index {
1035 let _ = self.add_ref(expr);
1036 }
1037 let _ = self.add_ref(coordinate);
1038 let _ = self.add_ref(value);
1039 FunctionUniformity::new()
1040 }
1041 S::Call {
1042 function,
1043 ref arguments,
1044 result: _,
1045 } => {
1046 for &argument in arguments {
1047 let _ = self.add_ref(argument);
1048 }
1049 let info = &other_functions[function.index()];
1050 self.process_call(info, arguments, expression_arena)?
1052 }
1053 S::Atomic {
1054 pointer,
1055 ref fun,
1056 value,
1057 result: _,
1058 } => {
1059 let _ = self.add_ref_impl(pointer, GlobalUse::READ | GlobalUse::WRITE);
1060 let _ = self.add_ref(value);
1061 if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
1062 let _ = self.add_ref(cmp);
1063 }
1064 FunctionUniformity::new()
1065 }
1066 S::ImageAtomic {
1067 image,
1068 coordinate,
1069 array_index,
1070 fun: _,
1071 value,
1072 } => {
1073 let _ = self.add_ref_impl(image, GlobalUse::ATOMIC);
1074 let _ = self.add_ref(coordinate);
1075 if let Some(expr) = array_index {
1076 let _ = self.add_ref(expr);
1077 }
1078 let _ = self.add_ref(value);
1079 FunctionUniformity::new()
1080 }
1081 S::RayQuery { query, ref fun } => {
1082 let _ = self.add_ref(query);
1083 if let crate::RayQueryFunction::Initialize {
1084 acceleration_structure,
1085 descriptor,
1086 } = *fun
1087 {
1088 let _ = self.add_ref(acceleration_structure);
1089 let _ = self.add_ref(descriptor);
1090 }
1091 FunctionUniformity::new()
1092 }
1093 S::SubgroupBallot {
1094 result: _,
1095 predicate,
1096 } => {
1097 if let Some(predicate) = predicate {
1098 let _ = self.add_ref(predicate);
1099 }
1100 FunctionUniformity::new()
1101 }
1102 S::SubgroupCollectiveOperation {
1103 op: _,
1104 collective_op: _,
1105 argument,
1106 result: _,
1107 } => {
1108 let _ = self.add_ref(argument);
1109 FunctionUniformity::new()
1110 }
1111 S::SubgroupGather {
1112 mode,
1113 argument,
1114 result: _,
1115 } => {
1116 let _ = self.add_ref(argument);
1117 match mode {
1118 crate::GatherMode::BroadcastFirst => {}
1119 crate::GatherMode::Broadcast(index)
1120 | crate::GatherMode::Shuffle(index)
1121 | crate::GatherMode::ShuffleDown(index)
1122 | crate::GatherMode::ShuffleUp(index)
1123 | crate::GatherMode::ShuffleXor(index) => {
1124 let _ = self.add_ref(index);
1125 }
1126 }
1127 FunctionUniformity::new()
1128 }
1129 };
1130
1131 disruptor = disruptor.or(uniformity.exit_disruptor());
1132 combined_uniformity = combined_uniformity | uniformity;
1133 }
1134 Ok(combined_uniformity)
1135 }
1136}
1137
1138impl ModuleInfo {
1139 pub(super) fn process_const_expression(
1141 &mut self,
1142 handle: Handle<crate::Expression>,
1143 resolve_context: &ResolveContext,
1144 gctx: crate::proc::GlobalCtx,
1145 ) -> Result<(), super::ConstExpressionError> {
1146 self.const_expression_types[handle.index()] =
1147 resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?;
1148 Ok(())
1149 }
1150
1151 pub(super) fn process_function(
1154 &self,
1155 fun: &crate::Function,
1156 module: &crate::Module,
1157 flags: ValidationFlags,
1158 capabilities: super::Capabilities,
1159 ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1160 let mut info = FunctionInfo {
1161 flags,
1162 available_stages: ShaderStages::all(),
1163 uniformity: Uniformity::new(),
1164 may_kill: false,
1165 sampling_set: crate::FastHashSet::default(),
1166 global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
1167 expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
1168 sampling: crate::FastHashSet::default(),
1169 dual_source_blending: false,
1170 diagnostic_filter_leaf: fun.diagnostic_filter_leaf,
1171 };
1172 let resolve_context =
1173 ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
1174
1175 for (handle, _) in fun.expressions.iter() {
1176 if let Err(source) = info.process_expression(
1177 handle,
1178 &fun.expressions,
1179 &self.functions,
1180 &resolve_context,
1181 capabilities,
1182 ) {
1183 return Err(FunctionError::Expression { handle, source }
1184 .with_span_handle(handle, &fun.expressions));
1185 }
1186 }
1187
1188 for (_, expr) in fun.local_variables.iter() {
1189 if let Some(init) = expr.init {
1190 let _ = info.add_ref(init);
1191 }
1192 }
1193
1194 let uniformity = info.process_block(
1195 &fun.body,
1196 &self.functions,
1197 None,
1198 &fun.expressions,
1199 &module.diagnostic_filters,
1200 )?;
1201 info.uniformity = uniformity.result;
1202 info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
1203
1204 Ok(info)
1205 }
1206
1207 pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
1208 &self.entry_points[index]
1209 }
1210}
1211
1212#[test]
1213fn uniform_control_flow() {
1214 use crate::{Expression as E, Statement as S};
1215
1216 let mut type_arena = crate::UniqueArena::new();
1217 let ty = type_arena.insert(
1218 crate::Type {
1219 name: None,
1220 inner: crate::TypeInner::Vector {
1221 size: crate::VectorSize::Bi,
1222 scalar: crate::Scalar::F32,
1223 },
1224 },
1225 Default::default(),
1226 );
1227 let mut global_var_arena = Arena::new();
1228 let non_uniform_global = global_var_arena.append(
1229 crate::GlobalVariable {
1230 name: None,
1231 init: None,
1232 ty,
1233 space: crate::AddressSpace::Handle,
1234 binding: None,
1235 },
1236 Default::default(),
1237 );
1238 let uniform_global = global_var_arena.append(
1239 crate::GlobalVariable {
1240 name: None,
1241 init: None,
1242 ty,
1243 binding: None,
1244 space: crate::AddressSpace::Uniform,
1245 },
1246 Default::default(),
1247 );
1248
1249 let mut expressions = Arena::new();
1250 let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default());
1252 let derivative_expr = expressions.append(
1254 E::Derivative {
1255 axis: crate::DerivativeAxis::X,
1256 ctrl: crate::DerivativeControl::None,
1257 expr: constant_expr,
1258 },
1259 Default::default(),
1260 );
1261 let emit_range_constant_derivative = expressions.range_from(0);
1262 let non_uniform_global_expr =
1263 expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
1264 let uniform_global_expr =
1265 expressions.append(E::GlobalVariable(uniform_global), Default::default());
1266 let emit_range_globals = expressions.range_from(2);
1267
1268 let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
1270 let access_expr = expressions.append(
1272 E::AccessIndex {
1273 base: non_uniform_global_expr,
1274 index: 1,
1275 },
1276 Default::default(),
1277 );
1278 let emit_range_query_access_globals = expressions.range_from(2);
1279
1280 let mut info = FunctionInfo {
1281 flags: ValidationFlags::all(),
1282 available_stages: ShaderStages::all(),
1283 uniformity: Uniformity::new(),
1284 may_kill: false,
1285 sampling_set: crate::FastHashSet::default(),
1286 global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
1287 expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
1288 sampling: crate::FastHashSet::default(),
1289 dual_source_blending: false,
1290 diagnostic_filter_leaf: None,
1291 };
1292 let resolve_context = ResolveContext {
1293 constants: &Arena::new(),
1294 overrides: &Arena::new(),
1295 types: &type_arena,
1296 special_types: &crate::SpecialTypes::default(),
1297 global_vars: &global_var_arena,
1298 local_vars: &Arena::new(),
1299 functions: &Arena::new(),
1300 arguments: &[],
1301 };
1302 for (handle, _) in expressions.iter() {
1303 info.process_expression(
1304 handle,
1305 &expressions,
1306 &[],
1307 &resolve_context,
1308 super::Capabilities::empty(),
1309 )
1310 .unwrap();
1311 }
1312 assert_eq!(info[non_uniform_global_expr].ref_count, 1);
1313 assert_eq!(info[uniform_global_expr].ref_count, 1);
1314 assert_eq!(info[query_expr].ref_count, 0);
1315 assert_eq!(info[access_expr].ref_count, 0);
1316 assert_eq!(info[non_uniform_global], GlobalUse::empty());
1317 assert_eq!(info[uniform_global], GlobalUse::QUERY);
1318
1319 let stmt_emit1 = S::Emit(emit_range_globals.clone());
1320 let stmt_if_uniform = S::If {
1321 condition: uniform_global_expr,
1322 accept: crate::Block::new(),
1323 reject: vec![
1324 S::Emit(emit_range_constant_derivative.clone()),
1325 S::Store {
1326 pointer: constant_expr,
1327 value: derivative_expr,
1328 },
1329 ]
1330 .into(),
1331 };
1332 assert_eq!(
1333 info.process_block(
1334 &vec![stmt_emit1, stmt_if_uniform].into(),
1335 &[],
1336 None,
1337 &expressions,
1338 &Arena::new(),
1339 ),
1340 Ok(FunctionUniformity {
1341 result: Uniformity {
1342 non_uniform_result: None,
1343 requirements: UniformityRequirements::DERIVATIVE,
1344 },
1345 exit: ExitFlags::empty(),
1346 }),
1347 );
1348 assert_eq!(info[constant_expr].ref_count, 2);
1349 assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
1350
1351 let stmt_emit2 = S::Emit(emit_range_globals.clone());
1352 let stmt_if_non_uniform = S::If {
1353 condition: non_uniform_global_expr,
1354 accept: vec![
1355 S::Emit(emit_range_constant_derivative),
1356 S::Store {
1357 pointer: constant_expr,
1358 value: derivative_expr,
1359 },
1360 ]
1361 .into(),
1362 reject: crate::Block::new(),
1363 };
1364 {
1365 let block_info = info.process_block(
1366 &vec![stmt_emit2.clone(), stmt_if_non_uniform.clone()].into(),
1367 &[],
1368 None,
1369 &expressions,
1370 &Arena::new(),
1371 );
1372 if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE {
1373 assert_eq!(info[derivative_expr].ref_count, 2);
1374 } else {
1375 assert_eq!(
1376 block_info,
1377 Err(FunctionError::NonUniformControlFlow(
1378 UniformityRequirements::DERIVATIVE,
1379 derivative_expr,
1380 UniformityDisruptor::Expression(non_uniform_global_expr)
1381 )
1382 .with_span()),
1383 );
1384 assert_eq!(info[derivative_expr].ref_count, 1);
1385
1386 let mut diagnostic_filters = Arena::new();
1388 let diagnostic_filter_leaf = diagnostic_filters.append(
1389 DiagnosticFilterNode {
1390 inner: crate::diagnostic_filter::DiagnosticFilter {
1391 new_severity: crate::diagnostic_filter::Severity::Off,
1392 triggering_rule:
1393 crate::diagnostic_filter::FilterableTriggeringRule::Standard(
1394 StandardFilterableTriggeringRule::DerivativeUniformity,
1395 ),
1396 },
1397 parent: None,
1398 },
1399 crate::Span::default(),
1400 );
1401 let mut info = FunctionInfo {
1402 diagnostic_filter_leaf: Some(diagnostic_filter_leaf),
1403 ..info.clone()
1404 };
1405
1406 let block_info = info.process_block(
1407 &vec![stmt_emit2, stmt_if_non_uniform].into(),
1408 &[],
1409 None,
1410 &expressions,
1411 &diagnostic_filters,
1412 );
1413 assert_eq!(
1414 block_info,
1415 Ok(FunctionUniformity {
1416 result: Uniformity {
1417 non_uniform_result: None,
1418 requirements: UniformityRequirements::DERIVATIVE,
1419 },
1420 exit: ExitFlags::empty()
1421 }),
1422 );
1423 assert_eq!(info[derivative_expr].ref_count, 2);
1424 }
1425 }
1426 assert_eq!(info[non_uniform_global], GlobalUse::READ);
1427
1428 let stmt_emit3 = S::Emit(emit_range_globals);
1429 let stmt_return_non_uniform = S::Return {
1430 value: Some(non_uniform_global_expr),
1431 };
1432 assert_eq!(
1433 info.process_block(
1434 &vec![stmt_emit3, stmt_return_non_uniform].into(),
1435 &[],
1436 Some(UniformityDisruptor::Return),
1437 &expressions,
1438 &Arena::new(),
1439 ),
1440 Ok(FunctionUniformity {
1441 result: Uniformity {
1442 non_uniform_result: Some(non_uniform_global_expr),
1443 requirements: UniformityRequirements::empty(),
1444 },
1445 exit: ExitFlags::MAY_RETURN,
1446 }),
1447 );
1448 assert_eq!(info[non_uniform_global_expr].ref_count, 3);
1449
1450 let stmt_emit4 = S::Emit(emit_range_query_access_globals);
1452 let stmt_assign = S::Store {
1453 pointer: access_expr,
1454 value: query_expr,
1455 };
1456 let stmt_return_pointer = S::Return {
1457 value: Some(access_expr),
1458 };
1459 let stmt_kill = S::Kill;
1460 assert_eq!(
1461 info.process_block(
1462 &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
1463 &[],
1464 Some(UniformityDisruptor::Discard),
1465 &expressions,
1466 &Arena::new(),
1467 ),
1468 Ok(FunctionUniformity {
1469 result: Uniformity {
1470 non_uniform_result: Some(non_uniform_global_expr),
1471 requirements: UniformityRequirements::empty(),
1472 },
1473 exit: ExitFlags::all(),
1474 }),
1475 );
1476 assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
1477}