import asyncio import logging import os import signal import traceback from contextlib import ExitStack from dataclasses import dataclass, field from logging import Logger, getLogger from random import random from threading import Event, RLock from time import perf_counter from typing import Coroutine from appdaemon.models import AppDaemonConfig from context_manager import AppDaemonRunContext @dataclass class ADSubsystem: AD: "AppDaemon" stop: Event """An thread event for the subsystem to use to shutdown gracefully""" lock: RLock = field(default_factory=RLock) """A threadsafe re-entrant lock to protect any internal data while it's being modified""" logger: Logger = field(init=False) tasks: list[asyncio.Task] = field(default_factory=list) def __post_init__(self) -> None: name = f'_{self.__class__.__name__.lower()}' self.logger = getLogger(f'AppDaemon.{name}') self.create_task = self.AD.create_task def __enter__(self): self.logger.debug(f'Starting {self.__class__.__name__}') def __exit__(self, exc_type, exc_value, traceback): pass # self.logger.debug(f'Exiting {self.__class__.__name__}') @property def stopping(self) -> bool: return self.stop.is_set() async def sleep(self, delay: float): """Wrapper function for asyncio.sleep that suppresses and logs a task cancellation""" try: if not self.stopping: await asyncio.sleep(delay) else: self.logger.debug('Skipping sleep due to stop event') except asyncio.CancelledError: self.logger.debug('Cancelled during sleep') @dataclass class Utility(ADSubsystem): loop_rate: float = 1.0 def __enter__(self): super().__enter__() self.create_task(self.loop(), 'utility loop', critical=True) return self async def loop(self): while not self.stopping: self.logger.debug('Looping...') await self.sleep(self.loop_rate) task_name = asyncio.current_task().get_name() self.logger.debug(f'Gracefully stopped {task_name} task') @dataclass class Plugin(ADSubsystem): state: dict[str, int] = field(default_factory=dict) update_rate: float = 5.0 def __post_init__(self) -> None: super().__post_init__() self.state['update_count'] = 0 def __enter__(self): super().__enter__() self.create_task( self.periodic_self_udpate(), name='plugin periodic update', critical=True ) return self async def periodic_self_udpate(self): loop_time = perf_counter() while not self.stopping: with self.lock: self.state['update_count'] += 1 # self.logger.debug('Long plugin update...') # await asyncio.sleep(random()) self.logger.debug( 'Plugin self update: %s %s', self.state["update_count"], f'{perf_counter()-loop_time:.3f}s' ) loop_time = perf_counter() # if self.state['update_count'] == 2: # raise ValueError('fake error') await self.sleep(self.update_rate) task_name = asyncio.current_task().get_name() self.logger.debug(f'Gracefully stopped {task_name} task') @dataclass class AppDaemon: cfg: AppDaemonConfig context: AppDaemonRunContext _stack: ExitStack = field(default_factory=ExitStack) utility: Utility = field(init=False) plugins: dict[str, Plugin] = field(default_factory=dict) def __post_init__(self) -> None: self.logger = logging.getLogger('AppDaemon') self.utility = Utility(self, self.context.stop_event) self.plugins['dummy'] = Plugin(self, self.context.stop_event) def __enter__(self): self.logger.info('Starting AppDaemon') self._stack.enter_context(self.utility) for plugin in self.plugins.values(): self._stack.enter_context(plugin) return self def __exit__(self, exc_type, exc_value, traceback): self._stack.__exit__(exc_type, exc_value, traceback) def create_task(self, coro: Coroutine, name: str | None = None, critical: bool = False): """Creates an async task and adds exception callbacks""" task = self.context.loop.create_task(coro, name=name) task.add_done_callback(self.check_task_exception) if critical: task.add_done_callback(self.critical_exception) return task def check_task_exception(self, task: asyncio.Task): if (exc := task.exception()) and not isinstance(exc, asyncio.CancelledError): self.logger.error('\n'.join(traceback.format_exception(exc))) def critical_exception(self, task: asyncio.Task): if task.exception(): self.logger.critical( 'Critical error in %s, forcing shutdown', task.get_name() ) self.shutdown() def shutdown(self): self.logger.debug('Shutting down by sending SIGTERM') os.kill(os.getpid(), signal.SIGTERM)