context manager work for startup/shutdown

This commit is contained in:
John Lancaster
2024-10-21 02:41:19 +00:00
parent 502c218c35
commit 0abbb9e546
3 changed files with 171 additions and 96 deletions

View File

@@ -5,41 +5,38 @@ import signal
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack, contextmanager from contextlib import ExitStack, contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from threading import Event, Lock from threading import Event
from typing import Any, Callable from typing import Any, Callable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def handler(signum, frame):
print('Signal handler called with signal', signum)
raise OSError("Couldn't open device!")
@dataclass @dataclass
class AppDaemonRunContext: class AppDaemonRunContext:
_stack: ExitStack = field(default_factory=ExitStack) _stack: ExitStack = field(default_factory=ExitStack)
stop_event: Event = field(default_factory=Event) stop_event: Event = field(default_factory=Event)
shutdown_lock: Lock = field(default_factory=Lock) shutdown_grace_period: float = 0.75
shutdown_grace_period: float = 1.0
loop: asyncio.AbstractEventLoop = field(init=False) loop: asyncio.AbstractEventLoop = field(init=False)
executor: ThreadPoolExecutor = field(init=False) executor: ThreadPoolExecutor = field(init=False)
def __enter__(self): def __enter__(self):
self.loop = self._stack.enter_context(self.asyncio_context()) self.loop = self._stack.enter_context(self.asyncio_context())
logger.debug("Entered asyncio context") logger.debug("Created event loop")
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
for s in signals: for s in signals:
self.loop.add_signal_handler(s, lambda s=s: self.loop.create_task(self.shutdown(s))) self.loop.add_signal_handler(s, lambda s=s: self.loop.create_task(self.shutdown(s)))
self.executor = self._stack.enter_context(self.thread_context()) self.executor = self._stack.enter_context(self.thread_context())
logger.debug("Entered threadpool context") logger.debug("Started thread pool")
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
self._stack.__exit__(exc_type, exc_value, traceback) logger.debug(f'Closing context from {self.__class__.__name__}')
logger.debug("Exited context") self._stack.close()
def get_running_tasks(self, exclude_current: bool = True) -> list[asyncio.Task]: def get_running_tasks(self, exclude_current: bool = True) -> list[asyncio.Task]:
return [ return [
@@ -47,7 +44,7 @@ class AppDaemonRunContext:
if exclude_current and t is not asyncio.current_task() if exclude_current and t is not asyncio.current_task()
] ]
async def shutdown(self, signal): async def shutdown(self, signal=signal.SIGTERM):
"""Cleanup tasks tied to the service's shutdown. """Cleanup tasks tied to the service's shutdown.
https://www.roguelynn.com/words/asyncio-graceful-shutdowns/ https://www.roguelynn.com/words/asyncio-graceful-shutdowns/
@@ -68,9 +65,9 @@ class AppDaemonRunContext:
for task in tasks: for task in tasks:
if task.cancelled(): if task.cancelled():
logger.warning(f'Cancelled {task.get_coro().__qualname__}') logger.warning(f'Cancelled {task.get_name()}')
logger.debug("Stopping event loop in context shutdown") logger.info("Stopping asyncio event loop")
self.loop.stop() self.loop.stop()
else: else:
logger.warning('Already started shutdown') logger.warning('Already started shutdown')
@@ -83,14 +80,14 @@ class AppDaemonRunContext:
finally: finally:
loop.close() loop.close()
if loop.is_closed(): if loop.is_closed():
logger.debug("Closed the event loop.") logger.debug("Gracefully closed event loop.")
@contextmanager @contextmanager
def thread_context(self): def thread_context(self):
with ThreadPoolExecutor(max_workers=5) as executor: with ThreadPoolExecutor(max_workers=5) as executor:
yield executor yield executor
if executor._shutdown: if executor._shutdown:
logger.debug('Shut down the ThreadPoolExecutor') logger.debug('Gracefully shut down ThreadPoolExecutor')
async def run_in_executor( async def run_in_executor(
self, self,

77
appdaemon/main.py Normal file
View File

@@ -0,0 +1,77 @@
from contextlib import ExitStack
from dataclasses import dataclass, field
from appdaemon.models.ad_config import AppDaemonConfig
from context_manager import AppDaemonRunContext
from subsystem import AppDaemon
@dataclass
class ADMain:
"""Class to contain the mechanics to run AppDaemon as module
"""
config_file: str
cfg: AppDaemonConfig = field(init=False)
_stack: ExitStack = field(default_factory=ExitStack)
run_context: AppDaemonRunContext = field(init=False)
ad: AppDaemon = field(init=False)
def __post_init__(self) -> None:
raw_cfg = read_config_file(self.config_file)
self.cfg = AppDaemonConfig(
config_file=self.config_file,
**raw_cfg['appdaemon']
)
def __enter__(self):
"""Used to start the asyncio loop, thread pool, and AppDaemon"""
# Use the ExitStack from ADMain instead of creating a new one
self.run_context = AppDaemonRunContext(_stack=self._stack)
self._stack.enter_context(self.run_context)
# Start AppDaemon by entering it's context
self.ad = AppDaemon(self.cfg, self.run_context)
self._stack.enter_context(self.ad)
return self
def __exit__(self, exc_type, exc_value, traceback):
self.ad.logger.debug(f'Closing context from {self.__class__.__name__}')
self._stack.close()
self.ad.logger.info('Stopped main()')
def run(self):
if hasattr(self, 'ad'):
self.ad.logger.info('Running asyncio event loop indefinitely...')
self.run_context.loop.run_forever()
else:
logging.error('Running ADMain without context manager')
if __name__ == '__main__':
import logging.config
from appdaemon.utils import read_config_file
from rich.console import Console
from rich.highlighter import NullHighlighter
console = Console()
logging.config.dictConfig(
{
'version': 1,
'disable_existing_loggers': False,
'formatters': {'basic': {'style': '{', 'format': '[yellow]{name}[/] {message}'}},
'handlers': {
'rich': {
'()': 'rich.logging.RichHandler',
'formatter': 'basic',
'console': console,
'highlighter': NullHighlighter(),
'markup': True
}
},
'root': {'level': 'INFO', 'handlers': ['rich']},
}
)
with ADMain('/conf/ad-test/conf/appdaemon.yaml') as main:
main.run()

View File

@@ -1,9 +1,15 @@
import asyncio import asyncio
import logging import logging
import os
import signal
import traceback
from contextlib import ExitStack
from dataclasses import dataclass, field from dataclasses import dataclass, field
from logging import Logger, getLogger from logging import Logger, getLogger
from random import random
from threading import Event, RLock from threading import Event, RLock
from typing import Callable, Coroutine from time import perf_counter
from typing import Coroutine
from appdaemon.models import AppDaemonConfig from appdaemon.models import AppDaemonConfig
from context_manager import AppDaemonRunContext from context_manager import AppDaemonRunContext
@@ -13,14 +19,22 @@ from context_manager import AppDaemonRunContext
class ADSubsystem: class ADSubsystem:
AD: "AppDaemon" AD: "AppDaemon"
stop: Event stop: Event
"""An thread event for the subsystem to use to shutdown gracefully"""
lock: RLock = field(default_factory=RLock) lock: RLock = field(default_factory=RLock)
"""A threadsafe re-entrant lock to protect any internal data while it's being modified"""
logger: Logger = field(init=False) logger: Logger = field(init=False)
tasks: list[asyncio.Task] = field(default_factory=list)
def __post_init__(self) -> None: def __post_init__(self) -> None:
name = f'_{self.__class__.__name__.lower()}' name = f'_{self.__class__.__name__.lower()}'
self.logger = getLogger(f'AppDaemon.{name}') self.logger = getLogger(f'AppDaemon.{name}')
if start_func := getattr(self, 'start', False): self.create_task = self.AD.create_task
self.AD.starts.append(start_func)
def __enter__(self):
self.logger.debug(f'Starting {self.__class__.__name__}')
def __exit__(self, exc_type, exc_value, traceback):
self.logger.debug(f'Exiting {self.__class__.__name__}')
@property @property
def stopping(self) -> bool: def stopping(self) -> bool:
@@ -29,120 +43,107 @@ class ADSubsystem:
@dataclass @dataclass
class Utility(ADSubsystem): class Utility(ADSubsystem):
loop_rate: float = 0.5 loop_rate: float = 0.25
def start(self): def __enter__(self):
self.AD.create_task(self.loop(), 'Utility loop') super().__enter__()
self.create_task(self.loop(), 'Utility loop', critical=True)
return self
async def loop(self): async def loop(self):
while not self.stopping: while not self.stopping:
self.logger.debug('Looping...') self.logger.debug('Looping...')
await asyncio.sleep(self.loop_rate) try:
self.logger.debug('Stopped utility loop') await asyncio.sleep(self.loop_rate)
except asyncio.CancelledError:
self.logger.debug('Cancelled during sleep')
self.logger.debug('Stopped utility loop gracefully')
@dataclass @dataclass
class Plugin(ADSubsystem): class Plugin(ADSubsystem):
state: dict[str, int] = field(default_factory=dict) state: dict[str, int] = field(default_factory=dict)
update_rate: float = 5.0 update_rate: float = 2.0
def __post_init__(self) -> None: def __post_init__(self) -> None:
super().__post_init__() super().__post_init__()
self.state['update_count'] = 0 self.state['update_count'] = 0
def start(self): def __enter__(self):
self.AD.create_task(self.periodic_self_udpate(), super().__enter__()
'Periodic plugin update') self.create_task(
self.periodic_self_udpate(),
name='plugin periodic update',
critical=True
)
return self
async def periodic_self_udpate(self): async def periodic_self_udpate(self):
loop_time = perf_counter()
while not self.stopping: while not self.stopping:
with self.lock: with self.lock:
self.state['update_count'] += 1 self.state['update_count'] += 1
self.logger.info(f'Updated self: {self.state["update_count"]}') # self.logger.debug('Long plugin update...')
await asyncio.sleep(self.update_rate) # await asyncio.sleep(random())
self.logger.debug('Stopped plugin updates') self.logger.debug(
'Plugin self update: %s %s',
self.state["update_count"],
f'{perf_counter()-loop_time:.3f}s'
)
loop_time = perf_counter()
# if self.state['update_count'] == 2:
# raise ValueError('fake error')
try:
await asyncio.sleep(self.update_rate)
except asyncio.CancelledError:
self.logger.debug('Cancelled during sleep')
self.logger.debug('Stopped plugin updates gracefully')
@dataclass @dataclass
class AppDaemon: class AppDaemon:
cfg: AppDaemonConfig cfg: AppDaemonConfig
context: AppDaemonRunContext context: AppDaemonRunContext
_stack: ExitStack = field(default_factory=ExitStack)
utility: Utility = field(init=False) utility: Utility = field(init=False)
plugins: dict[str, Plugin] = field(default_factory=dict) plugins: dict[str, Plugin] = field(default_factory=dict)
starts: list[Callable] = field(default_factory=list)
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.logger = logging.getLogger('AppDaemon') self.logger = logging.getLogger('AppDaemon')
self.utility = Utility(self, self.context.stop_event) self.utility = Utility(self, self.context.stop_event)
self.plugins['dummy'] = Plugin(self, self.context.stop_event) self.plugins['dummy'] = Plugin(self, self.context.stop_event)
def create_task(self, coro: Coroutine, name: str | None = None): def __enter__(self):
return self.context.loop.create_task(coro, name=name) self.logger.info('Starting AppDaemon')
self._stack.enter_context(self.utility)
for plugin in self.plugins.values():
self._stack.enter_context(plugin)
return self
def start(self): def __exit__(self, exc_type, exc_value, traceback):
for start in self.starts: self._stack.__exit__(exc_type, exc_value, traceback)
subsystem = start.__qualname__.split('.')[0]
self.logger.debug(f'Starting {subsystem}')
start()
def create_task(self, coro: Coroutine, name: str | None = None, critical: bool = False):
"""Creates an async task and adds exception callbacks"""
task = self.context.loop.create_task(coro, name=name)
task.add_done_callback(self.check_task_exception)
if critical:
task.add_done_callback(self.critical_exception)
return task
@dataclass def check_task_exception(self, task: asyncio.Task):
class ADMain: if (exc := task.exception()) and not isinstance(exc, asyncio.CancelledError):
config_file: str self.logger.error('\n'.join(traceback.format_exception(exc)))
cfg: AppDaemonConfig = field(init=False)
def __post_init__(self) -> None: def critical_exception(self, task: asyncio.Task):
raw_cfg = read_config_file(self.config_file) if task.exception():
self.cfg = AppDaemonConfig( self.logger.critical(
config_file=self.config_file, 'Critical error in %s, forcing shutdown',
**raw_cfg['appdaemon'] task.get_name()
) )
self.shutdown()
def run(self): def shutdown(self):
with AppDaemonRunContext() as cm: self.logger.debug('Shutting down by sending SIGTERM')
ad = AppDaemon(self.cfg, cm) os.kill(os.getpid(), signal.SIGTERM)
ad.start()
cm.loop.run_forever()
if __name__ == '__main__':
import logging.config
from appdaemon.utils import read_config_file
from rich.console import Console
from rich.highlighter import NullHighlighter
console = Console()
logging.config.dictConfig(
{
'version': 1,
'disable_existing_loggers': False,
'formatters': {'basic': {'style': '{', 'format': '[yellow]{name}[/] {message}'}},
'handlers': {
'rich': {
'()': 'rich.logging.RichHandler',
'formatter': 'basic',
'console': console,
'highlighter': NullHighlighter(),
'markup': True
}
},
'root': {'level': 'DEBUG', 'handlers': ['rich']},
}
)
main = ADMain('/conf/ad-test/conf/appdaemon.yaml')
main.run()
# config_file = '/conf/ad-test/conf/appdaemon.yaml'
# raw_cfg = read_config_file(config_file)
# cfg = AppDaemonConfig(
# config_file=config_file,
# **raw_cfg['appdaemon']
# )
# with AppDaemonRunContext() as cm:
# ad = AppDaemon(cfg, cm)
# # ad.start()
# # cm.loop.run_forever()