1use bevy_ecs::{prelude::Entity, world::World};
2#[cfg(feature = "trace")]
3use bevy_utils::tracing::info_span;
4use bevy_utils::HashMap;
5
6use alloc::{borrow::Cow, collections::VecDeque};
7use derive_more::derive::{Display, Error, From};
8use smallvec::{smallvec, SmallVec};
9
10use crate::{
11 diagnostic::internal::{DiagnosticsRecorder, RenderDiagnosticsMutex},
12 render_graph::{
13 Edge, InternedRenderLabel, InternedRenderSubGraph, NodeRunError, NodeState, RenderGraph,
14 RenderGraphContext, SlotLabel, SlotType, SlotValue,
15 },
16 renderer::{RenderContext, RenderDevice},
17};
18
19pub(crate) struct RenderGraphRunner;
31
32#[derive(Error, Display, Debug, From)]
33pub enum RenderGraphRunnerError {
34 NodeRunError(NodeRunError),
35 #[display("node output slot not set (index {slot_index}, name {slot_name})")]
36 EmptyNodeOutputSlot {
37 type_name: &'static str,
38 slot_index: usize,
39 slot_name: Cow<'static, str>,
40 },
41 #[display("graph '{sub_graph:?}' could not be run because slot '{slot_name}' at index {slot_index} has no value")]
42 MissingInput {
43 slot_index: usize,
44 slot_name: Cow<'static, str>,
45 sub_graph: Option<InternedRenderSubGraph>,
46 },
47 #[display("attempted to use the wrong type for input slot")]
48 MismatchedInputSlotType {
49 slot_index: usize,
50 label: SlotLabel,
51 expected: SlotType,
52 actual: SlotType,
53 },
54 #[display(
55 "node (name: '{node_name:?}') has {slot_count} input slots, but was provided {value_count} values"
56 )]
57 MismatchedInputCount {
58 node_name: InternedRenderLabel,
59 slot_count: usize,
60 value_count: usize,
61 },
62}
63
64impl RenderGraphRunner {
65 pub fn run(
66 graph: &RenderGraph,
67 render_device: RenderDevice,
68 mut diagnostics_recorder: Option<DiagnosticsRecorder>,
69 queue: &wgpu::Queue,
70 adapter: &wgpu::Adapter,
71 world: &World,
72 finalizer: impl FnOnce(&mut wgpu::CommandEncoder),
73 ) -> Result<Option<DiagnosticsRecorder>, RenderGraphRunnerError> {
74 if let Some(recorder) = &mut diagnostics_recorder {
75 recorder.begin_frame();
76 }
77
78 let mut render_context =
79 RenderContext::new(render_device, adapter.get_info(), diagnostics_recorder);
80 Self::run_graph(graph, None, &mut render_context, world, &[], None)?;
81 finalizer(render_context.command_encoder());
82
83 let (render_device, mut diagnostics_recorder) = {
84 #[cfg(feature = "trace")]
85 let _span = info_span!("submit_graph_commands").entered();
86
87 let (commands, render_device, diagnostics_recorder) = render_context.finish();
88 queue.submit(commands);
89
90 (render_device, diagnostics_recorder)
91 };
92
93 if let Some(recorder) = &mut diagnostics_recorder {
94 let render_diagnostics_mutex = world.resource::<RenderDiagnosticsMutex>().0.clone();
95 recorder.finish_frame(&render_device, move |diagnostics| {
96 *render_diagnostics_mutex.lock().expect("lock poisoned") = Some(diagnostics);
97 });
98 }
99
100 Ok(diagnostics_recorder)
101 }
102
103 fn run_graph<'w>(
106 graph: &RenderGraph,
107 sub_graph: Option<InternedRenderSubGraph>,
108 render_context: &mut RenderContext<'w>,
109 world: &'w World,
110 inputs: &[SlotValue],
111 view_entity: Option<Entity>,
112 ) -> Result<(), RenderGraphRunnerError> {
113 let mut node_outputs: HashMap<InternedRenderLabel, SmallVec<[SlotValue; 4]>> =
114 HashMap::default();
115 #[cfg(feature = "trace")]
116 let span = if let Some(label) = &sub_graph {
117 info_span!("run_graph", name = format!("{label:?}"))
118 } else {
119 info_span!("run_graph", name = "main_graph")
120 };
121 #[cfg(feature = "trace")]
122 let _guard = span.enter();
123
124 let mut node_queue: VecDeque<&NodeState> = graph
126 .iter_nodes()
127 .filter(|node| node.input_slots.is_empty())
128 .collect();
129
130 if let Some(input_node) = graph.get_input_node() {
132 let mut input_values: SmallVec<[SlotValue; 4]> = SmallVec::new();
133 for (i, input_slot) in input_node.input_slots.iter().enumerate() {
134 if let Some(input_value) = inputs.get(i) {
135 if input_slot.slot_type != input_value.slot_type() {
136 return Err(RenderGraphRunnerError::MismatchedInputSlotType {
137 slot_index: i,
138 actual: input_value.slot_type(),
139 expected: input_slot.slot_type,
140 label: input_slot.name.clone().into(),
141 });
142 }
143 input_values.push(input_value.clone());
144 } else {
145 return Err(RenderGraphRunnerError::MissingInput {
146 slot_index: i,
147 slot_name: input_slot.name.clone(),
148 sub_graph,
149 });
150 }
151 }
152
153 node_outputs.insert(input_node.label, input_values);
154
155 for (_, node_state) in graph
156 .iter_node_outputs(input_node.label)
157 .expect("node exists")
158 {
159 node_queue.push_front(node_state);
160 }
161 }
162
163 'handle_node: while let Some(node_state) = node_queue.pop_back() {
164 if node_outputs.contains_key(&node_state.label) {
166 continue;
167 }
168
169 let mut slot_indices_and_inputs: SmallVec<[(usize, SlotValue); 4]> = SmallVec::new();
170 for (edge, input_node) in graph
172 .iter_node_inputs(node_state.label)
173 .expect("node is in graph")
174 {
175 match edge {
176 Edge::SlotEdge {
177 output_index,
178 input_index,
179 ..
180 } => {
181 if let Some(outputs) = node_outputs.get(&input_node.label) {
182 slot_indices_and_inputs
183 .push((*input_index, outputs[*output_index].clone()));
184 } else {
185 node_queue.push_front(node_state);
186 continue 'handle_node;
187 }
188 }
189 Edge::NodeEdge { .. } => {
190 if !node_outputs.contains_key(&input_node.label) {
191 node_queue.push_front(node_state);
192 continue 'handle_node;
193 }
194 }
195 }
196 }
197
198 slot_indices_and_inputs.sort_by_key(|(index, _)| *index);
200 let inputs: SmallVec<[SlotValue; 4]> = slot_indices_and_inputs
201 .into_iter()
202 .map(|(_, value)| value)
203 .collect();
204
205 if inputs.len() != node_state.input_slots.len() {
206 return Err(RenderGraphRunnerError::MismatchedInputCount {
207 node_name: node_state.label,
208 slot_count: node_state.input_slots.len(),
209 value_count: inputs.len(),
210 });
211 }
212
213 let mut outputs: SmallVec<[Option<SlotValue>; 4]> =
214 smallvec![None; node_state.output_slots.len()];
215 {
216 let mut context = RenderGraphContext::new(graph, node_state, &inputs, &mut outputs);
217 if let Some(view_entity) = view_entity {
218 context.set_view_entity(view_entity);
219 }
220
221 {
222 #[cfg(feature = "trace")]
223 let _span = info_span!("node", name = node_state.type_name).entered();
224
225 node_state.node.run(&mut context, render_context, world)?;
226 }
227
228 for run_sub_graph in context.finish() {
229 let sub_graph = graph
230 .get_sub_graph(run_sub_graph.sub_graph)
231 .expect("sub graph exists because it was validated when queued.");
232 Self::run_graph(
233 sub_graph,
234 Some(run_sub_graph.sub_graph),
235 render_context,
236 world,
237 &run_sub_graph.inputs,
238 run_sub_graph.view_entity,
239 )?;
240 }
241 }
242
243 let mut values: SmallVec<[SlotValue; 4]> = SmallVec::new();
244 for (i, output) in outputs.into_iter().enumerate() {
245 if let Some(value) = output {
246 values.push(value);
247 } else {
248 let empty_slot = node_state.output_slots.get_slot(i).unwrap();
249 return Err(RenderGraphRunnerError::EmptyNodeOutputSlot {
250 type_name: node_state.type_name,
251 slot_index: i,
252 slot_name: empty_slot.name.clone(),
253 });
254 }
255 }
256 node_outputs.insert(node_state.label, values);
257
258 for (_, node_state) in graph
259 .iter_node_outputs(node_state.label)
260 .expect("node exists")
261 {
262 node_queue.push_front(node_state);
263 }
264 }
265
266 Ok(())
267 }
268}