bevy_tasks/iter/mod.rs
1use crate::TaskPool;
2
3mod adapters;
4pub use adapters::*;
5
6/// [`ParallelIterator`] closely emulates the `std::iter::Iterator`
7/// interface. However, it uses `bevy_task` to compute batches in parallel.
8///
9/// Note that the overhead of [`ParallelIterator`] is high relative to some
10/// workloads. In particular, if the batch size is too small or task being
11/// run in parallel is inexpensive, *a [`ParallelIterator`] could take longer
12/// than a normal [`Iterator`]*. Therefore, you should profile your code before
13/// using [`ParallelIterator`].
14pub trait ParallelIterator<BatchIter>
15where
16 BatchIter: Iterator + Send,
17 Self: Sized + Send,
18{
19 /// Returns the next batch of items for processing.
20 ///
21 /// Each batch is an iterator with items of the same type as the
22 /// [`ParallelIterator`]. Returns `None` when there are no batches left.
23 fn next_batch(&mut self) -> Option<BatchIter>;
24
25 /// Returns the bounds on the remaining number of items in the
26 /// parallel iterator.
27 ///
28 /// See [`Iterator::size_hint()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.size_hint)
29 fn size_hint(&self) -> (usize, Option<usize>) {
30 (0, None)
31 }
32
33 /// Consumes the parallel iterator and returns the number of items.
34 ///
35 /// See [`Iterator::count()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.count)
36 fn count(mut self, pool: &TaskPool) -> usize {
37 pool.scope(|s| {
38 while let Some(batch) = self.next_batch() {
39 s.spawn(async move { batch.count() });
40 }
41 })
42 .iter()
43 .sum()
44 }
45
46 /// Consumes the parallel iterator and returns the last item.
47 ///
48 /// See [`Iterator::last()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.last)
49 fn last(mut self, _pool: &TaskPool) -> Option<BatchIter::Item> {
50 let mut last_item = None;
51 while let Some(batch) = self.next_batch() {
52 last_item = batch.last();
53 }
54 last_item
55 }
56
57 /// Consumes the parallel iterator and returns the nth item.
58 ///
59 /// See [`Iterator::nth()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.nth)
60 // TODO: Optimize with size_hint on each batch
61 fn nth(mut self, _pool: &TaskPool, n: usize) -> Option<BatchIter::Item> {
62 let mut i = 0;
63 while let Some(batch) = self.next_batch() {
64 for item in batch {
65 if i == n {
66 return Some(item);
67 }
68 i += 1;
69 }
70 }
71 None
72 }
73
74 /// Takes two parallel iterators and returns a parallel iterators over
75 /// both in sequence.
76 ///
77 /// See [`Iterator::chain()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.chain)
78 // TODO: Use IntoParallelIterator for U
79 fn chain<U>(self, other: U) -> Chain<Self, U>
80 where
81 U: ParallelIterator<BatchIter>,
82 {
83 Chain {
84 left: self,
85 right: other,
86 left_in_progress: true,
87 }
88 }
89
90 /// Takes a closure and creates a parallel iterator which calls that
91 /// closure on each item.
92 ///
93 /// See [`Iterator::map()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.map)
94 fn map<T, F>(self, f: F) -> Map<Self, F>
95 where
96 F: FnMut(BatchIter::Item) -> T + Send + Clone,
97 {
98 Map { iter: self, f }
99 }
100
101 /// Calls a closure on each item of a parallel iterator.
102 ///
103 /// See [`Iterator::for_each()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.for_each)
104 fn for_each<F>(mut self, pool: &TaskPool, f: F)
105 where
106 F: FnMut(BatchIter::Item) + Send + Clone + Sync,
107 {
108 pool.scope(|s| {
109 while let Some(batch) = self.next_batch() {
110 let newf = f.clone();
111 s.spawn(async move {
112 batch.for_each(newf);
113 });
114 }
115 });
116 }
117
118 /// Creates a parallel iterator which uses a closure to determine
119 /// if an element should be yielded.
120 ///
121 /// See [`Iterator::filter()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.filter)
122 fn filter<F>(self, predicate: F) -> Filter<Self, F>
123 where
124 F: FnMut(&BatchIter::Item) -> bool,
125 {
126 Filter {
127 iter: self,
128 predicate,
129 }
130 }
131
132 /// Creates a parallel iterator that both filters and maps.
133 ///
134 /// See [`Iterator::filter_map()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.filter_map)
135 fn filter_map<R, F>(self, f: F) -> FilterMap<Self, F>
136 where
137 F: FnMut(BatchIter::Item) -> Option<R>,
138 {
139 FilterMap { iter: self, f }
140 }
141
142 /// Creates a parallel iterator that works like map, but flattens
143 /// nested structure.
144 ///
145 /// See [`Iterator::flat_map()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.flat_map)
146 fn flat_map<U, F>(self, f: F) -> FlatMap<Self, F>
147 where
148 F: FnMut(BatchIter::Item) -> U,
149 U: IntoIterator,
150 {
151 FlatMap { iter: self, f }
152 }
153
154 /// Creates a parallel iterator that flattens nested structure.
155 ///
156 /// See [`Iterator::flatten()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.flatten)
157 fn flatten(self) -> Flatten<Self>
158 where
159 BatchIter::Item: IntoIterator,
160 {
161 Flatten { iter: self }
162 }
163
164 /// Creates a parallel iterator which ends after the first None.
165 ///
166 /// See [`Iterator::fuse()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.fuse)
167 fn fuse(self) -> Fuse<Self> {
168 Fuse { iter: Some(self) }
169 }
170
171 /// Does something with each item of a parallel iterator, passing
172 /// the value on.
173 ///
174 /// See [`Iterator::inspect()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.inspect)
175 fn inspect<F>(self, f: F) -> Inspect<Self, F>
176 where
177 F: FnMut(&BatchIter::Item),
178 {
179 Inspect { iter: self, f }
180 }
181
182 /// Borrows a parallel iterator, rather than consuming it.
183 ///
184 /// See [`Iterator::by_ref()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.by_ref)
185 fn by_ref(&mut self) -> &mut Self {
186 self
187 }
188
189 /// Transforms a parallel iterator into a collection.
190 ///
191 /// See [`Iterator::collect()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.collect)
192 // TODO: Investigate optimizations for less copying
193 fn collect<C>(mut self, pool: &TaskPool) -> C
194 where
195 C: FromIterator<BatchIter::Item>,
196 BatchIter::Item: Send + 'static,
197 {
198 pool.scope(|s| {
199 while let Some(batch) = self.next_batch() {
200 s.spawn(async move { batch.collect::<Vec<_>>() });
201 }
202 })
203 .into_iter()
204 .flatten()
205 .collect()
206 }
207
208 /// Consumes a parallel iterator, creating two collections from it.
209 ///
210 /// See [`Iterator::partition()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.partition)
211 // TODO: Investigate optimizations for less copying
212 fn partition<C, F>(mut self, pool: &TaskPool, f: F) -> (C, C)
213 where
214 C: Default + Extend<BatchIter::Item> + Send,
215 F: FnMut(&BatchIter::Item) -> bool + Send + Sync + Clone,
216 BatchIter::Item: Send + 'static,
217 {
218 let (mut a, mut b) = <(C, C)>::default();
219 pool.scope(|s| {
220 while let Some(batch) = self.next_batch() {
221 let newf = f.clone();
222 s.spawn(async move { batch.partition::<Vec<_>, F>(newf) });
223 }
224 })
225 .into_iter()
226 .for_each(|(c, d)| {
227 a.extend(c);
228 b.extend(d);
229 });
230 (a, b)
231 }
232
233 /// Repeatedly applies a function to items of each batch of a parallel
234 /// iterator, producing a Vec of final values.
235 ///
236 /// *Note that this folds each batch independently and returns a Vec of
237 /// results (in batch order).*
238 ///
239 /// See [`Iterator::fold()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.fold)
240 fn fold<C, F, D>(mut self, pool: &TaskPool, init: C, f: F) -> Vec<C>
241 where
242 F: FnMut(C, BatchIter::Item) -> C + Send + Sync + Clone,
243 C: Clone + Send + Sync + 'static,
244 {
245 pool.scope(|s| {
246 while let Some(batch) = self.next_batch() {
247 let newf = f.clone();
248 let newi = init.clone();
249 s.spawn(async move { batch.fold(newi, newf) });
250 }
251 })
252 }
253
254 /// Tests if every element of the parallel iterator matches a predicate.
255 ///
256 /// *Note that all is **not** short circuiting.*
257 ///
258 /// See [`Iterator::all()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.all)
259 fn all<F>(mut self, pool: &TaskPool, f: F) -> bool
260 where
261 F: FnMut(BatchIter::Item) -> bool + Send + Sync + Clone,
262 {
263 pool.scope(|s| {
264 while let Some(mut batch) = self.next_batch() {
265 let newf = f.clone();
266 s.spawn(async move { batch.all(newf) });
267 }
268 })
269 .into_iter()
270 .all(core::convert::identity)
271 }
272
273 /// Tests if any element of the parallel iterator matches a predicate.
274 ///
275 /// *Note that any is **not** short circuiting.*
276 ///
277 /// See [`Iterator::any()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.any)
278 fn any<F>(mut self, pool: &TaskPool, f: F) -> bool
279 where
280 F: FnMut(BatchIter::Item) -> bool + Send + Sync + Clone,
281 {
282 pool.scope(|s| {
283 while let Some(mut batch) = self.next_batch() {
284 let newf = f.clone();
285 s.spawn(async move { batch.any(newf) });
286 }
287 })
288 .into_iter()
289 .any(core::convert::identity)
290 }
291
292 /// Searches for an element in a parallel iterator, returning its index.
293 ///
294 /// *Note that position consumes the whole iterator.*
295 ///
296 /// See [`Iterator::position()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.position)
297 // TODO: Investigate optimizations for less copying
298 fn position<F>(mut self, pool: &TaskPool, f: F) -> Option<usize>
299 where
300 F: FnMut(BatchIter::Item) -> bool + Send + Sync + Clone,
301 {
302 let poses = pool.scope(|s| {
303 while let Some(batch) = self.next_batch() {
304 let mut newf = f.clone();
305 s.spawn(async move {
306 let mut len = 0;
307 let mut pos = None;
308 for item in batch {
309 if pos.is_none() && newf(item) {
310 pos = Some(len);
311 }
312 len += 1;
313 }
314 (len, pos)
315 });
316 }
317 });
318 let mut start = 0;
319 for (len, pos) in poses {
320 if let Some(pos) = pos {
321 return Some(start + pos);
322 }
323 start += len;
324 }
325 None
326 }
327
328 /// Returns the maximum item of a parallel iterator.
329 ///
330 /// See [`Iterator::max()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.max)
331 fn max(mut self, pool: &TaskPool) -> Option<BatchIter::Item>
332 where
333 BatchIter::Item: Ord + Send + 'static,
334 {
335 pool.scope(|s| {
336 while let Some(batch) = self.next_batch() {
337 s.spawn(async move { batch.max() });
338 }
339 })
340 .into_iter()
341 .flatten()
342 .max()
343 }
344
345 /// Returns the minimum item of a parallel iterator.
346 ///
347 /// See [`Iterator::min()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.min)
348 fn min(mut self, pool: &TaskPool) -> Option<BatchIter::Item>
349 where
350 BatchIter::Item: Ord + Send + 'static,
351 {
352 pool.scope(|s| {
353 while let Some(batch) = self.next_batch() {
354 s.spawn(async move { batch.min() });
355 }
356 })
357 .into_iter()
358 .flatten()
359 .min()
360 }
361
362 /// Returns the item that gives the maximum value from the specified function.
363 ///
364 /// See [`Iterator::max_by_key()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.max_by_key)
365 fn max_by_key<R, F>(mut self, pool: &TaskPool, f: F) -> Option<BatchIter::Item>
366 where
367 R: Ord,
368 F: FnMut(&BatchIter::Item) -> R + Send + Sync + Clone,
369 BatchIter::Item: Send + 'static,
370 {
371 pool.scope(|s| {
372 while let Some(batch) = self.next_batch() {
373 let newf = f.clone();
374 s.spawn(async move { batch.max_by_key(newf) });
375 }
376 })
377 .into_iter()
378 .flatten()
379 .max_by_key(f)
380 }
381
382 /// Returns the item that gives the maximum value with respect to the specified comparison
383 /// function.
384 ///
385 /// See [`Iterator::max_by()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.max_by)
386 fn max_by<F>(mut self, pool: &TaskPool, f: F) -> Option<BatchIter::Item>
387 where
388 F: FnMut(&BatchIter::Item, &BatchIter::Item) -> core::cmp::Ordering + Send + Sync + Clone,
389 BatchIter::Item: Send + 'static,
390 {
391 pool.scope(|s| {
392 while let Some(batch) = self.next_batch() {
393 let newf = f.clone();
394 s.spawn(async move { batch.max_by(newf) });
395 }
396 })
397 .into_iter()
398 .flatten()
399 .max_by(f)
400 }
401
402 /// Returns the item that gives the minimum value from the specified function.
403 ///
404 /// See [`Iterator::min_by_key()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.min_by_key)
405 fn min_by_key<R, F>(mut self, pool: &TaskPool, f: F) -> Option<BatchIter::Item>
406 where
407 R: Ord,
408 F: FnMut(&BatchIter::Item) -> R + Send + Sync + Clone,
409 BatchIter::Item: Send + 'static,
410 {
411 pool.scope(|s| {
412 while let Some(batch) = self.next_batch() {
413 let newf = f.clone();
414 s.spawn(async move { batch.min_by_key(newf) });
415 }
416 })
417 .into_iter()
418 .flatten()
419 .min_by_key(f)
420 }
421
422 /// Returns the item that gives the minimum value with respect to the specified comparison
423 /// function.
424 ///
425 /// See [`Iterator::min_by()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.min_by)
426 fn min_by<F>(mut self, pool: &TaskPool, f: F) -> Option<BatchIter::Item>
427 where
428 F: FnMut(&BatchIter::Item, &BatchIter::Item) -> core::cmp::Ordering + Send + Sync + Clone,
429 BatchIter::Item: Send + 'static,
430 {
431 pool.scope(|s| {
432 while let Some(batch) = self.next_batch() {
433 let newf = f.clone();
434 s.spawn(async move { batch.min_by(newf) });
435 }
436 })
437 .into_iter()
438 .flatten()
439 .min_by(f)
440 }
441
442 /// Creates a parallel iterator which copies all of its items.
443 ///
444 /// See [`Iterator::copied()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.copied)
445 fn copied<'a, T>(self) -> Copied<Self>
446 where
447 Self: ParallelIterator<BatchIter>,
448 T: 'a + Copy,
449 {
450 Copied { iter: self }
451 }
452
453 /// Creates a parallel iterator which clones all of its items.
454 ///
455 /// See [`Iterator::cloned()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.cloned)
456 fn cloned<'a, T>(self) -> Cloned<Self>
457 where
458 Self: ParallelIterator<BatchIter>,
459 T: 'a + Copy,
460 {
461 Cloned { iter: self }
462 }
463
464 /// Repeats a parallel iterator endlessly.
465 ///
466 /// See [`Iterator::cycle()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.cycle)
467 fn cycle(self) -> Cycle<Self>
468 where
469 Self: Clone,
470 {
471 Cycle {
472 iter: self,
473 curr: None,
474 }
475 }
476
477 /// Sums the items of a parallel iterator.
478 ///
479 /// See [`Iterator::sum()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.sum)
480 fn sum<S, R>(mut self, pool: &TaskPool) -> R
481 where
482 S: core::iter::Sum<BatchIter::Item> + Send + 'static,
483 R: core::iter::Sum<S>,
484 {
485 pool.scope(|s| {
486 while let Some(batch) = self.next_batch() {
487 s.spawn(async move { batch.sum() });
488 }
489 })
490 .into_iter()
491 .sum()
492 }
493
494 /// Multiplies all the items of a parallel iterator.
495 ///
496 /// See [`Iterator::product()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.product)
497 fn product<S, R>(mut self, pool: &TaskPool) -> R
498 where
499 S: core::iter::Product<BatchIter::Item> + Send + 'static,
500 R: core::iter::Product<S>,
501 {
502 pool.scope(|s| {
503 while let Some(batch) = self.next_batch() {
504 s.spawn(async move { batch.product() });
505 }
506 })
507 .into_iter()
508 .product()
509 }
510}