naga/valid/
analyzer.rs

1//! Module analyzer.
2//!
3//! Figures out the following properties:
4//! - control flow uniformity
5//! - texture/sampler pairs
6//! - expression reference counts
7
8use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
9use crate::diagnostic_filter::{DiagnosticFilterNode, StandardFilterableTriggeringRule};
10use crate::span::{AddSpan as _, WithSpan};
11use crate::{
12    arena::{Arena, Handle},
13    proc::{ResolveContext, TypeResolution},
14};
15use std::ops;
16
17pub type NonUniformResult = Option<Handle<crate::Expression>>;
18
19const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true;
20
21bitflags::bitflags! {
22    /// Kinds of expressions that require uniform control flow.
23    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
24    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
25    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
26    pub struct UniformityRequirements: u8 {
27        const WORK_GROUP_BARRIER = 0x1;
28        const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 };
29        const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 };
30    }
31}
32
33/// Uniform control flow characteristics.
34#[derive(Clone, Debug)]
35#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
36#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
37#[cfg_attr(test, derive(PartialEq))]
38pub struct Uniformity {
39    /// A child expression with non-uniform result.
40    ///
41    /// This means, when the relevant invocations are scheduled on a compute unit,
42    /// they have to use vector registers to store an individual value
43    /// per invocation.
44    ///
45    /// Whenever the control flow is conditioned on such value,
46    /// the hardware needs to keep track of the mask of invocations,
47    /// and process all branches of the control flow.
48    ///
49    /// Any operations that depend on non-uniform results also produce non-uniform.
50    pub non_uniform_result: NonUniformResult,
51    /// If this expression requires uniform control flow, store the reason here.
52    pub requirements: UniformityRequirements,
53}
54
55impl Uniformity {
56    const fn new() -> Self {
57        Uniformity {
58            non_uniform_result: None,
59            requirements: UniformityRequirements::empty(),
60        }
61    }
62}
63
64bitflags::bitflags! {
65    #[derive(Clone, Copy, Debug, PartialEq)]
66    struct ExitFlags: u8 {
67        /// Control flow may return from the function, which makes all the
68        /// subsequent statements within the current function (only!)
69        /// to be executed in a non-uniform control flow.
70        const MAY_RETURN = 0x1;
71        /// Control flow may be killed. Anything after [`Statement::Kill`] is
72        /// considered inside non-uniform context.
73        ///
74        /// [`Statement::Kill`]: crate::Statement::Kill
75        const MAY_KILL = 0x2;
76    }
77}
78
79/// Uniformity characteristics of a function.
80#[cfg_attr(test, derive(Debug, PartialEq))]
81struct FunctionUniformity {
82    result: Uniformity,
83    exit: ExitFlags,
84}
85
86impl ops::BitOr for FunctionUniformity {
87    type Output = Self;
88    fn bitor(self, other: Self) -> Self {
89        FunctionUniformity {
90            result: Uniformity {
91                non_uniform_result: self
92                    .result
93                    .non_uniform_result
94                    .or(other.result.non_uniform_result),
95                requirements: self.result.requirements | other.result.requirements,
96            },
97            exit: self.exit | other.exit,
98        }
99    }
100}
101
102impl FunctionUniformity {
103    const fn new() -> Self {
104        FunctionUniformity {
105            result: Uniformity::new(),
106            exit: ExitFlags::empty(),
107        }
108    }
109
110    /// Returns a disruptor based on the stored exit flags, if any.
111    const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
112        if self.exit.contains(ExitFlags::MAY_RETURN) {
113            Some(UniformityDisruptor::Return)
114        } else if self.exit.contains(ExitFlags::MAY_KILL) {
115            Some(UniformityDisruptor::Discard)
116        } else {
117            None
118        }
119    }
120}
121
122bitflags::bitflags! {
123    /// Indicates how a global variable is used.
124    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
125    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
126    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
127    pub struct GlobalUse: u8 {
128        /// Data will be read from the variable.
129        const READ = 0x1;
130        /// Data will be written to the variable.
131        const WRITE = 0x2;
132        /// The information about the data is queried.
133        const QUERY = 0x4;
134        /// Atomic operations will be performed on the variable.
135        const ATOMIC = 0x8;
136    }
137}
138
139#[derive(Clone, Debug, Eq, Hash, PartialEq)]
140#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
141#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
142pub struct SamplingKey {
143    pub image: Handle<crate::GlobalVariable>,
144    pub sampler: Handle<crate::GlobalVariable>,
145}
146
147#[derive(Clone, Debug)]
148#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
149#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
150/// Information about an expression in a function body.
151pub struct ExpressionInfo {
152    /// Whether this expression is uniform, and why.
153    ///
154    /// If this expression's value is not uniform, this is the handle
155    /// of the expression from which this one's non-uniformity
156    /// originates. Otherwise, this is `None`.
157    pub uniformity: Uniformity,
158
159    /// The number of statements and other expressions using this
160    /// expression's value.
161    pub ref_count: usize,
162
163    /// The global variable into which this expression produces a pointer.
164    ///
165    /// This is `None` unless this expression is either a
166    /// [`GlobalVariable`], or an [`Access`] or [`AccessIndex`] that
167    /// ultimately refers to some part of a global.
168    ///
169    /// [`Load`] expressions applied to pointer-typed arguments could
170    /// refer to globals, but we leave this as `None` for them.
171    ///
172    /// [`GlobalVariable`]: crate::Expression::GlobalVariable
173    /// [`Access`]: crate::Expression::Access
174    /// [`AccessIndex`]: crate::Expression::AccessIndex
175    /// [`Load`]: crate::Expression::Load
176    assignable_global: Option<Handle<crate::GlobalVariable>>,
177
178    /// The type of this expression.
179    pub ty: TypeResolution,
180}
181
182impl ExpressionInfo {
183    const fn new() -> Self {
184        ExpressionInfo {
185            uniformity: Uniformity::new(),
186            ref_count: 0,
187            assignable_global: None,
188            // this doesn't matter at this point, will be overwritten
189            ty: TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
190                kind: crate::ScalarKind::Bool,
191                width: 0,
192            })),
193        }
194    }
195}
196
197#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
198#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
199#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
200enum GlobalOrArgument {
201    Global(Handle<crate::GlobalVariable>),
202    Argument(u32),
203}
204
205impl GlobalOrArgument {
206    fn from_expression(
207        expression_arena: &Arena<crate::Expression>,
208        expression: Handle<crate::Expression>,
209    ) -> Result<GlobalOrArgument, ExpressionError> {
210        Ok(match expression_arena[expression] {
211            crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
212            crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
213            crate::Expression::Access { base, .. }
214            | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
215                crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
216                _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
217            },
218            _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
219        })
220    }
221}
222
223#[derive(Debug, Clone, PartialEq, Eq, Hash)]
224#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
225#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
226struct Sampling {
227    image: GlobalOrArgument,
228    sampler: GlobalOrArgument,
229}
230
231#[derive(Debug, Clone)]
232#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
233#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
234pub struct FunctionInfo {
235    /// Validation flags.
236    #[allow(dead_code)]
237    flags: ValidationFlags,
238    /// Set of shader stages where calling this function is valid.
239    pub available_stages: ShaderStages,
240    /// Uniformity characteristics.
241    pub uniformity: Uniformity,
242    /// Function may kill the invocation.
243    pub may_kill: bool,
244
245    /// All pairs of (texture, sampler) globals that may be used together in
246    /// sampling operations by this function and its callees. This includes
247    /// pairings that arise when this function passes textures and samplers as
248    /// arguments to its callees.
249    ///
250    /// This table does not include uses of textures and samplers passed as
251    /// arguments to this function itself, since we do not know which globals
252    /// those will be. However, this table *is* exhaustive when computed for an
253    /// entry point function: entry points never receive textures or samplers as
254    /// arguments, so all an entry point's sampling can be reported in terms of
255    /// globals.
256    ///
257    /// The GLSL back end uses this table to construct reflection info that
258    /// clients need to construct texture-combined sampler values.
259    pub sampling_set: crate::FastHashSet<SamplingKey>,
260
261    /// How this function and its callees use this module's globals.
262    ///
263    /// This is indexed by `Handle<GlobalVariable>` indices. However,
264    /// `FunctionInfo` implements `std::ops::Index<Handle<GlobalVariable>>`,
265    /// so you can simply index this struct with a global handle to retrieve
266    /// its usage information.
267    global_uses: Box<[GlobalUse]>,
268
269    /// Information about each expression in this function's body.
270    ///
271    /// This is indexed by `Handle<Expression>` indices. However, `FunctionInfo`
272    /// implements `std::ops::Index<Handle<Expression>>`, so you can simply
273    /// index this struct with an expression handle to retrieve its
274    /// `ExpressionInfo`.
275    expressions: Box<[ExpressionInfo]>,
276
277    /// All (texture, sampler) pairs that may be used together in sampling
278    /// operations by this function and its callees, whether they are accessed
279    /// as globals or passed as arguments.
280    ///
281    /// Participants are represented by [`GlobalVariable`] handles whenever
282    /// possible, and otherwise by indices of this function's arguments.
283    ///
284    /// When analyzing a function call, we combine this data about the callee
285    /// with the actual arguments being passed to produce the callers' own
286    /// `sampling_set` and `sampling` tables.
287    ///
288    /// [`GlobalVariable`]: crate::GlobalVariable
289    sampling: crate::FastHashSet<Sampling>,
290
291    /// Indicates that the function is using dual source blending.
292    pub dual_source_blending: bool,
293
294    /// The leaf of all module-wide diagnostic filter rules tree parsed from directives in this
295    /// module.
296    ///
297    /// See [`DiagnosticFilterNode`] for details on how the tree is represented and used in
298    /// validation.
299    diagnostic_filter_leaf: Option<Handle<DiagnosticFilterNode>>,
300}
301
302impl FunctionInfo {
303    pub const fn global_variable_count(&self) -> usize {
304        self.global_uses.len()
305    }
306    pub const fn expression_count(&self) -> usize {
307        self.expressions.len()
308    }
309    pub fn dominates_global_use(&self, other: &Self) -> bool {
310        for (self_global_uses, other_global_uses) in
311            self.global_uses.iter().zip(other.global_uses.iter())
312        {
313            if !self_global_uses.contains(*other_global_uses) {
314                return false;
315            }
316        }
317        true
318    }
319}
320
321impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
322    type Output = GlobalUse;
323    fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
324        &self.global_uses[handle.index()]
325    }
326}
327
328impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
329    type Output = ExpressionInfo;
330    fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
331        &self.expressions[handle.index()]
332    }
333}
334
335/// Disruptor of the uniform control flow.
336#[derive(Clone, Copy, Debug, thiserror::Error)]
337#[cfg_attr(test, derive(PartialEq))]
338pub enum UniformityDisruptor {
339    #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
340    Expression(Handle<crate::Expression>),
341    #[error("There is a Return earlier in the control flow of the function")]
342    Return,
343    #[error("There is a Discard earlier in the entry point across all called functions")]
344    Discard,
345}
346
347impl FunctionInfo {
348    /// Record a use of `expr` of the sort given by `global_use`.
349    ///
350    /// Bump `expr`'s reference count, and return its uniformity.
351    ///
352    /// If `expr` is a pointer to a global variable, or some part of
353    /// a global variable, add `global_use` to that global's set of
354    /// uses.
355    #[must_use]
356    fn add_ref_impl(
357        &mut self,
358        expr: Handle<crate::Expression>,
359        global_use: GlobalUse,
360    ) -> NonUniformResult {
361        let info = &mut self.expressions[expr.index()];
362        info.ref_count += 1;
363        // mark the used global as read
364        if let Some(global) = info.assignable_global {
365            self.global_uses[global.index()] |= global_use;
366        }
367        info.uniformity.non_uniform_result
368    }
369
370    /// Record a use of `expr` for its value.
371    ///
372    /// This is used for almost all expression references. Anything
373    /// that writes to the value `expr` points to, or otherwise wants
374    /// contribute flags other than `GlobalUse::READ`, should use
375    /// `add_ref_impl` directly.
376    #[must_use]
377    fn add_ref(&mut self, expr: Handle<crate::Expression>) -> NonUniformResult {
378        self.add_ref_impl(expr, GlobalUse::READ)
379    }
380
381    /// Record a use of `expr`, and indicate which global variable it
382    /// refers to, if any.
383    ///
384    /// Bump `expr`'s reference count, and return its uniformity.
385    ///
386    /// If `expr` is a pointer to a global variable, or some part
387    /// thereof, store that global in `*assignable_global`. Leave the
388    /// global's uses unchanged.
389    ///
390    /// This is used to determine the [`assignable_global`] for
391    /// [`Access`] and [`AccessIndex`] expressions that ultimately
392    /// refer to a global variable. Those expressions don't contribute
393    /// any usage to the global themselves; that depends on how other
394    /// expressions use them.
395    ///
396    /// [`assignable_global`]: ExpressionInfo::assignable_global
397    /// [`Access`]: crate::Expression::Access
398    /// [`AccessIndex`]: crate::Expression::AccessIndex
399    #[must_use]
400    fn add_assignable_ref(
401        &mut self,
402        expr: Handle<crate::Expression>,
403        assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
404    ) -> NonUniformResult {
405        let info = &mut self.expressions[expr.index()];
406        info.ref_count += 1;
407        // propagate the assignable global up the chain, till it either hits
408        // a value-type expression, or the assignment statement.
409        if let Some(global) = info.assignable_global {
410            if let Some(_old) = assignable_global.replace(global) {
411                unreachable!()
412            }
413        }
414        info.uniformity.non_uniform_result
415    }
416
417    /// Inherit information from a called function.
418    fn process_call(
419        &mut self,
420        callee: &Self,
421        arguments: &[Handle<crate::Expression>],
422        expression_arena: &Arena<crate::Expression>,
423    ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
424        self.sampling_set
425            .extend(callee.sampling_set.iter().cloned());
426        for sampling in callee.sampling.iter() {
427            // If the callee was passed the texture or sampler as an argument,
428            // we may now be able to determine which globals those referred to.
429            let image_storage = match sampling.image {
430                GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
431                GlobalOrArgument::Argument(i) => {
432                    let Some(handle) = arguments.get(i as usize).cloned() else {
433                        // Argument count mismatch, will be reported later by validate_call
434                        break;
435                    };
436                    GlobalOrArgument::from_expression(expression_arena, handle).map_err(
437                        |source| {
438                            FunctionError::Expression { handle, source }
439                                .with_span_handle(handle, expression_arena)
440                        },
441                    )?
442                }
443            };
444
445            let sampler_storage = match sampling.sampler {
446                GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
447                GlobalOrArgument::Argument(i) => {
448                    let Some(handle) = arguments.get(i as usize).cloned() else {
449                        // Argument count mismatch, will be reported later by validate_call
450                        break;
451                    };
452                    GlobalOrArgument::from_expression(expression_arena, handle).map_err(
453                        |source| {
454                            FunctionError::Expression { handle, source }
455                                .with_span_handle(handle, expression_arena)
456                        },
457                    )?
458                }
459            };
460
461            // If we've managed to pin both the image and sampler down to
462            // specific globals, record that in our `sampling_set`. Otherwise,
463            // record as much as we do know in our own `sampling` table, for our
464            // callers to sort out.
465            match (image_storage, sampler_storage) {
466                (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
467                    self.sampling_set.insert(SamplingKey { image, sampler });
468                }
469                (image, sampler) => {
470                    self.sampling.insert(Sampling { image, sampler });
471                }
472            }
473        }
474
475        // Inherit global use from our callees.
476        for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
477            *mine |= *other;
478        }
479
480        Ok(FunctionUniformity {
481            result: callee.uniformity.clone(),
482            exit: if callee.may_kill {
483                ExitFlags::MAY_KILL
484            } else {
485                ExitFlags::empty()
486            },
487        })
488    }
489
490    /// Compute the [`ExpressionInfo`] for `handle`.
491    ///
492    /// Replace the dummy entry in [`self.expressions`] for `handle`
493    /// with a real `ExpressionInfo` value describing that expression.
494    ///
495    /// This function is called as part of a forward sweep through the
496    /// arena, so we can assume that all earlier expressions in the
497    /// arena already have valid info. Since expressions only depend
498    /// on earlier expressions, this includes all our subexpressions.
499    ///
500    /// Adjust the reference counts on all expressions we use.
501    ///
502    /// Also populate the [`sampling_set`], [`sampling`] and
503    /// [`global_uses`] fields of `self`.
504    ///
505    /// [`self.expressions`]: FunctionInfo::expressions
506    /// [`sampling_set`]: FunctionInfo::sampling_set
507    /// [`sampling`]: FunctionInfo::sampling
508    /// [`global_uses`]: FunctionInfo::global_uses
509    #[allow(clippy::or_fun_call)]
510    fn process_expression(
511        &mut self,
512        handle: Handle<crate::Expression>,
513        expression_arena: &Arena<crate::Expression>,
514        other_functions: &[FunctionInfo],
515        resolve_context: &ResolveContext,
516        capabilities: super::Capabilities,
517    ) -> Result<(), ExpressionError> {
518        use crate::{Expression as E, SampleLevel as Sl};
519
520        let expression = &expression_arena[handle];
521        let mut assignable_global = None;
522        let uniformity = match *expression {
523            E::Access { base, index } => {
524                let base_ty = self[base].ty.inner_with(resolve_context.types);
525
526                // build up the caps needed if this is indexed non-uniformly
527                let mut needed_caps = super::Capabilities::empty();
528                let is_binding_array = match *base_ty {
529                    crate::TypeInner::BindingArray {
530                        base: array_element_ty_handle,
531                        ..
532                    } => {
533                        // these are nasty aliases, but these idents are too long and break rustfmt
534                        let ub_st = super::Capabilities::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING;
535                        let st_sb = super::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
536                        let sampler = super::Capabilities::SAMPLER_NON_UNIFORM_INDEXING;
537
538                        // We're a binding array, so lets use the type of _what_ we are array of to determine if we can non-uniformly index it.
539                        let array_element_ty =
540                            &resolve_context.types[array_element_ty_handle].inner;
541
542                        needed_caps |= match *array_element_ty {
543                            // If we're an image, use the appropriate limit.
544                            crate::TypeInner::Image { class, .. } => match class {
545                                crate::ImageClass::Storage { .. } => ub_st,
546                                _ => st_sb,
547                            },
548                            crate::TypeInner::Sampler { .. } => sampler,
549                            // If we're anything but an image, assume we're a buffer and use the address space.
550                            _ => {
551                                if let E::GlobalVariable(global_handle) = expression_arena[base] {
552                                    let global = &resolve_context.global_vars[global_handle];
553                                    match global.space {
554                                        crate::AddressSpace::Uniform => ub_st,
555                                        crate::AddressSpace::Storage { .. } => st_sb,
556                                        _ => unreachable!(),
557                                    }
558                                } else {
559                                    unreachable!()
560                                }
561                            }
562                        };
563
564                        true
565                    }
566                    _ => false,
567                };
568
569                if self[index].uniformity.non_uniform_result.is_some()
570                    && !capabilities.contains(needed_caps)
571                    && is_binding_array
572                {
573                    return Err(ExpressionError::MissingCapabilities(needed_caps));
574                }
575
576                Uniformity {
577                    non_uniform_result: self
578                        .add_assignable_ref(base, &mut assignable_global)
579                        .or(self.add_ref(index)),
580                    requirements: UniformityRequirements::empty(),
581                }
582            }
583            E::AccessIndex { base, .. } => Uniformity {
584                non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
585                requirements: UniformityRequirements::empty(),
586            },
587            // always uniform
588            E::Splat { size: _, value } => Uniformity {
589                non_uniform_result: self.add_ref(value),
590                requirements: UniformityRequirements::empty(),
591            },
592            E::Swizzle { vector, .. } => Uniformity {
593                non_uniform_result: self.add_ref(vector),
594                requirements: UniformityRequirements::empty(),
595            },
596            E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(),
597            E::Compose { ref components, .. } => {
598                let non_uniform_result = components
599                    .iter()
600                    .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
601                Uniformity {
602                    non_uniform_result,
603                    requirements: UniformityRequirements::empty(),
604                }
605            }
606            // depends on the builtin
607            E::FunctionArgument(index) => {
608                let arg = &resolve_context.arguments[index as usize];
609                let uniform = match arg.binding {
610                    Some(crate::Binding::BuiltIn(
611                        // per-work-group built-ins are uniform
612                        crate::BuiltIn::WorkGroupId
613                        | crate::BuiltIn::WorkGroupSize
614                        | crate::BuiltIn::NumWorkGroups,
615                    )) => true,
616                    _ => false,
617                };
618                Uniformity {
619                    non_uniform_result: if uniform { None } else { Some(handle) },
620                    requirements: UniformityRequirements::empty(),
621                }
622            }
623            // depends on the address space
624            E::GlobalVariable(gh) => {
625                use crate::AddressSpace as As;
626                assignable_global = Some(gh);
627                let var = &resolve_context.global_vars[gh];
628                let uniform = match var.space {
629                    // local data is non-uniform
630                    As::Function | As::Private => false,
631                    // workgroup memory is exclusively accessed by the group
632                    As::WorkGroup => true,
633                    // uniform data
634                    As::Uniform | As::PushConstant => true,
635                    // storage data is only uniform when read-only
636                    As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
637                    As::Handle => false,
638                };
639                Uniformity {
640                    non_uniform_result: if uniform { None } else { Some(handle) },
641                    requirements: UniformityRequirements::empty(),
642                }
643            }
644            E::LocalVariable(_) => Uniformity {
645                non_uniform_result: Some(handle),
646                requirements: UniformityRequirements::empty(),
647            },
648            E::Load { pointer } => Uniformity {
649                non_uniform_result: self.add_ref(pointer),
650                requirements: UniformityRequirements::empty(),
651            },
652            E::ImageSample {
653                image,
654                sampler,
655                gather: _,
656                coordinate,
657                array_index,
658                offset: _,
659                level,
660                depth_ref,
661            } => {
662                let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
663                let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
664
665                match (image_storage, sampler_storage) {
666                    (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
667                        self.sampling_set.insert(SamplingKey { image, sampler });
668                    }
669                    _ => {
670                        self.sampling.insert(Sampling {
671                            image: image_storage,
672                            sampler: sampler_storage,
673                        });
674                    }
675                }
676
677                // "nur" == "Non-Uniform Result"
678                let array_nur = array_index.and_then(|h| self.add_ref(h));
679                let level_nur = match level {
680                    Sl::Auto | Sl::Zero => None,
681                    Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
682                    Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
683                };
684                let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
685                Uniformity {
686                    non_uniform_result: self
687                        .add_ref(image)
688                        .or(self.add_ref(sampler))
689                        .or(self.add_ref(coordinate))
690                        .or(array_nur)
691                        .or(level_nur)
692                        .or(dref_nur),
693                    requirements: if level.implicit_derivatives() {
694                        UniformityRequirements::IMPLICIT_LEVEL
695                    } else {
696                        UniformityRequirements::empty()
697                    },
698                }
699            }
700            E::ImageLoad {
701                image,
702                coordinate,
703                array_index,
704                sample,
705                level,
706            } => {
707                let array_nur = array_index.and_then(|h| self.add_ref(h));
708                let sample_nur = sample.and_then(|h| self.add_ref(h));
709                let level_nur = level.and_then(|h| self.add_ref(h));
710                Uniformity {
711                    non_uniform_result: self
712                        .add_ref(image)
713                        .or(self.add_ref(coordinate))
714                        .or(array_nur)
715                        .or(sample_nur)
716                        .or(level_nur),
717                    requirements: UniformityRequirements::empty(),
718                }
719            }
720            E::ImageQuery { image, query } => {
721                let query_nur = match query {
722                    crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
723                    _ => None,
724                };
725                Uniformity {
726                    non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
727                    requirements: UniformityRequirements::empty(),
728                }
729            }
730            E::Unary { expr, .. } => Uniformity {
731                non_uniform_result: self.add_ref(expr),
732                requirements: UniformityRequirements::empty(),
733            },
734            E::Binary { left, right, .. } => Uniformity {
735                non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
736                requirements: UniformityRequirements::empty(),
737            },
738            E::Select {
739                condition,
740                accept,
741                reject,
742            } => Uniformity {
743                non_uniform_result: self
744                    .add_ref(condition)
745                    .or(self.add_ref(accept))
746                    .or(self.add_ref(reject)),
747                requirements: UniformityRequirements::empty(),
748            },
749            // explicit derivatives require uniform
750            E::Derivative { expr, .. } => Uniformity {
751                //Note: taking a derivative of a uniform doesn't make it non-uniform
752                non_uniform_result: self.add_ref(expr),
753                requirements: UniformityRequirements::DERIVATIVE,
754            },
755            E::Relational { argument, .. } => Uniformity {
756                non_uniform_result: self.add_ref(argument),
757                requirements: UniformityRequirements::empty(),
758            },
759            E::Math {
760                fun: _,
761                arg,
762                arg1,
763                arg2,
764                arg3,
765            } => {
766                let arg1_nur = arg1.and_then(|h| self.add_ref(h));
767                let arg2_nur = arg2.and_then(|h| self.add_ref(h));
768                let arg3_nur = arg3.and_then(|h| self.add_ref(h));
769                Uniformity {
770                    non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur),
771                    requirements: UniformityRequirements::empty(),
772                }
773            }
774            E::As { expr, .. } => Uniformity {
775                non_uniform_result: self.add_ref(expr),
776                requirements: UniformityRequirements::empty(),
777            },
778            E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
779            E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
780                non_uniform_result: Some(handle),
781                requirements: UniformityRequirements::empty(),
782            },
783            E::WorkGroupUniformLoadResult { .. } => Uniformity {
784                // The result of WorkGroupUniformLoad is always uniform by definition
785                non_uniform_result: None,
786                // The call is what cares about uniformity, not the expression
787                // This expression is never emitted, so this requirement should never be used anyway?
788                requirements: UniformityRequirements::empty(),
789            },
790            E::ArrayLength(expr) => Uniformity {
791                non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
792                requirements: UniformityRequirements::empty(),
793            },
794            E::RayQueryGetIntersection {
795                query,
796                committed: _,
797            } => Uniformity {
798                non_uniform_result: self.add_ref(query),
799                requirements: UniformityRequirements::empty(),
800            },
801            E::SubgroupBallotResult => Uniformity {
802                non_uniform_result: Some(handle),
803                requirements: UniformityRequirements::empty(),
804            },
805            E::SubgroupOperationResult { .. } => Uniformity {
806                non_uniform_result: Some(handle),
807                requirements: UniformityRequirements::empty(),
808            },
809        };
810
811        let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
812        self.expressions[handle.index()] = ExpressionInfo {
813            uniformity,
814            ref_count: 0,
815            assignable_global,
816            ty,
817        };
818        Ok(())
819    }
820
821    /// Analyzes the uniformity requirements of a block (as a sequence of statements).
822    /// Returns the uniformity characteristics at the *function* level, i.e.
823    /// whether or not the function requires to be called in uniform control flow,
824    /// and whether the produced result is not disrupting the control flow.
825    ///
826    /// The parent control flow is uniform if `disruptor.is_none()`.
827    ///
828    /// Returns a `NonUniformControlFlow` error if any of the expressions in the block
829    /// require uniformity, but the current flow is non-uniform.
830    #[allow(clippy::or_fun_call)]
831    fn process_block(
832        &mut self,
833        statements: &crate::Block,
834        other_functions: &[FunctionInfo],
835        mut disruptor: Option<UniformityDisruptor>,
836        expression_arena: &Arena<crate::Expression>,
837        diagnostic_filter_arena: &Arena<DiagnosticFilterNode>,
838    ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
839        use crate::Statement as S;
840
841        let mut combined_uniformity = FunctionUniformity::new();
842        for statement in statements {
843            let uniformity = match *statement {
844                S::Emit(ref range) => {
845                    let mut requirements = UniformityRequirements::empty();
846                    for expr in range.clone() {
847                        let req = self.expressions[expr.index()].uniformity.requirements;
848                        if self
849                            .flags
850                            .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY)
851                            && !req.is_empty()
852                        {
853                            if let Some(cause) = disruptor {
854                                let severity = DiagnosticFilterNode::search(
855                                    self.diagnostic_filter_leaf,
856                                    diagnostic_filter_arena,
857                                    StandardFilterableTriggeringRule::DerivativeUniformity,
858                                );
859                                severity.report_diag(
860                                    FunctionError::NonUniformControlFlow(req, expr, cause)
861                                        .with_span_handle(expr, expression_arena),
862                                    // TODO: Yes, this isn't contextualized with source, because
863                                    // the user is supposed to render what would normally be an
864                                    // error here. Once we actually support warning-level
865                                    // diagnostic items, then we won't need this non-compliant hack:
866                                    // <https://github.com/gfx-rs/wgpu/issues/6458>
867                                    |e, level| log::log!(level, "{e}"),
868                                )?;
869                            }
870                        }
871                        requirements |= req;
872                    }
873                    FunctionUniformity {
874                        result: Uniformity {
875                            non_uniform_result: None,
876                            requirements,
877                        },
878                        exit: ExitFlags::empty(),
879                    }
880                }
881                S::Break | S::Continue => FunctionUniformity::new(),
882                S::Kill => FunctionUniformity {
883                    result: Uniformity::new(),
884                    exit: if disruptor.is_some() {
885                        ExitFlags::MAY_KILL
886                    } else {
887                        ExitFlags::empty()
888                    },
889                },
890                S::Barrier(_) => FunctionUniformity {
891                    result: Uniformity {
892                        non_uniform_result: None,
893                        requirements: UniformityRequirements::WORK_GROUP_BARRIER,
894                    },
895                    exit: ExitFlags::empty(),
896                },
897                S::WorkGroupUniformLoad { pointer, .. } => {
898                    let _condition_nur = self.add_ref(pointer);
899
900                    // Don't check that this call occurs in uniform control flow until Naga implements WGSL's standard
901                    // uniformity analysis (https://github.com/gfx-rs/naga/issues/1744).
902                    // The uniformity analysis Naga uses now is less accurate than the one in the WGSL standard,
903                    // causing Naga to reject correct uses of `workgroupUniformLoad` in some interesting programs.
904
905                    /*
906                    if self
907                        .flags
908                        .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY)
909                    {
910                        let condition_nur = self.add_ref(pointer);
911                        let this_disruptor =
912                            disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
913                        if let Some(cause) = this_disruptor {
914                            return Err(FunctionError::NonUniformWorkgroupUniformLoad(cause)
915                                .with_span_static(*span, "WorkGroupUniformLoad"));
916                        }
917                    } */
918                    FunctionUniformity {
919                        result: Uniformity {
920                            non_uniform_result: None,
921                            requirements: UniformityRequirements::WORK_GROUP_BARRIER,
922                        },
923                        exit: ExitFlags::empty(),
924                    }
925                }
926                S::Block(ref b) => self.process_block(
927                    b,
928                    other_functions,
929                    disruptor,
930                    expression_arena,
931                    diagnostic_filter_arena,
932                )?,
933                S::If {
934                    condition,
935                    ref accept,
936                    ref reject,
937                } => {
938                    let condition_nur = self.add_ref(condition);
939                    let branch_disruptor =
940                        disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
941                    let accept_uniformity = self.process_block(
942                        accept,
943                        other_functions,
944                        branch_disruptor,
945                        expression_arena,
946                        diagnostic_filter_arena,
947                    )?;
948                    let reject_uniformity = self.process_block(
949                        reject,
950                        other_functions,
951                        branch_disruptor,
952                        expression_arena,
953                        diagnostic_filter_arena,
954                    )?;
955                    accept_uniformity | reject_uniformity
956                }
957                S::Switch {
958                    selector,
959                    ref cases,
960                } => {
961                    let selector_nur = self.add_ref(selector);
962                    let branch_disruptor =
963                        disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
964                    let mut uniformity = FunctionUniformity::new();
965                    let mut case_disruptor = branch_disruptor;
966                    for case in cases.iter() {
967                        let case_uniformity = self.process_block(
968                            &case.body,
969                            other_functions,
970                            case_disruptor,
971                            expression_arena,
972                            diagnostic_filter_arena,
973                        )?;
974                        case_disruptor = if case.fall_through {
975                            case_disruptor.or(case_uniformity.exit_disruptor())
976                        } else {
977                            branch_disruptor
978                        };
979                        uniformity = uniformity | case_uniformity;
980                    }
981                    uniformity
982                }
983                S::Loop {
984                    ref body,
985                    ref continuing,
986                    break_if,
987                } => {
988                    let body_uniformity = self.process_block(
989                        body,
990                        other_functions,
991                        disruptor,
992                        expression_arena,
993                        diagnostic_filter_arena,
994                    )?;
995                    let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
996                    let continuing_uniformity = self.process_block(
997                        continuing,
998                        other_functions,
999                        continuing_disruptor,
1000                        expression_arena,
1001                        diagnostic_filter_arena,
1002                    )?;
1003                    if let Some(expr) = break_if {
1004                        let _ = self.add_ref(expr);
1005                    }
1006                    body_uniformity | continuing_uniformity
1007                }
1008                S::Return { value } => FunctionUniformity {
1009                    result: Uniformity {
1010                        non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
1011                        requirements: UniformityRequirements::empty(),
1012                    },
1013                    exit: if disruptor.is_some() {
1014                        ExitFlags::MAY_RETURN
1015                    } else {
1016                        ExitFlags::empty()
1017                    },
1018                },
1019                // Here and below, the used expressions are already emitted,
1020                // and their results do not affect the function return value,
1021                // so we can ignore their non-uniformity.
1022                S::Store { pointer, value } => {
1023                    let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
1024                    let _ = self.add_ref(value);
1025                    FunctionUniformity::new()
1026                }
1027                S::ImageStore {
1028                    image,
1029                    coordinate,
1030                    array_index,
1031                    value,
1032                } => {
1033                    let _ = self.add_ref_impl(image, GlobalUse::WRITE);
1034                    if let Some(expr) = array_index {
1035                        let _ = self.add_ref(expr);
1036                    }
1037                    let _ = self.add_ref(coordinate);
1038                    let _ = self.add_ref(value);
1039                    FunctionUniformity::new()
1040                }
1041                S::Call {
1042                    function,
1043                    ref arguments,
1044                    result: _,
1045                } => {
1046                    for &argument in arguments {
1047                        let _ = self.add_ref(argument);
1048                    }
1049                    let info = &other_functions[function.index()];
1050                    //Note: the result is validated by the Validator, not here
1051                    self.process_call(info, arguments, expression_arena)?
1052                }
1053                S::Atomic {
1054                    pointer,
1055                    ref fun,
1056                    value,
1057                    result: _,
1058                } => {
1059                    let _ = self.add_ref_impl(pointer, GlobalUse::READ | GlobalUse::WRITE);
1060                    let _ = self.add_ref(value);
1061                    if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
1062                        let _ = self.add_ref(cmp);
1063                    }
1064                    FunctionUniformity::new()
1065                }
1066                S::ImageAtomic {
1067                    image,
1068                    coordinate,
1069                    array_index,
1070                    fun: _,
1071                    value,
1072                } => {
1073                    let _ = self.add_ref_impl(image, GlobalUse::ATOMIC);
1074                    let _ = self.add_ref(coordinate);
1075                    if let Some(expr) = array_index {
1076                        let _ = self.add_ref(expr);
1077                    }
1078                    let _ = self.add_ref(value);
1079                    FunctionUniformity::new()
1080                }
1081                S::RayQuery { query, ref fun } => {
1082                    let _ = self.add_ref(query);
1083                    if let crate::RayQueryFunction::Initialize {
1084                        acceleration_structure,
1085                        descriptor,
1086                    } = *fun
1087                    {
1088                        let _ = self.add_ref(acceleration_structure);
1089                        let _ = self.add_ref(descriptor);
1090                    }
1091                    FunctionUniformity::new()
1092                }
1093                S::SubgroupBallot {
1094                    result: _,
1095                    predicate,
1096                } => {
1097                    if let Some(predicate) = predicate {
1098                        let _ = self.add_ref(predicate);
1099                    }
1100                    FunctionUniformity::new()
1101                }
1102                S::SubgroupCollectiveOperation {
1103                    op: _,
1104                    collective_op: _,
1105                    argument,
1106                    result: _,
1107                } => {
1108                    let _ = self.add_ref(argument);
1109                    FunctionUniformity::new()
1110                }
1111                S::SubgroupGather {
1112                    mode,
1113                    argument,
1114                    result: _,
1115                } => {
1116                    let _ = self.add_ref(argument);
1117                    match mode {
1118                        crate::GatherMode::BroadcastFirst => {}
1119                        crate::GatherMode::Broadcast(index)
1120                        | crate::GatherMode::Shuffle(index)
1121                        | crate::GatherMode::ShuffleDown(index)
1122                        | crate::GatherMode::ShuffleUp(index)
1123                        | crate::GatherMode::ShuffleXor(index) => {
1124                            let _ = self.add_ref(index);
1125                        }
1126                    }
1127                    FunctionUniformity::new()
1128                }
1129            };
1130
1131            disruptor = disruptor.or(uniformity.exit_disruptor());
1132            combined_uniformity = combined_uniformity | uniformity;
1133        }
1134        Ok(combined_uniformity)
1135    }
1136}
1137
1138impl ModuleInfo {
1139    /// Populates `self.const_expression_types`
1140    pub(super) fn process_const_expression(
1141        &mut self,
1142        handle: Handle<crate::Expression>,
1143        resolve_context: &ResolveContext,
1144        gctx: crate::proc::GlobalCtx,
1145    ) -> Result<(), super::ConstExpressionError> {
1146        self.const_expression_types[handle.index()] =
1147            resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?;
1148        Ok(())
1149    }
1150
1151    /// Builds the `FunctionInfo` based on the function, and validates the
1152    /// uniform control flow if required by the expressions of this function.
1153    pub(super) fn process_function(
1154        &self,
1155        fun: &crate::Function,
1156        module: &crate::Module,
1157        flags: ValidationFlags,
1158        capabilities: super::Capabilities,
1159    ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1160        let mut info = FunctionInfo {
1161            flags,
1162            available_stages: ShaderStages::all(),
1163            uniformity: Uniformity::new(),
1164            may_kill: false,
1165            sampling_set: crate::FastHashSet::default(),
1166            global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
1167            expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
1168            sampling: crate::FastHashSet::default(),
1169            dual_source_blending: false,
1170            diagnostic_filter_leaf: fun.diagnostic_filter_leaf,
1171        };
1172        let resolve_context =
1173            ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
1174
1175        for (handle, _) in fun.expressions.iter() {
1176            if let Err(source) = info.process_expression(
1177                handle,
1178                &fun.expressions,
1179                &self.functions,
1180                &resolve_context,
1181                capabilities,
1182            ) {
1183                return Err(FunctionError::Expression { handle, source }
1184                    .with_span_handle(handle, &fun.expressions));
1185            }
1186        }
1187
1188        for (_, expr) in fun.local_variables.iter() {
1189            if let Some(init) = expr.init {
1190                let _ = info.add_ref(init);
1191            }
1192        }
1193
1194        let uniformity = info.process_block(
1195            &fun.body,
1196            &self.functions,
1197            None,
1198            &fun.expressions,
1199            &module.diagnostic_filters,
1200        )?;
1201        info.uniformity = uniformity.result;
1202        info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
1203
1204        Ok(info)
1205    }
1206
1207    pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
1208        &self.entry_points[index]
1209    }
1210}
1211
1212#[test]
1213fn uniform_control_flow() {
1214    use crate::{Expression as E, Statement as S};
1215
1216    let mut type_arena = crate::UniqueArena::new();
1217    let ty = type_arena.insert(
1218        crate::Type {
1219            name: None,
1220            inner: crate::TypeInner::Vector {
1221                size: crate::VectorSize::Bi,
1222                scalar: crate::Scalar::F32,
1223            },
1224        },
1225        Default::default(),
1226    );
1227    let mut global_var_arena = Arena::new();
1228    let non_uniform_global = global_var_arena.append(
1229        crate::GlobalVariable {
1230            name: None,
1231            init: None,
1232            ty,
1233            space: crate::AddressSpace::Handle,
1234            binding: None,
1235        },
1236        Default::default(),
1237    );
1238    let uniform_global = global_var_arena.append(
1239        crate::GlobalVariable {
1240            name: None,
1241            init: None,
1242            ty,
1243            binding: None,
1244            space: crate::AddressSpace::Uniform,
1245        },
1246        Default::default(),
1247    );
1248
1249    let mut expressions = Arena::new();
1250    // checks the uniform control flow
1251    let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default());
1252    // checks the non-uniform control flow
1253    let derivative_expr = expressions.append(
1254        E::Derivative {
1255            axis: crate::DerivativeAxis::X,
1256            ctrl: crate::DerivativeControl::None,
1257            expr: constant_expr,
1258        },
1259        Default::default(),
1260    );
1261    let emit_range_constant_derivative = expressions.range_from(0);
1262    let non_uniform_global_expr =
1263        expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
1264    let uniform_global_expr =
1265        expressions.append(E::GlobalVariable(uniform_global), Default::default());
1266    let emit_range_globals = expressions.range_from(2);
1267
1268    // checks the QUERY flag
1269    let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
1270    // checks the transitive WRITE flag
1271    let access_expr = expressions.append(
1272        E::AccessIndex {
1273            base: non_uniform_global_expr,
1274            index: 1,
1275        },
1276        Default::default(),
1277    );
1278    let emit_range_query_access_globals = expressions.range_from(2);
1279
1280    let mut info = FunctionInfo {
1281        flags: ValidationFlags::all(),
1282        available_stages: ShaderStages::all(),
1283        uniformity: Uniformity::new(),
1284        may_kill: false,
1285        sampling_set: crate::FastHashSet::default(),
1286        global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
1287        expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
1288        sampling: crate::FastHashSet::default(),
1289        dual_source_blending: false,
1290        diagnostic_filter_leaf: None,
1291    };
1292    let resolve_context = ResolveContext {
1293        constants: &Arena::new(),
1294        overrides: &Arena::new(),
1295        types: &type_arena,
1296        special_types: &crate::SpecialTypes::default(),
1297        global_vars: &global_var_arena,
1298        local_vars: &Arena::new(),
1299        functions: &Arena::new(),
1300        arguments: &[],
1301    };
1302    for (handle, _) in expressions.iter() {
1303        info.process_expression(
1304            handle,
1305            &expressions,
1306            &[],
1307            &resolve_context,
1308            super::Capabilities::empty(),
1309        )
1310        .unwrap();
1311    }
1312    assert_eq!(info[non_uniform_global_expr].ref_count, 1);
1313    assert_eq!(info[uniform_global_expr].ref_count, 1);
1314    assert_eq!(info[query_expr].ref_count, 0);
1315    assert_eq!(info[access_expr].ref_count, 0);
1316    assert_eq!(info[non_uniform_global], GlobalUse::empty());
1317    assert_eq!(info[uniform_global], GlobalUse::QUERY);
1318
1319    let stmt_emit1 = S::Emit(emit_range_globals.clone());
1320    let stmt_if_uniform = S::If {
1321        condition: uniform_global_expr,
1322        accept: crate::Block::new(),
1323        reject: vec![
1324            S::Emit(emit_range_constant_derivative.clone()),
1325            S::Store {
1326                pointer: constant_expr,
1327                value: derivative_expr,
1328            },
1329        ]
1330        .into(),
1331    };
1332    assert_eq!(
1333        info.process_block(
1334            &vec![stmt_emit1, stmt_if_uniform].into(),
1335            &[],
1336            None,
1337            &expressions,
1338            &Arena::new(),
1339        ),
1340        Ok(FunctionUniformity {
1341            result: Uniformity {
1342                non_uniform_result: None,
1343                requirements: UniformityRequirements::DERIVATIVE,
1344            },
1345            exit: ExitFlags::empty(),
1346        }),
1347    );
1348    assert_eq!(info[constant_expr].ref_count, 2);
1349    assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
1350
1351    let stmt_emit2 = S::Emit(emit_range_globals.clone());
1352    let stmt_if_non_uniform = S::If {
1353        condition: non_uniform_global_expr,
1354        accept: vec![
1355            S::Emit(emit_range_constant_derivative),
1356            S::Store {
1357                pointer: constant_expr,
1358                value: derivative_expr,
1359            },
1360        ]
1361        .into(),
1362        reject: crate::Block::new(),
1363    };
1364    {
1365        let block_info = info.process_block(
1366            &vec![stmt_emit2.clone(), stmt_if_non_uniform.clone()].into(),
1367            &[],
1368            None,
1369            &expressions,
1370            &Arena::new(),
1371        );
1372        if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE {
1373            assert_eq!(info[derivative_expr].ref_count, 2);
1374        } else {
1375            assert_eq!(
1376                block_info,
1377                Err(FunctionError::NonUniformControlFlow(
1378                    UniformityRequirements::DERIVATIVE,
1379                    derivative_expr,
1380                    UniformityDisruptor::Expression(non_uniform_global_expr)
1381                )
1382                .with_span()),
1383            );
1384            assert_eq!(info[derivative_expr].ref_count, 1);
1385
1386            // Test that the same thing passes when we disable the `derivative_uniformity`
1387            let mut diagnostic_filters = Arena::new();
1388            let diagnostic_filter_leaf = diagnostic_filters.append(
1389                DiagnosticFilterNode {
1390                    inner: crate::diagnostic_filter::DiagnosticFilter {
1391                        new_severity: crate::diagnostic_filter::Severity::Off,
1392                        triggering_rule:
1393                            crate::diagnostic_filter::FilterableTriggeringRule::Standard(
1394                                StandardFilterableTriggeringRule::DerivativeUniformity,
1395                            ),
1396                    },
1397                    parent: None,
1398                },
1399                crate::Span::default(),
1400            );
1401            let mut info = FunctionInfo {
1402                diagnostic_filter_leaf: Some(diagnostic_filter_leaf),
1403                ..info.clone()
1404            };
1405
1406            let block_info = info.process_block(
1407                &vec![stmt_emit2, stmt_if_non_uniform].into(),
1408                &[],
1409                None,
1410                &expressions,
1411                &diagnostic_filters,
1412            );
1413            assert_eq!(
1414                block_info,
1415                Ok(FunctionUniformity {
1416                    result: Uniformity {
1417                        non_uniform_result: None,
1418                        requirements: UniformityRequirements::DERIVATIVE,
1419                    },
1420                    exit: ExitFlags::empty()
1421                }),
1422            );
1423            assert_eq!(info[derivative_expr].ref_count, 2);
1424        }
1425    }
1426    assert_eq!(info[non_uniform_global], GlobalUse::READ);
1427
1428    let stmt_emit3 = S::Emit(emit_range_globals);
1429    let stmt_return_non_uniform = S::Return {
1430        value: Some(non_uniform_global_expr),
1431    };
1432    assert_eq!(
1433        info.process_block(
1434            &vec![stmt_emit3, stmt_return_non_uniform].into(),
1435            &[],
1436            Some(UniformityDisruptor::Return),
1437            &expressions,
1438            &Arena::new(),
1439        ),
1440        Ok(FunctionUniformity {
1441            result: Uniformity {
1442                non_uniform_result: Some(non_uniform_global_expr),
1443                requirements: UniformityRequirements::empty(),
1444            },
1445            exit: ExitFlags::MAY_RETURN,
1446        }),
1447    );
1448    assert_eq!(info[non_uniform_global_expr].ref_count, 3);
1449
1450    // Check that uniformity requirements reach through a pointer
1451    let stmt_emit4 = S::Emit(emit_range_query_access_globals);
1452    let stmt_assign = S::Store {
1453        pointer: access_expr,
1454        value: query_expr,
1455    };
1456    let stmt_return_pointer = S::Return {
1457        value: Some(access_expr),
1458    };
1459    let stmt_kill = S::Kill;
1460    assert_eq!(
1461        info.process_block(
1462            &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
1463            &[],
1464            Some(UniformityDisruptor::Discard),
1465            &expressions,
1466            &Arena::new(),
1467        ),
1468        Ok(FunctionUniformity {
1469            result: Uniformity {
1470                non_uniform_result: Some(non_uniform_global_expr),
1471                requirements: UniformityRequirements::empty(),
1472            },
1473            exit: ExitFlags::all(),
1474        }),
1475    );
1476    assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
1477}