diff --git a/pyproject.toml b/pyproject.toml index 71509da..9d0576d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ select = [ "SIM", # flake8-simplify ] extend-fixable = ["ALL"] -ignore = ["UP046"] +ignore = ["UP046", "UP047"] [tool.ruff.lint.isort] known-first-party = ["daglib"] diff --git a/src/daglib/dag.py b/src/daglib/dag.py index 760774f..f98728e 100644 --- a/src/daglib/dag.py +++ b/src/daglib/dag.py @@ -5,9 +5,9 @@ from collections.abc import Callable, Iterable, Iterator, MutableMapping from dataclasses import dataclass, field from typing import Generic, TypeVar +from .operations import transitive_closure from .set import DAGSetView - -type dictset = defaultdict[T, set[T]] +from .typing import dictset T = TypeVar("T") @@ -115,3 +115,13 @@ class DAG(Generic[T], MutableMapping[T, DAGSetView[T]]): self._succ.pop(u, None) self._pred.pop(u, None) + + def subgraph(self, sub: Iterable[T]) -> dictset[T]: + closure = transitive_closure(self, sub) + subgraph = DAG[T]() + for n, deps in self._succ.items(): + if n in closure: + subgraph[n] += deps & closure + subgraph.on_add = self.on_add + subgraph.on_remove = self.on_remove + return subgraph diff --git a/src/daglib/operations.py b/src/daglib/operations.py new file mode 100644 index 0000000..0d642e5 --- /dev/null +++ b/src/daglib/operations.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, TypeVar + +from .util import ensure_set + +if TYPE_CHECKING: + from .dag import DAG + from .typing import dictset + +T = TypeVar("T") + + +def all_nodes(graph: DAG) -> set: + return set(graph.keys()) | set(graph.reverse.keys()) + + +def transitive_closure(graph: DAG[T], sub: T | Iterable[T], *, include_self: bool = False): + sub = ensure_set(sub) + stack = list(sub) + seen = set() + + while stack: + n = stack.pop() + if n in seen: + continue + seen.add(n) + stack.extend(graph.get(n, [])) + + if include_self: + seen.update(sub) + + return seen + + +def slice_subgraph(graph: DAG, sub: Any) -> dictset: + closure = transitive_closure(graph, sub) + return {n: d & closure for n in closure if (d := set(graph[n]))} + + +def topological_sort(graph: DAG, *, reverse: bool = False) -> list: + return list(graph.keys()) diff --git a/src/daglib/typing.py b/src/daglib/typing.py new file mode 100644 index 0000000..642cb4a --- /dev/null +++ b/src/daglib/typing.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from collections import defaultdict +from collections.abc import MutableMapping, MutableSet +from typing import TypeVar + +T = TypeVar("T") + +type dictset = defaultdict[T, set[T]] +type DAGType = MutableMapping[T, MutableSet[T]] diff --git a/tests/test_dag.py b/tests/test_dag.py index 5072ffa..b39296a 100644 --- a/tests/test_dag.py +++ b/tests/test_dag.py @@ -301,6 +301,7 @@ class TestDAGComplexScenarios: g["foo"] -= {"bar"} assert len(g["foo"]) == 2 assert "bar" not in g["foo"] + assert set(g["foo"]) == {"baz", "d"} def test_type_hints_with_ints(self) -> None: g = DAG[int]()