From 424ce9acaba1aeb2b975523870bd742d8c40ec32 Mon Sep 17 00:00:00 2001 From: John Lancaster <32917998+jsl12@users.noreply.github.com> Date: Fri, 20 Feb 2026 23:28:21 -0600 Subject: [PATCH] first hooks --- src/daglib/set.py | 66 ++++++++++++++++++++++++++++++----------------- 1 file changed, 43 insertions(+), 23 deletions(-) diff --git a/src/daglib/set.py b/src/daglib/set.py index 79fef90..95a3b62 100644 --- a/src/daglib/set.py +++ b/src/daglib/set.py @@ -1,7 +1,7 @@ from __future__ import annotations -from collections.abc import Iterable, Iterator, MutableSet -from typing import TypeVar +from collections.abc import Callable, Iterable, Iterator, MutableSet +from typing import Any, TypeVar T = TypeVar("T") @@ -18,46 +18,66 @@ class DAGSet(MutableSet[T]): Also supports += as an alias for update for ergonomic batching. """ + on_add: Callable[[Any], None] | None = None + on_remove: Callable[[Any], None] | None = None + + def __init__(self, v: Iterable[T] | None = None) -> None: + match v: + case set(): + pass + case Iterable(): + v = set(v) + case None: + v = set() + case _: + raise TypeError(f"Expected set or iterable, got {type(v).__name__}") + self._data: set[T] = v + # --- required MutableSet methods --- def __contains__(self, x: object) -> bool: - return super().__contains__(x) + return x in self._data def __iter__(self) -> Iterator[T]: - return super().__iter__() + return iter(self._data) def __len__(self) -> int: - return super().__len__() + return len(self._data) def add(self, v: T) -> None: - super().add(v) + self._data.add(v) + if self.on_add: + self.on_add(v) def discard(self, v: T) -> None: - super().discard(v) + self._data.discard(v) + if self.on_remove: + self.on_remove() # --- in-place operator support --- - def __ior__(self, other: Iterable[T]) -> DAGSet[T]: - # a |= b => add everything in other - return self + # def __ior__(self, other: Iterable[T]) -> DAGSet[T]: + # # a |= b => add everything in other + # return self - def __iadd__(self, other: Iterable[T]) -> DAGSet[T]: - # a += b => update/extend - return self + # def __iadd__(self, other: Iterable[T]) -> DAGSet[T]: + # # a += b => update/extend + # return self - def __isub__(self, other: Iterable[T]) -> DAGSet[T]: - # a -= b => remove those in other - return self + # def __isub__(self, other: Iterable[T]) -> DAGSet[T]: + # # a -= b => remove those in other + # return self - def __iand__(self, other: Iterable[T]) -> DAGSet[T]: - # a &= b => keep only those also in other - return self + # def __iand__(self, other: Iterable[T]) -> DAGSet[T]: + # # a &= b => keep only those also in other + # return self - def __ixor__(self, other: Iterable[T]) -> DAGSet[T]: - # a ^= b => symmetric difference update - return self + # def __ixor__(self, other: Iterable[T]) -> DAGSet[T]: + # # a ^= b => symmetric difference update + # return self # --- nice repr for debugging --- def __repr__(self) -> str: - return super().__repr__() + item_str = ", ".join(repr(x) for x in self) + return f"{{{item_str}}}"