naga_oil/
redirect.rs

1use std::collections::{HashMap, HashSet};
2
3use naga::{Block, Expression, Function, Handle, Module, Statement};
4use thiserror::Error;
5
6use crate::derive::DerivedModule;
7
8#[derive(Debug, Error)]
9pub enum RedirectError {
10    #[error("can't find function {0} for redirection")]
11    FunctionNotFound(String),
12    #[error("{0} cannot override {1} due to argument mismatch")]
13    ArgumentMismatch(String, String),
14    #[error("{0} cannot override {1} due to return type mismatch")]
15    ReturnTypeMismatch(String, String),
16    #[error("circular reference; can't find an order for : {0}")]
17    CircularReference(String),
18}
19
20pub struct Redirector {
21    module: Module,
22}
23
24impl Redirector {
25    pub fn new(module: Module) -> Self {
26        Self { module }
27    }
28
29    fn redirect_block(block: &mut Block, original: Handle<Function>, new: Handle<Function>) {
30        for stmt in block.iter_mut() {
31            match stmt {
32                Statement::Call {
33                    ref mut function, ..
34                } => {
35                    if *function == original {
36                        *function = new;
37                    }
38                }
39                Statement::Block(b) => Self::redirect_block(b, original, new),
40                Statement::If {
41                    condition: _,
42                    accept,
43                    reject,
44                } => {
45                    Self::redirect_block(accept, original, new);
46                    Self::redirect_block(reject, original, new);
47                }
48                Statement::Switch { selector: _, cases } => {
49                    for case in cases.iter_mut() {
50                        Self::redirect_block(&mut case.body, original, new);
51                    }
52                }
53                Statement::Loop {
54                    body,
55                    continuing,
56                    break_if: _,
57                } => {
58                    Self::redirect_block(body, original, new);
59                    Self::redirect_block(continuing, original, new);
60                }
61                Statement::Emit(_)
62                | Statement::Break
63                | Statement::Continue
64                | Statement::Return { .. }
65                | Statement::WorkGroupUniformLoad { .. }
66                | Statement::Kill
67                | Statement::MemoryBarrier(_)
68                | Statement::ControlBarrier(_)
69                | Statement::Store { .. }
70                | Statement::ImageStore { .. }
71                | Statement::Atomic { .. }
72                | Statement::RayQuery { .. }
73                | Statement::SubgroupBallot { .. }
74                | Statement::SubgroupGather { .. }
75                | Statement::SubgroupCollectiveOperation { .. }
76                | Statement::ImageAtomic { .. } => (),
77            }
78        }
79    }
80
81    fn redirect_expr(expr: &mut Expression, original: Handle<Function>, new: Handle<Function>) {
82        if let Expression::CallResult(f) = expr {
83            if f == &original {
84                *expr = Expression::CallResult(new);
85            }
86        }
87    }
88
89    fn redirect_fn(func: &mut Function, original: Handle<Function>, new: Handle<Function>) {
90        Self::redirect_block(&mut func.body, original, new);
91        for (_, expr) in func.expressions.iter_mut() {
92            Self::redirect_expr(expr, original, new);
93        }
94    }
95
96    /// redirect all calls to the function named `original` with references to the function named `replacement`, except within the replacement function
97    /// or in any function contained in the `omit` set.
98    /// returns handles to the original and replacement functions.
99    /// NB: requires the replacement to be defined in the arena before any calls to the original, or validation will fail.
100    pub fn redirect_function(
101        &mut self,
102        original: &str,
103        replacement: &str,
104        omit: &HashSet<String>,
105    ) -> Result<(Handle<Function>, Handle<Function>), RedirectError> {
106        let (h_original, f_original) = self
107            .module
108            .functions
109            .iter()
110            .find(|(_, f)| f.name.as_deref() == Some(original))
111            .ok_or_else(|| RedirectError::FunctionNotFound(original.to_owned()))?;
112        let (h_replacement, f_replacement) = self
113            .module
114            .functions
115            .iter()
116            .find(|(_, f)| f.name.as_deref() == Some(replacement))
117            .ok_or_else(|| RedirectError::FunctionNotFound(replacement.to_owned()))?;
118
119        for (arg1, arg2) in f_original
120            .arguments
121            .iter()
122            .zip(f_replacement.arguments.iter())
123        {
124            if arg1.ty != arg2.ty {
125                return Err(RedirectError::ArgumentMismatch(
126                    original.to_owned(),
127                    replacement.to_owned(),
128                ));
129            }
130        }
131
132        if f_original.result.as_ref().map(|r| r.ty) != f_replacement.result.as_ref().map(|r| r.ty) {
133            return Err(RedirectError::ReturnTypeMismatch(
134                original.to_owned(),
135                replacement.to_owned(),
136            ));
137        }
138
139        for (h_f, f) in self.module.functions.iter_mut() {
140            if h_f != h_replacement && !omit.contains(f.name.as_ref().unwrap()) {
141                Self::redirect_fn(f, h_original, h_replacement);
142            }
143        }
144
145        for ep in &mut self.module.entry_points {
146            Self::redirect_fn(&mut ep.function, h_original, h_replacement);
147        }
148
149        Ok((h_original, h_replacement))
150    }
151
152    fn gather_requirements(block: &Block) -> HashSet<Handle<Function>> {
153        let mut requirements = HashSet::default();
154
155        for stmt in block.iter() {
156            match stmt {
157                Statement::Block(b) => requirements.extend(Self::gather_requirements(b)),
158                Statement::If { accept, reject, .. } => {
159                    requirements.extend(Self::gather_requirements(accept));
160                    requirements.extend(Self::gather_requirements(reject));
161                }
162                Statement::Switch { cases, .. } => {
163                    for case in cases {
164                        requirements.extend(Self::gather_requirements(&case.body));
165                    }
166                }
167                Statement::Loop {
168                    body, continuing, ..
169                } => {
170                    requirements.extend(Self::gather_requirements(body));
171                    requirements.extend(Self::gather_requirements(continuing));
172                }
173                Statement::Call { function, .. } => {
174                    requirements.insert(*function);
175                }
176                _ => (),
177            }
178        }
179
180        requirements
181    }
182
183    pub fn into_module(self) -> Result<naga::Module, RedirectError> {
184        // reorder functions so that dependents come first
185        let mut requirements: HashMap<_, _> = self
186            .module
187            .functions
188            .iter()
189            .map(|(h_f, f)| (h_f, Self::gather_requirements(&f.body)))
190            .collect();
191
192        let mut derived = DerivedModule::default();
193        derived.set_shader_source(&self.module, 0);
194
195        while !requirements.is_empty() {
196            let start_len = requirements.len();
197
198            let mut added: HashSet<Handle<Function>> = HashSet::new();
199
200            // add anything that has all requirements satisfied
201            requirements.retain(|h_f, reqs| {
202                if reqs.is_empty() {
203                    let func = self.module.functions.try_get(*h_f).unwrap();
204                    let span = self.module.functions.get_span(*h_f);
205                    derived.import_function(func, span);
206                    added.insert(*h_f);
207                    false
208                } else {
209                    true
210                }
211            });
212
213            // remove things we added from requirements
214            for reqs in requirements.values_mut() {
215                reqs.retain(|req| !added.contains(req));
216            }
217
218            if requirements.len() == start_len {
219                return Err(RedirectError::CircularReference(format!(
220                    "{:#?}",
221                    requirements.keys()
222                )));
223            }
224        }
225
226        Ok(derived.into_module_with_entrypoints())
227    }
228}
229
230impl TryFrom<Redirector> for naga::Module {
231    type Error = RedirectError;
232
233    fn try_from(redirector: Redirector) -> Result<Self, Self::Error> {
234        redirector.into_module()
235    }
236}