wgpu_core/command/
compute.rs

1use crate::{
2    binding_model::{
3        BindError, BindGroup, LateMinBufferBindingSizeMismatch, PushConstantUploadError,
4    },
5    command::{
6        bind::Binder,
7        compute_command::ArcComputeCommand,
8        end_pipeline_statistics_query,
9        memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
10        validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites, BasePass,
11        BindGroupStateChange, CommandBuffer, CommandEncoderError, CommandEncoderStatus, MapPassErr,
12        PassErrorScope, PassTimestampWrites, QueryUseError, StateChange,
13    },
14    device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures},
15    global::Global,
16    hal_label, id,
17    init_tracker::{BufferInitTrackerAction, MemoryInitKind},
18    pipeline::ComputePipeline,
19    resource::{
20        self, Buffer, DestroyedResourceError, InvalidResourceError, Labeled,
21        MissingBufferUsageError, ParentDevice,
22    },
23    snatch::SnatchGuard,
24    track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex, UsageScope},
25    Label,
26};
27
28use thiserror::Error;
29use wgt::{BufferAddress, DynamicOffset};
30
31use std::sync::Arc;
32use std::{fmt, mem::size_of, str};
33
34use super::{bind::BinderError, memory_init::CommandBufferTextureMemoryActions};
35
36pub struct ComputePass {
37    /// All pass data & records is stored here.
38    ///
39    /// If this is `None`, the pass is in the 'ended' state and can no longer be used.
40    /// Any attempt to record more commands will result in a validation error.
41    base: Option<BasePass<ArcComputeCommand>>,
42
43    /// Parent command buffer that this pass records commands into.
44    ///
45    /// If it is none, this pass is invalid and any operation on it will return an error.
46    parent: Option<Arc<CommandBuffer>>,
47
48    timestamp_writes: Option<ArcPassTimestampWrites>,
49
50    // Resource binding dedupe state.
51    current_bind_groups: BindGroupStateChange,
52    current_pipeline: StateChange<id::ComputePipelineId>,
53}
54
55impl ComputePass {
56    /// If the parent command buffer is invalid, the returned pass will be invalid.
57    fn new(parent: Option<Arc<CommandBuffer>>, desc: ArcComputePassDescriptor) -> Self {
58        let ArcComputePassDescriptor {
59            label,
60            timestamp_writes,
61        } = desc;
62
63        Self {
64            base: Some(BasePass::new(label)),
65            parent,
66            timestamp_writes,
67
68            current_bind_groups: BindGroupStateChange::new(),
69            current_pipeline: StateChange::new(),
70        }
71    }
72
73    #[inline]
74    pub fn label(&self) -> Option<&str> {
75        self.base.as_ref().and_then(|base| base.label.as_deref())
76    }
77
78    fn base_mut<'a>(
79        &'a mut self,
80        scope: PassErrorScope,
81    ) -> Result<&'a mut BasePass<ArcComputeCommand>, ComputePassError> {
82        self.base
83            .as_mut()
84            .ok_or(ComputePassErrorInner::PassEnded)
85            .map_pass_err(scope)
86    }
87}
88
89impl fmt::Debug for ComputePass {
90    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91        match self.parent {
92            Some(ref cmd_buf) => write!(f, "ComputePass {{ parent: {} }}", cmd_buf.error_ident()),
93            None => write!(f, "ComputePass {{ parent: None }}"),
94        }
95    }
96}
97
98#[derive(Clone, Debug, Default)]
99pub struct ComputePassDescriptor<'a> {
100    pub label: Label<'a>,
101    /// Defines where and when timestamp values will be written for this pass.
102    pub timestamp_writes: Option<&'a PassTimestampWrites>,
103}
104
105struct ArcComputePassDescriptor<'a> {
106    pub label: &'a Label<'a>,
107    /// Defines where and when timestamp values will be written for this pass.
108    pub timestamp_writes: Option<ArcPassTimestampWrites>,
109}
110
111#[derive(Clone, Debug, Error)]
112#[non_exhaustive]
113pub enum DispatchError {
114    #[error("Compute pipeline must be set")]
115    MissingPipeline,
116    #[error(transparent)]
117    IncompatibleBindGroup(#[from] Box<BinderError>),
118    #[error(
119        "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
120    )]
121    InvalidGroupSize { current: [u32; 3], limit: u32 },
122    #[error(transparent)]
123    BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
124}
125
126/// Error encountered when performing a compute pass.
127#[derive(Clone, Debug, Error)]
128pub enum ComputePassErrorInner {
129    #[error(transparent)]
130    Device(#[from] DeviceError),
131    #[error(transparent)]
132    Encoder(#[from] CommandEncoderError),
133    #[error("Parent encoder is invalid")]
134    InvalidParentEncoder,
135    #[error("Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}")]
136    BindGroupIndexOutOfRange { index: u32, max: u32 },
137    #[error(transparent)]
138    DestroyedResource(#[from] DestroyedResourceError),
139    #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
140    UnalignedIndirectBufferOffset(BufferAddress),
141    #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
142    IndirectBufferOverrun {
143        offset: u64,
144        end_offset: u64,
145        buffer_size: u64,
146    },
147    #[error(transparent)]
148    ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
149    #[error(transparent)]
150    MissingBufferUsage(#[from] MissingBufferUsageError),
151    #[error("Cannot pop debug group, because number of pushed debug groups is zero")]
152    InvalidPopDebugGroup,
153    #[error(transparent)]
154    Dispatch(#[from] DispatchError),
155    #[error(transparent)]
156    Bind(#[from] BindError),
157    #[error(transparent)]
158    PushConstants(#[from] PushConstantUploadError),
159    #[error("Push constant offset must be aligned to 4 bytes")]
160    PushConstantOffsetAlignment,
161    #[error("Push constant size must be aligned to 4 bytes")]
162    PushConstantSizeAlignment,
163    #[error("Ran out of push constant space. Don't set 4gb of push constants per ComputePass.")]
164    PushConstantOutOfMemory,
165    #[error(transparent)]
166    QueryUse(#[from] QueryUseError),
167    #[error(transparent)]
168    MissingFeatures(#[from] MissingFeatures),
169    #[error(transparent)]
170    MissingDownlevelFlags(#[from] MissingDownlevelFlags),
171    #[error("The compute pass has already been ended and no further commands can be recorded")]
172    PassEnded,
173    #[error(transparent)]
174    InvalidResource(#[from] InvalidResourceError),
175}
176
177/// Error encountered when performing a compute pass.
178#[derive(Clone, Debug, Error)]
179#[error("{scope}")]
180pub struct ComputePassError {
181    pub scope: PassErrorScope,
182    #[source]
183    pub(super) inner: ComputePassErrorInner,
184}
185
186impl<T, E> MapPassErr<T, ComputePassError> for Result<T, E>
187where
188    E: Into<ComputePassErrorInner>,
189{
190    fn map_pass_err(self, scope: PassErrorScope) -> Result<T, ComputePassError> {
191        self.map_err(|inner| ComputePassError {
192            scope,
193            inner: inner.into(),
194        })
195    }
196}
197
198struct State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder> {
199    binder: Binder,
200    pipeline: Option<Arc<ComputePipeline>>,
201    scope: UsageScope<'scope>,
202    debug_scope_depth: u32,
203
204    snatch_guard: SnatchGuard<'snatch_guard>,
205
206    device: &'cmd_buf Arc<Device>,
207
208    raw_encoder: &'raw_encoder mut dyn hal::DynCommandEncoder,
209
210    tracker: &'cmd_buf mut Tracker,
211    buffer_memory_init_actions: &'cmd_buf mut Vec<BufferInitTrackerAction>,
212    texture_memory_actions: &'cmd_buf mut CommandBufferTextureMemoryActions,
213
214    temp_offsets: Vec<u32>,
215    dynamic_offset_count: usize,
216    string_offset: usize,
217    active_query: Option<(Arc<resource::QuerySet>, u32)>,
218
219    push_constants: Vec<u32>,
220
221    intermediate_trackers: Tracker,
222
223    /// Immediate texture inits required because of prior discards. Need to
224    /// be inserted before texture reads.
225    pending_discard_init_fixups: SurfacesInDiscardState,
226}
227
228impl<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder>
229    State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder>
230{
231    fn is_ready(&self) -> Result<(), DispatchError> {
232        if let Some(pipeline) = self.pipeline.as_ref() {
233            self.binder.check_compatibility(pipeline.as_ref())?;
234            self.binder.check_late_buffer_bindings()?;
235            Ok(())
236        } else {
237            Err(DispatchError::MissingPipeline)
238        }
239    }
240
241    // `extra_buffer` is there to represent the indirect buffer that is also
242    // part of the usage scope.
243    fn flush_states(
244        &mut self,
245        indirect_buffer: Option<TrackerIndex>,
246    ) -> Result<(), ResourceUsageCompatibilityError> {
247        for bind_group in self.binder.list_active() {
248            unsafe { self.scope.merge_bind_group(&bind_group.used)? };
249            // Note: stateless trackers are not merged: the lifetime reference
250            // is held to the bind group itself.
251        }
252
253        for bind_group in self.binder.list_active() {
254            unsafe {
255                self.intermediate_trackers
256                    .set_and_remove_from_usage_scope_sparse(&mut self.scope, &bind_group.used)
257            }
258        }
259
260        // Add the state of the indirect buffer if it hasn't been hit before.
261        unsafe {
262            self.intermediate_trackers
263                .buffers
264                .set_and_remove_from_usage_scope_sparse(&mut self.scope.buffers, indirect_buffer);
265        }
266
267        CommandBuffer::drain_barriers(
268            self.raw_encoder,
269            &mut self.intermediate_trackers,
270            &self.snatch_guard,
271        );
272        Ok(())
273    }
274}
275
276// Running the compute pass.
277
278impl Global {
279    /// Creates a compute pass.
280    ///
281    /// If creation fails, an invalid pass is returned.
282    /// Any operation on an invalid pass will return an error.
283    ///
284    /// If successful, puts the encoder into the [`CommandEncoderStatus::Locked`] state.
285    pub fn command_encoder_create_compute_pass(
286        &self,
287        encoder_id: id::CommandEncoderId,
288        desc: &ComputePassDescriptor<'_>,
289    ) -> (ComputePass, Option<CommandEncoderError>) {
290        let hub = &self.hub;
291
292        let mut arc_desc = ArcComputePassDescriptor {
293            label: &desc.label,
294            timestamp_writes: None, // Handle only once we resolved the encoder.
295        };
296
297        let make_err = |e, arc_desc| (ComputePass::new(None, arc_desc), Some(e));
298
299        let cmd_buf = hub.command_buffers.get(encoder_id.into_command_buffer_id());
300
301        match cmd_buf
302            .try_get()
303            .map_err(|e| e.into())
304            .and_then(|mut cmd_buf_data| cmd_buf_data.lock_encoder())
305        {
306            Ok(_) => {}
307            Err(e) => return make_err(e, arc_desc),
308        };
309
310        arc_desc.timestamp_writes = if let Some(tw) = desc.timestamp_writes {
311            let query_set = match hub.query_sets.get(tw.query_set).get() {
312                Ok(query_set) => query_set,
313                Err(e) => return make_err(e.into(), arc_desc),
314            };
315
316            Some(ArcPassTimestampWrites {
317                query_set,
318                beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
319                end_of_pass_write_index: tw.end_of_pass_write_index,
320            })
321        } else {
322            None
323        };
324
325        (ComputePass::new(Some(cmd_buf), arc_desc), None)
326    }
327
328    /// Note that this differs from [`Self::compute_pass_end`], it will
329    /// create a new pass, replay the commands and end the pass.
330    #[doc(hidden)]
331    #[cfg(any(feature = "serde", feature = "replay"))]
332    pub fn compute_pass_end_with_unresolved_commands(
333        &self,
334        encoder_id: id::CommandEncoderId,
335        base: BasePass<super::ComputeCommand>,
336        timestamp_writes: Option<&PassTimestampWrites>,
337    ) -> Result<(), ComputePassError> {
338        let pass_scope = PassErrorScope::Pass;
339
340        #[cfg(feature = "trace")]
341        {
342            let cmd_buf = self
343                .hub
344                .command_buffers
345                .get(encoder_id.into_command_buffer_id());
346            let mut cmd_buf_data = cmd_buf.try_get().map_pass_err(pass_scope)?;
347
348            if let Some(ref mut list) = cmd_buf_data.commands {
349                list.push(crate::device::trace::Command::RunComputePass {
350                    base: BasePass {
351                        label: base.label.clone(),
352                        commands: base.commands.clone(),
353                        dynamic_offsets: base.dynamic_offsets.clone(),
354                        string_data: base.string_data.clone(),
355                        push_constant_data: base.push_constant_data.clone(),
356                    },
357                    timestamp_writes: timestamp_writes.cloned(),
358                });
359            }
360        }
361
362        let BasePass {
363            label,
364            commands,
365            dynamic_offsets,
366            string_data,
367            push_constant_data,
368        } = base;
369
370        let (mut compute_pass, encoder_error) = self.command_encoder_create_compute_pass(
371            encoder_id,
372            &ComputePassDescriptor {
373                label: label.as_deref().map(std::borrow::Cow::Borrowed),
374                timestamp_writes,
375            },
376        );
377        if let Some(err) = encoder_error {
378            return Err(ComputePassError {
379                scope: pass_scope,
380                inner: err.into(),
381            });
382        };
383
384        compute_pass.base = Some(BasePass {
385            label,
386            commands: super::ComputeCommand::resolve_compute_command_ids(&self.hub, &commands)?,
387            dynamic_offsets,
388            string_data,
389            push_constant_data,
390        });
391
392        self.compute_pass_end(&mut compute_pass)
393    }
394
395    pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), ComputePassError> {
396        profiling::scope!("CommandEncoder::run_compute_pass");
397        let pass_scope = PassErrorScope::Pass;
398
399        let cmd_buf = pass
400            .parent
401            .as_ref()
402            .ok_or(ComputePassErrorInner::InvalidParentEncoder)
403            .map_pass_err(pass_scope)?;
404
405        let base = pass
406            .base
407            .take()
408            .ok_or(ComputePassErrorInner::PassEnded)
409            .map_pass_err(pass_scope)?;
410
411        let device = &cmd_buf.device;
412        device.check_is_valid().map_pass_err(pass_scope)?;
413
414        let mut cmd_buf_data = cmd_buf.try_get().map_pass_err(pass_scope)?;
415        cmd_buf_data.unlock_encoder().map_pass_err(pass_scope)?;
416        let cmd_buf_data = &mut *cmd_buf_data;
417
418        let encoder = &mut cmd_buf_data.encoder;
419        let status = &mut cmd_buf_data.status;
420
421        // We automatically keep extending command buffers over time, and because
422        // we want to insert a command buffer _before_ what we're about to record,
423        // we need to make sure to close the previous one.
424        encoder.close(&cmd_buf.device).map_pass_err(pass_scope)?;
425        // will be reset to true if recording is done without errors
426        *status = CommandEncoderStatus::Error;
427        let raw_encoder = encoder.open(&cmd_buf.device).map_pass_err(pass_scope)?;
428
429        let mut state = State {
430            binder: Binder::new(),
431            pipeline: None,
432            scope: device.new_usage_scope(),
433            debug_scope_depth: 0,
434
435            snatch_guard: device.snatchable_lock.read(),
436
437            device,
438            raw_encoder,
439            tracker: &mut cmd_buf_data.trackers,
440            buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions,
441            texture_memory_actions: &mut cmd_buf_data.texture_memory_actions,
442
443            temp_offsets: Vec::new(),
444            dynamic_offset_count: 0,
445            string_offset: 0,
446            active_query: None,
447
448            push_constants: Vec::new(),
449
450            intermediate_trackers: Tracker::new(),
451
452            pending_discard_init_fixups: SurfacesInDiscardState::new(),
453        };
454
455        let indices = &state.device.tracker_indices;
456        state.tracker.buffers.set_size(indices.buffers.size());
457        state.tracker.textures.set_size(indices.textures.size());
458
459        let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
460            if let Some(tw) = pass.timestamp_writes.take() {
461                tw.query_set
462                    .same_device_as(cmd_buf.as_ref())
463                    .map_pass_err(pass_scope)?;
464
465                let query_set = state.tracker.query_sets.insert_single(tw.query_set);
466
467                // Unlike in render passes we can't delay resetting the query sets since
468                // there is no auxiliary pass.
469                let range = if let (Some(index_a), Some(index_b)) =
470                    (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
471                {
472                    Some(index_a.min(index_b)..index_a.max(index_b) + 1)
473                } else {
474                    tw.beginning_of_pass_write_index
475                        .or(tw.end_of_pass_write_index)
476                        .map(|i| i..i + 1)
477                };
478                // Range should always be Some, both values being None should lead to a validation error.
479                // But no point in erroring over that nuance here!
480                if let Some(range) = range {
481                    unsafe {
482                        state.raw_encoder.reset_queries(query_set.raw(), range);
483                    }
484                }
485
486                Some(hal::PassTimestampWrites {
487                    query_set: query_set.raw(),
488                    beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
489                    end_of_pass_write_index: tw.end_of_pass_write_index,
490                })
491            } else {
492                None
493            };
494
495        let hal_desc = hal::ComputePassDescriptor {
496            label: hal_label(base.label.as_deref(), device.instance_flags),
497            timestamp_writes,
498        };
499
500        unsafe {
501            state.raw_encoder.begin_compute_pass(&hal_desc);
502        }
503
504        for command in base.commands {
505            match command {
506                ArcComputeCommand::SetBindGroup {
507                    index,
508                    num_dynamic_offsets,
509                    bind_group,
510                } => {
511                    let scope = PassErrorScope::SetBindGroup;
512                    set_bind_group(
513                        &mut state,
514                        cmd_buf,
515                        &base.dynamic_offsets,
516                        index,
517                        num_dynamic_offsets,
518                        bind_group,
519                    )
520                    .map_pass_err(scope)?;
521                }
522                ArcComputeCommand::SetPipeline(pipeline) => {
523                    let scope = PassErrorScope::SetPipelineCompute;
524                    set_pipeline(&mut state, cmd_buf, pipeline).map_pass_err(scope)?;
525                }
526                ArcComputeCommand::SetPushConstant {
527                    offset,
528                    size_bytes,
529                    values_offset,
530                } => {
531                    let scope = PassErrorScope::SetPushConstant;
532                    set_push_constant(
533                        &mut state,
534                        &base.push_constant_data,
535                        offset,
536                        size_bytes,
537                        values_offset,
538                    )
539                    .map_pass_err(scope)?;
540                }
541                ArcComputeCommand::Dispatch(groups) => {
542                    let scope = PassErrorScope::Dispatch { indirect: false };
543                    dispatch(&mut state, groups).map_pass_err(scope)?;
544                }
545                ArcComputeCommand::DispatchIndirect { buffer, offset } => {
546                    let scope = PassErrorScope::Dispatch { indirect: true };
547                    dispatch_indirect(&mut state, cmd_buf, buffer, offset).map_pass_err(scope)?;
548                }
549                ArcComputeCommand::PushDebugGroup { color: _, len } => {
550                    push_debug_group(&mut state, &base.string_data, len);
551                }
552                ArcComputeCommand::PopDebugGroup => {
553                    let scope = PassErrorScope::PopDebugGroup;
554                    pop_debug_group(&mut state).map_pass_err(scope)?;
555                }
556                ArcComputeCommand::InsertDebugMarker { color: _, len } => {
557                    insert_debug_marker(&mut state, &base.string_data, len);
558                }
559                ArcComputeCommand::WriteTimestamp {
560                    query_set,
561                    query_index,
562                } => {
563                    let scope = PassErrorScope::WriteTimestamp;
564                    write_timestamp(&mut state, cmd_buf, query_set, query_index)
565                        .map_pass_err(scope)?;
566                }
567                ArcComputeCommand::BeginPipelineStatisticsQuery {
568                    query_set,
569                    query_index,
570                } => {
571                    let scope = PassErrorScope::BeginPipelineStatisticsQuery;
572                    validate_and_begin_pipeline_statistics_query(
573                        query_set,
574                        state.raw_encoder,
575                        &mut state.tracker.query_sets,
576                        cmd_buf,
577                        query_index,
578                        None,
579                        &mut state.active_query,
580                    )
581                    .map_pass_err(scope)?;
582                }
583                ArcComputeCommand::EndPipelineStatisticsQuery => {
584                    let scope = PassErrorScope::EndPipelineStatisticsQuery;
585                    end_pipeline_statistics_query(state.raw_encoder, &mut state.active_query)
586                        .map_pass_err(scope)?;
587                }
588            }
589        }
590
591        unsafe {
592            state.raw_encoder.end_compute_pass();
593        }
594
595        // We've successfully recorded the compute pass, bring the
596        // command buffer out of the error state.
597        *status = CommandEncoderStatus::Recording;
598
599        let State {
600            snatch_guard,
601            tracker,
602            intermediate_trackers,
603            pending_discard_init_fixups,
604            ..
605        } = state;
606
607        // Stop the current command buffer.
608        encoder.close(&cmd_buf.device).map_pass_err(pass_scope)?;
609
610        // Create a new command buffer, which we will insert _before_ the body of the compute pass.
611        //
612        // Use that buffer to insert barriers and clear discarded images.
613        let transit = encoder.open(&cmd_buf.device).map_pass_err(pass_scope)?;
614        fixup_discarded_surfaces(
615            pending_discard_init_fixups.into_iter(),
616            transit,
617            &mut tracker.textures,
618            device,
619            &snatch_guard,
620        );
621        CommandBuffer::insert_barriers_from_tracker(
622            transit,
623            tracker,
624            &intermediate_trackers,
625            &snatch_guard,
626        );
627        // Close the command buffer, and swap it with the previous.
628        encoder
629            .close_and_swap(&cmd_buf.device)
630            .map_pass_err(pass_scope)?;
631
632        Ok(())
633    }
634}
635
636fn set_bind_group(
637    state: &mut State,
638    cmd_buf: &CommandBuffer,
639    dynamic_offsets: &[DynamicOffset],
640    index: u32,
641    num_dynamic_offsets: usize,
642    bind_group: Option<Arc<BindGroup>>,
643) -> Result<(), ComputePassErrorInner> {
644    let max_bind_groups = state.device.limits.max_bind_groups;
645    if index >= max_bind_groups {
646        return Err(ComputePassErrorInner::BindGroupIndexOutOfRange {
647            index,
648            max: max_bind_groups,
649        });
650    }
651
652    state.temp_offsets.clear();
653    state.temp_offsets.extend_from_slice(
654        &dynamic_offsets
655            [state.dynamic_offset_count..state.dynamic_offset_count + num_dynamic_offsets],
656    );
657    state.dynamic_offset_count += num_dynamic_offsets;
658
659    if bind_group.is_none() {
660        // TODO: Handle bind_group None.
661        return Ok(());
662    }
663
664    let bind_group = bind_group.unwrap();
665    let bind_group = state.tracker.bind_groups.insert_single(bind_group);
666
667    bind_group.same_device_as(cmd_buf)?;
668
669    bind_group.validate_dynamic_bindings(index, &state.temp_offsets)?;
670
671    state
672        .buffer_memory_init_actions
673        .extend(bind_group.used_buffer_ranges.iter().filter_map(|action| {
674            action
675                .buffer
676                .initialization_status
677                .read()
678                .check_action(action)
679        }));
680
681    for action in bind_group.used_texture_ranges.iter() {
682        state
683            .pending_discard_init_fixups
684            .extend(state.texture_memory_actions.register_init_action(action));
685    }
686
687    let pipeline_layout = state.binder.pipeline_layout.clone();
688    let entries = state
689        .binder
690        .assign_group(index as usize, bind_group, &state.temp_offsets);
691    if !entries.is_empty() && pipeline_layout.is_some() {
692        let pipeline_layout = pipeline_layout.as_ref().unwrap().raw();
693        for (i, e) in entries.iter().enumerate() {
694            if let Some(group) = e.group.as_ref() {
695                let raw_bg = group.try_raw(&state.snatch_guard)?;
696                unsafe {
697                    state.raw_encoder.set_bind_group(
698                        pipeline_layout,
699                        index + i as u32,
700                        Some(raw_bg),
701                        &e.dynamic_offsets,
702                    );
703                }
704            }
705        }
706    }
707    Ok(())
708}
709
710fn set_pipeline(
711    state: &mut State,
712    cmd_buf: &CommandBuffer,
713    pipeline: Arc<ComputePipeline>,
714) -> Result<(), ComputePassErrorInner> {
715    pipeline.same_device_as(cmd_buf)?;
716
717    state.pipeline = Some(pipeline.clone());
718
719    let pipeline = state.tracker.compute_pipelines.insert_single(pipeline);
720
721    unsafe {
722        state.raw_encoder.set_compute_pipeline(pipeline.raw());
723    }
724
725    // Rebind resources
726    if state.binder.pipeline_layout.is_none()
727        || !state
728            .binder
729            .pipeline_layout
730            .as_ref()
731            .unwrap()
732            .is_equal(&pipeline.layout)
733    {
734        let (start_index, entries) = state
735            .binder
736            .change_pipeline_layout(&pipeline.layout, &pipeline.late_sized_buffer_groups);
737        if !entries.is_empty() {
738            for (i, e) in entries.iter().enumerate() {
739                if let Some(group) = e.group.as_ref() {
740                    let raw_bg = group.try_raw(&state.snatch_guard)?;
741                    unsafe {
742                        state.raw_encoder.set_bind_group(
743                            pipeline.layout.raw(),
744                            start_index as u32 + i as u32,
745                            Some(raw_bg),
746                            &e.dynamic_offsets,
747                        );
748                    }
749                }
750            }
751        }
752
753        // TODO: integrate this in the code below once we simplify push constants
754        state.push_constants.clear();
755        // Note that can only be one range for each stage. See the `MoreThanOnePushConstantRangePerStage` error.
756        if let Some(push_constant_range) =
757            pipeline.layout.push_constant_ranges.iter().find_map(|pcr| {
758                pcr.stages
759                    .contains(wgt::ShaderStages::COMPUTE)
760                    .then_some(pcr.range.clone())
761            })
762        {
763            // Note that non-0 range start doesn't work anyway https://github.com/gfx-rs/wgpu/issues/4502
764            let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize;
765            state.push_constants.extend(core::iter::repeat(0).take(len));
766        }
767
768        // Clear push constant ranges
769        let non_overlapping =
770            super::bind::compute_nonoverlapping_ranges(&pipeline.layout.push_constant_ranges);
771        for range in non_overlapping {
772            let offset = range.range.start;
773            let size_bytes = range.range.end - offset;
774            super::push_constant_clear(offset, size_bytes, |clear_offset, clear_data| unsafe {
775                state.raw_encoder.set_push_constants(
776                    pipeline.layout.raw(),
777                    wgt::ShaderStages::COMPUTE,
778                    clear_offset,
779                    clear_data,
780                );
781            });
782        }
783    }
784    Ok(())
785}
786
787fn set_push_constant(
788    state: &mut State,
789    push_constant_data: &[u32],
790    offset: u32,
791    size_bytes: u32,
792    values_offset: u32,
793) -> Result<(), ComputePassErrorInner> {
794    let end_offset_bytes = offset + size_bytes;
795    let values_end_offset = (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
796    let data_slice = &push_constant_data[(values_offset as usize)..values_end_offset];
797
798    let pipeline_layout = state
799        .binder
800        .pipeline_layout
801        .as_ref()
802        // TODO: don't error here, lazily update the push constants using `state.push_constants`
803        .ok_or(ComputePassErrorInner::Dispatch(
804            DispatchError::MissingPipeline,
805        ))?;
806
807    pipeline_layout.validate_push_constant_ranges(
808        wgt::ShaderStages::COMPUTE,
809        offset,
810        end_offset_bytes,
811    )?;
812
813    let offset_in_elements = (offset / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
814    let size_in_elements = (size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
815    state.push_constants[offset_in_elements..][..size_in_elements].copy_from_slice(data_slice);
816
817    unsafe {
818        state.raw_encoder.set_push_constants(
819            pipeline_layout.raw(),
820            wgt::ShaderStages::COMPUTE,
821            offset,
822            data_slice,
823        );
824    }
825    Ok(())
826}
827
828fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
829    state.is_ready()?;
830
831    state.flush_states(None)?;
832
833    let groups_size_limit = state.device.limits.max_compute_workgroups_per_dimension;
834
835    if groups[0] > groups_size_limit
836        || groups[1] > groups_size_limit
837        || groups[2] > groups_size_limit
838    {
839        return Err(ComputePassErrorInner::Dispatch(
840            DispatchError::InvalidGroupSize {
841                current: groups,
842                limit: groups_size_limit,
843            },
844        ));
845    }
846
847    unsafe {
848        state.raw_encoder.dispatch(groups);
849    }
850    Ok(())
851}
852
853fn dispatch_indirect(
854    state: &mut State,
855    cmd_buf: &CommandBuffer,
856    buffer: Arc<Buffer>,
857    offset: u64,
858) -> Result<(), ComputePassErrorInner> {
859    buffer.same_device_as(cmd_buf)?;
860
861    state.is_ready()?;
862
863    state
864        .device
865        .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
866
867    buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
868
869    if offset % 4 != 0 {
870        return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
871    }
872
873    let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
874    if end_offset > buffer.size {
875        return Err(ComputePassErrorInner::IndirectBufferOverrun {
876            offset,
877            end_offset,
878            buffer_size: buffer.size,
879        });
880    }
881
882    let stride = 3 * 4; // 3 integers, x/y/z group size
883    state
884        .buffer_memory_init_actions
885        .extend(buffer.initialization_status.read().create_action(
886            &buffer,
887            offset..(offset + stride),
888            MemoryInitKind::NeedsInitializedMemory,
889        ));
890
891    #[cfg(feature = "indirect-validation")]
892    {
893        let params = state.device.indirect_validation.as_ref().unwrap().params(
894            &state.device.limits,
895            offset,
896            buffer.size,
897        );
898
899        unsafe {
900            state.raw_encoder.set_compute_pipeline(params.pipeline);
901        }
902
903        unsafe {
904            state.raw_encoder.set_push_constants(
905                params.pipeline_layout,
906                wgt::ShaderStages::COMPUTE,
907                0,
908                &[params.offset_remainder as u32 / 4],
909            );
910        }
911
912        unsafe {
913            state.raw_encoder.set_bind_group(
914                params.pipeline_layout,
915                0,
916                Some(params.dst_bind_group),
917                &[],
918            );
919        }
920        unsafe {
921            state.raw_encoder.set_bind_group(
922                params.pipeline_layout,
923                1,
924                Some(
925                    buffer
926                        .raw_indirect_validation_bind_group
927                        .get(&state.snatch_guard)
928                        .unwrap()
929                        .as_ref(),
930                ),
931                &[params.aligned_offset as u32],
932            );
933        }
934
935        let src_transition = state
936            .intermediate_trackers
937            .buffers
938            .set_single(&buffer, hal::BufferUses::STORAGE_READ);
939        let src_barrier =
940            src_transition.map(|transition| transition.into_hal(&buffer, &state.snatch_guard));
941        unsafe {
942            state.raw_encoder.transition_buffers(src_barrier.as_slice());
943        }
944
945        unsafe {
946            state.raw_encoder.transition_buffers(&[hal::BufferBarrier {
947                buffer: params.dst_buffer,
948                usage: hal::BufferUses::INDIRECT..hal::BufferUses::STORAGE_READ_WRITE,
949            }]);
950        }
951
952        unsafe {
953            state.raw_encoder.dispatch([1, 1, 1]);
954        }
955
956        // reset state
957        {
958            let pipeline = state.pipeline.as_ref().unwrap();
959
960            unsafe {
961                state.raw_encoder.set_compute_pipeline(pipeline.raw());
962            }
963
964            if !state.push_constants.is_empty() {
965                unsafe {
966                    state.raw_encoder.set_push_constants(
967                        pipeline.layout.raw(),
968                        wgt::ShaderStages::COMPUTE,
969                        0,
970                        &state.push_constants,
971                    );
972                }
973            }
974
975            for (i, e) in state.binder.list_valid() {
976                let group = e.group.as_ref().unwrap();
977                let raw_bg = group.try_raw(&state.snatch_guard)?;
978                unsafe {
979                    state.raw_encoder.set_bind_group(
980                        pipeline.layout.raw(),
981                        i as u32,
982                        Some(raw_bg),
983                        &e.dynamic_offsets,
984                    );
985                }
986            }
987        }
988
989        unsafe {
990            state.raw_encoder.transition_buffers(&[hal::BufferBarrier {
991                buffer: params.dst_buffer,
992                usage: hal::BufferUses::STORAGE_READ_WRITE..hal::BufferUses::INDIRECT,
993            }]);
994        }
995
996        state.flush_states(None)?;
997        unsafe {
998            state.raw_encoder.dispatch_indirect(params.dst_buffer, 0);
999        }
1000    };
1001    #[cfg(not(feature = "indirect-validation"))]
1002    {
1003        state
1004            .scope
1005            .buffers
1006            .merge_single(&buffer, hal::BufferUses::INDIRECT)?;
1007
1008        use crate::resource::Trackable;
1009        state.flush_states(Some(buffer.tracker_index()))?;
1010
1011        let buf_raw = buffer.try_raw(&state.snatch_guard)?;
1012        unsafe {
1013            state.raw_encoder.dispatch_indirect(buf_raw, offset);
1014        }
1015    }
1016
1017    Ok(())
1018}
1019
1020fn push_debug_group(state: &mut State, string_data: &[u8], len: usize) {
1021    state.debug_scope_depth += 1;
1022    if !state
1023        .device
1024        .instance_flags
1025        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
1026    {
1027        let label =
1028            str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
1029        unsafe {
1030            state.raw_encoder.begin_debug_marker(label);
1031        }
1032    }
1033    state.string_offset += len;
1034}
1035
1036fn pop_debug_group(state: &mut State) -> Result<(), ComputePassErrorInner> {
1037    if state.debug_scope_depth == 0 {
1038        return Err(ComputePassErrorInner::InvalidPopDebugGroup);
1039    }
1040    state.debug_scope_depth -= 1;
1041    if !state
1042        .device
1043        .instance_flags
1044        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
1045    {
1046        unsafe {
1047            state.raw_encoder.end_debug_marker();
1048        }
1049    }
1050    Ok(())
1051}
1052
1053fn insert_debug_marker(state: &mut State, string_data: &[u8], len: usize) {
1054    if !state
1055        .device
1056        .instance_flags
1057        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
1058    {
1059        let label =
1060            str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
1061        unsafe { state.raw_encoder.insert_debug_marker(label) }
1062    }
1063    state.string_offset += len;
1064}
1065
1066fn write_timestamp(
1067    state: &mut State,
1068    cmd_buf: &CommandBuffer,
1069    query_set: Arc<resource::QuerySet>,
1070    query_index: u32,
1071) -> Result<(), ComputePassErrorInner> {
1072    query_set.same_device_as(cmd_buf)?;
1073
1074    state
1075        .device
1076        .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)?;
1077
1078    let query_set = state.tracker.query_sets.insert_single(query_set);
1079
1080    query_set.validate_and_write_timestamp(state.raw_encoder, query_index, None)?;
1081    Ok(())
1082}
1083
1084// Recording a compute pass.
1085impl Global {
1086    pub fn compute_pass_set_bind_group(
1087        &self,
1088        pass: &mut ComputePass,
1089        index: u32,
1090        bind_group_id: Option<id::BindGroupId>,
1091        offsets: &[DynamicOffset],
1092    ) -> Result<(), ComputePassError> {
1093        let scope = PassErrorScope::SetBindGroup;
1094        let base = pass
1095            .base
1096            .as_mut()
1097            .ok_or(ComputePassErrorInner::PassEnded)
1098            .map_pass_err(scope)?; // Can't use base_mut() utility here because of borrow checker.
1099
1100        let redundant = pass.current_bind_groups.set_and_check_redundant(
1101            bind_group_id,
1102            index,
1103            &mut base.dynamic_offsets,
1104            offsets,
1105        );
1106
1107        if redundant {
1108            return Ok(());
1109        }
1110
1111        let mut bind_group = None;
1112        if bind_group_id.is_some() {
1113            let bind_group_id = bind_group_id.unwrap();
1114
1115            let hub = &self.hub;
1116            let bg = hub
1117                .bind_groups
1118                .get(bind_group_id)
1119                .get()
1120                .map_pass_err(scope)?;
1121            bind_group = Some(bg);
1122        }
1123
1124        base.commands.push(ArcComputeCommand::SetBindGroup {
1125            index,
1126            num_dynamic_offsets: offsets.len(),
1127            bind_group,
1128        });
1129
1130        Ok(())
1131    }
1132
1133    pub fn compute_pass_set_pipeline(
1134        &self,
1135        pass: &mut ComputePass,
1136        pipeline_id: id::ComputePipelineId,
1137    ) -> Result<(), ComputePassError> {
1138        let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1139
1140        let scope = PassErrorScope::SetPipelineCompute;
1141
1142        let base = pass.base_mut(scope)?;
1143        if redundant {
1144            // Do redundant early-out **after** checking whether the pass is ended or not.
1145            return Ok(());
1146        }
1147
1148        let hub = &self.hub;
1149        let pipeline = hub
1150            .compute_pipelines
1151            .get(pipeline_id)
1152            .get()
1153            .map_pass_err(scope)?;
1154
1155        base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1156
1157        Ok(())
1158    }
1159
1160    pub fn compute_pass_set_push_constants(
1161        &self,
1162        pass: &mut ComputePass,
1163        offset: u32,
1164        data: &[u8],
1165    ) -> Result<(), ComputePassError> {
1166        let scope = PassErrorScope::SetPushConstant;
1167        let base = pass.base_mut(scope)?;
1168
1169        if offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1170            return Err(ComputePassErrorInner::PushConstantOffsetAlignment).map_pass_err(scope);
1171        }
1172
1173        if data.len() as u32 & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1174            return Err(ComputePassErrorInner::PushConstantSizeAlignment).map_pass_err(scope);
1175        }
1176        let value_offset = base
1177            .push_constant_data
1178            .len()
1179            .try_into()
1180            .map_err(|_| ComputePassErrorInner::PushConstantOutOfMemory)
1181            .map_pass_err(scope)?;
1182
1183        base.push_constant_data.extend(
1184            data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
1185                .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1186        );
1187
1188        base.commands.push(ArcComputeCommand::SetPushConstant {
1189            offset,
1190            size_bytes: data.len() as u32,
1191            values_offset: value_offset,
1192        });
1193
1194        Ok(())
1195    }
1196
1197    pub fn compute_pass_dispatch_workgroups(
1198        &self,
1199        pass: &mut ComputePass,
1200        groups_x: u32,
1201        groups_y: u32,
1202        groups_z: u32,
1203    ) -> Result<(), ComputePassError> {
1204        let scope = PassErrorScope::Dispatch { indirect: false };
1205
1206        let base = pass.base_mut(scope)?;
1207        base.commands
1208            .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1209
1210        Ok(())
1211    }
1212
1213    pub fn compute_pass_dispatch_workgroups_indirect(
1214        &self,
1215        pass: &mut ComputePass,
1216        buffer_id: id::BufferId,
1217        offset: BufferAddress,
1218    ) -> Result<(), ComputePassError> {
1219        let hub = &self.hub;
1220        let scope = PassErrorScope::Dispatch { indirect: true };
1221        let base = pass.base_mut(scope)?;
1222
1223        let buffer = hub.buffers.get(buffer_id).get().map_pass_err(scope)?;
1224
1225        base.commands
1226            .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1227
1228        Ok(())
1229    }
1230
1231    pub fn compute_pass_push_debug_group(
1232        &self,
1233        pass: &mut ComputePass,
1234        label: &str,
1235        color: u32,
1236    ) -> Result<(), ComputePassError> {
1237        let base = pass.base_mut(PassErrorScope::PushDebugGroup)?;
1238
1239        let bytes = label.as_bytes();
1240        base.string_data.extend_from_slice(bytes);
1241
1242        base.commands.push(ArcComputeCommand::PushDebugGroup {
1243            color,
1244            len: bytes.len(),
1245        });
1246
1247        Ok(())
1248    }
1249
1250    pub fn compute_pass_pop_debug_group(
1251        &self,
1252        pass: &mut ComputePass,
1253    ) -> Result<(), ComputePassError> {
1254        let base = pass.base_mut(PassErrorScope::PopDebugGroup)?;
1255
1256        base.commands.push(ArcComputeCommand::PopDebugGroup);
1257
1258        Ok(())
1259    }
1260
1261    pub fn compute_pass_insert_debug_marker(
1262        &self,
1263        pass: &mut ComputePass,
1264        label: &str,
1265        color: u32,
1266    ) -> Result<(), ComputePassError> {
1267        let base = pass.base_mut(PassErrorScope::InsertDebugMarker)?;
1268
1269        let bytes = label.as_bytes();
1270        base.string_data.extend_from_slice(bytes);
1271
1272        base.commands.push(ArcComputeCommand::InsertDebugMarker {
1273            color,
1274            len: bytes.len(),
1275        });
1276
1277        Ok(())
1278    }
1279
1280    pub fn compute_pass_write_timestamp(
1281        &self,
1282        pass: &mut ComputePass,
1283        query_set_id: id::QuerySetId,
1284        query_index: u32,
1285    ) -> Result<(), ComputePassError> {
1286        let scope = PassErrorScope::WriteTimestamp;
1287        let base = pass.base_mut(scope)?;
1288
1289        let hub = &self.hub;
1290        let query_set = hub.query_sets.get(query_set_id).get().map_pass_err(scope)?;
1291
1292        base.commands.push(ArcComputeCommand::WriteTimestamp {
1293            query_set,
1294            query_index,
1295        });
1296
1297        Ok(())
1298    }
1299
1300    pub fn compute_pass_begin_pipeline_statistics_query(
1301        &self,
1302        pass: &mut ComputePass,
1303        query_set_id: id::QuerySetId,
1304        query_index: u32,
1305    ) -> Result<(), ComputePassError> {
1306        let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1307        let base = pass.base_mut(scope)?;
1308
1309        let hub = &self.hub;
1310        let query_set = hub.query_sets.get(query_set_id).get().map_pass_err(scope)?;
1311
1312        base.commands
1313            .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1314                query_set,
1315                query_index,
1316            });
1317
1318        Ok(())
1319    }
1320
1321    pub fn compute_pass_end_pipeline_statistics_query(
1322        &self,
1323        pass: &mut ComputePass,
1324    ) -> Result<(), ComputePassError> {
1325        let scope = PassErrorScope::EndPipelineStatisticsQuery;
1326        let base = pass.base_mut(scope)?;
1327        base.commands
1328            .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1329
1330        Ok(())
1331    }
1332}