1use super::ShaderDefVal;
2use alloc::borrow::Cow;
3use bevy_asset::{io::Reader, Asset, AssetLoader, AssetPath, Handle, LoadContext};
4use bevy_reflect::TypePath;
5use core::{marker::Copy, num::NonZero};
6use thiserror::Error;
7
8#[derive(Copy, Clone, Hash, Eq, PartialEq, PartialOrd, Ord, Debug)]
9pub struct ShaderId(NonZero<u32>);
10
11impl ShaderId {
12 #[expect(
13 clippy::new_without_default,
14 reason = "Implementing the `Default` trait on atomic IDs would imply that two `<AtomicIdType>::default()` equal each other. By only implementing `new()`, we indicate that each atomic ID created will be unique."
15 )]
16 pub fn new() -> Self {
17 use core::sync::atomic::{AtomicU32, Ordering};
18 static COUNTER: AtomicU32 = AtomicU32::new(1);
19 let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
20 Self(NonZero::<u32>::new(counter).unwrap_or_else(|| {
21 panic!("The system ran out of unique `{}`s.", stringify!(ShaderId));
22 }))
23 }
24}
25impl From<ShaderId> for NonZero<u32> {
26 fn from(value: ShaderId) -> Self {
27 value.0
28 }
29}
30impl From<NonZero<u32>> for ShaderId {
31 fn from(value: NonZero<u32>) -> Self {
32 Self(value)
33 }
34}
35
36#[derive(Error, Debug)]
37pub enum ShaderReflectError {
38 #[error(transparent)]
39 WgslParse(#[from] naga::front::wgsl::ParseError),
40 #[cfg(feature = "shader_format_glsl")]
41 #[error("GLSL Parse Error: {0:?}")]
42 GlslParse(Vec<naga::front::glsl::Error>),
43 #[cfg(feature = "shader_format_spirv")]
44 #[error(transparent)]
45 SpirVParse(#[from] naga::front::spv::Error),
46 #[error(transparent)]
47 Validation(#[from] naga::WithSpan<naga::valid::ValidationError>),
48}
49
50#[derive(Clone, Debug, Default)]
58pub enum ValidateShader {
59 #[default]
60 Disabled,
64 Enabled,
71}
72
73#[derive(Asset, TypePath, Debug, Clone)]
75pub struct Shader {
76 pub path: String,
77 pub source: Source,
78 pub import_path: ShaderImport,
79 pub imports: Vec<ShaderImport>,
80 pub additional_imports: Vec<naga_oil::compose::ImportDefinition>,
82 pub shader_defs: Vec<ShaderDefVal>,
84 pub file_dependencies: Vec<Handle<Shader>>,
87 pub validate_shader: ValidateShader,
91}
92
93impl Shader {
94 fn preprocess(source: &str, path: &str) -> (ShaderImport, Vec<ShaderImport>) {
95 let (import_path, imports, _) = naga_oil::compose::get_preprocessor_data(source);
96
97 let import_path = import_path
98 .map(ShaderImport::Custom)
99 .unwrap_or_else(|| ShaderImport::AssetPath(path.to_owned()));
100
101 let imports = imports
102 .into_iter()
103 .map(|import| {
104 if import.import.starts_with('\"') {
105 let import = import
106 .import
107 .chars()
108 .skip(1)
109 .take_while(|c| *c != '\"')
110 .collect();
111 ShaderImport::AssetPath(import)
112 } else {
113 ShaderImport::Custom(import.import)
114 }
115 })
116 .collect();
117
118 (import_path, imports)
119 }
120
121 pub fn from_wgsl(source: impl Into<Cow<'static, str>>, path: impl Into<String>) -> Shader {
122 let source = source.into();
123 let path = path.into();
124 let (import_path, imports) = Shader::preprocess(&source, &path);
125 Shader {
126 path,
127 imports,
128 import_path,
129 source: Source::Wgsl(source),
130 additional_imports: Default::default(),
131 shader_defs: Default::default(),
132 file_dependencies: Default::default(),
133 validate_shader: ValidateShader::Disabled,
134 }
135 }
136
137 pub fn from_wgsl_with_defs(
138 source: impl Into<Cow<'static, str>>,
139 path: impl Into<String>,
140 shader_defs: Vec<ShaderDefVal>,
141 ) -> Shader {
142 Self {
143 shader_defs,
144 ..Self::from_wgsl(source, path)
145 }
146 }
147
148 pub fn from_glsl(
149 source: impl Into<Cow<'static, str>>,
150 stage: naga::ShaderStage,
151 path: impl Into<String>,
152 ) -> Shader {
153 let source = source.into();
154 let path = path.into();
155 let (import_path, imports) = Shader::preprocess(&source, &path);
156 Shader {
157 path,
158 imports,
159 import_path,
160 source: Source::Glsl(source, stage),
161 additional_imports: Default::default(),
162 shader_defs: Default::default(),
163 file_dependencies: Default::default(),
164 validate_shader: ValidateShader::Disabled,
165 }
166 }
167
168 pub fn from_spirv(source: impl Into<Cow<'static, [u8]>>, path: impl Into<String>) -> Shader {
169 let path = path.into();
170 Shader {
171 path: path.clone(),
172 imports: Vec::new(),
173 import_path: ShaderImport::AssetPath(path),
174 source: Source::SpirV(source.into()),
175 additional_imports: Default::default(),
176 shader_defs: Default::default(),
177 file_dependencies: Default::default(),
178 validate_shader: ValidateShader::Disabled,
179 }
180 }
181
182 #[cfg(feature = "shader_format_wesl")]
183 pub fn from_wesl(source: impl Into<Cow<'static, str>>, path: impl Into<String>) -> Shader {
184 let source = source.into();
185 let path = path.into();
186 let (import_path, imports) = Shader::preprocess(&source, &path);
187
188 match import_path {
189 ShaderImport::AssetPath(asset_path) => {
190 let shader_path = std::path::Path::new("/").join(&asset_path);
192
193 let import_path_str = shader_path
195 .with_extension("")
196 .to_string_lossy()
197 .replace('\\', "/");
198
199 let import_path = ShaderImport::AssetPath(import_path_str.to_string());
200
201 Shader {
202 path,
203 imports,
204 import_path,
205 source: Source::Wesl(source),
206 additional_imports: Default::default(),
207 shader_defs: Default::default(),
208 file_dependencies: Default::default(),
209 validate_shader: ValidateShader::Disabled,
210 }
211 }
212 ShaderImport::Custom(_) => {
213 panic!("Wesl shaders must be imported from an asset path");
214 }
215 }
216 }
217
218 pub fn set_import_path<P: Into<String>>(&mut self, import_path: P) {
219 self.import_path = ShaderImport::Custom(import_path.into());
220 }
221
222 #[must_use]
223 pub fn with_import_path<P: Into<String>>(mut self, import_path: P) -> Self {
224 self.set_import_path(import_path);
225 self
226 }
227
228 #[inline]
229 pub fn import_path(&self) -> &ShaderImport {
230 &self.import_path
231 }
232
233 pub fn imports(&self) -> impl ExactSizeIterator<Item = &ShaderImport> {
234 self.imports.iter()
235 }
236}
237
238impl<'a> From<&'a Shader> for naga_oil::compose::ComposableModuleDescriptor<'a> {
239 fn from(shader: &'a Shader) -> Self {
240 let shader_defs = shader
241 .shader_defs
242 .iter()
243 .map(|def| match def {
244 ShaderDefVal::Bool(name, b) => {
245 (name.clone(), naga_oil::compose::ShaderDefValue::Bool(*b))
246 }
247 ShaderDefVal::Int(name, i) => {
248 (name.clone(), naga_oil::compose::ShaderDefValue::Int(*i))
249 }
250 ShaderDefVal::UInt(name, i) => {
251 (name.clone(), naga_oil::compose::ShaderDefValue::UInt(*i))
252 }
253 })
254 .collect();
255
256 let as_name = match &shader.import_path {
257 ShaderImport::AssetPath(asset_path) => Some(format!("\"{asset_path}\"")),
258 ShaderImport::Custom(_) => None,
259 };
260
261 naga_oil::compose::ComposableModuleDescriptor {
262 source: shader.source.as_str(),
263 file_path: &shader.path,
264 language: (&shader.source).into(),
265 additional_imports: &shader.additional_imports,
266 shader_defs,
267 as_name,
268 }
269 }
270}
271
272impl<'a> From<&'a Shader> for naga_oil::compose::NagaModuleDescriptor<'a> {
273 fn from(shader: &'a Shader) -> Self {
274 naga_oil::compose::NagaModuleDescriptor {
275 source: shader.source.as_str(),
276 file_path: &shader.path,
277 shader_type: (&shader.source).into(),
278 ..Default::default()
279 }
280 }
281}
282
283#[derive(Debug, Clone)]
284pub enum Source {
285 Wgsl(Cow<'static, str>),
286 Wesl(Cow<'static, str>),
287 Glsl(Cow<'static, str>, naga::ShaderStage),
288 SpirV(Cow<'static, [u8]>),
289 }
293
294impl Source {
295 pub fn as_str(&self) -> &str {
296 match self {
297 Source::Wgsl(s) | Source::Wesl(s) | Source::Glsl(s, _) => s,
298 Source::SpirV(_) => panic!("spirv not yet implemented"),
299 }
300 }
301}
302
303impl From<&Source> for naga_oil::compose::ShaderLanguage {
304 fn from(value: &Source) -> Self {
305 match value {
306 Source::Wgsl(_) => naga_oil::compose::ShaderLanguage::Wgsl,
307 #[cfg(any(feature = "shader_format_glsl", target_arch = "wasm32"))]
308 Source::Glsl(_, _) => naga_oil::compose::ShaderLanguage::Glsl,
309 #[cfg(all(not(feature = "shader_format_glsl"), not(target_arch = "wasm32")))]
310 Source::Glsl(_, _) => panic!(
311 "GLSL is not supported in this configuration; use the feature `shader_format_glsl`"
312 ),
313 Source::SpirV(_) => panic!("spirv not yet implemented"),
314 Source::Wesl(_) => panic!("wesl not yet implemented"),
315 }
316 }
317}
318
319impl From<&Source> for naga_oil::compose::ShaderType {
320 fn from(value: &Source) -> Self {
321 match value {
322 Source::Wgsl(_) => naga_oil::compose::ShaderType::Wgsl,
323 #[cfg(any(feature = "shader_format_glsl", target_arch = "wasm32"))]
324 Source::Glsl(_, shader_stage) => match shader_stage {
325 naga::ShaderStage::Vertex => naga_oil::compose::ShaderType::GlslVertex,
326 naga::ShaderStage::Fragment => naga_oil::compose::ShaderType::GlslFragment,
327 naga::ShaderStage::Compute => panic!("glsl compute not yet implemented"),
328 naga::ShaderStage::Task => panic!("task shaders not yet implemented"),
329 naga::ShaderStage::Mesh => panic!("mesh shaders not yet implemented"),
330 },
331 #[cfg(all(not(feature = "shader_format_glsl"), not(target_arch = "wasm32")))]
332 Source::Glsl(_, _) => panic!(
333 "GLSL is not supported in this configuration; use the feature `shader_format_glsl`"
334 ),
335 Source::SpirV(_) => panic!("spirv not yet implemented"),
336 Source::Wesl(_) => panic!("wesl not yet implemented"),
337 }
338 }
339}
340
341#[derive(Default)]
342pub struct ShaderLoader;
343
344#[non_exhaustive]
345#[derive(Debug, Error)]
346pub enum ShaderLoaderError {
347 #[error("Could not load shader: {0}")]
348 Io(#[from] std::io::Error),
349 #[error("Could not parse shader: {0}")]
350 Parse(#[from] alloc::string::FromUtf8Error),
351}
352
353#[derive(serde::Serialize, serde::Deserialize, Debug, Default)]
355pub struct ShaderSettings {
356 pub shader_defs: Vec<ShaderDefVal>,
358}
359
360impl AssetLoader for ShaderLoader {
361 type Asset = Shader;
362 type Settings = ShaderSettings;
363 type Error = ShaderLoaderError;
364 async fn load(
365 &self,
366 reader: &mut dyn Reader,
367 settings: &Self::Settings,
368 load_context: &mut LoadContext<'_>,
369 ) -> Result<Shader, Self::Error> {
370 let ext = load_context.path().extension().unwrap().to_str().unwrap();
371 let path = load_context.asset_path().to_string();
372 let path = path.replace(std::path::MAIN_SEPARATOR, "/");
375 let mut bytes = Vec::new();
376 reader.read_to_end(&mut bytes).await?;
377 if ext != "wgsl" && !settings.shader_defs.is_empty() {
378 tracing::warn!(
379 "Tried to load a non-wgsl shader with shader defs, this isn't supported: \
380 The shader defs will be ignored."
381 );
382 }
383 let mut shader = match ext {
384 "spv" => Shader::from_spirv(bytes, load_context.path().to_string_lossy()),
385 "wgsl" => Shader::from_wgsl_with_defs(
386 String::from_utf8(bytes)?,
387 path,
388 settings.shader_defs.clone(),
389 ),
390 "vert" => Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Vertex, path),
391 "frag" => {
392 Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Fragment, path)
393 }
394 "comp" => {
395 Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Compute, path)
396 }
397 #[cfg(feature = "shader_format_wesl")]
398 "wesl" => Shader::from_wesl(String::from_utf8(bytes)?, path),
399 _ => panic!("unhandled extension: {ext}"),
400 };
401
402 for import in &shader.imports {
404 if let ShaderImport::AssetPath(asset_path) = import {
405 shader.file_dependencies.push(load_context.load(asset_path));
406 }
407 }
408 Ok(shader)
409 }
410
411 fn extensions(&self) -> &[&str] {
412 &["spv", "wgsl", "vert", "frag", "comp", "wesl"]
413 }
414}
415
416#[derive(Debug, PartialEq, Eq, Clone, Hash)]
417pub enum ShaderImport {
418 AssetPath(String),
419 Custom(String),
420}
421
422impl ShaderImport {
423 pub fn module_name(&self) -> Cow<'_, String> {
424 match self {
425 ShaderImport::AssetPath(s) => Cow::Owned(format!("\"{s}\"")),
426 ShaderImport::Custom(s) => Cow::Borrowed(s),
427 }
428 }
429}
430
431#[derive(Default)]
433pub enum ShaderRef {
434 #[default]
436 Default,
437 Handle(Handle<Shader>),
439 Path(AssetPath<'static>),
441}
442
443impl From<Handle<Shader>> for ShaderRef {
444 fn from(handle: Handle<Shader>) -> Self {
445 Self::Handle(handle)
446 }
447}
448
449impl From<AssetPath<'static>> for ShaderRef {
450 fn from(path: AssetPath<'static>) -> Self {
451 Self::Path(path)
452 }
453}
454
455impl From<&'static str> for ShaderRef {
456 fn from(path: &'static str) -> Self {
457 Self::Path(AssetPath::from(path))
458 }
459}