naga_oil/
derive.rs

1use indexmap::IndexMap;
2use naga::{
3    Arena, AtomicFunction, Block, Constant, EntryPoint, Expression, Function, FunctionArgument,
4    FunctionResult, GatherMode, GlobalVariable, Handle, ImageQuery, LocalVariable, Module,
5    Override, SampleLevel, Span, Statement, StructMember, SwitchCase, Type, TypeInner, UniqueArena,
6};
7use std::{cell::RefCell, rc::Rc};
8
9#[derive(Debug, Default)]
10pub struct DerivedModule<'a> {
11    shader: Option<&'a Module>,
12    span_offset: usize,
13
14    /// Maps the original type handle to the the mangled type handle.
15    type_map: IndexMap<Handle<Type>, Handle<Type>>,
16    /// Maps the original const handle to the the mangled const handle.
17    const_map: IndexMap<Handle<Constant>, Handle<Constant>>,
18    /// Maps the original pipeline override handle to the the mangled pipeline override handle.
19    pipeline_override_map: IndexMap<Handle<Override>, Handle<Override>>,
20    /// Contains both const expressions and pipeline override constant expressions.
21    /// The expressions are stored together because that's what Naga expects.
22    global_expressions: Rc<RefCell<Arena<Expression>>>,
23    /// Maps the original expression handle to the new expression handle for const expressions and pipeline override expressions.
24    /// The expressions are stored together because that's what Naga expects.
25    global_expression_map: Rc<RefCell<IndexMap<Handle<Expression>, Handle<Expression>>>>,
26    global_map: IndexMap<Handle<GlobalVariable>, Handle<GlobalVariable>>,
27    function_map: IndexMap<String, Handle<Function>>,
28    types: UniqueArena<Type>,
29    constants: Arena<Constant>,
30    globals: Arena<GlobalVariable>,
31    functions: Arena<Function>,
32    pipeline_overrides: Arena<Override>,
33    special_types: naga::SpecialTypes,
34}
35
36impl<'a> DerivedModule<'a> {
37    // set source context for import operations
38    pub fn set_shader_source(&mut self, shader: &'a Module, span_offset: usize) {
39        self.clear_shader_source();
40        self.shader = Some(shader);
41        self.span_offset = span_offset;
42
43        // eagerly import special types
44        if let Some(h_special_type) = shader.special_types.ray_desc.as_ref() {
45            if let Some(derived_special_type) = self.special_types.ray_desc.as_ref() {
46                self.type_map.insert(*h_special_type, *derived_special_type);
47            } else {
48                self.special_types.ray_desc = Some(self.import_type(h_special_type));
49            }
50        }
51        if let Some(h_special_type) = shader.special_types.ray_intersection.as_ref() {
52            if let Some(derived_special_type) = self.special_types.ray_intersection.as_ref() {
53                self.type_map.insert(*h_special_type, *derived_special_type);
54            } else {
55                self.special_types.ray_intersection = Some(self.import_type(h_special_type));
56            }
57        }
58        if let Some(h_special_type) = shader.special_types.ray_vertex_return.as_ref() {
59            if let Some(derived_special_type) = self.special_types.ray_vertex_return.as_ref() {
60                self.type_map.insert(*h_special_type, *derived_special_type);
61            } else {
62                self.special_types.ray_vertex_return = Some(self.import_type(h_special_type));
63            }
64        }
65        for (predeclared, h_predeclared_type) in shader.special_types.predeclared_types.iter() {
66            if let Some(derived_special_type) =
67                self.special_types.predeclared_types.get(predeclared)
68            {
69                self.type_map
70                    .insert(*h_predeclared_type, *derived_special_type);
71            } else {
72                let new_h = self.import_type(h_predeclared_type);
73                self.special_types
74                    .predeclared_types
75                    .insert(predeclared.clone(), new_h);
76            }
77        }
78    }
79
80    // detach source context
81    pub fn clear_shader_source(&mut self) {
82        self.shader = None;
83        self.type_map.clear();
84        self.const_map.clear();
85        self.global_map.clear();
86        self.global_expression_map.borrow_mut().clear();
87        self.pipeline_override_map.clear();
88    }
89
90    pub fn map_span(&self, span: Span) -> Span {
91        let span = span.to_range();
92        match span {
93            Some(rng) => Span::new(
94                (rng.start + self.span_offset) as u32,
95                (rng.end + self.span_offset) as u32,
96            ),
97            None => Span::UNDEFINED,
98        }
99    }
100
101    // remap a type from source context into our derived context
102    pub fn import_type(&mut self, h_type: &Handle<Type>) -> Handle<Type> {
103        self.rename_type(h_type, None)
104    }
105
106    // remap a type from source context into our derived context, and rename it
107    pub fn rename_type(&mut self, h_type: &Handle<Type>, name: Option<String>) -> Handle<Type> {
108        self.type_map.get(h_type).copied().unwrap_or_else(|| {
109            let ty = self
110                .shader
111                .as_ref()
112                .unwrap()
113                .types
114                .get_handle(*h_type)
115                .unwrap();
116
117            let name = match name {
118                Some(name) => Some(name),
119                None => ty.name.clone(),
120            };
121
122            let new_type = Type {
123                name,
124                inner: match &ty.inner {
125                    TypeInner::Scalar { .. }
126                    | TypeInner::Vector { .. }
127                    | TypeInner::Matrix { .. }
128                    | TypeInner::ValuePointer { .. }
129                    | TypeInner::Image { .. }
130                    | TypeInner::Sampler { .. }
131                    | TypeInner::Atomic { .. }
132                    | TypeInner::AccelerationStructure { .. }
133                    | TypeInner::RayQuery { .. } => ty.inner.clone(),
134
135                    TypeInner::Pointer { base, space } => TypeInner::Pointer {
136                        base: self.import_type(base),
137                        space: *space,
138                    },
139                    TypeInner::Struct { members, span } => {
140                        let members = members
141                            .iter()
142                            .map(|m| StructMember {
143                                name: m.name.clone(),
144                                ty: self.import_type(&m.ty),
145                                binding: m.binding.clone(),
146                                offset: m.offset,
147                            })
148                            .collect();
149                        TypeInner::Struct {
150                            members,
151                            span: *span,
152                        }
153                    }
154                    TypeInner::Array { base, size, stride } => TypeInner::Array {
155                        base: self.import_type(base),
156                        size: *size,
157                        stride: *stride,
158                    },
159                    TypeInner::BindingArray { base, size } => TypeInner::BindingArray {
160                        base: self.import_type(base),
161                        size: *size,
162                    },
163                },
164            };
165            let span = self.shader.as_ref().unwrap().types.get_span(*h_type);
166            let new_h = self.types.insert(new_type, self.map_span(span));
167            self.type_map.insert(*h_type, new_h);
168            new_h
169        })
170    }
171
172    // remap a const from source context into our derived context
173    pub fn import_const(&mut self, h_const: &Handle<Constant>) -> Handle<Constant> {
174        self.const_map.get(h_const).copied().unwrap_or_else(|| {
175            let c = self
176                .shader
177                .as_ref()
178                .unwrap()
179                .constants
180                .try_get(*h_const)
181                .unwrap();
182
183            let new_const = Constant {
184                name: c.name.clone(),
185                ty: self.import_type(&c.ty),
186                init: self.import_global_expression(c.init),
187            };
188
189            let span = self.shader.as_ref().unwrap().constants.get_span(*h_const);
190            let new_h = self
191                .constants
192                .fetch_or_append(new_const, self.map_span(span));
193            self.const_map.insert(*h_const, new_h);
194            new_h
195        })
196    }
197
198    // remap a global from source context into our derived context
199    pub fn import_global(&mut self, h_global: &Handle<GlobalVariable>) -> Handle<GlobalVariable> {
200        self.global_map.get(h_global).copied().unwrap_or_else(|| {
201            let gv = self
202                .shader
203                .as_ref()
204                .unwrap()
205                .global_variables
206                .try_get(*h_global)
207                .unwrap();
208
209            let new_global = GlobalVariable {
210                name: gv.name.clone(),
211                space: gv.space,
212                binding: gv.binding,
213                ty: self.import_type(&gv.ty),
214                init: gv.init.map(|c| self.import_global_expression(c)),
215            };
216
217            let span = self
218                .shader
219                .as_ref()
220                .unwrap()
221                .global_variables
222                .get_span(*h_global);
223            let new_h = self
224                .globals
225                .fetch_or_append(new_global, self.map_span(span));
226            self.global_map.insert(*h_global, new_h);
227            new_h
228        })
229    }
230
231    // remap either a const or pipeline override expression from source context into our derived context
232    pub fn import_global_expression(&mut self, h_expr: Handle<Expression>) -> Handle<Expression> {
233        self.import_expression(
234            h_expr,
235            &self.shader.as_ref().unwrap().global_expressions,
236            self.global_expression_map.clone(),
237            self.global_expressions.clone(),
238            false,
239            true,
240        )
241    }
242
243    // remap a pipeline override from source context into our derived context
244    pub fn import_pipeline_override(&mut self, h_override: &Handle<Override>) -> Handle<Override> {
245        self.pipeline_override_map
246            .get(h_override)
247            .copied()
248            .unwrap_or_else(|| {
249                let pipeline_override = self
250                    .shader
251                    .as_ref()
252                    .unwrap()
253                    .overrides
254                    .try_get(*h_override)
255                    .unwrap();
256
257                let new_override = Override {
258                    name: pipeline_override.name.clone(),
259                    id: pipeline_override.id,
260                    ty: self.import_type(&pipeline_override.ty),
261                    init: pipeline_override
262                        .init
263                        .map(|init| self.import_global_expression(init)),
264                };
265
266                let span = self
267                    .shader
268                    .as_ref()
269                    .unwrap()
270                    .overrides
271                    .get_span(*h_override);
272                let new_h = self
273                    .pipeline_overrides
274                    .fetch_or_append(new_override, self.map_span(span));
275                self.pipeline_override_map.insert(*h_override, new_h);
276                new_h
277            })
278    }
279
280    // remap a block
281    fn import_block(
282        &mut self,
283        block: &Block,
284        old_expressions: &Arena<Expression>,
285        already_imported: Rc<RefCell<IndexMap<Handle<Expression>, Handle<Expression>>>>,
286        new_expressions: Rc<RefCell<Arena<Expression>>>,
287    ) -> Block {
288        macro_rules! map_expr {
289            ($e:expr) => {
290                self.import_expression(
291                    *$e,
292                    old_expressions,
293                    already_imported.clone(),
294                    new_expressions.clone(),
295                    false,
296                    false,
297                )
298            };
299        }
300
301        macro_rules! map_expr_opt {
302            ($e:expr) => {
303                $e.as_ref().map(|expr| map_expr!(expr))
304            };
305        }
306
307        macro_rules! map_block {
308            ($b:expr) => {
309                self.import_block(
310                    $b,
311                    old_expressions,
312                    already_imported.clone(),
313                    new_expressions.clone(),
314                )
315            };
316        }
317
318        let statements = block
319            .iter()
320            .map(|stmt| {
321                match stmt {
322                    // remap function calls
323                    Statement::Call {
324                        function,
325                        arguments,
326                        result,
327                    } => Statement::Call {
328                        function: self.map_function_handle(function),
329                        arguments: arguments.iter().map(|expr| map_expr!(expr)).collect(),
330                        result: result.as_ref().map(|result| map_expr!(result)),
331                    },
332
333                    // recursively
334                    Statement::Block(b) => Statement::Block(map_block!(b)),
335                    Statement::If {
336                        condition,
337                        accept,
338                        reject,
339                    } => Statement::If {
340                        condition: map_expr!(condition),
341                        accept: map_block!(accept),
342                        reject: map_block!(reject),
343                    },
344                    Statement::Switch { selector, cases } => Statement::Switch {
345                        selector: map_expr!(selector),
346                        cases: cases
347                            .iter()
348                            .map(|case| SwitchCase {
349                                value: case.value,
350                                body: map_block!(&case.body),
351                                fall_through: case.fall_through,
352                            })
353                            .collect(),
354                    },
355                    Statement::Loop {
356                        body,
357                        continuing,
358                        break_if,
359                    } => Statement::Loop {
360                        body: map_block!(body),
361                        continuing: map_block!(continuing),
362                        break_if: map_expr_opt!(break_if),
363                    },
364
365                    // map expressions
366                    Statement::Emit(exprs) => {
367                        // iterate once to add expressions that should NOT be part of the emit statement
368                        for expr in exprs.clone() {
369                            self.import_expression(
370                                expr,
371                                old_expressions,
372                                already_imported.clone(),
373                                new_expressions.clone(),
374                                true,
375                                false,
376                            );
377                        }
378                        let old_length = new_expressions.borrow().len();
379                        // iterate again to add expressions that should be part of the emit statement
380                        for expr in exprs.clone() {
381                            map_expr!(&expr);
382                        }
383
384                        Statement::Emit(new_expressions.borrow().range_from(old_length))
385                    }
386                    Statement::Store { pointer, value } => Statement::Store {
387                        pointer: map_expr!(pointer),
388                        value: map_expr!(value),
389                    },
390                    Statement::ImageStore {
391                        image,
392                        coordinate,
393                        array_index,
394                        value,
395                    } => Statement::ImageStore {
396                        image: map_expr!(image),
397                        coordinate: map_expr!(coordinate),
398                        array_index: map_expr_opt!(array_index),
399                        value: map_expr!(value),
400                    },
401                    Statement::Atomic {
402                        pointer,
403                        fun,
404                        value,
405                        result,
406                    } => {
407                        let fun = match fun {
408                            AtomicFunction::Exchange {
409                                compare: Some(compare_expr),
410                            } => AtomicFunction::Exchange {
411                                compare: Some(map_expr!(compare_expr)),
412                            },
413                            fun => *fun,
414                        };
415                        Statement::Atomic {
416                            pointer: map_expr!(pointer),
417                            fun,
418                            value: map_expr!(value),
419                            result: map_expr_opt!(result),
420                        }
421                    }
422                    Statement::WorkGroupUniformLoad { pointer, result } => {
423                        Statement::WorkGroupUniformLoad {
424                            pointer: map_expr!(pointer),
425                            result: map_expr!(result),
426                        }
427                    }
428                    Statement::Return { value } => Statement::Return {
429                        value: map_expr_opt!(value),
430                    },
431                    Statement::RayQuery { query, fun } => Statement::RayQuery {
432                        query: map_expr!(query),
433                        fun: match fun {
434                            naga::RayQueryFunction::Initialize {
435                                acceleration_structure,
436                                descriptor,
437                            } => naga::RayQueryFunction::Initialize {
438                                acceleration_structure: map_expr!(acceleration_structure),
439                                descriptor: map_expr!(descriptor),
440                            },
441                            naga::RayQueryFunction::Proceed { result } => {
442                                naga::RayQueryFunction::Proceed {
443                                    result: map_expr!(result),
444                                }
445                            }
446                            naga::RayQueryFunction::Terminate => naga::RayQueryFunction::Terminate,
447                            naga::RayQueryFunction::GenerateIntersection { hit_t } => {
448                                naga::RayQueryFunction::GenerateIntersection {
449                                    hit_t: map_expr!(hit_t),
450                                }
451                            }
452                            naga::RayQueryFunction::ConfirmIntersection => {
453                                naga::RayQueryFunction::ConfirmIntersection
454                            }
455                        },
456                    },
457                    Statement::SubgroupBallot { result, predicate } => Statement::SubgroupBallot {
458                        result: map_expr!(result),
459                        predicate: map_expr_opt!(predicate),
460                    },
461                    Statement::SubgroupGather {
462                        mut mode,
463                        argument,
464                        result,
465                    } => {
466                        match mode {
467                            GatherMode::BroadcastFirst => (),
468                            GatherMode::Broadcast(ref mut h_src)
469                            | GatherMode::QuadBroadcast(ref mut h_src)
470                            | GatherMode::Shuffle(ref mut h_src)
471                            | GatherMode::ShuffleDown(ref mut h_src)
472                            | GatherMode::ShuffleUp(ref mut h_src)
473                            | GatherMode::ShuffleXor(ref mut h_src) => *h_src = map_expr!(h_src),
474                            GatherMode::QuadSwap(_) => (),
475                        };
476                        Statement::SubgroupGather {
477                            mode,
478                            argument: map_expr!(argument),
479                            result: map_expr!(result),
480                        }
481                    }
482                    Statement::SubgroupCollectiveOperation {
483                        op,
484                        collective_op,
485                        argument,
486                        result,
487                    } => Statement::SubgroupCollectiveOperation {
488                        op: *op,
489                        collective_op: *collective_op,
490                        argument: map_expr!(argument),
491                        result: map_expr!(result),
492                    },
493                    Statement::ImageAtomic {
494                        image,
495                        coordinate,
496                        array_index,
497                        fun,
498                        value,
499                    } => {
500                        let fun = match fun {
501                            AtomicFunction::Exchange {
502                                compare: Some(compare_expr),
503                            } => AtomicFunction::Exchange {
504                                compare: Some(map_expr!(compare_expr)),
505                            },
506                            fun => *fun,
507                        };
508                        Statement::ImageAtomic {
509                            image: map_expr!(image),
510                            coordinate: map_expr!(coordinate),
511                            array_index: map_expr_opt!(array_index),
512                            fun,
513                            value: map_expr!(value),
514                        }
515                    }
516                    // else just copy
517                    Statement::Break
518                    | Statement::Continue
519                    | Statement::Kill
520                    | Statement::MemoryBarrier(_)
521                    | Statement::ControlBarrier(_) => stmt.clone(),
522                }
523            })
524            .collect();
525
526        let mut new_block = Block::from_vec(statements);
527
528        for ((_, new_span), (_, old_span)) in new_block.span_iter_mut().zip(block.span_iter()) {
529            *new_span.unwrap() = self.map_span(*old_span);
530        }
531
532        new_block
533    }
534
535    fn import_expression(
536        &mut self,
537        h_expr: Handle<Expression>,
538        old_expressions: &Arena<Expression>,
539        already_imported: Rc<RefCell<IndexMap<Handle<Expression>, Handle<Expression>>>>,
540        new_expressions: Rc<RefCell<Arena<Expression>>>,
541        non_emitting_only: bool, // only brings items that should NOT be emitted into scope
542        unique: bool,            // ensure expressions are unique with custom comparison
543    ) -> Handle<Expression> {
544        if let Some(h_new) = already_imported.borrow().get(&h_expr) {
545            return *h_new;
546        }
547
548        macro_rules! map_expr {
549            ($e:expr) => {
550                self.import_expression(
551                    *$e,
552                    old_expressions,
553                    already_imported.clone(),
554                    new_expressions.clone(),
555                    non_emitting_only,
556                    unique,
557                )
558            };
559        }
560
561        macro_rules! map_expr_opt {
562            ($e:expr) => {
563                $e.as_ref().map(|expr| {
564                    self.import_expression(
565                        *expr,
566                        old_expressions,
567                        already_imported.clone(),
568                        new_expressions.clone(),
569                        non_emitting_only,
570                        unique,
571                    )
572                })
573            };
574        }
575
576        let mut is_external = false;
577        let expr = old_expressions.try_get(h_expr).unwrap();
578        let expr = match expr {
579            Expression::Literal(_) => {
580                is_external = true;
581                expr.clone()
582            }
583            Expression::ZeroValue(zv) => {
584                is_external = true;
585                Expression::ZeroValue(self.import_type(zv))
586            }
587            Expression::CallResult(f) => Expression::CallResult(self.map_function_handle(f)),
588            Expression::Constant(c) => {
589                is_external = true;
590                Expression::Constant(self.import_const(c))
591            }
592            Expression::Compose { ty, components } => Expression::Compose {
593                ty: self.import_type(ty),
594                components: components.iter().map(|expr| map_expr!(expr)).collect(),
595            },
596            Expression::GlobalVariable(gv) => {
597                is_external = true;
598                Expression::GlobalVariable(self.import_global(gv))
599            }
600            Expression::ImageSample {
601                image,
602                sampler,
603                gather,
604                coordinate,
605                array_index,
606                offset,
607                level,
608                depth_ref,
609                clamp_to_edge,
610            } => Expression::ImageSample {
611                image: map_expr!(image),
612                sampler: map_expr!(sampler),
613                gather: *gather,
614                coordinate: map_expr!(coordinate),
615                array_index: map_expr_opt!(array_index),
616                offset: map_expr_opt!(offset),
617                level: match level {
618                    SampleLevel::Auto | SampleLevel::Zero => *level,
619                    SampleLevel::Exact(expr) => SampleLevel::Exact(map_expr!(expr)),
620                    SampleLevel::Bias(expr) => SampleLevel::Bias(map_expr!(expr)),
621                    SampleLevel::Gradient { x, y } => SampleLevel::Gradient {
622                        x: map_expr!(x),
623                        y: map_expr!(y),
624                    },
625                },
626                depth_ref: map_expr_opt!(depth_ref),
627                clamp_to_edge: *clamp_to_edge,
628            },
629            Expression::Access { base, index } => Expression::Access {
630                base: map_expr!(base),
631                index: map_expr!(index),
632            },
633            Expression::AccessIndex { base, index } => Expression::AccessIndex {
634                base: map_expr!(base),
635                index: *index,
636            },
637            Expression::Splat { size, value } => Expression::Splat {
638                size: *size,
639                value: map_expr!(value),
640            },
641            Expression::Swizzle {
642                size,
643                vector,
644                pattern,
645            } => Expression::Swizzle {
646                size: *size,
647                vector: map_expr!(vector),
648                pattern: *pattern,
649            },
650            Expression::Load { pointer } => Expression::Load {
651                pointer: map_expr!(pointer),
652            },
653            Expression::ImageLoad {
654                image,
655                coordinate,
656                array_index,
657                sample,
658                level,
659            } => Expression::ImageLoad {
660                image: map_expr!(image),
661                coordinate: map_expr!(coordinate),
662                array_index: map_expr_opt!(array_index),
663                sample: map_expr_opt!(sample),
664                level: map_expr_opt!(level),
665            },
666            Expression::ImageQuery { image, query } => Expression::ImageQuery {
667                image: map_expr!(image),
668                query: match query {
669                    ImageQuery::Size { level } => ImageQuery::Size {
670                        level: map_expr_opt!(level),
671                    },
672                    _ => *query,
673                },
674            },
675            Expression::Unary { op, expr } => Expression::Unary {
676                op: *op,
677                expr: map_expr!(expr),
678            },
679            Expression::Binary { op, left, right } => Expression::Binary {
680                op: *op,
681                left: map_expr!(left),
682                right: map_expr!(right),
683            },
684            Expression::Select {
685                condition,
686                accept,
687                reject,
688            } => Expression::Select {
689                condition: map_expr!(condition),
690                accept: map_expr!(accept),
691                reject: map_expr!(reject),
692            },
693            Expression::Derivative { axis, expr, ctrl } => Expression::Derivative {
694                axis: *axis,
695                expr: map_expr!(expr),
696                ctrl: *ctrl,
697            },
698            Expression::Relational { fun, argument } => Expression::Relational {
699                fun: *fun,
700                argument: map_expr!(argument),
701            },
702            Expression::Math {
703                fun,
704                arg,
705                arg1,
706                arg2,
707                arg3,
708            } => Expression::Math {
709                fun: *fun,
710                arg: map_expr!(arg),
711                arg1: map_expr_opt!(arg1),
712                arg2: map_expr_opt!(arg2),
713                arg3: map_expr_opt!(arg3),
714            },
715            Expression::As {
716                expr,
717                kind,
718                convert,
719            } => Expression::As {
720                expr: map_expr!(expr),
721                kind: *kind,
722                convert: *convert,
723            },
724            Expression::ArrayLength(expr) => Expression::ArrayLength(map_expr!(expr)),
725
726            Expression::LocalVariable(_) | Expression::FunctionArgument(_) => {
727                is_external = true;
728                expr.clone()
729            }
730
731            Expression::AtomicResult { ty, comparison } => Expression::AtomicResult {
732                ty: self.import_type(ty),
733                comparison: *comparison,
734            },
735            Expression::WorkGroupUniformLoadResult { ty } => {
736                Expression::WorkGroupUniformLoadResult {
737                    ty: self.import_type(ty),
738                }
739            }
740            Expression::RayQueryProceedResult => expr.clone(),
741            Expression::RayQueryGetIntersection { query, committed } => {
742                Expression::RayQueryGetIntersection {
743                    query: map_expr!(query),
744                    committed: *committed,
745                }
746            }
747            Expression::RayQueryVertexPositions { query, committed } => {
748                Expression::RayQueryVertexPositions {
749                    query: map_expr!(query),
750                    committed: *committed,
751                }
752            }
753            Expression::Override(h_override) => {
754                is_external = true;
755                Expression::Override(self.import_pipeline_override(h_override))
756            }
757            Expression::SubgroupBallotResult => expr.clone(),
758            Expression::SubgroupOperationResult { ty } => Expression::SubgroupOperationResult {
759                ty: self.import_type(ty),
760            },
761        };
762
763        if !non_emitting_only || is_external {
764            let span = old_expressions.get_span(h_expr);
765            let h_new = if unique {
766                new_expressions.borrow_mut().fetch_if_or_append(
767                    expr,
768                    self.map_span(span),
769                    |lhs, rhs| lhs == rhs,
770                )
771            } else {
772                new_expressions
773                    .borrow_mut()
774                    .append(expr, self.map_span(span))
775            };
776
777            already_imported.borrow_mut().insert(h_expr, h_new);
778            h_new
779        } else {
780            h_expr
781        }
782    }
783
784    // remap function global references (global vars, consts, types) into our derived context
785    pub fn localize_function(&mut self, func: &Function) -> Function {
786        let arguments = func
787            .arguments
788            .iter()
789            .map(|arg| FunctionArgument {
790                name: arg.name.clone(),
791                ty: self.import_type(&arg.ty),
792                binding: arg.binding.clone(),
793            })
794            .collect();
795
796        let result = func.result.as_ref().map(|r| FunctionResult {
797            ty: self.import_type(&r.ty),
798            binding: r.binding.clone(),
799        });
800
801        let expressions = Rc::new(RefCell::new(Arena::new()));
802        let expr_map = Rc::new(RefCell::new(IndexMap::new()));
803
804        let mut local_variables = Arena::new();
805        for (h_l, l) in func.local_variables.iter() {
806            let new_local = LocalVariable {
807                name: l.name.clone(),
808                ty: self.import_type(&l.ty),
809                init: l.init.map(|c| {
810                    self.import_expression(
811                        c,
812                        &func.expressions,
813                        expr_map.clone(),
814                        expressions.clone(),
815                        false,
816                        true,
817                    )
818                }),
819            };
820            let span = func.local_variables.get_span(h_l);
821            let new_h = local_variables.append(new_local, self.map_span(span));
822            assert_eq!(h_l, new_h);
823        }
824
825        let body = self.import_block(
826            &func.body,
827            &func.expressions,
828            expr_map.clone(),
829            expressions.clone(),
830        );
831
832        let named_expressions = func
833            .named_expressions
834            .iter()
835            .flat_map(|(h_expr, name)| {
836                expr_map
837                    .borrow()
838                    .get(h_expr)
839                    .map(|new_h| (*new_h, name.clone()))
840            })
841            .collect::<IndexMap<_, _, std::hash::BuildHasherDefault<rustc_hash::FxHasher>>>();
842
843        Function {
844            name: func.name.clone(),
845            arguments,
846            result,
847            local_variables,
848            expressions: Rc::try_unwrap(expressions).unwrap().into_inner(),
849            named_expressions,
850            body,
851            diagnostic_filter_leaf: None,
852        }
853    }
854
855    // import a function defined in the source shader context.
856    // func name may be already defined, the returned handle will refer to the new function.
857    // the previously defined function will still be valid.
858    pub fn import_function(&mut self, func: &Function, span: Span) -> Handle<Function> {
859        let name = func.name.as_ref().unwrap().clone();
860        let mapped_func = self.localize_function(func);
861        let new_span = self.map_span(span);
862        let new_h = self.functions.append(mapped_func, new_span);
863        self.function_map.insert(name, new_h);
864        new_h
865    }
866
867    // get the derived handle corresponding to the given source function handle
868    // requires func to be named
869    pub fn map_function_handle(&mut self, h_func: &Handle<Function>) -> Handle<Function> {
870        let functions = &self.shader.as_ref().unwrap().functions;
871        let func = functions.try_get(*h_func).unwrap();
872        let name = func.name.as_ref().unwrap();
873        self.function_map.get(name).copied().unwrap_or_else(|| {
874            let span = functions.get_span(*h_func);
875            self.import_function(func, span)
876        })
877    }
878
879    /// swap an already imported function for a new one.
880    /// note span cannot be updated
881    pub fn import_function_if_new(&mut self, func: &Function, span: Span) -> Handle<Function> {
882        let name = func.name.as_ref().unwrap().clone();
883        if let Some(h) = self.function_map.get(&name) {
884            return *h;
885        }
886
887        self.import_function(func, span)
888    }
889
890    /// get any required special types for this module
891    pub fn has_required_special_types(&self) -> bool {
892        !self.special_types.predeclared_types.is_empty()
893            || self.special_types.ray_desc.is_some()
894            || self.special_types.ray_intersection.is_some()
895            || self.special_types.ray_vertex_return.is_some()
896    }
897
898    pub fn into_module_with_entrypoints(mut self) -> naga::Module {
899        let entry_points = self
900            .shader
901            .unwrap()
902            .entry_points
903            .iter()
904            .map(|ep| EntryPoint {
905                name: ep.name.clone(),
906                stage: ep.stage,
907                early_depth_test: ep.early_depth_test,
908                workgroup_size: ep.workgroup_size,
909                function: self.localize_function(&ep.function),
910                workgroup_size_overrides: ep.workgroup_size_overrides,
911            })
912            .collect();
913
914        naga::Module {
915            entry_points,
916            ..self.into()
917        }
918    }
919}
920
921impl From<DerivedModule<'_>> for naga::Module {
922    fn from(derived: DerivedModule) -> Self {
923        naga::Module {
924            types: derived.types,
925            constants: derived.constants,
926            global_variables: derived.globals,
927            global_expressions: Rc::try_unwrap(derived.global_expressions)
928                .unwrap()
929                .into_inner(),
930            functions: derived.functions,
931            special_types: derived.special_types,
932            entry_points: Default::default(),
933            overrides: derived.pipeline_overrides,
934            diagnostic_filters: Default::default(),
935            diagnostic_filter_leaf: None,
936            doc_comments: None,
937        }
938    }
939}