1use super::PipelineConstants;
2use crate::{
3 arena::HandleVec,
4 proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
5 valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
6 Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
7 Span, Statement, TypeInner, WithSpan,
8};
9use std::{borrow::Cow, collections::HashSet, mem};
10use thiserror::Error;
11
12#[derive(Error, Debug, Clone)]
13#[cfg_attr(test, derive(PartialEq))]
14pub enum PipelineConstantError {
15 #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
16 MissingValue(String),
17 #[error(
18 "Source f64 value needs to be finite ({}) for number destinations",
19 "NaNs and Inifinites are not allowed"
20 )]
21 SrcNeedsToBeFinite,
22 #[error("Source f64 value doesn't fit in destination")]
23 DstRangeTooSmall,
24 #[error(transparent)]
25 ConstantEvaluatorError(#[from] ConstantEvaluatorError),
26 #[error(transparent)]
27 ValidationError(#[from] WithSpan<ValidationError>),
28 #[error("workgroup_size override isn't strictly positive")]
29 NegativeWorkgroupSize,
30}
31
32pub fn process_overrides<'a>(
46 module: &'a Module,
47 module_info: &'a ModuleInfo,
48 pipeline_constants: &PipelineConstants,
49) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
50 if module.overrides.is_empty() {
51 return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
52 }
53
54 let mut module = module.clone();
55
56 let mut override_map = HandleVec::with_capacity(module.overrides.len());
59
60 let mut adjusted_global_expressions = HandleVec::with_capacity(module.global_expressions.len());
63
64 let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
75
76 let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
77
78 let mut override_iter = module.overrides.drain();
81
82 for (old_h, expr, span) in module.global_expressions.drain() {
112 let mut expr = match expr {
113 Expression::Override(h) => {
114 let c_h = if let Some(new_h) = override_map.get(h) {
115 *new_h
116 } else {
117 let mut new_h = None;
118 for entry in override_iter.by_ref() {
119 let stop = entry.0 == h;
120 new_h = Some(process_override(
121 entry,
122 pipeline_constants,
123 &mut module,
124 &mut override_map,
125 &adjusted_global_expressions,
126 &mut adjusted_constant_initializers,
127 &mut global_expression_kind_tracker,
128 )?);
129 if stop {
130 break;
131 }
132 }
133 new_h.unwrap()
134 };
135 Expression::Constant(c_h)
136 }
137 Expression::Constant(c_h) => {
138 if adjusted_constant_initializers.insert(c_h) {
139 let init = &mut module.constants[c_h].init;
140 *init = adjusted_global_expressions[*init];
141 }
142 expr
143 }
144 expr => expr,
145 };
146 let mut evaluator = ConstantEvaluator::for_wgsl_module(
147 &mut module,
148 &mut global_expression_kind_tracker,
149 false,
150 );
151 adjust_expr(&adjusted_global_expressions, &mut expr);
152 let h = evaluator.try_eval_and_append(expr, span)?;
153 adjusted_global_expressions.insert(old_h, h);
154 }
155
156 for entry in override_iter {
158 process_override(
159 entry,
160 pipeline_constants,
161 &mut module,
162 &mut override_map,
163 &adjusted_global_expressions,
164 &mut adjusted_constant_initializers,
165 &mut global_expression_kind_tracker,
166 )?;
167 }
168
169 for (_, c) in module
173 .constants
174 .iter_mut()
175 .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
176 {
177 c.init = adjusted_global_expressions[c.init];
178 }
179
180 for (_, v) in module.global_variables.iter_mut() {
181 if let Some(ref mut init) = v.init {
182 *init = adjusted_global_expressions[*init];
183 }
184 }
185
186 let mut functions = mem::take(&mut module.functions);
187 for (_, function) in functions.iter_mut() {
188 process_function(&mut module, &override_map, function)?;
189 }
190 module.functions = functions;
191
192 let mut entry_points = mem::take(&mut module.entry_points);
193 for ep in entry_points.iter_mut() {
194 process_function(&mut module, &override_map, &mut ep.function)?;
195 process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?;
196 }
197 module.entry_points = entry_points;
198
199 process_pending(&mut module, &override_map, &adjusted_global_expressions)?;
200
201 let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
205 let module_info = validator.validate_no_overrides(&module)?;
206
207 Ok((Cow::Owned(module), Cow::Owned(module_info)))
208}
209
210fn process_pending(
211 module: &mut Module,
212 override_map: &HandleVec<Override, Handle<Constant>>,
213 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
214) -> Result<(), PipelineConstantError> {
215 for (handle, ty) in module.types.clone().iter() {
216 if let TypeInner::Array {
217 base,
218 size: crate::ArraySize::Pending(size),
219 stride,
220 } = ty.inner
221 {
222 let expr = match size {
223 crate::PendingArraySize::Expression(size_expr) => {
224 adjusted_global_expressions[size_expr]
225 }
226 crate::PendingArraySize::Override(size_override) => {
227 module.constants[override_map[size_override]].init
228 }
229 };
230 let value = module
231 .to_ctx()
232 .eval_expr_to_u32(expr)
233 .map(|n| {
234 if n == 0 {
235 Err(PipelineConstantError::ValidationError(
236 WithSpan::new(ValidationError::ArraySizeError { handle: expr })
237 .with_span(
238 module.global_expressions.get_span(expr),
239 "evaluated to zero",
240 ),
241 ))
242 } else {
243 Ok(std::num::NonZeroU32::new(n).unwrap())
244 }
245 })
246 .map_err(|_| {
247 PipelineConstantError::ValidationError(
248 WithSpan::new(ValidationError::ArraySizeError { handle: expr })
249 .with_span(module.global_expressions.get_span(expr), "negative"),
250 )
251 })??;
252 module.types.replace(
253 handle,
254 crate::Type {
255 name: None,
256 inner: TypeInner::Array {
257 base,
258 size: crate::ArraySize::Constant(value),
259 stride,
260 },
261 },
262 );
263 }
264 }
265 Ok(())
266}
267
268fn process_workgroup_size_override(
269 module: &mut Module,
270 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
271 ep: &mut crate::EntryPoint,
272) -> Result<(), PipelineConstantError> {
273 match ep.workgroup_size_overrides {
274 None => {}
275 Some(overrides) => {
276 overrides.iter().enumerate().try_for_each(
277 |(i, overridden)| -> Result<(), PipelineConstantError> {
278 match *overridden {
279 None => Ok(()),
280 Some(h) => {
281 ep.workgroup_size[i] = module
282 .to_ctx()
283 .eval_expr_to_u32(adjusted_global_expressions[h])
284 .map(|n| {
285 if n == 0 {
286 Err(PipelineConstantError::NegativeWorkgroupSize)
287 } else {
288 Ok(n)
289 }
290 })
291 .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)??;
292 Ok(())
293 }
294 }
295 },
296 )?;
297 ep.workgroup_size_overrides = None;
298 }
299 }
300 Ok(())
301}
302
303fn process_override(
307 (old_h, override_, span): (Handle<Override>, Override, Span),
308 pipeline_constants: &PipelineConstants,
309 module: &mut Module,
310 override_map: &mut HandleVec<Override, Handle<Constant>>,
311 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
312 adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
313 global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
314) -> Result<Handle<Constant>, PipelineConstantError> {
315 let key = if let Some(id) = override_.id {
317 Cow::Owned(id.to_string())
318 } else if let Some(ref name) = override_.name {
319 Cow::Borrowed(name)
320 } else {
321 unreachable!();
322 };
323
324 let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
328 let literal = match module.types[override_.ty].inner {
329 TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
330 _ => unreachable!(),
331 };
332 let expr = module
333 .global_expressions
334 .append(Expression::Literal(literal), Span::UNDEFINED);
335 global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
336 expr
337 } else if let Some(init) = override_.init {
338 adjusted_global_expressions[init]
339 } else {
340 return Err(PipelineConstantError::MissingValue(key.to_string()));
341 };
342
343 let constant = Constant {
345 name: override_.name,
346 ty: override_.ty,
347 init,
348 };
349 let h = module.constants.append(constant, span);
350 override_map.insert(old_h, h);
351 adjusted_constant_initializers.insert(h);
352 Ok(h)
353}
354
355fn process_function(
365 module: &mut Module,
366 override_map: &HandleVec<Override, Handle<Constant>>,
367 function: &mut Function,
368) -> Result<(), ConstantEvaluatorError> {
369 let mut adjusted_local_expressions = HandleVec::with_capacity(function.expressions.len());
372
373 let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
374
375 let mut expressions = mem::take(&mut function.expressions);
376
377 let mut emitter = Emitter::default();
385 let mut block = Block::new();
386
387 let mut evaluator = ConstantEvaluator::for_wgsl_function(
388 module,
389 &mut function.expressions,
390 &mut local_expression_kind_tracker,
391 &mut emitter,
392 &mut block,
393 false,
394 );
395
396 for (old_h, mut expr, span) in expressions.drain() {
397 if let Expression::Override(h) = expr {
398 expr = Expression::Constant(override_map[h]);
399 }
400 adjust_expr(&adjusted_local_expressions, &mut expr);
401 let h = evaluator.try_eval_and_append(expr, span)?;
402 adjusted_local_expressions.insert(old_h, h);
403 }
404
405 adjust_block(&adjusted_local_expressions, &mut function.body);
406
407 filter_emits_in_block(&mut function.body, &function.expressions);
408
409 for (_, local) in function.local_variables.iter_mut() {
411 if let &mut Some(ref mut init) = &mut local.init {
412 *init = adjusted_local_expressions[*init];
413 }
414 }
415
416 let named_expressions = mem::take(&mut function.named_expressions);
419 for (expr_h, name) in named_expressions {
420 function
421 .named_expressions
422 .insert(adjusted_local_expressions[expr_h], name);
423 }
424
425 Ok(())
426}
427
428fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut Expression) {
431 let adjust = |expr: &mut Handle<Expression>| {
432 *expr = new_pos[*expr];
433 };
434 match *expr {
435 Expression::Compose {
436 ref mut components,
437 ty: _,
438 } => {
439 for c in components.iter_mut() {
440 adjust(c);
441 }
442 }
443 Expression::Access {
444 ref mut base,
445 ref mut index,
446 } => {
447 adjust(base);
448 adjust(index);
449 }
450 Expression::AccessIndex {
451 ref mut base,
452 index: _,
453 } => {
454 adjust(base);
455 }
456 Expression::Splat {
457 ref mut value,
458 size: _,
459 } => {
460 adjust(value);
461 }
462 Expression::Swizzle {
463 ref mut vector,
464 size: _,
465 pattern: _,
466 } => {
467 adjust(vector);
468 }
469 Expression::Load { ref mut pointer } => {
470 adjust(pointer);
471 }
472 Expression::ImageSample {
473 ref mut image,
474 ref mut sampler,
475 ref mut coordinate,
476 ref mut array_index,
477 ref mut offset,
478 ref mut level,
479 ref mut depth_ref,
480 gather: _,
481 } => {
482 adjust(image);
483 adjust(sampler);
484 adjust(coordinate);
485 if let Some(e) = array_index.as_mut() {
486 adjust(e);
487 }
488 if let Some(e) = offset.as_mut() {
489 adjust(e);
490 }
491 match *level {
492 crate::SampleLevel::Exact(ref mut expr)
493 | crate::SampleLevel::Bias(ref mut expr) => {
494 adjust(expr);
495 }
496 crate::SampleLevel::Gradient {
497 ref mut x,
498 ref mut y,
499 } => {
500 adjust(x);
501 adjust(y);
502 }
503 _ => {}
504 }
505 if let Some(e) = depth_ref.as_mut() {
506 adjust(e);
507 }
508 }
509 Expression::ImageLoad {
510 ref mut image,
511 ref mut coordinate,
512 ref mut array_index,
513 ref mut sample,
514 ref mut level,
515 } => {
516 adjust(image);
517 adjust(coordinate);
518 if let Some(e) = array_index.as_mut() {
519 adjust(e);
520 }
521 if let Some(e) = sample.as_mut() {
522 adjust(e);
523 }
524 if let Some(e) = level.as_mut() {
525 adjust(e);
526 }
527 }
528 Expression::ImageQuery {
529 ref mut image,
530 ref mut query,
531 } => {
532 adjust(image);
533 match *query {
534 crate::ImageQuery::Size { ref mut level } => {
535 if let Some(e) = level.as_mut() {
536 adjust(e);
537 }
538 }
539 crate::ImageQuery::NumLevels
540 | crate::ImageQuery::NumLayers
541 | crate::ImageQuery::NumSamples => {}
542 }
543 }
544 Expression::Unary {
545 ref mut expr,
546 op: _,
547 } => {
548 adjust(expr);
549 }
550 Expression::Binary {
551 ref mut left,
552 ref mut right,
553 op: _,
554 } => {
555 adjust(left);
556 adjust(right);
557 }
558 Expression::Select {
559 ref mut condition,
560 ref mut accept,
561 ref mut reject,
562 } => {
563 adjust(condition);
564 adjust(accept);
565 adjust(reject);
566 }
567 Expression::Derivative {
568 ref mut expr,
569 axis: _,
570 ctrl: _,
571 } => {
572 adjust(expr);
573 }
574 Expression::Relational {
575 ref mut argument,
576 fun: _,
577 } => {
578 adjust(argument);
579 }
580 Expression::Math {
581 ref mut arg,
582 ref mut arg1,
583 ref mut arg2,
584 ref mut arg3,
585 fun: _,
586 } => {
587 adjust(arg);
588 if let Some(e) = arg1.as_mut() {
589 adjust(e);
590 }
591 if let Some(e) = arg2.as_mut() {
592 adjust(e);
593 }
594 if let Some(e) = arg3.as_mut() {
595 adjust(e);
596 }
597 }
598 Expression::As {
599 ref mut expr,
600 kind: _,
601 convert: _,
602 } => {
603 adjust(expr);
604 }
605 Expression::ArrayLength(ref mut expr) => {
606 adjust(expr);
607 }
608 Expression::RayQueryGetIntersection {
609 ref mut query,
610 committed: _,
611 } => {
612 adjust(query);
613 }
614 Expression::Literal(_)
615 | Expression::FunctionArgument(_)
616 | Expression::GlobalVariable(_)
617 | Expression::LocalVariable(_)
618 | Expression::CallResult(_)
619 | Expression::RayQueryProceedResult
620 | Expression::Constant(_)
621 | Expression::Override(_)
622 | Expression::ZeroValue(_)
623 | Expression::AtomicResult {
624 ty: _,
625 comparison: _,
626 }
627 | Expression::WorkGroupUniformLoadResult { ty: _ }
628 | Expression::SubgroupBallotResult
629 | Expression::SubgroupOperationResult { .. } => {}
630 }
631}
632
633fn adjust_block(new_pos: &HandleVec<Expression, Handle<Expression>>, block: &mut Block) {
636 for stmt in block.iter_mut() {
637 adjust_stmt(new_pos, stmt);
638 }
639}
640
641fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut Statement) {
644 let adjust = |expr: &mut Handle<Expression>| {
645 *expr = new_pos[*expr];
646 };
647 match *stmt {
648 Statement::Emit(ref mut range) => {
649 if let Some((mut first, mut last)) = range.first_and_last() {
650 adjust(&mut first);
651 adjust(&mut last);
652 *range = Range::new_from_bounds(first, last);
653 }
654 }
655 Statement::Block(ref mut block) => {
656 adjust_block(new_pos, block);
657 }
658 Statement::If {
659 ref mut condition,
660 ref mut accept,
661 ref mut reject,
662 } => {
663 adjust(condition);
664 adjust_block(new_pos, accept);
665 adjust_block(new_pos, reject);
666 }
667 Statement::Switch {
668 ref mut selector,
669 ref mut cases,
670 } => {
671 adjust(selector);
672 for case in cases.iter_mut() {
673 adjust_block(new_pos, &mut case.body);
674 }
675 }
676 Statement::Loop {
677 ref mut body,
678 ref mut continuing,
679 ref mut break_if,
680 } => {
681 adjust_block(new_pos, body);
682 adjust_block(new_pos, continuing);
683 if let Some(e) = break_if.as_mut() {
684 adjust(e);
685 }
686 }
687 Statement::Return { ref mut value } => {
688 if let Some(e) = value.as_mut() {
689 adjust(e);
690 }
691 }
692 Statement::Store {
693 ref mut pointer,
694 ref mut value,
695 } => {
696 adjust(pointer);
697 adjust(value);
698 }
699 Statement::ImageStore {
700 ref mut image,
701 ref mut coordinate,
702 ref mut array_index,
703 ref mut value,
704 } => {
705 adjust(image);
706 adjust(coordinate);
707 if let Some(e) = array_index.as_mut() {
708 adjust(e);
709 }
710 adjust(value);
711 }
712 Statement::Atomic {
713 ref mut pointer,
714 ref mut value,
715 ref mut result,
716 ref mut fun,
717 } => {
718 adjust(pointer);
719 adjust(value);
720 if let Some(ref mut result) = *result {
721 adjust(result);
722 }
723 match *fun {
724 crate::AtomicFunction::Exchange {
725 compare: Some(ref mut compare),
726 } => {
727 adjust(compare);
728 }
729 crate::AtomicFunction::Add
730 | crate::AtomicFunction::Subtract
731 | crate::AtomicFunction::And
732 | crate::AtomicFunction::ExclusiveOr
733 | crate::AtomicFunction::InclusiveOr
734 | crate::AtomicFunction::Min
735 | crate::AtomicFunction::Max
736 | crate::AtomicFunction::Exchange { compare: None } => {}
737 }
738 }
739 Statement::ImageAtomic {
740 ref mut image,
741 ref mut coordinate,
742 ref mut array_index,
743 fun: _,
744 ref mut value,
745 } => {
746 adjust(image);
747 adjust(coordinate);
748 if let Some(ref mut array_index) = *array_index {
749 adjust(array_index);
750 }
751 adjust(value);
752 }
753 Statement::WorkGroupUniformLoad {
754 ref mut pointer,
755 ref mut result,
756 } => {
757 adjust(pointer);
758 adjust(result);
759 }
760 Statement::SubgroupBallot {
761 ref mut result,
762 ref mut predicate,
763 } => {
764 if let Some(ref mut predicate) = *predicate {
765 adjust(predicate);
766 }
767 adjust(result);
768 }
769 Statement::SubgroupCollectiveOperation {
770 ref mut argument,
771 ref mut result,
772 ..
773 } => {
774 adjust(argument);
775 adjust(result);
776 }
777 Statement::SubgroupGather {
778 ref mut mode,
779 ref mut argument,
780 ref mut result,
781 } => {
782 match *mode {
783 crate::GatherMode::BroadcastFirst => {}
784 crate::GatherMode::Broadcast(ref mut index)
785 | crate::GatherMode::Shuffle(ref mut index)
786 | crate::GatherMode::ShuffleDown(ref mut index)
787 | crate::GatherMode::ShuffleUp(ref mut index)
788 | crate::GatherMode::ShuffleXor(ref mut index) => {
789 adjust(index);
790 }
791 }
792 adjust(argument);
793 adjust(result)
794 }
795 Statement::Call {
796 ref mut arguments,
797 ref mut result,
798 function: _,
799 } => {
800 for argument in arguments.iter_mut() {
801 adjust(argument);
802 }
803 if let Some(e) = result.as_mut() {
804 adjust(e);
805 }
806 }
807 Statement::RayQuery {
808 ref mut query,
809 ref mut fun,
810 } => {
811 adjust(query);
812 match *fun {
813 crate::RayQueryFunction::Initialize {
814 ref mut acceleration_structure,
815 ref mut descriptor,
816 } => {
817 adjust(acceleration_structure);
818 adjust(descriptor);
819 }
820 crate::RayQueryFunction::Proceed { ref mut result } => {
821 adjust(result);
822 }
823 crate::RayQueryFunction::Terminate => {}
824 }
825 }
826 Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {}
827 }
828}
829
830fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) {
847 let original = mem::replace(block, Block::with_capacity(block.len()));
848 for (stmt, span) in original.span_into_iter() {
849 match stmt {
850 Statement::Emit(range) => {
851 let mut current = None;
852 for expr_h in range {
853 if expressions[expr_h].needs_pre_emit() {
854 if let Some((first, last)) = current {
855 block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
856 }
857
858 current = None;
859 } else if let Some((_, ref mut last)) = current {
860 *last = expr_h;
861 } else {
862 current = Some((expr_h, expr_h));
863 }
864 }
865 if let Some((first, last)) = current {
866 block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
867 }
868 }
869 Statement::Block(mut child) => {
870 filter_emits_in_block(&mut child, expressions);
871 block.push(Statement::Block(child), span);
872 }
873 Statement::If {
874 condition,
875 mut accept,
876 mut reject,
877 } => {
878 filter_emits_in_block(&mut accept, expressions);
879 filter_emits_in_block(&mut reject, expressions);
880 block.push(
881 Statement::If {
882 condition,
883 accept,
884 reject,
885 },
886 span,
887 );
888 }
889 Statement::Switch {
890 selector,
891 mut cases,
892 } => {
893 for case in &mut cases {
894 filter_emits_in_block(&mut case.body, expressions);
895 }
896 block.push(Statement::Switch { selector, cases }, span);
897 }
898 Statement::Loop {
899 mut body,
900 mut continuing,
901 break_if,
902 } => {
903 filter_emits_in_block(&mut body, expressions);
904 filter_emits_in_block(&mut continuing, expressions);
905 block.push(
906 Statement::Loop {
907 body,
908 continuing,
909 break_if,
910 },
911 span,
912 );
913 }
914 stmt => block.push(stmt.clone(), span),
915 }
916 }
917}
918
919fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
920 match scalar {
922 Scalar::BOOL => {
923 let value = value != 0.0 && !value.is_nan();
925 Ok(Literal::Bool(value))
926 }
927 Scalar::I32 => {
928 if !value.is_finite() {
930 return Err(PipelineConstantError::SrcNeedsToBeFinite);
931 }
932
933 let value = value.trunc();
934 if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
935 return Err(PipelineConstantError::DstRangeTooSmall);
936 }
937
938 let value = value as i32;
939 Ok(Literal::I32(value))
940 }
941 Scalar::U32 => {
942 if !value.is_finite() {
944 return Err(PipelineConstantError::SrcNeedsToBeFinite);
945 }
946
947 let value = value.trunc();
948 if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
949 return Err(PipelineConstantError::DstRangeTooSmall);
950 }
951
952 let value = value as u32;
953 Ok(Literal::U32(value))
954 }
955 Scalar::F32 => {
956 if !value.is_finite() {
958 return Err(PipelineConstantError::SrcNeedsToBeFinite);
959 }
960
961 let value = value as f32;
962 if !value.is_finite() {
963 return Err(PipelineConstantError::DstRangeTooSmall);
964 }
965
966 Ok(Literal::F32(value))
967 }
968 Scalar::F64 => {
969 if !value.is_finite() {
971 return Err(PipelineConstantError::SrcNeedsToBeFinite);
972 }
973
974 Ok(Literal::F64(value))
975 }
976 _ => unreachable!(),
977 }
978}
979
980#[test]
981fn test_map_value_to_literal() {
982 let bool_test_cases = [
983 (0.0, false),
984 (-0.0, false),
985 (f64::NAN, false),
986 (1.0, true),
987 (f64::INFINITY, true),
988 (f64::NEG_INFINITY, true),
989 ];
990 for (value, out) in bool_test_cases {
991 let res = Ok(Literal::Bool(out));
992 assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
993 }
994
995 for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
996 for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
997 let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
998 assert_eq!(map_value_to_literal(value, scalar), res);
999 }
1000 }
1001
1002 assert_eq!(
1004 map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
1005 Ok(Literal::I32(i32::MIN))
1006 );
1007 assert_eq!(
1008 map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
1009 Ok(Literal::I32(i32::MAX))
1010 );
1011 assert_eq!(
1012 map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
1013 Err(PipelineConstantError::DstRangeTooSmall)
1014 );
1015 assert_eq!(
1016 map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
1017 Err(PipelineConstantError::DstRangeTooSmall)
1018 );
1019
1020 assert_eq!(
1022 map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
1023 Ok(Literal::U32(u32::MIN))
1024 );
1025 assert_eq!(
1026 map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
1027 Ok(Literal::U32(u32::MAX))
1028 );
1029 assert_eq!(
1030 map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
1031 Err(PipelineConstantError::DstRangeTooSmall)
1032 );
1033 assert_eq!(
1034 map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
1035 Err(PipelineConstantError::DstRangeTooSmall)
1036 );
1037
1038 assert_eq!(
1040 map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
1041 Ok(Literal::F32(f32::MIN))
1042 );
1043 assert_eq!(
1044 map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
1045 Ok(Literal::F32(f32::MAX))
1046 );
1047 assert_eq!(
1048 map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
1049 Ok(Literal::F32(f32::MIN))
1050 );
1051 assert_eq!(
1052 map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
1053 Ok(Literal::F32(f32::MAX))
1054 );
1055 assert_eq!(
1056 map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
1057 Err(PipelineConstantError::DstRangeTooSmall)
1058 );
1059 assert_eq!(
1060 map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
1061 Err(PipelineConstantError::DstRangeTooSmall)
1062 );
1063
1064 assert_eq!(
1066 map_value_to_literal(f64::MIN, Scalar::F64),
1067 Ok(Literal::F64(f64::MIN))
1068 );
1069 assert_eq!(
1070 map_value_to_literal(f64::MAX, Scalar::F64),
1071 Ok(Literal::F64(f64::MAX))
1072 );
1073}