naga/back/
pipeline_constants.rs

1use super::PipelineConstants;
2use crate::{
3    arena::HandleVec,
4    proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
5    valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
6    Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
7    Span, Statement, TypeInner, WithSpan,
8};
9use std::{borrow::Cow, collections::HashSet, mem};
10use thiserror::Error;
11
12#[derive(Error, Debug, Clone)]
13#[cfg_attr(test, derive(PartialEq))]
14pub enum PipelineConstantError {
15    #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
16    MissingValue(String),
17    #[error(
18        "Source f64 value needs to be finite ({}) for number destinations",
19        "NaNs and Inifinites are not allowed"
20    )]
21    SrcNeedsToBeFinite,
22    #[error("Source f64 value doesn't fit in destination")]
23    DstRangeTooSmall,
24    #[error(transparent)]
25    ConstantEvaluatorError(#[from] ConstantEvaluatorError),
26    #[error(transparent)]
27    ValidationError(#[from] WithSpan<ValidationError>),
28    #[error("workgroup_size override isn't strictly positive")]
29    NegativeWorkgroupSize,
30}
31
32/// Replace all overrides in `module` with constants.
33///
34/// If no changes are needed, this just returns `Cow::Borrowed`
35/// references to `module` and `module_info`. Otherwise, it clones
36/// `module`, edits its [`global_expressions`] arena to contain only
37/// fully-evaluated expressions, and returns `Cow::Owned` values
38/// holding the simplified module and its validation results.
39///
40/// In either case, the module returned has an empty `overrides`
41/// arena, and the `global_expressions` arena contains only
42/// fully-evaluated expressions.
43///
44/// [`global_expressions`]: Module::global_expressions
45pub fn process_overrides<'a>(
46    module: &'a Module,
47    module_info: &'a ModuleInfo,
48    pipeline_constants: &PipelineConstants,
49) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
50    if module.overrides.is_empty() {
51        return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
52    }
53
54    let mut module = module.clone();
55
56    // A map from override handles to the handles of the constants
57    // we've replaced them with.
58    let mut override_map = HandleVec::with_capacity(module.overrides.len());
59
60    // A map from `module`'s original global expression handles to
61    // handles in the new, simplified global expression arena.
62    let mut adjusted_global_expressions = HandleVec::with_capacity(module.global_expressions.len());
63
64    // The set of constants whose initializer handles we've already
65    // updated to refer to the newly built global expression arena.
66    //
67    // All constants in `module` must have their `init` handles
68    // updated to point into the new, simplified global expression
69    // arena. Some of these we can most easily handle as a side effect
70    // during the simplification process, but we must handle the rest
71    // in a final fixup pass, guided by `adjusted_global_expressions`. We
72    // add their handles to this set, so that the final fixup step can
73    // leave them alone.
74    let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
75
76    let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
77
78    // An iterator through the original overrides table, consumed in
79    // approximate tandem with the global expressions.
80    let mut override_iter = module.overrides.drain();
81
82    // Do two things in tandem:
83    //
84    // - Rebuild the global expression arena from scratch, fully
85    //   evaluating all expressions, and replacing each `Override`
86    //   expression in `module.global_expressions` with a `Constant`
87    //   expression.
88    //
89    // - Build a new `Constant` in `module.constants` to take the
90    //   place of each `Override`.
91    //
92    // Build a map from old global expression handles to their
93    // fully-evaluated counterparts in `adjusted_global_expressions` as we
94    // go.
95    //
96    // Why in tandem? Overrides refer to expressions, and expressions
97    // refer to overrides, so we can't disentangle the two into
98    // separate phases. However, we can take advantage of the fact
99    // that the overrides and expressions must form a DAG, and work
100    // our way from the leaves to the roots, replacing and evaluating
101    // as we go.
102    //
103    // Although the two loops are nested, this is really two
104    // alternating phases: we adjust and evaluate constant expressions
105    // until we hit an `Override` expression, at which point we switch
106    // to building `Constant`s for `Overrides` until we've handled the
107    // one used by the expression. Then we switch back to processing
108    // expressions. Because we know they form a DAG, we know the
109    // `Override` expressions we encounter can only have initializers
110    // referring to global expressions we've already simplified.
111    for (old_h, expr, span) in module.global_expressions.drain() {
112        let mut expr = match expr {
113            Expression::Override(h) => {
114                let c_h = if let Some(new_h) = override_map.get(h) {
115                    *new_h
116                } else {
117                    let mut new_h = None;
118                    for entry in override_iter.by_ref() {
119                        let stop = entry.0 == h;
120                        new_h = Some(process_override(
121                            entry,
122                            pipeline_constants,
123                            &mut module,
124                            &mut override_map,
125                            &adjusted_global_expressions,
126                            &mut adjusted_constant_initializers,
127                            &mut global_expression_kind_tracker,
128                        )?);
129                        if stop {
130                            break;
131                        }
132                    }
133                    new_h.unwrap()
134                };
135                Expression::Constant(c_h)
136            }
137            Expression::Constant(c_h) => {
138                if adjusted_constant_initializers.insert(c_h) {
139                    let init = &mut module.constants[c_h].init;
140                    *init = adjusted_global_expressions[*init];
141                }
142                expr
143            }
144            expr => expr,
145        };
146        let mut evaluator = ConstantEvaluator::for_wgsl_module(
147            &mut module,
148            &mut global_expression_kind_tracker,
149            false,
150        );
151        adjust_expr(&adjusted_global_expressions, &mut expr);
152        let h = evaluator.try_eval_and_append(expr, span)?;
153        adjusted_global_expressions.insert(old_h, h);
154    }
155
156    // Finish processing any overrides we didn't visit in the loop above.
157    for entry in override_iter {
158        process_override(
159            entry,
160            pipeline_constants,
161            &mut module,
162            &mut override_map,
163            &adjusted_global_expressions,
164            &mut adjusted_constant_initializers,
165            &mut global_expression_kind_tracker,
166        )?;
167    }
168
169    // Update the initialization expression handles of all `Constant`s
170    // and `GlobalVariable`s. Skip `Constant`s we'd already updated en
171    // passant.
172    for (_, c) in module
173        .constants
174        .iter_mut()
175        .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
176    {
177        c.init = adjusted_global_expressions[c.init];
178    }
179
180    for (_, v) in module.global_variables.iter_mut() {
181        if let Some(ref mut init) = v.init {
182            *init = adjusted_global_expressions[*init];
183        }
184    }
185
186    let mut functions = mem::take(&mut module.functions);
187    for (_, function) in functions.iter_mut() {
188        process_function(&mut module, &override_map, function)?;
189    }
190    module.functions = functions;
191
192    let mut entry_points = mem::take(&mut module.entry_points);
193    for ep in entry_points.iter_mut() {
194        process_function(&mut module, &override_map, &mut ep.function)?;
195        process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?;
196    }
197    module.entry_points = entry_points;
198
199    process_pending(&mut module, &override_map, &adjusted_global_expressions)?;
200
201    // Now that we've rewritten all the expressions, we need to
202    // recompute their types and other metadata. For the time being,
203    // do a full re-validation.
204    let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
205    let module_info = validator.validate_no_overrides(&module)?;
206
207    Ok((Cow::Owned(module), Cow::Owned(module_info)))
208}
209
210fn process_pending(
211    module: &mut Module,
212    override_map: &HandleVec<Override, Handle<Constant>>,
213    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
214) -> Result<(), PipelineConstantError> {
215    for (handle, ty) in module.types.clone().iter() {
216        if let TypeInner::Array {
217            base,
218            size: crate::ArraySize::Pending(size),
219            stride,
220        } = ty.inner
221        {
222            let expr = match size {
223                crate::PendingArraySize::Expression(size_expr) => {
224                    adjusted_global_expressions[size_expr]
225                }
226                crate::PendingArraySize::Override(size_override) => {
227                    module.constants[override_map[size_override]].init
228                }
229            };
230            let value = module
231                .to_ctx()
232                .eval_expr_to_u32(expr)
233                .map(|n| {
234                    if n == 0 {
235                        Err(PipelineConstantError::ValidationError(
236                            WithSpan::new(ValidationError::ArraySizeError { handle: expr })
237                                .with_span(
238                                    module.global_expressions.get_span(expr),
239                                    "evaluated to zero",
240                                ),
241                        ))
242                    } else {
243                        Ok(std::num::NonZeroU32::new(n).unwrap())
244                    }
245                })
246                .map_err(|_| {
247                    PipelineConstantError::ValidationError(
248                        WithSpan::new(ValidationError::ArraySizeError { handle: expr })
249                            .with_span(module.global_expressions.get_span(expr), "negative"),
250                    )
251                })??;
252            module.types.replace(
253                handle,
254                crate::Type {
255                    name: None,
256                    inner: TypeInner::Array {
257                        base,
258                        size: crate::ArraySize::Constant(value),
259                        stride,
260                    },
261                },
262            );
263        }
264    }
265    Ok(())
266}
267
268fn process_workgroup_size_override(
269    module: &mut Module,
270    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
271    ep: &mut crate::EntryPoint,
272) -> Result<(), PipelineConstantError> {
273    match ep.workgroup_size_overrides {
274        None => {}
275        Some(overrides) => {
276            overrides.iter().enumerate().try_for_each(
277                |(i, overridden)| -> Result<(), PipelineConstantError> {
278                    match *overridden {
279                        None => Ok(()),
280                        Some(h) => {
281                            ep.workgroup_size[i] = module
282                                .to_ctx()
283                                .eval_expr_to_u32(adjusted_global_expressions[h])
284                                .map(|n| {
285                                    if n == 0 {
286                                        Err(PipelineConstantError::NegativeWorkgroupSize)
287                                    } else {
288                                        Ok(n)
289                                    }
290                                })
291                                .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)??;
292                            Ok(())
293                        }
294                    }
295                },
296            )?;
297            ep.workgroup_size_overrides = None;
298        }
299    }
300    Ok(())
301}
302
303/// Add a [`Constant`] to `module` for the override `old_h`.
304///
305/// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`.
306fn process_override(
307    (old_h, override_, span): (Handle<Override>, Override, Span),
308    pipeline_constants: &PipelineConstants,
309    module: &mut Module,
310    override_map: &mut HandleVec<Override, Handle<Constant>>,
311    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
312    adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
313    global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
314) -> Result<Handle<Constant>, PipelineConstantError> {
315    // Determine which key to use for `override_` in `pipeline_constants`.
316    let key = if let Some(id) = override_.id {
317        Cow::Owned(id.to_string())
318    } else if let Some(ref name) = override_.name {
319        Cow::Borrowed(name)
320    } else {
321        unreachable!();
322    };
323
324    // Generate a global expression for `override_`'s value, either
325    // from the provided `pipeline_constants` table or its initializer
326    // in the module.
327    let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
328        let literal = match module.types[override_.ty].inner {
329            TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
330            _ => unreachable!(),
331        };
332        let expr = module
333            .global_expressions
334            .append(Expression::Literal(literal), Span::UNDEFINED);
335        global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
336        expr
337    } else if let Some(init) = override_.init {
338        adjusted_global_expressions[init]
339    } else {
340        return Err(PipelineConstantError::MissingValue(key.to_string()));
341    };
342
343    // Generate a new `Constant` to represent the override's value.
344    let constant = Constant {
345        name: override_.name,
346        ty: override_.ty,
347        init,
348    };
349    let h = module.constants.append(constant, span);
350    override_map.insert(old_h, h);
351    adjusted_constant_initializers.insert(h);
352    Ok(h)
353}
354
355/// Replace all override expressions in `function` with fully-evaluated constants.
356///
357/// Replace all `Expression::Override`s in `function`'s expression arena with
358/// the corresponding `Expression::Constant`s, as given in `override_map`.
359/// Replace any expressions whose values are now known with their fully
360/// evaluated form.
361///
362/// If `h` is a `Handle<Override>`, then `override_map[h]` is the
363/// `Handle<Constant>` for the override's final value.
364fn process_function(
365    module: &mut Module,
366    override_map: &HandleVec<Override, Handle<Constant>>,
367    function: &mut Function,
368) -> Result<(), ConstantEvaluatorError> {
369    // A map from original local expression handles to
370    // handles in the new, local expression arena.
371    let mut adjusted_local_expressions = HandleVec::with_capacity(function.expressions.len());
372
373    let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
374
375    let mut expressions = mem::take(&mut function.expressions);
376
377    // Dummy `emitter` and `block` for the constant evaluator.
378    // We can ignore the concept of emitting expressions here since
379    // expressions have already been covered by a `Statement::Emit`
380    // in the frontend.
381    // The only thing we might have to do is remove some expressions
382    // that have been covered by a `Statement::Emit`. See the docs of
383    // `filter_emits_in_block` for the reasoning.
384    let mut emitter = Emitter::default();
385    let mut block = Block::new();
386
387    let mut evaluator = ConstantEvaluator::for_wgsl_function(
388        module,
389        &mut function.expressions,
390        &mut local_expression_kind_tracker,
391        &mut emitter,
392        &mut block,
393        false,
394    );
395
396    for (old_h, mut expr, span) in expressions.drain() {
397        if let Expression::Override(h) = expr {
398            expr = Expression::Constant(override_map[h]);
399        }
400        adjust_expr(&adjusted_local_expressions, &mut expr);
401        let h = evaluator.try_eval_and_append(expr, span)?;
402        adjusted_local_expressions.insert(old_h, h);
403    }
404
405    adjust_block(&adjusted_local_expressions, &mut function.body);
406
407    filter_emits_in_block(&mut function.body, &function.expressions);
408
409    // Update local expression initializers.
410    for (_, local) in function.local_variables.iter_mut() {
411        if let &mut Some(ref mut init) = &mut local.init {
412            *init = adjusted_local_expressions[*init];
413        }
414    }
415
416    // We've changed the keys of `function.named_expression`, so we have to
417    // rebuild it from scratch.
418    let named_expressions = mem::take(&mut function.named_expressions);
419    for (expr_h, name) in named_expressions {
420        function
421            .named_expressions
422            .insert(adjusted_local_expressions[expr_h], name);
423    }
424
425    Ok(())
426}
427
428/// Replace every expression handle in `expr` with its counterpart
429/// given by `new_pos`.
430fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut Expression) {
431    let adjust = |expr: &mut Handle<Expression>| {
432        *expr = new_pos[*expr];
433    };
434    match *expr {
435        Expression::Compose {
436            ref mut components,
437            ty: _,
438        } => {
439            for c in components.iter_mut() {
440                adjust(c);
441            }
442        }
443        Expression::Access {
444            ref mut base,
445            ref mut index,
446        } => {
447            adjust(base);
448            adjust(index);
449        }
450        Expression::AccessIndex {
451            ref mut base,
452            index: _,
453        } => {
454            adjust(base);
455        }
456        Expression::Splat {
457            ref mut value,
458            size: _,
459        } => {
460            adjust(value);
461        }
462        Expression::Swizzle {
463            ref mut vector,
464            size: _,
465            pattern: _,
466        } => {
467            adjust(vector);
468        }
469        Expression::Load { ref mut pointer } => {
470            adjust(pointer);
471        }
472        Expression::ImageSample {
473            ref mut image,
474            ref mut sampler,
475            ref mut coordinate,
476            ref mut array_index,
477            ref mut offset,
478            ref mut level,
479            ref mut depth_ref,
480            gather: _,
481        } => {
482            adjust(image);
483            adjust(sampler);
484            adjust(coordinate);
485            if let Some(e) = array_index.as_mut() {
486                adjust(e);
487            }
488            if let Some(e) = offset.as_mut() {
489                adjust(e);
490            }
491            match *level {
492                crate::SampleLevel::Exact(ref mut expr)
493                | crate::SampleLevel::Bias(ref mut expr) => {
494                    adjust(expr);
495                }
496                crate::SampleLevel::Gradient {
497                    ref mut x,
498                    ref mut y,
499                } => {
500                    adjust(x);
501                    adjust(y);
502                }
503                _ => {}
504            }
505            if let Some(e) = depth_ref.as_mut() {
506                adjust(e);
507            }
508        }
509        Expression::ImageLoad {
510            ref mut image,
511            ref mut coordinate,
512            ref mut array_index,
513            ref mut sample,
514            ref mut level,
515        } => {
516            adjust(image);
517            adjust(coordinate);
518            if let Some(e) = array_index.as_mut() {
519                adjust(e);
520            }
521            if let Some(e) = sample.as_mut() {
522                adjust(e);
523            }
524            if let Some(e) = level.as_mut() {
525                adjust(e);
526            }
527        }
528        Expression::ImageQuery {
529            ref mut image,
530            ref mut query,
531        } => {
532            adjust(image);
533            match *query {
534                crate::ImageQuery::Size { ref mut level } => {
535                    if let Some(e) = level.as_mut() {
536                        adjust(e);
537                    }
538                }
539                crate::ImageQuery::NumLevels
540                | crate::ImageQuery::NumLayers
541                | crate::ImageQuery::NumSamples => {}
542            }
543        }
544        Expression::Unary {
545            ref mut expr,
546            op: _,
547        } => {
548            adjust(expr);
549        }
550        Expression::Binary {
551            ref mut left,
552            ref mut right,
553            op: _,
554        } => {
555            adjust(left);
556            adjust(right);
557        }
558        Expression::Select {
559            ref mut condition,
560            ref mut accept,
561            ref mut reject,
562        } => {
563            adjust(condition);
564            adjust(accept);
565            adjust(reject);
566        }
567        Expression::Derivative {
568            ref mut expr,
569            axis: _,
570            ctrl: _,
571        } => {
572            adjust(expr);
573        }
574        Expression::Relational {
575            ref mut argument,
576            fun: _,
577        } => {
578            adjust(argument);
579        }
580        Expression::Math {
581            ref mut arg,
582            ref mut arg1,
583            ref mut arg2,
584            ref mut arg3,
585            fun: _,
586        } => {
587            adjust(arg);
588            if let Some(e) = arg1.as_mut() {
589                adjust(e);
590            }
591            if let Some(e) = arg2.as_mut() {
592                adjust(e);
593            }
594            if let Some(e) = arg3.as_mut() {
595                adjust(e);
596            }
597        }
598        Expression::As {
599            ref mut expr,
600            kind: _,
601            convert: _,
602        } => {
603            adjust(expr);
604        }
605        Expression::ArrayLength(ref mut expr) => {
606            adjust(expr);
607        }
608        Expression::RayQueryGetIntersection {
609            ref mut query,
610            committed: _,
611        } => {
612            adjust(query);
613        }
614        Expression::Literal(_)
615        | Expression::FunctionArgument(_)
616        | Expression::GlobalVariable(_)
617        | Expression::LocalVariable(_)
618        | Expression::CallResult(_)
619        | Expression::RayQueryProceedResult
620        | Expression::Constant(_)
621        | Expression::Override(_)
622        | Expression::ZeroValue(_)
623        | Expression::AtomicResult {
624            ty: _,
625            comparison: _,
626        }
627        | Expression::WorkGroupUniformLoadResult { ty: _ }
628        | Expression::SubgroupBallotResult
629        | Expression::SubgroupOperationResult { .. } => {}
630    }
631}
632
633/// Replace every expression handle in `block` with its counterpart
634/// given by `new_pos`.
635fn adjust_block(new_pos: &HandleVec<Expression, Handle<Expression>>, block: &mut Block) {
636    for stmt in block.iter_mut() {
637        adjust_stmt(new_pos, stmt);
638    }
639}
640
641/// Replace every expression handle in `stmt` with its counterpart
642/// given by `new_pos`.
643fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut Statement) {
644    let adjust = |expr: &mut Handle<Expression>| {
645        *expr = new_pos[*expr];
646    };
647    match *stmt {
648        Statement::Emit(ref mut range) => {
649            if let Some((mut first, mut last)) = range.first_and_last() {
650                adjust(&mut first);
651                adjust(&mut last);
652                *range = Range::new_from_bounds(first, last);
653            }
654        }
655        Statement::Block(ref mut block) => {
656            adjust_block(new_pos, block);
657        }
658        Statement::If {
659            ref mut condition,
660            ref mut accept,
661            ref mut reject,
662        } => {
663            adjust(condition);
664            adjust_block(new_pos, accept);
665            adjust_block(new_pos, reject);
666        }
667        Statement::Switch {
668            ref mut selector,
669            ref mut cases,
670        } => {
671            adjust(selector);
672            for case in cases.iter_mut() {
673                adjust_block(new_pos, &mut case.body);
674            }
675        }
676        Statement::Loop {
677            ref mut body,
678            ref mut continuing,
679            ref mut break_if,
680        } => {
681            adjust_block(new_pos, body);
682            adjust_block(new_pos, continuing);
683            if let Some(e) = break_if.as_mut() {
684                adjust(e);
685            }
686        }
687        Statement::Return { ref mut value } => {
688            if let Some(e) = value.as_mut() {
689                adjust(e);
690            }
691        }
692        Statement::Store {
693            ref mut pointer,
694            ref mut value,
695        } => {
696            adjust(pointer);
697            adjust(value);
698        }
699        Statement::ImageStore {
700            ref mut image,
701            ref mut coordinate,
702            ref mut array_index,
703            ref mut value,
704        } => {
705            adjust(image);
706            adjust(coordinate);
707            if let Some(e) = array_index.as_mut() {
708                adjust(e);
709            }
710            adjust(value);
711        }
712        Statement::Atomic {
713            ref mut pointer,
714            ref mut value,
715            ref mut result,
716            ref mut fun,
717        } => {
718            adjust(pointer);
719            adjust(value);
720            if let Some(ref mut result) = *result {
721                adjust(result);
722            }
723            match *fun {
724                crate::AtomicFunction::Exchange {
725                    compare: Some(ref mut compare),
726                } => {
727                    adjust(compare);
728                }
729                crate::AtomicFunction::Add
730                | crate::AtomicFunction::Subtract
731                | crate::AtomicFunction::And
732                | crate::AtomicFunction::ExclusiveOr
733                | crate::AtomicFunction::InclusiveOr
734                | crate::AtomicFunction::Min
735                | crate::AtomicFunction::Max
736                | crate::AtomicFunction::Exchange { compare: None } => {}
737            }
738        }
739        Statement::ImageAtomic {
740            ref mut image,
741            ref mut coordinate,
742            ref mut array_index,
743            fun: _,
744            ref mut value,
745        } => {
746            adjust(image);
747            adjust(coordinate);
748            if let Some(ref mut array_index) = *array_index {
749                adjust(array_index);
750            }
751            adjust(value);
752        }
753        Statement::WorkGroupUniformLoad {
754            ref mut pointer,
755            ref mut result,
756        } => {
757            adjust(pointer);
758            adjust(result);
759        }
760        Statement::SubgroupBallot {
761            ref mut result,
762            ref mut predicate,
763        } => {
764            if let Some(ref mut predicate) = *predicate {
765                adjust(predicate);
766            }
767            adjust(result);
768        }
769        Statement::SubgroupCollectiveOperation {
770            ref mut argument,
771            ref mut result,
772            ..
773        } => {
774            adjust(argument);
775            adjust(result);
776        }
777        Statement::SubgroupGather {
778            ref mut mode,
779            ref mut argument,
780            ref mut result,
781        } => {
782            match *mode {
783                crate::GatherMode::BroadcastFirst => {}
784                crate::GatherMode::Broadcast(ref mut index)
785                | crate::GatherMode::Shuffle(ref mut index)
786                | crate::GatherMode::ShuffleDown(ref mut index)
787                | crate::GatherMode::ShuffleUp(ref mut index)
788                | crate::GatherMode::ShuffleXor(ref mut index) => {
789                    adjust(index);
790                }
791            }
792            adjust(argument);
793            adjust(result)
794        }
795        Statement::Call {
796            ref mut arguments,
797            ref mut result,
798            function: _,
799        } => {
800            for argument in arguments.iter_mut() {
801                adjust(argument);
802            }
803            if let Some(e) = result.as_mut() {
804                adjust(e);
805            }
806        }
807        Statement::RayQuery {
808            ref mut query,
809            ref mut fun,
810        } => {
811            adjust(query);
812            match *fun {
813                crate::RayQueryFunction::Initialize {
814                    ref mut acceleration_structure,
815                    ref mut descriptor,
816                } => {
817                    adjust(acceleration_structure);
818                    adjust(descriptor);
819                }
820                crate::RayQueryFunction::Proceed { ref mut result } => {
821                    adjust(result);
822                }
823                crate::RayQueryFunction::Terminate => {}
824            }
825        }
826        Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {}
827    }
828}
829
830/// Adjust [`Emit`] statements in `block` to skip [`needs_pre_emit`] expressions we have introduced.
831///
832/// According to validation, [`Emit`] statements must not cover any expressions
833/// for which [`Expression::needs_pre_emit`] returns true. All expressions built
834/// by successful constant evaluation fall into that category, meaning that
835/// `process_function` will usually rewrite [`Override`] expressions and those
836/// that use their values into pre-emitted expressions, leaving any [`Emit`]
837/// statements that cover them invalid.
838///
839/// This function rewrites all [`Emit`] statements into zero or more new
840/// [`Emit`] statements covering only those expressions in the original range
841/// that are not pre-emitted.
842///
843/// [`Emit`]: Statement::Emit
844/// [`needs_pre_emit`]: Expression::needs_pre_emit
845/// [`Override`]: Expression::Override
846fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) {
847    let original = mem::replace(block, Block::with_capacity(block.len()));
848    for (stmt, span) in original.span_into_iter() {
849        match stmt {
850            Statement::Emit(range) => {
851                let mut current = None;
852                for expr_h in range {
853                    if expressions[expr_h].needs_pre_emit() {
854                        if let Some((first, last)) = current {
855                            block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
856                        }
857
858                        current = None;
859                    } else if let Some((_, ref mut last)) = current {
860                        *last = expr_h;
861                    } else {
862                        current = Some((expr_h, expr_h));
863                    }
864                }
865                if let Some((first, last)) = current {
866                    block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
867                }
868            }
869            Statement::Block(mut child) => {
870                filter_emits_in_block(&mut child, expressions);
871                block.push(Statement::Block(child), span);
872            }
873            Statement::If {
874                condition,
875                mut accept,
876                mut reject,
877            } => {
878                filter_emits_in_block(&mut accept, expressions);
879                filter_emits_in_block(&mut reject, expressions);
880                block.push(
881                    Statement::If {
882                        condition,
883                        accept,
884                        reject,
885                    },
886                    span,
887                );
888            }
889            Statement::Switch {
890                selector,
891                mut cases,
892            } => {
893                for case in &mut cases {
894                    filter_emits_in_block(&mut case.body, expressions);
895                }
896                block.push(Statement::Switch { selector, cases }, span);
897            }
898            Statement::Loop {
899                mut body,
900                mut continuing,
901                break_if,
902            } => {
903                filter_emits_in_block(&mut body, expressions);
904                filter_emits_in_block(&mut continuing, expressions);
905                block.push(
906                    Statement::Loop {
907                        body,
908                        continuing,
909                        break_if,
910                    },
911                    span,
912                );
913            }
914            stmt => block.push(stmt.clone(), span),
915        }
916    }
917}
918
919fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
920    // note that in rust 0.0 == -0.0
921    match scalar {
922        Scalar::BOOL => {
923            // https://webidl.spec.whatwg.org/#js-boolean
924            let value = value != 0.0 && !value.is_nan();
925            Ok(Literal::Bool(value))
926        }
927        Scalar::I32 => {
928            // https://webidl.spec.whatwg.org/#js-long
929            if !value.is_finite() {
930                return Err(PipelineConstantError::SrcNeedsToBeFinite);
931            }
932
933            let value = value.trunc();
934            if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
935                return Err(PipelineConstantError::DstRangeTooSmall);
936            }
937
938            let value = value as i32;
939            Ok(Literal::I32(value))
940        }
941        Scalar::U32 => {
942            // https://webidl.spec.whatwg.org/#js-unsigned-long
943            if !value.is_finite() {
944                return Err(PipelineConstantError::SrcNeedsToBeFinite);
945            }
946
947            let value = value.trunc();
948            if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
949                return Err(PipelineConstantError::DstRangeTooSmall);
950            }
951
952            let value = value as u32;
953            Ok(Literal::U32(value))
954        }
955        Scalar::F32 => {
956            // https://webidl.spec.whatwg.org/#js-float
957            if !value.is_finite() {
958                return Err(PipelineConstantError::SrcNeedsToBeFinite);
959            }
960
961            let value = value as f32;
962            if !value.is_finite() {
963                return Err(PipelineConstantError::DstRangeTooSmall);
964            }
965
966            Ok(Literal::F32(value))
967        }
968        Scalar::F64 => {
969            // https://webidl.spec.whatwg.org/#js-double
970            if !value.is_finite() {
971                return Err(PipelineConstantError::SrcNeedsToBeFinite);
972            }
973
974            Ok(Literal::F64(value))
975        }
976        _ => unreachable!(),
977    }
978}
979
980#[test]
981fn test_map_value_to_literal() {
982    let bool_test_cases = [
983        (0.0, false),
984        (-0.0, false),
985        (f64::NAN, false),
986        (1.0, true),
987        (f64::INFINITY, true),
988        (f64::NEG_INFINITY, true),
989    ];
990    for (value, out) in bool_test_cases {
991        let res = Ok(Literal::Bool(out));
992        assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
993    }
994
995    for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
996        for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
997            let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
998            assert_eq!(map_value_to_literal(value, scalar), res);
999        }
1000    }
1001
1002    // i32
1003    assert_eq!(
1004        map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
1005        Ok(Literal::I32(i32::MIN))
1006    );
1007    assert_eq!(
1008        map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
1009        Ok(Literal::I32(i32::MAX))
1010    );
1011    assert_eq!(
1012        map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
1013        Err(PipelineConstantError::DstRangeTooSmall)
1014    );
1015    assert_eq!(
1016        map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
1017        Err(PipelineConstantError::DstRangeTooSmall)
1018    );
1019
1020    // u32
1021    assert_eq!(
1022        map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
1023        Ok(Literal::U32(u32::MIN))
1024    );
1025    assert_eq!(
1026        map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
1027        Ok(Literal::U32(u32::MAX))
1028    );
1029    assert_eq!(
1030        map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
1031        Err(PipelineConstantError::DstRangeTooSmall)
1032    );
1033    assert_eq!(
1034        map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
1035        Err(PipelineConstantError::DstRangeTooSmall)
1036    );
1037
1038    // f32
1039    assert_eq!(
1040        map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
1041        Ok(Literal::F32(f32::MIN))
1042    );
1043    assert_eq!(
1044        map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
1045        Ok(Literal::F32(f32::MAX))
1046    );
1047    assert_eq!(
1048        map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
1049        Ok(Literal::F32(f32::MIN))
1050    );
1051    assert_eq!(
1052        map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
1053        Ok(Literal::F32(f32::MAX))
1054    );
1055    assert_eq!(
1056        map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
1057        Err(PipelineConstantError::DstRangeTooSmall)
1058    );
1059    assert_eq!(
1060        map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
1061        Err(PipelineConstantError::DstRangeTooSmall)
1062    );
1063
1064    // f64
1065    assert_eq!(
1066        map_value_to_literal(f64::MIN, Scalar::F64),
1067        Ok(Literal::F64(f64::MIN))
1068    );
1069    assert_eq!(
1070        map_value_to_literal(f64::MAX, Scalar::F64),
1071        Ok(Literal::F64(f64::MAX))
1072    );
1073}