1use crate::{
2 binding_model,
3 hub::Hub,
4 id::{BindGroupLayoutId, PipelineLayoutId},
5 resource::{
6 Buffer, BufferAccessError, BufferAccessResult, BufferMapOperation, Labeled,
7 ResourceErrorIdent,
8 },
9 snatch::SnatchGuard,
10 Label, DOWNLEVEL_ERROR_MESSAGE,
11};
12
13use arrayvec::ArrayVec;
14use smallvec::SmallVec;
15use std::os::raw::c_char;
16use thiserror::Error;
17use wgt::{BufferAddress, DeviceLostReason, TextureFormat};
18
19use std::num::NonZeroU32;
20
21pub(crate) mod bgl;
22pub mod global;
23mod life;
24pub mod queue;
25pub mod resource;
26#[cfg(any(feature = "trace", feature = "replay"))]
27pub mod trace;
28pub use {life::WaitIdleError, resource::Device};
29
30pub const SHADER_STAGE_COUNT: usize = hal::MAX_CONCURRENT_SHADER_STAGES;
31pub(crate) const ZERO_BUFFER_SIZE: BufferAddress = 512 << 10;
34
35const CLEANUP_WAIT_MS: u32 = 60000;
38
39pub(crate) const ENTRYPOINT_FAILURE_ERROR: &str = "The given EntryPoint is Invalid";
40
41pub type DeviceDescriptor<'a> = wgt::DeviceDescriptor<Label<'a>>;
42
43#[repr(C)]
44#[derive(Clone, Copy, Debug, Eq, PartialEq)]
45#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
46pub enum HostMap {
47 Read,
48 Write,
49}
50
51#[derive(Clone, Debug, Hash, PartialEq)]
52#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
53pub(crate) struct AttachmentData<T> {
54 pub colors: ArrayVec<Option<T>, { hal::MAX_COLOR_ATTACHMENTS }>,
55 pub resolves: ArrayVec<T, { hal::MAX_COLOR_ATTACHMENTS }>,
56 pub depth_stencil: Option<T>,
57}
58impl<T: PartialEq> Eq for AttachmentData<T> {}
59
60#[derive(Clone, Debug, Hash, PartialEq)]
61#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
62pub(crate) struct RenderPassContext {
63 pub attachments: AttachmentData<TextureFormat>,
64 pub sample_count: u32,
65 pub multiview: Option<NonZeroU32>,
66}
67#[derive(Clone, Debug, Error)]
68#[non_exhaustive]
69pub enum RenderPassCompatibilityError {
70 #[error(
71 "Incompatible color attachments at indices {indices:?}: the RenderPass uses textures with formats {expected:?} but the {res} uses attachments with formats {actual:?}",
72 )]
73 IncompatibleColorAttachment {
74 indices: Vec<usize>,
75 expected: Vec<Option<TextureFormat>>,
76 actual: Vec<Option<TextureFormat>>,
77 res: ResourceErrorIdent,
78 },
79 #[error(
80 "Incompatible depth-stencil attachment format: the RenderPass uses a texture with format {expected:?} but the {res} uses an attachment with format {actual:?}",
81 )]
82 IncompatibleDepthStencilAttachment {
83 expected: Option<TextureFormat>,
84 actual: Option<TextureFormat>,
85 res: ResourceErrorIdent,
86 },
87 #[error(
88 "Incompatible sample count: the RenderPass uses textures with sample count {expected:?} but the {res} uses attachments with format {actual:?}",
89 )]
90 IncompatibleSampleCount {
91 expected: u32,
92 actual: u32,
93 res: ResourceErrorIdent,
94 },
95 #[error("Incompatible multiview setting: the RenderPass uses setting {expected:?} but the {res} uses setting {actual:?}")]
96 IncompatibleMultiview {
97 expected: Option<NonZeroU32>,
98 actual: Option<NonZeroU32>,
99 res: ResourceErrorIdent,
100 },
101}
102
103impl RenderPassContext {
104 pub(crate) fn check_compatible<T: Labeled>(
106 &self,
107 other: &Self,
108 res: &T,
109 ) -> Result<(), RenderPassCompatibilityError> {
110 if self.attachments.colors != other.attachments.colors {
111 let indices = self
112 .attachments
113 .colors
114 .iter()
115 .zip(&other.attachments.colors)
116 .enumerate()
117 .filter_map(|(idx, (left, right))| (left != right).then_some(idx))
118 .collect();
119 return Err(RenderPassCompatibilityError::IncompatibleColorAttachment {
120 indices,
121 expected: self.attachments.colors.iter().cloned().collect(),
122 actual: other.attachments.colors.iter().cloned().collect(),
123 res: res.error_ident(),
124 });
125 }
126 if self.attachments.depth_stencil != other.attachments.depth_stencil {
127 return Err(
128 RenderPassCompatibilityError::IncompatibleDepthStencilAttachment {
129 expected: self.attachments.depth_stencil,
130 actual: other.attachments.depth_stencil,
131 res: res.error_ident(),
132 },
133 );
134 }
135 if self.sample_count != other.sample_count {
136 return Err(RenderPassCompatibilityError::IncompatibleSampleCount {
137 expected: self.sample_count,
138 actual: other.sample_count,
139 res: res.error_ident(),
140 });
141 }
142 if self.multiview != other.multiview {
143 return Err(RenderPassCompatibilityError::IncompatibleMultiview {
144 expected: self.multiview,
145 actual: other.multiview,
146 res: res.error_ident(),
147 });
148 }
149 Ok(())
150 }
151}
152
153pub type BufferMapPendingClosure = (BufferMapOperation, BufferAccessResult);
154
155#[derive(Default)]
156pub struct UserClosures {
157 pub mappings: Vec<BufferMapPendingClosure>,
158 pub submissions: SmallVec<[queue::SubmittedWorkDoneClosure; 1]>,
159 pub device_lost_invocations: SmallVec<[DeviceLostInvocation; 1]>,
160}
161
162impl UserClosures {
163 fn extend(&mut self, other: Self) {
164 self.mappings.extend(other.mappings);
165 self.submissions.extend(other.submissions);
166 self.device_lost_invocations
167 .extend(other.device_lost_invocations);
168 }
169
170 fn fire(self) {
171 for (mut operation, status) in self.mappings {
177 if let Some(callback) = operation.callback.take() {
178 callback.call(status);
179 }
180 }
181 for closure in self.submissions {
182 closure.call();
183 }
184 for invocation in self.device_lost_invocations {
185 invocation
186 .closure
187 .call(invocation.reason, invocation.message);
188 }
189 }
190}
191
192#[cfg(send_sync)]
193pub type DeviceLostCallback = Box<dyn Fn(DeviceLostReason, String) + Send + 'static>;
194#[cfg(not(send_sync))]
195pub type DeviceLostCallback = Box<dyn Fn(DeviceLostReason, String) + 'static>;
196
197pub struct DeviceLostClosureRust {
198 pub callback: DeviceLostCallback,
199 consumed: bool,
200}
201
202impl Drop for DeviceLostClosureRust {
203 fn drop(&mut self) {
204 if !self.consumed {
205 panic!("DeviceLostClosureRust must be consumed before it is dropped.");
206 }
207 }
208}
209
210#[repr(C)]
211pub struct DeviceLostClosureC {
212 pub callback: unsafe extern "C" fn(user_data: *mut u8, reason: u8, message: *const c_char),
213 pub user_data: *mut u8,
214 consumed: bool,
215}
216
217#[cfg(send_sync)]
218unsafe impl Send for DeviceLostClosureC {}
219
220impl Drop for DeviceLostClosureC {
221 fn drop(&mut self) {
222 if !self.consumed {
223 panic!("DeviceLostClosureC must be consumed before it is dropped.");
224 }
225 }
226}
227
228pub struct DeviceLostClosure {
229 inner: DeviceLostClosureInner,
232}
233
234pub struct DeviceLostInvocation {
235 closure: DeviceLostClosure,
236 reason: DeviceLostReason,
237 message: String,
238}
239
240enum DeviceLostClosureInner {
241 Rust { inner: DeviceLostClosureRust },
242 C { inner: DeviceLostClosureC },
243}
244
245impl DeviceLostClosure {
246 pub fn from_rust(callback: DeviceLostCallback) -> Self {
247 let inner = DeviceLostClosureRust {
248 callback,
249 consumed: false,
250 };
251 Self {
252 inner: DeviceLostClosureInner::Rust { inner },
253 }
254 }
255
256 pub unsafe fn from_c(mut closure: DeviceLostClosureC) -> Self {
264 let inner = DeviceLostClosureC {
267 callback: closure.callback,
268 user_data: closure.user_data,
269 consumed: false,
270 };
271
272 closure.consumed = true;
274
275 Self {
276 inner: DeviceLostClosureInner::C { inner },
277 }
278 }
279
280 pub(crate) fn call(self, reason: DeviceLostReason, message: String) {
281 match self.inner {
282 DeviceLostClosureInner::Rust { mut inner } => {
283 inner.consumed = true;
284
285 (inner.callback)(reason, message)
286 }
287 DeviceLostClosureInner::C { mut inner } => unsafe {
289 inner.consumed = true;
290
291 let message = std::ffi::CString::new(message).unwrap();
294 (inner.callback)(inner.user_data, reason as u8, message.as_ptr())
295 },
296 }
297 }
298}
299
300fn map_buffer(
301 raw: &dyn hal::DynDevice,
302 buffer: &Buffer,
303 offset: BufferAddress,
304 size: BufferAddress,
305 kind: HostMap,
306 snatch_guard: &SnatchGuard,
307) -> Result<hal::BufferMapping, BufferAccessError> {
308 let raw_buffer = buffer.try_raw(snatch_guard)?;
309 let mapping = unsafe {
310 raw.map_buffer(raw_buffer, offset..offset + size)
311 .map_err(|e| buffer.device.handle_hal_error(e))?
312 };
313
314 if !mapping.is_coherent && kind == HostMap::Read {
315 #[allow(clippy::single_range_in_vec_init)]
316 unsafe {
317 raw.invalidate_mapped_ranges(raw_buffer, &[offset..offset + size]);
318 }
319 }
320
321 assert_eq!(offset % wgt::COPY_BUFFER_ALIGNMENT, 0);
322 assert_eq!(size % wgt::COPY_BUFFER_ALIGNMENT, 0);
323 let mapped = unsafe { std::slice::from_raw_parts_mut(mapping.ptr.as_ptr(), size as usize) };
338
339 if !mapping.is_coherent
341 && kind == HostMap::Read
342 && !buffer.usage.contains(wgt::BufferUsages::MAP_WRITE)
343 {
344 for uninitialized in buffer
345 .initialization_status
346 .write()
347 .uninitialized(offset..(size + offset))
348 {
349 let fill_range =
352 (uninitialized.start - offset) as usize..(uninitialized.end - offset) as usize;
353 mapped[fill_range].fill(0);
354 }
355 } else {
356 for uninitialized in buffer
357 .initialization_status
358 .write()
359 .drain(offset..(size + offset))
360 {
361 let fill_range =
364 (uninitialized.start - offset) as usize..(uninitialized.end - offset) as usize;
365 mapped[fill_range].fill(0);
366
367 if !mapping.is_coherent
369 && kind == HostMap::Read
370 && buffer.usage.contains(wgt::BufferUsages::MAP_WRITE)
371 {
372 unsafe { raw.flush_mapped_ranges(raw_buffer, &[uninitialized]) };
373 }
374 }
375 }
376
377 Ok(mapping)
378}
379
380#[derive(Clone, Debug)]
381#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
382pub struct DeviceMismatch {
383 pub(super) res: ResourceErrorIdent,
384 pub(super) res_device: ResourceErrorIdent,
385 pub(super) target: Option<ResourceErrorIdent>,
386 pub(super) target_device: ResourceErrorIdent,
387}
388
389impl std::fmt::Display for DeviceMismatch {
390 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
391 write!(
392 f,
393 "{} of {} doesn't match {}",
394 self.res_device, self.res, self.target_device
395 )?;
396 if let Some(target) = self.target.as_ref() {
397 write!(f, " of {target}")?;
398 }
399 Ok(())
400 }
401}
402
403impl std::error::Error for DeviceMismatch {}
404
405#[derive(Clone, Debug, Error)]
406#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
407#[non_exhaustive]
408pub enum DeviceError {
409 #[error("{0} is invalid.")]
410 Invalid(ResourceErrorIdent),
411 #[error("Parent device is lost")]
412 Lost,
413 #[error("Not enough memory left.")]
414 OutOfMemory,
415 #[error("Creation of a resource failed for a reason other than running out of memory.")]
416 ResourceCreationFailed,
417 #[error(transparent)]
418 DeviceMismatch(#[from] Box<DeviceMismatch>),
419}
420
421impl DeviceError {
422 pub fn from_hal(error: hal::DeviceError) -> Self {
426 match error {
427 hal::DeviceError::Lost => Self::Lost,
428 hal::DeviceError::OutOfMemory => Self::OutOfMemory,
429 hal::DeviceError::ResourceCreationFailed => Self::ResourceCreationFailed,
430 hal::DeviceError::Unexpected => Self::Lost,
431 }
432 }
433}
434
435#[derive(Clone, Debug, Error)]
436#[error("Features {0:?} are required but not enabled on the device")]
437pub struct MissingFeatures(pub wgt::Features);
438
439#[derive(Clone, Debug, Error)]
440#[error(
441 "Downlevel flags {0:?} are required but not supported on the device.\n{}",
442 DOWNLEVEL_ERROR_MESSAGE
443)]
444pub struct MissingDownlevelFlags(pub wgt::DownlevelFlags);
445
446#[derive(Clone, Debug)]
447#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
448pub struct ImplicitPipelineContext {
449 pub root_id: PipelineLayoutId,
450 pub group_ids: ArrayVec<BindGroupLayoutId, { hal::MAX_BIND_GROUPS }>,
451}
452
453pub struct ImplicitPipelineIds<'a> {
454 pub root_id: PipelineLayoutId,
455 pub group_ids: &'a [BindGroupLayoutId],
456}
457
458impl ImplicitPipelineIds<'_> {
459 fn prepare(self, hub: &Hub) -> ImplicitPipelineContext {
460 ImplicitPipelineContext {
461 root_id: hub.pipeline_layouts.prepare(Some(self.root_id)).id(),
462 group_ids: self
463 .group_ids
464 .iter()
465 .map(|id_in| hub.bind_group_layouts.prepare(Some(*id_in)).id())
466 .collect(),
467 }
468 }
469}
470
471pub fn create_validator(
473 features: wgt::Features,
474 downlevel: wgt::DownlevelFlags,
475 flags: naga::valid::ValidationFlags,
476) -> naga::valid::Validator {
477 use naga::valid::Capabilities as Caps;
478 let mut caps = Caps::empty();
479 caps.set(
480 Caps::PUSH_CONSTANT,
481 features.contains(wgt::Features::PUSH_CONSTANTS),
482 );
483 caps.set(Caps::FLOAT64, features.contains(wgt::Features::SHADER_F64));
484 caps.set(
485 Caps::PRIMITIVE_INDEX,
486 features.contains(wgt::Features::SHADER_PRIMITIVE_INDEX),
487 );
488 caps.set(
489 Caps::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
490 features
491 .contains(wgt::Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
492 );
493 caps.set(
494 Caps::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING,
495 features
496 .contains(wgt::Features::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING),
497 );
498 caps.set(
500 Caps::SAMPLER_NON_UNIFORM_INDEXING,
501 features
502 .contains(wgt::Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
503 );
504 caps.set(
505 Caps::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
506 features.contains(wgt::Features::TEXTURE_FORMAT_16BIT_NORM),
507 );
508 caps.set(Caps::MULTIVIEW, features.contains(wgt::Features::MULTIVIEW));
509 caps.set(
510 Caps::EARLY_DEPTH_TEST,
511 features.contains(wgt::Features::SHADER_EARLY_DEPTH_TEST),
512 );
513 caps.set(
514 Caps::SHADER_INT64,
515 features.contains(wgt::Features::SHADER_INT64),
516 );
517 caps.set(
518 Caps::SHADER_INT64_ATOMIC_MIN_MAX,
519 features.intersects(
520 wgt::Features::SHADER_INT64_ATOMIC_MIN_MAX | wgt::Features::SHADER_INT64_ATOMIC_ALL_OPS,
521 ),
522 );
523 caps.set(
524 Caps::SHADER_INT64_ATOMIC_ALL_OPS,
525 features.contains(wgt::Features::SHADER_INT64_ATOMIC_ALL_OPS),
526 );
527 caps.set(
528 Caps::MULTISAMPLED_SHADING,
529 downlevel.contains(wgt::DownlevelFlags::MULTISAMPLED_SHADING),
530 );
531 caps.set(
532 Caps::DUAL_SOURCE_BLENDING,
533 features.contains(wgt::Features::DUAL_SOURCE_BLENDING),
534 );
535 caps.set(
536 Caps::CUBE_ARRAY_TEXTURES,
537 downlevel.contains(wgt::DownlevelFlags::CUBE_ARRAY_TEXTURES),
538 );
539 caps.set(
540 Caps::SUBGROUP,
541 features.intersects(wgt::Features::SUBGROUP | wgt::Features::SUBGROUP_VERTEX),
542 );
543 caps.set(
544 Caps::SUBGROUP_BARRIER,
545 features.intersects(wgt::Features::SUBGROUP_BARRIER),
546 );
547 caps.set(
548 Caps::SUBGROUP_VERTEX_STAGE,
549 features.contains(wgt::Features::SUBGROUP_VERTEX),
550 );
551
552 naga::valid::Validator::new(flags, caps)
553}