1use crate::{
2 render_graph::{
3 Edge, Node, NodeRunError, NodeState, RenderGraphContext, RenderGraphError, RenderLabel,
4 SlotInfo, SlotLabel,
5 },
6 renderer::RenderContext,
7};
8use bevy_ecs::{define_label, intern::Interned, prelude::World, resource::Resource};
9use bevy_platform::collections::HashMap;
10use core::fmt::Debug;
11
12use super::{EdgeExistence, InternedRenderLabel, IntoRenderNodeArray};
13
14pub use bevy_render_macros::RenderSubGraph;
15
16define_label!(
17 #[diagnostic::on_unimplemented(
18 note = "consider annotating `{Self}` with `#[derive(RenderSubGraph)]`"
19 )]
20 RenderSubGraph,
22 RENDER_SUB_GRAPH_INTERNER
23);
24
25pub type InternedRenderSubGraph = Interned<dyn RenderSubGraph>;
27
28#[derive(Resource, Default)]
75pub struct RenderGraph {
76 nodes: HashMap<InternedRenderLabel, NodeState>,
77 sub_graphs: HashMap<InternedRenderSubGraph, RenderGraph>,
78}
79
80#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
82pub struct GraphInput;
83
84impl RenderGraph {
85 pub fn update(&mut self, world: &mut World) {
87 for node in self.nodes.values_mut() {
88 node.node.update(world);
89 }
90
91 for sub_graph in self.sub_graphs.values_mut() {
92 sub_graph.update(world);
93 }
94 }
95
96 pub fn set_input(&mut self, inputs: Vec<SlotInfo>) {
98 assert!(
99 matches!(
100 self.get_node_state(GraphInput),
101 Err(RenderGraphError::InvalidNode(_))
102 ),
103 "Graph already has an input node"
104 );
105
106 self.add_node(GraphInput, GraphInputNode { inputs });
107 }
108
109 #[inline]
115 pub fn get_input_node(&self) -> Option<&NodeState> {
116 self.get_node_state(GraphInput).ok()
117 }
118
119 #[inline]
129 pub fn input_node(&self) -> &NodeState {
130 self.get_input_node().unwrap()
131 }
132
133 pub fn add_node<T>(&mut self, label: impl RenderLabel, node: T)
136 where
137 T: Node,
138 {
139 let label = label.intern();
140 let node_state = NodeState::new(label, node);
141 self.nodes.insert(label, node_state);
142 }
143
144 #[track_caller]
149 pub fn add_node_edges<const N: usize>(&mut self, edges: impl IntoRenderNodeArray<N>) {
150 for window in edges.into_array().windows(2) {
151 let [a, b] = window else {
152 break;
153 };
154 if let Err(err) = self.try_add_node_edge(*a, *b) {
155 match err {
156 RenderGraphError::EdgeAlreadyExists(_) => {}
159 _ => panic!("{err}"),
160 }
161 }
162 }
163 }
164
165 pub fn remove_node(&mut self, label: impl RenderLabel) -> Result<(), RenderGraphError> {
168 let label = label.intern();
169 if let Some(node_state) = self.nodes.remove(&label) {
170 for input_edge in node_state.edges.input_edges() {
173 match input_edge {
174 Edge::SlotEdge { output_node, .. }
175 | Edge::NodeEdge {
176 input_node: _,
177 output_node,
178 } => {
179 if let Ok(output_node) = self.get_node_state_mut(*output_node) {
180 output_node.edges.remove_output_edge(input_edge.clone())?;
181 }
182 }
183 }
184 }
185 for output_edge in node_state.edges.output_edges() {
188 match output_edge {
189 Edge::SlotEdge {
190 output_node: _,
191 output_index: _,
192 input_node,
193 input_index: _,
194 }
195 | Edge::NodeEdge {
196 output_node: _,
197 input_node,
198 } => {
199 if let Ok(input_node) = self.get_node_state_mut(*input_node) {
200 input_node.edges.remove_input_edge(output_edge.clone())?;
201 }
202 }
203 }
204 }
205 }
206
207 Ok(())
208 }
209
210 pub fn get_node_state(&self, label: impl RenderLabel) -> Result<&NodeState, RenderGraphError> {
212 let label = label.intern();
213 self.nodes
214 .get(&label)
215 .ok_or(RenderGraphError::InvalidNode(label))
216 }
217
218 pub fn get_node_state_mut(
220 &mut self,
221 label: impl RenderLabel,
222 ) -> Result<&mut NodeState, RenderGraphError> {
223 let label = label.intern();
224 self.nodes
225 .get_mut(&label)
226 .ok_or(RenderGraphError::InvalidNode(label))
227 }
228
229 pub fn get_node<T>(&self, label: impl RenderLabel) -> Result<&T, RenderGraphError>
231 where
232 T: Node,
233 {
234 self.get_node_state(label).and_then(|n| n.node())
235 }
236
237 pub fn get_node_mut<T>(&mut self, label: impl RenderLabel) -> Result<&mut T, RenderGraphError>
239 where
240 T: Node,
241 {
242 self.get_node_state_mut(label).and_then(|n| n.node_mut())
243 }
244
245 pub fn try_add_slot_edge(
254 &mut self,
255 output_node: impl RenderLabel,
256 output_slot: impl Into<SlotLabel>,
257 input_node: impl RenderLabel,
258 input_slot: impl Into<SlotLabel>,
259 ) -> Result<(), RenderGraphError> {
260 let output_slot = output_slot.into();
261 let input_slot = input_slot.into();
262
263 let output_node = output_node.intern();
264 let input_node = input_node.intern();
265
266 let output_index = self
267 .get_node_state(output_node)?
268 .output_slots
269 .get_slot_index(output_slot.clone())
270 .ok_or(RenderGraphError::InvalidOutputNodeSlot(output_slot))?;
271 let input_index = self
272 .get_node_state(input_node)?
273 .input_slots
274 .get_slot_index(input_slot.clone())
275 .ok_or(RenderGraphError::InvalidInputNodeSlot(input_slot))?;
276
277 let edge = Edge::SlotEdge {
278 output_node,
279 output_index,
280 input_node,
281 input_index,
282 };
283
284 self.validate_edge(&edge, EdgeExistence::DoesNotExist)?;
285
286 {
287 let output_node = self.get_node_state_mut(output_node)?;
288 output_node.edges.add_output_edge(edge.clone())?;
289 }
290 let input_node = self.get_node_state_mut(input_node)?;
291 input_node.edges.add_input_edge(edge)?;
292
293 Ok(())
294 }
295
296 pub fn add_slot_edge(
307 &mut self,
308 output_node: impl RenderLabel,
309 output_slot: impl Into<SlotLabel>,
310 input_node: impl RenderLabel,
311 input_slot: impl Into<SlotLabel>,
312 ) {
313 self.try_add_slot_edge(output_node, output_slot, input_node, input_slot)
314 .unwrap();
315 }
316
317 pub fn remove_slot_edge(
320 &mut self,
321 output_node: impl RenderLabel,
322 output_slot: impl Into<SlotLabel>,
323 input_node: impl RenderLabel,
324 input_slot: impl Into<SlotLabel>,
325 ) -> Result<(), RenderGraphError> {
326 let output_slot = output_slot.into();
327 let input_slot = input_slot.into();
328
329 let output_node = output_node.intern();
330 let input_node = input_node.intern();
331
332 let output_index = self
333 .get_node_state(output_node)?
334 .output_slots
335 .get_slot_index(output_slot.clone())
336 .ok_or(RenderGraphError::InvalidOutputNodeSlot(output_slot))?;
337 let input_index = self
338 .get_node_state(input_node)?
339 .input_slots
340 .get_slot_index(input_slot.clone())
341 .ok_or(RenderGraphError::InvalidInputNodeSlot(input_slot))?;
342
343 let edge = Edge::SlotEdge {
344 output_node,
345 output_index,
346 input_node,
347 input_index,
348 };
349
350 self.validate_edge(&edge, EdgeExistence::Exists)?;
351
352 {
353 let output_node = self.get_node_state_mut(output_node)?;
354 output_node.edges.remove_output_edge(edge.clone())?;
355 }
356 let input_node = self.get_node_state_mut(input_node)?;
357 input_node.edges.remove_input_edge(edge)?;
358
359 Ok(())
360 }
361
362 pub fn try_add_node_edge(
371 &mut self,
372 output_node: impl RenderLabel,
373 input_node: impl RenderLabel,
374 ) -> Result<(), RenderGraphError> {
375 let output_node = output_node.intern();
376 let input_node = input_node.intern();
377
378 let edge = Edge::NodeEdge {
379 output_node,
380 input_node,
381 };
382
383 self.validate_edge(&edge, EdgeExistence::DoesNotExist)?;
384
385 {
386 let output_node = self.get_node_state_mut(output_node)?;
387 output_node.edges.add_output_edge(edge.clone())?;
388 }
389 let input_node = self.get_node_state_mut(input_node)?;
390 input_node.edges.add_input_edge(edge)?;
391
392 Ok(())
393 }
394
395 pub fn add_node_edge(&mut self, output_node: impl RenderLabel, input_node: impl RenderLabel) {
406 self.try_add_node_edge(output_node, input_node).unwrap();
407 }
408
409 pub fn remove_node_edge(
412 &mut self,
413 output_node: impl RenderLabel,
414 input_node: impl RenderLabel,
415 ) -> Result<(), RenderGraphError> {
416 let output_node = output_node.intern();
417 let input_node = input_node.intern();
418
419 let edge = Edge::NodeEdge {
420 output_node,
421 input_node,
422 };
423
424 self.validate_edge(&edge, EdgeExistence::Exists)?;
425
426 {
427 let output_node = self.get_node_state_mut(output_node)?;
428 output_node.edges.remove_output_edge(edge.clone())?;
429 }
430 let input_node = self.get_node_state_mut(input_node)?;
431 input_node.edges.remove_input_edge(edge)?;
432
433 Ok(())
434 }
435
436 pub fn validate_edge(
439 &mut self,
440 edge: &Edge,
441 should_exist: EdgeExistence,
442 ) -> Result<(), RenderGraphError> {
443 if should_exist == EdgeExistence::Exists && !self.has_edge(edge) {
444 return Err(RenderGraphError::EdgeDoesNotExist(edge.clone()));
445 } else if should_exist == EdgeExistence::DoesNotExist && self.has_edge(edge) {
446 return Err(RenderGraphError::EdgeAlreadyExists(edge.clone()));
447 }
448
449 match *edge {
450 Edge::SlotEdge {
451 output_node,
452 output_index,
453 input_node,
454 input_index,
455 } => {
456 let output_node_state = self.get_node_state(output_node)?;
457 let input_node_state = self.get_node_state(input_node)?;
458
459 let output_slot = output_node_state
460 .output_slots
461 .get_slot(output_index)
462 .ok_or(RenderGraphError::InvalidOutputNodeSlot(SlotLabel::Index(
463 output_index,
464 )))?;
465 let input_slot = input_node_state.input_slots.get_slot(input_index).ok_or(
466 RenderGraphError::InvalidInputNodeSlot(SlotLabel::Index(input_index)),
467 )?;
468
469 if let Some(Edge::SlotEdge {
470 output_node: current_output_node,
471 ..
472 }) = input_node_state.edges.input_edges().iter().find(|e| {
473 if let Edge::SlotEdge {
474 input_index: current_input_index,
475 ..
476 } = e
477 {
478 input_index == *current_input_index
479 } else {
480 false
481 }
482 }) && should_exist == EdgeExistence::DoesNotExist
483 {
484 return Err(RenderGraphError::NodeInputSlotAlreadyOccupied {
485 node: input_node,
486 input_slot: input_index,
487 occupied_by_node: *current_output_node,
488 });
489 }
490
491 if output_slot.slot_type != input_slot.slot_type {
492 return Err(RenderGraphError::MismatchedNodeSlots {
493 output_node,
494 output_slot: output_index,
495 input_node,
496 input_slot: input_index,
497 });
498 }
499 }
500 Edge::NodeEdge { .. } => { }
501 }
502
503 Ok(())
504 }
505
506 pub fn has_edge(&self, edge: &Edge) -> bool {
508 let output_node_state = self.get_node_state(edge.get_output_node());
509 let input_node_state = self.get_node_state(edge.get_input_node());
510 if let Ok(output_node_state) = output_node_state
511 && output_node_state.edges.output_edges().contains(edge)
512 && let Ok(input_node_state) = input_node_state
513 && input_node_state.edges.input_edges().contains(edge)
514 {
515 return true;
516 }
517
518 false
519 }
520
521 pub fn iter_nodes(&self) -> impl Iterator<Item = &NodeState> {
523 self.nodes.values()
524 }
525
526 pub fn iter_nodes_mut(&mut self) -> impl Iterator<Item = &mut NodeState> {
528 self.nodes.values_mut()
529 }
530
531 pub fn iter_sub_graphs(&self) -> impl Iterator<Item = (InternedRenderSubGraph, &RenderGraph)> {
533 self.sub_graphs.iter().map(|(name, graph)| (*name, graph))
534 }
535
536 pub fn iter_sub_graphs_mut(
538 &mut self,
539 ) -> impl Iterator<Item = (InternedRenderSubGraph, &mut RenderGraph)> {
540 self.sub_graphs
541 .iter_mut()
542 .map(|(name, graph)| (*name, graph))
543 }
544
545 pub fn iter_node_inputs(
548 &self,
549 label: impl RenderLabel,
550 ) -> Result<impl Iterator<Item = (&Edge, &NodeState)>, RenderGraphError> {
551 let node = self.get_node_state(label)?;
552 Ok(node
553 .edges
554 .input_edges()
555 .iter()
556 .map(|edge| (edge, edge.get_output_node()))
557 .map(move |(edge, output_node)| (edge, self.get_node_state(output_node).unwrap())))
558 }
559
560 pub fn iter_node_outputs(
563 &self,
564 label: impl RenderLabel,
565 ) -> Result<impl Iterator<Item = (&Edge, &NodeState)>, RenderGraphError> {
566 let node = self.get_node_state(label)?;
567 Ok(node
568 .edges
569 .output_edges()
570 .iter()
571 .map(|edge| (edge, edge.get_input_node()))
572 .map(move |(edge, input_node)| (edge, self.get_node_state(input_node).unwrap())))
573 }
574
575 pub fn add_sub_graph(&mut self, label: impl RenderSubGraph, sub_graph: RenderGraph) {
578 self.sub_graphs.insert(label.intern(), sub_graph);
579 }
580
581 pub fn remove_sub_graph(&mut self, label: impl RenderSubGraph) {
584 self.sub_graphs.remove(&label.intern());
585 }
586
587 pub fn get_sub_graph(&self, label: impl RenderSubGraph) -> Option<&RenderGraph> {
589 self.sub_graphs.get(&label.intern())
590 }
591
592 pub fn get_sub_graph_mut(&mut self, label: impl RenderSubGraph) -> Option<&mut RenderGraph> {
594 self.sub_graphs.get_mut(&label.intern())
595 }
596
597 pub fn sub_graph(&self, label: impl RenderSubGraph) -> &RenderGraph {
607 let label = label.intern();
608 self.sub_graphs
609 .get(&label)
610 .unwrap_or_else(|| panic!("Subgraph {label:?} not found"))
611 }
612
613 pub fn sub_graph_mut(&mut self, label: impl RenderSubGraph) -> &mut RenderGraph {
623 let label = label.intern();
624 self.sub_graphs
625 .get_mut(&label)
626 .unwrap_or_else(|| panic!("Subgraph {label:?} not found"))
627 }
628}
629
630impl Debug for RenderGraph {
631 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
632 for node in self.iter_nodes() {
633 writeln!(f, "{:?}", node.label)?;
634 writeln!(f, " in: {:?}", node.input_slots)?;
635 writeln!(f, " out: {:?}", node.output_slots)?;
636 }
637
638 Ok(())
639 }
640}
641
642pub struct GraphInputNode {
645 inputs: Vec<SlotInfo>,
646}
647
648impl Node for GraphInputNode {
649 fn input(&self) -> Vec<SlotInfo> {
650 self.inputs.clone()
651 }
652
653 fn output(&self) -> Vec<SlotInfo> {
654 self.inputs.clone()
655 }
656
657 fn run(
658 &self,
659 graph: &mut RenderGraphContext,
660 _render_context: &mut RenderContext,
661 _world: &World,
662 ) -> Result<(), NodeRunError> {
663 for i in 0..graph.inputs().len() {
664 let input = graph.inputs()[i].clone();
665 graph.set_output(i, input)?;
666 }
667 Ok(())
668 }
669}
670
671#[cfg(test)]
672mod tests {
673 use crate::{
674 render_graph::{
675 node::IntoRenderNodeArray, Edge, InternedRenderLabel, Node, NodeRunError, RenderGraph,
676 RenderGraphContext, RenderGraphError, RenderLabel, SlotInfo, SlotType,
677 },
678 renderer::RenderContext,
679 };
680 use bevy_ecs::world::{FromWorld, World};
681 use bevy_platform::collections::HashSet;
682
683 #[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
684 enum TestLabel {
685 A,
686 B,
687 C,
688 D,
689 }
690
691 #[derive(Debug)]
692 struct TestNode {
693 inputs: Vec<SlotInfo>,
694 outputs: Vec<SlotInfo>,
695 }
696
697 impl TestNode {
698 pub fn new(inputs: usize, outputs: usize) -> Self {
699 TestNode {
700 inputs: (0..inputs)
701 .map(|i| SlotInfo::new(format!("in_{i}"), SlotType::TextureView))
702 .collect(),
703 outputs: (0..outputs)
704 .map(|i| SlotInfo::new(format!("out_{i}"), SlotType::TextureView))
705 .collect(),
706 }
707 }
708 }
709
710 impl Node for TestNode {
711 fn input(&self) -> Vec<SlotInfo> {
712 self.inputs.clone()
713 }
714
715 fn output(&self) -> Vec<SlotInfo> {
716 self.outputs.clone()
717 }
718
719 fn run(
720 &self,
721 _: &mut RenderGraphContext,
722 _: &mut RenderContext,
723 _: &World,
724 ) -> Result<(), NodeRunError> {
725 Ok(())
726 }
727 }
728
729 fn input_nodes(label: impl RenderLabel, graph: &RenderGraph) -> HashSet<InternedRenderLabel> {
730 graph
731 .iter_node_inputs(label)
732 .unwrap()
733 .map(|(_edge, node)| node.label)
734 .collect::<HashSet<InternedRenderLabel>>()
735 }
736
737 fn output_nodes(label: impl RenderLabel, graph: &RenderGraph) -> HashSet<InternedRenderLabel> {
738 graph
739 .iter_node_outputs(label)
740 .unwrap()
741 .map(|(_edge, node)| node.label)
742 .collect::<HashSet<InternedRenderLabel>>()
743 }
744
745 #[test]
746 fn test_graph_edges() {
747 let mut graph = RenderGraph::default();
748 graph.add_node(TestLabel::A, TestNode::new(0, 1));
749 graph.add_node(TestLabel::B, TestNode::new(0, 1));
750 graph.add_node(TestLabel::C, TestNode::new(1, 1));
751 graph.add_node(TestLabel::D, TestNode::new(1, 0));
752
753 graph.add_slot_edge(TestLabel::A, "out_0", TestLabel::C, "in_0");
754 graph.add_node_edge(TestLabel::B, TestLabel::C);
755 graph.add_slot_edge(TestLabel::C, 0, TestLabel::D, 0);
756
757 assert!(
758 input_nodes(TestLabel::A, &graph).is_empty(),
759 "A has no inputs"
760 );
761 assert_eq!(
762 output_nodes(TestLabel::A, &graph),
763 HashSet::from_iter((TestLabel::C,).into_array()),
764 "A outputs to C"
765 );
766
767 assert!(
768 input_nodes(TestLabel::B, &graph).is_empty(),
769 "B has no inputs"
770 );
771 assert_eq!(
772 output_nodes(TestLabel::B, &graph),
773 HashSet::from_iter((TestLabel::C,).into_array()),
774 "B outputs to C"
775 );
776
777 assert_eq!(
778 input_nodes(TestLabel::C, &graph),
779 HashSet::from_iter((TestLabel::A, TestLabel::B).into_array()),
780 "A and B input to C"
781 );
782 assert_eq!(
783 output_nodes(TestLabel::C, &graph),
784 HashSet::from_iter((TestLabel::D,).into_array()),
785 "C outputs to D"
786 );
787
788 assert_eq!(
789 input_nodes(TestLabel::D, &graph),
790 HashSet::from_iter((TestLabel::C,).into_array()),
791 "C inputs to D"
792 );
793 assert!(
794 output_nodes(TestLabel::D, &graph).is_empty(),
795 "D has no outputs"
796 );
797 }
798
799 #[test]
800 fn test_get_node_typed() {
801 struct MyNode {
802 value: usize,
803 }
804
805 impl Node for MyNode {
806 fn run(
807 &self,
808 _: &mut RenderGraphContext,
809 _: &mut RenderContext,
810 _: &World,
811 ) -> Result<(), NodeRunError> {
812 Ok(())
813 }
814 }
815
816 let mut graph = RenderGraph::default();
817
818 graph.add_node(TestLabel::A, MyNode { value: 42 });
819
820 let node: &MyNode = graph.get_node(TestLabel::A).unwrap();
821 assert_eq!(node.value, 42, "node value matches");
822
823 let result: Result<&TestNode, RenderGraphError> = graph.get_node(TestLabel::A);
824 assert_eq!(
825 result.unwrap_err(),
826 RenderGraphError::WrongNodeType,
827 "expect a wrong node type error"
828 );
829 }
830
831 #[test]
832 fn test_slot_already_occupied() {
833 let mut graph = RenderGraph::default();
834
835 graph.add_node(TestLabel::A, TestNode::new(0, 1));
836 graph.add_node(TestLabel::B, TestNode::new(0, 1));
837 graph.add_node(TestLabel::C, TestNode::new(1, 1));
838
839 graph.add_slot_edge(TestLabel::A, 0, TestLabel::C, 0);
840 assert_eq!(
841 graph.try_add_slot_edge(TestLabel::B, 0, TestLabel::C, 0),
842 Err(RenderGraphError::NodeInputSlotAlreadyOccupied {
843 node: TestLabel::C.intern(),
844 input_slot: 0,
845 occupied_by_node: TestLabel::A.intern(),
846 }),
847 "Adding to a slot that is already occupied should return an error"
848 );
849 }
850
851 #[test]
852 fn test_edge_already_exists() {
853 let mut graph = RenderGraph::default();
854
855 graph.add_node(TestLabel::A, TestNode::new(0, 1));
856 graph.add_node(TestLabel::B, TestNode::new(1, 0));
857
858 graph.add_slot_edge(TestLabel::A, 0, TestLabel::B, 0);
859 assert_eq!(
860 graph.try_add_slot_edge(TestLabel::A, 0, TestLabel::B, 0),
861 Err(RenderGraphError::EdgeAlreadyExists(Edge::SlotEdge {
862 output_node: TestLabel::A.intern(),
863 output_index: 0,
864 input_node: TestLabel::B.intern(),
865 input_index: 0,
866 })),
867 "Adding to a duplicate edge should return an error"
868 );
869 }
870
871 #[test]
872 fn test_add_node_edges() {
873 struct SimpleNode;
874 impl Node for SimpleNode {
875 fn run(
876 &self,
877 _graph: &mut RenderGraphContext,
878 _render_context: &mut RenderContext,
879 _world: &World,
880 ) -> Result<(), NodeRunError> {
881 Ok(())
882 }
883 }
884 impl FromWorld for SimpleNode {
885 fn from_world(_world: &mut World) -> Self {
886 Self
887 }
888 }
889
890 let mut graph = RenderGraph::default();
891 graph.add_node(TestLabel::A, SimpleNode);
892 graph.add_node(TestLabel::B, SimpleNode);
893 graph.add_node(TestLabel::C, SimpleNode);
894
895 graph.add_node_edges((TestLabel::A, TestLabel::B, TestLabel::C));
896
897 assert_eq!(
898 output_nodes(TestLabel::A, &graph),
899 HashSet::from_iter((TestLabel::B,).into_array()),
900 "A -> B"
901 );
902 assert_eq!(
903 input_nodes(TestLabel::B, &graph),
904 HashSet::from_iter((TestLabel::A,).into_array()),
905 "A -> B"
906 );
907 assert_eq!(
908 output_nodes(TestLabel::B, &graph),
909 HashSet::from_iter((TestLabel::C,).into_array()),
910 "B -> C"
911 );
912 assert_eq!(
913 input_nodes(TestLabel::C, &graph),
914 HashSet::from_iter((TestLabel::B,).into_array()),
915 "B -> C"
916 );
917 }
918}