bevy_ecs/query/
par_iter.rs

1use crate::{
2    batching::BatchingStrategy, component::Tick, world::unsafe_world_cell::UnsafeWorldCell,
3};
4
5use super::{QueryData, QueryFilter, QueryItem, QueryState};
6
7/// A parallel iterator over query results of a [`Query`](crate::system::Query).
8///
9/// This struct is created by the [`Query::par_iter`](crate::system::Query::par_iter) and
10/// [`Query::par_iter_mut`](crate::system::Query::par_iter_mut) methods.
11pub struct QueryParIter<'w, 's, D: QueryData, F: QueryFilter> {
12    pub(crate) world: UnsafeWorldCell<'w>,
13    pub(crate) state: &'s QueryState<D, F>,
14    pub(crate) last_run: Tick,
15    pub(crate) this_run: Tick,
16    pub(crate) batching_strategy: BatchingStrategy,
17}
18
19impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> {
20    /// Changes the batching strategy used when iterating.
21    ///
22    /// For more information on how this affects the resultant iteration, see
23    /// [`BatchingStrategy`].
24    pub fn batching_strategy(mut self, strategy: BatchingStrategy) -> Self {
25        self.batching_strategy = strategy;
26        self
27    }
28
29    /// Runs `func` on each query result in parallel.
30    ///
31    /// # Panics
32    /// If the [`ComputeTaskPool`] is not initialized. If using this from a query that is being
33    /// initialized and run from the ECS scheduler, this should never panic.
34    ///
35    /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
36    #[inline]
37    pub fn for_each<FN: Fn(QueryItem<'w, D>) + Send + Sync + Clone>(self, func: FN) {
38        self.for_each_init(|| {}, |_, item| func(item));
39    }
40
41    /// Runs `func` on each query result in parallel on a value returned by `init`.
42    ///
43    /// `init` may be called multiple times per thread, and the values returned may be discarded between tasks on any given thread.
44    /// Callers should avoid using this function as if it were a parallel version
45    /// of [`Iterator::fold`].
46    ///
47    /// # Example
48    ///
49    /// ```
50    /// use bevy_utils::Parallel;
51    /// use crate::{bevy_ecs::prelude::Component, bevy_ecs::system::Query};
52    /// #[derive(Component)]
53    /// struct T;
54    /// fn system(query: Query<&T>){
55    ///     let mut queue: Parallel<usize> = Parallel::default();
56    ///     // queue.borrow_local_mut() will get or create a thread_local queue for each task/thread;
57    ///     query.par_iter().for_each_init(|| queue.borrow_local_mut(),|local_queue,item| {
58    ///         **local_queue += 1;
59    ///      });
60    ///     
61    ///     // collect value from every thread
62    ///     let entity_count: usize = queue.iter_mut().map(|v| *v).sum();
63    /// }
64    /// ```
65    ///
66    /// # Panics
67    /// If the [`ComputeTaskPool`] is not initialized. If using this from a query that is being
68    /// initialized and run from the ECS scheduler, this should never panic.
69    ///
70    /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
71    #[inline]
72    pub fn for_each_init<FN, INIT, T>(self, init: INIT, func: FN)
73    where
74        FN: Fn(&mut T, QueryItem<'w, D>) + Send + Sync + Clone,
75        INIT: Fn() -> T + Sync + Send + Clone,
76    {
77        let func = |mut init, item| {
78            func(&mut init, item);
79            init
80        };
81        #[cfg(any(target_arch = "wasm32", not(feature = "multi_threaded")))]
82        {
83            let init = init();
84            // SAFETY:
85            // This method can only be called once per instance of QueryParIter,
86            // which ensures that mutable queries cannot be executed multiple times at once.
87            // Mutable instances of QueryParIter can only be created via an exclusive borrow of a
88            // Query or a World, which ensures that multiple aliasing QueryParIters cannot exist
89            // at the same time.
90            unsafe {
91                self.state
92                    .iter_unchecked_manual(self.world, self.last_run, self.this_run)
93                    .fold(init, func);
94            }
95        }
96        #[cfg(all(not(target_arch = "wasm32"), feature = "multi_threaded"))]
97        {
98            let thread_count = bevy_tasks::ComputeTaskPool::get().thread_num();
99            if thread_count <= 1 {
100                let init = init();
101                // SAFETY: See the safety comment above.
102                unsafe {
103                    self.state
104                        .iter_unchecked_manual(self.world, self.last_run, self.this_run)
105                        .fold(init, func);
106                }
107            } else {
108                // Need a batch size of at least 1.
109                let batch_size = self.get_batch_size(thread_count).max(1);
110                // SAFETY: See the safety comment above.
111                unsafe {
112                    self.state.par_fold_init_unchecked_manual(
113                        init,
114                        self.world,
115                        batch_size,
116                        func,
117                        self.last_run,
118                        self.this_run,
119                    );
120                }
121            }
122        }
123    }
124
125    #[cfg(all(not(target_arch = "wasm32"), feature = "multi_threaded"))]
126    fn get_batch_size(&self, thread_count: usize) -> usize {
127        let max_items = || {
128            let id_iter = self.state.matched_storage_ids.iter();
129            if self.state.is_dense {
130                // SAFETY: We only access table metadata.
131                let tables = unsafe { &self.world.world_metadata().storages().tables };
132                id_iter
133                    // SAFETY: The if check ensures that matched_storage_ids stores TableIds
134                    .map(|id| unsafe { tables[id.table_id].entity_count() })
135                    .max()
136            } else {
137                let archetypes = &self.world.archetypes();
138                id_iter
139                    // SAFETY: The if check ensures that matched_storage_ids stores ArchetypeIds
140                    .map(|id| unsafe { archetypes[id.archetype_id].len() })
141                    .max()
142            }
143            .unwrap_or(0)
144        };
145        self.batching_strategy
146            .calc_batch_size(max_items, thread_count)
147    }
148}