diff --git a/src/daglib/dag.py b/src/daglib/dag.py index ccb01bc..4601b72 100644 --- a/src/daglib/dag.py +++ b/src/daglib/dag.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections import defaultdict from collections.abc import Iterable, Iterator, MutableMapping, MutableSet -from typing import Generic, Set, TypeVar +from typing import Generic, TypeVar from .set import DAGSet @@ -20,8 +20,12 @@ class DAG(Generic[T], MutableMapping[T, MutableSet[T]]): """ def __init__(self) -> None: - self._succ: defaultdict[T, Set[T]] = defaultdict(set) - self.reverse: defaultdict[T, Set[T]] = defaultdict(set) + self._succ: defaultdict[T, set[T]] = defaultdict(set) + self._pred: defaultdict[T, set[T]] = defaultdict(set) + + @property + def reverse(self) -> defaultdict[T, set[T]]: + return self._pred # --- MutableMapping interface --- @@ -30,12 +34,12 @@ class DAG(Generic[T], MutableMapping[T, MutableSet[T]]): def on_add(v: T) -> None: self._succ[u].add(v) - self.reverse[v].add(u) + self._pred[v].add(u) print(f"Adding edge {u} -> {v}") def on_remove(v: T) -> None: self._succ[u].remove(v) - self.reverse[v].remove(u) + self._pred[v].remove(u) print(f"Removing edge {u} -> {v}") dagset.on_add = on_add @@ -67,21 +71,21 @@ class DAG(Generic[T], MutableMapping[T, MutableSet[T]]): if u not in self._succ: self._succ[u] = set() - if v not in self.reverse: - self.reverse[v] = set() + if v not in self._pred: + self._pred[v] = set() if v not in self._succ[u]: self._succ[u].add(v) - self.reverse[v].add(u) + self._pred[v].add(u) # Touch keys so nodes appear if accessed later (optional but nice) _ = self._succ[v] - _ = self.reverse[u] + _ = self._pred[u] def remove_edge(self, u: T, v: T, *, missing_ok: bool = True) -> None: if v in self._succ.get(u, ()): self._succ[u].remove(v) - self.reverse[v].discard(u) + self._pred[v].discard(u) elif not missing_ok: raise KeyError((u, v)) @@ -90,8 +94,8 @@ class DAG(Generic[T], MutableMapping[T, MutableSet[T]]): for v in list(self._succ.get(u, ())): self.remove_edge(u, v, missing_ok=True) # remove incoming - for p in list(self.reverse.get(u, ())): + for p in list(self._pred.get(u, ())): self.remove_edge(p, u, missing_ok=True) self._succ.pop(u, None) - self.reverse.pop(u, None) + self._pred.pop(u, None)