From 0abbb9e546c43a9698e4be13501be8d09167354a Mon Sep 17 00:00:00 2001 From: John Lancaster <32917998+jsl12@users.noreply.github.com> Date: Mon, 21 Oct 2024 02:41:19 +0000 Subject: [PATCH] context manager work for startup/shutdown --- appdaemon/context_manager.py | 27 +++--- appdaemon/main.py | 77 +++++++++++++++++ appdaemon/subsystem.py | 163 ++++++++++++++++++----------------- 3 files changed, 171 insertions(+), 96 deletions(-) create mode 100644 appdaemon/main.py diff --git a/appdaemon/context_manager.py b/appdaemon/context_manager.py index aceb11b..46b8dd0 100644 --- a/appdaemon/context_manager.py +++ b/appdaemon/context_manager.py @@ -5,41 +5,38 @@ import signal from concurrent.futures import ThreadPoolExecutor from contextlib import ExitStack, contextmanager from dataclasses import dataclass, field -from threading import Event, Lock +from threading import Event from typing import Any, Callable logger = logging.getLogger(__name__) -def handler(signum, frame): - print('Signal handler called with signal', signum) - raise OSError("Couldn't open device!") + @dataclass class AppDaemonRunContext: _stack: ExitStack = field(default_factory=ExitStack) stop_event: Event = field(default_factory=Event) - shutdown_lock: Lock = field(default_factory=Lock) - shutdown_grace_period: float = 1.0 + shutdown_grace_period: float = 0.75 loop: asyncio.AbstractEventLoop = field(init=False) executor: ThreadPoolExecutor = field(init=False) def __enter__(self): self.loop = self._stack.enter_context(self.asyncio_context()) - logger.debug("Entered asyncio context") + logger.debug("Created event loop") signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) for s in signals: self.loop.add_signal_handler(s, lambda s=s: self.loop.create_task(self.shutdown(s))) self.executor = self._stack.enter_context(self.thread_context()) - logger.debug("Entered threadpool context") + logger.debug("Started thread pool") return self def __exit__(self, exc_type, exc_value, traceback): - self._stack.__exit__(exc_type, exc_value, traceback) - logger.debug("Exited context") + logger.debug(f'Closing context from {self.__class__.__name__}') + self._stack.close() def get_running_tasks(self, exclude_current: bool = True) -> list[asyncio.Task]: return [ @@ -47,7 +44,7 @@ class AppDaemonRunContext: if exclude_current and t is not asyncio.current_task() ] - async def shutdown(self, signal): + async def shutdown(self, signal=signal.SIGTERM): """Cleanup tasks tied to the service's shutdown. https://www.roguelynn.com/words/asyncio-graceful-shutdowns/ @@ -68,9 +65,9 @@ class AppDaemonRunContext: for task in tasks: if task.cancelled(): - logger.warning(f'Cancelled {task.get_coro().__qualname__}') + logger.warning(f'Cancelled {task.get_name()}') - logger.debug("Stopping event loop in context shutdown") + logger.info("Stopping asyncio event loop") self.loop.stop() else: logger.warning('Already started shutdown') @@ -83,14 +80,14 @@ class AppDaemonRunContext: finally: loop.close() if loop.is_closed(): - logger.debug("Closed the event loop.") + logger.debug("Gracefully closed event loop.") @contextmanager def thread_context(self): with ThreadPoolExecutor(max_workers=5) as executor: yield executor if executor._shutdown: - logger.debug('Shut down the ThreadPoolExecutor') + logger.debug('Gracefully shut down ThreadPoolExecutor') async def run_in_executor( self, diff --git a/appdaemon/main.py b/appdaemon/main.py new file mode 100644 index 0000000..bc78c0c --- /dev/null +++ b/appdaemon/main.py @@ -0,0 +1,77 @@ +from contextlib import ExitStack +from dataclasses import dataclass, field + +from appdaemon.models.ad_config import AppDaemonConfig +from context_manager import AppDaemonRunContext +from subsystem import AppDaemon + + +@dataclass +class ADMain: + """Class to contain the mechanics to run AppDaemon as module + """ + config_file: str + cfg: AppDaemonConfig = field(init=False) + _stack: ExitStack = field(default_factory=ExitStack) + run_context: AppDaemonRunContext = field(init=False) + ad: AppDaemon = field(init=False) + + def __post_init__(self) -> None: + raw_cfg = read_config_file(self.config_file) + self.cfg = AppDaemonConfig( + config_file=self.config_file, + **raw_cfg['appdaemon'] + ) + + def __enter__(self): + """Used to start the asyncio loop, thread pool, and AppDaemon""" + # Use the ExitStack from ADMain instead of creating a new one + self.run_context = AppDaemonRunContext(_stack=self._stack) + self._stack.enter_context(self.run_context) + + # Start AppDaemon by entering it's context + self.ad = AppDaemon(self.cfg, self.run_context) + self._stack.enter_context(self.ad) + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.ad.logger.debug(f'Closing context from {self.__class__.__name__}') + self._stack.close() + self.ad.logger.info('Stopped main()') + + def run(self): + if hasattr(self, 'ad'): + self.ad.logger.info('Running asyncio event loop indefinitely...') + self.run_context.loop.run_forever() + else: + logging.error('Running ADMain without context manager') + + +if __name__ == '__main__': + import logging.config + + from appdaemon.utils import read_config_file + from rich.console import Console + from rich.highlighter import NullHighlighter + + console = Console() + logging.config.dictConfig( + { + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': {'basic': {'style': '{', 'format': '[yellow]{name}[/] {message}'}}, + 'handlers': { + 'rich': { + '()': 'rich.logging.RichHandler', + 'formatter': 'basic', + 'console': console, + 'highlighter': NullHighlighter(), + 'markup': True + } + }, + 'root': {'level': 'INFO', 'handlers': ['rich']}, + } + ) + + with ADMain('/conf/ad-test/conf/appdaemon.yaml') as main: + main.run() diff --git a/appdaemon/subsystem.py b/appdaemon/subsystem.py index a45bea2..6a55e09 100644 --- a/appdaemon/subsystem.py +++ b/appdaemon/subsystem.py @@ -1,9 +1,15 @@ import asyncio import logging +import os +import signal +import traceback +from contextlib import ExitStack from dataclasses import dataclass, field from logging import Logger, getLogger +from random import random from threading import Event, RLock -from typing import Callable, Coroutine +from time import perf_counter +from typing import Coroutine from appdaemon.models import AppDaemonConfig from context_manager import AppDaemonRunContext @@ -13,14 +19,22 @@ from context_manager import AppDaemonRunContext class ADSubsystem: AD: "AppDaemon" stop: Event + """An thread event for the subsystem to use to shutdown gracefully""" lock: RLock = field(default_factory=RLock) + """A threadsafe re-entrant lock to protect any internal data while it's being modified""" logger: Logger = field(init=False) + tasks: list[asyncio.Task] = field(default_factory=list) def __post_init__(self) -> None: name = f'_{self.__class__.__name__.lower()}' self.logger = getLogger(f'AppDaemon.{name}') - if start_func := getattr(self, 'start', False): - self.AD.starts.append(start_func) + self.create_task = self.AD.create_task + + def __enter__(self): + self.logger.debug(f'Starting {self.__class__.__name__}') + + def __exit__(self, exc_type, exc_value, traceback): + self.logger.debug(f'Exiting {self.__class__.__name__}') @property def stopping(self) -> bool: @@ -29,120 +43,107 @@ class ADSubsystem: @dataclass class Utility(ADSubsystem): - loop_rate: float = 0.5 + loop_rate: float = 0.25 - def start(self): - self.AD.create_task(self.loop(), 'Utility loop') + def __enter__(self): + super().__enter__() + self.create_task(self.loop(), 'Utility loop', critical=True) + return self async def loop(self): while not self.stopping: self.logger.debug('Looping...') - await asyncio.sleep(self.loop_rate) - self.logger.debug('Stopped utility loop') + try: + await asyncio.sleep(self.loop_rate) + except asyncio.CancelledError: + self.logger.debug('Cancelled during sleep') + self.logger.debug('Stopped utility loop gracefully') @dataclass class Plugin(ADSubsystem): state: dict[str, int] = field(default_factory=dict) - update_rate: float = 5.0 + update_rate: float = 2.0 def __post_init__(self) -> None: super().__post_init__() self.state['update_count'] = 0 - def start(self): - self.AD.create_task(self.periodic_self_udpate(), - 'Periodic plugin update') + def __enter__(self): + super().__enter__() + self.create_task( + self.periodic_self_udpate(), + name='plugin periodic update', + critical=True + ) + return self async def periodic_self_udpate(self): + loop_time = perf_counter() while not self.stopping: with self.lock: self.state['update_count'] += 1 - self.logger.info(f'Updated self: {self.state["update_count"]}') - await asyncio.sleep(self.update_rate) - self.logger.debug('Stopped plugin updates') + # self.logger.debug('Long plugin update...') + # await asyncio.sleep(random()) + self.logger.debug( + 'Plugin self update: %s %s', + self.state["update_count"], + f'{perf_counter()-loop_time:.3f}s' + ) + loop_time = perf_counter() + + # if self.state['update_count'] == 2: + # raise ValueError('fake error') + try: + await asyncio.sleep(self.update_rate) + except asyncio.CancelledError: + self.logger.debug('Cancelled during sleep') + self.logger.debug('Stopped plugin updates gracefully') @dataclass class AppDaemon: cfg: AppDaemonConfig context: AppDaemonRunContext + _stack: ExitStack = field(default_factory=ExitStack) utility: Utility = field(init=False) plugins: dict[str, Plugin] = field(default_factory=dict) - starts: list[Callable] = field(default_factory=list) def __post_init__(self) -> None: self.logger = logging.getLogger('AppDaemon') self.utility = Utility(self, self.context.stop_event) self.plugins['dummy'] = Plugin(self, self.context.stop_event) - def create_task(self, coro: Coroutine, name: str | None = None): - return self.context.loop.create_task(coro, name=name) + def __enter__(self): + self.logger.info('Starting AppDaemon') + self._stack.enter_context(self.utility) + for plugin in self.plugins.values(): + self._stack.enter_context(plugin) + return self - def start(self): - for start in self.starts: - subsystem = start.__qualname__.split('.')[0] - self.logger.debug(f'Starting {subsystem}') - start() + def __exit__(self, exc_type, exc_value, traceback): + self._stack.__exit__(exc_type, exc_value, traceback) + def create_task(self, coro: Coroutine, name: str | None = None, critical: bool = False): + """Creates an async task and adds exception callbacks""" + task = self.context.loop.create_task(coro, name=name) + task.add_done_callback(self.check_task_exception) + if critical: + task.add_done_callback(self.critical_exception) + return task -@dataclass -class ADMain: - config_file: str - cfg: AppDaemonConfig = field(init=False) + def check_task_exception(self, task: asyncio.Task): + if (exc := task.exception()) and not isinstance(exc, asyncio.CancelledError): + self.logger.error('\n'.join(traceback.format_exception(exc))) - def __post_init__(self) -> None: - raw_cfg = read_config_file(self.config_file) - self.cfg = AppDaemonConfig( - config_file=self.config_file, - **raw_cfg['appdaemon'] - ) + def critical_exception(self, task: asyncio.Task): + if task.exception(): + self.logger.critical( + 'Critical error in %s, forcing shutdown', + task.get_name() + ) + self.shutdown() - def run(self): - with AppDaemonRunContext() as cm: - ad = AppDaemon(self.cfg, cm) - ad.start() - cm.loop.run_forever() - - -if __name__ == '__main__': - import logging.config - - from appdaemon.utils import read_config_file - from rich.console import Console - from rich.highlighter import NullHighlighter - - console = Console() - logging.config.dictConfig( - { - 'version': 1, - 'disable_existing_loggers': False, - 'formatters': {'basic': {'style': '{', 'format': '[yellow]{name}[/] {message}'}}, - 'handlers': { - 'rich': { - '()': 'rich.logging.RichHandler', - 'formatter': 'basic', - 'console': console, - 'highlighter': NullHighlighter(), - 'markup': True - } - }, - 'root': {'level': 'DEBUG', 'handlers': ['rich']}, - } - ) - - main = ADMain('/conf/ad-test/conf/appdaemon.yaml') - main.run() - - # config_file = '/conf/ad-test/conf/appdaemon.yaml' - # raw_cfg = read_config_file(config_file) - - # cfg = AppDaemonConfig( - # config_file=config_file, - # **raw_cfg['appdaemon'] - # ) - - # with AppDaemonRunContext() as cm: - # ad = AppDaemon(cfg, cm) - # # ad.start() - # # cm.loop.run_forever() + def shutdown(self): + self.logger.debug('Shutting down by sending SIGTERM') + os.kill(os.getpid(), signal.SIGTERM)