naga_oil/
derive.rs

1use indexmap::IndexMap;
2use naga::{
3    Arena, AtomicFunction, Block, Constant, EntryPoint, Expression, Function, FunctionArgument,
4    FunctionResult, GatherMode, GlobalVariable, Handle, ImageQuery, LocalVariable, Module,
5    Override, SampleLevel, Span, Statement, StructMember, SwitchCase, Type, TypeInner, UniqueArena,
6};
7use std::{cell::RefCell, rc::Rc};
8
9#[derive(Debug, Default)]
10pub struct DerivedModule<'a> {
11    shader: Option<&'a Module>,
12    span_offset: usize,
13
14    /// Maps the original type handle to the the mangled type handle.
15    type_map: IndexMap<Handle<Type>, Handle<Type>>,
16    /// Maps the original const handle to the the mangled const handle.
17    const_map: IndexMap<Handle<Constant>, Handle<Constant>>,
18    /// Maps the original pipeline override handle to the the mangled pipeline override handle.
19    pipeline_override_map: IndexMap<Handle<Override>, Handle<Override>>,
20    /// Contains both const expressions and pipeline override constant expressions.
21    /// The expressions are stored together because that's what Naga expects.
22    global_expressions: Rc<RefCell<Arena<Expression>>>,
23    /// Maps the original expression handle to the new expression handle for const expressions and pipeline override expressions.
24    /// The expressions are stored together because that's what Naga expects.
25    global_expression_map: Rc<RefCell<IndexMap<Handle<Expression>, Handle<Expression>>>>,
26    global_map: IndexMap<Handle<GlobalVariable>, Handle<GlobalVariable>>,
27    function_map: IndexMap<String, Handle<Function>>,
28    types: UniqueArena<Type>,
29    constants: Arena<Constant>,
30    globals: Arena<GlobalVariable>,
31    functions: Arena<Function>,
32    pipeline_overrides: Arena<Override>,
33    special_types: naga::SpecialTypes,
34}
35
36impl<'a> DerivedModule<'a> {
37    // set source context for import operations
38    pub fn set_shader_source(&mut self, shader: &'a Module, span_offset: usize) {
39        self.clear_shader_source();
40        self.shader = Some(shader);
41        self.span_offset = span_offset;
42
43        // eagerly import special types
44        if let Some(h_special_type) = shader.special_types.ray_desc.as_ref() {
45            if let Some(derived_special_type) = self.special_types.ray_desc.as_ref() {
46                self.type_map.insert(*h_special_type, *derived_special_type);
47            } else {
48                self.special_types.ray_desc = Some(self.import_type(h_special_type));
49            }
50        }
51        if let Some(h_special_type) = shader.special_types.ray_intersection.as_ref() {
52            if let Some(derived_special_type) = self.special_types.ray_intersection.as_ref() {
53                self.type_map.insert(*h_special_type, *derived_special_type);
54            } else {
55                self.special_types.ray_intersection = Some(self.import_type(h_special_type));
56            }
57        }
58        for (predeclared, h_predeclared_type) in shader.special_types.predeclared_types.iter() {
59            if let Some(derived_special_type) =
60                self.special_types.predeclared_types.get(predeclared)
61            {
62                self.type_map
63                    .insert(*h_predeclared_type, *derived_special_type);
64            } else {
65                let new_h = self.import_type(h_predeclared_type);
66                self.special_types
67                    .predeclared_types
68                    .insert(predeclared.clone(), new_h);
69            }
70        }
71    }
72
73    // detach source context
74    pub fn clear_shader_source(&mut self) {
75        self.shader = None;
76        self.type_map.clear();
77        self.const_map.clear();
78        self.global_map.clear();
79        self.global_expression_map.borrow_mut().clear();
80        self.pipeline_override_map.clear();
81    }
82
83    pub fn map_span(&self, span: Span) -> Span {
84        let span = span.to_range();
85        match span {
86            Some(rng) => Span::new(
87                (rng.start + self.span_offset) as u32,
88                (rng.end + self.span_offset) as u32,
89            ),
90            None => Span::UNDEFINED,
91        }
92    }
93
94    // remap a type from source context into our derived context
95    pub fn import_type(&mut self, h_type: &Handle<Type>) -> Handle<Type> {
96        self.rename_type(h_type, None)
97    }
98
99    // remap a type from source context into our derived context, and rename it
100    pub fn rename_type(&mut self, h_type: &Handle<Type>, name: Option<String>) -> Handle<Type> {
101        self.type_map.get(h_type).copied().unwrap_or_else(|| {
102            let ty = self
103                .shader
104                .as_ref()
105                .unwrap()
106                .types
107                .get_handle(*h_type)
108                .unwrap();
109
110            let name = match name {
111                Some(name) => Some(name),
112                None => ty.name.clone(),
113            };
114
115            let new_type = Type {
116                name,
117                inner: match &ty.inner {
118                    TypeInner::Scalar { .. }
119                    | TypeInner::Vector { .. }
120                    | TypeInner::Matrix { .. }
121                    | TypeInner::ValuePointer { .. }
122                    | TypeInner::Image { .. }
123                    | TypeInner::Sampler { .. }
124                    | TypeInner::Atomic { .. }
125                    | TypeInner::AccelerationStructure
126                    | TypeInner::RayQuery => ty.inner.clone(),
127
128                    TypeInner::Pointer { base, space } => TypeInner::Pointer {
129                        base: self.import_type(base),
130                        space: *space,
131                    },
132                    TypeInner::Struct { members, span } => {
133                        let members = members
134                            .iter()
135                            .map(|m| StructMember {
136                                name: m.name.clone(),
137                                ty: self.import_type(&m.ty),
138                                binding: m.binding.clone(),
139                                offset: m.offset,
140                            })
141                            .collect();
142                        TypeInner::Struct {
143                            members,
144                            span: *span,
145                        }
146                    }
147                    TypeInner::Array { base, size, stride } => TypeInner::Array {
148                        base: self.import_type(base),
149                        size: *size,
150                        stride: *stride,
151                    },
152                    TypeInner::BindingArray { base, size } => TypeInner::BindingArray {
153                        base: self.import_type(base),
154                        size: *size,
155                    },
156                },
157            };
158            let span = self.shader.as_ref().unwrap().types.get_span(*h_type);
159            let new_h = self.types.insert(new_type, self.map_span(span));
160            self.type_map.insert(*h_type, new_h);
161            new_h
162        })
163    }
164
165    // remap a const from source context into our derived context
166    pub fn import_const(&mut self, h_const: &Handle<Constant>) -> Handle<Constant> {
167        self.const_map.get(h_const).copied().unwrap_or_else(|| {
168            let c = self
169                .shader
170                .as_ref()
171                .unwrap()
172                .constants
173                .try_get(*h_const)
174                .unwrap();
175
176            let new_const = Constant {
177                name: c.name.clone(),
178                ty: self.import_type(&c.ty),
179                init: self.import_global_expression(c.init),
180            };
181
182            let span = self.shader.as_ref().unwrap().constants.get_span(*h_const);
183            let new_h = self
184                .constants
185                .fetch_or_append(new_const, self.map_span(span));
186            self.const_map.insert(*h_const, new_h);
187            new_h
188        })
189    }
190
191    // remap a global from source context into our derived context
192    pub fn import_global(&mut self, h_global: &Handle<GlobalVariable>) -> Handle<GlobalVariable> {
193        self.global_map.get(h_global).copied().unwrap_or_else(|| {
194            let gv = self
195                .shader
196                .as_ref()
197                .unwrap()
198                .global_variables
199                .try_get(*h_global)
200                .unwrap();
201
202            let new_global = GlobalVariable {
203                name: gv.name.clone(),
204                space: gv.space,
205                binding: gv.binding.clone(),
206                ty: self.import_type(&gv.ty),
207                init: gv.init.map(|c| self.import_global_expression(c)),
208            };
209
210            let span = self
211                .shader
212                .as_ref()
213                .unwrap()
214                .global_variables
215                .get_span(*h_global);
216            let new_h = self
217                .globals
218                .fetch_or_append(new_global, self.map_span(span));
219            self.global_map.insert(*h_global, new_h);
220            new_h
221        })
222    }
223
224    // remap either a const or pipeline override expression from source context into our derived context
225    pub fn import_global_expression(&mut self, h_expr: Handle<Expression>) -> Handle<Expression> {
226        self.import_expression(
227            h_expr,
228            &self.shader.as_ref().unwrap().global_expressions,
229            self.global_expression_map.clone(),
230            self.global_expressions.clone(),
231            false,
232            true,
233        )
234    }
235
236    // remap a pipeline override from source context into our derived context
237    pub fn import_pipeline_override(&mut self, h_override: &Handle<Override>) -> Handle<Override> {
238        self.pipeline_override_map
239            .get(h_override)
240            .copied()
241            .unwrap_or_else(|| {
242                let pipeline_override = self
243                    .shader
244                    .as_ref()
245                    .unwrap()
246                    .overrides
247                    .try_get(*h_override)
248                    .unwrap();
249
250                let new_override = Override {
251                    name: pipeline_override.name.clone(),
252                    id: pipeline_override.id,
253                    ty: self.import_type(&pipeline_override.ty),
254                    init: pipeline_override
255                        .init
256                        .map(|init| self.import_global_expression(init)),
257                };
258
259                let span = self
260                    .shader
261                    .as_ref()
262                    .unwrap()
263                    .overrides
264                    .get_span(*h_override);
265                let new_h = self
266                    .pipeline_overrides
267                    .fetch_or_append(new_override, self.map_span(span));
268                self.pipeline_override_map.insert(*h_override, new_h);
269                new_h
270            })
271    }
272
273    // remap a block
274    fn import_block(
275        &mut self,
276        block: &Block,
277        old_expressions: &Arena<Expression>,
278        already_imported: Rc<RefCell<IndexMap<Handle<Expression>, Handle<Expression>>>>,
279        new_expressions: Rc<RefCell<Arena<Expression>>>,
280    ) -> Block {
281        macro_rules! map_expr {
282            ($e:expr) => {
283                self.import_expression(
284                    *$e,
285                    old_expressions,
286                    already_imported.clone(),
287                    new_expressions.clone(),
288                    false,
289                    false,
290                )
291            };
292        }
293
294        macro_rules! map_expr_opt {
295            ($e:expr) => {
296                $e.as_ref().map(|expr| map_expr!(expr))
297            };
298        }
299
300        macro_rules! map_block {
301            ($b:expr) => {
302                self.import_block(
303                    $b,
304                    old_expressions,
305                    already_imported.clone(),
306                    new_expressions.clone(),
307                )
308            };
309        }
310
311        let statements = block
312            .iter()
313            .map(|stmt| {
314                match stmt {
315                    // remap function calls
316                    Statement::Call {
317                        function,
318                        arguments,
319                        result,
320                    } => Statement::Call {
321                        function: self.map_function_handle(function),
322                        arguments: arguments.iter().map(|expr| map_expr!(expr)).collect(),
323                        result: result.as_ref().map(|result| map_expr!(result)),
324                    },
325
326                    // recursively
327                    Statement::Block(b) => Statement::Block(map_block!(b)),
328                    Statement::If {
329                        condition,
330                        accept,
331                        reject,
332                    } => Statement::If {
333                        condition: map_expr!(condition),
334                        accept: map_block!(accept),
335                        reject: map_block!(reject),
336                    },
337                    Statement::Switch { selector, cases } => Statement::Switch {
338                        selector: map_expr!(selector),
339                        cases: cases
340                            .iter()
341                            .map(|case| SwitchCase {
342                                value: case.value,
343                                body: map_block!(&case.body),
344                                fall_through: case.fall_through,
345                            })
346                            .collect(),
347                    },
348                    Statement::Loop {
349                        body,
350                        continuing,
351                        break_if,
352                    } => Statement::Loop {
353                        body: map_block!(body),
354                        continuing: map_block!(continuing),
355                        break_if: map_expr_opt!(break_if),
356                    },
357
358                    // map expressions
359                    Statement::Emit(exprs) => {
360                        // iterate once to add expressions that should NOT be part of the emit statement
361                        for expr in exprs.clone() {
362                            self.import_expression(
363                                expr,
364                                old_expressions,
365                                already_imported.clone(),
366                                new_expressions.clone(),
367                                true,
368                                false,
369                            );
370                        }
371                        let old_length = new_expressions.borrow().len();
372                        // iterate again to add expressions that should be part of the emit statement
373                        for expr in exprs.clone() {
374                            map_expr!(&expr);
375                        }
376
377                        Statement::Emit(new_expressions.borrow().range_from(old_length))
378                    }
379                    Statement::Store { pointer, value } => Statement::Store {
380                        pointer: map_expr!(pointer),
381                        value: map_expr!(value),
382                    },
383                    Statement::ImageStore {
384                        image,
385                        coordinate,
386                        array_index,
387                        value,
388                    } => Statement::ImageStore {
389                        image: map_expr!(image),
390                        coordinate: map_expr!(coordinate),
391                        array_index: map_expr_opt!(array_index),
392                        value: map_expr!(value),
393                    },
394                    Statement::Atomic {
395                        pointer,
396                        fun,
397                        value,
398                        result,
399                    } => {
400                        let fun = match fun {
401                            AtomicFunction::Exchange {
402                                compare: Some(compare_expr),
403                            } => AtomicFunction::Exchange {
404                                compare: Some(map_expr!(compare_expr)),
405                            },
406                            fun => *fun,
407                        };
408                        Statement::Atomic {
409                            pointer: map_expr!(pointer),
410                            fun,
411                            value: map_expr!(value),
412                            result: map_expr_opt!(result),
413                        }
414                    }
415                    Statement::WorkGroupUniformLoad { pointer, result } => {
416                        Statement::WorkGroupUniformLoad {
417                            pointer: map_expr!(pointer),
418                            result: map_expr!(result),
419                        }
420                    }
421                    Statement::Return { value } => Statement::Return {
422                        value: map_expr_opt!(value),
423                    },
424                    Statement::RayQuery { query, fun } => Statement::RayQuery {
425                        query: map_expr!(query),
426                        fun: match fun {
427                            naga::RayQueryFunction::Initialize {
428                                acceleration_structure,
429                                descriptor,
430                            } => naga::RayQueryFunction::Initialize {
431                                acceleration_structure: map_expr!(acceleration_structure),
432                                descriptor: map_expr!(descriptor),
433                            },
434                            naga::RayQueryFunction::Proceed { result } => {
435                                naga::RayQueryFunction::Proceed {
436                                    result: map_expr!(result),
437                                }
438                            }
439                            naga::RayQueryFunction::Terminate => naga::RayQueryFunction::Terminate,
440                        },
441                    },
442                    Statement::SubgroupBallot { result, predicate } => Statement::SubgroupBallot {
443                        result: map_expr!(result),
444                        predicate: map_expr_opt!(predicate),
445                    },
446                    Statement::SubgroupGather {
447                        mut mode,
448                        argument,
449                        result,
450                    } => {
451                        match mode {
452                            GatherMode::BroadcastFirst => (),
453                            GatherMode::Broadcast(ref mut h_src)
454                            | GatherMode::Shuffle(ref mut h_src)
455                            | GatherMode::ShuffleDown(ref mut h_src)
456                            | GatherMode::ShuffleUp(ref mut h_src)
457                            | GatherMode::ShuffleXor(ref mut h_src) => *h_src = map_expr!(h_src),
458                        };
459                        Statement::SubgroupGather {
460                            mode,
461                            argument: map_expr!(argument),
462                            result: map_expr!(result),
463                        }
464                    }
465                    Statement::SubgroupCollectiveOperation {
466                        op,
467                        collective_op,
468                        argument,
469                        result,
470                    } => Statement::SubgroupCollectiveOperation {
471                        op: *op,
472                        collective_op: *collective_op,
473                        argument: map_expr!(argument),
474                        result: map_expr!(result),
475                    },
476                    Statement::ImageAtomic {
477                        image,
478                        coordinate,
479                        array_index,
480                        fun,
481                        value,
482                    } => {
483                        let fun = match fun {
484                            AtomicFunction::Exchange {
485                                compare: Some(compare_expr),
486                            } => AtomicFunction::Exchange {
487                                compare: Some(map_expr!(compare_expr)),
488                            },
489                            fun => *fun,
490                        };
491                        Statement::ImageAtomic {
492                            image: map_expr!(image),
493                            coordinate: map_expr!(coordinate),
494                            array_index: map_expr_opt!(array_index),
495                            fun,
496                            value: map_expr!(value),
497                        }
498                    }
499                    // else just copy
500                    Statement::Break
501                    | Statement::Continue
502                    | Statement::Kill
503                    | Statement::Barrier(_) => stmt.clone(),
504                }
505            })
506            .collect();
507
508        let mut new_block = Block::from_vec(statements);
509
510        for ((_, new_span), (_, old_span)) in new_block.span_iter_mut().zip(block.span_iter()) {
511            *new_span.unwrap() = self.map_span(*old_span);
512        }
513
514        new_block
515    }
516
517    fn import_expression(
518        &mut self,
519        h_expr: Handle<Expression>,
520        old_expressions: &Arena<Expression>,
521        already_imported: Rc<RefCell<IndexMap<Handle<Expression>, Handle<Expression>>>>,
522        new_expressions: Rc<RefCell<Arena<Expression>>>,
523        non_emitting_only: bool, // only brings items that should NOT be emitted into scope
524        unique: bool,            // ensure expressions are unique with custom comparison
525    ) -> Handle<Expression> {
526        if let Some(h_new) = already_imported.borrow().get(&h_expr) {
527            return *h_new;
528        }
529
530        macro_rules! map_expr {
531            ($e:expr) => {
532                self.import_expression(
533                    *$e,
534                    old_expressions,
535                    already_imported.clone(),
536                    new_expressions.clone(),
537                    non_emitting_only,
538                    unique,
539                )
540            };
541        }
542
543        macro_rules! map_expr_opt {
544            ($e:expr) => {
545                $e.as_ref().map(|expr| {
546                    self.import_expression(
547                        *expr,
548                        old_expressions,
549                        already_imported.clone(),
550                        new_expressions.clone(),
551                        non_emitting_only,
552                        unique,
553                    )
554                })
555            };
556        }
557
558        let mut is_external = false;
559        let expr = old_expressions.try_get(h_expr).unwrap();
560        let expr = match expr {
561            Expression::Literal(_) => {
562                is_external = true;
563                expr.clone()
564            }
565            Expression::ZeroValue(zv) => {
566                is_external = true;
567                Expression::ZeroValue(self.import_type(zv))
568            }
569            Expression::CallResult(f) => Expression::CallResult(self.map_function_handle(f)),
570            Expression::Constant(c) => {
571                is_external = true;
572                Expression::Constant(self.import_const(c))
573            }
574            Expression::Compose { ty, components } => Expression::Compose {
575                ty: self.import_type(ty),
576                components: components.iter().map(|expr| map_expr!(expr)).collect(),
577            },
578            Expression::GlobalVariable(gv) => {
579                is_external = true;
580                Expression::GlobalVariable(self.import_global(gv))
581            }
582            Expression::ImageSample {
583                image,
584                sampler,
585                gather,
586                coordinate,
587                array_index,
588                offset,
589                level,
590                depth_ref,
591            } => Expression::ImageSample {
592                image: map_expr!(image),
593                sampler: map_expr!(sampler),
594                gather: *gather,
595                coordinate: map_expr!(coordinate),
596                array_index: map_expr_opt!(array_index),
597                offset: offset.map(|c| self.import_global_expression(c)),
598                level: match level {
599                    SampleLevel::Auto | SampleLevel::Zero => *level,
600                    SampleLevel::Exact(expr) => SampleLevel::Exact(map_expr!(expr)),
601                    SampleLevel::Bias(expr) => SampleLevel::Bias(map_expr!(expr)),
602                    SampleLevel::Gradient { x, y } => SampleLevel::Gradient {
603                        x: map_expr!(x),
604                        y: map_expr!(y),
605                    },
606                },
607                depth_ref: map_expr_opt!(depth_ref),
608            },
609            Expression::Access { base, index } => Expression::Access {
610                base: map_expr!(base),
611                index: map_expr!(index),
612            },
613            Expression::AccessIndex { base, index } => Expression::AccessIndex {
614                base: map_expr!(base),
615                index: *index,
616            },
617            Expression::Splat { size, value } => Expression::Splat {
618                size: *size,
619                value: map_expr!(value),
620            },
621            Expression::Swizzle {
622                size,
623                vector,
624                pattern,
625            } => Expression::Swizzle {
626                size: *size,
627                vector: map_expr!(vector),
628                pattern: *pattern,
629            },
630            Expression::Load { pointer } => Expression::Load {
631                pointer: map_expr!(pointer),
632            },
633            Expression::ImageLoad {
634                image,
635                coordinate,
636                array_index,
637                sample,
638                level,
639            } => Expression::ImageLoad {
640                image: map_expr!(image),
641                coordinate: map_expr!(coordinate),
642                array_index: map_expr_opt!(array_index),
643                sample: map_expr_opt!(sample),
644                level: map_expr_opt!(level),
645            },
646            Expression::ImageQuery { image, query } => Expression::ImageQuery {
647                image: map_expr!(image),
648                query: match query {
649                    ImageQuery::Size { level } => ImageQuery::Size {
650                        level: map_expr_opt!(level),
651                    },
652                    _ => *query,
653                },
654            },
655            Expression::Unary { op, expr } => Expression::Unary {
656                op: *op,
657                expr: map_expr!(expr),
658            },
659            Expression::Binary { op, left, right } => Expression::Binary {
660                op: *op,
661                left: map_expr!(left),
662                right: map_expr!(right),
663            },
664            Expression::Select {
665                condition,
666                accept,
667                reject,
668            } => Expression::Select {
669                condition: map_expr!(condition),
670                accept: map_expr!(accept),
671                reject: map_expr!(reject),
672            },
673            Expression::Derivative { axis, expr, ctrl } => Expression::Derivative {
674                axis: *axis,
675                expr: map_expr!(expr),
676                ctrl: *ctrl,
677            },
678            Expression::Relational { fun, argument } => Expression::Relational {
679                fun: *fun,
680                argument: map_expr!(argument),
681            },
682            Expression::Math {
683                fun,
684                arg,
685                arg1,
686                arg2,
687                arg3,
688            } => Expression::Math {
689                fun: *fun,
690                arg: map_expr!(arg),
691                arg1: map_expr_opt!(arg1),
692                arg2: map_expr_opt!(arg2),
693                arg3: map_expr_opt!(arg3),
694            },
695            Expression::As {
696                expr,
697                kind,
698                convert,
699            } => Expression::As {
700                expr: map_expr!(expr),
701                kind: *kind,
702                convert: *convert,
703            },
704            Expression::ArrayLength(expr) => Expression::ArrayLength(map_expr!(expr)),
705
706            Expression::LocalVariable(_) | Expression::FunctionArgument(_) => {
707                is_external = true;
708                expr.clone()
709            }
710
711            Expression::AtomicResult { ty, comparison } => Expression::AtomicResult {
712                ty: self.import_type(ty),
713                comparison: *comparison,
714            },
715            Expression::WorkGroupUniformLoadResult { ty } => {
716                Expression::WorkGroupUniformLoadResult {
717                    ty: self.import_type(ty),
718                }
719            }
720            Expression::RayQueryProceedResult => expr.clone(),
721            Expression::RayQueryGetIntersection { query, committed } => {
722                Expression::RayQueryGetIntersection {
723                    query: map_expr!(query),
724                    committed: *committed,
725                }
726            }
727            Expression::Override(h_override) => {
728                is_external = true;
729                Expression::Override(self.import_pipeline_override(h_override))
730            }
731            Expression::SubgroupBallotResult => expr.clone(),
732            Expression::SubgroupOperationResult { ty } => Expression::SubgroupOperationResult {
733                ty: self.import_type(ty),
734            },
735        };
736
737        if !non_emitting_only || is_external {
738            let span = old_expressions.get_span(h_expr);
739            let h_new = if unique {
740                new_expressions.borrow_mut().fetch_if_or_append(
741                    expr,
742                    self.map_span(span),
743                    |lhs, rhs| lhs == rhs,
744                )
745            } else {
746                new_expressions
747                    .borrow_mut()
748                    .append(expr, self.map_span(span))
749            };
750
751            already_imported.borrow_mut().insert(h_expr, h_new);
752            h_new
753        } else {
754            h_expr
755        }
756    }
757
758    // remap function global references (global vars, consts, types) into our derived context
759    pub fn localize_function(&mut self, func: &Function) -> Function {
760        let arguments = func
761            .arguments
762            .iter()
763            .map(|arg| FunctionArgument {
764                name: arg.name.clone(),
765                ty: self.import_type(&arg.ty),
766                binding: arg.binding.clone(),
767            })
768            .collect();
769
770        let result = func.result.as_ref().map(|r| FunctionResult {
771            ty: self.import_type(&r.ty),
772            binding: r.binding.clone(),
773        });
774
775        let expressions = Rc::new(RefCell::new(Arena::new()));
776        let expr_map = Rc::new(RefCell::new(IndexMap::new()));
777
778        let mut local_variables = Arena::new();
779        for (h_l, l) in func.local_variables.iter() {
780            let new_local = LocalVariable {
781                name: l.name.clone(),
782                ty: self.import_type(&l.ty),
783                init: l.init.map(|c| {
784                    self.import_expression(
785                        c,
786                        &func.expressions,
787                        expr_map.clone(),
788                        expressions.clone(),
789                        false,
790                        true,
791                    )
792                }),
793            };
794            let span = func.local_variables.get_span(h_l);
795            let new_h = local_variables.append(new_local, self.map_span(span));
796            assert_eq!(h_l, new_h);
797        }
798
799        let body = self.import_block(
800            &func.body,
801            &func.expressions,
802            expr_map.clone(),
803            expressions.clone(),
804        );
805
806        let named_expressions = func
807            .named_expressions
808            .iter()
809            .flat_map(|(h_expr, name)| {
810                expr_map
811                    .borrow()
812                    .get(h_expr)
813                    .map(|new_h| (*new_h, name.clone()))
814            })
815            .collect::<IndexMap<_, _, std::hash::BuildHasherDefault<rustc_hash::FxHasher>>>();
816
817        Function {
818            name: func.name.clone(),
819            arguments,
820            result,
821            local_variables,
822            expressions: Rc::try_unwrap(expressions).unwrap().into_inner(),
823            named_expressions,
824            body,
825            diagnostic_filter_leaf: None,
826        }
827    }
828
829    // import a function defined in the source shader context.
830    // func name may be already defined, the returned handle will refer to the new function.
831    // the previously defined function will still be valid.
832    pub fn import_function(&mut self, func: &Function, span: Span) -> Handle<Function> {
833        let name = func.name.as_ref().unwrap().clone();
834        let mapped_func = self.localize_function(func);
835        let new_span = self.map_span(span);
836        let new_h = self.functions.append(mapped_func, new_span);
837        self.function_map.insert(name, new_h);
838        new_h
839    }
840
841    // get the derived handle corresponding to the given source function handle
842    // requires func to be named
843    pub fn map_function_handle(&mut self, h_func: &Handle<Function>) -> Handle<Function> {
844        let functions = &self.shader.as_ref().unwrap().functions;
845        let func = functions.try_get(*h_func).unwrap();
846        let name = func.name.as_ref().unwrap();
847        self.function_map.get(name).copied().unwrap_or_else(|| {
848            let span = functions.get_span(*h_func);
849            self.import_function(func, span)
850        })
851    }
852
853    /// swap an already imported function for a new one.
854    /// note span cannot be updated
855    pub fn import_function_if_new(&mut self, func: &Function, span: Span) -> Handle<Function> {
856        let name = func.name.as_ref().unwrap().clone();
857        if let Some(h) = self.function_map.get(&name) {
858            return *h;
859        }
860
861        self.import_function(func, span)
862    }
863
864    /// get any required special types for this module
865    pub fn has_required_special_types(&self) -> bool {
866        !self.special_types.predeclared_types.is_empty()
867            || self.special_types.ray_desc.is_some()
868            || self.special_types.ray_intersection.is_some()
869    }
870
871    pub fn into_module_with_entrypoints(mut self) -> naga::Module {
872        let entry_points = self
873            .shader
874            .unwrap()
875            .entry_points
876            .iter()
877            .map(|ep| EntryPoint {
878                name: ep.name.clone(),
879                stage: ep.stage,
880                early_depth_test: ep.early_depth_test,
881                workgroup_size: ep.workgroup_size,
882                function: self.localize_function(&ep.function),
883                workgroup_size_overrides: ep.workgroup_size_overrides,
884            })
885            .collect();
886
887        naga::Module {
888            entry_points,
889            ..self.into()
890        }
891    }
892}
893
894impl From<DerivedModule<'_>> for naga::Module {
895    fn from(derived: DerivedModule) -> Self {
896        naga::Module {
897            types: derived.types,
898            constants: derived.constants,
899            global_variables: derived.globals,
900            global_expressions: Rc::try_unwrap(derived.global_expressions)
901                .unwrap()
902                .into_inner(),
903            functions: derived.functions,
904            special_types: derived.special_types,
905            entry_points: Default::default(),
906            overrides: derived.pipeline_overrides,
907            diagnostic_filters: Default::default(),
908            diagnostic_filter_leaf: None,
909        }
910    }
911}