bevy_ecs/schedule/graph/
dag.rs

1use alloc::vec::Vec;
2use core::{
3    fmt::{self, Debug},
4    hash::{BuildHasher, Hash},
5    ops::{Deref, DerefMut},
6};
7
8use bevy_platform::{
9    collections::{HashMap, HashSet},
10    hash::FixedHasher,
11};
12use fixedbitset::FixedBitSet;
13use indexmap::IndexSet;
14use thiserror::Error;
15
16use crate::{
17    error::Result,
18    schedule::graph::{
19        index, row_col, DiGraph, DiGraphToposortError,
20        Direction::{Incoming, Outgoing},
21        GraphNodeId, UnGraph,
22    },
23};
24
25/// A directed acyclic graph structure.
26#[derive(Clone)]
27pub struct Dag<N: GraphNodeId, S: BuildHasher = FixedHasher> {
28    /// The underlying directed graph.
29    graph: DiGraph<N, S>,
30    /// A cached topological ordering of the graph. This is recomputed when the
31    /// graph is modified, and is not valid when `dirty` is true.
32    toposort: Vec<N>,
33    /// Whether the graph has been modified since the last topological sort.
34    dirty: bool,
35}
36
37impl<N: GraphNodeId, S: BuildHasher> Dag<N, S> {
38    /// Creates a new directed acyclic graph.
39    pub fn new() -> Self
40    where
41        S: Default,
42    {
43        Self::default()
44    }
45
46    /// Read-only access to the underlying directed graph.
47    #[must_use]
48    pub fn graph(&self) -> &DiGraph<N, S> {
49        &self.graph
50    }
51
52    /// Mutable access to the underlying directed graph. Marks the graph as dirty.
53    #[must_use = "This function marks the graph as dirty, so it should be used."]
54    pub fn graph_mut(&mut self) -> &mut DiGraph<N, S> {
55        self.dirty = true;
56        &mut self.graph
57    }
58
59    /// Returns whether the graph is dirty (i.e., has been modified since the
60    /// last topological sort).
61    #[must_use]
62    pub fn is_dirty(&self) -> bool {
63        self.dirty
64    }
65
66    /// Returns whether the graph is topologically sorted (i.e., not dirty).
67    #[must_use]
68    pub fn is_toposorted(&self) -> bool {
69        !self.dirty
70    }
71
72    /// Ensures the graph is topologically sorted, recomputing the toposort if
73    /// the graph is dirty.
74    ///
75    /// # Errors
76    ///
77    /// Returns [`DiGraphToposortError`] if the DAG is dirty and cannot be
78    /// topologically sorted.
79    pub fn ensure_toposorted(&mut self) -> Result<(), DiGraphToposortError<N>> {
80        if self.dirty {
81            // recompute the toposort, reusing the existing allocation
82            self.toposort = self.graph.toposort(core::mem::take(&mut self.toposort))?;
83            self.dirty = false;
84        }
85        Ok(())
86    }
87
88    /// Returns the cached toposort if the graph is not dirty, otherwise returns
89    /// `None`.
90    #[must_use = "This method only returns a cached value and does not compute anything."]
91    pub fn get_toposort(&self) -> Option<&[N]> {
92        if self.dirty {
93            None
94        } else {
95            Some(&self.toposort)
96        }
97    }
98
99    /// Returns a topological ordering of the graph, computing it if the graph
100    /// is dirty.
101    ///
102    /// # Errors
103    ///
104    /// Returns [`DiGraphToposortError`] if the DAG is dirty and cannot be
105    /// topologically sorted.
106    pub fn toposort(&mut self) -> Result<&[N], DiGraphToposortError<N>> {
107        self.ensure_toposorted()?;
108        Ok(&self.toposort)
109    }
110
111    /// Returns both the topological ordering and the underlying graph,
112    /// computing the toposort if the graph is dirty.
113    ///
114    /// This function is useful to avoid multiple borrow issues when both
115    /// the graph and the toposort are needed.
116    ///
117    /// # Errors
118    ///
119    /// Returns [`DiGraphToposortError`] if the DAG is dirty and cannot be
120    /// topologically sorted.
121    pub fn toposort_and_graph(
122        &mut self,
123    ) -> Result<(&[N], &DiGraph<N, S>), DiGraphToposortError<N>> {
124        self.ensure_toposorted()?;
125        Ok((&self.toposort, &self.graph))
126    }
127
128    /// Processes a DAG and computes various properties about it.
129    ///
130    /// See [`DagAnalysis::new`] for details on what is computed.
131    ///
132    /// # Note
133    ///
134    /// If the DAG is dirty, this method will first attempt to topologically sort it.
135    ///
136    /// # Errors
137    ///
138    /// Returns [`DiGraphToposortError`] if the DAG is dirty and cannot be
139    /// topologically sorted.
140    ///
141    pub fn analyze(&mut self) -> Result<DagAnalysis<N, S>, DiGraphToposortError<N>>
142    where
143        S: Default,
144    {
145        let (toposort, graph) = self.toposort_and_graph()?;
146        Ok(DagAnalysis::new(graph, toposort))
147    }
148
149    /// Replaces the current graph with its transitive reduction based on the
150    /// provided analysis.
151    ///
152    /// # Note
153    ///
154    /// The given [`DagAnalysis`] must have been generated from this DAG.
155    pub fn remove_redundant_edges(&mut self, analysis: &DagAnalysis<N, S>)
156    where
157        S: Clone,
158    {
159        // We don't need to mark the graph as dirty, since transitive reduction
160        // is guaranteed to have the same topological ordering as the original graph.
161        self.graph = analysis.transitive_reduction.clone();
162    }
163
164    /// Groups nodes in this DAG by a key type `K`, collecting value nodes `V`
165    /// under all of their ancestor key nodes. `num_groups` hints at the
166    /// expected number of groups, for memory allocation optimization.
167    ///
168    /// The node type `N` must be convertible into either a key type `K` or
169    /// a value type `V` via the [`TryInto`] trait.
170    ///
171    /// # Errors
172    ///
173    /// Returns [`DiGraphToposortError`] if the DAG is dirty and cannot be
174    /// topologically sorted.
175    pub fn group_by_key<K, V>(
176        &mut self,
177        num_groups: usize,
178    ) -> Result<DagGroups<K, V, S>, DiGraphToposortError<N>>
179    where
180        N: TryInto<K, Error = V>,
181        K: Eq + Hash,
182        V: Clone + Eq + Hash,
183        S: BuildHasher + Default,
184    {
185        let (toposort, graph) = self.toposort_and_graph()?;
186        Ok(DagGroups::with_capacity(num_groups, graph, toposort))
187    }
188
189    /// Converts from one [`GraphNodeId`] type to another. If the conversion fails,
190    /// it returns the error from the target type's [`TryFrom`] implementation.
191    ///
192    /// Nodes must uniquely convert from `N` to `T` (i.e. no two `N` can convert
193    /// to the same `T`). The resulting DAG must be re-topologically sorted.
194    ///
195    /// # Errors
196    ///
197    /// If the conversion fails, it returns an error of type `N::Error`.
198    pub fn try_convert<T>(self) -> Result<Dag<T, S>, N::Error>
199    where
200        N: TryInto<T>,
201        T: GraphNodeId,
202        S: Default,
203    {
204        Ok(Dag {
205            graph: self.graph.try_convert()?,
206            toposort: Vec::new(),
207            dirty: true,
208        })
209    }
210}
211
212impl<N: GraphNodeId, S: BuildHasher> Deref for Dag<N, S> {
213    type Target = DiGraph<N, S>;
214
215    fn deref(&self) -> &Self::Target {
216        self.graph()
217    }
218}
219
220impl<N: GraphNodeId, S: BuildHasher> DerefMut for Dag<N, S> {
221    fn deref_mut(&mut self) -> &mut Self::Target {
222        self.graph_mut()
223    }
224}
225
226impl<N: GraphNodeId, S: BuildHasher + Default> Default for Dag<N, S> {
227    fn default() -> Self {
228        Self {
229            graph: Default::default(),
230            toposort: Default::default(),
231            dirty: false,
232        }
233    }
234}
235
236impl<N: GraphNodeId, S: BuildHasher> Debug for Dag<N, S> {
237    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
238        if self.dirty {
239            f.debug_struct("Dag")
240                .field("graph", &self.graph)
241                .field("dirty", &self.dirty)
242                .finish()
243        } else {
244            f.debug_struct("Dag")
245                .field("graph", &self.graph)
246                .field("toposort", &self.toposort)
247                .finish()
248        }
249    }
250}
251
252/// Stores the results of a call to [`Dag::analyze`].
253pub struct DagAnalysis<N: GraphNodeId, S: BuildHasher = FixedHasher> {
254    /// Boolean reachability matrix for the graph.
255    reachable: FixedBitSet,
256    /// Pairs of nodes that have a path connecting them.
257    connected: HashSet<(N, N), S>,
258    /// Pairs of nodes that don't have a path connecting them.
259    disconnected: Vec<(N, N)>,
260    /// Edges that are redundant because a longer path exists.
261    transitive_edges: Vec<(N, N)>,
262    /// Variant of the graph with no transitive edges.
263    transitive_reduction: DiGraph<N, S>,
264    /// Variant of the graph with all possible transitive edges.
265    transitive_closure: DiGraph<N, S>,
266}
267
268impl<N: GraphNodeId, S: BuildHasher> DagAnalysis<N, S> {
269    /// Processes a DAG and computes its:
270    /// - transitive reduction (along with the set of removed edges)
271    /// - transitive closure
272    /// - reachability matrix (as a bitset)
273    /// - pairs of nodes connected by a path
274    /// - pairs of nodes not connected by a path
275    ///
276    /// The algorithm implemented comes from
277    /// ["On the calculation of transitive reduction-closure of orders"][1] by Habib, Morvan and Rampon.
278    ///
279    /// [1]: https://doi.org/10.1016/0012-365X(93)90164-O
280    pub fn new(graph: &DiGraph<N, S>, topological_order: &[N]) -> Self
281    where
282        S: Default,
283    {
284        if graph.node_count() == 0 {
285            return DagAnalysis::default();
286        }
287        let n = graph.node_count();
288
289        // build a copy of the graph where the nodes and edges appear in topsorted order
290        let mut map = <HashMap<_, _>>::with_capacity_and_hasher(n, Default::default());
291        let mut topsorted =
292            DiGraph::<N>::with_capacity(topological_order.len(), graph.edge_count());
293
294        // iterate nodes in topological order
295        for (i, &node) in topological_order.iter().enumerate() {
296            map.insert(node, i);
297            topsorted.add_node(node);
298            // insert nodes as successors to their predecessors
299            for pred in graph.neighbors_directed(node, Incoming) {
300                topsorted.add_edge(pred, node);
301            }
302        }
303
304        let mut reachable = FixedBitSet::with_capacity(n * n);
305        let mut connected = HashSet::default();
306        let mut disconnected = Vec::default();
307        let mut transitive_edges = Vec::default();
308        let mut transitive_reduction = DiGraph::with_capacity(topsorted.node_count(), 0);
309        let mut transitive_closure = DiGraph::with_capacity(topsorted.node_count(), 0);
310
311        let mut visited = FixedBitSet::with_capacity(n);
312
313        // iterate nodes in topological order
314        for node in topsorted.nodes() {
315            transitive_reduction.add_node(node);
316            transitive_closure.add_node(node);
317        }
318
319        // iterate nodes in reverse topological order
320        for a in topsorted.nodes().rev() {
321            let index_a = *map.get(&a).unwrap();
322            // iterate their successors in topological order
323            for b in topsorted.neighbors_directed(a, Outgoing) {
324                let index_b = *map.get(&b).unwrap();
325                debug_assert!(index_a < index_b);
326                if !visited[index_b] {
327                    // edge <a, b> is not redundant
328                    transitive_reduction.add_edge(a, b);
329                    transitive_closure.add_edge(a, b);
330                    reachable.insert(index(index_a, index_b, n));
331
332                    let successors = transitive_closure
333                        .neighbors_directed(b, Outgoing)
334                        .collect::<Vec<_>>();
335                    for c in successors {
336                        let index_c = *map.get(&c).unwrap();
337                        debug_assert!(index_b < index_c);
338                        if !visited[index_c] {
339                            visited.insert(index_c);
340                            transitive_closure.add_edge(a, c);
341                            reachable.insert(index(index_a, index_c, n));
342                        }
343                    }
344                } else {
345                    // edge <a, b> is redundant
346                    transitive_edges.push((a, b));
347                }
348            }
349
350            visited.clear();
351        }
352
353        // partition pairs of nodes into "connected by path" and "not connected by path"
354        for i in 0..(n - 1) {
355            // reachable is upper triangular because the nodes were topsorted
356            for index in index(i, i + 1, n)..=index(i, n - 1, n) {
357                let (a, b) = row_col(index, n);
358                let pair = (topological_order[a], topological_order[b]);
359                if reachable[index] {
360                    connected.insert(pair);
361                } else {
362                    disconnected.push(pair);
363                }
364            }
365        }
366
367        // fill diagonal (nodes reach themselves)
368        // for i in 0..n {
369        //     reachable.set(index(i, i, n), true);
370        // }
371
372        DagAnalysis {
373            reachable,
374            connected,
375            disconnected,
376            transitive_edges,
377            transitive_reduction,
378            transitive_closure,
379        }
380    }
381
382    /// Returns the reachability matrix.
383    pub fn reachable(&self) -> &FixedBitSet {
384        &self.reachable
385    }
386
387    /// Returns the set of node pairs that are connected by a path.
388    pub fn connected(&self) -> &HashSet<(N, N), S> {
389        &self.connected
390    }
391
392    /// Returns the list of node pairs that are not connected by a path.
393    pub fn disconnected(&self) -> &[(N, N)] {
394        &self.disconnected
395    }
396
397    /// Returns the list of redundant edges because a longer path exists.
398    pub fn transitive_edges(&self) -> &[(N, N)] {
399        &self.transitive_edges
400    }
401
402    /// Returns the transitive reduction of the graph.
403    pub fn transitive_reduction(&self) -> &DiGraph<N, S> {
404        &self.transitive_reduction
405    }
406
407    /// Returns the transitive closure of the graph.
408    pub fn transitive_closure(&self) -> &DiGraph<N, S> {
409        &self.transitive_closure
410    }
411
412    /// Checks if the graph has any redundant (transitive) edges.
413    ///
414    /// # Errors
415    ///
416    /// If there are redundant edges, returns a [`DagRedundancyError`]
417    /// containing the list of redundant edges.
418    pub fn check_for_redundant_edges(&self) -> Result<(), DagRedundancyError<N>>
419    where
420        S: Clone,
421    {
422        if self.transitive_edges.is_empty() {
423            Ok(())
424        } else {
425            Err(DagRedundancyError(self.transitive_edges.clone()))
426        }
427    }
428
429    /// Checks if there are any pairs of nodes that have a path in both this
430    /// graph and another graph.
431    ///
432    /// # Errors
433    ///
434    /// Returns [`DagCrossDependencyError`] if any node pair is connected in
435    /// both graphs.
436    pub fn check_for_cross_dependencies(
437        &self,
438        other: &Self,
439    ) -> Result<(), DagCrossDependencyError<N>> {
440        for &(a, b) in &self.connected {
441            if other.connected.contains(&(a, b)) || other.connected.contains(&(b, a)) {
442                return Err(DagCrossDependencyError(a, b));
443            }
444        }
445
446        Ok(())
447    }
448
449    /// Checks if any connected node pairs that are both keys have overlapping
450    /// groups.
451    ///
452    /// # Errors
453    ///
454    /// If there are overlapping groups, returns a [`DagOverlappingGroupError`]
455    /// containing the first pair of keys that have overlapping groups.
456    pub fn check_for_overlapping_groups<K, V>(
457        &self,
458        groups: &DagGroups<K, V>,
459    ) -> Result<(), DagOverlappingGroupError<K>>
460    where
461        N: TryInto<K>,
462        K: Eq + Hash,
463        V: Eq + Hash,
464    {
465        for &(a, b) in &self.connected {
466            let (Ok(a_key), Ok(b_key)) = (a.try_into(), b.try_into()) else {
467                continue;
468            };
469            let a_group = groups.get(&a_key).unwrap();
470            let b_group = groups.get(&b_key).unwrap();
471            if !a_group.is_disjoint(b_group) {
472                return Err(DagOverlappingGroupError(a_key, b_key));
473            }
474        }
475        Ok(())
476    }
477}
478
479impl<N: GraphNodeId, S: BuildHasher + Default> Default for DagAnalysis<N, S> {
480    fn default() -> Self {
481        Self {
482            reachable: Default::default(),
483            connected: Default::default(),
484            disconnected: Default::default(),
485            transitive_edges: Default::default(),
486            transitive_reduction: Default::default(),
487            transitive_closure: Default::default(),
488        }
489    }
490}
491
492impl<N: GraphNodeId, S: BuildHasher> Debug for DagAnalysis<N, S> {
493    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
494        f.debug_struct("DagAnalysis")
495            .field("reachable", &self.reachable)
496            .field("connected", &self.connected)
497            .field("disconnected", &self.disconnected)
498            .field("transitive_edges", &self.transitive_edges)
499            .field("transitive_reduction", &self.transitive_reduction)
500            .field("transitive_closure", &self.transitive_closure)
501            .finish()
502    }
503}
504
505/// A mapping of keys to groups of values in a [`Dag`].
506pub struct DagGroups<K, V, S = FixedHasher>(HashMap<K, IndexSet<V, S>, S>);
507
508impl<K: Eq + Hash, V: Clone + Eq + Hash, S: BuildHasher + Default> DagGroups<K, V, S> {
509    /// Groups nodes in this DAG by a key type `K`, collecting value nodes `V`
510    /// under all of their ancestor key nodes.
511    ///
512    /// The node type `N` must be convertible into either a key type `K` or
513    /// a value type `V` via the [`TryInto`] trait.
514    pub fn new<N>(graph: &DiGraph<N, S>, toposort: &[N]) -> Self
515    where
516        N: GraphNodeId + TryInto<K, Error = V>,
517    {
518        Self::with_capacity(0, graph, toposort)
519    }
520
521    /// Groups nodes in this DAG by a key type `K`, collecting value nodes `V`
522    /// under all of their ancestor key nodes. `capacity` hints at the
523    /// expected number of groups, for memory allocation optimization.
524    ///
525    /// The node type `N` must be convertible into either a key type `K` or
526    /// a value type `V` via the [`TryInto`] trait.
527    pub fn with_capacity<N>(capacity: usize, graph: &DiGraph<N, S>, toposort: &[N]) -> Self
528    where
529        N: GraphNodeId + TryInto<K, Error = V>,
530    {
531        let mut groups: HashMap<K, IndexSet<V, S>, S> =
532            HashMap::with_capacity_and_hasher(capacity, Default::default());
533
534        // Iterate in reverse topological order (bottom-up) so we hit children before parents.
535        for &id in toposort.iter().rev() {
536            let Ok(key) = id.try_into() else {
537                continue;
538            };
539
540            let mut children = IndexSet::default();
541
542            for node in graph.neighbors_directed(id, Outgoing) {
543                match node.try_into() {
544                    Ok(key) => {
545                        // If the child is a key, this key inherits all of its children.
546                        let key_children = groups.get(&key).unwrap();
547                        children.extend(key_children.iter().cloned());
548                    }
549                    Err(value) => {
550                        // If the child is a value, add it directly.
551                        children.insert(value);
552                    }
553                }
554            }
555
556            groups.insert(key, children);
557        }
558
559        Self(groups)
560    }
561}
562
563impl<K: GraphNodeId, V: GraphNodeId, S: BuildHasher> DagGroups<K, V, S> {
564    /// Converts the given [`Dag`] into a flattened version where key nodes
565    /// (`K`) are replaced by their associated value nodes (`V`). Edges to/from
566    /// key nodes are redirected to connect their value nodes instead.
567    ///
568    /// The `collapse_group` function is called for each key node to customize
569    /// how its group is collapsed.
570    ///
571    /// The resulting [`Dag`] will have only value nodes (`V`).
572    pub fn flatten<N>(
573        &self,
574        dag: Dag<N>,
575        mut collapse_group: impl FnMut(K, &IndexSet<V, S>, &Dag<N>, &mut Vec<(N, N)>),
576    ) -> Dag<V>
577    where
578        N: GraphNodeId + TryInto<V, Error = K> + From<K> + From<V>,
579    {
580        let mut flattening = dag;
581        let mut temp = Vec::new();
582
583        for (&key, values) in self.iter() {
584            // Call the user-provided function to handle collapsing the group.
585            collapse_group(key, values, &flattening, &mut temp);
586
587            if values.is_empty() {
588                // Replace connections to the key node with connections between its neighbors.
589                for a in flattening.neighbors_directed(N::from(key), Incoming) {
590                    for b in flattening.neighbors_directed(N::from(key), Outgoing) {
591                        temp.push((a, b));
592                    }
593                }
594            } else {
595                // Redirect edges to/from the key node to connect to its value nodes.
596                for a in flattening.neighbors_directed(N::from(key), Incoming) {
597                    for &value in values {
598                        temp.push((a, N::from(value)));
599                    }
600                }
601                for b in flattening.neighbors_directed(N::from(key), Outgoing) {
602                    for &value in values {
603                        temp.push((N::from(value), b));
604                    }
605                }
606            }
607
608            // Remove the key node from the graph.
609            flattening.remove_node(N::from(key));
610            // Add all previously collected edges.
611            flattening.reserve_edges(temp.len());
612            for (a, b) in temp.drain(..) {
613                flattening.add_edge(a, b);
614            }
615        }
616
617        // By this point, we should have removed all keys from the graph,
618        // so this conversion should never fail.
619        flattening
620            .try_convert::<V>()
621            .unwrap_or_else(|n| unreachable!("Flattened graph has a leftover key {n:?}"))
622    }
623
624    /// Converts an undirected graph by replacing key nodes (`K`) with their
625    /// associated value nodes (`V`). Edges connected to key nodes are
626    /// redirected to connect their value nodes instead.
627    ///
628    /// The resulting undirected graph will have only value nodes (`V`).
629    pub fn flatten_undirected<N>(&self, graph: &UnGraph<N>) -> UnGraph<V>
630    where
631        N: GraphNodeId + TryInto<V, Error = K>,
632    {
633        let mut flattened = UnGraph::default();
634
635        for (lhs, rhs) in graph.all_edges() {
636            match (lhs.try_into(), rhs.try_into()) {
637                (Ok(lhs), Ok(rhs)) => {
638                    // Normal edge between two value nodes
639                    flattened.add_edge(lhs, rhs);
640                }
641                (Err(lhs_key), Ok(rhs)) => {
642                    // Edge from a key node to a value node, expand to all values in the key's group
643                    let Some(lhs_group) = self.get(&lhs_key) else {
644                        continue;
645                    };
646                    flattened.reserve_edges(lhs_group.len());
647                    for &lhs in lhs_group {
648                        flattened.add_edge(lhs, rhs);
649                    }
650                }
651                (Ok(lhs), Err(rhs_key)) => {
652                    // Edge from a value node to a key node, expand to all values in the key's group
653                    let Some(rhs_group) = self.get(&rhs_key) else {
654                        continue;
655                    };
656                    flattened.reserve_edges(rhs_group.len());
657                    for &rhs in rhs_group {
658                        flattened.add_edge(lhs, rhs);
659                    }
660                }
661                (Err(lhs_key), Err(rhs_key)) => {
662                    // Edge between two key nodes, expand to all combinations of their value nodes
663                    let Some(lhs_group) = self.get(&lhs_key) else {
664                        continue;
665                    };
666                    let Some(rhs_group) = self.get(&rhs_key) else {
667                        continue;
668                    };
669                    flattened.reserve_edges(lhs_group.len() * rhs_group.len());
670                    for &lhs in lhs_group {
671                        for &rhs in rhs_group {
672                            flattened.add_edge(lhs, rhs);
673                        }
674                    }
675                }
676            }
677        }
678
679        flattened
680    }
681}
682
683impl<K, V, S> Deref for DagGroups<K, V, S> {
684    type Target = HashMap<K, IndexSet<V, S>, S>;
685
686    fn deref(&self) -> &Self::Target {
687        &self.0
688    }
689}
690
691impl<K, V, S> DerefMut for DagGroups<K, V, S> {
692    fn deref_mut(&mut self) -> &mut Self::Target {
693        &mut self.0
694    }
695}
696
697impl<K, V, S> Default for DagGroups<K, V, S>
698where
699    S: BuildHasher + Default,
700{
701    fn default() -> Self {
702        Self(Default::default())
703    }
704}
705
706impl<K: Debug, V: Debug, S> Debug for DagGroups<K, V, S> {
707    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
708        f.debug_tuple("DagGroups").field(&self.0).finish()
709    }
710}
711
712/// Error indicating that the graph has redundant edges.
713#[derive(Error, Debug)]
714#[error("DAG has redundant edges: {0:?}")]
715pub struct DagRedundancyError<N: GraphNodeId>(pub Vec<(N, N)>);
716
717/// Error indicating that two graphs both have a dependency between the same nodes.
718#[derive(Error, Debug)]
719#[error("DAG has a cross-dependency between nodes {0:?} and {1:?}")]
720pub struct DagCrossDependencyError<N>(pub N, pub N);
721
722/// Error indicating that the graph has overlapping groups between two keys.
723#[derive(Error, Debug)]
724#[error("DAG has overlapping groups between keys {0:?} and {1:?}")]
725pub struct DagOverlappingGroupError<K>(pub K, pub K);
726
727#[cfg(test)]
728mod tests {
729    use core::ops::DerefMut;
730
731    use crate::schedule::graph::{index, Dag, Direction, GraphNodeId, UnGraph};
732
733    #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
734    struct TestNode(u32);
735
736    impl GraphNodeId for TestNode {
737        type Adjacent = (TestNode, Direction);
738        type Edge = (TestNode, TestNode);
739
740        fn kind(&self) -> &'static str {
741            "test node"
742        }
743    }
744
745    #[test]
746    fn mark_dirty() {
747        {
748            let mut dag = Dag::<TestNode>::new();
749            dag.add_node(TestNode(1));
750            assert!(dag.is_dirty());
751        }
752        {
753            let mut dag = Dag::<TestNode>::new();
754            dag.add_edge(TestNode(1), TestNode(2));
755            assert!(dag.is_dirty());
756        }
757        {
758            let mut dag = Dag::<TestNode>::new();
759            dag.deref_mut();
760            assert!(dag.is_dirty());
761        }
762        {
763            let mut dag = Dag::<TestNode>::new();
764            let _ = dag.graph_mut();
765            assert!(dag.is_dirty());
766        }
767    }
768
769    #[test]
770    fn toposort() {
771        let mut dag = Dag::<TestNode>::new();
772        dag.add_edge(TestNode(1), TestNode(2));
773        dag.add_edge(TestNode(2), TestNode(3));
774        dag.add_edge(TestNode(1), TestNode(3));
775
776        assert_eq!(
777            dag.toposort().unwrap(),
778            &[TestNode(1), TestNode(2), TestNode(3)]
779        );
780        assert_eq!(
781            dag.get_toposort().unwrap(),
782            &[TestNode(1), TestNode(2), TestNode(3)]
783        );
784    }
785
786    #[test]
787    fn analyze() {
788        let mut dag1 = Dag::<TestNode>::new();
789        dag1.add_edge(TestNode(1), TestNode(2));
790        dag1.add_edge(TestNode(2), TestNode(3));
791        dag1.add_edge(TestNode(1), TestNode(3)); // redundant edge
792
793        let analysis1 = dag1.analyze().unwrap();
794
795        assert!(analysis1.reachable().contains(index(0, 1, 3)));
796        assert!(analysis1.reachable().contains(index(1, 2, 3)));
797        assert!(analysis1.reachable().contains(index(0, 2, 3)));
798
799        assert!(analysis1.connected().contains(&(TestNode(1), TestNode(2))));
800        assert!(analysis1.connected().contains(&(TestNode(2), TestNode(3))));
801        assert!(analysis1.connected().contains(&(TestNode(1), TestNode(3))));
802
803        assert!(!analysis1
804            .disconnected()
805            .contains(&(TestNode(2), TestNode(1))));
806        assert!(!analysis1
807            .disconnected()
808            .contains(&(TestNode(3), TestNode(2))));
809        assert!(!analysis1
810            .disconnected()
811            .contains(&(TestNode(3), TestNode(1))));
812
813        assert!(analysis1
814            .transitive_edges()
815            .contains(&(TestNode(1), TestNode(3))));
816
817        assert!(analysis1.check_for_redundant_edges().is_err());
818
819        let mut dag2 = Dag::<TestNode>::new();
820        dag2.add_edge(TestNode(3), TestNode(4));
821
822        let analysis2 = dag2.analyze().unwrap();
823
824        assert!(analysis2.check_for_redundant_edges().is_ok());
825        assert!(analysis1.check_for_cross_dependencies(&analysis2).is_ok());
826
827        let mut dag3 = Dag::<TestNode>::new();
828        dag3.add_edge(TestNode(1), TestNode(2));
829
830        let analysis3 = dag3.analyze().unwrap();
831
832        assert!(analysis1.check_for_cross_dependencies(&analysis3).is_err());
833
834        dag1.remove_redundant_edges(&analysis1);
835        let analysis1 = dag1.analyze().unwrap();
836        assert!(analysis1.check_for_redundant_edges().is_ok());
837    }
838
839    #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
840    enum Node {
841        Key(Key),
842        Value(Value),
843    }
844    #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
845    struct Key(u32);
846    #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
847    struct Value(u32);
848
849    impl GraphNodeId for Node {
850        type Adjacent = (Node, Direction);
851        type Edge = (Node, Node);
852
853        fn kind(&self) -> &'static str {
854            "node"
855        }
856    }
857
858    impl TryInto<Key> for Node {
859        type Error = Value;
860
861        fn try_into(self) -> Result<Key, Value> {
862            match self {
863                Node::Key(k) => Ok(k),
864                Node::Value(v) => Err(v),
865            }
866        }
867    }
868
869    impl TryInto<Value> for Node {
870        type Error = Key;
871
872        fn try_into(self) -> Result<Value, Key> {
873            match self {
874                Node::Value(v) => Ok(v),
875                Node::Key(k) => Err(k),
876            }
877        }
878    }
879
880    impl GraphNodeId for Key {
881        type Adjacent = (Key, Direction);
882        type Edge = (Key, Key);
883
884        fn kind(&self) -> &'static str {
885            "key"
886        }
887    }
888
889    impl GraphNodeId for Value {
890        type Adjacent = (Value, Direction);
891        type Edge = (Value, Value);
892
893        fn kind(&self) -> &'static str {
894            "value"
895        }
896    }
897
898    impl From<Key> for Node {
899        fn from(key: Key) -> Self {
900            Node::Key(key)
901        }
902    }
903
904    impl From<Value> for Node {
905        fn from(value: Value) -> Self {
906            Node::Value(value)
907        }
908    }
909
910    #[test]
911    fn group_by_key() {
912        let mut dag = Dag::<Node>::new();
913        dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10)));
914        dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11)));
915        dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20)));
916        dag.add_edge(Node::Key(Key(2)), Node::Key(Key(1)));
917        dag.add_edge(Node::Value(Value(10)), Node::Value(Value(11)));
918
919        let groups = dag.group_by_key::<Key, Value>(2).unwrap();
920        assert_eq!(groups.len(), 2);
921
922        let group_key1 = groups.get(&Key(1)).unwrap();
923        assert!(group_key1.contains(&Value(10)));
924        assert!(group_key1.contains(&Value(11)));
925
926        let group_key2 = groups.get(&Key(2)).unwrap();
927        assert!(group_key2.contains(&Value(10)));
928        assert!(group_key2.contains(&Value(11)));
929        assert!(group_key2.contains(&Value(20)));
930    }
931
932    #[test]
933    fn flatten() {
934        let mut dag = Dag::<Node>::new();
935        dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10)));
936        dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11)));
937        dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20)));
938        dag.add_edge(Node::Key(Key(2)), Node::Value(Value(21)));
939        dag.add_edge(Node::Value(Value(30)), Node::Key(Key(1)));
940        dag.add_edge(Node::Key(Key(1)), Node::Value(Value(40)));
941
942        let groups = dag.group_by_key::<Key, Value>(2).unwrap();
943        let flattened = groups.flatten(dag, |_key, _values, _dag, _temp| {});
944
945        assert!(flattened.contains_node(Value(10)));
946        assert!(flattened.contains_node(Value(11)));
947        assert!(flattened.contains_node(Value(20)));
948        assert!(flattened.contains_node(Value(21)));
949        assert!(flattened.contains_node(Value(30)));
950        assert!(flattened.contains_node(Value(40)));
951
952        assert!(flattened.contains_edge(Value(30), Value(10)));
953        assert!(flattened.contains_edge(Value(30), Value(11)));
954        assert!(flattened.contains_edge(Value(10), Value(40)));
955        assert!(flattened.contains_edge(Value(11), Value(40)));
956    }
957
958    #[test]
959    fn flatten_undirected() {
960        let mut dag = Dag::<Node>::new();
961        dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10)));
962        dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11)));
963        dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20)));
964        dag.add_edge(Node::Key(Key(2)), Node::Value(Value(21)));
965
966        let groups = dag.group_by_key::<Key, Value>(2).unwrap();
967
968        let mut ungraph = UnGraph::<Node>::default();
969        ungraph.add_edge(Node::Value(Value(10)), Node::Value(Value(11)));
970        ungraph.add_edge(Node::Key(Key(1)), Node::Value(Value(30)));
971        ungraph.add_edge(Node::Value(Value(40)), Node::Key(Key(2)));
972        ungraph.add_edge(Node::Key(Key(1)), Node::Key(Key(2)));
973
974        let flattened = groups.flatten_undirected(&ungraph);
975
976        assert!(flattened.contains_edge(Value(10), Value(11)));
977        assert!(flattened.contains_edge(Value(10), Value(30)));
978        assert!(flattened.contains_edge(Value(11), Value(30)));
979        assert!(flattened.contains_edge(Value(40), Value(20)));
980        assert!(flattened.contains_edge(Value(40), Value(21)));
981        assert!(flattened.contains_edge(Value(10), Value(20)));
982        assert!(flattened.contains_edge(Value(10), Value(21)));
983        assert!(flattened.contains_edge(Value(11), Value(20)));
984        assert!(flattened.contains_edge(Value(11), Value(21)));
985    }
986
987    #[test]
988    fn overlapping_groups() {
989        let mut dag = Dag::<Node>::new();
990        dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10)));
991        dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11)));
992        dag.add_edge(Node::Key(Key(2)), Node::Value(Value(11))); // overlap
993        dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20)));
994        dag.add_edge(Node::Key(Key(1)), Node::Key(Key(2)));
995
996        let groups = dag.group_by_key::<Key, Value>(2).unwrap();
997        let analysis = dag.analyze().unwrap();
998
999        let result = analysis.check_for_overlapping_groups(&groups);
1000        assert!(result.is_err());
1001    }
1002
1003    #[test]
1004    fn disjoint_groups() {
1005        let mut dag = Dag::<Node>::new();
1006        dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10)));
1007        dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11)));
1008        dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20)));
1009        dag.add_edge(Node::Key(Key(2)), Node::Value(Value(21)));
1010
1011        let groups = dag.group_by_key::<Key, Value>(2).unwrap();
1012        let analysis = dag.analyze().unwrap();
1013
1014        let result = analysis.check_for_overlapping_groups(&groups);
1015        assert!(result.is_ok());
1016    }
1017}