ensure_set
This commit is contained in:
@@ -3,6 +3,8 @@ from __future__ import annotations
|
||||
from collections.abc import Callable, Iterable, Iterator, MutableSet
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from .util import ensure_set
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@@ -15,18 +17,7 @@ class DAGSetView(MutableSet[T]):
|
||||
on_remove: Callable[[Any], None] | None = None
|
||||
|
||||
def __init__(self, v: Iterable[T] | None = None) -> None:
|
||||
match v:
|
||||
case set():
|
||||
pass
|
||||
case str():
|
||||
v = {v}
|
||||
case Iterable():
|
||||
v = set(v)
|
||||
case None:
|
||||
v = set()
|
||||
case _:
|
||||
v = {v}
|
||||
self._data: set[T] = v
|
||||
self._data: set[T] = ensure_set(v)
|
||||
|
||||
# --- required MutableSet methods ---
|
||||
|
||||
@@ -52,62 +43,25 @@ class DAGSetView(MutableSet[T]):
|
||||
|
||||
# --- in-place operator support ---
|
||||
|
||||
def __ior__(self, other: Iterable[T]) -> DAGSetView[T]:
|
||||
# a |= b => union update
|
||||
match other:
|
||||
case str():
|
||||
other = {other}
|
||||
case Iterable():
|
||||
other = set(other)
|
||||
case _:
|
||||
other = {other}
|
||||
return super().__ior__(other)
|
||||
|
||||
def __iadd__(self, other: Iterable[T]) -> DAGSetView[T]:
|
||||
# a += b => update/extend
|
||||
match other:
|
||||
case set():
|
||||
self |= other
|
||||
case str():
|
||||
self |= {other}
|
||||
case Iterable():
|
||||
self |= set(other)
|
||||
case _:
|
||||
self |= {other}
|
||||
self |= ensure_set(other)
|
||||
return self
|
||||
|
||||
def __ior__(self, other: Iterable[T]) -> DAGSetView[T]:
|
||||
return super().__ior__(ensure_set(other))
|
||||
|
||||
def __isub__(self, other: Iterable[T]) -> DAGSetView[T]:
|
||||
# a -= b => remove those in other
|
||||
match other:
|
||||
case str():
|
||||
other = {other}
|
||||
case Iterable():
|
||||
other = set(other)
|
||||
case _:
|
||||
other = {other}
|
||||
return super().__isub__(other)
|
||||
return super().__isub__(ensure_set(other))
|
||||
|
||||
def __iand__(self, other: Iterable[T]) -> DAGSetView[T]:
|
||||
# a &= b => keep only those also in other
|
||||
match other:
|
||||
case str():
|
||||
other = {other}
|
||||
case Iterable():
|
||||
other = set(other)
|
||||
case _:
|
||||
other = {other}
|
||||
return super().__iand__(other)
|
||||
return super().__iand__(ensure_set(other))
|
||||
|
||||
def __ixor__(self, other: Iterable[T]) -> DAGSetView[T]:
|
||||
# a ^= b => symmetric difference update
|
||||
match other:
|
||||
case str():
|
||||
other = {other}
|
||||
case Iterable():
|
||||
other = set(other)
|
||||
case _:
|
||||
other = {other}
|
||||
return super().__ixor__(other)
|
||||
return super().__ixor__(ensure_set(other))
|
||||
|
||||
# --- nice repr for debugging ---
|
||||
|
||||
|
||||
17
src/daglib/util.py
Normal file
17
src/daglib/util.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
|
||||
def ensure_set(x: Any) -> set:
|
||||
"""Convert x to a set if it isn't already."""
|
||||
match x:
|
||||
case set() as s:
|
||||
return s
|
||||
case str() as s:
|
||||
return {s}
|
||||
case Iterable() as it:
|
||||
return set(it)
|
||||
case None:
|
||||
return set()
|
||||
case _:
|
||||
return {x}
|
||||
@@ -45,11 +45,19 @@ class TestDAGOps:
|
||||
assert "bar" in g._succ
|
||||
assert "baz" in g._succ["bar"]
|
||||
assert set(g["bar"]) == {"baz"}
|
||||
|
||||
assert "bar" in g["foo"]
|
||||
assert "baz" in g["foo"]
|
||||
assert "baz" in g["bar"]
|
||||
|
||||
def test_dagsetview(self) -> None:
|
||||
g = DAG[str]()
|
||||
added = []
|
||||
g.on_add = lambda u, v: added.append(v)
|
||||
g["foo"] += {"bar", "baz"}
|
||||
assert set(g["foo"]) == {"bar", "baz"}
|
||||
assert "bar" in added
|
||||
assert "baz" in added
|
||||
|
||||
def test_self_loop_raises_error(self) -> None:
|
||||
g = DAG[str]()
|
||||
with pytest.raises(ValueError, match="Self-loops are not allowed"):
|
||||
@@ -80,6 +88,15 @@ class TestDAGOps:
|
||||
with pytest.raises(KeyError):
|
||||
g.remove_edge("foo", "bar", missing_ok=False)
|
||||
|
||||
def test_dagsetview(self) -> None:
|
||||
g = DAG[str]()
|
||||
g["foo"] += {"bar", "baz"}
|
||||
removed = []
|
||||
g.on_remove = lambda u, v: removed.append(v)
|
||||
g["foo"] -= {"baz"}
|
||||
assert set(g["foo"]) == {"bar"}
|
||||
assert "baz" in removed
|
||||
|
||||
class TestDAGDiscardNode:
|
||||
"""Test node removal."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user