1use indexmap::IndexMap;
2use naga::{
3 Arena, AtomicFunction, Block, Constant, EntryPoint, Expression, Function, FunctionArgument,
4 FunctionResult, GatherMode, GlobalVariable, Handle, ImageQuery, LocalVariable, Module,
5 Override, SampleLevel, Span, Statement, StructMember, SwitchCase, Type, TypeInner, UniqueArena,
6};
7use std::{cell::RefCell, rc::Rc};
8
9#[derive(Debug, Default)]
10pub struct DerivedModule<'a> {
11 shader: Option<&'a Module>,
12 span_offset: usize,
13
14 type_map: IndexMap<Handle<Type>, Handle<Type>>,
16 const_map: IndexMap<Handle<Constant>, Handle<Constant>>,
18 pipeline_override_map: IndexMap<Handle<Override>, Handle<Override>>,
20 global_expressions: Rc<RefCell<Arena<Expression>>>,
23 global_expression_map: Rc<RefCell<IndexMap<Handle<Expression>, Handle<Expression>>>>,
26 global_map: IndexMap<Handle<GlobalVariable>, Handle<GlobalVariable>>,
27 function_map: IndexMap<String, Handle<Function>>,
28 types: UniqueArena<Type>,
29 constants: Arena<Constant>,
30 globals: Arena<GlobalVariable>,
31 functions: Arena<Function>,
32 pipeline_overrides: Arena<Override>,
33}
34
35impl<'a> DerivedModule<'a> {
36 pub fn set_shader_source(&mut self, shader: &'a Module, span_offset: usize) {
38 self.clear_shader_source();
39 self.shader = Some(shader);
40 self.span_offset = span_offset;
41 }
42
43 pub fn clear_shader_source(&mut self) {
45 self.shader = None;
46 self.type_map.clear();
47 self.const_map.clear();
48 self.global_map.clear();
49 self.global_expression_map.borrow_mut().clear();
50 self.pipeline_override_map.clear();
51 }
52
53 pub fn map_span(&self, span: Span) -> Span {
54 let span = span.to_range();
55 match span {
56 Some(rng) => Span::new(
57 (rng.start + self.span_offset) as u32,
58 (rng.end + self.span_offset) as u32,
59 ),
60 None => Span::UNDEFINED,
61 }
62 }
63
64 pub fn import_type(&mut self, h_type: &Handle<Type>) -> Handle<Type> {
66 self.rename_type(h_type, None)
67 }
68
69 pub fn rename_type(&mut self, h_type: &Handle<Type>, name: Option<String>) -> Handle<Type> {
71 self.type_map.get(h_type).copied().unwrap_or_else(|| {
72 let ty = self
73 .shader
74 .as_ref()
75 .unwrap()
76 .types
77 .get_handle(*h_type)
78 .unwrap();
79
80 let name = match name {
81 Some(name) => Some(name),
82 None => ty.name.clone(),
83 };
84
85 let new_type = Type {
86 name,
87 inner: match &ty.inner {
88 TypeInner::Scalar { .. }
89 | TypeInner::Vector { .. }
90 | TypeInner::Matrix { .. }
91 | TypeInner::ValuePointer { .. }
92 | TypeInner::Image { .. }
93 | TypeInner::Sampler { .. }
94 | TypeInner::Atomic { .. }
95 | TypeInner::AccelerationStructure
96 | TypeInner::RayQuery => ty.inner.clone(),
97
98 TypeInner::Pointer { base, space } => TypeInner::Pointer {
99 base: self.import_type(base),
100 space: *space,
101 },
102 TypeInner::Struct { members, span } => {
103 let members = members
104 .iter()
105 .map(|m| StructMember {
106 name: m.name.clone(),
107 ty: self.import_type(&m.ty),
108 binding: m.binding.clone(),
109 offset: m.offset,
110 })
111 .collect();
112 TypeInner::Struct {
113 members,
114 span: *span,
115 }
116 }
117 TypeInner::Array { base, size, stride } => TypeInner::Array {
118 base: self.import_type(base),
119 size: *size,
120 stride: *stride,
121 },
122 TypeInner::BindingArray { base, size } => TypeInner::BindingArray {
123 base: self.import_type(base),
124 size: *size,
125 },
126 },
127 };
128 let span = self.shader.as_ref().unwrap().types.get_span(*h_type);
129 let new_h = self.types.insert(new_type, self.map_span(span));
130 self.type_map.insert(*h_type, new_h);
131 new_h
132 })
133 }
134
135 pub fn import_const(&mut self, h_const: &Handle<Constant>) -> Handle<Constant> {
137 self.const_map.get(h_const).copied().unwrap_or_else(|| {
138 let c = self
139 .shader
140 .as_ref()
141 .unwrap()
142 .constants
143 .try_get(*h_const)
144 .unwrap();
145
146 let new_const = Constant {
147 name: c.name.clone(),
148 ty: self.import_type(&c.ty),
149 init: self.import_global_expression(c.init),
150 };
151
152 let span = self.shader.as_ref().unwrap().constants.get_span(*h_const);
153 let new_h = self
154 .constants
155 .fetch_or_append(new_const, self.map_span(span));
156 self.const_map.insert(*h_const, new_h);
157 new_h
158 })
159 }
160
161 pub fn import_global(&mut self, h_global: &Handle<GlobalVariable>) -> Handle<GlobalVariable> {
163 self.global_map.get(h_global).copied().unwrap_or_else(|| {
164 let gv = self
165 .shader
166 .as_ref()
167 .unwrap()
168 .global_variables
169 .try_get(*h_global)
170 .unwrap();
171
172 let new_global = GlobalVariable {
173 name: gv.name.clone(),
174 space: gv.space,
175 binding: gv.binding.clone(),
176 ty: self.import_type(&gv.ty),
177 init: gv.init.map(|c| self.import_global_expression(c)),
178 };
179
180 let span = self
181 .shader
182 .as_ref()
183 .unwrap()
184 .global_variables
185 .get_span(*h_global);
186 let new_h = self
187 .globals
188 .fetch_or_append(new_global, self.map_span(span));
189 self.global_map.insert(*h_global, new_h);
190 new_h
191 })
192 }
193
194 pub fn import_global_expression(&mut self, h_expr: Handle<Expression>) -> Handle<Expression> {
196 self.import_expression(
197 h_expr,
198 &self.shader.as_ref().unwrap().global_expressions,
199 self.global_expression_map.clone(),
200 self.global_expressions.clone(),
201 false,
202 true,
203 )
204 }
205
206 pub fn import_pipeline_override(&mut self, h_override: &Handle<Override>) -> Handle<Override> {
208 self.pipeline_override_map
209 .get(h_override)
210 .copied()
211 .unwrap_or_else(|| {
212 let pipeline_override = self
213 .shader
214 .as_ref()
215 .unwrap()
216 .overrides
217 .try_get(*h_override)
218 .unwrap();
219
220 let new_override = Override {
221 name: pipeline_override.name.clone(),
222 id: pipeline_override.id,
223 ty: self.import_type(&pipeline_override.ty),
224 init: pipeline_override
225 .init
226 .map(|init| self.import_global_expression(init)),
227 };
228
229 let span = self
230 .shader
231 .as_ref()
232 .unwrap()
233 .overrides
234 .get_span(*h_override);
235 let new_h = self
236 .pipeline_overrides
237 .fetch_or_append(new_override, self.map_span(span));
238 self.pipeline_override_map.insert(*h_override, new_h);
239 new_h
240 })
241 }
242
243 fn import_block(
245 &mut self,
246 block: &Block,
247 old_expressions: &Arena<Expression>,
248 already_imported: Rc<RefCell<IndexMap<Handle<Expression>, Handle<Expression>>>>,
249 new_expressions: Rc<RefCell<Arena<Expression>>>,
250 ) -> Block {
251 macro_rules! map_expr {
252 ($e:expr) => {
253 self.import_expression(
254 *$e,
255 old_expressions,
256 already_imported.clone(),
257 new_expressions.clone(),
258 false,
259 false,
260 )
261 };
262 }
263
264 macro_rules! map_expr_opt {
265 ($e:expr) => {
266 $e.as_ref().map(|expr| map_expr!(expr))
267 };
268 }
269
270 macro_rules! map_block {
271 ($b:expr) => {
272 self.import_block(
273 $b,
274 old_expressions,
275 already_imported.clone(),
276 new_expressions.clone(),
277 )
278 };
279 }
280
281 let statements = block
282 .iter()
283 .map(|stmt| {
284 match stmt {
285 Statement::Call {
287 function,
288 arguments,
289 result,
290 } => Statement::Call {
291 function: self.map_function_handle(function),
292 arguments: arguments.iter().map(|expr| map_expr!(expr)).collect(),
293 result: result.as_ref().map(|result| map_expr!(result)),
294 },
295
296 Statement::Block(b) => Statement::Block(map_block!(b)),
298 Statement::If {
299 condition,
300 accept,
301 reject,
302 } => Statement::If {
303 condition: map_expr!(condition),
304 accept: map_block!(accept),
305 reject: map_block!(reject),
306 },
307 Statement::Switch { selector, cases } => Statement::Switch {
308 selector: map_expr!(selector),
309 cases: cases
310 .iter()
311 .map(|case| SwitchCase {
312 value: case.value,
313 body: map_block!(&case.body),
314 fall_through: case.fall_through,
315 })
316 .collect(),
317 },
318 Statement::Loop {
319 body,
320 continuing,
321 break_if,
322 } => Statement::Loop {
323 body: map_block!(body),
324 continuing: map_block!(continuing),
325 break_if: map_expr_opt!(break_if),
326 },
327
328 Statement::Emit(exprs) => {
330 for expr in exprs.clone() {
332 self.import_expression(
333 expr,
334 old_expressions,
335 already_imported.clone(),
336 new_expressions.clone(),
337 true,
338 false,
339 );
340 }
341 let old_length = new_expressions.borrow().len();
342 for expr in exprs.clone() {
344 map_expr!(&expr);
345 }
346
347 Statement::Emit(new_expressions.borrow().range_from(old_length))
348 }
349 Statement::Store { pointer, value } => Statement::Store {
350 pointer: map_expr!(pointer),
351 value: map_expr!(value),
352 },
353 Statement::ImageStore {
354 image,
355 coordinate,
356 array_index,
357 value,
358 } => Statement::ImageStore {
359 image: map_expr!(image),
360 coordinate: map_expr!(coordinate),
361 array_index: map_expr_opt!(array_index),
362 value: map_expr!(value),
363 },
364 Statement::Atomic {
365 pointer,
366 fun,
367 value,
368 result,
369 } => {
370 let fun = match fun {
371 AtomicFunction::Exchange {
372 compare: Some(compare_expr),
373 } => AtomicFunction::Exchange {
374 compare: Some(map_expr!(compare_expr)),
375 },
376 fun => *fun,
377 };
378 Statement::Atomic {
379 pointer: map_expr!(pointer),
380 fun,
381 value: map_expr!(value),
382 result: map_expr_opt!(result),
383 }
384 }
385 Statement::WorkGroupUniformLoad { pointer, result } => {
386 Statement::WorkGroupUniformLoad {
387 pointer: map_expr!(pointer),
388 result: map_expr!(result),
389 }
390 }
391 Statement::Return { value } => Statement::Return {
392 value: map_expr_opt!(value),
393 },
394 Statement::RayQuery { query, fun } => Statement::RayQuery {
395 query: map_expr!(query),
396 fun: match fun {
397 naga::RayQueryFunction::Initialize {
398 acceleration_structure,
399 descriptor,
400 } => naga::RayQueryFunction::Initialize {
401 acceleration_structure: map_expr!(acceleration_structure),
402 descriptor: map_expr!(descriptor),
403 },
404 naga::RayQueryFunction::Proceed { result } => {
405 naga::RayQueryFunction::Proceed {
406 result: map_expr!(result),
407 }
408 }
409 naga::RayQueryFunction::Terminate => naga::RayQueryFunction::Terminate,
410 },
411 },
412 Statement::SubgroupBallot { result, predicate } => Statement::SubgroupBallot {
413 result: map_expr!(result),
414 predicate: map_expr_opt!(predicate),
415 },
416 Statement::SubgroupGather {
417 mut mode,
418 argument,
419 result,
420 } => {
421 match mode {
422 GatherMode::BroadcastFirst => (),
423 GatherMode::Broadcast(ref mut h_src)
424 | GatherMode::Shuffle(ref mut h_src)
425 | GatherMode::ShuffleDown(ref mut h_src)
426 | GatherMode::ShuffleUp(ref mut h_src)
427 | GatherMode::ShuffleXor(ref mut h_src) => *h_src = map_expr!(h_src),
428 };
429 Statement::SubgroupGather {
430 mode,
431 argument: map_expr!(argument),
432 result: map_expr!(result),
433 }
434 }
435 Statement::SubgroupCollectiveOperation {
436 op,
437 collective_op,
438 argument,
439 result,
440 } => Statement::SubgroupCollectiveOperation {
441 op: *op,
442 collective_op: *collective_op,
443 argument: map_expr!(argument),
444 result: map_expr!(result),
445 },
446 Statement::Break
448 | Statement::Continue
449 | Statement::Kill
450 | Statement::Barrier(_) => stmt.clone(),
451 }
452 })
453 .collect();
454
455 let mut new_block = Block::from_vec(statements);
456
457 for ((_, new_span), (_, old_span)) in new_block.span_iter_mut().zip(block.span_iter()) {
458 *new_span.unwrap() = self.map_span(*old_span);
459 }
460
461 new_block
462 }
463
464 fn import_expression(
465 &mut self,
466 h_expr: Handle<Expression>,
467 old_expressions: &Arena<Expression>,
468 already_imported: Rc<RefCell<IndexMap<Handle<Expression>, Handle<Expression>>>>,
469 new_expressions: Rc<RefCell<Arena<Expression>>>,
470 non_emitting_only: bool, unique: bool, ) -> Handle<Expression> {
473 if let Some(h_new) = already_imported.borrow().get(&h_expr) {
474 return *h_new;
475 }
476
477 macro_rules! map_expr {
478 ($e:expr) => {
479 self.import_expression(
480 *$e,
481 old_expressions,
482 already_imported.clone(),
483 new_expressions.clone(),
484 non_emitting_only,
485 unique,
486 )
487 };
488 }
489
490 macro_rules! map_expr_opt {
491 ($e:expr) => {
492 $e.as_ref().map(|expr| {
493 self.import_expression(
494 *expr,
495 old_expressions,
496 already_imported.clone(),
497 new_expressions.clone(),
498 non_emitting_only,
499 unique,
500 )
501 })
502 };
503 }
504
505 let mut is_external = false;
506 let expr = old_expressions.try_get(h_expr).unwrap();
507 let expr = match expr {
508 Expression::Literal(_) => {
509 is_external = true;
510 expr.clone()
511 }
512 Expression::ZeroValue(zv) => {
513 is_external = true;
514 Expression::ZeroValue(self.import_type(zv))
515 }
516 Expression::CallResult(f) => Expression::CallResult(self.map_function_handle(f)),
517 Expression::Constant(c) => {
518 is_external = true;
519 Expression::Constant(self.import_const(c))
520 }
521 Expression::Compose { ty, components } => Expression::Compose {
522 ty: self.import_type(ty),
523 components: components.iter().map(|expr| map_expr!(expr)).collect(),
524 },
525 Expression::GlobalVariable(gv) => {
526 is_external = true;
527 Expression::GlobalVariable(self.import_global(gv))
528 }
529 Expression::ImageSample {
530 image,
531 sampler,
532 gather,
533 coordinate,
534 array_index,
535 offset,
536 level,
537 depth_ref,
538 } => Expression::ImageSample {
539 image: map_expr!(image),
540 sampler: map_expr!(sampler),
541 gather: *gather,
542 coordinate: map_expr!(coordinate),
543 array_index: map_expr_opt!(array_index),
544 offset: offset.map(|c| self.import_global_expression(c)),
545 level: match level {
546 SampleLevel::Auto | SampleLevel::Zero => *level,
547 SampleLevel::Exact(expr) => SampleLevel::Exact(map_expr!(expr)),
548 SampleLevel::Bias(expr) => SampleLevel::Bias(map_expr!(expr)),
549 SampleLevel::Gradient { x, y } => SampleLevel::Gradient {
550 x: map_expr!(x),
551 y: map_expr!(y),
552 },
553 },
554 depth_ref: map_expr_opt!(depth_ref),
555 },
556 Expression::Access { base, index } => Expression::Access {
557 base: map_expr!(base),
558 index: map_expr!(index),
559 },
560 Expression::AccessIndex { base, index } => Expression::AccessIndex {
561 base: map_expr!(base),
562 index: *index,
563 },
564 Expression::Splat { size, value } => Expression::Splat {
565 size: *size,
566 value: map_expr!(value),
567 },
568 Expression::Swizzle {
569 size,
570 vector,
571 pattern,
572 } => Expression::Swizzle {
573 size: *size,
574 vector: map_expr!(vector),
575 pattern: *pattern,
576 },
577 Expression::Load { pointer } => Expression::Load {
578 pointer: map_expr!(pointer),
579 },
580 Expression::ImageLoad {
581 image,
582 coordinate,
583 array_index,
584 sample,
585 level,
586 } => Expression::ImageLoad {
587 image: map_expr!(image),
588 coordinate: map_expr!(coordinate),
589 array_index: map_expr_opt!(array_index),
590 sample: map_expr_opt!(sample),
591 level: map_expr_opt!(level),
592 },
593 Expression::ImageQuery { image, query } => Expression::ImageQuery {
594 image: map_expr!(image),
595 query: match query {
596 ImageQuery::Size { level } => ImageQuery::Size {
597 level: map_expr_opt!(level),
598 },
599 _ => *query,
600 },
601 },
602 Expression::Unary { op, expr } => Expression::Unary {
603 op: *op,
604 expr: map_expr!(expr),
605 },
606 Expression::Binary { op, left, right } => Expression::Binary {
607 op: *op,
608 left: map_expr!(left),
609 right: map_expr!(right),
610 },
611 Expression::Select {
612 condition,
613 accept,
614 reject,
615 } => Expression::Select {
616 condition: map_expr!(condition),
617 accept: map_expr!(accept),
618 reject: map_expr!(reject),
619 },
620 Expression::Derivative { axis, expr, ctrl } => Expression::Derivative {
621 axis: *axis,
622 expr: map_expr!(expr),
623 ctrl: *ctrl,
624 },
625 Expression::Relational { fun, argument } => Expression::Relational {
626 fun: *fun,
627 argument: map_expr!(argument),
628 },
629 Expression::Math {
630 fun,
631 arg,
632 arg1,
633 arg2,
634 arg3,
635 } => Expression::Math {
636 fun: *fun,
637 arg: map_expr!(arg),
638 arg1: map_expr_opt!(arg1),
639 arg2: map_expr_opt!(arg2),
640 arg3: map_expr_opt!(arg3),
641 },
642 Expression::As {
643 expr,
644 kind,
645 convert,
646 } => Expression::As {
647 expr: map_expr!(expr),
648 kind: *kind,
649 convert: *convert,
650 },
651 Expression::ArrayLength(expr) => Expression::ArrayLength(map_expr!(expr)),
652
653 Expression::LocalVariable(_) | Expression::FunctionArgument(_) => {
654 is_external = true;
655 expr.clone()
656 }
657
658 Expression::AtomicResult { ty, comparison } => Expression::AtomicResult {
659 ty: self.import_type(ty),
660 comparison: *comparison,
661 },
662 Expression::WorkGroupUniformLoadResult { ty } => {
663 Expression::WorkGroupUniformLoadResult {
664 ty: self.import_type(ty),
665 }
666 }
667 Expression::RayQueryProceedResult => expr.clone(),
668 Expression::RayQueryGetIntersection { query, committed } => {
669 Expression::RayQueryGetIntersection {
670 query: map_expr!(query),
671 committed: *committed,
672 }
673 }
674 Expression::Override(h_override) => {
675 is_external = true;
676 Expression::Override(self.import_pipeline_override(h_override))
677 }
678 Expression::SubgroupBallotResult => expr.clone(),
679 Expression::SubgroupOperationResult { ty } => Expression::SubgroupOperationResult {
680 ty: self.import_type(ty),
681 },
682 };
683
684 if !non_emitting_only || is_external {
685 let span = old_expressions.get_span(h_expr);
686 let h_new = if unique {
687 new_expressions.borrow_mut().fetch_if_or_append(
688 expr,
689 self.map_span(span),
690 |lhs, rhs| lhs == rhs,
691 )
692 } else {
693 new_expressions
694 .borrow_mut()
695 .append(expr, self.map_span(span))
696 };
697
698 already_imported.borrow_mut().insert(h_expr, h_new);
699 h_new
700 } else {
701 h_expr
702 }
703 }
704
705 pub fn localize_function(&mut self, func: &Function) -> Function {
707 let arguments = func
708 .arguments
709 .iter()
710 .map(|arg| FunctionArgument {
711 name: arg.name.clone(),
712 ty: self.import_type(&arg.ty),
713 binding: arg.binding.clone(),
714 })
715 .collect();
716
717 let result = func.result.as_ref().map(|r| FunctionResult {
718 ty: self.import_type(&r.ty),
719 binding: r.binding.clone(),
720 });
721
722 let expressions = Rc::new(RefCell::new(Arena::new()));
723 let expr_map = Rc::new(RefCell::new(IndexMap::new()));
724
725 let mut local_variables = Arena::new();
726 for (h_l, l) in func.local_variables.iter() {
727 let new_local = LocalVariable {
728 name: l.name.clone(),
729 ty: self.import_type(&l.ty),
730 init: l.init.map(|c| {
731 self.import_expression(
732 c,
733 &func.expressions,
734 expr_map.clone(),
735 expressions.clone(),
736 false,
737 true,
738 )
739 }),
740 };
741 let span = func.local_variables.get_span(h_l);
742 let new_h = local_variables.append(new_local, self.map_span(span));
743 assert_eq!(h_l, new_h);
744 }
745
746 let body = self.import_block(
747 &func.body,
748 &func.expressions,
749 expr_map.clone(),
750 expressions.clone(),
751 );
752
753 let named_expressions = func
754 .named_expressions
755 .iter()
756 .flat_map(|(h_expr, name)| {
757 expr_map
758 .borrow()
759 .get(h_expr)
760 .map(|new_h| (*new_h, name.clone()))
761 })
762 .collect::<IndexMap<_, _, std::hash::BuildHasherDefault<rustc_hash::FxHasher>>>();
763
764 Function {
765 name: func.name.clone(),
766 arguments,
767 result,
768 local_variables,
769 expressions: Rc::try_unwrap(expressions).unwrap().into_inner(),
770 named_expressions,
771 body,
772 }
773 }
774
775 pub fn import_function(&mut self, func: &Function, span: Span) -> Handle<Function> {
779 let name = func.name.as_ref().unwrap().clone();
780 let mapped_func = self.localize_function(func);
781 let new_span = self.map_span(span);
782 let new_h = self.functions.append(mapped_func, new_span);
783 self.function_map.insert(name, new_h);
784 new_h
785 }
786
787 pub fn map_function_handle(&mut self, h_func: &Handle<Function>) -> Handle<Function> {
790 let functions = &self.shader.as_ref().unwrap().functions;
791 let func = functions.try_get(*h_func).unwrap();
792 let name = func.name.as_ref().unwrap();
793 self.function_map.get(name).copied().unwrap_or_else(|| {
794 let span = functions.get_span(*h_func);
795 self.import_function(func, span)
796 })
797 }
798
799 pub fn import_function_if_new(&mut self, func: &Function, span: Span) -> Handle<Function> {
802 let name = func.name.as_ref().unwrap().clone();
803 if let Some(h) = self.function_map.get(&name) {
804 return *h;
805 }
806
807 self.import_function(func, span)
808 }
809
810 pub fn into_module_with_entrypoints(mut self) -> naga::Module {
811 let entry_points = self
812 .shader
813 .unwrap()
814 .entry_points
815 .iter()
816 .map(|ep| EntryPoint {
817 name: ep.name.clone(),
818 stage: ep.stage,
819 early_depth_test: ep.early_depth_test,
820 workgroup_size: ep.workgroup_size,
821 function: self.localize_function(&ep.function),
822 })
823 .collect();
824
825 naga::Module {
826 entry_points,
827 ..self.into()
828 }
829 }
830}
831
832impl<'a> From<DerivedModule<'a>> for naga::Module {
833 fn from(derived: DerivedModule) -> Self {
834 naga::Module {
835 types: derived.types,
836 constants: derived.constants,
837 global_variables: derived.globals,
838 global_expressions: Rc::try_unwrap(derived.global_expressions)
839 .unwrap()
840 .into_inner(),
841 functions: derived.functions,
842 special_types: Default::default(),
843 entry_points: Default::default(),
844 overrides: derived.pipeline_overrides,
845 }
846 }
847}