bevy_app/
task_pool_plugin.rs1use crate::{App, Plugin};
2
3use alloc::string::ToString;
4use bevy_platform::sync::Arc;
5use bevy_tasks::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, TaskPoolBuilder};
6use core::{fmt::Debug, marker::PhantomData};
7use log::trace;
8
9cfg_if::cfg_if! {
10 if #[cfg(not(all(target_arch = "wasm32", feature = "web")))] {
11 use {crate::Last, bevy_ecs::prelude::NonSend, bevy_tasks::tick_global_task_pools_on_main_thread};
12
13 fn tick_global_task_pools(_main_thread_marker: Option<NonSend<NonSendMarker>>) {
18 tick_global_task_pools_on_main_thread();
19 }
20 }
21}
22
23#[derive(Default)]
25pub struct TaskPoolPlugin {
26 pub task_pool_options: TaskPoolOptions,
28}
29
30impl Plugin for TaskPoolPlugin {
31 fn build(&self, _app: &mut App) {
32 self.task_pool_options.create_default_pools();
34
35 #[cfg(not(all(target_arch = "wasm32", feature = "web")))]
36 _app.add_systems(Last, tick_global_task_pools);
37 }
38}
39pub struct NonSendMarker(PhantomData<*mut ()>);
41
42#[derive(Clone)]
45pub struct TaskPoolThreadAssignmentPolicy {
46 pub min_threads: usize,
48 pub max_threads: usize,
50 pub percent: f32,
53 pub on_thread_spawn: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
56 pub on_thread_destroy: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
59}
60
61impl Debug for TaskPoolThreadAssignmentPolicy {
62 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
63 f.debug_struct("TaskPoolThreadAssignmentPolicy")
64 .field("min_threads", &self.min_threads)
65 .field("max_threads", &self.max_threads)
66 .field("percent", &self.percent)
67 .finish()
68 }
69}
70
71impl TaskPoolThreadAssignmentPolicy {
72 fn get_number_of_threads(&self, remaining_threads: usize, total_threads: usize) -> usize {
74 assert!(self.percent >= 0.0);
75 let proportion = total_threads as f32 * self.percent;
76 let mut desired = proportion as usize;
77
78 if proportion - desired as f32 >= 0.5 {
81 desired += 1;
82 }
83
84 desired = desired.min(remaining_threads);
86
87 desired.clamp(self.min_threads, self.max_threads)
91 }
92}
93
94#[derive(Clone, Debug)]
97pub struct TaskPoolOptions {
98 pub min_total_threads: usize,
101 pub max_total_threads: usize,
104
105 pub io: TaskPoolThreadAssignmentPolicy,
107 pub async_compute: TaskPoolThreadAssignmentPolicy,
109 pub compute: TaskPoolThreadAssignmentPolicy,
111}
112
113impl Default for TaskPoolOptions {
114 fn default() -> Self {
115 TaskPoolOptions {
116 min_total_threads: 1,
118 max_total_threads: usize::MAX,
119
120 io: TaskPoolThreadAssignmentPolicy {
122 min_threads: 1,
123 max_threads: 4,
124 percent: 0.25,
125 on_thread_spawn: None,
126 on_thread_destroy: None,
127 },
128
129 async_compute: TaskPoolThreadAssignmentPolicy {
131 min_threads: 1,
132 max_threads: 4,
133 percent: 0.25,
134 on_thread_spawn: None,
135 on_thread_destroy: None,
136 },
137
138 compute: TaskPoolThreadAssignmentPolicy {
140 min_threads: 1,
141 max_threads: usize::MAX,
142 percent: 1.0, on_thread_spawn: None,
144 on_thread_destroy: None,
145 },
146 }
147 }
148}
149
150impl TaskPoolOptions {
151 pub fn with_num_threads(thread_count: usize) -> Self {
153 TaskPoolOptions {
154 min_total_threads: thread_count,
155 max_total_threads: thread_count,
156 ..Default::default()
157 }
158 }
159
160 pub fn create_default_pools(&self) {
162 let total_threads = bevy_tasks::available_parallelism()
163 .clamp(self.min_total_threads, self.max_total_threads);
164 trace!("Assigning {} cores to default task pools", total_threads);
165
166 let mut remaining_threads = total_threads;
167
168 {
169 let io_threads = self
171 .io
172 .get_number_of_threads(remaining_threads, total_threads);
173
174 trace!("IO Threads: {}", io_threads);
175 remaining_threads = remaining_threads.saturating_sub(io_threads);
176
177 IoTaskPool::get_or_init(|| {
178 let builder = TaskPoolBuilder::default()
179 .num_threads(io_threads)
180 .thread_name("IO Task Pool".to_string());
181
182 #[cfg(not(all(target_arch = "wasm32", feature = "web")))]
183 let builder = {
184 let mut builder = builder;
185 if let Some(f) = self.io.on_thread_spawn.clone() {
186 builder = builder.on_thread_spawn(move || f());
187 }
188 if let Some(f) = self.io.on_thread_destroy.clone() {
189 builder = builder.on_thread_destroy(move || f());
190 }
191 builder
192 };
193
194 builder.build()
195 });
196 }
197
198 {
199 let async_compute_threads = self
201 .async_compute
202 .get_number_of_threads(remaining_threads, total_threads);
203
204 trace!("Async Compute Threads: {}", async_compute_threads);
205 remaining_threads = remaining_threads.saturating_sub(async_compute_threads);
206
207 AsyncComputeTaskPool::get_or_init(|| {
208 let builder = TaskPoolBuilder::default()
209 .num_threads(async_compute_threads)
210 .thread_name("Async Compute Task Pool".to_string());
211
212 #[cfg(not(all(target_arch = "wasm32", feature = "web")))]
213 let builder = {
214 let mut builder = builder;
215 if let Some(f) = self.async_compute.on_thread_spawn.clone() {
216 builder = builder.on_thread_spawn(move || f());
217 }
218 if let Some(f) = self.async_compute.on_thread_destroy.clone() {
219 builder = builder.on_thread_destroy(move || f());
220 }
221 builder
222 };
223
224 builder.build()
225 });
226 }
227
228 {
229 let compute_threads = self
232 .compute
233 .get_number_of_threads(remaining_threads, total_threads);
234
235 trace!("Compute Threads: {}", compute_threads);
236
237 ComputeTaskPool::get_or_init(|| {
238 let builder = TaskPoolBuilder::default()
239 .num_threads(compute_threads)
240 .thread_name("Compute Task Pool".to_string());
241
242 #[cfg(not(all(target_arch = "wasm32", feature = "web")))]
243 let builder = {
244 let mut builder = builder;
245 if let Some(f) = self.compute.on_thread_spawn.clone() {
246 builder = builder.on_thread_spawn(move || f());
247 }
248 if let Some(f) = self.compute.on_thread_destroy.clone() {
249 builder = builder.on_thread_destroy(move || f());
250 }
251 builder
252 };
253
254 builder.build()
255 });
256 }
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use bevy_tasks::prelude::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool};
264
265 #[test]
266 fn runs_spawn_local_tasks() {
267 let mut app = App::new();
268 app.add_plugins(TaskPoolPlugin::default());
269
270 let (async_tx, async_rx) = crossbeam_channel::unbounded();
271 AsyncComputeTaskPool::get()
272 .spawn_local(async move {
273 async_tx.send(()).unwrap();
274 })
275 .detach();
276
277 let (compute_tx, compute_rx) = crossbeam_channel::unbounded();
278 ComputeTaskPool::get()
279 .spawn_local(async move {
280 compute_tx.send(()).unwrap();
281 })
282 .detach();
283
284 let (io_tx, io_rx) = crossbeam_channel::unbounded();
285 IoTaskPool::get()
286 .spawn_local(async move {
287 io_tx.send(()).unwrap();
288 })
289 .detach();
290
291 app.run();
292
293 async_rx.try_recv().unwrap();
294 compute_rx.try_recv().unwrap();
295 io_rx.try_recv().unwrap();
296 }
297}