1use alloc::{vec, vec::Vec};
8use core::{
9 fmt::{self, Debug},
10 hash::{BuildHasher, Hash},
11};
12use thiserror::Error;
13
14use bevy_platform::{
15 collections::{HashMap, HashSet},
16 hash::FixedHasher,
17};
18use indexmap::IndexMap;
19use smallvec::SmallVec;
20
21use Direction::{Incoming, Outgoing};
22
23pub trait GraphNodeId: Copy + Eq + Hash + Ord + Debug {
28 type Adjacent: Copy + Debug + From<(Self, Direction)> + Into<(Self, Direction)>;
31 type Edge: Copy + Eq + Hash + Debug + From<(Self, Self)> + Into<(Self, Self)>;
34
35 fn kind(&self) -> &'static str;
40}
41
42pub type UnGraph<N, S = FixedHasher> = Graph<false, N, S>;
47
48pub type DiGraph<N, S = FixedHasher> = Graph<true, N, S>;
53
54#[derive(Clone)]
72pub struct Graph<const DIRECTED: bool, N: GraphNodeId, S = FixedHasher>
73where
74 S: BuildHasher,
75{
76 nodes: IndexMap<N, Vec<N::Adjacent>, S>,
77 edges: HashSet<N::Edge, S>,
78}
79
80impl<const DIRECTED: bool, N: GraphNodeId, S: BuildHasher> Debug for Graph<DIRECTED, N, S> {
81 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
82 self.nodes.fmt(f)
83 }
84}
85
86impl<const DIRECTED: bool, N: GraphNodeId, S: BuildHasher> Graph<DIRECTED, N, S> {
87 pub fn with_capacity(nodes: usize, edges: usize) -> Self
89 where
90 S: Default,
91 {
92 Self {
93 nodes: IndexMap::with_capacity_and_hasher(nodes, S::default()),
94 edges: HashSet::with_capacity_and_hasher(edges, S::default()),
95 }
96 }
97
98 #[inline]
100 fn edge_key(a: N, b: N) -> N::Edge {
101 let (a, b) = if DIRECTED || a <= b { (a, b) } else { (b, a) };
102
103 N::Edge::from((a, b))
104 }
105
106 pub fn node_count(&self) -> usize {
108 self.nodes.len()
109 }
110
111 pub fn edge_count(&self) -> usize {
113 self.edges.len()
114 }
115
116 pub fn add_node(&mut self, n: N) {
118 self.nodes.entry(n).or_default();
119 }
120
121 pub fn remove_node(&mut self, n: N) {
125 let Some(links) = self.nodes.swap_remove(&n) else {
126 return;
127 };
128
129 let links = links.into_iter().map(N::Adjacent::into);
130
131 for (succ, dir) in links {
132 let edge = if dir == Outgoing {
133 Self::edge_key(n, succ)
134 } else {
135 Self::edge_key(succ, n)
136 };
137 self.remove_single_edge(succ, n, dir.opposite());
139 self.edges.remove(&edge);
141 }
142 }
143
144 pub fn contains_node(&self, n: N) -> bool {
146 self.nodes.contains_key(&n)
147 }
148
149 pub fn add_edge(&mut self, a: N, b: N) {
154 if self.edges.insert(Self::edge_key(a, b)) {
155 self.nodes
157 .entry(a)
158 .or_insert_with(|| Vec::with_capacity(1))
159 .push(N::Adjacent::from((b, Outgoing)));
160 if a != b {
161 self.nodes
163 .entry(b)
164 .or_insert_with(|| Vec::with_capacity(1))
165 .push(N::Adjacent::from((a, Incoming)));
166 }
167 }
168 }
169
170 fn remove_single_edge(&mut self, a: N, b: N, dir: Direction) -> bool {
174 let Some(sus) = self.nodes.get_mut(&a) else {
175 return false;
176 };
177
178 let Some(index) = sus
179 .iter()
180 .copied()
181 .map(N::Adjacent::into)
182 .position(|elt| (DIRECTED && elt == (b, dir)) || (!DIRECTED && elt.0 == b))
183 else {
184 return false;
185 };
186
187 sus.swap_remove(index);
188 true
189 }
190
191 pub fn remove_edge(&mut self, a: N, b: N) -> bool {
195 let exist1 = self.remove_single_edge(a, b, Outgoing);
196 let exist2 = if a != b {
197 self.remove_single_edge(b, a, Incoming)
198 } else {
199 exist1
200 };
201 let weight = self.edges.remove(&Self::edge_key(a, b));
202 debug_assert!(exist1 == exist2 && exist1 == weight);
203 weight
204 }
205
206 pub fn contains_edge(&self, a: N, b: N) -> bool {
208 self.edges.contains(&Self::edge_key(a, b))
209 }
210
211 pub fn reserve_nodes(&mut self, additional: usize) {
214 self.nodes.reserve(additional);
215 }
216
217 pub fn reserve_edges(&mut self, additional: usize) {
220 self.edges.reserve(additional);
221 }
222
223 pub fn nodes(&self) -> impl DoubleEndedIterator<Item = N> + ExactSizeIterator<Item = N> + '_ {
225 self.nodes.keys().copied()
226 }
227
228 pub fn neighbors(&self, a: N) -> impl DoubleEndedIterator<Item = N> + '_ {
230 let iter = match self.nodes.get(&a) {
231 Some(neigh) => neigh.iter(),
232 None => [].iter(),
233 };
234
235 iter.copied()
236 .map(N::Adjacent::into)
237 .filter_map(|(n, dir)| (!DIRECTED || dir == Outgoing).then_some(n))
238 }
239
240 pub fn neighbors_directed(
244 &self,
245 a: N,
246 dir: Direction,
247 ) -> impl DoubleEndedIterator<Item = N> + '_ {
248 let iter = match self.nodes.get(&a) {
249 Some(neigh) => neigh.iter(),
250 None => [].iter(),
251 };
252
253 iter.copied()
254 .map(N::Adjacent::into)
255 .filter_map(move |(n, d)| (!DIRECTED || d == dir || n == a).then_some(n))
256 }
257
258 pub fn edges(&self, a: N) -> impl DoubleEndedIterator<Item = (N, N)> + '_ {
261 self.neighbors(a)
262 .map(move |b| match self.edges.get(&Self::edge_key(a, b)) {
263 None => unreachable!(),
264 Some(_) => (a, b),
265 })
266 }
267
268 pub fn edges_directed(
271 &self,
272 a: N,
273 dir: Direction,
274 ) -> impl DoubleEndedIterator<Item = (N, N)> + '_ {
275 self.neighbors_directed(a, dir).map(move |b| {
276 let (a, b) = if dir == Incoming { (b, a) } else { (a, b) };
277
278 match self.edges.get(&Self::edge_key(a, b)) {
279 None => unreachable!(),
280 Some(_) => (a, b),
281 }
282 })
283 }
284
285 pub fn all_edges(&self) -> impl ExactSizeIterator<Item = (N, N)> + '_ {
287 self.edges.iter().copied().map(N::Edge::into)
288 }
289
290 pub(crate) fn to_index(&self, ix: N) -> usize {
291 self.nodes.get_index_of(&ix).unwrap()
292 }
293
294 pub fn try_convert<T>(self) -> Result<Graph<DIRECTED, T, S>, N::Error>
304 where
305 N: TryInto<T>,
306 T: GraphNodeId,
307 S: Default,
308 {
309 fn try_convert_node<N: GraphNodeId + TryInto<T>, T: GraphNodeId>(
311 (key, adj): (N, Vec<N::Adjacent>),
312 ) -> Result<(T, Vec<T::Adjacent>), N::Error> {
313 let key = key.try_into()?;
314 let adj = adj
315 .into_iter()
316 .map(|node| {
317 let (id, dir) = node.into();
318 Ok(T::Adjacent::from((id.try_into()?, dir)))
319 })
320 .collect::<Result<_, N::Error>>()?;
321 Ok((key, adj))
322 }
323 fn try_convert_edge<N: GraphNodeId + TryInto<T>, T: GraphNodeId>(
325 edge: N::Edge,
326 ) -> Result<T::Edge, N::Error> {
327 let (a, b) = edge.into();
328 Ok(T::Edge::from((a.try_into()?, b.try_into()?)))
329 }
330
331 let nodes = self
332 .nodes
333 .into_iter()
334 .map(try_convert_node::<N, T>)
335 .collect::<Result<_, N::Error>>()?;
336 let edges = self
337 .edges
338 .into_iter()
339 .map(try_convert_edge::<N, T>)
340 .collect::<Result<_, N::Error>>()?;
341 Ok(Graph { nodes, edges })
342 }
343}
344
345impl<const DIRECTED: bool, N, S> Default for Graph<DIRECTED, N, S>
347where
348 N: GraphNodeId,
349 S: BuildHasher + Default,
350{
351 fn default() -> Self {
352 Self::with_capacity(0, 0)
353 }
354}
355
356impl<N: GraphNodeId, S: BuildHasher> DiGraph<N, S> {
357 pub fn toposort(&self, mut scratch: Vec<N>) -> Result<Vec<N>, DiGraphToposortError<N>> {
369 if let Some((node, _)) = self.all_edges().find(|(left, right)| left == right) {
372 return Err(DiGraphToposortError::Loop(node));
373 }
374
375 scratch.clear();
377 scratch.reserve_exact(self.node_count().saturating_sub(scratch.capacity()));
378 let mut top_sorted_nodes = scratch;
379 let mut sccs_with_cycles = Vec::new();
380
381 for scc in self.iter_sccs() {
382 top_sorted_nodes.extend_from_slice(&scc);
386 if scc.len() > 1 {
387 sccs_with_cycles.push(scc);
388 }
389 }
390
391 if sccs_with_cycles.is_empty() {
392 top_sorted_nodes.reverse();
394 Ok(top_sorted_nodes)
395 } else {
396 let mut cycles = Vec::new();
397 for scc in &sccs_with_cycles {
398 cycles.append(&mut self.simple_cycles_in_component(scc));
399 }
400
401 Err(DiGraphToposortError::Cycle(cycles))
402 }
403 }
404
405 pub fn simple_cycles_in_component(&self, scc: &[N]) -> Vec<Vec<N>> {
412 let mut cycles = vec![];
413 let mut sccs = vec![SmallVec::from_slice(scc)];
414
415 while let Some(mut scc) = sccs.pop() {
416 let mut subgraph = DiGraph::<N>::with_capacity(scc.len(), 0);
418 for &node in &scc {
419 subgraph.add_node(node);
420 }
421
422 for &node in &scc {
423 for successor in self.neighbors(node) {
424 if subgraph.contains_node(successor) {
425 subgraph.add_edge(node, successor);
426 }
427 }
428 }
429
430 let mut path = Vec::with_capacity(subgraph.node_count());
432 let mut blocked: HashSet<_> =
434 HashSet::with_capacity_and_hasher(subgraph.node_count(), Default::default());
435 let mut unblock_together: HashMap<N, HashSet<N>> =
438 HashMap::with_capacity_and_hasher(subgraph.node_count(), Default::default());
439 let mut unblock_stack = Vec::with_capacity(subgraph.node_count());
441 let mut maybe_in_more_cycles: HashSet<N> =
443 HashSet::with_capacity_and_hasher(subgraph.node_count(), Default::default());
444 let mut stack = Vec::with_capacity(subgraph.node_count());
446
447 let root = scc.pop().unwrap();
449 path.clear();
451 path.push(root);
452 blocked.insert(root);
454
455 stack.clear();
457 stack.push((root, subgraph.neighbors(root)));
458 while !stack.is_empty() {
459 let &mut (ref node, ref mut successors) = stack.last_mut().unwrap();
460 if let Some(next) = successors.next() {
461 if next == root {
462 maybe_in_more_cycles.extend(path.iter());
464 cycles.push(path.clone());
465 } else if !blocked.contains(&next) {
466 maybe_in_more_cycles.remove(&next);
468 path.push(next);
469 blocked.insert(next);
470 stack.push((next, subgraph.neighbors(next)));
471 continue;
472 } else {
473 }
475 }
476
477 if successors.peekable().peek().is_none() {
478 if maybe_in_more_cycles.contains(node) {
479 unblock_stack.push(*node);
480 while let Some(n) = unblock_stack.pop() {
482 if blocked.remove(&n) {
483 let unblock_predecessors = unblock_together.entry(n).or_default();
484 unblock_stack.extend(unblock_predecessors.iter());
485 unblock_predecessors.clear();
486 }
487 }
488 } else {
489 for successor in subgraph.neighbors(*node) {
491 unblock_together.entry(successor).or_default().insert(*node);
492 }
493 }
494
495 path.pop();
497 stack.pop();
498 }
499 }
500
501 drop(stack);
502
503 subgraph.remove_node(root);
505
506 sccs.extend(subgraph.iter_sccs().filter(|scc| scc.len() > 1));
508 }
509
510 cycles
511 }
512
513 pub(crate) fn iter_sccs(&self) -> impl Iterator<Item = SmallVec<[N; 4]>> + '_ {
515 super::tarjan_scc::new_tarjan_scc(self)
516 }
517}
518
519#[derive(Error, Debug)]
521pub enum DiGraphToposortError<N: GraphNodeId> {
522 #[error("self-loop detected at node `{0:?}`")]
524 Loop(N),
525 #[error("cycles detected: {0:?}")]
527 Cycle(Vec<Vec<N>>),
528}
529
530#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Ord, Eq, Hash)]
532#[repr(u8)]
533pub enum Direction {
534 Outgoing = 0,
536 Incoming = 1,
538}
539
540impl Direction {
541 #[inline]
543 pub fn opposite(self) -> Self {
544 match self {
545 Self::Outgoing => Self::Incoming,
546 Self::Incoming => Self::Outgoing,
547 }
548 }
549}
550
551#[cfg(test)]
552mod tests {
553 use crate::schedule::{NodeId, SystemKey};
554
555 use super::*;
556 use alloc::vec;
557 use slotmap::SlotMap;
558
559 #[test]
563 fn node_order_preservation() {
564 use NodeId::System;
565
566 let mut slotmap = SlotMap::<SystemKey, ()>::with_key();
567 let mut graph = DiGraph::<NodeId>::default();
568
569 let sys1 = slotmap.insert(());
570 let sys2 = slotmap.insert(());
571 let sys3 = slotmap.insert(());
572 let sys4 = slotmap.insert(());
573
574 graph.add_node(System(sys1));
575 graph.add_node(System(sys2));
576 graph.add_node(System(sys3));
577 graph.add_node(System(sys4));
578
579 assert_eq!(
580 graph.nodes().collect::<Vec<_>>(),
581 vec![System(sys1), System(sys2), System(sys3), System(sys4)]
582 );
583
584 graph.remove_node(System(sys1));
585
586 assert_eq!(
587 graph.nodes().collect::<Vec<_>>(),
588 vec![System(sys4), System(sys2), System(sys3)]
589 );
590
591 graph.remove_node(System(sys4));
592
593 assert_eq!(
594 graph.nodes().collect::<Vec<_>>(),
595 vec![System(sys3), System(sys2)]
596 );
597
598 graph.remove_node(System(sys2));
599
600 assert_eq!(graph.nodes().collect::<Vec<_>>(), vec![System(sys3)]);
601
602 graph.remove_node(System(sys3));
603
604 assert_eq!(graph.nodes().collect::<Vec<_>>(), vec![]);
605 }
606
607 #[test]
611 fn strongly_connected_components() {
612 use NodeId::System;
613
614 let mut slotmap = SlotMap::<SystemKey, ()>::with_key();
615 let mut graph = DiGraph::<NodeId>::default();
616
617 let sys1 = slotmap.insert(());
618 let sys2 = slotmap.insert(());
619 let sys3 = slotmap.insert(());
620 let sys4 = slotmap.insert(());
621 let sys5 = slotmap.insert(());
622 let sys6 = slotmap.insert(());
623
624 graph.add_edge(System(sys1), System(sys2));
625 graph.add_edge(System(sys2), System(sys1));
626
627 graph.add_edge(System(sys2), System(sys3));
628 graph.add_edge(System(sys3), System(sys2));
629
630 graph.add_edge(System(sys4), System(sys5));
631 graph.add_edge(System(sys5), System(sys4));
632
633 graph.add_edge(System(sys6), System(sys2));
634
635 let sccs = graph
636 .iter_sccs()
637 .map(|scc| scc.to_vec())
638 .collect::<Vec<_>>();
639
640 assert_eq!(
641 sccs,
642 vec![
643 vec![System(sys3), System(sys2), System(sys1)],
644 vec![System(sys5), System(sys4)],
645 vec![System(sys6)]
646 ]
647 );
648 }
649}