bevy_tasks/iter/mod.rs
1use crate::TaskPool;
2use alloc::vec::Vec;
3
4mod adapters;
5pub use adapters::*;
6
7/// [`ParallelIterator`] closely emulates the `std::iter::Iterator`
8/// interface. However, it uses `bevy_task` to compute batches in parallel.
9///
10/// Note that the overhead of [`ParallelIterator`] is high relative to some
11/// workloads. In particular, if the batch size is too small or task being
12/// run in parallel is inexpensive, *a [`ParallelIterator`] could take longer
13/// than a normal [`Iterator`]*. Therefore, you should profile your code before
14/// using [`ParallelIterator`].
15pub trait ParallelIterator<BatchIter>
16where
17 BatchIter: Iterator + Send,
18 Self: Sized + Send,
19{
20 /// Returns the next batch of items for processing.
21 ///
22 /// Each batch is an iterator with items of the same type as the
23 /// [`ParallelIterator`]. Returns `None` when there are no batches left.
24 fn next_batch(&mut self) -> Option<BatchIter>;
25
26 /// Returns the bounds on the remaining number of items in the
27 /// parallel iterator.
28 ///
29 /// See [`Iterator::size_hint()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.size_hint)
30 fn size_hint(&self) -> (usize, Option<usize>) {
31 (0, None)
32 }
33
34 /// Consumes the parallel iterator and returns the number of items.
35 ///
36 /// See [`Iterator::count()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.count)
37 fn count(mut self, pool: &TaskPool) -> usize {
38 pool.scope(|s| {
39 while let Some(batch) = self.next_batch() {
40 s.spawn(async move { batch.count() });
41 }
42 })
43 .iter()
44 .sum()
45 }
46
47 /// Consumes the parallel iterator and returns the last item.
48 ///
49 /// See [`Iterator::last()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.last)
50 fn last(mut self, _pool: &TaskPool) -> Option<BatchIter::Item> {
51 let mut last_item = None;
52 while let Some(batch) = self.next_batch() {
53 last_item = batch.last();
54 }
55 last_item
56 }
57
58 /// Consumes the parallel iterator and returns the nth item.
59 ///
60 /// See [`Iterator::nth()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.nth)
61 // TODO: Optimize with size_hint on each batch
62 fn nth(mut self, _pool: &TaskPool, n: usize) -> Option<BatchIter::Item> {
63 let mut i = 0;
64 while let Some(batch) = self.next_batch() {
65 for item in batch {
66 if i == n {
67 return Some(item);
68 }
69 i += 1;
70 }
71 }
72 None
73 }
74
75 /// Takes two parallel iterators and returns a parallel iterators over
76 /// both in sequence.
77 ///
78 /// See [`Iterator::chain()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.chain)
79 // TODO: Use IntoParallelIterator for U
80 fn chain<U>(self, other: U) -> Chain<Self, U>
81 where
82 U: ParallelIterator<BatchIter>,
83 {
84 Chain {
85 left: self,
86 right: other,
87 left_in_progress: true,
88 }
89 }
90
91 /// Takes a closure and creates a parallel iterator which calls that
92 /// closure on each item.
93 ///
94 /// See [`Iterator::map()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.map)
95 fn map<T, F>(self, f: F) -> Map<Self, F>
96 where
97 F: FnMut(BatchIter::Item) -> T + Send + Clone,
98 {
99 Map { iter: self, f }
100 }
101
102 /// Calls a closure on each item of a parallel iterator.
103 ///
104 /// See [`Iterator::for_each()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.for_each)
105 fn for_each<F>(mut self, pool: &TaskPool, f: F)
106 where
107 F: FnMut(BatchIter::Item) + Send + Clone + Sync,
108 {
109 pool.scope(|s| {
110 while let Some(batch) = self.next_batch() {
111 let newf = f.clone();
112 s.spawn(async move {
113 batch.for_each(newf);
114 });
115 }
116 });
117 }
118
119 /// Creates a parallel iterator which uses a closure to determine
120 /// if an element should be yielded.
121 ///
122 /// See [`Iterator::filter()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.filter)
123 fn filter<F>(self, predicate: F) -> Filter<Self, F>
124 where
125 F: FnMut(&BatchIter::Item) -> bool,
126 {
127 Filter {
128 iter: self,
129 predicate,
130 }
131 }
132
133 /// Creates a parallel iterator that both filters and maps.
134 ///
135 /// See [`Iterator::filter_map()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.filter_map)
136 fn filter_map<R, F>(self, f: F) -> FilterMap<Self, F>
137 where
138 F: FnMut(BatchIter::Item) -> Option<R>,
139 {
140 FilterMap { iter: self, f }
141 }
142
143 /// Creates a parallel iterator that works like map, but flattens
144 /// nested structure.
145 ///
146 /// See [`Iterator::flat_map()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.flat_map)
147 fn flat_map<U, F>(self, f: F) -> FlatMap<Self, F>
148 where
149 F: FnMut(BatchIter::Item) -> U,
150 U: IntoIterator,
151 {
152 FlatMap { iter: self, f }
153 }
154
155 /// Creates a parallel iterator that flattens nested structure.
156 ///
157 /// See [`Iterator::flatten()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.flatten)
158 fn flatten(self) -> Flatten<Self>
159 where
160 BatchIter::Item: IntoIterator,
161 {
162 Flatten { iter: self }
163 }
164
165 /// Creates a parallel iterator which ends after the first None.
166 ///
167 /// See [`Iterator::fuse()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.fuse)
168 fn fuse(self) -> Fuse<Self> {
169 Fuse { iter: Some(self) }
170 }
171
172 /// Does something with each item of a parallel iterator, passing
173 /// the value on.
174 ///
175 /// See [`Iterator::inspect()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.inspect)
176 fn inspect<F>(self, f: F) -> Inspect<Self, F>
177 where
178 F: FnMut(&BatchIter::Item),
179 {
180 Inspect { iter: self, f }
181 }
182
183 /// Borrows a parallel iterator, rather than consuming it.
184 ///
185 /// See [`Iterator::by_ref()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.by_ref)
186 fn by_ref(&mut self) -> &mut Self {
187 self
188 }
189
190 /// Transforms a parallel iterator into a collection.
191 ///
192 /// See [`Iterator::collect()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.collect)
193 // TODO: Investigate optimizations for less copying
194 fn collect<C>(mut self, pool: &TaskPool) -> C
195 where
196 C: FromIterator<BatchIter::Item>,
197 BatchIter::Item: Send + 'static,
198 {
199 pool.scope(|s| {
200 while let Some(batch) = self.next_batch() {
201 s.spawn(async move { batch.collect::<Vec<_>>() });
202 }
203 })
204 .into_iter()
205 .flatten()
206 .collect()
207 }
208
209 /// Consumes a parallel iterator, creating two collections from it.
210 ///
211 /// See [`Iterator::partition()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.partition)
212 // TODO: Investigate optimizations for less copying
213 fn partition<C, F>(mut self, pool: &TaskPool, f: F) -> (C, C)
214 where
215 C: Default + Extend<BatchIter::Item> + Send,
216 F: FnMut(&BatchIter::Item) -> bool + Send + Sync + Clone,
217 BatchIter::Item: Send + 'static,
218 {
219 let (mut a, mut b) = <(C, C)>::default();
220 pool.scope(|s| {
221 while let Some(batch) = self.next_batch() {
222 let newf = f.clone();
223 s.spawn(async move { batch.partition::<Vec<_>, F>(newf) });
224 }
225 })
226 .into_iter()
227 .for_each(|(c, d)| {
228 a.extend(c);
229 b.extend(d);
230 });
231 (a, b)
232 }
233
234 /// Repeatedly applies a function to items of each batch of a parallel
235 /// iterator, producing a Vec of final values.
236 ///
237 /// *Note that this folds each batch independently and returns a Vec of
238 /// results (in batch order).*
239 ///
240 /// See [`Iterator::fold()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.fold)
241 fn fold<C, F, D>(mut self, pool: &TaskPool, init: C, f: F) -> Vec<C>
242 where
243 F: FnMut(C, BatchIter::Item) -> C + Send + Sync + Clone,
244 C: Clone + Send + Sync + 'static,
245 {
246 pool.scope(|s| {
247 while let Some(batch) = self.next_batch() {
248 let newf = f.clone();
249 let newi = init.clone();
250 s.spawn(async move { batch.fold(newi, newf) });
251 }
252 })
253 }
254
255 /// Tests if every element of the parallel iterator matches a predicate.
256 ///
257 /// *Note that all is **not** short circuiting.*
258 ///
259 /// See [`Iterator::all()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.all)
260 fn all<F>(mut self, pool: &TaskPool, f: F) -> bool
261 where
262 F: FnMut(BatchIter::Item) -> bool + Send + Sync + Clone,
263 {
264 pool.scope(|s| {
265 while let Some(mut batch) = self.next_batch() {
266 let newf = f.clone();
267 s.spawn(async move { batch.all(newf) });
268 }
269 })
270 .into_iter()
271 .all(core::convert::identity)
272 }
273
274 /// Tests if any element of the parallel iterator matches a predicate.
275 ///
276 /// *Note that any is **not** short circuiting.*
277 ///
278 /// See [`Iterator::any()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.any)
279 fn any<F>(mut self, pool: &TaskPool, f: F) -> bool
280 where
281 F: FnMut(BatchIter::Item) -> bool + Send + Sync + Clone,
282 {
283 pool.scope(|s| {
284 while let Some(mut batch) = self.next_batch() {
285 let newf = f.clone();
286 s.spawn(async move { batch.any(newf) });
287 }
288 })
289 .into_iter()
290 .any(core::convert::identity)
291 }
292
293 /// Searches for an element in a parallel iterator, returning its index.
294 ///
295 /// *Note that position consumes the whole iterator.*
296 ///
297 /// See [`Iterator::position()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.position)
298 // TODO: Investigate optimizations for less copying
299 fn position<F>(mut self, pool: &TaskPool, f: F) -> Option<usize>
300 where
301 F: FnMut(BatchIter::Item) -> bool + Send + Sync + Clone,
302 {
303 let poses = pool.scope(|s| {
304 while let Some(batch) = self.next_batch() {
305 let mut newf = f.clone();
306 s.spawn(async move {
307 let mut len = 0;
308 let mut pos = None;
309 for item in batch {
310 if pos.is_none() && newf(item) {
311 pos = Some(len);
312 }
313 len += 1;
314 }
315 (len, pos)
316 });
317 }
318 });
319 let mut start = 0;
320 for (len, pos) in poses {
321 if let Some(pos) = pos {
322 return Some(start + pos);
323 }
324 start += len;
325 }
326 None
327 }
328
329 /// Returns the maximum item of a parallel iterator.
330 ///
331 /// See [`Iterator::max()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.max)
332 fn max(mut self, pool: &TaskPool) -> Option<BatchIter::Item>
333 where
334 BatchIter::Item: Ord + Send + 'static,
335 {
336 pool.scope(|s| {
337 while let Some(batch) = self.next_batch() {
338 s.spawn(async move { batch.max() });
339 }
340 })
341 .into_iter()
342 .flatten()
343 .max()
344 }
345
346 /// Returns the minimum item of a parallel iterator.
347 ///
348 /// See [`Iterator::min()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.min)
349 fn min(mut self, pool: &TaskPool) -> Option<BatchIter::Item>
350 where
351 BatchIter::Item: Ord + Send + 'static,
352 {
353 pool.scope(|s| {
354 while let Some(batch) = self.next_batch() {
355 s.spawn(async move { batch.min() });
356 }
357 })
358 .into_iter()
359 .flatten()
360 .min()
361 }
362
363 /// Returns the item that gives the maximum value from the specified function.
364 ///
365 /// See [`Iterator::max_by_key()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.max_by_key)
366 fn max_by_key<R, F>(mut self, pool: &TaskPool, f: F) -> Option<BatchIter::Item>
367 where
368 R: Ord,
369 F: FnMut(&BatchIter::Item) -> R + Send + Sync + Clone,
370 BatchIter::Item: Send + 'static,
371 {
372 pool.scope(|s| {
373 while let Some(batch) = self.next_batch() {
374 let newf = f.clone();
375 s.spawn(async move { batch.max_by_key(newf) });
376 }
377 })
378 .into_iter()
379 .flatten()
380 .max_by_key(f)
381 }
382
383 /// Returns the item that gives the maximum value with respect to the specified comparison
384 /// function.
385 ///
386 /// See [`Iterator::max_by()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.max_by)
387 fn max_by<F>(mut self, pool: &TaskPool, f: F) -> Option<BatchIter::Item>
388 where
389 F: FnMut(&BatchIter::Item, &BatchIter::Item) -> core::cmp::Ordering + Send + Sync + Clone,
390 BatchIter::Item: Send + 'static,
391 {
392 pool.scope(|s| {
393 while let Some(batch) = self.next_batch() {
394 let newf = f.clone();
395 s.spawn(async move { batch.max_by(newf) });
396 }
397 })
398 .into_iter()
399 .flatten()
400 .max_by(f)
401 }
402
403 /// Returns the item that gives the minimum value from the specified function.
404 ///
405 /// See [`Iterator::min_by_key()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.min_by_key)
406 fn min_by_key<R, F>(mut self, pool: &TaskPool, f: F) -> Option<BatchIter::Item>
407 where
408 R: Ord,
409 F: FnMut(&BatchIter::Item) -> R + Send + Sync + Clone,
410 BatchIter::Item: Send + 'static,
411 {
412 pool.scope(|s| {
413 while let Some(batch) = self.next_batch() {
414 let newf = f.clone();
415 s.spawn(async move { batch.min_by_key(newf) });
416 }
417 })
418 .into_iter()
419 .flatten()
420 .min_by_key(f)
421 }
422
423 /// Returns the item that gives the minimum value with respect to the specified comparison
424 /// function.
425 ///
426 /// See [`Iterator::min_by()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.min_by)
427 fn min_by<F>(mut self, pool: &TaskPool, f: F) -> Option<BatchIter::Item>
428 where
429 F: FnMut(&BatchIter::Item, &BatchIter::Item) -> core::cmp::Ordering + Send + Sync + Clone,
430 BatchIter::Item: Send + 'static,
431 {
432 pool.scope(|s| {
433 while let Some(batch) = self.next_batch() {
434 let newf = f.clone();
435 s.spawn(async move { batch.min_by(newf) });
436 }
437 })
438 .into_iter()
439 .flatten()
440 .min_by(f)
441 }
442
443 /// Creates a parallel iterator which copies all of its items.
444 ///
445 /// See [`Iterator::copied()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.copied)
446 fn copied<'a, T>(self) -> Copied<Self>
447 where
448 Self: ParallelIterator<BatchIter>,
449 T: 'a + Copy,
450 {
451 Copied { iter: self }
452 }
453
454 /// Creates a parallel iterator which clones all of its items.
455 ///
456 /// See [`Iterator::cloned()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.cloned)
457 fn cloned<'a, T>(self) -> Cloned<Self>
458 where
459 Self: ParallelIterator<BatchIter>,
460 T: 'a + Copy,
461 {
462 Cloned { iter: self }
463 }
464
465 /// Repeats a parallel iterator endlessly.
466 ///
467 /// See [`Iterator::cycle()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.cycle)
468 fn cycle(self) -> Cycle<Self>
469 where
470 Self: Clone,
471 {
472 Cycle {
473 iter: self,
474 curr: None,
475 }
476 }
477
478 /// Sums the items of a parallel iterator.
479 ///
480 /// See [`Iterator::sum()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.sum)
481 fn sum<S, R>(mut self, pool: &TaskPool) -> R
482 where
483 S: core::iter::Sum<BatchIter::Item> + Send + 'static,
484 R: core::iter::Sum<S>,
485 {
486 pool.scope(|s| {
487 while let Some(batch) = self.next_batch() {
488 s.spawn(async move { batch.sum() });
489 }
490 })
491 .into_iter()
492 .sum()
493 }
494
495 /// Multiplies all the items of a parallel iterator.
496 ///
497 /// See [`Iterator::product()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.product)
498 fn product<S, R>(mut self, pool: &TaskPool) -> R
499 where
500 S: core::iter::Product<BatchIter::Item> + Send + 'static,
501 R: core::iter::Product<S>,
502 {
503 pool.scope(|s| {
504 while let Some(batch) = self.next_batch() {
505 s.spawn(async move { batch.product() });
506 }
507 })
508 .into_iter()
509 .product()
510 }
511}