diff --git a/src/daglib/set.py b/src/daglib/set.py index f7d44d6..79fef90 100644 --- a/src/daglib/set.py +++ b/src/daglib/set.py @@ -1,11 +1,7 @@ from __future__ import annotations -from collections.abc import Iterable, Iterator, MutableSet, Set -from typing import TYPE_CHECKING, TypeVar - -if TYPE_CHECKING: - from .dag import DAG - +from collections.abc import Iterable, Iterator, MutableSet +from typing import TypeVar T = TypeVar("T") @@ -22,88 +18,46 @@ class DAGSet(MutableSet[T]): Also supports += as an alias for update for ergonomic batching. """ - __slots__ = ("_g", "_u") - - def __init__(self, g: DAG[T], u: T) -> None: - self._g = g - self._u = u - - def _raw(self) -> Set[T]: - return self._g._succ[self._u] - # --- required MutableSet methods --- def __contains__(self, x: object) -> bool: - return x in self._raw() - + return super().__contains__(x) + def __iter__(self) -> Iterator[T]: - return iter(self._raw()) + return super().__iter__() def __len__(self) -> int: - return len(self._raw()) + return super().__len__() def add(self, v: T) -> None: - self._g.add_edge(self._u, v) + super().add(v) def discard(self, v: T) -> None: - self._g.remove_edge(self._u, v, missing_ok=True) - - # --- convenience / correctness helpers --- - - def remove(self, v: T) -> None: - self._g.remove_edge(self._u, v, missing_ok=False) - - def clear(self) -> None: - for v in list(self._raw()): - self._g.remove_edge(self._u, v, missing_ok=True) - - def update(self, it: Iterable[T]) -> None: - for v in it: - self._g.add_edge(self._u, v) + super().discard(v) # --- in-place operator support --- def __ior__(self, other: Iterable[T]) -> DAGSet[T]: # a |= b => add everything in other - self.update(other) return self - def __iand__(self, other: Iterable[T]) -> DAGSet[T]: - # a &= b => keep only those also in other - other_set = set(other) - for v in list(self._raw()): - if v not in other_set: - self._g.remove_edge(self._u, v, missing_ok=True) + def __iadd__(self, other: Iterable[T]) -> DAGSet[T]: + # a += b => update/extend return self def __isub__(self, other: Iterable[T]) -> DAGSet[T]: # a -= b => remove those in other - for v in other: - self._g.remove_edge(self._u, v, missing_ok=True) + return self + + def __iand__(self, other: Iterable[T]) -> DAGSet[T]: + # a &= b => keep only those also in other return self def __ixor__(self, other: Iterable[T]) -> DAGSet[T]: # a ^= b => symmetric difference update - other_set = set(other) - raw = self._raw() - to_remove = raw & other_set - to_add = other_set - raw - - for v in to_remove: - self._g.remove_edge(self._u, v, missing_ok=True) - for v in to_add: - self._g.add_edge(self._u, v) - return self - - def __iadd__(self, other: Iterable[T]) -> DAGSet[T]: - """ - Built-in set doesn't define +=, but many folks expect it. - Treat it as 'update' (extend). - """ - self.update(other) return self # --- nice repr for debugging --- def __repr__(self) -> str: - return repr(self._raw()) + return super().__repr__()