naga_oil/compose/
mod.rs

1use indexmap::IndexMap;
2/// the compose module allows construction of shaders from modules (which are themselves shaders).
3///
4/// it does this by treating shaders as modules, and
5/// - building each module independently to naga IR
6/// - creating "header" files for each supported language, which are used to build dependent modules/shaders
7/// - making final shaders by combining the shader IR with the IR for imported modules
8///
9/// for multiple small shaders with large common imports, this can be faster than parsing the full source for each shader, and it allows for constructing shaders in a cleaner modular manner with better scope control.
10///
11/// ## imports
12///
13/// shaders can be added to the composer as modules. this makes their types, constants, variables and functions available to modules/shaders that import them. note that importing a module will affect the final shader's global state if the module defines globals variables with bindings.
14///
15/// modules must include a `#define_import_path` directive that names the module.
16///
17/// ```ignore
18/// #define_import_path my_module
19///
20/// fn my_func() -> f32 {
21///     return 1.0;
22/// }
23/// ```
24///
25/// shaders can then import the module with an `#import` directive (with an optional `as` name). at point of use, imported items must be qualified:
26///
27/// ```ignore
28/// #import my_module
29/// #import my_other_module as Mod2
30///
31/// fn main() -> f32 {
32///     let x = my_module::my_func();
33///     let y = Mod2::my_other_func();
34///     return x*y;
35/// }
36/// ```
37///
38/// or import a comma-separated list of individual items with a `#from` directive. at point of use, imported items must be prefixed with `::` :
39///
40/// ```ignore
41/// #from my_module import my_func, my_const
42///
43/// fn main() -> f32 {
44///     return ::my_func(::my_const);
45/// }
46/// ```
47///
48/// imports can be nested - modules may import other modules, but not recursively. when a new module is added, all its `#import`s must already have been added.
49/// the same module can be imported multiple times by different modules in the import tree.
50/// there is no overlap of namespaces, so the same function names (or type, constant, or variable names) may be used in different modules.
51///
52/// note: when importing an item with the `#from` directive, the final shader will include the required dependencies (bindings, globals, consts, other functions) of the imported item, but will not include the rest of the imported module. it will however still include all of any modules imported by the imported module. this is probably not desired in general and may be fixed in a future version. currently for a more complete culling of unused dependencies the `prune` module can be used.
53///
54/// ## overriding functions
55///
56/// virtual functions can be declared with the `virtual` keyword:
57/// ```ignore
58/// virtual fn point_light(world_position: vec3<f32>) -> vec3<f32> { ... }
59/// ```
60/// virtual functions defined in imported modules can then be overridden using the `override` keyword:
61///
62/// ```ignore
63/// #import bevy_pbr::lighting as Lighting
64///
65/// override fn Lighting::point_light (world_position: vec3<f32>) -> vec3<f32> {
66///     let original = Lighting::point_light(world_position);
67///     let quantized = vec3<u32>(original * 3.0);
68///     return vec3<f32>(quantized) / 3.0;
69/// }
70/// ```
71///
72/// override function definitions cause *all* calls to the original function in the entire shader scope to be replaced by calls to the new function, with the exception of calls within the override function itself.
73///
74/// the function signature of the override must match the base function.
75///
76/// overrides can be specified at any point in the final shader's import tree.
77///
78/// multiple overrides can be applied to the same function. for example, given :
79/// - a module `a` containing a function `f`,
80/// - a module `b` that imports `a`, and containing an `override a::f` function,
81/// - a module `c` that imports `a` and `b`, and containing an `override a::f` function,
82///
83/// then b and c both specify an override for `a::f`.
84/// the `override fn a::f` declared in module `b` may call to `a::f` within its body.
85/// the `override fn a::f` declared in module 'c' may call to `a::f` within its body, but the call will be redirected to `b::f`.
86/// any other calls to `a::f` (within modules 'a' or `b`, or anywhere else) will end up redirected to `c::f`
87/// in this way a chain or stack of overrides can be applied.
88///
89/// different overrides of the same function can be specified in different import branches. the final stack will be ordered based on the first occurrence of the override in the import tree (using a depth first search).
90///
91/// note that imports into a module/shader are processed in order, but are processed before the body of the current shader/module regardless of where they occur in that module, so there is no way to import a module containing an override and inject a call into the override stack prior to that imported override. you can instead create two modules each containing an override and import them into a parent module/shader to order them as required.
92/// override functions can currently only be defined in wgsl.
93///
94/// if the `override_any` crate feature is enabled, then the `virtual` keyword is not required for the function being overridden.
95///
96/// ## languages
97///
98/// modules can we written in GLSL or WGSL. shaders with entry points can be imported as modules (provided they have a `#define_import_path` directive). entry points are available to call from imported modules either via their name (for WGSL) or via `module::main` (for GLSL).
99///
100/// final shaders can also be written in GLSL or WGSL. for GLSL users must specify whether the shader is a vertex shader or fragment shader via the `ShaderType` argument (GLSL compute shaders are not supported).
101///
102/// ## preprocessing
103///
104/// when generating a final shader or adding a composable module, a set of `shader_def` string/value pairs must be provided. The value can be a bool (`ShaderDefValue::Bool`), an i32 (`ShaderDefValue::Int`) or a u32 (`ShaderDefValue::UInt`).
105///
106/// these allow conditional compilation of parts of modules and the final shader. conditional compilation is performed with `#if` / `#ifdef` / `#ifndef`, `#else` and `#endif` preprocessor directives:
107///
108/// ```ignore
109/// fn get_number() -> f32 {
110///     #ifdef BIG_NUMBER
111///         return 999.0;
112///     #else
113///         return 0.999;
114///     #endif
115/// }
116/// ```
117/// the `#ifdef` directive matches when the def name exists in the input binding set (regardless of value). the `#ifndef` directive is the reverse.
118///
119/// the `#if` directive requires a def name, an operator, and a value for comparison:
120/// - the def name must be a provided `shader_def` name.
121/// - the operator must be one of `==`, `!=`, `>=`, `>`, `<`, `<=`
122/// - the value must be an integer literal if comparing to a `ShaderDef::Int`, or `true` or `false` if comparing to a `ShaderDef::Bool`.
123///
124/// shader defs can also be used in the shader source with `#SHADER_DEF` or `#{SHADER_DEF}`, and will be substituted for their value.
125///
126/// ## error reporting
127///
128/// codespan reporting for errors is available using the error `emit_to_string` method. this requires validation to be enabled, which is true by default. `Composer::non_validating()` produces a non-validating composer that is not able to give accurate error reporting.
129///
130use naga::EntryPoint;
131use regex::Regex;
132use std::collections::{hash_map::Entry, BTreeMap, HashMap, HashSet};
133use std::sync::LazyLock;
134use tracing::{debug, trace};
135
136use crate::{
137    compose::preprocess::{PreprocessOutput, PreprocessorMetaData},
138    derive::DerivedModule,
139    redirect::Redirector,
140};
141
142pub use self::error::{ComposerError, ComposerErrorInner, ErrSource};
143use self::preprocess::Preprocessor;
144pub use self::wgsl_directives::{
145    DiagnosticDirective, EnableDirective, RequiresDirective, WgslDirectives,
146};
147
148pub mod comment_strip_iter;
149pub mod error;
150pub mod parse_imports;
151pub mod preprocess;
152mod test;
153pub mod tokenizer;
154pub mod wgsl_directives;
155
156#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug, Default)]
157pub enum ShaderLanguage {
158    #[default]
159    Wgsl,
160    #[cfg(feature = "glsl")]
161    Glsl,
162}
163
164#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug, Default)]
165pub enum ShaderType {
166    #[default]
167    Wgsl,
168    #[cfg(feature = "glsl")]
169    GlslVertex,
170    #[cfg(feature = "glsl")]
171    GlslFragment,
172}
173
174impl From<ShaderType> for ShaderLanguage {
175    fn from(ty: ShaderType) -> Self {
176        match ty {
177            ShaderType::Wgsl => ShaderLanguage::Wgsl,
178            #[cfg(feature = "glsl")]
179            ShaderType::GlslVertex | ShaderType::GlslFragment => ShaderLanguage::Glsl,
180        }
181    }
182}
183
184#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
185pub enum ShaderDefValue {
186    Bool(bool),
187    Int(i32),
188    UInt(u32),
189}
190
191impl Default for ShaderDefValue {
192    fn default() -> Self {
193        ShaderDefValue::Bool(true)
194    }
195}
196
197impl ShaderDefValue {
198    fn value_as_string(&self) -> String {
199        match self {
200            ShaderDefValue::Bool(val) => val.to_string(),
201            ShaderDefValue::Int(val) => val.to_string(),
202            ShaderDefValue::UInt(val) => val.to_string(),
203        }
204    }
205}
206
207#[derive(Clone, PartialEq, Eq, Hash, Debug, Default)]
208pub struct OwnedShaderDefs(BTreeMap<String, ShaderDefValue>);
209
210#[derive(Clone, PartialEq, Eq, Hash, Debug)]
211struct ModuleKey(OwnedShaderDefs);
212
213impl ModuleKey {
214    fn from_members(key: &HashMap<String, ShaderDefValue>, universe: &[String]) -> Self {
215        let mut acc = OwnedShaderDefs::default();
216        for item in universe {
217            if let Some(value) = key.get(item) {
218                acc.0.insert(item.to_owned(), *value);
219            }
220        }
221        ModuleKey(acc)
222    }
223}
224
225// a module built with a specific set of shader_defs
226#[derive(Default, Debug)]
227pub struct ComposableModule {
228    // module decoration, prefixed to all items from this module in the final source
229    pub decorated_name: String,
230    // module names required as imports, optionally with a list of items to import
231    pub imports: Vec<ImportDefinition>,
232    // types exported
233    pub owned_types: HashSet<String>,
234    // constants exported
235    pub owned_constants: HashSet<String>,
236    // vars exported
237    pub owned_vars: HashSet<String>,
238    // functions exported
239    pub owned_functions: HashSet<String>,
240    // local functions that can be overridden
241    pub virtual_functions: HashSet<String>,
242    // overriding functions defined in this module
243    // target function -> Vec<replacement functions>
244    pub override_functions: IndexMap<String, Vec<String>>,
245    // naga module, built against headers for any imports
246    module_ir: naga::Module,
247    // headers in different shader languages, used for building modules/shaders that import this module
248    // headers contain types, constants, global vars and empty function definitions -
249    // just enough to convert source strings that want to import this module into naga IR
250    // headers: HashMap<ShaderLanguage, String>,
251    header_ir: naga::Module,
252    // character offset of the start of the owned module string
253    start_offset: usize,
254}
255
256// data used to build a ComposableModule
257#[derive(Debug)]
258pub struct ComposableModuleDefinition {
259    pub name: String,
260    // shader text (with auto bindings replaced - we do this on module add as we only want to do it once to avoid burning slots)
261    pub sanitized_source: String,
262    // language
263    pub language: ShaderLanguage,
264    // source path for error display
265    pub file_path: String,
266    // shader def values bound to this module
267    pub shader_defs: HashMap<String, ShaderDefValue>,
268    // list of shader_defs that can affect this module
269    effective_defs: Vec<String>,
270    // full list of possible imports (regardless of shader_def configuration)
271    all_imports: HashSet<String>,
272    // additional imports to add (as though they were included in the source after any other imports)
273    additional_imports: Vec<ImportDefinition>,
274    // built composable modules for a given set of shader defs
275    modules: HashMap<ModuleKey, ComposableModule>,
276    // used in spans when this module is included
277    module_index: usize,
278}
279
280impl ComposableModuleDefinition {
281    fn get_module(
282        &self,
283        shader_defs: &HashMap<String, ShaderDefValue>,
284    ) -> Option<&ComposableModule> {
285        self.modules
286            .get(&ModuleKey::from_members(shader_defs, &self.effective_defs))
287    }
288
289    fn insert_module(
290        &mut self,
291        shader_defs: &HashMap<String, ShaderDefValue>,
292        module: ComposableModule,
293    ) -> &ComposableModule {
294        match self
295            .modules
296            .entry(ModuleKey::from_members(shader_defs, &self.effective_defs))
297        {
298            Entry::Occupied(_) => panic!("entry already populated"),
299            Entry::Vacant(v) => v.insert(module),
300        }
301    }
302}
303
304#[derive(Debug, Clone, Default, PartialEq, Eq)]
305pub struct ImportDefinition {
306    pub import: String,
307    pub items: Vec<String>,
308}
309
310#[derive(Debug, Clone)]
311pub struct ImportDefWithOffset {
312    definition: ImportDefinition,
313    offset: usize,
314}
315
316/// module composer.
317/// stores any modules that can be imported into a shader
318/// and builds the final shader
319#[derive(Debug)]
320pub struct Composer {
321    pub validate: bool,
322    pub module_sets: HashMap<String, ComposableModuleDefinition>,
323    pub module_index: HashMap<usize, String>,
324    pub capabilities: naga::valid::Capabilities,
325    preprocessor: Preprocessor,
326    check_decoration_regex: Regex,
327    undecorate_regex: Regex,
328    virtual_fn_regex: Regex,
329    override_fn_regex: Regex,
330    undecorate_override_regex: Regex,
331    auto_binding_regex: Regex,
332    auto_binding_index: u32,
333}
334
335// shift for module index
336// 21 gives
337//   max size for shader of 2m characters
338//   max 2048 modules
339const SPAN_SHIFT: usize = 21;
340
341impl Default for Composer {
342    fn default() -> Self {
343        Self {
344            validate: true,
345            capabilities: Default::default(),
346            module_sets: Default::default(),
347            module_index: Default::default(),
348            preprocessor: Preprocessor::default(),
349            check_decoration_regex: Regex::new(
350                format!(
351                    "({}|{})",
352                    regex::escape(DECORATION_PRE),
353                    regex::escape(DECORATION_OVERRIDE_PRE)
354                )
355                .as_str(),
356            )
357            .unwrap(),
358            undecorate_regex: Regex::new(
359                format!(
360                    r"(\x1B\[\d+\w)?([\w\d_]+){}([A-Z0-9]*){}",
361                    regex::escape(DECORATION_PRE),
362                    regex::escape(DECORATION_POST)
363                )
364                .as_str(),
365            )
366            .unwrap(),
367            virtual_fn_regex: Regex::new(
368                r"(?P<lead>[\s]*virtual\s+fn\s+)(?P<function>[^\s]+)(?P<trail>\s*)\(",
369            )
370            .unwrap(),
371            override_fn_regex: Regex::new(
372                format!(
373                    r"(override\s+fn\s+)([^\s]+){}([\w\d]+){}(\s*)\(",
374                    regex::escape(DECORATION_PRE),
375                    regex::escape(DECORATION_POST)
376                )
377                .as_str(),
378            )
379            .unwrap(),
380            undecorate_override_regex: Regex::new(
381                format!(
382                    "{}([A-Z0-9]*){}",
383                    regex::escape(DECORATION_OVERRIDE_PRE),
384                    regex::escape(DECORATION_POST)
385                )
386                .as_str(),
387            )
388            .unwrap(),
389            auto_binding_regex: Regex::new(r"@binding\(auto\)").unwrap(),
390            auto_binding_index: 0,
391        }
392    }
393}
394
395const DECORATION_PRE: &str = "X_naga_oil_mod_X";
396const DECORATION_POST: &str = "X";
397
398// must be same length as DECORATION_PRE for spans to work
399const DECORATION_OVERRIDE_PRE: &str = "X_naga_oil_vrt_X";
400
401struct IrBuildResult {
402    module: naga::Module,
403    start_offset: usize,
404    override_functions: IndexMap<String, Vec<String>>,
405}
406
407impl Composer {
408    pub fn decorated_name(module_name: Option<&str>, item_name: &str) -> String {
409        match module_name {
410            Some(module_name) => format!("{}{}", item_name, Self::decorate(module_name)),
411            None => item_name.to_owned(),
412        }
413    }
414
415    fn decorate(module: &str) -> String {
416        let encoded = data_encoding::BASE32_NOPAD.encode(module.as_bytes());
417        format!("{DECORATION_PRE}{encoded}{DECORATION_POST}")
418    }
419
420    fn decode(from: &str) -> String {
421        String::from_utf8(data_encoding::BASE32_NOPAD.decode(from.as_bytes()).unwrap()).unwrap()
422    }
423
424    /// Shorthand for creating a naga validator.
425    fn create_validator(&self) -> naga::valid::Validator {
426        naga::valid::Validator::new(naga::valid::ValidationFlags::all(), self.capabilities)
427    }
428
429    fn undecorate(&self, string: &str) -> String {
430        let undecor = self
431            .undecorate_regex
432            .replace_all(string, |caps: &regex::Captures| {
433                format!(
434                    "{}{}::{}",
435                    caps.get(1).map(|cc| cc.as_str()).unwrap_or(""),
436                    Self::decode(caps.get(3).unwrap().as_str()),
437                    caps.get(2).unwrap().as_str()
438                )
439            });
440
441        let undecor =
442            self.undecorate_override_regex
443                .replace_all(&undecor, |caps: &regex::Captures| {
444                    format!(
445                        "override fn {}::",
446                        Self::decode(caps.get(1).unwrap().as_str())
447                    )
448                });
449
450        undecor.to_string()
451    }
452
453    fn sanitize_and_set_auto_bindings(&mut self, source: &str) -> String {
454        let mut substituted_source = source.replace("\r\n", "\n").replace('\r', "\n");
455        if !substituted_source.ends_with('\n') {
456            substituted_source.push('\n');
457        }
458
459        // replace @binding(auto) with an incrementing index
460        struct AutoBindingReplacer<'a> {
461            auto: &'a mut u32,
462        }
463
464        impl regex::Replacer for AutoBindingReplacer<'_> {
465            fn replace_append(&mut self, _: &regex::Captures<'_>, dst: &mut String) {
466                dst.push_str(&format!("@binding({})", self.auto));
467                *self.auto += 1;
468            }
469        }
470
471        let substituted_source = self.auto_binding_regex.replace_all(
472            &substituted_source,
473            AutoBindingReplacer {
474                auto: &mut self.auto_binding_index,
475            },
476        );
477
478        substituted_source.into_owned()
479    }
480
481    fn naga_to_string(
482        &self,
483        naga_module: &mut naga::Module,
484        language: ShaderLanguage,
485        #[allow(unused)] header_for: &str, // Only used when GLSL is enabled
486    ) -> Result<String, ComposerErrorInner> {
487        // TODO: cache headers again
488        let info = self
489            .create_validator()
490            .validate(naga_module)
491            .map_err(ComposerErrorInner::HeaderValidationError)?;
492
493        match language {
494            ShaderLanguage::Wgsl => naga::back::wgsl::write_string(
495                naga_module,
496                &info,
497                naga::back::wgsl::WriterFlags::EXPLICIT_TYPES,
498            )
499            .map_err(ComposerErrorInner::WgslBackError),
500            #[cfg(feature = "glsl")]
501            ShaderLanguage::Glsl => {
502                let vec4 = naga_module.types.insert(
503                    naga::Type {
504                        name: None,
505                        inner: naga::TypeInner::Vector {
506                            size: naga::VectorSize::Quad,
507                            scalar: naga::Scalar::F32,
508                        },
509                    },
510                    naga::Span::UNDEFINED,
511                );
512                // add a dummy entry point for glsl headers
513                let dummy_entry_point = "dummy_module_entry_point".to_owned();
514                let func = naga::Function {
515                    name: Some(dummy_entry_point.clone()),
516                    arguments: Default::default(),
517                    result: Some(naga::FunctionResult {
518                        ty: vec4,
519                        binding: Some(naga::Binding::BuiltIn(naga::BuiltIn::Position {
520                            invariant: false,
521                        })),
522                    }),
523                    local_variables: Default::default(),
524                    expressions: Default::default(),
525                    named_expressions: Default::default(),
526                    body: Default::default(),
527                    diagnostic_filter_leaf: Default::default(),
528                };
529                let ep = EntryPoint {
530                    name: dummy_entry_point.clone(),
531                    stage: naga::ShaderStage::Vertex,
532                    function: func,
533                    early_depth_test: None,
534                    workgroup_size: [0, 0, 0],
535                    workgroup_size_overrides: None,
536                };
537
538                naga_module.entry_points.push(ep);
539
540                let info = self
541                    .create_validator()
542                    .validate(naga_module)
543                    .map_err(ComposerErrorInner::HeaderValidationError)?;
544
545                let mut string = String::new();
546                let options = naga::back::glsl::Options {
547                    version: naga::back::glsl::Version::Desktop(450),
548                    writer_flags: naga::back::glsl::WriterFlags::INCLUDE_UNUSED_ITEMS,
549                    ..Default::default()
550                };
551                let pipeline_options = naga::back::glsl::PipelineOptions {
552                    shader_stage: naga::ShaderStage::Vertex,
553                    entry_point: dummy_entry_point,
554                    multiview: None,
555                };
556                let mut writer = naga::back::glsl::Writer::new(
557                    &mut string,
558                    naga_module,
559                    &info,
560                    &options,
561                    &pipeline_options,
562                    naga::proc::BoundsCheckPolicies::default(),
563                )
564                .map_err(ComposerErrorInner::GlslBackError)?;
565
566                writer.write().map_err(ComposerErrorInner::GlslBackError)?;
567
568                // strip version decl and main() impl
569                let lines: Vec<_> = string.lines().collect();
570                let string = lines[1..lines.len() - 3].join("\n");
571                trace!("glsl header for {}:\n\"\n{:?}\n\"", header_for, string);
572
573                Ok(string)
574            }
575        }
576    }
577
578    // build naga module for a given shader_def configuration. builds a minimal self-contained module built against headers for imports
579    fn create_module_ir(
580        &self,
581        name: &str,
582        source: String,
583        language: ShaderLanguage,
584        imports: &[ImportDefinition],
585        shader_defs: &HashMap<String, ShaderDefValue>,
586        wgsl_directives: Option<WgslDirectives>,
587    ) -> Result<IrBuildResult, ComposerError> {
588        debug!("creating IR for {} with defs: {:?}", name, shader_defs);
589
590        let mut wgsl_string = String::new();
591        if let Some(wgsl_directives) = &wgsl_directives {
592            trace!("adding WGSL directives for {}", name);
593            wgsl_string = wgsl_directives.to_wgsl_string();
594        }
595        let mut module_string = match language {
596            ShaderLanguage::Wgsl => wgsl_string,
597            #[cfg(feature = "glsl")]
598            ShaderLanguage::Glsl => String::from("#version 450\n"),
599        };
600
601        let mut override_functions: IndexMap<String, Vec<String>> = IndexMap::default();
602        let mut added_imports: HashSet<String> = HashSet::new();
603        let mut header_module = DerivedModule::default();
604
605        for import in imports {
606            if added_imports.contains(&import.import) {
607                continue;
608            }
609            // add to header module
610            self.add_import(
611                &mut header_module,
612                import,
613                shader_defs,
614                true,
615                &mut added_imports,
616            );
617
618            // // we must have ensured these exist with Composer::ensure_imports()
619            trace!("looking for {}", import.import);
620            let import_module_set = self.module_sets.get(&import.import).unwrap();
621            trace!("with defs {:?}", shader_defs);
622            let module = import_module_set.get_module(shader_defs).unwrap();
623            trace!("ok");
624
625            // gather overrides
626            if !module.override_functions.is_empty() {
627                for (original, replacements) in &module.override_functions {
628                    match override_functions.entry(original.clone()) {
629                        indexmap::map::Entry::Occupied(o) => {
630                            let existing = o.into_mut();
631                            let new_replacements: Vec<_> = replacements
632                                .iter()
633                                .filter(|rep| !existing.contains(rep))
634                                .cloned()
635                                .collect();
636                            existing.extend(new_replacements);
637                        }
638                        indexmap::map::Entry::Vacant(v) => {
639                            v.insert(replacements.clone());
640                        }
641                    }
642                }
643            }
644        }
645
646        let composed_header = self
647            .naga_to_string(&mut header_module.into(), language, name)
648            .map_err(|inner| ComposerError {
649                inner,
650                source: ErrSource::Module {
651                    name: name.to_owned(),
652                    offset: 0,
653                    defs: shader_defs.clone(),
654                },
655            })?;
656        module_string.push_str(&composed_header);
657
658        let start_offset = module_string.len();
659
660        module_string.push_str(&source);
661
662        trace!(
663            "parsing {}: {}, header len {}, total len {}",
664            name,
665            module_string,
666            start_offset,
667            module_string.len()
668        );
669        let module = match language {
670            ShaderLanguage::Wgsl => naga::front::wgsl::parse_str(&module_string).map_err(|e| {
671                debug!("full err'd source file: \n---\n{}\n---", module_string);
672                ComposerError {
673                    inner: ComposerErrorInner::WgslParseError(e),
674                    source: ErrSource::Module {
675                        name: name.to_owned(),
676                        offset: start_offset,
677                        defs: shader_defs.clone(),
678                    },
679                }
680            })?,
681            #[cfg(feature = "glsl")]
682            ShaderLanguage::Glsl => naga::front::glsl::Frontend::default()
683                .parse(
684                    &naga::front::glsl::Options {
685                        stage: naga::ShaderStage::Vertex,
686                        defines: Default::default(),
687                    },
688                    &module_string,
689                )
690                .map_err(|e| {
691                    debug!("full err'd source file: \n---\n{}\n---", module_string);
692                    ComposerError {
693                        inner: ComposerErrorInner::GlslParseError(e),
694                        source: ErrSource::Module {
695                            name: name.to_owned(),
696                            offset: start_offset,
697                            defs: shader_defs.clone(),
698                        },
699                    }
700                })?,
701        };
702
703        Ok(IrBuildResult {
704            module,
705            start_offset,
706            override_functions,
707        })
708    }
709
710    // check that identifiers exported by a module do not get modified in string export
711    fn validate_identifiers(
712        source_ir: &naga::Module,
713        lang: ShaderLanguage,
714        header: &str,
715        module_decoration: &str,
716        owned_types: &HashSet<String>,
717    ) -> Result<(), ComposerErrorInner> {
718        // TODO: remove this once glsl front support is complete
719        #[cfg(feature = "glsl")]
720        if lang == ShaderLanguage::Glsl {
721            return Ok(());
722        }
723
724        let recompiled = match lang {
725            ShaderLanguage::Wgsl => naga::front::wgsl::parse_str(header).unwrap(),
726            #[cfg(feature = "glsl")]
727            ShaderLanguage::Glsl => naga::front::glsl::Frontend::default()
728                .parse(
729                    &naga::front::glsl::Options {
730                        stage: naga::ShaderStage::Vertex,
731                        defines: Default::default(),
732                    },
733                    &format!("{}\n{}", header, "void main() {}"),
734                )
735                .map_err(|e| {
736                    debug!("full err'd source file: \n---\n{header}\n---");
737                    ComposerErrorInner::GlslParseError(e)
738                })?,
739        };
740
741        let recompiled_types: IndexMap<_, _> = recompiled
742            .types
743            .iter()
744            .flat_map(|(h, ty)| ty.name.as_deref().map(|name| (name, h)))
745            .collect();
746        for (h, ty) in source_ir.types.iter() {
747            if let Some(name) = &ty.name {
748                let decorated_type_name = format!("{name}{module_decoration}");
749                if !owned_types.contains(&decorated_type_name) {
750                    continue;
751                }
752                match recompiled_types.get(decorated_type_name.as_str()) {
753                    Some(recompiled_h) => {
754                        if let naga::TypeInner::Struct { members, .. } = &ty.inner {
755                            let recompiled_ty = recompiled.types.get_handle(*recompiled_h).unwrap();
756                            let naga::TypeInner::Struct {
757                                members: recompiled_members,
758                                ..
759                            } = &recompiled_ty.inner
760                            else {
761                                panic!();
762                            };
763                            for (member, recompiled_member) in
764                                members.iter().zip(recompiled_members)
765                            {
766                                if member.name != recompiled_member.name {
767                                    return Err(ComposerErrorInner::InvalidIdentifier {
768                                        original: member.name.clone().unwrap_or_default(),
769                                        at: source_ir.types.get_span(h),
770                                    });
771                                }
772                            }
773                        }
774                    }
775                    None => {
776                        return Err(ComposerErrorInner::InvalidIdentifier {
777                            original: name.clone(),
778                            at: source_ir.types.get_span(h),
779                        })
780                    }
781                }
782            }
783        }
784
785        let recompiled_consts: HashSet<_> = recompiled
786            .constants
787            .iter()
788            .flat_map(|(_, c)| c.name.as_deref())
789            .filter(|name| name.ends_with(module_decoration))
790            .collect();
791        for (h, c) in source_ir.constants.iter() {
792            if let Some(name) = &c.name {
793                if name.ends_with(module_decoration) && !recompiled_consts.contains(name.as_str()) {
794                    return Err(ComposerErrorInner::InvalidIdentifier {
795                        original: name.clone(),
796                        at: source_ir.constants.get_span(h),
797                    });
798                }
799            }
800        }
801
802        let recompiled_globals: HashSet<_> = recompiled
803            .global_variables
804            .iter()
805            .flat_map(|(_, c)| c.name.as_deref())
806            .filter(|name| name.ends_with(module_decoration))
807            .collect();
808        for (h, gv) in source_ir.global_variables.iter() {
809            if let Some(name) = &gv.name {
810                if name.ends_with(module_decoration) && !recompiled_globals.contains(name.as_str())
811                {
812                    return Err(ComposerErrorInner::InvalidIdentifier {
813                        original: name.clone(),
814                        at: source_ir.global_variables.get_span(h),
815                    });
816                }
817            }
818        }
819
820        let recompiled_fns: HashSet<_> = recompiled
821            .functions
822            .iter()
823            .flat_map(|(_, c)| c.name.as_deref())
824            .filter(|name| name.ends_with(module_decoration))
825            .collect();
826        for (h, f) in source_ir.functions.iter() {
827            if let Some(name) = &f.name {
828                if name.ends_with(module_decoration) && !recompiled_fns.contains(name.as_str()) {
829                    return Err(ComposerErrorInner::InvalidIdentifier {
830                        original: name.clone(),
831                        at: source_ir.functions.get_span(h),
832                    });
833                }
834            }
835        }
836
837        Ok(())
838    }
839
840    // build a ComposableModule from a ComposableModuleDefinition, for a given set of shader defs
841    // - build the naga IR (against headers)
842    // - record any types/vars/constants/functions that are defined within this module
843    // - build headers for each supported language
844    #[allow(clippy::too_many_arguments)]
845    fn create_composable_module(
846        &mut self,
847        module_definition: &ComposableModuleDefinition,
848        module_decoration: String,
849        shader_defs: &HashMap<String, ShaderDefValue>,
850        create_headers: bool,
851        demote_entrypoints: bool,
852        source: &str,
853        imports: Vec<ImportDefWithOffset>,
854        wgsl_directives: Option<WgslDirectives>,
855    ) -> Result<ComposableModule, ComposerError> {
856        let mut imports: Vec<_> = imports
857            .into_iter()
858            .map(|import_with_offset| import_with_offset.definition)
859            .collect();
860        imports.extend(module_definition.additional_imports.to_vec());
861
862        trace!(
863            "create composable module {}: source len {}",
864            module_definition.name,
865            source.len()
866        );
867
868        // record virtual/overridable functions
869        let mut virtual_functions: HashSet<String> = Default::default();
870        let source = self
871            .virtual_fn_regex
872            .replace_all(source, |cap: &regex::Captures| {
873                let target_function = cap.get(2).unwrap().as_str().to_owned();
874
875                let replacement_str = format!(
876                    "{}fn {}{}(",
877                    " ".repeat(cap.get(1).unwrap().range().len() - 3),
878                    target_function,
879                    " ".repeat(cap.get(3).unwrap().range().len()),
880                );
881
882                virtual_functions.insert(target_function);
883
884                replacement_str
885            });
886
887        // record and rename override functions
888        let mut local_override_functions: IndexMap<String, String> = Default::default();
889
890        #[cfg(not(feature = "override_any"))]
891        let mut override_error = None;
892
893        let source =
894            self.override_fn_regex
895                .replace_all(&source, |cap: &regex::Captures| {
896                    let target_module = cap.get(3).unwrap().as_str().to_owned();
897                    let target_function = cap.get(2).unwrap().as_str().to_owned();
898
899                    #[cfg(not(feature = "override_any"))]
900                    {
901                        let wrap_err = |inner: ComposerErrorInner| -> ComposerError {
902                            ComposerError {
903                                inner,
904                                source: ErrSource::Module {
905                                    name: module_definition.name.to_owned(),
906                                    offset: 0,
907                                    defs: shader_defs.clone(),
908                                },
909                            }
910                        };
911
912                        // ensure overrides are applied to virtual functions
913                        let raw_module_name = Self::decode(&target_module);
914                        let module_set = self.module_sets.get(&raw_module_name);
915
916                        match module_set {
917                            None => {
918                                // TODO this should be unreachable?
919                                let pos = cap.get(3).unwrap().start();
920                                override_error = Some(wrap_err(
921                                    ComposerErrorInner::ImportNotFound(raw_module_name, pos),
922                                ));
923                            }
924                            Some(module_set) => {
925                                let module = module_set.get_module(shader_defs).unwrap();
926                                if !module.virtual_functions.contains(&target_function) {
927                                    let pos = cap.get(2).unwrap().start();
928                                    override_error =
929                                        Some(wrap_err(ComposerErrorInner::OverrideNotVirtual {
930                                            name: target_function.clone(),
931                                            pos,
932                                        }));
933                                }
934                            }
935                        }
936                    }
937
938                    let base_name = format!(
939                        "{}{}{}{}",
940                        target_function.as_str(),
941                        DECORATION_PRE,
942                        target_module.as_str(),
943                        DECORATION_POST,
944                    );
945                    let rename = format!(
946                        "{}{}{}{}",
947                        target_function.as_str(),
948                        DECORATION_OVERRIDE_PRE,
949                        target_module.as_str(),
950                        DECORATION_POST,
951                    );
952
953                    let replacement_str = format!(
954                        "{}fn {}{}(",
955                        " ".repeat(cap.get(1).unwrap().range().len() - 3),
956                        rename,
957                        " ".repeat(cap.get(4).unwrap().range().len()),
958                    );
959
960                    local_override_functions.insert(rename, base_name);
961
962                    replacement_str
963                })
964                .to_string();
965
966        #[cfg(not(feature = "override_any"))]
967        if let Some(err) = override_error {
968            return Err(err);
969        }
970
971        trace!("local overrides: {:?}", local_override_functions);
972        trace!(
973            "create composable module {}: source len {}",
974            module_definition.name,
975            source.len()
976        );
977
978        let IrBuildResult {
979            module: mut source_ir,
980            start_offset,
981            mut override_functions,
982        } = self.create_module_ir(
983            &module_definition.name,
984            source,
985            module_definition.language,
986            &imports,
987            shader_defs,
988            wgsl_directives,
989        )?;
990
991        // from here on errors need to be reported using the modified source with start_offset
992        let wrap_err = |inner: ComposerErrorInner| -> ComposerError {
993            ComposerError {
994                inner,
995                source: ErrSource::Module {
996                    name: module_definition.name.to_owned(),
997                    offset: start_offset,
998                    defs: shader_defs.clone(),
999                },
1000            }
1001        };
1002
1003        // add our local override to the total set of overrides for the given function
1004        for (rename, base_name) in &local_override_functions {
1005            override_functions
1006                .entry(base_name.clone())
1007                .or_default()
1008                .push(format!("{rename}{module_decoration}"));
1009        }
1010
1011        // rename and record owned items (except types which can't be mutably accessed)
1012        let mut owned_constants = IndexMap::new();
1013        for (h, c) in source_ir.constants.iter_mut() {
1014            if let Some(name) = c.name.as_mut() {
1015                if !name.contains(DECORATION_PRE) {
1016                    *name = format!("{name}{module_decoration}");
1017                    owned_constants.insert(name.clone(), h);
1018                }
1019            }
1020        }
1021
1022        // These are naga/wgpu's pipeline override constants, not naga_oil's overrides
1023        let mut owned_pipeline_overrides = IndexMap::new();
1024        for (h, po) in source_ir.overrides.iter_mut() {
1025            if let Some(name) = po.name.as_mut() {
1026                if !name.contains(DECORATION_PRE) {
1027                    *name = format!("{name}{module_decoration}");
1028                    owned_pipeline_overrides.insert(name.clone(), h);
1029                }
1030            }
1031        }
1032
1033        let mut owned_vars = IndexMap::new();
1034        for (h, gv) in source_ir.global_variables.iter_mut() {
1035            if let Some(name) = gv.name.as_mut() {
1036                if !name.contains(DECORATION_PRE) {
1037                    *name = format!("{name}{module_decoration}");
1038
1039                    owned_vars.insert(name.clone(), h);
1040                }
1041            }
1042        }
1043
1044        let mut owned_functions = IndexMap::new();
1045        for (h_f, f) in source_ir.functions.iter_mut() {
1046            if let Some(name) = f.name.as_mut() {
1047                if !name.contains(DECORATION_PRE) {
1048                    *name = format!("{name}{module_decoration}");
1049
1050                    // create dummy header function
1051                    let header_function = naga::Function {
1052                        name: Some(name.clone()),
1053                        arguments: f.arguments.to_vec(),
1054                        result: f.result.clone(),
1055                        local_variables: Default::default(),
1056                        expressions: Default::default(),
1057                        named_expressions: Default::default(),
1058                        body: Default::default(),
1059                        diagnostic_filter_leaf: None,
1060                    };
1061
1062                    // record owned function
1063                    owned_functions.insert(name.clone(), (Some(h_f), header_function));
1064                }
1065            }
1066        }
1067
1068        if demote_entrypoints {
1069            // make normal functions out of the source entry points
1070            for ep in &mut source_ir.entry_points {
1071                ep.function.name = Some(format!(
1072                    "{}{}",
1073                    ep.function.name.as_deref().unwrap_or("main"),
1074                    module_decoration,
1075                ));
1076                let header_function = naga::Function {
1077                    name: ep.function.name.clone(),
1078                    arguments: ep
1079                        .function
1080                        .arguments
1081                        .iter()
1082                        .cloned()
1083                        .map(|arg| naga::FunctionArgument {
1084                            name: arg.name,
1085                            ty: arg.ty,
1086                            binding: None,
1087                        })
1088                        .collect(),
1089                    result: ep.function.result.clone().map(|res| naga::FunctionResult {
1090                        ty: res.ty,
1091                        binding: None,
1092                    }),
1093                    local_variables: Default::default(),
1094                    expressions: Default::default(),
1095                    named_expressions: Default::default(),
1096                    body: Default::default(),
1097                    diagnostic_filter_leaf: None,
1098                };
1099
1100                owned_functions.insert(ep.function.name.clone().unwrap(), (None, header_function));
1101            }
1102        };
1103
1104        let mut module_builder = DerivedModule::default();
1105        let mut header_builder = DerivedModule::default();
1106        module_builder.set_shader_source(&source_ir, 0);
1107        header_builder.set_shader_source(&source_ir, 0);
1108
1109        // gather special types to exclude from owned types
1110        let mut special_types: HashSet<&naga::Handle<naga::Type>> = HashSet::new();
1111        special_types.extend(source_ir.special_types.predeclared_types.values());
1112        special_types.extend(
1113            [
1114                source_ir.special_types.ray_desc.as_ref(),
1115                source_ir.special_types.ray_intersection.as_ref(),
1116                source_ir.special_types.ray_vertex_return.as_ref(),
1117            ]
1118            .iter()
1119            .flatten(),
1120        );
1121
1122        // as the header of imports that use special types includes the special type definitions explicitly,
1123        // we also exclude anything with a name matching the known special type names
1124        let special_type_names = special_types
1125            .iter()
1126            .flat_map(|h| source_ir.types.get_handle(**h).unwrap().name.clone())
1127            .collect::<HashSet<_>>();
1128
1129        let mut owned_types = HashSet::new();
1130        for (h, ty) in source_ir.types.iter() {
1131            if let Some(name) = &ty.name {
1132                // we exclude any special types, these are added back later
1133                if special_types.contains(&h) || special_type_names.contains(name) {
1134                    continue;
1135                }
1136
1137                if !name.contains(DECORATION_PRE) {
1138                    let name = format!("{name}{module_decoration}");
1139                    owned_types.insert(name.clone());
1140                    // copy and rename types
1141                    module_builder.rename_type(&h, Some(name.clone()));
1142                    header_builder.rename_type(&h, Some(name));
1143                    continue;
1144                }
1145            }
1146
1147            // copy all required types
1148            module_builder.import_type(&h);
1149        }
1150
1151        // copy owned types into header and module
1152        for h in owned_constants.values() {
1153            header_builder.import_const(h);
1154            module_builder.import_const(h);
1155        }
1156
1157        for h in owned_pipeline_overrides.values() {
1158            header_builder.import_pipeline_override(h);
1159            module_builder.import_pipeline_override(h);
1160        }
1161
1162        for h in owned_vars.values() {
1163            header_builder.import_global(h);
1164            module_builder.import_global(h);
1165        }
1166
1167        // only stubs of owned functions into the header
1168        for (h_f, f) in owned_functions.values() {
1169            let span = h_f
1170                .map(|h_f| source_ir.functions.get_span(h_f))
1171                .unwrap_or(naga::Span::UNDEFINED);
1172            header_builder.import_function(f, span); // header stub function
1173        }
1174        // all functions into the module (note source_ir only contains stubs for imported functions)
1175        for (h_f, f) in source_ir.functions.iter() {
1176            let span = source_ir.functions.get_span(h_f);
1177            module_builder.import_function(f, span);
1178        }
1179        // // including entry points as vanilla functions if required
1180        if demote_entrypoints {
1181            for ep in &source_ir.entry_points {
1182                let mut f = ep.function.clone();
1183                f.arguments = f
1184                    .arguments
1185                    .into_iter()
1186                    .map(|arg| naga::FunctionArgument {
1187                        name: arg.name,
1188                        ty: arg.ty,
1189                        binding: None,
1190                    })
1191                    .collect();
1192                f.result = f.result.map(|res| naga::FunctionResult {
1193                    ty: res.ty,
1194                    binding: None,
1195                });
1196
1197                module_builder.import_function(&f, naga::Span::UNDEFINED);
1198                // todo figure out how to get span info for entrypoints
1199            }
1200        }
1201
1202        let has_special_types = module_builder.has_required_special_types();
1203        let module_ir = module_builder.into_module_with_entrypoints();
1204        let mut header_ir: naga::Module = header_builder.into();
1205
1206        // note: we cannot validate when special types are used, as writeback isn't supported
1207        if self.validate && create_headers && !has_special_types {
1208            // check that identifiers haven't been renamed
1209            #[allow(clippy::single_element_loop)]
1210            for language in [
1211                ShaderLanguage::Wgsl,
1212                #[cfg(feature = "glsl")]
1213                ShaderLanguage::Glsl,
1214            ] {
1215                let header = self
1216                    .naga_to_string(&mut header_ir, language, &module_definition.name)
1217                    .map_err(wrap_err)?;
1218                Self::validate_identifiers(
1219                    &source_ir,
1220                    language,
1221                    &header,
1222                    &module_decoration,
1223                    &owned_types,
1224                )
1225                .map_err(wrap_err)?;
1226            }
1227        }
1228
1229        let composable_module = ComposableModule {
1230            decorated_name: module_decoration,
1231            imports,
1232            owned_types,
1233            owned_constants: owned_constants.into_keys().collect(),
1234            owned_vars: owned_vars.into_keys().collect(),
1235            owned_functions: owned_functions.into_keys().collect(),
1236            virtual_functions,
1237            override_functions,
1238            module_ir,
1239            header_ir,
1240            start_offset,
1241        };
1242
1243        Ok(composable_module)
1244    }
1245
1246    // shunt all data owned by a composable into a derived module
1247    fn add_composable_data<'a>(
1248        derived: &mut DerivedModule<'a>,
1249        composable: &'a ComposableModule,
1250        items: Option<&Vec<String>>,
1251        span_offset: usize,
1252        header: bool,
1253    ) {
1254        let items: Option<HashSet<String>> = items.map(|items| {
1255            items
1256                .iter()
1257                .map(|item| format!("{}{}", item, composable.decorated_name))
1258                .collect()
1259        });
1260        let items = items.as_ref();
1261
1262        let source_ir = match header {
1263            true => &composable.header_ir,
1264            false => &composable.module_ir,
1265        };
1266
1267        derived.set_shader_source(source_ir, span_offset);
1268
1269        for (h, ty) in source_ir.types.iter() {
1270            if let Some(name) = &ty.name {
1271                if composable.owned_types.contains(name)
1272                    && items.is_none_or(|items| items.contains(name))
1273                {
1274                    derived.import_type(&h);
1275                }
1276            }
1277        }
1278
1279        for (h, c) in source_ir.constants.iter() {
1280            if let Some(name) = &c.name {
1281                if composable.owned_constants.contains(name)
1282                    && items.is_none_or(|items| items.contains(name))
1283                {
1284                    derived.import_const(&h);
1285                }
1286            }
1287        }
1288
1289        for (h, po) in source_ir.overrides.iter() {
1290            if let Some(name) = &po.name {
1291                if composable.owned_functions.contains(name)
1292                    && items.is_none_or(|items| items.contains(name))
1293                {
1294                    derived.import_pipeline_override(&h);
1295                }
1296            }
1297        }
1298
1299        for (h, v) in source_ir.global_variables.iter() {
1300            if let Some(name) = &v.name {
1301                if composable.owned_vars.contains(name)
1302                    && items.is_none_or(|items| items.contains(name))
1303                {
1304                    derived.import_global(&h);
1305                }
1306            }
1307        }
1308
1309        for (h_f, f) in source_ir.functions.iter() {
1310            if let Some(name) = &f.name {
1311                if composable.owned_functions.contains(name)
1312                    && (items.is_none_or(|items| items.contains(name))
1313                        || composable
1314                            .override_functions
1315                            .values()
1316                            .any(|v| v.contains(name)))
1317                {
1318                    let span = composable.module_ir.functions.get_span(h_f);
1319                    derived.import_function_if_new(f, span);
1320                }
1321            }
1322        }
1323
1324        derived.clear_shader_source();
1325    }
1326
1327    // add an import (and recursive imports) into a derived module
1328    fn add_import<'a>(
1329        &'a self,
1330        derived: &mut DerivedModule<'a>,
1331        import: &ImportDefinition,
1332        shader_defs: &HashMap<String, ShaderDefValue>,
1333        header: bool,
1334        already_added: &mut HashSet<String>,
1335    ) {
1336        if already_added.contains(&import.import) {
1337            trace!("skipping {}, already added", import.import);
1338            return;
1339        }
1340
1341        let import_module_set = self.module_sets.get(&import.import).unwrap();
1342        let module = import_module_set.get_module(shader_defs).unwrap();
1343
1344        for import in &module.imports {
1345            self.add_import(derived, import, shader_defs, header, already_added);
1346        }
1347
1348        Self::add_composable_data(
1349            derived,
1350            module,
1351            Some(&import.items),
1352            import_module_set.module_index << SPAN_SHIFT,
1353            header,
1354        );
1355    }
1356
1357    fn ensure_import(
1358        &mut self,
1359        module_set: &ComposableModuleDefinition,
1360        shader_defs: &HashMap<String, ShaderDefValue>,
1361    ) -> Result<ComposableModule, EnsureImportsError> {
1362        let PreprocessOutput {
1363            preprocessed_source,
1364            imports,
1365        } = self
1366            .preprocessor
1367            .preprocess(&module_set.sanitized_source, shader_defs)
1368            .map_err(|inner| {
1369                EnsureImportsError::from(ComposerError {
1370                    inner,
1371                    source: ErrSource::Module {
1372                        name: module_set.name.to_owned(),
1373                        offset: 0,
1374                        defs: shader_defs.clone(),
1375                    },
1376                })
1377            })?;
1378
1379        self.ensure_imports(imports.iter().map(|import| &import.definition), shader_defs)?;
1380        self.ensure_imports(&module_set.additional_imports, shader_defs)?;
1381
1382        self.create_composable_module(
1383            module_set,
1384            Self::decorate(&module_set.name),
1385            shader_defs,
1386            true,
1387            true,
1388            &preprocessed_source,
1389            imports,
1390            None,
1391        )
1392        .map_err(|err| err.into())
1393    }
1394
1395    // build required ComposableModules for a given set of shader_defs
1396    fn ensure_imports<'a>(
1397        &mut self,
1398        imports: impl IntoIterator<Item = &'a ImportDefinition>,
1399        shader_defs: &HashMap<String, ShaderDefValue>,
1400    ) -> Result<(), EnsureImportsError> {
1401        for ImportDefinition { import, .. } in imports.into_iter() {
1402            let Some(module_set) = self.module_sets.get(import) else {
1403                return Err(EnsureImportsError::MissingImport(import.to_owned()));
1404            };
1405            if module_set.get_module(shader_defs).is_some() {
1406                continue;
1407            }
1408
1409            // we need to build the module
1410            // take the set so we can recurse without borrowing
1411            let (set_key, mut module_set) = self.module_sets.remove_entry(import).unwrap();
1412
1413            match self.ensure_import(&module_set, shader_defs) {
1414                Ok(module) => {
1415                    module_set.insert_module(shader_defs, module);
1416                    self.module_sets.insert(set_key, module_set);
1417                }
1418                Err(e) => {
1419                    self.module_sets.insert(set_key, module_set);
1420                    return Err(e);
1421                }
1422            }
1423        }
1424
1425        Ok(())
1426    }
1427}
1428
1429pub enum EnsureImportsError {
1430    MissingImport(String),
1431    ComposerError(ComposerError),
1432}
1433
1434impl EnsureImportsError {
1435    fn into_composer_error(self, err_source: ErrSource) -> ComposerError {
1436        match self {
1437            EnsureImportsError::MissingImport(import) => ComposerError {
1438                inner: ComposerErrorInner::ImportNotFound(import.to_owned(), 0),
1439                source: err_source,
1440            },
1441            EnsureImportsError::ComposerError(err) => err,
1442        }
1443    }
1444}
1445
1446impl From<ComposerError> for EnsureImportsError {
1447    fn from(value: ComposerError) -> Self {
1448        EnsureImportsError::ComposerError(value)
1449    }
1450}
1451
1452#[derive(Default)]
1453pub struct ComposableModuleDescriptor<'a> {
1454    pub source: &'a str,
1455    pub file_path: &'a str,
1456    pub language: ShaderLanguage,
1457    pub as_name: Option<String>,
1458    pub additional_imports: &'a [ImportDefinition],
1459    pub shader_defs: HashMap<String, ShaderDefValue>,
1460}
1461
1462#[derive(Default)]
1463pub struct NagaModuleDescriptor<'a> {
1464    pub source: &'a str,
1465    pub file_path: &'a str,
1466    pub shader_type: ShaderType,
1467    pub shader_defs: HashMap<String, ShaderDefValue>,
1468    pub additional_imports: &'a [ImportDefinition],
1469}
1470
1471// public api
1472impl Composer {
1473    /// create a non-validating composer.
1474    /// validation errors in the final shader will not be caught, and errors resulting from their
1475    /// use will have bad span data, so codespan reporting will fail.
1476    /// use default() to create a validating composer.
1477    pub fn non_validating() -> Self {
1478        Self {
1479            validate: false,
1480            ..Default::default()
1481        }
1482    }
1483
1484    /// specify capabilities to be used for naga module generation.
1485    /// purges any existing modules
1486    /// See https://github.com/gfx-rs/wgpu/blob/d9c054c645af0ea9ef81617c3e762fbf0f3fecda/wgpu-core/src/device/mod.rs#L515
1487    /// for how to set the subgroup_stages value.
1488    pub fn with_capabilities(self, capabilities: naga::valid::Capabilities) -> Self {
1489        Self {
1490            capabilities,
1491            validate: self.validate,
1492            ..Default::default()
1493        }
1494    }
1495
1496    /// check if a module with the given name has been added
1497    pub fn contains_module(&self, module_name: &str) -> bool {
1498        self.module_sets.contains_key(module_name)
1499    }
1500
1501    /// add a composable module to the composer.
1502    /// all modules imported by this module must already have been added
1503    pub fn add_composable_module(
1504        &mut self,
1505        desc: ComposableModuleDescriptor,
1506    ) -> Result<&ComposableModuleDefinition, ComposerError> {
1507        let ComposableModuleDescriptor {
1508            source,
1509            file_path,
1510            language,
1511            as_name,
1512            additional_imports,
1513            mut shader_defs,
1514        } = desc;
1515
1516        // reject a module containing the DECORATION strings
1517        if let Some(decor) = self.check_decoration_regex.find(source) {
1518            return Err(ComposerError {
1519                inner: ComposerErrorInner::DecorationInSource(decor.range()),
1520                source: ErrSource::Constructing {
1521                    path: file_path.to_owned(),
1522                    source: source.to_owned(),
1523                    offset: 0,
1524                },
1525            });
1526        }
1527
1528        let substituted_source = self.sanitize_and_set_auto_bindings(source);
1529
1530        let PreprocessorMetaData {
1531            name: module_name,
1532            mut imports,
1533            mut effective_defs,
1534            cleaned_source,
1535            ..
1536        } = self
1537            .preprocessor
1538            .get_preprocessor_metadata(&substituted_source, false)
1539            .map_err(|inner| ComposerError {
1540                inner,
1541                source: ErrSource::Constructing {
1542                    path: file_path.to_owned(),
1543                    source: source.to_owned(),
1544                    offset: 0,
1545                },
1546            })?;
1547        let module_name = as_name.or(module_name);
1548        if module_name.is_none() {
1549            return Err(ComposerError {
1550                inner: ComposerErrorInner::NoModuleName,
1551                source: ErrSource::Constructing {
1552                    path: file_path.to_owned(),
1553                    source: source.to_owned(),
1554                    offset: 0,
1555                },
1556            });
1557        }
1558        let module_name = module_name.unwrap();
1559
1560        debug!(
1561            "adding module definition for {} with defs: {:?}",
1562            module_name, shader_defs
1563        );
1564
1565        // add custom imports
1566        let additional_imports = additional_imports.to_vec();
1567        imports.extend(
1568            additional_imports
1569                .iter()
1570                .cloned()
1571                .map(|def| ImportDefWithOffset {
1572                    definition: def,
1573                    offset: 0,
1574                }),
1575        );
1576
1577        for import in &imports {
1578            // we require modules already added so that we can capture the shader_defs that may impact us by impacting our dependencies
1579            let module_set = self
1580                .module_sets
1581                .get(&import.definition.import)
1582                .ok_or_else(|| ComposerError {
1583                    inner: ComposerErrorInner::ImportNotFound(
1584                        import.definition.import.clone(),
1585                        import.offset,
1586                    ),
1587                    source: ErrSource::Constructing {
1588                        path: file_path.to_owned(),
1589                        source: substituted_source.to_owned(),
1590                        offset: 0,
1591                    },
1592                })?;
1593            effective_defs.extend(module_set.effective_defs.iter().cloned());
1594            shader_defs.extend(
1595                module_set
1596                    .shader_defs
1597                    .iter()
1598                    .map(|def| (def.0.clone(), *def.1)),
1599            );
1600        }
1601
1602        // remove defs that are already specified through our imports
1603        effective_defs.retain(|name| !shader_defs.contains_key(name));
1604
1605        // can't gracefully report errors for more modules. perhaps this should be a warning
1606        assert!((self.module_sets.len() as u32) < u32::MAX >> SPAN_SHIFT);
1607        let module_index = self.module_sets.len() + 1;
1608
1609        let module_set = ComposableModuleDefinition {
1610            name: module_name.clone(),
1611            sanitized_source: cleaned_source,
1612            file_path: file_path.to_owned(),
1613            language,
1614            effective_defs: effective_defs.into_iter().collect(),
1615            all_imports: imports.into_iter().map(|id| id.definition.import).collect(),
1616            additional_imports,
1617            shader_defs,
1618            module_index,
1619            modules: Default::default(),
1620        };
1621
1622        // invalidate dependent modules if this module already exists
1623        self.remove_composable_module(&module_name);
1624
1625        self.module_sets.insert(module_name.clone(), module_set);
1626        self.module_index.insert(module_index, module_name.clone());
1627        Ok(self.module_sets.get(&module_name).unwrap())
1628    }
1629
1630    /// remove a composable module. also removes modules that depend on this module, as we cannot be sure about
1631    /// the completeness of their effective shader defs any more...
1632    pub fn remove_composable_module(&mut self, module_name: &str) {
1633        // todo this could be improved by making effective defs an Option<HashSet> and populating on demand?
1634        let mut dependent_sets = Vec::new();
1635
1636        if self.module_sets.remove(module_name).is_some() {
1637            dependent_sets.extend(self.module_sets.iter().filter_map(|(dependent_name, set)| {
1638                if set.all_imports.contains(module_name) {
1639                    Some(dependent_name.clone())
1640                } else {
1641                    None
1642                }
1643            }));
1644        }
1645
1646        for dependent_set in dependent_sets {
1647            self.remove_composable_module(&dependent_set);
1648        }
1649    }
1650
1651    /// build a naga shader module
1652    pub fn make_naga_module(
1653        &mut self,
1654        desc: NagaModuleDescriptor,
1655    ) -> Result<naga::Module, ComposerError> {
1656        let NagaModuleDescriptor {
1657            source,
1658            file_path,
1659            shader_type,
1660            mut shader_defs,
1661            additional_imports,
1662        } = desc;
1663
1664        let sanitized_source = self.sanitize_and_set_auto_bindings(source);
1665
1666        let PreprocessorMetaData {
1667            name,
1668            defines,
1669            wgsl_directives,
1670            cleaned_source,
1671            ..
1672        } = self
1673            .preprocessor
1674            .get_preprocessor_metadata(&sanitized_source, true)
1675            .map_err(|inner| ComposerError {
1676                inner,
1677                source: ErrSource::Constructing {
1678                    path: file_path.to_owned(),
1679                    source: sanitized_source.to_owned(),
1680                    offset: 0,
1681                },
1682            })?;
1683        shader_defs.extend(defines);
1684
1685        let name = name.unwrap_or_default();
1686
1687        let PreprocessOutput { imports, .. } = self
1688            .preprocessor
1689            .preprocess(&cleaned_source, &shader_defs)
1690            .map_err(|inner| ComposerError {
1691                inner,
1692                source: ErrSource::Constructing {
1693                    path: file_path.to_owned(),
1694                    source: sanitized_source.to_owned(),
1695                    offset: 0,
1696                },
1697            })?;
1698
1699        // make sure imports have been added
1700        // and gather additional defs specified at module level
1701        for (import_name, offset) in imports
1702            .iter()
1703            .map(|id| (&id.definition.import, id.offset))
1704            .chain(additional_imports.iter().map(|ai| (&ai.import, 0)))
1705        {
1706            if let Some(module_set) = self.module_sets.get(import_name) {
1707                for (def, value) in &module_set.shader_defs {
1708                    if let Some(prior_value) = shader_defs.insert(def.clone(), *value) {
1709                        if prior_value != *value {
1710                            return Err(ComposerError {
1711                                inner: ComposerErrorInner::InconsistentShaderDefValue {
1712                                    def: def.clone(),
1713                                },
1714                                source: ErrSource::Constructing {
1715                                    path: file_path.to_owned(),
1716                                    source: sanitized_source.to_owned(),
1717                                    offset: 0,
1718                                },
1719                            });
1720                        }
1721                    }
1722                }
1723            } else {
1724                return Err(ComposerError {
1725                    inner: ComposerErrorInner::ImportNotFound(import_name.clone(), offset),
1726                    source: ErrSource::Constructing {
1727                        path: file_path.to_owned(),
1728                        source: sanitized_source.to_owned(),
1729                        offset: 0,
1730                    },
1731                });
1732            }
1733        }
1734        self.ensure_imports(
1735            imports.iter().map(|import| &import.definition),
1736            &shader_defs,
1737        )
1738        .map_err(|err| {
1739            err.into_composer_error(ErrSource::Constructing {
1740                path: file_path.to_owned(),
1741                source: sanitized_source.to_owned(),
1742                offset: 0,
1743            })
1744        })?;
1745        self.ensure_imports(additional_imports, &shader_defs)
1746            .map_err(|err| {
1747                err.into_composer_error(ErrSource::Constructing {
1748                    path: file_path.to_owned(),
1749                    source: sanitized_source.to_owned(),
1750                    offset: 0,
1751                })
1752            })?;
1753
1754        let definition = ComposableModuleDefinition {
1755            name,
1756            sanitized_source: cleaned_source.clone(),
1757            language: shader_type.into(),
1758            file_path: file_path.to_owned(),
1759            module_index: 0,
1760            additional_imports: additional_imports.to_vec(),
1761            // we don't care about these for creating a top-level module
1762            effective_defs: Default::default(),
1763            all_imports: Default::default(),
1764            shader_defs: Default::default(),
1765            modules: Default::default(),
1766        };
1767
1768        let PreprocessOutput {
1769            preprocessed_source,
1770            imports,
1771        } = self
1772            .preprocessor
1773            .preprocess(&cleaned_source, &shader_defs)
1774            .map_err(|inner| ComposerError {
1775                inner,
1776                source: ErrSource::Constructing {
1777                    path: file_path.to_owned(),
1778                    source: sanitized_source,
1779                    offset: 0,
1780                },
1781            })?;
1782
1783        let composable = self
1784            .create_composable_module(
1785                &definition,
1786                String::from(""),
1787                &shader_defs,
1788                false,
1789                false,
1790                &preprocessed_source,
1791                imports,
1792                Some(wgsl_directives),
1793            )
1794            .map_err(|e| ComposerError {
1795                inner: e.inner,
1796                source: ErrSource::Constructing {
1797                    path: definition.file_path.to_owned(),
1798                    source: preprocessed_source.clone(),
1799                    offset: e.source.offset(),
1800                },
1801            })?;
1802
1803        let mut derived = DerivedModule::default();
1804
1805        let mut already_added = Default::default();
1806        for import in &composable.imports {
1807            self.add_import(
1808                &mut derived,
1809                import,
1810                &shader_defs,
1811                false,
1812                &mut already_added,
1813            );
1814        }
1815
1816        Self::add_composable_data(&mut derived, &composable, None, 0, false);
1817
1818        let stage = match shader_type {
1819            #[cfg(feature = "glsl")]
1820            ShaderType::GlslVertex => Some(naga::ShaderStage::Vertex),
1821            #[cfg(feature = "glsl")]
1822            ShaderType::GlslFragment => Some(naga::ShaderStage::Fragment),
1823            _ => None,
1824        };
1825
1826        let mut entry_points = Vec::default();
1827        derived.set_shader_source(&composable.module_ir, 0);
1828        for ep in &composable.module_ir.entry_points {
1829            let mapped_func = derived.localize_function(&ep.function);
1830            entry_points.push(EntryPoint {
1831                name: ep.name.clone(),
1832                function: mapped_func,
1833                stage: stage.unwrap_or(ep.stage),
1834                early_depth_test: ep.early_depth_test,
1835                workgroup_size: ep.workgroup_size,
1836                workgroup_size_overrides: ep.workgroup_size_overrides,
1837            });
1838        }
1839        let mut naga_module = naga::Module {
1840            entry_points,
1841            ..derived.into()
1842        };
1843
1844        // apply overrides
1845        if !composable.override_functions.is_empty() {
1846            let mut redirect = Redirector::new(naga_module);
1847
1848            for (base_function, overrides) in composable.override_functions {
1849                let mut omit = HashSet::default();
1850
1851                let mut original = base_function;
1852                for replacement in overrides {
1853                    let (_h_orig, _h_replace) = redirect
1854                        .redirect_function(&original, &replacement, &omit)
1855                        .map_err(|e| ComposerError {
1856                            inner: e.into(),
1857                            source: ErrSource::Constructing {
1858                                path: file_path.to_owned(),
1859                                source: preprocessed_source.clone(),
1860                                offset: composable.start_offset,
1861                            },
1862                        })?;
1863                    omit.insert(replacement.clone());
1864                    original = replacement;
1865                }
1866            }
1867
1868            naga_module = redirect.into_module().map_err(|e| ComposerError {
1869                inner: e.into(),
1870                source: ErrSource::Constructing {
1871                    path: file_path.to_owned(),
1872                    source: preprocessed_source.clone(),
1873                    offset: composable.start_offset,
1874                },
1875            })?;
1876        }
1877
1878        // validation
1879        if self.validate {
1880            let info = self.create_validator().validate(&naga_module);
1881            match info {
1882                Ok(_) => Ok(naga_module),
1883                Err(e) => {
1884                    let original_span = e.spans().last();
1885                    let err_source = match original_span.and_then(|(span, _)| span.to_range()) {
1886                        Some(rng) => {
1887                            let module_index = rng.start >> SPAN_SHIFT;
1888                            match module_index {
1889                                0 => ErrSource::Constructing {
1890                                    path: file_path.to_owned(),
1891                                    source: preprocessed_source.clone(),
1892                                    offset: composable.start_offset,
1893                                },
1894                                _ => {
1895                                    let module_name =
1896                                        self.module_index.get(&module_index).unwrap().clone();
1897                                    let offset = self
1898                                        .module_sets
1899                                        .get(&module_name)
1900                                        .unwrap()
1901                                        .get_module(&shader_defs)
1902                                        .unwrap()
1903                                        .start_offset;
1904                                    ErrSource::Module {
1905                                        name: module_name,
1906                                        offset,
1907                                        defs: shader_defs.clone(),
1908                                    }
1909                                }
1910                            }
1911                        }
1912                        None => ErrSource::Constructing {
1913                            path: file_path.to_owned(),
1914                            source: preprocessed_source.clone(),
1915                            offset: composable.start_offset,
1916                        },
1917                    };
1918
1919                    Err(ComposerError {
1920                        inner: ComposerErrorInner::ShaderValidationError(e),
1921                        source: err_source,
1922                    })
1923                }
1924            }
1925        } else {
1926            Ok(naga_module)
1927        }
1928    }
1929}
1930
1931static PREPROCESSOR: LazyLock<Preprocessor> = LazyLock::new(Preprocessor::default);
1932
1933/// Get module name and all required imports (ignoring shader_defs) from a shader string
1934pub fn get_preprocessor_data(
1935    source: &str,
1936) -> (
1937    Option<String>,
1938    Vec<ImportDefinition>,
1939    HashMap<String, ShaderDefValue>,
1940) {
1941    if let Ok(PreprocessorMetaData {
1942        name,
1943        imports,
1944        defines,
1945        ..
1946    }) = PREPROCESSOR.get_preprocessor_metadata(source, true)
1947    {
1948        (
1949            name,
1950            imports
1951                .into_iter()
1952                .map(|import_with_offset| import_with_offset.definition)
1953                .collect(),
1954            defines,
1955        )
1956    } else {
1957        // if errors occur we return nothing; the actual error will be displayed when the caller attempts to use the shader
1958        Default::default()
1959    }
1960}