bevy_core_pipeline/tonemapping/
node.rs

1use std::sync::Mutex;
2
3use crate::tonemapping::{TonemappingLuts, TonemappingPipeline, ViewTonemappingPipeline};
4
5use bevy_ecs::{prelude::*, query::QueryItem};
6use bevy_render::{
7    render_asset::RenderAssets,
8    render_graph::{NodeRunError, RenderGraphContext, ViewNode},
9    render_resource::{
10        BindGroup, BindGroupEntries, BufferId, LoadOp, Operations, PipelineCache,
11        RenderPassColorAttachment, RenderPassDescriptor, StoreOp, TextureViewId,
12    },
13    renderer::RenderContext,
14    texture::{FallbackImage, GpuImage},
15    view::{ViewTarget, ViewUniformOffset, ViewUniforms},
16};
17
18use super::{get_lut_bindings, Tonemapping};
19
20#[derive(Default)]
21pub struct TonemappingNode {
22    cached_bind_group: Mutex<Option<(BufferId, TextureViewId, TextureViewId, BindGroup)>>,
23    last_tonemapping: Mutex<Option<Tonemapping>>,
24}
25
26impl ViewNode for TonemappingNode {
27    type ViewQuery = (
28        &'static ViewUniformOffset,
29        &'static ViewTarget,
30        &'static ViewTonemappingPipeline,
31        &'static Tonemapping,
32    );
33
34    fn run(
35        &self,
36        _graph: &mut RenderGraphContext,
37        render_context: &mut RenderContext,
38        (view_uniform_offset, target, view_tonemapping_pipeline, tonemapping): QueryItem<
39            Self::ViewQuery,
40        >,
41        world: &World,
42    ) -> Result<(), NodeRunError> {
43        let pipeline_cache = world.resource::<PipelineCache>();
44        let tonemapping_pipeline = world.resource::<TonemappingPipeline>();
45        let gpu_images = world.get_resource::<RenderAssets<GpuImage>>().unwrap();
46        let fallback_image = world.resource::<FallbackImage>();
47        let view_uniforms_resource = world.resource::<ViewUniforms>();
48        let view_uniforms = &view_uniforms_resource.uniforms;
49        let view_uniforms_id = view_uniforms.buffer().unwrap().id();
50
51        if *tonemapping == Tonemapping::None {
52            return Ok(());
53        }
54
55        if !target.is_hdr() {
56            return Ok(());
57        }
58
59        let Some(pipeline) = pipeline_cache.get_render_pipeline(view_tonemapping_pipeline.0) else {
60            return Ok(());
61        };
62
63        let post_process = target.post_process_write();
64        let source = post_process.source;
65        let destination = post_process.destination;
66
67        let mut last_tonemapping = self.last_tonemapping.lock().unwrap();
68
69        let tonemapping_changed = if let Some(last_tonemapping) = &*last_tonemapping {
70            tonemapping != last_tonemapping
71        } else {
72            true
73        };
74        if tonemapping_changed {
75            *last_tonemapping = Some(*tonemapping);
76        }
77
78        let mut cached_bind_group = self.cached_bind_group.lock().unwrap();
79        let bind_group = match &mut *cached_bind_group {
80            Some((buffer_id, texture_id, lut_id, bind_group))
81                if view_uniforms_id == *buffer_id
82                    && source.id() == *texture_id
83                    && *lut_id != fallback_image.d3.texture_view.id()
84                    && !tonemapping_changed =>
85            {
86                bind_group
87            }
88            cached_bind_group => {
89                let tonemapping_luts = world.resource::<TonemappingLuts>();
90
91                let lut_bindings =
92                    get_lut_bindings(gpu_images, tonemapping_luts, tonemapping, fallback_image);
93
94                let bind_group = render_context.render_device().create_bind_group(
95                    None,
96                    &tonemapping_pipeline.texture_bind_group,
97                    &BindGroupEntries::sequential((
98                        view_uniforms,
99                        source,
100                        &tonemapping_pipeline.sampler,
101                        lut_bindings.0,
102                        lut_bindings.1,
103                    )),
104                );
105
106                let (_, _, _, bind_group) = cached_bind_group.insert((
107                    view_uniforms_id,
108                    source.id(),
109                    lut_bindings.0.id(),
110                    bind_group,
111                ));
112                bind_group
113            }
114        };
115
116        let pass_descriptor = RenderPassDescriptor {
117            label: Some("tonemapping_pass"),
118            color_attachments: &[Some(RenderPassColorAttachment {
119                view: destination,
120                resolve_target: None,
121                ops: Operations {
122                    load: LoadOp::Clear(Default::default()), // TODO shouldn't need to be cleared
123                    store: StoreOp::Store,
124                },
125            })],
126            depth_stencil_attachment: None,
127            timestamp_writes: None,
128            occlusion_query_set: None,
129        };
130
131        let mut render_pass = render_context
132            .command_encoder()
133            .begin_render_pass(&pass_descriptor);
134
135        render_pass.set_pipeline(pipeline);
136        render_pass.set_bind_group(0, bind_group, &[view_uniform_offset.offset]);
137        render_pass.draw(0..3, 0..1);
138
139        Ok(())
140    }
141}