naga_oil/
derive.rs

1use indexmap::IndexMap;
2use naga::{
3    Arena, AtomicFunction, Block, Constant, EntryPoint, Expression, Function, FunctionArgument,
4    FunctionResult, GatherMode, GlobalVariable, Handle, ImageQuery, LocalVariable, Module,
5    Override, SampleLevel, Span, Statement, StructMember, SwitchCase, Type, TypeInner, UniqueArena,
6};
7use std::{cell::RefCell, rc::Rc};
8
9#[derive(Debug, Default)]
10pub struct DerivedModule<'a> {
11    shader: Option<&'a Module>,
12    span_offset: usize,
13
14    /// Maps the original type handle to the the mangled type handle.
15    type_map: IndexMap<Handle<Type>, Handle<Type>>,
16    /// Maps the original const handle to the the mangled const handle.
17    const_map: IndexMap<Handle<Constant>, Handle<Constant>>,
18    /// Maps the original pipeline override handle to the the mangled pipeline override handle.
19    pipeline_override_map: IndexMap<Handle<Override>, Handle<Override>>,
20    /// Contains both const expressions and pipeline override constant expressions.
21    /// The expressions are stored together because that's what Naga expects.
22    global_expressions: Rc<RefCell<Arena<Expression>>>,
23    /// Maps the original expression handle to the new expression handle for const expressions and pipeline override expressions.
24    /// The expressions are stored together because that's what Naga expects.
25    global_expression_map: Rc<RefCell<IndexMap<Handle<Expression>, Handle<Expression>>>>,
26    global_map: IndexMap<Handle<GlobalVariable>, Handle<GlobalVariable>>,
27    function_map: IndexMap<String, Handle<Function>>,
28    types: UniqueArena<Type>,
29    constants: Arena<Constant>,
30    globals: Arena<GlobalVariable>,
31    functions: Arena<Function>,
32    pipeline_overrides: Arena<Override>,
33}
34
35impl<'a> DerivedModule<'a> {
36    // set source context for import operations
37    pub fn set_shader_source(&mut self, shader: &'a Module, span_offset: usize) {
38        self.clear_shader_source();
39        self.shader = Some(shader);
40        self.span_offset = span_offset;
41    }
42
43    // detach source context
44    pub fn clear_shader_source(&mut self) {
45        self.shader = None;
46        self.type_map.clear();
47        self.const_map.clear();
48        self.global_map.clear();
49        self.global_expression_map.borrow_mut().clear();
50        self.pipeline_override_map.clear();
51    }
52
53    pub fn map_span(&self, span: Span) -> Span {
54        let span = span.to_range();
55        match span {
56            Some(rng) => Span::new(
57                (rng.start + self.span_offset) as u32,
58                (rng.end + self.span_offset) as u32,
59            ),
60            None => Span::UNDEFINED,
61        }
62    }
63
64    // remap a type from source context into our derived context
65    pub fn import_type(&mut self, h_type: &Handle<Type>) -> Handle<Type> {
66        self.rename_type(h_type, None)
67    }
68
69    // remap a type from source context into our derived context, and rename it
70    pub fn rename_type(&mut self, h_type: &Handle<Type>, name: Option<String>) -> Handle<Type> {
71        self.type_map.get(h_type).copied().unwrap_or_else(|| {
72            let ty = self
73                .shader
74                .as_ref()
75                .unwrap()
76                .types
77                .get_handle(*h_type)
78                .unwrap();
79
80            let name = match name {
81                Some(name) => Some(name),
82                None => ty.name.clone(),
83            };
84
85            let new_type = Type {
86                name,
87                inner: match &ty.inner {
88                    TypeInner::Scalar { .. }
89                    | TypeInner::Vector { .. }
90                    | TypeInner::Matrix { .. }
91                    | TypeInner::ValuePointer { .. }
92                    | TypeInner::Image { .. }
93                    | TypeInner::Sampler { .. }
94                    | TypeInner::Atomic { .. }
95                    | TypeInner::AccelerationStructure
96                    | TypeInner::RayQuery => ty.inner.clone(),
97
98                    TypeInner::Pointer { base, space } => TypeInner::Pointer {
99                        base: self.import_type(base),
100                        space: *space,
101                    },
102                    TypeInner::Struct { members, span } => {
103                        let members = members
104                            .iter()
105                            .map(|m| StructMember {
106                                name: m.name.clone(),
107                                ty: self.import_type(&m.ty),
108                                binding: m.binding.clone(),
109                                offset: m.offset,
110                            })
111                            .collect();
112                        TypeInner::Struct {
113                            members,
114                            span: *span,
115                        }
116                    }
117                    TypeInner::Array { base, size, stride } => TypeInner::Array {
118                        base: self.import_type(base),
119                        size: *size,
120                        stride: *stride,
121                    },
122                    TypeInner::BindingArray { base, size } => TypeInner::BindingArray {
123                        base: self.import_type(base),
124                        size: *size,
125                    },
126                },
127            };
128            let span = self.shader.as_ref().unwrap().types.get_span(*h_type);
129            let new_h = self.types.insert(new_type, self.map_span(span));
130            self.type_map.insert(*h_type, new_h);
131            new_h
132        })
133    }
134
135    // remap a const from source context into our derived context
136    pub fn import_const(&mut self, h_const: &Handle<Constant>) -> Handle<Constant> {
137        self.const_map.get(h_const).copied().unwrap_or_else(|| {
138            let c = self
139                .shader
140                .as_ref()
141                .unwrap()
142                .constants
143                .try_get(*h_const)
144                .unwrap();
145
146            let new_const = Constant {
147                name: c.name.clone(),
148                ty: self.import_type(&c.ty),
149                init: self.import_global_expression(c.init),
150            };
151
152            let span = self.shader.as_ref().unwrap().constants.get_span(*h_const);
153            let new_h = self
154                .constants
155                .fetch_or_append(new_const, self.map_span(span));
156            self.const_map.insert(*h_const, new_h);
157            new_h
158        })
159    }
160
161    // remap a global from source context into our derived context
162    pub fn import_global(&mut self, h_global: &Handle<GlobalVariable>) -> Handle<GlobalVariable> {
163        self.global_map.get(h_global).copied().unwrap_or_else(|| {
164            let gv = self
165                .shader
166                .as_ref()
167                .unwrap()
168                .global_variables
169                .try_get(*h_global)
170                .unwrap();
171
172            let new_global = GlobalVariable {
173                name: gv.name.clone(),
174                space: gv.space,
175                binding: gv.binding.clone(),
176                ty: self.import_type(&gv.ty),
177                init: gv.init.map(|c| self.import_global_expression(c)),
178            };
179
180            let span = self
181                .shader
182                .as_ref()
183                .unwrap()
184                .global_variables
185                .get_span(*h_global);
186            let new_h = self
187                .globals
188                .fetch_or_append(new_global, self.map_span(span));
189            self.global_map.insert(*h_global, new_h);
190            new_h
191        })
192    }
193
194    // remap either a const or pipeline override expression from source context into our derived context
195    pub fn import_global_expression(&mut self, h_expr: Handle<Expression>) -> Handle<Expression> {
196        self.import_expression(
197            h_expr,
198            &self.shader.as_ref().unwrap().global_expressions,
199            self.global_expression_map.clone(),
200            self.global_expressions.clone(),
201            false,
202            true,
203        )
204    }
205
206    // remap a pipeline override from source context into our derived context
207    pub fn import_pipeline_override(&mut self, h_override: &Handle<Override>) -> Handle<Override> {
208        self.pipeline_override_map
209            .get(h_override)
210            .copied()
211            .unwrap_or_else(|| {
212                let pipeline_override = self
213                    .shader
214                    .as_ref()
215                    .unwrap()
216                    .overrides
217                    .try_get(*h_override)
218                    .unwrap();
219
220                let new_override = Override {
221                    name: pipeline_override.name.clone(),
222                    id: pipeline_override.id,
223                    ty: self.import_type(&pipeline_override.ty),
224                    init: pipeline_override
225                        .init
226                        .map(|init| self.import_global_expression(init)),
227                };
228
229                let span = self
230                    .shader
231                    .as_ref()
232                    .unwrap()
233                    .overrides
234                    .get_span(*h_override);
235                let new_h = self
236                    .pipeline_overrides
237                    .fetch_or_append(new_override, self.map_span(span));
238                self.pipeline_override_map.insert(*h_override, new_h);
239                new_h
240            })
241    }
242
243    // remap a block
244    fn import_block(
245        &mut self,
246        block: &Block,
247        old_expressions: &Arena<Expression>,
248        already_imported: Rc<RefCell<IndexMap<Handle<Expression>, Handle<Expression>>>>,
249        new_expressions: Rc<RefCell<Arena<Expression>>>,
250    ) -> Block {
251        macro_rules! map_expr {
252            ($e:expr) => {
253                self.import_expression(
254                    *$e,
255                    old_expressions,
256                    already_imported.clone(),
257                    new_expressions.clone(),
258                    false,
259                    false,
260                )
261            };
262        }
263
264        macro_rules! map_expr_opt {
265            ($e:expr) => {
266                $e.as_ref().map(|expr| map_expr!(expr))
267            };
268        }
269
270        macro_rules! map_block {
271            ($b:expr) => {
272                self.import_block(
273                    $b,
274                    old_expressions,
275                    already_imported.clone(),
276                    new_expressions.clone(),
277                )
278            };
279        }
280
281        let statements = block
282            .iter()
283            .map(|stmt| {
284                match stmt {
285                    // remap function calls
286                    Statement::Call {
287                        function,
288                        arguments,
289                        result,
290                    } => Statement::Call {
291                        function: self.map_function_handle(function),
292                        arguments: arguments.iter().map(|expr| map_expr!(expr)).collect(),
293                        result: result.as_ref().map(|result| map_expr!(result)),
294                    },
295
296                    // recursively
297                    Statement::Block(b) => Statement::Block(map_block!(b)),
298                    Statement::If {
299                        condition,
300                        accept,
301                        reject,
302                    } => Statement::If {
303                        condition: map_expr!(condition),
304                        accept: map_block!(accept),
305                        reject: map_block!(reject),
306                    },
307                    Statement::Switch { selector, cases } => Statement::Switch {
308                        selector: map_expr!(selector),
309                        cases: cases
310                            .iter()
311                            .map(|case| SwitchCase {
312                                value: case.value,
313                                body: map_block!(&case.body),
314                                fall_through: case.fall_through,
315                            })
316                            .collect(),
317                    },
318                    Statement::Loop {
319                        body,
320                        continuing,
321                        break_if,
322                    } => Statement::Loop {
323                        body: map_block!(body),
324                        continuing: map_block!(continuing),
325                        break_if: map_expr_opt!(break_if),
326                    },
327
328                    // map expressions
329                    Statement::Emit(exprs) => {
330                        // iterate once to add expressions that should NOT be part of the emit statement
331                        for expr in exprs.clone() {
332                            self.import_expression(
333                                expr,
334                                old_expressions,
335                                already_imported.clone(),
336                                new_expressions.clone(),
337                                true,
338                                false,
339                            );
340                        }
341                        let old_length = new_expressions.borrow().len();
342                        // iterate again to add expressions that should be part of the emit statement
343                        for expr in exprs.clone() {
344                            map_expr!(&expr);
345                        }
346
347                        Statement::Emit(new_expressions.borrow().range_from(old_length))
348                    }
349                    Statement::Store { pointer, value } => Statement::Store {
350                        pointer: map_expr!(pointer),
351                        value: map_expr!(value),
352                    },
353                    Statement::ImageStore {
354                        image,
355                        coordinate,
356                        array_index,
357                        value,
358                    } => Statement::ImageStore {
359                        image: map_expr!(image),
360                        coordinate: map_expr!(coordinate),
361                        array_index: map_expr_opt!(array_index),
362                        value: map_expr!(value),
363                    },
364                    Statement::Atomic {
365                        pointer,
366                        fun,
367                        value,
368                        result,
369                    } => {
370                        let fun = match fun {
371                            AtomicFunction::Exchange {
372                                compare: Some(compare_expr),
373                            } => AtomicFunction::Exchange {
374                                compare: Some(map_expr!(compare_expr)),
375                            },
376                            fun => *fun,
377                        };
378                        Statement::Atomic {
379                            pointer: map_expr!(pointer),
380                            fun,
381                            value: map_expr!(value),
382                            result: map_expr_opt!(result),
383                        }
384                    }
385                    Statement::WorkGroupUniformLoad { pointer, result } => {
386                        Statement::WorkGroupUniformLoad {
387                            pointer: map_expr!(pointer),
388                            result: map_expr!(result),
389                        }
390                    }
391                    Statement::Return { value } => Statement::Return {
392                        value: map_expr_opt!(value),
393                    },
394                    Statement::RayQuery { query, fun } => Statement::RayQuery {
395                        query: map_expr!(query),
396                        fun: match fun {
397                            naga::RayQueryFunction::Initialize {
398                                acceleration_structure,
399                                descriptor,
400                            } => naga::RayQueryFunction::Initialize {
401                                acceleration_structure: map_expr!(acceleration_structure),
402                                descriptor: map_expr!(descriptor),
403                            },
404                            naga::RayQueryFunction::Proceed { result } => {
405                                naga::RayQueryFunction::Proceed {
406                                    result: map_expr!(result),
407                                }
408                            }
409                            naga::RayQueryFunction::Terminate => naga::RayQueryFunction::Terminate,
410                        },
411                    },
412                    Statement::SubgroupBallot { result, predicate } => Statement::SubgroupBallot {
413                        result: map_expr!(result),
414                        predicate: map_expr_opt!(predicate),
415                    },
416                    Statement::SubgroupGather {
417                        mut mode,
418                        argument,
419                        result,
420                    } => {
421                        match mode {
422                            GatherMode::BroadcastFirst => (),
423                            GatherMode::Broadcast(ref mut h_src)
424                            | GatherMode::Shuffle(ref mut h_src)
425                            | GatherMode::ShuffleDown(ref mut h_src)
426                            | GatherMode::ShuffleUp(ref mut h_src)
427                            | GatherMode::ShuffleXor(ref mut h_src) => *h_src = map_expr!(h_src),
428                        };
429                        Statement::SubgroupGather {
430                            mode,
431                            argument: map_expr!(argument),
432                            result: map_expr!(result),
433                        }
434                    }
435                    Statement::SubgroupCollectiveOperation {
436                        op,
437                        collective_op,
438                        argument,
439                        result,
440                    } => Statement::SubgroupCollectiveOperation {
441                        op: *op,
442                        collective_op: *collective_op,
443                        argument: map_expr!(argument),
444                        result: map_expr!(result),
445                    },
446                    Statement::ImageAtomic {
447                        image,
448                        coordinate,
449                        array_index,
450                        fun,
451                        value,
452                    } => {
453                        let fun = match fun {
454                            AtomicFunction::Exchange {
455                                compare: Some(compare_expr),
456                            } => AtomicFunction::Exchange {
457                                compare: Some(map_expr!(compare_expr)),
458                            },
459                            fun => *fun,
460                        };
461                        Statement::ImageAtomic {
462                            image: map_expr!(image),
463                            coordinate: map_expr!(coordinate),
464                            array_index: map_expr_opt!(array_index),
465                            fun,
466                            value: map_expr!(value),
467                        }
468                    }
469                    // else just copy
470                    Statement::Break
471                    | Statement::Continue
472                    | Statement::Kill
473                    | Statement::Barrier(_) => stmt.clone(),
474                }
475            })
476            .collect();
477
478        let mut new_block = Block::from_vec(statements);
479
480        for ((_, new_span), (_, old_span)) in new_block.span_iter_mut().zip(block.span_iter()) {
481            *new_span.unwrap() = self.map_span(*old_span);
482        }
483
484        new_block
485    }
486
487    fn import_expression(
488        &mut self,
489        h_expr: Handle<Expression>,
490        old_expressions: &Arena<Expression>,
491        already_imported: Rc<RefCell<IndexMap<Handle<Expression>, Handle<Expression>>>>,
492        new_expressions: Rc<RefCell<Arena<Expression>>>,
493        non_emitting_only: bool, // only brings items that should NOT be emitted into scope
494        unique: bool,            // ensure expressions are unique with custom comparison
495    ) -> Handle<Expression> {
496        if let Some(h_new) = already_imported.borrow().get(&h_expr) {
497            return *h_new;
498        }
499
500        macro_rules! map_expr {
501            ($e:expr) => {
502                self.import_expression(
503                    *$e,
504                    old_expressions,
505                    already_imported.clone(),
506                    new_expressions.clone(),
507                    non_emitting_only,
508                    unique,
509                )
510            };
511        }
512
513        macro_rules! map_expr_opt {
514            ($e:expr) => {
515                $e.as_ref().map(|expr| {
516                    self.import_expression(
517                        *expr,
518                        old_expressions,
519                        already_imported.clone(),
520                        new_expressions.clone(),
521                        non_emitting_only,
522                        unique,
523                    )
524                })
525            };
526        }
527
528        let mut is_external = false;
529        let expr = old_expressions.try_get(h_expr).unwrap();
530        let expr = match expr {
531            Expression::Literal(_) => {
532                is_external = true;
533                expr.clone()
534            }
535            Expression::ZeroValue(zv) => {
536                is_external = true;
537                Expression::ZeroValue(self.import_type(zv))
538            }
539            Expression::CallResult(f) => Expression::CallResult(self.map_function_handle(f)),
540            Expression::Constant(c) => {
541                is_external = true;
542                Expression::Constant(self.import_const(c))
543            }
544            Expression::Compose { ty, components } => Expression::Compose {
545                ty: self.import_type(ty),
546                components: components.iter().map(|expr| map_expr!(expr)).collect(),
547            },
548            Expression::GlobalVariable(gv) => {
549                is_external = true;
550                Expression::GlobalVariable(self.import_global(gv))
551            }
552            Expression::ImageSample {
553                image,
554                sampler,
555                gather,
556                coordinate,
557                array_index,
558                offset,
559                level,
560                depth_ref,
561            } => Expression::ImageSample {
562                image: map_expr!(image),
563                sampler: map_expr!(sampler),
564                gather: *gather,
565                coordinate: map_expr!(coordinate),
566                array_index: map_expr_opt!(array_index),
567                offset: offset.map(|c| self.import_global_expression(c)),
568                level: match level {
569                    SampleLevel::Auto | SampleLevel::Zero => *level,
570                    SampleLevel::Exact(expr) => SampleLevel::Exact(map_expr!(expr)),
571                    SampleLevel::Bias(expr) => SampleLevel::Bias(map_expr!(expr)),
572                    SampleLevel::Gradient { x, y } => SampleLevel::Gradient {
573                        x: map_expr!(x),
574                        y: map_expr!(y),
575                    },
576                },
577                depth_ref: map_expr_opt!(depth_ref),
578            },
579            Expression::Access { base, index } => Expression::Access {
580                base: map_expr!(base),
581                index: map_expr!(index),
582            },
583            Expression::AccessIndex { base, index } => Expression::AccessIndex {
584                base: map_expr!(base),
585                index: *index,
586            },
587            Expression::Splat { size, value } => Expression::Splat {
588                size: *size,
589                value: map_expr!(value),
590            },
591            Expression::Swizzle {
592                size,
593                vector,
594                pattern,
595            } => Expression::Swizzle {
596                size: *size,
597                vector: map_expr!(vector),
598                pattern: *pattern,
599            },
600            Expression::Load { pointer } => Expression::Load {
601                pointer: map_expr!(pointer),
602            },
603            Expression::ImageLoad {
604                image,
605                coordinate,
606                array_index,
607                sample,
608                level,
609            } => Expression::ImageLoad {
610                image: map_expr!(image),
611                coordinate: map_expr!(coordinate),
612                array_index: map_expr_opt!(array_index),
613                sample: map_expr_opt!(sample),
614                level: map_expr_opt!(level),
615            },
616            Expression::ImageQuery { image, query } => Expression::ImageQuery {
617                image: map_expr!(image),
618                query: match query {
619                    ImageQuery::Size { level } => ImageQuery::Size {
620                        level: map_expr_opt!(level),
621                    },
622                    _ => *query,
623                },
624            },
625            Expression::Unary { op, expr } => Expression::Unary {
626                op: *op,
627                expr: map_expr!(expr),
628            },
629            Expression::Binary { op, left, right } => Expression::Binary {
630                op: *op,
631                left: map_expr!(left),
632                right: map_expr!(right),
633            },
634            Expression::Select {
635                condition,
636                accept,
637                reject,
638            } => Expression::Select {
639                condition: map_expr!(condition),
640                accept: map_expr!(accept),
641                reject: map_expr!(reject),
642            },
643            Expression::Derivative { axis, expr, ctrl } => Expression::Derivative {
644                axis: *axis,
645                expr: map_expr!(expr),
646                ctrl: *ctrl,
647            },
648            Expression::Relational { fun, argument } => Expression::Relational {
649                fun: *fun,
650                argument: map_expr!(argument),
651            },
652            Expression::Math {
653                fun,
654                arg,
655                arg1,
656                arg2,
657                arg3,
658            } => Expression::Math {
659                fun: *fun,
660                arg: map_expr!(arg),
661                arg1: map_expr_opt!(arg1),
662                arg2: map_expr_opt!(arg2),
663                arg3: map_expr_opt!(arg3),
664            },
665            Expression::As {
666                expr,
667                kind,
668                convert,
669            } => Expression::As {
670                expr: map_expr!(expr),
671                kind: *kind,
672                convert: *convert,
673            },
674            Expression::ArrayLength(expr) => Expression::ArrayLength(map_expr!(expr)),
675
676            Expression::LocalVariable(_) | Expression::FunctionArgument(_) => {
677                is_external = true;
678                expr.clone()
679            }
680
681            Expression::AtomicResult { ty, comparison } => Expression::AtomicResult {
682                ty: self.import_type(ty),
683                comparison: *comparison,
684            },
685            Expression::WorkGroupUniformLoadResult { ty } => {
686                Expression::WorkGroupUniformLoadResult {
687                    ty: self.import_type(ty),
688                }
689            }
690            Expression::RayQueryProceedResult => expr.clone(),
691            Expression::RayQueryGetIntersection { query, committed } => {
692                Expression::RayQueryGetIntersection {
693                    query: map_expr!(query),
694                    committed: *committed,
695                }
696            }
697            Expression::Override(h_override) => {
698                is_external = true;
699                Expression::Override(self.import_pipeline_override(h_override))
700            }
701            Expression::SubgroupBallotResult => expr.clone(),
702            Expression::SubgroupOperationResult { ty } => Expression::SubgroupOperationResult {
703                ty: self.import_type(ty),
704            },
705        };
706
707        if !non_emitting_only || is_external {
708            let span = old_expressions.get_span(h_expr);
709            let h_new = if unique {
710                new_expressions.borrow_mut().fetch_if_or_append(
711                    expr,
712                    self.map_span(span),
713                    |lhs, rhs| lhs == rhs,
714                )
715            } else {
716                new_expressions
717                    .borrow_mut()
718                    .append(expr, self.map_span(span))
719            };
720
721            already_imported.borrow_mut().insert(h_expr, h_new);
722            h_new
723        } else {
724            h_expr
725        }
726    }
727
728    // remap function global references (global vars, consts, types) into our derived context
729    pub fn localize_function(&mut self, func: &Function) -> Function {
730        let arguments = func
731            .arguments
732            .iter()
733            .map(|arg| FunctionArgument {
734                name: arg.name.clone(),
735                ty: self.import_type(&arg.ty),
736                binding: arg.binding.clone(),
737            })
738            .collect();
739
740        let result = func.result.as_ref().map(|r| FunctionResult {
741            ty: self.import_type(&r.ty),
742            binding: r.binding.clone(),
743        });
744
745        let expressions = Rc::new(RefCell::new(Arena::new()));
746        let expr_map = Rc::new(RefCell::new(IndexMap::new()));
747
748        let mut local_variables = Arena::new();
749        for (h_l, l) in func.local_variables.iter() {
750            let new_local = LocalVariable {
751                name: l.name.clone(),
752                ty: self.import_type(&l.ty),
753                init: l.init.map(|c| {
754                    self.import_expression(
755                        c,
756                        &func.expressions,
757                        expr_map.clone(),
758                        expressions.clone(),
759                        false,
760                        true,
761                    )
762                }),
763            };
764            let span = func.local_variables.get_span(h_l);
765            let new_h = local_variables.append(new_local, self.map_span(span));
766            assert_eq!(h_l, new_h);
767        }
768
769        let body = self.import_block(
770            &func.body,
771            &func.expressions,
772            expr_map.clone(),
773            expressions.clone(),
774        );
775
776        let named_expressions = func
777            .named_expressions
778            .iter()
779            .flat_map(|(h_expr, name)| {
780                expr_map
781                    .borrow()
782                    .get(h_expr)
783                    .map(|new_h| (*new_h, name.clone()))
784            })
785            .collect::<IndexMap<_, _, std::hash::BuildHasherDefault<rustc_hash::FxHasher>>>();
786
787        Function {
788            name: func.name.clone(),
789            arguments,
790            result,
791            local_variables,
792            expressions: Rc::try_unwrap(expressions).unwrap().into_inner(),
793            named_expressions,
794            body,
795            diagnostic_filter_leaf: None,
796        }
797    }
798
799    // import a function defined in the source shader context.
800    // func name may be already defined, the returned handle will refer to the new function.
801    // the previously defined function will still be valid.
802    pub fn import_function(&mut self, func: &Function, span: Span) -> Handle<Function> {
803        let name = func.name.as_ref().unwrap().clone();
804        let mapped_func = self.localize_function(func);
805        let new_span = self.map_span(span);
806        let new_h = self.functions.append(mapped_func, new_span);
807        self.function_map.insert(name, new_h);
808        new_h
809    }
810
811    // get the derived handle corresponding to the given source function handle
812    // requires func to be named
813    pub fn map_function_handle(&mut self, h_func: &Handle<Function>) -> Handle<Function> {
814        let functions = &self.shader.as_ref().unwrap().functions;
815        let func = functions.try_get(*h_func).unwrap();
816        let name = func.name.as_ref().unwrap();
817        self.function_map.get(name).copied().unwrap_or_else(|| {
818            let span = functions.get_span(*h_func);
819            self.import_function(func, span)
820        })
821    }
822
823    /// swap an already imported function for a new one.
824    /// note span cannot be updated
825    pub fn import_function_if_new(&mut self, func: &Function, span: Span) -> Handle<Function> {
826        let name = func.name.as_ref().unwrap().clone();
827        if let Some(h) = self.function_map.get(&name) {
828            return *h;
829        }
830
831        self.import_function(func, span)
832    }
833
834    pub fn into_module_with_entrypoints(mut self) -> naga::Module {
835        let entry_points = self
836            .shader
837            .unwrap()
838            .entry_points
839            .iter()
840            .map(|ep| EntryPoint {
841                name: ep.name.clone(),
842                stage: ep.stage,
843                early_depth_test: ep.early_depth_test,
844                workgroup_size: ep.workgroup_size,
845                function: self.localize_function(&ep.function),
846                workgroup_size_overrides: ep.workgroup_size_overrides,
847            })
848            .collect();
849
850        naga::Module {
851            entry_points,
852            ..self.into()
853        }
854    }
855}
856
857impl From<DerivedModule<'_>> for naga::Module {
858    fn from(derived: DerivedModule) -> Self {
859        naga::Module {
860            types: derived.types,
861            constants: derived.constants,
862            global_variables: derived.globals,
863            global_expressions: Rc::try_unwrap(derived.global_expressions)
864                .unwrap()
865                .into_inner(),
866            functions: derived.functions,
867            special_types: Default::default(),
868            entry_points: Default::default(),
869            overrides: derived.pipeline_overrides,
870            diagnostic_filters: Default::default(),
871            diagnostic_filter_leaf: None,
872        }
873    }
874}