bevy_render/
gpu_readback.rs

1use crate::{
2    extract_component::ExtractComponentPlugin,
3    render_asset::RenderAssets,
4    render_resource::{Buffer, BufferUsages, Extent3d, ImageDataLayout, Texture, TextureFormat},
5    renderer::{render_system, RenderDevice},
6    storage::{GpuShaderStorageBuffer, ShaderStorageBuffer},
7    sync_world::MainEntity,
8    texture::GpuImage,
9    ExtractSchedule, MainWorld, Render, RenderApp, RenderSet,
10};
11use async_channel::{Receiver, Sender};
12use bevy_app::{App, Plugin};
13use bevy_asset::Handle;
14use bevy_derive::{Deref, DerefMut};
15use bevy_ecs::schedule::IntoSystemConfigs;
16use bevy_ecs::{
17    change_detection::ResMut,
18    entity::Entity,
19    event::Event,
20    prelude::{Component, Resource, World},
21    system::{Query, Res},
22};
23use bevy_image::{Image, TextureFormatPixelInfo};
24use bevy_reflect::Reflect;
25use bevy_render_macros::ExtractComponent;
26use bevy_utils::{default, tracing::warn, HashMap};
27use encase::internal::ReadFrom;
28use encase::private::Reader;
29use encase::ShaderType;
30use wgpu::CommandEncoder;
31
32/// A plugin that enables reading back gpu buffers and textures to the cpu.
33pub struct GpuReadbackPlugin {
34    /// Describes the number of frames a buffer can be unused before it is removed from the pool in
35    /// order to avoid unnecessary reallocations.
36    max_unused_frames: usize,
37}
38
39impl Default for GpuReadbackPlugin {
40    fn default() -> Self {
41        Self {
42            max_unused_frames: 10,
43        }
44    }
45}
46
47impl Plugin for GpuReadbackPlugin {
48    fn build(&self, app: &mut App) {
49        app.add_plugins(ExtractComponentPlugin::<Readback>::default());
50
51        if let Some(render_app) = app.get_sub_app_mut(RenderApp) {
52            render_app
53                .init_resource::<GpuReadbackBufferPool>()
54                .init_resource::<GpuReadbacks>()
55                .insert_resource(GpuReadbackMaxUnusedFrames(self.max_unused_frames))
56                .add_systems(ExtractSchedule, sync_readbacks.ambiguous_with_all())
57                .add_systems(
58                    Render,
59                    (
60                        prepare_buffers.in_set(RenderSet::PrepareResources),
61                        map_buffers.after(render_system).in_set(RenderSet::Render),
62                    ),
63                );
64        }
65    }
66}
67
68/// A component that registers the wrapped handle for gpu readback, either a texture or a buffer.
69///
70/// Data is read asynchronously and will be triggered on the entity via the [`ReadbackComplete`] event
71/// when complete. If this component is not removed, the readback will be attempted every frame
72#[derive(Component, ExtractComponent, Clone, Debug)]
73pub enum Readback {
74    Texture(Handle<Image>),
75    Buffer(Handle<ShaderStorageBuffer>),
76}
77
78impl Readback {
79    /// Create a readback component for a texture using the given handle.
80    pub fn texture(image: Handle<Image>) -> Self {
81        Self::Texture(image)
82    }
83
84    /// Create a readback component for a buffer using the given handle.
85    pub fn buffer(buffer: Handle<ShaderStorageBuffer>) -> Self {
86        Self::Buffer(buffer)
87    }
88}
89
90/// An event that is triggered when a gpu readback is complete.
91///
92/// The event contains the data as a `Vec<u8>`, which can be interpreted as the raw bytes of the
93/// requested buffer or texture.
94#[derive(Event, Deref, DerefMut, Reflect, Debug)]
95#[reflect(Debug)]
96pub struct ReadbackComplete(pub Vec<u8>);
97
98impl ReadbackComplete {
99    /// Convert the raw bytes of the event to a shader type.
100    pub fn to_shader_type<T: ShaderType + ReadFrom + Default>(&self) -> T {
101        let mut val = T::default();
102        let mut reader = Reader::new::<T>(&self.0, 0).expect("Failed to create Reader");
103        T::read_from(&mut val, &mut reader);
104        val
105    }
106}
107
108#[derive(Resource)]
109struct GpuReadbackMaxUnusedFrames(usize);
110
111struct GpuReadbackBuffer {
112    buffer: Buffer,
113    taken: bool,
114    frames_unused: usize,
115}
116
117#[derive(Resource, Default)]
118struct GpuReadbackBufferPool {
119    // Map of buffer size to list of buffers, with a flag for whether the buffer is taken and how
120    // many frames it has been unused for.
121    // TODO: We could ideally write all readback data to one big buffer per frame, the assumption
122    // here is that very few entities well actually be read back at once, and their size is
123    // unlikely to change.
124    buffers: HashMap<u64, Vec<GpuReadbackBuffer>>,
125}
126
127impl GpuReadbackBufferPool {
128    fn get(&mut self, render_device: &RenderDevice, size: u64) -> Buffer {
129        let buffers = self.buffers.entry(size).or_default();
130
131        // find an untaken buffer for this size
132        if let Some(buf) = buffers.iter_mut().find(|x| !x.taken) {
133            buf.taken = true;
134            buf.frames_unused = 0;
135            return buf.buffer.clone();
136        }
137
138        let buffer = render_device.create_buffer(&wgpu::BufferDescriptor {
139            label: Some("Readback Buffer"),
140            size,
141            usage: BufferUsages::COPY_DST | BufferUsages::MAP_READ,
142            mapped_at_creation: false,
143        });
144        buffers.push(GpuReadbackBuffer {
145            buffer: buffer.clone(),
146            taken: true,
147            frames_unused: 0,
148        });
149        buffer
150    }
151
152    // Returns the buffer to the pool so it can be used in a future frame
153    fn return_buffer(&mut self, buffer: &Buffer) {
154        let size = buffer.size();
155        let buffers = self
156            .buffers
157            .get_mut(&size)
158            .expect("Returned buffer of untracked size");
159        if let Some(buf) = buffers.iter_mut().find(|x| x.buffer.id() == buffer.id()) {
160            buf.taken = false;
161        } else {
162            warn!("Returned buffer that was not allocated");
163        }
164    }
165
166    fn update(&mut self, max_unused_frames: usize) {
167        for (_, buffers) in &mut self.buffers {
168            // Tick all the buffers
169            for buf in &mut *buffers {
170                if !buf.taken {
171                    buf.frames_unused += 1;
172                }
173            }
174
175            // Remove buffers that haven't been used for MAX_UNUSED_FRAMES
176            buffers.retain(|x| x.frames_unused < max_unused_frames);
177        }
178
179        // Remove empty buffer sizes
180        self.buffers.retain(|_, buffers| !buffers.is_empty());
181    }
182}
183
184enum ReadbackSource {
185    Texture {
186        texture: Texture,
187        layout: ImageDataLayout,
188        size: Extent3d,
189    },
190    Buffer {
191        src_start: u64,
192        dst_start: u64,
193        buffer: Buffer,
194    },
195}
196
197#[derive(Resource, Default)]
198struct GpuReadbacks {
199    requested: Vec<GpuReadback>,
200    mapped: Vec<GpuReadback>,
201}
202
203struct GpuReadback {
204    pub entity: Entity,
205    pub src: ReadbackSource,
206    pub buffer: Buffer,
207    pub rx: Receiver<(Entity, Buffer, Vec<u8>)>,
208    pub tx: Sender<(Entity, Buffer, Vec<u8>)>,
209}
210
211fn sync_readbacks(
212    mut main_world: ResMut<MainWorld>,
213    mut buffer_pool: ResMut<GpuReadbackBufferPool>,
214    mut readbacks: ResMut<GpuReadbacks>,
215    max_unused_frames: Res<GpuReadbackMaxUnusedFrames>,
216) {
217    readbacks.mapped.retain(|readback| {
218        if let Ok((entity, buffer, result)) = readback.rx.try_recv() {
219            main_world.trigger_targets(ReadbackComplete(result), entity);
220            buffer_pool.return_buffer(&buffer);
221            false
222        } else {
223            true
224        }
225    });
226
227    buffer_pool.update(max_unused_frames.0);
228}
229
230fn prepare_buffers(
231    render_device: Res<RenderDevice>,
232    mut readbacks: ResMut<GpuReadbacks>,
233    mut buffer_pool: ResMut<GpuReadbackBufferPool>,
234    gpu_images: Res<RenderAssets<GpuImage>>,
235    ssbos: Res<RenderAssets<GpuShaderStorageBuffer>>,
236    handles: Query<(&MainEntity, &Readback)>,
237) {
238    for (entity, readback) in handles.iter() {
239        match readback {
240            Readback::Texture(image) => {
241                if let Some(gpu_image) = gpu_images.get(image) {
242                    let size = Extent3d {
243                        width: gpu_image.size.x,
244                        height: gpu_image.size.y,
245                        ..default()
246                    };
247                    let layout = layout_data(size.width, size.height, gpu_image.texture_format);
248                    let buffer = buffer_pool.get(
249                        &render_device,
250                        get_aligned_size(
251                            size.width,
252                            size.height,
253                            gpu_image.texture_format.pixel_size() as u32,
254                        ) as u64,
255                    );
256                    let (tx, rx) = async_channel::bounded(1);
257                    readbacks.requested.push(GpuReadback {
258                        entity: entity.id(),
259                        src: ReadbackSource::Texture {
260                            texture: gpu_image.texture.clone(),
261                            layout,
262                            size,
263                        },
264                        buffer,
265                        rx,
266                        tx,
267                    });
268                }
269            }
270            Readback::Buffer(buffer) => {
271                if let Some(ssbo) = ssbos.get(buffer) {
272                    let size = ssbo.buffer.size();
273                    let buffer = buffer_pool.get(&render_device, size);
274                    let (tx, rx) = async_channel::bounded(1);
275                    readbacks.requested.push(GpuReadback {
276                        entity: entity.id(),
277                        src: ReadbackSource::Buffer {
278                            src_start: 0,
279                            dst_start: 0,
280                            buffer: ssbo.buffer.clone(),
281                        },
282                        buffer,
283                        rx,
284                        tx,
285                    });
286                }
287            }
288        }
289    }
290}
291
292pub(crate) fn submit_readback_commands(world: &World, command_encoder: &mut CommandEncoder) {
293    let readbacks = world.resource::<GpuReadbacks>();
294    for readback in &readbacks.requested {
295        match &readback.src {
296            ReadbackSource::Texture {
297                texture,
298                layout,
299                size,
300            } => {
301                command_encoder.copy_texture_to_buffer(
302                    texture.as_image_copy(),
303                    wgpu::ImageCopyBuffer {
304                        buffer: &readback.buffer,
305                        layout: *layout,
306                    },
307                    *size,
308                );
309            }
310            ReadbackSource::Buffer {
311                src_start,
312                dst_start,
313                buffer,
314            } => {
315                command_encoder.copy_buffer_to_buffer(
316                    buffer,
317                    *src_start,
318                    &readback.buffer,
319                    *dst_start,
320                    buffer.size(),
321                );
322            }
323        }
324    }
325}
326
327/// Move requested readbacks to mapped readbacks after commands have been submitted in render system
328fn map_buffers(mut readbacks: ResMut<GpuReadbacks>) {
329    let requested = readbacks.requested.drain(..).collect::<Vec<GpuReadback>>();
330    for readback in requested {
331        let slice = readback.buffer.slice(..);
332        let entity = readback.entity;
333        let buffer = readback.buffer.clone();
334        let tx = readback.tx.clone();
335        slice.map_async(wgpu::MapMode::Read, move |res| {
336            res.expect("Failed to map buffer");
337            let buffer_slice = buffer.slice(..);
338            let data = buffer_slice.get_mapped_range();
339            let result = Vec::from(&*data);
340            drop(data);
341            buffer.unmap();
342            if let Err(e) = tx.try_send((entity, buffer, result)) {
343                warn!("Failed to send readback result: {:?}", e);
344            }
345        });
346        readbacks.mapped.push(readback);
347    }
348}
349
350// Utils
351
352/// Round up a given value to be a multiple of [`wgpu::COPY_BYTES_PER_ROW_ALIGNMENT`].
353pub(crate) const fn align_byte_size(value: u32) -> u32 {
354    RenderDevice::align_copy_bytes_per_row(value as usize) as u32
355}
356
357/// Get the size of a image when the size of each row has been rounded up to [`wgpu::COPY_BYTES_PER_ROW_ALIGNMENT`].
358pub(crate) const fn get_aligned_size(width: u32, height: u32, pixel_size: u32) -> u32 {
359    height * align_byte_size(width * pixel_size)
360}
361
362/// Get a [`ImageDataLayout`] aligned such that the image can be copied into a buffer.
363pub(crate) fn layout_data(width: u32, height: u32, format: TextureFormat) -> ImageDataLayout {
364    ImageDataLayout {
365        bytes_per_row: if height > 1 {
366            // 1 = 1 row
367            Some(get_aligned_size(width, 1, format.pixel_size() as u32))
368        } else {
369            None
370        },
371        rows_per_image: None,
372        ..Default::default()
373    }
374}