subgraph test

This commit is contained in:
John Lancaster
2026-02-21 15:09:38 -06:00
parent b2b90555c2
commit 88a4064a71
2 changed files with 35 additions and 7 deletions

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, MutableMapping from collections.abc import Callable, Iterable, Iterator, MutableMapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from graphlib import TopologicalSorter
from typing import Generic, TypeVar from typing import Generic, TypeVar
from .operations import transitive_closure from .operations import transitive_closure
@@ -30,8 +31,11 @@ class DAG(Generic[T], MutableMapping[T, DAGSetView[T]]):
on_remove: Callable[[T, T], None] | None = None on_remove: Callable[[T, T], None] | None = None
@property @property
def reverse(self) -> dictset[T]: def reverse(self) -> DAG[T]:
return self._pred return type(self)(
_succ=self._pred,
_pred=self._succ,
)
# --- DAG internals --- # --- DAG internals ---
@@ -116,12 +120,18 @@ class DAG(Generic[T], MutableMapping[T, DAGSetView[T]]):
self._succ.pop(u, None) self._succ.pop(u, None)
self._pred.pop(u, None) self._pred.pop(u, None)
def subgraph(self, sub: Iterable[T]) -> dictset[T]: def subgraph(self, sub: Iterable[T]) -> DAG[T]:
closure = transitive_closure(self, sub) closure = transitive_closure(self, sub)
subgraph = DAG[T]() subgraph = type(self)(
for n, deps in self._succ.items(): _succ=defaultdict(set, {k: self._succ[k] for k in set(self._succ.keys()) & closure}),
if n in closure: _pred=defaultdict(set, {k: self._pred[k] for k in set(self._pred.keys()) & closure}),
subgraph[n] += deps & closure )
subgraph.on_add = self.on_add subgraph.on_add = self.on_add
subgraph.on_remove = self.on_remove subgraph.on_remove = self.on_remove
return subgraph return subgraph
def topo_sort(self, *, reverse: bool = False) -> list[T]:
order = list(TopologicalSorter(self._pred).static_order())
if reverse:
order.reverse()
return order

View File

@@ -12,6 +12,10 @@ class TestDAGInit:
g = DAG() g = DAG()
assert len(g) == 0 assert len(g) == 0
def test_reverse(self) -> None:
g = DAG()
assert len(g.reverse) == 0
def test_default_callbacks_none(self) -> None: def test_default_callbacks_none(self) -> None:
g = DAG() g = DAG()
assert g.on_add is None assert g.on_add is None
@@ -310,3 +314,17 @@ class TestDAGComplexScenarios:
assert 2 in g[1] assert 2 in g[1]
assert 4 in g[2] assert 4 in g[2]
assert len(g) == 3 assert len(g) == 3
def test_subgraphs(self) -> None:
g = DAG[str]()
g["A"] += "B"
g["B"] += "C"
g["B"] += "D"
g["D"] += "E"
g["E"] += "F"
assert len(g) == 5
sub = g.subgraph("D")
assert len(sub) == 2
assert set(sub["D"]) == {"E"}
assert set(sub["E"]) == {"F"}