bevy_tasks/task_pool.rs
1use alloc::{boxed::Box, format, string::String, vec::Vec};
2use core::{future::Future, marker::PhantomData, mem, panic::AssertUnwindSafe};
3use std::{
4 thread::{self, JoinHandle},
5 thread_local,
6};
7
8use crate::executor::FallibleTask;
9use bevy_platform::sync::Arc;
10use concurrent_queue::ConcurrentQueue;
11use futures_lite::FutureExt;
12
13use crate::{
14 block_on,
15 thread_executor::{ThreadExecutor, ThreadExecutorTicker},
16 Task,
17};
18
19struct CallOnDrop(Option<Arc<dyn Fn() + Send + Sync + 'static>>);
20
21impl Drop for CallOnDrop {
22 fn drop(&mut self) {
23 if let Some(call) = self.0.as_ref() {
24 call();
25 }
26 }
27}
28
29/// Used to create a [`TaskPool`]
30#[derive(Default)]
31#[must_use]
32pub struct TaskPoolBuilder {
33 /// If set, we'll set up the thread pool to use at most `num_threads` threads.
34 /// Otherwise use the logical core count of the system
35 num_threads: Option<usize>,
36 /// If set, we'll use the given stack size rather than the system default
37 stack_size: Option<usize>,
38 /// Allows customizing the name of the threads - helpful for debugging. If set, threads will
39 /// be named `<thread_name> (<thread_index>)`, i.e. `"MyThreadPool (2)"`.
40 thread_name: Option<String>,
41
42 on_thread_spawn: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
43 on_thread_destroy: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
44}
45
46impl TaskPoolBuilder {
47 /// Creates a new [`TaskPoolBuilder`] instance
48 pub fn new() -> Self {
49 Self::default()
50 }
51
52 /// Override the number of threads created for the pool. If unset, we default to the number
53 /// of logical cores of the system
54 pub fn num_threads(mut self, num_threads: usize) -> Self {
55 self.num_threads = Some(num_threads);
56 self
57 }
58
59 /// Override the stack size of the threads created for the pool
60 pub fn stack_size(mut self, stack_size: usize) -> Self {
61 self.stack_size = Some(stack_size);
62 self
63 }
64
65 /// Override the name of the threads created for the pool. If set, threads will
66 /// be named `<thread_name> (<thread_index>)`, i.e. `MyThreadPool (2)`
67 pub fn thread_name(mut self, thread_name: String) -> Self {
68 self.thread_name = Some(thread_name);
69 self
70 }
71
72 /// Sets a callback that is invoked once for every created thread as it starts.
73 ///
74 /// This is called on the thread itself and has access to all thread-local storage.
75 /// This will block running async tasks on the thread until the callback completes.
76 pub fn on_thread_spawn(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
77 let arc = Arc::new(f);
78
79 #[cfg(not(target_has_atomic = "ptr"))]
80 #[expect(
81 unsafe_code,
82 reason = "unsized coercion is an unstable feature for non-std types"
83 )]
84 // SAFETY:
85 // - Coercion from `impl Fn` to `dyn Fn` is valid
86 // - `Arc::from_raw` receives a valid pointer from a previous call to `Arc::into_raw`
87 let arc = unsafe {
88 Arc::from_raw(Arc::into_raw(arc) as *const (dyn Fn() + Send + Sync + 'static))
89 };
90
91 self.on_thread_spawn = Some(arc);
92 self
93 }
94
95 /// Sets a callback that is invoked once for every created thread as it terminates.
96 ///
97 /// This is called on the thread itself and has access to all thread-local storage.
98 /// This will block thread termination until the callback completes.
99 pub fn on_thread_destroy(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
100 let arc = Arc::new(f);
101
102 #[cfg(not(target_has_atomic = "ptr"))]
103 #[expect(
104 unsafe_code,
105 reason = "unsized coercion is an unstable feature for non-std types"
106 )]
107 // SAFETY:
108 // - Coercion from `impl Fn` to `dyn Fn` is valid
109 // - `Arc::from_raw` receives a valid pointer from a previous call to `Arc::into_raw`
110 let arc = unsafe {
111 Arc::from_raw(Arc::into_raw(arc) as *const (dyn Fn() + Send + Sync + 'static))
112 };
113
114 self.on_thread_destroy = Some(arc);
115 self
116 }
117
118 /// Creates a new [`TaskPool`] based on the current options.
119 pub fn build(self) -> TaskPool {
120 TaskPool::new_internal(self)
121 }
122}
123
124/// A thread pool for executing tasks.
125///
126/// While futures usually need to be polled to be executed, Bevy tasks are being
127/// automatically driven by the pool on threads owned by the pool. The [`Task`]
128/// future only needs to be polled in order to receive the result. (For that
129/// purpose, it is often stored in a component or resource, see the
130/// `async_compute` example.)
131///
132/// If the result is not required, one may also use [`Task::detach`] and the pool
133/// will still execute a task, even if it is dropped.
134#[derive(Debug)]
135pub struct TaskPool {
136 /// The executor for the pool.
137 executor: Arc<crate::executor::Executor<'static>>,
138
139 // The inner state of the pool.
140 threads: Vec<JoinHandle<()>>,
141 shutdown_tx: async_channel::Sender<()>,
142}
143
144impl TaskPool {
145 thread_local! {
146 static LOCAL_EXECUTOR: crate::executor::LocalExecutor<'static> = const { crate::executor::LocalExecutor::new() };
147 static THREAD_EXECUTOR: Arc<ThreadExecutor<'static>> = Arc::new(ThreadExecutor::new());
148 }
149
150 /// Each thread should only create one `ThreadExecutor`, otherwise, there are good chances they will deadlock
151 pub fn get_thread_executor() -> Arc<ThreadExecutor<'static>> {
152 Self::THREAD_EXECUTOR.with(Clone::clone)
153 }
154
155 /// Create a `TaskPool` with the default configuration.
156 pub fn new() -> Self {
157 TaskPoolBuilder::new().build()
158 }
159
160 fn new_internal(builder: TaskPoolBuilder) -> Self {
161 let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>();
162
163 let executor = Arc::new(crate::executor::Executor::new());
164
165 let num_threads = builder
166 .num_threads
167 .unwrap_or_else(crate::available_parallelism);
168
169 let threads = (0..num_threads)
170 .map(|i| {
171 let ex = Arc::clone(&executor);
172 let shutdown_rx = shutdown_rx.clone();
173
174 let thread_name = if let Some(thread_name) = builder.thread_name.as_deref() {
175 format!("{thread_name} ({i})")
176 } else {
177 format!("TaskPool ({i})")
178 };
179 let mut thread_builder = thread::Builder::new().name(thread_name);
180
181 if let Some(stack_size) = builder.stack_size {
182 thread_builder = thread_builder.stack_size(stack_size);
183 }
184
185 let on_thread_spawn = builder.on_thread_spawn.clone();
186 let on_thread_destroy = builder.on_thread_destroy.clone();
187
188 thread_builder
189 .spawn(move || {
190 TaskPool::LOCAL_EXECUTOR.with(|local_executor| {
191 if let Some(on_thread_spawn) = on_thread_spawn {
192 on_thread_spawn();
193 drop(on_thread_spawn);
194 }
195 let _destructor = CallOnDrop(on_thread_destroy);
196 loop {
197 let res = std::panic::catch_unwind(|| {
198 let tick_forever = async move {
199 loop {
200 local_executor.tick().await;
201 }
202 };
203 block_on(ex.run(tick_forever.or(shutdown_rx.recv())))
204 });
205 if let Ok(value) = res {
206 // Use unwrap_err because we expect a Closed error
207 value.unwrap_err();
208 break;
209 }
210 }
211 });
212 })
213 .expect("Failed to spawn thread.")
214 })
215 .collect();
216
217 Self {
218 executor,
219 threads,
220 shutdown_tx,
221 }
222 }
223
224 /// Return the number of threads owned by the task pool
225 pub fn thread_num(&self) -> usize {
226 self.threads.len()
227 }
228
229 /// Allows spawning non-`'static` futures on the thread pool. The function takes a callback,
230 /// passing a scope object into it. The scope object provided to the callback can be used
231 /// to spawn tasks. This function will await the completion of all tasks before returning.
232 ///
233 /// This is similar to [`thread::scope`] and `rayon::scope`.
234 ///
235 /// # Example
236 ///
237 /// ```
238 /// use bevy_tasks::TaskPool;
239 ///
240 /// let pool = TaskPool::new();
241 /// let mut x = 0;
242 /// let results = pool.scope(|s| {
243 /// s.spawn(async {
244 /// // you can borrow the spawner inside a task and spawn tasks from within the task
245 /// s.spawn(async {
246 /// // borrow x and mutate it.
247 /// x = 2;
248 /// // return a value from the task
249 /// 1
250 /// });
251 /// // return some other value from the first task
252 /// 0
253 /// });
254 /// });
255 ///
256 /// // The ordering of results is non-deterministic if you spawn from within tasks as above.
257 /// // If you're doing this, you'll have to write your code to not depend on the ordering.
258 /// assert!(results.contains(&0));
259 /// assert!(results.contains(&1));
260 ///
261 /// // The ordering is deterministic if you only spawn directly from the closure function.
262 /// let results = pool.scope(|s| {
263 /// s.spawn(async { 0 });
264 /// s.spawn(async { 1 });
265 /// });
266 /// assert_eq!(&results[..], &[0, 1]);
267 ///
268 /// // You can access x after scope runs, since it was only temporarily borrowed in the scope.
269 /// assert_eq!(x, 2);
270 /// ```
271 ///
272 /// # Lifetimes
273 ///
274 /// The [`Scope`] object takes two lifetimes: `'scope` and `'env`.
275 ///
276 /// The `'scope` lifetime represents the lifetime of the scope. That is the time during
277 /// which the provided closure and tasks that are spawned into the scope are run.
278 ///
279 /// The `'env` lifetime represents the lifetime of whatever is borrowed by the scope.
280 /// Thus this lifetime must outlive `'scope`.
281 ///
282 /// ```compile_fail
283 /// use bevy_tasks::TaskPool;
284 /// fn scope_escapes_closure() {
285 /// let pool = TaskPool::new();
286 /// let foo = Box::new(42);
287 /// pool.scope(|scope| {
288 /// std::thread::spawn(move || {
289 /// // UB. This could spawn on the scope after `.scope` returns and the internal Scope is dropped.
290 /// scope.spawn(async move {
291 /// assert_eq!(*foo, 42);
292 /// });
293 /// });
294 /// });
295 /// }
296 /// ```
297 ///
298 /// ```compile_fail
299 /// use bevy_tasks::TaskPool;
300 /// fn cannot_borrow_from_closure() {
301 /// let pool = TaskPool::new();
302 /// pool.scope(|scope| {
303 /// let x = 1;
304 /// let y = &x;
305 /// scope.spawn(async move {
306 /// assert_eq!(*y, 1);
307 /// });
308 /// });
309 /// }
310 pub fn scope<'env, F, T>(&self, f: F) -> Vec<T>
311 where
312 F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>),
313 T: Send + 'static,
314 {
315 Self::THREAD_EXECUTOR.with(|scope_executor| {
316 self.scope_with_executor_inner(true, scope_executor, scope_executor, f)
317 })
318 }
319
320 /// This allows passing an external executor to spawn tasks on. When you pass an external executor
321 /// [`Scope::spawn_on_scope`] spawns is then run on the thread that [`ThreadExecutor`] is being ticked on.
322 /// If [`None`] is passed the scope will use a [`ThreadExecutor`] that is ticked on the current thread.
323 ///
324 /// When `tick_task_pool_executor` is set to `true`, the multithreaded task stealing executor is ticked on the scope
325 /// thread. Disabling this can be useful when finishing the scope is latency sensitive. Pulling tasks from
326 /// global executor can run tasks unrelated to the scope and delay when the scope returns.
327 ///
328 /// See [`Self::scope`] for more details in general about how scopes work.
329 pub fn scope_with_executor<'env, F, T>(
330 &self,
331 tick_task_pool_executor: bool,
332 external_executor: Option<&ThreadExecutor>,
333 f: F,
334 ) -> Vec<T>
335 where
336 F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>),
337 T: Send + 'static,
338 {
339 Self::THREAD_EXECUTOR.with(|scope_executor| {
340 // If an `external_executor` is passed, use that. Otherwise, get the executor stored
341 // in the `THREAD_EXECUTOR` thread local.
342 if let Some(external_executor) = external_executor {
343 self.scope_with_executor_inner(
344 tick_task_pool_executor,
345 external_executor,
346 scope_executor,
347 f,
348 )
349 } else {
350 self.scope_with_executor_inner(
351 tick_task_pool_executor,
352 scope_executor,
353 scope_executor,
354 f,
355 )
356 }
357 })
358 }
359
360 #[expect(unsafe_code, reason = "Required to transmute lifetimes.")]
361 fn scope_with_executor_inner<'env, F, T>(
362 &self,
363 tick_task_pool_executor: bool,
364 external_executor: &ThreadExecutor,
365 scope_executor: &ThreadExecutor,
366 f: F,
367 ) -> Vec<T>
368 where
369 F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>),
370 T: Send + 'static,
371 {
372 // SAFETY: This safety comment applies to all references transmuted to 'env.
373 // Any futures spawned with these references need to return before this function completes.
374 // This is guaranteed because we drive all the futures spawned onto the Scope
375 // to completion in this function. However, rust has no way of knowing this so we
376 // transmute the lifetimes to 'env here to appease the compiler as it is unable to validate safety.
377 // Any usages of the references passed into `Scope` must be accessed through
378 // the transmuted reference for the rest of this function.
379 let executor: &crate::executor::Executor = &self.executor;
380 // SAFETY: As above, all futures must complete in this function so we can change the lifetime
381 let executor: &'env crate::executor::Executor = unsafe { mem::transmute(executor) };
382 // SAFETY: As above, all futures must complete in this function so we can change the lifetime
383 let external_executor: &'env ThreadExecutor<'env> =
384 unsafe { mem::transmute(external_executor) };
385 // SAFETY: As above, all futures must complete in this function so we can change the lifetime
386 let scope_executor: &'env ThreadExecutor<'env> = unsafe { mem::transmute(scope_executor) };
387 let spawned: ConcurrentQueue<FallibleTask<Result<T, Box<(dyn core::any::Any + Send)>>>> =
388 ConcurrentQueue::unbounded();
389 // shadow the variable so that the owned value cannot be used for the rest of the function
390 // SAFETY: As above, all futures must complete in this function so we can change the lifetime
391 let spawned: &'env ConcurrentQueue<
392 FallibleTask<Result<T, Box<(dyn core::any::Any + Send)>>>,
393 > = unsafe { mem::transmute(&spawned) };
394
395 let scope = Scope {
396 executor,
397 external_executor,
398 scope_executor,
399 spawned,
400 scope: PhantomData,
401 env: PhantomData,
402 };
403
404 // shadow the variable so that the owned value cannot be used for the rest of the function
405 // SAFETY: As above, all futures must complete in this function so we can change the lifetime
406 let scope: &'env Scope<'_, 'env, T> = unsafe { mem::transmute(&scope) };
407
408 f(scope);
409
410 if spawned.is_empty() {
411 Vec::new()
412 } else {
413 block_on(async move {
414 let get_results = async {
415 let mut results = Vec::with_capacity(spawned.len());
416 while let Ok(task) = spawned.pop() {
417 if let Some(res) = task.await {
418 match res {
419 Ok(res) => results.push(res),
420 Err(payload) => std::panic::resume_unwind(payload),
421 }
422 } else {
423 panic!("Failed to catch panic!");
424 }
425 }
426 results
427 };
428
429 let tick_task_pool_executor = tick_task_pool_executor || self.threads.is_empty();
430
431 // we get this from a thread local so we should always be on the scope executors thread.
432 // note: it is possible `scope_executor` and `external_executor` is the same executor,
433 // in that case, we should only tick one of them, otherwise, it may cause deadlock.
434 let scope_ticker = scope_executor.ticker().unwrap();
435 let external_ticker = if !external_executor.is_same(scope_executor) {
436 external_executor.ticker()
437 } else {
438 None
439 };
440
441 match (external_ticker, tick_task_pool_executor) {
442 (Some(external_ticker), true) => {
443 Self::execute_global_external_scope(
444 executor,
445 external_ticker,
446 scope_ticker,
447 get_results,
448 )
449 .await
450 }
451 (Some(external_ticker), false) => {
452 Self::execute_external_scope(external_ticker, scope_ticker, get_results)
453 .await
454 }
455 // either external_executor is none or it is same as scope_executor
456 (None, true) => {
457 Self::execute_global_scope(executor, scope_ticker, get_results).await
458 }
459 (None, false) => Self::execute_scope(scope_ticker, get_results).await,
460 }
461 })
462 }
463 }
464
465 #[inline]
466 async fn execute_global_external_scope<'scope, 'ticker, T>(
467 executor: &'scope crate::executor::Executor<'scope>,
468 external_ticker: ThreadExecutorTicker<'scope, 'ticker>,
469 scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
470 get_results: impl Future<Output = Vec<T>>,
471 ) -> Vec<T> {
472 // we restart the executors if a task errors. if a scoped
473 // task errors it will panic the scope on the call to get_results
474 let execute_forever = async move {
475 loop {
476 let tick_forever = async {
477 loop {
478 external_ticker.tick().or(scope_ticker.tick()).await;
479 }
480 };
481 // we don't care if it errors. If a scoped task errors it will propagate
482 // to get_results
483 let _result = AssertUnwindSafe(executor.run(tick_forever))
484 .catch_unwind()
485 .await
486 .is_ok();
487 }
488 };
489 get_results.or(execute_forever).await
490 }
491
492 #[inline]
493 async fn execute_external_scope<'scope, 'ticker, T>(
494 external_ticker: ThreadExecutorTicker<'scope, 'ticker>,
495 scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
496 get_results: impl Future<Output = Vec<T>>,
497 ) -> Vec<T> {
498 let execute_forever = async {
499 loop {
500 let tick_forever = async {
501 loop {
502 external_ticker.tick().or(scope_ticker.tick()).await;
503 }
504 };
505 let _result = AssertUnwindSafe(tick_forever).catch_unwind().await.is_ok();
506 }
507 };
508 get_results.or(execute_forever).await
509 }
510
511 #[inline]
512 async fn execute_global_scope<'scope, 'ticker, T>(
513 executor: &'scope crate::executor::Executor<'scope>,
514 scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
515 get_results: impl Future<Output = Vec<T>>,
516 ) -> Vec<T> {
517 let execute_forever = async {
518 loop {
519 let tick_forever = async {
520 loop {
521 scope_ticker.tick().await;
522 }
523 };
524 let _result = AssertUnwindSafe(executor.run(tick_forever))
525 .catch_unwind()
526 .await
527 .is_ok();
528 }
529 };
530 get_results.or(execute_forever).await
531 }
532
533 #[inline]
534 async fn execute_scope<'scope, 'ticker, T>(
535 scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
536 get_results: impl Future<Output = Vec<T>>,
537 ) -> Vec<T> {
538 let execute_forever = async {
539 loop {
540 let tick_forever = async {
541 loop {
542 scope_ticker.tick().await;
543 }
544 };
545 let _result = AssertUnwindSafe(tick_forever).catch_unwind().await.is_ok();
546 }
547 };
548 get_results.or(execute_forever).await
549 }
550
551 /// Spawns a static future onto the thread pool. The returned [`Task`] is a
552 /// future that can be polled for the result. It can also be canceled and
553 /// "detached", allowing the task to continue running even if dropped. In
554 /// any case, the pool will execute the task even without polling by the
555 /// end-user.
556 ///
557 /// If the provided future is non-`Send`, [`TaskPool::spawn_local`] should
558 /// be used instead.
559 pub fn spawn<T>(&self, future: impl Future<Output = T> + Send + 'static) -> Task<T>
560 where
561 T: Send + 'static,
562 {
563 Task::new(self.executor.spawn(future))
564 }
565
566 /// Spawns a static future on the thread-local async executor for the
567 /// current thread. The task will run entirely on the thread the task was
568 /// spawned on.
569 ///
570 /// The returned [`Task`] is a future that can be polled for the
571 /// result. It can also be canceled and "detached", allowing the task to
572 /// continue running even if dropped. In any case, the pool will execute the
573 /// task even without polling by the end-user.
574 ///
575 /// Users should generally prefer to use [`TaskPool::spawn`] instead,
576 /// unless the provided future is not `Send`.
577 pub fn spawn_local<T>(&self, future: impl Future<Output = T> + 'static) -> Task<T>
578 where
579 T: 'static,
580 {
581 Task::new(TaskPool::LOCAL_EXECUTOR.with(|executor| executor.spawn(future)))
582 }
583
584 /// Runs a function with the local executor. Typically used to tick
585 /// the local executor on the main thread as it needs to share time with
586 /// other things.
587 ///
588 /// ```
589 /// use bevy_tasks::TaskPool;
590 ///
591 /// TaskPool::new().with_local_executor(|local_executor| {
592 /// local_executor.try_tick();
593 /// });
594 /// ```
595 pub fn with_local_executor<F, R>(&self, f: F) -> R
596 where
597 F: FnOnce(&crate::executor::LocalExecutor) -> R,
598 {
599 Self::LOCAL_EXECUTOR.with(f)
600 }
601}
602
603impl Default for TaskPool {
604 fn default() -> Self {
605 Self::new()
606 }
607}
608
609impl Drop for TaskPool {
610 fn drop(&mut self) {
611 self.shutdown_tx.close();
612
613 let panicking = thread::panicking();
614 for join_handle in self.threads.drain(..) {
615 let res = join_handle.join();
616 if !panicking {
617 res.expect("Task thread panicked while executing.");
618 }
619 }
620 }
621}
622
623/// A [`TaskPool`] scope for running one or more non-`'static` futures.
624///
625/// For more information, see [`TaskPool::scope`].
626#[derive(Debug)]
627pub struct Scope<'scope, 'env: 'scope, T> {
628 executor: &'scope crate::executor::Executor<'scope>,
629 external_executor: &'scope ThreadExecutor<'scope>,
630 scope_executor: &'scope ThreadExecutor<'scope>,
631 spawned: &'scope ConcurrentQueue<FallibleTask<Result<T, Box<(dyn core::any::Any + Send)>>>>,
632 // make `Scope` invariant over 'scope and 'env
633 scope: PhantomData<&'scope mut &'scope ()>,
634 env: PhantomData<&'env mut &'env ()>,
635}
636
637impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> {
638 /// Spawns a scoped future onto the thread pool. The scope *must* outlive
639 /// the provided future. The results of the future will be returned as a part of
640 /// [`TaskPool::scope`]'s return value.
641 ///
642 /// For futures that should run on the thread `scope` is called on [`Scope::spawn_on_scope`] should be used
643 /// instead.
644 ///
645 /// For more information, see [`TaskPool::scope`].
646 pub fn spawn<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
647 let task = self
648 .executor
649 .spawn(AssertUnwindSafe(f).catch_unwind())
650 .fallible();
651 // ConcurrentQueue only errors when closed or full, but we never
652 // close and use an unbounded queue, so it is safe to unwrap
653 self.spawned.push(task).unwrap();
654 }
655
656 /// Spawns a scoped future onto the thread the scope is run on. The scope *must* outlive
657 /// the provided future. The results of the future will be returned as a part of
658 /// [`TaskPool::scope`]'s return value. Users should generally prefer to use
659 /// [`Scope::spawn`] instead, unless the provided future needs to run on the scope's thread.
660 ///
661 /// For more information, see [`TaskPool::scope`].
662 pub fn spawn_on_scope<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
663 let task = self
664 .scope_executor
665 .spawn(AssertUnwindSafe(f).catch_unwind())
666 .fallible();
667 // ConcurrentQueue only errors when closed or full, but we never
668 // close and use an unbounded queue, so it is safe to unwrap
669 self.spawned.push(task).unwrap();
670 }
671
672 /// Spawns a scoped future onto the thread of the external thread executor.
673 /// This is typically the main thread. The scope *must* outlive
674 /// the provided future. The results of the future will be returned as a part of
675 /// [`TaskPool::scope`]'s return value. Users should generally prefer to use
676 /// [`Scope::spawn`] instead, unless the provided future needs to run on the external thread.
677 ///
678 /// For more information, see [`TaskPool::scope`].
679 pub fn spawn_on_external<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
680 let task = self
681 .external_executor
682 .spawn(AssertUnwindSafe(f).catch_unwind())
683 .fallible();
684 // ConcurrentQueue only errors when closed or full, but we never
685 // close and use an unbounded queue, so it is safe to unwrap
686 self.spawned.push(task).unwrap();
687 }
688}
689
690impl<'scope, 'env, T> Drop for Scope<'scope, 'env, T>
691where
692 T: 'scope,
693{
694 fn drop(&mut self) {
695 block_on(async {
696 while let Ok(task) = self.spawned.pop() {
697 task.cancel().await;
698 }
699 });
700 }
701}
702
703#[cfg(test)]
704mod tests {
705 use super::*;
706 use core::sync::atomic::{AtomicBool, AtomicI32, Ordering};
707 use std::sync::Barrier;
708
709 #[test]
710 fn test_spawn() {
711 let pool = TaskPool::new();
712
713 let foo = Box::new(42);
714 let foo = &*foo;
715
716 let count = Arc::new(AtomicI32::new(0));
717
718 let outputs = pool.scope(|scope| {
719 for _ in 0..100 {
720 let count_clone = count.clone();
721 scope.spawn(async move {
722 if *foo != 42 {
723 panic!("not 42!?!?")
724 } else {
725 count_clone.fetch_add(1, Ordering::Relaxed);
726 *foo
727 }
728 });
729 }
730 });
731
732 for output in &outputs {
733 assert_eq!(*output, 42);
734 }
735
736 assert_eq!(outputs.len(), 100);
737 assert_eq!(count.load(Ordering::Relaxed), 100);
738 }
739
740 #[test]
741 fn test_thread_callbacks() {
742 let counter = Arc::new(AtomicI32::new(0));
743 let start_counter = counter.clone();
744 {
745 let barrier = Arc::new(Barrier::new(11));
746 let last_barrier = barrier.clone();
747 // Build and immediately drop to terminate
748 let _pool = TaskPoolBuilder::new()
749 .num_threads(10)
750 .on_thread_spawn(move || {
751 start_counter.fetch_add(1, Ordering::Relaxed);
752 barrier.clone().wait();
753 })
754 .build();
755 last_barrier.wait();
756 assert_eq!(10, counter.load(Ordering::Relaxed));
757 }
758 assert_eq!(10, counter.load(Ordering::Relaxed));
759 let end_counter = counter.clone();
760 {
761 let _pool = TaskPoolBuilder::new()
762 .num_threads(20)
763 .on_thread_destroy(move || {
764 end_counter.fetch_sub(1, Ordering::Relaxed);
765 })
766 .build();
767 assert_eq!(10, counter.load(Ordering::Relaxed));
768 }
769 assert_eq!(-10, counter.load(Ordering::Relaxed));
770 let start_counter = counter.clone();
771 let end_counter = counter.clone();
772 {
773 let barrier = Arc::new(Barrier::new(6));
774 let last_barrier = barrier.clone();
775 let _pool = TaskPoolBuilder::new()
776 .num_threads(5)
777 .on_thread_spawn(move || {
778 start_counter.fetch_add(1, Ordering::Relaxed);
779 barrier.wait();
780 })
781 .on_thread_destroy(move || {
782 end_counter.fetch_sub(1, Ordering::Relaxed);
783 })
784 .build();
785 last_barrier.wait();
786 assert_eq!(-5, counter.load(Ordering::Relaxed));
787 }
788 assert_eq!(-10, counter.load(Ordering::Relaxed));
789 }
790
791 #[test]
792 fn test_mixed_spawn_on_scope_and_spawn() {
793 let pool = TaskPool::new();
794
795 let foo = Box::new(42);
796 let foo = &*foo;
797
798 let local_count = Arc::new(AtomicI32::new(0));
799 let non_local_count = Arc::new(AtomicI32::new(0));
800
801 let outputs = pool.scope(|scope| {
802 for i in 0..100 {
803 if i % 2 == 0 {
804 let count_clone = non_local_count.clone();
805 scope.spawn(async move {
806 if *foo != 42 {
807 panic!("not 42!?!?")
808 } else {
809 count_clone.fetch_add(1, Ordering::Relaxed);
810 *foo
811 }
812 });
813 } else {
814 let count_clone = local_count.clone();
815 scope.spawn_on_scope(async move {
816 if *foo != 42 {
817 panic!("not 42!?!?")
818 } else {
819 count_clone.fetch_add(1, Ordering::Relaxed);
820 *foo
821 }
822 });
823 }
824 }
825 });
826
827 for output in &outputs {
828 assert_eq!(*output, 42);
829 }
830
831 assert_eq!(outputs.len(), 100);
832 assert_eq!(local_count.load(Ordering::Relaxed), 50);
833 assert_eq!(non_local_count.load(Ordering::Relaxed), 50);
834 }
835
836 #[test]
837 fn test_thread_locality() {
838 let pool = Arc::new(TaskPool::new());
839 let count = Arc::new(AtomicI32::new(0));
840 let barrier = Arc::new(Barrier::new(101));
841 let thread_check_failed = Arc::new(AtomicBool::new(false));
842
843 for _ in 0..100 {
844 let inner_barrier = barrier.clone();
845 let count_clone = count.clone();
846 let inner_pool = pool.clone();
847 let inner_thread_check_failed = thread_check_failed.clone();
848 thread::spawn(move || {
849 inner_pool.scope(|scope| {
850 let inner_count_clone = count_clone.clone();
851 scope.spawn(async move {
852 inner_count_clone.fetch_add(1, Ordering::Release);
853 });
854 let spawner = thread::current().id();
855 let inner_count_clone = count_clone.clone();
856 scope.spawn_on_scope(async move {
857 inner_count_clone.fetch_add(1, Ordering::Release);
858 if thread::current().id() != spawner {
859 // NOTE: This check is using an atomic rather than simply panicking the
860 // thread to avoid deadlocking the barrier on failure
861 inner_thread_check_failed.store(true, Ordering::Release);
862 }
863 });
864 });
865 inner_barrier.wait();
866 });
867 }
868 barrier.wait();
869 assert!(!thread_check_failed.load(Ordering::Acquire));
870 assert_eq!(count.load(Ordering::Acquire), 200);
871 }
872
873 #[test]
874 fn test_nested_spawn() {
875 let pool = TaskPool::new();
876
877 let foo = Box::new(42);
878 let foo = &*foo;
879
880 let count = Arc::new(AtomicI32::new(0));
881
882 let outputs: Vec<i32> = pool.scope(|scope| {
883 for _ in 0..10 {
884 let count_clone = count.clone();
885 scope.spawn(async move {
886 for _ in 0..10 {
887 let count_clone_clone = count_clone.clone();
888 scope.spawn(async move {
889 if *foo != 42 {
890 panic!("not 42!?!?")
891 } else {
892 count_clone_clone.fetch_add(1, Ordering::Relaxed);
893 *foo
894 }
895 });
896 }
897 *foo
898 });
899 }
900 });
901
902 for output in &outputs {
903 assert_eq!(*output, 42);
904 }
905
906 // the inner loop runs 100 times and the outer one runs 10. 100 + 10
907 assert_eq!(outputs.len(), 110);
908 assert_eq!(count.load(Ordering::Relaxed), 100);
909 }
910
911 #[test]
912 fn test_nested_locality() {
913 let pool = Arc::new(TaskPool::new());
914 let count = Arc::new(AtomicI32::new(0));
915 let barrier = Arc::new(Barrier::new(101));
916 let thread_check_failed = Arc::new(AtomicBool::new(false));
917
918 for _ in 0..100 {
919 let inner_barrier = barrier.clone();
920 let count_clone = count.clone();
921 let inner_pool = pool.clone();
922 let inner_thread_check_failed = thread_check_failed.clone();
923 thread::spawn(move || {
924 inner_pool.scope(|scope| {
925 let spawner = thread::current().id();
926 let inner_count_clone = count_clone.clone();
927 scope.spawn(async move {
928 inner_count_clone.fetch_add(1, Ordering::Release);
929
930 // spawning on the scope from another thread runs the futures on the scope's thread
931 scope.spawn_on_scope(async move {
932 inner_count_clone.fetch_add(1, Ordering::Release);
933 if thread::current().id() != spawner {
934 // NOTE: This check is using an atomic rather than simply panicking the
935 // thread to avoid deadlocking the barrier on failure
936 inner_thread_check_failed.store(true, Ordering::Release);
937 }
938 });
939 });
940 });
941 inner_barrier.wait();
942 });
943 }
944 barrier.wait();
945 assert!(!thread_check_failed.load(Ordering::Acquire));
946 assert_eq!(count.load(Ordering::Acquire), 200);
947 }
948
949 // This test will often freeze on other executors.
950 #[test]
951 fn test_nested_scopes() {
952 let pool = TaskPool::new();
953 let count = Arc::new(AtomicI32::new(0));
954
955 pool.scope(|scope| {
956 scope.spawn(async {
957 pool.scope(|scope| {
958 scope.spawn(async {
959 count.fetch_add(1, Ordering::Relaxed);
960 });
961 });
962 });
963 });
964
965 assert_eq!(count.load(Ordering::Acquire), 1);
966 }
967}