1use std::collections::{HashMap, HashSet};
2
3use naga::{Block, Expression, Function, Handle, Module, Statement};
4use thiserror::Error;
5
6use crate::derive::DerivedModule;
7
8#[derive(Debug, Error)]
9pub enum RedirectError {
10 #[error("can't find function {0} for redirection")]
11 FunctionNotFound(String),
12 #[error("{0} cannot override {1} due to argument mismatch")]
13 ArgumentMismatch(String, String),
14 #[error("{0} cannot override {1} due to return type mismatch")]
15 ReturnTypeMismatch(String, String),
16 #[error("circular reference; can't find an order for : {0}")]
17 CircularReference(String),
18}
19
20pub struct Redirector {
21 module: Module,
22}
23
24impl Redirector {
25 pub fn new(module: Module) -> Self {
26 Self { module }
27 }
28
29 fn redirect_block(block: &mut Block, original: Handle<Function>, new: Handle<Function>) {
30 for stmt in block.iter_mut() {
31 match stmt {
32 Statement::Call {
33 ref mut function, ..
34 } => {
35 if *function == original {
36 *function = new;
37 }
38 }
39 Statement::Block(b) => Self::redirect_block(b, original, new),
40 Statement::If {
41 condition: _,
42 accept,
43 reject,
44 } => {
45 Self::redirect_block(accept, original, new);
46 Self::redirect_block(reject, original, new);
47 }
48 Statement::Switch { selector: _, cases } => {
49 for case in cases.iter_mut() {
50 Self::redirect_block(&mut case.body, original, new);
51 }
52 }
53 Statement::Loop {
54 body,
55 continuing,
56 break_if: _,
57 } => {
58 Self::redirect_block(body, original, new);
59 Self::redirect_block(continuing, original, new);
60 }
61 Statement::Emit(_)
62 | Statement::Break
63 | Statement::Continue
64 | Statement::Return { .. }
65 | Statement::WorkGroupUniformLoad { .. }
66 | Statement::Kill
67 | Statement::MemoryBarrier(_)
68 | Statement::ControlBarrier(_)
69 | Statement::Store { .. }
70 | Statement::ImageStore { .. }
71 | Statement::Atomic { .. }
72 | Statement::RayQuery { .. }
73 | Statement::SubgroupBallot { .. }
74 | Statement::SubgroupGather { .. }
75 | Statement::SubgroupCollectiveOperation { .. }
76 | Statement::ImageAtomic { .. } => (),
77 }
78 }
79 }
80
81 fn redirect_expr(expr: &mut Expression, original: Handle<Function>, new: Handle<Function>) {
82 if let Expression::CallResult(f) = expr {
83 if f == &original {
84 *expr = Expression::CallResult(new);
85 }
86 }
87 }
88
89 fn redirect_fn(func: &mut Function, original: Handle<Function>, new: Handle<Function>) {
90 Self::redirect_block(&mut func.body, original, new);
91 for (_, expr) in func.expressions.iter_mut() {
92 Self::redirect_expr(expr, original, new);
93 }
94 }
95
96 pub fn redirect_function(
101 &mut self,
102 original: &str,
103 replacement: &str,
104 omit: &HashSet<String>,
105 ) -> Result<(Handle<Function>, Handle<Function>), RedirectError> {
106 let (h_original, f_original) = self
107 .module
108 .functions
109 .iter()
110 .find(|(_, f)| f.name.as_deref() == Some(original))
111 .ok_or_else(|| RedirectError::FunctionNotFound(original.to_owned()))?;
112 let (h_replacement, f_replacement) = self
113 .module
114 .functions
115 .iter()
116 .find(|(_, f)| f.name.as_deref() == Some(replacement))
117 .ok_or_else(|| RedirectError::FunctionNotFound(replacement.to_owned()))?;
118
119 for (arg1, arg2) in f_original
120 .arguments
121 .iter()
122 .zip(f_replacement.arguments.iter())
123 {
124 if arg1.ty != arg2.ty {
125 return Err(RedirectError::ArgumentMismatch(
126 original.to_owned(),
127 replacement.to_owned(),
128 ));
129 }
130 }
131
132 if f_original.result.as_ref().map(|r| r.ty) != f_replacement.result.as_ref().map(|r| r.ty) {
133 return Err(RedirectError::ReturnTypeMismatch(
134 original.to_owned(),
135 replacement.to_owned(),
136 ));
137 }
138
139 for (h_f, f) in self.module.functions.iter_mut() {
140 if h_f != h_replacement && !omit.contains(f.name.as_ref().unwrap()) {
141 Self::redirect_fn(f, h_original, h_replacement);
142 }
143 }
144
145 for ep in &mut self.module.entry_points {
146 Self::redirect_fn(&mut ep.function, h_original, h_replacement);
147 }
148
149 Ok((h_original, h_replacement))
150 }
151
152 fn gather_requirements(block: &Block) -> HashSet<Handle<Function>> {
153 let mut requirements = HashSet::default();
154
155 for stmt in block.iter() {
156 match stmt {
157 Statement::Block(b) => requirements.extend(Self::gather_requirements(b)),
158 Statement::If { accept, reject, .. } => {
159 requirements.extend(Self::gather_requirements(accept));
160 requirements.extend(Self::gather_requirements(reject));
161 }
162 Statement::Switch { cases, .. } => {
163 for case in cases {
164 requirements.extend(Self::gather_requirements(&case.body));
165 }
166 }
167 Statement::Loop {
168 body, continuing, ..
169 } => {
170 requirements.extend(Self::gather_requirements(body));
171 requirements.extend(Self::gather_requirements(continuing));
172 }
173 Statement::Call { function, .. } => {
174 requirements.insert(*function);
175 }
176 _ => (),
177 }
178 }
179
180 requirements
181 }
182
183 pub fn into_module(self) -> Result<naga::Module, RedirectError> {
184 let mut requirements: HashMap<_, _> = self
186 .module
187 .functions
188 .iter()
189 .map(|(h_f, f)| (h_f, Self::gather_requirements(&f.body)))
190 .collect();
191
192 let mut derived = DerivedModule::default();
193 derived.set_shader_source(&self.module, 0);
194
195 while !requirements.is_empty() {
196 let start_len = requirements.len();
197
198 let mut added: HashSet<Handle<Function>> = HashSet::new();
199
200 requirements.retain(|h_f, reqs| {
202 if reqs.is_empty() {
203 let func = self.module.functions.try_get(*h_f).unwrap();
204 let span = self.module.functions.get_span(*h_f);
205 derived.import_function(func, span);
206 added.insert(*h_f);
207 false
208 } else {
209 true
210 }
211 });
212
213 for reqs in requirements.values_mut() {
215 reqs.retain(|req| !added.contains(req));
216 }
217
218 if requirements.len() == start_len {
219 return Err(RedirectError::CircularReference(format!(
220 "{:#?}",
221 requirements.keys()
222 )));
223 }
224 }
225
226 Ok(derived.into_module_with_entrypoints())
227 }
228}
229
230impl TryFrom<Redirector> for naga::Module {
231 type Error = RedirectError;
232
233 fn try_from(redirector: Redirector) -> Result<Self, Self::Error> {
234 redirector.into_module()
235 }
236}