wgpu_core/command/
query.rs

1#[cfg(feature = "trace")]
2use crate::device::trace::Command as TraceCommand;
3use crate::{
4    command::{CommandBuffer, CommandEncoderError},
5    device::{DeviceError, MissingFeatures},
6    global::Global,
7    id,
8    init_tracker::MemoryInitKind,
9    resource::{
10        DestroyedResourceError, InvalidResourceError, MissingBufferUsageError, ParentDevice,
11        QuerySet, Trackable,
12    },
13    track::{StatelessTracker, TrackerIndex},
14    FastHashMap,
15};
16use std::{iter, sync::Arc};
17use thiserror::Error;
18use wgt::BufferAddress;
19
20#[derive(Debug)]
21pub(crate) struct QueryResetMap {
22    map: FastHashMap<TrackerIndex, (Vec<bool>, Arc<QuerySet>)>,
23}
24impl QueryResetMap {
25    pub fn new() -> Self {
26        Self {
27            map: FastHashMap::default(),
28        }
29    }
30
31    pub fn use_query_set(&mut self, query_set: &Arc<QuerySet>, query: u32) -> bool {
32        let vec_pair = self
33            .map
34            .entry(query_set.tracker_index())
35            .or_insert_with(|| {
36                (
37                    vec![false; query_set.desc.count as usize],
38                    query_set.clone(),
39                )
40            });
41
42        std::mem::replace(&mut vec_pair.0[query as usize], true)
43    }
44
45    pub fn reset_queries(&mut self, raw_encoder: &mut dyn hal::DynCommandEncoder) {
46        for (_, (state, query_set)) in self.map.drain() {
47            debug_assert_eq!(state.len(), query_set.desc.count as usize);
48
49            // Need to find all "runs" of values which need resets. If the state vector is:
50            // [false, true, true, false, true], we want to reset [1..3, 4..5]. This minimizes
51            // the amount of resets needed.
52            let mut run_start: Option<u32> = None;
53            for (idx, value) in state.into_iter().chain(iter::once(false)).enumerate() {
54                match (run_start, value) {
55                    // We're inside of a run, do nothing
56                    (Some(..), true) => {}
57                    // We've hit the end of a run, dispatch a reset
58                    (Some(start), false) => {
59                        run_start = None;
60                        unsafe { raw_encoder.reset_queries(query_set.raw(), start..idx as u32) };
61                    }
62                    // We're starting a run
63                    (None, true) => {
64                        run_start = Some(idx as u32);
65                    }
66                    // We're in a run of falses, do nothing.
67                    (None, false) => {}
68                }
69            }
70        }
71    }
72}
73
74#[derive(Debug, Copy, Clone, PartialEq, Eq)]
75pub enum SimplifiedQueryType {
76    Occlusion,
77    Timestamp,
78    PipelineStatistics,
79}
80impl From<wgt::QueryType> for SimplifiedQueryType {
81    fn from(q: wgt::QueryType) -> Self {
82        match q {
83            wgt::QueryType::Occlusion => SimplifiedQueryType::Occlusion,
84            wgt::QueryType::Timestamp => SimplifiedQueryType::Timestamp,
85            wgt::QueryType::PipelineStatistics(..) => SimplifiedQueryType::PipelineStatistics,
86        }
87    }
88}
89
90/// Error encountered when dealing with queries
91#[derive(Clone, Debug, Error)]
92#[non_exhaustive]
93pub enum QueryError {
94    #[error(transparent)]
95    Device(#[from] DeviceError),
96    #[error(transparent)]
97    Encoder(#[from] CommandEncoderError),
98    #[error(transparent)]
99    MissingFeature(#[from] MissingFeatures),
100    #[error("Error encountered while trying to use queries")]
101    Use(#[from] QueryUseError),
102    #[error("Error encountered while trying to resolve a query")]
103    Resolve(#[from] ResolveError),
104    #[error(transparent)]
105    DestroyedResource(#[from] DestroyedResourceError),
106    #[error(transparent)]
107    InvalidResource(#[from] InvalidResourceError),
108}
109
110/// Error encountered while trying to use queries
111#[derive(Clone, Debug, Error)]
112#[non_exhaustive]
113pub enum QueryUseError {
114    #[error(transparent)]
115    Device(#[from] DeviceError),
116    #[error("Query {query_index} is out of bounds for a query set of size {query_set_size}")]
117    OutOfBounds {
118        query_index: u32,
119        query_set_size: u32,
120    },
121    #[error("Query {query_index} has already been used within the same renderpass. Queries must only be used once per renderpass")]
122    UsedTwiceInsideRenderpass { query_index: u32 },
123    #[error("Query {new_query_index} was started while query {active_query_index} was already active. No more than one statistic or occlusion query may be active at once")]
124    AlreadyStarted {
125        active_query_index: u32,
126        new_query_index: u32,
127    },
128    #[error("Query was stopped while there was no active query")]
129    AlreadyStopped,
130    #[error("A query of type {query_type:?} was started using a query set of type {set_type:?}")]
131    IncompatibleType {
132        set_type: SimplifiedQueryType,
133        query_type: SimplifiedQueryType,
134    },
135}
136
137/// Error encountered while trying to resolve a query.
138#[derive(Clone, Debug, Error)]
139#[non_exhaustive]
140pub enum ResolveError {
141    #[error(transparent)]
142    MissingBufferUsage(#[from] MissingBufferUsageError),
143    #[error("Resolve buffer offset has to be aligned to `QUERY_RESOLVE_BUFFER_ALIGNMENT")]
144    BufferOffsetAlignment,
145    #[error("Resolving queries {start_query}..{end_query} would overrun the query set of size {query_set_size}")]
146    QueryOverrun {
147        start_query: u32,
148        end_query: u32,
149        query_set_size: u32,
150    },
151    #[error("Resolving queries {start_query}..{end_query} ({stride} byte queries) will end up overrunning the bounds of the destination buffer of size {buffer_size} using offsets {buffer_start_offset}..{buffer_end_offset}")]
152    BufferOverrun {
153        start_query: u32,
154        end_query: u32,
155        stride: u32,
156        buffer_size: BufferAddress,
157        buffer_start_offset: BufferAddress,
158        buffer_end_offset: BufferAddress,
159    },
160}
161
162impl QuerySet {
163    fn validate_query(
164        self: &Arc<Self>,
165        query_type: SimplifiedQueryType,
166        query_index: u32,
167        reset_state: Option<&mut QueryResetMap>,
168    ) -> Result<(), QueryUseError> {
169        // We need to defer our resets because we are in a renderpass,
170        // add the usage to the reset map.
171        if let Some(reset) = reset_state {
172            let used = reset.use_query_set(self, query_index);
173            if used {
174                return Err(QueryUseError::UsedTwiceInsideRenderpass { query_index });
175            }
176        }
177
178        let simple_set_type = SimplifiedQueryType::from(self.desc.ty);
179        if simple_set_type != query_type {
180            return Err(QueryUseError::IncompatibleType {
181                query_type,
182                set_type: simple_set_type,
183            });
184        }
185
186        if query_index >= self.desc.count {
187            return Err(QueryUseError::OutOfBounds {
188                query_index,
189                query_set_size: self.desc.count,
190            });
191        }
192
193        Ok(())
194    }
195
196    pub(super) fn validate_and_write_timestamp(
197        self: &Arc<Self>,
198        raw_encoder: &mut dyn hal::DynCommandEncoder,
199        query_index: u32,
200        reset_state: Option<&mut QueryResetMap>,
201    ) -> Result<(), QueryUseError> {
202        let needs_reset = reset_state.is_none();
203        self.validate_query(SimplifiedQueryType::Timestamp, query_index, reset_state)?;
204
205        unsafe {
206            // If we don't have a reset state tracker which can defer resets, we must reset now.
207            if needs_reset {
208                raw_encoder.reset_queries(self.raw(), query_index..(query_index + 1));
209            }
210            raw_encoder.write_timestamp(self.raw(), query_index);
211        }
212
213        Ok(())
214    }
215}
216
217pub(super) fn validate_and_begin_occlusion_query(
218    query_set: Arc<QuerySet>,
219    raw_encoder: &mut dyn hal::DynCommandEncoder,
220    tracker: &mut StatelessTracker<QuerySet>,
221    query_index: u32,
222    reset_state: Option<&mut QueryResetMap>,
223    active_query: &mut Option<(Arc<QuerySet>, u32)>,
224) -> Result<(), QueryUseError> {
225    let needs_reset = reset_state.is_none();
226    query_set.validate_query(SimplifiedQueryType::Occlusion, query_index, reset_state)?;
227
228    tracker.insert_single(query_set.clone());
229
230    if let Some((_old, old_idx)) = active_query.take() {
231        return Err(QueryUseError::AlreadyStarted {
232            active_query_index: old_idx,
233            new_query_index: query_index,
234        });
235    }
236    let (query_set, _) = &active_query.insert((query_set, query_index));
237
238    unsafe {
239        // If we don't have a reset state tracker which can defer resets, we must reset now.
240        if needs_reset {
241            raw_encoder.reset_queries(query_set.raw(), query_index..(query_index + 1));
242        }
243        raw_encoder.begin_query(query_set.raw(), query_index);
244    }
245
246    Ok(())
247}
248
249pub(super) fn end_occlusion_query(
250    raw_encoder: &mut dyn hal::DynCommandEncoder,
251    active_query: &mut Option<(Arc<QuerySet>, u32)>,
252) -> Result<(), QueryUseError> {
253    if let Some((query_set, query_index)) = active_query.take() {
254        unsafe { raw_encoder.end_query(query_set.raw(), query_index) };
255        Ok(())
256    } else {
257        Err(QueryUseError::AlreadyStopped)
258    }
259}
260
261pub(super) fn validate_and_begin_pipeline_statistics_query(
262    query_set: Arc<QuerySet>,
263    raw_encoder: &mut dyn hal::DynCommandEncoder,
264    tracker: &mut StatelessTracker<QuerySet>,
265    cmd_buf: &CommandBuffer,
266    query_index: u32,
267    reset_state: Option<&mut QueryResetMap>,
268    active_query: &mut Option<(Arc<QuerySet>, u32)>,
269) -> Result<(), QueryUseError> {
270    query_set.same_device_as(cmd_buf)?;
271
272    let needs_reset = reset_state.is_none();
273    query_set.validate_query(
274        SimplifiedQueryType::PipelineStatistics,
275        query_index,
276        reset_state,
277    )?;
278
279    tracker.insert_single(query_set.clone());
280
281    if let Some((_old, old_idx)) = active_query.take() {
282        return Err(QueryUseError::AlreadyStarted {
283            active_query_index: old_idx,
284            new_query_index: query_index,
285        });
286    }
287    let (query_set, _) = &active_query.insert((query_set, query_index));
288
289    unsafe {
290        // If we don't have a reset state tracker which can defer resets, we must reset now.
291        if needs_reset {
292            raw_encoder.reset_queries(query_set.raw(), query_index..(query_index + 1));
293        }
294        raw_encoder.begin_query(query_set.raw(), query_index);
295    }
296
297    Ok(())
298}
299
300pub(super) fn end_pipeline_statistics_query(
301    raw_encoder: &mut dyn hal::DynCommandEncoder,
302    active_query: &mut Option<(Arc<QuerySet>, u32)>,
303) -> Result<(), QueryUseError> {
304    if let Some((query_set, query_index)) = active_query.take() {
305        unsafe { raw_encoder.end_query(query_set.raw(), query_index) };
306        Ok(())
307    } else {
308        Err(QueryUseError::AlreadyStopped)
309    }
310}
311
312impl Global {
313    pub fn command_encoder_write_timestamp(
314        &self,
315        command_encoder_id: id::CommandEncoderId,
316        query_set_id: id::QuerySetId,
317        query_index: u32,
318    ) -> Result<(), QueryError> {
319        let hub = &self.hub;
320
321        let cmd_buf = hub
322            .command_buffers
323            .get(command_encoder_id.into_command_buffer_id());
324        let mut cmd_buf_data = cmd_buf.try_get()?;
325        cmd_buf_data.check_recording()?;
326
327        cmd_buf
328            .device
329            .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_ENCODERS)?;
330
331        #[cfg(feature = "trace")]
332        if let Some(ref mut list) = cmd_buf_data.commands {
333            list.push(TraceCommand::WriteTimestamp {
334                query_set_id,
335                query_index,
336            });
337        }
338
339        let raw_encoder = cmd_buf_data.encoder.open(&cmd_buf.device)?;
340
341        let query_set = hub.query_sets.get(query_set_id).get()?;
342
343        query_set.validate_and_write_timestamp(raw_encoder, query_index, None)?;
344
345        cmd_buf_data.trackers.query_sets.insert_single(query_set);
346
347        Ok(())
348    }
349
350    pub fn command_encoder_resolve_query_set(
351        &self,
352        command_encoder_id: id::CommandEncoderId,
353        query_set_id: id::QuerySetId,
354        start_query: u32,
355        query_count: u32,
356        destination: id::BufferId,
357        destination_offset: BufferAddress,
358    ) -> Result<(), QueryError> {
359        let hub = &self.hub;
360
361        let cmd_buf = hub
362            .command_buffers
363            .get(command_encoder_id.into_command_buffer_id());
364        let mut cmd_buf_data = cmd_buf.try_get()?;
365        cmd_buf_data.check_recording()?;
366
367        #[cfg(feature = "trace")]
368        if let Some(ref mut list) = cmd_buf_data.commands {
369            list.push(TraceCommand::ResolveQuerySet {
370                query_set_id,
371                start_query,
372                query_count,
373                destination,
374                destination_offset,
375            });
376        }
377
378        if destination_offset % wgt::QUERY_RESOLVE_BUFFER_ALIGNMENT != 0 {
379            return Err(QueryError::Resolve(ResolveError::BufferOffsetAlignment));
380        }
381
382        let query_set = hub.query_sets.get(query_set_id).get()?;
383
384        query_set.same_device_as(cmd_buf.as_ref())?;
385
386        let dst_buffer = hub.buffers.get(destination).get()?;
387
388        dst_buffer.same_device_as(cmd_buf.as_ref())?;
389
390        let dst_pending = cmd_buf_data
391            .trackers
392            .buffers
393            .set_single(&dst_buffer, hal::BufferUses::COPY_DST);
394
395        let snatch_guard = dst_buffer.device.snatchable_lock.read();
396
397        let dst_barrier = dst_pending.map(|pending| pending.into_hal(&dst_buffer, &snatch_guard));
398
399        dst_buffer
400            .check_usage(wgt::BufferUsages::QUERY_RESOLVE)
401            .map_err(ResolveError::MissingBufferUsage)?;
402
403        let end_query = start_query + query_count;
404        if end_query > query_set.desc.count {
405            return Err(ResolveError::QueryOverrun {
406                start_query,
407                end_query,
408                query_set_size: query_set.desc.count,
409            }
410            .into());
411        }
412
413        let elements_per_query = match query_set.desc.ty {
414            wgt::QueryType::Occlusion => 1,
415            wgt::QueryType::PipelineStatistics(ps) => ps.bits().count_ones(),
416            wgt::QueryType::Timestamp => 1,
417        };
418        let stride = elements_per_query * wgt::QUERY_SIZE;
419        let bytes_used = (stride * query_count) as BufferAddress;
420
421        let buffer_start_offset = destination_offset;
422        let buffer_end_offset = buffer_start_offset + bytes_used;
423
424        if buffer_end_offset > dst_buffer.size {
425            return Err(ResolveError::BufferOverrun {
426                start_query,
427                end_query,
428                stride,
429                buffer_size: dst_buffer.size,
430                buffer_start_offset,
431                buffer_end_offset,
432            }
433            .into());
434        }
435
436        // TODO(https://github.com/gfx-rs/wgpu/issues/3993): Need to track initialization state.
437        cmd_buf_data.buffer_memory_init_actions.extend(
438            dst_buffer.initialization_status.read().create_action(
439                &dst_buffer,
440                buffer_start_offset..buffer_end_offset,
441                MemoryInitKind::ImplicitlyInitialized,
442            ),
443        );
444
445        let raw_dst_buffer = dst_buffer.try_raw(&snatch_guard)?;
446        let raw_encoder = cmd_buf_data.encoder.open(&cmd_buf.device)?;
447        unsafe {
448            raw_encoder.transition_buffers(dst_barrier.as_slice());
449            raw_encoder.copy_query_results(
450                query_set.raw(),
451                start_query..end_query,
452                raw_dst_buffer,
453                destination_offset,
454                wgt::BufferSize::new_unchecked(stride as u64),
455            );
456        }
457
458        cmd_buf_data.trackers.query_sets.insert_single(query_set);
459
460        Ok(())
461    }
462}