test updates

This commit is contained in:
John Lancaster
2026-02-21 11:38:06 -06:00
parent d71d40c4d3
commit 81793a7bfe
4 changed files with 12 additions and 15 deletions

View File

@@ -36,20 +36,19 @@ class DAG(Generic[T], MutableMapping[T, DAGSetView[T]]):
# --- MutableMapping interface --- # --- MutableMapping interface ---
def __getitem__(self, u: T) -> DAGSetView[T]: def __getitem__(self, u: T) -> DAGSetView[T]:
dagset = DAGSetView(self._succ.get(u))
def _on_add(v: T) -> None: def _on_add(v: T) -> None:
self._succ[u] |= {v} self._succ[u] |= {v}
self._pred[v] |= {u} self._pred[v] |= {u}
if self.on_add is not None: if add_hook := getattr(self, "on_add", None):
self.on_add(u, v) add_hook(u, v)
def _on_remove(v: T) -> None: def _on_remove(v: T) -> None:
self._succ[u] -= {v} self._succ[u] -= {v}
self._pred[v] -= {u} self._pred[v] -= {u}
if self.on_remove is not None: if remove_hook := getattr(self, "on_remove", None):
self.on_remove(u, v) remove_hook(u, v)
dagset = DAGSetView(self._succ.get(u))
dagset.on_add = _on_add dagset.on_add = _on_add
dagset.on_remove = _on_remove dagset.on_remove = _on_remove
return dagset return dagset
@@ -66,8 +65,8 @@ class DAG(Generic[T], MutableMapping[T, DAGSetView[T]]):
self.__setsingle__(u, v) self.__setsingle__(u, v)
def __setsingle__(self, u: T, v: T) -> None: def __setsingle__(self, u: T, v: T) -> None:
self._succ[u] |= {v} self._succ[u] = v
self._pred[v] |= {u} self._pred[v] = u
def __setmultiple__(self, u: T, vs: Iterable[T]) -> None: def __setmultiple__(self, u: T, vs: Iterable[T]) -> None:
self._succ[u] |= set(vs) self._succ[u] |= set(vs)

View File

@@ -8,6 +8,7 @@ T = TypeVar("T")
class DAGSetView(MutableSet[T]): class DAGSetView(MutableSet[T]):
"""A mutable set-like view onto DAG._succ[u], with reverse-index maintenance.""" """A mutable set-like view onto DAG._succ[u], with reverse-index maintenance."""
_data: set[T] _data: set[T]
on_add: Callable[[Any], None] | None = None on_add: Callable[[Any], None] | None = None
@@ -22,7 +23,7 @@ class DAGSetView(MutableSet[T]):
case None: case None:
v = set() v = set()
case _: case _:
raise TypeError(f"Expected set or iterable, got {type(v).__name__}") v = {v}
self._data: set[T] = v self._data: set[T] = v
# --- required MutableSet methods --- # --- required MutableSet methods ---

View File

@@ -25,8 +25,9 @@ class TestDAGOps:
def test_add_single_edge(self) -> None: def test_add_single_edge(self) -> None:
g = DAG[str]() g = DAG[str]()
g.add_edge("a", "b") g.add_edge("a", "b")
assert "b" in g["a"] assert "a" in g._succ
assert "a" in g.reverse["b"] assert "b" in g._succ["a"]
assert "a" in g._pred["b"]
def test_add_multiple_edges(self) -> None: def test_add_multiple_edges(self) -> None:
g = DAG[str]() g = DAG[str]()

View File

@@ -34,10 +34,6 @@ class TestDAGSetInit:
s = DAGSetView(None) s = DAGSetView(None)
assert len(s) == 0 assert len(s) == 0
def test_from_invalid_type(self) -> None:
with pytest.raises(TypeError):
DAGSetView(42) # type: ignore
class TestDAGSetBasicOps: class TestDAGSetBasicOps:
"""Test basic set operations.""" """Test basic set operations."""