bevy_app/
task_pool_plugin.rs

1use crate::{App, Plugin};
2
3use alloc::string::ToString;
4use bevy_platform::sync::Arc;
5use bevy_tasks::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, TaskPoolBuilder};
6use core::{fmt::Debug, marker::PhantomData};
7use log::trace;
8
9cfg_if::cfg_if! {
10    if #[cfg(not(all(target_arch = "wasm32", feature = "web")))] {
11        use {crate::Last, bevy_ecs::prelude::NonSend, bevy_tasks::tick_global_task_pools_on_main_thread};
12
13        /// A system used to check and advanced our task pools.
14        ///
15        /// Calls [`tick_global_task_pools_on_main_thread`],
16        /// and uses [`NonSendMarker`] to ensure that this system runs on the main thread
17        fn tick_global_task_pools(_main_thread_marker: Option<NonSend<NonSendMarker>>) {
18            tick_global_task_pools_on_main_thread();
19        }
20    }
21}
22
23/// Setup of default task pools: [`AsyncComputeTaskPool`], [`ComputeTaskPool`], [`IoTaskPool`].
24#[derive(Default)]
25pub struct TaskPoolPlugin {
26    /// Options for the [`TaskPool`](bevy_tasks::TaskPool) created at application start.
27    pub task_pool_options: TaskPoolOptions,
28}
29
30impl Plugin for TaskPoolPlugin {
31    fn build(&self, _app: &mut App) {
32        // Setup the default bevy task pools
33        self.task_pool_options.create_default_pools();
34
35        #[cfg(not(all(target_arch = "wasm32", feature = "web")))]
36        _app.add_systems(Last, tick_global_task_pools);
37    }
38}
39/// A dummy type that is [`!Send`](Send), to force systems to run on the main thread.
40pub struct NonSendMarker(PhantomData<*mut ()>);
41
42/// Defines a simple way to determine how many threads to use given the number of remaining cores
43/// and number of total cores
44#[derive(Clone)]
45pub struct TaskPoolThreadAssignmentPolicy {
46    /// Force using at least this many threads
47    pub min_threads: usize,
48    /// Under no circumstance use more than this many threads for this pool
49    pub max_threads: usize,
50    /// Target using this percentage of total cores, clamped by `min_threads` and `max_threads`. It is
51    /// permitted to use 1.0 to try to use all remaining threads
52    pub percent: f32,
53    /// Callback that is invoked once for every created thread as it starts.
54    /// This configuration will be ignored under wasm platform.
55    pub on_thread_spawn: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
56    /// Callback that is invoked once for every created thread as it terminates
57    /// This configuration will be ignored under wasm platform.
58    pub on_thread_destroy: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
59}
60
61impl Debug for TaskPoolThreadAssignmentPolicy {
62    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
63        f.debug_struct("TaskPoolThreadAssignmentPolicy")
64            .field("min_threads", &self.min_threads)
65            .field("max_threads", &self.max_threads)
66            .field("percent", &self.percent)
67            .finish()
68    }
69}
70
71impl TaskPoolThreadAssignmentPolicy {
72    /// Determine the number of threads to use for this task pool
73    fn get_number_of_threads(&self, remaining_threads: usize, total_threads: usize) -> usize {
74        assert!(self.percent >= 0.0);
75        let proportion = total_threads as f32 * self.percent;
76        let mut desired = proportion as usize;
77
78        // Equivalent to round() for positive floats without libm requirement for
79        // no_std compatibility
80        if proportion - desired as f32 >= 0.5 {
81            desired += 1;
82        }
83
84        // Limit ourselves to the number of cores available
85        desired = desired.min(remaining_threads);
86
87        // Clamp by min_threads, max_threads. (This may result in us using more threads than are
88        // available, this is intended. An example case where this might happen is a device with
89        // <= 2 threads.
90        desired.clamp(self.min_threads, self.max_threads)
91    }
92}
93
94/// Helper for configuring and creating the default task pools. For end-users who want full control,
95/// set up [`TaskPoolPlugin`]
96#[derive(Clone, Debug)]
97pub struct TaskPoolOptions {
98    /// If the number of physical cores is less than `min_total_threads`, force using
99    /// `min_total_threads`
100    pub min_total_threads: usize,
101    /// If the number of physical cores is greater than `max_total_threads`, force using
102    /// `max_total_threads`
103    pub max_total_threads: usize,
104
105    /// Used to determine number of IO threads to allocate
106    pub io: TaskPoolThreadAssignmentPolicy,
107    /// Used to determine number of async compute threads to allocate
108    pub async_compute: TaskPoolThreadAssignmentPolicy,
109    /// Used to determine number of compute threads to allocate
110    pub compute: TaskPoolThreadAssignmentPolicy,
111}
112
113impl Default for TaskPoolOptions {
114    fn default() -> Self {
115        TaskPoolOptions {
116            // By default, use however many cores are available on the system
117            min_total_threads: 1,
118            max_total_threads: usize::MAX,
119
120            // Use 25% of cores for IO, at least 1, no more than 4
121            io: TaskPoolThreadAssignmentPolicy {
122                min_threads: 1,
123                max_threads: 4,
124                percent: 0.25,
125                on_thread_spawn: None,
126                on_thread_destroy: None,
127            },
128
129            // Use 25% of cores for async compute, at least 1, no more than 4
130            async_compute: TaskPoolThreadAssignmentPolicy {
131                min_threads: 1,
132                max_threads: 4,
133                percent: 0.25,
134                on_thread_spawn: None,
135                on_thread_destroy: None,
136            },
137
138            // Use all remaining cores for compute (at least 1)
139            compute: TaskPoolThreadAssignmentPolicy {
140                min_threads: 1,
141                max_threads: usize::MAX,
142                percent: 1.0, // This 1.0 here means "whatever is left over"
143                on_thread_spawn: None,
144                on_thread_destroy: None,
145            },
146        }
147    }
148}
149
150impl TaskPoolOptions {
151    /// Create a configuration that forces using the given number of threads.
152    pub fn with_num_threads(thread_count: usize) -> Self {
153        TaskPoolOptions {
154            min_total_threads: thread_count,
155            max_total_threads: thread_count,
156            ..Default::default()
157        }
158    }
159
160    /// Inserts the default thread pools into the given resource map based on the configured values
161    pub fn create_default_pools(&self) {
162        let total_threads = bevy_tasks::available_parallelism()
163            .clamp(self.min_total_threads, self.max_total_threads);
164        trace!("Assigning {} cores to default task pools", total_threads);
165
166        let mut remaining_threads = total_threads;
167
168        {
169            // Determine the number of IO threads we will use
170            let io_threads = self
171                .io
172                .get_number_of_threads(remaining_threads, total_threads);
173
174            trace!("IO Threads: {}", io_threads);
175            remaining_threads = remaining_threads.saturating_sub(io_threads);
176
177            IoTaskPool::get_or_init(|| {
178                let builder = TaskPoolBuilder::default()
179                    .num_threads(io_threads)
180                    .thread_name("IO Task Pool".to_string());
181
182                #[cfg(not(all(target_arch = "wasm32", feature = "web")))]
183                let builder = {
184                    let mut builder = builder;
185                    if let Some(f) = self.io.on_thread_spawn.clone() {
186                        builder = builder.on_thread_spawn(move || f());
187                    }
188                    if let Some(f) = self.io.on_thread_destroy.clone() {
189                        builder = builder.on_thread_destroy(move || f());
190                    }
191                    builder
192                };
193
194                builder.build()
195            });
196        }
197
198        {
199            // Determine the number of async compute threads we will use
200            let async_compute_threads = self
201                .async_compute
202                .get_number_of_threads(remaining_threads, total_threads);
203
204            trace!("Async Compute Threads: {}", async_compute_threads);
205            remaining_threads = remaining_threads.saturating_sub(async_compute_threads);
206
207            AsyncComputeTaskPool::get_or_init(|| {
208                let builder = TaskPoolBuilder::default()
209                    .num_threads(async_compute_threads)
210                    .thread_name("Async Compute Task Pool".to_string());
211
212                #[cfg(not(all(target_arch = "wasm32", feature = "web")))]
213                let builder = {
214                    let mut builder = builder;
215                    if let Some(f) = self.async_compute.on_thread_spawn.clone() {
216                        builder = builder.on_thread_spawn(move || f());
217                    }
218                    if let Some(f) = self.async_compute.on_thread_destroy.clone() {
219                        builder = builder.on_thread_destroy(move || f());
220                    }
221                    builder
222                };
223
224                builder.build()
225            });
226        }
227
228        {
229            // Determine the number of compute threads we will use
230            // This is intentionally last so that an end user can specify 1.0 as the percent
231            let compute_threads = self
232                .compute
233                .get_number_of_threads(remaining_threads, total_threads);
234
235            trace!("Compute Threads: {}", compute_threads);
236
237            ComputeTaskPool::get_or_init(|| {
238                let builder = TaskPoolBuilder::default()
239                    .num_threads(compute_threads)
240                    .thread_name("Compute Task Pool".to_string());
241
242                #[cfg(not(all(target_arch = "wasm32", feature = "web")))]
243                let builder = {
244                    let mut builder = builder;
245                    if let Some(f) = self.compute.on_thread_spawn.clone() {
246                        builder = builder.on_thread_spawn(move || f());
247                    }
248                    if let Some(f) = self.compute.on_thread_destroy.clone() {
249                        builder = builder.on_thread_destroy(move || f());
250                    }
251                    builder
252                };
253
254                builder.build()
255            });
256        }
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use bevy_tasks::prelude::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool};
264
265    #[test]
266    fn runs_spawn_local_tasks() {
267        let mut app = App::new();
268        app.add_plugins(TaskPoolPlugin::default());
269
270        let (async_tx, async_rx) = crossbeam_channel::unbounded();
271        AsyncComputeTaskPool::get()
272            .spawn_local(async move {
273                async_tx.send(()).unwrap();
274            })
275            .detach();
276
277        let (compute_tx, compute_rx) = crossbeam_channel::unbounded();
278        ComputeTaskPool::get()
279            .spawn_local(async move {
280                compute_tx.send(()).unwrap();
281            })
282            .detach();
283
284        let (io_tx, io_rx) = crossbeam_channel::unbounded();
285        IoTaskPool::get()
286            .spawn_local(async move {
287                io_tx.send(()).unwrap();
288            })
289            .detach();
290
291        app.run();
292
293        async_rx.try_recv().unwrap();
294        compute_rx.try_recv().unwrap();
295        io_rx.try_recv().unwrap();
296    }
297}