diff --git a/src/daglib/set.py b/src/daglib/set.py index 954c931..3c13180 100644 --- a/src/daglib/set.py +++ b/src/daglib/set.py @@ -50,9 +50,10 @@ class DAGSet(MutableSet[T]): self.on_add(v) def discard(self, v: T) -> None: - self._data.discard(v) - if self.on_remove: - self.on_remove(v) + if v in self._data: + self._data.discard(v) + if self.on_remove: + self.on_remove(v) # --- in-place operator support --- @@ -69,22 +70,38 @@ class DAGSet(MutableSet[T]): self |= {other} return self - # def __isub__(self, other: Iterable[T]) -> DAGSet[T]: - # # a -= b => remove those in other - # match other: - # case set(): - # self -= other - # case _: - # self -= set(other) - # return self + def __isub__(self, other: Iterable[T]) -> DAGSet[T]: + # a -= b => remove those in other + match other: + case str(): + other = {other} + case Iterable(): + other = set(other) + case _: + other = {other} + return super().__isub__(other) - # def __iand__(self, other: Iterable[T]) -> DAGSet[T]: - # # a &= b => keep only those also in other - # return self + def __iand__(self, other: Iterable[T]) -> DAGSet[T]: + # a &= b => keep only those also in other + match other: + case str(): + other = {other} + case Iterable(): + other = set(other) + case _: + other = {other} + return super().__iand__(other) - # def __ixor__(self, other: Iterable[T]) -> DAGSet[T]: - # # a ^= b => symmetric difference update - # return self + def __ixor__(self, other: Iterable[T]) -> DAGSet[T]: + # a ^= b => symmetric difference update + match other: + case str(): + other = {other} + case Iterable(): + other = set(other) + case _: + other = {other} + return super().__ixor__(other) # --- nice repr for debugging ---