1use {
2 alloc::{collections::VecDeque, vec::Vec},
3 core::{
4 convert::TryFrom as _,
5 fmt::{self, Debug, Display},
6 },
7 gpu_descriptor_types::{
8 CreatePoolError, DescriptorDevice, DescriptorPoolCreateFlags, DescriptorTotalCount,
9 DeviceAllocationError,
10 },
11 hashbrown::HashMap,
12};
13
14bitflags::bitflags! {
15 #[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
17 pub struct DescriptorSetLayoutCreateFlags: u32 {
18 const UPDATE_AFTER_BIND = 0x2;
24 }
25}
26
27#[derive(Debug)]
29pub struct DescriptorSet<S> {
30 raw: S,
31 pool_id: u64,
32 size: DescriptorTotalCount,
33 update_after_bind: bool,
34}
35
36impl<S> DescriptorSet<S> {
37 pub fn raw(&self) -> &S {
39 &self.raw
40 }
41
42 pub unsafe fn raw_mut(&mut self) -> &mut S {
48 &mut self.raw
49 }
50}
51
52#[derive(Debug)]
54pub enum AllocationError {
55 OutOfDeviceMemory,
59
60 OutOfHostMemory,
63
64 Fragmentation,
68}
69
70impl Display for AllocationError {
71 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
72 match self {
73 AllocationError::OutOfDeviceMemory => fmt.write_str("Device memory exhausted"),
74 AllocationError::OutOfHostMemory => fmt.write_str("Host memory exhausted"),
75 AllocationError::Fragmentation => fmt.write_str("Fragmentation"),
76 }
77 }
78}
79
80#[cfg(feature = "std")]
81impl std::error::Error for AllocationError {}
82
83impl From<CreatePoolError> for AllocationError {
84 fn from(err: CreatePoolError) -> Self {
85 match err {
86 CreatePoolError::OutOfDeviceMemory => AllocationError::OutOfDeviceMemory,
87 CreatePoolError::OutOfHostMemory => AllocationError::OutOfHostMemory,
88 CreatePoolError::Fragmentation => AllocationError::Fragmentation,
89 }
90 }
91}
92
93const MIN_SETS: u32 = 64;
94const MAX_SETS: u32 = 512;
95
96#[derive(Debug)]
97struct DescriptorPool<P> {
98 raw: P,
99
100 allocated: u32,
102
103 available: u32,
105}
106
107#[derive(Debug)]
108struct DescriptorBucket<P> {
109 offset: u64,
110 pools: VecDeque<DescriptorPool<P>>,
111 total: u32,
112 update_after_bind: bool,
113 size: DescriptorTotalCount,
114}
115
116impl<P> Drop for DescriptorBucket<P> {
117 #[cfg(feature = "tracing")]
118 fn drop(&mut self) {
119 #[cfg(feature = "std")]
120 {
121 if std::thread::panicking() {
122 return;
123 }
124 }
125 if self.total > 0 {
126 tracing::error!("Descriptor sets were not deallocated");
127 }
128 }
129
130 #[cfg(all(not(feature = "tracing"), feature = "std"))]
131 fn drop(&mut self) {
132 if std::thread::panicking() {
133 return;
134 }
135 if self.total > 0 {
136 eprintln!("Descriptor sets were not deallocated")
137 }
138 }
139
140 #[cfg(all(not(feature = "tracing"), not(feature = "std")))]
141 fn drop(&mut self) {
142 if self.total > 0 {
143 panic!("Descriptor sets were not deallocated")
144 }
145 }
146}
147
148impl<P> DescriptorBucket<P> {
149 fn new(update_after_bind: bool, size: DescriptorTotalCount) -> Self {
150 DescriptorBucket {
151 offset: 0,
152 pools: VecDeque::new(),
153 total: 0,
154 update_after_bind,
155 size,
156 }
157 }
158
159 fn new_pool_size(&self, minimal_set_count: u32) -> (DescriptorTotalCount, u32) {
160 let mut max_sets = MIN_SETS .max(minimal_set_count) .max(self.total.min(MAX_SETS)) .checked_next_power_of_two() .unwrap_or(i32::MAX as u32);
165
166 max_sets = (u32::MAX / self.size.sampler.max(1)).min(max_sets);
167 max_sets = (u32::MAX / self.size.combined_image_sampler.max(1)).min(max_sets);
168 max_sets = (u32::MAX / self.size.sampled_image.max(1)).min(max_sets);
169 max_sets = (u32::MAX / self.size.storage_image.max(1)).min(max_sets);
170 max_sets = (u32::MAX / self.size.uniform_texel_buffer.max(1)).min(max_sets);
171 max_sets = (u32::MAX / self.size.storage_texel_buffer.max(1)).min(max_sets);
172 max_sets = (u32::MAX / self.size.uniform_buffer.max(1)).min(max_sets);
173 max_sets = (u32::MAX / self.size.storage_buffer.max(1)).min(max_sets);
174 max_sets = (u32::MAX / self.size.uniform_buffer_dynamic.max(1)).min(max_sets);
175 max_sets = (u32::MAX / self.size.storage_buffer_dynamic.max(1)).min(max_sets);
176 max_sets = (u32::MAX / self.size.input_attachment.max(1)).min(max_sets);
177 max_sets = (u32::MAX / self.size.acceleration_structure.max(1)).min(max_sets);
178 max_sets = (u32::MAX / self.size.inline_uniform_block_bytes.max(1)).min(max_sets);
179 max_sets = (u32::MAX / self.size.inline_uniform_block_bindings.max(1)).min(max_sets);
180
181 let mut pool_size = DescriptorTotalCount {
182 sampler: self.size.sampler * max_sets,
183 combined_image_sampler: self.size.combined_image_sampler * max_sets,
184 sampled_image: self.size.sampled_image * max_sets,
185 storage_image: self.size.storage_image * max_sets,
186 uniform_texel_buffer: self.size.uniform_texel_buffer * max_sets,
187 storage_texel_buffer: self.size.storage_texel_buffer * max_sets,
188 uniform_buffer: self.size.uniform_buffer * max_sets,
189 storage_buffer: self.size.storage_buffer * max_sets,
190 uniform_buffer_dynamic: self.size.uniform_buffer_dynamic * max_sets,
191 storage_buffer_dynamic: self.size.storage_buffer_dynamic * max_sets,
192 input_attachment: self.size.input_attachment * max_sets,
193 acceleration_structure: self.size.acceleration_structure * max_sets,
194 inline_uniform_block_bytes: self.size.inline_uniform_block_bytes * max_sets,
195 inline_uniform_block_bindings: self.size.inline_uniform_block_bindings * max_sets,
196 };
197
198 if pool_size == Default::default() {
199 pool_size.sampler = 1;
200 }
201
202 (pool_size, max_sets)
203 }
204
205 unsafe fn allocate<L, S>(
206 &mut self,
207 device: &impl DescriptorDevice<L, P, S>,
208 layout: &L,
209 mut count: u32,
210 allocated_sets: &mut Vec<DescriptorSet<S>>,
211 ) -> Result<(), AllocationError> {
212 debug_assert!(usize::try_from(count).is_ok(), "Must be ensured by caller");
213
214 if count == 0 {
215 return Ok(());
216 }
217
218 for (index, pool) in self.pools.iter_mut().enumerate().rev() {
219 if pool.available == 0 {
220 continue;
221 }
222
223 let allocate = pool.available.min(count);
224
225 #[cfg(feature = "tracing")]
226 tracing::trace!("Allocate `{}` sets from exising pool", allocate);
227
228 let result = device.alloc_descriptor_sets(
229 &mut pool.raw,
230 (0..allocate).map(|_| layout),
231 &mut Allocation {
232 size: self.size,
233 update_after_bind: self.update_after_bind,
234 pool_id: index as u64 + self.offset,
235 sets: allocated_sets,
236 },
237 );
238
239 match result {
240 Ok(()) => {}
241 Err(DeviceAllocationError::OutOfDeviceMemory) => {
242 return Err(AllocationError::OutOfDeviceMemory)
243 }
244 Err(DeviceAllocationError::OutOfHostMemory) => {
245 return Err(AllocationError::OutOfHostMemory)
246 }
247 Err(DeviceAllocationError::FragmentedPool) => {
248 #[cfg(feature = "tracing")]
250 tracing::error!("Unexpectedly failed to allocated descriptor sets due to pool fragmentation");
251 pool.available = 0;
252 continue;
253 }
254 Err(DeviceAllocationError::OutOfPoolMemory) => {
255 pool.available = 0;
256 continue;
257 }
258 }
259
260 count -= allocate;
261 pool.available -= allocate;
262 pool.allocated += allocate;
263 self.total += allocate;
264
265 if count == 0 {
266 return Ok(());
267 }
268 }
269
270 while count > 0 {
271 let (pool_size, max_sets) = self.new_pool_size(count);
272 #[cfg(feature = "tracing")]
273 tracing::trace!(
274 "Create new pool with {} sets and {:?} descriptors",
275 max_sets,
276 pool_size,
277 );
278
279 let mut raw = device.create_descriptor_pool(
280 &pool_size,
281 max_sets,
282 if self.update_after_bind {
283 DescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET
284 | DescriptorPoolCreateFlags::UPDATE_AFTER_BIND
285 } else {
286 DescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET
287 },
288 )?;
289
290 let pool_id = self.pools.len() as u64 + self.offset;
291
292 let allocate = max_sets.min(count);
293 let result = device.alloc_descriptor_sets(
294 &mut raw,
295 (0..allocate).map(|_| layout),
296 &mut Allocation {
297 pool_id,
298 size: self.size,
299 update_after_bind: self.update_after_bind,
300 sets: allocated_sets,
301 },
302 );
303
304 match result {
305 Ok(()) => {}
306 Err(err) => {
307 device.destroy_descriptor_pool(raw);
308 match err {
309 DeviceAllocationError::OutOfDeviceMemory => {
310 return Err(AllocationError::OutOfDeviceMemory)
311 }
312 DeviceAllocationError::OutOfHostMemory => {
313 return Err(AllocationError::OutOfHostMemory)
314 }
315 DeviceAllocationError::FragmentedPool => {
316 #[cfg(feature = "tracing")]
318 tracing::error!("Unexpectedly failed to allocated descriptor sets due to pool fragmentation");
319 }
320 DeviceAllocationError::OutOfPoolMemory => {}
321 }
322 panic!("Failed to allocate descriptor sets from fresh pool");
323 }
324 }
325
326 count -= allocate;
327 self.pools.push_back(DescriptorPool {
328 raw,
329 allocated: allocate,
330 available: max_sets - allocate,
331 });
332 self.total += allocate;
333 }
334
335 Ok(())
336 }
337
338 unsafe fn free<L, S>(
339 &mut self,
340 device: &impl DescriptorDevice<L, P, S>,
341 raw_sets: impl IntoIterator<Item = S>,
342 pool_id: u64,
343 ) {
344 let pool = usize::try_from(pool_id - self.offset)
345 .ok()
346 .and_then(|index| self.pools.get_mut(index))
347 .expect("Invalid pool id");
348
349 let mut raw_sets = raw_sets.into_iter();
350 let mut count = 0;
351 device.dealloc_descriptor_sets(&mut pool.raw, raw_sets.by_ref().inspect(|_| count += 1));
352
353 debug_assert!(
354 raw_sets.next().is_none(),
355 "Device must deallocated all sets from iterator"
356 );
357
358 pool.available += count;
359 pool.allocated -= count;
360 self.total -= count;
361 #[cfg(feature = "tracing")]
362 tracing::trace!("Freed {} from descriptor bucket", count);
363
364 while let Some(pool) = self.pools.pop_front() {
365 if self.pools.is_empty() || pool.allocated != 0 {
366 self.pools.push_front(pool);
367 break;
368 }
369
370 #[cfg(feature = "tracing")]
371 tracing::trace!("Destroying old descriptor pool");
372
373 device.destroy_descriptor_pool(pool.raw);
374 self.offset += 1;
375 }
376 }
377
378 unsafe fn cleanup<L, S>(&mut self, device: &impl DescriptorDevice<L, P, S>) {
379 while let Some(pool) = self.pools.pop_front() {
380 if pool.allocated != 0 {
381 self.pools.push_front(pool);
382 break;
383 }
384
385 #[cfg(feature = "tracing")]
386 tracing::trace!("Destroying old descriptor pool");
387
388 device.destroy_descriptor_pool(pool.raw);
389 self.offset += 1;
390 }
391 }
392}
393
394#[derive(Debug)]
397pub struct DescriptorAllocator<P, S> {
398 buckets: HashMap<(DescriptorTotalCount, bool), DescriptorBucket<P>>,
399 sets_cache: Vec<DescriptorSet<S>>,
400 raw_sets_cache: Vec<S>,
401 max_update_after_bind_descriptors_in_all_pools: u32,
402 current_update_after_bind_descriptors_in_all_pools: u32,
403 total: u32,
404}
405
406impl<P, S> Drop for DescriptorAllocator<P, S> {
407 fn drop(&mut self) {
408 if self.buckets.drain().any(|(_, bucket)| bucket.total != 0) {
409 #[cfg(feature = "tracing")]
410 tracing::error!(
411 "`DescriptorAllocator` is dropped while some descriptor sets were not deallocated"
412 );
413 }
414 }
415}
416
417impl<P, S> DescriptorAllocator<P, S> {
418 pub fn new(max_update_after_bind_descriptors_in_all_pools: u32) -> Self {
420 DescriptorAllocator {
421 buckets: HashMap::default(),
422 total: 0,
423 sets_cache: Vec::new(),
424 raw_sets_cache: Vec::new(),
425 max_update_after_bind_descriptors_in_all_pools,
426 current_update_after_bind_descriptors_in_all_pools: 0,
427 }
428 }
429
430 pub unsafe fn allocate<L, D>(
439 &mut self,
440 device: &D,
441 layout: &L,
442 flags: DescriptorSetLayoutCreateFlags,
443 layout_descriptor_count: &DescriptorTotalCount,
444 count: u32,
445 ) -> Result<Vec<DescriptorSet<S>>, AllocationError>
446 where
447 S: Debug,
448 L: Debug,
449 D: DescriptorDevice<L, P, S>,
450 {
451 if count == 0 {
452 return Ok(Vec::new());
453 }
454
455 let descriptor_count = count * layout_descriptor_count.total();
456
457 let update_after_bind = flags.contains(DescriptorSetLayoutCreateFlags::UPDATE_AFTER_BIND);
458
459 if update_after_bind
460 && self.max_update_after_bind_descriptors_in_all_pools
461 - self.current_update_after_bind_descriptors_in_all_pools
462 < descriptor_count
463 {
464 return Err(AllocationError::Fragmentation);
465 }
466
467 #[cfg(feature = "tracing")]
468 tracing::trace!(
469 "Allocating {} sets with layout {:?} @ {:?}",
470 count,
471 layout,
472 layout_descriptor_count
473 );
474
475 let bucket = self
476 .buckets
477 .entry((*layout_descriptor_count, update_after_bind))
478 .or_insert_with(|| DescriptorBucket::new(update_after_bind, *layout_descriptor_count));
479 match bucket.allocate(device, layout, count, &mut self.sets_cache) {
480 Ok(()) => {
481 self.total += descriptor_count;
482 if update_after_bind {
483 self.current_update_after_bind_descriptors_in_all_pools += descriptor_count;
484 }
485
486 Ok(core::mem::take(&mut self.sets_cache))
487 }
488 Err(err) => {
489 debug_assert!(self.raw_sets_cache.is_empty());
490
491 let mut last = None;
493
494 for set in self.sets_cache.drain(..) {
495 if Some(set.pool_id) != last {
496 if let Some(last_id) = last {
497 bucket.free(device, self.raw_sets_cache.drain(..), last_id);
499 }
500 }
501 last = Some(set.pool_id);
502 self.raw_sets_cache.push(set.raw);
503 }
504
505 if let Some(last_id) = last {
506 bucket.free(device, self.raw_sets_cache.drain(..), last_id);
507 }
508
509 Err(err)
510 }
511 }
512 }
513
514 pub unsafe fn free<L, D, I>(&mut self, device: &D, sets: I)
524 where
525 D: DescriptorDevice<L, P, S>,
526 I: IntoIterator<Item = DescriptorSet<S>>,
527 {
528 debug_assert!(self.raw_sets_cache.is_empty());
529
530 let mut last_key = (EMPTY_COUNT, false);
531 let mut last_pool_id = None;
532
533 for set in sets {
534 if last_key != (set.size, set.update_after_bind) || last_pool_id != Some(set.pool_id) {
535 if let Some(pool_id) = last_pool_id {
536 let bucket = self
537 .buckets
538 .get_mut(&last_key)
539 .expect("Set must be allocated from this allocator");
540
541 debug_assert!(u32::try_from(self.raw_sets_cache.len())
542 .ok()
543 .map_or(false, |count| count <= bucket.total));
544
545 bucket.free(device, self.raw_sets_cache.drain(..), pool_id);
546 }
547 last_key = (set.size, set.update_after_bind);
548 last_pool_id = Some(set.pool_id);
549 }
550 self.raw_sets_cache.push(set.raw);
551 }
552
553 if let Some(pool_id) = last_pool_id {
554 let bucket = self
555 .buckets
556 .get_mut(&last_key)
557 .expect("Set must be allocated from this allocator");
558
559 debug_assert!(u32::try_from(self.raw_sets_cache.len())
560 .ok()
561 .map_or(false, |count| count <= bucket.total));
562
563 bucket.free(device, self.raw_sets_cache.drain(..), pool_id);
564 }
565 }
566
567 pub unsafe fn cleanup<L>(&mut self, device: &impl DescriptorDevice<L, P, S>) {
574 for bucket in self.buckets.values_mut() {
575 bucket.cleanup(device)
576 }
577 self.buckets.retain(|_, bucket| !bucket.pools.is_empty());
578 }
579}
580
581const EMPTY_COUNT: DescriptorTotalCount = DescriptorTotalCount {
583 sampler: 0,
584 combined_image_sampler: 0,
585 sampled_image: 0,
586 storage_image: 0,
587 uniform_texel_buffer: 0,
588 storage_texel_buffer: 0,
589 uniform_buffer: 0,
590 storage_buffer: 0,
591 uniform_buffer_dynamic: 0,
592 storage_buffer_dynamic: 0,
593 input_attachment: 0,
594 acceleration_structure: 0,
595 inline_uniform_block_bytes: 0,
596 inline_uniform_block_bindings: 0,
597};
598
599struct Allocation<'a, S> {
600 update_after_bind: bool,
601 size: DescriptorTotalCount,
602 pool_id: u64,
603 sets: &'a mut Vec<DescriptorSet<S>>,
604}
605
606impl<S> Extend<S> for Allocation<'_, S> {
607 fn extend<T: IntoIterator<Item = S>>(&mut self, iter: T) {
608 let update_after_bind = self.update_after_bind;
609 let size = self.size;
610 let pool_id = self.pool_id;
611 self.sets.extend(iter.into_iter().map(|raw| DescriptorSet {
612 raw,
613 pool_id,
614 update_after_bind,
615 size,
616 }))
617 }
618}