1use indexmap::IndexMap;
2use naga::{
3 Arena, AtomicFunction, Block, Constant, EntryPoint, Expression, Function, FunctionArgument,
4 FunctionResult, GatherMode, GlobalVariable, Handle, ImageQuery, LocalVariable, Module,
5 Override, SampleLevel, Span, Statement, StructMember, SwitchCase, Type, TypeInner, UniqueArena,
6};
7use std::{cell::RefCell, rc::Rc};
8
9#[derive(Debug, Default)]
10pub struct DerivedModule<'a> {
11 shader: Option<&'a Module>,
12 span_offset: usize,
13
14 type_map: IndexMap<Handle<Type>, Handle<Type>>,
16 const_map: IndexMap<Handle<Constant>, Handle<Constant>>,
18 pipeline_override_map: IndexMap<Handle<Override>, Handle<Override>>,
20 global_expressions: Rc<RefCell<Arena<Expression>>>,
23 global_expression_map: Rc<RefCell<IndexMap<Handle<Expression>, Handle<Expression>>>>,
26 global_map: IndexMap<Handle<GlobalVariable>, Handle<GlobalVariable>>,
27 function_map: IndexMap<String, Handle<Function>>,
28 types: UniqueArena<Type>,
29 constants: Arena<Constant>,
30 globals: Arena<GlobalVariable>,
31 functions: Arena<Function>,
32 pipeline_overrides: Arena<Override>,
33 special_types: naga::SpecialTypes,
34}
35
36impl<'a> DerivedModule<'a> {
37 pub fn set_shader_source(&mut self, shader: &'a Module, span_offset: usize) {
39 self.clear_shader_source();
40 self.shader = Some(shader);
41 self.span_offset = span_offset;
42
43 if let Some(h_special_type) = shader.special_types.ray_desc.as_ref() {
45 if let Some(derived_special_type) = self.special_types.ray_desc.as_ref() {
46 self.type_map.insert(*h_special_type, *derived_special_type);
47 } else {
48 self.special_types.ray_desc = Some(self.import_type(h_special_type));
49 }
50 }
51 if let Some(h_special_type) = shader.special_types.ray_intersection.as_ref() {
52 if let Some(derived_special_type) = self.special_types.ray_intersection.as_ref() {
53 self.type_map.insert(*h_special_type, *derived_special_type);
54 } else {
55 self.special_types.ray_intersection = Some(self.import_type(h_special_type));
56 }
57 }
58 if let Some(h_special_type) = shader.special_types.ray_vertex_return.as_ref() {
59 if let Some(derived_special_type) = self.special_types.ray_vertex_return.as_ref() {
60 self.type_map.insert(*h_special_type, *derived_special_type);
61 } else {
62 self.special_types.ray_vertex_return = Some(self.import_type(h_special_type));
63 }
64 }
65 for (predeclared, h_predeclared_type) in shader.special_types.predeclared_types.iter() {
66 if let Some(derived_special_type) =
67 self.special_types.predeclared_types.get(predeclared)
68 {
69 self.type_map
70 .insert(*h_predeclared_type, *derived_special_type);
71 } else {
72 let new_h = self.import_type(h_predeclared_type);
73 self.special_types
74 .predeclared_types
75 .insert(predeclared.clone(), new_h);
76 }
77 }
78 }
79
80 pub fn clear_shader_source(&mut self) {
82 self.shader = None;
83 self.type_map.clear();
84 self.const_map.clear();
85 self.global_map.clear();
86 self.global_expression_map.borrow_mut().clear();
87 self.pipeline_override_map.clear();
88 }
89
90 pub fn map_span(&self, span: Span) -> Span {
91 let span = span.to_range();
92 match span {
93 Some(rng) => Span::new(
94 (rng.start + self.span_offset) as u32,
95 (rng.end + self.span_offset) as u32,
96 ),
97 None => Span::UNDEFINED,
98 }
99 }
100
101 pub fn import_type(&mut self, h_type: &Handle<Type>) -> Handle<Type> {
103 self.rename_type(h_type, None)
104 }
105
106 pub fn rename_type(&mut self, h_type: &Handle<Type>, name: Option<String>) -> Handle<Type> {
108 self.type_map.get(h_type).copied().unwrap_or_else(|| {
109 let ty = self
110 .shader
111 .as_ref()
112 .unwrap()
113 .types
114 .get_handle(*h_type)
115 .unwrap();
116
117 let name = match name {
118 Some(name) => Some(name),
119 None => ty.name.clone(),
120 };
121
122 let new_type = Type {
123 name,
124 inner: match &ty.inner {
125 TypeInner::Scalar { .. }
126 | TypeInner::Vector { .. }
127 | TypeInner::Matrix { .. }
128 | TypeInner::ValuePointer { .. }
129 | TypeInner::Image { .. }
130 | TypeInner::Sampler { .. }
131 | TypeInner::Atomic { .. }
132 | TypeInner::AccelerationStructure { .. }
133 | TypeInner::RayQuery { .. } => ty.inner.clone(),
134
135 TypeInner::Pointer { base, space } => TypeInner::Pointer {
136 base: self.import_type(base),
137 space: *space,
138 },
139 TypeInner::Struct { members, span } => {
140 let members = members
141 .iter()
142 .map(|m| StructMember {
143 name: m.name.clone(),
144 ty: self.import_type(&m.ty),
145 binding: m.binding.clone(),
146 offset: m.offset,
147 })
148 .collect();
149 TypeInner::Struct {
150 members,
151 span: *span,
152 }
153 }
154 TypeInner::Array { base, size, stride } => TypeInner::Array {
155 base: self.import_type(base),
156 size: *size,
157 stride: *stride,
158 },
159 TypeInner::BindingArray { base, size } => TypeInner::BindingArray {
160 base: self.import_type(base),
161 size: *size,
162 },
163 },
164 };
165 let span = self.shader.as_ref().unwrap().types.get_span(*h_type);
166 let new_h = self.types.insert(new_type, self.map_span(span));
167 self.type_map.insert(*h_type, new_h);
168 new_h
169 })
170 }
171
172 pub fn import_const(&mut self, h_const: &Handle<Constant>) -> Handle<Constant> {
174 self.const_map.get(h_const).copied().unwrap_or_else(|| {
175 let c = self
176 .shader
177 .as_ref()
178 .unwrap()
179 .constants
180 .try_get(*h_const)
181 .unwrap();
182
183 let new_const = Constant {
184 name: c.name.clone(),
185 ty: self.import_type(&c.ty),
186 init: self.import_global_expression(c.init),
187 };
188
189 let span = self.shader.as_ref().unwrap().constants.get_span(*h_const);
190 let new_h = self
191 .constants
192 .fetch_or_append(new_const, self.map_span(span));
193 self.const_map.insert(*h_const, new_h);
194 new_h
195 })
196 }
197
198 pub fn import_global(&mut self, h_global: &Handle<GlobalVariable>) -> Handle<GlobalVariable> {
200 self.global_map.get(h_global).copied().unwrap_or_else(|| {
201 let gv = self
202 .shader
203 .as_ref()
204 .unwrap()
205 .global_variables
206 .try_get(*h_global)
207 .unwrap();
208
209 let new_global = GlobalVariable {
210 name: gv.name.clone(),
211 space: gv.space,
212 binding: gv.binding,
213 ty: self.import_type(&gv.ty),
214 init: gv.init.map(|c| self.import_global_expression(c)),
215 };
216
217 let span = self
218 .shader
219 .as_ref()
220 .unwrap()
221 .global_variables
222 .get_span(*h_global);
223 let new_h = self
224 .globals
225 .fetch_or_append(new_global, self.map_span(span));
226 self.global_map.insert(*h_global, new_h);
227 new_h
228 })
229 }
230
231 pub fn import_global_expression(&mut self, h_expr: Handle<Expression>) -> Handle<Expression> {
233 self.import_expression(
234 h_expr,
235 &self.shader.as_ref().unwrap().global_expressions,
236 self.global_expression_map.clone(),
237 self.global_expressions.clone(),
238 false,
239 true,
240 )
241 }
242
243 pub fn import_pipeline_override(&mut self, h_override: &Handle<Override>) -> Handle<Override> {
245 self.pipeline_override_map
246 .get(h_override)
247 .copied()
248 .unwrap_or_else(|| {
249 let pipeline_override = self
250 .shader
251 .as_ref()
252 .unwrap()
253 .overrides
254 .try_get(*h_override)
255 .unwrap();
256
257 let new_override = Override {
258 name: pipeline_override.name.clone(),
259 id: pipeline_override.id,
260 ty: self.import_type(&pipeline_override.ty),
261 init: pipeline_override
262 .init
263 .map(|init| self.import_global_expression(init)),
264 };
265
266 let span = self
267 .shader
268 .as_ref()
269 .unwrap()
270 .overrides
271 .get_span(*h_override);
272 let new_h = self
273 .pipeline_overrides
274 .fetch_or_append(new_override, self.map_span(span));
275 self.pipeline_override_map.insert(*h_override, new_h);
276 new_h
277 })
278 }
279
280 fn import_block(
282 &mut self,
283 block: &Block,
284 old_expressions: &Arena<Expression>,
285 already_imported: Rc<RefCell<IndexMap<Handle<Expression>, Handle<Expression>>>>,
286 new_expressions: Rc<RefCell<Arena<Expression>>>,
287 ) -> Block {
288 macro_rules! map_expr {
289 ($e:expr) => {
290 self.import_expression(
291 *$e,
292 old_expressions,
293 already_imported.clone(),
294 new_expressions.clone(),
295 false,
296 false,
297 )
298 };
299 }
300
301 macro_rules! map_expr_opt {
302 ($e:expr) => {
303 $e.as_ref().map(|expr| map_expr!(expr))
304 };
305 }
306
307 macro_rules! map_block {
308 ($b:expr) => {
309 self.import_block(
310 $b,
311 old_expressions,
312 already_imported.clone(),
313 new_expressions.clone(),
314 )
315 };
316 }
317
318 let statements = block
319 .iter()
320 .map(|stmt| {
321 match stmt {
322 Statement::Call {
324 function,
325 arguments,
326 result,
327 } => Statement::Call {
328 function: self.map_function_handle(function),
329 arguments: arguments.iter().map(|expr| map_expr!(expr)).collect(),
330 result: result.as_ref().map(|result| map_expr!(result)),
331 },
332
333 Statement::Block(b) => Statement::Block(map_block!(b)),
335 Statement::If {
336 condition,
337 accept,
338 reject,
339 } => Statement::If {
340 condition: map_expr!(condition),
341 accept: map_block!(accept),
342 reject: map_block!(reject),
343 },
344 Statement::Switch { selector, cases } => Statement::Switch {
345 selector: map_expr!(selector),
346 cases: cases
347 .iter()
348 .map(|case| SwitchCase {
349 value: case.value,
350 body: map_block!(&case.body),
351 fall_through: case.fall_through,
352 })
353 .collect(),
354 },
355 Statement::Loop {
356 body,
357 continuing,
358 break_if,
359 } => Statement::Loop {
360 body: map_block!(body),
361 continuing: map_block!(continuing),
362 break_if: map_expr_opt!(break_if),
363 },
364
365 Statement::Emit(exprs) => {
367 for expr in exprs.clone() {
369 self.import_expression(
370 expr,
371 old_expressions,
372 already_imported.clone(),
373 new_expressions.clone(),
374 true,
375 false,
376 );
377 }
378 let old_length = new_expressions.borrow().len();
379 for expr in exprs.clone() {
381 map_expr!(&expr);
382 }
383
384 Statement::Emit(new_expressions.borrow().range_from(old_length))
385 }
386 Statement::Store { pointer, value } => Statement::Store {
387 pointer: map_expr!(pointer),
388 value: map_expr!(value),
389 },
390 Statement::ImageStore {
391 image,
392 coordinate,
393 array_index,
394 value,
395 } => Statement::ImageStore {
396 image: map_expr!(image),
397 coordinate: map_expr!(coordinate),
398 array_index: map_expr_opt!(array_index),
399 value: map_expr!(value),
400 },
401 Statement::Atomic {
402 pointer,
403 fun,
404 value,
405 result,
406 } => {
407 let fun = match fun {
408 AtomicFunction::Exchange {
409 compare: Some(compare_expr),
410 } => AtomicFunction::Exchange {
411 compare: Some(map_expr!(compare_expr)),
412 },
413 fun => *fun,
414 };
415 Statement::Atomic {
416 pointer: map_expr!(pointer),
417 fun,
418 value: map_expr!(value),
419 result: map_expr_opt!(result),
420 }
421 }
422 Statement::WorkGroupUniformLoad { pointer, result } => {
423 Statement::WorkGroupUniformLoad {
424 pointer: map_expr!(pointer),
425 result: map_expr!(result),
426 }
427 }
428 Statement::Return { value } => Statement::Return {
429 value: map_expr_opt!(value),
430 },
431 Statement::RayQuery { query, fun } => Statement::RayQuery {
432 query: map_expr!(query),
433 fun: match fun {
434 naga::RayQueryFunction::Initialize {
435 acceleration_structure,
436 descriptor,
437 } => naga::RayQueryFunction::Initialize {
438 acceleration_structure: map_expr!(acceleration_structure),
439 descriptor: map_expr!(descriptor),
440 },
441 naga::RayQueryFunction::Proceed { result } => {
442 naga::RayQueryFunction::Proceed {
443 result: map_expr!(result),
444 }
445 }
446 naga::RayQueryFunction::Terminate => naga::RayQueryFunction::Terminate,
447 naga::RayQueryFunction::GenerateIntersection { hit_t } => {
448 naga::RayQueryFunction::GenerateIntersection {
449 hit_t: map_expr!(hit_t),
450 }
451 }
452 naga::RayQueryFunction::ConfirmIntersection => {
453 naga::RayQueryFunction::ConfirmIntersection
454 }
455 },
456 },
457 Statement::SubgroupBallot { result, predicate } => Statement::SubgroupBallot {
458 result: map_expr!(result),
459 predicate: map_expr_opt!(predicate),
460 },
461 Statement::SubgroupGather {
462 mut mode,
463 argument,
464 result,
465 } => {
466 match mode {
467 GatherMode::BroadcastFirst => (),
468 GatherMode::Broadcast(ref mut h_src)
469 | GatherMode::QuadBroadcast(ref mut h_src)
470 | GatherMode::Shuffle(ref mut h_src)
471 | GatherMode::ShuffleDown(ref mut h_src)
472 | GatherMode::ShuffleUp(ref mut h_src)
473 | GatherMode::ShuffleXor(ref mut h_src) => *h_src = map_expr!(h_src),
474 GatherMode::QuadSwap(_) => (),
475 };
476 Statement::SubgroupGather {
477 mode,
478 argument: map_expr!(argument),
479 result: map_expr!(result),
480 }
481 }
482 Statement::SubgroupCollectiveOperation {
483 op,
484 collective_op,
485 argument,
486 result,
487 } => Statement::SubgroupCollectiveOperation {
488 op: *op,
489 collective_op: *collective_op,
490 argument: map_expr!(argument),
491 result: map_expr!(result),
492 },
493 Statement::ImageAtomic {
494 image,
495 coordinate,
496 array_index,
497 fun,
498 value,
499 } => {
500 let fun = match fun {
501 AtomicFunction::Exchange {
502 compare: Some(compare_expr),
503 } => AtomicFunction::Exchange {
504 compare: Some(map_expr!(compare_expr)),
505 },
506 fun => *fun,
507 };
508 Statement::ImageAtomic {
509 image: map_expr!(image),
510 coordinate: map_expr!(coordinate),
511 array_index: map_expr_opt!(array_index),
512 fun,
513 value: map_expr!(value),
514 }
515 }
516 Statement::Break
518 | Statement::Continue
519 | Statement::Kill
520 | Statement::MemoryBarrier(_)
521 | Statement::ControlBarrier(_) => stmt.clone(),
522 }
523 })
524 .collect();
525
526 let mut new_block = Block::from_vec(statements);
527
528 for ((_, new_span), (_, old_span)) in new_block.span_iter_mut().zip(block.span_iter()) {
529 *new_span.unwrap() = self.map_span(*old_span);
530 }
531
532 new_block
533 }
534
535 fn import_expression(
536 &mut self,
537 h_expr: Handle<Expression>,
538 old_expressions: &Arena<Expression>,
539 already_imported: Rc<RefCell<IndexMap<Handle<Expression>, Handle<Expression>>>>,
540 new_expressions: Rc<RefCell<Arena<Expression>>>,
541 non_emitting_only: bool, unique: bool, ) -> Handle<Expression> {
544 if let Some(h_new) = already_imported.borrow().get(&h_expr) {
545 return *h_new;
546 }
547
548 macro_rules! map_expr {
549 ($e:expr) => {
550 self.import_expression(
551 *$e,
552 old_expressions,
553 already_imported.clone(),
554 new_expressions.clone(),
555 non_emitting_only,
556 unique,
557 )
558 };
559 }
560
561 macro_rules! map_expr_opt {
562 ($e:expr) => {
563 $e.as_ref().map(|expr| {
564 self.import_expression(
565 *expr,
566 old_expressions,
567 already_imported.clone(),
568 new_expressions.clone(),
569 non_emitting_only,
570 unique,
571 )
572 })
573 };
574 }
575
576 let mut is_external = false;
577 let expr = old_expressions.try_get(h_expr).unwrap();
578 let expr = match expr {
579 Expression::Literal(_) => {
580 is_external = true;
581 expr.clone()
582 }
583 Expression::ZeroValue(zv) => {
584 is_external = true;
585 Expression::ZeroValue(self.import_type(zv))
586 }
587 Expression::CallResult(f) => Expression::CallResult(self.map_function_handle(f)),
588 Expression::Constant(c) => {
589 is_external = true;
590 Expression::Constant(self.import_const(c))
591 }
592 Expression::Compose { ty, components } => Expression::Compose {
593 ty: self.import_type(ty),
594 components: components.iter().map(|expr| map_expr!(expr)).collect(),
595 },
596 Expression::GlobalVariable(gv) => {
597 is_external = true;
598 Expression::GlobalVariable(self.import_global(gv))
599 }
600 Expression::ImageSample {
601 image,
602 sampler,
603 gather,
604 coordinate,
605 array_index,
606 offset,
607 level,
608 depth_ref,
609 clamp_to_edge,
610 } => Expression::ImageSample {
611 image: map_expr!(image),
612 sampler: map_expr!(sampler),
613 gather: *gather,
614 coordinate: map_expr!(coordinate),
615 array_index: map_expr_opt!(array_index),
616 offset: map_expr_opt!(offset),
617 level: match level {
618 SampleLevel::Auto | SampleLevel::Zero => *level,
619 SampleLevel::Exact(expr) => SampleLevel::Exact(map_expr!(expr)),
620 SampleLevel::Bias(expr) => SampleLevel::Bias(map_expr!(expr)),
621 SampleLevel::Gradient { x, y } => SampleLevel::Gradient {
622 x: map_expr!(x),
623 y: map_expr!(y),
624 },
625 },
626 depth_ref: map_expr_opt!(depth_ref),
627 clamp_to_edge: *clamp_to_edge,
628 },
629 Expression::Access { base, index } => Expression::Access {
630 base: map_expr!(base),
631 index: map_expr!(index),
632 },
633 Expression::AccessIndex { base, index } => Expression::AccessIndex {
634 base: map_expr!(base),
635 index: *index,
636 },
637 Expression::Splat { size, value } => Expression::Splat {
638 size: *size,
639 value: map_expr!(value),
640 },
641 Expression::Swizzle {
642 size,
643 vector,
644 pattern,
645 } => Expression::Swizzle {
646 size: *size,
647 vector: map_expr!(vector),
648 pattern: *pattern,
649 },
650 Expression::Load { pointer } => Expression::Load {
651 pointer: map_expr!(pointer),
652 },
653 Expression::ImageLoad {
654 image,
655 coordinate,
656 array_index,
657 sample,
658 level,
659 } => Expression::ImageLoad {
660 image: map_expr!(image),
661 coordinate: map_expr!(coordinate),
662 array_index: map_expr_opt!(array_index),
663 sample: map_expr_opt!(sample),
664 level: map_expr_opt!(level),
665 },
666 Expression::ImageQuery { image, query } => Expression::ImageQuery {
667 image: map_expr!(image),
668 query: match query {
669 ImageQuery::Size { level } => ImageQuery::Size {
670 level: map_expr_opt!(level),
671 },
672 _ => *query,
673 },
674 },
675 Expression::Unary { op, expr } => Expression::Unary {
676 op: *op,
677 expr: map_expr!(expr),
678 },
679 Expression::Binary { op, left, right } => Expression::Binary {
680 op: *op,
681 left: map_expr!(left),
682 right: map_expr!(right),
683 },
684 Expression::Select {
685 condition,
686 accept,
687 reject,
688 } => Expression::Select {
689 condition: map_expr!(condition),
690 accept: map_expr!(accept),
691 reject: map_expr!(reject),
692 },
693 Expression::Derivative { axis, expr, ctrl } => Expression::Derivative {
694 axis: *axis,
695 expr: map_expr!(expr),
696 ctrl: *ctrl,
697 },
698 Expression::Relational { fun, argument } => Expression::Relational {
699 fun: *fun,
700 argument: map_expr!(argument),
701 },
702 Expression::Math {
703 fun,
704 arg,
705 arg1,
706 arg2,
707 arg3,
708 } => Expression::Math {
709 fun: *fun,
710 arg: map_expr!(arg),
711 arg1: map_expr_opt!(arg1),
712 arg2: map_expr_opt!(arg2),
713 arg3: map_expr_opt!(arg3),
714 },
715 Expression::As {
716 expr,
717 kind,
718 convert,
719 } => Expression::As {
720 expr: map_expr!(expr),
721 kind: *kind,
722 convert: *convert,
723 },
724 Expression::ArrayLength(expr) => Expression::ArrayLength(map_expr!(expr)),
725
726 Expression::LocalVariable(_) | Expression::FunctionArgument(_) => {
727 is_external = true;
728 expr.clone()
729 }
730
731 Expression::AtomicResult { ty, comparison } => Expression::AtomicResult {
732 ty: self.import_type(ty),
733 comparison: *comparison,
734 },
735 Expression::WorkGroupUniformLoadResult { ty } => {
736 Expression::WorkGroupUniformLoadResult {
737 ty: self.import_type(ty),
738 }
739 }
740 Expression::RayQueryProceedResult => expr.clone(),
741 Expression::RayQueryGetIntersection { query, committed } => {
742 Expression::RayQueryGetIntersection {
743 query: map_expr!(query),
744 committed: *committed,
745 }
746 }
747 Expression::RayQueryVertexPositions { query, committed } => {
748 Expression::RayQueryVertexPositions {
749 query: map_expr!(query),
750 committed: *committed,
751 }
752 }
753 Expression::Override(h_override) => {
754 is_external = true;
755 Expression::Override(self.import_pipeline_override(h_override))
756 }
757 Expression::SubgroupBallotResult => expr.clone(),
758 Expression::SubgroupOperationResult { ty } => Expression::SubgroupOperationResult {
759 ty: self.import_type(ty),
760 },
761 };
762
763 if !non_emitting_only || is_external {
764 let span = old_expressions.get_span(h_expr);
765 let h_new = if unique {
766 new_expressions.borrow_mut().fetch_if_or_append(
767 expr,
768 self.map_span(span),
769 |lhs, rhs| lhs == rhs,
770 )
771 } else {
772 new_expressions
773 .borrow_mut()
774 .append(expr, self.map_span(span))
775 };
776
777 already_imported.borrow_mut().insert(h_expr, h_new);
778 h_new
779 } else {
780 h_expr
781 }
782 }
783
784 pub fn localize_function(&mut self, func: &Function) -> Function {
786 let arguments = func
787 .arguments
788 .iter()
789 .map(|arg| FunctionArgument {
790 name: arg.name.clone(),
791 ty: self.import_type(&arg.ty),
792 binding: arg.binding.clone(),
793 })
794 .collect();
795
796 let result = func.result.as_ref().map(|r| FunctionResult {
797 ty: self.import_type(&r.ty),
798 binding: r.binding.clone(),
799 });
800
801 let expressions = Rc::new(RefCell::new(Arena::new()));
802 let expr_map = Rc::new(RefCell::new(IndexMap::new()));
803
804 let mut local_variables = Arena::new();
805 for (h_l, l) in func.local_variables.iter() {
806 let new_local = LocalVariable {
807 name: l.name.clone(),
808 ty: self.import_type(&l.ty),
809 init: l.init.map(|c| {
810 self.import_expression(
811 c,
812 &func.expressions,
813 expr_map.clone(),
814 expressions.clone(),
815 false,
816 true,
817 )
818 }),
819 };
820 let span = func.local_variables.get_span(h_l);
821 let new_h = local_variables.append(new_local, self.map_span(span));
822 assert_eq!(h_l, new_h);
823 }
824
825 let body = self.import_block(
826 &func.body,
827 &func.expressions,
828 expr_map.clone(),
829 expressions.clone(),
830 );
831
832 let named_expressions = func
833 .named_expressions
834 .iter()
835 .flat_map(|(h_expr, name)| {
836 expr_map
837 .borrow()
838 .get(h_expr)
839 .map(|new_h| (*new_h, name.clone()))
840 })
841 .collect::<IndexMap<_, _, std::hash::BuildHasherDefault<rustc_hash::FxHasher>>>();
842
843 Function {
844 name: func.name.clone(),
845 arguments,
846 result,
847 local_variables,
848 expressions: Rc::try_unwrap(expressions).unwrap().into_inner(),
849 named_expressions,
850 body,
851 diagnostic_filter_leaf: None,
852 }
853 }
854
855 pub fn import_function(&mut self, func: &Function, span: Span) -> Handle<Function> {
859 let name = func.name.as_ref().unwrap().clone();
860 let mapped_func = self.localize_function(func);
861 let new_span = self.map_span(span);
862 let new_h = self.functions.append(mapped_func, new_span);
863 self.function_map.insert(name, new_h);
864 new_h
865 }
866
867 pub fn map_function_handle(&mut self, h_func: &Handle<Function>) -> Handle<Function> {
870 let functions = &self.shader.as_ref().unwrap().functions;
871 let func = functions.try_get(*h_func).unwrap();
872 let name = func.name.as_ref().unwrap();
873 self.function_map.get(name).copied().unwrap_or_else(|| {
874 let span = functions.get_span(*h_func);
875 self.import_function(func, span)
876 })
877 }
878
879 pub fn import_function_if_new(&mut self, func: &Function, span: Span) -> Handle<Function> {
882 let name = func.name.as_ref().unwrap().clone();
883 if let Some(h) = self.function_map.get(&name) {
884 return *h;
885 }
886
887 self.import_function(func, span)
888 }
889
890 pub fn has_required_special_types(&self) -> bool {
892 !self.special_types.predeclared_types.is_empty()
893 || self.special_types.ray_desc.is_some()
894 || self.special_types.ray_intersection.is_some()
895 || self.special_types.ray_vertex_return.is_some()
896 }
897
898 pub fn into_module_with_entrypoints(mut self) -> naga::Module {
899 let entry_points = self
900 .shader
901 .unwrap()
902 .entry_points
903 .iter()
904 .map(|ep| EntryPoint {
905 name: ep.name.clone(),
906 stage: ep.stage,
907 early_depth_test: ep.early_depth_test,
908 workgroup_size: ep.workgroup_size,
909 function: self.localize_function(&ep.function),
910 workgroup_size_overrides: ep.workgroup_size_overrides,
911 })
912 .collect();
913
914 naga::Module {
915 entry_points,
916 ..self.into()
917 }
918 }
919}
920
921impl From<DerivedModule<'_>> for naga::Module {
922 fn from(derived: DerivedModule) -> Self {
923 naga::Module {
924 types: derived.types,
925 constants: derived.constants,
926 global_variables: derived.globals,
927 global_expressions: Rc::try_unwrap(derived.global_expressions)
928 .unwrap()
929 .into_inner(),
930 functions: derived.functions,
931 special_types: derived.special_types,
932 entry_points: Default::default(),
933 overrides: derived.pipeline_overrides,
934 diagnostic_filters: Default::default(),
935 diagnostic_filter_leaf: None,
936 doc_comments: None,
937 }
938 }
939}