bevy_tasks/
task_pool.rs

1use alloc::{boxed::Box, format, string::String, vec::Vec};
2use core::{future::Future, marker::PhantomData, mem, panic::AssertUnwindSafe};
3use std::{
4    thread::{self, JoinHandle},
5    thread_local,
6};
7
8use crate::executor::FallibleTask;
9use bevy_platform::sync::Arc;
10use concurrent_queue::ConcurrentQueue;
11use futures_lite::FutureExt;
12
13use crate::{
14    block_on,
15    thread_executor::{ThreadExecutor, ThreadExecutorTicker},
16    Task,
17};
18
19struct CallOnDrop(Option<Arc<dyn Fn() + Send + Sync + 'static>>);
20
21impl Drop for CallOnDrop {
22    fn drop(&mut self) {
23        if let Some(call) = self.0.as_ref() {
24            call();
25        }
26    }
27}
28
29/// Used to create a [`TaskPool`]
30#[derive(Default)]
31#[must_use]
32pub struct TaskPoolBuilder {
33    /// If set, we'll set up the thread pool to use at most `num_threads` threads.
34    /// Otherwise use the logical core count of the system
35    num_threads: Option<usize>,
36    /// If set, we'll use the given stack size rather than the system default
37    stack_size: Option<usize>,
38    /// Allows customizing the name of the threads - helpful for debugging. If set, threads will
39    /// be named `<thread_name> (<thread_index>)`, i.e. `"MyThreadPool (2)"`.
40    thread_name: Option<String>,
41
42    on_thread_spawn: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
43    on_thread_destroy: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
44}
45
46impl TaskPoolBuilder {
47    /// Creates a new [`TaskPoolBuilder`] instance
48    pub fn new() -> Self {
49        Self::default()
50    }
51
52    /// Override the number of threads created for the pool. If unset, we default to the number
53    /// of logical cores of the system
54    pub fn num_threads(mut self, num_threads: usize) -> Self {
55        self.num_threads = Some(num_threads);
56        self
57    }
58
59    /// Override the stack size of the threads created for the pool
60    pub fn stack_size(mut self, stack_size: usize) -> Self {
61        self.stack_size = Some(stack_size);
62        self
63    }
64
65    /// Override the name of the threads created for the pool. If set, threads will
66    /// be named `<thread_name> (<thread_index>)`, i.e. `MyThreadPool (2)`
67    pub fn thread_name(mut self, thread_name: String) -> Self {
68        self.thread_name = Some(thread_name);
69        self
70    }
71
72    /// Sets a callback that is invoked once for every created thread as it starts.
73    ///
74    /// This is called on the thread itself and has access to all thread-local storage.
75    /// This will block running async tasks on the thread until the callback completes.
76    pub fn on_thread_spawn(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
77        let arc = Arc::new(f);
78
79        #[cfg(not(target_has_atomic = "ptr"))]
80        #[expect(
81            unsafe_code,
82            reason = "unsized coercion is an unstable feature for non-std types"
83        )]
84        // SAFETY:
85        // - Coercion from `impl Fn` to `dyn Fn` is valid
86        // - `Arc::from_raw` receives a valid pointer from a previous call to `Arc::into_raw`
87        let arc = unsafe {
88            Arc::from_raw(Arc::into_raw(arc) as *const (dyn Fn() + Send + Sync + 'static))
89        };
90
91        self.on_thread_spawn = Some(arc);
92        self
93    }
94
95    /// Sets a callback that is invoked once for every created thread as it terminates.
96    ///
97    /// This is called on the thread itself and has access to all thread-local storage.
98    /// This will block thread termination until the callback completes.
99    pub fn on_thread_destroy(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
100        let arc = Arc::new(f);
101
102        #[cfg(not(target_has_atomic = "ptr"))]
103        #[expect(
104            unsafe_code,
105            reason = "unsized coercion is an unstable feature for non-std types"
106        )]
107        // SAFETY:
108        // - Coercion from `impl Fn` to `dyn Fn` is valid
109        // - `Arc::from_raw` receives a valid pointer from a previous call to `Arc::into_raw`
110        let arc = unsafe {
111            Arc::from_raw(Arc::into_raw(arc) as *const (dyn Fn() + Send + Sync + 'static))
112        };
113
114        self.on_thread_destroy = Some(arc);
115        self
116    }
117
118    /// Creates a new [`TaskPool`] based on the current options.
119    pub fn build(self) -> TaskPool {
120        TaskPool::new_internal(self)
121    }
122}
123
124/// A thread pool for executing tasks.
125///
126/// While futures usually need to be polled to be executed, Bevy tasks are being
127/// automatically driven by the pool on threads owned by the pool. The [`Task`]
128/// future only needs to be polled in order to receive the result. (For that
129/// purpose, it is often stored in a component or resource, see the
130/// `async_compute` example.)
131///
132/// If the result is not required, one may also use [`Task::detach`] and the pool
133/// will still execute a task, even if it is dropped.
134#[derive(Debug)]
135pub struct TaskPool {
136    /// The executor for the pool.
137    executor: Arc<crate::executor::Executor<'static>>,
138
139    // The inner state of the pool.
140    threads: Vec<JoinHandle<()>>,
141    shutdown_tx: async_channel::Sender<()>,
142}
143
144impl TaskPool {
145    thread_local! {
146        static LOCAL_EXECUTOR: crate::executor::LocalExecutor<'static> = const { crate::executor::LocalExecutor::new() };
147        static THREAD_EXECUTOR: Arc<ThreadExecutor<'static>> = Arc::new(ThreadExecutor::new());
148    }
149
150    /// Each thread should only create one `ThreadExecutor`, otherwise, there are good chances they will deadlock
151    pub fn get_thread_executor() -> Arc<ThreadExecutor<'static>> {
152        Self::THREAD_EXECUTOR.with(Clone::clone)
153    }
154
155    /// Create a `TaskPool` with the default configuration.
156    pub fn new() -> Self {
157        TaskPoolBuilder::new().build()
158    }
159
160    fn new_internal(builder: TaskPoolBuilder) -> Self {
161        let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>();
162
163        let executor = Arc::new(crate::executor::Executor::new());
164
165        let num_threads = builder
166            .num_threads
167            .unwrap_or_else(crate::available_parallelism);
168
169        let threads = (0..num_threads)
170            .map(|i| {
171                let ex = Arc::clone(&executor);
172                let shutdown_rx = shutdown_rx.clone();
173
174                let thread_name = if let Some(thread_name) = builder.thread_name.as_deref() {
175                    format!("{thread_name} ({i})")
176                } else {
177                    format!("TaskPool ({i})")
178                };
179                let mut thread_builder = thread::Builder::new().name(thread_name);
180
181                if let Some(stack_size) = builder.stack_size {
182                    thread_builder = thread_builder.stack_size(stack_size);
183                }
184
185                let on_thread_spawn = builder.on_thread_spawn.clone();
186                let on_thread_destroy = builder.on_thread_destroy.clone();
187
188                thread_builder
189                    .spawn(move || {
190                        TaskPool::LOCAL_EXECUTOR.with(|local_executor| {
191                            if let Some(on_thread_spawn) = on_thread_spawn {
192                                on_thread_spawn();
193                                drop(on_thread_spawn);
194                            }
195                            let _destructor = CallOnDrop(on_thread_destroy);
196                            loop {
197                                let res = std::panic::catch_unwind(|| {
198                                    let tick_forever = async move {
199                                        loop {
200                                            local_executor.tick().await;
201                                        }
202                                    };
203                                    block_on(ex.run(tick_forever.or(shutdown_rx.recv())))
204                                });
205                                if let Ok(value) = res {
206                                    // Use unwrap_err because we expect a Closed error
207                                    value.unwrap_err();
208                                    break;
209                                }
210                            }
211                        });
212                    })
213                    .expect("Failed to spawn thread.")
214            })
215            .collect();
216
217        Self {
218            executor,
219            threads,
220            shutdown_tx,
221        }
222    }
223
224    /// Return the number of threads owned by the task pool
225    pub fn thread_num(&self) -> usize {
226        self.threads.len()
227    }
228
229    /// Allows spawning non-`'static` futures on the thread pool. The function takes a callback,
230    /// passing a scope object into it. The scope object provided to the callback can be used
231    /// to spawn tasks. This function will await the completion of all tasks before returning.
232    ///
233    /// This is similar to [`thread::scope`] and `rayon::scope`.
234    ///
235    /// # Example
236    ///
237    /// ```
238    /// use bevy_tasks::TaskPool;
239    ///
240    /// let pool = TaskPool::new();
241    /// let mut x = 0;
242    /// let results = pool.scope(|s| {
243    ///     s.spawn(async {
244    ///         // you can borrow the spawner inside a task and spawn tasks from within the task
245    ///         s.spawn(async {
246    ///             // borrow x and mutate it.
247    ///             x = 2;
248    ///             // return a value from the task
249    ///             1
250    ///         });
251    ///         // return some other value from the first task
252    ///         0
253    ///     });
254    /// });
255    ///
256    /// // The ordering of results is non-deterministic if you spawn from within tasks as above.
257    /// // If you're doing this, you'll have to write your code to not depend on the ordering.
258    /// assert!(results.contains(&0));
259    /// assert!(results.contains(&1));
260    ///
261    /// // The ordering is deterministic if you only spawn directly from the closure function.
262    /// let results = pool.scope(|s| {
263    ///     s.spawn(async { 0 });
264    ///     s.spawn(async { 1 });
265    /// });
266    /// assert_eq!(&results[..], &[0, 1]);
267    ///
268    /// // You can access x after scope runs, since it was only temporarily borrowed in the scope.
269    /// assert_eq!(x, 2);
270    /// ```
271    ///
272    /// # Lifetimes
273    ///
274    /// The [`Scope`] object takes two lifetimes: `'scope` and `'env`.
275    ///
276    /// The `'scope` lifetime represents the lifetime of the scope. That is the time during
277    /// which the provided closure and tasks that are spawned into the scope are run.
278    ///
279    /// The `'env` lifetime represents the lifetime of whatever is borrowed by the scope.
280    /// Thus this lifetime must outlive `'scope`.
281    ///
282    /// ```compile_fail
283    /// use bevy_tasks::TaskPool;
284    /// fn scope_escapes_closure() {
285    ///     let pool = TaskPool::new();
286    ///     let foo = Box::new(42);
287    ///     pool.scope(|scope| {
288    ///         std::thread::spawn(move || {
289    ///             // UB. This could spawn on the scope after `.scope` returns and the internal Scope is dropped.
290    ///             scope.spawn(async move {
291    ///                 assert_eq!(*foo, 42);
292    ///             });
293    ///         });
294    ///     });
295    /// }
296    /// ```
297    ///
298    /// ```compile_fail
299    /// use bevy_tasks::TaskPool;
300    /// fn cannot_borrow_from_closure() {
301    ///     let pool = TaskPool::new();
302    ///     pool.scope(|scope| {
303    ///         let x = 1;
304    ///         let y = &x;
305    ///         scope.spawn(async move {
306    ///             assert_eq!(*y, 1);
307    ///         });
308    ///     });
309    /// }
310    pub fn scope<'env, F, T>(&self, f: F) -> Vec<T>
311    where
312        F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>),
313        T: Send + 'static,
314    {
315        Self::THREAD_EXECUTOR.with(|scope_executor| {
316            self.scope_with_executor_inner(true, scope_executor, scope_executor, f)
317        })
318    }
319
320    /// This allows passing an external executor to spawn tasks on. When you pass an external executor
321    /// [`Scope::spawn_on_scope`] spawns is then run on the thread that [`ThreadExecutor`] is being ticked on.
322    /// If [`None`] is passed the scope will use a [`ThreadExecutor`] that is ticked on the current thread.
323    ///
324    /// When `tick_task_pool_executor` is set to `true`, the multithreaded task stealing executor is ticked on the scope
325    /// thread. Disabling this can be useful when finishing the scope is latency sensitive. Pulling tasks from
326    /// global executor can run tasks unrelated to the scope and delay when the scope returns.
327    ///
328    /// See [`Self::scope`] for more details in general about how scopes work.
329    pub fn scope_with_executor<'env, F, T>(
330        &self,
331        tick_task_pool_executor: bool,
332        external_executor: Option<&ThreadExecutor>,
333        f: F,
334    ) -> Vec<T>
335    where
336        F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>),
337        T: Send + 'static,
338    {
339        Self::THREAD_EXECUTOR.with(|scope_executor| {
340            // If an `external_executor` is passed, use that. Otherwise, get the executor stored
341            // in the `THREAD_EXECUTOR` thread local.
342            if let Some(external_executor) = external_executor {
343                self.scope_with_executor_inner(
344                    tick_task_pool_executor,
345                    external_executor,
346                    scope_executor,
347                    f,
348                )
349            } else {
350                self.scope_with_executor_inner(
351                    tick_task_pool_executor,
352                    scope_executor,
353                    scope_executor,
354                    f,
355                )
356            }
357        })
358    }
359
360    #[expect(unsafe_code, reason = "Required to transmute lifetimes.")]
361    fn scope_with_executor_inner<'env, F, T>(
362        &self,
363        tick_task_pool_executor: bool,
364        external_executor: &ThreadExecutor,
365        scope_executor: &ThreadExecutor,
366        f: F,
367    ) -> Vec<T>
368    where
369        F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>),
370        T: Send + 'static,
371    {
372        // SAFETY: This safety comment applies to all references transmuted to 'env.
373        // Any futures spawned with these references need to return before this function completes.
374        // This is guaranteed because we drive all the futures spawned onto the Scope
375        // to completion in this function. However, rust has no way of knowing this so we
376        // transmute the lifetimes to 'env here to appease the compiler as it is unable to validate safety.
377        // Any usages of the references passed into `Scope` must be accessed through
378        // the transmuted reference for the rest of this function.
379        let executor: &crate::executor::Executor = &self.executor;
380        // SAFETY: As above, all futures must complete in this function so we can change the lifetime
381        let executor: &'env crate::executor::Executor = unsafe { mem::transmute(executor) };
382        // SAFETY: As above, all futures must complete in this function so we can change the lifetime
383        let external_executor: &'env ThreadExecutor<'env> =
384            unsafe { mem::transmute(external_executor) };
385        // SAFETY: As above, all futures must complete in this function so we can change the lifetime
386        let scope_executor: &'env ThreadExecutor<'env> = unsafe { mem::transmute(scope_executor) };
387        let spawned: ConcurrentQueue<FallibleTask<Result<T, Box<(dyn core::any::Any + Send)>>>> =
388            ConcurrentQueue::unbounded();
389        // shadow the variable so that the owned value cannot be used for the rest of the function
390        // SAFETY: As above, all futures must complete in this function so we can change the lifetime
391        let spawned: &'env ConcurrentQueue<
392            FallibleTask<Result<T, Box<(dyn core::any::Any + Send)>>>,
393        > = unsafe { mem::transmute(&spawned) };
394
395        let scope = Scope {
396            executor,
397            external_executor,
398            scope_executor,
399            spawned,
400            scope: PhantomData,
401            env: PhantomData,
402        };
403
404        // shadow the variable so that the owned value cannot be used for the rest of the function
405        // SAFETY: As above, all futures must complete in this function so we can change the lifetime
406        let scope: &'env Scope<'_, 'env, T> = unsafe { mem::transmute(&scope) };
407
408        f(scope);
409
410        if spawned.is_empty() {
411            Vec::new()
412        } else {
413            block_on(async move {
414                let get_results = async {
415                    let mut results = Vec::with_capacity(spawned.len());
416                    while let Ok(task) = spawned.pop() {
417                        if let Some(res) = task.await {
418                            match res {
419                                Ok(res) => results.push(res),
420                                Err(payload) => std::panic::resume_unwind(payload),
421                            }
422                        } else {
423                            panic!("Failed to catch panic!");
424                        }
425                    }
426                    results
427                };
428
429                let tick_task_pool_executor = tick_task_pool_executor || self.threads.is_empty();
430
431                // we get this from a thread local so we should always be on the scope executors thread.
432                // note: it is possible `scope_executor` and `external_executor` is the same executor,
433                // in that case, we should only tick one of them, otherwise, it may cause deadlock.
434                let scope_ticker = scope_executor.ticker().unwrap();
435                let external_ticker = if !external_executor.is_same(scope_executor) {
436                    external_executor.ticker()
437                } else {
438                    None
439                };
440
441                match (external_ticker, tick_task_pool_executor) {
442                    (Some(external_ticker), true) => {
443                        Self::execute_global_external_scope(
444                            executor,
445                            external_ticker,
446                            scope_ticker,
447                            get_results,
448                        )
449                        .await
450                    }
451                    (Some(external_ticker), false) => {
452                        Self::execute_external_scope(external_ticker, scope_ticker, get_results)
453                            .await
454                    }
455                    // either external_executor is none or it is same as scope_executor
456                    (None, true) => {
457                        Self::execute_global_scope(executor, scope_ticker, get_results).await
458                    }
459                    (None, false) => Self::execute_scope(scope_ticker, get_results).await,
460                }
461            })
462        }
463    }
464
465    #[inline]
466    async fn execute_global_external_scope<'scope, 'ticker, T>(
467        executor: &'scope crate::executor::Executor<'scope>,
468        external_ticker: ThreadExecutorTicker<'scope, 'ticker>,
469        scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
470        get_results: impl Future<Output = Vec<T>>,
471    ) -> Vec<T> {
472        // we restart the executors if a task errors. if a scoped
473        // task errors it will panic the scope on the call to get_results
474        let execute_forever = async move {
475            loop {
476                let tick_forever = async {
477                    loop {
478                        external_ticker.tick().or(scope_ticker.tick()).await;
479                    }
480                };
481                // we don't care if it errors. If a scoped task errors it will propagate
482                // to get_results
483                let _result = AssertUnwindSafe(executor.run(tick_forever))
484                    .catch_unwind()
485                    .await
486                    .is_ok();
487            }
488        };
489        get_results.or(execute_forever).await
490    }
491
492    #[inline]
493    async fn execute_external_scope<'scope, 'ticker, T>(
494        external_ticker: ThreadExecutorTicker<'scope, 'ticker>,
495        scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
496        get_results: impl Future<Output = Vec<T>>,
497    ) -> Vec<T> {
498        let execute_forever = async {
499            loop {
500                let tick_forever = async {
501                    loop {
502                        external_ticker.tick().or(scope_ticker.tick()).await;
503                    }
504                };
505                let _result = AssertUnwindSafe(tick_forever).catch_unwind().await.is_ok();
506            }
507        };
508        get_results.or(execute_forever).await
509    }
510
511    #[inline]
512    async fn execute_global_scope<'scope, 'ticker, T>(
513        executor: &'scope crate::executor::Executor<'scope>,
514        scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
515        get_results: impl Future<Output = Vec<T>>,
516    ) -> Vec<T> {
517        let execute_forever = async {
518            loop {
519                let tick_forever = async {
520                    loop {
521                        scope_ticker.tick().await;
522                    }
523                };
524                let _result = AssertUnwindSafe(executor.run(tick_forever))
525                    .catch_unwind()
526                    .await
527                    .is_ok();
528            }
529        };
530        get_results.or(execute_forever).await
531    }
532
533    #[inline]
534    async fn execute_scope<'scope, 'ticker, T>(
535        scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
536        get_results: impl Future<Output = Vec<T>>,
537    ) -> Vec<T> {
538        let execute_forever = async {
539            loop {
540                let tick_forever = async {
541                    loop {
542                        scope_ticker.tick().await;
543                    }
544                };
545                let _result = AssertUnwindSafe(tick_forever).catch_unwind().await.is_ok();
546            }
547        };
548        get_results.or(execute_forever).await
549    }
550
551    /// Spawns a static future onto the thread pool. The returned [`Task`] is a
552    /// future that can be polled for the result. It can also be canceled and
553    /// "detached", allowing the task to continue running even if dropped. In
554    /// any case, the pool will execute the task even without polling by the
555    /// end-user.
556    ///
557    /// If the provided future is non-`Send`, [`TaskPool::spawn_local`] should
558    /// be used instead.
559    pub fn spawn<T>(&self, future: impl Future<Output = T> + Send + 'static) -> Task<T>
560    where
561        T: Send + 'static,
562    {
563        Task::new(self.executor.spawn(future))
564    }
565
566    /// Spawns a static future on the thread-local async executor for the
567    /// current thread. The task will run entirely on the thread the task was
568    /// spawned on.
569    ///
570    /// The returned [`Task`] is a future that can be polled for the
571    /// result. It can also be canceled and "detached", allowing the task to
572    /// continue running even if dropped. In any case, the pool will execute the
573    /// task even without polling by the end-user.
574    ///
575    /// Users should generally prefer to use [`TaskPool::spawn`] instead,
576    /// unless the provided future is not `Send`.
577    pub fn spawn_local<T>(&self, future: impl Future<Output = T> + 'static) -> Task<T>
578    where
579        T: 'static,
580    {
581        Task::new(TaskPool::LOCAL_EXECUTOR.with(|executor| executor.spawn(future)))
582    }
583
584    /// Runs a function with the local executor. Typically used to tick
585    /// the local executor on the main thread as it needs to share time with
586    /// other things.
587    ///
588    /// ```
589    /// use bevy_tasks::TaskPool;
590    ///
591    /// TaskPool::new().with_local_executor(|local_executor| {
592    ///     local_executor.try_tick();
593    /// });
594    /// ```
595    pub fn with_local_executor<F, R>(&self, f: F) -> R
596    where
597        F: FnOnce(&crate::executor::LocalExecutor) -> R,
598    {
599        Self::LOCAL_EXECUTOR.with(f)
600    }
601}
602
603impl Default for TaskPool {
604    fn default() -> Self {
605        Self::new()
606    }
607}
608
609impl Drop for TaskPool {
610    fn drop(&mut self) {
611        self.shutdown_tx.close();
612
613        let panicking = thread::panicking();
614        for join_handle in self.threads.drain(..) {
615            let res = join_handle.join();
616            if !panicking {
617                res.expect("Task thread panicked while executing.");
618            }
619        }
620    }
621}
622
623/// A [`TaskPool`] scope for running one or more non-`'static` futures.
624///
625/// For more information, see [`TaskPool::scope`].
626#[derive(Debug)]
627pub struct Scope<'scope, 'env: 'scope, T> {
628    executor: &'scope crate::executor::Executor<'scope>,
629    external_executor: &'scope ThreadExecutor<'scope>,
630    scope_executor: &'scope ThreadExecutor<'scope>,
631    spawned: &'scope ConcurrentQueue<FallibleTask<Result<T, Box<(dyn core::any::Any + Send)>>>>,
632    // make `Scope` invariant over 'scope and 'env
633    scope: PhantomData<&'scope mut &'scope ()>,
634    env: PhantomData<&'env mut &'env ()>,
635}
636
637impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> {
638    /// Spawns a scoped future onto the thread pool. The scope *must* outlive
639    /// the provided future. The results of the future will be returned as a part of
640    /// [`TaskPool::scope`]'s return value.
641    ///
642    /// For futures that should run on the thread `scope` is called on [`Scope::spawn_on_scope`] should be used
643    /// instead.
644    ///
645    /// For more information, see [`TaskPool::scope`].
646    pub fn spawn<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
647        let task = self
648            .executor
649            .spawn(AssertUnwindSafe(f).catch_unwind())
650            .fallible();
651        // ConcurrentQueue only errors when closed or full, but we never
652        // close and use an unbounded queue, so it is safe to unwrap
653        self.spawned.push(task).unwrap();
654    }
655
656    /// Spawns a scoped future onto the thread the scope is run on. The scope *must* outlive
657    /// the provided future. The results of the future will be returned as a part of
658    /// [`TaskPool::scope`]'s return value.  Users should generally prefer to use
659    /// [`Scope::spawn`] instead, unless the provided future needs to run on the scope's thread.
660    ///
661    /// For more information, see [`TaskPool::scope`].
662    pub fn spawn_on_scope<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
663        let task = self
664            .scope_executor
665            .spawn(AssertUnwindSafe(f).catch_unwind())
666            .fallible();
667        // ConcurrentQueue only errors when closed or full, but we never
668        // close and use an unbounded queue, so it is safe to unwrap
669        self.spawned.push(task).unwrap();
670    }
671
672    /// Spawns a scoped future onto the thread of the external thread executor.
673    /// This is typically the main thread. The scope *must* outlive
674    /// the provided future. The results of the future will be returned as a part of
675    /// [`TaskPool::scope`]'s return value.  Users should generally prefer to use
676    /// [`Scope::spawn`] instead, unless the provided future needs to run on the external thread.
677    ///
678    /// For more information, see [`TaskPool::scope`].
679    pub fn spawn_on_external<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
680        let task = self
681            .external_executor
682            .spawn(AssertUnwindSafe(f).catch_unwind())
683            .fallible();
684        // ConcurrentQueue only errors when closed or full, but we never
685        // close and use an unbounded queue, so it is safe to unwrap
686        self.spawned.push(task).unwrap();
687    }
688}
689
690impl<'scope, 'env, T> Drop for Scope<'scope, 'env, T>
691where
692    T: 'scope,
693{
694    fn drop(&mut self) {
695        block_on(async {
696            while let Ok(task) = self.spawned.pop() {
697                task.cancel().await;
698            }
699        });
700    }
701}
702
703#[cfg(test)]
704mod tests {
705    use super::*;
706    use core::sync::atomic::{AtomicBool, AtomicI32, Ordering};
707    use std::sync::Barrier;
708
709    #[test]
710    fn test_spawn() {
711        let pool = TaskPool::new();
712
713        let foo = Box::new(42);
714        let foo = &*foo;
715
716        let count = Arc::new(AtomicI32::new(0));
717
718        let outputs = pool.scope(|scope| {
719            for _ in 0..100 {
720                let count_clone = count.clone();
721                scope.spawn(async move {
722                    if *foo != 42 {
723                        panic!("not 42!?!?")
724                    } else {
725                        count_clone.fetch_add(1, Ordering::Relaxed);
726                        *foo
727                    }
728                });
729            }
730        });
731
732        for output in &outputs {
733            assert_eq!(*output, 42);
734        }
735
736        assert_eq!(outputs.len(), 100);
737        assert_eq!(count.load(Ordering::Relaxed), 100);
738    }
739
740    #[test]
741    fn test_thread_callbacks() {
742        let counter = Arc::new(AtomicI32::new(0));
743        let start_counter = counter.clone();
744        {
745            let barrier = Arc::new(Barrier::new(11));
746            let last_barrier = barrier.clone();
747            // Build and immediately drop to terminate
748            let _pool = TaskPoolBuilder::new()
749                .num_threads(10)
750                .on_thread_spawn(move || {
751                    start_counter.fetch_add(1, Ordering::Relaxed);
752                    barrier.clone().wait();
753                })
754                .build();
755            last_barrier.wait();
756            assert_eq!(10, counter.load(Ordering::Relaxed));
757        }
758        assert_eq!(10, counter.load(Ordering::Relaxed));
759        let end_counter = counter.clone();
760        {
761            let _pool = TaskPoolBuilder::new()
762                .num_threads(20)
763                .on_thread_destroy(move || {
764                    end_counter.fetch_sub(1, Ordering::Relaxed);
765                })
766                .build();
767            assert_eq!(10, counter.load(Ordering::Relaxed));
768        }
769        assert_eq!(-10, counter.load(Ordering::Relaxed));
770        let start_counter = counter.clone();
771        let end_counter = counter.clone();
772        {
773            let barrier = Arc::new(Barrier::new(6));
774            let last_barrier = barrier.clone();
775            let _pool = TaskPoolBuilder::new()
776                .num_threads(5)
777                .on_thread_spawn(move || {
778                    start_counter.fetch_add(1, Ordering::Relaxed);
779                    barrier.wait();
780                })
781                .on_thread_destroy(move || {
782                    end_counter.fetch_sub(1, Ordering::Relaxed);
783                })
784                .build();
785            last_barrier.wait();
786            assert_eq!(-5, counter.load(Ordering::Relaxed));
787        }
788        assert_eq!(-10, counter.load(Ordering::Relaxed));
789    }
790
791    #[test]
792    fn test_mixed_spawn_on_scope_and_spawn() {
793        let pool = TaskPool::new();
794
795        let foo = Box::new(42);
796        let foo = &*foo;
797
798        let local_count = Arc::new(AtomicI32::new(0));
799        let non_local_count = Arc::new(AtomicI32::new(0));
800
801        let outputs = pool.scope(|scope| {
802            for i in 0..100 {
803                if i % 2 == 0 {
804                    let count_clone = non_local_count.clone();
805                    scope.spawn(async move {
806                        if *foo != 42 {
807                            panic!("not 42!?!?")
808                        } else {
809                            count_clone.fetch_add(1, Ordering::Relaxed);
810                            *foo
811                        }
812                    });
813                } else {
814                    let count_clone = local_count.clone();
815                    scope.spawn_on_scope(async move {
816                        if *foo != 42 {
817                            panic!("not 42!?!?")
818                        } else {
819                            count_clone.fetch_add(1, Ordering::Relaxed);
820                            *foo
821                        }
822                    });
823                }
824            }
825        });
826
827        for output in &outputs {
828            assert_eq!(*output, 42);
829        }
830
831        assert_eq!(outputs.len(), 100);
832        assert_eq!(local_count.load(Ordering::Relaxed), 50);
833        assert_eq!(non_local_count.load(Ordering::Relaxed), 50);
834    }
835
836    #[test]
837    fn test_thread_locality() {
838        let pool = Arc::new(TaskPool::new());
839        let count = Arc::new(AtomicI32::new(0));
840        let barrier = Arc::new(Barrier::new(101));
841        let thread_check_failed = Arc::new(AtomicBool::new(false));
842
843        for _ in 0..100 {
844            let inner_barrier = barrier.clone();
845            let count_clone = count.clone();
846            let inner_pool = pool.clone();
847            let inner_thread_check_failed = thread_check_failed.clone();
848            thread::spawn(move || {
849                inner_pool.scope(|scope| {
850                    let inner_count_clone = count_clone.clone();
851                    scope.spawn(async move {
852                        inner_count_clone.fetch_add(1, Ordering::Release);
853                    });
854                    let spawner = thread::current().id();
855                    let inner_count_clone = count_clone.clone();
856                    scope.spawn_on_scope(async move {
857                        inner_count_clone.fetch_add(1, Ordering::Release);
858                        if thread::current().id() != spawner {
859                            // NOTE: This check is using an atomic rather than simply panicking the
860                            // thread to avoid deadlocking the barrier on failure
861                            inner_thread_check_failed.store(true, Ordering::Release);
862                        }
863                    });
864                });
865                inner_barrier.wait();
866            });
867        }
868        barrier.wait();
869        assert!(!thread_check_failed.load(Ordering::Acquire));
870        assert_eq!(count.load(Ordering::Acquire), 200);
871    }
872
873    #[test]
874    fn test_nested_spawn() {
875        let pool = TaskPool::new();
876
877        let foo = Box::new(42);
878        let foo = &*foo;
879
880        let count = Arc::new(AtomicI32::new(0));
881
882        let outputs: Vec<i32> = pool.scope(|scope| {
883            for _ in 0..10 {
884                let count_clone = count.clone();
885                scope.spawn(async move {
886                    for _ in 0..10 {
887                        let count_clone_clone = count_clone.clone();
888                        scope.spawn(async move {
889                            if *foo != 42 {
890                                panic!("not 42!?!?")
891                            } else {
892                                count_clone_clone.fetch_add(1, Ordering::Relaxed);
893                                *foo
894                            }
895                        });
896                    }
897                    *foo
898                });
899            }
900        });
901
902        for output in &outputs {
903            assert_eq!(*output, 42);
904        }
905
906        // the inner loop runs 100 times and the outer one runs 10. 100 + 10
907        assert_eq!(outputs.len(), 110);
908        assert_eq!(count.load(Ordering::Relaxed), 100);
909    }
910
911    #[test]
912    fn test_nested_locality() {
913        let pool = Arc::new(TaskPool::new());
914        let count = Arc::new(AtomicI32::new(0));
915        let barrier = Arc::new(Barrier::new(101));
916        let thread_check_failed = Arc::new(AtomicBool::new(false));
917
918        for _ in 0..100 {
919            let inner_barrier = barrier.clone();
920            let count_clone = count.clone();
921            let inner_pool = pool.clone();
922            let inner_thread_check_failed = thread_check_failed.clone();
923            thread::spawn(move || {
924                inner_pool.scope(|scope| {
925                    let spawner = thread::current().id();
926                    let inner_count_clone = count_clone.clone();
927                    scope.spawn(async move {
928                        inner_count_clone.fetch_add(1, Ordering::Release);
929
930                        // spawning on the scope from another thread runs the futures on the scope's thread
931                        scope.spawn_on_scope(async move {
932                            inner_count_clone.fetch_add(1, Ordering::Release);
933                            if thread::current().id() != spawner {
934                                // NOTE: This check is using an atomic rather than simply panicking the
935                                // thread to avoid deadlocking the barrier on failure
936                                inner_thread_check_failed.store(true, Ordering::Release);
937                            }
938                        });
939                    });
940                });
941                inner_barrier.wait();
942            });
943        }
944        barrier.wait();
945        assert!(!thread_check_failed.load(Ordering::Acquire));
946        assert_eq!(count.load(Ordering::Acquire), 200);
947    }
948
949    // This test will often freeze on other executors.
950    #[test]
951    fn test_nested_scopes() {
952        let pool = TaskPool::new();
953        let count = Arc::new(AtomicI32::new(0));
954
955        pool.scope(|scope| {
956            scope.spawn(async {
957                pool.scope(|scope| {
958                    scope.spawn(async {
959                        count.fetch_add(1, Ordering::Relaxed);
960                    });
961                });
962            });
963        });
964
965        assert_eq!(count.load(Ordering::Acquire), 1);
966    }
967}