bevy_post_process/bloom/
mod.rs

1mod downsampling_pipeline;
2mod settings;
3mod upsampling_pipeline;
4
5use bevy_image::ToExtents;
6pub use settings::{Bloom, BloomCompositeMode, BloomPrefilter};
7
8use crate::bloom::{
9    downsampling_pipeline::init_bloom_downsampling_pipeline,
10    upsampling_pipeline::init_bloom_upscaling_pipeline,
11};
12use bevy_app::{App, Plugin};
13use bevy_asset::embedded_asset;
14use bevy_color::{Gray, LinearRgba};
15use bevy_core_pipeline::{
16    core_2d::graph::{Core2d, Node2d},
17    core_3d::graph::{Core3d, Node3d},
18};
19use bevy_ecs::{prelude::*, query::QueryItem};
20use bevy_math::{ops, UVec2};
21use bevy_render::{
22    camera::ExtractedCamera,
23    diagnostic::RecordDiagnostics,
24    extract_component::{
25        ComponentUniforms, DynamicUniformIndex, ExtractComponentPlugin, UniformComponentPlugin,
26    },
27    render_graph::{NodeRunError, RenderGraphContext, RenderGraphExt, ViewNode, ViewNodeRunner},
28    render_resource::*,
29    renderer::{RenderContext, RenderDevice},
30    texture::{CachedTexture, TextureCache},
31    view::ViewTarget,
32    Render, RenderApp, RenderStartup, RenderSystems,
33};
34use downsampling_pipeline::{
35    prepare_downsampling_pipeline, BloomDownsamplingPipeline, BloomDownsamplingPipelineIds,
36    BloomUniforms,
37};
38#[cfg(feature = "trace")]
39use tracing::info_span;
40use upsampling_pipeline::{
41    prepare_upsampling_pipeline, BloomUpsamplingPipeline, UpsamplingPipelineIds,
42};
43
44const BLOOM_TEXTURE_FORMAT: TextureFormat = TextureFormat::Rg11b10Ufloat;
45
46#[derive(Default)]
47pub struct BloomPlugin;
48
49impl Plugin for BloomPlugin {
50    fn build(&self, app: &mut App) {
51        embedded_asset!(app, "bloom.wgsl");
52
53        app.add_plugins((
54            ExtractComponentPlugin::<Bloom>::default(),
55            UniformComponentPlugin::<BloomUniforms>::default(),
56        ));
57
58        let Some(render_app) = app.get_sub_app_mut(RenderApp) else {
59            return;
60        };
61        render_app
62            .init_resource::<SpecializedRenderPipelines<BloomDownsamplingPipeline>>()
63            .init_resource::<SpecializedRenderPipelines<BloomUpsamplingPipeline>>()
64            .add_systems(
65                RenderStartup,
66                (
67                    init_bloom_downsampling_pipeline,
68                    init_bloom_upscaling_pipeline,
69                ),
70            )
71            .add_systems(
72                Render,
73                (
74                    prepare_downsampling_pipeline.in_set(RenderSystems::Prepare),
75                    prepare_upsampling_pipeline.in_set(RenderSystems::Prepare),
76                    prepare_bloom_textures.in_set(RenderSystems::PrepareResources),
77                    prepare_bloom_bind_groups.in_set(RenderSystems::PrepareBindGroups),
78                ),
79            )
80            // Add bloom to the 3d render graph
81            .add_render_graph_node::<ViewNodeRunner<BloomNode>>(Core3d, Node3d::Bloom)
82            .add_render_graph_edges(
83                Core3d,
84                (
85                    Node3d::StartMainPassPostProcessing,
86                    Node3d::Bloom,
87                    Node3d::Tonemapping,
88                ),
89            )
90            // Add bloom to the 2d render graph
91            .add_render_graph_node::<ViewNodeRunner<BloomNode>>(Core2d, Node2d::Bloom)
92            .add_render_graph_edges(
93                Core2d,
94                (
95                    Node2d::StartMainPassPostProcessing,
96                    Node2d::Bloom,
97                    Node2d::Tonemapping,
98                ),
99            );
100    }
101}
102
103#[derive(Default)]
104struct BloomNode;
105impl ViewNode for BloomNode {
106    type ViewQuery = (
107        &'static ExtractedCamera,
108        &'static ViewTarget,
109        &'static BloomTexture,
110        &'static BloomBindGroups,
111        &'static DynamicUniformIndex<BloomUniforms>,
112        &'static Bloom,
113        &'static UpsamplingPipelineIds,
114        &'static BloomDownsamplingPipelineIds,
115    );
116
117    // Atypically for a post-processing effect, we do not need to
118    // use a secondary texture normally provided by view_target.post_process_write(),
119    // instead we write into our own bloom texture and then directly back onto main.
120    fn run<'w>(
121        &self,
122        _graph: &mut RenderGraphContext,
123        render_context: &mut RenderContext<'w>,
124        (
125            camera,
126            view_target,
127            bloom_texture,
128            bind_groups,
129            uniform_index,
130            bloom_settings,
131            upsampling_pipeline_ids,
132            downsampling_pipeline_ids,
133        ): QueryItem<'w, '_, Self::ViewQuery>,
134        world: &'w World,
135    ) -> Result<(), NodeRunError> {
136        if bloom_settings.intensity == 0.0 {
137            return Ok(());
138        }
139
140        let downsampling_pipeline_res = world.resource::<BloomDownsamplingPipeline>();
141        let pipeline_cache = world.resource::<PipelineCache>();
142        let uniforms = world.resource::<ComponentUniforms<BloomUniforms>>();
143
144        let (
145            Some(uniforms),
146            Some(downsampling_first_pipeline),
147            Some(downsampling_pipeline),
148            Some(upsampling_pipeline),
149            Some(upsampling_final_pipeline),
150        ) = (
151            uniforms.binding(),
152            pipeline_cache.get_render_pipeline(downsampling_pipeline_ids.first),
153            pipeline_cache.get_render_pipeline(downsampling_pipeline_ids.main),
154            pipeline_cache.get_render_pipeline(upsampling_pipeline_ids.id_main),
155            pipeline_cache.get_render_pipeline(upsampling_pipeline_ids.id_final),
156        )
157        else {
158            return Ok(());
159        };
160
161        let view_texture = view_target.main_texture_view();
162        let view_texture_unsampled = view_target.get_unsampled_color_attachment();
163        let diagnostics = render_context.diagnostic_recorder();
164
165        render_context.add_command_buffer_generation_task(move |render_device| {
166            #[cfg(feature = "trace")]
167            let _bloom_span = info_span!("bloom").entered();
168
169            let mut command_encoder =
170                render_device.create_command_encoder(&CommandEncoderDescriptor {
171                    label: Some("bloom_command_encoder"),
172                });
173            command_encoder.push_debug_group("bloom");
174            let time_span = diagnostics.time_span(&mut command_encoder, "bloom");
175
176            // First downsample pass
177            {
178                let downsampling_first_bind_group = render_device.create_bind_group(
179                    "bloom_downsampling_first_bind_group",
180                    &downsampling_pipeline_res.bind_group_layout,
181                    &BindGroupEntries::sequential((
182                        // Read from main texture directly
183                        view_texture,
184                        &bind_groups.sampler,
185                        uniforms.clone(),
186                    )),
187                );
188
189                let view = &bloom_texture.view(0);
190                let mut downsampling_first_pass =
191                    command_encoder.begin_render_pass(&RenderPassDescriptor {
192                        label: Some("bloom_downsampling_first_pass"),
193                        color_attachments: &[Some(RenderPassColorAttachment {
194                            view,
195                            depth_slice: None,
196                            resolve_target: None,
197                            ops: Operations::default(),
198                        })],
199                        depth_stencil_attachment: None,
200                        timestamp_writes: None,
201                        occlusion_query_set: None,
202                    });
203                downsampling_first_pass.set_pipeline(downsampling_first_pipeline);
204                downsampling_first_pass.set_bind_group(
205                    0,
206                    &downsampling_first_bind_group,
207                    &[uniform_index.index()],
208                );
209                downsampling_first_pass.draw(0..3, 0..1);
210            }
211
212            // Other downsample passes
213            for mip in 1..bloom_texture.mip_count {
214                let view = &bloom_texture.view(mip);
215                let mut downsampling_pass =
216                    command_encoder.begin_render_pass(&RenderPassDescriptor {
217                        label: Some("bloom_downsampling_pass"),
218                        color_attachments: &[Some(RenderPassColorAttachment {
219                            view,
220                            depth_slice: None,
221                            resolve_target: None,
222                            ops: Operations::default(),
223                        })],
224                        depth_stencil_attachment: None,
225                        timestamp_writes: None,
226                        occlusion_query_set: None,
227                    });
228                downsampling_pass.set_pipeline(downsampling_pipeline);
229                downsampling_pass.set_bind_group(
230                    0,
231                    &bind_groups.downsampling_bind_groups[mip as usize - 1],
232                    &[uniform_index.index()],
233                );
234                downsampling_pass.draw(0..3, 0..1);
235            }
236
237            // Upsample passes except the final one
238            for mip in (1..bloom_texture.mip_count).rev() {
239                let view = &bloom_texture.view(mip - 1);
240                let mut upsampling_pass =
241                    command_encoder.begin_render_pass(&RenderPassDescriptor {
242                        label: Some("bloom_upsampling_pass"),
243                        color_attachments: &[Some(RenderPassColorAttachment {
244                            view,
245                            depth_slice: None,
246                            resolve_target: None,
247                            ops: Operations {
248                                load: LoadOp::Load,
249                                store: StoreOp::Store,
250                            },
251                        })],
252                        depth_stencil_attachment: None,
253                        timestamp_writes: None,
254                        occlusion_query_set: None,
255                    });
256                upsampling_pass.set_pipeline(upsampling_pipeline);
257                upsampling_pass.set_bind_group(
258                    0,
259                    &bind_groups.upsampling_bind_groups
260                        [(bloom_texture.mip_count - mip - 1) as usize],
261                    &[uniform_index.index()],
262                );
263                let blend = compute_blend_factor(
264                    bloom_settings,
265                    mip as f32,
266                    (bloom_texture.mip_count - 1) as f32,
267                );
268                upsampling_pass.set_blend_constant(LinearRgba::gray(blend).into());
269                upsampling_pass.draw(0..3, 0..1);
270            }
271
272            // Final upsample pass
273            // This is very similar to the above upsampling passes with the only difference
274            // being the pipeline (which itself is barely different) and the color attachment
275            {
276                let mut upsampling_final_pass =
277                    command_encoder.begin_render_pass(&RenderPassDescriptor {
278                        label: Some("bloom_upsampling_final_pass"),
279                        color_attachments: &[Some(view_texture_unsampled)],
280                        depth_stencil_attachment: None,
281                        timestamp_writes: None,
282                        occlusion_query_set: None,
283                    });
284                upsampling_final_pass.set_pipeline(upsampling_final_pipeline);
285                upsampling_final_pass.set_bind_group(
286                    0,
287                    &bind_groups.upsampling_bind_groups[(bloom_texture.mip_count - 1) as usize],
288                    &[uniform_index.index()],
289                );
290                if let Some(viewport) = camera.viewport.as_ref() {
291                    upsampling_final_pass.set_viewport(
292                        viewport.physical_position.x as f32,
293                        viewport.physical_position.y as f32,
294                        viewport.physical_size.x as f32,
295                        viewport.physical_size.y as f32,
296                        viewport.depth.start,
297                        viewport.depth.end,
298                    );
299                }
300                let blend =
301                    compute_blend_factor(bloom_settings, 0.0, (bloom_texture.mip_count - 1) as f32);
302                upsampling_final_pass.set_blend_constant(LinearRgba::gray(blend).into());
303                upsampling_final_pass.draw(0..3, 0..1);
304            }
305
306            time_span.end(&mut command_encoder);
307            command_encoder.pop_debug_group();
308            command_encoder.finish()
309        });
310
311        Ok(())
312    }
313}
314
315#[derive(Component)]
316struct BloomTexture {
317    // First mip is half the screen resolution, successive mips are half the previous
318    #[cfg(any(
319        not(feature = "webgl"),
320        not(target_arch = "wasm32"),
321        feature = "webgpu"
322    ))]
323    texture: CachedTexture,
324    // WebGL does not support binding specific mip levels for sampling, fallback to separate textures instead
325    #[cfg(all(feature = "webgl", target_arch = "wasm32", not(feature = "webgpu")))]
326    texture: Vec<CachedTexture>,
327    mip_count: u32,
328}
329
330impl BloomTexture {
331    #[cfg(any(
332        not(feature = "webgl"),
333        not(target_arch = "wasm32"),
334        feature = "webgpu"
335    ))]
336    fn view(&self, base_mip_level: u32) -> TextureView {
337        self.texture.texture.create_view(&TextureViewDescriptor {
338            base_mip_level,
339            mip_level_count: Some(1u32),
340            ..Default::default()
341        })
342    }
343    #[cfg(all(feature = "webgl", target_arch = "wasm32", not(feature = "webgpu")))]
344    fn view(&self, base_mip_level: u32) -> TextureView {
345        self.texture[base_mip_level as usize]
346            .texture
347            .create_view(&TextureViewDescriptor {
348                base_mip_level: 0,
349                mip_level_count: Some(1u32),
350                ..Default::default()
351            })
352    }
353}
354
355fn prepare_bloom_textures(
356    mut commands: Commands,
357    mut texture_cache: ResMut<TextureCache>,
358    render_device: Res<RenderDevice>,
359    views: Query<(Entity, &ExtractedCamera, &Bloom)>,
360) {
361    for (entity, camera, bloom) in &views {
362        if let Some(viewport) = camera.physical_viewport_size {
363            // How many times we can halve the resolution minus one so we don't go unnecessarily low
364            let mip_count = bloom.max_mip_dimension.ilog2().max(2) - 1;
365            let mip_height_ratio = if viewport.y != 0 {
366                bloom.max_mip_dimension as f32 / viewport.y as f32
367            } else {
368                0.
369            };
370
371            let texture_descriptor = TextureDescriptor {
372                label: Some("bloom_texture"),
373                size: (viewport.as_vec2() * mip_height_ratio)
374                    .round()
375                    .as_uvec2()
376                    .max(UVec2::ONE)
377                    .to_extents(),
378                mip_level_count: mip_count,
379                sample_count: 1,
380                dimension: TextureDimension::D2,
381                format: BLOOM_TEXTURE_FORMAT,
382                usage: TextureUsages::RENDER_ATTACHMENT | TextureUsages::TEXTURE_BINDING,
383                view_formats: &[],
384            };
385
386            #[cfg(any(
387                not(feature = "webgl"),
388                not(target_arch = "wasm32"),
389                feature = "webgpu"
390            ))]
391            let texture = texture_cache.get(&render_device, texture_descriptor);
392            #[cfg(all(feature = "webgl", target_arch = "wasm32", not(feature = "webgpu")))]
393            let texture: Vec<CachedTexture> = (0..mip_count)
394                .map(|mip| {
395                    texture_cache.get(
396                        &render_device,
397                        TextureDescriptor {
398                            size: Extent3d {
399                                width: (texture_descriptor.size.width >> mip).max(1),
400                                height: (texture_descriptor.size.height >> mip).max(1),
401                                depth_or_array_layers: 1,
402                            },
403                            mip_level_count: 1,
404                            ..texture_descriptor.clone()
405                        },
406                    )
407                })
408                .collect();
409
410            commands
411                .entity(entity)
412                .insert(BloomTexture { texture, mip_count });
413        }
414    }
415}
416
417#[derive(Component)]
418struct BloomBindGroups {
419    downsampling_bind_groups: Box<[BindGroup]>,
420    upsampling_bind_groups: Box<[BindGroup]>,
421    sampler: Sampler,
422}
423
424fn prepare_bloom_bind_groups(
425    mut commands: Commands,
426    render_device: Res<RenderDevice>,
427    downsampling_pipeline: Res<BloomDownsamplingPipeline>,
428    upsampling_pipeline: Res<BloomUpsamplingPipeline>,
429    views: Query<(Entity, &BloomTexture)>,
430    uniforms: Res<ComponentUniforms<BloomUniforms>>,
431) {
432    let sampler = &downsampling_pipeline.sampler;
433
434    for (entity, bloom_texture) in &views {
435        let bind_group_count = bloom_texture.mip_count as usize - 1;
436
437        let mut downsampling_bind_groups = Vec::with_capacity(bind_group_count);
438        for mip in 1..bloom_texture.mip_count {
439            downsampling_bind_groups.push(render_device.create_bind_group(
440                "bloom_downsampling_bind_group",
441                &downsampling_pipeline.bind_group_layout,
442                &BindGroupEntries::sequential((
443                    &bloom_texture.view(mip - 1),
444                    sampler,
445                    uniforms.binding().unwrap(),
446                )),
447            ));
448        }
449
450        let mut upsampling_bind_groups = Vec::with_capacity(bind_group_count);
451        for mip in (0..bloom_texture.mip_count).rev() {
452            upsampling_bind_groups.push(render_device.create_bind_group(
453                "bloom_upsampling_bind_group",
454                &upsampling_pipeline.bind_group_layout,
455                &BindGroupEntries::sequential((
456                    &bloom_texture.view(mip),
457                    sampler,
458                    uniforms.binding().unwrap(),
459                )),
460            ));
461        }
462
463        commands.entity(entity).insert(BloomBindGroups {
464            downsampling_bind_groups: downsampling_bind_groups.into_boxed_slice(),
465            upsampling_bind_groups: upsampling_bind_groups.into_boxed_slice(),
466            sampler: sampler.clone(),
467        });
468    }
469}
470
471/// Calculates blend intensities of blur pyramid levels
472/// during the upsampling + compositing stage.
473///
474/// The function assumes all pyramid levels are upsampled and
475/// blended into higher frequency ones using this function to
476/// calculate blend levels every time. The final (highest frequency)
477/// pyramid level in not blended into anything therefore this function
478/// is not applied to it. As a result, the *mip* parameter of 0 indicates
479/// the second-highest frequency pyramid level (in our case that is the
480/// 0th mip of the bloom texture with the original image being the
481/// actual highest frequency level).
482///
483/// Parameters:
484/// * `mip` - the index of the lower frequency pyramid level (0 - `max_mip`, where 0 indicates highest frequency mip but not the highest frequency image).
485/// * `max_mip` - the index of the lowest frequency pyramid level.
486///
487/// This function can be visually previewed for all values of *mip* (normalized) with tweakable
488/// [`Bloom`] parameters on [Desmos graphing calculator](https://www.desmos.com/calculator/ncc8xbhzzl).
489fn compute_blend_factor(bloom: &Bloom, mip: f32, max_mip: f32) -> f32 {
490    let mut lf_boost =
491        (1.0 - ops::powf(
492            1.0 - (mip / max_mip),
493            1.0 / (1.0 - bloom.low_frequency_boost_curvature),
494        )) * bloom.low_frequency_boost;
495    let high_pass_lq = 1.0
496        - (((mip / max_mip) - bloom.high_pass_frequency) / bloom.high_pass_frequency)
497            .clamp(0.0, 1.0);
498    lf_boost *= match bloom.composite_mode {
499        BloomCompositeMode::EnergyConserving => 1.0 - bloom.intensity,
500        BloomCompositeMode::Additive => 1.0,
501    };
502
503    (bloom.intensity + lf_boost) * high_pass_lq
504}