115 lines
3.7 KiB
Python
115 lines
3.7 KiB
Python
import asyncio
|
|
import functools
|
|
import logging
|
|
import signal
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from contextlib import ExitStack, contextmanager
|
|
from dataclasses import dataclass, field
|
|
from threading import Event, Lock
|
|
from typing import Any, Callable
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def handler(signum, frame):
|
|
print('Signal handler called with signal', signum)
|
|
raise OSError("Couldn't open device!")
|
|
|
|
@dataclass
|
|
class AppDaemonRunContext:
|
|
_stack: ExitStack = field(default_factory=ExitStack)
|
|
stop_event: Event = field(default_factory=Event)
|
|
shutdown_lock: Lock = field(default_factory=Lock)
|
|
shutdown_grace_period: float = 1.0
|
|
|
|
loop: asyncio.AbstractEventLoop = field(init=False)
|
|
executor: ThreadPoolExecutor = field(init=False)
|
|
|
|
def __enter__(self):
|
|
self.loop = self._stack.enter_context(self.asyncio_context())
|
|
logger.debug("Entered asyncio context")
|
|
|
|
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
|
|
for s in signals:
|
|
self.loop.add_signal_handler(s, lambda s=s: self.loop.create_task(self.shutdown(s)))
|
|
|
|
self.executor = self._stack.enter_context(self.thread_context())
|
|
logger.debug("Entered threadpool context")
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self._stack.__exit__(exc_type, exc_value, traceback)
|
|
logger.debug("Exited context")
|
|
|
|
def get_running_tasks(self, exclude_current: bool = True) -> list[asyncio.Task]:
|
|
return [
|
|
t for t in asyncio.all_tasks(self.loop)
|
|
if exclude_current and t is not asyncio.current_task()
|
|
]
|
|
|
|
async def shutdown(self, signal):
|
|
"""Cleanup tasks tied to the service's shutdown.
|
|
|
|
https://www.roguelynn.com/words/asyncio-graceful-shutdowns/
|
|
"""
|
|
logger.info(f"Received exit signal {signal.name}...")
|
|
if not self.stop_event.is_set():
|
|
logger.debug('Setting stop event')
|
|
self.stop_event.set()
|
|
|
|
tasks = self.get_running_tasks()
|
|
|
|
graceful = (
|
|
asyncio.wait_for(t, timeout=self.shutdown_grace_period)
|
|
for t in tasks
|
|
)
|
|
logger.debug(f'Allowing graceful shutdown from stop event for {self.shutdown_grace_period}s')
|
|
await asyncio.gather(*graceful, return_exceptions=True)
|
|
|
|
for task in tasks:
|
|
if task.cancelled():
|
|
logger.warning(f'Cancelled {task.get_coro().__qualname__}')
|
|
|
|
logger.debug("Stopping event loop in context shutdown")
|
|
self.loop.stop()
|
|
else:
|
|
logger.warning('Already started shutdown')
|
|
|
|
@contextmanager
|
|
def asyncio_context(self):
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
yield loop
|
|
finally:
|
|
loop.close()
|
|
if loop.is_closed():
|
|
logger.debug("Closed the event loop.")
|
|
|
|
@contextmanager
|
|
def thread_context(self):
|
|
with ThreadPoolExecutor(max_workers=5) as executor:
|
|
yield executor
|
|
if executor._shutdown:
|
|
logger.debug('Shut down the ThreadPoolExecutor')
|
|
|
|
async def run_in_executor(
|
|
self,
|
|
func: Callable,
|
|
*args,
|
|
timeout: float | None = None,
|
|
**kwargs
|
|
) -> Any:
|
|
"""Run the sync function using the ThreadPoolExecutor and await the result"""
|
|
timeout = timeout or 10.0
|
|
|
|
coro = self.loop.run_in_executor(
|
|
self.executor,
|
|
functools.partial(func, **kwargs),
|
|
*args,
|
|
)
|
|
|
|
try:
|
|
return await asyncio.wait_for(coro, timeout)
|
|
except asyncio.TimeoutError:
|
|
print('Timed out')
|