1use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
9use crate::span::{AddSpan as _, WithSpan};
10use crate::{
11 arena::{Arena, Handle},
12 proc::{ResolveContext, TypeResolution},
13};
14use std::ops;
15
16pub type NonUniformResult = Option<Handle<crate::Expression>>;
17
18const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true;
21
22bitflags::bitflags! {
23 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
25 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
26 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
27 pub struct UniformityRequirements: u8 {
28 const WORK_GROUP_BARRIER = 0x1;
29 const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 };
30 const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 };
31 }
32}
33
34#[derive(Clone, Debug)]
36#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
37#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
38#[cfg_attr(test, derive(PartialEq))]
39pub struct Uniformity {
40 pub non_uniform_result: NonUniformResult,
52 pub requirements: UniformityRequirements,
54}
55
56impl Uniformity {
57 const fn new() -> Self {
58 Uniformity {
59 non_uniform_result: None,
60 requirements: UniformityRequirements::empty(),
61 }
62 }
63}
64
65bitflags::bitflags! {
66 #[derive(Clone, Copy, Debug, PartialEq)]
67 struct ExitFlags: u8 {
68 const MAY_RETURN = 0x1;
72 const MAY_KILL = 0x2;
77 }
78}
79
80#[cfg_attr(test, derive(Debug, PartialEq))]
82struct FunctionUniformity {
83 result: Uniformity,
84 exit: ExitFlags,
85}
86
87impl ops::BitOr for FunctionUniformity {
88 type Output = Self;
89 fn bitor(self, other: Self) -> Self {
90 FunctionUniformity {
91 result: Uniformity {
92 non_uniform_result: self
93 .result
94 .non_uniform_result
95 .or(other.result.non_uniform_result),
96 requirements: self.result.requirements | other.result.requirements,
97 },
98 exit: self.exit | other.exit,
99 }
100 }
101}
102
103impl FunctionUniformity {
104 const fn new() -> Self {
105 FunctionUniformity {
106 result: Uniformity::new(),
107 exit: ExitFlags::empty(),
108 }
109 }
110
111 const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
113 if self.exit.contains(ExitFlags::MAY_RETURN) {
114 Some(UniformityDisruptor::Return)
115 } else if self.exit.contains(ExitFlags::MAY_KILL) {
116 Some(UniformityDisruptor::Discard)
117 } else {
118 None
119 }
120 }
121}
122
123bitflags::bitflags! {
124 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
126 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
127 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
128 pub struct GlobalUse: u8 {
129 const READ = 0x1;
131 const WRITE = 0x2;
133 const QUERY = 0x4;
135 }
136}
137
138#[derive(Clone, Debug, Eq, Hash, PartialEq)]
139#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
140#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
141pub struct SamplingKey {
142 pub image: Handle<crate::GlobalVariable>,
143 pub sampler: Handle<crate::GlobalVariable>,
144}
145
146#[derive(Clone, Debug)]
147#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
148#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
149pub struct ExpressionInfo {
151 pub uniformity: Uniformity,
157
158 pub ref_count: usize,
161
162 assignable_global: Option<Handle<crate::GlobalVariable>>,
176
177 pub ty: TypeResolution,
179}
180
181impl ExpressionInfo {
182 const fn new() -> Self {
183 ExpressionInfo {
184 uniformity: Uniformity::new(),
185 ref_count: 0,
186 assignable_global: None,
187 ty: TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
189 kind: crate::ScalarKind::Bool,
190 width: 0,
191 })),
192 }
193 }
194}
195
196#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
197#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
198#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
199enum GlobalOrArgument {
200 Global(Handle<crate::GlobalVariable>),
201 Argument(u32),
202}
203
204impl GlobalOrArgument {
205 fn from_expression(
206 expression_arena: &Arena<crate::Expression>,
207 expression: Handle<crate::Expression>,
208 ) -> Result<GlobalOrArgument, ExpressionError> {
209 Ok(match expression_arena[expression] {
210 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
211 crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
212 crate::Expression::Access { base, .. }
213 | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
214 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
215 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
216 },
217 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
218 })
219 }
220}
221
222#[derive(Debug, Clone, PartialEq, Eq, Hash)]
223#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
224#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
225struct Sampling {
226 image: GlobalOrArgument,
227 sampler: GlobalOrArgument,
228}
229
230#[derive(Debug, Clone)]
231#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
232#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
233pub struct FunctionInfo {
234 #[allow(dead_code)]
236 flags: ValidationFlags,
237 pub available_stages: ShaderStages,
239 pub uniformity: Uniformity,
241 pub may_kill: bool,
243
244 pub sampling_set: crate::FastHashSet<SamplingKey>,
259
260 global_uses: Box<[GlobalUse]>,
267
268 expressions: Box<[ExpressionInfo]>,
275
276 sampling: crate::FastHashSet<Sampling>,
289
290 pub dual_source_blending: bool,
292}
293
294impl FunctionInfo {
295 pub const fn global_variable_count(&self) -> usize {
296 self.global_uses.len()
297 }
298 pub const fn expression_count(&self) -> usize {
299 self.expressions.len()
300 }
301 pub fn dominates_global_use(&self, other: &Self) -> bool {
302 for (self_global_uses, other_global_uses) in
303 self.global_uses.iter().zip(other.global_uses.iter())
304 {
305 if !self_global_uses.contains(*other_global_uses) {
306 return false;
307 }
308 }
309 true
310 }
311}
312
313impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
314 type Output = GlobalUse;
315 fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
316 &self.global_uses[handle.index()]
317 }
318}
319
320impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
321 type Output = ExpressionInfo;
322 fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
323 &self.expressions[handle.index()]
324 }
325}
326
327#[derive(Clone, Copy, Debug, thiserror::Error)]
329#[cfg_attr(test, derive(PartialEq))]
330pub enum UniformityDisruptor {
331 #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
332 Expression(Handle<crate::Expression>),
333 #[error("There is a Return earlier in the control flow of the function")]
334 Return,
335 #[error("There is a Discard earlier in the entry point across all called functions")]
336 Discard,
337}
338
339impl FunctionInfo {
340 #[must_use]
348 fn add_ref_impl(
349 &mut self,
350 expr: Handle<crate::Expression>,
351 global_use: GlobalUse,
352 ) -> NonUniformResult {
353 let info = &mut self.expressions[expr.index()];
354 info.ref_count += 1;
355 if let Some(global) = info.assignable_global {
357 self.global_uses[global.index()] |= global_use;
358 }
359 info.uniformity.non_uniform_result
360 }
361
362 #[must_use]
369 fn add_ref(&mut self, expr: Handle<crate::Expression>) -> NonUniformResult {
370 self.add_ref_impl(expr, GlobalUse::READ)
371 }
372
373 #[must_use]
392 fn add_assignable_ref(
393 &mut self,
394 expr: Handle<crate::Expression>,
395 assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
396 ) -> NonUniformResult {
397 let info = &mut self.expressions[expr.index()];
398 info.ref_count += 1;
399 if let Some(global) = info.assignable_global {
402 if let Some(_old) = assignable_global.replace(global) {
403 unreachable!()
404 }
405 }
406 info.uniformity.non_uniform_result
407 }
408
409 fn process_call(
411 &mut self,
412 callee: &Self,
413 arguments: &[Handle<crate::Expression>],
414 expression_arena: &Arena<crate::Expression>,
415 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
416 self.sampling_set
417 .extend(callee.sampling_set.iter().cloned());
418 for sampling in callee.sampling.iter() {
419 let image_storage = match sampling.image {
422 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
423 GlobalOrArgument::Argument(i) => {
424 let handle = arguments[i as usize];
425 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
426 |source| {
427 FunctionError::Expression { handle, source }
428 .with_span_handle(handle, expression_arena)
429 },
430 )?
431 }
432 };
433
434 let sampler_storage = match sampling.sampler {
435 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
436 GlobalOrArgument::Argument(i) => {
437 let handle = arguments[i as usize];
438 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
439 |source| {
440 FunctionError::Expression { handle, source }
441 .with_span_handle(handle, expression_arena)
442 },
443 )?
444 }
445 };
446
447 match (image_storage, sampler_storage) {
452 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
453 self.sampling_set.insert(SamplingKey { image, sampler });
454 }
455 (image, sampler) => {
456 self.sampling.insert(Sampling { image, sampler });
457 }
458 }
459 }
460
461 for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
463 *mine |= *other;
464 }
465
466 Ok(FunctionUniformity {
467 result: callee.uniformity.clone(),
468 exit: if callee.may_kill {
469 ExitFlags::MAY_KILL
470 } else {
471 ExitFlags::empty()
472 },
473 })
474 }
475
476 #[allow(clippy::or_fun_call)]
496 fn process_expression(
497 &mut self,
498 handle: Handle<crate::Expression>,
499 expression_arena: &Arena<crate::Expression>,
500 other_functions: &[FunctionInfo],
501 resolve_context: &ResolveContext,
502 capabilities: super::Capabilities,
503 ) -> Result<(), ExpressionError> {
504 use crate::{Expression as E, SampleLevel as Sl};
505
506 let expression = &expression_arena[handle];
507 let mut assignable_global = None;
508 let uniformity = match *expression {
509 E::Access { base, index } => {
510 let base_ty = self[base].ty.inner_with(resolve_context.types);
511
512 let mut needed_caps = super::Capabilities::empty();
514 let is_binding_array = match *base_ty {
515 crate::TypeInner::BindingArray {
516 base: array_element_ty_handle,
517 ..
518 } => {
519 let ub_st = super::Capabilities::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING;
521 let st_sb = super::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
522 let sampler = super::Capabilities::SAMPLER_NON_UNIFORM_INDEXING;
523
524 let array_element_ty =
526 &resolve_context.types[array_element_ty_handle].inner;
527
528 needed_caps |= match *array_element_ty {
529 crate::TypeInner::Image { class, .. } => match class {
531 crate::ImageClass::Storage { .. } => ub_st,
532 _ => st_sb,
533 },
534 crate::TypeInner::Sampler { .. } => sampler,
535 _ => {
537 if let E::GlobalVariable(global_handle) = expression_arena[base] {
538 let global = &resolve_context.global_vars[global_handle];
539 match global.space {
540 crate::AddressSpace::Uniform => ub_st,
541 crate::AddressSpace::Storage { .. } => st_sb,
542 _ => unreachable!(),
543 }
544 } else {
545 unreachable!()
546 }
547 }
548 };
549
550 true
551 }
552 _ => false,
553 };
554
555 if self[index].uniformity.non_uniform_result.is_some()
556 && !capabilities.contains(needed_caps)
557 && is_binding_array
558 {
559 return Err(ExpressionError::MissingCapabilities(needed_caps));
560 }
561
562 Uniformity {
563 non_uniform_result: self
564 .add_assignable_ref(base, &mut assignable_global)
565 .or(self.add_ref(index)),
566 requirements: UniformityRequirements::empty(),
567 }
568 }
569 E::AccessIndex { base, .. } => Uniformity {
570 non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
571 requirements: UniformityRequirements::empty(),
572 },
573 E::Splat { size: _, value } => Uniformity {
575 non_uniform_result: self.add_ref(value),
576 requirements: UniformityRequirements::empty(),
577 },
578 E::Swizzle { vector, .. } => Uniformity {
579 non_uniform_result: self.add_ref(vector),
580 requirements: UniformityRequirements::empty(),
581 },
582 E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(),
583 E::Compose { ref components, .. } => {
584 let non_uniform_result = components
585 .iter()
586 .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
587 Uniformity {
588 non_uniform_result,
589 requirements: UniformityRequirements::empty(),
590 }
591 }
592 E::FunctionArgument(index) => {
594 let arg = &resolve_context.arguments[index as usize];
595 let uniform = match arg.binding {
596 Some(crate::Binding::BuiltIn(
597 crate::BuiltIn::WorkGroupId
599 | crate::BuiltIn::WorkGroupSize
600 | crate::BuiltIn::NumWorkGroups,
601 )) => true,
602 _ => false,
603 };
604 Uniformity {
605 non_uniform_result: if uniform { None } else { Some(handle) },
606 requirements: UniformityRequirements::empty(),
607 }
608 }
609 E::GlobalVariable(gh) => {
611 use crate::AddressSpace as As;
612 assignable_global = Some(gh);
613 let var = &resolve_context.global_vars[gh];
614 let uniform = match var.space {
615 As::Function | As::Private => false,
617 As::WorkGroup => true,
619 As::Uniform | As::PushConstant => true,
621 As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
623 As::Handle => false,
624 };
625 Uniformity {
626 non_uniform_result: if uniform { None } else { Some(handle) },
627 requirements: UniformityRequirements::empty(),
628 }
629 }
630 E::LocalVariable(_) => Uniformity {
631 non_uniform_result: Some(handle),
632 requirements: UniformityRequirements::empty(),
633 },
634 E::Load { pointer } => Uniformity {
635 non_uniform_result: self.add_ref(pointer),
636 requirements: UniformityRequirements::empty(),
637 },
638 E::ImageSample {
639 image,
640 sampler,
641 gather: _,
642 coordinate,
643 array_index,
644 offset: _,
645 level,
646 depth_ref,
647 } => {
648 let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
649 let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
650
651 match (image_storage, sampler_storage) {
652 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
653 self.sampling_set.insert(SamplingKey { image, sampler });
654 }
655 _ => {
656 self.sampling.insert(Sampling {
657 image: image_storage,
658 sampler: sampler_storage,
659 });
660 }
661 }
662
663 let array_nur = array_index.and_then(|h| self.add_ref(h));
665 let level_nur = match level {
666 Sl::Auto | Sl::Zero => None,
667 Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
668 Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
669 };
670 let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
671 Uniformity {
672 non_uniform_result: self
673 .add_ref(image)
674 .or(self.add_ref(sampler))
675 .or(self.add_ref(coordinate))
676 .or(array_nur)
677 .or(level_nur)
678 .or(dref_nur),
679 requirements: if level.implicit_derivatives() {
680 UniformityRequirements::IMPLICIT_LEVEL
681 } else {
682 UniformityRequirements::empty()
683 },
684 }
685 }
686 E::ImageLoad {
687 image,
688 coordinate,
689 array_index,
690 sample,
691 level,
692 } => {
693 let array_nur = array_index.and_then(|h| self.add_ref(h));
694 let sample_nur = sample.and_then(|h| self.add_ref(h));
695 let level_nur = level.and_then(|h| self.add_ref(h));
696 Uniformity {
697 non_uniform_result: self
698 .add_ref(image)
699 .or(self.add_ref(coordinate))
700 .or(array_nur)
701 .or(sample_nur)
702 .or(level_nur),
703 requirements: UniformityRequirements::empty(),
704 }
705 }
706 E::ImageQuery { image, query } => {
707 let query_nur = match query {
708 crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
709 _ => None,
710 };
711 Uniformity {
712 non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
713 requirements: UniformityRequirements::empty(),
714 }
715 }
716 E::Unary { expr, .. } => Uniformity {
717 non_uniform_result: self.add_ref(expr),
718 requirements: UniformityRequirements::empty(),
719 },
720 E::Binary { left, right, .. } => Uniformity {
721 non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
722 requirements: UniformityRequirements::empty(),
723 },
724 E::Select {
725 condition,
726 accept,
727 reject,
728 } => Uniformity {
729 non_uniform_result: self
730 .add_ref(condition)
731 .or(self.add_ref(accept))
732 .or(self.add_ref(reject)),
733 requirements: UniformityRequirements::empty(),
734 },
735 E::Derivative { expr, .. } => Uniformity {
737 non_uniform_result: self.add_ref(expr),
739 requirements: UniformityRequirements::DERIVATIVE,
740 },
741 E::Relational { argument, .. } => Uniformity {
742 non_uniform_result: self.add_ref(argument),
743 requirements: UniformityRequirements::empty(),
744 },
745 E::Math {
746 fun: _,
747 arg,
748 arg1,
749 arg2,
750 arg3,
751 } => {
752 let arg1_nur = arg1.and_then(|h| self.add_ref(h));
753 let arg2_nur = arg2.and_then(|h| self.add_ref(h));
754 let arg3_nur = arg3.and_then(|h| self.add_ref(h));
755 Uniformity {
756 non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur),
757 requirements: UniformityRequirements::empty(),
758 }
759 }
760 E::As { expr, .. } => Uniformity {
761 non_uniform_result: self.add_ref(expr),
762 requirements: UniformityRequirements::empty(),
763 },
764 E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
765 E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
766 non_uniform_result: Some(handle),
767 requirements: UniformityRequirements::empty(),
768 },
769 E::WorkGroupUniformLoadResult { .. } => Uniformity {
770 non_uniform_result: None,
772 requirements: UniformityRequirements::empty(),
775 },
776 E::ArrayLength(expr) => Uniformity {
777 non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
778 requirements: UniformityRequirements::empty(),
779 },
780 E::RayQueryGetIntersection {
781 query,
782 committed: _,
783 } => Uniformity {
784 non_uniform_result: self.add_ref(query),
785 requirements: UniformityRequirements::empty(),
786 },
787 E::SubgroupBallotResult => Uniformity {
788 non_uniform_result: Some(handle),
789 requirements: UniformityRequirements::empty(),
790 },
791 E::SubgroupOperationResult { .. } => Uniformity {
792 non_uniform_result: Some(handle),
793 requirements: UniformityRequirements::empty(),
794 },
795 };
796
797 let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
798 self.expressions[handle.index()] = ExpressionInfo {
799 uniformity,
800 ref_count: 0,
801 assignable_global,
802 ty,
803 };
804 Ok(())
805 }
806
807 #[allow(clippy::or_fun_call)]
817 fn process_block(
818 &mut self,
819 statements: &crate::Block,
820 other_functions: &[FunctionInfo],
821 mut disruptor: Option<UniformityDisruptor>,
822 expression_arena: &Arena<crate::Expression>,
823 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
824 use crate::Statement as S;
825
826 let mut combined_uniformity = FunctionUniformity::new();
827 for statement in statements {
828 let uniformity = match *statement {
829 S::Emit(ref range) => {
830 let mut requirements = UniformityRequirements::empty();
831 for expr in range.clone() {
832 let req = self.expressions[expr.index()].uniformity.requirements;
833 if self
834 .flags
835 .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY)
836 && !req.is_empty()
837 {
838 if let Some(cause) = disruptor {
839 return Err(FunctionError::NonUniformControlFlow(req, expr, cause)
840 .with_span_handle(expr, expression_arena));
841 }
842 }
843 requirements |= req;
844 }
845 FunctionUniformity {
846 result: Uniformity {
847 non_uniform_result: None,
848 requirements,
849 },
850 exit: ExitFlags::empty(),
851 }
852 }
853 S::Break | S::Continue => FunctionUniformity::new(),
854 S::Kill => FunctionUniformity {
855 result: Uniformity::new(),
856 exit: if disruptor.is_some() {
857 ExitFlags::MAY_KILL
858 } else {
859 ExitFlags::empty()
860 },
861 },
862 S::Barrier(_) => FunctionUniformity {
863 result: Uniformity {
864 non_uniform_result: None,
865 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
866 },
867 exit: ExitFlags::empty(),
868 },
869 S::WorkGroupUniformLoad { pointer, .. } => {
870 let _condition_nur = self.add_ref(pointer);
871
872 FunctionUniformity {
891 result: Uniformity {
892 non_uniform_result: None,
893 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
894 },
895 exit: ExitFlags::empty(),
896 }
897 }
898 S::Block(ref b) => {
899 self.process_block(b, other_functions, disruptor, expression_arena)?
900 }
901 S::If {
902 condition,
903 ref accept,
904 ref reject,
905 } => {
906 let condition_nur = self.add_ref(condition);
907 let branch_disruptor =
908 disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
909 let accept_uniformity = self.process_block(
910 accept,
911 other_functions,
912 branch_disruptor,
913 expression_arena,
914 )?;
915 let reject_uniformity = self.process_block(
916 reject,
917 other_functions,
918 branch_disruptor,
919 expression_arena,
920 )?;
921 accept_uniformity | reject_uniformity
922 }
923 S::Switch {
924 selector,
925 ref cases,
926 } => {
927 let selector_nur = self.add_ref(selector);
928 let branch_disruptor =
929 disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
930 let mut uniformity = FunctionUniformity::new();
931 let mut case_disruptor = branch_disruptor;
932 for case in cases.iter() {
933 let case_uniformity = self.process_block(
934 &case.body,
935 other_functions,
936 case_disruptor,
937 expression_arena,
938 )?;
939 case_disruptor = if case.fall_through {
940 case_disruptor.or(case_uniformity.exit_disruptor())
941 } else {
942 branch_disruptor
943 };
944 uniformity = uniformity | case_uniformity;
945 }
946 uniformity
947 }
948 S::Loop {
949 ref body,
950 ref continuing,
951 break_if,
952 } => {
953 let body_uniformity =
954 self.process_block(body, other_functions, disruptor, expression_arena)?;
955 let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
956 let continuing_uniformity = self.process_block(
957 continuing,
958 other_functions,
959 continuing_disruptor,
960 expression_arena,
961 )?;
962 if let Some(expr) = break_if {
963 let _ = self.add_ref(expr);
964 }
965 body_uniformity | continuing_uniformity
966 }
967 S::Return { value } => FunctionUniformity {
968 result: Uniformity {
969 non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
970 requirements: UniformityRequirements::empty(),
971 },
972 exit: if disruptor.is_some() {
973 ExitFlags::MAY_RETURN
974 } else {
975 ExitFlags::empty()
976 },
977 },
978 S::Store { pointer, value } => {
982 let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
983 let _ = self.add_ref(value);
984 FunctionUniformity::new()
985 }
986 S::ImageStore {
987 image,
988 coordinate,
989 array_index,
990 value,
991 } => {
992 let _ = self.add_ref_impl(image, GlobalUse::WRITE);
993 if let Some(expr) = array_index {
994 let _ = self.add_ref(expr);
995 }
996 let _ = self.add_ref(coordinate);
997 let _ = self.add_ref(value);
998 FunctionUniformity::new()
999 }
1000 S::Call {
1001 function,
1002 ref arguments,
1003 result: _,
1004 } => {
1005 for &argument in arguments {
1006 let _ = self.add_ref(argument);
1007 }
1008 let info = &other_functions[function.index()];
1009 self.process_call(info, arguments, expression_arena)?
1011 }
1012 S::Atomic {
1013 pointer,
1014 ref fun,
1015 value,
1016 result: _,
1017 } => {
1018 let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
1019 let _ = self.add_ref(value);
1020 if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
1021 let _ = self.add_ref(cmp);
1022 }
1023 FunctionUniformity::new()
1024 }
1025 S::RayQuery { query, ref fun } => {
1026 let _ = self.add_ref(query);
1027 if let crate::RayQueryFunction::Initialize {
1028 acceleration_structure,
1029 descriptor,
1030 } = *fun
1031 {
1032 let _ = self.add_ref(acceleration_structure);
1033 let _ = self.add_ref(descriptor);
1034 }
1035 FunctionUniformity::new()
1036 }
1037 S::SubgroupBallot {
1038 result: _,
1039 predicate,
1040 } => {
1041 if let Some(predicate) = predicate {
1042 let _ = self.add_ref(predicate);
1043 }
1044 FunctionUniformity::new()
1045 }
1046 S::SubgroupCollectiveOperation {
1047 op: _,
1048 collective_op: _,
1049 argument,
1050 result: _,
1051 } => {
1052 let _ = self.add_ref(argument);
1053 FunctionUniformity::new()
1054 }
1055 S::SubgroupGather {
1056 mode,
1057 argument,
1058 result: _,
1059 } => {
1060 let _ = self.add_ref(argument);
1061 match mode {
1062 crate::GatherMode::BroadcastFirst => {}
1063 crate::GatherMode::Broadcast(index)
1064 | crate::GatherMode::Shuffle(index)
1065 | crate::GatherMode::ShuffleDown(index)
1066 | crate::GatherMode::ShuffleUp(index)
1067 | crate::GatherMode::ShuffleXor(index) => {
1068 let _ = self.add_ref(index);
1069 }
1070 }
1071 FunctionUniformity::new()
1072 }
1073 };
1074
1075 disruptor = disruptor.or(uniformity.exit_disruptor());
1076 combined_uniformity = combined_uniformity | uniformity;
1077 }
1078 Ok(combined_uniformity)
1079 }
1080}
1081
1082impl ModuleInfo {
1083 pub(super) fn process_const_expression(
1085 &mut self,
1086 handle: Handle<crate::Expression>,
1087 resolve_context: &ResolveContext,
1088 gctx: crate::proc::GlobalCtx,
1089 ) -> Result<(), super::ConstExpressionError> {
1090 self.const_expression_types[handle.index()] =
1091 resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?;
1092 Ok(())
1093 }
1094
1095 pub(super) fn process_function(
1098 &self,
1099 fun: &crate::Function,
1100 module: &crate::Module,
1101 flags: ValidationFlags,
1102 capabilities: super::Capabilities,
1103 ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1104 let mut info = FunctionInfo {
1105 flags,
1106 available_stages: ShaderStages::all(),
1107 uniformity: Uniformity::new(),
1108 may_kill: false,
1109 sampling_set: crate::FastHashSet::default(),
1110 global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
1111 expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
1112 sampling: crate::FastHashSet::default(),
1113 dual_source_blending: false,
1114 };
1115 let resolve_context =
1116 ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
1117
1118 for (handle, _) in fun.expressions.iter() {
1119 if let Err(source) = info.process_expression(
1120 handle,
1121 &fun.expressions,
1122 &self.functions,
1123 &resolve_context,
1124 capabilities,
1125 ) {
1126 return Err(FunctionError::Expression { handle, source }
1127 .with_span_handle(handle, &fun.expressions));
1128 }
1129 }
1130
1131 for (_, expr) in fun.local_variables.iter() {
1132 if let Some(init) = expr.init {
1133 let _ = info.add_ref(init);
1134 }
1135 }
1136
1137 let uniformity = info.process_block(&fun.body, &self.functions, None, &fun.expressions)?;
1138 info.uniformity = uniformity.result;
1139 info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
1140
1141 Ok(info)
1142 }
1143
1144 pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
1145 &self.entry_points[index]
1146 }
1147}
1148
1149#[test]
1150fn uniform_control_flow() {
1151 use crate::{Expression as E, Statement as S};
1152
1153 let mut type_arena = crate::UniqueArena::new();
1154 let ty = type_arena.insert(
1155 crate::Type {
1156 name: None,
1157 inner: crate::TypeInner::Vector {
1158 size: crate::VectorSize::Bi,
1159 scalar: crate::Scalar::F32,
1160 },
1161 },
1162 Default::default(),
1163 );
1164 let mut global_var_arena = Arena::new();
1165 let non_uniform_global = global_var_arena.append(
1166 crate::GlobalVariable {
1167 name: None,
1168 init: None,
1169 ty,
1170 space: crate::AddressSpace::Handle,
1171 binding: None,
1172 },
1173 Default::default(),
1174 );
1175 let uniform_global = global_var_arena.append(
1176 crate::GlobalVariable {
1177 name: None,
1178 init: None,
1179 ty,
1180 binding: None,
1181 space: crate::AddressSpace::Uniform,
1182 },
1183 Default::default(),
1184 );
1185
1186 let mut expressions = Arena::new();
1187 let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default());
1189 let derivative_expr = expressions.append(
1191 E::Derivative {
1192 axis: crate::DerivativeAxis::X,
1193 ctrl: crate::DerivativeControl::None,
1194 expr: constant_expr,
1195 },
1196 Default::default(),
1197 );
1198 let emit_range_constant_derivative = expressions.range_from(0);
1199 let non_uniform_global_expr =
1200 expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
1201 let uniform_global_expr =
1202 expressions.append(E::GlobalVariable(uniform_global), Default::default());
1203 let emit_range_globals = expressions.range_from(2);
1204
1205 let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
1207 let access_expr = expressions.append(
1209 E::AccessIndex {
1210 base: non_uniform_global_expr,
1211 index: 1,
1212 },
1213 Default::default(),
1214 );
1215 let emit_range_query_access_globals = expressions.range_from(2);
1216
1217 let mut info = FunctionInfo {
1218 flags: ValidationFlags::all(),
1219 available_stages: ShaderStages::all(),
1220 uniformity: Uniformity::new(),
1221 may_kill: false,
1222 sampling_set: crate::FastHashSet::default(),
1223 global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
1224 expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
1225 sampling: crate::FastHashSet::default(),
1226 dual_source_blending: false,
1227 };
1228 let resolve_context = ResolveContext {
1229 constants: &Arena::new(),
1230 overrides: &Arena::new(),
1231 types: &type_arena,
1232 special_types: &crate::SpecialTypes::default(),
1233 global_vars: &global_var_arena,
1234 local_vars: &Arena::new(),
1235 functions: &Arena::new(),
1236 arguments: &[],
1237 };
1238 for (handle, _) in expressions.iter() {
1239 info.process_expression(
1240 handle,
1241 &expressions,
1242 &[],
1243 &resolve_context,
1244 super::Capabilities::empty(),
1245 )
1246 .unwrap();
1247 }
1248 assert_eq!(info[non_uniform_global_expr].ref_count, 1);
1249 assert_eq!(info[uniform_global_expr].ref_count, 1);
1250 assert_eq!(info[query_expr].ref_count, 0);
1251 assert_eq!(info[access_expr].ref_count, 0);
1252 assert_eq!(info[non_uniform_global], GlobalUse::empty());
1253 assert_eq!(info[uniform_global], GlobalUse::QUERY);
1254
1255 let stmt_emit1 = S::Emit(emit_range_globals.clone());
1256 let stmt_if_uniform = S::If {
1257 condition: uniform_global_expr,
1258 accept: crate::Block::new(),
1259 reject: vec![
1260 S::Emit(emit_range_constant_derivative.clone()),
1261 S::Store {
1262 pointer: constant_expr,
1263 value: derivative_expr,
1264 },
1265 ]
1266 .into(),
1267 };
1268 assert_eq!(
1269 info.process_block(
1270 &vec![stmt_emit1, stmt_if_uniform].into(),
1271 &[],
1272 None,
1273 &expressions
1274 ),
1275 Ok(FunctionUniformity {
1276 result: Uniformity {
1277 non_uniform_result: None,
1278 requirements: UniformityRequirements::DERIVATIVE,
1279 },
1280 exit: ExitFlags::empty(),
1281 }),
1282 );
1283 assert_eq!(info[constant_expr].ref_count, 2);
1284 assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
1285
1286 let stmt_emit2 = S::Emit(emit_range_globals.clone());
1287 let stmt_if_non_uniform = S::If {
1288 condition: non_uniform_global_expr,
1289 accept: vec![
1290 S::Emit(emit_range_constant_derivative),
1291 S::Store {
1292 pointer: constant_expr,
1293 value: derivative_expr,
1294 },
1295 ]
1296 .into(),
1297 reject: crate::Block::new(),
1298 };
1299 {
1300 let block_info = info.process_block(
1301 &vec![stmt_emit2, stmt_if_non_uniform].into(),
1302 &[],
1303 None,
1304 &expressions,
1305 );
1306 if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE {
1307 assert_eq!(info[derivative_expr].ref_count, 2);
1308 } else {
1309 assert_eq!(
1310 block_info,
1311 Err(FunctionError::NonUniformControlFlow(
1312 UniformityRequirements::DERIVATIVE,
1313 derivative_expr,
1314 UniformityDisruptor::Expression(non_uniform_global_expr)
1315 )
1316 .with_span()),
1317 );
1318 assert_eq!(info[derivative_expr].ref_count, 1);
1319 }
1320 }
1321 assert_eq!(info[non_uniform_global], GlobalUse::READ);
1322
1323 let stmt_emit3 = S::Emit(emit_range_globals);
1324 let stmt_return_non_uniform = S::Return {
1325 value: Some(non_uniform_global_expr),
1326 };
1327 assert_eq!(
1328 info.process_block(
1329 &vec![stmt_emit3, stmt_return_non_uniform].into(),
1330 &[],
1331 Some(UniformityDisruptor::Return),
1332 &expressions
1333 ),
1334 Ok(FunctionUniformity {
1335 result: Uniformity {
1336 non_uniform_result: Some(non_uniform_global_expr),
1337 requirements: UniformityRequirements::empty(),
1338 },
1339 exit: ExitFlags::MAY_RETURN,
1340 }),
1341 );
1342 assert_eq!(info[non_uniform_global_expr].ref_count, 3);
1343
1344 let stmt_emit4 = S::Emit(emit_range_query_access_globals);
1346 let stmt_assign = S::Store {
1347 pointer: access_expr,
1348 value: query_expr,
1349 };
1350 let stmt_return_pointer = S::Return {
1351 value: Some(access_expr),
1352 };
1353 let stmt_kill = S::Kill;
1354 assert_eq!(
1355 info.process_block(
1356 &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
1357 &[],
1358 Some(UniformityDisruptor::Discard),
1359 &expressions
1360 ),
1361 Ok(FunctionUniformity {
1362 result: Uniformity {
1363 non_uniform_result: Some(non_uniform_global_expr),
1364 requirements: UniformityRequirements::empty(),
1365 },
1366 exit: ExitFlags::all(),
1367 }),
1368 );
1369 assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
1370}