bevy_ecs/schedule/graph/
graph_map.rs

1//! `Graph<DIRECTED>` is a graph datastructure where node values are mapping
2//! keys.
3//! Based on the `GraphMap` datastructure from [`petgraph`].
4//!
5//! [`petgraph`]: https://docs.rs/petgraph/0.6.5/petgraph/
6
7use alloc::{vec, vec::Vec};
8use core::{
9    fmt::{self, Debug},
10    hash::{BuildHasher, Hash},
11};
12use thiserror::Error;
13
14use bevy_platform::{
15    collections::{HashMap, HashSet},
16    hash::FixedHasher,
17};
18use indexmap::IndexMap;
19use smallvec::SmallVec;
20
21use Direction::{Incoming, Outgoing};
22
23/// Types that can be used as node identifiers in a [`DiGraph`]/[`UnGraph`].
24///
25/// [`DiGraph`]: crate::schedule::graph::DiGraph
26/// [`UnGraph`]: crate::schedule::graph::UnGraph
27pub trait GraphNodeId: Copy + Eq + Hash + Ord + Debug {
28    /// The type that packs and unpacks this [`GraphNodeId`] with a [`Direction`].
29    /// This is used to save space in the graph's adjacency list.
30    type Adjacent: Copy + Debug + From<(Self, Direction)> + Into<(Self, Direction)>;
31    /// The type that packs and unpacks this [`GraphNodeId`] with another
32    /// [`GraphNodeId`]. This is used to save space in the graph's edge list.
33    type Edge: Copy + Eq + Hash + Debug + From<(Self, Self)> + Into<(Self, Self)>;
34
35    /// Name of the kind of this node id.
36    ///
37    /// For structs, this should return a human-readable name of the struct.
38    /// For enums, this should return a human-readable name of the enum variant.
39    fn kind(&self) -> &'static str;
40}
41
42/// A `Graph` with undirected edges of some [`GraphNodeId`] `N`.
43///
44/// For example, an edge between *1* and *2* is equivalent to an edge between
45/// *2* and *1*.
46pub type UnGraph<N, S = FixedHasher> = Graph<false, N, S>;
47
48/// A `Graph` with directed edges of some [`GraphNodeId`] `N`.
49///
50/// For example, an edge from *1* to *2* is distinct from an edge from *2* to
51/// *1*.
52pub type DiGraph<N, S = FixedHasher> = Graph<true, N, S>;
53
54/// `Graph<DIRECTED>` is a graph datastructure using an associative array
55/// of its node weights of some [`GraphNodeId`].
56///
57/// It uses a combined adjacency list and sparse adjacency matrix
58/// representation, using **O(|N| + |E|)** space, and allows testing for edge
59/// existence in constant time.
60///
61/// `Graph` is parameterized over:
62///
63/// - Constant generic bool `DIRECTED` determines whether the graph edges are directed or
64///   undirected.
65/// - The `GraphNodeId` type `N`, which is used as the node weight.
66/// - The `BuildHasher` `S`.
67///
68/// You can use the type aliases `UnGraph` and `DiGraph` for convenience.
69///
70/// `Graph` does not allow parallel edges, but self loops are allowed.
71#[derive(Clone)]
72pub struct Graph<const DIRECTED: bool, N: GraphNodeId, S = FixedHasher>
73where
74    S: BuildHasher,
75{
76    nodes: IndexMap<N, Vec<N::Adjacent>, S>,
77    edges: HashSet<N::Edge, S>,
78}
79
80impl<const DIRECTED: bool, N: GraphNodeId, S: BuildHasher> Debug for Graph<DIRECTED, N, S> {
81    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
82        self.nodes.fmt(f)
83    }
84}
85
86impl<const DIRECTED: bool, N: GraphNodeId, S: BuildHasher> Graph<DIRECTED, N, S> {
87    /// Create a new `Graph` with estimated capacity.
88    pub fn with_capacity(nodes: usize, edges: usize) -> Self
89    where
90        S: Default,
91    {
92        Self {
93            nodes: IndexMap::with_capacity_and_hasher(nodes, S::default()),
94            edges: HashSet::with_capacity_and_hasher(edges, S::default()),
95        }
96    }
97
98    /// Use their natural order to map the node pair (a, b) to a canonical edge id.
99    #[inline]
100    fn edge_key(a: N, b: N) -> N::Edge {
101        let (a, b) = if DIRECTED || a <= b { (a, b) } else { (b, a) };
102
103        N::Edge::from((a, b))
104    }
105
106    /// Return the number of nodes in the graph.
107    pub fn node_count(&self) -> usize {
108        self.nodes.len()
109    }
110
111    /// Return the number of edges in the graph.
112    pub fn edge_count(&self) -> usize {
113        self.edges.len()
114    }
115
116    /// Add node `n` to the graph.
117    pub fn add_node(&mut self, n: N) {
118        self.nodes.entry(n).or_default();
119    }
120
121    /// Remove a node `n` from the graph.
122    ///
123    /// Computes in **O(N)** time, due to the removal of edges with other nodes.
124    pub fn remove_node(&mut self, n: N) {
125        let Some(links) = self.nodes.swap_remove(&n) else {
126            return;
127        };
128
129        let links = links.into_iter().map(N::Adjacent::into);
130
131        for (succ, dir) in links {
132            let edge = if dir == Outgoing {
133                Self::edge_key(n, succ)
134            } else {
135                Self::edge_key(succ, n)
136            };
137            // remove all successor links
138            self.remove_single_edge(succ, n, dir.opposite());
139            // Remove all edge values
140            self.edges.remove(&edge);
141        }
142    }
143
144    /// Return `true` if the node is contained in the graph.
145    pub fn contains_node(&self, n: N) -> bool {
146        self.nodes.contains_key(&n)
147    }
148
149    /// Add an edge connecting `a` and `b` to the graph.
150    /// For a directed graph, the edge is directed from `a` to `b`.
151    ///
152    /// Inserts nodes `a` and/or `b` if they aren't already part of the graph.
153    pub fn add_edge(&mut self, a: N, b: N) {
154        if self.edges.insert(Self::edge_key(a, b)) {
155            // insert in the adjacency list if it's a new edge
156            self.nodes
157                .entry(a)
158                .or_insert_with(|| Vec::with_capacity(1))
159                .push(N::Adjacent::from((b, Outgoing)));
160            if a != b {
161                // self loops don't have the Incoming entry
162                self.nodes
163                    .entry(b)
164                    .or_insert_with(|| Vec::with_capacity(1))
165                    .push(N::Adjacent::from((a, Incoming)));
166            }
167        }
168    }
169
170    /// Remove edge relation from a to b.
171    ///
172    /// Return `true` if it did exist.
173    fn remove_single_edge(&mut self, a: N, b: N, dir: Direction) -> bool {
174        let Some(sus) = self.nodes.get_mut(&a) else {
175            return false;
176        };
177
178        let Some(index) = sus
179            .iter()
180            .copied()
181            .map(N::Adjacent::into)
182            .position(|elt| (DIRECTED && elt == (b, dir)) || (!DIRECTED && elt.0 == b))
183        else {
184            return false;
185        };
186
187        sus.swap_remove(index);
188        true
189    }
190
191    /// Remove edge from `a` to `b` from the graph.
192    ///
193    /// Return `false` if the edge didn't exist.
194    pub fn remove_edge(&mut self, a: N, b: N) -> bool {
195        let exist1 = self.remove_single_edge(a, b, Outgoing);
196        let exist2 = if a != b {
197            self.remove_single_edge(b, a, Incoming)
198        } else {
199            exist1
200        };
201        let weight = self.edges.remove(&Self::edge_key(a, b));
202        debug_assert!(exist1 == exist2 && exist1 == weight);
203        weight
204    }
205
206    /// Return `true` if the edge connecting `a` with `b` is contained in the graph.
207    pub fn contains_edge(&self, a: N, b: N) -> bool {
208        self.edges.contains(&Self::edge_key(a, b))
209    }
210
211    /// Reserve capacity for at least `additional` more nodes to be inserted
212    /// in the graph.
213    pub fn reserve_nodes(&mut self, additional: usize) {
214        self.nodes.reserve(additional);
215    }
216
217    /// Reserve capacity for at least `additional` more edges to be inserted
218    /// in the graph.
219    pub fn reserve_edges(&mut self, additional: usize) {
220        self.edges.reserve(additional);
221    }
222
223    /// Return an iterator over the nodes of the graph.
224    pub fn nodes(&self) -> impl DoubleEndedIterator<Item = N> + ExactSizeIterator<Item = N> + '_ {
225        self.nodes.keys().copied()
226    }
227
228    /// Return an iterator of all nodes with an edge starting from `a`.
229    pub fn neighbors(&self, a: N) -> impl DoubleEndedIterator<Item = N> + '_ {
230        let iter = match self.nodes.get(&a) {
231            Some(neigh) => neigh.iter(),
232            None => [].iter(),
233        };
234
235        iter.copied()
236            .map(N::Adjacent::into)
237            .filter_map(|(n, dir)| (!DIRECTED || dir == Outgoing).then_some(n))
238    }
239
240    /// Return an iterator of all neighbors that have an edge between them and
241    /// `a`, in the specified direction.
242    /// If the graph's edges are undirected, this is equivalent to *.neighbors(a)*.
243    pub fn neighbors_directed(
244        &self,
245        a: N,
246        dir: Direction,
247    ) -> impl DoubleEndedIterator<Item = N> + '_ {
248        let iter = match self.nodes.get(&a) {
249            Some(neigh) => neigh.iter(),
250            None => [].iter(),
251        };
252
253        iter.copied()
254            .map(N::Adjacent::into)
255            .filter_map(move |(n, d)| (!DIRECTED || d == dir || n == a).then_some(n))
256    }
257
258    /// Return an iterator of target nodes with an edge starting from `a`,
259    /// paired with their respective edge weights.
260    pub fn edges(&self, a: N) -> impl DoubleEndedIterator<Item = (N, N)> + '_ {
261        self.neighbors(a)
262            .map(move |b| match self.edges.get(&Self::edge_key(a, b)) {
263                None => unreachable!(),
264                Some(_) => (a, b),
265            })
266    }
267
268    /// Return an iterator of target nodes with an edge starting from `a`,
269    /// paired with their respective edge weights.
270    pub fn edges_directed(
271        &self,
272        a: N,
273        dir: Direction,
274    ) -> impl DoubleEndedIterator<Item = (N, N)> + '_ {
275        self.neighbors_directed(a, dir).map(move |b| {
276            let (a, b) = if dir == Incoming { (b, a) } else { (a, b) };
277
278            match self.edges.get(&Self::edge_key(a, b)) {
279                None => unreachable!(),
280                Some(_) => (a, b),
281            }
282        })
283    }
284
285    /// Return an iterator over all edges of the graph with their weight in arbitrary order.
286    pub fn all_edges(&self) -> impl ExactSizeIterator<Item = (N, N)> + '_ {
287        self.edges.iter().copied().map(N::Edge::into)
288    }
289
290    pub(crate) fn to_index(&self, ix: N) -> usize {
291        self.nodes.get_index_of(&ix).unwrap()
292    }
293
294    /// Converts from one [`GraphNodeId`] type to another. If the conversion fails,
295    /// it returns the error from the target type's [`TryFrom`] implementation.
296    ///
297    /// Nodes must uniquely convert from `N` to `T` (i.e. no two `N` can convert
298    /// to the same `T`).
299    ///
300    /// # Errors
301    ///
302    /// If the conversion fails, it returns an error of type `N::Error`.
303    pub fn try_convert<T>(self) -> Result<Graph<DIRECTED, T, S>, N::Error>
304    where
305        N: TryInto<T>,
306        T: GraphNodeId,
307        S: Default,
308    {
309        // Converts the node key and every adjacency list entry from `N` to `T`.
310        fn try_convert_node<N: GraphNodeId + TryInto<T>, T: GraphNodeId>(
311            (key, adj): (N, Vec<N::Adjacent>),
312        ) -> Result<(T, Vec<T::Adjacent>), N::Error> {
313            let key = key.try_into()?;
314            let adj = adj
315                .into_iter()
316                .map(|node| {
317                    let (id, dir) = node.into();
318                    Ok(T::Adjacent::from((id.try_into()?, dir)))
319                })
320                .collect::<Result<_, N::Error>>()?;
321            Ok((key, adj))
322        }
323        // Unpacks the edge pair, converts the nodes from `N` to `T`, and repacks them.
324        fn try_convert_edge<N: GraphNodeId + TryInto<T>, T: GraphNodeId>(
325            edge: N::Edge,
326        ) -> Result<T::Edge, N::Error> {
327            let (a, b) = edge.into();
328            Ok(T::Edge::from((a.try_into()?, b.try_into()?)))
329        }
330
331        let nodes = self
332            .nodes
333            .into_iter()
334            .map(try_convert_node::<N, T>)
335            .collect::<Result<_, N::Error>>()?;
336        let edges = self
337            .edges
338            .into_iter()
339            .map(try_convert_edge::<N, T>)
340            .collect::<Result<_, N::Error>>()?;
341        Ok(Graph { nodes, edges })
342    }
343}
344
345/// Create a new empty `Graph`.
346impl<const DIRECTED: bool, N, S> Default for Graph<DIRECTED, N, S>
347where
348    N: GraphNodeId,
349    S: BuildHasher + Default,
350{
351    fn default() -> Self {
352        Self::with_capacity(0, 0)
353    }
354}
355
356impl<N: GraphNodeId, S: BuildHasher> DiGraph<N, S> {
357    /// Tries to topologically sort this directed graph.
358    ///
359    /// If the graph is acyclic, returns [`Ok`] with the list of [`GraphNodeId`]s
360    /// in a valid topological order. If the graph contains cycles, returns [`Err`]
361    /// with the list of strongly-connected components that contain cycles
362    /// (also in a valid topological order).
363    ///
364    /// # Errors
365    ///
366    /// - If the graph contains a self-loop, returns [`DiGraphToposortError::Loop`].
367    /// - If the graph contains cycles, returns [`DiGraphToposortError::Cycle`].
368    pub fn toposort(&self, mut scratch: Vec<N>) -> Result<Vec<N>, DiGraphToposortError<N>> {
369        // Check explicitly for self-edges.
370        // `iter_sccs` won't report them as cycles because they still form components of one node.
371        if let Some((node, _)) = self.all_edges().find(|(left, right)| left == right) {
372            return Err(DiGraphToposortError::Loop(node));
373        }
374
375        // Tarjan's SCC algorithm returns elements in *reverse* topological order.
376        scratch.clear();
377        scratch.reserve_exact(self.node_count().saturating_sub(scratch.capacity()));
378        let mut top_sorted_nodes = scratch;
379        let mut sccs_with_cycles = Vec::new();
380
381        for scc in self.iter_sccs() {
382            // A strongly-connected component is a group of nodes who can all reach each other
383            // through one or more paths. If an SCC contains more than one node, there must be
384            // at least one cycle within them.
385            top_sorted_nodes.extend_from_slice(&scc);
386            if scc.len() > 1 {
387                sccs_with_cycles.push(scc);
388            }
389        }
390
391        if sccs_with_cycles.is_empty() {
392            // reverse to get topological order
393            top_sorted_nodes.reverse();
394            Ok(top_sorted_nodes)
395        } else {
396            let mut cycles = Vec::new();
397            for scc in &sccs_with_cycles {
398                cycles.append(&mut self.simple_cycles_in_component(scc));
399            }
400
401            Err(DiGraphToposortError::Cycle(cycles))
402        }
403    }
404
405    /// Returns the simple cycles in a strongly-connected component of a directed graph.
406    ///
407    /// The algorithm implemented comes from
408    /// ["Finding all the elementary circuits of a directed graph"][1] by D. B. Johnson.
409    ///
410    /// [1]: https://doi.org/10.1137/0204007
411    pub fn simple_cycles_in_component(&self, scc: &[N]) -> Vec<Vec<N>> {
412        let mut cycles = vec![];
413        let mut sccs = vec![SmallVec::from_slice(scc)];
414
415        while let Some(mut scc) = sccs.pop() {
416            // only look at nodes and edges in this strongly-connected component
417            let mut subgraph = DiGraph::<N>::with_capacity(scc.len(), 0);
418            for &node in &scc {
419                subgraph.add_node(node);
420            }
421
422            for &node in &scc {
423                for successor in self.neighbors(node) {
424                    if subgraph.contains_node(successor) {
425                        subgraph.add_edge(node, successor);
426                    }
427                }
428            }
429
430            // path of nodes that may form a cycle
431            let mut path = Vec::with_capacity(subgraph.node_count());
432            // we mark nodes as "blocked" to avoid finding permutations of the same cycles
433            let mut blocked: HashSet<_> =
434                HashSet::with_capacity_and_hasher(subgraph.node_count(), Default::default());
435            // connects nodes along path segments that can't be part of a cycle (given current root)
436            // those nodes can be unblocked at the same time
437            let mut unblock_together: HashMap<N, HashSet<N>> =
438                HashMap::with_capacity_and_hasher(subgraph.node_count(), Default::default());
439            // stack for unblocking nodes
440            let mut unblock_stack = Vec::with_capacity(subgraph.node_count());
441            // nodes can be involved in multiple cycles
442            let mut maybe_in_more_cycles: HashSet<N> =
443                HashSet::with_capacity_and_hasher(subgraph.node_count(), Default::default());
444            // stack for DFS
445            let mut stack = Vec::with_capacity(subgraph.node_count());
446
447            // we're going to look for all cycles that begin and end at this node
448            let root = scc.pop().unwrap();
449            // start a path at the root
450            path.clear();
451            path.push(root);
452            // mark this node as blocked
453            blocked.insert(root);
454
455            // DFS
456            stack.clear();
457            stack.push((root, subgraph.neighbors(root)));
458            while !stack.is_empty() {
459                let &mut (ref node, ref mut successors) = stack.last_mut().unwrap();
460                if let Some(next) = successors.next() {
461                    if next == root {
462                        // found a cycle
463                        maybe_in_more_cycles.extend(path.iter());
464                        cycles.push(path.clone());
465                    } else if !blocked.contains(&next) {
466                        // first time seeing `next` on this path
467                        maybe_in_more_cycles.remove(&next);
468                        path.push(next);
469                        blocked.insert(next);
470                        stack.push((next, subgraph.neighbors(next)));
471                        continue;
472                    } else {
473                        // not first time seeing `next` on this path
474                    }
475                }
476
477                if successors.peekable().peek().is_none() {
478                    if maybe_in_more_cycles.contains(node) {
479                        unblock_stack.push(*node);
480                        // unblock this node's ancestors
481                        while let Some(n) = unblock_stack.pop() {
482                            if blocked.remove(&n) {
483                                let unblock_predecessors = unblock_together.entry(n).or_default();
484                                unblock_stack.extend(unblock_predecessors.iter());
485                                unblock_predecessors.clear();
486                            }
487                        }
488                    } else {
489                        // if its descendants can be unblocked later, this node will be too
490                        for successor in subgraph.neighbors(*node) {
491                            unblock_together.entry(successor).or_default().insert(*node);
492                        }
493                    }
494
495                    // remove node from path and DFS stack
496                    path.pop();
497                    stack.pop();
498                }
499            }
500
501            drop(stack);
502
503            // remove node from subgraph
504            subgraph.remove_node(root);
505
506            // divide remainder into smaller SCCs
507            sccs.extend(subgraph.iter_sccs().filter(|scc| scc.len() > 1));
508        }
509
510        cycles
511    }
512
513    /// Iterate over all *Strongly Connected Components* in this graph.
514    pub(crate) fn iter_sccs(&self) -> impl Iterator<Item = SmallVec<[N; 4]>> + '_ {
515        super::tarjan_scc::new_tarjan_scc(self)
516    }
517}
518
519/// Error returned when topologically sorting a directed graph fails.
520#[derive(Error, Debug)]
521pub enum DiGraphToposortError<N: GraphNodeId> {
522    /// A self-loop was detected.
523    #[error("self-loop detected at node `{0:?}`")]
524    Loop(N),
525    /// Cycles were detected.
526    #[error("cycles detected: {0:?}")]
527    Cycle(Vec<Vec<N>>),
528}
529
530/// Edge direction.
531#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Ord, Eq, Hash)]
532#[repr(u8)]
533pub enum Direction {
534    /// An `Outgoing` edge is an outward edge *from* the current node.
535    Outgoing = 0,
536    /// An `Incoming` edge is an inbound edge *to* the current node.
537    Incoming = 1,
538}
539
540impl Direction {
541    /// Return the opposite `Direction`.
542    #[inline]
543    pub fn opposite(self) -> Self {
544        match self {
545            Self::Outgoing => Self::Incoming,
546            Self::Incoming => Self::Outgoing,
547        }
548    }
549}
550
551#[cfg(test)]
552mod tests {
553    use crate::schedule::{NodeId, SystemKey};
554
555    use super::*;
556    use alloc::vec;
557    use slotmap::SlotMap;
558
559    /// The `Graph` type _must_ preserve the order that nodes are inserted in if
560    /// no removals occur. Removals are permitted to swap the latest node into the
561    /// location of the removed node.
562    #[test]
563    fn node_order_preservation() {
564        use NodeId::System;
565
566        let mut slotmap = SlotMap::<SystemKey, ()>::with_key();
567        let mut graph = DiGraph::<NodeId>::default();
568
569        let sys1 = slotmap.insert(());
570        let sys2 = slotmap.insert(());
571        let sys3 = slotmap.insert(());
572        let sys4 = slotmap.insert(());
573
574        graph.add_node(System(sys1));
575        graph.add_node(System(sys2));
576        graph.add_node(System(sys3));
577        graph.add_node(System(sys4));
578
579        assert_eq!(
580            graph.nodes().collect::<Vec<_>>(),
581            vec![System(sys1), System(sys2), System(sys3), System(sys4)]
582        );
583
584        graph.remove_node(System(sys1));
585
586        assert_eq!(
587            graph.nodes().collect::<Vec<_>>(),
588            vec![System(sys4), System(sys2), System(sys3)]
589        );
590
591        graph.remove_node(System(sys4));
592
593        assert_eq!(
594            graph.nodes().collect::<Vec<_>>(),
595            vec![System(sys3), System(sys2)]
596        );
597
598        graph.remove_node(System(sys2));
599
600        assert_eq!(graph.nodes().collect::<Vec<_>>(), vec![System(sys3)]);
601
602        graph.remove_node(System(sys3));
603
604        assert_eq!(graph.nodes().collect::<Vec<_>>(), vec![]);
605    }
606
607    /// Nodes that have bidirectional edges (or any edge in the case of undirected graphs) are
608    /// considered strongly connected. A strongly connected component is a collection of
609    /// nodes where there exists a path from any node to any other node in the collection.
610    #[test]
611    fn strongly_connected_components() {
612        use NodeId::System;
613
614        let mut slotmap = SlotMap::<SystemKey, ()>::with_key();
615        let mut graph = DiGraph::<NodeId>::default();
616
617        let sys1 = slotmap.insert(());
618        let sys2 = slotmap.insert(());
619        let sys3 = slotmap.insert(());
620        let sys4 = slotmap.insert(());
621        let sys5 = slotmap.insert(());
622        let sys6 = slotmap.insert(());
623
624        graph.add_edge(System(sys1), System(sys2));
625        graph.add_edge(System(sys2), System(sys1));
626
627        graph.add_edge(System(sys2), System(sys3));
628        graph.add_edge(System(sys3), System(sys2));
629
630        graph.add_edge(System(sys4), System(sys5));
631        graph.add_edge(System(sys5), System(sys4));
632
633        graph.add_edge(System(sys6), System(sys2));
634
635        let sccs = graph
636            .iter_sccs()
637            .map(|scc| scc.to_vec())
638            .collect::<Vec<_>>();
639
640        assert_eq!(
641            sccs,
642            vec![
643                vec![System(sys3), System(sys2), System(sys1)],
644                vec![System(sys5), System(sys4)],
645                vec![System(sys6)]
646            ]
647        );
648    }
649}