Files
appdaemon_snippets/appdaemon/context_manager.py
2024-10-21 02:41:19 +00:00

112 lines
3.5 KiB
Python

import asyncio
import functools
import logging
import signal
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack, contextmanager
from dataclasses import dataclass, field
from threading import Event
from typing import Any, Callable
logger = logging.getLogger(__name__)
@dataclass
class AppDaemonRunContext:
_stack: ExitStack = field(default_factory=ExitStack)
stop_event: Event = field(default_factory=Event)
shutdown_grace_period: float = 0.75
loop: asyncio.AbstractEventLoop = field(init=False)
executor: ThreadPoolExecutor = field(init=False)
def __enter__(self):
self.loop = self._stack.enter_context(self.asyncio_context())
logger.debug("Created event loop")
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
for s in signals:
self.loop.add_signal_handler(s, lambda s=s: self.loop.create_task(self.shutdown(s)))
self.executor = self._stack.enter_context(self.thread_context())
logger.debug("Started thread pool")
return self
def __exit__(self, exc_type, exc_value, traceback):
logger.debug(f'Closing context from {self.__class__.__name__}')
self._stack.close()
def get_running_tasks(self, exclude_current: bool = True) -> list[asyncio.Task]:
return [
t for t in asyncio.all_tasks(self.loop)
if exclude_current and t is not asyncio.current_task()
]
async def shutdown(self, signal=signal.SIGTERM):
"""Cleanup tasks tied to the service's shutdown.
https://www.roguelynn.com/words/asyncio-graceful-shutdowns/
"""
logger.info(f"Received exit signal {signal.name}...")
if not self.stop_event.is_set():
logger.debug('Setting stop event')
self.stop_event.set()
tasks = self.get_running_tasks()
graceful = (
asyncio.wait_for(t, timeout=self.shutdown_grace_period)
for t in tasks
)
logger.debug(f'Allowing graceful shutdown from stop event for {self.shutdown_grace_period}s')
await asyncio.gather(*graceful, return_exceptions=True)
for task in tasks:
if task.cancelled():
logger.warning(f'Cancelled {task.get_name()}')
logger.info("Stopping asyncio event loop")
self.loop.stop()
else:
logger.warning('Already started shutdown')
@contextmanager
def asyncio_context(self):
try:
loop = asyncio.get_event_loop()
yield loop
finally:
loop.close()
if loop.is_closed():
logger.debug("Gracefully closed event loop.")
@contextmanager
def thread_context(self):
with ThreadPoolExecutor(max_workers=5) as executor:
yield executor
if executor._shutdown:
logger.debug('Gracefully shut down ThreadPoolExecutor')
async def run_in_executor(
self,
func: Callable,
*args,
timeout: float | None = None,
**kwargs
) -> Any:
"""Run the sync function using the ThreadPoolExecutor and await the result"""
timeout = timeout or 10.0
coro = self.loop.run_in_executor(
self.executor,
functools.partial(func, **kwargs),
*args,
)
try:
return await asyncio.wait_for(coro, timeout)
except asyncio.TimeoutError:
print('Timed out')