bevy_core_pipeline/contrast_adaptive_sharpening/
node.rs

1use std::sync::Mutex;
2
3use crate::contrast_adaptive_sharpening::ViewCasPipeline;
4use bevy_ecs::prelude::*;
5use bevy_render::{
6    extract_component::{ComponentUniforms, DynamicUniformIndex},
7    render_graph::{Node, NodeRunError, RenderGraphContext},
8    render_resource::{
9        BindGroup, BindGroupEntries, BufferId, Operations, PipelineCache,
10        RenderPassColorAttachment, RenderPassDescriptor, TextureViewId,
11    },
12    renderer::RenderContext,
13    view::{ExtractedView, ViewTarget},
14};
15
16use super::{CasPipeline, CasUniform};
17
18pub struct CasNode {
19    query: QueryState<
20        (
21            &'static ViewTarget,
22            &'static ViewCasPipeline,
23            &'static DynamicUniformIndex<CasUniform>,
24        ),
25        With<ExtractedView>,
26    >,
27    cached_bind_group: Mutex<Option<(BufferId, TextureViewId, BindGroup)>>,
28}
29
30impl FromWorld for CasNode {
31    fn from_world(world: &mut World) -> Self {
32        Self {
33            query: QueryState::new(world),
34            cached_bind_group: Mutex::new(None),
35        }
36    }
37}
38
39impl Node for CasNode {
40    fn update(&mut self, world: &mut World) {
41        self.query.update_archetypes(world);
42    }
43
44    fn run(
45        &self,
46        graph: &mut RenderGraphContext,
47        render_context: &mut RenderContext,
48        world: &World,
49    ) -> Result<(), NodeRunError> {
50        let view_entity = graph.view_entity();
51        let pipeline_cache = world.resource::<PipelineCache>();
52        let sharpening_pipeline = world.resource::<CasPipeline>();
53        let uniforms = world.resource::<ComponentUniforms<CasUniform>>();
54
55        let Ok((target, pipeline, uniform_index)) = self.query.get_manual(world, view_entity)
56        else {
57            return Ok(());
58        };
59
60        let uniforms_id = uniforms.buffer().unwrap().id();
61        let Some(uniforms) = uniforms.binding() else {
62            return Ok(());
63        };
64
65        let Some(pipeline) = pipeline_cache.get_render_pipeline(pipeline.0) else {
66            return Ok(());
67        };
68
69        let view_target = target.post_process_write();
70        let source = view_target.source;
71        let destination = view_target.destination;
72
73        let mut cached_bind_group = self.cached_bind_group.lock().unwrap();
74        let bind_group = match &mut *cached_bind_group {
75            Some((buffer_id, texture_id, bind_group))
76                if source.id() == *texture_id && uniforms_id == *buffer_id =>
77            {
78                bind_group
79            }
80            cached_bind_group => {
81                let bind_group = render_context.render_device().create_bind_group(
82                    "cas_bind_group",
83                    &sharpening_pipeline.texture_bind_group,
84                    &BindGroupEntries::sequential((
85                        view_target.source,
86                        &sharpening_pipeline.sampler,
87                        uniforms,
88                    )),
89                );
90
91                let (_, _, bind_group) =
92                    cached_bind_group.insert((uniforms_id, source.id(), bind_group));
93                bind_group
94            }
95        };
96
97        let pass_descriptor = RenderPassDescriptor {
98            label: Some("contrast_adaptive_sharpening"),
99            color_attachments: &[Some(RenderPassColorAttachment {
100                view: destination,
101                resolve_target: None,
102                ops: Operations::default(),
103            })],
104            depth_stencil_attachment: None,
105            timestamp_writes: None,
106            occlusion_query_set: None,
107        };
108
109        let mut render_pass = render_context
110            .command_encoder()
111            .begin_render_pass(&pass_descriptor);
112
113        render_pass.set_pipeline(pipeline);
114        render_pass.set_bind_group(0, bind_group, &[uniform_index.index()]);
115        render_pass.draw(0..3, 0..1);
116
117        Ok(())
118    }
119}