ensure_set

This commit is contained in:
John Lancaster
2026-02-21 12:57:21 -06:00
parent f357ee1a5a
commit 8ec77b51fe
3 changed files with 45 additions and 57 deletions

View File

@@ -3,6 +3,8 @@ from __future__ import annotations
from collections.abc import Callable, Iterable, Iterator, MutableSet from collections.abc import Callable, Iterable, Iterator, MutableSet
from typing import Any, TypeVar from typing import Any, TypeVar
from .util import ensure_set
T = TypeVar("T") T = TypeVar("T")
@@ -15,18 +17,7 @@ class DAGSetView(MutableSet[T]):
on_remove: Callable[[Any], None] | None = None on_remove: Callable[[Any], None] | None = None
def __init__(self, v: Iterable[T] | None = None) -> None: def __init__(self, v: Iterable[T] | None = None) -> None:
match v: self._data: set[T] = ensure_set(v)
case set():
pass
case str():
v = {v}
case Iterable():
v = set(v)
case None:
v = set()
case _:
v = {v}
self._data: set[T] = v
# --- required MutableSet methods --- # --- required MutableSet methods ---
@@ -52,62 +43,25 @@ class DAGSetView(MutableSet[T]):
# --- in-place operator support --- # --- in-place operator support ---
def __ior__(self, other: Iterable[T]) -> DAGSetView[T]:
# a |= b => union update
match other:
case str():
other = {other}
case Iterable():
other = set(other)
case _:
other = {other}
return super().__ior__(other)
def __iadd__(self, other: Iterable[T]) -> DAGSetView[T]: def __iadd__(self, other: Iterable[T]) -> DAGSetView[T]:
# a += b => update/extend # a += b => update/extend
match other: self |= ensure_set(other)
case set():
self |= other
case str():
self |= {other}
case Iterable():
self |= set(other)
case _:
self |= {other}
return self return self
def __ior__(self, other: Iterable[T]) -> DAGSetView[T]:
return super().__ior__(ensure_set(other))
def __isub__(self, other: Iterable[T]) -> DAGSetView[T]: def __isub__(self, other: Iterable[T]) -> DAGSetView[T]:
# a -= b => remove those in other # a -= b => remove those in other
match other: return super().__isub__(ensure_set(other))
case str():
other = {other}
case Iterable():
other = set(other)
case _:
other = {other}
return super().__isub__(other)
def __iand__(self, other: Iterable[T]) -> DAGSetView[T]: def __iand__(self, other: Iterable[T]) -> DAGSetView[T]:
# a &= b => keep only those also in other # a &= b => keep only those also in other
match other: return super().__iand__(ensure_set(other))
case str():
other = {other}
case Iterable():
other = set(other)
case _:
other = {other}
return super().__iand__(other)
def __ixor__(self, other: Iterable[T]) -> DAGSetView[T]: def __ixor__(self, other: Iterable[T]) -> DAGSetView[T]:
# a ^= b => symmetric difference update # a ^= b => symmetric difference update
match other: return super().__ixor__(ensure_set(other))
case str():
other = {other}
case Iterable():
other = set(other)
case _:
other = {other}
return super().__ixor__(other)
# --- nice repr for debugging --- # --- nice repr for debugging ---

17
src/daglib/util.py Normal file
View File

@@ -0,0 +1,17 @@
from collections.abc import Iterable
from typing import Any
def ensure_set(x: Any) -> set:
"""Convert x to a set if it isn't already."""
match x:
case set() as s:
return s
case str() as s:
return {s}
case Iterable() as it:
return set(it)
case None:
return set()
case _:
return {x}

View File

@@ -45,11 +45,19 @@ class TestDAGOps:
assert "bar" in g._succ assert "bar" in g._succ
assert "baz" in g._succ["bar"] assert "baz" in g._succ["bar"]
assert set(g["bar"]) == {"baz"} assert set(g["bar"]) == {"baz"}
assert "bar" in g["foo"] assert "bar" in g["foo"]
assert "baz" in g["foo"] assert "baz" in g["foo"]
assert "baz" in g["bar"] assert "baz" in g["bar"]
def test_dagsetview(self) -> None:
g = DAG[str]()
added = []
g.on_add = lambda u, v: added.append(v)
g["foo"] += {"bar", "baz"}
assert set(g["foo"]) == {"bar", "baz"}
assert "bar" in added
assert "baz" in added
def test_self_loop_raises_error(self) -> None: def test_self_loop_raises_error(self) -> None:
g = DAG[str]() g = DAG[str]()
with pytest.raises(ValueError, match="Self-loops are not allowed"): with pytest.raises(ValueError, match="Self-loops are not allowed"):
@@ -80,6 +88,15 @@ class TestDAGOps:
with pytest.raises(KeyError): with pytest.raises(KeyError):
g.remove_edge("foo", "bar", missing_ok=False) g.remove_edge("foo", "bar", missing_ok=False)
def test_dagsetview(self) -> None:
g = DAG[str]()
g["foo"] += {"bar", "baz"}
removed = []
g.on_remove = lambda u, v: removed.append(v)
g["foo"] -= {"baz"}
assert set(g["foo"]) == {"bar"}
assert "baz" in removed
class TestDAGDiscardNode: class TestDAGDiscardNode:
"""Test node removal.""" """Test node removal."""