bevy_core_pipeline/bloom/
mod.rs

1mod downsampling_pipeline;
2mod settings;
3mod upsampling_pipeline;
4
5pub use settings::{Bloom, BloomCompositeMode, BloomPrefilter};
6
7use crate::{
8    core_2d::graph::{Core2d, Node2d},
9    core_3d::graph::{Core3d, Node3d},
10};
11use bevy_app::{App, Plugin};
12use bevy_asset::{load_internal_asset, weak_handle, Handle};
13use bevy_color::{Gray, LinearRgba};
14use bevy_ecs::{prelude::*, query::QueryItem};
15use bevy_math::{ops, UVec2};
16use bevy_render::{
17    camera::ExtractedCamera,
18    diagnostic::RecordDiagnostics,
19    extract_component::{
20        ComponentUniforms, DynamicUniformIndex, ExtractComponentPlugin, UniformComponentPlugin,
21    },
22    render_graph::{NodeRunError, RenderGraphApp, RenderGraphContext, ViewNode, ViewNodeRunner},
23    render_resource::*,
24    renderer::{RenderContext, RenderDevice},
25    texture::{CachedTexture, TextureCache},
26    view::ViewTarget,
27    Render, RenderApp, RenderSet,
28};
29use downsampling_pipeline::{
30    prepare_downsampling_pipeline, BloomDownsamplingPipeline, BloomDownsamplingPipelineIds,
31    BloomUniforms,
32};
33#[cfg(feature = "trace")]
34use tracing::info_span;
35use upsampling_pipeline::{
36    prepare_upsampling_pipeline, BloomUpsamplingPipeline, UpsamplingPipelineIds,
37};
38
39const BLOOM_SHADER_HANDLE: Handle<Shader> = weak_handle!("c9190ddc-573b-4472-8b21-573cab502b73");
40
41const BLOOM_TEXTURE_FORMAT: TextureFormat = TextureFormat::Rg11b10Ufloat;
42
43pub struct BloomPlugin;
44
45impl Plugin for BloomPlugin {
46    fn build(&self, app: &mut App) {
47        load_internal_asset!(app, BLOOM_SHADER_HANDLE, "bloom.wgsl", Shader::from_wgsl);
48
49        app.register_type::<Bloom>();
50        app.register_type::<BloomPrefilter>();
51        app.register_type::<BloomCompositeMode>();
52        app.add_plugins((
53            ExtractComponentPlugin::<Bloom>::default(),
54            UniformComponentPlugin::<BloomUniforms>::default(),
55        ));
56
57        let Some(render_app) = app.get_sub_app_mut(RenderApp) else {
58            return;
59        };
60        render_app
61            .init_resource::<SpecializedRenderPipelines<BloomDownsamplingPipeline>>()
62            .init_resource::<SpecializedRenderPipelines<BloomUpsamplingPipeline>>()
63            .add_systems(
64                Render,
65                (
66                    prepare_downsampling_pipeline.in_set(RenderSet::Prepare),
67                    prepare_upsampling_pipeline.in_set(RenderSet::Prepare),
68                    prepare_bloom_textures.in_set(RenderSet::PrepareResources),
69                    prepare_bloom_bind_groups.in_set(RenderSet::PrepareBindGroups),
70                ),
71            )
72            // Add bloom to the 3d render graph
73            .add_render_graph_node::<ViewNodeRunner<BloomNode>>(Core3d, Node3d::Bloom)
74            .add_render_graph_edges(
75                Core3d,
76                (Node3d::EndMainPass, Node3d::Bloom, Node3d::Tonemapping),
77            )
78            // Add bloom to the 2d render graph
79            .add_render_graph_node::<ViewNodeRunner<BloomNode>>(Core2d, Node2d::Bloom)
80            .add_render_graph_edges(
81                Core2d,
82                (Node2d::EndMainPass, Node2d::Bloom, Node2d::Tonemapping),
83            );
84    }
85
86    fn finish(&self, app: &mut App) {
87        let Some(render_app) = app.get_sub_app_mut(RenderApp) else {
88            return;
89        };
90        render_app
91            .init_resource::<BloomDownsamplingPipeline>()
92            .init_resource::<BloomUpsamplingPipeline>();
93    }
94}
95
96#[derive(Default)]
97struct BloomNode;
98impl ViewNode for BloomNode {
99    type ViewQuery = (
100        &'static ExtractedCamera,
101        &'static ViewTarget,
102        &'static BloomTexture,
103        &'static BloomBindGroups,
104        &'static DynamicUniformIndex<BloomUniforms>,
105        &'static Bloom,
106        &'static UpsamplingPipelineIds,
107        &'static BloomDownsamplingPipelineIds,
108    );
109
110    // Atypically for a post-processing effect, we do not need to
111    // use a secondary texture normally provided by view_target.post_process_write(),
112    // instead we write into our own bloom texture and then directly back onto main.
113    fn run<'w>(
114        &self,
115        _graph: &mut RenderGraphContext,
116        render_context: &mut RenderContext<'w>,
117        (
118            camera,
119            view_target,
120            bloom_texture,
121            bind_groups,
122            uniform_index,
123            bloom_settings,
124            upsampling_pipeline_ids,
125            downsampling_pipeline_ids,
126        ): QueryItem<'w, Self::ViewQuery>,
127        world: &'w World,
128    ) -> Result<(), NodeRunError> {
129        if bloom_settings.intensity == 0.0 {
130            return Ok(());
131        }
132
133        let downsampling_pipeline_res = world.resource::<BloomDownsamplingPipeline>();
134        let pipeline_cache = world.resource::<PipelineCache>();
135        let uniforms = world.resource::<ComponentUniforms<BloomUniforms>>();
136
137        let (
138            Some(uniforms),
139            Some(downsampling_first_pipeline),
140            Some(downsampling_pipeline),
141            Some(upsampling_pipeline),
142            Some(upsampling_final_pipeline),
143        ) = (
144            uniforms.binding(),
145            pipeline_cache.get_render_pipeline(downsampling_pipeline_ids.first),
146            pipeline_cache.get_render_pipeline(downsampling_pipeline_ids.main),
147            pipeline_cache.get_render_pipeline(upsampling_pipeline_ids.id_main),
148            pipeline_cache.get_render_pipeline(upsampling_pipeline_ids.id_final),
149        )
150        else {
151            return Ok(());
152        };
153
154        let view_texture = view_target.main_texture_view();
155        let view_texture_unsampled = view_target.get_unsampled_color_attachment();
156        let diagnostics = render_context.diagnostic_recorder();
157
158        render_context.add_command_buffer_generation_task(move |render_device| {
159            #[cfg(feature = "trace")]
160            let _bloom_span = info_span!("bloom").entered();
161
162            let mut command_encoder =
163                render_device.create_command_encoder(&CommandEncoderDescriptor {
164                    label: Some("bloom_command_encoder"),
165                });
166            command_encoder.push_debug_group("bloom");
167            let time_span = diagnostics.time_span(&mut command_encoder, "bloom");
168
169            // First downsample pass
170            {
171                let downsampling_first_bind_group = render_device.create_bind_group(
172                    "bloom_downsampling_first_bind_group",
173                    &downsampling_pipeline_res.bind_group_layout,
174                    &BindGroupEntries::sequential((
175                        // Read from main texture directly
176                        view_texture,
177                        &bind_groups.sampler,
178                        uniforms.clone(),
179                    )),
180                );
181
182                let view = &bloom_texture.view(0);
183                let mut downsampling_first_pass =
184                    command_encoder.begin_render_pass(&RenderPassDescriptor {
185                        label: Some("bloom_downsampling_first_pass"),
186                        color_attachments: &[Some(RenderPassColorAttachment {
187                            view,
188                            resolve_target: None,
189                            ops: Operations::default(),
190                        })],
191                        depth_stencil_attachment: None,
192                        timestamp_writes: None,
193                        occlusion_query_set: None,
194                    });
195                downsampling_first_pass.set_pipeline(downsampling_first_pipeline);
196                downsampling_first_pass.set_bind_group(
197                    0,
198                    &downsampling_first_bind_group,
199                    &[uniform_index.index()],
200                );
201                downsampling_first_pass.draw(0..3, 0..1);
202            }
203
204            // Other downsample passes
205            for mip in 1..bloom_texture.mip_count {
206                let view = &bloom_texture.view(mip);
207                let mut downsampling_pass =
208                    command_encoder.begin_render_pass(&RenderPassDescriptor {
209                        label: Some("bloom_downsampling_pass"),
210                        color_attachments: &[Some(RenderPassColorAttachment {
211                            view,
212                            resolve_target: None,
213                            ops: Operations::default(),
214                        })],
215                        depth_stencil_attachment: None,
216                        timestamp_writes: None,
217                        occlusion_query_set: None,
218                    });
219                downsampling_pass.set_pipeline(downsampling_pipeline);
220                downsampling_pass.set_bind_group(
221                    0,
222                    &bind_groups.downsampling_bind_groups[mip as usize - 1],
223                    &[uniform_index.index()],
224                );
225                downsampling_pass.draw(0..3, 0..1);
226            }
227
228            // Upsample passes except the final one
229            for mip in (1..bloom_texture.mip_count).rev() {
230                let view = &bloom_texture.view(mip - 1);
231                let mut upsampling_pass =
232                    command_encoder.begin_render_pass(&RenderPassDescriptor {
233                        label: Some("bloom_upsampling_pass"),
234                        color_attachments: &[Some(RenderPassColorAttachment {
235                            view,
236                            resolve_target: None,
237                            ops: Operations {
238                                load: LoadOp::Load,
239                                store: StoreOp::Store,
240                            },
241                        })],
242                        depth_stencil_attachment: None,
243                        timestamp_writes: None,
244                        occlusion_query_set: None,
245                    });
246                upsampling_pass.set_pipeline(upsampling_pipeline);
247                upsampling_pass.set_bind_group(
248                    0,
249                    &bind_groups.upsampling_bind_groups
250                        [(bloom_texture.mip_count - mip - 1) as usize],
251                    &[uniform_index.index()],
252                );
253                let blend = compute_blend_factor(
254                    bloom_settings,
255                    mip as f32,
256                    (bloom_texture.mip_count - 1) as f32,
257                );
258                upsampling_pass.set_blend_constant(LinearRgba::gray(blend).into());
259                upsampling_pass.draw(0..3, 0..1);
260            }
261
262            // Final upsample pass
263            // This is very similar to the above upsampling passes with the only difference
264            // being the pipeline (which itself is barely different) and the color attachment
265            {
266                let mut upsampling_final_pass =
267                    command_encoder.begin_render_pass(&RenderPassDescriptor {
268                        label: Some("bloom_upsampling_final_pass"),
269                        color_attachments: &[Some(view_texture_unsampled)],
270                        depth_stencil_attachment: None,
271                        timestamp_writes: None,
272                        occlusion_query_set: None,
273                    });
274                upsampling_final_pass.set_pipeline(upsampling_final_pipeline);
275                upsampling_final_pass.set_bind_group(
276                    0,
277                    &bind_groups.upsampling_bind_groups[(bloom_texture.mip_count - 1) as usize],
278                    &[uniform_index.index()],
279                );
280                if let Some(viewport) = camera.viewport.as_ref() {
281                    upsampling_final_pass.set_viewport(
282                        viewport.physical_position.x as f32,
283                        viewport.physical_position.y as f32,
284                        viewport.physical_size.x as f32,
285                        viewport.physical_size.y as f32,
286                        viewport.depth.start,
287                        viewport.depth.end,
288                    );
289                }
290                let blend =
291                    compute_blend_factor(bloom_settings, 0.0, (bloom_texture.mip_count - 1) as f32);
292                upsampling_final_pass.set_blend_constant(LinearRgba::gray(blend).into());
293                upsampling_final_pass.draw(0..3, 0..1);
294            }
295
296            time_span.end(&mut command_encoder);
297            command_encoder.pop_debug_group();
298            command_encoder.finish()
299        });
300
301        Ok(())
302    }
303}
304
305#[derive(Component)]
306struct BloomTexture {
307    // First mip is half the screen resolution, successive mips are half the previous
308    #[cfg(any(
309        not(feature = "webgl"),
310        not(target_arch = "wasm32"),
311        feature = "webgpu"
312    ))]
313    texture: CachedTexture,
314    // WebGL does not support binding specific mip levels for sampling, fallback to separate textures instead
315    #[cfg(all(feature = "webgl", target_arch = "wasm32", not(feature = "webgpu")))]
316    texture: Vec<CachedTexture>,
317    mip_count: u32,
318}
319
320impl BloomTexture {
321    #[cfg(any(
322        not(feature = "webgl"),
323        not(target_arch = "wasm32"),
324        feature = "webgpu"
325    ))]
326    fn view(&self, base_mip_level: u32) -> TextureView {
327        self.texture.texture.create_view(&TextureViewDescriptor {
328            base_mip_level,
329            mip_level_count: Some(1u32),
330            ..Default::default()
331        })
332    }
333    #[cfg(all(feature = "webgl", target_arch = "wasm32", not(feature = "webgpu")))]
334    fn view(&self, base_mip_level: u32) -> TextureView {
335        self.texture[base_mip_level as usize]
336            .texture
337            .create_view(&TextureViewDescriptor {
338                base_mip_level: 0,
339                mip_level_count: Some(1u32),
340                ..Default::default()
341            })
342    }
343}
344
345fn prepare_bloom_textures(
346    mut commands: Commands,
347    mut texture_cache: ResMut<TextureCache>,
348    render_device: Res<RenderDevice>,
349    views: Query<(Entity, &ExtractedCamera, &Bloom)>,
350) {
351    for (entity, camera, bloom) in &views {
352        if let Some(UVec2 {
353            x: width,
354            y: height,
355        }) = camera.physical_viewport_size
356        {
357            // How many times we can halve the resolution minus one so we don't go unnecessarily low
358            let mip_count = bloom.max_mip_dimension.ilog2().max(2) - 1;
359            let mip_height_ratio = if height != 0 {
360                bloom.max_mip_dimension as f32 / height as f32
361            } else {
362                0.
363            };
364
365            let texture_descriptor = TextureDescriptor {
366                label: Some("bloom_texture"),
367                size: Extent3d {
368                    width: ((width as f32 * mip_height_ratio).round() as u32).max(1),
369                    height: ((height as f32 * mip_height_ratio).round() as u32).max(1),
370                    depth_or_array_layers: 1,
371                },
372                mip_level_count: mip_count,
373                sample_count: 1,
374                dimension: TextureDimension::D2,
375                format: BLOOM_TEXTURE_FORMAT,
376                usage: TextureUsages::RENDER_ATTACHMENT | TextureUsages::TEXTURE_BINDING,
377                view_formats: &[],
378            };
379
380            #[cfg(any(
381                not(feature = "webgl"),
382                not(target_arch = "wasm32"),
383                feature = "webgpu"
384            ))]
385            let texture = texture_cache.get(&render_device, texture_descriptor);
386            #[cfg(all(feature = "webgl", target_arch = "wasm32", not(feature = "webgpu")))]
387            let texture: Vec<CachedTexture> = (0..mip_count)
388                .map(|mip| {
389                    texture_cache.get(
390                        &render_device,
391                        TextureDescriptor {
392                            size: Extent3d {
393                                width: (texture_descriptor.size.width >> mip).max(1),
394                                height: (texture_descriptor.size.height >> mip).max(1),
395                                depth_or_array_layers: 1,
396                            },
397                            mip_level_count: 1,
398                            ..texture_descriptor.clone()
399                        },
400                    )
401                })
402                .collect();
403
404            commands
405                .entity(entity)
406                .insert(BloomTexture { texture, mip_count });
407        }
408    }
409}
410
411#[derive(Component)]
412struct BloomBindGroups {
413    downsampling_bind_groups: Box<[BindGroup]>,
414    upsampling_bind_groups: Box<[BindGroup]>,
415    sampler: Sampler,
416}
417
418fn prepare_bloom_bind_groups(
419    mut commands: Commands,
420    render_device: Res<RenderDevice>,
421    downsampling_pipeline: Res<BloomDownsamplingPipeline>,
422    upsampling_pipeline: Res<BloomUpsamplingPipeline>,
423    views: Query<(Entity, &BloomTexture)>,
424    uniforms: Res<ComponentUniforms<BloomUniforms>>,
425) {
426    let sampler = &downsampling_pipeline.sampler;
427
428    for (entity, bloom_texture) in &views {
429        let bind_group_count = bloom_texture.mip_count as usize - 1;
430
431        let mut downsampling_bind_groups = Vec::with_capacity(bind_group_count);
432        for mip in 1..bloom_texture.mip_count {
433            downsampling_bind_groups.push(render_device.create_bind_group(
434                "bloom_downsampling_bind_group",
435                &downsampling_pipeline.bind_group_layout,
436                &BindGroupEntries::sequential((
437                    &bloom_texture.view(mip - 1),
438                    sampler,
439                    uniforms.binding().unwrap(),
440                )),
441            ));
442        }
443
444        let mut upsampling_bind_groups = Vec::with_capacity(bind_group_count);
445        for mip in (0..bloom_texture.mip_count).rev() {
446            upsampling_bind_groups.push(render_device.create_bind_group(
447                "bloom_upsampling_bind_group",
448                &upsampling_pipeline.bind_group_layout,
449                &BindGroupEntries::sequential((
450                    &bloom_texture.view(mip),
451                    sampler,
452                    uniforms.binding().unwrap(),
453                )),
454            ));
455        }
456
457        commands.entity(entity).insert(BloomBindGroups {
458            downsampling_bind_groups: downsampling_bind_groups.into_boxed_slice(),
459            upsampling_bind_groups: upsampling_bind_groups.into_boxed_slice(),
460            sampler: sampler.clone(),
461        });
462    }
463}
464
465/// Calculates blend intensities of blur pyramid levels
466/// during the upsampling + compositing stage.
467///
468/// The function assumes all pyramid levels are upsampled and
469/// blended into higher frequency ones using this function to
470/// calculate blend levels every time. The final (highest frequency)
471/// pyramid level in not blended into anything therefore this function
472/// is not applied to it. As a result, the *mip* parameter of 0 indicates
473/// the second-highest frequency pyramid level (in our case that is the
474/// 0th mip of the bloom texture with the original image being the
475/// actual highest frequency level).
476///
477/// Parameters:
478/// * `mip` - the index of the lower frequency pyramid level (0 - `max_mip`, where 0 indicates highest frequency mip but not the highest frequency image).
479/// * `max_mip` - the index of the lowest frequency pyramid level.
480///
481/// This function can be visually previewed for all values of *mip* (normalized) with tweakable
482/// [`Bloom`] parameters on [Desmos graphing calculator](https://www.desmos.com/calculator/ncc8xbhzzl).
483fn compute_blend_factor(bloom: &Bloom, mip: f32, max_mip: f32) -> f32 {
484    let mut lf_boost =
485        (1.0 - ops::powf(
486            1.0 - (mip / max_mip),
487            1.0 / (1.0 - bloom.low_frequency_boost_curvature),
488        )) * bloom.low_frequency_boost;
489    let high_pass_lq = 1.0
490        - (((mip / max_mip) - bloom.high_pass_frequency) / bloom.high_pass_frequency)
491            .clamp(0.0, 1.0);
492    lf_boost *= match bloom.composite_mode {
493        BloomCompositeMode::EnergyConserving => 1.0 - bloom.intensity,
494        BloomCompositeMode::Additive => 1.0,
495    };
496
497    (bloom.intensity + lf_boost) * high_pass_lq
498}