subgraph test

This commit is contained in:
John Lancaster
2026-02-21 15:09:38 -06:00
parent b2b90555c2
commit 88a4064a71
2 changed files with 35 additions and 7 deletions

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, MutableMapping
from dataclasses import dataclass, field
from graphlib import TopologicalSorter
from typing import Generic, TypeVar
from .operations import transitive_closure
@@ -30,8 +31,11 @@ class DAG(Generic[T], MutableMapping[T, DAGSetView[T]]):
on_remove: Callable[[T, T], None] | None = None
@property
def reverse(self) -> dictset[T]:
return self._pred
def reverse(self) -> DAG[T]:
return type(self)(
_succ=self._pred,
_pred=self._succ,
)
# --- DAG internals ---
@@ -116,12 +120,18 @@ class DAG(Generic[T], MutableMapping[T, DAGSetView[T]]):
self._succ.pop(u, None)
self._pred.pop(u, None)
def subgraph(self, sub: Iterable[T]) -> dictset[T]:
def subgraph(self, sub: Iterable[T]) -> DAG[T]:
closure = transitive_closure(self, sub)
subgraph = DAG[T]()
for n, deps in self._succ.items():
if n in closure:
subgraph[n] += deps & closure
subgraph = type(self)(
_succ=defaultdict(set, {k: self._succ[k] for k in set(self._succ.keys()) & closure}),
_pred=defaultdict(set, {k: self._pred[k] for k in set(self._pred.keys()) & closure}),
)
subgraph.on_add = self.on_add
subgraph.on_remove = self.on_remove
return subgraph
def topo_sort(self, *, reverse: bool = False) -> list[T]:
order = list(TopologicalSorter(self._pred).static_order())
if reverse:
order.reverse()
return order

View File

@@ -12,6 +12,10 @@ class TestDAGInit:
g = DAG()
assert len(g) == 0
def test_reverse(self) -> None:
g = DAG()
assert len(g.reverse) == 0
def test_default_callbacks_none(self) -> None:
g = DAG()
assert g.on_add is None
@@ -310,3 +314,17 @@ class TestDAGComplexScenarios:
assert 2 in g[1]
assert 4 in g[2]
assert len(g) == 3
def test_subgraphs(self) -> None:
g = DAG[str]()
g["A"] += "B"
g["B"] += "C"
g["B"] += "D"
g["D"] += "E"
g["E"] += "F"
assert len(g) == 5
sub = g.subgraph("D")
assert len(sub) == 2
assert set(sub["D"]) == {"E"}
assert set(sub["E"]) == {"F"}