diff --git a/src/daglib/dag.py b/src/daglib/dag.py index 0227808..e3d7b72 100644 --- a/src/daglib/dag.py +++ b/src/daglib/dag.py @@ -11,6 +11,7 @@ type dictset = defaultdict[T, set[T]] T = TypeVar("T") + @dataclass(repr=False) class DAG(Generic[T], MutableMapping[T, DAGSetView[T]]): """ @@ -53,22 +54,25 @@ class DAG(Generic[T], MutableMapping[T, DAGSetView[T]]): dagset.on_remove = _on_remove return dagset - def __setitem__(self, u: T, vs: Iterable[T]) -> None: - match vs: - case set(): - self._succ[u] |= vs - for v in vs: - self._pred[v] |= {u} - case str(): - self._succ[u] |= {vs} - self._pred[vs] |= {u} - case Iterable(): - self._succ[u] |= set(vs) - for v in vs: - self._pred[v] |= {u} - case _: - self._succ[u] |= {vs} - self._pred[vs] |= {u} + def __setitem__(self, u: T, value: Iterable[T]) -> None: + match value: + case set() as vs: + self.__setmultiple__(u, vs) + case str() as v: + self.__setsingle__(u, v) + case Iterable() as vs: + self.__setmultiple__(u, vs) + case _ as v: + self.__setsingle__(u, v) + + def __setsingle__(self, u: T, v: T) -> None: + self._succ[u] |= {v} + self._pred[v] |= {u} + + def __setmultiple__(self, u: T, vs: Iterable[T]) -> None: + self._succ[u] |= set(vs) + for v in vs: + self._pred[v] |= {u} def __delitem__(self, u: T) -> None: self.discard_node(u)