cleared types

This commit is contained in:
John Lancaster
2026-02-20 23:11:24 -06:00
parent 9d1a0e50e6
commit 1d36496d3e

View File

@@ -1,11 +1,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable, Iterator, MutableSet, Set from collections.abc import Iterable, Iterator, MutableSet
from typing import TYPE_CHECKING, TypeVar from typing import TypeVar
if TYPE_CHECKING:
from .dag import DAG
T = TypeVar("T") T = TypeVar("T")
@@ -22,88 +18,46 @@ class DAGSet(MutableSet[T]):
Also supports += as an alias for update for ergonomic batching. Also supports += as an alias for update for ergonomic batching.
""" """
__slots__ = ("_g", "_u")
def __init__(self, g: DAG[T], u: T) -> None:
self._g = g
self._u = u
def _raw(self) -> Set[T]:
return self._g._succ[self._u]
# --- required MutableSet methods --- # --- required MutableSet methods ---
def __contains__(self, x: object) -> bool: def __contains__(self, x: object) -> bool:
return x in self._raw() return super().__contains__(x)
def __iter__(self) -> Iterator[T]: def __iter__(self) -> Iterator[T]:
return iter(self._raw()) return super().__iter__()
def __len__(self) -> int: def __len__(self) -> int:
return len(self._raw()) return super().__len__()
def add(self, v: T) -> None: def add(self, v: T) -> None:
self._g.add_edge(self._u, v) super().add(v)
def discard(self, v: T) -> None: def discard(self, v: T) -> None:
self._g.remove_edge(self._u, v, missing_ok=True) super().discard(v)
# --- convenience / correctness helpers ---
def remove(self, v: T) -> None:
self._g.remove_edge(self._u, v, missing_ok=False)
def clear(self) -> None:
for v in list(self._raw()):
self._g.remove_edge(self._u, v, missing_ok=True)
def update(self, it: Iterable[T]) -> None:
for v in it:
self._g.add_edge(self._u, v)
# --- in-place operator support --- # --- in-place operator support ---
def __ior__(self, other: Iterable[T]) -> DAGSet[T]: def __ior__(self, other: Iterable[T]) -> DAGSet[T]:
# a |= b => add everything in other # a |= b => add everything in other
self.update(other)
return self return self
def __iand__(self, other: Iterable[T]) -> DAGSet[T]: def __iadd__(self, other: Iterable[T]) -> DAGSet[T]:
# a &= b => keep only those also in other # a += b => update/extend
other_set = set(other)
for v in list(self._raw()):
if v not in other_set:
self._g.remove_edge(self._u, v, missing_ok=True)
return self return self
def __isub__(self, other: Iterable[T]) -> DAGSet[T]: def __isub__(self, other: Iterable[T]) -> DAGSet[T]:
# a -= b => remove those in other # a -= b => remove those in other
for v in other: return self
self._g.remove_edge(self._u, v, missing_ok=True)
def __iand__(self, other: Iterable[T]) -> DAGSet[T]:
# a &= b => keep only those also in other
return self return self
def __ixor__(self, other: Iterable[T]) -> DAGSet[T]: def __ixor__(self, other: Iterable[T]) -> DAGSet[T]:
# a ^= b => symmetric difference update # a ^= b => symmetric difference update
other_set = set(other)
raw = self._raw()
to_remove = raw & other_set
to_add = other_set - raw
for v in to_remove:
self._g.remove_edge(self._u, v, missing_ok=True)
for v in to_add:
self._g.add_edge(self._u, v)
return self
def __iadd__(self, other: Iterable[T]) -> DAGSet[T]:
"""
Built-in set doesn't define +=, but many folks expect it.
Treat it as 'update' (extend).
"""
self.update(other)
return self return self
# --- nice repr for debugging --- # --- nice repr for debugging ---
def __repr__(self) -> str: def __repr__(self) -> str:
return repr(self._raw()) return super().__repr__()