1use indexmap::IndexMap;
2use naga::{
3 Arena, AtomicFunction, Block, Constant, EntryPoint, Expression, Function, FunctionArgument,
4 FunctionResult, GatherMode, GlobalVariable, Handle, ImageQuery, LocalVariable, Module,
5 Override, SampleLevel, Span, Statement, StructMember, SwitchCase, Type, TypeInner, UniqueArena,
6};
7use std::{cell::RefCell, rc::Rc};
8
9#[derive(Debug, Default)]
10pub struct DerivedModule<'a> {
11 shader: Option<&'a Module>,
12 span_offset: usize,
13
14 type_map: IndexMap<Handle<Type>, Handle<Type>>,
16 const_map: IndexMap<Handle<Constant>, Handle<Constant>>,
18 pipeline_override_map: IndexMap<Handle<Override>, Handle<Override>>,
20 global_expressions: Rc<RefCell<Arena<Expression>>>,
23 global_expression_map: Rc<RefCell<IndexMap<Handle<Expression>, Handle<Expression>>>>,
26 global_map: IndexMap<Handle<GlobalVariable>, Handle<GlobalVariable>>,
27 function_map: IndexMap<String, Handle<Function>>,
28 types: UniqueArena<Type>,
29 constants: Arena<Constant>,
30 globals: Arena<GlobalVariable>,
31 functions: Arena<Function>,
32 pipeline_overrides: Arena<Override>,
33 special_types: naga::SpecialTypes,
34}
35
36impl<'a> DerivedModule<'a> {
37 pub fn set_shader_source(&mut self, shader: &'a Module, span_offset: usize) {
39 self.clear_shader_source();
40 self.shader = Some(shader);
41 self.span_offset = span_offset;
42
43 if let Some(h_special_type) = shader.special_types.ray_desc.as_ref() {
45 if let Some(derived_special_type) = self.special_types.ray_desc.as_ref() {
46 self.type_map.insert(*h_special_type, *derived_special_type);
47 } else {
48 self.special_types.ray_desc = Some(self.import_type(h_special_type));
49 }
50 }
51 if let Some(h_special_type) = shader.special_types.ray_intersection.as_ref() {
52 if let Some(derived_special_type) = self.special_types.ray_intersection.as_ref() {
53 self.type_map.insert(*h_special_type, *derived_special_type);
54 } else {
55 self.special_types.ray_intersection = Some(self.import_type(h_special_type));
56 }
57 }
58 for (predeclared, h_predeclared_type) in shader.special_types.predeclared_types.iter() {
59 if let Some(derived_special_type) =
60 self.special_types.predeclared_types.get(predeclared)
61 {
62 self.type_map
63 .insert(*h_predeclared_type, *derived_special_type);
64 } else {
65 let new_h = self.import_type(h_predeclared_type);
66 self.special_types
67 .predeclared_types
68 .insert(predeclared.clone(), new_h);
69 }
70 }
71 }
72
73 pub fn clear_shader_source(&mut self) {
75 self.shader = None;
76 self.type_map.clear();
77 self.const_map.clear();
78 self.global_map.clear();
79 self.global_expression_map.borrow_mut().clear();
80 self.pipeline_override_map.clear();
81 }
82
83 pub fn map_span(&self, span: Span) -> Span {
84 let span = span.to_range();
85 match span {
86 Some(rng) => Span::new(
87 (rng.start + self.span_offset) as u32,
88 (rng.end + self.span_offset) as u32,
89 ),
90 None => Span::UNDEFINED,
91 }
92 }
93
94 pub fn import_type(&mut self, h_type: &Handle<Type>) -> Handle<Type> {
96 self.rename_type(h_type, None)
97 }
98
99 pub fn rename_type(&mut self, h_type: &Handle<Type>, name: Option<String>) -> Handle<Type> {
101 self.type_map.get(h_type).copied().unwrap_or_else(|| {
102 let ty = self
103 .shader
104 .as_ref()
105 .unwrap()
106 .types
107 .get_handle(*h_type)
108 .unwrap();
109
110 let name = match name {
111 Some(name) => Some(name),
112 None => ty.name.clone(),
113 };
114
115 let new_type = Type {
116 name,
117 inner: match &ty.inner {
118 TypeInner::Scalar { .. }
119 | TypeInner::Vector { .. }
120 | TypeInner::Matrix { .. }
121 | TypeInner::ValuePointer { .. }
122 | TypeInner::Image { .. }
123 | TypeInner::Sampler { .. }
124 | TypeInner::Atomic { .. }
125 | TypeInner::AccelerationStructure
126 | TypeInner::RayQuery => ty.inner.clone(),
127
128 TypeInner::Pointer { base, space } => TypeInner::Pointer {
129 base: self.import_type(base),
130 space: *space,
131 },
132 TypeInner::Struct { members, span } => {
133 let members = members
134 .iter()
135 .map(|m| StructMember {
136 name: m.name.clone(),
137 ty: self.import_type(&m.ty),
138 binding: m.binding.clone(),
139 offset: m.offset,
140 })
141 .collect();
142 TypeInner::Struct {
143 members,
144 span: *span,
145 }
146 }
147 TypeInner::Array { base, size, stride } => TypeInner::Array {
148 base: self.import_type(base),
149 size: *size,
150 stride: *stride,
151 },
152 TypeInner::BindingArray { base, size } => TypeInner::BindingArray {
153 base: self.import_type(base),
154 size: *size,
155 },
156 },
157 };
158 let span = self.shader.as_ref().unwrap().types.get_span(*h_type);
159 let new_h = self.types.insert(new_type, self.map_span(span));
160 self.type_map.insert(*h_type, new_h);
161 new_h
162 })
163 }
164
165 pub fn import_const(&mut self, h_const: &Handle<Constant>) -> Handle<Constant> {
167 self.const_map.get(h_const).copied().unwrap_or_else(|| {
168 let c = self
169 .shader
170 .as_ref()
171 .unwrap()
172 .constants
173 .try_get(*h_const)
174 .unwrap();
175
176 let new_const = Constant {
177 name: c.name.clone(),
178 ty: self.import_type(&c.ty),
179 init: self.import_global_expression(c.init),
180 };
181
182 let span = self.shader.as_ref().unwrap().constants.get_span(*h_const);
183 let new_h = self
184 .constants
185 .fetch_or_append(new_const, self.map_span(span));
186 self.const_map.insert(*h_const, new_h);
187 new_h
188 })
189 }
190
191 pub fn import_global(&mut self, h_global: &Handle<GlobalVariable>) -> Handle<GlobalVariable> {
193 self.global_map.get(h_global).copied().unwrap_or_else(|| {
194 let gv = self
195 .shader
196 .as_ref()
197 .unwrap()
198 .global_variables
199 .try_get(*h_global)
200 .unwrap();
201
202 let new_global = GlobalVariable {
203 name: gv.name.clone(),
204 space: gv.space,
205 binding: gv.binding.clone(),
206 ty: self.import_type(&gv.ty),
207 init: gv.init.map(|c| self.import_global_expression(c)),
208 };
209
210 let span = self
211 .shader
212 .as_ref()
213 .unwrap()
214 .global_variables
215 .get_span(*h_global);
216 let new_h = self
217 .globals
218 .fetch_or_append(new_global, self.map_span(span));
219 self.global_map.insert(*h_global, new_h);
220 new_h
221 })
222 }
223
224 pub fn import_global_expression(&mut self, h_expr: Handle<Expression>) -> Handle<Expression> {
226 self.import_expression(
227 h_expr,
228 &self.shader.as_ref().unwrap().global_expressions,
229 self.global_expression_map.clone(),
230 self.global_expressions.clone(),
231 false,
232 true,
233 )
234 }
235
236 pub fn import_pipeline_override(&mut self, h_override: &Handle<Override>) -> Handle<Override> {
238 self.pipeline_override_map
239 .get(h_override)
240 .copied()
241 .unwrap_or_else(|| {
242 let pipeline_override = self
243 .shader
244 .as_ref()
245 .unwrap()
246 .overrides
247 .try_get(*h_override)
248 .unwrap();
249
250 let new_override = Override {
251 name: pipeline_override.name.clone(),
252 id: pipeline_override.id,
253 ty: self.import_type(&pipeline_override.ty),
254 init: pipeline_override
255 .init
256 .map(|init| self.import_global_expression(init)),
257 };
258
259 let span = self
260 .shader
261 .as_ref()
262 .unwrap()
263 .overrides
264 .get_span(*h_override);
265 let new_h = self
266 .pipeline_overrides
267 .fetch_or_append(new_override, self.map_span(span));
268 self.pipeline_override_map.insert(*h_override, new_h);
269 new_h
270 })
271 }
272
273 fn import_block(
275 &mut self,
276 block: &Block,
277 old_expressions: &Arena<Expression>,
278 already_imported: Rc<RefCell<IndexMap<Handle<Expression>, Handle<Expression>>>>,
279 new_expressions: Rc<RefCell<Arena<Expression>>>,
280 ) -> Block {
281 macro_rules! map_expr {
282 ($e:expr) => {
283 self.import_expression(
284 *$e,
285 old_expressions,
286 already_imported.clone(),
287 new_expressions.clone(),
288 false,
289 false,
290 )
291 };
292 }
293
294 macro_rules! map_expr_opt {
295 ($e:expr) => {
296 $e.as_ref().map(|expr| map_expr!(expr))
297 };
298 }
299
300 macro_rules! map_block {
301 ($b:expr) => {
302 self.import_block(
303 $b,
304 old_expressions,
305 already_imported.clone(),
306 new_expressions.clone(),
307 )
308 };
309 }
310
311 let statements = block
312 .iter()
313 .map(|stmt| {
314 match stmt {
315 Statement::Call {
317 function,
318 arguments,
319 result,
320 } => Statement::Call {
321 function: self.map_function_handle(function),
322 arguments: arguments.iter().map(|expr| map_expr!(expr)).collect(),
323 result: result.as_ref().map(|result| map_expr!(result)),
324 },
325
326 Statement::Block(b) => Statement::Block(map_block!(b)),
328 Statement::If {
329 condition,
330 accept,
331 reject,
332 } => Statement::If {
333 condition: map_expr!(condition),
334 accept: map_block!(accept),
335 reject: map_block!(reject),
336 },
337 Statement::Switch { selector, cases } => Statement::Switch {
338 selector: map_expr!(selector),
339 cases: cases
340 .iter()
341 .map(|case| SwitchCase {
342 value: case.value,
343 body: map_block!(&case.body),
344 fall_through: case.fall_through,
345 })
346 .collect(),
347 },
348 Statement::Loop {
349 body,
350 continuing,
351 break_if,
352 } => Statement::Loop {
353 body: map_block!(body),
354 continuing: map_block!(continuing),
355 break_if: map_expr_opt!(break_if),
356 },
357
358 Statement::Emit(exprs) => {
360 for expr in exprs.clone() {
362 self.import_expression(
363 expr,
364 old_expressions,
365 already_imported.clone(),
366 new_expressions.clone(),
367 true,
368 false,
369 );
370 }
371 let old_length = new_expressions.borrow().len();
372 for expr in exprs.clone() {
374 map_expr!(&expr);
375 }
376
377 Statement::Emit(new_expressions.borrow().range_from(old_length))
378 }
379 Statement::Store { pointer, value } => Statement::Store {
380 pointer: map_expr!(pointer),
381 value: map_expr!(value),
382 },
383 Statement::ImageStore {
384 image,
385 coordinate,
386 array_index,
387 value,
388 } => Statement::ImageStore {
389 image: map_expr!(image),
390 coordinate: map_expr!(coordinate),
391 array_index: map_expr_opt!(array_index),
392 value: map_expr!(value),
393 },
394 Statement::Atomic {
395 pointer,
396 fun,
397 value,
398 result,
399 } => {
400 let fun = match fun {
401 AtomicFunction::Exchange {
402 compare: Some(compare_expr),
403 } => AtomicFunction::Exchange {
404 compare: Some(map_expr!(compare_expr)),
405 },
406 fun => *fun,
407 };
408 Statement::Atomic {
409 pointer: map_expr!(pointer),
410 fun,
411 value: map_expr!(value),
412 result: map_expr_opt!(result),
413 }
414 }
415 Statement::WorkGroupUniformLoad { pointer, result } => {
416 Statement::WorkGroupUniformLoad {
417 pointer: map_expr!(pointer),
418 result: map_expr!(result),
419 }
420 }
421 Statement::Return { value } => Statement::Return {
422 value: map_expr_opt!(value),
423 },
424 Statement::RayQuery { query, fun } => Statement::RayQuery {
425 query: map_expr!(query),
426 fun: match fun {
427 naga::RayQueryFunction::Initialize {
428 acceleration_structure,
429 descriptor,
430 } => naga::RayQueryFunction::Initialize {
431 acceleration_structure: map_expr!(acceleration_structure),
432 descriptor: map_expr!(descriptor),
433 },
434 naga::RayQueryFunction::Proceed { result } => {
435 naga::RayQueryFunction::Proceed {
436 result: map_expr!(result),
437 }
438 }
439 naga::RayQueryFunction::Terminate => naga::RayQueryFunction::Terminate,
440 },
441 },
442 Statement::SubgroupBallot { result, predicate } => Statement::SubgroupBallot {
443 result: map_expr!(result),
444 predicate: map_expr_opt!(predicate),
445 },
446 Statement::SubgroupGather {
447 mut mode,
448 argument,
449 result,
450 } => {
451 match mode {
452 GatherMode::BroadcastFirst => (),
453 GatherMode::Broadcast(ref mut h_src)
454 | GatherMode::Shuffle(ref mut h_src)
455 | GatherMode::ShuffleDown(ref mut h_src)
456 | GatherMode::ShuffleUp(ref mut h_src)
457 | GatherMode::ShuffleXor(ref mut h_src) => *h_src = map_expr!(h_src),
458 };
459 Statement::SubgroupGather {
460 mode,
461 argument: map_expr!(argument),
462 result: map_expr!(result),
463 }
464 }
465 Statement::SubgroupCollectiveOperation {
466 op,
467 collective_op,
468 argument,
469 result,
470 } => Statement::SubgroupCollectiveOperation {
471 op: *op,
472 collective_op: *collective_op,
473 argument: map_expr!(argument),
474 result: map_expr!(result),
475 },
476 Statement::ImageAtomic {
477 image,
478 coordinate,
479 array_index,
480 fun,
481 value,
482 } => {
483 let fun = match fun {
484 AtomicFunction::Exchange {
485 compare: Some(compare_expr),
486 } => AtomicFunction::Exchange {
487 compare: Some(map_expr!(compare_expr)),
488 },
489 fun => *fun,
490 };
491 Statement::ImageAtomic {
492 image: map_expr!(image),
493 coordinate: map_expr!(coordinate),
494 array_index: map_expr_opt!(array_index),
495 fun,
496 value: map_expr!(value),
497 }
498 }
499 Statement::Break
501 | Statement::Continue
502 | Statement::Kill
503 | Statement::Barrier(_) => stmt.clone(),
504 }
505 })
506 .collect();
507
508 let mut new_block = Block::from_vec(statements);
509
510 for ((_, new_span), (_, old_span)) in new_block.span_iter_mut().zip(block.span_iter()) {
511 *new_span.unwrap() = self.map_span(*old_span);
512 }
513
514 new_block
515 }
516
517 fn import_expression(
518 &mut self,
519 h_expr: Handle<Expression>,
520 old_expressions: &Arena<Expression>,
521 already_imported: Rc<RefCell<IndexMap<Handle<Expression>, Handle<Expression>>>>,
522 new_expressions: Rc<RefCell<Arena<Expression>>>,
523 non_emitting_only: bool, unique: bool, ) -> Handle<Expression> {
526 if let Some(h_new) = already_imported.borrow().get(&h_expr) {
527 return *h_new;
528 }
529
530 macro_rules! map_expr {
531 ($e:expr) => {
532 self.import_expression(
533 *$e,
534 old_expressions,
535 already_imported.clone(),
536 new_expressions.clone(),
537 non_emitting_only,
538 unique,
539 )
540 };
541 }
542
543 macro_rules! map_expr_opt {
544 ($e:expr) => {
545 $e.as_ref().map(|expr| {
546 self.import_expression(
547 *expr,
548 old_expressions,
549 already_imported.clone(),
550 new_expressions.clone(),
551 non_emitting_only,
552 unique,
553 )
554 })
555 };
556 }
557
558 let mut is_external = false;
559 let expr = old_expressions.try_get(h_expr).unwrap();
560 let expr = match expr {
561 Expression::Literal(_) => {
562 is_external = true;
563 expr.clone()
564 }
565 Expression::ZeroValue(zv) => {
566 is_external = true;
567 Expression::ZeroValue(self.import_type(zv))
568 }
569 Expression::CallResult(f) => Expression::CallResult(self.map_function_handle(f)),
570 Expression::Constant(c) => {
571 is_external = true;
572 Expression::Constant(self.import_const(c))
573 }
574 Expression::Compose { ty, components } => Expression::Compose {
575 ty: self.import_type(ty),
576 components: components.iter().map(|expr| map_expr!(expr)).collect(),
577 },
578 Expression::GlobalVariable(gv) => {
579 is_external = true;
580 Expression::GlobalVariable(self.import_global(gv))
581 }
582 Expression::ImageSample {
583 image,
584 sampler,
585 gather,
586 coordinate,
587 array_index,
588 offset,
589 level,
590 depth_ref,
591 } => Expression::ImageSample {
592 image: map_expr!(image),
593 sampler: map_expr!(sampler),
594 gather: *gather,
595 coordinate: map_expr!(coordinate),
596 array_index: map_expr_opt!(array_index),
597 offset: offset.map(|c| self.import_global_expression(c)),
598 level: match level {
599 SampleLevel::Auto | SampleLevel::Zero => *level,
600 SampleLevel::Exact(expr) => SampleLevel::Exact(map_expr!(expr)),
601 SampleLevel::Bias(expr) => SampleLevel::Bias(map_expr!(expr)),
602 SampleLevel::Gradient { x, y } => SampleLevel::Gradient {
603 x: map_expr!(x),
604 y: map_expr!(y),
605 },
606 },
607 depth_ref: map_expr_opt!(depth_ref),
608 },
609 Expression::Access { base, index } => Expression::Access {
610 base: map_expr!(base),
611 index: map_expr!(index),
612 },
613 Expression::AccessIndex { base, index } => Expression::AccessIndex {
614 base: map_expr!(base),
615 index: *index,
616 },
617 Expression::Splat { size, value } => Expression::Splat {
618 size: *size,
619 value: map_expr!(value),
620 },
621 Expression::Swizzle {
622 size,
623 vector,
624 pattern,
625 } => Expression::Swizzle {
626 size: *size,
627 vector: map_expr!(vector),
628 pattern: *pattern,
629 },
630 Expression::Load { pointer } => Expression::Load {
631 pointer: map_expr!(pointer),
632 },
633 Expression::ImageLoad {
634 image,
635 coordinate,
636 array_index,
637 sample,
638 level,
639 } => Expression::ImageLoad {
640 image: map_expr!(image),
641 coordinate: map_expr!(coordinate),
642 array_index: map_expr_opt!(array_index),
643 sample: map_expr_opt!(sample),
644 level: map_expr_opt!(level),
645 },
646 Expression::ImageQuery { image, query } => Expression::ImageQuery {
647 image: map_expr!(image),
648 query: match query {
649 ImageQuery::Size { level } => ImageQuery::Size {
650 level: map_expr_opt!(level),
651 },
652 _ => *query,
653 },
654 },
655 Expression::Unary { op, expr } => Expression::Unary {
656 op: *op,
657 expr: map_expr!(expr),
658 },
659 Expression::Binary { op, left, right } => Expression::Binary {
660 op: *op,
661 left: map_expr!(left),
662 right: map_expr!(right),
663 },
664 Expression::Select {
665 condition,
666 accept,
667 reject,
668 } => Expression::Select {
669 condition: map_expr!(condition),
670 accept: map_expr!(accept),
671 reject: map_expr!(reject),
672 },
673 Expression::Derivative { axis, expr, ctrl } => Expression::Derivative {
674 axis: *axis,
675 expr: map_expr!(expr),
676 ctrl: *ctrl,
677 },
678 Expression::Relational { fun, argument } => Expression::Relational {
679 fun: *fun,
680 argument: map_expr!(argument),
681 },
682 Expression::Math {
683 fun,
684 arg,
685 arg1,
686 arg2,
687 arg3,
688 } => Expression::Math {
689 fun: *fun,
690 arg: map_expr!(arg),
691 arg1: map_expr_opt!(arg1),
692 arg2: map_expr_opt!(arg2),
693 arg3: map_expr_opt!(arg3),
694 },
695 Expression::As {
696 expr,
697 kind,
698 convert,
699 } => Expression::As {
700 expr: map_expr!(expr),
701 kind: *kind,
702 convert: *convert,
703 },
704 Expression::ArrayLength(expr) => Expression::ArrayLength(map_expr!(expr)),
705
706 Expression::LocalVariable(_) | Expression::FunctionArgument(_) => {
707 is_external = true;
708 expr.clone()
709 }
710
711 Expression::AtomicResult { ty, comparison } => Expression::AtomicResult {
712 ty: self.import_type(ty),
713 comparison: *comparison,
714 },
715 Expression::WorkGroupUniformLoadResult { ty } => {
716 Expression::WorkGroupUniformLoadResult {
717 ty: self.import_type(ty),
718 }
719 }
720 Expression::RayQueryProceedResult => expr.clone(),
721 Expression::RayQueryGetIntersection { query, committed } => {
722 Expression::RayQueryGetIntersection {
723 query: map_expr!(query),
724 committed: *committed,
725 }
726 }
727 Expression::Override(h_override) => {
728 is_external = true;
729 Expression::Override(self.import_pipeline_override(h_override))
730 }
731 Expression::SubgroupBallotResult => expr.clone(),
732 Expression::SubgroupOperationResult { ty } => Expression::SubgroupOperationResult {
733 ty: self.import_type(ty),
734 },
735 };
736
737 if !non_emitting_only || is_external {
738 let span = old_expressions.get_span(h_expr);
739 let h_new = if unique {
740 new_expressions.borrow_mut().fetch_if_or_append(
741 expr,
742 self.map_span(span),
743 |lhs, rhs| lhs == rhs,
744 )
745 } else {
746 new_expressions
747 .borrow_mut()
748 .append(expr, self.map_span(span))
749 };
750
751 already_imported.borrow_mut().insert(h_expr, h_new);
752 h_new
753 } else {
754 h_expr
755 }
756 }
757
758 pub fn localize_function(&mut self, func: &Function) -> Function {
760 let arguments = func
761 .arguments
762 .iter()
763 .map(|arg| FunctionArgument {
764 name: arg.name.clone(),
765 ty: self.import_type(&arg.ty),
766 binding: arg.binding.clone(),
767 })
768 .collect();
769
770 let result = func.result.as_ref().map(|r| FunctionResult {
771 ty: self.import_type(&r.ty),
772 binding: r.binding.clone(),
773 });
774
775 let expressions = Rc::new(RefCell::new(Arena::new()));
776 let expr_map = Rc::new(RefCell::new(IndexMap::new()));
777
778 let mut local_variables = Arena::new();
779 for (h_l, l) in func.local_variables.iter() {
780 let new_local = LocalVariable {
781 name: l.name.clone(),
782 ty: self.import_type(&l.ty),
783 init: l.init.map(|c| {
784 self.import_expression(
785 c,
786 &func.expressions,
787 expr_map.clone(),
788 expressions.clone(),
789 false,
790 true,
791 )
792 }),
793 };
794 let span = func.local_variables.get_span(h_l);
795 let new_h = local_variables.append(new_local, self.map_span(span));
796 assert_eq!(h_l, new_h);
797 }
798
799 let body = self.import_block(
800 &func.body,
801 &func.expressions,
802 expr_map.clone(),
803 expressions.clone(),
804 );
805
806 let named_expressions = func
807 .named_expressions
808 .iter()
809 .flat_map(|(h_expr, name)| {
810 expr_map
811 .borrow()
812 .get(h_expr)
813 .map(|new_h| (*new_h, name.clone()))
814 })
815 .collect::<IndexMap<_, _, std::hash::BuildHasherDefault<rustc_hash::FxHasher>>>();
816
817 Function {
818 name: func.name.clone(),
819 arguments,
820 result,
821 local_variables,
822 expressions: Rc::try_unwrap(expressions).unwrap().into_inner(),
823 named_expressions,
824 body,
825 diagnostic_filter_leaf: None,
826 }
827 }
828
829 pub fn import_function(&mut self, func: &Function, span: Span) -> Handle<Function> {
833 let name = func.name.as_ref().unwrap().clone();
834 let mapped_func = self.localize_function(func);
835 let new_span = self.map_span(span);
836 let new_h = self.functions.append(mapped_func, new_span);
837 self.function_map.insert(name, new_h);
838 new_h
839 }
840
841 pub fn map_function_handle(&mut self, h_func: &Handle<Function>) -> Handle<Function> {
844 let functions = &self.shader.as_ref().unwrap().functions;
845 let func = functions.try_get(*h_func).unwrap();
846 let name = func.name.as_ref().unwrap();
847 self.function_map.get(name).copied().unwrap_or_else(|| {
848 let span = functions.get_span(*h_func);
849 self.import_function(func, span)
850 })
851 }
852
853 pub fn import_function_if_new(&mut self, func: &Function, span: Span) -> Handle<Function> {
856 let name = func.name.as_ref().unwrap().clone();
857 if let Some(h) = self.function_map.get(&name) {
858 return *h;
859 }
860
861 self.import_function(func, span)
862 }
863
864 pub fn has_required_special_types(&self) -> bool {
866 !self.special_types.predeclared_types.is_empty()
867 || self.special_types.ray_desc.is_some()
868 || self.special_types.ray_intersection.is_some()
869 }
870
871 pub fn into_module_with_entrypoints(mut self) -> naga::Module {
872 let entry_points = self
873 .shader
874 .unwrap()
875 .entry_points
876 .iter()
877 .map(|ep| EntryPoint {
878 name: ep.name.clone(),
879 stage: ep.stage,
880 early_depth_test: ep.early_depth_test,
881 workgroup_size: ep.workgroup_size,
882 function: self.localize_function(&ep.function),
883 workgroup_size_overrides: ep.workgroup_size_overrides,
884 })
885 .collect();
886
887 naga::Module {
888 entry_points,
889 ..self.into()
890 }
891 }
892}
893
894impl From<DerivedModule<'_>> for naga::Module {
895 fn from(derived: DerivedModule) -> Self {
896 naga::Module {
897 types: derived.types,
898 constants: derived.constants,
899 global_variables: derived.globals,
900 global_expressions: Rc::try_unwrap(derived.global_expressions)
901 .unwrap()
902 .into_inner(),
903 functions: derived.functions,
904 special_types: derived.special_types,
905 entry_points: Default::default(),
906 overrides: derived.pipeline_overrides,
907 diagnostic_filters: Default::default(),
908 diagnostic_filter_leaf: None,
909 }
910 }
911}