1mod expressions;
2mod functions;
3mod handle_set_map;
4mod statements;
5mod types;
6
7use crate::arena::HandleSet;
8use crate::{arena, compact::functions::FunctionTracer};
9use handle_set_map::HandleMap;
10
11pub fn compact(module: &mut crate::Module) {
33 let mut module_tracer = ModuleTracer::new(module);
34
35 log::trace!("tracing global variables");
37 {
38 for (_, global) in module.global_variables.iter() {
39 log::trace!("tracing global {:?}", global.name);
40 module_tracer.types_used.insert(global.ty);
41 if let Some(init) = global.init {
42 module_tracer.global_expressions_used.insert(init);
43 }
44 }
45 }
46
47 module_tracer.trace_special_types(&module.special_types);
49
50 for (handle, constant) in module.constants.iter() {
52 if constant.name.is_some() {
53 module_tracer.constants_used.insert(handle);
54 module_tracer.global_expressions_used.insert(constant.init);
55 }
56 }
57
58 for (_, override_) in module.overrides.iter() {
60 module_tracer.types_used.insert(override_.ty);
61 if let Some(init) = override_.init {
62 module_tracer.global_expressions_used.insert(init);
63 }
64 }
65
66 for (_, ty) in module.types.iter() {
67 if let crate::TypeInner::Array {
68 size: crate::ArraySize::Pending(crate::PendingArraySize::Expression(size_expr)),
69 ..
70 } = ty.inner
71 {
72 module_tracer.global_expressions_used.insert(size_expr);
73 }
74 }
75
76 log::trace!("tracing functions");
83 let function_maps: Vec<FunctionMap> = module
84 .functions
85 .iter()
86 .map(|(_, f)| {
87 log::trace!("tracing function {:?}", f.name);
88 let mut function_tracer = module_tracer.as_function(f);
89 function_tracer.trace();
90 FunctionMap::from(function_tracer)
91 })
92 .collect();
93
94 log::trace!("tracing entry points");
96 let entry_point_maps: Vec<FunctionMap> = module
97 .entry_points
98 .iter()
99 .map(|e| {
100 log::trace!("tracing entry point {:?}", e.function.name);
101
102 if let Some(sizes) = e.workgroup_size_overrides {
103 for size in sizes.iter().filter_map(|x| *x) {
104 module_tracer.global_expressions_used.insert(size);
105 }
106 }
107
108 let mut used = module_tracer.as_function(&e.function);
109 used.trace();
110 FunctionMap::from(used)
111 })
112 .collect();
113
114 module_tracer.as_const_expression().trace_expressions();
119
120 for (handle, constant) in module.constants.iter() {
124 if module_tracer.constants_used.contains(handle) {
125 module_tracer.types_used.insert(constant.ty);
126 }
127 }
128
129 for (handle, ty) in module.types.iter() {
131 log::trace!("tracing type {:?}, name {:?}", handle, ty.name);
132 if ty.name.is_some() {
133 module_tracer.types_used.insert(handle);
134 }
135 }
136
137 module_tracer.as_type().trace_types();
139
140 let module_map = ModuleMap::from(module_tracer);
145
146 log::trace!("compacting types");
152 let mut new_types = arena::UniqueArena::new();
153 for (old_handle, mut ty, span) in module.types.drain_all() {
154 if let Some(expected_new_handle) = module_map.types.try_adjust(old_handle) {
155 module_map.adjust_type(&mut ty);
156 let actual_new_handle = new_types.insert(ty, span);
157 assert_eq!(actual_new_handle, expected_new_handle);
158 }
159 }
160 module.types = new_types;
161 log::trace!("adjusting special types");
162 module_map.adjust_special_types(&mut module.special_types);
163
164 log::trace!("adjusting constant expressions");
166 module.global_expressions.retain_mut(|handle, expr| {
167 if module_map.global_expressions.used(handle) {
168 module_map.adjust_expression(expr, &module_map.global_expressions);
169 true
170 } else {
171 false
172 }
173 });
174
175 log::trace!("adjusting constants");
177 module.constants.retain_mut(|handle, constant| {
178 if module_map.constants.used(handle) {
179 module_map.types.adjust(&mut constant.ty);
180 module_map.global_expressions.adjust(&mut constant.init);
181 true
182 } else {
183 false
184 }
185 });
186
187 log::trace!("adjusting overrides");
189 for (_, override_) in module.overrides.iter_mut() {
190 module_map.types.adjust(&mut override_.ty);
191 if let Some(init) = override_.init.as_mut() {
192 module_map.global_expressions.adjust(init);
193 }
194 }
195
196 log::trace!("adjusting workgroup_size_overrides");
198 for e in module.entry_points.iter_mut() {
199 if let Some(sizes) = e.workgroup_size_overrides.as_mut() {
200 for size in sizes.iter_mut() {
201 if let Some(expr) = size.as_mut() {
202 module_map.global_expressions.adjust(expr);
203 }
204 }
205 }
206 }
207
208 log::trace!("adjusting global variables");
210 for (_, global) in module.global_variables.iter_mut() {
211 log::trace!("adjusting global {:?}", global.name);
212 module_map.types.adjust(&mut global.ty);
213 if let Some(ref mut init) = global.init {
214 module_map.global_expressions.adjust(init);
215 }
216 }
217
218 let mut reused_named_expressions = crate::NamedExpressions::default();
221
222 for ((_, function), map) in module.functions.iter_mut().zip(function_maps.iter()) {
224 log::trace!("compacting function {:?}", function.name);
225 map.compact(function, &module_map, &mut reused_named_expressions);
226 }
227
228 for (entry, map) in module.entry_points.iter_mut().zip(entry_point_maps.iter()) {
230 log::trace!("compacting entry point {:?}", entry.function.name);
231 map.compact(
232 &mut entry.function,
233 &module_map,
234 &mut reused_named_expressions,
235 );
236 }
237}
238
239struct ModuleTracer<'module> {
240 module: &'module crate::Module,
241 types_used: HandleSet<crate::Type>,
242 constants_used: HandleSet<crate::Constant>,
243 global_expressions_used: HandleSet<crate::Expression>,
244}
245
246impl<'module> ModuleTracer<'module> {
247 fn new(module: &'module crate::Module) -> Self {
248 Self {
249 module,
250 types_used: HandleSet::for_arena(&module.types),
251 constants_used: HandleSet::for_arena(&module.constants),
252 global_expressions_used: HandleSet::for_arena(&module.global_expressions),
253 }
254 }
255
256 fn trace_special_types(&mut self, special_types: &crate::SpecialTypes) {
257 let crate::SpecialTypes {
258 ref ray_desc,
259 ref ray_intersection,
260 ref predeclared_types,
261 } = *special_types;
262
263 if let Some(ray_desc) = *ray_desc {
264 self.types_used.insert(ray_desc);
265 }
266 if let Some(ray_intersection) = *ray_intersection {
267 self.types_used.insert(ray_intersection);
268 }
269 for (_, &handle) in predeclared_types {
270 self.types_used.insert(handle);
271 }
272 }
273
274 fn as_type(&mut self) -> types::TypeTracer {
275 types::TypeTracer {
276 types: &self.module.types,
277 types_used: &mut self.types_used,
278 }
279 }
280
281 fn as_const_expression(&mut self) -> expressions::ExpressionTracer {
282 expressions::ExpressionTracer {
283 expressions: &self.module.global_expressions,
284 constants: &self.module.constants,
285 types_used: &mut self.types_used,
286 constants_used: &mut self.constants_used,
287 expressions_used: &mut self.global_expressions_used,
288 global_expressions_used: None,
289 }
290 }
291
292 pub fn as_function<'tracer>(
293 &'tracer mut self,
294 function: &'tracer crate::Function,
295 ) -> FunctionTracer<'tracer> {
296 FunctionTracer {
297 function,
298 constants: &self.module.constants,
299 types_used: &mut self.types_used,
300 constants_used: &mut self.constants_used,
301 global_expressions_used: &mut self.global_expressions_used,
302 expressions_used: HandleSet::for_arena(&function.expressions),
303 }
304 }
305}
306
307struct ModuleMap {
308 types: HandleMap<crate::Type>,
309 constants: HandleMap<crate::Constant>,
310 global_expressions: HandleMap<crate::Expression>,
311}
312
313impl From<ModuleTracer<'_>> for ModuleMap {
314 fn from(used: ModuleTracer) -> Self {
315 ModuleMap {
316 types: HandleMap::from_set(used.types_used),
317 constants: HandleMap::from_set(used.constants_used),
318 global_expressions: HandleMap::from_set(used.global_expressions_used),
319 }
320 }
321}
322
323impl ModuleMap {
324 fn adjust_special_types(&self, special: &mut crate::SpecialTypes) {
325 let crate::SpecialTypes {
326 ref mut ray_desc,
327 ref mut ray_intersection,
328 ref mut predeclared_types,
329 } = *special;
330
331 if let Some(ref mut ray_desc) = *ray_desc {
332 self.types.adjust(ray_desc);
333 }
334 if let Some(ref mut ray_intersection) = *ray_intersection {
335 self.types.adjust(ray_intersection);
336 }
337
338 for handle in predeclared_types.values_mut() {
339 self.types.adjust(handle);
340 }
341 }
342}
343
344struct FunctionMap {
345 expressions: HandleMap<crate::Expression>,
346}
347
348impl From<FunctionTracer<'_>> for FunctionMap {
349 fn from(used: FunctionTracer) -> Self {
350 FunctionMap {
351 expressions: HandleMap::from_set(used.expressions_used),
352 }
353 }
354}