1use crate::{
2 render_graph::{
3 Edge, Node, NodeRunError, NodeState, RenderGraphContext, RenderGraphError, RenderLabel,
4 SlotInfo, SlotLabel,
5 },
6 renderer::RenderContext,
7};
8use bevy_ecs::{define_label, intern::Interned, prelude::World, system::Resource};
9use bevy_utils::HashMap;
10use core::fmt::Debug;
11
12use super::{EdgeExistence, InternedRenderLabel, IntoRenderNodeArray};
13
14pub use bevy_render_macros::RenderSubGraph;
15
16define_label!(
17 RenderSubGraph,
19 RENDER_SUB_GRAPH_INTERNER
20);
21
22pub type InternedRenderSubGraph = Interned<dyn RenderSubGraph>;
24
25#[derive(Resource, Default)]
72pub struct RenderGraph {
73 nodes: HashMap<InternedRenderLabel, NodeState>,
74 sub_graphs: HashMap<InternedRenderSubGraph, RenderGraph>,
75}
76
77#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
79pub struct GraphInput;
80
81impl RenderGraph {
82 pub fn update(&mut self, world: &mut World) {
84 for node in self.nodes.values_mut() {
85 node.node.update(world);
86 }
87
88 for sub_graph in self.sub_graphs.values_mut() {
89 sub_graph.update(world);
90 }
91 }
92
93 pub fn set_input(&mut self, inputs: Vec<SlotInfo>) {
95 assert!(
96 matches!(
97 self.get_node_state(GraphInput),
98 Err(RenderGraphError::InvalidNode(_))
99 ),
100 "Graph already has an input node"
101 );
102
103 self.add_node(GraphInput, GraphInputNode { inputs });
104 }
105
106 #[inline]
112 pub fn get_input_node(&self) -> Option<&NodeState> {
113 self.get_node_state(GraphInput).ok()
114 }
115
116 #[inline]
126 pub fn input_node(&self) -> &NodeState {
127 self.get_input_node().unwrap()
128 }
129
130 pub fn add_node<T>(&mut self, label: impl RenderLabel, node: T)
133 where
134 T: Node,
135 {
136 let label = label.intern();
137 let node_state = NodeState::new(label, node);
138 self.nodes.insert(label, node_state);
139 }
140
141 pub fn add_node_edges<const N: usize>(&mut self, edges: impl IntoRenderNodeArray<N>) {
146 for window in edges.into_array().windows(2) {
147 let [a, b] = window else {
148 break;
149 };
150 if let Err(err) = self.try_add_node_edge(*a, *b) {
151 match err {
152 RenderGraphError::EdgeAlreadyExists(_) => {}
155 _ => panic!("{err:?}"),
156 }
157 }
158 }
159 }
160
161 pub fn remove_node(&mut self, label: impl RenderLabel) -> Result<(), RenderGraphError> {
164 let label = label.intern();
165 if let Some(node_state) = self.nodes.remove(&label) {
166 for input_edge in node_state.edges.input_edges() {
169 match input_edge {
170 Edge::SlotEdge { output_node, .. }
171 | Edge::NodeEdge {
172 input_node: _,
173 output_node,
174 } => {
175 if let Ok(output_node) = self.get_node_state_mut(*output_node) {
176 output_node.edges.remove_output_edge(input_edge.clone())?;
177 }
178 }
179 }
180 }
181 for output_edge in node_state.edges.output_edges() {
184 match output_edge {
185 Edge::SlotEdge {
186 output_node: _,
187 output_index: _,
188 input_node,
189 input_index: _,
190 }
191 | Edge::NodeEdge {
192 output_node: _,
193 input_node,
194 } => {
195 if let Ok(input_node) = self.get_node_state_mut(*input_node) {
196 input_node.edges.remove_input_edge(output_edge.clone())?;
197 }
198 }
199 }
200 }
201 }
202
203 Ok(())
204 }
205
206 pub fn get_node_state(&self, label: impl RenderLabel) -> Result<&NodeState, RenderGraphError> {
208 let label = label.intern();
209 self.nodes
210 .get(&label)
211 .ok_or(RenderGraphError::InvalidNode(label))
212 }
213
214 pub fn get_node_state_mut(
216 &mut self,
217 label: impl RenderLabel,
218 ) -> Result<&mut NodeState, RenderGraphError> {
219 let label = label.intern();
220 self.nodes
221 .get_mut(&label)
222 .ok_or(RenderGraphError::InvalidNode(label))
223 }
224
225 pub fn get_node<T>(&self, label: impl RenderLabel) -> Result<&T, RenderGraphError>
227 where
228 T: Node,
229 {
230 self.get_node_state(label).and_then(|n| n.node())
231 }
232
233 pub fn get_node_mut<T>(&mut self, label: impl RenderLabel) -> Result<&mut T, RenderGraphError>
235 where
236 T: Node,
237 {
238 self.get_node_state_mut(label).and_then(|n| n.node_mut())
239 }
240
241 pub fn try_add_slot_edge(
250 &mut self,
251 output_node: impl RenderLabel,
252 output_slot: impl Into<SlotLabel>,
253 input_node: impl RenderLabel,
254 input_slot: impl Into<SlotLabel>,
255 ) -> Result<(), RenderGraphError> {
256 let output_slot = output_slot.into();
257 let input_slot = input_slot.into();
258
259 let output_node = output_node.intern();
260 let input_node = input_node.intern();
261
262 let output_index = self
263 .get_node_state(output_node)?
264 .output_slots
265 .get_slot_index(output_slot.clone())
266 .ok_or(RenderGraphError::InvalidOutputNodeSlot(output_slot))?;
267 let input_index = self
268 .get_node_state(input_node)?
269 .input_slots
270 .get_slot_index(input_slot.clone())
271 .ok_or(RenderGraphError::InvalidInputNodeSlot(input_slot))?;
272
273 let edge = Edge::SlotEdge {
274 output_node,
275 output_index,
276 input_node,
277 input_index,
278 };
279
280 self.validate_edge(&edge, EdgeExistence::DoesNotExist)?;
281
282 {
283 let output_node = self.get_node_state_mut(output_node)?;
284 output_node.edges.add_output_edge(edge.clone())?;
285 }
286 let input_node = self.get_node_state_mut(input_node)?;
287 input_node.edges.add_input_edge(edge)?;
288
289 Ok(())
290 }
291
292 pub fn add_slot_edge(
303 &mut self,
304 output_node: impl RenderLabel,
305 output_slot: impl Into<SlotLabel>,
306 input_node: impl RenderLabel,
307 input_slot: impl Into<SlotLabel>,
308 ) {
309 self.try_add_slot_edge(output_node, output_slot, input_node, input_slot)
310 .unwrap();
311 }
312
313 pub fn remove_slot_edge(
316 &mut self,
317 output_node: impl RenderLabel,
318 output_slot: impl Into<SlotLabel>,
319 input_node: impl RenderLabel,
320 input_slot: impl Into<SlotLabel>,
321 ) -> Result<(), RenderGraphError> {
322 let output_slot = output_slot.into();
323 let input_slot = input_slot.into();
324
325 let output_node = output_node.intern();
326 let input_node = input_node.intern();
327
328 let output_index = self
329 .get_node_state(output_node)?
330 .output_slots
331 .get_slot_index(output_slot.clone())
332 .ok_or(RenderGraphError::InvalidOutputNodeSlot(output_slot))?;
333 let input_index = self
334 .get_node_state(input_node)?
335 .input_slots
336 .get_slot_index(input_slot.clone())
337 .ok_or(RenderGraphError::InvalidInputNodeSlot(input_slot))?;
338
339 let edge = Edge::SlotEdge {
340 output_node,
341 output_index,
342 input_node,
343 input_index,
344 };
345
346 self.validate_edge(&edge, EdgeExistence::Exists)?;
347
348 {
349 let output_node = self.get_node_state_mut(output_node)?;
350 output_node.edges.remove_output_edge(edge.clone())?;
351 }
352 let input_node = self.get_node_state_mut(input_node)?;
353 input_node.edges.remove_input_edge(edge)?;
354
355 Ok(())
356 }
357
358 pub fn try_add_node_edge(
367 &mut self,
368 output_node: impl RenderLabel,
369 input_node: impl RenderLabel,
370 ) -> Result<(), RenderGraphError> {
371 let output_node = output_node.intern();
372 let input_node = input_node.intern();
373
374 let edge = Edge::NodeEdge {
375 output_node,
376 input_node,
377 };
378
379 self.validate_edge(&edge, EdgeExistence::DoesNotExist)?;
380
381 {
382 let output_node = self.get_node_state_mut(output_node)?;
383 output_node.edges.add_output_edge(edge.clone())?;
384 }
385 let input_node = self.get_node_state_mut(input_node)?;
386 input_node.edges.add_input_edge(edge)?;
387
388 Ok(())
389 }
390
391 pub fn add_node_edge(&mut self, output_node: impl RenderLabel, input_node: impl RenderLabel) {
402 self.try_add_node_edge(output_node, input_node).unwrap();
403 }
404
405 pub fn remove_node_edge(
408 &mut self,
409 output_node: impl RenderLabel,
410 input_node: impl RenderLabel,
411 ) -> Result<(), RenderGraphError> {
412 let output_node = output_node.intern();
413 let input_node = input_node.intern();
414
415 let edge = Edge::NodeEdge {
416 output_node,
417 input_node,
418 };
419
420 self.validate_edge(&edge, EdgeExistence::Exists)?;
421
422 {
423 let output_node = self.get_node_state_mut(output_node)?;
424 output_node.edges.remove_output_edge(edge.clone())?;
425 }
426 let input_node = self.get_node_state_mut(input_node)?;
427 input_node.edges.remove_input_edge(edge)?;
428
429 Ok(())
430 }
431
432 pub fn validate_edge(
435 &mut self,
436 edge: &Edge,
437 should_exist: EdgeExistence,
438 ) -> Result<(), RenderGraphError> {
439 if should_exist == EdgeExistence::Exists && !self.has_edge(edge) {
440 return Err(RenderGraphError::EdgeDoesNotExist(edge.clone()));
441 } else if should_exist == EdgeExistence::DoesNotExist && self.has_edge(edge) {
442 return Err(RenderGraphError::EdgeAlreadyExists(edge.clone()));
443 }
444
445 match *edge {
446 Edge::SlotEdge {
447 output_node,
448 output_index,
449 input_node,
450 input_index,
451 } => {
452 let output_node_state = self.get_node_state(output_node)?;
453 let input_node_state = self.get_node_state(input_node)?;
454
455 let output_slot = output_node_state
456 .output_slots
457 .get_slot(output_index)
458 .ok_or(RenderGraphError::InvalidOutputNodeSlot(SlotLabel::Index(
459 output_index,
460 )))?;
461 let input_slot = input_node_state.input_slots.get_slot(input_index).ok_or(
462 RenderGraphError::InvalidInputNodeSlot(SlotLabel::Index(input_index)),
463 )?;
464
465 if let Some(Edge::SlotEdge {
466 output_node: current_output_node,
467 ..
468 }) = input_node_state.edges.input_edges().iter().find(|e| {
469 if let Edge::SlotEdge {
470 input_index: current_input_index,
471 ..
472 } = e
473 {
474 input_index == *current_input_index
475 } else {
476 false
477 }
478 }) {
479 if should_exist == EdgeExistence::DoesNotExist {
480 return Err(RenderGraphError::NodeInputSlotAlreadyOccupied {
481 node: input_node,
482 input_slot: input_index,
483 occupied_by_node: *current_output_node,
484 });
485 }
486 }
487
488 if output_slot.slot_type != input_slot.slot_type {
489 return Err(RenderGraphError::MismatchedNodeSlots {
490 output_node,
491 output_slot: output_index,
492 input_node,
493 input_slot: input_index,
494 });
495 }
496 }
497 Edge::NodeEdge { .. } => { }
498 }
499
500 Ok(())
501 }
502
503 pub fn has_edge(&self, edge: &Edge) -> bool {
505 let output_node_state = self.get_node_state(edge.get_output_node());
506 let input_node_state = self.get_node_state(edge.get_input_node());
507 if let Ok(output_node_state) = output_node_state {
508 if output_node_state.edges.output_edges().contains(edge) {
509 if let Ok(input_node_state) = input_node_state {
510 if input_node_state.edges.input_edges().contains(edge) {
511 return true;
512 }
513 }
514 }
515 }
516
517 false
518 }
519
520 pub fn iter_nodes(&self) -> impl Iterator<Item = &NodeState> {
522 self.nodes.values()
523 }
524
525 pub fn iter_nodes_mut(&mut self) -> impl Iterator<Item = &mut NodeState> {
527 self.nodes.values_mut()
528 }
529
530 pub fn iter_sub_graphs(&self) -> impl Iterator<Item = (InternedRenderSubGraph, &RenderGraph)> {
532 self.sub_graphs.iter().map(|(name, graph)| (*name, graph))
533 }
534
535 pub fn iter_sub_graphs_mut(
537 &mut self,
538 ) -> impl Iterator<Item = (InternedRenderSubGraph, &mut RenderGraph)> {
539 self.sub_graphs
540 .iter_mut()
541 .map(|(name, graph)| (*name, graph))
542 }
543
544 pub fn iter_node_inputs(
547 &self,
548 label: impl RenderLabel,
549 ) -> Result<impl Iterator<Item = (&Edge, &NodeState)>, RenderGraphError> {
550 let node = self.get_node_state(label)?;
551 Ok(node
552 .edges
553 .input_edges()
554 .iter()
555 .map(|edge| (edge, edge.get_output_node()))
556 .map(move |(edge, output_node)| (edge, self.get_node_state(output_node).unwrap())))
557 }
558
559 pub fn iter_node_outputs(
562 &self,
563 label: impl RenderLabel,
564 ) -> Result<impl Iterator<Item = (&Edge, &NodeState)>, RenderGraphError> {
565 let node = self.get_node_state(label)?;
566 Ok(node
567 .edges
568 .output_edges()
569 .iter()
570 .map(|edge| (edge, edge.get_input_node()))
571 .map(move |(edge, input_node)| (edge, self.get_node_state(input_node).unwrap())))
572 }
573
574 pub fn add_sub_graph(&mut self, label: impl RenderSubGraph, sub_graph: RenderGraph) {
577 self.sub_graphs.insert(label.intern(), sub_graph);
578 }
579
580 pub fn remove_sub_graph(&mut self, label: impl RenderSubGraph) {
583 self.sub_graphs.remove(&label.intern());
584 }
585
586 pub fn get_sub_graph(&self, label: impl RenderSubGraph) -> Option<&RenderGraph> {
588 self.sub_graphs.get(&label.intern())
589 }
590
591 pub fn get_sub_graph_mut(&mut self, label: impl RenderSubGraph) -> Option<&mut RenderGraph> {
593 self.sub_graphs.get_mut(&label.intern())
594 }
595
596 pub fn sub_graph(&self, label: impl RenderSubGraph) -> &RenderGraph {
606 let label = label.intern();
607 self.sub_graphs
608 .get(&label)
609 .unwrap_or_else(|| panic!("Subgraph {label:?} not found"))
610 }
611
612 pub fn sub_graph_mut(&mut self, label: impl RenderSubGraph) -> &mut RenderGraph {
622 let label = label.intern();
623 self.sub_graphs
624 .get_mut(&label)
625 .unwrap_or_else(|| panic!("Subgraph {label:?} not found"))
626 }
627}
628
629impl Debug for RenderGraph {
630 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
631 for node in self.iter_nodes() {
632 writeln!(f, "{:?}", node.label)?;
633 writeln!(f, " in: {:?}", node.input_slots)?;
634 writeln!(f, " out: {:?}", node.output_slots)?;
635 }
636
637 Ok(())
638 }
639}
640
641pub struct GraphInputNode {
644 inputs: Vec<SlotInfo>,
645}
646
647impl Node for GraphInputNode {
648 fn input(&self) -> Vec<SlotInfo> {
649 self.inputs.clone()
650 }
651
652 fn output(&self) -> Vec<SlotInfo> {
653 self.inputs.clone()
654 }
655
656 fn run(
657 &self,
658 graph: &mut RenderGraphContext,
659 _render_context: &mut RenderContext,
660 _world: &World,
661 ) -> Result<(), NodeRunError> {
662 for i in 0..graph.inputs().len() {
663 let input = graph.inputs()[i].clone();
664 graph.set_output(i, input)?;
665 }
666 Ok(())
667 }
668}
669
670#[cfg(test)]
671mod tests {
672 use crate::{
673 render_graph::{
674 node::IntoRenderNodeArray, Edge, InternedRenderLabel, Node, NodeRunError, RenderGraph,
675 RenderGraphContext, RenderGraphError, RenderLabel, SlotInfo, SlotType,
676 },
677 renderer::RenderContext,
678 };
679 use bevy_ecs::world::{FromWorld, World};
680 use bevy_utils::HashSet;
681
682 #[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
683 enum TestLabel {
684 A,
685 B,
686 C,
687 D,
688 }
689
690 #[derive(Debug)]
691 struct TestNode {
692 inputs: Vec<SlotInfo>,
693 outputs: Vec<SlotInfo>,
694 }
695
696 impl TestNode {
697 pub fn new(inputs: usize, outputs: usize) -> Self {
698 TestNode {
699 inputs: (0..inputs)
700 .map(|i| SlotInfo::new(format!("in_{i}"), SlotType::TextureView))
701 .collect(),
702 outputs: (0..outputs)
703 .map(|i| SlotInfo::new(format!("out_{i}"), SlotType::TextureView))
704 .collect(),
705 }
706 }
707 }
708
709 impl Node for TestNode {
710 fn input(&self) -> Vec<SlotInfo> {
711 self.inputs.clone()
712 }
713
714 fn output(&self) -> Vec<SlotInfo> {
715 self.outputs.clone()
716 }
717
718 fn run(
719 &self,
720 _: &mut RenderGraphContext,
721 _: &mut RenderContext,
722 _: &World,
723 ) -> Result<(), NodeRunError> {
724 Ok(())
725 }
726 }
727
728 fn input_nodes(label: impl RenderLabel, graph: &RenderGraph) -> HashSet<InternedRenderLabel> {
729 graph
730 .iter_node_inputs(label)
731 .unwrap()
732 .map(|(_edge, node)| node.label)
733 .collect::<HashSet<InternedRenderLabel>>()
734 }
735
736 fn output_nodes(label: impl RenderLabel, graph: &RenderGraph) -> HashSet<InternedRenderLabel> {
737 graph
738 .iter_node_outputs(label)
739 .unwrap()
740 .map(|(_edge, node)| node.label)
741 .collect::<HashSet<InternedRenderLabel>>()
742 }
743
744 #[test]
745 fn test_graph_edges() {
746 let mut graph = RenderGraph::default();
747 graph.add_node(TestLabel::A, TestNode::new(0, 1));
748 graph.add_node(TestLabel::B, TestNode::new(0, 1));
749 graph.add_node(TestLabel::C, TestNode::new(1, 1));
750 graph.add_node(TestLabel::D, TestNode::new(1, 0));
751
752 graph.add_slot_edge(TestLabel::A, "out_0", TestLabel::C, "in_0");
753 graph.add_node_edge(TestLabel::B, TestLabel::C);
754 graph.add_slot_edge(TestLabel::C, 0, TestLabel::D, 0);
755
756 assert!(
757 input_nodes(TestLabel::A, &graph).is_empty(),
758 "A has no inputs"
759 );
760 assert_eq!(
761 output_nodes(TestLabel::A, &graph),
762 HashSet::from_iter((TestLabel::C,).into_array()),
763 "A outputs to C"
764 );
765
766 assert!(
767 input_nodes(TestLabel::B, &graph).is_empty(),
768 "B has no inputs"
769 );
770 assert_eq!(
771 output_nodes(TestLabel::B, &graph),
772 HashSet::from_iter((TestLabel::C,).into_array()),
773 "B outputs to C"
774 );
775
776 assert_eq!(
777 input_nodes(TestLabel::C, &graph),
778 HashSet::from_iter((TestLabel::A, TestLabel::B).into_array()),
779 "A and B input to C"
780 );
781 assert_eq!(
782 output_nodes(TestLabel::C, &graph),
783 HashSet::from_iter((TestLabel::D,).into_array()),
784 "C outputs to D"
785 );
786
787 assert_eq!(
788 input_nodes(TestLabel::D, &graph),
789 HashSet::from_iter((TestLabel::C,).into_array()),
790 "C inputs to D"
791 );
792 assert!(
793 output_nodes(TestLabel::D, &graph).is_empty(),
794 "D has no outputs"
795 );
796 }
797
798 #[test]
799 fn test_get_node_typed() {
800 struct MyNode {
801 value: usize,
802 }
803
804 impl Node for MyNode {
805 fn run(
806 &self,
807 _: &mut RenderGraphContext,
808 _: &mut RenderContext,
809 _: &World,
810 ) -> Result<(), NodeRunError> {
811 Ok(())
812 }
813 }
814
815 let mut graph = RenderGraph::default();
816
817 graph.add_node(TestLabel::A, MyNode { value: 42 });
818
819 let node: &MyNode = graph.get_node(TestLabel::A).unwrap();
820 assert_eq!(node.value, 42, "node value matches");
821
822 let result: Result<&TestNode, RenderGraphError> = graph.get_node(TestLabel::A);
823 assert_eq!(
824 result.unwrap_err(),
825 RenderGraphError::WrongNodeType,
826 "expect a wrong node type error"
827 );
828 }
829
830 #[test]
831 fn test_slot_already_occupied() {
832 let mut graph = RenderGraph::default();
833
834 graph.add_node(TestLabel::A, TestNode::new(0, 1));
835 graph.add_node(TestLabel::B, TestNode::new(0, 1));
836 graph.add_node(TestLabel::C, TestNode::new(1, 1));
837
838 graph.add_slot_edge(TestLabel::A, 0, TestLabel::C, 0);
839 assert_eq!(
840 graph.try_add_slot_edge(TestLabel::B, 0, TestLabel::C, 0),
841 Err(RenderGraphError::NodeInputSlotAlreadyOccupied {
842 node: TestLabel::C.intern(),
843 input_slot: 0,
844 occupied_by_node: TestLabel::A.intern(),
845 }),
846 "Adding to a slot that is already occupied should return an error"
847 );
848 }
849
850 #[test]
851 fn test_edge_already_exists() {
852 let mut graph = RenderGraph::default();
853
854 graph.add_node(TestLabel::A, TestNode::new(0, 1));
855 graph.add_node(TestLabel::B, TestNode::new(1, 0));
856
857 graph.add_slot_edge(TestLabel::A, 0, TestLabel::B, 0);
858 assert_eq!(
859 graph.try_add_slot_edge(TestLabel::A, 0, TestLabel::B, 0),
860 Err(RenderGraphError::EdgeAlreadyExists(Edge::SlotEdge {
861 output_node: TestLabel::A.intern(),
862 output_index: 0,
863 input_node: TestLabel::B.intern(),
864 input_index: 0,
865 })),
866 "Adding to a duplicate edge should return an error"
867 );
868 }
869
870 #[test]
871 fn test_add_node_edges() {
872 struct SimpleNode;
873 impl Node for SimpleNode {
874 fn run(
875 &self,
876 _graph: &mut RenderGraphContext,
877 _render_context: &mut RenderContext,
878 _world: &World,
879 ) -> Result<(), NodeRunError> {
880 Ok(())
881 }
882 }
883 impl FromWorld for SimpleNode {
884 fn from_world(_world: &mut World) -> Self {
885 Self
886 }
887 }
888
889 let mut graph = RenderGraph::default();
890 graph.add_node(TestLabel::A, SimpleNode);
891 graph.add_node(TestLabel::B, SimpleNode);
892 graph.add_node(TestLabel::C, SimpleNode);
893
894 graph.add_node_edges((TestLabel::A, TestLabel::B, TestLabel::C));
895
896 assert_eq!(
897 output_nodes(TestLabel::A, &graph),
898 HashSet::from_iter((TestLabel::B,).into_array()),
899 "A -> B"
900 );
901 assert_eq!(
902 input_nodes(TestLabel::B, &graph),
903 HashSet::from_iter((TestLabel::A,).into_array()),
904 "A -> B"
905 );
906 assert_eq!(
907 output_nodes(TestLabel::B, &graph),
908 HashSet::from_iter((TestLabel::C,).into_array()),
909 "B -> C"
910 );
911 assert_eq!(
912 input_nodes(TestLabel::C, &graph),
913 HashSet::from_iter((TestLabel::B,).into_array()),
914 "B -> C"
915 );
916 }
917}