naga/back/
pipeline_constants.rs

1use super::PipelineConstants;
2use crate::{
3    arena::HandleVec,
4    proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
5    valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
6    Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
7    Span, Statement, TypeInner, WithSpan,
8};
9use std::{borrow::Cow, collections::HashSet, mem};
10use thiserror::Error;
11
12#[derive(Error, Debug, Clone)]
13#[cfg_attr(test, derive(PartialEq))]
14pub enum PipelineConstantError {
15    #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
16    MissingValue(String),
17    #[error(
18        "Source f64 value needs to be finite ({}) for number destinations",
19        "NaNs and Inifinites are not allowed"
20    )]
21    SrcNeedsToBeFinite,
22    #[error("Source f64 value doesn't fit in destination")]
23    DstRangeTooSmall,
24    #[error(transparent)]
25    ConstantEvaluatorError(#[from] ConstantEvaluatorError),
26    #[error(transparent)]
27    ValidationError(#[from] WithSpan<ValidationError>),
28}
29
30/// Replace all overrides in `module` with constants.
31///
32/// If no changes are needed, this just returns `Cow::Borrowed`
33/// references to `module` and `module_info`. Otherwise, it clones
34/// `module`, edits its [`global_expressions`] arena to contain only
35/// fully-evaluated expressions, and returns `Cow::Owned` values
36/// holding the simplified module and its validation results.
37///
38/// In either case, the module returned has an empty `overrides`
39/// arena, and the `global_expressions` arena contains only
40/// fully-evaluated expressions.
41///
42/// [`global_expressions`]: Module::global_expressions
43pub fn process_overrides<'a>(
44    module: &'a Module,
45    module_info: &'a ModuleInfo,
46    pipeline_constants: &PipelineConstants,
47) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
48    if module.overrides.is_empty() {
49        return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
50    }
51
52    let mut module = module.clone();
53
54    // A map from override handles to the handles of the constants
55    // we've replaced them with.
56    let mut override_map = HandleVec::with_capacity(module.overrides.len());
57
58    // A map from `module`'s original global expression handles to
59    // handles in the new, simplified global expression arena.
60    let mut adjusted_global_expressions = HandleVec::with_capacity(module.global_expressions.len());
61
62    // The set of constants whose initializer handles we've already
63    // updated to refer to the newly built global expression arena.
64    //
65    // All constants in `module` must have their `init` handles
66    // updated to point into the new, simplified global expression
67    // arena. Some of these we can most easily handle as a side effect
68    // during the simplification process, but we must handle the rest
69    // in a final fixup pass, guided by `adjusted_global_expressions`. We
70    // add their handles to this set, so that the final fixup step can
71    // leave them alone.
72    let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
73
74    let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
75
76    // An iterator through the original overrides table, consumed in
77    // approximate tandem with the global expressions.
78    let mut override_iter = module.overrides.drain();
79
80    // Do two things in tandem:
81    //
82    // - Rebuild the global expression arena from scratch, fully
83    //   evaluating all expressions, and replacing each `Override`
84    //   expression in `module.global_expressions` with a `Constant`
85    //   expression.
86    //
87    // - Build a new `Constant` in `module.constants` to take the
88    //   place of each `Override`.
89    //
90    // Build a map from old global expression handles to their
91    // fully-evaluated counterparts in `adjusted_global_expressions` as we
92    // go.
93    //
94    // Why in tandem? Overrides refer to expressions, and expressions
95    // refer to overrides, so we can't disentangle the two into
96    // separate phases. However, we can take advantage of the fact
97    // that the overrides and expressions must form a DAG, and work
98    // our way from the leaves to the roots, replacing and evaluating
99    // as we go.
100    //
101    // Although the two loops are nested, this is really two
102    // alternating phases: we adjust and evaluate constant expressions
103    // until we hit an `Override` expression, at which point we switch
104    // to building `Constant`s for `Overrides` until we've handled the
105    // one used by the expression. Then we switch back to processing
106    // expressions. Because we know they form a DAG, we know the
107    // `Override` expressions we encounter can only have initializers
108    // referring to global expressions we've already simplified.
109    for (old_h, expr, span) in module.global_expressions.drain() {
110        let mut expr = match expr {
111            Expression::Override(h) => {
112                let c_h = if let Some(new_h) = override_map.get(h) {
113                    *new_h
114                } else {
115                    let mut new_h = None;
116                    for entry in override_iter.by_ref() {
117                        let stop = entry.0 == h;
118                        new_h = Some(process_override(
119                            entry,
120                            pipeline_constants,
121                            &mut module,
122                            &mut override_map,
123                            &adjusted_global_expressions,
124                            &mut adjusted_constant_initializers,
125                            &mut global_expression_kind_tracker,
126                        )?);
127                        if stop {
128                            break;
129                        }
130                    }
131                    new_h.unwrap()
132                };
133                Expression::Constant(c_h)
134            }
135            Expression::Constant(c_h) => {
136                if adjusted_constant_initializers.insert(c_h) {
137                    let init = &mut module.constants[c_h].init;
138                    *init = adjusted_global_expressions[*init];
139                }
140                expr
141            }
142            expr => expr,
143        };
144        let mut evaluator = ConstantEvaluator::for_wgsl_module(
145            &mut module,
146            &mut global_expression_kind_tracker,
147            false,
148        );
149        adjust_expr(&adjusted_global_expressions, &mut expr);
150        let h = evaluator.try_eval_and_append(expr, span)?;
151        adjusted_global_expressions.insert(old_h, h);
152    }
153
154    // Finish processing any overrides we didn't visit in the loop above.
155    for entry in override_iter {
156        process_override(
157            entry,
158            pipeline_constants,
159            &mut module,
160            &mut override_map,
161            &adjusted_global_expressions,
162            &mut adjusted_constant_initializers,
163            &mut global_expression_kind_tracker,
164        )?;
165    }
166
167    // Update the initialization expression handles of all `Constant`s
168    // and `GlobalVariable`s. Skip `Constant`s we'd already updated en
169    // passant.
170    for (_, c) in module
171        .constants
172        .iter_mut()
173        .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
174    {
175        c.init = adjusted_global_expressions[c.init];
176    }
177
178    for (_, v) in module.global_variables.iter_mut() {
179        if let Some(ref mut init) = v.init {
180            *init = adjusted_global_expressions[*init];
181        }
182    }
183
184    let mut functions = mem::take(&mut module.functions);
185    for (_, function) in functions.iter_mut() {
186        process_function(&mut module, &override_map, function)?;
187    }
188    module.functions = functions;
189
190    let mut entry_points = mem::take(&mut module.entry_points);
191    for ep in entry_points.iter_mut() {
192        process_function(&mut module, &override_map, &mut ep.function)?;
193    }
194    module.entry_points = entry_points;
195
196    // Now that we've rewritten all the expressions, we need to
197    // recompute their types and other metadata. For the time being,
198    // do a full re-validation.
199    let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
200    let module_info = validator.validate_no_overrides(&module)?;
201
202    Ok((Cow::Owned(module), Cow::Owned(module_info)))
203}
204
205/// Add a [`Constant`] to `module` for the override `old_h`.
206///
207/// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`.
208fn process_override(
209    (old_h, override_, span): (Handle<Override>, Override, Span),
210    pipeline_constants: &PipelineConstants,
211    module: &mut Module,
212    override_map: &mut HandleVec<Override, Handle<Constant>>,
213    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
214    adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
215    global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
216) -> Result<Handle<Constant>, PipelineConstantError> {
217    // Determine which key to use for `override_` in `pipeline_constants`.
218    let key = if let Some(id) = override_.id {
219        Cow::Owned(id.to_string())
220    } else if let Some(ref name) = override_.name {
221        Cow::Borrowed(name)
222    } else {
223        unreachable!();
224    };
225
226    // Generate a global expression for `override_`'s value, either
227    // from the provided `pipeline_constants` table or its initializer
228    // in the module.
229    let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
230        let literal = match module.types[override_.ty].inner {
231            TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
232            _ => unreachable!(),
233        };
234        let expr = module
235            .global_expressions
236            .append(Expression::Literal(literal), Span::UNDEFINED);
237        global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
238        expr
239    } else if let Some(init) = override_.init {
240        adjusted_global_expressions[init]
241    } else {
242        return Err(PipelineConstantError::MissingValue(key.to_string()));
243    };
244
245    // Generate a new `Constant` to represent the override's value.
246    let constant = Constant {
247        name: override_.name,
248        ty: override_.ty,
249        init,
250    };
251    let h = module.constants.append(constant, span);
252    override_map.insert(old_h, h);
253    adjusted_constant_initializers.insert(h);
254    Ok(h)
255}
256
257/// Replace all override expressions in `function` with fully-evaluated constants.
258///
259/// Replace all `Expression::Override`s in `function`'s expression arena with
260/// the corresponding `Expression::Constant`s, as given in `override_map`.
261/// Replace any expressions whose values are now known with their fully
262/// evaluated form.
263///
264/// If `h` is a `Handle<Override>`, then `override_map[h]` is the
265/// `Handle<Constant>` for the override's final value.
266fn process_function(
267    module: &mut Module,
268    override_map: &HandleVec<Override, Handle<Constant>>,
269    function: &mut Function,
270) -> Result<(), ConstantEvaluatorError> {
271    // A map from original local expression handles to
272    // handles in the new, local expression arena.
273    let mut adjusted_local_expressions = HandleVec::with_capacity(function.expressions.len());
274
275    let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
276
277    let mut expressions = mem::take(&mut function.expressions);
278
279    // Dummy `emitter` and `block` for the constant evaluator.
280    // We can ignore the concept of emitting expressions here since
281    // expressions have already been covered by a `Statement::Emit`
282    // in the frontend.
283    // The only thing we might have to do is remove some expressions
284    // that have been covered by a `Statement::Emit`. See the docs of
285    // `filter_emits_in_block` for the reasoning.
286    let mut emitter = Emitter::default();
287    let mut block = Block::new();
288
289    let mut evaluator = ConstantEvaluator::for_wgsl_function(
290        module,
291        &mut function.expressions,
292        &mut local_expression_kind_tracker,
293        &mut emitter,
294        &mut block,
295        false,
296    );
297
298    for (old_h, mut expr, span) in expressions.drain() {
299        if let Expression::Override(h) = expr {
300            expr = Expression::Constant(override_map[h]);
301        }
302        adjust_expr(&adjusted_local_expressions, &mut expr);
303        let h = evaluator.try_eval_and_append(expr, span)?;
304        adjusted_local_expressions.insert(old_h, h);
305    }
306
307    adjust_block(&adjusted_local_expressions, &mut function.body);
308
309    filter_emits_in_block(&mut function.body, &function.expressions);
310
311    // Update local expression initializers.
312    for (_, local) in function.local_variables.iter_mut() {
313        if let &mut Some(ref mut init) = &mut local.init {
314            *init = adjusted_local_expressions[*init];
315        }
316    }
317
318    // We've changed the keys of `function.named_expression`, so we have to
319    // rebuild it from scratch.
320    let named_expressions = mem::take(&mut function.named_expressions);
321    for (expr_h, name) in named_expressions {
322        function
323            .named_expressions
324            .insert(adjusted_local_expressions[expr_h], name);
325    }
326
327    Ok(())
328}
329
330/// Replace every expression handle in `expr` with its counterpart
331/// given by `new_pos`.
332fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut Expression) {
333    let adjust = |expr: &mut Handle<Expression>| {
334        *expr = new_pos[*expr];
335    };
336    match *expr {
337        Expression::Compose {
338            ref mut components,
339            ty: _,
340        } => {
341            for c in components.iter_mut() {
342                adjust(c);
343            }
344        }
345        Expression::Access {
346            ref mut base,
347            ref mut index,
348        } => {
349            adjust(base);
350            adjust(index);
351        }
352        Expression::AccessIndex {
353            ref mut base,
354            index: _,
355        } => {
356            adjust(base);
357        }
358        Expression::Splat {
359            ref mut value,
360            size: _,
361        } => {
362            adjust(value);
363        }
364        Expression::Swizzle {
365            ref mut vector,
366            size: _,
367            pattern: _,
368        } => {
369            adjust(vector);
370        }
371        Expression::Load { ref mut pointer } => {
372            adjust(pointer);
373        }
374        Expression::ImageSample {
375            ref mut image,
376            ref mut sampler,
377            ref mut coordinate,
378            ref mut array_index,
379            ref mut offset,
380            ref mut level,
381            ref mut depth_ref,
382            gather: _,
383        } => {
384            adjust(image);
385            adjust(sampler);
386            adjust(coordinate);
387            if let Some(e) = array_index.as_mut() {
388                adjust(e);
389            }
390            if let Some(e) = offset.as_mut() {
391                adjust(e);
392            }
393            match *level {
394                crate::SampleLevel::Exact(ref mut expr)
395                | crate::SampleLevel::Bias(ref mut expr) => {
396                    adjust(expr);
397                }
398                crate::SampleLevel::Gradient {
399                    ref mut x,
400                    ref mut y,
401                } => {
402                    adjust(x);
403                    adjust(y);
404                }
405                _ => {}
406            }
407            if let Some(e) = depth_ref.as_mut() {
408                adjust(e);
409            }
410        }
411        Expression::ImageLoad {
412            ref mut image,
413            ref mut coordinate,
414            ref mut array_index,
415            ref mut sample,
416            ref mut level,
417        } => {
418            adjust(image);
419            adjust(coordinate);
420            if let Some(e) = array_index.as_mut() {
421                adjust(e);
422            }
423            if let Some(e) = sample.as_mut() {
424                adjust(e);
425            }
426            if let Some(e) = level.as_mut() {
427                adjust(e);
428            }
429        }
430        Expression::ImageQuery {
431            ref mut image,
432            ref mut query,
433        } => {
434            adjust(image);
435            match *query {
436                crate::ImageQuery::Size { ref mut level } => {
437                    if let Some(e) = level.as_mut() {
438                        adjust(e);
439                    }
440                }
441                crate::ImageQuery::NumLevels
442                | crate::ImageQuery::NumLayers
443                | crate::ImageQuery::NumSamples => {}
444            }
445        }
446        Expression::Unary {
447            ref mut expr,
448            op: _,
449        } => {
450            adjust(expr);
451        }
452        Expression::Binary {
453            ref mut left,
454            ref mut right,
455            op: _,
456        } => {
457            adjust(left);
458            adjust(right);
459        }
460        Expression::Select {
461            ref mut condition,
462            ref mut accept,
463            ref mut reject,
464        } => {
465            adjust(condition);
466            adjust(accept);
467            adjust(reject);
468        }
469        Expression::Derivative {
470            ref mut expr,
471            axis: _,
472            ctrl: _,
473        } => {
474            adjust(expr);
475        }
476        Expression::Relational {
477            ref mut argument,
478            fun: _,
479        } => {
480            adjust(argument);
481        }
482        Expression::Math {
483            ref mut arg,
484            ref mut arg1,
485            ref mut arg2,
486            ref mut arg3,
487            fun: _,
488        } => {
489            adjust(arg);
490            if let Some(e) = arg1.as_mut() {
491                adjust(e);
492            }
493            if let Some(e) = arg2.as_mut() {
494                adjust(e);
495            }
496            if let Some(e) = arg3.as_mut() {
497                adjust(e);
498            }
499        }
500        Expression::As {
501            ref mut expr,
502            kind: _,
503            convert: _,
504        } => {
505            adjust(expr);
506        }
507        Expression::ArrayLength(ref mut expr) => {
508            adjust(expr);
509        }
510        Expression::RayQueryGetIntersection {
511            ref mut query,
512            committed: _,
513        } => {
514            adjust(query);
515        }
516        Expression::Literal(_)
517        | Expression::FunctionArgument(_)
518        | Expression::GlobalVariable(_)
519        | Expression::LocalVariable(_)
520        | Expression::CallResult(_)
521        | Expression::RayQueryProceedResult
522        | Expression::Constant(_)
523        | Expression::Override(_)
524        | Expression::ZeroValue(_)
525        | Expression::AtomicResult {
526            ty: _,
527            comparison: _,
528        }
529        | Expression::WorkGroupUniformLoadResult { ty: _ }
530        | Expression::SubgroupBallotResult
531        | Expression::SubgroupOperationResult { .. } => {}
532    }
533}
534
535/// Replace every expression handle in `block` with its counterpart
536/// given by `new_pos`.
537fn adjust_block(new_pos: &HandleVec<Expression, Handle<Expression>>, block: &mut Block) {
538    for stmt in block.iter_mut() {
539        adjust_stmt(new_pos, stmt);
540    }
541}
542
543/// Replace every expression handle in `stmt` with its counterpart
544/// given by `new_pos`.
545fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut Statement) {
546    let adjust = |expr: &mut Handle<Expression>| {
547        *expr = new_pos[*expr];
548    };
549    match *stmt {
550        Statement::Emit(ref mut range) => {
551            if let Some((mut first, mut last)) = range.first_and_last() {
552                adjust(&mut first);
553                adjust(&mut last);
554                *range = Range::new_from_bounds(first, last);
555            }
556        }
557        Statement::Block(ref mut block) => {
558            adjust_block(new_pos, block);
559        }
560        Statement::If {
561            ref mut condition,
562            ref mut accept,
563            ref mut reject,
564        } => {
565            adjust(condition);
566            adjust_block(new_pos, accept);
567            adjust_block(new_pos, reject);
568        }
569        Statement::Switch {
570            ref mut selector,
571            ref mut cases,
572        } => {
573            adjust(selector);
574            for case in cases.iter_mut() {
575                adjust_block(new_pos, &mut case.body);
576            }
577        }
578        Statement::Loop {
579            ref mut body,
580            ref mut continuing,
581            ref mut break_if,
582        } => {
583            adjust_block(new_pos, body);
584            adjust_block(new_pos, continuing);
585            if let Some(e) = break_if.as_mut() {
586                adjust(e);
587            }
588        }
589        Statement::Return { ref mut value } => {
590            if let Some(e) = value.as_mut() {
591                adjust(e);
592            }
593        }
594        Statement::Store {
595            ref mut pointer,
596            ref mut value,
597        } => {
598            adjust(pointer);
599            adjust(value);
600        }
601        Statement::ImageStore {
602            ref mut image,
603            ref mut coordinate,
604            ref mut array_index,
605            ref mut value,
606        } => {
607            adjust(image);
608            adjust(coordinate);
609            if let Some(e) = array_index.as_mut() {
610                adjust(e);
611            }
612            adjust(value);
613        }
614        Statement::Atomic {
615            ref mut pointer,
616            ref mut value,
617            ref mut result,
618            ref mut fun,
619        } => {
620            adjust(pointer);
621            adjust(value);
622            if let Some(ref mut result) = *result {
623                adjust(result);
624            }
625            match *fun {
626                crate::AtomicFunction::Exchange {
627                    compare: Some(ref mut compare),
628                } => {
629                    adjust(compare);
630                }
631                crate::AtomicFunction::Add
632                | crate::AtomicFunction::Subtract
633                | crate::AtomicFunction::And
634                | crate::AtomicFunction::ExclusiveOr
635                | crate::AtomicFunction::InclusiveOr
636                | crate::AtomicFunction::Min
637                | crate::AtomicFunction::Max
638                | crate::AtomicFunction::Exchange { compare: None } => {}
639            }
640        }
641        Statement::WorkGroupUniformLoad {
642            ref mut pointer,
643            ref mut result,
644        } => {
645            adjust(pointer);
646            adjust(result);
647        }
648        Statement::SubgroupBallot {
649            ref mut result,
650            ref mut predicate,
651        } => {
652            if let Some(ref mut predicate) = *predicate {
653                adjust(predicate);
654            }
655            adjust(result);
656        }
657        Statement::SubgroupCollectiveOperation {
658            ref mut argument,
659            ref mut result,
660            ..
661        } => {
662            adjust(argument);
663            adjust(result);
664        }
665        Statement::SubgroupGather {
666            ref mut mode,
667            ref mut argument,
668            ref mut result,
669        } => {
670            match *mode {
671                crate::GatherMode::BroadcastFirst => {}
672                crate::GatherMode::Broadcast(ref mut index)
673                | crate::GatherMode::Shuffle(ref mut index)
674                | crate::GatherMode::ShuffleDown(ref mut index)
675                | crate::GatherMode::ShuffleUp(ref mut index)
676                | crate::GatherMode::ShuffleXor(ref mut index) => {
677                    adjust(index);
678                }
679            }
680            adjust(argument);
681            adjust(result)
682        }
683        Statement::Call {
684            ref mut arguments,
685            ref mut result,
686            function: _,
687        } => {
688            for argument in arguments.iter_mut() {
689                adjust(argument);
690            }
691            if let Some(e) = result.as_mut() {
692                adjust(e);
693            }
694        }
695        Statement::RayQuery {
696            ref mut query,
697            ref mut fun,
698        } => {
699            adjust(query);
700            match *fun {
701                crate::RayQueryFunction::Initialize {
702                    ref mut acceleration_structure,
703                    ref mut descriptor,
704                } => {
705                    adjust(acceleration_structure);
706                    adjust(descriptor);
707                }
708                crate::RayQueryFunction::Proceed { ref mut result } => {
709                    adjust(result);
710                }
711                crate::RayQueryFunction::Terminate => {}
712            }
713        }
714        Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {}
715    }
716}
717
718/// Adjust [`Emit`] statements in `block` to skip [`needs_pre_emit`] expressions we have introduced.
719///
720/// According to validation, [`Emit`] statements must not cover any expressions
721/// for which [`Expression::needs_pre_emit`] returns true. All expressions built
722/// by successful constant evaluation fall into that category, meaning that
723/// `process_function` will usually rewrite [`Override`] expressions and those
724/// that use their values into pre-emitted expressions, leaving any [`Emit`]
725/// statements that cover them invalid.
726///
727/// This function rewrites all [`Emit`] statements into zero or more new
728/// [`Emit`] statements covering only those expressions in the original range
729/// that are not pre-emitted.
730///
731/// [`Emit`]: Statement::Emit
732/// [`needs_pre_emit`]: Expression::needs_pre_emit
733/// [`Override`]: Expression::Override
734fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) {
735    let original = mem::replace(block, Block::with_capacity(block.len()));
736    for (stmt, span) in original.span_into_iter() {
737        match stmt {
738            Statement::Emit(range) => {
739                let mut current = None;
740                for expr_h in range {
741                    if expressions[expr_h].needs_pre_emit() {
742                        if let Some((first, last)) = current {
743                            block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
744                        }
745
746                        current = None;
747                    } else if let Some((_, ref mut last)) = current {
748                        *last = expr_h;
749                    } else {
750                        current = Some((expr_h, expr_h));
751                    }
752                }
753                if let Some((first, last)) = current {
754                    block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
755                }
756            }
757            Statement::Block(mut child) => {
758                filter_emits_in_block(&mut child, expressions);
759                block.push(Statement::Block(child), span);
760            }
761            Statement::If {
762                condition,
763                mut accept,
764                mut reject,
765            } => {
766                filter_emits_in_block(&mut accept, expressions);
767                filter_emits_in_block(&mut reject, expressions);
768                block.push(
769                    Statement::If {
770                        condition,
771                        accept,
772                        reject,
773                    },
774                    span,
775                );
776            }
777            Statement::Switch {
778                selector,
779                mut cases,
780            } => {
781                for case in &mut cases {
782                    filter_emits_in_block(&mut case.body, expressions);
783                }
784                block.push(Statement::Switch { selector, cases }, span);
785            }
786            Statement::Loop {
787                mut body,
788                mut continuing,
789                break_if,
790            } => {
791                filter_emits_in_block(&mut body, expressions);
792                filter_emits_in_block(&mut continuing, expressions);
793                block.push(
794                    Statement::Loop {
795                        body,
796                        continuing,
797                        break_if,
798                    },
799                    span,
800                );
801            }
802            stmt => block.push(stmt.clone(), span),
803        }
804    }
805}
806
807fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
808    // note that in rust 0.0 == -0.0
809    match scalar {
810        Scalar::BOOL => {
811            // https://webidl.spec.whatwg.org/#js-boolean
812            let value = value != 0.0 && !value.is_nan();
813            Ok(Literal::Bool(value))
814        }
815        Scalar::I32 => {
816            // https://webidl.spec.whatwg.org/#js-long
817            if !value.is_finite() {
818                return Err(PipelineConstantError::SrcNeedsToBeFinite);
819            }
820
821            let value = value.trunc();
822            if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
823                return Err(PipelineConstantError::DstRangeTooSmall);
824            }
825
826            let value = value as i32;
827            Ok(Literal::I32(value))
828        }
829        Scalar::U32 => {
830            // https://webidl.spec.whatwg.org/#js-unsigned-long
831            if !value.is_finite() {
832                return Err(PipelineConstantError::SrcNeedsToBeFinite);
833            }
834
835            let value = value.trunc();
836            if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
837                return Err(PipelineConstantError::DstRangeTooSmall);
838            }
839
840            let value = value as u32;
841            Ok(Literal::U32(value))
842        }
843        Scalar::F32 => {
844            // https://webidl.spec.whatwg.org/#js-float
845            if !value.is_finite() {
846                return Err(PipelineConstantError::SrcNeedsToBeFinite);
847            }
848
849            let value = value as f32;
850            if !value.is_finite() {
851                return Err(PipelineConstantError::DstRangeTooSmall);
852            }
853
854            Ok(Literal::F32(value))
855        }
856        Scalar::F64 => {
857            // https://webidl.spec.whatwg.org/#js-double
858            if !value.is_finite() {
859                return Err(PipelineConstantError::SrcNeedsToBeFinite);
860            }
861
862            Ok(Literal::F64(value))
863        }
864        _ => unreachable!(),
865    }
866}
867
868#[test]
869fn test_map_value_to_literal() {
870    let bool_test_cases = [
871        (0.0, false),
872        (-0.0, false),
873        (f64::NAN, false),
874        (1.0, true),
875        (f64::INFINITY, true),
876        (f64::NEG_INFINITY, true),
877    ];
878    for (value, out) in bool_test_cases {
879        let res = Ok(Literal::Bool(out));
880        assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
881    }
882
883    for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
884        for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
885            let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
886            assert_eq!(map_value_to_literal(value, scalar), res);
887        }
888    }
889
890    // i32
891    assert_eq!(
892        map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
893        Ok(Literal::I32(i32::MIN))
894    );
895    assert_eq!(
896        map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
897        Ok(Literal::I32(i32::MAX))
898    );
899    assert_eq!(
900        map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
901        Err(PipelineConstantError::DstRangeTooSmall)
902    );
903    assert_eq!(
904        map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
905        Err(PipelineConstantError::DstRangeTooSmall)
906    );
907
908    // u32
909    assert_eq!(
910        map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
911        Ok(Literal::U32(u32::MIN))
912    );
913    assert_eq!(
914        map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
915        Ok(Literal::U32(u32::MAX))
916    );
917    assert_eq!(
918        map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
919        Err(PipelineConstantError::DstRangeTooSmall)
920    );
921    assert_eq!(
922        map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
923        Err(PipelineConstantError::DstRangeTooSmall)
924    );
925
926    // f32
927    assert_eq!(
928        map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
929        Ok(Literal::F32(f32::MIN))
930    );
931    assert_eq!(
932        map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
933        Ok(Literal::F32(f32::MAX))
934    );
935    assert_eq!(
936        map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
937        Ok(Literal::F32(f32::MIN))
938    );
939    assert_eq!(
940        map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
941        Ok(Literal::F32(f32::MAX))
942    );
943    assert_eq!(
944        map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
945        Err(PipelineConstantError::DstRangeTooSmall)
946    );
947    assert_eq!(
948        map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
949        Err(PipelineConstantError::DstRangeTooSmall)
950    );
951
952    // f64
953    assert_eq!(
954        map_value_to_literal(f64::MIN, Scalar::F64),
955        Ok(Literal::F64(f64::MIN))
956    );
957    assert_eq!(
958        map_value_to_literal(f64::MAX, Scalar::F64),
959        Ok(Literal::F64(f64::MAX))
960    );
961}