This commit is contained in:
John Lancaster
2026-02-21 09:45:53 -06:00
parent b8d7478f2b
commit 88f1ba4426
2 changed files with 12 additions and 11 deletions

View File

@@ -5,14 +5,14 @@ from collections.abc import Callable, Iterable, Iterator, MutableMapping
from dataclasses import dataclass, field
from typing import Generic, TypeVar
from .set import DAGSet
from .set import DAGSetView
type dictset = defaultdict[T, set[T]]
T = TypeVar("T")
@dataclass(repr=False)
class DAG(Generic[T], MutableMapping[T, DAGSet[T]]):
class DAG(Generic[T], MutableMapping[T, DAGSetView[T]]):
"""
DAG adjacency map:
- g[u] -> mutable set-like view of successors of u
@@ -34,8 +34,8 @@ class DAG(Generic[T], MutableMapping[T, DAGSet[T]]):
# --- MutableMapping interface ---
def __getitem__(self, u: T) -> DAGSet[T]:
dagset = DAGSet(self._succ.get(u))
def __getitem__(self, u: T) -> DAGSetView[T]:
dagset = DAGSetView(self._succ.get(u))
def _on_add(v: T) -> None:
self._succ[u] |= {v}

View File

@@ -6,7 +6,7 @@ from typing import Any, TypeVar
T = TypeVar("T")
class DAGSet(MutableSet[T]):
class DAGSetView(MutableSet[T]):
"""
A mutable set-like view onto DAG._succ[u], with reverse-index maintenance.
@@ -17,6 +17,7 @@ class DAGSet(MutableSet[T]):
- ^= (symmetric_difference_update)
Also supports += as an alias for update for ergonomic batching.
"""
_data: set[T]
on_add: Callable[[Any], None] | None = None
on_remove: Callable[[Any], None] | None = None
@@ -57,7 +58,7 @@ class DAGSet(MutableSet[T]):
# --- in-place operator support ---
def __ior__(self, other: Iterable[T]) -> DAGSet[T]:
def __ior__(self, other: Iterable[T]) -> DAGSetView[T]:
# a |= b => union update
match other:
case str():
@@ -68,7 +69,7 @@ class DAGSet(MutableSet[T]):
other = {other}
return super().__ior__(other)
def __iadd__(self, other: Iterable[T]) -> DAGSet[T]:
def __iadd__(self, other: Iterable[T]) -> DAGSetView[T]:
# a += b => update/extend
match other:
case set():
@@ -81,7 +82,7 @@ class DAGSet(MutableSet[T]):
self |= {other}
return self
def __isub__(self, other: Iterable[T]) -> DAGSet[T]:
def __isub__(self, other: Iterable[T]) -> DAGSetView[T]:
# a -= b => remove those in other
match other:
case str():
@@ -92,7 +93,7 @@ class DAGSet(MutableSet[T]):
other = {other}
return super().__isub__(other)
def __iand__(self, other: Iterable[T]) -> DAGSet[T]:
def __iand__(self, other: Iterable[T]) -> DAGSetView[T]:
# a &= b => keep only those also in other
match other:
case str():
@@ -103,7 +104,7 @@ class DAGSet(MutableSet[T]):
other = {other}
return super().__iand__(other)
def __ixor__(self, other: Iterable[T]) -> DAGSet[T]:
def __ixor__(self, other: Iterable[T]) -> DAGSetView[T]:
# a ^= b => symmetric difference update
match other:
case str():
@@ -118,4 +119,4 @@ class DAGSet(MutableSet[T]):
def __repr__(self) -> str:
item_str = ", ".join(repr(x) for x in self)
return f"{{{item_str}}}"
return f"{self.__class__.__name__}{{{item_str}}}"