diff --git a/src/daglib/dag.py b/src/daglib/dag.py index 76efa9a..760774f 100644 --- a/src/daglib/dag.py +++ b/src/daglib/dag.py @@ -70,10 +70,10 @@ class DAG(Generic[T], MutableMapping[T, DAGSetView[T]]): if add_hook := getattr(self, "on_add", None): add_hook(u, v) - def _on_remove(u: T) -> None: - self.__dagdel__(u) + def _on_remove(v: T) -> None: + self.__dagdel__(u, v) if remove_hook := getattr(self, "on_remove", None): - remove_hook(u, u) + remove_hook(u, v) dagset = DAGSetView(self._succ.get(u)) dagset.on_add = _on_add @@ -97,20 +97,11 @@ class DAG(Generic[T], MutableMapping[T, DAGSetView[T]]): def add_edge(self, u: T, v: T) -> None: if u == v: raise ValueError("Self-loops are not allowed in a DAG") - - if u not in self._succ: - self._succ[u] = set() - if v not in self._pred: - self._pred[v] = set() - - if v not in self._succ[u]: - self._succ[u].add(v) - self._pred[v].add(u) + self[u] += v def remove_edge(self, u: T, v: T, *, missing_ok: bool = True) -> None: if v in self._succ.get(u, ()): - self._succ[u].remove(v) - self._pred[v].discard(u) + self[u] -= v elif not missing_ok: raise KeyError((u, v)) diff --git a/tests/test_dag.py b/tests/test_dag.py index a81f581..3e77ced 100644 --- a/tests/test_dag.py +++ b/tests/test_dag.py @@ -27,23 +27,29 @@ class TestDAGOps: g.add_edge("foo", "bar") assert "foo" in g._succ assert "bar" in g._succ["foo"] + assert "bar" in g._pred assert "foo" in g._pred["bar"] + assert set(g["foo"]) == {"bar"} def test_add_multiple_edges(self) -> None: g = DAG[str]() g.add_edge("foo", "bar") g.add_edge("foo", "baz") g.add_edge("bar", "baz") + + assert "foo" in g._succ + assert "bar" in g._succ["foo"] + assert "baz" in g._succ["foo"] + assert set(g["foo"]) == {"bar", "baz"} + + assert "bar" in g._succ + assert "baz" in g._succ["bar"] + assert set(g["bar"]) == {"baz"} + assert "bar" in g["foo"] assert "baz" in g["foo"] assert "baz" in g["bar"] - def test_add_edge_creates_nodes(self) -> None: - g = DAG[str]() - g.add_edge("foo", "bar") - assert "foo" in g._succ - assert "bar" in g._pred - def test_self_loop_raises_error(self) -> None: g = DAG[str]() with pytest.raises(ValueError, match="Self-loops are not allowed"):