bevy_ecs/schedule/graph/
graph_map.rs

1//! `Graph<DIRECTED>` is a graph datastructure where node values are mapping
2//! keys.
3//! Based on the `GraphMap` datastructure from [`petgraph`].
4//!
5//! [`petgraph`]: https://docs.rs/petgraph/0.6.5/petgraph/
6
7use alloc::vec::Vec;
8use bevy_platform::{collections::HashSet, hash::FixedHasher};
9use core::{
10    fmt,
11    hash::{BuildHasher, Hash},
12};
13use indexmap::IndexMap;
14use smallvec::SmallVec;
15
16use super::NodeId;
17
18use Direction::{Incoming, Outgoing};
19
20/// A `Graph` with undirected edges.
21///
22/// For example, an edge between *1* and *2* is equivalent to an edge between
23/// *2* and *1*.
24pub type UnGraph<S = FixedHasher> = Graph<false, S>;
25
26/// A `Graph` with directed edges.
27///
28/// For example, an edge from *1* to *2* is distinct from an edge from *2* to
29/// *1*.
30pub type DiGraph<S = FixedHasher> = Graph<true, S>;
31
32/// `Graph<DIRECTED>` is a graph datastructure using an associative array
33/// of its node weights `NodeId`.
34///
35/// It uses a combined adjacency list and sparse adjacency matrix
36/// representation, using **O(|N| + |E|)** space, and allows testing for edge
37/// existence in constant time.
38///
39/// `Graph` is parameterized over:
40///
41/// - Constant generic bool `DIRECTED` determines whether the graph edges are directed or
42///   undirected.
43/// - The `BuildHasher` `S`.
44///
45/// You can use the type aliases `UnGraph` and `DiGraph` for convenience.
46///
47/// `Graph` does not allow parallel edges, but self loops are allowed.
48#[derive(Clone)]
49pub struct Graph<const DIRECTED: bool, S = FixedHasher>
50where
51    S: BuildHasher,
52{
53    nodes: IndexMap<NodeId, Vec<CompactNodeIdAndDirection>, S>,
54    edges: HashSet<CompactNodeIdPair, S>,
55}
56
57impl<const DIRECTED: bool, S: BuildHasher> fmt::Debug for Graph<DIRECTED, S> {
58    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
59        self.nodes.fmt(f)
60    }
61}
62
63impl<const DIRECTED: bool, S> Graph<DIRECTED, S>
64where
65    S: BuildHasher,
66{
67    /// Create a new `Graph` with estimated capacity.
68    pub fn with_capacity(nodes: usize, edges: usize) -> Self
69    where
70        S: Default,
71    {
72        Self {
73            nodes: IndexMap::with_capacity_and_hasher(nodes, S::default()),
74            edges: HashSet::with_capacity_and_hasher(edges, S::default()),
75        }
76    }
77
78    /// Use their natural order to map the node pair (a, b) to a canonical edge id.
79    #[inline]
80    fn edge_key(a: NodeId, b: NodeId) -> CompactNodeIdPair {
81        let (a, b) = if DIRECTED || a <= b { (a, b) } else { (b, a) };
82
83        CompactNodeIdPair::store(a, b)
84    }
85
86    /// Return the number of nodes in the graph.
87    pub fn node_count(&self) -> usize {
88        self.nodes.len()
89    }
90
91    /// Add node `n` to the graph.
92    pub fn add_node(&mut self, n: NodeId) {
93        self.nodes.entry(n).or_default();
94    }
95
96    /// Remove a node `n` from the graph.
97    ///
98    /// Computes in **O(N)** time, due to the removal of edges with other nodes.
99    pub fn remove_node(&mut self, n: NodeId) {
100        let Some(links) = self.nodes.swap_remove(&n) else {
101            return;
102        };
103
104        let links = links.into_iter().map(CompactNodeIdAndDirection::load);
105
106        for (succ, dir) in links {
107            let edge = if dir == Outgoing {
108                Self::edge_key(n, succ)
109            } else {
110                Self::edge_key(succ, n)
111            };
112            // remove all successor links
113            self.remove_single_edge(succ, n, dir.opposite());
114            // Remove all edge values
115            self.edges.remove(&edge);
116        }
117    }
118
119    /// Return `true` if the node is contained in the graph.
120    pub fn contains_node(&self, n: NodeId) -> bool {
121        self.nodes.contains_key(&n)
122    }
123
124    /// Add an edge connecting `a` and `b` to the graph.
125    /// For a directed graph, the edge is directed from `a` to `b`.
126    ///
127    /// Inserts nodes `a` and/or `b` if they aren't already part of the graph.
128    pub fn add_edge(&mut self, a: NodeId, b: NodeId) {
129        if self.edges.insert(Self::edge_key(a, b)) {
130            // insert in the adjacency list if it's a new edge
131            self.nodes
132                .entry(a)
133                .or_insert_with(|| Vec::with_capacity(1))
134                .push(CompactNodeIdAndDirection::store(b, Outgoing));
135            if a != b {
136                // self loops don't have the Incoming entry
137                self.nodes
138                    .entry(b)
139                    .or_insert_with(|| Vec::with_capacity(1))
140                    .push(CompactNodeIdAndDirection::store(a, Incoming));
141            }
142        }
143    }
144
145    /// Remove edge relation from a to b
146    ///
147    /// Return `true` if it did exist.
148    fn remove_single_edge(&mut self, a: NodeId, b: NodeId, dir: Direction) -> bool {
149        let Some(sus) = self.nodes.get_mut(&a) else {
150            return false;
151        };
152
153        let Some(index) = sus
154            .iter()
155            .copied()
156            .map(CompactNodeIdAndDirection::load)
157            .position(|elt| (DIRECTED && elt == (b, dir)) || (!DIRECTED && elt.0 == b))
158        else {
159            return false;
160        };
161
162        sus.swap_remove(index);
163        true
164    }
165
166    /// Remove edge from `a` to `b` from the graph.
167    ///
168    /// Return `false` if the edge didn't exist.
169    pub fn remove_edge(&mut self, a: NodeId, b: NodeId) -> bool {
170        let exist1 = self.remove_single_edge(a, b, Outgoing);
171        let exist2 = if a != b {
172            self.remove_single_edge(b, a, Incoming)
173        } else {
174            exist1
175        };
176        let weight = self.edges.remove(&Self::edge_key(a, b));
177        debug_assert!(exist1 == exist2 && exist1 == weight);
178        weight
179    }
180
181    /// Return `true` if the edge connecting `a` with `b` is contained in the graph.
182    pub fn contains_edge(&self, a: NodeId, b: NodeId) -> bool {
183        self.edges.contains(&Self::edge_key(a, b))
184    }
185
186    /// Return an iterator over the nodes of the graph.
187    pub fn nodes(
188        &self,
189    ) -> impl DoubleEndedIterator<Item = NodeId> + ExactSizeIterator<Item = NodeId> + '_ {
190        self.nodes.keys().copied()
191    }
192
193    /// Return an iterator of all nodes with an edge starting from `a`.
194    pub fn neighbors(&self, a: NodeId) -> impl DoubleEndedIterator<Item = NodeId> + '_ {
195        let iter = match self.nodes.get(&a) {
196            Some(neigh) => neigh.iter(),
197            None => [].iter(),
198        };
199
200        iter.copied()
201            .map(CompactNodeIdAndDirection::load)
202            .filter_map(|(n, dir)| (!DIRECTED || dir == Outgoing).then_some(n))
203    }
204
205    /// Return an iterator of all neighbors that have an edge between them and
206    /// `a`, in the specified direction.
207    /// If the graph's edges are undirected, this is equivalent to *.neighbors(a)*.
208    pub fn neighbors_directed(
209        &self,
210        a: NodeId,
211        dir: Direction,
212    ) -> impl DoubleEndedIterator<Item = NodeId> + '_ {
213        let iter = match self.nodes.get(&a) {
214            Some(neigh) => neigh.iter(),
215            None => [].iter(),
216        };
217
218        iter.copied()
219            .map(CompactNodeIdAndDirection::load)
220            .filter_map(move |(n, d)| (!DIRECTED || d == dir || n == a).then_some(n))
221    }
222
223    /// Return an iterator of target nodes with an edge starting from `a`,
224    /// paired with their respective edge weights.
225    pub fn edges(&self, a: NodeId) -> impl DoubleEndedIterator<Item = (NodeId, NodeId)> + '_ {
226        self.neighbors(a)
227            .map(move |b| match self.edges.get(&Self::edge_key(a, b)) {
228                None => unreachable!(),
229                Some(_) => (a, b),
230            })
231    }
232
233    /// Return an iterator of target nodes with an edge starting from `a`,
234    /// paired with their respective edge weights.
235    pub fn edges_directed(
236        &self,
237        a: NodeId,
238        dir: Direction,
239    ) -> impl DoubleEndedIterator<Item = (NodeId, NodeId)> + '_ {
240        self.neighbors_directed(a, dir).map(move |b| {
241            let (a, b) = if dir == Incoming { (b, a) } else { (a, b) };
242
243            match self.edges.get(&Self::edge_key(a, b)) {
244                None => unreachable!(),
245                Some(_) => (a, b),
246            }
247        })
248    }
249
250    /// Return an iterator over all edges of the graph with their weight in arbitrary order.
251    pub fn all_edges(&self) -> impl ExactSizeIterator<Item = (NodeId, NodeId)> + '_ {
252        self.edges.iter().copied().map(CompactNodeIdPair::load)
253    }
254
255    pub(crate) fn to_index(&self, ix: NodeId) -> usize {
256        self.nodes.get_index_of(&ix).unwrap()
257    }
258}
259
260/// Create a new empty `Graph`.
261impl<const DIRECTED: bool, S> Default for Graph<DIRECTED, S>
262where
263    S: BuildHasher + Default,
264{
265    fn default() -> Self {
266        Self::with_capacity(0, 0)
267    }
268}
269
270impl<S: BuildHasher> DiGraph<S> {
271    /// Iterate over all *Strongly Connected Components* in this graph.
272    pub(crate) fn iter_sccs(&self) -> impl Iterator<Item = SmallVec<[NodeId; 4]>> + '_ {
273        super::tarjan_scc::new_tarjan_scc(self)
274    }
275}
276
277/// Edge direction.
278#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Ord, Eq, Hash)]
279#[repr(u8)]
280pub enum Direction {
281    /// An `Outgoing` edge is an outward edge *from* the current node.
282    Outgoing = 0,
283    /// An `Incoming` edge is an inbound edge *to* the current node.
284    Incoming = 1,
285}
286
287impl Direction {
288    /// Return the opposite `Direction`.
289    #[inline]
290    pub fn opposite(self) -> Self {
291        match self {
292            Self::Outgoing => Self::Incoming,
293            Self::Incoming => Self::Outgoing,
294        }
295    }
296}
297
298/// Compact storage of a [`NodeId`] and a [`Direction`].
299#[derive(Clone, Copy)]
300struct CompactNodeIdAndDirection {
301    index: usize,
302    is_system: bool,
303    direction: Direction,
304}
305
306impl fmt::Debug for CompactNodeIdAndDirection {
307    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
308        self.load().fmt(f)
309    }
310}
311
312impl CompactNodeIdAndDirection {
313    const fn store(node: NodeId, direction: Direction) -> Self {
314        let index = node.index();
315        let is_system = node.is_system();
316
317        Self {
318            index,
319            is_system,
320            direction,
321        }
322    }
323
324    const fn load(self) -> (NodeId, Direction) {
325        let Self {
326            index,
327            is_system,
328            direction,
329        } = self;
330
331        let node = match is_system {
332            true => NodeId::System(index),
333            false => NodeId::Set(index),
334        };
335
336        (node, direction)
337    }
338}
339
340/// Compact storage of a [`NodeId`] pair.
341#[derive(Clone, Copy, Hash, PartialEq, Eq)]
342struct CompactNodeIdPair {
343    index_a: usize,
344    index_b: usize,
345    is_system_a: bool,
346    is_system_b: bool,
347}
348
349impl fmt::Debug for CompactNodeIdPair {
350    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351        self.load().fmt(f)
352    }
353}
354
355impl CompactNodeIdPair {
356    const fn store(a: NodeId, b: NodeId) -> Self {
357        let index_a = a.index();
358        let is_system_a = a.is_system();
359
360        let index_b = b.index();
361        let is_system_b = b.is_system();
362
363        Self {
364            index_a,
365            index_b,
366            is_system_a,
367            is_system_b,
368        }
369    }
370
371    const fn load(self) -> (NodeId, NodeId) {
372        let Self {
373            index_a,
374            index_b,
375            is_system_a,
376            is_system_b,
377        } = self;
378
379        let a = match is_system_a {
380            true => NodeId::System(index_a),
381            false => NodeId::Set(index_a),
382        };
383
384        let b = match is_system_b {
385            true => NodeId::System(index_b),
386            false => NodeId::Set(index_b),
387        };
388
389        (a, b)
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use alloc::vec;
397
398    /// The `Graph` type _must_ preserve the order that nodes are inserted in if
399    /// no removals occur. Removals are permitted to swap the latest node into the
400    /// location of the removed node.
401    #[test]
402    fn node_order_preservation() {
403        use NodeId::System;
404
405        let mut graph = <DiGraph>::default();
406
407        graph.add_node(System(1));
408        graph.add_node(System(2));
409        graph.add_node(System(3));
410        graph.add_node(System(4));
411
412        assert_eq!(
413            graph.nodes().collect::<Vec<_>>(),
414            vec![System(1), System(2), System(3), System(4)]
415        );
416
417        graph.remove_node(System(1));
418
419        assert_eq!(
420            graph.nodes().collect::<Vec<_>>(),
421            vec![System(4), System(2), System(3)]
422        );
423
424        graph.remove_node(System(4));
425
426        assert_eq!(
427            graph.nodes().collect::<Vec<_>>(),
428            vec![System(3), System(2)]
429        );
430
431        graph.remove_node(System(2));
432
433        assert_eq!(graph.nodes().collect::<Vec<_>>(), vec![System(3)]);
434
435        graph.remove_node(System(3));
436
437        assert_eq!(graph.nodes().collect::<Vec<_>>(), vec![]);
438    }
439
440    /// Nodes that have bidirectional edges (or any edge in the case of undirected graphs) are
441    /// considered strongly connected. A strongly connected component is a collection of
442    /// nodes where there exists a path from any node to any other node in the collection.
443    #[test]
444    fn strongly_connected_components() {
445        use NodeId::System;
446
447        let mut graph = <DiGraph>::default();
448
449        graph.add_edge(System(1), System(2));
450        graph.add_edge(System(2), System(1));
451
452        graph.add_edge(System(2), System(3));
453        graph.add_edge(System(3), System(2));
454
455        graph.add_edge(System(4), System(5));
456        graph.add_edge(System(5), System(4));
457
458        graph.add_edge(System(6), System(2));
459
460        let sccs = graph
461            .iter_sccs()
462            .map(|scc| scc.to_vec())
463            .collect::<Vec<_>>();
464
465        assert_eq!(
466            sccs,
467            vec![
468                vec![System(3), System(2), System(1)],
469                vec![System(5), System(4)],
470                vec![System(6)]
471            ]
472        );
473    }
474}