bevy_render/render_graph/
graph.rs

1use crate::{
2    render_graph::{
3        Edge, Node, NodeRunError, NodeState, RenderGraphContext, RenderGraphError, RenderLabel,
4        SlotInfo, SlotLabel,
5    },
6    renderer::RenderContext,
7};
8use bevy_ecs::{define_label, intern::Interned, prelude::World, resource::Resource};
9use bevy_platform::collections::HashMap;
10use core::fmt::Debug;
11
12use super::{EdgeExistence, InternedRenderLabel, IntoRenderNodeArray};
13
14pub use bevy_render_macros::RenderSubGraph;
15
16define_label!(
17    #[diagnostic::on_unimplemented(
18        note = "consider annotating `{Self}` with `#[derive(RenderSubGraph)]`"
19    )]
20    /// A strongly-typed class of labels used to identify a [`SubGraph`] in a render graph.
21    RenderSubGraph,
22    RENDER_SUB_GRAPH_INTERNER
23);
24
25/// A shorthand for `Interned<dyn RenderSubGraph>`.
26pub type InternedRenderSubGraph = Interned<dyn RenderSubGraph>;
27
28/// The render graph configures the modular and re-usable render logic.
29///
30/// It is a retained and stateless (nodes themselves may have their own internal state) structure,
31/// which can not be modified while it is executed by the graph runner.
32///
33/// The render graph runner is responsible for executing the entire graph each frame.
34/// It will execute each node in the graph in the correct order, based on the edges between the nodes.
35///
36/// It consists of three main components: [`Nodes`](Node), [`Edges`](Edge)
37/// and [`Slots`](super::SlotType).
38///
39/// Nodes are responsible for generating draw calls and operating on input and output slots.
40/// Edges specify the order of execution for nodes and connect input and output slots together.
41/// Slots describe the render resources created or used by the nodes.
42///
43/// Additionally a render graph can contain multiple sub graphs, which are run by the
44/// corresponding nodes. Every render graph can have its own optional input node.
45///
46/// ## Example
47/// Here is a simple render graph example with two nodes connected by a node edge.
48/// ```ignore
49/// # TODO: Remove when #10645 is fixed
50/// # use bevy_app::prelude::*;
51/// # use bevy_ecs::prelude::World;
52/// # use bevy_render::render_graph::{RenderGraph, RenderLabel, Node, RenderGraphContext, NodeRunError};
53/// # use bevy_render::renderer::RenderContext;
54/// #
55/// #[derive(RenderLabel)]
56/// enum Labels {
57///     A,
58///     B,
59/// }
60///
61/// # struct MyNode;
62/// #
63/// # impl Node for MyNode {
64/// #     fn run(&self, graph: &mut RenderGraphContext, render_context: &mut RenderContext, world: &World) -> Result<(), NodeRunError> {
65/// #         unimplemented!()
66/// #     }
67/// # }
68/// #
69/// let mut graph = RenderGraph::default();
70/// graph.add_node(Labels::A, MyNode);
71/// graph.add_node(Labels::B, MyNode);
72/// graph.add_node_edge(Labels::B, Labels::A);
73/// ```
74#[derive(Resource, Default)]
75pub struct RenderGraph {
76    nodes: HashMap<InternedRenderLabel, NodeState>,
77    sub_graphs: HashMap<InternedRenderSubGraph, RenderGraph>,
78}
79
80/// The label for the input node of a graph. Used to connect other nodes to it.
81#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
82pub struct GraphInput;
83
84impl RenderGraph {
85    /// Updates all nodes and sub graphs of the render graph. Should be called before executing it.
86    pub fn update(&mut self, world: &mut World) {
87        for node in self.nodes.values_mut() {
88            node.node.update(world);
89        }
90
91        for sub_graph in self.sub_graphs.values_mut() {
92            sub_graph.update(world);
93        }
94    }
95
96    /// Creates an [`GraphInputNode`] with the specified slots if not already present.
97    pub fn set_input(&mut self, inputs: Vec<SlotInfo>) {
98        assert!(
99            matches!(
100                self.get_node_state(GraphInput),
101                Err(RenderGraphError::InvalidNode(_))
102            ),
103            "Graph already has an input node"
104        );
105
106        self.add_node(GraphInput, GraphInputNode { inputs });
107    }
108
109    /// Returns the [`NodeState`] of the input node of this graph.
110    ///
111    /// # See also
112    ///
113    /// - [`input_node`](Self::input_node) for an unchecked version.
114    #[inline]
115    pub fn get_input_node(&self) -> Option<&NodeState> {
116        self.get_node_state(GraphInput).ok()
117    }
118
119    /// Returns the [`NodeState`] of the input node of this graph.
120    ///
121    /// # Panics
122    ///
123    /// Panics if there is no input node set.
124    ///
125    /// # See also
126    ///
127    /// - [`get_input_node`](Self::get_input_node) for a version which returns an [`Option`] instead.
128    #[inline]
129    pub fn input_node(&self) -> &NodeState {
130        self.get_input_node().unwrap()
131    }
132
133    /// Adds the `node` with the `label` to the graph.
134    /// If the label is already present replaces it instead.
135    pub fn add_node<T>(&mut self, label: impl RenderLabel, node: T)
136    where
137        T: Node,
138    {
139        let label = label.intern();
140        let node_state = NodeState::new(label, node);
141        self.nodes.insert(label, node_state);
142    }
143
144    /// Add `node_edge`s based on the order of the given `edges` array.
145    ///
146    /// Defining an edge that already exists is not considered an error with this api.
147    /// It simply won't create a new edge.
148    pub fn add_node_edges<const N: usize>(&mut self, edges: impl IntoRenderNodeArray<N>) {
149        for window in edges.into_array().windows(2) {
150            let [a, b] = window else {
151                break;
152            };
153            if let Err(err) = self.try_add_node_edge(*a, *b) {
154                match err {
155                    // Already existing edges are very easy to produce with this api
156                    // and shouldn't cause a panic
157                    RenderGraphError::EdgeAlreadyExists(_) => {}
158                    _ => panic!("{err:?}"),
159                }
160            }
161        }
162    }
163
164    /// Removes the `node` with the `label` from the graph.
165    /// If the label does not exist, nothing happens.
166    pub fn remove_node(&mut self, label: impl RenderLabel) -> Result<(), RenderGraphError> {
167        let label = label.intern();
168        if let Some(node_state) = self.nodes.remove(&label) {
169            // Remove all edges from other nodes to this one. Note that as we're removing this
170            // node, we don't need to remove its input edges
171            for input_edge in node_state.edges.input_edges() {
172                match input_edge {
173                    Edge::SlotEdge { output_node, .. }
174                    | Edge::NodeEdge {
175                        input_node: _,
176                        output_node,
177                    } => {
178                        if let Ok(output_node) = self.get_node_state_mut(*output_node) {
179                            output_node.edges.remove_output_edge(input_edge.clone())?;
180                        }
181                    }
182                }
183            }
184            // Remove all edges from this node to other nodes. Note that as we're removing this
185            // node, we don't need to remove its output edges
186            for output_edge in node_state.edges.output_edges() {
187                match output_edge {
188                    Edge::SlotEdge {
189                        output_node: _,
190                        output_index: _,
191                        input_node,
192                        input_index: _,
193                    }
194                    | Edge::NodeEdge {
195                        output_node: _,
196                        input_node,
197                    } => {
198                        if let Ok(input_node) = self.get_node_state_mut(*input_node) {
199                            input_node.edges.remove_input_edge(output_edge.clone())?;
200                        }
201                    }
202                }
203            }
204        }
205
206        Ok(())
207    }
208
209    /// Retrieves the [`NodeState`] referenced by the `label`.
210    pub fn get_node_state(&self, label: impl RenderLabel) -> Result<&NodeState, RenderGraphError> {
211        let label = label.intern();
212        self.nodes
213            .get(&label)
214            .ok_or(RenderGraphError::InvalidNode(label))
215    }
216
217    /// Retrieves the [`NodeState`] referenced by the `label` mutably.
218    pub fn get_node_state_mut(
219        &mut self,
220        label: impl RenderLabel,
221    ) -> Result<&mut NodeState, RenderGraphError> {
222        let label = label.intern();
223        self.nodes
224            .get_mut(&label)
225            .ok_or(RenderGraphError::InvalidNode(label))
226    }
227
228    /// Retrieves the [`Node`] referenced by the `label`.
229    pub fn get_node<T>(&self, label: impl RenderLabel) -> Result<&T, RenderGraphError>
230    where
231        T: Node,
232    {
233        self.get_node_state(label).and_then(|n| n.node())
234    }
235
236    /// Retrieves the [`Node`] referenced by the `label` mutably.
237    pub fn get_node_mut<T>(&mut self, label: impl RenderLabel) -> Result<&mut T, RenderGraphError>
238    where
239        T: Node,
240    {
241        self.get_node_state_mut(label).and_then(|n| n.node_mut())
242    }
243
244    /// Adds the [`Edge::SlotEdge`] to the graph. This guarantees that the `output_node`
245    /// is run before the `input_node` and also connects the `output_slot` to the `input_slot`.
246    ///
247    /// Fails if any invalid [`RenderLabel`]s or [`SlotLabel`]s are given.
248    ///
249    /// # See also
250    ///
251    /// - [`add_slot_edge`](Self::add_slot_edge) for an infallible version.
252    pub fn try_add_slot_edge(
253        &mut self,
254        output_node: impl RenderLabel,
255        output_slot: impl Into<SlotLabel>,
256        input_node: impl RenderLabel,
257        input_slot: impl Into<SlotLabel>,
258    ) -> Result<(), RenderGraphError> {
259        let output_slot = output_slot.into();
260        let input_slot = input_slot.into();
261
262        let output_node = output_node.intern();
263        let input_node = input_node.intern();
264
265        let output_index = self
266            .get_node_state(output_node)?
267            .output_slots
268            .get_slot_index(output_slot.clone())
269            .ok_or(RenderGraphError::InvalidOutputNodeSlot(output_slot))?;
270        let input_index = self
271            .get_node_state(input_node)?
272            .input_slots
273            .get_slot_index(input_slot.clone())
274            .ok_or(RenderGraphError::InvalidInputNodeSlot(input_slot))?;
275
276        let edge = Edge::SlotEdge {
277            output_node,
278            output_index,
279            input_node,
280            input_index,
281        };
282
283        self.validate_edge(&edge, EdgeExistence::DoesNotExist)?;
284
285        {
286            let output_node = self.get_node_state_mut(output_node)?;
287            output_node.edges.add_output_edge(edge.clone())?;
288        }
289        let input_node = self.get_node_state_mut(input_node)?;
290        input_node.edges.add_input_edge(edge)?;
291
292        Ok(())
293    }
294
295    /// Adds the [`Edge::SlotEdge`] to the graph. This guarantees that the `output_node`
296    /// is run before the `input_node` and also connects the `output_slot` to the `input_slot`.
297    ///
298    /// # Panics
299    ///
300    /// Any invalid [`RenderLabel`]s or [`SlotLabel`]s are given.
301    ///
302    /// # See also
303    ///
304    /// - [`try_add_slot_edge`](Self::try_add_slot_edge) for a fallible version.
305    pub fn add_slot_edge(
306        &mut self,
307        output_node: impl RenderLabel,
308        output_slot: impl Into<SlotLabel>,
309        input_node: impl RenderLabel,
310        input_slot: impl Into<SlotLabel>,
311    ) {
312        self.try_add_slot_edge(output_node, output_slot, input_node, input_slot)
313            .unwrap();
314    }
315
316    /// Removes the [`Edge::SlotEdge`] from the graph. If any nodes or slots do not exist then
317    /// nothing happens.
318    pub fn remove_slot_edge(
319        &mut self,
320        output_node: impl RenderLabel,
321        output_slot: impl Into<SlotLabel>,
322        input_node: impl RenderLabel,
323        input_slot: impl Into<SlotLabel>,
324    ) -> Result<(), RenderGraphError> {
325        let output_slot = output_slot.into();
326        let input_slot = input_slot.into();
327
328        let output_node = output_node.intern();
329        let input_node = input_node.intern();
330
331        let output_index = self
332            .get_node_state(output_node)?
333            .output_slots
334            .get_slot_index(output_slot.clone())
335            .ok_or(RenderGraphError::InvalidOutputNodeSlot(output_slot))?;
336        let input_index = self
337            .get_node_state(input_node)?
338            .input_slots
339            .get_slot_index(input_slot.clone())
340            .ok_or(RenderGraphError::InvalidInputNodeSlot(input_slot))?;
341
342        let edge = Edge::SlotEdge {
343            output_node,
344            output_index,
345            input_node,
346            input_index,
347        };
348
349        self.validate_edge(&edge, EdgeExistence::Exists)?;
350
351        {
352            let output_node = self.get_node_state_mut(output_node)?;
353            output_node.edges.remove_output_edge(edge.clone())?;
354        }
355        let input_node = self.get_node_state_mut(input_node)?;
356        input_node.edges.remove_input_edge(edge)?;
357
358        Ok(())
359    }
360
361    /// Adds the [`Edge::NodeEdge`] to the graph. This guarantees that the `output_node`
362    /// is run before the `input_node`.
363    ///
364    /// Fails if any invalid [`RenderLabel`] is given.
365    ///
366    /// # See also
367    ///
368    /// - [`add_node_edge`](Self::add_node_edge) for an infallible version.
369    pub fn try_add_node_edge(
370        &mut self,
371        output_node: impl RenderLabel,
372        input_node: impl RenderLabel,
373    ) -> Result<(), RenderGraphError> {
374        let output_node = output_node.intern();
375        let input_node = input_node.intern();
376
377        let edge = Edge::NodeEdge {
378            output_node,
379            input_node,
380        };
381
382        self.validate_edge(&edge, EdgeExistence::DoesNotExist)?;
383
384        {
385            let output_node = self.get_node_state_mut(output_node)?;
386            output_node.edges.add_output_edge(edge.clone())?;
387        }
388        let input_node = self.get_node_state_mut(input_node)?;
389        input_node.edges.add_input_edge(edge)?;
390
391        Ok(())
392    }
393
394    /// Adds the [`Edge::NodeEdge`] to the graph. This guarantees that the `output_node`
395    /// is run before the `input_node`.
396    ///
397    /// # Panics
398    ///
399    /// Panics if any invalid [`RenderLabel`] is given.
400    ///
401    /// # See also
402    ///
403    /// - [`try_add_node_edge`](Self::try_add_node_edge) for a fallible version.
404    pub fn add_node_edge(&mut self, output_node: impl RenderLabel, input_node: impl RenderLabel) {
405        self.try_add_node_edge(output_node, input_node).unwrap();
406    }
407
408    /// Removes the [`Edge::NodeEdge`] from the graph. If either node does not exist then nothing
409    /// happens.
410    pub fn remove_node_edge(
411        &mut self,
412        output_node: impl RenderLabel,
413        input_node: impl RenderLabel,
414    ) -> Result<(), RenderGraphError> {
415        let output_node = output_node.intern();
416        let input_node = input_node.intern();
417
418        let edge = Edge::NodeEdge {
419            output_node,
420            input_node,
421        };
422
423        self.validate_edge(&edge, EdgeExistence::Exists)?;
424
425        {
426            let output_node = self.get_node_state_mut(output_node)?;
427            output_node.edges.remove_output_edge(edge.clone())?;
428        }
429        let input_node = self.get_node_state_mut(input_node)?;
430        input_node.edges.remove_input_edge(edge)?;
431
432        Ok(())
433    }
434
435    /// Verifies that the edge existence is as expected and
436    /// checks that slot edges are connected correctly.
437    pub fn validate_edge(
438        &mut self,
439        edge: &Edge,
440        should_exist: EdgeExistence,
441    ) -> Result<(), RenderGraphError> {
442        if should_exist == EdgeExistence::Exists && !self.has_edge(edge) {
443            return Err(RenderGraphError::EdgeDoesNotExist(edge.clone()));
444        } else if should_exist == EdgeExistence::DoesNotExist && self.has_edge(edge) {
445            return Err(RenderGraphError::EdgeAlreadyExists(edge.clone()));
446        }
447
448        match *edge {
449            Edge::SlotEdge {
450                output_node,
451                output_index,
452                input_node,
453                input_index,
454            } => {
455                let output_node_state = self.get_node_state(output_node)?;
456                let input_node_state = self.get_node_state(input_node)?;
457
458                let output_slot = output_node_state
459                    .output_slots
460                    .get_slot(output_index)
461                    .ok_or(RenderGraphError::InvalidOutputNodeSlot(SlotLabel::Index(
462                        output_index,
463                    )))?;
464                let input_slot = input_node_state.input_slots.get_slot(input_index).ok_or(
465                    RenderGraphError::InvalidInputNodeSlot(SlotLabel::Index(input_index)),
466                )?;
467
468                if let Some(Edge::SlotEdge {
469                    output_node: current_output_node,
470                    ..
471                }) = input_node_state.edges.input_edges().iter().find(|e| {
472                    if let Edge::SlotEdge {
473                        input_index: current_input_index,
474                        ..
475                    } = e
476                    {
477                        input_index == *current_input_index
478                    } else {
479                        false
480                    }
481                }) {
482                    if should_exist == EdgeExistence::DoesNotExist {
483                        return Err(RenderGraphError::NodeInputSlotAlreadyOccupied {
484                            node: input_node,
485                            input_slot: input_index,
486                            occupied_by_node: *current_output_node,
487                        });
488                    }
489                }
490
491                if output_slot.slot_type != input_slot.slot_type {
492                    return Err(RenderGraphError::MismatchedNodeSlots {
493                        output_node,
494                        output_slot: output_index,
495                        input_node,
496                        input_slot: input_index,
497                    });
498                }
499            }
500            Edge::NodeEdge { .. } => { /* nothing to validate here */ }
501        }
502
503        Ok(())
504    }
505
506    /// Checks whether the `edge` already exists in the graph.
507    pub fn has_edge(&self, edge: &Edge) -> bool {
508        let output_node_state = self.get_node_state(edge.get_output_node());
509        let input_node_state = self.get_node_state(edge.get_input_node());
510        if let Ok(output_node_state) = output_node_state {
511            if output_node_state.edges.output_edges().contains(edge) {
512                if let Ok(input_node_state) = input_node_state {
513                    if input_node_state.edges.input_edges().contains(edge) {
514                        return true;
515                    }
516                }
517            }
518        }
519
520        false
521    }
522
523    /// Returns an iterator over the [`NodeStates`](NodeState).
524    pub fn iter_nodes(&self) -> impl Iterator<Item = &NodeState> {
525        self.nodes.values()
526    }
527
528    /// Returns an iterator over the [`NodeStates`](NodeState), that allows modifying each value.
529    pub fn iter_nodes_mut(&mut self) -> impl Iterator<Item = &mut NodeState> {
530        self.nodes.values_mut()
531    }
532
533    /// Returns an iterator over the sub graphs.
534    pub fn iter_sub_graphs(&self) -> impl Iterator<Item = (InternedRenderSubGraph, &RenderGraph)> {
535        self.sub_graphs.iter().map(|(name, graph)| (*name, graph))
536    }
537
538    /// Returns an iterator over the sub graphs, that allows modifying each value.
539    pub fn iter_sub_graphs_mut(
540        &mut self,
541    ) -> impl Iterator<Item = (InternedRenderSubGraph, &mut RenderGraph)> {
542        self.sub_graphs
543            .iter_mut()
544            .map(|(name, graph)| (*name, graph))
545    }
546
547    /// Returns an iterator over a tuple of the input edges and the corresponding output nodes
548    /// for the node referenced by the label.
549    pub fn iter_node_inputs(
550        &self,
551        label: impl RenderLabel,
552    ) -> Result<impl Iterator<Item = (&Edge, &NodeState)>, RenderGraphError> {
553        let node = self.get_node_state(label)?;
554        Ok(node
555            .edges
556            .input_edges()
557            .iter()
558            .map(|edge| (edge, edge.get_output_node()))
559            .map(move |(edge, output_node)| (edge, self.get_node_state(output_node).unwrap())))
560    }
561
562    /// Returns an iterator over a tuple of the output edges and the corresponding input nodes
563    /// for the node referenced by the label.
564    pub fn iter_node_outputs(
565        &self,
566        label: impl RenderLabel,
567    ) -> Result<impl Iterator<Item = (&Edge, &NodeState)>, RenderGraphError> {
568        let node = self.get_node_state(label)?;
569        Ok(node
570            .edges
571            .output_edges()
572            .iter()
573            .map(|edge| (edge, edge.get_input_node()))
574            .map(move |(edge, input_node)| (edge, self.get_node_state(input_node).unwrap())))
575    }
576
577    /// Adds the `sub_graph` with the `label` to the graph.
578    /// If the label is already present replaces it instead.
579    pub fn add_sub_graph(&mut self, label: impl RenderSubGraph, sub_graph: RenderGraph) {
580        self.sub_graphs.insert(label.intern(), sub_graph);
581    }
582
583    /// Removes the `sub_graph` with the `label` from the graph.
584    /// If the label does not exist then nothing happens.
585    pub fn remove_sub_graph(&mut self, label: impl RenderSubGraph) {
586        self.sub_graphs.remove(&label.intern());
587    }
588
589    /// Retrieves the sub graph corresponding to the `label`.
590    pub fn get_sub_graph(&self, label: impl RenderSubGraph) -> Option<&RenderGraph> {
591        self.sub_graphs.get(&label.intern())
592    }
593
594    /// Retrieves the sub graph corresponding to the `label` mutably.
595    pub fn get_sub_graph_mut(&mut self, label: impl RenderSubGraph) -> Option<&mut RenderGraph> {
596        self.sub_graphs.get_mut(&label.intern())
597    }
598
599    /// Retrieves the sub graph corresponding to the `label`.
600    ///
601    /// # Panics
602    ///
603    /// Panics if any invalid subgraph label is given.
604    ///
605    /// # See also
606    ///
607    /// - [`get_sub_graph`](Self::get_sub_graph) for a fallible version.
608    pub fn sub_graph(&self, label: impl RenderSubGraph) -> &RenderGraph {
609        let label = label.intern();
610        self.sub_graphs
611            .get(&label)
612            .unwrap_or_else(|| panic!("Subgraph {label:?} not found"))
613    }
614
615    /// Retrieves the sub graph corresponding to the `label` mutably.
616    ///
617    /// # Panics
618    ///
619    /// Panics if any invalid subgraph label is given.
620    ///
621    /// # See also
622    ///
623    /// - [`get_sub_graph_mut`](Self::get_sub_graph_mut) for a fallible version.
624    pub fn sub_graph_mut(&mut self, label: impl RenderSubGraph) -> &mut RenderGraph {
625        let label = label.intern();
626        self.sub_graphs
627            .get_mut(&label)
628            .unwrap_or_else(|| panic!("Subgraph {label:?} not found"))
629    }
630}
631
632impl Debug for RenderGraph {
633    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
634        for node in self.iter_nodes() {
635            writeln!(f, "{:?}", node.label)?;
636            writeln!(f, "  in: {:?}", node.input_slots)?;
637            writeln!(f, "  out: {:?}", node.output_slots)?;
638        }
639
640        Ok(())
641    }
642}
643
644/// A [`Node`] which acts as an entry point for a [`RenderGraph`] with custom inputs.
645/// It has the same input and output slots and simply copies them over when run.
646pub struct GraphInputNode {
647    inputs: Vec<SlotInfo>,
648}
649
650impl Node for GraphInputNode {
651    fn input(&self) -> Vec<SlotInfo> {
652        self.inputs.clone()
653    }
654
655    fn output(&self) -> Vec<SlotInfo> {
656        self.inputs.clone()
657    }
658
659    fn run(
660        &self,
661        graph: &mut RenderGraphContext,
662        _render_context: &mut RenderContext,
663        _world: &World,
664    ) -> Result<(), NodeRunError> {
665        for i in 0..graph.inputs().len() {
666            let input = graph.inputs()[i].clone();
667            graph.set_output(i, input)?;
668        }
669        Ok(())
670    }
671}
672
673#[cfg(test)]
674mod tests {
675    use crate::{
676        render_graph::{
677            node::IntoRenderNodeArray, Edge, InternedRenderLabel, Node, NodeRunError, RenderGraph,
678            RenderGraphContext, RenderGraphError, RenderLabel, SlotInfo, SlotType,
679        },
680        renderer::RenderContext,
681    };
682    use bevy_ecs::world::{FromWorld, World};
683    use bevy_platform::collections::HashSet;
684
685    #[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
686    enum TestLabel {
687        A,
688        B,
689        C,
690        D,
691    }
692
693    #[derive(Debug)]
694    struct TestNode {
695        inputs: Vec<SlotInfo>,
696        outputs: Vec<SlotInfo>,
697    }
698
699    impl TestNode {
700        pub fn new(inputs: usize, outputs: usize) -> Self {
701            TestNode {
702                inputs: (0..inputs)
703                    .map(|i| SlotInfo::new(format!("in_{i}"), SlotType::TextureView))
704                    .collect(),
705                outputs: (0..outputs)
706                    .map(|i| SlotInfo::new(format!("out_{i}"), SlotType::TextureView))
707                    .collect(),
708            }
709        }
710    }
711
712    impl Node for TestNode {
713        fn input(&self) -> Vec<SlotInfo> {
714            self.inputs.clone()
715        }
716
717        fn output(&self) -> Vec<SlotInfo> {
718            self.outputs.clone()
719        }
720
721        fn run(
722            &self,
723            _: &mut RenderGraphContext,
724            _: &mut RenderContext,
725            _: &World,
726        ) -> Result<(), NodeRunError> {
727            Ok(())
728        }
729    }
730
731    fn input_nodes(label: impl RenderLabel, graph: &RenderGraph) -> HashSet<InternedRenderLabel> {
732        graph
733            .iter_node_inputs(label)
734            .unwrap()
735            .map(|(_edge, node)| node.label)
736            .collect::<HashSet<InternedRenderLabel>>()
737    }
738
739    fn output_nodes(label: impl RenderLabel, graph: &RenderGraph) -> HashSet<InternedRenderLabel> {
740        graph
741            .iter_node_outputs(label)
742            .unwrap()
743            .map(|(_edge, node)| node.label)
744            .collect::<HashSet<InternedRenderLabel>>()
745    }
746
747    #[test]
748    fn test_graph_edges() {
749        let mut graph = RenderGraph::default();
750        graph.add_node(TestLabel::A, TestNode::new(0, 1));
751        graph.add_node(TestLabel::B, TestNode::new(0, 1));
752        graph.add_node(TestLabel::C, TestNode::new(1, 1));
753        graph.add_node(TestLabel::D, TestNode::new(1, 0));
754
755        graph.add_slot_edge(TestLabel::A, "out_0", TestLabel::C, "in_0");
756        graph.add_node_edge(TestLabel::B, TestLabel::C);
757        graph.add_slot_edge(TestLabel::C, 0, TestLabel::D, 0);
758
759        assert!(
760            input_nodes(TestLabel::A, &graph).is_empty(),
761            "A has no inputs"
762        );
763        assert_eq!(
764            output_nodes(TestLabel::A, &graph),
765            HashSet::from_iter((TestLabel::C,).into_array()),
766            "A outputs to C"
767        );
768
769        assert!(
770            input_nodes(TestLabel::B, &graph).is_empty(),
771            "B has no inputs"
772        );
773        assert_eq!(
774            output_nodes(TestLabel::B, &graph),
775            HashSet::from_iter((TestLabel::C,).into_array()),
776            "B outputs to C"
777        );
778
779        assert_eq!(
780            input_nodes(TestLabel::C, &graph),
781            HashSet::from_iter((TestLabel::A, TestLabel::B).into_array()),
782            "A and B input to C"
783        );
784        assert_eq!(
785            output_nodes(TestLabel::C, &graph),
786            HashSet::from_iter((TestLabel::D,).into_array()),
787            "C outputs to D"
788        );
789
790        assert_eq!(
791            input_nodes(TestLabel::D, &graph),
792            HashSet::from_iter((TestLabel::C,).into_array()),
793            "C inputs to D"
794        );
795        assert!(
796            output_nodes(TestLabel::D, &graph).is_empty(),
797            "D has no outputs"
798        );
799    }
800
801    #[test]
802    fn test_get_node_typed() {
803        struct MyNode {
804            value: usize,
805        }
806
807        impl Node for MyNode {
808            fn run(
809                &self,
810                _: &mut RenderGraphContext,
811                _: &mut RenderContext,
812                _: &World,
813            ) -> Result<(), NodeRunError> {
814                Ok(())
815            }
816        }
817
818        let mut graph = RenderGraph::default();
819
820        graph.add_node(TestLabel::A, MyNode { value: 42 });
821
822        let node: &MyNode = graph.get_node(TestLabel::A).unwrap();
823        assert_eq!(node.value, 42, "node value matches");
824
825        let result: Result<&TestNode, RenderGraphError> = graph.get_node(TestLabel::A);
826        assert_eq!(
827            result.unwrap_err(),
828            RenderGraphError::WrongNodeType,
829            "expect a wrong node type error"
830        );
831    }
832
833    #[test]
834    fn test_slot_already_occupied() {
835        let mut graph = RenderGraph::default();
836
837        graph.add_node(TestLabel::A, TestNode::new(0, 1));
838        graph.add_node(TestLabel::B, TestNode::new(0, 1));
839        graph.add_node(TestLabel::C, TestNode::new(1, 1));
840
841        graph.add_slot_edge(TestLabel::A, 0, TestLabel::C, 0);
842        assert_eq!(
843            graph.try_add_slot_edge(TestLabel::B, 0, TestLabel::C, 0),
844            Err(RenderGraphError::NodeInputSlotAlreadyOccupied {
845                node: TestLabel::C.intern(),
846                input_slot: 0,
847                occupied_by_node: TestLabel::A.intern(),
848            }),
849            "Adding to a slot that is already occupied should return an error"
850        );
851    }
852
853    #[test]
854    fn test_edge_already_exists() {
855        let mut graph = RenderGraph::default();
856
857        graph.add_node(TestLabel::A, TestNode::new(0, 1));
858        graph.add_node(TestLabel::B, TestNode::new(1, 0));
859
860        graph.add_slot_edge(TestLabel::A, 0, TestLabel::B, 0);
861        assert_eq!(
862            graph.try_add_slot_edge(TestLabel::A, 0, TestLabel::B, 0),
863            Err(RenderGraphError::EdgeAlreadyExists(Edge::SlotEdge {
864                output_node: TestLabel::A.intern(),
865                output_index: 0,
866                input_node: TestLabel::B.intern(),
867                input_index: 0,
868            })),
869            "Adding to a duplicate edge should return an error"
870        );
871    }
872
873    #[test]
874    fn test_add_node_edges() {
875        struct SimpleNode;
876        impl Node for SimpleNode {
877            fn run(
878                &self,
879                _graph: &mut RenderGraphContext,
880                _render_context: &mut RenderContext,
881                _world: &World,
882            ) -> Result<(), NodeRunError> {
883                Ok(())
884            }
885        }
886        impl FromWorld for SimpleNode {
887            fn from_world(_world: &mut World) -> Self {
888                Self
889            }
890        }
891
892        let mut graph = RenderGraph::default();
893        graph.add_node(TestLabel::A, SimpleNode);
894        graph.add_node(TestLabel::B, SimpleNode);
895        graph.add_node(TestLabel::C, SimpleNode);
896
897        graph.add_node_edges((TestLabel::A, TestLabel::B, TestLabel::C));
898
899        assert_eq!(
900            output_nodes(TestLabel::A, &graph),
901            HashSet::from_iter((TestLabel::B,).into_array()),
902            "A -> B"
903        );
904        assert_eq!(
905            input_nodes(TestLabel::B, &graph),
906            HashSet::from_iter((TestLabel::A,).into_array()),
907            "A -> B"
908        );
909        assert_eq!(
910            output_nodes(TestLabel::B, &graph),
911            HashSet::from_iter((TestLabel::C,).into_array()),
912            "B -> C"
913        );
914        assert_eq!(
915            input_nodes(TestLabel::C, &graph),
916            HashSet::from_iter((TestLabel::B,).into_array()),
917            "B -> C"
918        );
919    }
920}