1use crate::{
2 extract_component::ExtractComponentPlugin,
3 render_asset::RenderAssets,
4 render_resource::{Buffer, BufferUsages, Extent3d, ImageDataLayout, Texture, TextureFormat},
5 renderer::{render_system, RenderDevice},
6 storage::{GpuShaderStorageBuffer, ShaderStorageBuffer},
7 sync_world::MainEntity,
8 texture::GpuImage,
9 ExtractSchedule, MainWorld, Render, RenderApp, RenderSet,
10};
11use async_channel::{Receiver, Sender};
12use bevy_app::{App, Plugin};
13use bevy_asset::Handle;
14use bevy_derive::{Deref, DerefMut};
15use bevy_ecs::schedule::IntoSystemConfigs;
16use bevy_ecs::{
17 change_detection::ResMut,
18 entity::Entity,
19 event::Event,
20 prelude::{Component, Resource, World},
21 system::{Query, Res},
22};
23use bevy_image::{Image, TextureFormatPixelInfo};
24use bevy_reflect::Reflect;
25use bevy_render_macros::ExtractComponent;
26use bevy_utils::{default, tracing::warn, HashMap};
27use encase::internal::ReadFrom;
28use encase::private::Reader;
29use encase::ShaderType;
30use wgpu::CommandEncoder;
31
32pub struct GpuReadbackPlugin {
34 max_unused_frames: usize,
37}
38
39impl Default for GpuReadbackPlugin {
40 fn default() -> Self {
41 Self {
42 max_unused_frames: 10,
43 }
44 }
45}
46
47impl Plugin for GpuReadbackPlugin {
48 fn build(&self, app: &mut App) {
49 app.add_plugins(ExtractComponentPlugin::<Readback>::default());
50
51 if let Some(render_app) = app.get_sub_app_mut(RenderApp) {
52 render_app
53 .init_resource::<GpuReadbackBufferPool>()
54 .init_resource::<GpuReadbacks>()
55 .insert_resource(GpuReadbackMaxUnusedFrames(self.max_unused_frames))
56 .add_systems(ExtractSchedule, sync_readbacks.ambiguous_with_all())
57 .add_systems(
58 Render,
59 (
60 prepare_buffers.in_set(RenderSet::PrepareResources),
61 map_buffers.after(render_system).in_set(RenderSet::Render),
62 ),
63 );
64 }
65 }
66}
67
68#[derive(Component, ExtractComponent, Clone, Debug)]
73pub enum Readback {
74 Texture(Handle<Image>),
75 Buffer(Handle<ShaderStorageBuffer>),
76}
77
78impl Readback {
79 pub fn texture(image: Handle<Image>) -> Self {
81 Self::Texture(image)
82 }
83
84 pub fn buffer(buffer: Handle<ShaderStorageBuffer>) -> Self {
86 Self::Buffer(buffer)
87 }
88}
89
90#[derive(Event, Deref, DerefMut, Reflect, Debug)]
95#[reflect(Debug)]
96pub struct ReadbackComplete(pub Vec<u8>);
97
98impl ReadbackComplete {
99 pub fn to_shader_type<T: ShaderType + ReadFrom + Default>(&self) -> T {
101 let mut val = T::default();
102 let mut reader = Reader::new::<T>(&self.0, 0).expect("Failed to create Reader");
103 T::read_from(&mut val, &mut reader);
104 val
105 }
106}
107
108#[derive(Resource)]
109struct GpuReadbackMaxUnusedFrames(usize);
110
111struct GpuReadbackBuffer {
112 buffer: Buffer,
113 taken: bool,
114 frames_unused: usize,
115}
116
117#[derive(Resource, Default)]
118struct GpuReadbackBufferPool {
119 buffers: HashMap<u64, Vec<GpuReadbackBuffer>>,
125}
126
127impl GpuReadbackBufferPool {
128 fn get(&mut self, render_device: &RenderDevice, size: u64) -> Buffer {
129 let buffers = self.buffers.entry(size).or_default();
130
131 if let Some(buf) = buffers.iter_mut().find(|x| !x.taken) {
133 buf.taken = true;
134 buf.frames_unused = 0;
135 return buf.buffer.clone();
136 }
137
138 let buffer = render_device.create_buffer(&wgpu::BufferDescriptor {
139 label: Some("Readback Buffer"),
140 size,
141 usage: BufferUsages::COPY_DST | BufferUsages::MAP_READ,
142 mapped_at_creation: false,
143 });
144 buffers.push(GpuReadbackBuffer {
145 buffer: buffer.clone(),
146 taken: true,
147 frames_unused: 0,
148 });
149 buffer
150 }
151
152 fn return_buffer(&mut self, buffer: &Buffer) {
154 let size = buffer.size();
155 let buffers = self
156 .buffers
157 .get_mut(&size)
158 .expect("Returned buffer of untracked size");
159 if let Some(buf) = buffers.iter_mut().find(|x| x.buffer.id() == buffer.id()) {
160 buf.taken = false;
161 } else {
162 warn!("Returned buffer that was not allocated");
163 }
164 }
165
166 fn update(&mut self, max_unused_frames: usize) {
167 for (_, buffers) in &mut self.buffers {
168 for buf in &mut *buffers {
170 if !buf.taken {
171 buf.frames_unused += 1;
172 }
173 }
174
175 buffers.retain(|x| x.frames_unused < max_unused_frames);
177 }
178
179 self.buffers.retain(|_, buffers| !buffers.is_empty());
181 }
182}
183
184enum ReadbackSource {
185 Texture {
186 texture: Texture,
187 layout: ImageDataLayout,
188 size: Extent3d,
189 },
190 Buffer {
191 src_start: u64,
192 dst_start: u64,
193 buffer: Buffer,
194 },
195}
196
197#[derive(Resource, Default)]
198struct GpuReadbacks {
199 requested: Vec<GpuReadback>,
200 mapped: Vec<GpuReadback>,
201}
202
203struct GpuReadback {
204 pub entity: Entity,
205 pub src: ReadbackSource,
206 pub buffer: Buffer,
207 pub rx: Receiver<(Entity, Buffer, Vec<u8>)>,
208 pub tx: Sender<(Entity, Buffer, Vec<u8>)>,
209}
210
211fn sync_readbacks(
212 mut main_world: ResMut<MainWorld>,
213 mut buffer_pool: ResMut<GpuReadbackBufferPool>,
214 mut readbacks: ResMut<GpuReadbacks>,
215 max_unused_frames: Res<GpuReadbackMaxUnusedFrames>,
216) {
217 readbacks.mapped.retain(|readback| {
218 if let Ok((entity, buffer, result)) = readback.rx.try_recv() {
219 main_world.trigger_targets(ReadbackComplete(result), entity);
220 buffer_pool.return_buffer(&buffer);
221 false
222 } else {
223 true
224 }
225 });
226
227 buffer_pool.update(max_unused_frames.0);
228}
229
230fn prepare_buffers(
231 render_device: Res<RenderDevice>,
232 mut readbacks: ResMut<GpuReadbacks>,
233 mut buffer_pool: ResMut<GpuReadbackBufferPool>,
234 gpu_images: Res<RenderAssets<GpuImage>>,
235 ssbos: Res<RenderAssets<GpuShaderStorageBuffer>>,
236 handles: Query<(&MainEntity, &Readback)>,
237) {
238 for (entity, readback) in handles.iter() {
239 match readback {
240 Readback::Texture(image) => {
241 if let Some(gpu_image) = gpu_images.get(image) {
242 let size = Extent3d {
243 width: gpu_image.size.x,
244 height: gpu_image.size.y,
245 ..default()
246 };
247 let layout = layout_data(size.width, size.height, gpu_image.texture_format);
248 let buffer = buffer_pool.get(
249 &render_device,
250 get_aligned_size(
251 size.width,
252 size.height,
253 gpu_image.texture_format.pixel_size() as u32,
254 ) as u64,
255 );
256 let (tx, rx) = async_channel::bounded(1);
257 readbacks.requested.push(GpuReadback {
258 entity: entity.id(),
259 src: ReadbackSource::Texture {
260 texture: gpu_image.texture.clone(),
261 layout,
262 size,
263 },
264 buffer,
265 rx,
266 tx,
267 });
268 }
269 }
270 Readback::Buffer(buffer) => {
271 if let Some(ssbo) = ssbos.get(buffer) {
272 let size = ssbo.buffer.size();
273 let buffer = buffer_pool.get(&render_device, size);
274 let (tx, rx) = async_channel::bounded(1);
275 readbacks.requested.push(GpuReadback {
276 entity: entity.id(),
277 src: ReadbackSource::Buffer {
278 src_start: 0,
279 dst_start: 0,
280 buffer: ssbo.buffer.clone(),
281 },
282 buffer,
283 rx,
284 tx,
285 });
286 }
287 }
288 }
289 }
290}
291
292pub(crate) fn submit_readback_commands(world: &World, command_encoder: &mut CommandEncoder) {
293 let readbacks = world.resource::<GpuReadbacks>();
294 for readback in &readbacks.requested {
295 match &readback.src {
296 ReadbackSource::Texture {
297 texture,
298 layout,
299 size,
300 } => {
301 command_encoder.copy_texture_to_buffer(
302 texture.as_image_copy(),
303 wgpu::ImageCopyBuffer {
304 buffer: &readback.buffer,
305 layout: *layout,
306 },
307 *size,
308 );
309 }
310 ReadbackSource::Buffer {
311 src_start,
312 dst_start,
313 buffer,
314 } => {
315 command_encoder.copy_buffer_to_buffer(
316 buffer,
317 *src_start,
318 &readback.buffer,
319 *dst_start,
320 buffer.size(),
321 );
322 }
323 }
324 }
325}
326
327fn map_buffers(mut readbacks: ResMut<GpuReadbacks>) {
329 let requested = readbacks.requested.drain(..).collect::<Vec<GpuReadback>>();
330 for readback in requested {
331 let slice = readback.buffer.slice(..);
332 let entity = readback.entity;
333 let buffer = readback.buffer.clone();
334 let tx = readback.tx.clone();
335 slice.map_async(wgpu::MapMode::Read, move |res| {
336 res.expect("Failed to map buffer");
337 let buffer_slice = buffer.slice(..);
338 let data = buffer_slice.get_mapped_range();
339 let result = Vec::from(&*data);
340 drop(data);
341 buffer.unmap();
342 if let Err(e) = tx.try_send((entity, buffer, result)) {
343 warn!("Failed to send readback result: {:?}", e);
344 }
345 });
346 readbacks.mapped.push(readback);
347 }
348}
349
350pub(crate) const fn align_byte_size(value: u32) -> u32 {
354 RenderDevice::align_copy_bytes_per_row(value as usize) as u32
355}
356
357pub(crate) const fn get_aligned_size(width: u32, height: u32, pixel_size: u32) -> u32 {
359 height * align_byte_size(width * pixel_size)
360}
361
362pub(crate) fn layout_data(width: u32, height: u32, format: TextureFormat) -> ImageDataLayout {
364 ImageDataLayout {
365 bytes_per_row: if height > 1 {
366 Some(get_aligned_size(width, 1, format.pixel_size() as u32))
368 } else {
369 None
370 },
371 rows_per_image: None,
372 ..Default::default()
373 }
374}