1use std::collections::{HashMap, HashSet};
2
3use naga::{Block, Expression, Function, Handle, Module, Statement};
4use thiserror::Error;
5
6use crate::derive::DerivedModule;
7
8#[derive(Debug, Error)]
9pub enum RedirectError {
10 #[error("can't find function {0} for redirection")]
11 FunctionNotFound(String),
12 #[error("{0} cannot override {1} due to argument mismatch")]
13 ArgumentMismatch(String, String),
14 #[error("{0} cannot override {1} due to return type mismatch")]
15 ReturnTypeMismatch(String, String),
16 #[error("circular reference; can't find an order for : {0}")]
17 CircularReference(String),
18}
19
20pub struct Redirector {
21 module: Module,
22}
23
24impl Redirector {
25 pub fn new(module: Module) -> Self {
26 Self { module }
27 }
28
29 fn redirect_block(block: &mut Block, original: Handle<Function>, new: Handle<Function>) {
30 for stmt in block.iter_mut() {
31 match stmt {
32 Statement::Call {
33 ref mut function, ..
34 } => {
35 if *function == original {
36 *function = new;
37 }
38 }
39 Statement::Block(b) => Self::redirect_block(b, original, new),
40 Statement::If {
41 condition: _,
42 accept,
43 reject,
44 } => {
45 Self::redirect_block(accept, original, new);
46 Self::redirect_block(reject, original, new);
47 }
48 Statement::Switch { selector: _, cases } => {
49 for case in cases.iter_mut() {
50 Self::redirect_block(&mut case.body, original, new);
51 }
52 }
53 Statement::Loop {
54 body,
55 continuing,
56 break_if: _,
57 } => {
58 Self::redirect_block(body, original, new);
59 Self::redirect_block(continuing, original, new);
60 }
61 Statement::Emit(_)
62 | Statement::Break
63 | Statement::Continue
64 | Statement::Return { .. }
65 | Statement::WorkGroupUniformLoad { .. }
66 | Statement::Kill
67 | Statement::Barrier(_)
68 | Statement::Store { .. }
69 | Statement::ImageStore { .. }
70 | Statement::Atomic { .. }
71 | Statement::RayQuery { .. }
72 | Statement::SubgroupBallot { .. }
73 | Statement::SubgroupGather { .. }
74 | Statement::SubgroupCollectiveOperation { .. }
75 | Statement::ImageAtomic { .. } => (),
76 }
77 }
78 }
79
80 fn redirect_expr(expr: &mut Expression, original: Handle<Function>, new: Handle<Function>) {
81 if let Expression::CallResult(f) = expr {
82 if f == &original {
83 *expr = Expression::CallResult(new);
84 }
85 }
86 }
87
88 fn redirect_fn(func: &mut Function, original: Handle<Function>, new: Handle<Function>) {
89 Self::redirect_block(&mut func.body, original, new);
90 for (_, expr) in func.expressions.iter_mut() {
91 Self::redirect_expr(expr, original, new);
92 }
93 }
94
95 pub fn redirect_function(
100 &mut self,
101 original: &str,
102 replacement: &str,
103 omit: &HashSet<String>,
104 ) -> Result<(Handle<Function>, Handle<Function>), RedirectError> {
105 let (h_original, f_original) = self
106 .module
107 .functions
108 .iter()
109 .find(|(_, f)| f.name.as_deref() == Some(original))
110 .ok_or_else(|| RedirectError::FunctionNotFound(original.to_owned()))?;
111 let (h_replacement, f_replacement) = self
112 .module
113 .functions
114 .iter()
115 .find(|(_, f)| f.name.as_deref() == Some(replacement))
116 .ok_or_else(|| RedirectError::FunctionNotFound(replacement.to_owned()))?;
117
118 for (arg1, arg2) in f_original
119 .arguments
120 .iter()
121 .zip(f_replacement.arguments.iter())
122 {
123 if arg1.ty != arg2.ty {
124 return Err(RedirectError::ArgumentMismatch(
125 original.to_owned(),
126 replacement.to_owned(),
127 ));
128 }
129 }
130
131 if f_original.result.as_ref().map(|r| r.ty) != f_replacement.result.as_ref().map(|r| r.ty) {
132 return Err(RedirectError::ReturnTypeMismatch(
133 original.to_owned(),
134 replacement.to_owned(),
135 ));
136 }
137
138 for (h_f, f) in self.module.functions.iter_mut() {
139 if h_f != h_replacement && !omit.contains(f.name.as_ref().unwrap()) {
140 Self::redirect_fn(f, h_original, h_replacement);
141 }
142 }
143
144 for ep in &mut self.module.entry_points {
145 Self::redirect_fn(&mut ep.function, h_original, h_replacement);
146 }
147
148 Ok((h_original, h_replacement))
149 }
150
151 fn gather_requirements(block: &Block) -> HashSet<Handle<Function>> {
152 let mut requirements = HashSet::default();
153
154 for stmt in block.iter() {
155 match stmt {
156 Statement::Block(b) => requirements.extend(Self::gather_requirements(b)),
157 Statement::If { accept, reject, .. } => {
158 requirements.extend(Self::gather_requirements(accept));
159 requirements.extend(Self::gather_requirements(reject));
160 }
161 Statement::Switch { cases, .. } => {
162 for case in cases {
163 requirements.extend(Self::gather_requirements(&case.body));
164 }
165 }
166 Statement::Loop {
167 body, continuing, ..
168 } => {
169 requirements.extend(Self::gather_requirements(body));
170 requirements.extend(Self::gather_requirements(continuing));
171 }
172 Statement::Call { function, .. } => {
173 requirements.insert(*function);
174 }
175 _ => (),
176 }
177 }
178
179 requirements
180 }
181
182 pub fn into_module(self) -> Result<naga::Module, RedirectError> {
183 let mut requirements: HashMap<_, _> = self
185 .module
186 .functions
187 .iter()
188 .map(|(h_f, f)| (h_f, Self::gather_requirements(&f.body)))
189 .collect();
190
191 let mut derived = DerivedModule::default();
192 derived.set_shader_source(&self.module, 0);
193
194 while !requirements.is_empty() {
195 let start_len = requirements.len();
196
197 let mut added: HashSet<Handle<Function>> = HashSet::new();
198
199 requirements.retain(|h_f, reqs| {
201 if reqs.is_empty() {
202 let func = self.module.functions.try_get(*h_f).unwrap();
203 let span = self.module.functions.get_span(*h_f);
204 derived.import_function(func, span);
205 added.insert(*h_f);
206 false
207 } else {
208 true
209 }
210 });
211
212 for reqs in requirements.values_mut() {
214 reqs.retain(|req| !added.contains(req));
215 }
216
217 if requirements.len() == start_len {
218 return Err(RedirectError::CircularReference(format!(
219 "{:#?}",
220 requirements.keys()
221 )));
222 }
223 }
224
225 Ok(derived.into_module_with_entrypoints())
226 }
227}
228
229impl TryFrom<Redirector> for naga::Module {
230 type Error = RedirectError;
231
232 fn try_from(redirector: Redirector) -> Result<Self, Self::Error> {
233 redirector.into_module()
234 }
235}