bevy_core_pipeline/tonemapping/
node.rs1use std::sync::Mutex;
2
3use crate::tonemapping::{TonemappingLuts, TonemappingPipeline, ViewTonemappingPipeline};
4
5use bevy_ecs::{prelude::*, query::QueryItem};
6use bevy_render::{
7 render_asset::RenderAssets,
8 render_graph::{NodeRunError, RenderGraphContext, ViewNode},
9 render_resource::{
10 BindGroup, BindGroupEntries, BufferId, LoadOp, Operations, PipelineCache,
11 RenderPassColorAttachment, RenderPassDescriptor, StoreOp, TextureViewId,
12 },
13 renderer::RenderContext,
14 texture::{FallbackImage, GpuImage},
15 view::{ViewTarget, ViewUniformOffset, ViewUniforms},
16};
17
18use super::{get_lut_bindings, Tonemapping};
19
20#[derive(Default)]
21pub struct TonemappingNode {
22 cached_bind_group: Mutex<Option<(BufferId, TextureViewId, TextureViewId, BindGroup)>>,
23 last_tonemapping: Mutex<Option<Tonemapping>>,
24}
25
26impl ViewNode for TonemappingNode {
27 type ViewQuery = (
28 &'static ViewUniformOffset,
29 &'static ViewTarget,
30 &'static ViewTonemappingPipeline,
31 &'static Tonemapping,
32 );
33
34 fn run(
35 &self,
36 _graph: &mut RenderGraphContext,
37 render_context: &mut RenderContext,
38 (view_uniform_offset, target, view_tonemapping_pipeline, tonemapping): QueryItem<
39 Self::ViewQuery,
40 >,
41 world: &World,
42 ) -> Result<(), NodeRunError> {
43 let pipeline_cache = world.resource::<PipelineCache>();
44 let tonemapping_pipeline = world.resource::<TonemappingPipeline>();
45 let gpu_images = world.get_resource::<RenderAssets<GpuImage>>().unwrap();
46 let fallback_image = world.resource::<FallbackImage>();
47 let view_uniforms_resource = world.resource::<ViewUniforms>();
48 let view_uniforms = &view_uniforms_resource.uniforms;
49 let view_uniforms_id = view_uniforms.buffer().unwrap().id();
50
51 if *tonemapping == Tonemapping::None {
52 return Ok(());
53 }
54
55 if !target.is_hdr() {
56 return Ok(());
57 }
58
59 let Some(pipeline) = pipeline_cache.get_render_pipeline(view_tonemapping_pipeline.0) else {
60 return Ok(());
61 };
62
63 let post_process = target.post_process_write();
64 let source = post_process.source;
65 let destination = post_process.destination;
66
67 let mut last_tonemapping = self.last_tonemapping.lock().unwrap();
68
69 let tonemapping_changed = if let Some(last_tonemapping) = &*last_tonemapping {
70 tonemapping != last_tonemapping
71 } else {
72 true
73 };
74 if tonemapping_changed {
75 *last_tonemapping = Some(*tonemapping);
76 }
77
78 let mut cached_bind_group = self.cached_bind_group.lock().unwrap();
79 let bind_group = match &mut *cached_bind_group {
80 Some((buffer_id, texture_id, lut_id, bind_group))
81 if view_uniforms_id == *buffer_id
82 && source.id() == *texture_id
83 && *lut_id != fallback_image.d3.texture_view.id()
84 && !tonemapping_changed =>
85 {
86 bind_group
87 }
88 cached_bind_group => {
89 let tonemapping_luts = world.resource::<TonemappingLuts>();
90
91 let lut_bindings =
92 get_lut_bindings(gpu_images, tonemapping_luts, tonemapping, fallback_image);
93
94 let bind_group = render_context.render_device().create_bind_group(
95 None,
96 &tonemapping_pipeline.texture_bind_group,
97 &BindGroupEntries::sequential((
98 view_uniforms,
99 source,
100 &tonemapping_pipeline.sampler,
101 lut_bindings.0,
102 lut_bindings.1,
103 )),
104 );
105
106 let (_, _, _, bind_group) = cached_bind_group.insert((
107 view_uniforms_id,
108 source.id(),
109 lut_bindings.0.id(),
110 bind_group,
111 ));
112 bind_group
113 }
114 };
115
116 let pass_descriptor = RenderPassDescriptor {
117 label: Some("tonemapping_pass"),
118 color_attachments: &[Some(RenderPassColorAttachment {
119 view: destination,
120 resolve_target: None,
121 ops: Operations {
122 load: LoadOp::Clear(Default::default()), store: StoreOp::Store,
124 },
125 })],
126 depth_stencil_attachment: None,
127 timestamp_writes: None,
128 occlusion_query_set: None,
129 };
130
131 let mut render_pass = render_context
132 .command_encoder()
133 .begin_render_pass(&pass_descriptor);
134
135 render_pass.set_pipeline(pipeline);
136 render_pass.set_bind_group(0, bind_group, &[view_uniform_offset.offset]);
137 render_pass.draw(0..3, 0..1);
138
139 Ok(())
140 }
141}