bevy_core_pipeline/tonemapping/
node.rs

1use std::sync::Mutex;
2
3use crate::tonemapping::{TonemappingLuts, TonemappingPipeline, ViewTonemappingPipeline};
4
5use bevy_ecs::{prelude::*, query::QueryItem};
6use bevy_render::{
7    diagnostic::RecordDiagnostics,
8    render_asset::RenderAssets,
9    render_graph::{NodeRunError, RenderGraphContext, ViewNode},
10    render_resource::{
11        BindGroup, BindGroupEntries, BufferId, LoadOp, Operations, PipelineCache,
12        RenderPassColorAttachment, RenderPassDescriptor, StoreOp, TextureViewId,
13    },
14    renderer::RenderContext,
15    texture::{FallbackImage, GpuImage},
16    view::{ViewTarget, ViewUniformOffset, ViewUniforms},
17};
18
19use super::{get_lut_bindings, Tonemapping};
20
21#[derive(Default)]
22pub struct TonemappingNode {
23    cached_bind_group: Mutex<Option<(BufferId, TextureViewId, TextureViewId, BindGroup)>>,
24    last_tonemapping: Mutex<Option<Tonemapping>>,
25}
26
27impl ViewNode for TonemappingNode {
28    type ViewQuery = (
29        &'static ViewUniformOffset,
30        &'static ViewTarget,
31        &'static ViewTonemappingPipeline,
32        &'static Tonemapping,
33    );
34
35    fn run(
36        &self,
37        _graph: &mut RenderGraphContext,
38        render_context: &mut RenderContext,
39        (view_uniform_offset, target, view_tonemapping_pipeline, tonemapping): QueryItem<
40            Self::ViewQuery,
41        >,
42        world: &World,
43    ) -> Result<(), NodeRunError> {
44        let pipeline_cache = world.resource::<PipelineCache>();
45        let tonemapping_pipeline = world.resource::<TonemappingPipeline>();
46        let gpu_images = world.get_resource::<RenderAssets<GpuImage>>().unwrap();
47        let fallback_image = world.resource::<FallbackImage>();
48        let view_uniforms_resource = world.resource::<ViewUniforms>();
49        let view_uniforms = &view_uniforms_resource.uniforms;
50        let view_uniforms_id = view_uniforms.buffer().unwrap().id();
51
52        if *tonemapping == Tonemapping::None {
53            return Ok(());
54        }
55
56        if !target.is_hdr() {
57            return Ok(());
58        }
59
60        let Some(pipeline) = pipeline_cache.get_render_pipeline(view_tonemapping_pipeline.0) else {
61            return Ok(());
62        };
63
64        let diagnostics = render_context.diagnostic_recorder();
65
66        let post_process = target.post_process_write();
67        let source = post_process.source;
68        let destination = post_process.destination;
69
70        let mut last_tonemapping = self.last_tonemapping.lock().unwrap();
71
72        let tonemapping_changed = if let Some(last_tonemapping) = &*last_tonemapping {
73            tonemapping != last_tonemapping
74        } else {
75            true
76        };
77        if tonemapping_changed {
78            *last_tonemapping = Some(*tonemapping);
79        }
80
81        let mut cached_bind_group = self.cached_bind_group.lock().unwrap();
82        let bind_group = match &mut *cached_bind_group {
83            Some((buffer_id, texture_id, lut_id, bind_group))
84                if view_uniforms_id == *buffer_id
85                    && source.id() == *texture_id
86                    && *lut_id != fallback_image.d3.texture_view.id()
87                    && !tonemapping_changed =>
88            {
89                bind_group
90            }
91            cached_bind_group => {
92                let tonemapping_luts = world.resource::<TonemappingLuts>();
93
94                let lut_bindings =
95                    get_lut_bindings(gpu_images, tonemapping_luts, tonemapping, fallback_image);
96
97                let bind_group = render_context.render_device().create_bind_group(
98                    None,
99                    &tonemapping_pipeline.texture_bind_group,
100                    &BindGroupEntries::sequential((
101                        view_uniforms,
102                        source,
103                        &tonemapping_pipeline.sampler,
104                        lut_bindings.0,
105                        lut_bindings.1,
106                    )),
107                );
108
109                let (_, _, _, bind_group) = cached_bind_group.insert((
110                    view_uniforms_id,
111                    source.id(),
112                    lut_bindings.0.id(),
113                    bind_group,
114                ));
115                bind_group
116            }
117        };
118
119        let pass_descriptor = RenderPassDescriptor {
120            label: Some("tonemapping"),
121            color_attachments: &[Some(RenderPassColorAttachment {
122                view: destination,
123                depth_slice: None,
124                resolve_target: None,
125                ops: Operations {
126                    load: LoadOp::Clear(Default::default()), // TODO shouldn't need to be cleared
127                    store: StoreOp::Store,
128                },
129            })],
130            depth_stencil_attachment: None,
131            timestamp_writes: None,
132            occlusion_query_set: None,
133        };
134
135        let mut render_pass = render_context
136            .command_encoder()
137            .begin_render_pass(&pass_descriptor);
138        let pass_span = diagnostics.pass_span(&mut render_pass, "tonemapping");
139
140        render_pass.set_pipeline(pipeline);
141        render_pass.set_bind_group(0, bind_group, &[view_uniform_offset.offset]);
142        render_pass.draw(0..3, 0..1);
143
144        pass_span.end(&mut render_pass);
145
146        Ok(())
147    }
148}