naga_oil/compose/
mod.rs

1use indexmap::IndexMap;
2/// the compose module allows construction of shaders from modules (which are themselves shaders).
3///
4/// it does this by treating shaders as modules, and
5/// - building each module independently to naga IR
6/// - creating "header" files for each supported language, which are used to build dependent modules/shaders
7/// - making final shaders by combining the shader IR with the IR for imported modules
8///
9/// for multiple small shaders with large common imports, this can be faster than parsing the full source for each shader, and it allows for constructing shaders in a cleaner modular manner with better scope control.
10///
11/// ## imports
12///
13/// shaders can be added to the composer as modules. this makes their types, constants, variables and functions available to modules/shaders that import them. note that importing a module will affect the final shader's global state if the module defines globals variables with bindings.
14///
15/// modules must include a `#define_import_path` directive that names the module.
16///
17/// ```ignore
18/// #define_import_path my_module
19///
20/// fn my_func() -> f32 {
21///     return 1.0;
22/// }
23/// ```
24///
25/// shaders can then import the module with an `#import` directive (with an optional `as` name). at point of use, imported items must be qualified:
26///
27/// ```ignore
28/// #import my_module
29/// #import my_other_module as Mod2
30///
31/// fn main() -> f32 {
32///     let x = my_module::my_func();
33///     let y = Mod2::my_other_func();
34///     return x*y;
35/// }
36/// ```
37///
38/// or import a comma-separated list of individual items with a `#from` directive. at point of use, imported items must be prefixed with `::` :
39///
40/// ```ignore
41/// #from my_module import my_func, my_const
42///
43/// fn main() -> f32 {
44///     return ::my_func(::my_const);
45/// }
46/// ```
47///
48/// imports can be nested - modules may import other modules, but not recursively. when a new module is added, all its `#import`s must already have been added.
49/// the same module can be imported multiple times by different modules in the import tree.
50/// there is no overlap of namespaces, so the same function names (or type, constant, or variable names) may be used in different modules.
51///
52/// note: when importing an item with the `#from` directive, the final shader will include the required dependencies (bindings, globals, consts, other functions) of the imported item, but will not include the rest of the imported module. it will however still include all of any modules imported by the imported module. this is probably not desired in general and may be fixed in a future version. currently for a more complete culling of unused dependencies the `prune` module can be used.
53///
54/// ## overriding functions
55///
56/// virtual functions can be declared with the `virtual` keyword:
57/// ```ignore
58/// virtual fn point_light(world_position: vec3<f32>) -> vec3<f32> { ... }
59/// ```
60/// virtual functions defined in imported modules can then be overridden using the `override` keyword:
61///
62/// ```ignore
63/// #import bevy_pbr::lighting as Lighting
64///
65/// override fn Lighting::point_light (world_position: vec3<f32>) -> vec3<f32> {
66///     let original = Lighting::point_light(world_position);
67///     let quantized = vec3<u32>(original * 3.0);
68///     return vec3<f32>(quantized) / 3.0;
69/// }
70/// ```
71///
72/// override function definitions cause *all* calls to the original function in the entire shader scope to be replaced by calls to the new function, with the exception of calls within the override function itself.
73///
74/// the function signature of the override must match the base function.
75///
76/// overrides can be specified at any point in the final shader's import tree.
77///
78/// multiple overrides can be applied to the same function. for example, given :
79/// - a module `a` containing a function `f`,
80/// - a module `b` that imports `a`, and containing an `override a::f` function,
81/// - a module `c` that imports `a` and `b`, and containing an `override a::f` function,
82///
83/// then b and c both specify an override for `a::f`.
84/// the `override fn a::f` declared in module `b` may call to `a::f` within its body.
85/// the `override fn a::f` declared in module 'c' may call to `a::f` within its body, but the call will be redirected to `b::f`.
86/// any other calls to `a::f` (within modules 'a' or `b`, or anywhere else) will end up redirected to `c::f`
87/// in this way a chain or stack of overrides can be applied.
88///
89/// different overrides of the same function can be specified in different import branches. the final stack will be ordered based on the first occurrence of the override in the import tree (using a depth first search).
90///
91/// note that imports into a module/shader are processed in order, but are processed before the body of the current shader/module regardless of where they occur in that module, so there is no way to import a module containing an override and inject a call into the override stack prior to that imported override. you can instead create two modules each containing an override and import them into a parent module/shader to order them as required.
92/// override functions can currently only be defined in wgsl.
93///
94/// if the `override_any` crate feature is enabled, then the `virtual` keyword is not required for the function being overridden.
95///
96/// ## languages
97///
98/// modules can we written in GLSL or WGSL. shaders with entry points can be imported as modules (provided they have a `#define_import_path` directive). entry points are available to call from imported modules either via their name (for WGSL) or via `module::main` (for GLSL).
99///
100/// final shaders can also be written in GLSL or WGSL. for GLSL users must specify whether the shader is a vertex shader or fragment shader via the `ShaderType` argument (GLSL compute shaders are not supported).
101///
102/// ## preprocessing
103///
104/// when generating a final shader or adding a composable module, a set of `shader_def` string/value pairs must be provided. The value can be a bool (`ShaderDefValue::Bool`), an i32 (`ShaderDefValue::Int`) or a u32 (`ShaderDefValue::UInt`).
105///
106/// these allow conditional compilation of parts of modules and the final shader. conditional compilation is performed with `#if` / `#ifdef` / `#ifndef`, `#else` and `#endif` preprocessor directives:
107///
108/// ```ignore
109/// fn get_number() -> f32 {
110///     #ifdef BIG_NUMBER
111///         return 999.0;
112///     #else
113///         return 0.999;
114///     #endif
115/// }
116/// ```
117/// the `#ifdef` directive matches when the def name exists in the input binding set (regardless of value). the `#ifndef` directive is the reverse.
118///
119/// the `#if` directive requires a def name, an operator, and a value for comparison:
120/// - the def name must be a provided `shader_def` name.
121/// - the operator must be one of `==`, `!=`, `>=`, `>`, `<`, `<=`
122/// - the value must be an integer literal if comparing to a `ShaderDef::Int`, or `true` or `false` if comparing to a `ShaderDef::Bool`.
123///
124/// shader defs can also be used in the shader source with `#SHADER_DEF` or `#{SHADER_DEF}`, and will be substituted for their value.
125///
126/// ## error reporting
127///
128/// codespan reporting for errors is available using the error `emit_to_string` method. this requires validation to be enabled, which is true by default. `Composer::non_validating()` produces a non-validating composer that is not able to give accurate error reporting.
129///
130use naga::EntryPoint;
131use regex::Regex;
132use std::collections::{hash_map::Entry, BTreeMap, HashMap, HashSet};
133use tracing::{debug, trace};
134
135use crate::{
136    compose::preprocess::{PreprocessOutput, PreprocessorMetaData},
137    derive::DerivedModule,
138    redirect::Redirector,
139};
140
141pub use self::error::{ComposerError, ComposerErrorInner, ErrSource};
142use self::preprocess::Preprocessor;
143
144pub mod comment_strip_iter;
145pub mod error;
146pub mod parse_imports;
147pub mod preprocess;
148mod test;
149pub mod tokenizer;
150
151#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug, Default)]
152pub enum ShaderLanguage {
153    #[default]
154    Wgsl,
155    #[cfg(feature = "glsl")]
156    Glsl,
157}
158
159#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug, Default)]
160pub enum ShaderType {
161    #[default]
162    Wgsl,
163    #[cfg(feature = "glsl")]
164    GlslVertex,
165    #[cfg(feature = "glsl")]
166    GlslFragment,
167}
168
169impl From<ShaderType> for ShaderLanguage {
170    fn from(ty: ShaderType) -> Self {
171        match ty {
172            ShaderType::Wgsl => ShaderLanguage::Wgsl,
173            #[cfg(feature = "glsl")]
174            ShaderType::GlslVertex | ShaderType::GlslFragment => ShaderLanguage::Glsl,
175        }
176    }
177}
178
179#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
180pub enum ShaderDefValue {
181    Bool(bool),
182    Int(i32),
183    UInt(u32),
184}
185
186impl Default for ShaderDefValue {
187    fn default() -> Self {
188        ShaderDefValue::Bool(true)
189    }
190}
191
192impl ShaderDefValue {
193    fn value_as_string(&self) -> String {
194        match self {
195            ShaderDefValue::Bool(val) => val.to_string(),
196            ShaderDefValue::Int(val) => val.to_string(),
197            ShaderDefValue::UInt(val) => val.to_string(),
198        }
199    }
200}
201
202#[derive(Clone, PartialEq, Eq, Hash, Debug, Default)]
203pub struct OwnedShaderDefs(BTreeMap<String, ShaderDefValue>);
204
205#[derive(Clone, PartialEq, Eq, Hash, Debug)]
206struct ModuleKey(OwnedShaderDefs);
207
208impl ModuleKey {
209    fn from_members(key: &HashMap<String, ShaderDefValue>, universe: &[String]) -> Self {
210        let mut acc = OwnedShaderDefs::default();
211        for item in universe {
212            if let Some(value) = key.get(item) {
213                acc.0.insert(item.to_owned(), *value);
214            }
215        }
216        ModuleKey(acc)
217    }
218}
219
220// a module built with a specific set of shader_defs
221#[derive(Default, Debug)]
222pub struct ComposableModule {
223    // module decoration, prefixed to all items from this module in the final source
224    pub decorated_name: String,
225    // module names required as imports, optionally with a list of items to import
226    pub imports: Vec<ImportDefinition>,
227    // types exported
228    pub owned_types: HashSet<String>,
229    // constants exported
230    pub owned_constants: HashSet<String>,
231    // vars exported
232    pub owned_vars: HashSet<String>,
233    // functions exported
234    pub owned_functions: HashSet<String>,
235    // local functions that can be overridden
236    pub virtual_functions: HashSet<String>,
237    // overriding functions defined in this module
238    // target function -> Vec<replacement functions>
239    pub override_functions: IndexMap<String, Vec<String>>,
240    // naga module, built against headers for any imports
241    module_ir: naga::Module,
242    // headers in different shader languages, used for building modules/shaders that import this module
243    // headers contain types, constants, global vars and empty function definitions -
244    // just enough to convert source strings that want to import this module into naga IR
245    // headers: HashMap<ShaderLanguage, String>,
246    header_ir: naga::Module,
247    // character offset of the start of the owned module string
248    start_offset: usize,
249}
250
251// data used to build a ComposableModule
252#[derive(Debug)]
253pub struct ComposableModuleDefinition {
254    pub name: String,
255    // shader text (with auto bindings replaced - we do this on module add as we only want to do it once to avoid burning slots)
256    pub sanitized_source: String,
257    // language
258    pub language: ShaderLanguage,
259    // source path for error display
260    pub file_path: String,
261    // shader def values bound to this module
262    pub shader_defs: HashMap<String, ShaderDefValue>,
263    // list of shader_defs that can affect this module
264    effective_defs: Vec<String>,
265    // full list of possible imports (regardless of shader_def configuration)
266    all_imports: HashSet<String>,
267    // additional imports to add (as though they were included in the source after any other imports)
268    additional_imports: Vec<ImportDefinition>,
269    // built composable modules for a given set of shader defs
270    modules: HashMap<ModuleKey, ComposableModule>,
271    // used in spans when this module is included
272    module_index: usize,
273    // preprocessor meta data
274    // metadata: PreprocessorMetaData,
275}
276
277impl ComposableModuleDefinition {
278    fn get_module(
279        &self,
280        shader_defs: &HashMap<String, ShaderDefValue>,
281    ) -> Option<&ComposableModule> {
282        self.modules
283            .get(&ModuleKey::from_members(shader_defs, &self.effective_defs))
284    }
285
286    fn insert_module(
287        &mut self,
288        shader_defs: &HashMap<String, ShaderDefValue>,
289        module: ComposableModule,
290    ) -> &ComposableModule {
291        match self
292            .modules
293            .entry(ModuleKey::from_members(shader_defs, &self.effective_defs))
294        {
295            Entry::Occupied(_) => panic!("entry already populated"),
296            Entry::Vacant(v) => v.insert(module),
297        }
298    }
299}
300
301#[derive(Debug, Clone, Default, PartialEq, Eq)]
302pub struct ImportDefinition {
303    pub import: String,
304    pub items: Vec<String>,
305}
306
307#[derive(Debug, Clone)]
308pub struct ImportDefWithOffset {
309    definition: ImportDefinition,
310    offset: usize,
311}
312
313/// module composer.
314/// stores any modules that can be imported into a shader
315/// and builds the final shader
316#[derive(Debug)]
317pub struct Composer {
318    pub validate: bool,
319    pub module_sets: HashMap<String, ComposableModuleDefinition>,
320    pub module_index: HashMap<usize, String>,
321    pub capabilities: naga::valid::Capabilities,
322    preprocessor: Preprocessor,
323    check_decoration_regex: Regex,
324    undecorate_regex: Regex,
325    virtual_fn_regex: Regex,
326    override_fn_regex: Regex,
327    undecorate_override_regex: Regex,
328    auto_binding_regex: Regex,
329    auto_binding_index: u32,
330}
331
332// shift for module index
333// 21 gives
334//   max size for shader of 2m characters
335//   max 2048 modules
336const SPAN_SHIFT: usize = 21;
337
338impl Default for Composer {
339    fn default() -> Self {
340        Self {
341            validate: true,
342            capabilities: Default::default(),
343            module_sets: Default::default(),
344            module_index: Default::default(),
345            preprocessor: Preprocessor::default(),
346            check_decoration_regex: Regex::new(
347                format!(
348                    "({}|{})",
349                    regex_syntax::escape(DECORATION_PRE),
350                    regex_syntax::escape(DECORATION_OVERRIDE_PRE)
351                )
352                .as_str(),
353            )
354            .unwrap(),
355            undecorate_regex: Regex::new(
356                format!(
357                    r"(\x1B\[\d+\w)?([\w\d_]+){}([A-Z0-9]*){}",
358                    regex_syntax::escape(DECORATION_PRE),
359                    regex_syntax::escape(DECORATION_POST)
360                )
361                .as_str(),
362            )
363            .unwrap(),
364            virtual_fn_regex: Regex::new(
365                r"(?P<lead>[\s]*virtual\s+fn\s+)(?P<function>[^\s]+)(?P<trail>\s*)\(",
366            )
367            .unwrap(),
368            override_fn_regex: Regex::new(
369                format!(
370                    r"(override\s+fn\s+)([^\s]+){}([\w\d]+){}(\s*)\(",
371                    regex_syntax::escape(DECORATION_PRE),
372                    regex_syntax::escape(DECORATION_POST)
373                )
374                .as_str(),
375            )
376            .unwrap(),
377            undecorate_override_regex: Regex::new(
378                format!(
379                    "{}([A-Z0-9]*){}",
380                    regex_syntax::escape(DECORATION_OVERRIDE_PRE),
381                    regex_syntax::escape(DECORATION_POST)
382                )
383                .as_str(),
384            )
385            .unwrap(),
386            auto_binding_regex: Regex::new(r"@binding\(auto\)").unwrap(),
387            auto_binding_index: 0,
388        }
389    }
390}
391
392const DECORATION_PRE: &str = "X_naga_oil_mod_X";
393const DECORATION_POST: &str = "X";
394
395// must be same length as DECORATION_PRE for spans to work
396const DECORATION_OVERRIDE_PRE: &str = "X_naga_oil_vrt_X";
397
398struct IrBuildResult {
399    module: naga::Module,
400    start_offset: usize,
401    override_functions: IndexMap<String, Vec<String>>,
402}
403
404impl Composer {
405    pub fn decorated_name(module_name: Option<&str>, item_name: &str) -> String {
406        match module_name {
407            Some(module_name) => format!("{}{}", item_name, Self::decorate(module_name)),
408            None => item_name.to_owned(),
409        }
410    }
411
412    fn decorate(module: &str) -> String {
413        let encoded = data_encoding::BASE32_NOPAD.encode(module.as_bytes());
414        format!("{DECORATION_PRE}{encoded}{DECORATION_POST}")
415    }
416
417    fn decode(from: &str) -> String {
418        String::from_utf8(data_encoding::BASE32_NOPAD.decode(from.as_bytes()).unwrap()).unwrap()
419    }
420
421    /// Shorthand for creating a naga validator.
422    fn create_validator(&self) -> naga::valid::Validator {
423        naga::valid::Validator::new(naga::valid::ValidationFlags::all(), self.capabilities)
424    }
425
426    fn undecorate(&self, string: &str) -> String {
427        let undecor = self
428            .undecorate_regex
429            .replace_all(string, |caps: &regex::Captures| {
430                format!(
431                    "{}{}::{}",
432                    caps.get(1).map(|cc| cc.as_str()).unwrap_or(""),
433                    Self::decode(caps.get(3).unwrap().as_str()),
434                    caps.get(2).unwrap().as_str()
435                )
436            });
437
438        let undecor =
439            self.undecorate_override_regex
440                .replace_all(&undecor, |caps: &regex::Captures| {
441                    format!(
442                        "override fn {}::",
443                        Self::decode(caps.get(1).unwrap().as_str())
444                    )
445                });
446
447        undecor.to_string()
448    }
449
450    fn sanitize_and_set_auto_bindings(&mut self, source: &str) -> String {
451        let mut substituted_source = source.replace("\r\n", "\n").replace('\r', "\n");
452        if !substituted_source.ends_with('\n') {
453            substituted_source.push('\n');
454        }
455
456        // replace @binding(auto) with an incrementing index
457        struct AutoBindingReplacer<'a> {
458            auto: &'a mut u32,
459        }
460
461        impl regex::Replacer for AutoBindingReplacer<'_> {
462            fn replace_append(&mut self, _: &regex::Captures<'_>, dst: &mut String) {
463                dst.push_str(&format!("@binding({})", self.auto));
464                *self.auto += 1;
465            }
466        }
467
468        let substituted_source = self.auto_binding_regex.replace_all(
469            &substituted_source,
470            AutoBindingReplacer {
471                auto: &mut self.auto_binding_index,
472            },
473        );
474
475        substituted_source.into_owned()
476    }
477
478    fn naga_to_string(
479        &self,
480        naga_module: &mut naga::Module,
481        language: ShaderLanguage,
482        #[allow(unused)] header_for: &str, // Only used when GLSL is enabled
483    ) -> Result<String, ComposerErrorInner> {
484        // TODO: cache headers again
485        let info = self
486            .create_validator()
487            .validate(naga_module)
488            .map_err(ComposerErrorInner::HeaderValidationError)?;
489
490        match language {
491            ShaderLanguage::Wgsl => naga::back::wgsl::write_string(
492                naga_module,
493                &info,
494                naga::back::wgsl::WriterFlags::EXPLICIT_TYPES,
495            )
496            .map_err(ComposerErrorInner::WgslBackError),
497            #[cfg(feature = "glsl")]
498            ShaderLanguage::Glsl => {
499                let vec4 = naga_module.types.insert(
500                    naga::Type {
501                        name: None,
502                        inner: naga::TypeInner::Vector {
503                            size: naga::VectorSize::Quad,
504                            scalar: naga::Scalar::F32,
505                        },
506                    },
507                    naga::Span::UNDEFINED,
508                );
509                // add a dummy entry point for glsl headers
510                let dummy_entry_point = "dummy_module_entry_point".to_owned();
511                let func = naga::Function {
512                    name: Some(dummy_entry_point.clone()),
513                    arguments: Default::default(),
514                    result: Some(naga::FunctionResult {
515                        ty: vec4,
516                        binding: Some(naga::Binding::BuiltIn(naga::BuiltIn::Position {
517                            invariant: false,
518                        })),
519                    }),
520                    local_variables: Default::default(),
521                    expressions: Default::default(),
522                    named_expressions: Default::default(),
523                    body: Default::default(),
524                    diagnostic_filter_leaf: Default::default(),
525                };
526                let ep = EntryPoint {
527                    name: dummy_entry_point.clone(),
528                    stage: naga::ShaderStage::Vertex,
529                    function: func,
530                    early_depth_test: None,
531                    workgroup_size: [0, 0, 0],
532                    workgroup_size_overrides: None,
533                };
534
535                naga_module.entry_points.push(ep);
536
537                let info = self
538                    .create_validator()
539                    .validate(naga_module)
540                    .map_err(ComposerErrorInner::HeaderValidationError)?;
541
542                let mut string = String::new();
543                let options = naga::back::glsl::Options {
544                    version: naga::back::glsl::Version::Desktop(450),
545                    writer_flags: naga::back::glsl::WriterFlags::INCLUDE_UNUSED_ITEMS,
546                    ..Default::default()
547                };
548                let pipeline_options = naga::back::glsl::PipelineOptions {
549                    shader_stage: naga::ShaderStage::Vertex,
550                    entry_point: dummy_entry_point,
551                    multiview: None,
552                };
553                let mut writer = naga::back::glsl::Writer::new(
554                    &mut string,
555                    naga_module,
556                    &info,
557                    &options,
558                    &pipeline_options,
559                    naga::proc::BoundsCheckPolicies::default(),
560                )
561                .map_err(ComposerErrorInner::GlslBackError)?;
562
563                writer.write().map_err(ComposerErrorInner::GlslBackError)?;
564
565                // strip version decl and main() impl
566                let lines: Vec<_> = string.lines().collect();
567                let string = lines[1..lines.len() - 3].join("\n");
568                trace!("glsl header for {}:\n\"\n{:?}\n\"", header_for, string);
569
570                Ok(string)
571            }
572        }
573    }
574
575    // build naga module for a given shader_def configuration. builds a minimal self-contained module built against headers for imports
576    fn create_module_ir(
577        &self,
578        name: &str,
579        source: String,
580        language: ShaderLanguage,
581        imports: &[ImportDefinition],
582        shader_defs: &HashMap<String, ShaderDefValue>,
583    ) -> Result<IrBuildResult, ComposerError> {
584        debug!("creating IR for {} with defs: {:?}", name, shader_defs);
585
586        let mut module_string = match language {
587            ShaderLanguage::Wgsl => String::new(),
588            #[cfg(feature = "glsl")]
589            ShaderLanguage::Glsl => String::from("#version 450\n"),
590        };
591
592        let mut override_functions: IndexMap<String, Vec<String>> = IndexMap::default();
593        let mut added_imports: HashSet<String> = HashSet::new();
594        let mut header_module = DerivedModule::default();
595
596        for import in imports {
597            if added_imports.contains(&import.import) {
598                continue;
599            }
600            // add to header module
601            self.add_import(
602                &mut header_module,
603                import,
604                shader_defs,
605                true,
606                &mut added_imports,
607            );
608
609            // // we must have ensured these exist with Composer::ensure_imports()
610            trace!("looking for {}", import.import);
611            let import_module_set = self.module_sets.get(&import.import).unwrap();
612            trace!("with defs {:?}", shader_defs);
613            let module = import_module_set.get_module(shader_defs).unwrap();
614            trace!("ok");
615
616            // gather overrides
617            if !module.override_functions.is_empty() {
618                for (original, replacements) in &module.override_functions {
619                    match override_functions.entry(original.clone()) {
620                        indexmap::map::Entry::Occupied(o) => {
621                            let existing = o.into_mut();
622                            let new_replacements: Vec<_> = replacements
623                                .iter()
624                                .filter(|rep| !existing.contains(rep))
625                                .cloned()
626                                .collect();
627                            existing.extend(new_replacements);
628                        }
629                        indexmap::map::Entry::Vacant(v) => {
630                            v.insert(replacements.clone());
631                        }
632                    }
633                }
634            }
635        }
636
637        let composed_header = self
638            .naga_to_string(&mut header_module.into(), language, name)
639            .map_err(|inner| ComposerError {
640                inner,
641                source: ErrSource::Module {
642                    name: name.to_owned(),
643                    offset: 0,
644                    defs: shader_defs.clone(),
645                },
646            })?;
647        module_string.push_str(&composed_header);
648
649        let start_offset = module_string.len();
650
651        module_string.push_str(&source);
652
653        trace!(
654            "parsing {}: {}, header len {}, total len {}",
655            name,
656            module_string,
657            start_offset,
658            module_string.len()
659        );
660        let module = match language {
661            ShaderLanguage::Wgsl => naga::front::wgsl::parse_str(&module_string).map_err(|e| {
662                debug!("full err'd source file: \n---\n{}\n---", module_string);
663                ComposerError {
664                    inner: ComposerErrorInner::WgslParseError(e),
665                    source: ErrSource::Module {
666                        name: name.to_owned(),
667                        offset: start_offset,
668                        defs: shader_defs.clone(),
669                    },
670                }
671            })?,
672            #[cfg(feature = "glsl")]
673            ShaderLanguage::Glsl => naga::front::glsl::Frontend::default()
674                .parse(
675                    &naga::front::glsl::Options {
676                        stage: naga::ShaderStage::Vertex,
677                        defines: Default::default(),
678                    },
679                    &module_string,
680                )
681                .map_err(|e| {
682                    debug!("full err'd source file: \n---\n{}\n---", module_string);
683                    ComposerError {
684                        inner: ComposerErrorInner::GlslParseError(e),
685                        source: ErrSource::Module {
686                            name: name.to_owned(),
687                            offset: start_offset,
688                            defs: shader_defs.clone(),
689                        },
690                    }
691                })?,
692        };
693
694        Ok(IrBuildResult {
695            module,
696            start_offset,
697            override_functions,
698        })
699    }
700
701    // check that identifiers exported by a module do not get modified in string export
702    fn validate_identifiers(
703        source_ir: &naga::Module,
704        lang: ShaderLanguage,
705        header: &str,
706        module_decoration: &str,
707        owned_types: &HashSet<String>,
708    ) -> Result<(), ComposerErrorInner> {
709        // TODO: remove this once glsl front support is complete
710        #[cfg(feature = "glsl")]
711        if lang == ShaderLanguage::Glsl {
712            return Ok(());
713        }
714
715        let recompiled = match lang {
716            ShaderLanguage::Wgsl => naga::front::wgsl::parse_str(header).unwrap(),
717            #[cfg(feature = "glsl")]
718            ShaderLanguage::Glsl => naga::front::glsl::Frontend::default()
719                .parse(
720                    &naga::front::glsl::Options {
721                        stage: naga::ShaderStage::Vertex,
722                        defines: Default::default(),
723                    },
724                    &format!("{}\n{}", header, "void main() {}"),
725                )
726                .map_err(|e| {
727                    debug!("full err'd source file: \n---\n{header}\n---");
728                    ComposerErrorInner::GlslParseError(e)
729                })?,
730        };
731
732        let recompiled_types: IndexMap<_, _> = recompiled
733            .types
734            .iter()
735            .flat_map(|(h, ty)| ty.name.as_deref().map(|name| (name, h)))
736            .collect();
737        for (h, ty) in source_ir.types.iter() {
738            if let Some(name) = &ty.name {
739                let decorated_type_name = format!("{name}{module_decoration}");
740                if !owned_types.contains(&decorated_type_name) {
741                    continue;
742                }
743                match recompiled_types.get(decorated_type_name.as_str()) {
744                    Some(recompiled_h) => {
745                        if let naga::TypeInner::Struct { members, .. } = &ty.inner {
746                            let recompiled_ty = recompiled.types.get_handle(*recompiled_h).unwrap();
747                            let naga::TypeInner::Struct {
748                                members: recompiled_members,
749                                ..
750                            } = &recompiled_ty.inner
751                            else {
752                                panic!();
753                            };
754                            for (member, recompiled_member) in
755                                members.iter().zip(recompiled_members)
756                            {
757                                if member.name != recompiled_member.name {
758                                    return Err(ComposerErrorInner::InvalidIdentifier {
759                                        original: member.name.clone().unwrap_or_default(),
760                                        at: source_ir.types.get_span(h),
761                                    });
762                                }
763                            }
764                        }
765                    }
766                    None => {
767                        return Err(ComposerErrorInner::InvalidIdentifier {
768                            original: name.clone(),
769                            at: source_ir.types.get_span(h),
770                        })
771                    }
772                }
773            }
774        }
775
776        let recompiled_consts: HashSet<_> = recompiled
777            .constants
778            .iter()
779            .flat_map(|(_, c)| c.name.as_deref())
780            .filter(|name| name.ends_with(module_decoration))
781            .collect();
782        for (h, c) in source_ir.constants.iter() {
783            if let Some(name) = &c.name {
784                if name.ends_with(module_decoration) && !recompiled_consts.contains(name.as_str()) {
785                    return Err(ComposerErrorInner::InvalidIdentifier {
786                        original: name.clone(),
787                        at: source_ir.constants.get_span(h),
788                    });
789                }
790            }
791        }
792
793        let recompiled_globals: HashSet<_> = recompiled
794            .global_variables
795            .iter()
796            .flat_map(|(_, c)| c.name.as_deref())
797            .filter(|name| name.ends_with(module_decoration))
798            .collect();
799        for (h, gv) in source_ir.global_variables.iter() {
800            if let Some(name) = &gv.name {
801                if name.ends_with(module_decoration) && !recompiled_globals.contains(name.as_str())
802                {
803                    return Err(ComposerErrorInner::InvalidIdentifier {
804                        original: name.clone(),
805                        at: source_ir.global_variables.get_span(h),
806                    });
807                }
808            }
809        }
810
811        let recompiled_fns: HashSet<_> = recompiled
812            .functions
813            .iter()
814            .flat_map(|(_, c)| c.name.as_deref())
815            .filter(|name| name.ends_with(module_decoration))
816            .collect();
817        for (h, f) in source_ir.functions.iter() {
818            if let Some(name) = &f.name {
819                if name.ends_with(module_decoration) && !recompiled_fns.contains(name.as_str()) {
820                    return Err(ComposerErrorInner::InvalidIdentifier {
821                        original: name.clone(),
822                        at: source_ir.functions.get_span(h),
823                    });
824                }
825            }
826        }
827
828        Ok(())
829    }
830
831    // build a ComposableModule from a ComposableModuleDefinition, for a given set of shader defs
832    // - build the naga IR (against headers)
833    // - record any types/vars/constants/functions that are defined within this module
834    // - build headers for each supported language
835    #[allow(clippy::too_many_arguments)]
836    fn create_composable_module(
837        &mut self,
838        module_definition: &ComposableModuleDefinition,
839        module_decoration: String,
840        shader_defs: &HashMap<String, ShaderDefValue>,
841        create_headers: bool,
842        demote_entrypoints: bool,
843        source: &str,
844        imports: Vec<ImportDefWithOffset>,
845    ) -> Result<ComposableModule, ComposerError> {
846        let mut imports: Vec<_> = imports
847            .into_iter()
848            .map(|import_with_offset| import_with_offset.definition)
849            .collect();
850        imports.extend(module_definition.additional_imports.to_vec());
851
852        trace!(
853            "create composable module {}: source len {}",
854            module_definition.name,
855            source.len()
856        );
857
858        // record virtual/overridable functions
859        let mut virtual_functions: HashSet<String> = Default::default();
860        let source = self
861            .virtual_fn_regex
862            .replace_all(source, |cap: &regex::Captures| {
863                let target_function = cap.get(2).unwrap().as_str().to_owned();
864
865                let replacement_str = format!(
866                    "{}fn {}{}(",
867                    " ".repeat(cap.get(1).unwrap().range().len() - 3),
868                    target_function,
869                    " ".repeat(cap.get(3).unwrap().range().len()),
870                );
871
872                virtual_functions.insert(target_function);
873
874                replacement_str
875            });
876
877        // record and rename override functions
878        let mut local_override_functions: IndexMap<String, String> = Default::default();
879
880        #[cfg(not(feature = "override_any"))]
881        let mut override_error = None;
882
883        let source =
884            self.override_fn_regex
885                .replace_all(&source, |cap: &regex::Captures| {
886                    let target_module = cap.get(3).unwrap().as_str().to_owned();
887                    let target_function = cap.get(2).unwrap().as_str().to_owned();
888
889                    #[cfg(not(feature = "override_any"))]
890                    {
891                        let wrap_err = |inner: ComposerErrorInner| -> ComposerError {
892                            ComposerError {
893                                inner,
894                                source: ErrSource::Module {
895                                    name: module_definition.name.to_owned(),
896                                    offset: 0,
897                                    defs: shader_defs.clone(),
898                                },
899                            }
900                        };
901
902                        // ensure overrides are applied to virtual functions
903                        let raw_module_name = Self::decode(&target_module);
904                        let module_set = self.module_sets.get(&raw_module_name);
905
906                        match module_set {
907                            None => {
908                                // TODO this should be unreachable?
909                                let pos = cap.get(3).unwrap().start();
910                                override_error = Some(wrap_err(
911                                    ComposerErrorInner::ImportNotFound(raw_module_name, pos),
912                                ));
913                            }
914                            Some(module_set) => {
915                                let module = module_set.get_module(shader_defs).unwrap();
916                                if !module.virtual_functions.contains(&target_function) {
917                                    let pos = cap.get(2).unwrap().start();
918                                    override_error =
919                                        Some(wrap_err(ComposerErrorInner::OverrideNotVirtual {
920                                            name: target_function.clone(),
921                                            pos,
922                                        }));
923                                }
924                            }
925                        }
926                    }
927
928                    let base_name = format!(
929                        "{}{}{}{}",
930                        target_function.as_str(),
931                        DECORATION_PRE,
932                        target_module.as_str(),
933                        DECORATION_POST,
934                    );
935                    let rename = format!(
936                        "{}{}{}{}",
937                        target_function.as_str(),
938                        DECORATION_OVERRIDE_PRE,
939                        target_module.as_str(),
940                        DECORATION_POST,
941                    );
942
943                    let replacement_str = format!(
944                        "{}fn {}{}(",
945                        " ".repeat(cap.get(1).unwrap().range().len() - 3),
946                        rename,
947                        " ".repeat(cap.get(4).unwrap().range().len()),
948                    );
949
950                    local_override_functions.insert(rename, base_name);
951
952                    replacement_str
953                })
954                .to_string();
955
956        #[cfg(not(feature = "override_any"))]
957        if let Some(err) = override_error {
958            return Err(err);
959        }
960
961        trace!("local overrides: {:?}", local_override_functions);
962        trace!(
963            "create composable module {}: source len {}",
964            module_definition.name,
965            source.len()
966        );
967
968        let IrBuildResult {
969            module: mut source_ir,
970            start_offset,
971            mut override_functions,
972        } = self.create_module_ir(
973            &module_definition.name,
974            source,
975            module_definition.language,
976            &imports,
977            shader_defs,
978        )?;
979
980        // from here on errors need to be reported using the modified source with start_offset
981        let wrap_err = |inner: ComposerErrorInner| -> ComposerError {
982            ComposerError {
983                inner,
984                source: ErrSource::Module {
985                    name: module_definition.name.to_owned(),
986                    offset: start_offset,
987                    defs: shader_defs.clone(),
988                },
989            }
990        };
991
992        // add our local override to the total set of overrides for the given function
993        for (rename, base_name) in &local_override_functions {
994            override_functions
995                .entry(base_name.clone())
996                .or_default()
997                .push(format!("{rename}{module_decoration}"));
998        }
999
1000        // rename and record owned items (except types which can't be mutably accessed)
1001        let mut owned_constants = IndexMap::new();
1002        for (h, c) in source_ir.constants.iter_mut() {
1003            if let Some(name) = c.name.as_mut() {
1004                if !name.contains(DECORATION_PRE) {
1005                    *name = format!("{name}{module_decoration}");
1006                    owned_constants.insert(name.clone(), h);
1007                }
1008            }
1009        }
1010
1011        // These are naga/wgpu's pipeline override constants, not naga_oil's overrides
1012        let mut owned_pipeline_overrides = IndexMap::new();
1013        for (h, po) in source_ir.overrides.iter_mut() {
1014            if let Some(name) = po.name.as_mut() {
1015                if !name.contains(DECORATION_PRE) {
1016                    *name = format!("{name}{module_decoration}");
1017                    owned_pipeline_overrides.insert(name.clone(), h);
1018                }
1019            }
1020        }
1021
1022        let mut owned_vars = IndexMap::new();
1023        for (h, gv) in source_ir.global_variables.iter_mut() {
1024            if let Some(name) = gv.name.as_mut() {
1025                if !name.contains(DECORATION_PRE) {
1026                    *name = format!("{name}{module_decoration}");
1027
1028                    owned_vars.insert(name.clone(), h);
1029                }
1030            }
1031        }
1032
1033        let mut owned_functions = IndexMap::new();
1034        for (h_f, f) in source_ir.functions.iter_mut() {
1035            if let Some(name) = f.name.as_mut() {
1036                if !name.contains(DECORATION_PRE) {
1037                    *name = format!("{name}{module_decoration}");
1038
1039                    // create dummy header function
1040                    let header_function = naga::Function {
1041                        name: Some(name.clone()),
1042                        arguments: f.arguments.to_vec(),
1043                        result: f.result.clone(),
1044                        local_variables: Default::default(),
1045                        expressions: Default::default(),
1046                        named_expressions: Default::default(),
1047                        body: Default::default(),
1048                        diagnostic_filter_leaf: None,
1049                    };
1050
1051                    // record owned function
1052                    owned_functions.insert(name.clone(), (Some(h_f), header_function));
1053                }
1054            }
1055        }
1056
1057        if demote_entrypoints {
1058            // make normal functions out of the source entry points
1059            for ep in &mut source_ir.entry_points {
1060                ep.function.name = Some(format!(
1061                    "{}{}",
1062                    ep.function.name.as_deref().unwrap_or("main"),
1063                    module_decoration,
1064                ));
1065                let header_function = naga::Function {
1066                    name: ep.function.name.clone(),
1067                    arguments: ep
1068                        .function
1069                        .arguments
1070                        .iter()
1071                        .cloned()
1072                        .map(|arg| naga::FunctionArgument {
1073                            name: arg.name,
1074                            ty: arg.ty,
1075                            binding: None,
1076                        })
1077                        .collect(),
1078                    result: ep.function.result.clone().map(|res| naga::FunctionResult {
1079                        ty: res.ty,
1080                        binding: None,
1081                    }),
1082                    local_variables: Default::default(),
1083                    expressions: Default::default(),
1084                    named_expressions: Default::default(),
1085                    body: Default::default(),
1086                    diagnostic_filter_leaf: None,
1087                };
1088
1089                owned_functions.insert(ep.function.name.clone().unwrap(), (None, header_function));
1090            }
1091        };
1092
1093        let mut module_builder = DerivedModule::default();
1094        let mut header_builder = DerivedModule::default();
1095        module_builder.set_shader_source(&source_ir, 0);
1096        header_builder.set_shader_source(&source_ir, 0);
1097
1098        // gather special types to exclude from owned types
1099        let mut special_types: HashSet<&naga::Handle<naga::Type>> = HashSet::new();
1100        special_types.extend(source_ir.special_types.predeclared_types.values());
1101        special_types.extend(
1102            [
1103                source_ir.special_types.ray_desc.as_ref(),
1104                source_ir.special_types.ray_intersection.as_ref(),
1105            ]
1106            .iter()
1107            .flatten(),
1108        );
1109
1110        // as the header of imports that use special types includes the special type definitions explicitly,
1111        // we also exclude anything with a name matching the known special type names
1112        let special_type_names = special_types
1113            .iter()
1114            .flat_map(|h| source_ir.types.get_handle(**h).unwrap().name.clone())
1115            .collect::<HashSet<_>>();
1116
1117        let mut owned_types = HashSet::new();
1118        for (h, ty) in source_ir.types.iter() {
1119            if let Some(name) = &ty.name {
1120                // we exclude any special types, these are added back later
1121                if special_types.contains(&h) || special_type_names.contains(name) {
1122                    continue;
1123                }
1124
1125                if !name.contains(DECORATION_PRE) {
1126                    let name = format!("{name}{module_decoration}");
1127                    owned_types.insert(name.clone());
1128                    // copy and rename types
1129                    module_builder.rename_type(&h, Some(name.clone()));
1130                    header_builder.rename_type(&h, Some(name));
1131                    continue;
1132                }
1133            }
1134
1135            // copy all required types
1136            module_builder.import_type(&h);
1137        }
1138
1139        // copy owned types into header and module
1140        for h in owned_constants.values() {
1141            header_builder.import_const(h);
1142            module_builder.import_const(h);
1143        }
1144
1145        for h in owned_pipeline_overrides.values() {
1146            header_builder.import_pipeline_override(h);
1147            module_builder.import_pipeline_override(h);
1148        }
1149
1150        for h in owned_vars.values() {
1151            header_builder.import_global(h);
1152            module_builder.import_global(h);
1153        }
1154
1155        // only stubs of owned functions into the header
1156        for (h_f, f) in owned_functions.values() {
1157            let span = h_f
1158                .map(|h_f| source_ir.functions.get_span(h_f))
1159                .unwrap_or(naga::Span::UNDEFINED);
1160            header_builder.import_function(f, span); // header stub function
1161        }
1162        // all functions into the module (note source_ir only contains stubs for imported functions)
1163        for (h_f, f) in source_ir.functions.iter() {
1164            let span = source_ir.functions.get_span(h_f);
1165            module_builder.import_function(f, span);
1166        }
1167        // // including entry points as vanilla functions if required
1168        if demote_entrypoints {
1169            for ep in &source_ir.entry_points {
1170                let mut f = ep.function.clone();
1171                f.arguments = f
1172                    .arguments
1173                    .into_iter()
1174                    .map(|arg| naga::FunctionArgument {
1175                        name: arg.name,
1176                        ty: arg.ty,
1177                        binding: None,
1178                    })
1179                    .collect();
1180                f.result = f.result.map(|res| naga::FunctionResult {
1181                    ty: res.ty,
1182                    binding: None,
1183                });
1184
1185                module_builder.import_function(&f, naga::Span::UNDEFINED);
1186                // todo figure out how to get span info for entrypoints
1187            }
1188        }
1189
1190        let has_special_types = module_builder.has_required_special_types();
1191        let module_ir = module_builder.into_module_with_entrypoints();
1192        let mut header_ir: naga::Module = header_builder.into();
1193
1194        // note: we cannot validate when special types are used, as writeback isn't supported
1195        if self.validate && create_headers && !has_special_types {
1196            // check that identifiers haven't been renamed
1197            #[allow(clippy::single_element_loop)]
1198            for language in [
1199                ShaderLanguage::Wgsl,
1200                #[cfg(feature = "glsl")]
1201                ShaderLanguage::Glsl,
1202            ] {
1203                let header = self
1204                    .naga_to_string(&mut header_ir, language, &module_definition.name)
1205                    .map_err(wrap_err)?;
1206                Self::validate_identifiers(
1207                    &source_ir,
1208                    language,
1209                    &header,
1210                    &module_decoration,
1211                    &owned_types,
1212                )
1213                .map_err(wrap_err)?;
1214            }
1215        }
1216
1217        let composable_module = ComposableModule {
1218            decorated_name: module_decoration,
1219            imports,
1220            owned_types,
1221            owned_constants: owned_constants.into_keys().collect(),
1222            owned_vars: owned_vars.into_keys().collect(),
1223            owned_functions: owned_functions.into_keys().collect(),
1224            virtual_functions,
1225            override_functions,
1226            module_ir,
1227            header_ir,
1228            start_offset,
1229        };
1230
1231        Ok(composable_module)
1232    }
1233
1234    // shunt all data owned by a composable into a derived module
1235    fn add_composable_data<'a>(
1236        derived: &mut DerivedModule<'a>,
1237        composable: &'a ComposableModule,
1238        items: Option<&Vec<String>>,
1239        span_offset: usize,
1240        header: bool,
1241    ) {
1242        let items: Option<HashSet<String>> = items.map(|items| {
1243            items
1244                .iter()
1245                .map(|item| format!("{}{}", item, composable.decorated_name))
1246                .collect()
1247        });
1248        let items = items.as_ref();
1249
1250        let source_ir = match header {
1251            true => &composable.header_ir,
1252            false => &composable.module_ir,
1253        };
1254
1255        derived.set_shader_source(source_ir, span_offset);
1256
1257        for (h, ty) in source_ir.types.iter() {
1258            if let Some(name) = &ty.name {
1259                if composable.owned_types.contains(name)
1260                    && items.is_none_or(|items| items.contains(name))
1261                {
1262                    derived.import_type(&h);
1263                }
1264            }
1265        }
1266
1267        for (h, c) in source_ir.constants.iter() {
1268            if let Some(name) = &c.name {
1269                if composable.owned_constants.contains(name)
1270                    && items.is_none_or(|items| items.contains(name))
1271                {
1272                    derived.import_const(&h);
1273                }
1274            }
1275        }
1276
1277        for (h, po) in source_ir.overrides.iter() {
1278            if let Some(name) = &po.name {
1279                if composable.owned_functions.contains(name)
1280                    && items.is_none_or(|items| items.contains(name))
1281                {
1282                    derived.import_pipeline_override(&h);
1283                }
1284            }
1285        }
1286
1287        for (h, v) in source_ir.global_variables.iter() {
1288            if let Some(name) = &v.name {
1289                if composable.owned_vars.contains(name)
1290                    && items.is_none_or(|items| items.contains(name))
1291                {
1292                    derived.import_global(&h);
1293                }
1294            }
1295        }
1296
1297        for (h_f, f) in source_ir.functions.iter() {
1298            if let Some(name) = &f.name {
1299                if composable.owned_functions.contains(name)
1300                    && (items.is_none_or(|items| items.contains(name))
1301                        || composable
1302                            .override_functions
1303                            .values()
1304                            .any(|v| v.contains(name)))
1305                {
1306                    let span = composable.module_ir.functions.get_span(h_f);
1307                    derived.import_function_if_new(f, span);
1308                }
1309            }
1310        }
1311
1312        derived.clear_shader_source();
1313    }
1314
1315    // add an import (and recursive imports) into a derived module
1316    fn add_import<'a>(
1317        &'a self,
1318        derived: &mut DerivedModule<'a>,
1319        import: &ImportDefinition,
1320        shader_defs: &HashMap<String, ShaderDefValue>,
1321        header: bool,
1322        already_added: &mut HashSet<String>,
1323    ) {
1324        if already_added.contains(&import.import) {
1325            trace!("skipping {}, already added", import.import);
1326            return;
1327        }
1328
1329        let import_module_set = self.module_sets.get(&import.import).unwrap();
1330        let module = import_module_set.get_module(shader_defs).unwrap();
1331
1332        for import in &module.imports {
1333            self.add_import(derived, import, shader_defs, header, already_added);
1334        }
1335
1336        Self::add_composable_data(
1337            derived,
1338            module,
1339            Some(&import.items),
1340            import_module_set.module_index << SPAN_SHIFT,
1341            header,
1342        );
1343    }
1344
1345    fn ensure_import(
1346        &mut self,
1347        module_set: &ComposableModuleDefinition,
1348        shader_defs: &HashMap<String, ShaderDefValue>,
1349    ) -> Result<ComposableModule, EnsureImportsError> {
1350        let PreprocessOutput {
1351            preprocessed_source,
1352            imports,
1353        } = self
1354            .preprocessor
1355            .preprocess(&module_set.sanitized_source, shader_defs)
1356            .map_err(|inner| {
1357                EnsureImportsError::from(ComposerError {
1358                    inner,
1359                    source: ErrSource::Module {
1360                        name: module_set.name.to_owned(),
1361                        offset: 0,
1362                        defs: shader_defs.clone(),
1363                    },
1364                })
1365            })?;
1366
1367        self.ensure_imports(imports.iter().map(|import| &import.definition), shader_defs)?;
1368        self.ensure_imports(&module_set.additional_imports, shader_defs)?;
1369
1370        self.create_composable_module(
1371            module_set,
1372            Self::decorate(&module_set.name),
1373            shader_defs,
1374            true,
1375            true,
1376            &preprocessed_source,
1377            imports,
1378        )
1379        .map_err(|err| err.into())
1380    }
1381
1382    // build required ComposableModules for a given set of shader_defs
1383    fn ensure_imports<'a>(
1384        &mut self,
1385        imports: impl IntoIterator<Item = &'a ImportDefinition>,
1386        shader_defs: &HashMap<String, ShaderDefValue>,
1387    ) -> Result<(), EnsureImportsError> {
1388        for ImportDefinition { import, .. } in imports.into_iter() {
1389            let Some(module_set) = self.module_sets.get(import) else {
1390                return Err(EnsureImportsError::MissingImport(import.to_owned()));
1391            };
1392            if module_set.get_module(shader_defs).is_some() {
1393                continue;
1394            }
1395
1396            // we need to build the module
1397            // take the set so we can recurse without borrowing
1398            let (set_key, mut module_set) = self.module_sets.remove_entry(import).unwrap();
1399
1400            match self.ensure_import(&module_set, shader_defs) {
1401                Ok(module) => {
1402                    module_set.insert_module(shader_defs, module);
1403                    self.module_sets.insert(set_key, module_set);
1404                }
1405                Err(e) => {
1406                    self.module_sets.insert(set_key, module_set);
1407                    return Err(e);
1408                }
1409            }
1410        }
1411
1412        Ok(())
1413    }
1414}
1415
1416pub enum EnsureImportsError {
1417    MissingImport(String),
1418    ComposerError(ComposerError),
1419}
1420
1421impl EnsureImportsError {
1422    fn into_composer_error(self, err_source: ErrSource) -> ComposerError {
1423        match self {
1424            EnsureImportsError::MissingImport(import) => ComposerError {
1425                inner: ComposerErrorInner::ImportNotFound(import.to_owned(), 0),
1426                source: err_source,
1427            },
1428            EnsureImportsError::ComposerError(err) => err,
1429        }
1430    }
1431}
1432
1433impl From<ComposerError> for EnsureImportsError {
1434    fn from(value: ComposerError) -> Self {
1435        EnsureImportsError::ComposerError(value)
1436    }
1437}
1438
1439#[derive(Default)]
1440pub struct ComposableModuleDescriptor<'a> {
1441    pub source: &'a str,
1442    pub file_path: &'a str,
1443    pub language: ShaderLanguage,
1444    pub as_name: Option<String>,
1445    pub additional_imports: &'a [ImportDefinition],
1446    pub shader_defs: HashMap<String, ShaderDefValue>,
1447}
1448
1449#[derive(Default)]
1450pub struct NagaModuleDescriptor<'a> {
1451    pub source: &'a str,
1452    pub file_path: &'a str,
1453    pub shader_type: ShaderType,
1454    pub shader_defs: HashMap<String, ShaderDefValue>,
1455    pub additional_imports: &'a [ImportDefinition],
1456}
1457
1458// public api
1459impl Composer {
1460    /// create a non-validating composer.
1461    /// validation errors in the final shader will not be caught, and errors resulting from their
1462    /// use will have bad span data, so codespan reporting will fail.
1463    /// use default() to create a validating composer.
1464    pub fn non_validating() -> Self {
1465        Self {
1466            validate: false,
1467            ..Default::default()
1468        }
1469    }
1470
1471    /// specify capabilities to be used for naga module generation.
1472    /// purges any existing modules
1473    /// See https://github.com/gfx-rs/wgpu/blob/d9c054c645af0ea9ef81617c3e762fbf0f3fecda/wgpu-core/src/device/mod.rs#L515
1474    /// for how to set the subgroup_stages value.
1475    pub fn with_capabilities(self, capabilities: naga::valid::Capabilities) -> Self {
1476        Self {
1477            capabilities,
1478            validate: self.validate,
1479            ..Default::default()
1480        }
1481    }
1482
1483    /// check if a module with the given name has been added
1484    pub fn contains_module(&self, module_name: &str) -> bool {
1485        self.module_sets.contains_key(module_name)
1486    }
1487
1488    /// add a composable module to the composer.
1489    /// all modules imported by this module must already have been added
1490    pub fn add_composable_module(
1491        &mut self,
1492        desc: ComposableModuleDescriptor,
1493    ) -> Result<&ComposableModuleDefinition, ComposerError> {
1494        let ComposableModuleDescriptor {
1495            source,
1496            file_path,
1497            language,
1498            as_name,
1499            additional_imports,
1500            mut shader_defs,
1501        } = desc;
1502
1503        // reject a module containing the DECORATION strings
1504        if let Some(decor) = self.check_decoration_regex.find(source) {
1505            return Err(ComposerError {
1506                inner: ComposerErrorInner::DecorationInSource(decor.range()),
1507                source: ErrSource::Constructing {
1508                    path: file_path.to_owned(),
1509                    source: source.to_owned(),
1510                    offset: 0,
1511                },
1512            });
1513        }
1514
1515        let substituted_source = self.sanitize_and_set_auto_bindings(source);
1516
1517        let PreprocessorMetaData {
1518            name: module_name,
1519            mut imports,
1520            mut effective_defs,
1521            ..
1522        } = self
1523            .preprocessor
1524            .get_preprocessor_metadata(&substituted_source, false)
1525            .map_err(|inner| ComposerError {
1526                inner,
1527                source: ErrSource::Constructing {
1528                    path: file_path.to_owned(),
1529                    source: source.to_owned(),
1530                    offset: 0,
1531                },
1532            })?;
1533        let module_name = as_name.or(module_name);
1534        if module_name.is_none() {
1535            return Err(ComposerError {
1536                inner: ComposerErrorInner::NoModuleName,
1537                source: ErrSource::Constructing {
1538                    path: file_path.to_owned(),
1539                    source: source.to_owned(),
1540                    offset: 0,
1541                },
1542            });
1543        }
1544        let module_name = module_name.unwrap();
1545
1546        debug!(
1547            "adding module definition for {} with defs: {:?}",
1548            module_name, shader_defs
1549        );
1550
1551        // add custom imports
1552        let additional_imports = additional_imports.to_vec();
1553        imports.extend(
1554            additional_imports
1555                .iter()
1556                .cloned()
1557                .map(|def| ImportDefWithOffset {
1558                    definition: def,
1559                    offset: 0,
1560                }),
1561        );
1562
1563        for import in &imports {
1564            // we require modules already added so that we can capture the shader_defs that may impact us by impacting our dependencies
1565            let module_set = self
1566                .module_sets
1567                .get(&import.definition.import)
1568                .ok_or_else(|| ComposerError {
1569                    inner: ComposerErrorInner::ImportNotFound(
1570                        import.definition.import.clone(),
1571                        import.offset,
1572                    ),
1573                    source: ErrSource::Constructing {
1574                        path: file_path.to_owned(),
1575                        source: substituted_source.to_owned(),
1576                        offset: 0,
1577                    },
1578                })?;
1579            effective_defs.extend(module_set.effective_defs.iter().cloned());
1580            shader_defs.extend(
1581                module_set
1582                    .shader_defs
1583                    .iter()
1584                    .map(|def| (def.0.clone(), *def.1)),
1585            );
1586        }
1587
1588        // remove defs that are already specified through our imports
1589        effective_defs.retain(|name| !shader_defs.contains_key(name));
1590
1591        // can't gracefully report errors for more modules. perhaps this should be a warning
1592        assert!((self.module_sets.len() as u32) < u32::MAX >> SPAN_SHIFT);
1593        let module_index = self.module_sets.len() + 1;
1594
1595        let module_set = ComposableModuleDefinition {
1596            name: module_name.clone(),
1597            sanitized_source: substituted_source,
1598            file_path: file_path.to_owned(),
1599            language,
1600            effective_defs: effective_defs.into_iter().collect(),
1601            all_imports: imports.into_iter().map(|id| id.definition.import).collect(),
1602            additional_imports,
1603            shader_defs,
1604            module_index,
1605            modules: Default::default(),
1606        };
1607
1608        // invalidate dependent modules if this module already exists
1609        self.remove_composable_module(&module_name);
1610
1611        self.module_sets.insert(module_name.clone(), module_set);
1612        self.module_index.insert(module_index, module_name.clone());
1613        Ok(self.module_sets.get(&module_name).unwrap())
1614    }
1615
1616    /// remove a composable module. also removes modules that depend on this module, as we cannot be sure about
1617    /// the completeness of their effective shader defs any more...
1618    pub fn remove_composable_module(&mut self, module_name: &str) {
1619        // todo this could be improved by making effective defs an Option<HashSet> and populating on demand?
1620        let mut dependent_sets = Vec::new();
1621
1622        if self.module_sets.remove(module_name).is_some() {
1623            dependent_sets.extend(self.module_sets.iter().filter_map(|(dependent_name, set)| {
1624                if set.all_imports.contains(module_name) {
1625                    Some(dependent_name.clone())
1626                } else {
1627                    None
1628                }
1629            }));
1630        }
1631
1632        for dependent_set in dependent_sets {
1633            self.remove_composable_module(&dependent_set);
1634        }
1635    }
1636
1637    /// build a naga shader module
1638    pub fn make_naga_module(
1639        &mut self,
1640        desc: NagaModuleDescriptor,
1641    ) -> Result<naga::Module, ComposerError> {
1642        let NagaModuleDescriptor {
1643            source,
1644            file_path,
1645            shader_type,
1646            mut shader_defs,
1647            additional_imports,
1648        } = desc;
1649
1650        let sanitized_source = self.sanitize_and_set_auto_bindings(source);
1651
1652        let PreprocessorMetaData { name, defines, .. } = self
1653            .preprocessor
1654            .get_preprocessor_metadata(&sanitized_source, true)
1655            .map_err(|inner| ComposerError {
1656                inner,
1657                source: ErrSource::Constructing {
1658                    path: file_path.to_owned(),
1659                    source: sanitized_source.to_owned(),
1660                    offset: 0,
1661                },
1662            })?;
1663        shader_defs.extend(defines);
1664
1665        let name = name.unwrap_or_default();
1666
1667        let PreprocessOutput { imports, .. } = self
1668            .preprocessor
1669            .preprocess(&sanitized_source, &shader_defs)
1670            .map_err(|inner| ComposerError {
1671                inner,
1672                source: ErrSource::Constructing {
1673                    path: file_path.to_owned(),
1674                    source: sanitized_source.to_owned(),
1675                    offset: 0,
1676                },
1677            })?;
1678
1679        // make sure imports have been added
1680        // and gather additional defs specified at module level
1681        for (import_name, offset) in imports
1682            .iter()
1683            .map(|id| (&id.definition.import, id.offset))
1684            .chain(additional_imports.iter().map(|ai| (&ai.import, 0)))
1685        {
1686            if let Some(module_set) = self.module_sets.get(import_name) {
1687                for (def, value) in &module_set.shader_defs {
1688                    if let Some(prior_value) = shader_defs.insert(def.clone(), *value) {
1689                        if prior_value != *value {
1690                            return Err(ComposerError {
1691                                inner: ComposerErrorInner::InconsistentShaderDefValue {
1692                                    def: def.clone(),
1693                                },
1694                                source: ErrSource::Constructing {
1695                                    path: file_path.to_owned(),
1696                                    source: sanitized_source.to_owned(),
1697                                    offset: 0,
1698                                },
1699                            });
1700                        }
1701                    }
1702                }
1703            } else {
1704                return Err(ComposerError {
1705                    inner: ComposerErrorInner::ImportNotFound(import_name.clone(), offset),
1706                    source: ErrSource::Constructing {
1707                        path: file_path.to_owned(),
1708                        source: sanitized_source.to_owned(),
1709                        offset: 0,
1710                    },
1711                });
1712            }
1713        }
1714        self.ensure_imports(
1715            imports.iter().map(|import| &import.definition),
1716            &shader_defs,
1717        )
1718        .map_err(|err| {
1719            err.into_composer_error(ErrSource::Constructing {
1720                path: file_path.to_owned(),
1721                source: sanitized_source.to_owned(),
1722                offset: 0,
1723            })
1724        })?;
1725        self.ensure_imports(additional_imports, &shader_defs)
1726            .map_err(|err| {
1727                err.into_composer_error(ErrSource::Constructing {
1728                    path: file_path.to_owned(),
1729                    source: sanitized_source.to_owned(),
1730                    offset: 0,
1731                })
1732            })?;
1733
1734        let definition = ComposableModuleDefinition {
1735            name,
1736            sanitized_source: sanitized_source.clone(),
1737            language: shader_type.into(),
1738            file_path: file_path.to_owned(),
1739            module_index: 0,
1740            additional_imports: additional_imports.to_vec(),
1741            // we don't care about these for creating a top-level module
1742            effective_defs: Default::default(),
1743            all_imports: Default::default(),
1744            shader_defs: Default::default(),
1745            modules: Default::default(),
1746        };
1747
1748        let PreprocessOutput {
1749            preprocessed_source,
1750            imports,
1751        } = self
1752            .preprocessor
1753            .preprocess(&sanitized_source, &shader_defs)
1754            .map_err(|inner| ComposerError {
1755                inner,
1756                source: ErrSource::Constructing {
1757                    path: file_path.to_owned(),
1758                    source: sanitized_source,
1759                    offset: 0,
1760                },
1761            })?;
1762
1763        let composable = self
1764            .create_composable_module(
1765                &definition,
1766                String::from(""),
1767                &shader_defs,
1768                false,
1769                false,
1770                &preprocessed_source,
1771                imports,
1772            )
1773            .map_err(|e| ComposerError {
1774                inner: e.inner,
1775                source: ErrSource::Constructing {
1776                    path: definition.file_path.to_owned(),
1777                    source: preprocessed_source.clone(),
1778                    offset: e.source.offset(),
1779                },
1780            })?;
1781
1782        let mut derived = DerivedModule::default();
1783
1784        let mut already_added = Default::default();
1785        for import in &composable.imports {
1786            self.add_import(
1787                &mut derived,
1788                import,
1789                &shader_defs,
1790                false,
1791                &mut already_added,
1792            );
1793        }
1794
1795        Self::add_composable_data(&mut derived, &composable, None, 0, false);
1796
1797        let stage = match shader_type {
1798            #[cfg(feature = "glsl")]
1799            ShaderType::GlslVertex => Some(naga::ShaderStage::Vertex),
1800            #[cfg(feature = "glsl")]
1801            ShaderType::GlslFragment => Some(naga::ShaderStage::Fragment),
1802            _ => None,
1803        };
1804
1805        let mut entry_points = Vec::default();
1806        derived.set_shader_source(&composable.module_ir, 0);
1807        for ep in &composable.module_ir.entry_points {
1808            let mapped_func = derived.localize_function(&ep.function);
1809            entry_points.push(EntryPoint {
1810                name: ep.name.clone(),
1811                function: mapped_func,
1812                stage: stage.unwrap_or(ep.stage),
1813                early_depth_test: ep.early_depth_test,
1814                workgroup_size: ep.workgroup_size,
1815                workgroup_size_overrides: ep.workgroup_size_overrides,
1816            });
1817        }
1818        let mut naga_module = naga::Module {
1819            entry_points,
1820            ..derived.into()
1821        };
1822
1823        // apply overrides
1824        if !composable.override_functions.is_empty() {
1825            let mut redirect = Redirector::new(naga_module);
1826
1827            for (base_function, overrides) in composable.override_functions {
1828                let mut omit = HashSet::default();
1829
1830                let mut original = base_function;
1831                for replacement in overrides {
1832                    let (_h_orig, _h_replace) = redirect
1833                        .redirect_function(&original, &replacement, &omit)
1834                        .map_err(|e| ComposerError {
1835                            inner: e.into(),
1836                            source: ErrSource::Constructing {
1837                                path: file_path.to_owned(),
1838                                source: preprocessed_source.clone(),
1839                                offset: composable.start_offset,
1840                            },
1841                        })?;
1842                    omit.insert(replacement.clone());
1843                    original = replacement;
1844                }
1845            }
1846
1847            naga_module = redirect.into_module().map_err(|e| ComposerError {
1848                inner: e.into(),
1849                source: ErrSource::Constructing {
1850                    path: file_path.to_owned(),
1851                    source: preprocessed_source.clone(),
1852                    offset: composable.start_offset,
1853                },
1854            })?;
1855        }
1856
1857        // validation
1858        if self.validate {
1859            let info = self.create_validator().validate(&naga_module);
1860            match info {
1861                Ok(_) => Ok(naga_module),
1862                Err(e) => {
1863                    let original_span = e.spans().last();
1864                    let err_source = match original_span.and_then(|(span, _)| span.to_range()) {
1865                        Some(rng) => {
1866                            let module_index = rng.start >> SPAN_SHIFT;
1867                            match module_index {
1868                                0 => ErrSource::Constructing {
1869                                    path: file_path.to_owned(),
1870                                    source: preprocessed_source.clone(),
1871                                    offset: composable.start_offset,
1872                                },
1873                                _ => {
1874                                    let module_name =
1875                                        self.module_index.get(&module_index).unwrap().clone();
1876                                    let offset = self
1877                                        .module_sets
1878                                        .get(&module_name)
1879                                        .unwrap()
1880                                        .get_module(&shader_defs)
1881                                        .unwrap()
1882                                        .start_offset;
1883                                    ErrSource::Module {
1884                                        name: module_name,
1885                                        offset,
1886                                        defs: shader_defs.clone(),
1887                                    }
1888                                }
1889                            }
1890                        }
1891                        None => ErrSource::Constructing {
1892                            path: file_path.to_owned(),
1893                            source: preprocessed_source.clone(),
1894                            offset: composable.start_offset,
1895                        },
1896                    };
1897
1898                    Err(ComposerError {
1899                        inner: ComposerErrorInner::ShaderValidationError(e),
1900                        source: err_source,
1901                    })
1902                }
1903            }
1904        } else {
1905            Ok(naga_module)
1906        }
1907    }
1908}
1909
1910static PREPROCESSOR: once_cell::sync::Lazy<Preprocessor> =
1911    once_cell::sync::Lazy::new(Preprocessor::default);
1912
1913/// Get module name and all required imports (ignoring shader_defs) from a shader string
1914pub fn get_preprocessor_data(
1915    source: &str,
1916) -> (
1917    Option<String>,
1918    Vec<ImportDefinition>,
1919    HashMap<String, ShaderDefValue>,
1920) {
1921    if let Ok(PreprocessorMetaData {
1922        name,
1923        imports,
1924        defines,
1925        ..
1926    }) = PREPROCESSOR.get_preprocessor_metadata(source, true)
1927    {
1928        (
1929            name,
1930            imports
1931                .into_iter()
1932                .map(|import_with_offset| import_with_offset.definition)
1933                .collect(),
1934            defines,
1935        )
1936    } else {
1937        // if errors occur we return nothing; the actual error will be displayed when the caller attempts to use the shader
1938        Default::default()
1939    }
1940}