bevy_tasks/iter/
mod.rs

1use crate::TaskPool;
2use alloc::vec::Vec;
3
4mod adapters;
5pub use adapters::*;
6
7/// [`ParallelIterator`] closely emulates the `std::iter::Iterator`
8/// interface. However, it uses `bevy_task` to compute batches in parallel.
9///
10/// Note that the overhead of [`ParallelIterator`] is high relative to some
11/// workloads. In particular, if the batch size is too small or task being
12/// run in parallel is inexpensive, *a [`ParallelIterator`] could take longer
13/// than a normal [`Iterator`]*. Therefore, you should profile your code before
14/// using [`ParallelIterator`].
15pub trait ParallelIterator<BatchIter>
16where
17    BatchIter: Iterator + Send,
18    Self: Sized + Send,
19{
20    /// Returns the next batch of items for processing.
21    ///
22    /// Each batch is an iterator with items of the same type as the
23    /// [`ParallelIterator`]. Returns `None` when there are no batches left.
24    fn next_batch(&mut self) -> Option<BatchIter>;
25
26    /// Returns the bounds on the remaining number of items in the
27    /// parallel iterator.
28    ///
29    /// See [`Iterator::size_hint()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.size_hint)
30    fn size_hint(&self) -> (usize, Option<usize>) {
31        (0, None)
32    }
33
34    /// Consumes the parallel iterator and returns the number of items.
35    ///
36    /// See [`Iterator::count()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.count)
37    fn count(mut self, pool: &TaskPool) -> usize {
38        pool.scope(|s| {
39            while let Some(batch) = self.next_batch() {
40                s.spawn(async move { batch.count() });
41            }
42        })
43        .iter()
44        .sum()
45    }
46
47    /// Consumes the parallel iterator and returns the last item.
48    ///
49    /// See [`Iterator::last()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.last)
50    fn last(mut self, _pool: &TaskPool) -> Option<BatchIter::Item> {
51        let mut last_item = None;
52        while let Some(batch) = self.next_batch() {
53            last_item = batch.last();
54        }
55        last_item
56    }
57
58    /// Consumes the parallel iterator and returns the nth item.
59    ///
60    /// See [`Iterator::nth()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.nth)
61    // TODO: Optimize with size_hint on each batch
62    fn nth(mut self, _pool: &TaskPool, n: usize) -> Option<BatchIter::Item> {
63        let mut i = 0;
64        while let Some(batch) = self.next_batch() {
65            for item in batch {
66                if i == n {
67                    return Some(item);
68                }
69                i += 1;
70            }
71        }
72        None
73    }
74
75    /// Takes two parallel iterators and returns a parallel iterators over
76    /// both in sequence.
77    ///
78    /// See [`Iterator::chain()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.chain)
79    // TODO: Use IntoParallelIterator for U
80    fn chain<U>(self, other: U) -> Chain<Self, U>
81    where
82        U: ParallelIterator<BatchIter>,
83    {
84        Chain {
85            left: self,
86            right: other,
87            left_in_progress: true,
88        }
89    }
90
91    /// Takes a closure and creates a parallel iterator which calls that
92    /// closure on each item.
93    ///
94    /// See [`Iterator::map()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.map)
95    fn map<T, F>(self, f: F) -> Map<Self, F>
96    where
97        F: FnMut(BatchIter::Item) -> T + Send + Clone,
98    {
99        Map { iter: self, f }
100    }
101
102    /// Calls a closure on each item of a parallel iterator.
103    ///
104    /// See [`Iterator::for_each()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.for_each)
105    fn for_each<F>(mut self, pool: &TaskPool, f: F)
106    where
107        F: FnMut(BatchIter::Item) + Send + Clone + Sync,
108    {
109        pool.scope(|s| {
110            while let Some(batch) = self.next_batch() {
111                let newf = f.clone();
112                s.spawn(async move {
113                    batch.for_each(newf);
114                });
115            }
116        });
117    }
118
119    /// Creates a parallel iterator which uses a closure to determine
120    /// if an element should be yielded.
121    ///
122    /// See [`Iterator::filter()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.filter)
123    fn filter<F>(self, predicate: F) -> Filter<Self, F>
124    where
125        F: FnMut(&BatchIter::Item) -> bool,
126    {
127        Filter {
128            iter: self,
129            predicate,
130        }
131    }
132
133    /// Creates a parallel iterator that both filters and maps.
134    ///
135    /// See [`Iterator::filter_map()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.filter_map)
136    fn filter_map<R, F>(self, f: F) -> FilterMap<Self, F>
137    where
138        F: FnMut(BatchIter::Item) -> Option<R>,
139    {
140        FilterMap { iter: self, f }
141    }
142
143    /// Creates a parallel iterator that works like map, but flattens
144    /// nested structure.
145    ///
146    /// See [`Iterator::flat_map()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.flat_map)
147    fn flat_map<U, F>(self, f: F) -> FlatMap<Self, F>
148    where
149        F: FnMut(BatchIter::Item) -> U,
150        U: IntoIterator,
151    {
152        FlatMap { iter: self, f }
153    }
154
155    /// Creates a parallel iterator that flattens nested structure.
156    ///
157    /// See [`Iterator::flatten()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.flatten)
158    fn flatten(self) -> Flatten<Self>
159    where
160        BatchIter::Item: IntoIterator,
161    {
162        Flatten { iter: self }
163    }
164
165    /// Creates a parallel iterator which ends after the first None.
166    ///
167    /// See [`Iterator::fuse()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.fuse)
168    fn fuse(self) -> Fuse<Self> {
169        Fuse { iter: Some(self) }
170    }
171
172    /// Does something with each item of a parallel iterator, passing
173    /// the value on.
174    ///
175    /// See [`Iterator::inspect()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.inspect)
176    fn inspect<F>(self, f: F) -> Inspect<Self, F>
177    where
178        F: FnMut(&BatchIter::Item),
179    {
180        Inspect { iter: self, f }
181    }
182
183    /// Borrows a parallel iterator, rather than consuming it.
184    ///
185    /// See [`Iterator::by_ref()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.by_ref)
186    fn by_ref(&mut self) -> &mut Self {
187        self
188    }
189
190    /// Transforms a parallel iterator into a collection.
191    ///
192    /// See [`Iterator::collect()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.collect)
193    // TODO: Investigate optimizations for less copying
194    fn collect<C>(mut self, pool: &TaskPool) -> C
195    where
196        C: FromIterator<BatchIter::Item>,
197        BatchIter::Item: Send + 'static,
198    {
199        pool.scope(|s| {
200            while let Some(batch) = self.next_batch() {
201                s.spawn(async move { batch.collect::<Vec<_>>() });
202            }
203        })
204        .into_iter()
205        .flatten()
206        .collect()
207    }
208
209    /// Consumes a parallel iterator, creating two collections from it.
210    ///
211    /// See [`Iterator::partition()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.partition)
212    // TODO: Investigate optimizations for less copying
213    fn partition<C, F>(mut self, pool: &TaskPool, f: F) -> (C, C)
214    where
215        C: Default + Extend<BatchIter::Item> + Send,
216        F: FnMut(&BatchIter::Item) -> bool + Send + Sync + Clone,
217        BatchIter::Item: Send + 'static,
218    {
219        let (mut a, mut b) = <(C, C)>::default();
220        pool.scope(|s| {
221            while let Some(batch) = self.next_batch() {
222                let newf = f.clone();
223                s.spawn(async move { batch.partition::<Vec<_>, F>(newf) });
224            }
225        })
226        .into_iter()
227        .for_each(|(c, d)| {
228            a.extend(c);
229            b.extend(d);
230        });
231        (a, b)
232    }
233
234    /// Repeatedly applies a function to items of each batch of a parallel
235    /// iterator, producing a Vec of final values.
236    ///
237    /// *Note that this folds each batch independently and returns a Vec of
238    /// results (in batch order).*
239    ///
240    /// See [`Iterator::fold()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.fold)
241    fn fold<C, F, D>(mut self, pool: &TaskPool, init: C, f: F) -> Vec<C>
242    where
243        F: FnMut(C, BatchIter::Item) -> C + Send + Sync + Clone,
244        C: Clone + Send + Sync + 'static,
245    {
246        pool.scope(|s| {
247            while let Some(batch) = self.next_batch() {
248                let newf = f.clone();
249                let newi = init.clone();
250                s.spawn(async move { batch.fold(newi, newf) });
251            }
252        })
253    }
254
255    /// Tests if every element of the parallel iterator matches a predicate.
256    ///
257    /// *Note that all is **not** short circuiting.*
258    ///
259    /// See [`Iterator::all()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.all)
260    fn all<F>(mut self, pool: &TaskPool, f: F) -> bool
261    where
262        F: FnMut(BatchIter::Item) -> bool + Send + Sync + Clone,
263    {
264        pool.scope(|s| {
265            while let Some(mut batch) = self.next_batch() {
266                let newf = f.clone();
267                s.spawn(async move { batch.all(newf) });
268            }
269        })
270        .into_iter()
271        .all(core::convert::identity)
272    }
273
274    /// Tests if any element of the parallel iterator matches a predicate.
275    ///
276    /// *Note that any is **not** short circuiting.*
277    ///
278    /// See [`Iterator::any()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.any)
279    fn any<F>(mut self, pool: &TaskPool, f: F) -> bool
280    where
281        F: FnMut(BatchIter::Item) -> bool + Send + Sync + Clone,
282    {
283        pool.scope(|s| {
284            while let Some(mut batch) = self.next_batch() {
285                let newf = f.clone();
286                s.spawn(async move { batch.any(newf) });
287            }
288        })
289        .into_iter()
290        .any(core::convert::identity)
291    }
292
293    /// Searches for an element in a parallel iterator, returning its index.
294    ///
295    /// *Note that position consumes the whole iterator.*
296    ///
297    /// See [`Iterator::position()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.position)
298    // TODO: Investigate optimizations for less copying
299    fn position<F>(mut self, pool: &TaskPool, f: F) -> Option<usize>
300    where
301        F: FnMut(BatchIter::Item) -> bool + Send + Sync + Clone,
302    {
303        let poses = pool.scope(|s| {
304            while let Some(batch) = self.next_batch() {
305                let mut newf = f.clone();
306                s.spawn(async move {
307                    let mut len = 0;
308                    let mut pos = None;
309                    for item in batch {
310                        if pos.is_none() && newf(item) {
311                            pos = Some(len);
312                        }
313                        len += 1;
314                    }
315                    (len, pos)
316                });
317            }
318        });
319        let mut start = 0;
320        for (len, pos) in poses {
321            if let Some(pos) = pos {
322                return Some(start + pos);
323            }
324            start += len;
325        }
326        None
327    }
328
329    /// Returns the maximum item of a parallel iterator.
330    ///
331    /// See [`Iterator::max()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.max)
332    fn max(mut self, pool: &TaskPool) -> Option<BatchIter::Item>
333    where
334        BatchIter::Item: Ord + Send + 'static,
335    {
336        pool.scope(|s| {
337            while let Some(batch) = self.next_batch() {
338                s.spawn(async move { batch.max() });
339            }
340        })
341        .into_iter()
342        .flatten()
343        .max()
344    }
345
346    /// Returns the minimum item of a parallel iterator.
347    ///
348    /// See [`Iterator::min()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.min)
349    fn min(mut self, pool: &TaskPool) -> Option<BatchIter::Item>
350    where
351        BatchIter::Item: Ord + Send + 'static,
352    {
353        pool.scope(|s| {
354            while let Some(batch) = self.next_batch() {
355                s.spawn(async move { batch.min() });
356            }
357        })
358        .into_iter()
359        .flatten()
360        .min()
361    }
362
363    /// Returns the item that gives the maximum value from the specified function.
364    ///
365    /// See [`Iterator::max_by_key()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.max_by_key)
366    fn max_by_key<R, F>(mut self, pool: &TaskPool, f: F) -> Option<BatchIter::Item>
367    where
368        R: Ord,
369        F: FnMut(&BatchIter::Item) -> R + Send + Sync + Clone,
370        BatchIter::Item: Send + 'static,
371    {
372        pool.scope(|s| {
373            while let Some(batch) = self.next_batch() {
374                let newf = f.clone();
375                s.spawn(async move { batch.max_by_key(newf) });
376            }
377        })
378        .into_iter()
379        .flatten()
380        .max_by_key(f)
381    }
382
383    /// Returns the item that gives the maximum value with respect to the specified comparison
384    /// function.
385    ///
386    /// See [`Iterator::max_by()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.max_by)
387    fn max_by<F>(mut self, pool: &TaskPool, f: F) -> Option<BatchIter::Item>
388    where
389        F: FnMut(&BatchIter::Item, &BatchIter::Item) -> core::cmp::Ordering + Send + Sync + Clone,
390        BatchIter::Item: Send + 'static,
391    {
392        pool.scope(|s| {
393            while let Some(batch) = self.next_batch() {
394                let newf = f.clone();
395                s.spawn(async move { batch.max_by(newf) });
396            }
397        })
398        .into_iter()
399        .flatten()
400        .max_by(f)
401    }
402
403    /// Returns the item that gives the minimum value from the specified function.
404    ///
405    /// See [`Iterator::min_by_key()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.min_by_key)
406    fn min_by_key<R, F>(mut self, pool: &TaskPool, f: F) -> Option<BatchIter::Item>
407    where
408        R: Ord,
409        F: FnMut(&BatchIter::Item) -> R + Send + Sync + Clone,
410        BatchIter::Item: Send + 'static,
411    {
412        pool.scope(|s| {
413            while let Some(batch) = self.next_batch() {
414                let newf = f.clone();
415                s.spawn(async move { batch.min_by_key(newf) });
416            }
417        })
418        .into_iter()
419        .flatten()
420        .min_by_key(f)
421    }
422
423    /// Returns the item that gives the minimum value with respect to the specified comparison
424    /// function.
425    ///
426    /// See [`Iterator::min_by()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.min_by)
427    fn min_by<F>(mut self, pool: &TaskPool, f: F) -> Option<BatchIter::Item>
428    where
429        F: FnMut(&BatchIter::Item, &BatchIter::Item) -> core::cmp::Ordering + Send + Sync + Clone,
430        BatchIter::Item: Send + 'static,
431    {
432        pool.scope(|s| {
433            while let Some(batch) = self.next_batch() {
434                let newf = f.clone();
435                s.spawn(async move { batch.min_by(newf) });
436            }
437        })
438        .into_iter()
439        .flatten()
440        .min_by(f)
441    }
442
443    /// Creates a parallel iterator which copies all of its items.
444    ///
445    /// See [`Iterator::copied()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.copied)
446    fn copied<'a, T>(self) -> Copied<Self>
447    where
448        Self: ParallelIterator<BatchIter>,
449        T: 'a + Copy,
450    {
451        Copied { iter: self }
452    }
453
454    /// Creates a parallel iterator which clones all of its items.
455    ///
456    /// See [`Iterator::cloned()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.cloned)
457    fn cloned<'a, T>(self) -> Cloned<Self>
458    where
459        Self: ParallelIterator<BatchIter>,
460        T: 'a + Copy,
461    {
462        Cloned { iter: self }
463    }
464
465    /// Repeats a parallel iterator endlessly.
466    ///
467    /// See [`Iterator::cycle()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.cycle)
468    fn cycle(self) -> Cycle<Self>
469    where
470        Self: Clone,
471    {
472        Cycle {
473            iter: self,
474            curr: None,
475        }
476    }
477
478    /// Sums the items of a parallel iterator.
479    ///
480    /// See [`Iterator::sum()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.sum)
481    fn sum<S, R>(mut self, pool: &TaskPool) -> R
482    where
483        S: core::iter::Sum<BatchIter::Item> + Send + 'static,
484        R: core::iter::Sum<S>,
485    {
486        pool.scope(|s| {
487            while let Some(batch) = self.next_batch() {
488                s.spawn(async move { batch.sum() });
489            }
490        })
491        .into_iter()
492        .sum()
493    }
494
495    /// Multiplies all the items of a parallel iterator.
496    ///
497    /// See [`Iterator::product()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.product)
498    fn product<S, R>(mut self, pool: &TaskPool) -> R
499    where
500        S: core::iter::Product<BatchIter::Item> + Send + 'static,
501        R: core::iter::Product<S>,
502    {
503        pool.scope(|s| {
504            while let Some(batch) = self.next_batch() {
505                s.spawn(async move { batch.product() });
506            }
507        })
508        .into_iter()
509        .product()
510    }
511}