naga_oil/compose/
wgsl_directives.rs

1use std::collections::HashSet;
2
3#[derive(Debug, Clone, PartialEq, Eq, Hash)]
4pub struct EnableDirective {
5    pub extensions: Vec<String>,
6    pub source_location: usize,
7}
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash)]
10pub struct RequiresDirective {
11    pub extensions: Vec<String>,
12    pub source_location: usize,
13}
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub struct DiagnosticDirective {
17    pub severity: String,
18    pub rule: String,
19    pub source_location: usize,
20}
21
22#[derive(Debug, Clone, Default)]
23pub struct WgslDirectives {
24    pub enables: Vec<EnableDirective>,
25    pub requires: Vec<RequiresDirective>,
26    pub diagnostics: Vec<DiagnosticDirective>,
27}
28
29impl WgslDirectives {
30    pub fn to_wgsl_string(&self) -> String {
31        let mut result = String::new();
32
33        let mut all_enables = HashSet::new();
34        for enable in &self.enables {
35            all_enables.extend(enable.extensions.iter().cloned());
36        }
37        if !all_enables.is_empty() {
38            let mut enables: Vec<_> = all_enables.into_iter().collect();
39            enables.sort();
40            result.push_str(&format!("enable {};\n", enables.join(", ")));
41        }
42
43        let mut all_requires = HashSet::new();
44        for requires in &self.requires {
45            all_requires.extend(requires.extensions.iter().cloned());
46        }
47        if !all_requires.is_empty() {
48            let mut requires: Vec<_> = all_requires.into_iter().collect();
49            requires.sort();
50            result.push_str(&format!("requires {};\n", requires.join(", ")));
51        }
52
53        for diagnostic in &self.diagnostics {
54            result.push_str(&format!(
55                "diagnostic({}, {});\n",
56                diagnostic.severity, diagnostic.rule
57            ));
58        }
59
60        if !result.is_empty() {
61            result.push('\n'); // Add blank line after directives
62        }
63
64        result
65    }
66
67    pub fn is_empty(&self) -> bool {
68        self.enables.is_empty() && self.requires.is_empty() && self.diagnostics.is_empty()
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75
76    #[test]
77    fn test_wgsl_directives_empty() {
78        let directives = WgslDirectives::default();
79        assert!(directives.is_empty());
80        assert_eq!(directives.to_wgsl_string(), "");
81    }
82
83    #[test]
84    fn test_wgsl_directives_to_string() {
85        let mut directives = WgslDirectives::default();
86        directives.enables.push(EnableDirective {
87            extensions: vec!["f16".to_string(), "subgroups".to_string()],
88            source_location: 0,
89        });
90        directives.requires.push(RequiresDirective {
91            extensions: vec!["readonly_and_readwrite_storage_textures".to_string()],
92            source_location: 0,
93        });
94        directives.diagnostics.push(DiagnosticDirective {
95            severity: "warn".to_string(),
96            rule: "derivative_uniformity".to_string(),
97            source_location: 0,
98        });
99
100        let result = directives.to_wgsl_string();
101        assert!(result.contains("enable f16, subgroups;"));
102        assert!(result.contains("requires readonly_and_readwrite_storage_textures;"));
103        assert!(result.contains("diagnostic(warn, derivative_uniformity);"));
104    }
105}