diff --git a/src/daglib/dag.py b/src/daglib/dag.py index f98728e..62583a7 100644 --- a/src/daglib/dag.py +++ b/src/daglib/dag.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections import defaultdict from collections.abc import Callable, Iterable, Iterator, MutableMapping from dataclasses import dataclass, field +from graphlib import TopologicalSorter from typing import Generic, TypeVar from .operations import transitive_closure @@ -30,8 +31,11 @@ class DAG(Generic[T], MutableMapping[T, DAGSetView[T]]): on_remove: Callable[[T, T], None] | None = None @property - def reverse(self) -> dictset[T]: - return self._pred + def reverse(self) -> DAG[T]: + return type(self)( + _succ=self._pred, + _pred=self._succ, + ) # --- DAG internals --- @@ -116,12 +120,18 @@ class DAG(Generic[T], MutableMapping[T, DAGSetView[T]]): self._succ.pop(u, None) self._pred.pop(u, None) - def subgraph(self, sub: Iterable[T]) -> dictset[T]: + def subgraph(self, sub: Iterable[T]) -> DAG[T]: closure = transitive_closure(self, sub) - subgraph = DAG[T]() - for n, deps in self._succ.items(): - if n in closure: - subgraph[n] += deps & closure + subgraph = type(self)( + _succ=defaultdict(set, {k: self._succ[k] for k in set(self._succ.keys()) & closure}), + _pred=defaultdict(set, {k: self._pred[k] for k in set(self._pred.keys()) & closure}), + ) subgraph.on_add = self.on_add subgraph.on_remove = self.on_remove return subgraph + + def topo_sort(self, *, reverse: bool = False) -> list[T]: + order = list(TopologicalSorter(self._pred).static_order()) + if reverse: + order.reverse() + return order diff --git a/tests/test_dag.py b/tests/test_dag.py index b39296a..00c9c02 100644 --- a/tests/test_dag.py +++ b/tests/test_dag.py @@ -12,6 +12,10 @@ class TestDAGInit: g = DAG() assert len(g) == 0 + def test_reverse(self) -> None: + g = DAG() + assert len(g.reverse) == 0 + def test_default_callbacks_none(self) -> None: g = DAG() assert g.on_add is None @@ -310,3 +314,17 @@ class TestDAGComplexScenarios: assert 2 in g[1] assert 4 in g[2] assert len(g) == 3 + + def test_subgraphs(self) -> None: + g = DAG[str]() + g["A"] += "B" + g["B"] += "C" + g["B"] += "D" + g["D"] += "E" + g["E"] += "F" + assert len(g) == 5 + + sub = g.subgraph("D") + assert len(sub) == 2 + assert set(sub["D"]) == {"E"} + assert set(sub["E"]) == {"F"}