1use crate::{
2 binding_model::{
3 BindError, BindGroup, LateMinBufferBindingSizeMismatch, PushConstantUploadError,
4 },
5 command::{
6 bind::Binder,
7 compute_command::ArcComputeCommand,
8 end_pipeline_statistics_query,
9 memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
10 validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites, BasePass,
11 BindGroupStateChange, CommandBuffer, CommandEncoderError, CommandEncoderStatus, MapPassErr,
12 PassErrorScope, PassTimestampWrites, QueryUseError, StateChange,
13 },
14 device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures},
15 global::Global,
16 hal_label, id,
17 init_tracker::{BufferInitTrackerAction, MemoryInitKind},
18 pipeline::ComputePipeline,
19 resource::{
20 self, Buffer, DestroyedResourceError, InvalidResourceError, Labeled,
21 MissingBufferUsageError, ParentDevice,
22 },
23 snatch::SnatchGuard,
24 track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex, UsageScope},
25 Label,
26};
27
28use thiserror::Error;
29use wgt::{BufferAddress, DynamicOffset};
30
31use std::sync::Arc;
32use std::{fmt, mem::size_of, str};
33
34use super::{bind::BinderError, memory_init::CommandBufferTextureMemoryActions};
35
36pub struct ComputePass {
37 base: Option<BasePass<ArcComputeCommand>>,
42
43 parent: Option<Arc<CommandBuffer>>,
47
48 timestamp_writes: Option<ArcPassTimestampWrites>,
49
50 current_bind_groups: BindGroupStateChange,
52 current_pipeline: StateChange<id::ComputePipelineId>,
53}
54
55impl ComputePass {
56 fn new(parent: Option<Arc<CommandBuffer>>, desc: ArcComputePassDescriptor) -> Self {
58 let ArcComputePassDescriptor {
59 label,
60 timestamp_writes,
61 } = desc;
62
63 Self {
64 base: Some(BasePass::new(label)),
65 parent,
66 timestamp_writes,
67
68 current_bind_groups: BindGroupStateChange::new(),
69 current_pipeline: StateChange::new(),
70 }
71 }
72
73 #[inline]
74 pub fn label(&self) -> Option<&str> {
75 self.base.as_ref().and_then(|base| base.label.as_deref())
76 }
77
78 fn base_mut<'a>(
79 &'a mut self,
80 scope: PassErrorScope,
81 ) -> Result<&'a mut BasePass<ArcComputeCommand>, ComputePassError> {
82 self.base
83 .as_mut()
84 .ok_or(ComputePassErrorInner::PassEnded)
85 .map_pass_err(scope)
86 }
87}
88
89impl fmt::Debug for ComputePass {
90 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91 match self.parent {
92 Some(ref cmd_buf) => write!(f, "ComputePass {{ parent: {} }}", cmd_buf.error_ident()),
93 None => write!(f, "ComputePass {{ parent: None }}"),
94 }
95 }
96}
97
98#[derive(Clone, Debug, Default)]
99pub struct ComputePassDescriptor<'a> {
100 pub label: Label<'a>,
101 pub timestamp_writes: Option<&'a PassTimestampWrites>,
103}
104
105struct ArcComputePassDescriptor<'a> {
106 pub label: &'a Label<'a>,
107 pub timestamp_writes: Option<ArcPassTimestampWrites>,
109}
110
111#[derive(Clone, Debug, Error)]
112#[non_exhaustive]
113pub enum DispatchError {
114 #[error("Compute pipeline must be set")]
115 MissingPipeline,
116 #[error(transparent)]
117 IncompatibleBindGroup(#[from] Box<BinderError>),
118 #[error(
119 "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
120 )]
121 InvalidGroupSize { current: [u32; 3], limit: u32 },
122 #[error(transparent)]
123 BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
124}
125
126#[derive(Clone, Debug, Error)]
128pub enum ComputePassErrorInner {
129 #[error(transparent)]
130 Device(#[from] DeviceError),
131 #[error(transparent)]
132 Encoder(#[from] CommandEncoderError),
133 #[error("Parent encoder is invalid")]
134 InvalidParentEncoder,
135 #[error("Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}")]
136 BindGroupIndexOutOfRange { index: u32, max: u32 },
137 #[error(transparent)]
138 DestroyedResource(#[from] DestroyedResourceError),
139 #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
140 UnalignedIndirectBufferOffset(BufferAddress),
141 #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
142 IndirectBufferOverrun {
143 offset: u64,
144 end_offset: u64,
145 buffer_size: u64,
146 },
147 #[error(transparent)]
148 ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
149 #[error(transparent)]
150 MissingBufferUsage(#[from] MissingBufferUsageError),
151 #[error("Cannot pop debug group, because number of pushed debug groups is zero")]
152 InvalidPopDebugGroup,
153 #[error(transparent)]
154 Dispatch(#[from] DispatchError),
155 #[error(transparent)]
156 Bind(#[from] BindError),
157 #[error(transparent)]
158 PushConstants(#[from] PushConstantUploadError),
159 #[error("Push constant offset must be aligned to 4 bytes")]
160 PushConstantOffsetAlignment,
161 #[error("Push constant size must be aligned to 4 bytes")]
162 PushConstantSizeAlignment,
163 #[error("Ran out of push constant space. Don't set 4gb of push constants per ComputePass.")]
164 PushConstantOutOfMemory,
165 #[error(transparent)]
166 QueryUse(#[from] QueryUseError),
167 #[error(transparent)]
168 MissingFeatures(#[from] MissingFeatures),
169 #[error(transparent)]
170 MissingDownlevelFlags(#[from] MissingDownlevelFlags),
171 #[error("The compute pass has already been ended and no further commands can be recorded")]
172 PassEnded,
173 #[error(transparent)]
174 InvalidResource(#[from] InvalidResourceError),
175}
176
177#[derive(Clone, Debug, Error)]
179#[error("{scope}")]
180pub struct ComputePassError {
181 pub scope: PassErrorScope,
182 #[source]
183 pub(super) inner: ComputePassErrorInner,
184}
185
186impl<T, E> MapPassErr<T, ComputePassError> for Result<T, E>
187where
188 E: Into<ComputePassErrorInner>,
189{
190 fn map_pass_err(self, scope: PassErrorScope) -> Result<T, ComputePassError> {
191 self.map_err(|inner| ComputePassError {
192 scope,
193 inner: inner.into(),
194 })
195 }
196}
197
198struct State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder> {
199 binder: Binder,
200 pipeline: Option<Arc<ComputePipeline>>,
201 scope: UsageScope<'scope>,
202 debug_scope_depth: u32,
203
204 snatch_guard: SnatchGuard<'snatch_guard>,
205
206 device: &'cmd_buf Arc<Device>,
207
208 raw_encoder: &'raw_encoder mut dyn hal::DynCommandEncoder,
209
210 tracker: &'cmd_buf mut Tracker,
211 buffer_memory_init_actions: &'cmd_buf mut Vec<BufferInitTrackerAction>,
212 texture_memory_actions: &'cmd_buf mut CommandBufferTextureMemoryActions,
213
214 temp_offsets: Vec<u32>,
215 dynamic_offset_count: usize,
216 string_offset: usize,
217 active_query: Option<(Arc<resource::QuerySet>, u32)>,
218
219 push_constants: Vec<u32>,
220
221 intermediate_trackers: Tracker,
222
223 pending_discard_init_fixups: SurfacesInDiscardState,
226}
227
228impl<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder>
229 State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder>
230{
231 fn is_ready(&self) -> Result<(), DispatchError> {
232 if let Some(pipeline) = self.pipeline.as_ref() {
233 self.binder.check_compatibility(pipeline.as_ref())?;
234 self.binder.check_late_buffer_bindings()?;
235 Ok(())
236 } else {
237 Err(DispatchError::MissingPipeline)
238 }
239 }
240
241 fn flush_states(
244 &mut self,
245 indirect_buffer: Option<TrackerIndex>,
246 ) -> Result<(), ResourceUsageCompatibilityError> {
247 for bind_group in self.binder.list_active() {
248 unsafe { self.scope.merge_bind_group(&bind_group.used)? };
249 }
252
253 for bind_group in self.binder.list_active() {
254 unsafe {
255 self.intermediate_trackers
256 .set_and_remove_from_usage_scope_sparse(&mut self.scope, &bind_group.used)
257 }
258 }
259
260 unsafe {
262 self.intermediate_trackers
263 .buffers
264 .set_and_remove_from_usage_scope_sparse(&mut self.scope.buffers, indirect_buffer);
265 }
266
267 CommandBuffer::drain_barriers(
268 self.raw_encoder,
269 &mut self.intermediate_trackers,
270 &self.snatch_guard,
271 );
272 Ok(())
273 }
274}
275
276impl Global {
279 pub fn command_encoder_create_compute_pass(
286 &self,
287 encoder_id: id::CommandEncoderId,
288 desc: &ComputePassDescriptor<'_>,
289 ) -> (ComputePass, Option<CommandEncoderError>) {
290 let hub = &self.hub;
291
292 let mut arc_desc = ArcComputePassDescriptor {
293 label: &desc.label,
294 timestamp_writes: None, };
296
297 let make_err = |e, arc_desc| (ComputePass::new(None, arc_desc), Some(e));
298
299 let cmd_buf = hub.command_buffers.get(encoder_id.into_command_buffer_id());
300
301 match cmd_buf
302 .try_get()
303 .map_err(|e| e.into())
304 .and_then(|mut cmd_buf_data| cmd_buf_data.lock_encoder())
305 {
306 Ok(_) => {}
307 Err(e) => return make_err(e, arc_desc),
308 };
309
310 arc_desc.timestamp_writes = if let Some(tw) = desc.timestamp_writes {
311 let query_set = match hub.query_sets.get(tw.query_set).get() {
312 Ok(query_set) => query_set,
313 Err(e) => return make_err(e.into(), arc_desc),
314 };
315
316 Some(ArcPassTimestampWrites {
317 query_set,
318 beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
319 end_of_pass_write_index: tw.end_of_pass_write_index,
320 })
321 } else {
322 None
323 };
324
325 (ComputePass::new(Some(cmd_buf), arc_desc), None)
326 }
327
328 #[doc(hidden)]
331 #[cfg(any(feature = "serde", feature = "replay"))]
332 pub fn compute_pass_end_with_unresolved_commands(
333 &self,
334 encoder_id: id::CommandEncoderId,
335 base: BasePass<super::ComputeCommand>,
336 timestamp_writes: Option<&PassTimestampWrites>,
337 ) -> Result<(), ComputePassError> {
338 let pass_scope = PassErrorScope::Pass;
339
340 #[cfg(feature = "trace")]
341 {
342 let cmd_buf = self
343 .hub
344 .command_buffers
345 .get(encoder_id.into_command_buffer_id());
346 let mut cmd_buf_data = cmd_buf.try_get().map_pass_err(pass_scope)?;
347
348 if let Some(ref mut list) = cmd_buf_data.commands {
349 list.push(crate::device::trace::Command::RunComputePass {
350 base: BasePass {
351 label: base.label.clone(),
352 commands: base.commands.clone(),
353 dynamic_offsets: base.dynamic_offsets.clone(),
354 string_data: base.string_data.clone(),
355 push_constant_data: base.push_constant_data.clone(),
356 },
357 timestamp_writes: timestamp_writes.cloned(),
358 });
359 }
360 }
361
362 let BasePass {
363 label,
364 commands,
365 dynamic_offsets,
366 string_data,
367 push_constant_data,
368 } = base;
369
370 let (mut compute_pass, encoder_error) = self.command_encoder_create_compute_pass(
371 encoder_id,
372 &ComputePassDescriptor {
373 label: label.as_deref().map(std::borrow::Cow::Borrowed),
374 timestamp_writes,
375 },
376 );
377 if let Some(err) = encoder_error {
378 return Err(ComputePassError {
379 scope: pass_scope,
380 inner: err.into(),
381 });
382 };
383
384 compute_pass.base = Some(BasePass {
385 label,
386 commands: super::ComputeCommand::resolve_compute_command_ids(&self.hub, &commands)?,
387 dynamic_offsets,
388 string_data,
389 push_constant_data,
390 });
391
392 self.compute_pass_end(&mut compute_pass)
393 }
394
395 pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), ComputePassError> {
396 profiling::scope!("CommandEncoder::run_compute_pass");
397 let pass_scope = PassErrorScope::Pass;
398
399 let cmd_buf = pass
400 .parent
401 .as_ref()
402 .ok_or(ComputePassErrorInner::InvalidParentEncoder)
403 .map_pass_err(pass_scope)?;
404
405 let base = pass
406 .base
407 .take()
408 .ok_or(ComputePassErrorInner::PassEnded)
409 .map_pass_err(pass_scope)?;
410
411 let device = &cmd_buf.device;
412 device.check_is_valid().map_pass_err(pass_scope)?;
413
414 let mut cmd_buf_data = cmd_buf.try_get().map_pass_err(pass_scope)?;
415 cmd_buf_data.unlock_encoder().map_pass_err(pass_scope)?;
416 let cmd_buf_data = &mut *cmd_buf_data;
417
418 let encoder = &mut cmd_buf_data.encoder;
419 let status = &mut cmd_buf_data.status;
420
421 encoder.close(&cmd_buf.device).map_pass_err(pass_scope)?;
425 *status = CommandEncoderStatus::Error;
427 let raw_encoder = encoder.open(&cmd_buf.device).map_pass_err(pass_scope)?;
428
429 let mut state = State {
430 binder: Binder::new(),
431 pipeline: None,
432 scope: device.new_usage_scope(),
433 debug_scope_depth: 0,
434
435 snatch_guard: device.snatchable_lock.read(),
436
437 device,
438 raw_encoder,
439 tracker: &mut cmd_buf_data.trackers,
440 buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions,
441 texture_memory_actions: &mut cmd_buf_data.texture_memory_actions,
442
443 temp_offsets: Vec::new(),
444 dynamic_offset_count: 0,
445 string_offset: 0,
446 active_query: None,
447
448 push_constants: Vec::new(),
449
450 intermediate_trackers: Tracker::new(),
451
452 pending_discard_init_fixups: SurfacesInDiscardState::new(),
453 };
454
455 let indices = &state.device.tracker_indices;
456 state.tracker.buffers.set_size(indices.buffers.size());
457 state.tracker.textures.set_size(indices.textures.size());
458
459 let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
460 if let Some(tw) = pass.timestamp_writes.take() {
461 tw.query_set
462 .same_device_as(cmd_buf.as_ref())
463 .map_pass_err(pass_scope)?;
464
465 let query_set = state.tracker.query_sets.insert_single(tw.query_set);
466
467 let range = if let (Some(index_a), Some(index_b)) =
470 (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
471 {
472 Some(index_a.min(index_b)..index_a.max(index_b) + 1)
473 } else {
474 tw.beginning_of_pass_write_index
475 .or(tw.end_of_pass_write_index)
476 .map(|i| i..i + 1)
477 };
478 if let Some(range) = range {
481 unsafe {
482 state.raw_encoder.reset_queries(query_set.raw(), range);
483 }
484 }
485
486 Some(hal::PassTimestampWrites {
487 query_set: query_set.raw(),
488 beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
489 end_of_pass_write_index: tw.end_of_pass_write_index,
490 })
491 } else {
492 None
493 };
494
495 let hal_desc = hal::ComputePassDescriptor {
496 label: hal_label(base.label.as_deref(), device.instance_flags),
497 timestamp_writes,
498 };
499
500 unsafe {
501 state.raw_encoder.begin_compute_pass(&hal_desc);
502 }
503
504 for command in base.commands {
505 match command {
506 ArcComputeCommand::SetBindGroup {
507 index,
508 num_dynamic_offsets,
509 bind_group,
510 } => {
511 let scope = PassErrorScope::SetBindGroup;
512 set_bind_group(
513 &mut state,
514 cmd_buf,
515 &base.dynamic_offsets,
516 index,
517 num_dynamic_offsets,
518 bind_group,
519 )
520 .map_pass_err(scope)?;
521 }
522 ArcComputeCommand::SetPipeline(pipeline) => {
523 let scope = PassErrorScope::SetPipelineCompute;
524 set_pipeline(&mut state, cmd_buf, pipeline).map_pass_err(scope)?;
525 }
526 ArcComputeCommand::SetPushConstant {
527 offset,
528 size_bytes,
529 values_offset,
530 } => {
531 let scope = PassErrorScope::SetPushConstant;
532 set_push_constant(
533 &mut state,
534 &base.push_constant_data,
535 offset,
536 size_bytes,
537 values_offset,
538 )
539 .map_pass_err(scope)?;
540 }
541 ArcComputeCommand::Dispatch(groups) => {
542 let scope = PassErrorScope::Dispatch { indirect: false };
543 dispatch(&mut state, groups).map_pass_err(scope)?;
544 }
545 ArcComputeCommand::DispatchIndirect { buffer, offset } => {
546 let scope = PassErrorScope::Dispatch { indirect: true };
547 dispatch_indirect(&mut state, cmd_buf, buffer, offset).map_pass_err(scope)?;
548 }
549 ArcComputeCommand::PushDebugGroup { color: _, len } => {
550 push_debug_group(&mut state, &base.string_data, len);
551 }
552 ArcComputeCommand::PopDebugGroup => {
553 let scope = PassErrorScope::PopDebugGroup;
554 pop_debug_group(&mut state).map_pass_err(scope)?;
555 }
556 ArcComputeCommand::InsertDebugMarker { color: _, len } => {
557 insert_debug_marker(&mut state, &base.string_data, len);
558 }
559 ArcComputeCommand::WriteTimestamp {
560 query_set,
561 query_index,
562 } => {
563 let scope = PassErrorScope::WriteTimestamp;
564 write_timestamp(&mut state, cmd_buf, query_set, query_index)
565 .map_pass_err(scope)?;
566 }
567 ArcComputeCommand::BeginPipelineStatisticsQuery {
568 query_set,
569 query_index,
570 } => {
571 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
572 validate_and_begin_pipeline_statistics_query(
573 query_set,
574 state.raw_encoder,
575 &mut state.tracker.query_sets,
576 cmd_buf,
577 query_index,
578 None,
579 &mut state.active_query,
580 )
581 .map_pass_err(scope)?;
582 }
583 ArcComputeCommand::EndPipelineStatisticsQuery => {
584 let scope = PassErrorScope::EndPipelineStatisticsQuery;
585 end_pipeline_statistics_query(state.raw_encoder, &mut state.active_query)
586 .map_pass_err(scope)?;
587 }
588 }
589 }
590
591 unsafe {
592 state.raw_encoder.end_compute_pass();
593 }
594
595 *status = CommandEncoderStatus::Recording;
598
599 let State {
600 snatch_guard,
601 tracker,
602 intermediate_trackers,
603 pending_discard_init_fixups,
604 ..
605 } = state;
606
607 encoder.close(&cmd_buf.device).map_pass_err(pass_scope)?;
609
610 let transit = encoder.open(&cmd_buf.device).map_pass_err(pass_scope)?;
614 fixup_discarded_surfaces(
615 pending_discard_init_fixups.into_iter(),
616 transit,
617 &mut tracker.textures,
618 device,
619 &snatch_guard,
620 );
621 CommandBuffer::insert_barriers_from_tracker(
622 transit,
623 tracker,
624 &intermediate_trackers,
625 &snatch_guard,
626 );
627 encoder
629 .close_and_swap(&cmd_buf.device)
630 .map_pass_err(pass_scope)?;
631
632 Ok(())
633 }
634}
635
636fn set_bind_group(
637 state: &mut State,
638 cmd_buf: &CommandBuffer,
639 dynamic_offsets: &[DynamicOffset],
640 index: u32,
641 num_dynamic_offsets: usize,
642 bind_group: Option<Arc<BindGroup>>,
643) -> Result<(), ComputePassErrorInner> {
644 let max_bind_groups = state.device.limits.max_bind_groups;
645 if index >= max_bind_groups {
646 return Err(ComputePassErrorInner::BindGroupIndexOutOfRange {
647 index,
648 max: max_bind_groups,
649 });
650 }
651
652 state.temp_offsets.clear();
653 state.temp_offsets.extend_from_slice(
654 &dynamic_offsets
655 [state.dynamic_offset_count..state.dynamic_offset_count + num_dynamic_offsets],
656 );
657 state.dynamic_offset_count += num_dynamic_offsets;
658
659 if bind_group.is_none() {
660 return Ok(());
662 }
663
664 let bind_group = bind_group.unwrap();
665 let bind_group = state.tracker.bind_groups.insert_single(bind_group);
666
667 bind_group.same_device_as(cmd_buf)?;
668
669 bind_group.validate_dynamic_bindings(index, &state.temp_offsets)?;
670
671 state
672 .buffer_memory_init_actions
673 .extend(bind_group.used_buffer_ranges.iter().filter_map(|action| {
674 action
675 .buffer
676 .initialization_status
677 .read()
678 .check_action(action)
679 }));
680
681 for action in bind_group.used_texture_ranges.iter() {
682 state
683 .pending_discard_init_fixups
684 .extend(state.texture_memory_actions.register_init_action(action));
685 }
686
687 let pipeline_layout = state.binder.pipeline_layout.clone();
688 let entries = state
689 .binder
690 .assign_group(index as usize, bind_group, &state.temp_offsets);
691 if !entries.is_empty() && pipeline_layout.is_some() {
692 let pipeline_layout = pipeline_layout.as_ref().unwrap().raw();
693 for (i, e) in entries.iter().enumerate() {
694 if let Some(group) = e.group.as_ref() {
695 let raw_bg = group.try_raw(&state.snatch_guard)?;
696 unsafe {
697 state.raw_encoder.set_bind_group(
698 pipeline_layout,
699 index + i as u32,
700 Some(raw_bg),
701 &e.dynamic_offsets,
702 );
703 }
704 }
705 }
706 }
707 Ok(())
708}
709
710fn set_pipeline(
711 state: &mut State,
712 cmd_buf: &CommandBuffer,
713 pipeline: Arc<ComputePipeline>,
714) -> Result<(), ComputePassErrorInner> {
715 pipeline.same_device_as(cmd_buf)?;
716
717 state.pipeline = Some(pipeline.clone());
718
719 let pipeline = state.tracker.compute_pipelines.insert_single(pipeline);
720
721 unsafe {
722 state.raw_encoder.set_compute_pipeline(pipeline.raw());
723 }
724
725 if state.binder.pipeline_layout.is_none()
727 || !state
728 .binder
729 .pipeline_layout
730 .as_ref()
731 .unwrap()
732 .is_equal(&pipeline.layout)
733 {
734 let (start_index, entries) = state
735 .binder
736 .change_pipeline_layout(&pipeline.layout, &pipeline.late_sized_buffer_groups);
737 if !entries.is_empty() {
738 for (i, e) in entries.iter().enumerate() {
739 if let Some(group) = e.group.as_ref() {
740 let raw_bg = group.try_raw(&state.snatch_guard)?;
741 unsafe {
742 state.raw_encoder.set_bind_group(
743 pipeline.layout.raw(),
744 start_index as u32 + i as u32,
745 Some(raw_bg),
746 &e.dynamic_offsets,
747 );
748 }
749 }
750 }
751 }
752
753 state.push_constants.clear();
755 if let Some(push_constant_range) =
757 pipeline.layout.push_constant_ranges.iter().find_map(|pcr| {
758 pcr.stages
759 .contains(wgt::ShaderStages::COMPUTE)
760 .then_some(pcr.range.clone())
761 })
762 {
763 let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize;
765 state.push_constants.extend(core::iter::repeat(0).take(len));
766 }
767
768 let non_overlapping =
770 super::bind::compute_nonoverlapping_ranges(&pipeline.layout.push_constant_ranges);
771 for range in non_overlapping {
772 let offset = range.range.start;
773 let size_bytes = range.range.end - offset;
774 super::push_constant_clear(offset, size_bytes, |clear_offset, clear_data| unsafe {
775 state.raw_encoder.set_push_constants(
776 pipeline.layout.raw(),
777 wgt::ShaderStages::COMPUTE,
778 clear_offset,
779 clear_data,
780 );
781 });
782 }
783 }
784 Ok(())
785}
786
787fn set_push_constant(
788 state: &mut State,
789 push_constant_data: &[u32],
790 offset: u32,
791 size_bytes: u32,
792 values_offset: u32,
793) -> Result<(), ComputePassErrorInner> {
794 let end_offset_bytes = offset + size_bytes;
795 let values_end_offset = (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
796 let data_slice = &push_constant_data[(values_offset as usize)..values_end_offset];
797
798 let pipeline_layout = state
799 .binder
800 .pipeline_layout
801 .as_ref()
802 .ok_or(ComputePassErrorInner::Dispatch(
804 DispatchError::MissingPipeline,
805 ))?;
806
807 pipeline_layout.validate_push_constant_ranges(
808 wgt::ShaderStages::COMPUTE,
809 offset,
810 end_offset_bytes,
811 )?;
812
813 let offset_in_elements = (offset / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
814 let size_in_elements = (size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
815 state.push_constants[offset_in_elements..][..size_in_elements].copy_from_slice(data_slice);
816
817 unsafe {
818 state.raw_encoder.set_push_constants(
819 pipeline_layout.raw(),
820 wgt::ShaderStages::COMPUTE,
821 offset,
822 data_slice,
823 );
824 }
825 Ok(())
826}
827
828fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
829 state.is_ready()?;
830
831 state.flush_states(None)?;
832
833 let groups_size_limit = state.device.limits.max_compute_workgroups_per_dimension;
834
835 if groups[0] > groups_size_limit
836 || groups[1] > groups_size_limit
837 || groups[2] > groups_size_limit
838 {
839 return Err(ComputePassErrorInner::Dispatch(
840 DispatchError::InvalidGroupSize {
841 current: groups,
842 limit: groups_size_limit,
843 },
844 ));
845 }
846
847 unsafe {
848 state.raw_encoder.dispatch(groups);
849 }
850 Ok(())
851}
852
853fn dispatch_indirect(
854 state: &mut State,
855 cmd_buf: &CommandBuffer,
856 buffer: Arc<Buffer>,
857 offset: u64,
858) -> Result<(), ComputePassErrorInner> {
859 buffer.same_device_as(cmd_buf)?;
860
861 state.is_ready()?;
862
863 state
864 .device
865 .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
866
867 buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
868
869 if offset % 4 != 0 {
870 return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
871 }
872
873 let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
874 if end_offset > buffer.size {
875 return Err(ComputePassErrorInner::IndirectBufferOverrun {
876 offset,
877 end_offset,
878 buffer_size: buffer.size,
879 });
880 }
881
882 let stride = 3 * 4; state
884 .buffer_memory_init_actions
885 .extend(buffer.initialization_status.read().create_action(
886 &buffer,
887 offset..(offset + stride),
888 MemoryInitKind::NeedsInitializedMemory,
889 ));
890
891 #[cfg(feature = "indirect-validation")]
892 {
893 let params = state.device.indirect_validation.as_ref().unwrap().params(
894 &state.device.limits,
895 offset,
896 buffer.size,
897 );
898
899 unsafe {
900 state.raw_encoder.set_compute_pipeline(params.pipeline);
901 }
902
903 unsafe {
904 state.raw_encoder.set_push_constants(
905 params.pipeline_layout,
906 wgt::ShaderStages::COMPUTE,
907 0,
908 &[params.offset_remainder as u32 / 4],
909 );
910 }
911
912 unsafe {
913 state.raw_encoder.set_bind_group(
914 params.pipeline_layout,
915 0,
916 Some(params.dst_bind_group),
917 &[],
918 );
919 }
920 unsafe {
921 state.raw_encoder.set_bind_group(
922 params.pipeline_layout,
923 1,
924 Some(
925 buffer
926 .raw_indirect_validation_bind_group
927 .get(&state.snatch_guard)
928 .unwrap()
929 .as_ref(),
930 ),
931 &[params.aligned_offset as u32],
932 );
933 }
934
935 let src_transition = state
936 .intermediate_trackers
937 .buffers
938 .set_single(&buffer, hal::BufferUses::STORAGE_READ);
939 let src_barrier =
940 src_transition.map(|transition| transition.into_hal(&buffer, &state.snatch_guard));
941 unsafe {
942 state.raw_encoder.transition_buffers(src_barrier.as_slice());
943 }
944
945 unsafe {
946 state.raw_encoder.transition_buffers(&[hal::BufferBarrier {
947 buffer: params.dst_buffer,
948 usage: hal::BufferUses::INDIRECT..hal::BufferUses::STORAGE_READ_WRITE,
949 }]);
950 }
951
952 unsafe {
953 state.raw_encoder.dispatch([1, 1, 1]);
954 }
955
956 {
958 let pipeline = state.pipeline.as_ref().unwrap();
959
960 unsafe {
961 state.raw_encoder.set_compute_pipeline(pipeline.raw());
962 }
963
964 if !state.push_constants.is_empty() {
965 unsafe {
966 state.raw_encoder.set_push_constants(
967 pipeline.layout.raw(),
968 wgt::ShaderStages::COMPUTE,
969 0,
970 &state.push_constants,
971 );
972 }
973 }
974
975 for (i, e) in state.binder.list_valid() {
976 let group = e.group.as_ref().unwrap();
977 let raw_bg = group.try_raw(&state.snatch_guard)?;
978 unsafe {
979 state.raw_encoder.set_bind_group(
980 pipeline.layout.raw(),
981 i as u32,
982 Some(raw_bg),
983 &e.dynamic_offsets,
984 );
985 }
986 }
987 }
988
989 unsafe {
990 state.raw_encoder.transition_buffers(&[hal::BufferBarrier {
991 buffer: params.dst_buffer,
992 usage: hal::BufferUses::STORAGE_READ_WRITE..hal::BufferUses::INDIRECT,
993 }]);
994 }
995
996 state.flush_states(None)?;
997 unsafe {
998 state.raw_encoder.dispatch_indirect(params.dst_buffer, 0);
999 }
1000 };
1001 #[cfg(not(feature = "indirect-validation"))]
1002 {
1003 state
1004 .scope
1005 .buffers
1006 .merge_single(&buffer, hal::BufferUses::INDIRECT)?;
1007
1008 use crate::resource::Trackable;
1009 state.flush_states(Some(buffer.tracker_index()))?;
1010
1011 let buf_raw = buffer.try_raw(&state.snatch_guard)?;
1012 unsafe {
1013 state.raw_encoder.dispatch_indirect(buf_raw, offset);
1014 }
1015 }
1016
1017 Ok(())
1018}
1019
1020fn push_debug_group(state: &mut State, string_data: &[u8], len: usize) {
1021 state.debug_scope_depth += 1;
1022 if !state
1023 .device
1024 .instance_flags
1025 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
1026 {
1027 let label =
1028 str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
1029 unsafe {
1030 state.raw_encoder.begin_debug_marker(label);
1031 }
1032 }
1033 state.string_offset += len;
1034}
1035
1036fn pop_debug_group(state: &mut State) -> Result<(), ComputePassErrorInner> {
1037 if state.debug_scope_depth == 0 {
1038 return Err(ComputePassErrorInner::InvalidPopDebugGroup);
1039 }
1040 state.debug_scope_depth -= 1;
1041 if !state
1042 .device
1043 .instance_flags
1044 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
1045 {
1046 unsafe {
1047 state.raw_encoder.end_debug_marker();
1048 }
1049 }
1050 Ok(())
1051}
1052
1053fn insert_debug_marker(state: &mut State, string_data: &[u8], len: usize) {
1054 if !state
1055 .device
1056 .instance_flags
1057 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
1058 {
1059 let label =
1060 str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
1061 unsafe { state.raw_encoder.insert_debug_marker(label) }
1062 }
1063 state.string_offset += len;
1064}
1065
1066fn write_timestamp(
1067 state: &mut State,
1068 cmd_buf: &CommandBuffer,
1069 query_set: Arc<resource::QuerySet>,
1070 query_index: u32,
1071) -> Result<(), ComputePassErrorInner> {
1072 query_set.same_device_as(cmd_buf)?;
1073
1074 state
1075 .device
1076 .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)?;
1077
1078 let query_set = state.tracker.query_sets.insert_single(query_set);
1079
1080 query_set.validate_and_write_timestamp(state.raw_encoder, query_index, None)?;
1081 Ok(())
1082}
1083
1084impl Global {
1086 pub fn compute_pass_set_bind_group(
1087 &self,
1088 pass: &mut ComputePass,
1089 index: u32,
1090 bind_group_id: Option<id::BindGroupId>,
1091 offsets: &[DynamicOffset],
1092 ) -> Result<(), ComputePassError> {
1093 let scope = PassErrorScope::SetBindGroup;
1094 let base = pass
1095 .base
1096 .as_mut()
1097 .ok_or(ComputePassErrorInner::PassEnded)
1098 .map_pass_err(scope)?; let redundant = pass.current_bind_groups.set_and_check_redundant(
1101 bind_group_id,
1102 index,
1103 &mut base.dynamic_offsets,
1104 offsets,
1105 );
1106
1107 if redundant {
1108 return Ok(());
1109 }
1110
1111 let mut bind_group = None;
1112 if bind_group_id.is_some() {
1113 let bind_group_id = bind_group_id.unwrap();
1114
1115 let hub = &self.hub;
1116 let bg = hub
1117 .bind_groups
1118 .get(bind_group_id)
1119 .get()
1120 .map_pass_err(scope)?;
1121 bind_group = Some(bg);
1122 }
1123
1124 base.commands.push(ArcComputeCommand::SetBindGroup {
1125 index,
1126 num_dynamic_offsets: offsets.len(),
1127 bind_group,
1128 });
1129
1130 Ok(())
1131 }
1132
1133 pub fn compute_pass_set_pipeline(
1134 &self,
1135 pass: &mut ComputePass,
1136 pipeline_id: id::ComputePipelineId,
1137 ) -> Result<(), ComputePassError> {
1138 let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1139
1140 let scope = PassErrorScope::SetPipelineCompute;
1141
1142 let base = pass.base_mut(scope)?;
1143 if redundant {
1144 return Ok(());
1146 }
1147
1148 let hub = &self.hub;
1149 let pipeline = hub
1150 .compute_pipelines
1151 .get(pipeline_id)
1152 .get()
1153 .map_pass_err(scope)?;
1154
1155 base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1156
1157 Ok(())
1158 }
1159
1160 pub fn compute_pass_set_push_constants(
1161 &self,
1162 pass: &mut ComputePass,
1163 offset: u32,
1164 data: &[u8],
1165 ) -> Result<(), ComputePassError> {
1166 let scope = PassErrorScope::SetPushConstant;
1167 let base = pass.base_mut(scope)?;
1168
1169 if offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1170 return Err(ComputePassErrorInner::PushConstantOffsetAlignment).map_pass_err(scope);
1171 }
1172
1173 if data.len() as u32 & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1174 return Err(ComputePassErrorInner::PushConstantSizeAlignment).map_pass_err(scope);
1175 }
1176 let value_offset = base
1177 .push_constant_data
1178 .len()
1179 .try_into()
1180 .map_err(|_| ComputePassErrorInner::PushConstantOutOfMemory)
1181 .map_pass_err(scope)?;
1182
1183 base.push_constant_data.extend(
1184 data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
1185 .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1186 );
1187
1188 base.commands.push(ArcComputeCommand::SetPushConstant {
1189 offset,
1190 size_bytes: data.len() as u32,
1191 values_offset: value_offset,
1192 });
1193
1194 Ok(())
1195 }
1196
1197 pub fn compute_pass_dispatch_workgroups(
1198 &self,
1199 pass: &mut ComputePass,
1200 groups_x: u32,
1201 groups_y: u32,
1202 groups_z: u32,
1203 ) -> Result<(), ComputePassError> {
1204 let scope = PassErrorScope::Dispatch { indirect: false };
1205
1206 let base = pass.base_mut(scope)?;
1207 base.commands
1208 .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1209
1210 Ok(())
1211 }
1212
1213 pub fn compute_pass_dispatch_workgroups_indirect(
1214 &self,
1215 pass: &mut ComputePass,
1216 buffer_id: id::BufferId,
1217 offset: BufferAddress,
1218 ) -> Result<(), ComputePassError> {
1219 let hub = &self.hub;
1220 let scope = PassErrorScope::Dispatch { indirect: true };
1221 let base = pass.base_mut(scope)?;
1222
1223 let buffer = hub.buffers.get(buffer_id).get().map_pass_err(scope)?;
1224
1225 base.commands
1226 .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1227
1228 Ok(())
1229 }
1230
1231 pub fn compute_pass_push_debug_group(
1232 &self,
1233 pass: &mut ComputePass,
1234 label: &str,
1235 color: u32,
1236 ) -> Result<(), ComputePassError> {
1237 let base = pass.base_mut(PassErrorScope::PushDebugGroup)?;
1238
1239 let bytes = label.as_bytes();
1240 base.string_data.extend_from_slice(bytes);
1241
1242 base.commands.push(ArcComputeCommand::PushDebugGroup {
1243 color,
1244 len: bytes.len(),
1245 });
1246
1247 Ok(())
1248 }
1249
1250 pub fn compute_pass_pop_debug_group(
1251 &self,
1252 pass: &mut ComputePass,
1253 ) -> Result<(), ComputePassError> {
1254 let base = pass.base_mut(PassErrorScope::PopDebugGroup)?;
1255
1256 base.commands.push(ArcComputeCommand::PopDebugGroup);
1257
1258 Ok(())
1259 }
1260
1261 pub fn compute_pass_insert_debug_marker(
1262 &self,
1263 pass: &mut ComputePass,
1264 label: &str,
1265 color: u32,
1266 ) -> Result<(), ComputePassError> {
1267 let base = pass.base_mut(PassErrorScope::InsertDebugMarker)?;
1268
1269 let bytes = label.as_bytes();
1270 base.string_data.extend_from_slice(bytes);
1271
1272 base.commands.push(ArcComputeCommand::InsertDebugMarker {
1273 color,
1274 len: bytes.len(),
1275 });
1276
1277 Ok(())
1278 }
1279
1280 pub fn compute_pass_write_timestamp(
1281 &self,
1282 pass: &mut ComputePass,
1283 query_set_id: id::QuerySetId,
1284 query_index: u32,
1285 ) -> Result<(), ComputePassError> {
1286 let scope = PassErrorScope::WriteTimestamp;
1287 let base = pass.base_mut(scope)?;
1288
1289 let hub = &self.hub;
1290 let query_set = hub.query_sets.get(query_set_id).get().map_pass_err(scope)?;
1291
1292 base.commands.push(ArcComputeCommand::WriteTimestamp {
1293 query_set,
1294 query_index,
1295 });
1296
1297 Ok(())
1298 }
1299
1300 pub fn compute_pass_begin_pipeline_statistics_query(
1301 &self,
1302 pass: &mut ComputePass,
1303 query_set_id: id::QuerySetId,
1304 query_index: u32,
1305 ) -> Result<(), ComputePassError> {
1306 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1307 let base = pass.base_mut(scope)?;
1308
1309 let hub = &self.hub;
1310 let query_set = hub.query_sets.get(query_set_id).get().map_pass_err(scope)?;
1311
1312 base.commands
1313 .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1314 query_set,
1315 query_index,
1316 });
1317
1318 Ok(())
1319 }
1320
1321 pub fn compute_pass_end_pipeline_statistics_query(
1322 &self,
1323 pass: &mut ComputePass,
1324 ) -> Result<(), ComputePassError> {
1325 let scope = PassErrorScope::EndPipelineStatisticsQuery;
1326 let base = pass.base_mut(scope)?;
1327 base.commands
1328 .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1329
1330 Ok(())
1331 }
1332}