bevy_render/
gpu_readback.rs

1use crate::{
2    extract_component::ExtractComponentPlugin,
3    render_asset::RenderAssets,
4    render_resource::{
5        Buffer, BufferUsages, CommandEncoder, Extent3d, TexelCopyBufferLayout, Texture,
6        TextureFormat,
7    },
8    renderer::{render_system, RenderDevice},
9    storage::{GpuShaderStorageBuffer, ShaderStorageBuffer},
10    sync_world::MainEntity,
11    texture::GpuImage,
12    ExtractSchedule, MainWorld, Render, RenderApp, RenderSet,
13};
14use async_channel::{Receiver, Sender};
15use bevy_app::{App, Plugin};
16use bevy_asset::Handle;
17use bevy_derive::{Deref, DerefMut};
18use bevy_ecs::schedule::IntoScheduleConfigs;
19use bevy_ecs::{
20    change_detection::ResMut,
21    entity::Entity,
22    event::Event,
23    prelude::{Component, Resource, World},
24    system::{Query, Res},
25};
26use bevy_image::{Image, TextureFormatPixelInfo};
27use bevy_platform::collections::HashMap;
28use bevy_reflect::Reflect;
29use bevy_render_macros::ExtractComponent;
30use encase::internal::ReadFrom;
31use encase::private::Reader;
32use encase::ShaderType;
33use tracing::warn;
34
35/// A plugin that enables reading back gpu buffers and textures to the cpu.
36pub struct GpuReadbackPlugin {
37    /// Describes the number of frames a buffer can be unused before it is removed from the pool in
38    /// order to avoid unnecessary reallocations.
39    max_unused_frames: usize,
40}
41
42impl Default for GpuReadbackPlugin {
43    fn default() -> Self {
44        Self {
45            max_unused_frames: 10,
46        }
47    }
48}
49
50impl Plugin for GpuReadbackPlugin {
51    fn build(&self, app: &mut App) {
52        app.add_plugins(ExtractComponentPlugin::<Readback>::default());
53
54        if let Some(render_app) = app.get_sub_app_mut(RenderApp) {
55            render_app
56                .init_resource::<GpuReadbackBufferPool>()
57                .init_resource::<GpuReadbacks>()
58                .insert_resource(GpuReadbackMaxUnusedFrames(self.max_unused_frames))
59                .add_systems(ExtractSchedule, sync_readbacks.ambiguous_with_all())
60                .add_systems(
61                    Render,
62                    (
63                        prepare_buffers.in_set(RenderSet::PrepareResources),
64                        map_buffers.after(render_system).in_set(RenderSet::Render),
65                    ),
66                );
67        }
68    }
69}
70
71/// A component that registers the wrapped handle for gpu readback, either a texture or a buffer.
72///
73/// Data is read asynchronously and will be triggered on the entity via the [`ReadbackComplete`] event
74/// when complete. If this component is not removed, the readback will be attempted every frame
75#[derive(Component, ExtractComponent, Clone, Debug)]
76pub enum Readback {
77    Texture(Handle<Image>),
78    Buffer(Handle<ShaderStorageBuffer>),
79}
80
81impl Readback {
82    /// Create a readback component for a texture using the given handle.
83    pub fn texture(image: Handle<Image>) -> Self {
84        Self::Texture(image)
85    }
86
87    /// Create a readback component for a buffer using the given handle.
88    pub fn buffer(buffer: Handle<ShaderStorageBuffer>) -> Self {
89        Self::Buffer(buffer)
90    }
91}
92
93/// An event that is triggered when a gpu readback is complete.
94///
95/// The event contains the data as a `Vec<u8>`, which can be interpreted as the raw bytes of the
96/// requested buffer or texture.
97#[derive(Event, Deref, DerefMut, Reflect, Debug)]
98#[reflect(Debug)]
99pub struct ReadbackComplete(pub Vec<u8>);
100
101impl ReadbackComplete {
102    /// Convert the raw bytes of the event to a shader type.
103    pub fn to_shader_type<T: ShaderType + ReadFrom + Default>(&self) -> T {
104        let mut val = T::default();
105        let mut reader = Reader::new::<T>(&self.0, 0).expect("Failed to create Reader");
106        T::read_from(&mut val, &mut reader);
107        val
108    }
109}
110
111#[derive(Resource)]
112struct GpuReadbackMaxUnusedFrames(usize);
113
114struct GpuReadbackBuffer {
115    buffer: Buffer,
116    taken: bool,
117    frames_unused: usize,
118}
119
120#[derive(Resource, Default)]
121struct GpuReadbackBufferPool {
122    // Map of buffer size to list of buffers, with a flag for whether the buffer is taken and how
123    // many frames it has been unused for.
124    // TODO: We could ideally write all readback data to one big buffer per frame, the assumption
125    // here is that very few entities well actually be read back at once, and their size is
126    // unlikely to change.
127    buffers: HashMap<u64, Vec<GpuReadbackBuffer>>,
128}
129
130impl GpuReadbackBufferPool {
131    fn get(&mut self, render_device: &RenderDevice, size: u64) -> Buffer {
132        let buffers = self.buffers.entry(size).or_default();
133
134        // find an untaken buffer for this size
135        if let Some(buf) = buffers.iter_mut().find(|x| !x.taken) {
136            buf.taken = true;
137            buf.frames_unused = 0;
138            return buf.buffer.clone();
139        }
140
141        let buffer = render_device.create_buffer(&wgpu::BufferDescriptor {
142            label: Some("Readback Buffer"),
143            size,
144            usage: BufferUsages::COPY_DST | BufferUsages::MAP_READ,
145            mapped_at_creation: false,
146        });
147        buffers.push(GpuReadbackBuffer {
148            buffer: buffer.clone(),
149            taken: true,
150            frames_unused: 0,
151        });
152        buffer
153    }
154
155    // Returns the buffer to the pool so it can be used in a future frame
156    fn return_buffer(&mut self, buffer: &Buffer) {
157        let size = buffer.size();
158        let buffers = self
159            .buffers
160            .get_mut(&size)
161            .expect("Returned buffer of untracked size");
162        if let Some(buf) = buffers.iter_mut().find(|x| x.buffer.id() == buffer.id()) {
163            buf.taken = false;
164        } else {
165            warn!("Returned buffer that was not allocated");
166        }
167    }
168
169    fn update(&mut self, max_unused_frames: usize) {
170        for (_, buffers) in &mut self.buffers {
171            // Tick all the buffers
172            for buf in &mut *buffers {
173                if !buf.taken {
174                    buf.frames_unused += 1;
175                }
176            }
177
178            // Remove buffers that haven't been used for MAX_UNUSED_FRAMES
179            buffers.retain(|x| x.frames_unused < max_unused_frames);
180        }
181
182        // Remove empty buffer sizes
183        self.buffers.retain(|_, buffers| !buffers.is_empty());
184    }
185}
186
187enum ReadbackSource {
188    Texture {
189        texture: Texture,
190        layout: TexelCopyBufferLayout,
191        size: Extent3d,
192    },
193    Buffer {
194        src_start: u64,
195        dst_start: u64,
196        buffer: Buffer,
197    },
198}
199
200#[derive(Resource, Default)]
201struct GpuReadbacks {
202    requested: Vec<GpuReadback>,
203    mapped: Vec<GpuReadback>,
204}
205
206struct GpuReadback {
207    pub entity: Entity,
208    pub src: ReadbackSource,
209    pub buffer: Buffer,
210    pub rx: Receiver<(Entity, Buffer, Vec<u8>)>,
211    pub tx: Sender<(Entity, Buffer, Vec<u8>)>,
212}
213
214fn sync_readbacks(
215    mut main_world: ResMut<MainWorld>,
216    mut buffer_pool: ResMut<GpuReadbackBufferPool>,
217    mut readbacks: ResMut<GpuReadbacks>,
218    max_unused_frames: Res<GpuReadbackMaxUnusedFrames>,
219) {
220    readbacks.mapped.retain(|readback| {
221        if let Ok((entity, buffer, result)) = readback.rx.try_recv() {
222            main_world.trigger_targets(ReadbackComplete(result), entity);
223            buffer_pool.return_buffer(&buffer);
224            false
225        } else {
226            true
227        }
228    });
229
230    buffer_pool.update(max_unused_frames.0);
231}
232
233fn prepare_buffers(
234    render_device: Res<RenderDevice>,
235    mut readbacks: ResMut<GpuReadbacks>,
236    mut buffer_pool: ResMut<GpuReadbackBufferPool>,
237    gpu_images: Res<RenderAssets<GpuImage>>,
238    ssbos: Res<RenderAssets<GpuShaderStorageBuffer>>,
239    handles: Query<(&MainEntity, &Readback)>,
240) {
241    for (entity, readback) in handles.iter() {
242        match readback {
243            Readback::Texture(image) => {
244                if let Some(gpu_image) = gpu_images.get(image) {
245                    let layout = layout_data(gpu_image.size, gpu_image.texture_format);
246                    let buffer = buffer_pool.get(
247                        &render_device,
248                        get_aligned_size(
249                            gpu_image.size,
250                            gpu_image.texture_format.pixel_size() as u32,
251                        ) as u64,
252                    );
253                    let (tx, rx) = async_channel::bounded(1);
254                    readbacks.requested.push(GpuReadback {
255                        entity: entity.id(),
256                        src: ReadbackSource::Texture {
257                            texture: gpu_image.texture.clone(),
258                            layout,
259                            size: gpu_image.size,
260                        },
261                        buffer,
262                        rx,
263                        tx,
264                    });
265                }
266            }
267            Readback::Buffer(buffer) => {
268                if let Some(ssbo) = ssbos.get(buffer) {
269                    let size = ssbo.buffer.size();
270                    let buffer = buffer_pool.get(&render_device, size);
271                    let (tx, rx) = async_channel::bounded(1);
272                    readbacks.requested.push(GpuReadback {
273                        entity: entity.id(),
274                        src: ReadbackSource::Buffer {
275                            src_start: 0,
276                            dst_start: 0,
277                            buffer: ssbo.buffer.clone(),
278                        },
279                        buffer,
280                        rx,
281                        tx,
282                    });
283                }
284            }
285        }
286    }
287}
288
289pub(crate) fn submit_readback_commands(world: &World, command_encoder: &mut CommandEncoder) {
290    let readbacks = world.resource::<GpuReadbacks>();
291    for readback in &readbacks.requested {
292        match &readback.src {
293            ReadbackSource::Texture {
294                texture,
295                layout,
296                size,
297            } => {
298                command_encoder.copy_texture_to_buffer(
299                    texture.as_image_copy(),
300                    wgpu::TexelCopyBufferInfo {
301                        buffer: &readback.buffer,
302                        layout: *layout,
303                    },
304                    *size,
305                );
306            }
307            ReadbackSource::Buffer {
308                src_start,
309                dst_start,
310                buffer,
311            } => {
312                command_encoder.copy_buffer_to_buffer(
313                    buffer,
314                    *src_start,
315                    &readback.buffer,
316                    *dst_start,
317                    buffer.size(),
318                );
319            }
320        }
321    }
322}
323
324/// Move requested readbacks to mapped readbacks after commands have been submitted in render system
325fn map_buffers(mut readbacks: ResMut<GpuReadbacks>) {
326    let requested = readbacks.requested.drain(..).collect::<Vec<GpuReadback>>();
327    for readback in requested {
328        let slice = readback.buffer.slice(..);
329        let entity = readback.entity;
330        let buffer = readback.buffer.clone();
331        let tx = readback.tx.clone();
332        slice.map_async(wgpu::MapMode::Read, move |res| {
333            res.expect("Failed to map buffer");
334            let buffer_slice = buffer.slice(..);
335            let data = buffer_slice.get_mapped_range();
336            let result = Vec::from(&*data);
337            drop(data);
338            buffer.unmap();
339            if let Err(e) = tx.try_send((entity, buffer, result)) {
340                warn!("Failed to send readback result: {}", e);
341            }
342        });
343        readbacks.mapped.push(readback);
344    }
345}
346
347// Utils
348
349/// Round up a given value to be a multiple of [`wgpu::COPY_BYTES_PER_ROW_ALIGNMENT`].
350pub(crate) const fn align_byte_size(value: u32) -> u32 {
351    RenderDevice::align_copy_bytes_per_row(value as usize) as u32
352}
353
354/// Get the size of a image when the size of each row has been rounded up to [`wgpu::COPY_BYTES_PER_ROW_ALIGNMENT`].
355pub(crate) const fn get_aligned_size(extent: Extent3d, pixel_size: u32) -> u32 {
356    extent.height * align_byte_size(extent.width * pixel_size) * extent.depth_or_array_layers
357}
358
359/// Get a [`TexelCopyBufferLayout`] aligned such that the image can be copied into a buffer.
360pub(crate) fn layout_data(extent: Extent3d, format: TextureFormat) -> TexelCopyBufferLayout {
361    TexelCopyBufferLayout {
362        bytes_per_row: if extent.height > 1 || extent.depth_or_array_layers > 1 {
363            // 1 = 1 row
364            Some(get_aligned_size(
365                Extent3d {
366                    width: extent.width,
367                    height: 1,
368                    depth_or_array_layers: 1,
369                },
370                format.pixel_size() as u32,
371            ))
372        } else {
373            None
374        },
375        rows_per_image: if extent.depth_or_array_layers > 1 {
376            let (_, block_dimension_y) = format.block_dimensions();
377            Some(extent.height / block_dimension_y)
378        } else {
379            None
380        },
381        offset: 0,
382    }
383}