1use alloc::vec::Vec;
8use bevy_platform::{collections::HashSet, hash::FixedHasher};
9use core::{
10 fmt,
11 hash::{BuildHasher, Hash},
12};
13use indexmap::IndexMap;
14use smallvec::SmallVec;
15
16use super::NodeId;
17
18use Direction::{Incoming, Outgoing};
19
20pub type UnGraph<S = FixedHasher> = Graph<false, S>;
25
26pub type DiGraph<S = FixedHasher> = Graph<true, S>;
31
32#[derive(Clone)]
49pub struct Graph<const DIRECTED: bool, S = FixedHasher>
50where
51 S: BuildHasher,
52{
53 nodes: IndexMap<NodeId, Vec<CompactNodeIdAndDirection>, S>,
54 edges: HashSet<CompactNodeIdPair, S>,
55}
56
57impl<const DIRECTED: bool, S: BuildHasher> fmt::Debug for Graph<DIRECTED, S> {
58 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
59 self.nodes.fmt(f)
60 }
61}
62
63impl<const DIRECTED: bool, S> Graph<DIRECTED, S>
64where
65 S: BuildHasher,
66{
67 pub fn with_capacity(nodes: usize, edges: usize) -> Self
69 where
70 S: Default,
71 {
72 Self {
73 nodes: IndexMap::with_capacity_and_hasher(nodes, S::default()),
74 edges: HashSet::with_capacity_and_hasher(edges, S::default()),
75 }
76 }
77
78 #[inline]
80 fn edge_key(a: NodeId, b: NodeId) -> CompactNodeIdPair {
81 let (a, b) = if DIRECTED || a <= b { (a, b) } else { (b, a) };
82
83 CompactNodeIdPair::store(a, b)
84 }
85
86 pub fn node_count(&self) -> usize {
88 self.nodes.len()
89 }
90
91 pub fn add_node(&mut self, n: NodeId) {
93 self.nodes.entry(n).or_default();
94 }
95
96 pub fn remove_node(&mut self, n: NodeId) {
100 let Some(links) = self.nodes.swap_remove(&n) else {
101 return;
102 };
103
104 let links = links.into_iter().map(CompactNodeIdAndDirection::load);
105
106 for (succ, dir) in links {
107 let edge = if dir == Outgoing {
108 Self::edge_key(n, succ)
109 } else {
110 Self::edge_key(succ, n)
111 };
112 self.remove_single_edge(succ, n, dir.opposite());
114 self.edges.remove(&edge);
116 }
117 }
118
119 pub fn contains_node(&self, n: NodeId) -> bool {
121 self.nodes.contains_key(&n)
122 }
123
124 pub fn add_edge(&mut self, a: NodeId, b: NodeId) {
129 if self.edges.insert(Self::edge_key(a, b)) {
130 self.nodes
132 .entry(a)
133 .or_insert_with(|| Vec::with_capacity(1))
134 .push(CompactNodeIdAndDirection::store(b, Outgoing));
135 if a != b {
136 self.nodes
138 .entry(b)
139 .or_insert_with(|| Vec::with_capacity(1))
140 .push(CompactNodeIdAndDirection::store(a, Incoming));
141 }
142 }
143 }
144
145 fn remove_single_edge(&mut self, a: NodeId, b: NodeId, dir: Direction) -> bool {
149 let Some(sus) = self.nodes.get_mut(&a) else {
150 return false;
151 };
152
153 let Some(index) = sus
154 .iter()
155 .copied()
156 .map(CompactNodeIdAndDirection::load)
157 .position(|elt| (DIRECTED && elt == (b, dir)) || (!DIRECTED && elt.0 == b))
158 else {
159 return false;
160 };
161
162 sus.swap_remove(index);
163 true
164 }
165
166 pub fn remove_edge(&mut self, a: NodeId, b: NodeId) -> bool {
170 let exist1 = self.remove_single_edge(a, b, Outgoing);
171 let exist2 = if a != b {
172 self.remove_single_edge(b, a, Incoming)
173 } else {
174 exist1
175 };
176 let weight = self.edges.remove(&Self::edge_key(a, b));
177 debug_assert!(exist1 == exist2 && exist1 == weight);
178 weight
179 }
180
181 pub fn contains_edge(&self, a: NodeId, b: NodeId) -> bool {
183 self.edges.contains(&Self::edge_key(a, b))
184 }
185
186 pub fn nodes(
188 &self,
189 ) -> impl DoubleEndedIterator<Item = NodeId> + ExactSizeIterator<Item = NodeId> + '_ {
190 self.nodes.keys().copied()
191 }
192
193 pub fn neighbors(&self, a: NodeId) -> impl DoubleEndedIterator<Item = NodeId> + '_ {
195 let iter = match self.nodes.get(&a) {
196 Some(neigh) => neigh.iter(),
197 None => [].iter(),
198 };
199
200 iter.copied()
201 .map(CompactNodeIdAndDirection::load)
202 .filter_map(|(n, dir)| (!DIRECTED || dir == Outgoing).then_some(n))
203 }
204
205 pub fn neighbors_directed(
209 &self,
210 a: NodeId,
211 dir: Direction,
212 ) -> impl DoubleEndedIterator<Item = NodeId> + '_ {
213 let iter = match self.nodes.get(&a) {
214 Some(neigh) => neigh.iter(),
215 None => [].iter(),
216 };
217
218 iter.copied()
219 .map(CompactNodeIdAndDirection::load)
220 .filter_map(move |(n, d)| (!DIRECTED || d == dir || n == a).then_some(n))
221 }
222
223 pub fn edges(&self, a: NodeId) -> impl DoubleEndedIterator<Item = (NodeId, NodeId)> + '_ {
226 self.neighbors(a)
227 .map(move |b| match self.edges.get(&Self::edge_key(a, b)) {
228 None => unreachable!(),
229 Some(_) => (a, b),
230 })
231 }
232
233 pub fn edges_directed(
236 &self,
237 a: NodeId,
238 dir: Direction,
239 ) -> impl DoubleEndedIterator<Item = (NodeId, NodeId)> + '_ {
240 self.neighbors_directed(a, dir).map(move |b| {
241 let (a, b) = if dir == Incoming { (b, a) } else { (a, b) };
242
243 match self.edges.get(&Self::edge_key(a, b)) {
244 None => unreachable!(),
245 Some(_) => (a, b),
246 }
247 })
248 }
249
250 pub fn all_edges(&self) -> impl ExactSizeIterator<Item = (NodeId, NodeId)> + '_ {
252 self.edges.iter().copied().map(CompactNodeIdPair::load)
253 }
254
255 pub(crate) fn to_index(&self, ix: NodeId) -> usize {
256 self.nodes.get_index_of(&ix).unwrap()
257 }
258}
259
260impl<const DIRECTED: bool, S> Default for Graph<DIRECTED, S>
262where
263 S: BuildHasher + Default,
264{
265 fn default() -> Self {
266 Self::with_capacity(0, 0)
267 }
268}
269
270impl<S: BuildHasher> DiGraph<S> {
271 pub(crate) fn iter_sccs(&self) -> impl Iterator<Item = SmallVec<[NodeId; 4]>> + '_ {
273 super::tarjan_scc::new_tarjan_scc(self)
274 }
275}
276
277#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Ord, Eq, Hash)]
279#[repr(u8)]
280pub enum Direction {
281 Outgoing = 0,
283 Incoming = 1,
285}
286
287impl Direction {
288 #[inline]
290 pub fn opposite(self) -> Self {
291 match self {
292 Self::Outgoing => Self::Incoming,
293 Self::Incoming => Self::Outgoing,
294 }
295 }
296}
297
298#[derive(Clone, Copy)]
300struct CompactNodeIdAndDirection {
301 index: usize,
302 is_system: bool,
303 direction: Direction,
304}
305
306impl fmt::Debug for CompactNodeIdAndDirection {
307 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
308 self.load().fmt(f)
309 }
310}
311
312impl CompactNodeIdAndDirection {
313 const fn store(node: NodeId, direction: Direction) -> Self {
314 let index = node.index();
315 let is_system = node.is_system();
316
317 Self {
318 index,
319 is_system,
320 direction,
321 }
322 }
323
324 const fn load(self) -> (NodeId, Direction) {
325 let Self {
326 index,
327 is_system,
328 direction,
329 } = self;
330
331 let node = match is_system {
332 true => NodeId::System(index),
333 false => NodeId::Set(index),
334 };
335
336 (node, direction)
337 }
338}
339
340#[derive(Clone, Copy, Hash, PartialEq, Eq)]
342struct CompactNodeIdPair {
343 index_a: usize,
344 index_b: usize,
345 is_system_a: bool,
346 is_system_b: bool,
347}
348
349impl fmt::Debug for CompactNodeIdPair {
350 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351 self.load().fmt(f)
352 }
353}
354
355impl CompactNodeIdPair {
356 const fn store(a: NodeId, b: NodeId) -> Self {
357 let index_a = a.index();
358 let is_system_a = a.is_system();
359
360 let index_b = b.index();
361 let is_system_b = b.is_system();
362
363 Self {
364 index_a,
365 index_b,
366 is_system_a,
367 is_system_b,
368 }
369 }
370
371 const fn load(self) -> (NodeId, NodeId) {
372 let Self {
373 index_a,
374 index_b,
375 is_system_a,
376 is_system_b,
377 } = self;
378
379 let a = match is_system_a {
380 true => NodeId::System(index_a),
381 false => NodeId::Set(index_a),
382 };
383
384 let b = match is_system_b {
385 true => NodeId::System(index_b),
386 false => NodeId::Set(index_b),
387 };
388
389 (a, b)
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use alloc::vec;
397
398 #[test]
402 fn node_order_preservation() {
403 use NodeId::System;
404
405 let mut graph = <DiGraph>::default();
406
407 graph.add_node(System(1));
408 graph.add_node(System(2));
409 graph.add_node(System(3));
410 graph.add_node(System(4));
411
412 assert_eq!(
413 graph.nodes().collect::<Vec<_>>(),
414 vec![System(1), System(2), System(3), System(4)]
415 );
416
417 graph.remove_node(System(1));
418
419 assert_eq!(
420 graph.nodes().collect::<Vec<_>>(),
421 vec![System(4), System(2), System(3)]
422 );
423
424 graph.remove_node(System(4));
425
426 assert_eq!(
427 graph.nodes().collect::<Vec<_>>(),
428 vec![System(3), System(2)]
429 );
430
431 graph.remove_node(System(2));
432
433 assert_eq!(graph.nodes().collect::<Vec<_>>(), vec![System(3)]);
434
435 graph.remove_node(System(3));
436
437 assert_eq!(graph.nodes().collect::<Vec<_>>(), vec![]);
438 }
439
440 #[test]
444 fn strongly_connected_components() {
445 use NodeId::System;
446
447 let mut graph = <DiGraph>::default();
448
449 graph.add_edge(System(1), System(2));
450 graph.add_edge(System(2), System(1));
451
452 graph.add_edge(System(2), System(3));
453 graph.add_edge(System(3), System(2));
454
455 graph.add_edge(System(4), System(5));
456 graph.add_edge(System(5), System(4));
457
458 graph.add_edge(System(6), System(2));
459
460 let sccs = graph
461 .iter_sccs()
462 .map(|scc| scc.to_vec())
463 .collect::<Vec<_>>();
464
465 assert_eq!(
466 sccs,
467 vec![
468 vec![System(3), System(2), System(1)],
469 vec![System(5), System(4)],
470 vec![System(6)]
471 ]
472 );
473 }
474}