1use super::ShaderDefVal;
2use crate::define_atomic_id;
3use alloc::borrow::Cow;
4use bevy_asset::{io::Reader, Asset, AssetLoader, AssetPath, Handle, LoadContext};
5use bevy_reflect::TypePath;
6use core::marker::Copy;
7use thiserror::Error;
8
9define_atomic_id!(ShaderId);
10
11#[derive(Error, Debug)]
12pub enum ShaderReflectError {
13 #[error(transparent)]
14 WgslParse(#[from] naga::front::wgsl::ParseError),
15 #[cfg(feature = "shader_format_glsl")]
16 #[error("GLSL Parse Error: {0:?}")]
17 GlslParse(Vec<naga::front::glsl::Error>),
18 #[cfg(feature = "shader_format_spirv")]
19 #[error(transparent)]
20 SpirVParse(#[from] naga::front::spv::Error),
21 #[error(transparent)]
22 Validation(#[from] naga::WithSpan<naga::valid::ValidationError>),
23}
24
25#[derive(Clone, Debug, Default)]
33pub enum ValidateShader {
34 #[default]
35 Disabled,
39 Enabled,
46}
47
48#[derive(Asset, TypePath, Debug, Clone)]
51pub struct Shader {
52 pub path: String,
53 pub source: Source,
54 pub import_path: ShaderImport,
55 pub imports: Vec<ShaderImport>,
56 pub additional_imports: Vec<naga_oil::compose::ImportDefinition>,
58 pub shader_defs: Vec<ShaderDefVal>,
60 pub file_dependencies: Vec<Handle<Shader>>,
63 pub validate_shader: ValidateShader,
67}
68
69impl Shader {
70 fn preprocess(source: &str, path: &str) -> (ShaderImport, Vec<ShaderImport>) {
71 let (import_path, imports, _) = naga_oil::compose::get_preprocessor_data(source);
72
73 let import_path = import_path
74 .map(ShaderImport::Custom)
75 .unwrap_or_else(|| ShaderImport::AssetPath(path.to_owned()));
76
77 let imports = imports
78 .into_iter()
79 .map(|import| {
80 if import.import.starts_with('\"') {
81 let import = import
82 .import
83 .chars()
84 .skip(1)
85 .take_while(|c| *c != '\"')
86 .collect();
87 ShaderImport::AssetPath(import)
88 } else {
89 ShaderImport::Custom(import.import)
90 }
91 })
92 .collect();
93
94 (import_path, imports)
95 }
96
97 pub fn from_wgsl(source: impl Into<Cow<'static, str>>, path: impl Into<String>) -> Shader {
98 let source = source.into();
99 let path = path.into();
100 let (import_path, imports) = Shader::preprocess(&source, &path);
101 Shader {
102 path,
103 imports,
104 import_path,
105 source: Source::Wgsl(source),
106 additional_imports: Default::default(),
107 shader_defs: Default::default(),
108 file_dependencies: Default::default(),
109 validate_shader: ValidateShader::Disabled,
110 }
111 }
112
113 pub fn from_wgsl_with_defs(
114 source: impl Into<Cow<'static, str>>,
115 path: impl Into<String>,
116 shader_defs: Vec<ShaderDefVal>,
117 ) -> Shader {
118 Self {
119 shader_defs,
120 ..Self::from_wgsl(source, path)
121 }
122 }
123
124 pub fn from_glsl(
125 source: impl Into<Cow<'static, str>>,
126 stage: naga::ShaderStage,
127 path: impl Into<String>,
128 ) -> Shader {
129 let source = source.into();
130 let path = path.into();
131 let (import_path, imports) = Shader::preprocess(&source, &path);
132 Shader {
133 path,
134 imports,
135 import_path,
136 source: Source::Glsl(source, stage),
137 additional_imports: Default::default(),
138 shader_defs: Default::default(),
139 file_dependencies: Default::default(),
140 validate_shader: ValidateShader::Disabled,
141 }
142 }
143
144 pub fn from_spirv(source: impl Into<Cow<'static, [u8]>>, path: impl Into<String>) -> Shader {
145 let path = path.into();
146 Shader {
147 path: path.clone(),
148 imports: Vec::new(),
149 import_path: ShaderImport::AssetPath(path),
150 source: Source::SpirV(source.into()),
151 additional_imports: Default::default(),
152 shader_defs: Default::default(),
153 file_dependencies: Default::default(),
154 validate_shader: ValidateShader::Disabled,
155 }
156 }
157
158 #[cfg(feature = "shader_format_wesl")]
159 pub fn from_wesl(source: impl Into<Cow<'static, str>>, path: impl Into<String>) -> Shader {
160 let source = source.into();
161 let path = path.into();
162 let (import_path, imports) = Shader::preprocess(&source, &path);
163
164 match import_path {
165 ShaderImport::AssetPath(asset_path) => {
166 let shader_path = std::path::Path::new("/").join(&asset_path);
168
169 let import_path_str = shader_path
171 .with_extension("")
172 .to_string_lossy()
173 .replace('\\', "/");
174
175 let import_path = ShaderImport::AssetPath(import_path_str.to_string());
176
177 Shader {
178 path,
179 imports,
180 import_path,
181 source: Source::Wesl(source),
182 additional_imports: Default::default(),
183 shader_defs: Default::default(),
184 file_dependencies: Default::default(),
185 validate_shader: ValidateShader::Disabled,
186 }
187 }
188 ShaderImport::Custom(_) => {
189 panic!("Wesl shaders must be imported from an asset path");
190 }
191 }
192 }
193
194 pub fn set_import_path<P: Into<String>>(&mut self, import_path: P) {
195 self.import_path = ShaderImport::Custom(import_path.into());
196 }
197
198 #[must_use]
199 pub fn with_import_path<P: Into<String>>(mut self, import_path: P) -> Self {
200 self.set_import_path(import_path);
201 self
202 }
203
204 #[inline]
205 pub fn import_path(&self) -> &ShaderImport {
206 &self.import_path
207 }
208
209 pub fn imports(&self) -> impl ExactSizeIterator<Item = &ShaderImport> {
210 self.imports.iter()
211 }
212}
213
214impl<'a> From<&'a Shader> for naga_oil::compose::ComposableModuleDescriptor<'a> {
215 fn from(shader: &'a Shader) -> Self {
216 let shader_defs = shader
217 .shader_defs
218 .iter()
219 .map(|def| match def {
220 ShaderDefVal::Bool(name, b) => {
221 (name.clone(), naga_oil::compose::ShaderDefValue::Bool(*b))
222 }
223 ShaderDefVal::Int(name, i) => {
224 (name.clone(), naga_oil::compose::ShaderDefValue::Int(*i))
225 }
226 ShaderDefVal::UInt(name, i) => {
227 (name.clone(), naga_oil::compose::ShaderDefValue::UInt(*i))
228 }
229 })
230 .collect();
231
232 let as_name = match &shader.import_path {
233 ShaderImport::AssetPath(asset_path) => Some(format!("\"{asset_path}\"")),
234 ShaderImport::Custom(_) => None,
235 };
236
237 naga_oil::compose::ComposableModuleDescriptor {
238 source: shader.source.as_str(),
239 file_path: &shader.path,
240 language: (&shader.source).into(),
241 additional_imports: &shader.additional_imports,
242 shader_defs,
243 as_name,
244 }
245 }
246}
247
248impl<'a> From<&'a Shader> for naga_oil::compose::NagaModuleDescriptor<'a> {
249 fn from(shader: &'a Shader) -> Self {
250 naga_oil::compose::NagaModuleDescriptor {
251 source: shader.source.as_str(),
252 file_path: &shader.path,
253 shader_type: (&shader.source).into(),
254 ..Default::default()
255 }
256 }
257}
258
259#[derive(Debug, Clone)]
260pub enum Source {
261 Wgsl(Cow<'static, str>),
262 Wesl(Cow<'static, str>),
263 Glsl(Cow<'static, str>, naga::ShaderStage),
264 SpirV(Cow<'static, [u8]>),
265 }
269
270impl Source {
271 pub fn as_str(&self) -> &str {
272 match self {
273 Source::Wgsl(s) | Source::Wesl(s) | Source::Glsl(s, _) => s,
274 Source::SpirV(_) => panic!("spirv not yet implemented"),
275 }
276 }
277}
278
279impl From<&Source> for naga_oil::compose::ShaderLanguage {
280 fn from(value: &Source) -> Self {
281 match value {
282 Source::Wgsl(_) => naga_oil::compose::ShaderLanguage::Wgsl,
283 #[cfg(any(feature = "shader_format_glsl", target_arch = "wasm32"))]
284 Source::Glsl(_, _) => naga_oil::compose::ShaderLanguage::Glsl,
285 #[cfg(all(not(feature = "shader_format_glsl"), not(target_arch = "wasm32")))]
286 Source::Glsl(_, _) => panic!(
287 "GLSL is not supported in this configuration; use the feature `shader_format_glsl`"
288 ),
289 Source::SpirV(_) => panic!("spirv not yet implemented"),
290 Source::Wesl(_) => panic!("wesl not yet implemented"),
291 }
292 }
293}
294
295impl From<&Source> for naga_oil::compose::ShaderType {
296 fn from(value: &Source) -> Self {
297 match value {
298 Source::Wgsl(_) => naga_oil::compose::ShaderType::Wgsl,
299 #[cfg(any(feature = "shader_format_glsl", target_arch = "wasm32"))]
300 Source::Glsl(_, shader_stage) => match shader_stage {
301 naga::ShaderStage::Vertex => naga_oil::compose::ShaderType::GlslVertex,
302 naga::ShaderStage::Fragment => naga_oil::compose::ShaderType::GlslFragment,
303 naga::ShaderStage::Compute => panic!("glsl compute not yet implemented"),
304 },
305 #[cfg(all(not(feature = "shader_format_glsl"), not(target_arch = "wasm32")))]
306 Source::Glsl(_, _) => panic!(
307 "GLSL is not supported in this configuration; use the feature `shader_format_glsl`"
308 ),
309 Source::SpirV(_) => panic!("spirv not yet implemented"),
310 Source::Wesl(_) => panic!("wesl not yet implemented"),
311 }
312 }
313}
314
315#[derive(Default)]
316pub struct ShaderLoader;
317
318#[non_exhaustive]
319#[derive(Debug, Error)]
320pub enum ShaderLoaderError {
321 #[error("Could not load shader: {0}")]
322 Io(#[from] std::io::Error),
323 #[error("Could not parse shader: {0}")]
324 Parse(#[from] alloc::string::FromUtf8Error),
325}
326
327impl AssetLoader for ShaderLoader {
328 type Asset = Shader;
329 type Settings = ();
330 type Error = ShaderLoaderError;
331 async fn load(
332 &self,
333 reader: &mut dyn Reader,
334 _settings: &Self::Settings,
335 load_context: &mut LoadContext<'_>,
336 ) -> Result<Shader, Self::Error> {
337 let ext = load_context.path().extension().unwrap().to_str().unwrap();
338 let path = load_context.asset_path().to_string();
339 let path = path.replace(std::path::MAIN_SEPARATOR, "/");
342 let mut bytes = Vec::new();
343 reader.read_to_end(&mut bytes).await?;
344 let mut shader = match ext {
345 "spv" => Shader::from_spirv(bytes, load_context.path().to_string_lossy()),
346 "wgsl" => Shader::from_wgsl(String::from_utf8(bytes)?, path),
347 "vert" => Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Vertex, path),
348 "frag" => {
349 Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Fragment, path)
350 }
351 "comp" => {
352 Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Compute, path)
353 }
354 #[cfg(feature = "shader_format_wesl")]
355 "wesl" => Shader::from_wesl(String::from_utf8(bytes)?, path),
356 _ => panic!("unhandled extension: {ext}"),
357 };
358
359 for import in &shader.imports {
361 if let ShaderImport::AssetPath(asset_path) = import {
362 shader.file_dependencies.push(load_context.load(asset_path));
363 }
364 }
365 Ok(shader)
366 }
367
368 fn extensions(&self) -> &[&str] {
369 &["spv", "wgsl", "vert", "frag", "comp", "wesl"]
370 }
371}
372
373#[derive(Debug, PartialEq, Eq, Clone, Hash)]
374pub enum ShaderImport {
375 AssetPath(String),
376 Custom(String),
377}
378
379impl ShaderImport {
380 pub fn module_name(&self) -> Cow<'_, String> {
381 match self {
382 ShaderImport::AssetPath(s) => Cow::Owned(format!("\"{s}\"")),
383 ShaderImport::Custom(s) => Cow::Borrowed(s),
384 }
385 }
386}
387
388pub enum ShaderRef {
390 Default,
392 Handle(Handle<Shader>),
394 Path(AssetPath<'static>),
396}
397
398impl From<Handle<Shader>> for ShaderRef {
399 fn from(handle: Handle<Shader>) -> Self {
400 Self::Handle(handle)
401 }
402}
403
404impl From<AssetPath<'static>> for ShaderRef {
405 fn from(path: AssetPath<'static>) -> Self {
406 Self::Path(path)
407 }
408}
409
410impl From<&'static str> for ShaderRef {
411 fn from(path: &'static str) -> Self {
412 Self::Path(AssetPath::from(path))
413 }
414}