started operations
This commit is contained in:
@@ -54,7 +54,7 @@ select = [
|
||||
"SIM", # flake8-simplify
|
||||
]
|
||||
extend-fixable = ["ALL"]
|
||||
ignore = ["UP046"]
|
||||
ignore = ["UP046", "UP047"]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["daglib"]
|
||||
|
||||
@@ -5,9 +5,9 @@ from collections.abc import Callable, Iterable, Iterator, MutableMapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from .operations import transitive_closure
|
||||
from .set import DAGSetView
|
||||
|
||||
type dictset = defaultdict[T, set[T]]
|
||||
from .typing import dictset
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -115,3 +115,13 @@ class DAG(Generic[T], MutableMapping[T, DAGSetView[T]]):
|
||||
|
||||
self._succ.pop(u, None)
|
||||
self._pred.pop(u, None)
|
||||
|
||||
def subgraph(self, sub: Iterable[T]) -> dictset[T]:
|
||||
closure = transitive_closure(self, sub)
|
||||
subgraph = DAG[T]()
|
||||
for n, deps in self._succ.items():
|
||||
if n in closure:
|
||||
subgraph[n] += deps & closure
|
||||
subgraph.on_add = self.on_add
|
||||
subgraph.on_remove = self.on_remove
|
||||
return subgraph
|
||||
|
||||
43
src/daglib/operations.py
Normal file
43
src/daglib/operations.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
from .util import ensure_set
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .dag import DAG
|
||||
from .typing import dictset
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def all_nodes(graph: DAG) -> set:
|
||||
return set(graph.keys()) | set(graph.reverse.keys())
|
||||
|
||||
|
||||
def transitive_closure(graph: DAG[T], sub: T | Iterable[T], *, include_self: bool = False):
|
||||
sub = ensure_set(sub)
|
||||
stack = list(sub)
|
||||
seen = set()
|
||||
|
||||
while stack:
|
||||
n = stack.pop()
|
||||
if n in seen:
|
||||
continue
|
||||
seen.add(n)
|
||||
stack.extend(graph.get(n, []))
|
||||
|
||||
if include_self:
|
||||
seen.update(sub)
|
||||
|
||||
return seen
|
||||
|
||||
|
||||
def slice_subgraph(graph: DAG, sub: Any) -> dictset:
|
||||
closure = transitive_closure(graph, sub)
|
||||
return {n: d & closure for n in closure if (d := set(graph[n]))}
|
||||
|
||||
|
||||
def topological_sort(graph: DAG, *, reverse: bool = False) -> list:
|
||||
return list(graph.keys())
|
||||
10
src/daglib/typing.py
Normal file
10
src/daglib/typing.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import MutableMapping, MutableSet
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
type dictset = defaultdict[T, set[T]]
|
||||
type DAGType = MutableMapping[T, MutableSet[T]]
|
||||
@@ -301,6 +301,7 @@ class TestDAGComplexScenarios:
|
||||
g["foo"] -= {"bar"}
|
||||
assert len(g["foo"]) == 2
|
||||
assert "bar" not in g["foo"]
|
||||
assert set(g["foo"]) == {"baz", "d"}
|
||||
|
||||
def test_type_hints_with_ints(self) -> None:
|
||||
g = DAG[int]()
|
||||
|
||||
Reference in New Issue
Block a user