naga_oil/
redirect.rs

1use std::collections::{HashMap, HashSet};
2
3use naga::{Block, Expression, Function, Handle, Module, Statement};
4use thiserror::Error;
5
6use crate::derive::DerivedModule;
7
8#[derive(Debug, Error)]
9pub enum RedirectError {
10    #[error("can't find function {0} for redirection")]
11    FunctionNotFound(String),
12    #[error("{0} cannot override {1} due to argument mismatch")]
13    ArgumentMismatch(String, String),
14    #[error("{0} cannot override {1} due to return type mismatch")]
15    ReturnTypeMismatch(String, String),
16    #[error("circular reference; can't find an order for : {0}")]
17    CircularReference(String),
18}
19
20pub struct Redirector {
21    module: Module,
22}
23
24impl Redirector {
25    pub fn new(module: Module) -> Self {
26        Self { module }
27    }
28
29    fn redirect_block(block: &mut Block, original: Handle<Function>, new: Handle<Function>) {
30        for stmt in block.iter_mut() {
31            match stmt {
32                Statement::Call {
33                    ref mut function, ..
34                } => {
35                    if *function == original {
36                        *function = new;
37                    }
38                }
39                Statement::Block(b) => Self::redirect_block(b, original, new),
40                Statement::If {
41                    condition: _,
42                    accept,
43                    reject,
44                } => {
45                    Self::redirect_block(accept, original, new);
46                    Self::redirect_block(reject, original, new);
47                }
48                Statement::Switch { selector: _, cases } => {
49                    for case in cases.iter_mut() {
50                        Self::redirect_block(&mut case.body, original, new);
51                    }
52                }
53                Statement::Loop {
54                    body,
55                    continuing,
56                    break_if: _,
57                } => {
58                    Self::redirect_block(body, original, new);
59                    Self::redirect_block(continuing, original, new);
60                }
61                Statement::Emit(_)
62                | Statement::Break
63                | Statement::Continue
64                | Statement::Return { .. }
65                | Statement::WorkGroupUniformLoad { .. }
66                | Statement::Kill
67                | Statement::Barrier(_)
68                | Statement::Store { .. }
69                | Statement::ImageStore { .. }
70                | Statement::Atomic { .. }
71                | Statement::RayQuery { .. }
72                | Statement::SubgroupBallot { .. }
73                | Statement::SubgroupGather { .. }
74                | Statement::SubgroupCollectiveOperation { .. }
75                | Statement::ImageAtomic { .. } => (),
76            }
77        }
78    }
79
80    fn redirect_expr(expr: &mut Expression, original: Handle<Function>, new: Handle<Function>) {
81        if let Expression::CallResult(f) = expr {
82            if f == &original {
83                *expr = Expression::CallResult(new);
84            }
85        }
86    }
87
88    fn redirect_fn(func: &mut Function, original: Handle<Function>, new: Handle<Function>) {
89        Self::redirect_block(&mut func.body, original, new);
90        for (_, expr) in func.expressions.iter_mut() {
91            Self::redirect_expr(expr, original, new);
92        }
93    }
94
95    /// redirect all calls to the function named `original` with references to the function named `replacement`, except within the replacement function
96    /// or in any function contained in the `omit` set.
97    /// returns handles to the original and replacement functions.
98    /// NB: requires the replacement to be defined in the arena before any calls to the original, or validation will fail.
99    pub fn redirect_function(
100        &mut self,
101        original: &str,
102        replacement: &str,
103        omit: &HashSet<String>,
104    ) -> Result<(Handle<Function>, Handle<Function>), RedirectError> {
105        let (h_original, f_original) = self
106            .module
107            .functions
108            .iter()
109            .find(|(_, f)| f.name.as_deref() == Some(original))
110            .ok_or_else(|| RedirectError::FunctionNotFound(original.to_owned()))?;
111        let (h_replacement, f_replacement) = self
112            .module
113            .functions
114            .iter()
115            .find(|(_, f)| f.name.as_deref() == Some(replacement))
116            .ok_or_else(|| RedirectError::FunctionNotFound(replacement.to_owned()))?;
117
118        for (arg1, arg2) in f_original
119            .arguments
120            .iter()
121            .zip(f_replacement.arguments.iter())
122        {
123            if arg1.ty != arg2.ty {
124                return Err(RedirectError::ArgumentMismatch(
125                    original.to_owned(),
126                    replacement.to_owned(),
127                ));
128            }
129        }
130
131        if f_original.result.as_ref().map(|r| r.ty) != f_replacement.result.as_ref().map(|r| r.ty) {
132            return Err(RedirectError::ReturnTypeMismatch(
133                original.to_owned(),
134                replacement.to_owned(),
135            ));
136        }
137
138        for (h_f, f) in self.module.functions.iter_mut() {
139            if h_f != h_replacement && !omit.contains(f.name.as_ref().unwrap()) {
140                Self::redirect_fn(f, h_original, h_replacement);
141            }
142        }
143
144        for ep in &mut self.module.entry_points {
145            Self::redirect_fn(&mut ep.function, h_original, h_replacement);
146        }
147
148        Ok((h_original, h_replacement))
149    }
150
151    fn gather_requirements(block: &Block) -> HashSet<Handle<Function>> {
152        let mut requirements = HashSet::default();
153
154        for stmt in block.iter() {
155            match stmt {
156                Statement::Block(b) => requirements.extend(Self::gather_requirements(b)),
157                Statement::If { accept, reject, .. } => {
158                    requirements.extend(Self::gather_requirements(accept));
159                    requirements.extend(Self::gather_requirements(reject));
160                }
161                Statement::Switch { cases, .. } => {
162                    for case in cases {
163                        requirements.extend(Self::gather_requirements(&case.body));
164                    }
165                }
166                Statement::Loop {
167                    body, continuing, ..
168                } => {
169                    requirements.extend(Self::gather_requirements(body));
170                    requirements.extend(Self::gather_requirements(continuing));
171                }
172                Statement::Call { function, .. } => {
173                    requirements.insert(*function);
174                }
175                _ => (),
176            }
177        }
178
179        requirements
180    }
181
182    pub fn into_module(self) -> Result<naga::Module, RedirectError> {
183        // reorder functions so that dependents come first
184        let mut requirements: HashMap<_, _> = self
185            .module
186            .functions
187            .iter()
188            .map(|(h_f, f)| (h_f, Self::gather_requirements(&f.body)))
189            .collect();
190
191        let mut derived = DerivedModule::default();
192        derived.set_shader_source(&self.module, 0);
193
194        while !requirements.is_empty() {
195            let start_len = requirements.len();
196
197            let mut added: HashSet<Handle<Function>> = HashSet::new();
198
199            // add anything that has all requirements satisfied
200            requirements.retain(|h_f, reqs| {
201                if reqs.is_empty() {
202                    let func = self.module.functions.try_get(*h_f).unwrap();
203                    let span = self.module.functions.get_span(*h_f);
204                    derived.import_function(func, span);
205                    added.insert(*h_f);
206                    false
207                } else {
208                    true
209                }
210            });
211
212            // remove things we added from requirements
213            for reqs in requirements.values_mut() {
214                reqs.retain(|req| !added.contains(req));
215            }
216
217            if requirements.len() == start_len {
218                return Err(RedirectError::CircularReference(format!(
219                    "{:#?}",
220                    requirements.keys()
221                )));
222            }
223        }
224
225        Ok(derived.into_module_with_entrypoints())
226    }
227}
228
229impl TryFrom<Redirector> for naga::Module {
230    type Error = RedirectError;
231
232    fn try_from(redirector: Redirector) -> Result<Self, Self::Error> {
233        redirector.into_module()
234    }
235}