wgpu_hal/dynamic/
command.rs

1use std::ops::Range;
2
3use crate::{
4    AccelerationStructureBarrier, Api, Attachment, BufferBarrier, BufferBinding, BufferCopy,
5    BufferTextureCopy, BuildAccelerationStructureDescriptor, ColorAttachment, CommandEncoder,
6    ComputePassDescriptor, DepthStencilAttachment, DeviceError, Label, MemoryRange,
7    PassTimestampWrites, Rect, RenderPassDescriptor, TextureBarrier, TextureCopy, TextureUses,
8};
9
10use super::{
11    DynAccelerationStructure, DynBindGroup, DynBuffer, DynCommandBuffer, DynComputePipeline,
12    DynPipelineLayout, DynQuerySet, DynRenderPipeline, DynResource, DynResourceExt as _,
13    DynTexture, DynTextureView,
14};
15
16pub trait DynCommandEncoder: DynResource + std::fmt::Debug {
17    unsafe fn begin_encoding(&mut self, label: Label) -> Result<(), DeviceError>;
18
19    unsafe fn discard_encoding(&mut self);
20
21    unsafe fn end_encoding(&mut self) -> Result<Box<dyn DynCommandBuffer>, DeviceError>;
22
23    unsafe fn reset_all(&mut self, command_buffers: Vec<Box<dyn DynCommandBuffer>>);
24
25    unsafe fn transition_buffers(&mut self, barriers: &[BufferBarrier<'_, dyn DynBuffer>]);
26    unsafe fn transition_textures(&mut self, barriers: &[TextureBarrier<'_, dyn DynTexture>]);
27
28    unsafe fn clear_buffer(&mut self, buffer: &dyn DynBuffer, range: MemoryRange);
29
30    unsafe fn copy_buffer_to_buffer(
31        &mut self,
32        src: &dyn DynBuffer,
33        dst: &dyn DynBuffer,
34        regions: &[BufferCopy],
35    );
36
37    unsafe fn copy_texture_to_texture(
38        &mut self,
39        src: &dyn DynTexture,
40        src_usage: TextureUses,
41        dst: &dyn DynTexture,
42        regions: &[TextureCopy],
43    );
44
45    unsafe fn copy_buffer_to_texture(
46        &mut self,
47        src: &dyn DynBuffer,
48        dst: &dyn DynTexture,
49        regions: &[BufferTextureCopy],
50    );
51
52    unsafe fn copy_texture_to_buffer(
53        &mut self,
54        src: &dyn DynTexture,
55        src_usage: TextureUses,
56        dst: &dyn DynBuffer,
57        regions: &[BufferTextureCopy],
58    );
59
60    unsafe fn set_bind_group(
61        &mut self,
62        layout: &dyn DynPipelineLayout,
63        index: u32,
64        group: Option<&dyn DynBindGroup>,
65        dynamic_offsets: &[wgt::DynamicOffset],
66    );
67
68    unsafe fn set_push_constants(
69        &mut self,
70        layout: &dyn DynPipelineLayout,
71        stages: wgt::ShaderStages,
72        offset_bytes: u32,
73        data: &[u32],
74    );
75
76    unsafe fn insert_debug_marker(&mut self, label: &str);
77    unsafe fn begin_debug_marker(&mut self, group_label: &str);
78    unsafe fn end_debug_marker(&mut self);
79
80    unsafe fn begin_query(&mut self, set: &dyn DynQuerySet, index: u32);
81    unsafe fn end_query(&mut self, set: &dyn DynQuerySet, index: u32);
82    unsafe fn write_timestamp(&mut self, set: &dyn DynQuerySet, index: u32);
83    unsafe fn reset_queries(&mut self, set: &dyn DynQuerySet, range: Range<u32>);
84    unsafe fn copy_query_results(
85        &mut self,
86        set: &dyn DynQuerySet,
87        range: Range<u32>,
88        buffer: &dyn DynBuffer,
89        offset: wgt::BufferAddress,
90        stride: wgt::BufferSize,
91    );
92
93    unsafe fn begin_render_pass(
94        &mut self,
95        desc: &RenderPassDescriptor<dyn DynQuerySet, dyn DynTextureView>,
96    );
97    unsafe fn end_render_pass(&mut self);
98
99    unsafe fn set_render_pipeline(&mut self, pipeline: &dyn DynRenderPipeline);
100
101    unsafe fn set_index_buffer<'a>(
102        &mut self,
103        binding: BufferBinding<'a, dyn DynBuffer>,
104        format: wgt::IndexFormat,
105    );
106
107    unsafe fn set_vertex_buffer<'a>(
108        &mut self,
109        index: u32,
110        binding: BufferBinding<'a, dyn DynBuffer>,
111    );
112    unsafe fn set_viewport(&mut self, rect: &Rect<f32>, depth_range: Range<f32>);
113    unsafe fn set_scissor_rect(&mut self, rect: &Rect<u32>);
114    unsafe fn set_stencil_reference(&mut self, value: u32);
115    unsafe fn set_blend_constants(&mut self, color: &[f32; 4]);
116
117    unsafe fn draw(
118        &mut self,
119        first_vertex: u32,
120        vertex_count: u32,
121        first_instance: u32,
122        instance_count: u32,
123    );
124    unsafe fn draw_indexed(
125        &mut self,
126        first_index: u32,
127        index_count: u32,
128        base_vertex: i32,
129        first_instance: u32,
130        instance_count: u32,
131    );
132    unsafe fn draw_indirect(
133        &mut self,
134        buffer: &dyn DynBuffer,
135        offset: wgt::BufferAddress,
136        draw_count: u32,
137    );
138    unsafe fn draw_indexed_indirect(
139        &mut self,
140        buffer: &dyn DynBuffer,
141        offset: wgt::BufferAddress,
142        draw_count: u32,
143    );
144    unsafe fn draw_indirect_count(
145        &mut self,
146        buffer: &dyn DynBuffer,
147        offset: wgt::BufferAddress,
148        count_buffer: &dyn DynBuffer,
149        count_offset: wgt::BufferAddress,
150        max_count: u32,
151    );
152    unsafe fn draw_indexed_indirect_count(
153        &mut self,
154        buffer: &dyn DynBuffer,
155        offset: wgt::BufferAddress,
156        count_buffer: &dyn DynBuffer,
157        count_offset: wgt::BufferAddress,
158        max_count: u32,
159    );
160
161    unsafe fn begin_compute_pass(&mut self, desc: &ComputePassDescriptor<dyn DynQuerySet>);
162    unsafe fn end_compute_pass(&mut self);
163
164    unsafe fn set_compute_pipeline(&mut self, pipeline: &dyn DynComputePipeline);
165
166    unsafe fn dispatch(&mut self, count: [u32; 3]);
167    unsafe fn dispatch_indirect(&mut self, buffer: &dyn DynBuffer, offset: wgt::BufferAddress);
168
169    unsafe fn build_acceleration_structures<'a>(
170        &mut self,
171        descriptors: &'a [BuildAccelerationStructureDescriptor<
172            'a,
173            dyn DynBuffer,
174            dyn DynAccelerationStructure,
175        >],
176    );
177
178    unsafe fn place_acceleration_structure_barrier(
179        &mut self,
180        barrier: AccelerationStructureBarrier,
181    );
182}
183
184impl<C: CommandEncoder + DynResource> DynCommandEncoder for C {
185    unsafe fn begin_encoding(&mut self, label: Label) -> Result<(), DeviceError> {
186        unsafe { C::begin_encoding(self, label) }
187    }
188
189    unsafe fn discard_encoding(&mut self) {
190        unsafe { C::discard_encoding(self) }
191    }
192
193    unsafe fn end_encoding(&mut self) -> Result<Box<dyn DynCommandBuffer>, DeviceError> {
194        unsafe { C::end_encoding(self) }.map(|cb| {
195            let boxed_command_buffer: Box<<C::A as Api>::CommandBuffer> = Box::new(cb);
196            let boxed_command_buffer: Box<dyn DynCommandBuffer> = boxed_command_buffer;
197            boxed_command_buffer
198        })
199    }
200
201    unsafe fn reset_all(&mut self, command_buffers: Vec<Box<dyn DynCommandBuffer>>) {
202        unsafe { C::reset_all(self, command_buffers.into_iter().map(|cb| cb.unbox())) }
203    }
204
205    unsafe fn transition_buffers(&mut self, barriers: &[BufferBarrier<'_, dyn DynBuffer>]) {
206        let barriers = barriers.iter().map(|barrier| BufferBarrier {
207            buffer: barrier.buffer.expect_downcast_ref(),
208            usage: barrier.usage.clone(),
209        });
210        unsafe { self.transition_buffers(barriers) };
211    }
212
213    unsafe fn transition_textures(&mut self, barriers: &[TextureBarrier<'_, dyn DynTexture>]) {
214        let barriers = barriers.iter().map(|barrier| TextureBarrier {
215            texture: barrier.texture.expect_downcast_ref(),
216            usage: barrier.usage.clone(),
217            range: barrier.range,
218        });
219        unsafe { self.transition_textures(barriers) };
220    }
221
222    unsafe fn clear_buffer(&mut self, buffer: &dyn DynBuffer, range: MemoryRange) {
223        let buffer = buffer.expect_downcast_ref();
224        unsafe { C::clear_buffer(self, buffer, range) };
225    }
226
227    unsafe fn copy_buffer_to_buffer(
228        &mut self,
229        src: &dyn DynBuffer,
230        dst: &dyn DynBuffer,
231        regions: &[BufferCopy],
232    ) {
233        let src = src.expect_downcast_ref();
234        let dst = dst.expect_downcast_ref();
235        unsafe {
236            C::copy_buffer_to_buffer(self, src, dst, regions.iter().copied());
237        }
238    }
239
240    unsafe fn copy_texture_to_texture(
241        &mut self,
242        src: &dyn DynTexture,
243        src_usage: TextureUses,
244        dst: &dyn DynTexture,
245        regions: &[TextureCopy],
246    ) {
247        let src = src.expect_downcast_ref();
248        let dst = dst.expect_downcast_ref();
249        unsafe {
250            C::copy_texture_to_texture(self, src, src_usage, dst, regions.iter().cloned());
251        }
252    }
253
254    unsafe fn copy_buffer_to_texture(
255        &mut self,
256        src: &dyn DynBuffer,
257        dst: &dyn DynTexture,
258        regions: &[BufferTextureCopy],
259    ) {
260        let src = src.expect_downcast_ref();
261        let dst = dst.expect_downcast_ref();
262        unsafe {
263            C::copy_buffer_to_texture(self, src, dst, regions.iter().cloned());
264        }
265    }
266
267    unsafe fn copy_texture_to_buffer(
268        &mut self,
269        src: &dyn DynTexture,
270        src_usage: TextureUses,
271        dst: &dyn DynBuffer,
272        regions: &[BufferTextureCopy],
273    ) {
274        let src = src.expect_downcast_ref();
275        let dst = dst.expect_downcast_ref();
276        unsafe {
277            C::copy_texture_to_buffer(self, src, src_usage, dst, regions.iter().cloned());
278        }
279    }
280
281    unsafe fn set_bind_group(
282        &mut self,
283        layout: &dyn DynPipelineLayout,
284        index: u32,
285        group: Option<&dyn DynBindGroup>,
286        dynamic_offsets: &[wgt::DynamicOffset],
287    ) {
288        if group.is_none() {
289            // TODO: Handle group None correctly.
290            return;
291        }
292        let group = group.unwrap();
293
294        let layout = layout.expect_downcast_ref();
295        let group = group.expect_downcast_ref();
296        unsafe { C::set_bind_group(self, layout, index, group, dynamic_offsets) };
297    }
298
299    unsafe fn set_push_constants(
300        &mut self,
301        layout: &dyn DynPipelineLayout,
302        stages: wgt::ShaderStages,
303        offset_bytes: u32,
304        data: &[u32],
305    ) {
306        let layout = layout.expect_downcast_ref();
307        unsafe { C::set_push_constants(self, layout, stages, offset_bytes, data) };
308    }
309
310    unsafe fn insert_debug_marker(&mut self, label: &str) {
311        unsafe {
312            C::insert_debug_marker(self, label);
313        }
314    }
315
316    unsafe fn begin_debug_marker(&mut self, group_label: &str) {
317        unsafe {
318            C::begin_debug_marker(self, group_label);
319        }
320    }
321
322    unsafe fn end_debug_marker(&mut self) {
323        unsafe {
324            C::end_debug_marker(self);
325        }
326    }
327
328    unsafe fn begin_query(&mut self, set: &dyn DynQuerySet, index: u32) {
329        let set = set.expect_downcast_ref();
330        unsafe { C::begin_query(self, set, index) };
331    }
332
333    unsafe fn end_query(&mut self, set: &dyn DynQuerySet, index: u32) {
334        let set = set.expect_downcast_ref();
335        unsafe { C::end_query(self, set, index) };
336    }
337
338    unsafe fn write_timestamp(&mut self, set: &dyn DynQuerySet, index: u32) {
339        let set = set.expect_downcast_ref();
340        unsafe { C::write_timestamp(self, set, index) };
341    }
342
343    unsafe fn reset_queries(&mut self, set: &dyn DynQuerySet, range: Range<u32>) {
344        let set = set.expect_downcast_ref();
345        unsafe { C::reset_queries(self, set, range) };
346    }
347
348    unsafe fn copy_query_results(
349        &mut self,
350        set: &dyn DynQuerySet,
351        range: Range<u32>,
352        buffer: &dyn DynBuffer,
353        offset: wgt::BufferAddress,
354        stride: wgt::BufferSize,
355    ) {
356        let set = set.expect_downcast_ref();
357        let buffer = buffer.expect_downcast_ref();
358        unsafe { C::copy_query_results(self, set, range, buffer, offset, stride) };
359    }
360
361    unsafe fn begin_render_pass(
362        &mut self,
363        desc: &RenderPassDescriptor<dyn DynQuerySet, dyn DynTextureView>,
364    ) {
365        let color_attachments = desc
366            .color_attachments
367            .iter()
368            .map(|attachment| {
369                attachment
370                    .as_ref()
371                    .map(|attachment| attachment.expect_downcast())
372            })
373            .collect::<Vec<_>>();
374
375        let desc: RenderPassDescriptor<<C::A as Api>::QuerySet, <C::A as Api>::TextureView> =
376            RenderPassDescriptor {
377                label: desc.label,
378                extent: desc.extent,
379                sample_count: desc.sample_count,
380                color_attachments: &color_attachments,
381                depth_stencil_attachment: desc
382                    .depth_stencil_attachment
383                    .as_ref()
384                    .map(|ds| ds.expect_downcast()),
385                multiview: desc.multiview,
386                timestamp_writes: desc
387                    .timestamp_writes
388                    .as_ref()
389                    .map(|writes| writes.expect_downcast()),
390                occlusion_query_set: desc
391                    .occlusion_query_set
392                    .map(|set| set.expect_downcast_ref()),
393            };
394        unsafe { C::begin_render_pass(self, &desc) };
395    }
396
397    unsafe fn end_render_pass(&mut self) {
398        unsafe {
399            C::end_render_pass(self);
400        }
401    }
402
403    unsafe fn set_viewport(&mut self, rect: &Rect<f32>, depth_range: Range<f32>) {
404        unsafe {
405            C::set_viewport(self, rect, depth_range);
406        }
407    }
408
409    unsafe fn set_scissor_rect(&mut self, rect: &Rect<u32>) {
410        unsafe {
411            C::set_scissor_rect(self, rect);
412        }
413    }
414
415    unsafe fn set_stencil_reference(&mut self, value: u32) {
416        unsafe {
417            C::set_stencil_reference(self, value);
418        }
419    }
420
421    unsafe fn set_blend_constants(&mut self, color: &[f32; 4]) {
422        unsafe { C::set_blend_constants(self, color) };
423    }
424
425    unsafe fn draw(
426        &mut self,
427        first_vertex: u32,
428        vertex_count: u32,
429        first_instance: u32,
430        instance_count: u32,
431    ) {
432        unsafe {
433            C::draw(
434                self,
435                first_vertex,
436                vertex_count,
437                first_instance,
438                instance_count,
439            )
440        };
441    }
442
443    unsafe fn draw_indexed(
444        &mut self,
445        first_index: u32,
446        index_count: u32,
447        base_vertex: i32,
448        first_instance: u32,
449        instance_count: u32,
450    ) {
451        unsafe {
452            C::draw_indexed(
453                self,
454                first_index,
455                index_count,
456                base_vertex,
457                first_instance,
458                instance_count,
459            )
460        };
461    }
462
463    unsafe fn draw_indirect(
464        &mut self,
465        buffer: &dyn DynBuffer,
466        offset: wgt::BufferAddress,
467        draw_count: u32,
468    ) {
469        let buffer = buffer.expect_downcast_ref();
470        unsafe { C::draw_indirect(self, buffer, offset, draw_count) };
471    }
472
473    unsafe fn draw_indexed_indirect(
474        &mut self,
475        buffer: &dyn DynBuffer,
476        offset: wgt::BufferAddress,
477        draw_count: u32,
478    ) {
479        let buffer = buffer.expect_downcast_ref();
480        unsafe { C::draw_indexed_indirect(self, buffer, offset, draw_count) };
481    }
482
483    unsafe fn draw_indirect_count(
484        &mut self,
485        buffer: &dyn DynBuffer,
486        offset: wgt::BufferAddress,
487        count_buffer: &dyn DynBuffer,
488        count_offset: wgt::BufferAddress,
489        max_count: u32,
490    ) {
491        let buffer = buffer.expect_downcast_ref();
492        let count_buffer = count_buffer.expect_downcast_ref();
493        unsafe {
494            C::draw_indirect_count(self, buffer, offset, count_buffer, count_offset, max_count)
495        };
496    }
497
498    unsafe fn draw_indexed_indirect_count(
499        &mut self,
500        buffer: &dyn DynBuffer,
501        offset: wgt::BufferAddress,
502        count_buffer: &dyn DynBuffer,
503        count_offset: wgt::BufferAddress,
504        max_count: u32,
505    ) {
506        let buffer = buffer.expect_downcast_ref();
507        let count_buffer = count_buffer.expect_downcast_ref();
508        unsafe {
509            C::draw_indexed_indirect_count(
510                self,
511                buffer,
512                offset,
513                count_buffer,
514                count_offset,
515                max_count,
516            )
517        };
518    }
519
520    unsafe fn begin_compute_pass(&mut self, desc: &ComputePassDescriptor<dyn DynQuerySet>) {
521        let desc = ComputePassDescriptor {
522            label: desc.label,
523            timestamp_writes: desc
524                .timestamp_writes
525                .as_ref()
526                .map(|writes| writes.expect_downcast()),
527        };
528        unsafe { C::begin_compute_pass(self, &desc) };
529    }
530
531    unsafe fn end_compute_pass(&mut self) {
532        unsafe { C::end_compute_pass(self) };
533    }
534
535    unsafe fn set_compute_pipeline(&mut self, pipeline: &dyn DynComputePipeline) {
536        let pipeline = pipeline.expect_downcast_ref();
537        unsafe { C::set_compute_pipeline(self, pipeline) };
538    }
539
540    unsafe fn dispatch(&mut self, count: [u32; 3]) {
541        unsafe { C::dispatch(self, count) };
542    }
543
544    unsafe fn dispatch_indirect(&mut self, buffer: &dyn DynBuffer, offset: wgt::BufferAddress) {
545        let buffer = buffer.expect_downcast_ref();
546        unsafe { C::dispatch_indirect(self, buffer, offset) };
547    }
548
549    unsafe fn set_render_pipeline(&mut self, pipeline: &dyn DynRenderPipeline) {
550        let pipeline = pipeline.expect_downcast_ref();
551        unsafe { C::set_render_pipeline(self, pipeline) };
552    }
553
554    unsafe fn set_index_buffer<'a>(
555        &mut self,
556        binding: BufferBinding<'a, dyn DynBuffer>,
557        format: wgt::IndexFormat,
558    ) {
559        let binding = binding.expect_downcast();
560        unsafe { self.set_index_buffer(binding, format) };
561    }
562
563    unsafe fn set_vertex_buffer<'a>(
564        &mut self,
565        index: u32,
566        binding: BufferBinding<'a, dyn DynBuffer>,
567    ) {
568        let binding = binding.expect_downcast();
569        unsafe { self.set_vertex_buffer(index, binding) };
570    }
571
572    unsafe fn build_acceleration_structures<'a>(
573        &mut self,
574        descriptors: &'a [BuildAccelerationStructureDescriptor<
575            'a,
576            dyn DynBuffer,
577            dyn DynAccelerationStructure,
578        >],
579    ) {
580        // Need to collect entries here so we can reference them in the descriptor.
581        // TODO: API should be redesigned to avoid this and other descriptor copies that happen due to the dyn api.
582        let descriptor_entries = descriptors
583            .iter()
584            .map(|d| d.entries.expect_downcast())
585            .collect::<Vec<_>>();
586        let descriptors = descriptors
587            .iter()
588            .zip(descriptor_entries.iter())
589            .map(|(d, entries)| BuildAccelerationStructureDescriptor::<
590                <C::A as Api>::Buffer,
591                <C::A as Api>::AccelerationStructure,
592            > {
593                entries,
594                mode: d.mode,
595                flags: d.flags,
596                source_acceleration_structure: d
597                    .source_acceleration_structure
598                    .map(|a| a.expect_downcast_ref()),
599                destination_acceleration_structure: d
600                    .destination_acceleration_structure
601                    .expect_downcast_ref(),
602                scratch_buffer: d.scratch_buffer.expect_downcast_ref(),
603                scratch_buffer_offset: d.scratch_buffer_offset,
604            });
605        unsafe { C::build_acceleration_structures(self, descriptors.len() as _, descriptors) };
606    }
607
608    unsafe fn place_acceleration_structure_barrier(
609        &mut self,
610        barrier: AccelerationStructureBarrier,
611    ) {
612        unsafe { C::place_acceleration_structure_barrier(self, barrier) };
613    }
614}
615
616impl<'a> PassTimestampWrites<'a, dyn DynQuerySet> {
617    pub fn expect_downcast<B: DynQuerySet>(&self) -> PassTimestampWrites<'a, B> {
618        PassTimestampWrites {
619            query_set: self.query_set.expect_downcast_ref(),
620            beginning_of_pass_write_index: self.beginning_of_pass_write_index,
621            end_of_pass_write_index: self.end_of_pass_write_index,
622        }
623    }
624}
625
626impl<'a> Attachment<'a, dyn DynTextureView> {
627    pub fn expect_downcast<B: DynTextureView>(&self) -> Attachment<'a, B> {
628        Attachment {
629            view: self.view.expect_downcast_ref(),
630            usage: self.usage,
631        }
632    }
633}
634
635impl<'a> ColorAttachment<'a, dyn DynTextureView> {
636    pub fn expect_downcast<B: DynTextureView>(&self) -> ColorAttachment<'a, B> {
637        ColorAttachment {
638            target: self.target.expect_downcast(),
639            resolve_target: self.resolve_target.as_ref().map(|rt| rt.expect_downcast()),
640            ops: self.ops,
641            clear_value: self.clear_value,
642        }
643    }
644}
645
646impl<'a> DepthStencilAttachment<'a, dyn DynTextureView> {
647    pub fn expect_downcast<B: DynTextureView>(&self) -> DepthStencilAttachment<'a, B> {
648        DepthStencilAttachment {
649            target: self.target.expect_downcast(),
650            depth_ops: self.depth_ops,
651            stencil_ops: self.stencil_ops,
652            clear_value: self.clear_value,
653        }
654    }
655}