naga_oil/
derive.rs

1use indexmap::IndexMap;
2use naga::{
3    Arena, AtomicFunction, Block, Constant, EntryPoint, Expression, Function, FunctionArgument,
4    FunctionResult, GatherMode, GlobalVariable, Handle, ImageQuery, LocalVariable, Module,
5    Override, SampleLevel, Span, Statement, StructMember, SwitchCase, Type, TypeInner, UniqueArena,
6};
7use std::{cell::RefCell, rc::Rc};
8
9#[derive(Debug, Default)]
10pub struct DerivedModule<'a> {
11    shader: Option<&'a Module>,
12    span_offset: usize,
13
14    /// Maps the original type handle to the the mangled type handle.
15    type_map: IndexMap<Handle<Type>, Handle<Type>>,
16    /// Maps the original const handle to the the mangled const handle.
17    const_map: IndexMap<Handle<Constant>, Handle<Constant>>,
18    /// Maps the original pipeline override handle to the the mangled pipeline override handle.
19    pipeline_override_map: IndexMap<Handle<Override>, Handle<Override>>,
20    /// Contains both const expressions and pipeline override constant expressions.
21    /// The expressions are stored together because that's what Naga expects.
22    global_expressions: Rc<RefCell<Arena<Expression>>>,
23    /// Maps the original expression handle to the new expression handle for const expressions and pipeline override expressions.
24    /// The expressions are stored together because that's what Naga expects.
25    global_expression_map: Rc<RefCell<IndexMap<Handle<Expression>, Handle<Expression>>>>,
26    global_map: IndexMap<Handle<GlobalVariable>, Handle<GlobalVariable>>,
27    function_map: IndexMap<String, Handle<Function>>,
28    types: UniqueArena<Type>,
29    constants: Arena<Constant>,
30    globals: Arena<GlobalVariable>,
31    functions: Arena<Function>,
32    pipeline_overrides: Arena<Override>,
33}
34
35impl<'a> DerivedModule<'a> {
36    // set source context for import operations
37    pub fn set_shader_source(&mut self, shader: &'a Module, span_offset: usize) {
38        self.clear_shader_source();
39        self.shader = Some(shader);
40        self.span_offset = span_offset;
41    }
42
43    // detach source context
44    pub fn clear_shader_source(&mut self) {
45        self.shader = None;
46        self.type_map.clear();
47        self.const_map.clear();
48        self.global_map.clear();
49        self.global_expression_map.borrow_mut().clear();
50        self.pipeline_override_map.clear();
51    }
52
53    pub fn map_span(&self, span: Span) -> Span {
54        let span = span.to_range();
55        match span {
56            Some(rng) => Span::new(
57                (rng.start + self.span_offset) as u32,
58                (rng.end + self.span_offset) as u32,
59            ),
60            None => Span::UNDEFINED,
61        }
62    }
63
64    // remap a type from source context into our derived context
65    pub fn import_type(&mut self, h_type: &Handle<Type>) -> Handle<Type> {
66        self.rename_type(h_type, None)
67    }
68
69    // remap a type from source context into our derived context, and rename it
70    pub fn rename_type(&mut self, h_type: &Handle<Type>, name: Option<String>) -> Handle<Type> {
71        self.type_map.get(h_type).copied().unwrap_or_else(|| {
72            let ty = self
73                .shader
74                .as_ref()
75                .unwrap()
76                .types
77                .get_handle(*h_type)
78                .unwrap();
79
80            let name = match name {
81                Some(name) => Some(name),
82                None => ty.name.clone(),
83            };
84
85            let new_type = Type {
86                name,
87                inner: match &ty.inner {
88                    TypeInner::Scalar { .. }
89                    | TypeInner::Vector { .. }
90                    | TypeInner::Matrix { .. }
91                    | TypeInner::ValuePointer { .. }
92                    | TypeInner::Image { .. }
93                    | TypeInner::Sampler { .. }
94                    | TypeInner::Atomic { .. }
95                    | TypeInner::AccelerationStructure
96                    | TypeInner::RayQuery => ty.inner.clone(),
97
98                    TypeInner::Pointer { base, space } => TypeInner::Pointer {
99                        base: self.import_type(base),
100                        space: *space,
101                    },
102                    TypeInner::Struct { members, span } => {
103                        let members = members
104                            .iter()
105                            .map(|m| StructMember {
106                                name: m.name.clone(),
107                                ty: self.import_type(&m.ty),
108                                binding: m.binding.clone(),
109                                offset: m.offset,
110                            })
111                            .collect();
112                        TypeInner::Struct {
113                            members,
114                            span: *span,
115                        }
116                    }
117                    TypeInner::Array { base, size, stride } => TypeInner::Array {
118                        base: self.import_type(base),
119                        size: *size,
120                        stride: *stride,
121                    },
122                    TypeInner::BindingArray { base, size } => TypeInner::BindingArray {
123                        base: self.import_type(base),
124                        size: *size,
125                    },
126                },
127            };
128            let span = self.shader.as_ref().unwrap().types.get_span(*h_type);
129            let new_h = self.types.insert(new_type, self.map_span(span));
130            self.type_map.insert(*h_type, new_h);
131            new_h
132        })
133    }
134
135    // remap a const from source context into our derived context
136    pub fn import_const(&mut self, h_const: &Handle<Constant>) -> Handle<Constant> {
137        self.const_map.get(h_const).copied().unwrap_or_else(|| {
138            let c = self
139                .shader
140                .as_ref()
141                .unwrap()
142                .constants
143                .try_get(*h_const)
144                .unwrap();
145
146            let new_const = Constant {
147                name: c.name.clone(),
148                ty: self.import_type(&c.ty),
149                init: self.import_global_expression(c.init),
150            };
151
152            let span = self.shader.as_ref().unwrap().constants.get_span(*h_const);
153            let new_h = self
154                .constants
155                .fetch_or_append(new_const, self.map_span(span));
156            self.const_map.insert(*h_const, new_h);
157            new_h
158        })
159    }
160
161    // remap a global from source context into our derived context
162    pub fn import_global(&mut self, h_global: &Handle<GlobalVariable>) -> Handle<GlobalVariable> {
163        self.global_map.get(h_global).copied().unwrap_or_else(|| {
164            let gv = self
165                .shader
166                .as_ref()
167                .unwrap()
168                .global_variables
169                .try_get(*h_global)
170                .unwrap();
171
172            let new_global = GlobalVariable {
173                name: gv.name.clone(),
174                space: gv.space,
175                binding: gv.binding.clone(),
176                ty: self.import_type(&gv.ty),
177                init: gv.init.map(|c| self.import_global_expression(c)),
178            };
179
180            let span = self
181                .shader
182                .as_ref()
183                .unwrap()
184                .global_variables
185                .get_span(*h_global);
186            let new_h = self
187                .globals
188                .fetch_or_append(new_global, self.map_span(span));
189            self.global_map.insert(*h_global, new_h);
190            new_h
191        })
192    }
193
194    // remap either a const or pipeline override expression from source context into our derived context
195    pub fn import_global_expression(&mut self, h_expr: Handle<Expression>) -> Handle<Expression> {
196        self.import_expression(
197            h_expr,
198            &self.shader.as_ref().unwrap().global_expressions,
199            self.global_expression_map.clone(),
200            self.global_expressions.clone(),
201            false,
202            true,
203        )
204    }
205
206    // remap a pipeline override from source context into our derived context
207    pub fn import_pipeline_override(&mut self, h_override: &Handle<Override>) -> Handle<Override> {
208        self.pipeline_override_map
209            .get(h_override)
210            .copied()
211            .unwrap_or_else(|| {
212                let pipeline_override = self
213                    .shader
214                    .as_ref()
215                    .unwrap()
216                    .overrides
217                    .try_get(*h_override)
218                    .unwrap();
219
220                let new_override = Override {
221                    name: pipeline_override.name.clone(),
222                    id: pipeline_override.id,
223                    ty: self.import_type(&pipeline_override.ty),
224                    init: pipeline_override
225                        .init
226                        .map(|init| self.import_global_expression(init)),
227                };
228
229                let span = self
230                    .shader
231                    .as_ref()
232                    .unwrap()
233                    .overrides
234                    .get_span(*h_override);
235                let new_h = self
236                    .pipeline_overrides
237                    .fetch_or_append(new_override, self.map_span(span));
238                self.pipeline_override_map.insert(*h_override, new_h);
239                new_h
240            })
241    }
242
243    // remap a block
244    fn import_block(
245        &mut self,
246        block: &Block,
247        old_expressions: &Arena<Expression>,
248        already_imported: Rc<RefCell<IndexMap<Handle<Expression>, Handle<Expression>>>>,
249        new_expressions: Rc<RefCell<Arena<Expression>>>,
250    ) -> Block {
251        macro_rules! map_expr {
252            ($e:expr) => {
253                self.import_expression(
254                    *$e,
255                    old_expressions,
256                    already_imported.clone(),
257                    new_expressions.clone(),
258                    false,
259                    false,
260                )
261            };
262        }
263
264        macro_rules! map_expr_opt {
265            ($e:expr) => {
266                $e.as_ref().map(|expr| map_expr!(expr))
267            };
268        }
269
270        macro_rules! map_block {
271            ($b:expr) => {
272                self.import_block(
273                    $b,
274                    old_expressions,
275                    already_imported.clone(),
276                    new_expressions.clone(),
277                )
278            };
279        }
280
281        let statements = block
282            .iter()
283            .map(|stmt| {
284                match stmt {
285                    // remap function calls
286                    Statement::Call {
287                        function,
288                        arguments,
289                        result,
290                    } => Statement::Call {
291                        function: self.map_function_handle(function),
292                        arguments: arguments.iter().map(|expr| map_expr!(expr)).collect(),
293                        result: result.as_ref().map(|result| map_expr!(result)),
294                    },
295
296                    // recursively
297                    Statement::Block(b) => Statement::Block(map_block!(b)),
298                    Statement::If {
299                        condition,
300                        accept,
301                        reject,
302                    } => Statement::If {
303                        condition: map_expr!(condition),
304                        accept: map_block!(accept),
305                        reject: map_block!(reject),
306                    },
307                    Statement::Switch { selector, cases } => Statement::Switch {
308                        selector: map_expr!(selector),
309                        cases: cases
310                            .iter()
311                            .map(|case| SwitchCase {
312                                value: case.value,
313                                body: map_block!(&case.body),
314                                fall_through: case.fall_through,
315                            })
316                            .collect(),
317                    },
318                    Statement::Loop {
319                        body,
320                        continuing,
321                        break_if,
322                    } => Statement::Loop {
323                        body: map_block!(body),
324                        continuing: map_block!(continuing),
325                        break_if: map_expr_opt!(break_if),
326                    },
327
328                    // map expressions
329                    Statement::Emit(exprs) => {
330                        // iterate once to add expressions that should NOT be part of the emit statement
331                        for expr in exprs.clone() {
332                            self.import_expression(
333                                expr,
334                                old_expressions,
335                                already_imported.clone(),
336                                new_expressions.clone(),
337                                true,
338                                false,
339                            );
340                        }
341                        let old_length = new_expressions.borrow().len();
342                        // iterate again to add expressions that should be part of the emit statement
343                        for expr in exprs.clone() {
344                            map_expr!(&expr);
345                        }
346
347                        Statement::Emit(new_expressions.borrow().range_from(old_length))
348                    }
349                    Statement::Store { pointer, value } => Statement::Store {
350                        pointer: map_expr!(pointer),
351                        value: map_expr!(value),
352                    },
353                    Statement::ImageStore {
354                        image,
355                        coordinate,
356                        array_index,
357                        value,
358                    } => Statement::ImageStore {
359                        image: map_expr!(image),
360                        coordinate: map_expr!(coordinate),
361                        array_index: map_expr_opt!(array_index),
362                        value: map_expr!(value),
363                    },
364                    Statement::Atomic {
365                        pointer,
366                        fun,
367                        value,
368                        result,
369                    } => {
370                        let fun = match fun {
371                            AtomicFunction::Exchange {
372                                compare: Some(compare_expr),
373                            } => AtomicFunction::Exchange {
374                                compare: Some(map_expr!(compare_expr)),
375                            },
376                            fun => *fun,
377                        };
378                        Statement::Atomic {
379                            pointer: map_expr!(pointer),
380                            fun,
381                            value: map_expr!(value),
382                            result: map_expr_opt!(result),
383                        }
384                    }
385                    Statement::WorkGroupUniformLoad { pointer, result } => {
386                        Statement::WorkGroupUniformLoad {
387                            pointer: map_expr!(pointer),
388                            result: map_expr!(result),
389                        }
390                    }
391                    Statement::Return { value } => Statement::Return {
392                        value: map_expr_opt!(value),
393                    },
394                    Statement::RayQuery { query, fun } => Statement::RayQuery {
395                        query: map_expr!(query),
396                        fun: match fun {
397                            naga::RayQueryFunction::Initialize {
398                                acceleration_structure,
399                                descriptor,
400                            } => naga::RayQueryFunction::Initialize {
401                                acceleration_structure: map_expr!(acceleration_structure),
402                                descriptor: map_expr!(descriptor),
403                            },
404                            naga::RayQueryFunction::Proceed { result } => {
405                                naga::RayQueryFunction::Proceed {
406                                    result: map_expr!(result),
407                                }
408                            }
409                            naga::RayQueryFunction::Terminate => naga::RayQueryFunction::Terminate,
410                        },
411                    },
412                    Statement::SubgroupBallot { result, predicate } => Statement::SubgroupBallot {
413                        result: map_expr!(result),
414                        predicate: map_expr_opt!(predicate),
415                    },
416                    Statement::SubgroupGather {
417                        mut mode,
418                        argument,
419                        result,
420                    } => {
421                        match mode {
422                            GatherMode::BroadcastFirst => (),
423                            GatherMode::Broadcast(ref mut h_src)
424                            | GatherMode::Shuffle(ref mut h_src)
425                            | GatherMode::ShuffleDown(ref mut h_src)
426                            | GatherMode::ShuffleUp(ref mut h_src)
427                            | GatherMode::ShuffleXor(ref mut h_src) => *h_src = map_expr!(h_src),
428                        };
429                        Statement::SubgroupGather {
430                            mode,
431                            argument: map_expr!(argument),
432                            result: map_expr!(result),
433                        }
434                    }
435                    Statement::SubgroupCollectiveOperation {
436                        op,
437                        collective_op,
438                        argument,
439                        result,
440                    } => Statement::SubgroupCollectiveOperation {
441                        op: *op,
442                        collective_op: *collective_op,
443                        argument: map_expr!(argument),
444                        result: map_expr!(result),
445                    },
446                    // else just copy
447                    Statement::Break
448                    | Statement::Continue
449                    | Statement::Kill
450                    | Statement::Barrier(_) => stmt.clone(),
451                }
452            })
453            .collect();
454
455        let mut new_block = Block::from_vec(statements);
456
457        for ((_, new_span), (_, old_span)) in new_block.span_iter_mut().zip(block.span_iter()) {
458            *new_span.unwrap() = self.map_span(*old_span);
459        }
460
461        new_block
462    }
463
464    fn import_expression(
465        &mut self,
466        h_expr: Handle<Expression>,
467        old_expressions: &Arena<Expression>,
468        already_imported: Rc<RefCell<IndexMap<Handle<Expression>, Handle<Expression>>>>,
469        new_expressions: Rc<RefCell<Arena<Expression>>>,
470        non_emitting_only: bool, // only brings items that should NOT be emitted into scope
471        unique: bool,            // ensure expressions are unique with custom comparison
472    ) -> Handle<Expression> {
473        if let Some(h_new) = already_imported.borrow().get(&h_expr) {
474            return *h_new;
475        }
476
477        macro_rules! map_expr {
478            ($e:expr) => {
479                self.import_expression(
480                    *$e,
481                    old_expressions,
482                    already_imported.clone(),
483                    new_expressions.clone(),
484                    non_emitting_only,
485                    unique,
486                )
487            };
488        }
489
490        macro_rules! map_expr_opt {
491            ($e:expr) => {
492                $e.as_ref().map(|expr| {
493                    self.import_expression(
494                        *expr,
495                        old_expressions,
496                        already_imported.clone(),
497                        new_expressions.clone(),
498                        non_emitting_only,
499                        unique,
500                    )
501                })
502            };
503        }
504
505        let mut is_external = false;
506        let expr = old_expressions.try_get(h_expr).unwrap();
507        let expr = match expr {
508            Expression::Literal(_) => {
509                is_external = true;
510                expr.clone()
511            }
512            Expression::ZeroValue(zv) => {
513                is_external = true;
514                Expression::ZeroValue(self.import_type(zv))
515            }
516            Expression::CallResult(f) => Expression::CallResult(self.map_function_handle(f)),
517            Expression::Constant(c) => {
518                is_external = true;
519                Expression::Constant(self.import_const(c))
520            }
521            Expression::Compose { ty, components } => Expression::Compose {
522                ty: self.import_type(ty),
523                components: components.iter().map(|expr| map_expr!(expr)).collect(),
524            },
525            Expression::GlobalVariable(gv) => {
526                is_external = true;
527                Expression::GlobalVariable(self.import_global(gv))
528            }
529            Expression::ImageSample {
530                image,
531                sampler,
532                gather,
533                coordinate,
534                array_index,
535                offset,
536                level,
537                depth_ref,
538            } => Expression::ImageSample {
539                image: map_expr!(image),
540                sampler: map_expr!(sampler),
541                gather: *gather,
542                coordinate: map_expr!(coordinate),
543                array_index: map_expr_opt!(array_index),
544                offset: offset.map(|c| self.import_global_expression(c)),
545                level: match level {
546                    SampleLevel::Auto | SampleLevel::Zero => *level,
547                    SampleLevel::Exact(expr) => SampleLevel::Exact(map_expr!(expr)),
548                    SampleLevel::Bias(expr) => SampleLevel::Bias(map_expr!(expr)),
549                    SampleLevel::Gradient { x, y } => SampleLevel::Gradient {
550                        x: map_expr!(x),
551                        y: map_expr!(y),
552                    },
553                },
554                depth_ref: map_expr_opt!(depth_ref),
555            },
556            Expression::Access { base, index } => Expression::Access {
557                base: map_expr!(base),
558                index: map_expr!(index),
559            },
560            Expression::AccessIndex { base, index } => Expression::AccessIndex {
561                base: map_expr!(base),
562                index: *index,
563            },
564            Expression::Splat { size, value } => Expression::Splat {
565                size: *size,
566                value: map_expr!(value),
567            },
568            Expression::Swizzle {
569                size,
570                vector,
571                pattern,
572            } => Expression::Swizzle {
573                size: *size,
574                vector: map_expr!(vector),
575                pattern: *pattern,
576            },
577            Expression::Load { pointer } => Expression::Load {
578                pointer: map_expr!(pointer),
579            },
580            Expression::ImageLoad {
581                image,
582                coordinate,
583                array_index,
584                sample,
585                level,
586            } => Expression::ImageLoad {
587                image: map_expr!(image),
588                coordinate: map_expr!(coordinate),
589                array_index: map_expr_opt!(array_index),
590                sample: map_expr_opt!(sample),
591                level: map_expr_opt!(level),
592            },
593            Expression::ImageQuery { image, query } => Expression::ImageQuery {
594                image: map_expr!(image),
595                query: match query {
596                    ImageQuery::Size { level } => ImageQuery::Size {
597                        level: map_expr_opt!(level),
598                    },
599                    _ => *query,
600                },
601            },
602            Expression::Unary { op, expr } => Expression::Unary {
603                op: *op,
604                expr: map_expr!(expr),
605            },
606            Expression::Binary { op, left, right } => Expression::Binary {
607                op: *op,
608                left: map_expr!(left),
609                right: map_expr!(right),
610            },
611            Expression::Select {
612                condition,
613                accept,
614                reject,
615            } => Expression::Select {
616                condition: map_expr!(condition),
617                accept: map_expr!(accept),
618                reject: map_expr!(reject),
619            },
620            Expression::Derivative { axis, expr, ctrl } => Expression::Derivative {
621                axis: *axis,
622                expr: map_expr!(expr),
623                ctrl: *ctrl,
624            },
625            Expression::Relational { fun, argument } => Expression::Relational {
626                fun: *fun,
627                argument: map_expr!(argument),
628            },
629            Expression::Math {
630                fun,
631                arg,
632                arg1,
633                arg2,
634                arg3,
635            } => Expression::Math {
636                fun: *fun,
637                arg: map_expr!(arg),
638                arg1: map_expr_opt!(arg1),
639                arg2: map_expr_opt!(arg2),
640                arg3: map_expr_opt!(arg3),
641            },
642            Expression::As {
643                expr,
644                kind,
645                convert,
646            } => Expression::As {
647                expr: map_expr!(expr),
648                kind: *kind,
649                convert: *convert,
650            },
651            Expression::ArrayLength(expr) => Expression::ArrayLength(map_expr!(expr)),
652
653            Expression::LocalVariable(_) | Expression::FunctionArgument(_) => {
654                is_external = true;
655                expr.clone()
656            }
657
658            Expression::AtomicResult { ty, comparison } => Expression::AtomicResult {
659                ty: self.import_type(ty),
660                comparison: *comparison,
661            },
662            Expression::WorkGroupUniformLoadResult { ty } => {
663                Expression::WorkGroupUniformLoadResult {
664                    ty: self.import_type(ty),
665                }
666            }
667            Expression::RayQueryProceedResult => expr.clone(),
668            Expression::RayQueryGetIntersection { query, committed } => {
669                Expression::RayQueryGetIntersection {
670                    query: map_expr!(query),
671                    committed: *committed,
672                }
673            }
674            Expression::Override(h_override) => {
675                is_external = true;
676                Expression::Override(self.import_pipeline_override(h_override))
677            }
678            Expression::SubgroupBallotResult => expr.clone(),
679            Expression::SubgroupOperationResult { ty } => Expression::SubgroupOperationResult {
680                ty: self.import_type(ty),
681            },
682        };
683
684        if !non_emitting_only || is_external {
685            let span = old_expressions.get_span(h_expr);
686            let h_new = if unique {
687                new_expressions.borrow_mut().fetch_if_or_append(
688                    expr,
689                    self.map_span(span),
690                    |lhs, rhs| lhs == rhs,
691                )
692            } else {
693                new_expressions
694                    .borrow_mut()
695                    .append(expr, self.map_span(span))
696            };
697
698            already_imported.borrow_mut().insert(h_expr, h_new);
699            h_new
700        } else {
701            h_expr
702        }
703    }
704
705    // remap function global references (global vars, consts, types) into our derived context
706    pub fn localize_function(&mut self, func: &Function) -> Function {
707        let arguments = func
708            .arguments
709            .iter()
710            .map(|arg| FunctionArgument {
711                name: arg.name.clone(),
712                ty: self.import_type(&arg.ty),
713                binding: arg.binding.clone(),
714            })
715            .collect();
716
717        let result = func.result.as_ref().map(|r| FunctionResult {
718            ty: self.import_type(&r.ty),
719            binding: r.binding.clone(),
720        });
721
722        let expressions = Rc::new(RefCell::new(Arena::new()));
723        let expr_map = Rc::new(RefCell::new(IndexMap::new()));
724
725        let mut local_variables = Arena::new();
726        for (h_l, l) in func.local_variables.iter() {
727            let new_local = LocalVariable {
728                name: l.name.clone(),
729                ty: self.import_type(&l.ty),
730                init: l.init.map(|c| {
731                    self.import_expression(
732                        c,
733                        &func.expressions,
734                        expr_map.clone(),
735                        expressions.clone(),
736                        false,
737                        true,
738                    )
739                }),
740            };
741            let span = func.local_variables.get_span(h_l);
742            let new_h = local_variables.append(new_local, self.map_span(span));
743            assert_eq!(h_l, new_h);
744        }
745
746        let body = self.import_block(
747            &func.body,
748            &func.expressions,
749            expr_map.clone(),
750            expressions.clone(),
751        );
752
753        let named_expressions = func
754            .named_expressions
755            .iter()
756            .flat_map(|(h_expr, name)| {
757                expr_map
758                    .borrow()
759                    .get(h_expr)
760                    .map(|new_h| (*new_h, name.clone()))
761            })
762            .collect::<IndexMap<_, _, std::hash::BuildHasherDefault<rustc_hash::FxHasher>>>();
763
764        Function {
765            name: func.name.clone(),
766            arguments,
767            result,
768            local_variables,
769            expressions: Rc::try_unwrap(expressions).unwrap().into_inner(),
770            named_expressions,
771            body,
772        }
773    }
774
775    // import a function defined in the source shader context.
776    // func name may be already defined, the returned handle will refer to the new function.
777    // the previously defined function will still be valid.
778    pub fn import_function(&mut self, func: &Function, span: Span) -> Handle<Function> {
779        let name = func.name.as_ref().unwrap().clone();
780        let mapped_func = self.localize_function(func);
781        let new_span = self.map_span(span);
782        let new_h = self.functions.append(mapped_func, new_span);
783        self.function_map.insert(name, new_h);
784        new_h
785    }
786
787    // get the derived handle corresponding to the given source function handle
788    // requires func to be named
789    pub fn map_function_handle(&mut self, h_func: &Handle<Function>) -> Handle<Function> {
790        let functions = &self.shader.as_ref().unwrap().functions;
791        let func = functions.try_get(*h_func).unwrap();
792        let name = func.name.as_ref().unwrap();
793        self.function_map.get(name).copied().unwrap_or_else(|| {
794            let span = functions.get_span(*h_func);
795            self.import_function(func, span)
796        })
797    }
798
799    /// swap an already imported function for a new one.
800    /// note span cannot be updated
801    pub fn import_function_if_new(&mut self, func: &Function, span: Span) -> Handle<Function> {
802        let name = func.name.as_ref().unwrap().clone();
803        if let Some(h) = self.function_map.get(&name) {
804            return *h;
805        }
806
807        self.import_function(func, span)
808    }
809
810    pub fn into_module_with_entrypoints(mut self) -> naga::Module {
811        let entry_points = self
812            .shader
813            .unwrap()
814            .entry_points
815            .iter()
816            .map(|ep| EntryPoint {
817                name: ep.name.clone(),
818                stage: ep.stage,
819                early_depth_test: ep.early_depth_test,
820                workgroup_size: ep.workgroup_size,
821                function: self.localize_function(&ep.function),
822            })
823            .collect();
824
825        naga::Module {
826            entry_points,
827            ..self.into()
828        }
829    }
830}
831
832impl<'a> From<DerivedModule<'a>> for naga::Module {
833    fn from(derived: DerivedModule) -> Self {
834        naga::Module {
835            types: derived.types,
836            constants: derived.constants,
837            global_variables: derived.globals,
838            global_expressions: Rc::try_unwrap(derived.global_expressions)
839                .unwrap()
840                .into_inner(),
841            functions: derived.functions,
842            special_types: Default::default(),
843            entry_points: Default::default(),
844            overrides: derived.pipeline_overrides,
845        }
846    }
847}