first hooks

This commit is contained in:
John Lancaster
2026-02-20 23:28:21 -06:00
parent 1d36496d3e
commit 424ce9acab

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable, Iterator, MutableSet from collections.abc import Callable, Iterable, Iterator, MutableSet
from typing import TypeVar from typing import Any, TypeVar
T = TypeVar("T") T = TypeVar("T")
@@ -18,46 +18,66 @@ class DAGSet(MutableSet[T]):
Also supports += as an alias for update for ergonomic batching. Also supports += as an alias for update for ergonomic batching.
""" """
on_add: Callable[[Any], None] | None = None
on_remove: Callable[[Any], None] | None = None
def __init__(self, v: Iterable[T] | None = None) -> None:
match v:
case set():
pass
case Iterable():
v = set(v)
case None:
v = set()
case _:
raise TypeError(f"Expected set or iterable, got {type(v).__name__}")
self._data: set[T] = v
# --- required MutableSet methods --- # --- required MutableSet methods ---
def __contains__(self, x: object) -> bool: def __contains__(self, x: object) -> bool:
return super().__contains__(x) return x in self._data
def __iter__(self) -> Iterator[T]: def __iter__(self) -> Iterator[T]:
return super().__iter__() return iter(self._data)
def __len__(self) -> int: def __len__(self) -> int:
return super().__len__() return len(self._data)
def add(self, v: T) -> None: def add(self, v: T) -> None:
super().add(v) self._data.add(v)
if self.on_add:
self.on_add(v)
def discard(self, v: T) -> None: def discard(self, v: T) -> None:
super().discard(v) self._data.discard(v)
if self.on_remove:
self.on_remove()
# --- in-place operator support --- # --- in-place operator support ---
def __ior__(self, other: Iterable[T]) -> DAGSet[T]: # def __ior__(self, other: Iterable[T]) -> DAGSet[T]:
# a |= b => add everything in other # # a |= b => add everything in other
return self # return self
def __iadd__(self, other: Iterable[T]) -> DAGSet[T]: # def __iadd__(self, other: Iterable[T]) -> DAGSet[T]:
# a += b => update/extend # # a += b => update/extend
return self # return self
def __isub__(self, other: Iterable[T]) -> DAGSet[T]: # def __isub__(self, other: Iterable[T]) -> DAGSet[T]:
# a -= b => remove those in other # # a -= b => remove those in other
return self # return self
def __iand__(self, other: Iterable[T]) -> DAGSet[T]: # def __iand__(self, other: Iterable[T]) -> DAGSet[T]:
# a &= b => keep only those also in other # # a &= b => keep only those also in other
return self # return self
def __ixor__(self, other: Iterable[T]) -> DAGSet[T]: # def __ixor__(self, other: Iterable[T]) -> DAGSet[T]:
# a ^= b => symmetric difference update # # a ^= b => symmetric difference update
return self # return self
# --- nice repr for debugging --- # --- nice repr for debugging ---
def __repr__(self) -> str: def __repr__(self) -> str:
return super().__repr__() item_str = ", ".join(repr(x) for x in self)
return f"{{{item_str}}}"