diff --git a/src/daglib/dag.py b/src/daglib/dag.py index f71e8bb..76efa9a 100644 --- a/src/daglib/dag.py +++ b/src/daglib/dag.py @@ -33,46 +33,53 @@ class DAG(Generic[T], MutableMapping[T, DAGSetView[T]]): def reverse(self) -> dictset[T]: return self._pred + # --- DAG internals --- + + def __dagset__(self, u: T, v: T) -> None: + self._succ[u].add(v) + self._pred[v].add(u) + + def __dagsetmulti__(self, u: T, vs: Iterable[T]) -> None: + for v in vs: + self.__dagset__(u, v) + + def __dagdel__(self, u: T, v: T) -> None: + self._succ[u].discard(v) + self._pred[v].discard(u) + + def __dagdelmulti__(self, u: T, vs: Iterable[T]) -> None: + for v in vs: + self.__dagdel__(u, v) + # --- MutableMapping interface --- + def __setitem__(self, u: T, value: Iterable[T]) -> None: + match value: + case set() as vs: + self.__dagsetmulti__(u, vs) + case str() as v: + self.__dagset__(u, v) + case Iterable() as vs: + self.__dagsetmulti__(u, vs) + case _ as v: + self.__dagset__(u, v) + def __getitem__(self, u: T) -> DAGSetView[T]: def _on_add(v: T) -> None: - self._succ[u] |= {v} - self._pred[v] |= {u} + self.__dagset__(u, v) if add_hook := getattr(self, "on_add", None): add_hook(u, v) - def _on_remove(v: T) -> None: - self._succ[u] -= {v} - self._pred[v] -= {u} + def _on_remove(u: T) -> None: + self.__dagdel__(u) if remove_hook := getattr(self, "on_remove", None): - remove_hook(u, v) + remove_hook(u, u) dagset = DAGSetView(self._succ.get(u)) dagset.on_add = _on_add dagset.on_remove = _on_remove return dagset - def __setitem__(self, u: T, value: Iterable[T]) -> None: - match value: - case set() as vs: - self.__setmultiple__(u, vs) - case str() as v: - self.__setsingle__(u, v) - case Iterable() as vs: - self.__setmultiple__(u, vs) - case _ as v: - self.__setsingle__(u, v) - - def __setsingle__(self, u: T, v: T) -> None: - self._succ[u] = v - self._pred[v] = u - - def __setmultiple__(self, u: T, vs: Iterable[T]) -> None: - self._succ[u] |= set(vs) - for v in vs: - self._pred[v] |= {u} - def __delitem__(self, u: T) -> None: self.discard_node(u) diff --git a/src/daglib/set.py b/src/daglib/set.py index e7fad15..90450f4 100644 --- a/src/daglib/set.py +++ b/src/daglib/set.py @@ -18,6 +18,8 @@ class DAGSetView(MutableSet[T]): match v: case set(): pass + case str(): + v = {v} case Iterable(): v = set(v) case None: diff --git a/tests/test_dag.py b/tests/test_dag.py index a89ccbe..a81f581 100644 --- a/tests/test_dag.py +++ b/tests/test_dag.py @@ -24,76 +24,76 @@ class TestDAGOps: def test_add_single_edge(self) -> None: g = DAG[str]() - g.add_edge("a", "b") - assert "a" in g._succ - assert "b" in g._succ["a"] - assert "a" in g._pred["b"] + g.add_edge("foo", "bar") + assert "foo" in g._succ + assert "bar" in g._succ["foo"] + assert "foo" in g._pred["bar"] def test_add_multiple_edges(self) -> None: g = DAG[str]() - g.add_edge("a", "b") - g.add_edge("a", "c") - g.add_edge("b", "c") - assert "b" in g["a"] - assert "c" in g["a"] - assert "c" in g["b"] + g.add_edge("foo", "bar") + g.add_edge("foo", "baz") + g.add_edge("bar", "baz") + assert "bar" in g["foo"] + assert "baz" in g["foo"] + assert "baz" in g["bar"] def test_add_edge_creates_nodes(self) -> None: g = DAG[str]() - g.add_edge("a", "b") - assert "a" in g._succ - assert "b" in g._pred + g.add_edge("foo", "bar") + assert "foo" in g._succ + assert "bar" in g._pred def test_self_loop_raises_error(self) -> None: g = DAG[str]() with pytest.raises(ValueError, match="Self-loops are not allowed"): - g.add_edge("a", "a") + g.add_edge("foo", "foo") def test_add_duplicate_edge_idempotent(self) -> None: g = DAG[str]() - g.add_edge("a", "b") - g.add_edge("a", "b") - assert len(g["a"]) == 1 + g.add_edge("foo", "bar") + g.add_edge("foo", "bar") + assert len(g["foo"]) == 1 class TestDAGRemoveEdge: """Test removing edges.""" def test_remove_existing_edge(self) -> None: g = DAG[str]() - g.add_edge("a", "b") - g.remove_edge("a", "b") - assert "b" not in g["a"] - assert "a" not in g.reverse["b"] + g.add_edge("foo", "bar") + g.remove_edge("foo", "bar") + assert "bar" not in g["foo"] + assert "foo" not in g.reverse["bar"] def test_remove_nonexistent_edge_missing_ok(self) -> None: g = DAG[str]() - g.remove_edge("a", "b", missing_ok=True) # should not raise + g.remove_edge("foo", "bar", missing_ok=True) # should not raise def test_remove_nonexistent_edge_error(self) -> None: g = DAG[str]() with pytest.raises(KeyError): - g.remove_edge("a", "b", missing_ok=False) + g.remove_edge("foo", "bar", missing_ok=False) class TestDAGDiscardNode: """Test node removal.""" def test_discard_node_removes_outgoing_edges(self) -> None: g = DAG[str]() - g.add_edge("a", "b") - g.add_edge("a", "c") - g.discard_node("a") - assert "a" not in g._succ - assert "a" not in g.reverse["b"] - assert "a" not in g.reverse["c"] + g.add_edge("foo", "bar") + g.add_edge("foo", "baz") + g.discard_node("foo") + assert "foo" not in g._succ + assert "foo" not in g.reverse["bar"] + assert "foo" not in g.reverse["baz"] def test_discard_node_removes_incoming_edges(self) -> None: g = DAG[str]() - g.add_edge("a", "c") - g.add_edge("b", "c") - g.discard_node("c") - assert "c" not in g._pred - assert "c" not in g["a"] - assert "c" not in g["b"] + g.add_edge("foo", "baz") + g.add_edge("bar", "baz") + g.discard_node("baz") + assert "baz" not in g._pred + assert "baz" not in g["foo"] + assert "baz" not in g["bar"] def test_discard_nonexistent_node(self) -> None: g = DAG[str]() @@ -106,20 +106,22 @@ class TestDAGBasicOps: def test_with_set(self) -> None: g = DAG[str]() - g["a"] = {"b", "c"} - assert "b" in g["a"] - assert "c" in g["a"] + g["foo"] = "qux" + g["foo"] = {"bar", "baz"} + assert set(g["foo"]) == {"bar", "baz", "qux"} + assert "bar" in g["foo"] + assert "baz" in g["foo"] def test_with_list(self) -> None: g = DAG[str]() - g["a"] = ["b", "c"] - assert "b" in g["a"] - assert "c" in g["a"] + g["foo"] = ["bar", "baz"] + assert "bar" in g["foo"] + assert "baz" in g["foo"] def test_with_string(self) -> None: g = DAG[str]() - g["a"] = "b" - assert "b" in g["a"] + g["foo"] = "bar" + assert "bar" in g["foo"] def test_with_single_item(self) -> None: g = DAG[int]() @@ -128,18 +130,18 @@ class TestDAGBasicOps: def test_updates_reverse(self) -> None: g = DAG[str]() - g["a"] = {"b", "c"} - assert "a" in g.reverse["b"] - assert "a" in g.reverse["c"] + g["foo"] = {"bar", "baz"} + assert "foo" in g.reverse["bar"] + assert "foo" in g.reverse["baz"] class TestGetItem: """Test dictionary-style access.""" def test_getitem_returns_dagset(self) -> None: g = DAG[str]() - g.add_edge("a", "b") - dagset = g["a"] - assert "b" in dagset + g.add_edge("foo", "bar") + dagset = g["foo"] + assert "bar" in dagset def test_getitem_empty_node(self) -> None: g = DAG[str]() @@ -148,25 +150,25 @@ class TestDAGBasicOps: def test_getitem_mutation_updates_graph(self) -> None: g = DAG[str]() - g["a"].add("b") - assert "b" in g["a"] - assert "a" in g.reverse["b"] + g["foo"].add("bar") + assert "bar" in g["foo"] + assert "foo" in g.reverse["bar"] def test_getitem_mutation_triggers_callbacks(self) -> None: added: list[tuple[str, str]] = [] g = DAG[str]() g.on_add = lambda u, v: added.append((u, v)) - g["a"].add("b") - assert ("a", "b") in added + g["foo"].add("bar") + assert ("foo", "bar") in added class TestDAGDelItem: """Test del operation.""" def test_removes_node(self) -> None: g = DAG[str]() - g.add_edge("a", "b") - del g["a"] - assert "a" not in g._succ + g.add_edge("foo", "bar") + del g["foo"] + assert "foo" not in g._succ class TestDAGIterOps: @@ -179,9 +181,9 @@ class TestDAGIterOps: def test_returns_nodes(self) -> None: g = DAG[str]() - g.add_edge("a", "b") - g.add_edge("b", "c") - assert {"a", "b"} == set(g) + g.add_edge("foo", "bar") + g.add_edge("bar", "baz") + assert {"foo", "bar"} == set(g) class TestLen: """Test length (edge count).""" @@ -192,16 +194,16 @@ class TestDAGIterOps: def test_len_counts_edges(self) -> None: g = DAG[str]() - g.add_edge("a", "b") - g.add_edge("a", "c") - g.add_edge("b", "c") + g.add_edge("foo", "bar") + g.add_edge("foo", "baz") + g.add_edge("bar", "baz") assert len(g) == 3 def test_len_after_removal(self) -> None: g = DAG[str]() - g.add_edge("a", "b") - g.add_edge("a", "c") - g.remove_edge("a", "b") + g.add_edge("foo", "bar") + g.add_edge("foo", "baz") + g.remove_edge("foo", "bar") assert len(g) == 1 @@ -210,17 +212,17 @@ class TestDAGReverse: def test_reverse_property(self) -> None: g = DAG[str]() - g.add_edge("a", "b") - assert "a" in g.reverse["b"] - assert len(g.reverse["a"]) == 0 + g.add_edge("foo", "bar") + assert "foo" in g.reverse["bar"] + assert len(g.reverse["foo"]) == 0 def test_reverse_multiple_predecessors(self) -> None: g = DAG[str]() - g.add_edge("a", "c") - g.add_edge("b", "c") - preds = g.reverse["c"] - assert "a" in preds - assert "b" in preds + g.add_edge("foo", "baz") + g.add_edge("bar", "baz") + preds = g.reverse["baz"] + assert "foo" in preds + assert "bar" in preds class TestDAGCallbacks: @@ -230,23 +232,23 @@ class TestDAGCallbacks: added: list[tuple[str, str]] = [] g = DAG[str]() g.on_add = lambda u, v: added.append((u, v)) - g["a"] = {"b", "c"} + g["foo"] = {"bar", "baz"} # Note: setitem doesn't trigger callbacks in current implementation def test_on_add_callback_via_getitem_mutation(self) -> None: added: list[tuple[str, str]] = [] g = DAG[str]() g.on_add = lambda u, v: added.append((u, v)) - g["a"].add("b") - assert ("a", "b") in added + g["foo"].add("bar") + assert ("foo", "bar") in added def test_on_remove_callback(self) -> None: removed: list[tuple[str, str]] = [] g = DAG[str]() g.on_remove = lambda u, v: removed.append((u, v)) - g["a"].add("b") - g["a"].discard("b") - assert ("a", "b") in removed + g["foo"].add("bar") + g["foo"].discard("bar") + assert ("foo", "bar") in removed class TestDAGComplexScenarios: @@ -254,28 +256,28 @@ class TestDAGComplexScenarios: def test_chain_graph(self) -> None: g = DAG[str]() - g["a"] = {"b"} - g["b"] = {"c"} - g["c"] = {"d"} - assert "b" in g["a"] - assert "c" in g["b"] - assert "d" in g["c"] + g["foo"] = {"bar"} + g["bar"] = {"baz"} + g["baz"] = {"d"} + assert "bar" in g["foo"] + assert "baz" in g["bar"] + assert "d" in g["baz"] assert len(g) == 3 def test_diamond_graph(self) -> None: g = DAG[str]() - g["a"] = {"b", "c"} - g["b"] = {"d"} - g["c"] = {"d"} + g["foo"] = {"bar", "baz"} + g["bar"] = {"d"} + g["baz"] = {"d"} assert len(g.reverse["d"]) == 2 def test_batch_operations(self) -> None: g = DAG[str]() - g["a"] += {"b", "c", "d"} - assert len(g["a"]) == 3 - g["a"] -= {"b"} - assert len(g["a"]) == 2 - assert "b" not in g["a"] + g["foo"] += {"bar", "baz", "d"} + assert len(g["foo"]) == 3 + g["foo"] -= {"bar"} + assert len(g["foo"]) == 2 + assert "bar" not in g["foo"] def test_type_hints_with_ints(self) -> None: g = DAG[int]() diff --git a/tests/test_dagset.py b/tests/test_dagset.py index 11e7589..ca5f22c 100644 --- a/tests/test_dagset.py +++ b/tests/test_dagset.py @@ -20,6 +20,11 @@ class TestDAGSetInit: assert len(s) == 3 assert set(s) == {"foo", 2, 3} + def test_from_str(self) -> None: + s = DAGSetView("foo") + assert len(s) == 1 + assert set(s) == {"foo"} + def test_from_list(self) -> None: s = DAGSetView(["foo", 2, 3, 2]) assert len(s) == 3