1use alloc::vec::Vec;
2use core::{
3 fmt::{self, Debug},
4 hash::{BuildHasher, Hash},
5 ops::{Deref, DerefMut},
6};
7
8use bevy_platform::{
9 collections::{HashMap, HashSet},
10 hash::FixedHasher,
11};
12use fixedbitset::FixedBitSet;
13use indexmap::IndexSet;
14use thiserror::Error;
15
16use crate::{
17 error::Result,
18 schedule::graph::{
19 index, row_col, DiGraph, DiGraphToposortError,
20 Direction::{Incoming, Outgoing},
21 GraphNodeId, UnGraph,
22 },
23};
24
25#[derive(Clone)]
27pub struct Dag<N: GraphNodeId, S: BuildHasher = FixedHasher> {
28 graph: DiGraph<N, S>,
30 toposort: Vec<N>,
33 dirty: bool,
35}
36
37impl<N: GraphNodeId, S: BuildHasher> Dag<N, S> {
38 pub fn new() -> Self
40 where
41 S: Default,
42 {
43 Self::default()
44 }
45
46 #[must_use]
48 pub fn graph(&self) -> &DiGraph<N, S> {
49 &self.graph
50 }
51
52 #[must_use = "This function marks the graph as dirty, so it should be used."]
54 pub fn graph_mut(&mut self) -> &mut DiGraph<N, S> {
55 self.dirty = true;
56 &mut self.graph
57 }
58
59 #[must_use]
62 pub fn is_dirty(&self) -> bool {
63 self.dirty
64 }
65
66 #[must_use]
68 pub fn is_toposorted(&self) -> bool {
69 !self.dirty
70 }
71
72 pub fn ensure_toposorted(&mut self) -> Result<(), DiGraphToposortError<N>> {
80 if self.dirty {
81 self.toposort = self.graph.toposort(core::mem::take(&mut self.toposort))?;
83 self.dirty = false;
84 }
85 Ok(())
86 }
87
88 #[must_use = "This method only returns a cached value and does not compute anything."]
91 pub fn get_toposort(&self) -> Option<&[N]> {
92 if self.dirty {
93 None
94 } else {
95 Some(&self.toposort)
96 }
97 }
98
99 pub fn toposort(&mut self) -> Result<&[N], DiGraphToposortError<N>> {
107 self.ensure_toposorted()?;
108 Ok(&self.toposort)
109 }
110
111 pub fn toposort_and_graph(
122 &mut self,
123 ) -> Result<(&[N], &DiGraph<N, S>), DiGraphToposortError<N>> {
124 self.ensure_toposorted()?;
125 Ok((&self.toposort, &self.graph))
126 }
127
128 pub fn analyze(&mut self) -> Result<DagAnalysis<N, S>, DiGraphToposortError<N>>
142 where
143 S: Default,
144 {
145 let (toposort, graph) = self.toposort_and_graph()?;
146 Ok(DagAnalysis::new(graph, toposort))
147 }
148
149 pub fn remove_redundant_edges(&mut self, analysis: &DagAnalysis<N, S>)
156 where
157 S: Clone,
158 {
159 self.graph = analysis.transitive_reduction.clone();
162 }
163
164 pub fn group_by_key<K, V>(
176 &mut self,
177 num_groups: usize,
178 ) -> Result<DagGroups<K, V, S>, DiGraphToposortError<N>>
179 where
180 N: TryInto<K, Error = V>,
181 K: Eq + Hash,
182 V: Clone + Eq + Hash,
183 S: BuildHasher + Default,
184 {
185 let (toposort, graph) = self.toposort_and_graph()?;
186 Ok(DagGroups::with_capacity(num_groups, graph, toposort))
187 }
188
189 pub fn try_convert<T>(self) -> Result<Dag<T, S>, N::Error>
199 where
200 N: TryInto<T>,
201 T: GraphNodeId,
202 S: Default,
203 {
204 Ok(Dag {
205 graph: self.graph.try_convert()?,
206 toposort: Vec::new(),
207 dirty: true,
208 })
209 }
210}
211
212impl<N: GraphNodeId, S: BuildHasher> Deref for Dag<N, S> {
213 type Target = DiGraph<N, S>;
214
215 fn deref(&self) -> &Self::Target {
216 self.graph()
217 }
218}
219
220impl<N: GraphNodeId, S: BuildHasher> DerefMut for Dag<N, S> {
221 fn deref_mut(&mut self) -> &mut Self::Target {
222 self.graph_mut()
223 }
224}
225
226impl<N: GraphNodeId, S: BuildHasher + Default> Default for Dag<N, S> {
227 fn default() -> Self {
228 Self {
229 graph: Default::default(),
230 toposort: Default::default(),
231 dirty: false,
232 }
233 }
234}
235
236impl<N: GraphNodeId, S: BuildHasher> Debug for Dag<N, S> {
237 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
238 if self.dirty {
239 f.debug_struct("Dag")
240 .field("graph", &self.graph)
241 .field("dirty", &self.dirty)
242 .finish()
243 } else {
244 f.debug_struct("Dag")
245 .field("graph", &self.graph)
246 .field("toposort", &self.toposort)
247 .finish()
248 }
249 }
250}
251
252pub struct DagAnalysis<N: GraphNodeId, S: BuildHasher = FixedHasher> {
254 reachable: FixedBitSet,
256 connected: HashSet<(N, N), S>,
258 disconnected: Vec<(N, N)>,
260 transitive_edges: Vec<(N, N)>,
262 transitive_reduction: DiGraph<N, S>,
264 transitive_closure: DiGraph<N, S>,
266}
267
268impl<N: GraphNodeId, S: BuildHasher> DagAnalysis<N, S> {
269 pub fn new(graph: &DiGraph<N, S>, topological_order: &[N]) -> Self
281 where
282 S: Default,
283 {
284 if graph.node_count() == 0 {
285 return DagAnalysis::default();
286 }
287 let n = graph.node_count();
288
289 let mut map = <HashMap<_, _>>::with_capacity_and_hasher(n, Default::default());
291 let mut topsorted =
292 DiGraph::<N>::with_capacity(topological_order.len(), graph.edge_count());
293
294 for (i, &node) in topological_order.iter().enumerate() {
296 map.insert(node, i);
297 topsorted.add_node(node);
298 for pred in graph.neighbors_directed(node, Incoming) {
300 topsorted.add_edge(pred, node);
301 }
302 }
303
304 let mut reachable = FixedBitSet::with_capacity(n * n);
305 let mut connected = HashSet::default();
306 let mut disconnected = Vec::default();
307 let mut transitive_edges = Vec::default();
308 let mut transitive_reduction = DiGraph::with_capacity(topsorted.node_count(), 0);
309 let mut transitive_closure = DiGraph::with_capacity(topsorted.node_count(), 0);
310
311 let mut visited = FixedBitSet::with_capacity(n);
312
313 for node in topsorted.nodes() {
315 transitive_reduction.add_node(node);
316 transitive_closure.add_node(node);
317 }
318
319 for a in topsorted.nodes().rev() {
321 let index_a = *map.get(&a).unwrap();
322 for b in topsorted.neighbors_directed(a, Outgoing) {
324 let index_b = *map.get(&b).unwrap();
325 debug_assert!(index_a < index_b);
326 if !visited[index_b] {
327 transitive_reduction.add_edge(a, b);
329 transitive_closure.add_edge(a, b);
330 reachable.insert(index(index_a, index_b, n));
331
332 let successors = transitive_closure
333 .neighbors_directed(b, Outgoing)
334 .collect::<Vec<_>>();
335 for c in successors {
336 let index_c = *map.get(&c).unwrap();
337 debug_assert!(index_b < index_c);
338 if !visited[index_c] {
339 visited.insert(index_c);
340 transitive_closure.add_edge(a, c);
341 reachable.insert(index(index_a, index_c, n));
342 }
343 }
344 } else {
345 transitive_edges.push((a, b));
347 }
348 }
349
350 visited.clear();
351 }
352
353 for i in 0..(n - 1) {
355 for index in index(i, i + 1, n)..=index(i, n - 1, n) {
357 let (a, b) = row_col(index, n);
358 let pair = (topological_order[a], topological_order[b]);
359 if reachable[index] {
360 connected.insert(pair);
361 } else {
362 disconnected.push(pair);
363 }
364 }
365 }
366
367 DagAnalysis {
373 reachable,
374 connected,
375 disconnected,
376 transitive_edges,
377 transitive_reduction,
378 transitive_closure,
379 }
380 }
381
382 pub fn reachable(&self) -> &FixedBitSet {
384 &self.reachable
385 }
386
387 pub fn connected(&self) -> &HashSet<(N, N), S> {
389 &self.connected
390 }
391
392 pub fn disconnected(&self) -> &[(N, N)] {
394 &self.disconnected
395 }
396
397 pub fn transitive_edges(&self) -> &[(N, N)] {
399 &self.transitive_edges
400 }
401
402 pub fn transitive_reduction(&self) -> &DiGraph<N, S> {
404 &self.transitive_reduction
405 }
406
407 pub fn transitive_closure(&self) -> &DiGraph<N, S> {
409 &self.transitive_closure
410 }
411
412 pub fn check_for_redundant_edges(&self) -> Result<(), DagRedundancyError<N>>
419 where
420 S: Clone,
421 {
422 if self.transitive_edges.is_empty() {
423 Ok(())
424 } else {
425 Err(DagRedundancyError(self.transitive_edges.clone()))
426 }
427 }
428
429 pub fn check_for_cross_dependencies(
437 &self,
438 other: &Self,
439 ) -> Result<(), DagCrossDependencyError<N>> {
440 for &(a, b) in &self.connected {
441 if other.connected.contains(&(a, b)) || other.connected.contains(&(b, a)) {
442 return Err(DagCrossDependencyError(a, b));
443 }
444 }
445
446 Ok(())
447 }
448
449 pub fn check_for_overlapping_groups<K, V>(
457 &self,
458 groups: &DagGroups<K, V>,
459 ) -> Result<(), DagOverlappingGroupError<K>>
460 where
461 N: TryInto<K>,
462 K: Eq + Hash,
463 V: Eq + Hash,
464 {
465 for &(a, b) in &self.connected {
466 let (Ok(a_key), Ok(b_key)) = (a.try_into(), b.try_into()) else {
467 continue;
468 };
469 let a_group = groups.get(&a_key).unwrap();
470 let b_group = groups.get(&b_key).unwrap();
471 if !a_group.is_disjoint(b_group) {
472 return Err(DagOverlappingGroupError(a_key, b_key));
473 }
474 }
475 Ok(())
476 }
477}
478
479impl<N: GraphNodeId, S: BuildHasher + Default> Default for DagAnalysis<N, S> {
480 fn default() -> Self {
481 Self {
482 reachable: Default::default(),
483 connected: Default::default(),
484 disconnected: Default::default(),
485 transitive_edges: Default::default(),
486 transitive_reduction: Default::default(),
487 transitive_closure: Default::default(),
488 }
489 }
490}
491
492impl<N: GraphNodeId, S: BuildHasher> Debug for DagAnalysis<N, S> {
493 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
494 f.debug_struct("DagAnalysis")
495 .field("reachable", &self.reachable)
496 .field("connected", &self.connected)
497 .field("disconnected", &self.disconnected)
498 .field("transitive_edges", &self.transitive_edges)
499 .field("transitive_reduction", &self.transitive_reduction)
500 .field("transitive_closure", &self.transitive_closure)
501 .finish()
502 }
503}
504
505pub struct DagGroups<K, V, S = FixedHasher>(HashMap<K, IndexSet<V, S>, S>);
507
508impl<K: Eq + Hash, V: Clone + Eq + Hash, S: BuildHasher + Default> DagGroups<K, V, S> {
509 pub fn new<N>(graph: &DiGraph<N, S>, toposort: &[N]) -> Self
515 where
516 N: GraphNodeId + TryInto<K, Error = V>,
517 {
518 Self::with_capacity(0, graph, toposort)
519 }
520
521 pub fn with_capacity<N>(capacity: usize, graph: &DiGraph<N, S>, toposort: &[N]) -> Self
528 where
529 N: GraphNodeId + TryInto<K, Error = V>,
530 {
531 let mut groups: HashMap<K, IndexSet<V, S>, S> =
532 HashMap::with_capacity_and_hasher(capacity, Default::default());
533
534 for &id in toposort.iter().rev() {
536 let Ok(key) = id.try_into() else {
537 continue;
538 };
539
540 let mut children = IndexSet::default();
541
542 for node in graph.neighbors_directed(id, Outgoing) {
543 match node.try_into() {
544 Ok(key) => {
545 let key_children = groups.get(&key).unwrap();
547 children.extend(key_children.iter().cloned());
548 }
549 Err(value) => {
550 children.insert(value);
552 }
553 }
554 }
555
556 groups.insert(key, children);
557 }
558
559 Self(groups)
560 }
561}
562
563impl<K: GraphNodeId, V: GraphNodeId, S: BuildHasher> DagGroups<K, V, S> {
564 pub fn flatten<N>(
573 &self,
574 dag: Dag<N>,
575 mut collapse_group: impl FnMut(K, &IndexSet<V, S>, &Dag<N>, &mut Vec<(N, N)>),
576 ) -> Dag<V>
577 where
578 N: GraphNodeId + TryInto<V, Error = K> + From<K> + From<V>,
579 {
580 let mut flattening = dag;
581 let mut temp = Vec::new();
582
583 for (&key, values) in self.iter() {
584 collapse_group(key, values, &flattening, &mut temp);
586
587 if values.is_empty() {
588 for a in flattening.neighbors_directed(N::from(key), Incoming) {
590 for b in flattening.neighbors_directed(N::from(key), Outgoing) {
591 temp.push((a, b));
592 }
593 }
594 } else {
595 for a in flattening.neighbors_directed(N::from(key), Incoming) {
597 for &value in values {
598 temp.push((a, N::from(value)));
599 }
600 }
601 for b in flattening.neighbors_directed(N::from(key), Outgoing) {
602 for &value in values {
603 temp.push((N::from(value), b));
604 }
605 }
606 }
607
608 flattening.remove_node(N::from(key));
610 flattening.reserve_edges(temp.len());
612 for (a, b) in temp.drain(..) {
613 flattening.add_edge(a, b);
614 }
615 }
616
617 flattening
620 .try_convert::<V>()
621 .unwrap_or_else(|n| unreachable!("Flattened graph has a leftover key {n:?}"))
622 }
623
624 pub fn flatten_undirected<N>(&self, graph: &UnGraph<N>) -> UnGraph<V>
630 where
631 N: GraphNodeId + TryInto<V, Error = K>,
632 {
633 let mut flattened = UnGraph::default();
634
635 for (lhs, rhs) in graph.all_edges() {
636 match (lhs.try_into(), rhs.try_into()) {
637 (Ok(lhs), Ok(rhs)) => {
638 flattened.add_edge(lhs, rhs);
640 }
641 (Err(lhs_key), Ok(rhs)) => {
642 let Some(lhs_group) = self.get(&lhs_key) else {
644 continue;
645 };
646 flattened.reserve_edges(lhs_group.len());
647 for &lhs in lhs_group {
648 flattened.add_edge(lhs, rhs);
649 }
650 }
651 (Ok(lhs), Err(rhs_key)) => {
652 let Some(rhs_group) = self.get(&rhs_key) else {
654 continue;
655 };
656 flattened.reserve_edges(rhs_group.len());
657 for &rhs in rhs_group {
658 flattened.add_edge(lhs, rhs);
659 }
660 }
661 (Err(lhs_key), Err(rhs_key)) => {
662 let Some(lhs_group) = self.get(&lhs_key) else {
664 continue;
665 };
666 let Some(rhs_group) = self.get(&rhs_key) else {
667 continue;
668 };
669 flattened.reserve_edges(lhs_group.len() * rhs_group.len());
670 for &lhs in lhs_group {
671 for &rhs in rhs_group {
672 flattened.add_edge(lhs, rhs);
673 }
674 }
675 }
676 }
677 }
678
679 flattened
680 }
681}
682
683impl<K, V, S> Deref for DagGroups<K, V, S> {
684 type Target = HashMap<K, IndexSet<V, S>, S>;
685
686 fn deref(&self) -> &Self::Target {
687 &self.0
688 }
689}
690
691impl<K, V, S> DerefMut for DagGroups<K, V, S> {
692 fn deref_mut(&mut self) -> &mut Self::Target {
693 &mut self.0
694 }
695}
696
697impl<K, V, S> Default for DagGroups<K, V, S>
698where
699 S: BuildHasher + Default,
700{
701 fn default() -> Self {
702 Self(Default::default())
703 }
704}
705
706impl<K: Debug, V: Debug, S> Debug for DagGroups<K, V, S> {
707 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
708 f.debug_tuple("DagGroups").field(&self.0).finish()
709 }
710}
711
712#[derive(Error, Debug)]
714#[error("DAG has redundant edges: {0:?}")]
715pub struct DagRedundancyError<N: GraphNodeId>(pub Vec<(N, N)>);
716
717#[derive(Error, Debug)]
719#[error("DAG has a cross-dependency between nodes {0:?} and {1:?}")]
720pub struct DagCrossDependencyError<N>(pub N, pub N);
721
722#[derive(Error, Debug)]
724#[error("DAG has overlapping groups between keys {0:?} and {1:?}")]
725pub struct DagOverlappingGroupError<K>(pub K, pub K);
726
727#[cfg(test)]
728mod tests {
729 use core::ops::DerefMut;
730
731 use crate::schedule::graph::{index, Dag, Direction, GraphNodeId, UnGraph};
732
733 #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
734 struct TestNode(u32);
735
736 impl GraphNodeId for TestNode {
737 type Adjacent = (TestNode, Direction);
738 type Edge = (TestNode, TestNode);
739
740 fn kind(&self) -> &'static str {
741 "test node"
742 }
743 }
744
745 #[test]
746 fn mark_dirty() {
747 {
748 let mut dag = Dag::<TestNode>::new();
749 dag.add_node(TestNode(1));
750 assert!(dag.is_dirty());
751 }
752 {
753 let mut dag = Dag::<TestNode>::new();
754 dag.add_edge(TestNode(1), TestNode(2));
755 assert!(dag.is_dirty());
756 }
757 {
758 let mut dag = Dag::<TestNode>::new();
759 dag.deref_mut();
760 assert!(dag.is_dirty());
761 }
762 {
763 let mut dag = Dag::<TestNode>::new();
764 let _ = dag.graph_mut();
765 assert!(dag.is_dirty());
766 }
767 }
768
769 #[test]
770 fn toposort() {
771 let mut dag = Dag::<TestNode>::new();
772 dag.add_edge(TestNode(1), TestNode(2));
773 dag.add_edge(TestNode(2), TestNode(3));
774 dag.add_edge(TestNode(1), TestNode(3));
775
776 assert_eq!(
777 dag.toposort().unwrap(),
778 &[TestNode(1), TestNode(2), TestNode(3)]
779 );
780 assert_eq!(
781 dag.get_toposort().unwrap(),
782 &[TestNode(1), TestNode(2), TestNode(3)]
783 );
784 }
785
786 #[test]
787 fn analyze() {
788 let mut dag1 = Dag::<TestNode>::new();
789 dag1.add_edge(TestNode(1), TestNode(2));
790 dag1.add_edge(TestNode(2), TestNode(3));
791 dag1.add_edge(TestNode(1), TestNode(3)); let analysis1 = dag1.analyze().unwrap();
794
795 assert!(analysis1.reachable().contains(index(0, 1, 3)));
796 assert!(analysis1.reachable().contains(index(1, 2, 3)));
797 assert!(analysis1.reachable().contains(index(0, 2, 3)));
798
799 assert!(analysis1.connected().contains(&(TestNode(1), TestNode(2))));
800 assert!(analysis1.connected().contains(&(TestNode(2), TestNode(3))));
801 assert!(analysis1.connected().contains(&(TestNode(1), TestNode(3))));
802
803 assert!(!analysis1
804 .disconnected()
805 .contains(&(TestNode(2), TestNode(1))));
806 assert!(!analysis1
807 .disconnected()
808 .contains(&(TestNode(3), TestNode(2))));
809 assert!(!analysis1
810 .disconnected()
811 .contains(&(TestNode(3), TestNode(1))));
812
813 assert!(analysis1
814 .transitive_edges()
815 .contains(&(TestNode(1), TestNode(3))));
816
817 assert!(analysis1.check_for_redundant_edges().is_err());
818
819 let mut dag2 = Dag::<TestNode>::new();
820 dag2.add_edge(TestNode(3), TestNode(4));
821
822 let analysis2 = dag2.analyze().unwrap();
823
824 assert!(analysis2.check_for_redundant_edges().is_ok());
825 assert!(analysis1.check_for_cross_dependencies(&analysis2).is_ok());
826
827 let mut dag3 = Dag::<TestNode>::new();
828 dag3.add_edge(TestNode(1), TestNode(2));
829
830 let analysis3 = dag3.analyze().unwrap();
831
832 assert!(analysis1.check_for_cross_dependencies(&analysis3).is_err());
833
834 dag1.remove_redundant_edges(&analysis1);
835 let analysis1 = dag1.analyze().unwrap();
836 assert!(analysis1.check_for_redundant_edges().is_ok());
837 }
838
839 #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
840 enum Node {
841 Key(Key),
842 Value(Value),
843 }
844 #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
845 struct Key(u32);
846 #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
847 struct Value(u32);
848
849 impl GraphNodeId for Node {
850 type Adjacent = (Node, Direction);
851 type Edge = (Node, Node);
852
853 fn kind(&self) -> &'static str {
854 "node"
855 }
856 }
857
858 impl TryInto<Key> for Node {
859 type Error = Value;
860
861 fn try_into(self) -> Result<Key, Value> {
862 match self {
863 Node::Key(k) => Ok(k),
864 Node::Value(v) => Err(v),
865 }
866 }
867 }
868
869 impl TryInto<Value> for Node {
870 type Error = Key;
871
872 fn try_into(self) -> Result<Value, Key> {
873 match self {
874 Node::Value(v) => Ok(v),
875 Node::Key(k) => Err(k),
876 }
877 }
878 }
879
880 impl GraphNodeId for Key {
881 type Adjacent = (Key, Direction);
882 type Edge = (Key, Key);
883
884 fn kind(&self) -> &'static str {
885 "key"
886 }
887 }
888
889 impl GraphNodeId for Value {
890 type Adjacent = (Value, Direction);
891 type Edge = (Value, Value);
892
893 fn kind(&self) -> &'static str {
894 "value"
895 }
896 }
897
898 impl From<Key> for Node {
899 fn from(key: Key) -> Self {
900 Node::Key(key)
901 }
902 }
903
904 impl From<Value> for Node {
905 fn from(value: Value) -> Self {
906 Node::Value(value)
907 }
908 }
909
910 #[test]
911 fn group_by_key() {
912 let mut dag = Dag::<Node>::new();
913 dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10)));
914 dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11)));
915 dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20)));
916 dag.add_edge(Node::Key(Key(2)), Node::Key(Key(1)));
917 dag.add_edge(Node::Value(Value(10)), Node::Value(Value(11)));
918
919 let groups = dag.group_by_key::<Key, Value>(2).unwrap();
920 assert_eq!(groups.len(), 2);
921
922 let group_key1 = groups.get(&Key(1)).unwrap();
923 assert!(group_key1.contains(&Value(10)));
924 assert!(group_key1.contains(&Value(11)));
925
926 let group_key2 = groups.get(&Key(2)).unwrap();
927 assert!(group_key2.contains(&Value(10)));
928 assert!(group_key2.contains(&Value(11)));
929 assert!(group_key2.contains(&Value(20)));
930 }
931
932 #[test]
933 fn flatten() {
934 let mut dag = Dag::<Node>::new();
935 dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10)));
936 dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11)));
937 dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20)));
938 dag.add_edge(Node::Key(Key(2)), Node::Value(Value(21)));
939 dag.add_edge(Node::Value(Value(30)), Node::Key(Key(1)));
940 dag.add_edge(Node::Key(Key(1)), Node::Value(Value(40)));
941
942 let groups = dag.group_by_key::<Key, Value>(2).unwrap();
943 let flattened = groups.flatten(dag, |_key, _values, _dag, _temp| {});
944
945 assert!(flattened.contains_node(Value(10)));
946 assert!(flattened.contains_node(Value(11)));
947 assert!(flattened.contains_node(Value(20)));
948 assert!(flattened.contains_node(Value(21)));
949 assert!(flattened.contains_node(Value(30)));
950 assert!(flattened.contains_node(Value(40)));
951
952 assert!(flattened.contains_edge(Value(30), Value(10)));
953 assert!(flattened.contains_edge(Value(30), Value(11)));
954 assert!(flattened.contains_edge(Value(10), Value(40)));
955 assert!(flattened.contains_edge(Value(11), Value(40)));
956 }
957
958 #[test]
959 fn flatten_undirected() {
960 let mut dag = Dag::<Node>::new();
961 dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10)));
962 dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11)));
963 dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20)));
964 dag.add_edge(Node::Key(Key(2)), Node::Value(Value(21)));
965
966 let groups = dag.group_by_key::<Key, Value>(2).unwrap();
967
968 let mut ungraph = UnGraph::<Node>::default();
969 ungraph.add_edge(Node::Value(Value(10)), Node::Value(Value(11)));
970 ungraph.add_edge(Node::Key(Key(1)), Node::Value(Value(30)));
971 ungraph.add_edge(Node::Value(Value(40)), Node::Key(Key(2)));
972 ungraph.add_edge(Node::Key(Key(1)), Node::Key(Key(2)));
973
974 let flattened = groups.flatten_undirected(&ungraph);
975
976 assert!(flattened.contains_edge(Value(10), Value(11)));
977 assert!(flattened.contains_edge(Value(10), Value(30)));
978 assert!(flattened.contains_edge(Value(11), Value(30)));
979 assert!(flattened.contains_edge(Value(40), Value(20)));
980 assert!(flattened.contains_edge(Value(40), Value(21)));
981 assert!(flattened.contains_edge(Value(10), Value(20)));
982 assert!(flattened.contains_edge(Value(10), Value(21)));
983 assert!(flattened.contains_edge(Value(11), Value(20)));
984 assert!(flattened.contains_edge(Value(11), Value(21)));
985 }
986
987 #[test]
988 fn overlapping_groups() {
989 let mut dag = Dag::<Node>::new();
990 dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10)));
991 dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11)));
992 dag.add_edge(Node::Key(Key(2)), Node::Value(Value(11))); dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20)));
994 dag.add_edge(Node::Key(Key(1)), Node::Key(Key(2)));
995
996 let groups = dag.group_by_key::<Key, Value>(2).unwrap();
997 let analysis = dag.analyze().unwrap();
998
999 let result = analysis.check_for_overlapping_groups(&groups);
1000 assert!(result.is_err());
1001 }
1002
1003 #[test]
1004 fn disjoint_groups() {
1005 let mut dag = Dag::<Node>::new();
1006 dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10)));
1007 dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11)));
1008 dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20)));
1009 dag.add_edge(Node::Key(Key(2)), Node::Value(Value(21)));
1010
1011 let groups = dag.group_by_key::<Key, Value>(2).unwrap();
1012 let analysis = dag.analyze().unwrap();
1013
1014 let result = analysis.check_for_overlapping_groups(&groups);
1015 assert!(result.is_ok());
1016 }
1017}