naga_oil/
redirect.rs

1use std::collections::{HashMap, HashSet};
2
3use naga::{Block, Expression, Function, Handle, Module, Statement};
4use thiserror::Error;
5
6use crate::derive::DerivedModule;
7
8#[derive(Debug, Error)]
9pub enum RedirectError {
10    #[error("can't find function {0} for redirection")]
11    FunctionNotFound(String),
12    #[error("{0} cannot override {1} due to argument mismatch")]
13    ArgumentMismatch(String, String),
14    #[error("{0} cannot override {1} due to return type mismatch")]
15    ReturnTypeMismatch(String, String),
16    #[error("circular reference; can't find an order for : {0}")]
17    CircularReference(String),
18}
19
20pub struct Redirector {
21    module: Module,
22}
23
24impl Redirector {
25    pub fn new(module: Module) -> Self {
26        Self { module }
27    }
28
29    fn redirect_block(block: &mut Block, original: Handle<Function>, new: Handle<Function>) {
30        for stmt in block.iter_mut() {
31            match stmt {
32                Statement::Call {
33                    ref mut function, ..
34                } => {
35                    if *function == original {
36                        *function = new;
37                    }
38                }
39                Statement::Block(b) => Self::redirect_block(b, original, new),
40                Statement::If {
41                    condition: _,
42                    accept,
43                    reject,
44                } => {
45                    Self::redirect_block(accept, original, new);
46                    Self::redirect_block(reject, original, new);
47                }
48                Statement::Switch { selector: _, cases } => {
49                    for case in cases.iter_mut() {
50                        Self::redirect_block(&mut case.body, original, new);
51                    }
52                }
53                Statement::Loop {
54                    body,
55                    continuing,
56                    break_if: _,
57                } => {
58                    Self::redirect_block(body, original, new);
59                    Self::redirect_block(continuing, original, new);
60                }
61                Statement::Emit(_)
62                | Statement::Break
63                | Statement::Continue
64                | Statement::Return { .. }
65                | Statement::WorkGroupUniformLoad { .. }
66                | Statement::Kill
67                | Statement::Barrier(_)
68                | Statement::Store { .. }
69                | Statement::ImageStore { .. }
70                | Statement::Atomic { .. }
71                | Statement::RayQuery { .. }
72                | Statement::SubgroupBallot { .. }
73                | Statement::SubgroupGather { .. }
74                | Statement::SubgroupCollectiveOperation { .. } => (),
75            }
76        }
77    }
78
79    fn redirect_expr(expr: &mut Expression, original: Handle<Function>, new: Handle<Function>) {
80        if let Expression::CallResult(f) = expr {
81            if f == &original {
82                *expr = Expression::CallResult(new);
83            }
84        }
85    }
86
87    fn redirect_fn(func: &mut Function, original: Handle<Function>, new: Handle<Function>) {
88        Self::redirect_block(&mut func.body, original, new);
89        for (_, expr) in func.expressions.iter_mut() {
90            Self::redirect_expr(expr, original, new);
91        }
92    }
93
94    /// redirect all calls to the function named `original` with references to the function named `replacement`, except within the replacement function
95    /// or in any function contained in the `omit` set.
96    /// returns handles to the original and replacement functions.
97    /// NB: requires the replacement to be defined in the arena before any calls to the original, or validation will fail.
98    pub fn redirect_function(
99        &mut self,
100        original: &str,
101        replacement: &str,
102        omit: &HashSet<String>,
103    ) -> Result<(Handle<Function>, Handle<Function>), RedirectError> {
104        let (h_original, f_original) = self
105            .module
106            .functions
107            .iter()
108            .find(|(_, f)| f.name.as_deref() == Some(original))
109            .ok_or_else(|| RedirectError::FunctionNotFound(original.to_owned()))?;
110        let (h_replacement, f_replacement) = self
111            .module
112            .functions
113            .iter()
114            .find(|(_, f)| f.name.as_deref() == Some(replacement))
115            .ok_or_else(|| RedirectError::FunctionNotFound(replacement.to_owned()))?;
116
117        for (arg1, arg2) in f_original
118            .arguments
119            .iter()
120            .zip(f_replacement.arguments.iter())
121        {
122            if arg1.ty != arg2.ty {
123                return Err(RedirectError::ArgumentMismatch(
124                    original.to_owned(),
125                    replacement.to_owned(),
126                ));
127            }
128        }
129
130        if f_original.result.as_ref().map(|r| r.ty) != f_replacement.result.as_ref().map(|r| r.ty) {
131            return Err(RedirectError::ReturnTypeMismatch(
132                original.to_owned(),
133                replacement.to_owned(),
134            ));
135        }
136
137        for (h_f, f) in self.module.functions.iter_mut() {
138            if h_f != h_replacement && !omit.contains(f.name.as_ref().unwrap()) {
139                Self::redirect_fn(f, h_original, h_replacement);
140            }
141        }
142
143        for ep in &mut self.module.entry_points {
144            Self::redirect_fn(&mut ep.function, h_original, h_replacement);
145        }
146
147        Ok((h_original, h_replacement))
148    }
149
150    fn gather_requirements(block: &Block) -> HashSet<Handle<Function>> {
151        let mut requirements = HashSet::default();
152
153        for stmt in block.iter() {
154            match stmt {
155                Statement::Block(b) => requirements.extend(Self::gather_requirements(b)),
156                Statement::If { accept, reject, .. } => {
157                    requirements.extend(Self::gather_requirements(accept));
158                    requirements.extend(Self::gather_requirements(reject));
159                }
160                Statement::Switch { cases, .. } => {
161                    for case in cases {
162                        requirements.extend(Self::gather_requirements(&case.body));
163                    }
164                }
165                Statement::Loop {
166                    body, continuing, ..
167                } => {
168                    requirements.extend(Self::gather_requirements(body));
169                    requirements.extend(Self::gather_requirements(continuing));
170                }
171                Statement::Call { function, .. } => {
172                    requirements.insert(*function);
173                }
174                _ => (),
175            }
176        }
177
178        requirements
179    }
180
181    pub fn into_module(self) -> Result<naga::Module, RedirectError> {
182        // reorder functions so that dependents come first
183        let mut requirements: HashMap<_, _> = self
184            .module
185            .functions
186            .iter()
187            .map(|(h_f, f)| (h_f, Self::gather_requirements(&f.body)))
188            .collect();
189
190        let mut derived = DerivedModule::default();
191        derived.set_shader_source(&self.module, 0);
192
193        while !requirements.is_empty() {
194            let start_len = requirements.len();
195
196            let mut added: HashSet<Handle<Function>> = HashSet::new();
197
198            // add anything that has all requirements satisfied
199            requirements.retain(|h_f, reqs| {
200                if reqs.is_empty() {
201                    let func = self.module.functions.try_get(*h_f).unwrap();
202                    let span = self.module.functions.get_span(*h_f);
203                    derived.import_function(func, span);
204                    added.insert(*h_f);
205                    false
206                } else {
207                    true
208                }
209            });
210
211            // remove things we added from requirements
212            for reqs in requirements.values_mut() {
213                reqs.retain(|req| !added.contains(req));
214            }
215
216            if requirements.len() == start_len {
217                return Err(RedirectError::CircularReference(format!(
218                    "{:#?}",
219                    requirements.keys()
220                )));
221            }
222        }
223
224        Ok(derived.into_module_with_entrypoints())
225    }
226}
227
228impl TryFrom<Redirector> for naga::Module {
229    type Error = RedirectError;
230
231    fn try_from(redirector: Redirector) -> Result<Self, Self::Error> {
232        redirector.into_module()
233    }
234}