diff --git a/src/daglib/dag.py b/src/daglib/dag.py index 5503844..0273280 100644 --- a/src/daglib/dag.py +++ b/src/daglib/dag.py @@ -38,13 +38,13 @@ class DAG(Generic[T], MutableMapping[T, MutableSet[T]]): dagset = DAGSet(self._succ.get(u)) def on_add(v: T) -> None: - self._succ[u].add(v) - self._pred[v].add(u) + self._succ[u] |= {v} + self._pred[v] |= {u} print(f"Adding edge {u} -> {v}") def on_remove(v: T) -> None: - self._succ[u].remove(v) - self._pred[v].remove(u) + self._succ[u] -= {v} + self._pred[v] -= {u} print(f"Removing edge {u} -> {v}") dagset.on_add = on_add @@ -52,10 +52,21 @@ class DAG(Generic[T], MutableMapping[T, MutableSet[T]]): return dagset def __setitem__(self, u: T, vs: Iterable[T]) -> None: - view = DAGSet(self, u) - view.clear() - view |= vs # uses our in-place operator - + match vs: + case set(): + self._succ[u] |= vs + for v in vs: + self._pred[v] |= {u} + case str(): + self._succ[u] |= {vs} + self._pred[vs] |= {u} + case Iterable(): + self._succ[u] |= set(vs) + for v in vs: + self._pred[v] |= {u} + case _: + raise TypeError(f"Expected set or iterable, got {type(vs).__name__}") + def __delitem__(self, u: T) -> None: self.discard_node(u)