1use std::collections::{HashMap, HashSet};
2
3use naga::{Block, Expression, Function, Handle, Module, Statement};
4use thiserror::Error;
5
6use crate::derive::DerivedModule;
7
8#[derive(Debug, Error)]
9pub enum RedirectError {
10 #[error("can't find function {0} for redirection")]
11 FunctionNotFound(String),
12 #[error("{0} cannot override {1} due to argument mismatch")]
13 ArgumentMismatch(String, String),
14 #[error("{0} cannot override {1} due to return type mismatch")]
15 ReturnTypeMismatch(String, String),
16 #[error("circular reference; can't find an order for : {0}")]
17 CircularReference(String),
18}
19
20pub struct Redirector {
21 module: Module,
22}
23
24impl Redirector {
25 pub fn new(module: Module) -> Self {
26 Self { module }
27 }
28
29 fn redirect_block(block: &mut Block, original: Handle<Function>, new: Handle<Function>) {
30 for stmt in block.iter_mut() {
31 match stmt {
32 Statement::Call {
33 ref mut function, ..
34 } => {
35 if *function == original {
36 *function = new;
37 }
38 }
39 Statement::Block(b) => Self::redirect_block(b, original, new),
40 Statement::If {
41 condition: _,
42 accept,
43 reject,
44 } => {
45 Self::redirect_block(accept, original, new);
46 Self::redirect_block(reject, original, new);
47 }
48 Statement::Switch { selector: _, cases } => {
49 for case in cases.iter_mut() {
50 Self::redirect_block(&mut case.body, original, new);
51 }
52 }
53 Statement::Loop {
54 body,
55 continuing,
56 break_if: _,
57 } => {
58 Self::redirect_block(body, original, new);
59 Self::redirect_block(continuing, original, new);
60 }
61 Statement::Emit(_)
62 | Statement::Break
63 | Statement::Continue
64 | Statement::Return { .. }
65 | Statement::WorkGroupUniformLoad { .. }
66 | Statement::Kill
67 | Statement::Barrier(_)
68 | Statement::Store { .. }
69 | Statement::ImageStore { .. }
70 | Statement::Atomic { .. }
71 | Statement::RayQuery { .. }
72 | Statement::SubgroupBallot { .. }
73 | Statement::SubgroupGather { .. }
74 | Statement::SubgroupCollectiveOperation { .. } => (),
75 }
76 }
77 }
78
79 fn redirect_expr(expr: &mut Expression, original: Handle<Function>, new: Handle<Function>) {
80 if let Expression::CallResult(f) = expr {
81 if f == &original {
82 *expr = Expression::CallResult(new);
83 }
84 }
85 }
86
87 fn redirect_fn(func: &mut Function, original: Handle<Function>, new: Handle<Function>) {
88 Self::redirect_block(&mut func.body, original, new);
89 for (_, expr) in func.expressions.iter_mut() {
90 Self::redirect_expr(expr, original, new);
91 }
92 }
93
94 pub fn redirect_function(
99 &mut self,
100 original: &str,
101 replacement: &str,
102 omit: &HashSet<String>,
103 ) -> Result<(Handle<Function>, Handle<Function>), RedirectError> {
104 let (h_original, f_original) = self
105 .module
106 .functions
107 .iter()
108 .find(|(_, f)| f.name.as_deref() == Some(original))
109 .ok_or_else(|| RedirectError::FunctionNotFound(original.to_owned()))?;
110 let (h_replacement, f_replacement) = self
111 .module
112 .functions
113 .iter()
114 .find(|(_, f)| f.name.as_deref() == Some(replacement))
115 .ok_or_else(|| RedirectError::FunctionNotFound(replacement.to_owned()))?;
116
117 for (arg1, arg2) in f_original
118 .arguments
119 .iter()
120 .zip(f_replacement.arguments.iter())
121 {
122 if arg1.ty != arg2.ty {
123 return Err(RedirectError::ArgumentMismatch(
124 original.to_owned(),
125 replacement.to_owned(),
126 ));
127 }
128 }
129
130 if f_original.result.as_ref().map(|r| r.ty) != f_replacement.result.as_ref().map(|r| r.ty) {
131 return Err(RedirectError::ReturnTypeMismatch(
132 original.to_owned(),
133 replacement.to_owned(),
134 ));
135 }
136
137 for (h_f, f) in self.module.functions.iter_mut() {
138 if h_f != h_replacement && !omit.contains(f.name.as_ref().unwrap()) {
139 Self::redirect_fn(f, h_original, h_replacement);
140 }
141 }
142
143 for ep in &mut self.module.entry_points {
144 Self::redirect_fn(&mut ep.function, h_original, h_replacement);
145 }
146
147 Ok((h_original, h_replacement))
148 }
149
150 fn gather_requirements(block: &Block) -> HashSet<Handle<Function>> {
151 let mut requirements = HashSet::default();
152
153 for stmt in block.iter() {
154 match stmt {
155 Statement::Block(b) => requirements.extend(Self::gather_requirements(b)),
156 Statement::If { accept, reject, .. } => {
157 requirements.extend(Self::gather_requirements(accept));
158 requirements.extend(Self::gather_requirements(reject));
159 }
160 Statement::Switch { cases, .. } => {
161 for case in cases {
162 requirements.extend(Self::gather_requirements(&case.body));
163 }
164 }
165 Statement::Loop {
166 body, continuing, ..
167 } => {
168 requirements.extend(Self::gather_requirements(body));
169 requirements.extend(Self::gather_requirements(continuing));
170 }
171 Statement::Call { function, .. } => {
172 requirements.insert(*function);
173 }
174 _ => (),
175 }
176 }
177
178 requirements
179 }
180
181 pub fn into_module(self) -> Result<naga::Module, RedirectError> {
182 let mut requirements: HashMap<_, _> = self
184 .module
185 .functions
186 .iter()
187 .map(|(h_f, f)| (h_f, Self::gather_requirements(&f.body)))
188 .collect();
189
190 let mut derived = DerivedModule::default();
191 derived.set_shader_source(&self.module, 0);
192
193 while !requirements.is_empty() {
194 let start_len = requirements.len();
195
196 let mut added: HashSet<Handle<Function>> = HashSet::new();
197
198 requirements.retain(|h_f, reqs| {
200 if reqs.is_empty() {
201 let func = self.module.functions.try_get(*h_f).unwrap();
202 let span = self.module.functions.get_span(*h_f);
203 derived.import_function(func, span);
204 added.insert(*h_f);
205 false
206 } else {
207 true
208 }
209 });
210
211 for reqs in requirements.values_mut() {
213 reqs.retain(|req| !added.contains(req));
214 }
215
216 if requirements.len() == start_len {
217 return Err(RedirectError::CircularReference(format!(
218 "{:#?}",
219 requirements.keys()
220 )));
221 }
222 }
223
224 Ok(derived.into_module_with_entrypoints())
225 }
226}
227
228impl TryFrom<Redirector> for naga::Module {
229 type Error = RedirectError;
230
231 fn try_from(redirector: Redirector) -> Result<Self, Self::Error> {
232 redirector.into_module()
233 }
234}