diff --git a/src/daglib/set.py b/src/daglib/set.py index 90450f4..182a427 100644 --- a/src/daglib/set.py +++ b/src/daglib/set.py @@ -3,6 +3,8 @@ from __future__ import annotations from collections.abc import Callable, Iterable, Iterator, MutableSet from typing import Any, TypeVar +from .util import ensure_set + T = TypeVar("T") @@ -15,18 +17,7 @@ class DAGSetView(MutableSet[T]): on_remove: Callable[[Any], None] | None = None def __init__(self, v: Iterable[T] | None = None) -> None: - match v: - case set(): - pass - case str(): - v = {v} - case Iterable(): - v = set(v) - case None: - v = set() - case _: - v = {v} - self._data: set[T] = v + self._data: set[T] = ensure_set(v) # --- required MutableSet methods --- @@ -52,62 +43,25 @@ class DAGSetView(MutableSet[T]): # --- in-place operator support --- - def __ior__(self, other: Iterable[T]) -> DAGSetView[T]: - # a |= b => union update - match other: - case str(): - other = {other} - case Iterable(): - other = set(other) - case _: - other = {other} - return super().__ior__(other) - def __iadd__(self, other: Iterable[T]) -> DAGSetView[T]: # a += b => update/extend - match other: - case set(): - self |= other - case str(): - self |= {other} - case Iterable(): - self |= set(other) - case _: - self |= {other} + self |= ensure_set(other) return self + def __ior__(self, other: Iterable[T]) -> DAGSetView[T]: + return super().__ior__(ensure_set(other)) + def __isub__(self, other: Iterable[T]) -> DAGSetView[T]: # a -= b => remove those in other - match other: - case str(): - other = {other} - case Iterable(): - other = set(other) - case _: - other = {other} - return super().__isub__(other) + return super().__isub__(ensure_set(other)) def __iand__(self, other: Iterable[T]) -> DAGSetView[T]: # a &= b => keep only those also in other - match other: - case str(): - other = {other} - case Iterable(): - other = set(other) - case _: - other = {other} - return super().__iand__(other) + return super().__iand__(ensure_set(other)) def __ixor__(self, other: Iterable[T]) -> DAGSetView[T]: # a ^= b => symmetric difference update - match other: - case str(): - other = {other} - case Iterable(): - other = set(other) - case _: - other = {other} - return super().__ixor__(other) + return super().__ixor__(ensure_set(other)) # --- nice repr for debugging --- diff --git a/src/daglib/util.py b/src/daglib/util.py new file mode 100644 index 0000000..d77888f --- /dev/null +++ b/src/daglib/util.py @@ -0,0 +1,17 @@ +from collections.abc import Iterable +from typing import Any + + +def ensure_set(x: Any) -> set: + """Convert x to a set if it isn't already.""" + match x: + case set() as s: + return s + case str() as s: + return {s} + case Iterable() as it: + return set(it) + case None: + return set() + case _: + return {x} diff --git a/tests/test_dag.py b/tests/test_dag.py index 3e77ced..5072ffa 100644 --- a/tests/test_dag.py +++ b/tests/test_dag.py @@ -45,11 +45,19 @@ class TestDAGOps: assert "bar" in g._succ assert "baz" in g._succ["bar"] assert set(g["bar"]) == {"baz"} - assert "bar" in g["foo"] assert "baz" in g["foo"] assert "baz" in g["bar"] + def test_dagsetview(self) -> None: + g = DAG[str]() + added = [] + g.on_add = lambda u, v: added.append(v) + g["foo"] += {"bar", "baz"} + assert set(g["foo"]) == {"bar", "baz"} + assert "bar" in added + assert "baz" in added + def test_self_loop_raises_error(self) -> None: g = DAG[str]() with pytest.raises(ValueError, match="Self-loops are not allowed"): @@ -80,6 +88,15 @@ class TestDAGOps: with pytest.raises(KeyError): g.remove_edge("foo", "bar", missing_ok=False) + def test_dagsetview(self) -> None: + g = DAG[str]() + g["foo"] += {"bar", "baz"} + removed = [] + g.on_remove = lambda u, v: removed.append(v) + g["foo"] -= {"baz"} + assert set(g["foo"]) == {"bar"} + assert "baz" in removed + class TestDAGDiscardNode: """Test node removal."""