bevy_tasks/iter/
mod.rs

1use crate::TaskPool;
2
3mod adapters;
4pub use adapters::*;
5
6/// [`ParallelIterator`] closely emulates the `std::iter::Iterator`
7/// interface. However, it uses `bevy_task` to compute batches in parallel.
8///
9/// Note that the overhead of [`ParallelIterator`] is high relative to some
10/// workloads. In particular, if the batch size is too small or task being
11/// run in parallel is inexpensive, *a [`ParallelIterator`] could take longer
12/// than a normal [`Iterator`]*. Therefore, you should profile your code before
13/// using [`ParallelIterator`].
14pub trait ParallelIterator<BatchIter>
15where
16    BatchIter: Iterator + Send,
17    Self: Sized + Send,
18{
19    /// Returns the next batch of items for processing.
20    ///
21    /// Each batch is an iterator with items of the same type as the
22    /// [`ParallelIterator`]. Returns `None` when there are no batches left.
23    fn next_batch(&mut self) -> Option<BatchIter>;
24
25    /// Returns the bounds on the remaining number of items in the
26    /// parallel iterator.
27    ///
28    /// See [`Iterator::size_hint()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.size_hint)
29    fn size_hint(&self) -> (usize, Option<usize>) {
30        (0, None)
31    }
32
33    /// Consumes the parallel iterator and returns the number of items.
34    ///
35    /// See [`Iterator::count()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.count)
36    fn count(mut self, pool: &TaskPool) -> usize {
37        pool.scope(|s| {
38            while let Some(batch) = self.next_batch() {
39                s.spawn(async move { batch.count() });
40            }
41        })
42        .iter()
43        .sum()
44    }
45
46    /// Consumes the parallel iterator and returns the last item.
47    ///
48    /// See [`Iterator::last()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.last)
49    fn last(mut self, _pool: &TaskPool) -> Option<BatchIter::Item> {
50        let mut last_item = None;
51        while let Some(batch) = self.next_batch() {
52            last_item = batch.last();
53        }
54        last_item
55    }
56
57    /// Consumes the parallel iterator and returns the nth item.
58    ///
59    /// See [`Iterator::nth()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.nth)
60    // TODO: Optimize with size_hint on each batch
61    fn nth(mut self, _pool: &TaskPool, n: usize) -> Option<BatchIter::Item> {
62        let mut i = 0;
63        while let Some(batch) = self.next_batch() {
64            for item in batch {
65                if i == n {
66                    return Some(item);
67                }
68                i += 1;
69            }
70        }
71        None
72    }
73
74    /// Takes two parallel iterators and returns a parallel iterators over
75    /// both in sequence.
76    ///
77    /// See [`Iterator::chain()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.chain)
78    // TODO: Use IntoParallelIterator for U
79    fn chain<U>(self, other: U) -> Chain<Self, U>
80    where
81        U: ParallelIterator<BatchIter>,
82    {
83        Chain {
84            left: self,
85            right: other,
86            left_in_progress: true,
87        }
88    }
89
90    /// Takes a closure and creates a parallel iterator which calls that
91    /// closure on each item.
92    ///
93    /// See [`Iterator::map()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.map)
94    fn map<T, F>(self, f: F) -> Map<Self, F>
95    where
96        F: FnMut(BatchIter::Item) -> T + Send + Clone,
97    {
98        Map { iter: self, f }
99    }
100
101    /// Calls a closure on each item of a parallel iterator.
102    ///
103    /// See [`Iterator::for_each()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.for_each)
104    fn for_each<F>(mut self, pool: &TaskPool, f: F)
105    where
106        F: FnMut(BatchIter::Item) + Send + Clone + Sync,
107    {
108        pool.scope(|s| {
109            while let Some(batch) = self.next_batch() {
110                let newf = f.clone();
111                s.spawn(async move {
112                    batch.for_each(newf);
113                });
114            }
115        });
116    }
117
118    /// Creates a parallel iterator which uses a closure to determine
119    /// if an element should be yielded.
120    ///
121    /// See [`Iterator::filter()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.filter)
122    fn filter<F>(self, predicate: F) -> Filter<Self, F>
123    where
124        F: FnMut(&BatchIter::Item) -> bool,
125    {
126        Filter {
127            iter: self,
128            predicate,
129        }
130    }
131
132    /// Creates a parallel iterator that both filters and maps.
133    ///
134    /// See [`Iterator::filter_map()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.filter_map)
135    fn filter_map<R, F>(self, f: F) -> FilterMap<Self, F>
136    where
137        F: FnMut(BatchIter::Item) -> Option<R>,
138    {
139        FilterMap { iter: self, f }
140    }
141
142    /// Creates a parallel iterator that works like map, but flattens
143    /// nested structure.
144    ///
145    /// See [`Iterator::flat_map()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.flat_map)
146    fn flat_map<U, F>(self, f: F) -> FlatMap<Self, F>
147    where
148        F: FnMut(BatchIter::Item) -> U,
149        U: IntoIterator,
150    {
151        FlatMap { iter: self, f }
152    }
153
154    /// Creates a parallel iterator that flattens nested structure.
155    ///
156    /// See [`Iterator::flatten()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.flatten)
157    fn flatten(self) -> Flatten<Self>
158    where
159        BatchIter::Item: IntoIterator,
160    {
161        Flatten { iter: self }
162    }
163
164    /// Creates a parallel iterator which ends after the first None.
165    ///
166    /// See [`Iterator::fuse()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.fuse)
167    fn fuse(self) -> Fuse<Self> {
168        Fuse { iter: Some(self) }
169    }
170
171    /// Does something with each item of a parallel iterator, passing
172    /// the value on.
173    ///
174    /// See [`Iterator::inspect()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.inspect)
175    fn inspect<F>(self, f: F) -> Inspect<Self, F>
176    where
177        F: FnMut(&BatchIter::Item),
178    {
179        Inspect { iter: self, f }
180    }
181
182    /// Borrows a parallel iterator, rather than consuming it.
183    ///
184    /// See [`Iterator::by_ref()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.by_ref)
185    fn by_ref(&mut self) -> &mut Self {
186        self
187    }
188
189    /// Transforms a parallel iterator into a collection.
190    ///
191    /// See [`Iterator::collect()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.collect)
192    // TODO: Investigate optimizations for less copying
193    fn collect<C>(mut self, pool: &TaskPool) -> C
194    where
195        C: FromIterator<BatchIter::Item>,
196        BatchIter::Item: Send + 'static,
197    {
198        pool.scope(|s| {
199            while let Some(batch) = self.next_batch() {
200                s.spawn(async move { batch.collect::<Vec<_>>() });
201            }
202        })
203        .into_iter()
204        .flatten()
205        .collect()
206    }
207
208    /// Consumes a parallel iterator, creating two collections from it.
209    ///
210    /// See [`Iterator::partition()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.partition)
211    // TODO: Investigate optimizations for less copying
212    fn partition<C, F>(mut self, pool: &TaskPool, f: F) -> (C, C)
213    where
214        C: Default + Extend<BatchIter::Item> + Send,
215        F: FnMut(&BatchIter::Item) -> bool + Send + Sync + Clone,
216        BatchIter::Item: Send + 'static,
217    {
218        let (mut a, mut b) = <(C, C)>::default();
219        pool.scope(|s| {
220            while let Some(batch) = self.next_batch() {
221                let newf = f.clone();
222                s.spawn(async move { batch.partition::<Vec<_>, F>(newf) });
223            }
224        })
225        .into_iter()
226        .for_each(|(c, d)| {
227            a.extend(c);
228            b.extend(d);
229        });
230        (a, b)
231    }
232
233    /// Repeatedly applies a function to items of each batch of a parallel
234    /// iterator, producing a Vec of final values.
235    ///
236    /// *Note that this folds each batch independently and returns a Vec of
237    /// results (in batch order).*
238    ///
239    /// See [`Iterator::fold()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.fold)
240    fn fold<C, F, D>(mut self, pool: &TaskPool, init: C, f: F) -> Vec<C>
241    where
242        F: FnMut(C, BatchIter::Item) -> C + Send + Sync + Clone,
243        C: Clone + Send + Sync + 'static,
244    {
245        pool.scope(|s| {
246            while let Some(batch) = self.next_batch() {
247                let newf = f.clone();
248                let newi = init.clone();
249                s.spawn(async move { batch.fold(newi, newf) });
250            }
251        })
252    }
253
254    /// Tests if every element of the parallel iterator matches a predicate.
255    ///
256    /// *Note that all is **not** short circuiting.*
257    ///
258    /// See [`Iterator::all()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.all)
259    fn all<F>(mut self, pool: &TaskPool, f: F) -> bool
260    where
261        F: FnMut(BatchIter::Item) -> bool + Send + Sync + Clone,
262    {
263        pool.scope(|s| {
264            while let Some(mut batch) = self.next_batch() {
265                let newf = f.clone();
266                s.spawn(async move { batch.all(newf) });
267            }
268        })
269        .into_iter()
270        .all(core::convert::identity)
271    }
272
273    /// Tests if any element of the parallel iterator matches a predicate.
274    ///
275    /// *Note that any is **not** short circuiting.*
276    ///
277    /// See [`Iterator::any()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.any)
278    fn any<F>(mut self, pool: &TaskPool, f: F) -> bool
279    where
280        F: FnMut(BatchIter::Item) -> bool + Send + Sync + Clone,
281    {
282        pool.scope(|s| {
283            while let Some(mut batch) = self.next_batch() {
284                let newf = f.clone();
285                s.spawn(async move { batch.any(newf) });
286            }
287        })
288        .into_iter()
289        .any(core::convert::identity)
290    }
291
292    /// Searches for an element in a parallel iterator, returning its index.
293    ///
294    /// *Note that position consumes the whole iterator.*
295    ///
296    /// See [`Iterator::position()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.position)
297    // TODO: Investigate optimizations for less copying
298    fn position<F>(mut self, pool: &TaskPool, f: F) -> Option<usize>
299    where
300        F: FnMut(BatchIter::Item) -> bool + Send + Sync + Clone,
301    {
302        let poses = pool.scope(|s| {
303            while let Some(batch) = self.next_batch() {
304                let mut newf = f.clone();
305                s.spawn(async move {
306                    let mut len = 0;
307                    let mut pos = None;
308                    for item in batch {
309                        if pos.is_none() && newf(item) {
310                            pos = Some(len);
311                        }
312                        len += 1;
313                    }
314                    (len, pos)
315                });
316            }
317        });
318        let mut start = 0;
319        for (len, pos) in poses {
320            if let Some(pos) = pos {
321                return Some(start + pos);
322            }
323            start += len;
324        }
325        None
326    }
327
328    /// Returns the maximum item of a parallel iterator.
329    ///
330    /// See [`Iterator::max()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.max)
331    fn max(mut self, pool: &TaskPool) -> Option<BatchIter::Item>
332    where
333        BatchIter::Item: Ord + Send + 'static,
334    {
335        pool.scope(|s| {
336            while let Some(batch) = self.next_batch() {
337                s.spawn(async move { batch.max() });
338            }
339        })
340        .into_iter()
341        .flatten()
342        .max()
343    }
344
345    /// Returns the minimum item of a parallel iterator.
346    ///
347    /// See [`Iterator::min()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.min)
348    fn min(mut self, pool: &TaskPool) -> Option<BatchIter::Item>
349    where
350        BatchIter::Item: Ord + Send + 'static,
351    {
352        pool.scope(|s| {
353            while let Some(batch) = self.next_batch() {
354                s.spawn(async move { batch.min() });
355            }
356        })
357        .into_iter()
358        .flatten()
359        .min()
360    }
361
362    /// Returns the item that gives the maximum value from the specified function.
363    ///
364    /// See [`Iterator::max_by_key()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.max_by_key)
365    fn max_by_key<R, F>(mut self, pool: &TaskPool, f: F) -> Option<BatchIter::Item>
366    where
367        R: Ord,
368        F: FnMut(&BatchIter::Item) -> R + Send + Sync + Clone,
369        BatchIter::Item: Send + 'static,
370    {
371        pool.scope(|s| {
372            while let Some(batch) = self.next_batch() {
373                let newf = f.clone();
374                s.spawn(async move { batch.max_by_key(newf) });
375            }
376        })
377        .into_iter()
378        .flatten()
379        .max_by_key(f)
380    }
381
382    /// Returns the item that gives the maximum value with respect to the specified comparison
383    /// function.
384    ///
385    /// See [`Iterator::max_by()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.max_by)
386    fn max_by<F>(mut self, pool: &TaskPool, f: F) -> Option<BatchIter::Item>
387    where
388        F: FnMut(&BatchIter::Item, &BatchIter::Item) -> core::cmp::Ordering + Send + Sync + Clone,
389        BatchIter::Item: Send + 'static,
390    {
391        pool.scope(|s| {
392            while let Some(batch) = self.next_batch() {
393                let newf = f.clone();
394                s.spawn(async move { batch.max_by(newf) });
395            }
396        })
397        .into_iter()
398        .flatten()
399        .max_by(f)
400    }
401
402    /// Returns the item that gives the minimum value from the specified function.
403    ///
404    /// See [`Iterator::min_by_key()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.min_by_key)
405    fn min_by_key<R, F>(mut self, pool: &TaskPool, f: F) -> Option<BatchIter::Item>
406    where
407        R: Ord,
408        F: FnMut(&BatchIter::Item) -> R + Send + Sync + Clone,
409        BatchIter::Item: Send + 'static,
410    {
411        pool.scope(|s| {
412            while let Some(batch) = self.next_batch() {
413                let newf = f.clone();
414                s.spawn(async move { batch.min_by_key(newf) });
415            }
416        })
417        .into_iter()
418        .flatten()
419        .min_by_key(f)
420    }
421
422    /// Returns the item that gives the minimum value with respect to the specified comparison
423    /// function.
424    ///
425    /// See [`Iterator::min_by()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.min_by)
426    fn min_by<F>(mut self, pool: &TaskPool, f: F) -> Option<BatchIter::Item>
427    where
428        F: FnMut(&BatchIter::Item, &BatchIter::Item) -> core::cmp::Ordering + Send + Sync + Clone,
429        BatchIter::Item: Send + 'static,
430    {
431        pool.scope(|s| {
432            while let Some(batch) = self.next_batch() {
433                let newf = f.clone();
434                s.spawn(async move { batch.min_by(newf) });
435            }
436        })
437        .into_iter()
438        .flatten()
439        .min_by(f)
440    }
441
442    /// Creates a parallel iterator which copies all of its items.
443    ///
444    /// See [`Iterator::copied()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.copied)
445    fn copied<'a, T>(self) -> Copied<Self>
446    where
447        Self: ParallelIterator<BatchIter>,
448        T: 'a + Copy,
449    {
450        Copied { iter: self }
451    }
452
453    /// Creates a parallel iterator which clones all of its items.
454    ///
455    /// See [`Iterator::cloned()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.cloned)
456    fn cloned<'a, T>(self) -> Cloned<Self>
457    where
458        Self: ParallelIterator<BatchIter>,
459        T: 'a + Copy,
460    {
461        Cloned { iter: self }
462    }
463
464    /// Repeats a parallel iterator endlessly.
465    ///
466    /// See [`Iterator::cycle()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.cycle)
467    fn cycle(self) -> Cycle<Self>
468    where
469        Self: Clone,
470    {
471        Cycle {
472            iter: self,
473            curr: None,
474        }
475    }
476
477    /// Sums the items of a parallel iterator.
478    ///
479    /// See [`Iterator::sum()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.sum)
480    fn sum<S, R>(mut self, pool: &TaskPool) -> R
481    where
482        S: core::iter::Sum<BatchIter::Item> + Send + 'static,
483        R: core::iter::Sum<S>,
484    {
485        pool.scope(|s| {
486            while let Some(batch) = self.next_batch() {
487                s.spawn(async move { batch.sum() });
488            }
489        })
490        .into_iter()
491        .sum()
492    }
493
494    /// Multiplies all the items of a parallel iterator.
495    ///
496    /// See [`Iterator::product()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.product)
497    fn product<S, R>(mut self, pool: &TaskPool) -> R
498    where
499        S: core::iter::Product<BatchIter::Item> + Send + 'static,
500        R: core::iter::Product<S>,
501    {
502        pool.scope(|s| {
503            while let Some(batch) = self.next_batch() {
504                s.spawn(async move { batch.product() });
505            }
506        })
507        .into_iter()
508        .product()
509    }
510}