diff --git a/src/daglib/dag.py b/src/daglib/dag.py index e3d7b72..f71e8bb 100644 --- a/src/daglib/dag.py +++ b/src/daglib/dag.py @@ -36,20 +36,19 @@ class DAG(Generic[T], MutableMapping[T, DAGSetView[T]]): # --- MutableMapping interface --- def __getitem__(self, u: T) -> DAGSetView[T]: - dagset = DAGSetView(self._succ.get(u)) - def _on_add(v: T) -> None: self._succ[u] |= {v} self._pred[v] |= {u} - if self.on_add is not None: - self.on_add(u, v) + if add_hook := getattr(self, "on_add", None): + add_hook(u, v) def _on_remove(v: T) -> None: self._succ[u] -= {v} self._pred[v] -= {u} - if self.on_remove is not None: - self.on_remove(u, v) + if remove_hook := getattr(self, "on_remove", None): + remove_hook(u, v) + dagset = DAGSetView(self._succ.get(u)) dagset.on_add = _on_add dagset.on_remove = _on_remove return dagset @@ -66,8 +65,8 @@ class DAG(Generic[T], MutableMapping[T, DAGSetView[T]]): self.__setsingle__(u, v) def __setsingle__(self, u: T, v: T) -> None: - self._succ[u] |= {v} - self._pred[v] |= {u} + self._succ[u] = v + self._pred[v] = u def __setmultiple__(self, u: T, vs: Iterable[T]) -> None: self._succ[u] |= set(vs) diff --git a/src/daglib/set.py b/src/daglib/set.py index 17a957e..e7fad15 100644 --- a/src/daglib/set.py +++ b/src/daglib/set.py @@ -8,6 +8,7 @@ T = TypeVar("T") class DAGSetView(MutableSet[T]): """A mutable set-like view onto DAG._succ[u], with reverse-index maintenance.""" + _data: set[T] on_add: Callable[[Any], None] | None = None @@ -22,7 +23,7 @@ class DAGSetView(MutableSet[T]): case None: v = set() case _: - raise TypeError(f"Expected set or iterable, got {type(v).__name__}") + v = {v} self._data: set[T] = v # --- required MutableSet methods --- diff --git a/tests/test_dag.py b/tests/test_dag.py index 606cad6..a89ccbe 100644 --- a/tests/test_dag.py +++ b/tests/test_dag.py @@ -25,8 +25,9 @@ class TestDAGOps: def test_add_single_edge(self) -> None: g = DAG[str]() g.add_edge("a", "b") - assert "b" in g["a"] - assert "a" in g.reverse["b"] + assert "a" in g._succ + assert "b" in g._succ["a"] + assert "a" in g._pred["b"] def test_add_multiple_edges(self) -> None: g = DAG[str]() diff --git a/tests/test_dagset.py b/tests/test_dagset.py index f11c2e4..11e7589 100644 --- a/tests/test_dagset.py +++ b/tests/test_dagset.py @@ -34,10 +34,6 @@ class TestDAGSetInit: s = DAGSetView(None) assert len(s) == 0 - def test_from_invalid_type(self) -> None: - with pytest.raises(TypeError): - DAGSetView(42) # type: ignore - class TestDAGSetBasicOps: """Test basic set operations."""