wgpu_core/device/
mod.rs

1use crate::{
2    binding_model,
3    hub::Hub,
4    id::{BindGroupLayoutId, PipelineLayoutId},
5    resource::{
6        Buffer, BufferAccessError, BufferAccessResult, BufferMapOperation, Labeled,
7        ResourceErrorIdent,
8    },
9    snatch::SnatchGuard,
10    Label, DOWNLEVEL_ERROR_MESSAGE,
11};
12
13use arrayvec::ArrayVec;
14use smallvec::SmallVec;
15use std::os::raw::c_char;
16use thiserror::Error;
17use wgt::{BufferAddress, DeviceLostReason, TextureFormat};
18
19use std::num::NonZeroU32;
20
21pub(crate) mod bgl;
22pub mod global;
23mod life;
24pub mod queue;
25pub mod resource;
26#[cfg(any(feature = "trace", feature = "replay"))]
27pub mod trace;
28pub use {life::WaitIdleError, resource::Device};
29
30pub const SHADER_STAGE_COUNT: usize = hal::MAX_CONCURRENT_SHADER_STAGES;
31// Should be large enough for the largest possible texture row. This
32// value is enough for a 16k texture with float4 format.
33pub(crate) const ZERO_BUFFER_SIZE: BufferAddress = 512 << 10;
34
35// If a submission is not completed within this time, we go off into UB land.
36// See https://github.com/gfx-rs/wgpu/issues/4589. 60s to reduce the chances of this.
37const CLEANUP_WAIT_MS: u32 = 60000;
38
39pub(crate) const ENTRYPOINT_FAILURE_ERROR: &str = "The given EntryPoint is Invalid";
40
41pub type DeviceDescriptor<'a> = wgt::DeviceDescriptor<Label<'a>>;
42
43#[repr(C)]
44#[derive(Clone, Copy, Debug, Eq, PartialEq)]
45#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
46pub enum HostMap {
47    Read,
48    Write,
49}
50
51#[derive(Clone, Debug, Hash, PartialEq)]
52#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
53pub(crate) struct AttachmentData<T> {
54    pub colors: ArrayVec<Option<T>, { hal::MAX_COLOR_ATTACHMENTS }>,
55    pub resolves: ArrayVec<T, { hal::MAX_COLOR_ATTACHMENTS }>,
56    pub depth_stencil: Option<T>,
57}
58impl<T: PartialEq> Eq for AttachmentData<T> {}
59
60#[derive(Clone, Debug, Hash, PartialEq)]
61#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
62pub(crate) struct RenderPassContext {
63    pub attachments: AttachmentData<TextureFormat>,
64    pub sample_count: u32,
65    pub multiview: Option<NonZeroU32>,
66}
67#[derive(Clone, Debug, Error)]
68#[non_exhaustive]
69pub enum RenderPassCompatibilityError {
70    #[error(
71        "Incompatible color attachments at indices {indices:?}: the RenderPass uses textures with formats {expected:?} but the {res} uses attachments with formats {actual:?}",
72    )]
73    IncompatibleColorAttachment {
74        indices: Vec<usize>,
75        expected: Vec<Option<TextureFormat>>,
76        actual: Vec<Option<TextureFormat>>,
77        res: ResourceErrorIdent,
78    },
79    #[error(
80        "Incompatible depth-stencil attachment format: the RenderPass uses a texture with format {expected:?} but the {res} uses an attachment with format {actual:?}",
81    )]
82    IncompatibleDepthStencilAttachment {
83        expected: Option<TextureFormat>,
84        actual: Option<TextureFormat>,
85        res: ResourceErrorIdent,
86    },
87    #[error(
88        "Incompatible sample count: the RenderPass uses textures with sample count {expected:?} but the {res} uses attachments with format {actual:?}",
89    )]
90    IncompatibleSampleCount {
91        expected: u32,
92        actual: u32,
93        res: ResourceErrorIdent,
94    },
95    #[error("Incompatible multiview setting: the RenderPass uses setting {expected:?} but the {res} uses setting {actual:?}")]
96    IncompatibleMultiview {
97        expected: Option<NonZeroU32>,
98        actual: Option<NonZeroU32>,
99        res: ResourceErrorIdent,
100    },
101}
102
103impl RenderPassContext {
104    // Assumes the renderpass only contains one subpass
105    pub(crate) fn check_compatible<T: Labeled>(
106        &self,
107        other: &Self,
108        res: &T,
109    ) -> Result<(), RenderPassCompatibilityError> {
110        if self.attachments.colors != other.attachments.colors {
111            let indices = self
112                .attachments
113                .colors
114                .iter()
115                .zip(&other.attachments.colors)
116                .enumerate()
117                .filter_map(|(idx, (left, right))| (left != right).then_some(idx))
118                .collect();
119            return Err(RenderPassCompatibilityError::IncompatibleColorAttachment {
120                indices,
121                expected: self.attachments.colors.iter().cloned().collect(),
122                actual: other.attachments.colors.iter().cloned().collect(),
123                res: res.error_ident(),
124            });
125        }
126        if self.attachments.depth_stencil != other.attachments.depth_stencil {
127            return Err(
128                RenderPassCompatibilityError::IncompatibleDepthStencilAttachment {
129                    expected: self.attachments.depth_stencil,
130                    actual: other.attachments.depth_stencil,
131                    res: res.error_ident(),
132                },
133            );
134        }
135        if self.sample_count != other.sample_count {
136            return Err(RenderPassCompatibilityError::IncompatibleSampleCount {
137                expected: self.sample_count,
138                actual: other.sample_count,
139                res: res.error_ident(),
140            });
141        }
142        if self.multiview != other.multiview {
143            return Err(RenderPassCompatibilityError::IncompatibleMultiview {
144                expected: self.multiview,
145                actual: other.multiview,
146                res: res.error_ident(),
147            });
148        }
149        Ok(())
150    }
151}
152
153pub type BufferMapPendingClosure = (BufferMapOperation, BufferAccessResult);
154
155#[derive(Default)]
156pub struct UserClosures {
157    pub mappings: Vec<BufferMapPendingClosure>,
158    pub submissions: SmallVec<[queue::SubmittedWorkDoneClosure; 1]>,
159    pub device_lost_invocations: SmallVec<[DeviceLostInvocation; 1]>,
160}
161
162impl UserClosures {
163    fn extend(&mut self, other: Self) {
164        self.mappings.extend(other.mappings);
165        self.submissions.extend(other.submissions);
166        self.device_lost_invocations
167            .extend(other.device_lost_invocations);
168    }
169
170    fn fire(self) {
171        // Note: this logic is specifically moved out of `handle_mapping()` in order to
172        // have nothing locked by the time we execute users callback code.
173
174        // Mappings _must_ be fired before submissions, as the spec requires all mapping callbacks that are registered before
175        // a on_submitted_work_done callback to be fired before the on_submitted_work_done callback.
176        for (mut operation, status) in self.mappings {
177            if let Some(callback) = operation.callback.take() {
178                callback.call(status);
179            }
180        }
181        for closure in self.submissions {
182            closure.call();
183        }
184        for invocation in self.device_lost_invocations {
185            invocation
186                .closure
187                .call(invocation.reason, invocation.message);
188        }
189    }
190}
191
192#[cfg(send_sync)]
193pub type DeviceLostCallback = Box<dyn Fn(DeviceLostReason, String) + Send + 'static>;
194#[cfg(not(send_sync))]
195pub type DeviceLostCallback = Box<dyn Fn(DeviceLostReason, String) + 'static>;
196
197pub struct DeviceLostClosureRust {
198    pub callback: DeviceLostCallback,
199    consumed: bool,
200}
201
202impl Drop for DeviceLostClosureRust {
203    fn drop(&mut self) {
204        if !self.consumed {
205            panic!("DeviceLostClosureRust must be consumed before it is dropped.");
206        }
207    }
208}
209
210#[repr(C)]
211pub struct DeviceLostClosureC {
212    pub callback: unsafe extern "C" fn(user_data: *mut u8, reason: u8, message: *const c_char),
213    pub user_data: *mut u8,
214    consumed: bool,
215}
216
217#[cfg(send_sync)]
218unsafe impl Send for DeviceLostClosureC {}
219
220impl Drop for DeviceLostClosureC {
221    fn drop(&mut self) {
222        if !self.consumed {
223            panic!("DeviceLostClosureC must be consumed before it is dropped.");
224        }
225    }
226}
227
228pub struct DeviceLostClosure {
229    // We wrap this so creating the enum in the C variant can be unsafe,
230    // allowing our call function to be safe.
231    inner: DeviceLostClosureInner,
232}
233
234pub struct DeviceLostInvocation {
235    closure: DeviceLostClosure,
236    reason: DeviceLostReason,
237    message: String,
238}
239
240enum DeviceLostClosureInner {
241    Rust { inner: DeviceLostClosureRust },
242    C { inner: DeviceLostClosureC },
243}
244
245impl DeviceLostClosure {
246    pub fn from_rust(callback: DeviceLostCallback) -> Self {
247        let inner = DeviceLostClosureRust {
248            callback,
249            consumed: false,
250        };
251        Self {
252            inner: DeviceLostClosureInner::Rust { inner },
253        }
254    }
255
256    /// # Safety
257    ///
258    /// - The callback pointer must be valid to call with the provided `user_data`
259    ///   pointer.
260    ///
261    /// - Both pointers must point to `'static` data, as the callback may happen at
262    ///   an unspecified time.
263    pub unsafe fn from_c(mut closure: DeviceLostClosureC) -> Self {
264        // Build an inner with the values from closure, ensuring that
265        // inner.consumed is false.
266        let inner = DeviceLostClosureC {
267            callback: closure.callback,
268            user_data: closure.user_data,
269            consumed: false,
270        };
271
272        // Mark the original closure as consumed, so we can safely drop it.
273        closure.consumed = true;
274
275        Self {
276            inner: DeviceLostClosureInner::C { inner },
277        }
278    }
279
280    pub(crate) fn call(self, reason: DeviceLostReason, message: String) {
281        match self.inner {
282            DeviceLostClosureInner::Rust { mut inner } => {
283                inner.consumed = true;
284
285                (inner.callback)(reason, message)
286            }
287            // SAFETY: the contract of the call to from_c says that this unsafe is sound.
288            DeviceLostClosureInner::C { mut inner } => unsafe {
289                inner.consumed = true;
290
291                // Ensure message is structured as a null-terminated C string. It only
292                // needs to live as long as the callback invocation.
293                let message = std::ffi::CString::new(message).unwrap();
294                (inner.callback)(inner.user_data, reason as u8, message.as_ptr())
295            },
296        }
297    }
298}
299
300fn map_buffer(
301    raw: &dyn hal::DynDevice,
302    buffer: &Buffer,
303    offset: BufferAddress,
304    size: BufferAddress,
305    kind: HostMap,
306    snatch_guard: &SnatchGuard,
307) -> Result<hal::BufferMapping, BufferAccessError> {
308    let raw_buffer = buffer.try_raw(snatch_guard)?;
309    let mapping = unsafe {
310        raw.map_buffer(raw_buffer, offset..offset + size)
311            .map_err(|e| buffer.device.handle_hal_error(e))?
312    };
313
314    if !mapping.is_coherent && kind == HostMap::Read {
315        #[allow(clippy::single_range_in_vec_init)]
316        unsafe {
317            raw.invalidate_mapped_ranges(raw_buffer, &[offset..offset + size]);
318        }
319    }
320
321    assert_eq!(offset % wgt::COPY_BUFFER_ALIGNMENT, 0);
322    assert_eq!(size % wgt::COPY_BUFFER_ALIGNMENT, 0);
323    // Zero out uninitialized parts of the mapping. (Spec dictates all resources
324    // behave as if they were initialized with zero)
325    //
326    // If this is a read mapping, ideally we would use a `clear_buffer` command
327    // before reading the data from GPU (i.e. `invalidate_range`). However, this
328    // would require us to kick off and wait for a command buffer or piggy back
329    // on an existing one (the later is likely the only worthwhile option). As
330    // reading uninitialized memory isn't a particular important path to
331    // support, we instead just initialize the memory here and make sure it is
332    // GPU visible, so this happens at max only once for every buffer region.
333    //
334    // If this is a write mapping zeroing out the memory here is the only
335    // reasonable way as all data is pushed to GPU anyways.
336
337    let mapped = unsafe { std::slice::from_raw_parts_mut(mapping.ptr.as_ptr(), size as usize) };
338
339    // We can't call flush_mapped_ranges in this case, so we can't drain the uninitialized ranges either
340    if !mapping.is_coherent
341        && kind == HostMap::Read
342        && !buffer.usage.contains(wgt::BufferUsages::MAP_WRITE)
343    {
344        for uninitialized in buffer
345            .initialization_status
346            .write()
347            .uninitialized(offset..(size + offset))
348        {
349            // The mapping's pointer is already offset, however we track the
350            // uninitialized range relative to the buffer's start.
351            let fill_range =
352                (uninitialized.start - offset) as usize..(uninitialized.end - offset) as usize;
353            mapped[fill_range].fill(0);
354        }
355    } else {
356        for uninitialized in buffer
357            .initialization_status
358            .write()
359            .drain(offset..(size + offset))
360        {
361            // The mapping's pointer is already offset, however we track the
362            // uninitialized range relative to the buffer's start.
363            let fill_range =
364                (uninitialized.start - offset) as usize..(uninitialized.end - offset) as usize;
365            mapped[fill_range].fill(0);
366
367            // NOTE: This is only possible when MAPPABLE_PRIMARY_BUFFERS is enabled.
368            if !mapping.is_coherent
369                && kind == HostMap::Read
370                && buffer.usage.contains(wgt::BufferUsages::MAP_WRITE)
371            {
372                unsafe { raw.flush_mapped_ranges(raw_buffer, &[uninitialized]) };
373            }
374        }
375    }
376
377    Ok(mapping)
378}
379
380#[derive(Clone, Debug)]
381#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
382pub struct DeviceMismatch {
383    pub(super) res: ResourceErrorIdent,
384    pub(super) res_device: ResourceErrorIdent,
385    pub(super) target: Option<ResourceErrorIdent>,
386    pub(super) target_device: ResourceErrorIdent,
387}
388
389impl std::fmt::Display for DeviceMismatch {
390    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
391        write!(
392            f,
393            "{} of {} doesn't match {}",
394            self.res_device, self.res, self.target_device
395        )?;
396        if let Some(target) = self.target.as_ref() {
397            write!(f, " of {target}")?;
398        }
399        Ok(())
400    }
401}
402
403impl std::error::Error for DeviceMismatch {}
404
405#[derive(Clone, Debug, Error)]
406#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
407#[non_exhaustive]
408pub enum DeviceError {
409    #[error("{0} is invalid.")]
410    Invalid(ResourceErrorIdent),
411    #[error("Parent device is lost")]
412    Lost,
413    #[error("Not enough memory left.")]
414    OutOfMemory,
415    #[error("Creation of a resource failed for a reason other than running out of memory.")]
416    ResourceCreationFailed,
417    #[error(transparent)]
418    DeviceMismatch(#[from] Box<DeviceMismatch>),
419}
420
421impl DeviceError {
422    /// Only use this function in contexts where there is no `Device`.
423    ///
424    /// Use [`Device::handle_hal_error`] otherwise.
425    pub fn from_hal(error: hal::DeviceError) -> Self {
426        match error {
427            hal::DeviceError::Lost => Self::Lost,
428            hal::DeviceError::OutOfMemory => Self::OutOfMemory,
429            hal::DeviceError::ResourceCreationFailed => Self::ResourceCreationFailed,
430            hal::DeviceError::Unexpected => Self::Lost,
431        }
432    }
433}
434
435#[derive(Clone, Debug, Error)]
436#[error("Features {0:?} are required but not enabled on the device")]
437pub struct MissingFeatures(pub wgt::Features);
438
439#[derive(Clone, Debug, Error)]
440#[error(
441    "Downlevel flags {0:?} are required but not supported on the device.\n{}",
442    DOWNLEVEL_ERROR_MESSAGE
443)]
444pub struct MissingDownlevelFlags(pub wgt::DownlevelFlags);
445
446#[derive(Clone, Debug)]
447#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
448pub struct ImplicitPipelineContext {
449    pub root_id: PipelineLayoutId,
450    pub group_ids: ArrayVec<BindGroupLayoutId, { hal::MAX_BIND_GROUPS }>,
451}
452
453pub struct ImplicitPipelineIds<'a> {
454    pub root_id: PipelineLayoutId,
455    pub group_ids: &'a [BindGroupLayoutId],
456}
457
458impl ImplicitPipelineIds<'_> {
459    fn prepare(self, hub: &Hub) -> ImplicitPipelineContext {
460        ImplicitPipelineContext {
461            root_id: hub.pipeline_layouts.prepare(Some(self.root_id)).id(),
462            group_ids: self
463                .group_ids
464                .iter()
465                .map(|id_in| hub.bind_group_layouts.prepare(Some(*id_in)).id())
466                .collect(),
467        }
468    }
469}
470
471/// Create a validator with the given validation flags.
472pub fn create_validator(
473    features: wgt::Features,
474    downlevel: wgt::DownlevelFlags,
475    flags: naga::valid::ValidationFlags,
476) -> naga::valid::Validator {
477    use naga::valid::Capabilities as Caps;
478    let mut caps = Caps::empty();
479    caps.set(
480        Caps::PUSH_CONSTANT,
481        features.contains(wgt::Features::PUSH_CONSTANTS),
482    );
483    caps.set(Caps::FLOAT64, features.contains(wgt::Features::SHADER_F64));
484    caps.set(
485        Caps::PRIMITIVE_INDEX,
486        features.contains(wgt::Features::SHADER_PRIMITIVE_INDEX),
487    );
488    caps.set(
489        Caps::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
490        features
491            .contains(wgt::Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
492    );
493    caps.set(
494        Caps::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING,
495        features
496            .contains(wgt::Features::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING),
497    );
498    // TODO: This needs a proper wgpu feature
499    caps.set(
500        Caps::SAMPLER_NON_UNIFORM_INDEXING,
501        features
502            .contains(wgt::Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
503    );
504    caps.set(
505        Caps::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
506        features.contains(wgt::Features::TEXTURE_FORMAT_16BIT_NORM),
507    );
508    caps.set(Caps::MULTIVIEW, features.contains(wgt::Features::MULTIVIEW));
509    caps.set(
510        Caps::EARLY_DEPTH_TEST,
511        features.contains(wgt::Features::SHADER_EARLY_DEPTH_TEST),
512    );
513    caps.set(
514        Caps::SHADER_INT64,
515        features.contains(wgt::Features::SHADER_INT64),
516    );
517    caps.set(
518        Caps::SHADER_INT64_ATOMIC_MIN_MAX,
519        features.intersects(
520            wgt::Features::SHADER_INT64_ATOMIC_MIN_MAX | wgt::Features::SHADER_INT64_ATOMIC_ALL_OPS,
521        ),
522    );
523    caps.set(
524        Caps::SHADER_INT64_ATOMIC_ALL_OPS,
525        features.contains(wgt::Features::SHADER_INT64_ATOMIC_ALL_OPS),
526    );
527    caps.set(
528        Caps::MULTISAMPLED_SHADING,
529        downlevel.contains(wgt::DownlevelFlags::MULTISAMPLED_SHADING),
530    );
531    caps.set(
532        Caps::DUAL_SOURCE_BLENDING,
533        features.contains(wgt::Features::DUAL_SOURCE_BLENDING),
534    );
535    caps.set(
536        Caps::CUBE_ARRAY_TEXTURES,
537        downlevel.contains(wgt::DownlevelFlags::CUBE_ARRAY_TEXTURES),
538    );
539    caps.set(
540        Caps::SUBGROUP,
541        features.intersects(wgt::Features::SUBGROUP | wgt::Features::SUBGROUP_VERTEX),
542    );
543    caps.set(
544        Caps::SUBGROUP_BARRIER,
545        features.intersects(wgt::Features::SUBGROUP_BARRIER),
546    );
547    caps.set(
548        Caps::SUBGROUP_VERTEX_STAGE,
549        features.contains(wgt::Features::SUBGROUP_VERTEX),
550    );
551
552    naga::valid::Validator::new(flags, caps)
553}