naga_oil/compose/
mod.rs

1use indexmap::IndexMap;
2/// the compose module allows construction of shaders from modules (which are themselves shaders).
3///
4/// it does this by treating shaders as modules, and
5/// - building each module independently to naga IR
6/// - creating "header" files for each supported language, which are used to build dependent modules/shaders
7/// - making final shaders by combining the shader IR with the IR for imported modules
8///
9/// for multiple small shaders with large common imports, this can be faster than parsing the full source for each shader, and it allows for constructing shaders in a cleaner modular manner with better scope control.
10///
11/// ## imports
12///
13/// shaders can be added to the composer as modules. this makes their types, constants, variables and functions available to modules/shaders that import them. note that importing a module will affect the final shader's global state if the module defines globals variables with bindings.
14///
15/// modules must include a `#define_import_path` directive that names the module.
16///
17/// ```ignore
18/// #define_import_path my_module
19///
20/// fn my_func() -> f32 {
21///     return 1.0;
22/// }
23/// ```
24///
25/// shaders can then import the module with an `#import` directive (with an optional `as` name). at point of use, imported items must be qualified:
26///
27/// ```ignore
28/// #import my_module
29/// #import my_other_module as Mod2
30///
31/// fn main() -> f32 {
32///     let x = my_module::my_func();
33///     let y = Mod2::my_other_func();
34///     return x*y;
35/// }
36/// ```
37///
38/// or import a comma-separated list of individual items with a `#from` directive. at point of use, imported items must be prefixed with `::` :
39///
40/// ```ignore
41/// #from my_module import my_func, my_const
42///
43/// fn main() -> f32 {
44///     return ::my_func(::my_const);
45/// }
46/// ```
47///
48/// imports can be nested - modules may import other modules, but not recursively. when a new module is added, all its `#import`s must already have been added.
49/// the same module can be imported multiple times by different modules in the import tree.
50/// there is no overlap of namespaces, so the same function names (or type, constant, or variable names) may be used in different modules.
51///
52/// note: when importing an item with the `#from` directive, the final shader will include the required dependencies (bindings, globals, consts, other functions) of the imported item, but will not include the rest of the imported module. it will however still include all of any modules imported by the imported module. this is probably not desired in general and may be fixed in a future version. currently for a more complete culling of unused dependencies the `prune` module can be used.
53///
54/// ## overriding functions
55///
56/// virtual functions can be declared with the `virtual` keyword:
57/// ```ignore
58/// virtual fn point_light(world_position: vec3<f32>) -> vec3<f32> { ... }
59/// ```
60/// virtual functions defined in imported modules can then be overridden using the `override` keyword:
61///
62/// ```ignore
63/// #import bevy_pbr::lighting as Lighting
64///
65/// override fn Lighting::point_light (world_position: vec3<f32>) -> vec3<f32> {
66///     let original = Lighting::point_light(world_position);
67///     let quantized = vec3<u32>(original * 3.0);
68///     return vec3<f32>(quantized) / 3.0;
69/// }
70/// ```
71///
72/// override function definitions cause *all* calls to the original function in the entire shader scope to be replaced by calls to the new function, with the exception of calls within the override function itself.
73///
74/// the function signature of the override must match the base function.
75///
76/// overrides can be specified at any point in the final shader's import tree.
77///
78/// multiple overrides can be applied to the same function. for example, given :
79/// - a module `a` containing a function `f`,
80/// - a module `b` that imports `a`, and containing an `override a::f` function,
81/// - a module `c` that imports `a` and `b`, and containing an `override a::f` function,
82///
83/// then b and c both specify an override for `a::f`.
84/// the `override fn a::f` declared in module `b` may call to `a::f` within its body.
85/// the `override fn a::f` declared in module 'c' may call to `a::f` within its body, but the call will be redirected to `b::f`.
86/// any other calls to `a::f` (within modules 'a' or `b`, or anywhere else) will end up redirected to `c::f`
87/// in this way a chain or stack of overrides can be applied.
88///
89/// different overrides of the same function can be specified in different import branches. the final stack will be ordered based on the first occurrence of the override in the import tree (using a depth first search).
90///
91/// note that imports into a module/shader are processed in order, but are processed before the body of the current shader/module regardless of where they occur in that module, so there is no way to import a module containing an override and inject a call into the override stack prior to that imported override. you can instead create two modules each containing an override and import them into a parent module/shader to order them as required.
92/// override functions can currently only be defined in wgsl.
93///
94/// if the `override_any` crate feature is enabled, then the `virtual` keyword is not required for the function being overridden.
95///
96/// ## languages
97///
98/// modules can we written in GLSL or WGSL. shaders with entry points can be imported as modules (provided they have a `#define_import_path` directive). entry points are available to call from imported modules either via their name (for WGSL) or via `module::main` (for GLSL).
99///
100/// final shaders can also be written in GLSL or WGSL. for GLSL users must specify whether the shader is a vertex shader or fragment shader via the `ShaderType` argument (GLSL compute shaders are not supported).
101///
102/// ## preprocessing
103///
104/// when generating a final shader or adding a composable module, a set of `shader_def` string/value pairs must be provided. The value can be a bool (`ShaderDefValue::Bool`), an i32 (`ShaderDefValue::Int`) or a u32 (`ShaderDefValue::UInt`).
105///
106/// these allow conditional compilation of parts of modules and the final shader. conditional compilation is performed with `#if` / `#ifdef` / `#ifndef`, `#else` and `#endif` preprocessor directives:
107///
108/// ```ignore
109/// fn get_number() -> f32 {
110///     #ifdef BIG_NUMBER
111///         return 999.0;
112///     #else
113///         return 0.999;
114///     #endif
115/// }
116/// ```
117/// the `#ifdef` directive matches when the def name exists in the input binding set (regardless of value). the `#ifndef` directive is the reverse.
118///
119/// the `#if` directive requires a def name, an operator, and a value for comparison:
120/// - the def name must be a provided `shader_def` name.
121/// - the operator must be one of `==`, `!=`, `>=`, `>`, `<`, `<=`
122/// - the value must be an integer literal if comparing to a `ShaderDef::Int`, or `true` or `false` if comparing to a `ShaderDef::Bool`.
123///
124/// shader defs can also be used in the shader source with `#SHADER_DEF` or `#{SHADER_DEF}`, and will be substituted for their value.
125///
126/// ## error reporting
127///
128/// codespan reporting for errors is available using the error `emit_to_string` method. this requires validation to be enabled, which is true by default. `Composer::non_validating()` produces a non-validating composer that is not able to give accurate error reporting.
129///
130use naga::EntryPoint;
131use regex::Regex;
132use std::collections::{hash_map::Entry, BTreeMap, HashMap, HashSet};
133use tracing::{debug, trace};
134
135use crate::{
136    compose::preprocess::{PreprocessOutput, PreprocessorMetaData},
137    derive::DerivedModule,
138    redirect::Redirector,
139};
140
141pub use self::error::{ComposerError, ComposerErrorInner, ErrSource};
142use self::preprocess::Preprocessor;
143
144pub mod comment_strip_iter;
145pub mod error;
146pub mod parse_imports;
147pub mod preprocess;
148mod test;
149pub mod tokenizer;
150
151#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug, Default)]
152pub enum ShaderLanguage {
153    #[default]
154    Wgsl,
155    #[cfg(feature = "glsl")]
156    Glsl,
157}
158
159#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug, Default)]
160pub enum ShaderType {
161    #[default]
162    Wgsl,
163    #[cfg(feature = "glsl")]
164    GlslVertex,
165    #[cfg(feature = "glsl")]
166    GlslFragment,
167}
168
169impl From<ShaderType> for ShaderLanguage {
170    fn from(ty: ShaderType) -> Self {
171        match ty {
172            ShaderType::Wgsl => ShaderLanguage::Wgsl,
173            #[cfg(feature = "glsl")]
174            ShaderType::GlslVertex | ShaderType::GlslFragment => ShaderLanguage::Glsl,
175        }
176    }
177}
178
179#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
180pub enum ShaderDefValue {
181    Bool(bool),
182    Int(i32),
183    UInt(u32),
184}
185
186impl Default for ShaderDefValue {
187    fn default() -> Self {
188        ShaderDefValue::Bool(true)
189    }
190}
191
192impl ShaderDefValue {
193    fn value_as_string(&self) -> String {
194        match self {
195            ShaderDefValue::Bool(val) => val.to_string(),
196            ShaderDefValue::Int(val) => val.to_string(),
197            ShaderDefValue::UInt(val) => val.to_string(),
198        }
199    }
200}
201
202#[derive(Clone, PartialEq, Eq, Hash, Debug, Default)]
203pub struct OwnedShaderDefs(BTreeMap<String, ShaderDefValue>);
204
205#[derive(Clone, PartialEq, Eq, Hash, Debug)]
206struct ModuleKey(OwnedShaderDefs);
207
208impl ModuleKey {
209    fn from_members(key: &HashMap<String, ShaderDefValue>, universe: &[String]) -> Self {
210        let mut acc = OwnedShaderDefs::default();
211        for item in universe {
212            if let Some(value) = key.get(item) {
213                acc.0.insert(item.to_owned(), *value);
214            }
215        }
216        ModuleKey(acc)
217    }
218}
219
220// a module built with a specific set of shader_defs
221#[derive(Default, Debug)]
222pub struct ComposableModule {
223    // module decoration, prefixed to all items from this module in the final source
224    pub decorated_name: String,
225    // module names required as imports, optionally with a list of items to import
226    pub imports: Vec<ImportDefinition>,
227    // types exported
228    pub owned_types: HashSet<String>,
229    // constants exported
230    pub owned_constants: HashSet<String>,
231    // vars exported
232    pub owned_vars: HashSet<String>,
233    // functions exported
234    pub owned_functions: HashSet<String>,
235    // local functions that can be overridden
236    pub virtual_functions: HashSet<String>,
237    // overriding functions defined in this module
238    // target function -> Vec<replacement functions>
239    pub override_functions: IndexMap<String, Vec<String>>,
240    // naga module, built against headers for any imports
241    module_ir: naga::Module,
242    // headers in different shader languages, used for building modules/shaders that import this module
243    // headers contain types, constants, global vars and empty function definitions -
244    // just enough to convert source strings that want to import this module into naga IR
245    // headers: HashMap<ShaderLanguage, String>,
246    header_ir: naga::Module,
247    // character offset of the start of the owned module string
248    start_offset: usize,
249}
250
251// data used to build a ComposableModule
252#[derive(Debug)]
253pub struct ComposableModuleDefinition {
254    pub name: String,
255    // shader text (with auto bindings replaced - we do this on module add as we only want to do it once to avoid burning slots)
256    pub sanitized_source: String,
257    // language
258    pub language: ShaderLanguage,
259    // source path for error display
260    pub file_path: String,
261    // shader def values bound to this module
262    pub shader_defs: HashMap<String, ShaderDefValue>,
263    // list of shader_defs that can affect this module
264    effective_defs: Vec<String>,
265    // full list of possible imports (regardless of shader_def configuration)
266    all_imports: HashSet<String>,
267    // additional imports to add (as though they were included in the source after any other imports)
268    additional_imports: Vec<ImportDefinition>,
269    // built composable modules for a given set of shader defs
270    modules: HashMap<ModuleKey, ComposableModule>,
271    // used in spans when this module is included
272    module_index: usize,
273    // preprocessor meta data
274    // metadata: PreprocessorMetaData,
275}
276
277impl ComposableModuleDefinition {
278    fn get_module(
279        &self,
280        shader_defs: &HashMap<String, ShaderDefValue>,
281    ) -> Option<&ComposableModule> {
282        self.modules
283            .get(&ModuleKey::from_members(shader_defs, &self.effective_defs))
284    }
285
286    fn insert_module(
287        &mut self,
288        shader_defs: &HashMap<String, ShaderDefValue>,
289        module: ComposableModule,
290    ) -> &ComposableModule {
291        match self
292            .modules
293            .entry(ModuleKey::from_members(shader_defs, &self.effective_defs))
294        {
295            Entry::Occupied(_) => panic!("entry already populated"),
296            Entry::Vacant(v) => v.insert(module),
297        }
298    }
299}
300
301#[derive(Debug, Clone, Default, PartialEq, Eq)]
302pub struct ImportDefinition {
303    pub import: String,
304    pub items: Vec<String>,
305}
306
307#[derive(Debug, Clone)]
308pub struct ImportDefWithOffset {
309    definition: ImportDefinition,
310    offset: usize,
311}
312
313/// module composer.
314/// stores any modules that can be imported into a shader
315/// and builds the final shader
316#[derive(Debug)]
317pub struct Composer {
318    pub validate: bool,
319    pub module_sets: HashMap<String, ComposableModuleDefinition>,
320    pub module_index: HashMap<usize, String>,
321    pub capabilities: naga::valid::Capabilities,
322    preprocessor: Preprocessor,
323    check_decoration_regex: Regex,
324    undecorate_regex: Regex,
325    virtual_fn_regex: Regex,
326    override_fn_regex: Regex,
327    undecorate_override_regex: Regex,
328    auto_binding_regex: Regex,
329    auto_binding_index: u32,
330}
331
332// shift for module index
333// 21 gives
334//   max size for shader of 2m characters
335//   max 2048 modules
336const SPAN_SHIFT: usize = 21;
337
338impl Default for Composer {
339    fn default() -> Self {
340        Self {
341            validate: true,
342            capabilities: Default::default(),
343            module_sets: Default::default(),
344            module_index: Default::default(),
345            preprocessor: Preprocessor::default(),
346            check_decoration_regex: Regex::new(
347                format!(
348                    "({}|{})",
349                    regex_syntax::escape(DECORATION_PRE),
350                    regex_syntax::escape(DECORATION_OVERRIDE_PRE)
351                )
352                .as_str(),
353            )
354            .unwrap(),
355            undecorate_regex: Regex::new(
356                format!(
357                    r"(\x1B\[\d+\w)?([\w\d_]+){}([A-Z0-9]*){}",
358                    regex_syntax::escape(DECORATION_PRE),
359                    regex_syntax::escape(DECORATION_POST)
360                )
361                .as_str(),
362            )
363            .unwrap(),
364            virtual_fn_regex: Regex::new(
365                r"(?P<lead>[\s]*virtual\s+fn\s+)(?P<function>[^\s]+)(?P<trail>\s*)\(",
366            )
367            .unwrap(),
368            override_fn_regex: Regex::new(
369                format!(
370                    r"(override\s+fn\s+)([^\s]+){}([\w\d]+){}(\s*)\(",
371                    regex_syntax::escape(DECORATION_PRE),
372                    regex_syntax::escape(DECORATION_POST)
373                )
374                .as_str(),
375            )
376            .unwrap(),
377            undecorate_override_regex: Regex::new(
378                format!(
379                    "{}([A-Z0-9]*){}",
380                    regex_syntax::escape(DECORATION_OVERRIDE_PRE),
381                    regex_syntax::escape(DECORATION_POST)
382                )
383                .as_str(),
384            )
385            .unwrap(),
386            auto_binding_regex: Regex::new(r"@binding\(auto\)").unwrap(),
387            auto_binding_index: 0,
388        }
389    }
390}
391
392const DECORATION_PRE: &str = "X_naga_oil_mod_X";
393const DECORATION_POST: &str = "X";
394
395// must be same length as DECORATION_PRE for spans to work
396const DECORATION_OVERRIDE_PRE: &str = "X_naga_oil_vrt_X";
397
398struct IrBuildResult {
399    module: naga::Module,
400    start_offset: usize,
401    override_functions: IndexMap<String, Vec<String>>,
402}
403
404impl Composer {
405    pub fn decorated_name(module_name: Option<&str>, item_name: &str) -> String {
406        match module_name {
407            Some(module_name) => format!("{}{}", item_name, Self::decorate(module_name)),
408            None => item_name.to_owned(),
409        }
410    }
411
412    fn decorate(module: &str) -> String {
413        let encoded = data_encoding::BASE32_NOPAD.encode(module.as_bytes());
414        format!("{DECORATION_PRE}{encoded}{DECORATION_POST}")
415    }
416
417    fn decode(from: &str) -> String {
418        String::from_utf8(data_encoding::BASE32_NOPAD.decode(from.as_bytes()).unwrap()).unwrap()
419    }
420
421    /// Shorthand for creating a naga validator.
422    fn create_validator(&self) -> naga::valid::Validator {
423        naga::valid::Validator::new(naga::valid::ValidationFlags::all(), self.capabilities)
424    }
425
426    fn undecorate(&self, string: &str) -> String {
427        let undecor = self
428            .undecorate_regex
429            .replace_all(string, |caps: &regex::Captures| {
430                format!(
431                    "{}{}::{}",
432                    caps.get(1).map(|cc| cc.as_str()).unwrap_or(""),
433                    Self::decode(caps.get(3).unwrap().as_str()),
434                    caps.get(2).unwrap().as_str()
435                )
436            });
437
438        let undecor =
439            self.undecorate_override_regex
440                .replace_all(&undecor, |caps: &regex::Captures| {
441                    format!(
442                        "override fn {}::",
443                        Self::decode(caps.get(1).unwrap().as_str())
444                    )
445                });
446
447        undecor.to_string()
448    }
449
450    fn sanitize_and_set_auto_bindings(&mut self, source: &str) -> String {
451        let mut substituted_source = source.replace("\r\n", "\n").replace('\r', "\n");
452        if !substituted_source.ends_with('\n') {
453            substituted_source.push('\n');
454        }
455
456        // replace @binding(auto) with an incrementing index
457        struct AutoBindingReplacer<'a> {
458            auto: &'a mut u32,
459        }
460
461        impl regex::Replacer for AutoBindingReplacer<'_> {
462            fn replace_append(&mut self, _: &regex::Captures<'_>, dst: &mut String) {
463                dst.push_str(&format!("@binding({})", self.auto));
464                *self.auto += 1;
465            }
466        }
467
468        let substituted_source = self.auto_binding_regex.replace_all(
469            &substituted_source,
470            AutoBindingReplacer {
471                auto: &mut self.auto_binding_index,
472            },
473        );
474
475        substituted_source.into_owned()
476    }
477
478    fn naga_to_string(
479        &self,
480        naga_module: &mut naga::Module,
481        language: ShaderLanguage,
482        #[allow(unused)] header_for: &str, // Only used when GLSL is enabled
483    ) -> Result<String, ComposerErrorInner> {
484        // TODO: cache headers again
485        let info = self
486            .create_validator()
487            .validate(naga_module)
488            .map_err(ComposerErrorInner::HeaderValidationError)?;
489
490        match language {
491            ShaderLanguage::Wgsl => naga::back::wgsl::write_string(
492                naga_module,
493                &info,
494                naga::back::wgsl::WriterFlags::EXPLICIT_TYPES,
495            )
496            .map_err(ComposerErrorInner::WgslBackError),
497            #[cfg(feature = "glsl")]
498            ShaderLanguage::Glsl => {
499                let vec4 = naga_module.types.insert(
500                    naga::Type {
501                        name: None,
502                        inner: naga::TypeInner::Vector {
503                            size: naga::VectorSize::Quad,
504                            scalar: naga::Scalar::F32,
505                        },
506                    },
507                    naga::Span::UNDEFINED,
508                );
509                // add a dummy entry point for glsl headers
510                let dummy_entry_point = "dummy_module_entry_point".to_owned();
511                let func = naga::Function {
512                    name: Some(dummy_entry_point.clone()),
513                    arguments: Default::default(),
514                    result: Some(naga::FunctionResult {
515                        ty: vec4,
516                        binding: Some(naga::Binding::BuiltIn(naga::BuiltIn::Position {
517                            invariant: false,
518                        })),
519                    }),
520                    local_variables: Default::default(),
521                    expressions: Default::default(),
522                    named_expressions: Default::default(),
523                    body: Default::default(),
524                    diagnostic_filter_leaf: Default::default(),
525                };
526                let ep = EntryPoint {
527                    name: dummy_entry_point.clone(),
528                    stage: naga::ShaderStage::Vertex,
529                    function: func,
530                    early_depth_test: None,
531                    workgroup_size: [0, 0, 0],
532                    workgroup_size_overrides: None,
533                };
534
535                naga_module.entry_points.push(ep);
536
537                let info = self
538                    .create_validator()
539                    .validate(naga_module)
540                    .map_err(ComposerErrorInner::HeaderValidationError)?;
541
542                let mut string = String::new();
543                let options = naga::back::glsl::Options {
544                    version: naga::back::glsl::Version::Desktop(450),
545                    writer_flags: naga::back::glsl::WriterFlags::INCLUDE_UNUSED_ITEMS,
546                    ..Default::default()
547                };
548                let pipeline_options = naga::back::glsl::PipelineOptions {
549                    shader_stage: naga::ShaderStage::Vertex,
550                    entry_point: dummy_entry_point,
551                    multiview: None,
552                };
553                let mut writer = naga::back::glsl::Writer::new(
554                    &mut string,
555                    naga_module,
556                    &info,
557                    &options,
558                    &pipeline_options,
559                    naga::proc::BoundsCheckPolicies::default(),
560                )
561                .map_err(ComposerErrorInner::GlslBackError)?;
562
563                writer.write().map_err(ComposerErrorInner::GlslBackError)?;
564
565                // strip version decl and main() impl
566                let lines: Vec<_> = string.lines().collect();
567                let string = lines[1..lines.len() - 3].join("\n");
568                trace!("glsl header for {}:\n\"\n{:?}\n\"", header_for, string);
569
570                Ok(string)
571            }
572        }
573    }
574
575    // build naga module for a given shader_def configuration. builds a minimal self-contained module built against headers for imports
576    fn create_module_ir(
577        &self,
578        name: &str,
579        source: String,
580        language: ShaderLanguage,
581        imports: &[ImportDefinition],
582        shader_defs: &HashMap<String, ShaderDefValue>,
583    ) -> Result<IrBuildResult, ComposerError> {
584        debug!("creating IR for {} with defs: {:?}", name, shader_defs);
585
586        let mut module_string = match language {
587            ShaderLanguage::Wgsl => String::new(),
588            #[cfg(feature = "glsl")]
589            ShaderLanguage::Glsl => String::from("#version 450\n"),
590        };
591
592        let mut override_functions: IndexMap<String, Vec<String>> = IndexMap::default();
593        let mut added_imports: HashSet<String> = HashSet::new();
594        let mut header_module = DerivedModule::default();
595
596        for import in imports {
597            if added_imports.contains(&import.import) {
598                continue;
599            }
600            // add to header module
601            self.add_import(
602                &mut header_module,
603                import,
604                shader_defs,
605                true,
606                &mut added_imports,
607            );
608
609            // // we must have ensured these exist with Composer::ensure_imports()
610            trace!("looking for {}", import.import);
611            let import_module_set = self.module_sets.get(&import.import).unwrap();
612            trace!("with defs {:?}", shader_defs);
613            let module = import_module_set.get_module(shader_defs).unwrap();
614            trace!("ok");
615
616            // gather overrides
617            if !module.override_functions.is_empty() {
618                for (original, replacements) in &module.override_functions {
619                    match override_functions.entry(original.clone()) {
620                        indexmap::map::Entry::Occupied(o) => {
621                            let existing = o.into_mut();
622                            let new_replacements: Vec<_> = replacements
623                                .iter()
624                                .filter(|rep| !existing.contains(rep))
625                                .cloned()
626                                .collect();
627                            existing.extend(new_replacements);
628                        }
629                        indexmap::map::Entry::Vacant(v) => {
630                            v.insert(replacements.clone());
631                        }
632                    }
633                }
634            }
635        }
636
637        let composed_header = self
638            .naga_to_string(&mut header_module.into(), language, name)
639            .map_err(|inner| ComposerError {
640                inner,
641                source: ErrSource::Module {
642                    name: name.to_owned(),
643                    offset: 0,
644                    defs: shader_defs.clone(),
645                },
646            })?;
647        module_string.push_str(&composed_header);
648
649        let start_offset = module_string.len();
650
651        module_string.push_str(&source);
652
653        trace!(
654            "parsing {}: {}, header len {}, total len {}",
655            name,
656            module_string,
657            start_offset,
658            module_string.len()
659        );
660        let module = match language {
661            ShaderLanguage::Wgsl => naga::front::wgsl::parse_str(&module_string).map_err(|e| {
662                debug!("full err'd source file: \n---\n{}\n---", module_string);
663                ComposerError {
664                    inner: ComposerErrorInner::WgslParseError(e),
665                    source: ErrSource::Module {
666                        name: name.to_owned(),
667                        offset: start_offset,
668                        defs: shader_defs.clone(),
669                    },
670                }
671            })?,
672            #[cfg(feature = "glsl")]
673            ShaderLanguage::Glsl => naga::front::glsl::Frontend::default()
674                .parse(
675                    &naga::front::glsl::Options {
676                        stage: naga::ShaderStage::Vertex,
677                        defines: Default::default(),
678                    },
679                    &module_string,
680                )
681                .map_err(|e| {
682                    debug!("full err'd source file: \n---\n{}\n---", module_string);
683                    ComposerError {
684                        inner: ComposerErrorInner::GlslParseError(e),
685                        source: ErrSource::Module {
686                            name: name.to_owned(),
687                            offset: start_offset,
688                            defs: shader_defs.clone(),
689                        },
690                    }
691                })?,
692        };
693
694        Ok(IrBuildResult {
695            module,
696            start_offset,
697            override_functions,
698        })
699    }
700
701    // check that identifiers exported by a module do not get modified in string export
702    fn validate_identifiers(
703        source_ir: &naga::Module,
704        lang: ShaderLanguage,
705        header: &str,
706        module_decoration: &str,
707        owned_types: &HashSet<String>,
708    ) -> Result<(), ComposerErrorInner> {
709        // TODO: remove this once glsl front support is complete
710        #[cfg(feature = "glsl")]
711        if lang == ShaderLanguage::Glsl {
712            return Ok(());
713        }
714
715        let recompiled = match lang {
716            ShaderLanguage::Wgsl => naga::front::wgsl::parse_str(header).unwrap(),
717            #[cfg(feature = "glsl")]
718            ShaderLanguage::Glsl => naga::front::glsl::Frontend::default()
719                .parse(
720                    &naga::front::glsl::Options {
721                        stage: naga::ShaderStage::Vertex,
722                        defines: Default::default(),
723                    },
724                    &format!("{}\n{}", header, "void main() {}"),
725                )
726                .map_err(|e| {
727                    debug!("full err'd source file: \n---\n{header}\n---");
728                    ComposerErrorInner::GlslParseError(e)
729                })?,
730        };
731
732        let recompiled_types: IndexMap<_, _> = recompiled
733            .types
734            .iter()
735            .flat_map(|(h, ty)| ty.name.as_deref().map(|name| (name, h)))
736            .collect();
737        for (h, ty) in source_ir.types.iter() {
738            if let Some(name) = &ty.name {
739                let decorated_type_name = format!("{name}{module_decoration}");
740                if !owned_types.contains(&decorated_type_name) {
741                    continue;
742                }
743                match recompiled_types.get(decorated_type_name.as_str()) {
744                    Some(recompiled_h) => {
745                        if let naga::TypeInner::Struct { members, .. } = &ty.inner {
746                            let recompiled_ty = recompiled.types.get_handle(*recompiled_h).unwrap();
747                            let naga::TypeInner::Struct {
748                                members: recompiled_members,
749                                ..
750                            } = &recompiled_ty.inner
751                            else {
752                                panic!();
753                            };
754                            for (member, recompiled_member) in
755                                members.iter().zip(recompiled_members)
756                            {
757                                if member.name != recompiled_member.name {
758                                    return Err(ComposerErrorInner::InvalidIdentifier {
759                                        original: member.name.clone().unwrap_or_default(),
760                                        at: source_ir.types.get_span(h),
761                                    });
762                                }
763                            }
764                        }
765                    }
766                    None => {
767                        return Err(ComposerErrorInner::InvalidIdentifier {
768                            original: name.clone(),
769                            at: source_ir.types.get_span(h),
770                        })
771                    }
772                }
773            }
774        }
775
776        let recompiled_consts: HashSet<_> = recompiled
777            .constants
778            .iter()
779            .flat_map(|(_, c)| c.name.as_deref())
780            .filter(|name| name.ends_with(module_decoration))
781            .collect();
782        for (h, c) in source_ir.constants.iter() {
783            if let Some(name) = &c.name {
784                if name.ends_with(module_decoration) && !recompiled_consts.contains(name.as_str()) {
785                    return Err(ComposerErrorInner::InvalidIdentifier {
786                        original: name.clone(),
787                        at: source_ir.constants.get_span(h),
788                    });
789                }
790            }
791        }
792
793        let recompiled_globals: HashSet<_> = recompiled
794            .global_variables
795            .iter()
796            .flat_map(|(_, c)| c.name.as_deref())
797            .filter(|name| name.ends_with(module_decoration))
798            .collect();
799        for (h, gv) in source_ir.global_variables.iter() {
800            if let Some(name) = &gv.name {
801                if name.ends_with(module_decoration) && !recompiled_globals.contains(name.as_str())
802                {
803                    return Err(ComposerErrorInner::InvalidIdentifier {
804                        original: name.clone(),
805                        at: source_ir.global_variables.get_span(h),
806                    });
807                }
808            }
809        }
810
811        let recompiled_fns: HashSet<_> = recompiled
812            .functions
813            .iter()
814            .flat_map(|(_, c)| c.name.as_deref())
815            .filter(|name| name.ends_with(module_decoration))
816            .collect();
817        for (h, f) in source_ir.functions.iter() {
818            if let Some(name) = &f.name {
819                if name.ends_with(module_decoration) && !recompiled_fns.contains(name.as_str()) {
820                    return Err(ComposerErrorInner::InvalidIdentifier {
821                        original: name.clone(),
822                        at: source_ir.functions.get_span(h),
823                    });
824                }
825            }
826        }
827
828        Ok(())
829    }
830
831    // build a ComposableModule from a ComposableModuleDefinition, for a given set of shader defs
832    // - build the naga IR (against headers)
833    // - record any types/vars/constants/functions that are defined within this module
834    // - build headers for each supported language
835    #[allow(clippy::too_many_arguments)]
836    fn create_composable_module(
837        &mut self,
838        module_definition: &ComposableModuleDefinition,
839        module_decoration: String,
840        shader_defs: &HashMap<String, ShaderDefValue>,
841        create_headers: bool,
842        demote_entrypoints: bool,
843        source: &str,
844        imports: Vec<ImportDefWithOffset>,
845    ) -> Result<ComposableModule, ComposerError> {
846        let mut imports: Vec<_> = imports
847            .into_iter()
848            .map(|import_with_offset| import_with_offset.definition)
849            .collect();
850        imports.extend(module_definition.additional_imports.to_vec());
851
852        trace!(
853            "create composable module {}: source len {}",
854            module_definition.name,
855            source.len()
856        );
857
858        // record virtual/overridable functions
859        let mut virtual_functions: HashSet<String> = Default::default();
860        let source = self
861            .virtual_fn_regex
862            .replace_all(source, |cap: &regex::Captures| {
863                let target_function = cap.get(2).unwrap().as_str().to_owned();
864
865                let replacement_str = format!(
866                    "{}fn {}{}(",
867                    " ".repeat(cap.get(1).unwrap().range().len() - 3),
868                    target_function,
869                    " ".repeat(cap.get(3).unwrap().range().len()),
870                );
871
872                virtual_functions.insert(target_function);
873
874                replacement_str
875            });
876
877        // record and rename override functions
878        let mut local_override_functions: IndexMap<String, String> = Default::default();
879
880        #[cfg(not(feature = "override_any"))]
881        let mut override_error = None;
882
883        let source =
884            self.override_fn_regex
885                .replace_all(&source, |cap: &regex::Captures| {
886                    let target_module = cap.get(3).unwrap().as_str().to_owned();
887                    let target_function = cap.get(2).unwrap().as_str().to_owned();
888
889                    #[cfg(not(feature = "override_any"))]
890                    {
891                        let wrap_err = |inner: ComposerErrorInner| -> ComposerError {
892                            ComposerError {
893                                inner,
894                                source: ErrSource::Module {
895                                    name: module_definition.name.to_owned(),
896                                    offset: 0,
897                                    defs: shader_defs.clone(),
898                                },
899                            }
900                        };
901
902                        // ensure overrides are applied to virtual functions
903                        let raw_module_name = Self::decode(&target_module);
904                        let module_set = self.module_sets.get(&raw_module_name);
905
906                        match module_set {
907                            None => {
908                                // TODO this should be unreachable?
909                                let pos = cap.get(3).unwrap().start();
910                                override_error = Some(wrap_err(
911                                    ComposerErrorInner::ImportNotFound(raw_module_name, pos),
912                                ));
913                            }
914                            Some(module_set) => {
915                                let module = module_set.get_module(shader_defs).unwrap();
916                                if !module.virtual_functions.contains(&target_function) {
917                                    let pos = cap.get(2).unwrap().start();
918                                    override_error =
919                                        Some(wrap_err(ComposerErrorInner::OverrideNotVirtual {
920                                            name: target_function.clone(),
921                                            pos,
922                                        }));
923                                }
924                            }
925                        }
926                    }
927
928                    let base_name = format!(
929                        "{}{}{}{}",
930                        target_function.as_str(),
931                        DECORATION_PRE,
932                        target_module.as_str(),
933                        DECORATION_POST,
934                    );
935                    let rename = format!(
936                        "{}{}{}{}",
937                        target_function.as_str(),
938                        DECORATION_OVERRIDE_PRE,
939                        target_module.as_str(),
940                        DECORATION_POST,
941                    );
942
943                    let replacement_str = format!(
944                        "{}fn {}{}(",
945                        " ".repeat(cap.get(1).unwrap().range().len() - 3),
946                        rename,
947                        " ".repeat(cap.get(4).unwrap().range().len()),
948                    );
949
950                    local_override_functions.insert(rename, base_name);
951
952                    replacement_str
953                })
954                .to_string();
955
956        #[cfg(not(feature = "override_any"))]
957        if let Some(err) = override_error {
958            return Err(err);
959        }
960
961        trace!("local overrides: {:?}", local_override_functions);
962        trace!(
963            "create composable module {}: source len {}",
964            module_definition.name,
965            source.len()
966        );
967
968        let IrBuildResult {
969            module: mut source_ir,
970            start_offset,
971            mut override_functions,
972        } = self.create_module_ir(
973            &module_definition.name,
974            source,
975            module_definition.language,
976            &imports,
977            shader_defs,
978        )?;
979
980        // from here on errors need to be reported using the modified source with start_offset
981        let wrap_err = |inner: ComposerErrorInner| -> ComposerError {
982            ComposerError {
983                inner,
984                source: ErrSource::Module {
985                    name: module_definition.name.to_owned(),
986                    offset: start_offset,
987                    defs: shader_defs.clone(),
988                },
989            }
990        };
991
992        // add our local override to the total set of overrides for the given function
993        for (rename, base_name) in &local_override_functions {
994            override_functions
995                .entry(base_name.clone())
996                .or_default()
997                .push(format!("{rename}{module_decoration}"));
998        }
999
1000        // rename and record owned items (except types which can't be mutably accessed)
1001        let mut owned_constants = IndexMap::new();
1002        for (h, c) in source_ir.constants.iter_mut() {
1003            if let Some(name) = c.name.as_mut() {
1004                if !name.contains(DECORATION_PRE) {
1005                    *name = format!("{name}{module_decoration}");
1006                    owned_constants.insert(name.clone(), h);
1007                }
1008            }
1009        }
1010
1011        // These are naga/wgpu's pipeline override constants, not naga_oil's overrides
1012        let mut owned_pipeline_overrides = IndexMap::new();
1013        for (h, po) in source_ir.overrides.iter_mut() {
1014            if let Some(name) = po.name.as_mut() {
1015                if !name.contains(DECORATION_PRE) {
1016                    *name = format!("{name}{module_decoration}");
1017                    owned_pipeline_overrides.insert(name.clone(), h);
1018                }
1019            }
1020        }
1021
1022        let mut owned_vars = IndexMap::new();
1023        for (h, gv) in source_ir.global_variables.iter_mut() {
1024            if let Some(name) = gv.name.as_mut() {
1025                if !name.contains(DECORATION_PRE) {
1026                    *name = format!("{name}{module_decoration}");
1027
1028                    owned_vars.insert(name.clone(), h);
1029                }
1030            }
1031        }
1032
1033        let mut owned_functions = IndexMap::new();
1034        for (h_f, f) in source_ir.functions.iter_mut() {
1035            if let Some(name) = f.name.as_mut() {
1036                if !name.contains(DECORATION_PRE) {
1037                    *name = format!("{name}{module_decoration}");
1038
1039                    // create dummy header function
1040                    let header_function = naga::Function {
1041                        name: Some(name.clone()),
1042                        arguments: f.arguments.to_vec(),
1043                        result: f.result.clone(),
1044                        local_variables: Default::default(),
1045                        expressions: Default::default(),
1046                        named_expressions: Default::default(),
1047                        body: Default::default(),
1048                        diagnostic_filter_leaf: None,
1049                    };
1050
1051                    // record owned function
1052                    owned_functions.insert(name.clone(), (Some(h_f), header_function));
1053                }
1054            }
1055        }
1056
1057        if demote_entrypoints {
1058            // make normal functions out of the source entry points
1059            for ep in &mut source_ir.entry_points {
1060                ep.function.name = Some(format!(
1061                    "{}{}",
1062                    ep.function.name.as_deref().unwrap_or("main"),
1063                    module_decoration,
1064                ));
1065                let header_function = naga::Function {
1066                    name: ep.function.name.clone(),
1067                    arguments: ep
1068                        .function
1069                        .arguments
1070                        .iter()
1071                        .cloned()
1072                        .map(|arg| naga::FunctionArgument {
1073                            name: arg.name,
1074                            ty: arg.ty,
1075                            binding: None,
1076                        })
1077                        .collect(),
1078                    result: ep.function.result.clone().map(|res| naga::FunctionResult {
1079                        ty: res.ty,
1080                        binding: None,
1081                    }),
1082                    local_variables: Default::default(),
1083                    expressions: Default::default(),
1084                    named_expressions: Default::default(),
1085                    body: Default::default(),
1086                    diagnostic_filter_leaf: None,
1087                };
1088
1089                owned_functions.insert(ep.function.name.clone().unwrap(), (None, header_function));
1090            }
1091        };
1092
1093        let mut module_builder = DerivedModule::default();
1094        let mut header_builder = DerivedModule::default();
1095        module_builder.set_shader_source(&source_ir, 0);
1096        header_builder.set_shader_source(&source_ir, 0);
1097
1098        let mut owned_types = HashSet::new();
1099        for (h, ty) in source_ir.types.iter() {
1100            if let Some(name) = &ty.name {
1101                // we need to exclude autogenerated struct names, i.e. those that begin with "__"
1102                // "__" is a reserved prefix for naga so user variables cannot use it.
1103                if !name.contains(DECORATION_PRE) && !name.starts_with("__") {
1104                    let name = format!("{name}{module_decoration}");
1105                    owned_types.insert(name.clone());
1106                    // copy and rename types
1107                    module_builder.rename_type(&h, Some(name.clone()));
1108                    header_builder.rename_type(&h, Some(name));
1109                    continue;
1110                }
1111            }
1112
1113            // copy all required types
1114            module_builder.import_type(&h);
1115        }
1116
1117        // copy owned types into header and module
1118        for h in owned_constants.values() {
1119            header_builder.import_const(h);
1120            module_builder.import_const(h);
1121        }
1122
1123        for h in owned_pipeline_overrides.values() {
1124            header_builder.import_pipeline_override(h);
1125            module_builder.import_pipeline_override(h);
1126        }
1127
1128        for h in owned_vars.values() {
1129            header_builder.import_global(h);
1130            module_builder.import_global(h);
1131        }
1132
1133        // only stubs of owned functions into the header
1134        for (h_f, f) in owned_functions.values() {
1135            let span = h_f
1136                .map(|h_f| source_ir.functions.get_span(h_f))
1137                .unwrap_or(naga::Span::UNDEFINED);
1138            header_builder.import_function(f, span); // header stub function
1139        }
1140        // all functions into the module (note source_ir only contains stubs for imported functions)
1141        for (h_f, f) in source_ir.functions.iter() {
1142            let span = source_ir.functions.get_span(h_f);
1143            module_builder.import_function(f, span);
1144        }
1145        // // including entry points as vanilla functions if required
1146        if demote_entrypoints {
1147            for ep in &source_ir.entry_points {
1148                let mut f = ep.function.clone();
1149                f.arguments = f
1150                    .arguments
1151                    .into_iter()
1152                    .map(|arg| naga::FunctionArgument {
1153                        name: arg.name,
1154                        ty: arg.ty,
1155                        binding: None,
1156                    })
1157                    .collect();
1158                f.result = f.result.map(|res| naga::FunctionResult {
1159                    ty: res.ty,
1160                    binding: None,
1161                });
1162
1163                module_builder.import_function(&f, naga::Span::UNDEFINED);
1164                // todo figure out how to get span info for entrypoints
1165            }
1166        }
1167
1168        let module_ir = module_builder.into_module_with_entrypoints();
1169        let mut header_ir: naga::Module = header_builder.into();
1170
1171        if self.validate && create_headers {
1172            // check that identifiers haven't been renamed
1173            #[allow(clippy::single_element_loop)]
1174            for language in [
1175                ShaderLanguage::Wgsl,
1176                #[cfg(feature = "glsl")]
1177                ShaderLanguage::Glsl,
1178            ] {
1179                let header = self
1180                    .naga_to_string(&mut header_ir, language, &module_definition.name)
1181                    .map_err(wrap_err)?;
1182                Self::validate_identifiers(
1183                    &source_ir,
1184                    language,
1185                    &header,
1186                    &module_decoration,
1187                    &owned_types,
1188                )
1189                .map_err(wrap_err)?;
1190            }
1191        }
1192
1193        let composable_module = ComposableModule {
1194            decorated_name: module_decoration,
1195            imports,
1196            owned_types,
1197            owned_constants: owned_constants.into_keys().collect(),
1198            owned_vars: owned_vars.into_keys().collect(),
1199            owned_functions: owned_functions.into_keys().collect(),
1200            virtual_functions,
1201            override_functions,
1202            module_ir,
1203            header_ir,
1204            start_offset,
1205        };
1206
1207        Ok(composable_module)
1208    }
1209
1210    // shunt all data owned by a composable into a derived module
1211    fn add_composable_data<'a>(
1212        derived: &mut DerivedModule<'a>,
1213        composable: &'a ComposableModule,
1214        items: Option<&Vec<String>>,
1215        span_offset: usize,
1216        header: bool,
1217    ) {
1218        let items: Option<HashSet<String>> = items.map(|items| {
1219            items
1220                .iter()
1221                .map(|item| format!("{}{}", item, composable.decorated_name))
1222                .collect()
1223        });
1224        let items = items.as_ref();
1225
1226        let source_ir = match header {
1227            true => &composable.header_ir,
1228            false => &composable.module_ir,
1229        };
1230
1231        derived.set_shader_source(source_ir, span_offset);
1232
1233        for (h, ty) in source_ir.types.iter() {
1234            if let Some(name) = &ty.name {
1235                if composable.owned_types.contains(name)
1236                    && items.map_or(true, |items| items.contains(name))
1237                {
1238                    derived.import_type(&h);
1239                }
1240            }
1241        }
1242
1243        for (h, c) in source_ir.constants.iter() {
1244            if let Some(name) = &c.name {
1245                if composable.owned_constants.contains(name)
1246                    && items.map_or(true, |items| items.contains(name))
1247                {
1248                    derived.import_const(&h);
1249                }
1250            }
1251        }
1252
1253        for (h, po) in source_ir.overrides.iter() {
1254            if let Some(name) = &po.name {
1255                if composable.owned_functions.contains(name)
1256                    && items.map_or(true, |items| items.contains(name))
1257                {
1258                    derived.import_pipeline_override(&h);
1259                }
1260            }
1261        }
1262
1263        for (h, v) in source_ir.global_variables.iter() {
1264            if let Some(name) = &v.name {
1265                if composable.owned_vars.contains(name)
1266                    && items.map_or(true, |items| items.contains(name))
1267                {
1268                    derived.import_global(&h);
1269                }
1270            }
1271        }
1272
1273        for (h_f, f) in source_ir.functions.iter() {
1274            if let Some(name) = &f.name {
1275                if composable.owned_functions.contains(name)
1276                    && (items.map_or(true, |items| items.contains(name))
1277                        || composable
1278                            .override_functions
1279                            .values()
1280                            .any(|v| v.contains(name)))
1281                {
1282                    let span = composable.module_ir.functions.get_span(h_f);
1283                    derived.import_function_if_new(f, span);
1284                }
1285            }
1286        }
1287
1288        derived.clear_shader_source();
1289    }
1290
1291    // add an import (and recursive imports) into a derived module
1292    fn add_import<'a>(
1293        &'a self,
1294        derived: &mut DerivedModule<'a>,
1295        import: &ImportDefinition,
1296        shader_defs: &HashMap<String, ShaderDefValue>,
1297        header: bool,
1298        already_added: &mut HashSet<String>,
1299    ) {
1300        if already_added.contains(&import.import) {
1301            trace!("skipping {}, already added", import.import);
1302            return;
1303        }
1304
1305        let import_module_set = self.module_sets.get(&import.import).unwrap();
1306        let module = import_module_set.get_module(shader_defs).unwrap();
1307
1308        for import in &module.imports {
1309            self.add_import(derived, import, shader_defs, header, already_added);
1310        }
1311
1312        Self::add_composable_data(
1313            derived,
1314            module,
1315            Some(&import.items),
1316            import_module_set.module_index << SPAN_SHIFT,
1317            header,
1318        );
1319    }
1320
1321    fn ensure_import(
1322        &mut self,
1323        module_set: &ComposableModuleDefinition,
1324        shader_defs: &HashMap<String, ShaderDefValue>,
1325    ) -> Result<ComposableModule, EnsureImportsError> {
1326        let PreprocessOutput {
1327            preprocessed_source,
1328            imports,
1329        } = self
1330            .preprocessor
1331            .preprocess(&module_set.sanitized_source, shader_defs)
1332            .map_err(|inner| {
1333                EnsureImportsError::from(ComposerError {
1334                    inner,
1335                    source: ErrSource::Module {
1336                        name: module_set.name.to_owned(),
1337                        offset: 0,
1338                        defs: shader_defs.clone(),
1339                    },
1340                })
1341            })?;
1342
1343        self.ensure_imports(imports.iter().map(|import| &import.definition), shader_defs)?;
1344        self.ensure_imports(&module_set.additional_imports, shader_defs)?;
1345
1346        self.create_composable_module(
1347            module_set,
1348            Self::decorate(&module_set.name),
1349            shader_defs,
1350            true,
1351            true,
1352            &preprocessed_source,
1353            imports,
1354        )
1355        .map_err(|err| err.into())
1356    }
1357
1358    // build required ComposableModules for a given set of shader_defs
1359    fn ensure_imports<'a>(
1360        &mut self,
1361        imports: impl IntoIterator<Item = &'a ImportDefinition>,
1362        shader_defs: &HashMap<String, ShaderDefValue>,
1363    ) -> Result<(), EnsureImportsError> {
1364        for ImportDefinition { import, .. } in imports.into_iter() {
1365            let Some(module_set) = self.module_sets.get(import) else {
1366                return Err(EnsureImportsError::MissingImport(import.to_owned()));
1367            };
1368            if module_set.get_module(shader_defs).is_some() {
1369                continue;
1370            }
1371
1372            // we need to build the module
1373            // take the set so we can recurse without borrowing
1374            let (set_key, mut module_set) = self.module_sets.remove_entry(import).unwrap();
1375
1376            match self.ensure_import(&module_set, shader_defs) {
1377                Ok(module) => {
1378                    module_set.insert_module(shader_defs, module);
1379                    self.module_sets.insert(set_key, module_set);
1380                }
1381                Err(e) => {
1382                    self.module_sets.insert(set_key, module_set);
1383                    return Err(e);
1384                }
1385            }
1386        }
1387
1388        Ok(())
1389    }
1390}
1391
1392pub enum EnsureImportsError {
1393    MissingImport(String),
1394    ComposerError(ComposerError),
1395}
1396
1397impl EnsureImportsError {
1398    fn into_composer_error(self, err_source: ErrSource) -> ComposerError {
1399        match self {
1400            EnsureImportsError::MissingImport(import) => ComposerError {
1401                inner: ComposerErrorInner::ImportNotFound(import.to_owned(), 0),
1402                source: err_source,
1403            },
1404            EnsureImportsError::ComposerError(err) => err,
1405        }
1406    }
1407}
1408
1409impl From<ComposerError> for EnsureImportsError {
1410    fn from(value: ComposerError) -> Self {
1411        EnsureImportsError::ComposerError(value)
1412    }
1413}
1414
1415#[derive(Default)]
1416pub struct ComposableModuleDescriptor<'a> {
1417    pub source: &'a str,
1418    pub file_path: &'a str,
1419    pub language: ShaderLanguage,
1420    pub as_name: Option<String>,
1421    pub additional_imports: &'a [ImportDefinition],
1422    pub shader_defs: HashMap<String, ShaderDefValue>,
1423}
1424
1425#[derive(Default)]
1426pub struct NagaModuleDescriptor<'a> {
1427    pub source: &'a str,
1428    pub file_path: &'a str,
1429    pub shader_type: ShaderType,
1430    pub shader_defs: HashMap<String, ShaderDefValue>,
1431    pub additional_imports: &'a [ImportDefinition],
1432}
1433
1434// public api
1435impl Composer {
1436    /// create a non-validating composer.
1437    /// validation errors in the final shader will not be caught, and errors resulting from their
1438    /// use will have bad span data, so codespan reporting will fail.
1439    /// use default() to create a validating composer.
1440    pub fn non_validating() -> Self {
1441        Self {
1442            validate: false,
1443            ..Default::default()
1444        }
1445    }
1446
1447    /// specify capabilities to be used for naga module generation.
1448    /// purges any existing modules
1449    /// See https://github.com/gfx-rs/wgpu/blob/d9c054c645af0ea9ef81617c3e762fbf0f3fecda/wgpu-core/src/device/mod.rs#L515
1450    /// for how to set the subgroup_stages value.
1451    pub fn with_capabilities(self, capabilities: naga::valid::Capabilities) -> Self {
1452        Self {
1453            capabilities,
1454            validate: self.validate,
1455            ..Default::default()
1456        }
1457    }
1458
1459    /// check if a module with the given name has been added
1460    pub fn contains_module(&self, module_name: &str) -> bool {
1461        self.module_sets.contains_key(module_name)
1462    }
1463
1464    /// add a composable module to the composer.
1465    /// all modules imported by this module must already have been added
1466    pub fn add_composable_module(
1467        &mut self,
1468        desc: ComposableModuleDescriptor,
1469    ) -> Result<&ComposableModuleDefinition, ComposerError> {
1470        let ComposableModuleDescriptor {
1471            source,
1472            file_path,
1473            language,
1474            as_name,
1475            additional_imports,
1476            mut shader_defs,
1477        } = desc;
1478
1479        // reject a module containing the DECORATION strings
1480        if let Some(decor) = self.check_decoration_regex.find(source) {
1481            return Err(ComposerError {
1482                inner: ComposerErrorInner::DecorationInSource(decor.range()),
1483                source: ErrSource::Constructing {
1484                    path: file_path.to_owned(),
1485                    source: source.to_owned(),
1486                    offset: 0,
1487                },
1488            });
1489        }
1490
1491        let substituted_source = self.sanitize_and_set_auto_bindings(source);
1492
1493        let PreprocessorMetaData {
1494            name: module_name,
1495            mut imports,
1496            mut effective_defs,
1497            ..
1498        } = self
1499            .preprocessor
1500            .get_preprocessor_metadata(&substituted_source, false)
1501            .map_err(|inner| ComposerError {
1502                inner,
1503                source: ErrSource::Constructing {
1504                    path: file_path.to_owned(),
1505                    source: source.to_owned(),
1506                    offset: 0,
1507                },
1508            })?;
1509        let module_name = as_name.or(module_name);
1510        if module_name.is_none() {
1511            return Err(ComposerError {
1512                inner: ComposerErrorInner::NoModuleName,
1513                source: ErrSource::Constructing {
1514                    path: file_path.to_owned(),
1515                    source: source.to_owned(),
1516                    offset: 0,
1517                },
1518            });
1519        }
1520        let module_name = module_name.unwrap();
1521
1522        debug!(
1523            "adding module definition for {} with defs: {:?}",
1524            module_name, shader_defs
1525        );
1526
1527        // add custom imports
1528        let additional_imports = additional_imports.to_vec();
1529        imports.extend(
1530            additional_imports
1531                .iter()
1532                .cloned()
1533                .map(|def| ImportDefWithOffset {
1534                    definition: def,
1535                    offset: 0,
1536                }),
1537        );
1538
1539        for import in &imports {
1540            // we require modules already added so that we can capture the shader_defs that may impact us by impacting our dependencies
1541            let module_set = self
1542                .module_sets
1543                .get(&import.definition.import)
1544                .ok_or_else(|| ComposerError {
1545                    inner: ComposerErrorInner::ImportNotFound(
1546                        import.definition.import.clone(),
1547                        import.offset,
1548                    ),
1549                    source: ErrSource::Constructing {
1550                        path: file_path.to_owned(),
1551                        source: substituted_source.to_owned(),
1552                        offset: 0,
1553                    },
1554                })?;
1555            effective_defs.extend(module_set.effective_defs.iter().cloned());
1556            shader_defs.extend(
1557                module_set
1558                    .shader_defs
1559                    .iter()
1560                    .map(|def| (def.0.clone(), *def.1)),
1561            );
1562        }
1563
1564        // remove defs that are already specified through our imports
1565        effective_defs.retain(|name| !shader_defs.contains_key(name));
1566
1567        // can't gracefully report errors for more modules. perhaps this should be a warning
1568        assert!((self.module_sets.len() as u32) < u32::MAX >> SPAN_SHIFT);
1569        let module_index = self.module_sets.len() + 1;
1570
1571        let module_set = ComposableModuleDefinition {
1572            name: module_name.clone(),
1573            sanitized_source: substituted_source,
1574            file_path: file_path.to_owned(),
1575            language,
1576            effective_defs: effective_defs.into_iter().collect(),
1577            all_imports: imports.into_iter().map(|id| id.definition.import).collect(),
1578            additional_imports,
1579            shader_defs,
1580            module_index,
1581            modules: Default::default(),
1582        };
1583
1584        // invalidate dependent modules if this module already exists
1585        self.remove_composable_module(&module_name);
1586
1587        self.module_sets.insert(module_name.clone(), module_set);
1588        self.module_index.insert(module_index, module_name.clone());
1589        Ok(self.module_sets.get(&module_name).unwrap())
1590    }
1591
1592    /// remove a composable module. also removes modules that depend on this module, as we cannot be sure about
1593    /// the completeness of their effective shader defs any more...
1594    pub fn remove_composable_module(&mut self, module_name: &str) {
1595        // todo this could be improved by making effective defs an Option<HashSet> and populating on demand?
1596        let mut dependent_sets = Vec::new();
1597
1598        if self.module_sets.remove(module_name).is_some() {
1599            dependent_sets.extend(self.module_sets.iter().filter_map(|(dependent_name, set)| {
1600                if set.all_imports.contains(module_name) {
1601                    Some(dependent_name.clone())
1602                } else {
1603                    None
1604                }
1605            }));
1606        }
1607
1608        for dependent_set in dependent_sets {
1609            self.remove_composable_module(&dependent_set);
1610        }
1611    }
1612
1613    /// build a naga shader module
1614    pub fn make_naga_module(
1615        &mut self,
1616        desc: NagaModuleDescriptor,
1617    ) -> Result<naga::Module, ComposerError> {
1618        let NagaModuleDescriptor {
1619            source,
1620            file_path,
1621            shader_type,
1622            mut shader_defs,
1623            additional_imports,
1624        } = desc;
1625
1626        let sanitized_source = self.sanitize_and_set_auto_bindings(source);
1627
1628        let PreprocessorMetaData { name, defines, .. } = self
1629            .preprocessor
1630            .get_preprocessor_metadata(&sanitized_source, true)
1631            .map_err(|inner| ComposerError {
1632                inner,
1633                source: ErrSource::Constructing {
1634                    path: file_path.to_owned(),
1635                    source: sanitized_source.to_owned(),
1636                    offset: 0,
1637                },
1638            })?;
1639        shader_defs.extend(defines);
1640
1641        let name = name.unwrap_or_default();
1642
1643        let PreprocessOutput { imports, .. } = self
1644            .preprocessor
1645            .preprocess(&sanitized_source, &shader_defs)
1646            .map_err(|inner| ComposerError {
1647                inner,
1648                source: ErrSource::Constructing {
1649                    path: file_path.to_owned(),
1650                    source: sanitized_source.to_owned(),
1651                    offset: 0,
1652                },
1653            })?;
1654
1655        // make sure imports have been added
1656        // and gather additional defs specified at module level
1657        for (import_name, offset) in imports
1658            .iter()
1659            .map(|id| (&id.definition.import, id.offset))
1660            .chain(additional_imports.iter().map(|ai| (&ai.import, 0)))
1661        {
1662            if let Some(module_set) = self.module_sets.get(import_name) {
1663                for (def, value) in &module_set.shader_defs {
1664                    if let Some(prior_value) = shader_defs.insert(def.clone(), *value) {
1665                        if prior_value != *value {
1666                            return Err(ComposerError {
1667                                inner: ComposerErrorInner::InconsistentShaderDefValue {
1668                                    def: def.clone(),
1669                                },
1670                                source: ErrSource::Constructing {
1671                                    path: file_path.to_owned(),
1672                                    source: sanitized_source.to_owned(),
1673                                    offset: 0,
1674                                },
1675                            });
1676                        }
1677                    }
1678                }
1679            } else {
1680                return Err(ComposerError {
1681                    inner: ComposerErrorInner::ImportNotFound(import_name.clone(), offset),
1682                    source: ErrSource::Constructing {
1683                        path: file_path.to_owned(),
1684                        source: sanitized_source.to_owned(),
1685                        offset: 0,
1686                    },
1687                });
1688            }
1689        }
1690        self.ensure_imports(
1691            imports.iter().map(|import| &import.definition),
1692            &shader_defs,
1693        )
1694        .map_err(|err| {
1695            err.into_composer_error(ErrSource::Constructing {
1696                path: file_path.to_owned(),
1697                source: sanitized_source.to_owned(),
1698                offset: 0,
1699            })
1700        })?;
1701        self.ensure_imports(additional_imports, &shader_defs)
1702            .map_err(|err| {
1703                err.into_composer_error(ErrSource::Constructing {
1704                    path: file_path.to_owned(),
1705                    source: sanitized_source.to_owned(),
1706                    offset: 0,
1707                })
1708            })?;
1709
1710        let definition = ComposableModuleDefinition {
1711            name,
1712            sanitized_source: sanitized_source.clone(),
1713            language: shader_type.into(),
1714            file_path: file_path.to_owned(),
1715            module_index: 0,
1716            additional_imports: additional_imports.to_vec(),
1717            // we don't care about these for creating a top-level module
1718            effective_defs: Default::default(),
1719            all_imports: Default::default(),
1720            shader_defs: Default::default(),
1721            modules: Default::default(),
1722        };
1723
1724        let PreprocessOutput {
1725            preprocessed_source,
1726            imports,
1727        } = self
1728            .preprocessor
1729            .preprocess(&sanitized_source, &shader_defs)
1730            .map_err(|inner| ComposerError {
1731                inner,
1732                source: ErrSource::Constructing {
1733                    path: file_path.to_owned(),
1734                    source: sanitized_source,
1735                    offset: 0,
1736                },
1737            })?;
1738
1739        let composable = self
1740            .create_composable_module(
1741                &definition,
1742                String::from(""),
1743                &shader_defs,
1744                false,
1745                false,
1746                &preprocessed_source,
1747                imports,
1748            )
1749            .map_err(|e| ComposerError {
1750                inner: e.inner,
1751                source: ErrSource::Constructing {
1752                    path: definition.file_path.to_owned(),
1753                    source: preprocessed_source.clone(),
1754                    offset: e.source.offset(),
1755                },
1756            })?;
1757
1758        let mut derived = DerivedModule::default();
1759
1760        let mut already_added = Default::default();
1761        for import in &composable.imports {
1762            self.add_import(
1763                &mut derived,
1764                import,
1765                &shader_defs,
1766                false,
1767                &mut already_added,
1768            );
1769        }
1770
1771        Self::add_composable_data(&mut derived, &composable, None, 0, false);
1772
1773        let stage = match shader_type {
1774            #[cfg(feature = "glsl")]
1775            ShaderType::GlslVertex => Some(naga::ShaderStage::Vertex),
1776            #[cfg(feature = "glsl")]
1777            ShaderType::GlslFragment => Some(naga::ShaderStage::Fragment),
1778            _ => None,
1779        };
1780
1781        let mut entry_points = Vec::default();
1782        derived.set_shader_source(&composable.module_ir, 0);
1783        for ep in &composable.module_ir.entry_points {
1784            let mapped_func = derived.localize_function(&ep.function);
1785            entry_points.push(EntryPoint {
1786                name: ep.name.clone(),
1787                function: mapped_func,
1788                stage: stage.unwrap_or(ep.stage),
1789                early_depth_test: ep.early_depth_test,
1790                workgroup_size: ep.workgroup_size,
1791                workgroup_size_overrides: ep.workgroup_size_overrides,
1792            });
1793        }
1794
1795        let mut naga_module = naga::Module {
1796            entry_points,
1797            ..derived.into()
1798        };
1799
1800        // apply overrides
1801        if !composable.override_functions.is_empty() {
1802            let mut redirect = Redirector::new(naga_module);
1803
1804            for (base_function, overrides) in composable.override_functions {
1805                let mut omit = HashSet::default();
1806
1807                let mut original = base_function;
1808                for replacement in overrides {
1809                    let (_h_orig, _h_replace) = redirect
1810                        .redirect_function(&original, &replacement, &omit)
1811                        .map_err(|e| ComposerError {
1812                            inner: e.into(),
1813                            source: ErrSource::Constructing {
1814                                path: file_path.to_owned(),
1815                                source: preprocessed_source.clone(),
1816                                offset: composable.start_offset,
1817                            },
1818                        })?;
1819                    omit.insert(replacement.clone());
1820                    original = replacement;
1821                }
1822            }
1823
1824            naga_module = redirect.into_module().map_err(|e| ComposerError {
1825                inner: e.into(),
1826                source: ErrSource::Constructing {
1827                    path: file_path.to_owned(),
1828                    source: preprocessed_source.clone(),
1829                    offset: composable.start_offset,
1830                },
1831            })?;
1832        }
1833
1834        // validation
1835        if self.validate {
1836            let info = self.create_validator().validate(&naga_module);
1837            match info {
1838                Ok(_) => Ok(naga_module),
1839                Err(e) => {
1840                    let original_span = e.spans().last();
1841                    let err_source = match original_span.and_then(|(span, _)| span.to_range()) {
1842                        Some(rng) => {
1843                            let module_index = rng.start >> SPAN_SHIFT;
1844                            match module_index {
1845                                0 => ErrSource::Constructing {
1846                                    path: file_path.to_owned(),
1847                                    source: preprocessed_source.clone(),
1848                                    offset: composable.start_offset,
1849                                },
1850                                _ => {
1851                                    let module_name =
1852                                        self.module_index.get(&module_index).unwrap().clone();
1853                                    let offset = self
1854                                        .module_sets
1855                                        .get(&module_name)
1856                                        .unwrap()
1857                                        .get_module(&shader_defs)
1858                                        .unwrap()
1859                                        .start_offset;
1860                                    ErrSource::Module {
1861                                        name: module_name,
1862                                        offset,
1863                                        defs: shader_defs.clone(),
1864                                    }
1865                                }
1866                            }
1867                        }
1868                        None => ErrSource::Constructing {
1869                            path: file_path.to_owned(),
1870                            source: preprocessed_source.clone(),
1871                            offset: composable.start_offset,
1872                        },
1873                    };
1874
1875                    Err(ComposerError {
1876                        inner: ComposerErrorInner::ShaderValidationError(e),
1877                        source: err_source,
1878                    })
1879                }
1880            }
1881        } else {
1882            Ok(naga_module)
1883        }
1884    }
1885}
1886
1887static PREPROCESSOR: once_cell::sync::Lazy<Preprocessor> =
1888    once_cell::sync::Lazy::new(Preprocessor::default);
1889
1890/// Get module name and all required imports (ignoring shader_defs) from a shader string
1891pub fn get_preprocessor_data(
1892    source: &str,
1893) -> (
1894    Option<String>,
1895    Vec<ImportDefinition>,
1896    HashMap<String, ShaderDefValue>,
1897) {
1898    if let Ok(PreprocessorMetaData {
1899        name,
1900        imports,
1901        defines,
1902        ..
1903    }) = PREPROCESSOR.get_preprocessor_metadata(source, true)
1904    {
1905        (
1906            name,
1907            imports
1908                .into_iter()
1909                .map(|import_with_offset| import_with_offset.definition)
1910                .collect(),
1911            defines,
1912        )
1913    } else {
1914        // if errors occur we return nothing; the actual error will be displayed when the caller attempts to use the shader
1915        Default::default()
1916    }
1917}