naga_oil/compose/
mod.rs

1use indexmap::IndexMap;
2/// the compose module allows construction of shaders from modules (which are themselves shaders).
3///
4/// it does this by treating shaders as modules, and
5/// - building each module independently to naga IR
6/// - creating "header" files for each supported language, which are used to build dependent modules/shaders
7/// - making final shaders by combining the shader IR with the IR for imported modules
8///
9/// for multiple small shaders with large common imports, this can be faster than parsing the full source for each shader, and it allows for constructing shaders in a cleaner modular manner with better scope control.
10///
11/// ## imports
12///
13/// shaders can be added to the composer as modules. this makes their types, constants, variables and functions available to modules/shaders that import them. note that importing a module will affect the final shader's global state if the module defines globals variables with bindings.
14///
15/// modules must include a `#define_import_path` directive that names the module.
16///
17/// ```ignore
18/// #define_import_path my_module
19///
20/// fn my_func() -> f32 {
21///     return 1.0;
22/// }
23/// ```
24///
25/// shaders can then import the module with an `#import` directive (with an optional `as` name). at point of use, imported items must be qualified:
26///
27/// ```ignore
28/// #import my_module
29/// #import my_other_module as Mod2
30///
31/// fn main() -> f32 {
32///     let x = my_module::my_func();
33///     let y = Mod2::my_other_func();
34///     return x*y;
35/// }
36/// ```
37///
38/// or import a comma-separated list of individual items with a `#from` directive. at point of use, imported items must be prefixed with `::` :
39///
40/// ```ignore
41/// #from my_module import my_func, my_const
42///
43/// fn main() -> f32 {
44///     return ::my_func(::my_const);
45/// }
46/// ```
47///
48/// imports can be nested - modules may import other modules, but not recursively. when a new module is added, all its `#import`s must already have been added.
49/// the same module can be imported multiple times by different modules in the import tree.
50/// there is no overlap of namespaces, so the same function names (or type, constant, or variable names) may be used in different modules.
51///
52/// note: when importing an item with the `#from` directive, the final shader will include the required dependencies (bindings, globals, consts, other functions) of the imported item, but will not include the rest of the imported module. it will however still include all of any modules imported by the imported module. this is probably not desired in general and may be fixed in a future version. currently for a more complete culling of unused dependencies the `prune` module can be used.
53///
54/// ## overriding functions
55///
56/// virtual functions can be declared with the `virtual` keyword:
57/// ```ignore
58/// virtual fn point_light(world_position: vec3<f32>) -> vec3<f32> { ... }
59/// ```
60/// virtual functions defined in imported modules can then be overridden using the `override` keyword:
61///
62/// ```ignore
63/// #import bevy_pbr::lighting as Lighting
64///
65/// override fn Lighting::point_light (world_position: vec3<f32>) -> vec3<f32> {
66///     let original = Lighting::point_light(world_position);
67///     let quantized = vec3<u32>(original * 3.0);
68///     return vec3<f32>(quantized) / 3.0;
69/// }
70/// ```
71///
72/// override function definitions cause *all* calls to the original function in the entire shader scope to be replaced by calls to the new function, with the exception of calls within the override function itself.
73///
74/// the function signature of the override must match the base function.
75///
76/// overrides can be specified at any point in the final shader's import tree.
77///
78/// multiple overrides can be applied to the same function. for example, given :
79/// - a module `a` containing a function `f`,
80/// - a module `b` that imports `a`, and containing an `override a::f` function,
81/// - a module `c` that imports `a` and `b`, and containing an `override a::f` function,
82///
83/// then b and c both specify an override for `a::f`.
84/// the `override fn a::f` declared in module `b` may call to `a::f` within its body.
85/// the `override fn a::f` declared in module 'c' may call to `a::f` within its body, but the call will be redirected to `b::f`.
86/// any other calls to `a::f` (within modules 'a' or `b`, or anywhere else) will end up redirected to `c::f`
87/// in this way a chain or stack of overrides can be applied.
88///
89/// different overrides of the same function can be specified in different import branches. the final stack will be ordered based on the first occurrence of the override in the import tree (using a depth first search).
90///
91/// note that imports into a module/shader are processed in order, but are processed before the body of the current shader/module regardless of where they occur in that module, so there is no way to import a module containing an override and inject a call into the override stack prior to that imported override. you can instead create two modules each containing an override and import them into a parent module/shader to order them as required.
92/// override functions can currently only be defined in wgsl.
93///
94/// if the `override_any` crate feature is enabled, then the `virtual` keyword is not required for the function being overridden.
95///
96/// ## languages
97///
98/// modules can we written in GLSL or WGSL. shaders with entry points can be imported as modules (provided they have a `#define_import_path` directive). entry points are available to call from imported modules either via their name (for WGSL) or via `module::main` (for GLSL).
99///
100/// final shaders can also be written in GLSL or WGSL. for GLSL users must specify whether the shader is a vertex shader or fragment shader via the `ShaderType` argument (GLSL compute shaders are not supported).
101///
102/// ## preprocessing
103///
104/// when generating a final shader or adding a composable module, a set of `shader_def` string/value pairs must be provided. The value can be a bool (`ShaderDefValue::Bool`), an i32 (`ShaderDefValue::Int`) or a u32 (`ShaderDefValue::UInt`).
105///
106/// these allow conditional compilation of parts of modules and the final shader. conditional compilation is performed with `#if` / `#ifdef` / `#ifndef`, `#else` and `#endif` preprocessor directives:
107///
108/// ```ignore
109/// fn get_number() -> f32 {
110///     #ifdef BIG_NUMBER
111///         return 999.0;
112///     #else
113///         return 0.999;
114///     #endif
115/// }
116/// ```
117/// the `#ifdef` directive matches when the def name exists in the input binding set (regardless of value). the `#ifndef` directive is the reverse.
118///
119/// the `#if` directive requires a def name, an operator, and a value for comparison:
120/// - the def name must be a provided `shader_def` name.
121/// - the operator must be one of `==`, `!=`, `>=`, `>`, `<`, `<=`
122/// - the value must be an integer literal if comparing to a `ShaderDef::Int`, or `true` or `false` if comparing to a `ShaderDef::Bool`.
123///
124/// shader defs can also be used in the shader source with `#SHADER_DEF` or `#{SHADER_DEF}`, and will be substituted for their value.
125///
126/// ## error reporting
127///
128/// codespan reporting for errors is available using the error `emit_to_string` method. this requires validation to be enabled, which is true by default. `Composer::non_validating()` produces a non-validating composer that is not able to give accurate error reporting.
129///
130use naga::EntryPoint;
131use regex::Regex;
132use std::collections::{hash_map::Entry, BTreeMap, HashMap, HashSet};
133use tracing::{debug, trace};
134
135use crate::{
136    compose::preprocess::{PreprocessOutput, PreprocessorMetaData},
137    derive::DerivedModule,
138    redirect::Redirector,
139};
140
141pub use self::error::{ComposerError, ComposerErrorInner, ErrSource};
142use self::preprocess::Preprocessor;
143
144pub mod comment_strip_iter;
145pub mod error;
146pub mod parse_imports;
147pub mod preprocess;
148mod test;
149pub mod tokenizer;
150
151#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug, Default)]
152pub enum ShaderLanguage {
153    #[default]
154    Wgsl,
155    #[cfg(feature = "glsl")]
156    Glsl,
157}
158
159#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug, Default)]
160pub enum ShaderType {
161    #[default]
162    Wgsl,
163    #[cfg(feature = "glsl")]
164    GlslVertex,
165    #[cfg(feature = "glsl")]
166    GlslFragment,
167}
168
169impl From<ShaderType> for ShaderLanguage {
170    fn from(ty: ShaderType) -> Self {
171        match ty {
172            ShaderType::Wgsl => ShaderLanguage::Wgsl,
173            #[cfg(feature = "glsl")]
174            ShaderType::GlslVertex | ShaderType::GlslFragment => ShaderLanguage::Glsl,
175        }
176    }
177}
178
179#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
180pub enum ShaderDefValue {
181    Bool(bool),
182    Int(i32),
183    UInt(u32),
184}
185
186impl Default for ShaderDefValue {
187    fn default() -> Self {
188        ShaderDefValue::Bool(true)
189    }
190}
191
192impl ShaderDefValue {
193    fn value_as_string(&self) -> String {
194        match self {
195            ShaderDefValue::Bool(val) => val.to_string(),
196            ShaderDefValue::Int(val) => val.to_string(),
197            ShaderDefValue::UInt(val) => val.to_string(),
198        }
199    }
200}
201
202#[derive(Clone, PartialEq, Eq, Hash, Debug, Default)]
203pub struct OwnedShaderDefs(BTreeMap<String, ShaderDefValue>);
204
205#[derive(Clone, PartialEq, Eq, Hash, Debug)]
206struct ModuleKey(OwnedShaderDefs);
207
208impl ModuleKey {
209    fn from_members(key: &HashMap<String, ShaderDefValue>, universe: &[String]) -> Self {
210        let mut acc = OwnedShaderDefs::default();
211        for item in universe {
212            if let Some(value) = key.get(item) {
213                acc.0.insert(item.to_owned(), *value);
214            }
215        }
216        ModuleKey(acc)
217    }
218}
219
220// a module built with a specific set of shader_defs
221#[derive(Default, Debug)]
222pub struct ComposableModule {
223    // module decoration, prefixed to all items from this module in the final source
224    pub decorated_name: String,
225    // module names required as imports, optionally with a list of items to import
226    pub imports: Vec<ImportDefinition>,
227    // types exported
228    pub owned_types: HashSet<String>,
229    // constants exported
230    pub owned_constants: HashSet<String>,
231    // vars exported
232    pub owned_vars: HashSet<String>,
233    // functions exported
234    pub owned_functions: HashSet<String>,
235    // local functions that can be overridden
236    pub virtual_functions: HashSet<String>,
237    // overriding functions defined in this module
238    // target function -> Vec<replacement functions>
239    pub override_functions: IndexMap<String, Vec<String>>,
240    // naga module, built against headers for any imports
241    module_ir: naga::Module,
242    // headers in different shader languages, used for building modules/shaders that import this module
243    // headers contain types, constants, global vars and empty function definitions -
244    // just enough to convert source strings that want to import this module into naga IR
245    // headers: HashMap<ShaderLanguage, String>,
246    header_ir: naga::Module,
247    // character offset of the start of the owned module string
248    start_offset: usize,
249}
250
251// data used to build a ComposableModule
252#[derive(Debug)]
253pub struct ComposableModuleDefinition {
254    pub name: String,
255    // shader text (with auto bindings replaced - we do this on module add as we only want to do it once to avoid burning slots)
256    pub sanitized_source: String,
257    // language
258    pub language: ShaderLanguage,
259    // source path for error display
260    pub file_path: String,
261    // shader def values bound to this module
262    pub shader_defs: HashMap<String, ShaderDefValue>,
263    // list of shader_defs that can affect this module
264    effective_defs: Vec<String>,
265    // full list of possible imports (regardless of shader_def configuration)
266    all_imports: HashSet<String>,
267    // additional imports to add (as though they were included in the source after any other imports)
268    additional_imports: Vec<ImportDefinition>,
269    // built composable modules for a given set of shader defs
270    modules: HashMap<ModuleKey, ComposableModule>,
271    // used in spans when this module is included
272    module_index: usize,
273    // preprocessor meta data
274    // metadata: PreprocessorMetaData,
275}
276
277impl ComposableModuleDefinition {
278    fn get_module(
279        &self,
280        shader_defs: &HashMap<String, ShaderDefValue>,
281    ) -> Option<&ComposableModule> {
282        self.modules
283            .get(&ModuleKey::from_members(shader_defs, &self.effective_defs))
284    }
285
286    fn insert_module(
287        &mut self,
288        shader_defs: &HashMap<String, ShaderDefValue>,
289        module: ComposableModule,
290    ) -> &ComposableModule {
291        match self
292            .modules
293            .entry(ModuleKey::from_members(shader_defs, &self.effective_defs))
294        {
295            Entry::Occupied(_) => panic!("entry already populated"),
296            Entry::Vacant(v) => v.insert(module),
297        }
298    }
299}
300
301#[derive(Debug, Clone, Default, PartialEq, Eq)]
302pub struct ImportDefinition {
303    pub import: String,
304    pub items: Vec<String>,
305}
306
307#[derive(Debug, Clone)]
308pub struct ImportDefWithOffset {
309    definition: ImportDefinition,
310    offset: usize,
311}
312
313/// module composer.
314/// stores any modules that can be imported into a shader
315/// and builds the final shader
316#[derive(Debug)]
317pub struct Composer {
318    pub validate: bool,
319    pub module_sets: HashMap<String, ComposableModuleDefinition>,
320    pub module_index: HashMap<usize, String>,
321    pub capabilities: naga::valid::Capabilities,
322    preprocessor: Preprocessor,
323    check_decoration_regex: Regex,
324    undecorate_regex: Regex,
325    virtual_fn_regex: Regex,
326    override_fn_regex: Regex,
327    undecorate_override_regex: Regex,
328    auto_binding_regex: Regex,
329    auto_binding_index: u32,
330}
331
332// shift for module index
333// 21 gives
334//   max size for shader of 2m characters
335//   max 2048 modules
336const SPAN_SHIFT: usize = 21;
337
338impl Default for Composer {
339    fn default() -> Self {
340        Self {
341            validate: true,
342            capabilities: Default::default(),
343            module_sets: Default::default(),
344            module_index: Default::default(),
345            preprocessor: Preprocessor::default(),
346            check_decoration_regex: Regex::new(
347                format!(
348                    "({}|{})",
349                    regex_syntax::escape(DECORATION_PRE),
350                    regex_syntax::escape(DECORATION_OVERRIDE_PRE)
351                )
352                .as_str(),
353            )
354            .unwrap(),
355            undecorate_regex: Regex::new(
356                format!(
357                    r"(\x1B\[\d+\w)?([\w\d_]+){}([A-Z0-9]*){}",
358                    regex_syntax::escape(DECORATION_PRE),
359                    regex_syntax::escape(DECORATION_POST)
360                )
361                .as_str(),
362            )
363            .unwrap(),
364            virtual_fn_regex: Regex::new(
365                r"(?P<lead>[\s]*virtual\s+fn\s+)(?P<function>[^\s]+)(?P<trail>\s*)\(",
366            )
367            .unwrap(),
368            override_fn_regex: Regex::new(
369                format!(
370                    r"(override\s+fn\s+)([^\s]+){}([\w\d]+){}(\s*)\(",
371                    regex_syntax::escape(DECORATION_PRE),
372                    regex_syntax::escape(DECORATION_POST)
373                )
374                .as_str(),
375            )
376            .unwrap(),
377            undecorate_override_regex: Regex::new(
378                format!(
379                    "{}([A-Z0-9]*){}",
380                    regex_syntax::escape(DECORATION_OVERRIDE_PRE),
381                    regex_syntax::escape(DECORATION_POST)
382                )
383                .as_str(),
384            )
385            .unwrap(),
386            auto_binding_regex: Regex::new(r"@binding\(auto\)").unwrap(),
387            auto_binding_index: 0,
388        }
389    }
390}
391
392const DECORATION_PRE: &str = "X_naga_oil_mod_X";
393const DECORATION_POST: &str = "X";
394
395// must be same length as DECORATION_PRE for spans to work
396const DECORATION_OVERRIDE_PRE: &str = "X_naga_oil_vrt_X";
397
398struct IrBuildResult {
399    module: naga::Module,
400    start_offset: usize,
401    override_functions: IndexMap<String, Vec<String>>,
402}
403
404impl Composer {
405    pub fn decorated_name(module_name: Option<&str>, item_name: &str) -> String {
406        match module_name {
407            Some(module_name) => format!("{}{}", item_name, Self::decorate(module_name)),
408            None => item_name.to_owned(),
409        }
410    }
411
412    fn decorate(module: &str) -> String {
413        let encoded = data_encoding::BASE32_NOPAD.encode(module.as_bytes());
414        format!("{DECORATION_PRE}{encoded}{DECORATION_POST}")
415    }
416
417    fn decode(from: &str) -> String {
418        String::from_utf8(data_encoding::BASE32_NOPAD.decode(from.as_bytes()).unwrap()).unwrap()
419    }
420
421    /// Shorthand for creating a naga validator.
422    fn create_validator(&self) -> naga::valid::Validator {
423        naga::valid::Validator::new(naga::valid::ValidationFlags::all(), self.capabilities)
424    }
425
426    fn undecorate(&self, string: &str) -> String {
427        let undecor = self
428            .undecorate_regex
429            .replace_all(string, |caps: &regex::Captures| {
430                format!(
431                    "{}{}::{}",
432                    caps.get(1).map(|cc| cc.as_str()).unwrap_or(""),
433                    Self::decode(caps.get(3).unwrap().as_str()),
434                    caps.get(2).unwrap().as_str()
435                )
436            });
437
438        let undecor =
439            self.undecorate_override_regex
440                .replace_all(&undecor, |caps: &regex::Captures| {
441                    format!(
442                        "override fn {}::",
443                        Self::decode(caps.get(1).unwrap().as_str())
444                    )
445                });
446
447        undecor.to_string()
448    }
449
450    fn sanitize_and_set_auto_bindings(&mut self, source: &str) -> String {
451        let mut substituted_source = source.replace("\r\n", "\n").replace('\r', "\n");
452        if !substituted_source.ends_with('\n') {
453            substituted_source.push('\n');
454        }
455
456        // replace @binding(auto) with an incrementing index
457        struct AutoBindingReplacer<'a> {
458            auto: &'a mut u32,
459        }
460
461        impl<'a> regex::Replacer for AutoBindingReplacer<'a> {
462            fn replace_append(&mut self, _: &regex::Captures<'_>, dst: &mut String) {
463                dst.push_str(&format!("@binding({})", self.auto));
464                *self.auto += 1;
465            }
466        }
467
468        let substituted_source = self.auto_binding_regex.replace_all(
469            &substituted_source,
470            AutoBindingReplacer {
471                auto: &mut self.auto_binding_index,
472            },
473        );
474
475        substituted_source.into_owned()
476    }
477
478    fn naga_to_string(
479        &self,
480        naga_module: &mut naga::Module,
481        language: ShaderLanguage,
482        #[allow(unused)] header_for: &str, // Only used when GLSL is enabled
483    ) -> Result<String, ComposerErrorInner> {
484        // TODO: cache headers again
485        let info = self
486            .create_validator()
487            .validate(naga_module)
488            .map_err(ComposerErrorInner::HeaderValidationError)?;
489
490        match language {
491            ShaderLanguage::Wgsl => naga::back::wgsl::write_string(
492                naga_module,
493                &info,
494                naga::back::wgsl::WriterFlags::EXPLICIT_TYPES,
495            )
496            .map_err(ComposerErrorInner::WgslBackError),
497            #[cfg(feature = "glsl")]
498            ShaderLanguage::Glsl => {
499                let vec4 = naga_module.types.insert(
500                    naga::Type {
501                        name: None,
502                        inner: naga::TypeInner::Vector {
503                            size: naga::VectorSize::Quad,
504                            scalar: naga::Scalar::F32,
505                        },
506                    },
507                    naga::Span::UNDEFINED,
508                );
509                // add a dummy entry point for glsl headers
510                let dummy_entry_point = "dummy_module_entry_point".to_owned();
511                let func = naga::Function {
512                    name: Some(dummy_entry_point.clone()),
513                    arguments: Default::default(),
514                    result: Some(naga::FunctionResult {
515                        ty: vec4,
516                        binding: Some(naga::Binding::BuiltIn(naga::BuiltIn::Position {
517                            invariant: false,
518                        })),
519                    }),
520                    local_variables: Default::default(),
521                    expressions: Default::default(),
522                    named_expressions: Default::default(),
523                    body: Default::default(),
524                };
525                let ep = EntryPoint {
526                    name: dummy_entry_point.clone(),
527                    stage: naga::ShaderStage::Vertex,
528                    function: func,
529                    early_depth_test: None,
530                    workgroup_size: [0, 0, 0],
531                };
532
533                naga_module.entry_points.push(ep);
534
535                let info = self
536                    .create_validator()
537                    .validate(naga_module)
538                    .map_err(ComposerErrorInner::HeaderValidationError)?;
539
540                let mut string = String::new();
541                let options = naga::back::glsl::Options {
542                    version: naga::back::glsl::Version::Desktop(450),
543                    writer_flags: naga::back::glsl::WriterFlags::INCLUDE_UNUSED_ITEMS,
544                    ..Default::default()
545                };
546                let pipeline_options = naga::back::glsl::PipelineOptions {
547                    shader_stage: naga::ShaderStage::Vertex,
548                    entry_point: dummy_entry_point,
549                    multiview: None,
550                };
551                let mut writer = naga::back::glsl::Writer::new(
552                    &mut string,
553                    naga_module,
554                    &info,
555                    &options,
556                    &pipeline_options,
557                    naga::proc::BoundsCheckPolicies::default(),
558                )
559                .map_err(ComposerErrorInner::GlslBackError)?;
560
561                writer.write().map_err(ComposerErrorInner::GlslBackError)?;
562
563                // strip version decl and main() impl
564                let lines: Vec<_> = string.lines().collect();
565                let string = lines[1..lines.len() - 3].join("\n");
566                trace!("glsl header for {}:\n\"\n{:?}\n\"", header_for, string);
567
568                Ok(string)
569            }
570        }
571    }
572
573    // build naga module for a given shader_def configuration. builds a minimal self-contained module built against headers for imports
574    fn create_module_ir(
575        &self,
576        name: &str,
577        source: String,
578        language: ShaderLanguage,
579        imports: &[ImportDefinition],
580        shader_defs: &HashMap<String, ShaderDefValue>,
581    ) -> Result<IrBuildResult, ComposerError> {
582        debug!("creating IR for {} with defs: {:?}", name, shader_defs);
583
584        let mut module_string = match language {
585            ShaderLanguage::Wgsl => String::new(),
586            #[cfg(feature = "glsl")]
587            ShaderLanguage::Glsl => String::from("#version 450\n"),
588        };
589
590        let mut override_functions: IndexMap<String, Vec<String>> = IndexMap::default();
591        let mut added_imports: HashSet<String> = HashSet::new();
592        let mut header_module = DerivedModule::default();
593
594        for import in imports {
595            if added_imports.contains(&import.import) {
596                continue;
597            }
598            // add to header module
599            self.add_import(
600                &mut header_module,
601                import,
602                shader_defs,
603                true,
604                &mut added_imports,
605            );
606
607            // // we must have ensured these exist with Composer::ensure_imports()
608            trace!("looking for {}", import.import);
609            let import_module_set = self.module_sets.get(&import.import).unwrap();
610            trace!("with defs {:?}", shader_defs);
611            let module = import_module_set.get_module(shader_defs).unwrap();
612            trace!("ok");
613
614            // gather overrides
615            if !module.override_functions.is_empty() {
616                for (original, replacements) in &module.override_functions {
617                    match override_functions.entry(original.clone()) {
618                        indexmap::map::Entry::Occupied(o) => {
619                            let existing = o.into_mut();
620                            let new_replacements: Vec<_> = replacements
621                                .iter()
622                                .filter(|rep| !existing.contains(rep))
623                                .cloned()
624                                .collect();
625                            existing.extend(new_replacements);
626                        }
627                        indexmap::map::Entry::Vacant(v) => {
628                            v.insert(replacements.clone());
629                        }
630                    }
631                }
632            }
633        }
634
635        let composed_header = self
636            .naga_to_string(&mut header_module.into(), language, name)
637            .map_err(|inner| ComposerError {
638                inner,
639                source: ErrSource::Module {
640                    name: name.to_owned(),
641                    offset: 0,
642                    defs: shader_defs.clone(),
643                },
644            })?;
645        module_string.push_str(&composed_header);
646
647        let start_offset = module_string.len();
648
649        module_string.push_str(&source);
650
651        trace!(
652            "parsing {}: {}, header len {}, total len {}",
653            name,
654            module_string,
655            start_offset,
656            module_string.len()
657        );
658        let module = match language {
659            ShaderLanguage::Wgsl => naga::front::wgsl::parse_str(&module_string).map_err(|e| {
660                debug!("full err'd source file: \n---\n{}\n---", module_string);
661                ComposerError {
662                    inner: ComposerErrorInner::WgslParseError(e),
663                    source: ErrSource::Module {
664                        name: name.to_owned(),
665                        offset: start_offset,
666                        defs: shader_defs.clone(),
667                    },
668                }
669            })?,
670            #[cfg(feature = "glsl")]
671            ShaderLanguage::Glsl => naga::front::glsl::Frontend::default()
672                .parse(
673                    &naga::front::glsl::Options {
674                        stage: naga::ShaderStage::Vertex,
675                        defines: Default::default(),
676                    },
677                    &module_string,
678                )
679                .map_err(|e| {
680                    debug!("full err'd source file: \n---\n{}\n---", module_string);
681                    ComposerError {
682                        inner: ComposerErrorInner::GlslParseError(e),
683                        source: ErrSource::Module {
684                            name: name.to_owned(),
685                            offset: start_offset,
686                            defs: shader_defs.clone(),
687                        },
688                    }
689                })?,
690        };
691
692        Ok(IrBuildResult {
693            module,
694            start_offset,
695            override_functions,
696        })
697    }
698
699    // check that identifiers exported by a module do not get modified in string export
700    fn validate_identifiers(
701        source_ir: &naga::Module,
702        lang: ShaderLanguage,
703        header: &str,
704        module_decoration: &str,
705        owned_types: &HashSet<String>,
706    ) -> Result<(), ComposerErrorInner> {
707        // TODO: remove this once glsl front support is complete
708        #[cfg(feature = "glsl")]
709        if lang == ShaderLanguage::Glsl {
710            return Ok(());
711        }
712
713        let recompiled = match lang {
714            ShaderLanguage::Wgsl => naga::front::wgsl::parse_str(header).unwrap(),
715            #[cfg(feature = "glsl")]
716            ShaderLanguage::Glsl => naga::front::glsl::Frontend::default()
717                .parse(
718                    &naga::front::glsl::Options {
719                        stage: naga::ShaderStage::Vertex,
720                        defines: Default::default(),
721                    },
722                    &format!("{}\n{}", header, "void main() {}"),
723                )
724                .map_err(|e| {
725                    debug!("full err'd source file: \n---\n{header}\n---");
726                    ComposerErrorInner::GlslParseError(e)
727                })?,
728        };
729
730        let recompiled_types: IndexMap<_, _> = recompiled
731            .types
732            .iter()
733            .flat_map(|(h, ty)| ty.name.as_deref().map(|name| (name, h)))
734            .collect();
735        for (h, ty) in source_ir.types.iter() {
736            if let Some(name) = &ty.name {
737                let decorated_type_name = format!("{name}{module_decoration}");
738                if !owned_types.contains(&decorated_type_name) {
739                    continue;
740                }
741                match recompiled_types.get(decorated_type_name.as_str()) {
742                    Some(recompiled_h) => {
743                        if let naga::TypeInner::Struct { members, .. } = &ty.inner {
744                            let recompiled_ty = recompiled.types.get_handle(*recompiled_h).unwrap();
745                            let naga::TypeInner::Struct {
746                                members: recompiled_members,
747                                ..
748                            } = &recompiled_ty.inner
749                            else {
750                                panic!();
751                            };
752                            for (member, recompiled_member) in
753                                members.iter().zip(recompiled_members)
754                            {
755                                if member.name != recompiled_member.name {
756                                    return Err(ComposerErrorInner::InvalidIdentifier {
757                                        original: member.name.clone().unwrap_or_default(),
758                                        at: source_ir.types.get_span(h),
759                                    });
760                                }
761                            }
762                        }
763                    }
764                    None => {
765                        return Err(ComposerErrorInner::InvalidIdentifier {
766                            original: name.clone(),
767                            at: source_ir.types.get_span(h),
768                        })
769                    }
770                }
771            }
772        }
773
774        let recompiled_consts: HashSet<_> = recompiled
775            .constants
776            .iter()
777            .flat_map(|(_, c)| c.name.as_deref())
778            .filter(|name| name.ends_with(module_decoration))
779            .collect();
780        for (h, c) in source_ir.constants.iter() {
781            if let Some(name) = &c.name {
782                if name.ends_with(module_decoration) && !recompiled_consts.contains(name.as_str()) {
783                    return Err(ComposerErrorInner::InvalidIdentifier {
784                        original: name.clone(),
785                        at: source_ir.constants.get_span(h),
786                    });
787                }
788            }
789        }
790
791        let recompiled_globals: HashSet<_> = recompiled
792            .global_variables
793            .iter()
794            .flat_map(|(_, c)| c.name.as_deref())
795            .filter(|name| name.ends_with(module_decoration))
796            .collect();
797        for (h, gv) in source_ir.global_variables.iter() {
798            if let Some(name) = &gv.name {
799                if name.ends_with(module_decoration) && !recompiled_globals.contains(name.as_str())
800                {
801                    return Err(ComposerErrorInner::InvalidIdentifier {
802                        original: name.clone(),
803                        at: source_ir.global_variables.get_span(h),
804                    });
805                }
806            }
807        }
808
809        let recompiled_fns: HashSet<_> = recompiled
810            .functions
811            .iter()
812            .flat_map(|(_, c)| c.name.as_deref())
813            .filter(|name| name.ends_with(module_decoration))
814            .collect();
815        for (h, f) in source_ir.functions.iter() {
816            if let Some(name) = &f.name {
817                if name.ends_with(module_decoration) && !recompiled_fns.contains(name.as_str()) {
818                    return Err(ComposerErrorInner::InvalidIdentifier {
819                        original: name.clone(),
820                        at: source_ir.functions.get_span(h),
821                    });
822                }
823            }
824        }
825
826        Ok(())
827    }
828
829    // build a ComposableModule from a ComposableModuleDefinition, for a given set of shader defs
830    // - build the naga IR (against headers)
831    // - record any types/vars/constants/functions that are defined within this module
832    // - build headers for each supported language
833    #[allow(clippy::too_many_arguments)]
834    fn create_composable_module(
835        &mut self,
836        module_definition: &ComposableModuleDefinition,
837        module_decoration: String,
838        shader_defs: &HashMap<String, ShaderDefValue>,
839        create_headers: bool,
840        demote_entrypoints: bool,
841        source: &str,
842        imports: Vec<ImportDefWithOffset>,
843    ) -> Result<ComposableModule, ComposerError> {
844        let mut imports: Vec<_> = imports
845            .into_iter()
846            .map(|import_with_offset| import_with_offset.definition)
847            .collect();
848        imports.extend(module_definition.additional_imports.to_vec());
849
850        trace!(
851            "create composable module {}: source len {}",
852            module_definition.name,
853            source.len()
854        );
855
856        // record virtual/overridable functions
857        let mut virtual_functions: HashSet<String> = Default::default();
858        let source = self
859            .virtual_fn_regex
860            .replace_all(source, |cap: &regex::Captures| {
861                let target_function = cap.get(2).unwrap().as_str().to_owned();
862
863                let replacement_str = format!(
864                    "{}fn {}{}(",
865                    " ".repeat(cap.get(1).unwrap().range().len() - 3),
866                    target_function,
867                    " ".repeat(cap.get(3).unwrap().range().len()),
868                );
869
870                virtual_functions.insert(target_function);
871
872                replacement_str
873            });
874
875        // record and rename override functions
876        let mut local_override_functions: IndexMap<String, String> = Default::default();
877
878        #[cfg(not(feature = "override_any"))]
879        let mut override_error = None;
880
881        let source =
882            self.override_fn_regex
883                .replace_all(&source, |cap: &regex::Captures| {
884                    let target_module = cap.get(3).unwrap().as_str().to_owned();
885                    let target_function = cap.get(2).unwrap().as_str().to_owned();
886
887                    #[cfg(not(feature = "override_any"))]
888                    {
889                        let wrap_err = |inner: ComposerErrorInner| -> ComposerError {
890                            ComposerError {
891                                inner,
892                                source: ErrSource::Module {
893                                    name: module_definition.name.to_owned(),
894                                    offset: 0,
895                                    defs: shader_defs.clone(),
896                                },
897                            }
898                        };
899
900                        // ensure overrides are applied to virtual functions
901                        let raw_module_name = Self::decode(&target_module);
902                        let module_set = self.module_sets.get(&raw_module_name);
903
904                        match module_set {
905                            None => {
906                                // TODO this should be unreachable?
907                                let pos = cap.get(3).unwrap().start();
908                                override_error = Some(wrap_err(
909                                    ComposerErrorInner::ImportNotFound(raw_module_name, pos),
910                                ));
911                            }
912                            Some(module_set) => {
913                                let module = module_set.get_module(shader_defs).unwrap();
914                                if !module.virtual_functions.contains(&target_function) {
915                                    let pos = cap.get(2).unwrap().start();
916                                    override_error =
917                                        Some(wrap_err(ComposerErrorInner::OverrideNotVirtual {
918                                            name: target_function.clone(),
919                                            pos,
920                                        }));
921                                }
922                            }
923                        }
924                    }
925
926                    let base_name = format!(
927                        "{}{}{}{}",
928                        target_function.as_str(),
929                        DECORATION_PRE,
930                        target_module.as_str(),
931                        DECORATION_POST,
932                    );
933                    let rename = format!(
934                        "{}{}{}{}",
935                        target_function.as_str(),
936                        DECORATION_OVERRIDE_PRE,
937                        target_module.as_str(),
938                        DECORATION_POST,
939                    );
940
941                    let replacement_str = format!(
942                        "{}fn {}{}(",
943                        " ".repeat(cap.get(1).unwrap().range().len() - 3),
944                        rename,
945                        " ".repeat(cap.get(4).unwrap().range().len()),
946                    );
947
948                    local_override_functions.insert(rename, base_name);
949
950                    replacement_str
951                })
952                .to_string();
953
954        #[cfg(not(feature = "override_any"))]
955        if let Some(err) = override_error {
956            return Err(err);
957        }
958
959        trace!("local overrides: {:?}", local_override_functions);
960        trace!(
961            "create composable module {}: source len {}",
962            module_definition.name,
963            source.len()
964        );
965
966        let IrBuildResult {
967            module: mut source_ir,
968            start_offset,
969            mut override_functions,
970        } = self.create_module_ir(
971            &module_definition.name,
972            source,
973            module_definition.language,
974            &imports,
975            shader_defs,
976        )?;
977
978        // from here on errors need to be reported using the modified source with start_offset
979        let wrap_err = |inner: ComposerErrorInner| -> ComposerError {
980            ComposerError {
981                inner,
982                source: ErrSource::Module {
983                    name: module_definition.name.to_owned(),
984                    offset: start_offset,
985                    defs: shader_defs.clone(),
986                },
987            }
988        };
989
990        // add our local override to the total set of overrides for the given function
991        for (rename, base_name) in &local_override_functions {
992            override_functions
993                .entry(base_name.clone())
994                .or_default()
995                .push(format!("{rename}{module_decoration}"));
996        }
997
998        // rename and record owned items (except types which can't be mutably accessed)
999        let mut owned_constants = IndexMap::new();
1000        for (h, c) in source_ir.constants.iter_mut() {
1001            if let Some(name) = c.name.as_mut() {
1002                if !name.contains(DECORATION_PRE) {
1003                    *name = format!("{name}{module_decoration}");
1004                    owned_constants.insert(name.clone(), h);
1005                }
1006            }
1007        }
1008
1009        // These are naga/wgpu's pipeline override constants, not naga_oil's overrides
1010        let mut owned_pipeline_overrides = IndexMap::new();
1011        for (h, po) in source_ir.overrides.iter_mut() {
1012            if let Some(name) = po.name.as_mut() {
1013                if !name.contains(DECORATION_PRE) {
1014                    *name = format!("{name}{module_decoration}");
1015                    owned_pipeline_overrides.insert(name.clone(), h);
1016                }
1017            }
1018        }
1019
1020        let mut owned_vars = IndexMap::new();
1021        for (h, gv) in source_ir.global_variables.iter_mut() {
1022            if let Some(name) = gv.name.as_mut() {
1023                if !name.contains(DECORATION_PRE) {
1024                    *name = format!("{name}{module_decoration}");
1025
1026                    owned_vars.insert(name.clone(), h);
1027                }
1028            }
1029        }
1030
1031        let mut owned_functions = IndexMap::new();
1032        for (h_f, f) in source_ir.functions.iter_mut() {
1033            if let Some(name) = f.name.as_mut() {
1034                if !name.contains(DECORATION_PRE) {
1035                    *name = format!("{name}{module_decoration}");
1036
1037                    // create dummy header function
1038                    let header_function = naga::Function {
1039                        name: Some(name.clone()),
1040                        arguments: f.arguments.to_vec(),
1041                        result: f.result.clone(),
1042                        local_variables: Default::default(),
1043                        expressions: Default::default(),
1044                        named_expressions: Default::default(),
1045                        body: Default::default(),
1046                    };
1047
1048                    // record owned function
1049                    owned_functions.insert(name.clone(), (Some(h_f), header_function));
1050                }
1051            }
1052        }
1053
1054        if demote_entrypoints {
1055            // make normal functions out of the source entry points
1056            for ep in &mut source_ir.entry_points {
1057                ep.function.name = Some(format!(
1058                    "{}{}",
1059                    ep.function.name.as_deref().unwrap_or("main"),
1060                    module_decoration,
1061                ));
1062                let header_function = naga::Function {
1063                    name: ep.function.name.clone(),
1064                    arguments: ep
1065                        .function
1066                        .arguments
1067                        .iter()
1068                        .cloned()
1069                        .map(|arg| naga::FunctionArgument {
1070                            name: arg.name,
1071                            ty: arg.ty,
1072                            binding: None,
1073                        })
1074                        .collect(),
1075                    result: ep.function.result.clone().map(|res| naga::FunctionResult {
1076                        ty: res.ty,
1077                        binding: None,
1078                    }),
1079                    local_variables: Default::default(),
1080                    expressions: Default::default(),
1081                    named_expressions: Default::default(),
1082                    body: Default::default(),
1083                };
1084
1085                owned_functions.insert(ep.function.name.clone().unwrap(), (None, header_function));
1086            }
1087        };
1088
1089        let mut module_builder = DerivedModule::default();
1090        let mut header_builder = DerivedModule::default();
1091        module_builder.set_shader_source(&source_ir, 0);
1092        header_builder.set_shader_source(&source_ir, 0);
1093
1094        let mut owned_types = HashSet::new();
1095        for (h, ty) in source_ir.types.iter() {
1096            if let Some(name) = &ty.name {
1097                // we need to exclude autogenerated struct names, i.e. those that begin with "__"
1098                // "__" is a reserved prefix for naga so user variables cannot use it.
1099                if !name.contains(DECORATION_PRE) && !name.starts_with("__") {
1100                    let name = format!("{name}{module_decoration}");
1101                    owned_types.insert(name.clone());
1102                    // copy and rename types
1103                    module_builder.rename_type(&h, Some(name.clone()));
1104                    header_builder.rename_type(&h, Some(name));
1105                    continue;
1106                }
1107            }
1108
1109            // copy all required types
1110            module_builder.import_type(&h);
1111        }
1112
1113        // copy owned types into header and module
1114        for h in owned_constants.values() {
1115            header_builder.import_const(h);
1116            module_builder.import_const(h);
1117        }
1118
1119        for h in owned_pipeline_overrides.values() {
1120            header_builder.import_pipeline_override(h);
1121            module_builder.import_pipeline_override(h);
1122        }
1123
1124        for h in owned_vars.values() {
1125            header_builder.import_global(h);
1126            module_builder.import_global(h);
1127        }
1128
1129        // only stubs of owned functions into the header
1130        for (h_f, f) in owned_functions.values() {
1131            let span = h_f
1132                .map(|h_f| source_ir.functions.get_span(h_f))
1133                .unwrap_or(naga::Span::UNDEFINED);
1134            header_builder.import_function(f, span); // header stub function
1135        }
1136        // all functions into the module (note source_ir only contains stubs for imported functions)
1137        for (h_f, f) in source_ir.functions.iter() {
1138            let span = source_ir.functions.get_span(h_f);
1139            module_builder.import_function(f, span);
1140        }
1141        // // including entry points as vanilla functions if required
1142        if demote_entrypoints {
1143            for ep in &source_ir.entry_points {
1144                let mut f = ep.function.clone();
1145                f.arguments = f
1146                    .arguments
1147                    .into_iter()
1148                    .map(|arg| naga::FunctionArgument {
1149                        name: arg.name,
1150                        ty: arg.ty,
1151                        binding: None,
1152                    })
1153                    .collect();
1154                f.result = f.result.map(|res| naga::FunctionResult {
1155                    ty: res.ty,
1156                    binding: None,
1157                });
1158
1159                module_builder.import_function(&f, naga::Span::UNDEFINED);
1160                // todo figure out how to get span info for entrypoints
1161            }
1162        }
1163
1164        let module_ir = module_builder.into_module_with_entrypoints();
1165        let mut header_ir: naga::Module = header_builder.into();
1166
1167        if self.validate && create_headers {
1168            // check that identifiers haven't been renamed
1169            #[allow(clippy::single_element_loop)]
1170            for language in [
1171                ShaderLanguage::Wgsl,
1172                #[cfg(feature = "glsl")]
1173                ShaderLanguage::Glsl,
1174            ] {
1175                let header = self
1176                    .naga_to_string(&mut header_ir, language, &module_definition.name)
1177                    .map_err(wrap_err)?;
1178                Self::validate_identifiers(
1179                    &source_ir,
1180                    language,
1181                    &header,
1182                    &module_decoration,
1183                    &owned_types,
1184                )
1185                .map_err(wrap_err)?;
1186            }
1187        }
1188
1189        let composable_module = ComposableModule {
1190            decorated_name: module_decoration,
1191            imports,
1192            owned_types,
1193            owned_constants: owned_constants.into_keys().collect(),
1194            owned_vars: owned_vars.into_keys().collect(),
1195            owned_functions: owned_functions.into_keys().collect(),
1196            virtual_functions,
1197            override_functions,
1198            module_ir,
1199            header_ir,
1200            start_offset,
1201        };
1202
1203        Ok(composable_module)
1204    }
1205
1206    // shunt all data owned by a composable into a derived module
1207    fn add_composable_data<'a>(
1208        derived: &mut DerivedModule<'a>,
1209        composable: &'a ComposableModule,
1210        items: Option<&Vec<String>>,
1211        span_offset: usize,
1212        header: bool,
1213    ) {
1214        let items: Option<HashSet<String>> = items.map(|items| {
1215            items
1216                .iter()
1217                .map(|item| format!("{}{}", item, composable.decorated_name))
1218                .collect()
1219        });
1220        let items = items.as_ref();
1221
1222        let source_ir = match header {
1223            true => &composable.header_ir,
1224            false => &composable.module_ir,
1225        };
1226
1227        derived.set_shader_source(source_ir, span_offset);
1228
1229        for (h, ty) in source_ir.types.iter() {
1230            if let Some(name) = &ty.name {
1231                if composable.owned_types.contains(name)
1232                    && items.map_or(true, |items| items.contains(name))
1233                {
1234                    derived.import_type(&h);
1235                }
1236            }
1237        }
1238
1239        for (h, c) in source_ir.constants.iter() {
1240            if let Some(name) = &c.name {
1241                if composable.owned_constants.contains(name)
1242                    && items.map_or(true, |items| items.contains(name))
1243                {
1244                    derived.import_const(&h);
1245                }
1246            }
1247        }
1248
1249        for (h, po) in source_ir.overrides.iter() {
1250            if let Some(name) = &po.name {
1251                if composable.owned_functions.contains(name)
1252                    && items.map_or(true, |items| items.contains(name))
1253                {
1254                    derived.import_pipeline_override(&h);
1255                }
1256            }
1257        }
1258
1259        for (h, v) in source_ir.global_variables.iter() {
1260            if let Some(name) = &v.name {
1261                if composable.owned_vars.contains(name)
1262                    && items.map_or(true, |items| items.contains(name))
1263                {
1264                    derived.import_global(&h);
1265                }
1266            }
1267        }
1268
1269        for (h_f, f) in source_ir.functions.iter() {
1270            if let Some(name) = &f.name {
1271                if composable.owned_functions.contains(name)
1272                    && (items.map_or(true, |items| items.contains(name))
1273                        || composable
1274                            .override_functions
1275                            .values()
1276                            .any(|v| v.contains(name)))
1277                {
1278                    let span = composable.module_ir.functions.get_span(h_f);
1279                    derived.import_function_if_new(f, span);
1280                }
1281            }
1282        }
1283
1284        derived.clear_shader_source();
1285    }
1286
1287    // add an import (and recursive imports) into a derived module
1288    fn add_import<'a>(
1289        &'a self,
1290        derived: &mut DerivedModule<'a>,
1291        import: &ImportDefinition,
1292        shader_defs: &HashMap<String, ShaderDefValue>,
1293        header: bool,
1294        already_added: &mut HashSet<String>,
1295    ) {
1296        if already_added.contains(&import.import) {
1297            trace!("skipping {}, already added", import.import);
1298            return;
1299        }
1300
1301        let import_module_set = self.module_sets.get(&import.import).unwrap();
1302        let module = import_module_set.get_module(shader_defs).unwrap();
1303
1304        for import in &module.imports {
1305            self.add_import(derived, import, shader_defs, header, already_added);
1306        }
1307
1308        Self::add_composable_data(
1309            derived,
1310            module,
1311            Some(&import.items),
1312            import_module_set.module_index << SPAN_SHIFT,
1313            header,
1314        );
1315    }
1316
1317    fn ensure_import(
1318        &mut self,
1319        module_set: &ComposableModuleDefinition,
1320        shader_defs: &HashMap<String, ShaderDefValue>,
1321    ) -> Result<ComposableModule, EnsureImportsError> {
1322        let PreprocessOutput {
1323            preprocessed_source,
1324            imports,
1325        } = self
1326            .preprocessor
1327            .preprocess(&module_set.sanitized_source, shader_defs)
1328            .map_err(|inner| {
1329                EnsureImportsError::from(ComposerError {
1330                    inner,
1331                    source: ErrSource::Module {
1332                        name: module_set.name.to_owned(),
1333                        offset: 0,
1334                        defs: shader_defs.clone(),
1335                    },
1336                })
1337            })?;
1338
1339        self.ensure_imports(imports.iter().map(|import| &import.definition), shader_defs)?;
1340        self.ensure_imports(&module_set.additional_imports, shader_defs)?;
1341
1342        self.create_composable_module(
1343            module_set,
1344            Self::decorate(&module_set.name),
1345            shader_defs,
1346            true,
1347            true,
1348            &preprocessed_source,
1349            imports,
1350        )
1351        .map_err(|err| err.into())
1352    }
1353
1354    // build required ComposableModules for a given set of shader_defs
1355    fn ensure_imports<'a>(
1356        &mut self,
1357        imports: impl IntoIterator<Item = &'a ImportDefinition>,
1358        shader_defs: &HashMap<String, ShaderDefValue>,
1359    ) -> Result<(), EnsureImportsError> {
1360        for ImportDefinition { import, .. } in imports.into_iter() {
1361            let Some(module_set) = self.module_sets.get(import) else {
1362                return Err(EnsureImportsError::MissingImport(import.to_owned()));
1363            };
1364            if module_set.get_module(shader_defs).is_some() {
1365                continue;
1366            }
1367
1368            // we need to build the module
1369            // take the set so we can recurse without borrowing
1370            let (set_key, mut module_set) = self.module_sets.remove_entry(import).unwrap();
1371
1372            match self.ensure_import(&module_set, shader_defs) {
1373                Ok(module) => {
1374                    module_set.insert_module(shader_defs, module);
1375                    self.module_sets.insert(set_key, module_set);
1376                }
1377                Err(e) => {
1378                    self.module_sets.insert(set_key, module_set);
1379                    return Err(e);
1380                }
1381            }
1382        }
1383
1384        Ok(())
1385    }
1386}
1387
1388pub enum EnsureImportsError {
1389    MissingImport(String),
1390    ComposerError(ComposerError),
1391}
1392
1393impl EnsureImportsError {
1394    fn into_composer_error(self, err_source: ErrSource) -> ComposerError {
1395        match self {
1396            EnsureImportsError::MissingImport(import) => ComposerError {
1397                inner: ComposerErrorInner::ImportNotFound(import.to_owned(), 0),
1398                source: err_source,
1399            },
1400            EnsureImportsError::ComposerError(err) => err,
1401        }
1402    }
1403}
1404
1405impl From<ComposerError> for EnsureImportsError {
1406    fn from(value: ComposerError) -> Self {
1407        EnsureImportsError::ComposerError(value)
1408    }
1409}
1410
1411#[derive(Default)]
1412pub struct ComposableModuleDescriptor<'a> {
1413    pub source: &'a str,
1414    pub file_path: &'a str,
1415    pub language: ShaderLanguage,
1416    pub as_name: Option<String>,
1417    pub additional_imports: &'a [ImportDefinition],
1418    pub shader_defs: HashMap<String, ShaderDefValue>,
1419}
1420
1421#[derive(Default)]
1422pub struct NagaModuleDescriptor<'a> {
1423    pub source: &'a str,
1424    pub file_path: &'a str,
1425    pub shader_type: ShaderType,
1426    pub shader_defs: HashMap<String, ShaderDefValue>,
1427    pub additional_imports: &'a [ImportDefinition],
1428}
1429
1430// public api
1431impl Composer {
1432    /// create a non-validating composer.
1433    /// validation errors in the final shader will not be caught, and errors resulting from their
1434    /// use will have bad span data, so codespan reporting will fail.
1435    /// use default() to create a validating composer.
1436    pub fn non_validating() -> Self {
1437        Self {
1438            validate: false,
1439            ..Default::default()
1440        }
1441    }
1442
1443    /// specify capabilities to be used for naga module generation.
1444    /// purges any existing modules
1445    /// See https://github.com/gfx-rs/wgpu/blob/d9c054c645af0ea9ef81617c3e762fbf0f3fecda/wgpu-core/src/device/mod.rs#L515
1446    /// for how to set the subgroup_stages value.
1447    pub fn with_capabilities(self, capabilities: naga::valid::Capabilities) -> Self {
1448        Self {
1449            capabilities,
1450            validate: self.validate,
1451            ..Default::default()
1452        }
1453    }
1454
1455    /// check if a module with the given name has been added
1456    pub fn contains_module(&self, module_name: &str) -> bool {
1457        self.module_sets.contains_key(module_name)
1458    }
1459
1460    /// add a composable module to the composer.
1461    /// all modules imported by this module must already have been added
1462    pub fn add_composable_module(
1463        &mut self,
1464        desc: ComposableModuleDescriptor,
1465    ) -> Result<&ComposableModuleDefinition, ComposerError> {
1466        let ComposableModuleDescriptor {
1467            source,
1468            file_path,
1469            language,
1470            as_name,
1471            additional_imports,
1472            mut shader_defs,
1473        } = desc;
1474
1475        // reject a module containing the DECORATION strings
1476        if let Some(decor) = self.check_decoration_regex.find(source) {
1477            return Err(ComposerError {
1478                inner: ComposerErrorInner::DecorationInSource(decor.range()),
1479                source: ErrSource::Constructing {
1480                    path: file_path.to_owned(),
1481                    source: source.to_owned(),
1482                    offset: 0,
1483                },
1484            });
1485        }
1486
1487        let substituted_source = self.sanitize_and_set_auto_bindings(source);
1488
1489        let PreprocessorMetaData {
1490            name: module_name,
1491            mut imports,
1492            mut effective_defs,
1493            ..
1494        } = self
1495            .preprocessor
1496            .get_preprocessor_metadata(&substituted_source, false)
1497            .map_err(|inner| ComposerError {
1498                inner,
1499                source: ErrSource::Constructing {
1500                    path: file_path.to_owned(),
1501                    source: source.to_owned(),
1502                    offset: 0,
1503                },
1504            })?;
1505        let module_name = as_name.or(module_name);
1506        if module_name.is_none() {
1507            return Err(ComposerError {
1508                inner: ComposerErrorInner::NoModuleName,
1509                source: ErrSource::Constructing {
1510                    path: file_path.to_owned(),
1511                    source: source.to_owned(),
1512                    offset: 0,
1513                },
1514            });
1515        }
1516        let module_name = module_name.unwrap();
1517
1518        debug!(
1519            "adding module definition for {} with defs: {:?}",
1520            module_name, shader_defs
1521        );
1522
1523        // add custom imports
1524        let additional_imports = additional_imports.to_vec();
1525        imports.extend(
1526            additional_imports
1527                .iter()
1528                .cloned()
1529                .map(|def| ImportDefWithOffset {
1530                    definition: def,
1531                    offset: 0,
1532                }),
1533        );
1534
1535        for import in &imports {
1536            // we require modules already added so that we can capture the shader_defs that may impact us by impacting our dependencies
1537            let module_set = self
1538                .module_sets
1539                .get(&import.definition.import)
1540                .ok_or_else(|| ComposerError {
1541                    inner: ComposerErrorInner::ImportNotFound(
1542                        import.definition.import.clone(),
1543                        import.offset,
1544                    ),
1545                    source: ErrSource::Constructing {
1546                        path: file_path.to_owned(),
1547                        source: substituted_source.to_owned(),
1548                        offset: 0,
1549                    },
1550                })?;
1551            effective_defs.extend(module_set.effective_defs.iter().cloned());
1552            shader_defs.extend(
1553                module_set
1554                    .shader_defs
1555                    .iter()
1556                    .map(|def| (def.0.clone(), *def.1)),
1557            );
1558        }
1559
1560        // remove defs that are already specified through our imports
1561        effective_defs.retain(|name| !shader_defs.contains_key(name));
1562
1563        // can't gracefully report errors for more modules. perhaps this should be a warning
1564        assert!((self.module_sets.len() as u32) < u32::MAX >> SPAN_SHIFT);
1565        let module_index = self.module_sets.len() + 1;
1566
1567        let module_set = ComposableModuleDefinition {
1568            name: module_name.clone(),
1569            sanitized_source: substituted_source,
1570            file_path: file_path.to_owned(),
1571            language,
1572            effective_defs: effective_defs.into_iter().collect(),
1573            all_imports: imports.into_iter().map(|id| id.definition.import).collect(),
1574            additional_imports,
1575            shader_defs,
1576            module_index,
1577            modules: Default::default(),
1578        };
1579
1580        // invalidate dependent modules if this module already exists
1581        self.remove_composable_module(&module_name);
1582
1583        self.module_sets.insert(module_name.clone(), module_set);
1584        self.module_index.insert(module_index, module_name.clone());
1585        Ok(self.module_sets.get(&module_name).unwrap())
1586    }
1587
1588    /// remove a composable module. also removes modules that depend on this module, as we cannot be sure about
1589    /// the completeness of their effective shader defs any more...
1590    pub fn remove_composable_module(&mut self, module_name: &str) {
1591        // todo this could be improved by making effective defs an Option<HashSet> and populating on demand?
1592        let mut dependent_sets = Vec::new();
1593
1594        if self.module_sets.remove(module_name).is_some() {
1595            dependent_sets.extend(self.module_sets.iter().filter_map(|(dependent_name, set)| {
1596                if set.all_imports.contains(module_name) {
1597                    Some(dependent_name.clone())
1598                } else {
1599                    None
1600                }
1601            }));
1602        }
1603
1604        for dependent_set in dependent_sets {
1605            self.remove_composable_module(&dependent_set);
1606        }
1607    }
1608
1609    /// build a naga shader module
1610    pub fn make_naga_module(
1611        &mut self,
1612        desc: NagaModuleDescriptor,
1613    ) -> Result<naga::Module, ComposerError> {
1614        let NagaModuleDescriptor {
1615            source,
1616            file_path,
1617            shader_type,
1618            mut shader_defs,
1619            additional_imports,
1620        } = desc;
1621
1622        let sanitized_source = self.sanitize_and_set_auto_bindings(source);
1623
1624        let PreprocessorMetaData { name, defines, .. } = self
1625            .preprocessor
1626            .get_preprocessor_metadata(&sanitized_source, true)
1627            .map_err(|inner| ComposerError {
1628                inner,
1629                source: ErrSource::Constructing {
1630                    path: file_path.to_owned(),
1631                    source: sanitized_source.to_owned(),
1632                    offset: 0,
1633                },
1634            })?;
1635        shader_defs.extend(defines);
1636
1637        let name = name.unwrap_or_default();
1638
1639        let PreprocessOutput { imports, .. } = self
1640            .preprocessor
1641            .preprocess(&sanitized_source, &shader_defs)
1642            .map_err(|inner| ComposerError {
1643                inner,
1644                source: ErrSource::Constructing {
1645                    path: file_path.to_owned(),
1646                    source: sanitized_source.to_owned(),
1647                    offset: 0,
1648                },
1649            })?;
1650
1651        // make sure imports have been added
1652        // and gather additional defs specified at module level
1653        for (import_name, offset) in imports
1654            .iter()
1655            .map(|id| (&id.definition.import, id.offset))
1656            .chain(additional_imports.iter().map(|ai| (&ai.import, 0)))
1657        {
1658            if let Some(module_set) = self.module_sets.get(import_name) {
1659                for (def, value) in &module_set.shader_defs {
1660                    if let Some(prior_value) = shader_defs.insert(def.clone(), *value) {
1661                        if prior_value != *value {
1662                            return Err(ComposerError {
1663                                inner: ComposerErrorInner::InconsistentShaderDefValue {
1664                                    def: def.clone(),
1665                                },
1666                                source: ErrSource::Constructing {
1667                                    path: file_path.to_owned(),
1668                                    source: sanitized_source.to_owned(),
1669                                    offset: 0,
1670                                },
1671                            });
1672                        }
1673                    }
1674                }
1675            } else {
1676                return Err(ComposerError {
1677                    inner: ComposerErrorInner::ImportNotFound(import_name.clone(), offset),
1678                    source: ErrSource::Constructing {
1679                        path: file_path.to_owned(),
1680                        source: sanitized_source.to_owned(),
1681                        offset: 0,
1682                    },
1683                });
1684            }
1685        }
1686        self.ensure_imports(
1687            imports.iter().map(|import| &import.definition),
1688            &shader_defs,
1689        )
1690        .map_err(|err| {
1691            err.into_composer_error(ErrSource::Constructing {
1692                path: file_path.to_owned(),
1693                source: sanitized_source.to_owned(),
1694                offset: 0,
1695            })
1696        })?;
1697        self.ensure_imports(additional_imports, &shader_defs)
1698            .map_err(|err| {
1699                err.into_composer_error(ErrSource::Constructing {
1700                    path: file_path.to_owned(),
1701                    source: sanitized_source.to_owned(),
1702                    offset: 0,
1703                })
1704            })?;
1705
1706        let definition = ComposableModuleDefinition {
1707            name,
1708            sanitized_source: sanitized_source.clone(),
1709            language: shader_type.into(),
1710            file_path: file_path.to_owned(),
1711            module_index: 0,
1712            additional_imports: additional_imports.to_vec(),
1713            // we don't care about these for creating a top-level module
1714            effective_defs: Default::default(),
1715            all_imports: Default::default(),
1716            shader_defs: Default::default(),
1717            modules: Default::default(),
1718        };
1719
1720        let PreprocessOutput {
1721            preprocessed_source,
1722            imports,
1723        } = self
1724            .preprocessor
1725            .preprocess(&sanitized_source, &shader_defs)
1726            .map_err(|inner| ComposerError {
1727                inner,
1728                source: ErrSource::Constructing {
1729                    path: file_path.to_owned(),
1730                    source: sanitized_source,
1731                    offset: 0,
1732                },
1733            })?;
1734
1735        let composable = self
1736            .create_composable_module(
1737                &definition,
1738                String::from(""),
1739                &shader_defs,
1740                false,
1741                false,
1742                &preprocessed_source,
1743                imports,
1744            )
1745            .map_err(|e| ComposerError {
1746                inner: e.inner,
1747                source: ErrSource::Constructing {
1748                    path: definition.file_path.to_owned(),
1749                    source: preprocessed_source.clone(),
1750                    offset: e.source.offset(),
1751                },
1752            })?;
1753
1754        let mut derived = DerivedModule::default();
1755
1756        let mut already_added = Default::default();
1757        for import in &composable.imports {
1758            self.add_import(
1759                &mut derived,
1760                import,
1761                &shader_defs,
1762                false,
1763                &mut already_added,
1764            );
1765        }
1766
1767        Self::add_composable_data(&mut derived, &composable, None, 0, false);
1768
1769        let stage = match shader_type {
1770            #[cfg(feature = "glsl")]
1771            ShaderType::GlslVertex => Some(naga::ShaderStage::Vertex),
1772            #[cfg(feature = "glsl")]
1773            ShaderType::GlslFragment => Some(naga::ShaderStage::Fragment),
1774            _ => None,
1775        };
1776
1777        let mut entry_points = Vec::default();
1778        derived.set_shader_source(&composable.module_ir, 0);
1779        for ep in &composable.module_ir.entry_points {
1780            let mapped_func = derived.localize_function(&ep.function);
1781            entry_points.push(EntryPoint {
1782                name: ep.name.clone(),
1783                function: mapped_func,
1784                stage: stage.unwrap_or(ep.stage),
1785                early_depth_test: ep.early_depth_test,
1786                workgroup_size: ep.workgroup_size,
1787            });
1788        }
1789
1790        let mut naga_module = naga::Module {
1791            entry_points,
1792            ..derived.into()
1793        };
1794
1795        // apply overrides
1796        if !composable.override_functions.is_empty() {
1797            let mut redirect = Redirector::new(naga_module);
1798
1799            for (base_function, overrides) in composable.override_functions {
1800                let mut omit = HashSet::default();
1801
1802                let mut original = base_function;
1803                for replacement in overrides {
1804                    let (_h_orig, _h_replace) = redirect
1805                        .redirect_function(&original, &replacement, &omit)
1806                        .map_err(|e| ComposerError {
1807                            inner: e.into(),
1808                            source: ErrSource::Constructing {
1809                                path: file_path.to_owned(),
1810                                source: preprocessed_source.clone(),
1811                                offset: composable.start_offset,
1812                            },
1813                        })?;
1814                    omit.insert(replacement.clone());
1815                    original = replacement;
1816                }
1817            }
1818
1819            naga_module = redirect.into_module().map_err(|e| ComposerError {
1820                inner: e.into(),
1821                source: ErrSource::Constructing {
1822                    path: file_path.to_owned(),
1823                    source: preprocessed_source.clone(),
1824                    offset: composable.start_offset,
1825                },
1826            })?;
1827        }
1828
1829        // validation
1830        if self.validate {
1831            let info = self.create_validator().validate(&naga_module);
1832            match info {
1833                Ok(_) => Ok(naga_module),
1834                Err(e) => {
1835                    let original_span = e.spans().last();
1836                    let err_source = match original_span.and_then(|(span, _)| span.to_range()) {
1837                        Some(rng) => {
1838                            let module_index = rng.start >> SPAN_SHIFT;
1839                            match module_index {
1840                                0 => ErrSource::Constructing {
1841                                    path: file_path.to_owned(),
1842                                    source: preprocessed_source.clone(),
1843                                    offset: composable.start_offset,
1844                                },
1845                                _ => {
1846                                    let module_name =
1847                                        self.module_index.get(&module_index).unwrap().clone();
1848                                    let offset = self
1849                                        .module_sets
1850                                        .get(&module_name)
1851                                        .unwrap()
1852                                        .get_module(&shader_defs)
1853                                        .unwrap()
1854                                        .start_offset;
1855                                    ErrSource::Module {
1856                                        name: module_name,
1857                                        offset,
1858                                        defs: shader_defs.clone(),
1859                                    }
1860                                }
1861                            }
1862                        }
1863                        None => ErrSource::Constructing {
1864                            path: file_path.to_owned(),
1865                            source: preprocessed_source.clone(),
1866                            offset: composable.start_offset,
1867                        },
1868                    };
1869
1870                    Err(ComposerError {
1871                        inner: ComposerErrorInner::ShaderValidationError(e),
1872                        source: err_source,
1873                    })
1874                }
1875            }
1876        } else {
1877            Ok(naga_module)
1878        }
1879    }
1880}
1881
1882static PREPROCESSOR: once_cell::sync::Lazy<Preprocessor> =
1883    once_cell::sync::Lazy::new(Preprocessor::default);
1884
1885/// Get module name and all required imports (ignoring shader_defs) from a shader string
1886pub fn get_preprocessor_data(
1887    source: &str,
1888) -> (
1889    Option<String>,
1890    Vec<ImportDefinition>,
1891    HashMap<String, ShaderDefValue>,
1892) {
1893    if let Ok(PreprocessorMetaData {
1894        name,
1895        imports,
1896        defines,
1897        ..
1898    }) = PREPROCESSOR.get_preprocessor_metadata(source, true)
1899    {
1900        (
1901            name,
1902            imports
1903                .into_iter()
1904                .map(|import_with_offset| import_with_offset.definition)
1905                .collect(),
1906            defines,
1907        )
1908    } else {
1909        // if errors occur we return nothing; the actual error will be displayed when the caller attempts to use the shader
1910        Default::default()
1911    }
1912}