bevy_render/
gpu_readback.rs

1use crate::{
2    extract_component::ExtractComponentPlugin,
3    render_asset::RenderAssets,
4    render_resource::{
5        Buffer, BufferUsages, CommandEncoder, Extent3d, TexelCopyBufferLayout, Texture,
6        TextureFormat,
7    },
8    renderer::{render_system, RenderDevice},
9    storage::{GpuShaderStorageBuffer, ShaderStorageBuffer},
10    sync_world::MainEntity,
11    texture::GpuImage,
12    ExtractSchedule, MainWorld, Render, RenderApp, RenderSystems,
13};
14use async_channel::{Receiver, Sender};
15use bevy_app::{App, Plugin};
16use bevy_asset::Handle;
17use bevy_derive::{Deref, DerefMut};
18use bevy_ecs::schedule::IntoScheduleConfigs;
19use bevy_ecs::{
20    change_detection::ResMut,
21    entity::Entity,
22    event::EntityEvent,
23    prelude::{Component, Resource, World},
24    system::{Query, Res},
25};
26use bevy_image::{Image, TextureFormatPixelInfo};
27use bevy_platform::collections::HashMap;
28use bevy_reflect::Reflect;
29use bevy_render_macros::ExtractComponent;
30use encase::internal::ReadFrom;
31use encase::private::Reader;
32use encase::ShaderType;
33use tracing::warn;
34
35/// A plugin that enables reading back gpu buffers and textures to the cpu.
36pub struct GpuReadbackPlugin {
37    /// Describes the number of frames a buffer can be unused before it is removed from the pool in
38    /// order to avoid unnecessary reallocations.
39    max_unused_frames: usize,
40}
41
42impl Default for GpuReadbackPlugin {
43    fn default() -> Self {
44        Self {
45            max_unused_frames: 10,
46        }
47    }
48}
49
50impl Plugin for GpuReadbackPlugin {
51    fn build(&self, app: &mut App) {
52        app.add_plugins(ExtractComponentPlugin::<Readback>::default());
53
54        if let Some(render_app) = app.get_sub_app_mut(RenderApp) {
55            render_app
56                .init_resource::<GpuReadbackBufferPool>()
57                .init_resource::<GpuReadbacks>()
58                .insert_resource(GpuReadbackMaxUnusedFrames(self.max_unused_frames))
59                .add_systems(ExtractSchedule, sync_readbacks.ambiguous_with_all())
60                .add_systems(
61                    Render,
62                    (
63                        prepare_buffers.in_set(RenderSystems::PrepareResources),
64                        map_buffers
65                            .after(render_system)
66                            .in_set(RenderSystems::Render),
67                    ),
68                );
69        }
70    }
71}
72
73/// A component that registers the wrapped handle for gpu readback, either a texture or a buffer.
74///
75/// Data is read asynchronously and will be triggered on the entity via the [`ReadbackComplete`] event
76/// when complete. If this component is not removed, the readback will be attempted every frame
77#[derive(Component, ExtractComponent, Clone, Debug)]
78pub enum Readback {
79    Texture(Handle<Image>),
80    Buffer {
81        buffer: Handle<ShaderStorageBuffer>,
82        start_offset_and_size: Option<(u64, u64)>,
83    },
84}
85
86impl Readback {
87    /// Create a readback component for a texture using the given handle.
88    pub fn texture(image: Handle<Image>) -> Self {
89        Self::Texture(image)
90    }
91
92    /// Create a readback component for a full buffer using the given handle.
93    pub fn buffer(buffer: Handle<ShaderStorageBuffer>) -> Self {
94        Self::Buffer {
95            buffer,
96            start_offset_and_size: None,
97        }
98    }
99
100    /// Create a readback component for a buffer range using the given handle, a start offset in bytes
101    /// and a number of bytes to read.
102    pub fn buffer_range(buffer: Handle<ShaderStorageBuffer>, start_offset: u64, size: u64) -> Self {
103        Self::Buffer {
104            buffer,
105            start_offset_and_size: Some((start_offset, size)),
106        }
107    }
108}
109
110/// An event that is triggered when a gpu readback is complete.
111///
112/// The event contains the data as a `Vec<u8>`, which can be interpreted as the raw bytes of the
113/// requested buffer or texture.
114#[derive(EntityEvent, Deref, DerefMut, Reflect, Debug)]
115#[reflect(Debug)]
116pub struct ReadbackComplete {
117    pub entity: Entity,
118    #[deref]
119    pub data: Vec<u8>,
120}
121
122impl ReadbackComplete {
123    /// Convert the raw bytes of the event to a shader type.
124    pub fn to_shader_type<T: ShaderType + ReadFrom + Default>(&self) -> T {
125        let mut val = T::default();
126        let mut reader = Reader::new::<T>(&self.data, 0).expect("Failed to create Reader");
127        T::read_from(&mut val, &mut reader);
128        val
129    }
130}
131
132#[derive(Resource)]
133struct GpuReadbackMaxUnusedFrames(usize);
134
135struct GpuReadbackBuffer {
136    buffer: Buffer,
137    taken: bool,
138    frames_unused: usize,
139}
140
141#[derive(Resource, Default)]
142struct GpuReadbackBufferPool {
143    // Map of buffer size to list of buffers, with a flag for whether the buffer is taken and how
144    // many frames it has been unused for.
145    // TODO: We could ideally write all readback data to one big buffer per frame, the assumption
146    // here is that very few entities well actually be read back at once, and their size is
147    // unlikely to change.
148    buffers: HashMap<u64, Vec<GpuReadbackBuffer>>,
149}
150
151impl GpuReadbackBufferPool {
152    fn get(&mut self, render_device: &RenderDevice, size: u64) -> Buffer {
153        let buffers = self.buffers.entry(size).or_default();
154
155        // find an untaken buffer for this size
156        if let Some(buf) = buffers.iter_mut().find(|x| !x.taken) {
157            buf.taken = true;
158            buf.frames_unused = 0;
159            return buf.buffer.clone();
160        }
161
162        let buffer = render_device.create_buffer(&wgpu::BufferDescriptor {
163            label: Some("Readback Buffer"),
164            size,
165            usage: BufferUsages::COPY_DST | BufferUsages::MAP_READ,
166            mapped_at_creation: false,
167        });
168        buffers.push(GpuReadbackBuffer {
169            buffer: buffer.clone(),
170            taken: true,
171            frames_unused: 0,
172        });
173        buffer
174    }
175
176    // Returns the buffer to the pool so it can be used in a future frame
177    fn return_buffer(&mut self, buffer: &Buffer) {
178        let size = buffer.size();
179        let buffers = self
180            .buffers
181            .get_mut(&size)
182            .expect("Returned buffer of untracked size");
183        if let Some(buf) = buffers.iter_mut().find(|x| x.buffer.id() == buffer.id()) {
184            buf.taken = false;
185        } else {
186            warn!("Returned buffer that was not allocated");
187        }
188    }
189
190    fn update(&mut self, max_unused_frames: usize) {
191        for (_, buffers) in &mut self.buffers {
192            // Tick all the buffers
193            for buf in &mut *buffers {
194                if !buf.taken {
195                    buf.frames_unused += 1;
196                }
197            }
198
199            // Remove buffers that haven't been used for MAX_UNUSED_FRAMES
200            buffers.retain(|x| x.frames_unused < max_unused_frames);
201        }
202
203        // Remove empty buffer sizes
204        self.buffers.retain(|_, buffers| !buffers.is_empty());
205    }
206}
207
208enum ReadbackSource {
209    Texture {
210        texture: Texture,
211        layout: TexelCopyBufferLayout,
212        size: Extent3d,
213    },
214    Buffer {
215        buffer: Buffer,
216        start_offset_and_size: Option<(u64, u64)>,
217    },
218}
219
220#[derive(Resource, Default)]
221struct GpuReadbacks {
222    requested: Vec<GpuReadback>,
223    mapped: Vec<GpuReadback>,
224}
225
226struct GpuReadback {
227    pub entity: Entity,
228    pub src: ReadbackSource,
229    pub buffer: Buffer,
230    pub rx: Receiver<(Entity, Buffer, Vec<u8>)>,
231    pub tx: Sender<(Entity, Buffer, Vec<u8>)>,
232}
233
234fn sync_readbacks(
235    mut main_world: ResMut<MainWorld>,
236    mut buffer_pool: ResMut<GpuReadbackBufferPool>,
237    mut readbacks: ResMut<GpuReadbacks>,
238    max_unused_frames: Res<GpuReadbackMaxUnusedFrames>,
239) {
240    readbacks.mapped.retain(|readback| {
241        if let Ok((entity, buffer, data)) = readback.rx.try_recv() {
242            main_world.trigger(ReadbackComplete { data, entity });
243            buffer_pool.return_buffer(&buffer);
244            false
245        } else {
246            true
247        }
248    });
249
250    buffer_pool.update(max_unused_frames.0);
251}
252
253fn prepare_buffers(
254    render_device: Res<RenderDevice>,
255    mut readbacks: ResMut<GpuReadbacks>,
256    mut buffer_pool: ResMut<GpuReadbackBufferPool>,
257    gpu_images: Res<RenderAssets<GpuImage>>,
258    ssbos: Res<RenderAssets<GpuShaderStorageBuffer>>,
259    handles: Query<(&MainEntity, &Readback)>,
260) {
261    for (entity, readback) in handles.iter() {
262        match readback {
263            Readback::Texture(image) => {
264                if let Some(gpu_image) = gpu_images.get(image)
265                    && let Ok(pixel_size) = gpu_image.texture_format.pixel_size()
266                {
267                    let layout = layout_data(gpu_image.size, gpu_image.texture_format);
268                    let buffer = buffer_pool.get(
269                        &render_device,
270                        get_aligned_size(gpu_image.size, pixel_size as u32) as u64,
271                    );
272                    let (tx, rx) = async_channel::bounded(1);
273                    readbacks.requested.push(GpuReadback {
274                        entity: entity.id(),
275                        src: ReadbackSource::Texture {
276                            texture: gpu_image.texture.clone(),
277                            layout,
278                            size: gpu_image.size,
279                        },
280                        buffer,
281                        rx,
282                        tx,
283                    });
284                }
285            }
286            Readback::Buffer {
287                buffer,
288                start_offset_and_size,
289            } => {
290                if let Some(ssbo) = ssbos.get(buffer) {
291                    let full_size = ssbo.buffer.size();
292                    let size = start_offset_and_size
293                        .map(|(start, size)| {
294                            let end = start + size;
295                            if end > full_size {
296                                panic!(
297                                    "Tried to read past the end of the buffer (start: {start}, \
298                                    size: {size}, buffer size: {full_size})."
299                                );
300                            }
301                            size
302                        })
303                        .unwrap_or(full_size);
304                    let buffer = buffer_pool.get(&render_device, size);
305                    let (tx, rx) = async_channel::bounded(1);
306                    readbacks.requested.push(GpuReadback {
307                        entity: entity.id(),
308                        src: ReadbackSource::Buffer {
309                            start_offset_and_size: *start_offset_and_size,
310                            buffer: ssbo.buffer.clone(),
311                        },
312                        buffer,
313                        rx,
314                        tx,
315                    });
316                }
317            }
318        }
319    }
320}
321
322pub(crate) fn submit_readback_commands(world: &World, command_encoder: &mut CommandEncoder) {
323    let readbacks = world.resource::<GpuReadbacks>();
324    for readback in &readbacks.requested {
325        match &readback.src {
326            ReadbackSource::Texture {
327                texture,
328                layout,
329                size,
330            } => {
331                command_encoder.copy_texture_to_buffer(
332                    texture.as_image_copy(),
333                    wgpu::TexelCopyBufferInfo {
334                        buffer: &readback.buffer,
335                        layout: *layout,
336                    },
337                    *size,
338                );
339            }
340            ReadbackSource::Buffer {
341                buffer,
342                start_offset_and_size,
343            } => {
344                let (src_start, size) = start_offset_and_size.unwrap_or((0, buffer.size()));
345                command_encoder.copy_buffer_to_buffer(buffer, src_start, &readback.buffer, 0, size);
346            }
347        }
348    }
349}
350
351/// Move requested readbacks to mapped readbacks after commands have been submitted in render system
352fn map_buffers(mut readbacks: ResMut<GpuReadbacks>) {
353    let requested = readbacks.requested.drain(..).collect::<Vec<GpuReadback>>();
354    for readback in requested {
355        let slice = readback.buffer.slice(..);
356        let entity = readback.entity;
357        let buffer = readback.buffer.clone();
358        let tx = readback.tx.clone();
359        slice.map_async(wgpu::MapMode::Read, move |res| {
360            res.expect("Failed to map buffer");
361            let buffer_slice = buffer.slice(..);
362            let data = buffer_slice.get_mapped_range();
363            let result = Vec::from(&*data);
364            drop(data);
365            buffer.unmap();
366            if let Err(e) = tx.try_send((entity, buffer, result)) {
367                warn!("Failed to send readback result: {}", e);
368            }
369        });
370        readbacks.mapped.push(readback);
371    }
372}
373
374// Utils
375
376/// Round up a given value to be a multiple of [`wgpu::COPY_BYTES_PER_ROW_ALIGNMENT`].
377pub(crate) const fn align_byte_size(value: u32) -> u32 {
378    RenderDevice::align_copy_bytes_per_row(value as usize) as u32
379}
380
381/// Get the size of a image when the size of each row has been rounded up to [`wgpu::COPY_BYTES_PER_ROW_ALIGNMENT`].
382pub(crate) const fn get_aligned_size(extent: Extent3d, pixel_size: u32) -> u32 {
383    extent.height * align_byte_size(extent.width * pixel_size) * extent.depth_or_array_layers
384}
385
386/// Get a [`TexelCopyBufferLayout`] aligned such that the image can be copied into a buffer.
387pub(crate) fn layout_data(extent: Extent3d, format: TextureFormat) -> TexelCopyBufferLayout {
388    TexelCopyBufferLayout {
389        bytes_per_row: if extent.height > 1 || extent.depth_or_array_layers > 1 {
390            if let Ok(pixel_size) = format.pixel_size() {
391                // 1 = 1 row
392                Some(get_aligned_size(
393                    Extent3d {
394                        width: extent.width,
395                        height: 1,
396                        depth_or_array_layers: 1,
397                    },
398                    pixel_size as u32,
399                ))
400            } else {
401                None
402            }
403        } else {
404            None
405        },
406        rows_per_image: if extent.depth_or_array_layers > 1 {
407            let (_, block_dimension_y) = format.block_dimensions();
408            Some(extent.height / block_dimension_y)
409        } else {
410            None
411        },
412        offset: 0,
413    }
414}