more work

This commit is contained in:
John Lancaster
2024-10-20 21:48:19 +00:00
parent 050fe75e71
commit 502c218c35
2 changed files with 34 additions and 61 deletions

View File

@@ -4,29 +4,26 @@ import logging
import signal import signal
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack, contextmanager from contextlib import ExitStack, contextmanager
from dataclasses import dataclass, field
from threading import Event, Lock from threading import Event, Lock
from time import sleep
from typing import Any, Callable from typing import Any, Callable
logger = logging.getLogger() logger = logging.getLogger(__name__)
def handler(signum, frame): def handler(signum, frame):
print('Signal handler called with signal', signum) print('Signal handler called with signal', signum)
raise OSError("Couldn't open device!") raise OSError("Couldn't open device!")
@dataclass
class AppDaemonRunContext: class AppDaemonRunContext:
_stack: ExitStack _stack: ExitStack = field(default_factory=ExitStack)
loop: asyncio.AbstractEventLoop stop_event: Event = field(default_factory=Event)
executor: ThreadPoolExecutor shutdown_lock: Lock = field(default_factory=Lock)
stop_event: Event shutdown_grace_period: float = 1.0
shutdown_lock: Lock
def __init__(self): loop: asyncio.AbstractEventLoop = field(init=False)
self._stack = ExitStack() executor: ThreadPoolExecutor = field(init=False)
self.stop_event = Event()
self.shutdown_lock = Lock()
def __enter__(self): def __enter__(self):
self.loop = self._stack.enter_context(self.asyncio_context()) self.loop = self._stack.enter_context(self.asyncio_context())
@@ -44,6 +41,12 @@ class AppDaemonRunContext:
self._stack.__exit__(exc_type, exc_value, traceback) self._stack.__exit__(exc_type, exc_value, traceback)
logger.debug("Exited context") logger.debug("Exited context")
def get_running_tasks(self, exclude_current: bool = True) -> list[asyncio.Task]:
return [
t for t in asyncio.all_tasks(self.loop)
if exclude_current and t is not asyncio.current_task()
]
async def shutdown(self, signal): async def shutdown(self, signal):
"""Cleanup tasks tied to the service's shutdown. """Cleanup tasks tied to the service's shutdown.
@@ -54,14 +57,19 @@ class AppDaemonRunContext:
logger.debug('Setting stop event') logger.debug('Setting stop event')
self.stop_event.set() self.stop_event.set()
tasks = self.get_running_tasks()
graceful = ( graceful = (
asyncio.wait_for(t, timeout=2.0) asyncio.wait_for(t, timeout=self.shutdown_grace_period)
for t in asyncio.all_tasks() for t in tasks
if t is not asyncio.current_task()
) )
logger.debug('Allowing graceful shutdown from stop event') logger.debug(f'Allowing graceful shutdown from stop event for {self.shutdown_grace_period}s')
await asyncio.gather(*graceful, return_exceptions=True) await asyncio.gather(*graceful, return_exceptions=True)
for task in tasks:
if task.cancelled():
logger.warning(f'Cancelled {task.get_coro().__qualname__}')
logger.debug("Stopping event loop in context shutdown") logger.debug("Stopping event loop in context shutdown")
self.loop.stop() self.loop.stop()
else: else:
@@ -104,40 +112,3 @@ class AppDaemonRunContext:
return await asyncio.wait_for(coro, timeout) return await asyncio.wait_for(coro, timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
print('Timed out') print('Timed out')
if __name__ == "__main__":
import logging
import random
from uuid import uuid4
logging.basicConfig(level="DEBUG", format="{levelname:<8} {message}", style="{")
def dummy_function(delay: float):
id_ = uuid4().hex[:4]
logger.info(f'{id_} sleeping for {delay:.1f}s')
sleep(delay)
logger.info(f'{id_} Done')
async def async_dummy_function(delay: float):
id_ = uuid4().hex[:4]
logger.info(f'{id_} async sleeping for {delay:.1f}s')
await asyncio.sleep(delay)
logger.info(f'{id_} Done async')
with AppDaemonRunContext() as cm:
for _ in range(3):
logger.info('Submitting random dummy_functions')
cm.executor.submit(dummy_function, random.random() * 10.0)
cm.loop.create_task(async_dummy_function(random.random() * 5.0))
try:
logger.info('Running until complete')
cm.loop.run_until_complete(asyncio.gather(*asyncio.all_tasks(cm.loop)))
except asyncio.CancelledError:
logger.error('Cancelled')
if cm.loop.is_closed():
logger.debug('Loop is closed')
if cm.executor._shutdown:
logger.debug('Executor is shut down')

View File

@@ -19,26 +19,24 @@ class ADSubsystem:
def __post_init__(self) -> None: def __post_init__(self) -> None:
name = f'_{self.__class__.__name__.lower()}' name = f'_{self.__class__.__name__.lower()}'
self.logger = getLogger(f'AppDaemon.{name}') self.logger = getLogger(f'AppDaemon.{name}')
self.AD.starts.append(self.start) if start_func := getattr(self, 'start', False):
self.AD.starts.append(start_func)
@property @property
def stopping(self) -> bool: def stopping(self) -> bool:
return self.stop.is_set() return self.stop.is_set()
def start(self):
raise NotImplementedError('Need to implement start for subsystem')
@dataclass @dataclass
class Utility(ADSubsystem): class Utility(ADSubsystem):
loop_rate: float = 1.0 loop_rate: float = 0.5
def start(self): def start(self):
self.AD.create_task(self.loop(), 'Utility loop') self.AD.create_task(self.loop(), 'Utility loop')
async def loop(self): async def loop(self):
while not self.stopping: while not self.stopping:
self.logger.info('Looping...') self.logger.debug('Looping...')
await asyncio.sleep(self.loop_rate) await asyncio.sleep(self.loop_rate)
self.logger.debug('Stopped utility loop') self.logger.debug('Stopped utility loop')
@@ -46,14 +44,15 @@ class Utility(ADSubsystem):
@dataclass @dataclass
class Plugin(ADSubsystem): class Plugin(ADSubsystem):
state: dict[str, int] = field(default_factory=dict) state: dict[str, int] = field(default_factory=dict)
update_rate: float = 30.0 update_rate: float = 5.0
def __post_init__(self) -> None: def __post_init__(self) -> None:
super().__post_init__() super().__post_init__()
self.state['update_count'] = 0 self.state['update_count'] = 0
def start(self): def start(self):
self.AD.create_task(self.periodic_self_udpate(), 'Periodic plugin update') self.AD.create_task(self.periodic_self_udpate(),
'Periodic plugin update')
async def periodic_self_udpate(self): async def periodic_self_udpate(self):
while not self.stopping: while not self.stopping:
@@ -73,6 +72,7 @@ class AppDaemon:
starts: list[Callable] = field(default_factory=list) starts: list[Callable] = field(default_factory=list)
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.logger = logging.getLogger('AppDaemon')
self.utility = Utility(self, self.context.stop_event) self.utility = Utility(self, self.context.stop_event)
self.plugins['dummy'] = Plugin(self, self.context.stop_event) self.plugins['dummy'] = Plugin(self, self.context.stop_event)
@@ -81,6 +81,8 @@ class AppDaemon:
def start(self): def start(self):
for start in self.starts: for start in self.starts:
subsystem = start.__qualname__.split('.')[0]
self.logger.debug(f'Starting {subsystem}')
start() start()