bevy_render/render_graph/
graph.rs

1use crate::{
2    render_graph::{
3        Edge, Node, NodeRunError, NodeState, RenderGraphContext, RenderGraphError, RenderLabel,
4        SlotInfo, SlotLabel,
5    },
6    renderer::RenderContext,
7};
8use bevy_ecs::{define_label, intern::Interned, prelude::World, resource::Resource};
9use bevy_platform::collections::HashMap;
10use core::fmt::Debug;
11
12use super::{EdgeExistence, InternedRenderLabel, IntoRenderNodeArray};
13
14pub use bevy_render_macros::RenderSubGraph;
15
16define_label!(
17    #[diagnostic::on_unimplemented(
18        note = "consider annotating `{Self}` with `#[derive(RenderSubGraph)]`"
19    )]
20    /// A strongly-typed class of labels used to identify a [`SubGraph`] in a render graph.
21    RenderSubGraph,
22    RENDER_SUB_GRAPH_INTERNER
23);
24
25/// A shorthand for `Interned<dyn RenderSubGraph>`.
26pub type InternedRenderSubGraph = Interned<dyn RenderSubGraph>;
27
28/// The render graph configures the modular and re-usable render logic.
29///
30/// It is a retained and stateless (nodes themselves may have their own internal state) structure,
31/// which can not be modified while it is executed by the graph runner.
32///
33/// The render graph runner is responsible for executing the entire graph each frame.
34/// It will execute each node in the graph in the correct order, based on the edges between the nodes.
35///
36/// It consists of three main components: [`Nodes`](Node), [`Edges`](Edge)
37/// and [`Slots`](super::SlotType).
38///
39/// Nodes are responsible for generating draw calls and operating on input and output slots.
40/// Edges specify the order of execution for nodes and connect input and output slots together.
41/// Slots describe the render resources created or used by the nodes.
42///
43/// Additionally a render graph can contain multiple sub graphs, which are run by the
44/// corresponding nodes. Every render graph can have its own optional input node.
45///
46/// ## Example
47/// Here is a simple render graph example with two nodes connected by a node edge.
48/// ```ignore
49/// # TODO: Remove when #10645 is fixed
50/// # use bevy_app::prelude::*;
51/// # use bevy_ecs::prelude::World;
52/// # use bevy_render::render_graph::{RenderGraph, RenderLabel, Node, RenderGraphContext, NodeRunError};
53/// # use bevy_render::renderer::RenderContext;
54/// #
55/// #[derive(RenderLabel)]
56/// enum Labels {
57///     A,
58///     B,
59/// }
60///
61/// # struct MyNode;
62/// #
63/// # impl Node for MyNode {
64/// #     fn run(&self, graph: &mut RenderGraphContext, render_context: &mut RenderContext, world: &World) -> Result<(), NodeRunError> {
65/// #         unimplemented!()
66/// #     }
67/// # }
68/// #
69/// let mut graph = RenderGraph::default();
70/// graph.add_node(Labels::A, MyNode);
71/// graph.add_node(Labels::B, MyNode);
72/// graph.add_node_edge(Labels::B, Labels::A);
73/// ```
74#[derive(Resource, Default)]
75pub struct RenderGraph {
76    nodes: HashMap<InternedRenderLabel, NodeState>,
77    sub_graphs: HashMap<InternedRenderSubGraph, RenderGraph>,
78}
79
80/// The label for the input node of a graph. Used to connect other nodes to it.
81#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
82pub struct GraphInput;
83
84impl RenderGraph {
85    /// Updates all nodes and sub graphs of the render graph. Should be called before executing it.
86    pub fn update(&mut self, world: &mut World) {
87        for node in self.nodes.values_mut() {
88            node.node.update(world);
89        }
90
91        for sub_graph in self.sub_graphs.values_mut() {
92            sub_graph.update(world);
93        }
94    }
95
96    /// Creates an [`GraphInputNode`] with the specified slots if not already present.
97    pub fn set_input(&mut self, inputs: Vec<SlotInfo>) {
98        assert!(
99            matches!(
100                self.get_node_state(GraphInput),
101                Err(RenderGraphError::InvalidNode(_))
102            ),
103            "Graph already has an input node"
104        );
105
106        self.add_node(GraphInput, GraphInputNode { inputs });
107    }
108
109    /// Returns the [`NodeState`] of the input node of this graph.
110    ///
111    /// # See also
112    ///
113    /// - [`input_node`](Self::input_node) for an unchecked version.
114    #[inline]
115    pub fn get_input_node(&self) -> Option<&NodeState> {
116        self.get_node_state(GraphInput).ok()
117    }
118
119    /// Returns the [`NodeState`] of the input node of this graph.
120    ///
121    /// # Panics
122    ///
123    /// Panics if there is no input node set.
124    ///
125    /// # See also
126    ///
127    /// - [`get_input_node`](Self::get_input_node) for a version which returns an [`Option`] instead.
128    #[inline]
129    pub fn input_node(&self) -> &NodeState {
130        self.get_input_node().unwrap()
131    }
132
133    /// Adds the `node` with the `label` to the graph.
134    /// If the label is already present replaces it instead.
135    pub fn add_node<T>(&mut self, label: impl RenderLabel, node: T)
136    where
137        T: Node,
138    {
139        let label = label.intern();
140        let node_state = NodeState::new(label, node);
141        self.nodes.insert(label, node_state);
142    }
143
144    /// Add `node_edge`s based on the order of the given `edges` array.
145    ///
146    /// Defining an edge that already exists is not considered an error with this api.
147    /// It simply won't create a new edge.
148    #[track_caller]
149    pub fn add_node_edges<const N: usize>(&mut self, edges: impl IntoRenderNodeArray<N>) {
150        for window in edges.into_array().windows(2) {
151            let [a, b] = window else {
152                break;
153            };
154            if let Err(err) = self.try_add_node_edge(*a, *b) {
155                match err {
156                    // Already existing edges are very easy to produce with this api
157                    // and shouldn't cause a panic
158                    RenderGraphError::EdgeAlreadyExists(_) => {}
159                    _ => panic!("{err}"),
160                }
161            }
162        }
163    }
164
165    /// Removes the `node` with the `label` from the graph.
166    /// If the label does not exist, nothing happens.
167    pub fn remove_node(&mut self, label: impl RenderLabel) -> Result<(), RenderGraphError> {
168        let label = label.intern();
169        if let Some(node_state) = self.nodes.remove(&label) {
170            // Remove all edges from other nodes to this one. Note that as we're removing this
171            // node, we don't need to remove its input edges
172            for input_edge in node_state.edges.input_edges() {
173                match input_edge {
174                    Edge::SlotEdge { output_node, .. }
175                    | Edge::NodeEdge {
176                        input_node: _,
177                        output_node,
178                    } => {
179                        if let Ok(output_node) = self.get_node_state_mut(*output_node) {
180                            output_node.edges.remove_output_edge(input_edge.clone())?;
181                        }
182                    }
183                }
184            }
185            // Remove all edges from this node to other nodes. Note that as we're removing this
186            // node, we don't need to remove its output edges
187            for output_edge in node_state.edges.output_edges() {
188                match output_edge {
189                    Edge::SlotEdge {
190                        output_node: _,
191                        output_index: _,
192                        input_node,
193                        input_index: _,
194                    }
195                    | Edge::NodeEdge {
196                        output_node: _,
197                        input_node,
198                    } => {
199                        if let Ok(input_node) = self.get_node_state_mut(*input_node) {
200                            input_node.edges.remove_input_edge(output_edge.clone())?;
201                        }
202                    }
203                }
204            }
205        }
206
207        Ok(())
208    }
209
210    /// Retrieves the [`NodeState`] referenced by the `label`.
211    pub fn get_node_state(&self, label: impl RenderLabel) -> Result<&NodeState, RenderGraphError> {
212        let label = label.intern();
213        self.nodes
214            .get(&label)
215            .ok_or(RenderGraphError::InvalidNode(label))
216    }
217
218    /// Retrieves the [`NodeState`] referenced by the `label` mutably.
219    pub fn get_node_state_mut(
220        &mut self,
221        label: impl RenderLabel,
222    ) -> Result<&mut NodeState, RenderGraphError> {
223        let label = label.intern();
224        self.nodes
225            .get_mut(&label)
226            .ok_or(RenderGraphError::InvalidNode(label))
227    }
228
229    /// Retrieves the [`Node`] referenced by the `label`.
230    pub fn get_node<T>(&self, label: impl RenderLabel) -> Result<&T, RenderGraphError>
231    where
232        T: Node,
233    {
234        self.get_node_state(label).and_then(|n| n.node())
235    }
236
237    /// Retrieves the [`Node`] referenced by the `label` mutably.
238    pub fn get_node_mut<T>(&mut self, label: impl RenderLabel) -> Result<&mut T, RenderGraphError>
239    where
240        T: Node,
241    {
242        self.get_node_state_mut(label).and_then(|n| n.node_mut())
243    }
244
245    /// Adds the [`Edge::SlotEdge`] to the graph. This guarantees that the `output_node`
246    /// is run before the `input_node` and also connects the `output_slot` to the `input_slot`.
247    ///
248    /// Fails if any invalid [`RenderLabel`]s or [`SlotLabel`]s are given.
249    ///
250    /// # See also
251    ///
252    /// - [`add_slot_edge`](Self::add_slot_edge) for an infallible version.
253    pub fn try_add_slot_edge(
254        &mut self,
255        output_node: impl RenderLabel,
256        output_slot: impl Into<SlotLabel>,
257        input_node: impl RenderLabel,
258        input_slot: impl Into<SlotLabel>,
259    ) -> Result<(), RenderGraphError> {
260        let output_slot = output_slot.into();
261        let input_slot = input_slot.into();
262
263        let output_node = output_node.intern();
264        let input_node = input_node.intern();
265
266        let output_index = self
267            .get_node_state(output_node)?
268            .output_slots
269            .get_slot_index(output_slot.clone())
270            .ok_or(RenderGraphError::InvalidOutputNodeSlot(output_slot))?;
271        let input_index = self
272            .get_node_state(input_node)?
273            .input_slots
274            .get_slot_index(input_slot.clone())
275            .ok_or(RenderGraphError::InvalidInputNodeSlot(input_slot))?;
276
277        let edge = Edge::SlotEdge {
278            output_node,
279            output_index,
280            input_node,
281            input_index,
282        };
283
284        self.validate_edge(&edge, EdgeExistence::DoesNotExist)?;
285
286        {
287            let output_node = self.get_node_state_mut(output_node)?;
288            output_node.edges.add_output_edge(edge.clone())?;
289        }
290        let input_node = self.get_node_state_mut(input_node)?;
291        input_node.edges.add_input_edge(edge)?;
292
293        Ok(())
294    }
295
296    /// Adds the [`Edge::SlotEdge`] to the graph. This guarantees that the `output_node`
297    /// is run before the `input_node` and also connects the `output_slot` to the `input_slot`.
298    ///
299    /// # Panics
300    ///
301    /// Any invalid [`RenderLabel`]s or [`SlotLabel`]s are given.
302    ///
303    /// # See also
304    ///
305    /// - [`try_add_slot_edge`](Self::try_add_slot_edge) for a fallible version.
306    pub fn add_slot_edge(
307        &mut self,
308        output_node: impl RenderLabel,
309        output_slot: impl Into<SlotLabel>,
310        input_node: impl RenderLabel,
311        input_slot: impl Into<SlotLabel>,
312    ) {
313        self.try_add_slot_edge(output_node, output_slot, input_node, input_slot)
314            .unwrap();
315    }
316
317    /// Removes the [`Edge::SlotEdge`] from the graph. If any nodes or slots do not exist then
318    /// nothing happens.
319    pub fn remove_slot_edge(
320        &mut self,
321        output_node: impl RenderLabel,
322        output_slot: impl Into<SlotLabel>,
323        input_node: impl RenderLabel,
324        input_slot: impl Into<SlotLabel>,
325    ) -> Result<(), RenderGraphError> {
326        let output_slot = output_slot.into();
327        let input_slot = input_slot.into();
328
329        let output_node = output_node.intern();
330        let input_node = input_node.intern();
331
332        let output_index = self
333            .get_node_state(output_node)?
334            .output_slots
335            .get_slot_index(output_slot.clone())
336            .ok_or(RenderGraphError::InvalidOutputNodeSlot(output_slot))?;
337        let input_index = self
338            .get_node_state(input_node)?
339            .input_slots
340            .get_slot_index(input_slot.clone())
341            .ok_or(RenderGraphError::InvalidInputNodeSlot(input_slot))?;
342
343        let edge = Edge::SlotEdge {
344            output_node,
345            output_index,
346            input_node,
347            input_index,
348        };
349
350        self.validate_edge(&edge, EdgeExistence::Exists)?;
351
352        {
353            let output_node = self.get_node_state_mut(output_node)?;
354            output_node.edges.remove_output_edge(edge.clone())?;
355        }
356        let input_node = self.get_node_state_mut(input_node)?;
357        input_node.edges.remove_input_edge(edge)?;
358
359        Ok(())
360    }
361
362    /// Adds the [`Edge::NodeEdge`] to the graph. This guarantees that the `output_node`
363    /// is run before the `input_node`.
364    ///
365    /// Fails if any invalid [`RenderLabel`] is given.
366    ///
367    /// # See also
368    ///
369    /// - [`add_node_edge`](Self::add_node_edge) for an infallible version.
370    pub fn try_add_node_edge(
371        &mut self,
372        output_node: impl RenderLabel,
373        input_node: impl RenderLabel,
374    ) -> Result<(), RenderGraphError> {
375        let output_node = output_node.intern();
376        let input_node = input_node.intern();
377
378        let edge = Edge::NodeEdge {
379            output_node,
380            input_node,
381        };
382
383        self.validate_edge(&edge, EdgeExistence::DoesNotExist)?;
384
385        {
386            let output_node = self.get_node_state_mut(output_node)?;
387            output_node.edges.add_output_edge(edge.clone())?;
388        }
389        let input_node = self.get_node_state_mut(input_node)?;
390        input_node.edges.add_input_edge(edge)?;
391
392        Ok(())
393    }
394
395    /// Adds the [`Edge::NodeEdge`] to the graph. This guarantees that the `output_node`
396    /// is run before the `input_node`.
397    ///
398    /// # Panics
399    ///
400    /// Panics if any invalid [`RenderLabel`] is given.
401    ///
402    /// # See also
403    ///
404    /// - [`try_add_node_edge`](Self::try_add_node_edge) for a fallible version.
405    pub fn add_node_edge(&mut self, output_node: impl RenderLabel, input_node: impl RenderLabel) {
406        self.try_add_node_edge(output_node, input_node).unwrap();
407    }
408
409    /// Removes the [`Edge::NodeEdge`] from the graph. If either node does not exist then nothing
410    /// happens.
411    pub fn remove_node_edge(
412        &mut self,
413        output_node: impl RenderLabel,
414        input_node: impl RenderLabel,
415    ) -> Result<(), RenderGraphError> {
416        let output_node = output_node.intern();
417        let input_node = input_node.intern();
418
419        let edge = Edge::NodeEdge {
420            output_node,
421            input_node,
422        };
423
424        self.validate_edge(&edge, EdgeExistence::Exists)?;
425
426        {
427            let output_node = self.get_node_state_mut(output_node)?;
428            output_node.edges.remove_output_edge(edge.clone())?;
429        }
430        let input_node = self.get_node_state_mut(input_node)?;
431        input_node.edges.remove_input_edge(edge)?;
432
433        Ok(())
434    }
435
436    /// Verifies that the edge existence is as expected and
437    /// checks that slot edges are connected correctly.
438    pub fn validate_edge(
439        &mut self,
440        edge: &Edge,
441        should_exist: EdgeExistence,
442    ) -> Result<(), RenderGraphError> {
443        if should_exist == EdgeExistence::Exists && !self.has_edge(edge) {
444            return Err(RenderGraphError::EdgeDoesNotExist(edge.clone()));
445        } else if should_exist == EdgeExistence::DoesNotExist && self.has_edge(edge) {
446            return Err(RenderGraphError::EdgeAlreadyExists(edge.clone()));
447        }
448
449        match *edge {
450            Edge::SlotEdge {
451                output_node,
452                output_index,
453                input_node,
454                input_index,
455            } => {
456                let output_node_state = self.get_node_state(output_node)?;
457                let input_node_state = self.get_node_state(input_node)?;
458
459                let output_slot = output_node_state
460                    .output_slots
461                    .get_slot(output_index)
462                    .ok_or(RenderGraphError::InvalidOutputNodeSlot(SlotLabel::Index(
463                        output_index,
464                    )))?;
465                let input_slot = input_node_state.input_slots.get_slot(input_index).ok_or(
466                    RenderGraphError::InvalidInputNodeSlot(SlotLabel::Index(input_index)),
467                )?;
468
469                if let Some(Edge::SlotEdge {
470                    output_node: current_output_node,
471                    ..
472                }) = input_node_state.edges.input_edges().iter().find(|e| {
473                    if let Edge::SlotEdge {
474                        input_index: current_input_index,
475                        ..
476                    } = e
477                    {
478                        input_index == *current_input_index
479                    } else {
480                        false
481                    }
482                }) && should_exist == EdgeExistence::DoesNotExist
483                {
484                    return Err(RenderGraphError::NodeInputSlotAlreadyOccupied {
485                        node: input_node,
486                        input_slot: input_index,
487                        occupied_by_node: *current_output_node,
488                    });
489                }
490
491                if output_slot.slot_type != input_slot.slot_type {
492                    return Err(RenderGraphError::MismatchedNodeSlots {
493                        output_node,
494                        output_slot: output_index,
495                        input_node,
496                        input_slot: input_index,
497                    });
498                }
499            }
500            Edge::NodeEdge { .. } => { /* nothing to validate here */ }
501        }
502
503        Ok(())
504    }
505
506    /// Checks whether the `edge` already exists in the graph.
507    pub fn has_edge(&self, edge: &Edge) -> bool {
508        let output_node_state = self.get_node_state(edge.get_output_node());
509        let input_node_state = self.get_node_state(edge.get_input_node());
510        if let Ok(output_node_state) = output_node_state
511            && output_node_state.edges.output_edges().contains(edge)
512            && let Ok(input_node_state) = input_node_state
513            && input_node_state.edges.input_edges().contains(edge)
514        {
515            return true;
516        }
517
518        false
519    }
520
521    /// Returns an iterator over the [`NodeStates`](NodeState).
522    pub fn iter_nodes(&self) -> impl Iterator<Item = &NodeState> {
523        self.nodes.values()
524    }
525
526    /// Returns an iterator over the [`NodeStates`](NodeState), that allows modifying each value.
527    pub fn iter_nodes_mut(&mut self) -> impl Iterator<Item = &mut NodeState> {
528        self.nodes.values_mut()
529    }
530
531    /// Returns an iterator over the sub graphs.
532    pub fn iter_sub_graphs(&self) -> impl Iterator<Item = (InternedRenderSubGraph, &RenderGraph)> {
533        self.sub_graphs.iter().map(|(name, graph)| (*name, graph))
534    }
535
536    /// Returns an iterator over the sub graphs, that allows modifying each value.
537    pub fn iter_sub_graphs_mut(
538        &mut self,
539    ) -> impl Iterator<Item = (InternedRenderSubGraph, &mut RenderGraph)> {
540        self.sub_graphs
541            .iter_mut()
542            .map(|(name, graph)| (*name, graph))
543    }
544
545    /// Returns an iterator over a tuple of the input edges and the corresponding output nodes
546    /// for the node referenced by the label.
547    pub fn iter_node_inputs(
548        &self,
549        label: impl RenderLabel,
550    ) -> Result<impl Iterator<Item = (&Edge, &NodeState)>, RenderGraphError> {
551        let node = self.get_node_state(label)?;
552        Ok(node
553            .edges
554            .input_edges()
555            .iter()
556            .map(|edge| (edge, edge.get_output_node()))
557            .map(move |(edge, output_node)| (edge, self.get_node_state(output_node).unwrap())))
558    }
559
560    /// Returns an iterator over a tuple of the output edges and the corresponding input nodes
561    /// for the node referenced by the label.
562    pub fn iter_node_outputs(
563        &self,
564        label: impl RenderLabel,
565    ) -> Result<impl Iterator<Item = (&Edge, &NodeState)>, RenderGraphError> {
566        let node = self.get_node_state(label)?;
567        Ok(node
568            .edges
569            .output_edges()
570            .iter()
571            .map(|edge| (edge, edge.get_input_node()))
572            .map(move |(edge, input_node)| (edge, self.get_node_state(input_node).unwrap())))
573    }
574
575    /// Adds the `sub_graph` with the `label` to the graph.
576    /// If the label is already present replaces it instead.
577    pub fn add_sub_graph(&mut self, label: impl RenderSubGraph, sub_graph: RenderGraph) {
578        self.sub_graphs.insert(label.intern(), sub_graph);
579    }
580
581    /// Removes the `sub_graph` with the `label` from the graph.
582    /// If the label does not exist then nothing happens.
583    pub fn remove_sub_graph(&mut self, label: impl RenderSubGraph) {
584        self.sub_graphs.remove(&label.intern());
585    }
586
587    /// Retrieves the sub graph corresponding to the `label`.
588    pub fn get_sub_graph(&self, label: impl RenderSubGraph) -> Option<&RenderGraph> {
589        self.sub_graphs.get(&label.intern())
590    }
591
592    /// Retrieves the sub graph corresponding to the `label` mutably.
593    pub fn get_sub_graph_mut(&mut self, label: impl RenderSubGraph) -> Option<&mut RenderGraph> {
594        self.sub_graphs.get_mut(&label.intern())
595    }
596
597    /// Retrieves the sub graph corresponding to the `label`.
598    ///
599    /// # Panics
600    ///
601    /// Panics if any invalid subgraph label is given.
602    ///
603    /// # See also
604    ///
605    /// - [`get_sub_graph`](Self::get_sub_graph) for a fallible version.
606    pub fn sub_graph(&self, label: impl RenderSubGraph) -> &RenderGraph {
607        let label = label.intern();
608        self.sub_graphs
609            .get(&label)
610            .unwrap_or_else(|| panic!("Subgraph {label:?} not found"))
611    }
612
613    /// Retrieves the sub graph corresponding to the `label` mutably.
614    ///
615    /// # Panics
616    ///
617    /// Panics if any invalid subgraph label is given.
618    ///
619    /// # See also
620    ///
621    /// - [`get_sub_graph_mut`](Self::get_sub_graph_mut) for a fallible version.
622    pub fn sub_graph_mut(&mut self, label: impl RenderSubGraph) -> &mut RenderGraph {
623        let label = label.intern();
624        self.sub_graphs
625            .get_mut(&label)
626            .unwrap_or_else(|| panic!("Subgraph {label:?} not found"))
627    }
628}
629
630impl Debug for RenderGraph {
631    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
632        for node in self.iter_nodes() {
633            writeln!(f, "{:?}", node.label)?;
634            writeln!(f, "  in: {:?}", node.input_slots)?;
635            writeln!(f, "  out: {:?}", node.output_slots)?;
636        }
637
638        Ok(())
639    }
640}
641
642/// A [`Node`] which acts as an entry point for a [`RenderGraph`] with custom inputs.
643/// It has the same input and output slots and simply copies them over when run.
644pub struct GraphInputNode {
645    inputs: Vec<SlotInfo>,
646}
647
648impl Node for GraphInputNode {
649    fn input(&self) -> Vec<SlotInfo> {
650        self.inputs.clone()
651    }
652
653    fn output(&self) -> Vec<SlotInfo> {
654        self.inputs.clone()
655    }
656
657    fn run(
658        &self,
659        graph: &mut RenderGraphContext,
660        _render_context: &mut RenderContext,
661        _world: &World,
662    ) -> Result<(), NodeRunError> {
663        for i in 0..graph.inputs().len() {
664            let input = graph.inputs()[i].clone();
665            graph.set_output(i, input)?;
666        }
667        Ok(())
668    }
669}
670
671#[cfg(test)]
672mod tests {
673    use crate::{
674        render_graph::{
675            node::IntoRenderNodeArray, Edge, InternedRenderLabel, Node, NodeRunError, RenderGraph,
676            RenderGraphContext, RenderGraphError, RenderLabel, SlotInfo, SlotType,
677        },
678        renderer::RenderContext,
679    };
680    use bevy_ecs::world::{FromWorld, World};
681    use bevy_platform::collections::HashSet;
682
683    #[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
684    enum TestLabel {
685        A,
686        B,
687        C,
688        D,
689    }
690
691    #[derive(Debug)]
692    struct TestNode {
693        inputs: Vec<SlotInfo>,
694        outputs: Vec<SlotInfo>,
695    }
696
697    impl TestNode {
698        pub fn new(inputs: usize, outputs: usize) -> Self {
699            TestNode {
700                inputs: (0..inputs)
701                    .map(|i| SlotInfo::new(format!("in_{i}"), SlotType::TextureView))
702                    .collect(),
703                outputs: (0..outputs)
704                    .map(|i| SlotInfo::new(format!("out_{i}"), SlotType::TextureView))
705                    .collect(),
706            }
707        }
708    }
709
710    impl Node for TestNode {
711        fn input(&self) -> Vec<SlotInfo> {
712            self.inputs.clone()
713        }
714
715        fn output(&self) -> Vec<SlotInfo> {
716            self.outputs.clone()
717        }
718
719        fn run(
720            &self,
721            _: &mut RenderGraphContext,
722            _: &mut RenderContext,
723            _: &World,
724        ) -> Result<(), NodeRunError> {
725            Ok(())
726        }
727    }
728
729    fn input_nodes(label: impl RenderLabel, graph: &RenderGraph) -> HashSet<InternedRenderLabel> {
730        graph
731            .iter_node_inputs(label)
732            .unwrap()
733            .map(|(_edge, node)| node.label)
734            .collect::<HashSet<InternedRenderLabel>>()
735    }
736
737    fn output_nodes(label: impl RenderLabel, graph: &RenderGraph) -> HashSet<InternedRenderLabel> {
738        graph
739            .iter_node_outputs(label)
740            .unwrap()
741            .map(|(_edge, node)| node.label)
742            .collect::<HashSet<InternedRenderLabel>>()
743    }
744
745    #[test]
746    fn test_graph_edges() {
747        let mut graph = RenderGraph::default();
748        graph.add_node(TestLabel::A, TestNode::new(0, 1));
749        graph.add_node(TestLabel::B, TestNode::new(0, 1));
750        graph.add_node(TestLabel::C, TestNode::new(1, 1));
751        graph.add_node(TestLabel::D, TestNode::new(1, 0));
752
753        graph.add_slot_edge(TestLabel::A, "out_0", TestLabel::C, "in_0");
754        graph.add_node_edge(TestLabel::B, TestLabel::C);
755        graph.add_slot_edge(TestLabel::C, 0, TestLabel::D, 0);
756
757        assert!(
758            input_nodes(TestLabel::A, &graph).is_empty(),
759            "A has no inputs"
760        );
761        assert_eq!(
762            output_nodes(TestLabel::A, &graph),
763            HashSet::from_iter((TestLabel::C,).into_array()),
764            "A outputs to C"
765        );
766
767        assert!(
768            input_nodes(TestLabel::B, &graph).is_empty(),
769            "B has no inputs"
770        );
771        assert_eq!(
772            output_nodes(TestLabel::B, &graph),
773            HashSet::from_iter((TestLabel::C,).into_array()),
774            "B outputs to C"
775        );
776
777        assert_eq!(
778            input_nodes(TestLabel::C, &graph),
779            HashSet::from_iter((TestLabel::A, TestLabel::B).into_array()),
780            "A and B input to C"
781        );
782        assert_eq!(
783            output_nodes(TestLabel::C, &graph),
784            HashSet::from_iter((TestLabel::D,).into_array()),
785            "C outputs to D"
786        );
787
788        assert_eq!(
789            input_nodes(TestLabel::D, &graph),
790            HashSet::from_iter((TestLabel::C,).into_array()),
791            "C inputs to D"
792        );
793        assert!(
794            output_nodes(TestLabel::D, &graph).is_empty(),
795            "D has no outputs"
796        );
797    }
798
799    #[test]
800    fn test_get_node_typed() {
801        struct MyNode {
802            value: usize,
803        }
804
805        impl Node for MyNode {
806            fn run(
807                &self,
808                _: &mut RenderGraphContext,
809                _: &mut RenderContext,
810                _: &World,
811            ) -> Result<(), NodeRunError> {
812                Ok(())
813            }
814        }
815
816        let mut graph = RenderGraph::default();
817
818        graph.add_node(TestLabel::A, MyNode { value: 42 });
819
820        let node: &MyNode = graph.get_node(TestLabel::A).unwrap();
821        assert_eq!(node.value, 42, "node value matches");
822
823        let result: Result<&TestNode, RenderGraphError> = graph.get_node(TestLabel::A);
824        assert_eq!(
825            result.unwrap_err(),
826            RenderGraphError::WrongNodeType,
827            "expect a wrong node type error"
828        );
829    }
830
831    #[test]
832    fn test_slot_already_occupied() {
833        let mut graph = RenderGraph::default();
834
835        graph.add_node(TestLabel::A, TestNode::new(0, 1));
836        graph.add_node(TestLabel::B, TestNode::new(0, 1));
837        graph.add_node(TestLabel::C, TestNode::new(1, 1));
838
839        graph.add_slot_edge(TestLabel::A, 0, TestLabel::C, 0);
840        assert_eq!(
841            graph.try_add_slot_edge(TestLabel::B, 0, TestLabel::C, 0),
842            Err(RenderGraphError::NodeInputSlotAlreadyOccupied {
843                node: TestLabel::C.intern(),
844                input_slot: 0,
845                occupied_by_node: TestLabel::A.intern(),
846            }),
847            "Adding to a slot that is already occupied should return an error"
848        );
849    }
850
851    #[test]
852    fn test_edge_already_exists() {
853        let mut graph = RenderGraph::default();
854
855        graph.add_node(TestLabel::A, TestNode::new(0, 1));
856        graph.add_node(TestLabel::B, TestNode::new(1, 0));
857
858        graph.add_slot_edge(TestLabel::A, 0, TestLabel::B, 0);
859        assert_eq!(
860            graph.try_add_slot_edge(TestLabel::A, 0, TestLabel::B, 0),
861            Err(RenderGraphError::EdgeAlreadyExists(Edge::SlotEdge {
862                output_node: TestLabel::A.intern(),
863                output_index: 0,
864                input_node: TestLabel::B.intern(),
865                input_index: 0,
866            })),
867            "Adding to a duplicate edge should return an error"
868        );
869    }
870
871    #[test]
872    fn test_add_node_edges() {
873        struct SimpleNode;
874        impl Node for SimpleNode {
875            fn run(
876                &self,
877                _graph: &mut RenderGraphContext,
878                _render_context: &mut RenderContext,
879                _world: &World,
880            ) -> Result<(), NodeRunError> {
881                Ok(())
882            }
883        }
884        impl FromWorld for SimpleNode {
885            fn from_world(_world: &mut World) -> Self {
886                Self
887            }
888        }
889
890        let mut graph = RenderGraph::default();
891        graph.add_node(TestLabel::A, SimpleNode);
892        graph.add_node(TestLabel::B, SimpleNode);
893        graph.add_node(TestLabel::C, SimpleNode);
894
895        graph.add_node_edges((TestLabel::A, TestLabel::B, TestLabel::C));
896
897        assert_eq!(
898            output_nodes(TestLabel::A, &graph),
899            HashSet::from_iter((TestLabel::B,).into_array()),
900            "A -> B"
901        );
902        assert_eq!(
903            input_nodes(TestLabel::B, &graph),
904            HashSet::from_iter((TestLabel::A,).into_array()),
905            "A -> B"
906        );
907        assert_eq!(
908            output_nodes(TestLabel::B, &graph),
909            HashSet::from_iter((TestLabel::C,).into_array()),
910            "B -> C"
911        );
912        assert_eq!(
913            input_nodes(TestLabel::C, &graph),
914            HashSet::from_iter((TestLabel::B,).into_array()),
915            "B -> C"
916        );
917    }
918}