ensure_set

This commit is contained in:
John Lancaster
2026-02-21 12:57:21 -06:00
parent f357ee1a5a
commit 8ec77b51fe
3 changed files with 45 additions and 57 deletions

View File

@@ -3,6 +3,8 @@ from __future__ import annotations
from collections.abc import Callable, Iterable, Iterator, MutableSet
from typing import Any, TypeVar
from .util import ensure_set
T = TypeVar("T")
@@ -15,18 +17,7 @@ class DAGSetView(MutableSet[T]):
on_remove: Callable[[Any], None] | None = None
def __init__(self, v: Iterable[T] | None = None) -> None:
match v:
case set():
pass
case str():
v = {v}
case Iterable():
v = set(v)
case None:
v = set()
case _:
v = {v}
self._data: set[T] = v
self._data: set[T] = ensure_set(v)
# --- required MutableSet methods ---
@@ -52,62 +43,25 @@ class DAGSetView(MutableSet[T]):
# --- in-place operator support ---
def __ior__(self, other: Iterable[T]) -> DAGSetView[T]:
# a |= b => union update
match other:
case str():
other = {other}
case Iterable():
other = set(other)
case _:
other = {other}
return super().__ior__(other)
def __iadd__(self, other: Iterable[T]) -> DAGSetView[T]:
# a += b => update/extend
match other:
case set():
self |= other
case str():
self |= {other}
case Iterable():
self |= set(other)
case _:
self |= {other}
self |= ensure_set(other)
return self
def __ior__(self, other: Iterable[T]) -> DAGSetView[T]:
return super().__ior__(ensure_set(other))
def __isub__(self, other: Iterable[T]) -> DAGSetView[T]:
# a -= b => remove those in other
match other:
case str():
other = {other}
case Iterable():
other = set(other)
case _:
other = {other}
return super().__isub__(other)
return super().__isub__(ensure_set(other))
def __iand__(self, other: Iterable[T]) -> DAGSetView[T]:
# a &= b => keep only those also in other
match other:
case str():
other = {other}
case Iterable():
other = set(other)
case _:
other = {other}
return super().__iand__(other)
return super().__iand__(ensure_set(other))
def __ixor__(self, other: Iterable[T]) -> DAGSetView[T]:
# a ^= b => symmetric difference update
match other:
case str():
other = {other}
case Iterable():
other = set(other)
case _:
other = {other}
return super().__ixor__(other)
return super().__ixor__(ensure_set(other))
# --- nice repr for debugging ---

17
src/daglib/util.py Normal file
View File

@@ -0,0 +1,17 @@
from collections.abc import Iterable
from typing import Any
def ensure_set(x: Any) -> set:
"""Convert x to a set if it isn't already."""
match x:
case set() as s:
return s
case str() as s:
return {s}
case Iterable() as it:
return set(it)
case None:
return set()
case _:
return {x}

View File

@@ -45,11 +45,19 @@ class TestDAGOps:
assert "bar" in g._succ
assert "baz" in g._succ["bar"]
assert set(g["bar"]) == {"baz"}
assert "bar" in g["foo"]
assert "baz" in g["foo"]
assert "baz" in g["bar"]
def test_dagsetview(self) -> None:
g = DAG[str]()
added = []
g.on_add = lambda u, v: added.append(v)
g["foo"] += {"bar", "baz"}
assert set(g["foo"]) == {"bar", "baz"}
assert "bar" in added
assert "baz" in added
def test_self_loop_raises_error(self) -> None:
g = DAG[str]()
with pytest.raises(ValueError, match="Self-loops are not allowed"):
@@ -80,6 +88,15 @@ class TestDAGOps:
with pytest.raises(KeyError):
g.remove_edge("foo", "bar", missing_ok=False)
def test_dagsetview(self) -> None:
g = DAG[str]()
g["foo"] += {"bar", "baz"}
removed = []
g.on_remove = lambda u, v: removed.append(v)
g["foo"] -= {"baz"}
assert set(g["foo"]) == {"bar"}
assert "baz" in removed
class TestDAGDiscardNode:
"""Test node removal."""