bevy_render/render_graph/
graph.rs

1use crate::{
2    render_graph::{
3        Edge, Node, NodeRunError, NodeState, RenderGraphContext, RenderGraphError, RenderLabel,
4        SlotInfo, SlotLabel,
5    },
6    renderer::RenderContext,
7};
8use bevy_ecs::{define_label, intern::Interned, prelude::World, system::Resource};
9use bevy_utils::HashMap;
10use core::fmt::Debug;
11
12use super::{EdgeExistence, InternedRenderLabel, IntoRenderNodeArray};
13
14pub use bevy_render_macros::RenderSubGraph;
15
16define_label!(
17    /// A strongly-typed class of labels used to identify a [`SubGraph`] in a render graph.
18    RenderSubGraph,
19    RENDER_SUB_GRAPH_INTERNER
20);
21
22/// A shorthand for `Interned<dyn RenderSubGraph>`.
23pub type InternedRenderSubGraph = Interned<dyn RenderSubGraph>;
24
25/// The render graph configures the modular and re-usable render logic.
26///
27/// It is a retained and stateless (nodes themselves may have their own internal state) structure,
28/// which can not be modified while it is executed by the graph runner.
29///
30/// The render graph runner is responsible for executing the entire graph each frame.
31/// It will execute each node in the graph in the correct order, based on the edges between the nodes.
32///
33/// It consists of three main components: [`Nodes`](Node), [`Edges`](Edge)
34/// and [`Slots`](super::SlotType).
35///
36/// Nodes are responsible for generating draw calls and operating on input and output slots.
37/// Edges specify the order of execution for nodes and connect input and output slots together.
38/// Slots describe the render resources created or used by the nodes.
39///
40/// Additionally a render graph can contain multiple sub graphs, which are run by the
41/// corresponding nodes. Every render graph can have its own optional input node.
42///
43/// ## Example
44/// Here is a simple render graph example with two nodes connected by a node edge.
45/// ```ignore
46/// # TODO: Remove when #10645 is fixed
47/// # use bevy_app::prelude::*;
48/// # use bevy_ecs::prelude::World;
49/// # use bevy_render::render_graph::{RenderGraph, RenderLabel, Node, RenderGraphContext, NodeRunError};
50/// # use bevy_render::renderer::RenderContext;
51/// #
52/// #[derive(RenderLabel)]
53/// enum Labels {
54///     A,
55///     B,
56/// }
57///
58/// # struct MyNode;
59/// #
60/// # impl Node for MyNode {
61/// #     fn run(&self, graph: &mut RenderGraphContext, render_context: &mut RenderContext, world: &World) -> Result<(), NodeRunError> {
62/// #         unimplemented!()
63/// #     }
64/// # }
65/// #
66/// let mut graph = RenderGraph::default();
67/// graph.add_node(Labels::A, MyNode);
68/// graph.add_node(Labels::B, MyNode);
69/// graph.add_node_edge(Labels::B, Labels::A);
70/// ```
71#[derive(Resource, Default)]
72pub struct RenderGraph {
73    nodes: HashMap<InternedRenderLabel, NodeState>,
74    sub_graphs: HashMap<InternedRenderSubGraph, RenderGraph>,
75}
76
77/// The label for the input node of a graph. Used to connect other nodes to it.
78#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
79pub struct GraphInput;
80
81impl RenderGraph {
82    /// Updates all nodes and sub graphs of the render graph. Should be called before executing it.
83    pub fn update(&mut self, world: &mut World) {
84        for node in self.nodes.values_mut() {
85            node.node.update(world);
86        }
87
88        for sub_graph in self.sub_graphs.values_mut() {
89            sub_graph.update(world);
90        }
91    }
92
93    /// Creates an [`GraphInputNode`] with the specified slots if not already present.
94    pub fn set_input(&mut self, inputs: Vec<SlotInfo>) {
95        assert!(
96            matches!(
97                self.get_node_state(GraphInput),
98                Err(RenderGraphError::InvalidNode(_))
99            ),
100            "Graph already has an input node"
101        );
102
103        self.add_node(GraphInput, GraphInputNode { inputs });
104    }
105
106    /// Returns the [`NodeState`] of the input node of this graph.
107    ///
108    /// # See also
109    ///
110    /// - [`input_node`](Self::input_node) for an unchecked version.
111    #[inline]
112    pub fn get_input_node(&self) -> Option<&NodeState> {
113        self.get_node_state(GraphInput).ok()
114    }
115
116    /// Returns the [`NodeState`] of the input node of this graph.
117    ///
118    /// # Panics
119    ///
120    /// Panics if there is no input node set.
121    ///
122    /// # See also
123    ///
124    /// - [`get_input_node`](Self::get_input_node) for a version which returns an [`Option`] instead.
125    #[inline]
126    pub fn input_node(&self) -> &NodeState {
127        self.get_input_node().unwrap()
128    }
129
130    /// Adds the `node` with the `label` to the graph.
131    /// If the label is already present replaces it instead.
132    pub fn add_node<T>(&mut self, label: impl RenderLabel, node: T)
133    where
134        T: Node,
135    {
136        let label = label.intern();
137        let node_state = NodeState::new(label, node);
138        self.nodes.insert(label, node_state);
139    }
140
141    /// Add `node_edge`s based on the order of the given `edges` array.
142    ///
143    /// Defining an edge that already exists is not considered an error with this api.
144    /// It simply won't create a new edge.
145    pub fn add_node_edges<const N: usize>(&mut self, edges: impl IntoRenderNodeArray<N>) {
146        for window in edges.into_array().windows(2) {
147            let [a, b] = window else {
148                break;
149            };
150            if let Err(err) = self.try_add_node_edge(*a, *b) {
151                match err {
152                    // Already existing edges are very easy to produce with this api
153                    // and shouldn't cause a panic
154                    RenderGraphError::EdgeAlreadyExists(_) => {}
155                    _ => panic!("{err:?}"),
156                }
157            }
158        }
159    }
160
161    /// Removes the `node` with the `label` from the graph.
162    /// If the label does not exist, nothing happens.
163    pub fn remove_node(&mut self, label: impl RenderLabel) -> Result<(), RenderGraphError> {
164        let label = label.intern();
165        if let Some(node_state) = self.nodes.remove(&label) {
166            // Remove all edges from other nodes to this one. Note that as we're removing this
167            // node, we don't need to remove its input edges
168            for input_edge in node_state.edges.input_edges() {
169                match input_edge {
170                    Edge::SlotEdge { output_node, .. }
171                    | Edge::NodeEdge {
172                        input_node: _,
173                        output_node,
174                    } => {
175                        if let Ok(output_node) = self.get_node_state_mut(*output_node) {
176                            output_node.edges.remove_output_edge(input_edge.clone())?;
177                        }
178                    }
179                }
180            }
181            // Remove all edges from this node to other nodes. Note that as we're removing this
182            // node, we don't need to remove its output edges
183            for output_edge in node_state.edges.output_edges() {
184                match output_edge {
185                    Edge::SlotEdge {
186                        output_node: _,
187                        output_index: _,
188                        input_node,
189                        input_index: _,
190                    }
191                    | Edge::NodeEdge {
192                        output_node: _,
193                        input_node,
194                    } => {
195                        if let Ok(input_node) = self.get_node_state_mut(*input_node) {
196                            input_node.edges.remove_input_edge(output_edge.clone())?;
197                        }
198                    }
199                }
200            }
201        }
202
203        Ok(())
204    }
205
206    /// Retrieves the [`NodeState`] referenced by the `label`.
207    pub fn get_node_state(&self, label: impl RenderLabel) -> Result<&NodeState, RenderGraphError> {
208        let label = label.intern();
209        self.nodes
210            .get(&label)
211            .ok_or(RenderGraphError::InvalidNode(label))
212    }
213
214    /// Retrieves the [`NodeState`] referenced by the `label` mutably.
215    pub fn get_node_state_mut(
216        &mut self,
217        label: impl RenderLabel,
218    ) -> Result<&mut NodeState, RenderGraphError> {
219        let label = label.intern();
220        self.nodes
221            .get_mut(&label)
222            .ok_or(RenderGraphError::InvalidNode(label))
223    }
224
225    /// Retrieves the [`Node`] referenced by the `label`.
226    pub fn get_node<T>(&self, label: impl RenderLabel) -> Result<&T, RenderGraphError>
227    where
228        T: Node,
229    {
230        self.get_node_state(label).and_then(|n| n.node())
231    }
232
233    /// Retrieves the [`Node`] referenced by the `label` mutably.
234    pub fn get_node_mut<T>(&mut self, label: impl RenderLabel) -> Result<&mut T, RenderGraphError>
235    where
236        T: Node,
237    {
238        self.get_node_state_mut(label).and_then(|n| n.node_mut())
239    }
240
241    /// Adds the [`Edge::SlotEdge`] to the graph. This guarantees that the `output_node`
242    /// is run before the `input_node` and also connects the `output_slot` to the `input_slot`.
243    ///
244    /// Fails if any invalid [`RenderLabel`]s or [`SlotLabel`]s are given.
245    ///
246    /// # See also
247    ///
248    /// - [`add_slot_edge`](Self::add_slot_edge) for an infallible version.
249    pub fn try_add_slot_edge(
250        &mut self,
251        output_node: impl RenderLabel,
252        output_slot: impl Into<SlotLabel>,
253        input_node: impl RenderLabel,
254        input_slot: impl Into<SlotLabel>,
255    ) -> Result<(), RenderGraphError> {
256        let output_slot = output_slot.into();
257        let input_slot = input_slot.into();
258
259        let output_node = output_node.intern();
260        let input_node = input_node.intern();
261
262        let output_index = self
263            .get_node_state(output_node)?
264            .output_slots
265            .get_slot_index(output_slot.clone())
266            .ok_or(RenderGraphError::InvalidOutputNodeSlot(output_slot))?;
267        let input_index = self
268            .get_node_state(input_node)?
269            .input_slots
270            .get_slot_index(input_slot.clone())
271            .ok_or(RenderGraphError::InvalidInputNodeSlot(input_slot))?;
272
273        let edge = Edge::SlotEdge {
274            output_node,
275            output_index,
276            input_node,
277            input_index,
278        };
279
280        self.validate_edge(&edge, EdgeExistence::DoesNotExist)?;
281
282        {
283            let output_node = self.get_node_state_mut(output_node)?;
284            output_node.edges.add_output_edge(edge.clone())?;
285        }
286        let input_node = self.get_node_state_mut(input_node)?;
287        input_node.edges.add_input_edge(edge)?;
288
289        Ok(())
290    }
291
292    /// Adds the [`Edge::SlotEdge`] to the graph. This guarantees that the `output_node`
293    /// is run before the `input_node` and also connects the `output_slot` to the `input_slot`.
294    ///
295    /// # Panics
296    ///
297    /// Any invalid [`RenderLabel`]s or [`SlotLabel`]s are given.
298    ///
299    /// # See also
300    ///
301    /// - [`try_add_slot_edge`](Self::try_add_slot_edge) for a fallible version.
302    pub fn add_slot_edge(
303        &mut self,
304        output_node: impl RenderLabel,
305        output_slot: impl Into<SlotLabel>,
306        input_node: impl RenderLabel,
307        input_slot: impl Into<SlotLabel>,
308    ) {
309        self.try_add_slot_edge(output_node, output_slot, input_node, input_slot)
310            .unwrap();
311    }
312
313    /// Removes the [`Edge::SlotEdge`] from the graph. If any nodes or slots do not exist then
314    /// nothing happens.
315    pub fn remove_slot_edge(
316        &mut self,
317        output_node: impl RenderLabel,
318        output_slot: impl Into<SlotLabel>,
319        input_node: impl RenderLabel,
320        input_slot: impl Into<SlotLabel>,
321    ) -> Result<(), RenderGraphError> {
322        let output_slot = output_slot.into();
323        let input_slot = input_slot.into();
324
325        let output_node = output_node.intern();
326        let input_node = input_node.intern();
327
328        let output_index = self
329            .get_node_state(output_node)?
330            .output_slots
331            .get_slot_index(output_slot.clone())
332            .ok_or(RenderGraphError::InvalidOutputNodeSlot(output_slot))?;
333        let input_index = self
334            .get_node_state(input_node)?
335            .input_slots
336            .get_slot_index(input_slot.clone())
337            .ok_or(RenderGraphError::InvalidInputNodeSlot(input_slot))?;
338
339        let edge = Edge::SlotEdge {
340            output_node,
341            output_index,
342            input_node,
343            input_index,
344        };
345
346        self.validate_edge(&edge, EdgeExistence::Exists)?;
347
348        {
349            let output_node = self.get_node_state_mut(output_node)?;
350            output_node.edges.remove_output_edge(edge.clone())?;
351        }
352        let input_node = self.get_node_state_mut(input_node)?;
353        input_node.edges.remove_input_edge(edge)?;
354
355        Ok(())
356    }
357
358    /// Adds the [`Edge::NodeEdge`] to the graph. This guarantees that the `output_node`
359    /// is run before the `input_node`.
360    ///
361    /// Fails if any invalid [`RenderLabel`] is given.
362    ///
363    /// # See also
364    ///
365    /// - [`add_node_edge`](Self::add_node_edge) for an infallible version.
366    pub fn try_add_node_edge(
367        &mut self,
368        output_node: impl RenderLabel,
369        input_node: impl RenderLabel,
370    ) -> Result<(), RenderGraphError> {
371        let output_node = output_node.intern();
372        let input_node = input_node.intern();
373
374        let edge = Edge::NodeEdge {
375            output_node,
376            input_node,
377        };
378
379        self.validate_edge(&edge, EdgeExistence::DoesNotExist)?;
380
381        {
382            let output_node = self.get_node_state_mut(output_node)?;
383            output_node.edges.add_output_edge(edge.clone())?;
384        }
385        let input_node = self.get_node_state_mut(input_node)?;
386        input_node.edges.add_input_edge(edge)?;
387
388        Ok(())
389    }
390
391    /// Adds the [`Edge::NodeEdge`] to the graph. This guarantees that the `output_node`
392    /// is run before the `input_node`.
393    ///
394    /// # Panics
395    ///
396    /// Panics if any invalid [`RenderLabel`] is given.
397    ///
398    /// # See also
399    ///
400    /// - [`try_add_node_edge`](Self::try_add_node_edge) for a fallible version.
401    pub fn add_node_edge(&mut self, output_node: impl RenderLabel, input_node: impl RenderLabel) {
402        self.try_add_node_edge(output_node, input_node).unwrap();
403    }
404
405    /// Removes the [`Edge::NodeEdge`] from the graph. If either node does not exist then nothing
406    /// happens.
407    pub fn remove_node_edge(
408        &mut self,
409        output_node: impl RenderLabel,
410        input_node: impl RenderLabel,
411    ) -> Result<(), RenderGraphError> {
412        let output_node = output_node.intern();
413        let input_node = input_node.intern();
414
415        let edge = Edge::NodeEdge {
416            output_node,
417            input_node,
418        };
419
420        self.validate_edge(&edge, EdgeExistence::Exists)?;
421
422        {
423            let output_node = self.get_node_state_mut(output_node)?;
424            output_node.edges.remove_output_edge(edge.clone())?;
425        }
426        let input_node = self.get_node_state_mut(input_node)?;
427        input_node.edges.remove_input_edge(edge)?;
428
429        Ok(())
430    }
431
432    /// Verifies that the edge existence is as expected and
433    /// checks that slot edges are connected correctly.
434    pub fn validate_edge(
435        &mut self,
436        edge: &Edge,
437        should_exist: EdgeExistence,
438    ) -> Result<(), RenderGraphError> {
439        if should_exist == EdgeExistence::Exists && !self.has_edge(edge) {
440            return Err(RenderGraphError::EdgeDoesNotExist(edge.clone()));
441        } else if should_exist == EdgeExistence::DoesNotExist && self.has_edge(edge) {
442            return Err(RenderGraphError::EdgeAlreadyExists(edge.clone()));
443        }
444
445        match *edge {
446            Edge::SlotEdge {
447                output_node,
448                output_index,
449                input_node,
450                input_index,
451            } => {
452                let output_node_state = self.get_node_state(output_node)?;
453                let input_node_state = self.get_node_state(input_node)?;
454
455                let output_slot = output_node_state
456                    .output_slots
457                    .get_slot(output_index)
458                    .ok_or(RenderGraphError::InvalidOutputNodeSlot(SlotLabel::Index(
459                        output_index,
460                    )))?;
461                let input_slot = input_node_state.input_slots.get_slot(input_index).ok_or(
462                    RenderGraphError::InvalidInputNodeSlot(SlotLabel::Index(input_index)),
463                )?;
464
465                if let Some(Edge::SlotEdge {
466                    output_node: current_output_node,
467                    ..
468                }) = input_node_state.edges.input_edges().iter().find(|e| {
469                    if let Edge::SlotEdge {
470                        input_index: current_input_index,
471                        ..
472                    } = e
473                    {
474                        input_index == *current_input_index
475                    } else {
476                        false
477                    }
478                }) {
479                    if should_exist == EdgeExistence::DoesNotExist {
480                        return Err(RenderGraphError::NodeInputSlotAlreadyOccupied {
481                            node: input_node,
482                            input_slot: input_index,
483                            occupied_by_node: *current_output_node,
484                        });
485                    }
486                }
487
488                if output_slot.slot_type != input_slot.slot_type {
489                    return Err(RenderGraphError::MismatchedNodeSlots {
490                        output_node,
491                        output_slot: output_index,
492                        input_node,
493                        input_slot: input_index,
494                    });
495                }
496            }
497            Edge::NodeEdge { .. } => { /* nothing to validate here */ }
498        }
499
500        Ok(())
501    }
502
503    /// Checks whether the `edge` already exists in the graph.
504    pub fn has_edge(&self, edge: &Edge) -> bool {
505        let output_node_state = self.get_node_state(edge.get_output_node());
506        let input_node_state = self.get_node_state(edge.get_input_node());
507        if let Ok(output_node_state) = output_node_state {
508            if output_node_state.edges.output_edges().contains(edge) {
509                if let Ok(input_node_state) = input_node_state {
510                    if input_node_state.edges.input_edges().contains(edge) {
511                        return true;
512                    }
513                }
514            }
515        }
516
517        false
518    }
519
520    /// Returns an iterator over the [`NodeStates`](NodeState).
521    pub fn iter_nodes(&self) -> impl Iterator<Item = &NodeState> {
522        self.nodes.values()
523    }
524
525    /// Returns an iterator over the [`NodeStates`](NodeState), that allows modifying each value.
526    pub fn iter_nodes_mut(&mut self) -> impl Iterator<Item = &mut NodeState> {
527        self.nodes.values_mut()
528    }
529
530    /// Returns an iterator over the sub graphs.
531    pub fn iter_sub_graphs(&self) -> impl Iterator<Item = (InternedRenderSubGraph, &RenderGraph)> {
532        self.sub_graphs.iter().map(|(name, graph)| (*name, graph))
533    }
534
535    /// Returns an iterator over the sub graphs, that allows modifying each value.
536    pub fn iter_sub_graphs_mut(
537        &mut self,
538    ) -> impl Iterator<Item = (InternedRenderSubGraph, &mut RenderGraph)> {
539        self.sub_graphs
540            .iter_mut()
541            .map(|(name, graph)| (*name, graph))
542    }
543
544    /// Returns an iterator over a tuple of the input edges and the corresponding output nodes
545    /// for the node referenced by the label.
546    pub fn iter_node_inputs(
547        &self,
548        label: impl RenderLabel,
549    ) -> Result<impl Iterator<Item = (&Edge, &NodeState)>, RenderGraphError> {
550        let node = self.get_node_state(label)?;
551        Ok(node
552            .edges
553            .input_edges()
554            .iter()
555            .map(|edge| (edge, edge.get_output_node()))
556            .map(move |(edge, output_node)| (edge, self.get_node_state(output_node).unwrap())))
557    }
558
559    /// Returns an iterator over a tuple of the output edges and the corresponding input nodes
560    /// for the node referenced by the label.
561    pub fn iter_node_outputs(
562        &self,
563        label: impl RenderLabel,
564    ) -> Result<impl Iterator<Item = (&Edge, &NodeState)>, RenderGraphError> {
565        let node = self.get_node_state(label)?;
566        Ok(node
567            .edges
568            .output_edges()
569            .iter()
570            .map(|edge| (edge, edge.get_input_node()))
571            .map(move |(edge, input_node)| (edge, self.get_node_state(input_node).unwrap())))
572    }
573
574    /// Adds the `sub_graph` with the `label` to the graph.
575    /// If the label is already present replaces it instead.
576    pub fn add_sub_graph(&mut self, label: impl RenderSubGraph, sub_graph: RenderGraph) {
577        self.sub_graphs.insert(label.intern(), sub_graph);
578    }
579
580    /// Removes the `sub_graph` with the `label` from the graph.
581    /// If the label does not exist then nothing happens.
582    pub fn remove_sub_graph(&mut self, label: impl RenderSubGraph) {
583        self.sub_graphs.remove(&label.intern());
584    }
585
586    /// Retrieves the sub graph corresponding to the `label`.
587    pub fn get_sub_graph(&self, label: impl RenderSubGraph) -> Option<&RenderGraph> {
588        self.sub_graphs.get(&label.intern())
589    }
590
591    /// Retrieves the sub graph corresponding to the `label` mutably.
592    pub fn get_sub_graph_mut(&mut self, label: impl RenderSubGraph) -> Option<&mut RenderGraph> {
593        self.sub_graphs.get_mut(&label.intern())
594    }
595
596    /// Retrieves the sub graph corresponding to the `label`.
597    ///
598    /// # Panics
599    ///
600    /// Panics if any invalid subgraph label is given.
601    ///
602    /// # See also
603    ///
604    /// - [`get_sub_graph`](Self::get_sub_graph) for a fallible version.
605    pub fn sub_graph(&self, label: impl RenderSubGraph) -> &RenderGraph {
606        let label = label.intern();
607        self.sub_graphs
608            .get(&label)
609            .unwrap_or_else(|| panic!("Subgraph {label:?} not found"))
610    }
611
612    /// Retrieves the sub graph corresponding to the `label` mutably.
613    ///
614    /// # Panics
615    ///
616    /// Panics if any invalid subgraph label is given.
617    ///
618    /// # See also
619    ///
620    /// - [`get_sub_graph_mut`](Self::get_sub_graph_mut) for a fallible version.
621    pub fn sub_graph_mut(&mut self, label: impl RenderSubGraph) -> &mut RenderGraph {
622        let label = label.intern();
623        self.sub_graphs
624            .get_mut(&label)
625            .unwrap_or_else(|| panic!("Subgraph {label:?} not found"))
626    }
627}
628
629impl Debug for RenderGraph {
630    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
631        for node in self.iter_nodes() {
632            writeln!(f, "{:?}", node.label)?;
633            writeln!(f, "  in: {:?}", node.input_slots)?;
634            writeln!(f, "  out: {:?}", node.output_slots)?;
635        }
636
637        Ok(())
638    }
639}
640
641/// A [`Node`] which acts as an entry point for a [`RenderGraph`] with custom inputs.
642/// It has the same input and output slots and simply copies them over when run.
643pub struct GraphInputNode {
644    inputs: Vec<SlotInfo>,
645}
646
647impl Node for GraphInputNode {
648    fn input(&self) -> Vec<SlotInfo> {
649        self.inputs.clone()
650    }
651
652    fn output(&self) -> Vec<SlotInfo> {
653        self.inputs.clone()
654    }
655
656    fn run(
657        &self,
658        graph: &mut RenderGraphContext,
659        _render_context: &mut RenderContext,
660        _world: &World,
661    ) -> Result<(), NodeRunError> {
662        for i in 0..graph.inputs().len() {
663            let input = graph.inputs()[i].clone();
664            graph.set_output(i, input)?;
665        }
666        Ok(())
667    }
668}
669
670#[cfg(test)]
671mod tests {
672    use crate::{
673        render_graph::{
674            node::IntoRenderNodeArray, Edge, InternedRenderLabel, Node, NodeRunError, RenderGraph,
675            RenderGraphContext, RenderGraphError, RenderLabel, SlotInfo, SlotType,
676        },
677        renderer::RenderContext,
678    };
679    use bevy_ecs::world::{FromWorld, World};
680    use bevy_utils::HashSet;
681
682    #[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
683    enum TestLabel {
684        A,
685        B,
686        C,
687        D,
688    }
689
690    #[derive(Debug)]
691    struct TestNode {
692        inputs: Vec<SlotInfo>,
693        outputs: Vec<SlotInfo>,
694    }
695
696    impl TestNode {
697        pub fn new(inputs: usize, outputs: usize) -> Self {
698            TestNode {
699                inputs: (0..inputs)
700                    .map(|i| SlotInfo::new(format!("in_{i}"), SlotType::TextureView))
701                    .collect(),
702                outputs: (0..outputs)
703                    .map(|i| SlotInfo::new(format!("out_{i}"), SlotType::TextureView))
704                    .collect(),
705            }
706        }
707    }
708
709    impl Node for TestNode {
710        fn input(&self) -> Vec<SlotInfo> {
711            self.inputs.clone()
712        }
713
714        fn output(&self) -> Vec<SlotInfo> {
715            self.outputs.clone()
716        }
717
718        fn run(
719            &self,
720            _: &mut RenderGraphContext,
721            _: &mut RenderContext,
722            _: &World,
723        ) -> Result<(), NodeRunError> {
724            Ok(())
725        }
726    }
727
728    fn input_nodes(label: impl RenderLabel, graph: &RenderGraph) -> HashSet<InternedRenderLabel> {
729        graph
730            .iter_node_inputs(label)
731            .unwrap()
732            .map(|(_edge, node)| node.label)
733            .collect::<HashSet<InternedRenderLabel>>()
734    }
735
736    fn output_nodes(label: impl RenderLabel, graph: &RenderGraph) -> HashSet<InternedRenderLabel> {
737        graph
738            .iter_node_outputs(label)
739            .unwrap()
740            .map(|(_edge, node)| node.label)
741            .collect::<HashSet<InternedRenderLabel>>()
742    }
743
744    #[test]
745    fn test_graph_edges() {
746        let mut graph = RenderGraph::default();
747        graph.add_node(TestLabel::A, TestNode::new(0, 1));
748        graph.add_node(TestLabel::B, TestNode::new(0, 1));
749        graph.add_node(TestLabel::C, TestNode::new(1, 1));
750        graph.add_node(TestLabel::D, TestNode::new(1, 0));
751
752        graph.add_slot_edge(TestLabel::A, "out_0", TestLabel::C, "in_0");
753        graph.add_node_edge(TestLabel::B, TestLabel::C);
754        graph.add_slot_edge(TestLabel::C, 0, TestLabel::D, 0);
755
756        assert!(
757            input_nodes(TestLabel::A, &graph).is_empty(),
758            "A has no inputs"
759        );
760        assert_eq!(
761            output_nodes(TestLabel::A, &graph),
762            HashSet::from_iter((TestLabel::C,).into_array()),
763            "A outputs to C"
764        );
765
766        assert!(
767            input_nodes(TestLabel::B, &graph).is_empty(),
768            "B has no inputs"
769        );
770        assert_eq!(
771            output_nodes(TestLabel::B, &graph),
772            HashSet::from_iter((TestLabel::C,).into_array()),
773            "B outputs to C"
774        );
775
776        assert_eq!(
777            input_nodes(TestLabel::C, &graph),
778            HashSet::from_iter((TestLabel::A, TestLabel::B).into_array()),
779            "A and B input to C"
780        );
781        assert_eq!(
782            output_nodes(TestLabel::C, &graph),
783            HashSet::from_iter((TestLabel::D,).into_array()),
784            "C outputs to D"
785        );
786
787        assert_eq!(
788            input_nodes(TestLabel::D, &graph),
789            HashSet::from_iter((TestLabel::C,).into_array()),
790            "C inputs to D"
791        );
792        assert!(
793            output_nodes(TestLabel::D, &graph).is_empty(),
794            "D has no outputs"
795        );
796    }
797
798    #[test]
799    fn test_get_node_typed() {
800        struct MyNode {
801            value: usize,
802        }
803
804        impl Node for MyNode {
805            fn run(
806                &self,
807                _: &mut RenderGraphContext,
808                _: &mut RenderContext,
809                _: &World,
810            ) -> Result<(), NodeRunError> {
811                Ok(())
812            }
813        }
814
815        let mut graph = RenderGraph::default();
816
817        graph.add_node(TestLabel::A, MyNode { value: 42 });
818
819        let node: &MyNode = graph.get_node(TestLabel::A).unwrap();
820        assert_eq!(node.value, 42, "node value matches");
821
822        let result: Result<&TestNode, RenderGraphError> = graph.get_node(TestLabel::A);
823        assert_eq!(
824            result.unwrap_err(),
825            RenderGraphError::WrongNodeType,
826            "expect a wrong node type error"
827        );
828    }
829
830    #[test]
831    fn test_slot_already_occupied() {
832        let mut graph = RenderGraph::default();
833
834        graph.add_node(TestLabel::A, TestNode::new(0, 1));
835        graph.add_node(TestLabel::B, TestNode::new(0, 1));
836        graph.add_node(TestLabel::C, TestNode::new(1, 1));
837
838        graph.add_slot_edge(TestLabel::A, 0, TestLabel::C, 0);
839        assert_eq!(
840            graph.try_add_slot_edge(TestLabel::B, 0, TestLabel::C, 0),
841            Err(RenderGraphError::NodeInputSlotAlreadyOccupied {
842                node: TestLabel::C.intern(),
843                input_slot: 0,
844                occupied_by_node: TestLabel::A.intern(),
845            }),
846            "Adding to a slot that is already occupied should return an error"
847        );
848    }
849
850    #[test]
851    fn test_edge_already_exists() {
852        let mut graph = RenderGraph::default();
853
854        graph.add_node(TestLabel::A, TestNode::new(0, 1));
855        graph.add_node(TestLabel::B, TestNode::new(1, 0));
856
857        graph.add_slot_edge(TestLabel::A, 0, TestLabel::B, 0);
858        assert_eq!(
859            graph.try_add_slot_edge(TestLabel::A, 0, TestLabel::B, 0),
860            Err(RenderGraphError::EdgeAlreadyExists(Edge::SlotEdge {
861                output_node: TestLabel::A.intern(),
862                output_index: 0,
863                input_node: TestLabel::B.intern(),
864                input_index: 0,
865            })),
866            "Adding to a duplicate edge should return an error"
867        );
868    }
869
870    #[test]
871    fn test_add_node_edges() {
872        struct SimpleNode;
873        impl Node for SimpleNode {
874            fn run(
875                &self,
876                _graph: &mut RenderGraphContext,
877                _render_context: &mut RenderContext,
878                _world: &World,
879            ) -> Result<(), NodeRunError> {
880                Ok(())
881            }
882        }
883        impl FromWorld for SimpleNode {
884            fn from_world(_world: &mut World) -> Self {
885                Self
886            }
887        }
888
889        let mut graph = RenderGraph::default();
890        graph.add_node(TestLabel::A, SimpleNode);
891        graph.add_node(TestLabel::B, SimpleNode);
892        graph.add_node(TestLabel::C, SimpleNode);
893
894        graph.add_node_edges((TestLabel::A, TestLabel::B, TestLabel::C));
895
896        assert_eq!(
897            output_nodes(TestLabel::A, &graph),
898            HashSet::from_iter((TestLabel::B,).into_array()),
899            "A -> B"
900        );
901        assert_eq!(
902            input_nodes(TestLabel::B, &graph),
903            HashSet::from_iter((TestLabel::A,).into_array()),
904            "A -> B"
905        );
906        assert_eq!(
907            output_nodes(TestLabel::B, &graph),
908            HashSet::from_iter((TestLabel::C,).into_array()),
909            "B -> C"
910        );
911        assert_eq!(
912            input_nodes(TestLabel::C, &graph),
913            HashSet::from_iter((TestLabel::B,).into_array()),
914            "B -> C"
915        );
916    }
917}