Files
appdaemon_snippets/context_manager.py
2024-10-16 03:50:25 +00:00

113 lines
3.4 KiB
Python

import asyncio
import functools
import logging
import signal
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack, contextmanager
from time import sleep
from typing import Any, Callable
logger = logging.getLogger()
class CustomContextManager:
_stack: ExitStack
loop: asyncio.AbstractEventLoop
executor: ThreadPoolExecutor
def __init__(self):
self._stack = ExitStack()
def __enter__(self):
self.loop = self._stack.enter_context(self.asyncio_context())
logger.debug("Entered asyncio context")
self.loop.add_signal_handler(signal.SIGINT, self.handle_signal)
self.loop.add_signal_handler(signal.SIGTERM, self.handle_signal)
# self.executor = self._stack.enter_context(ThreadPoolExecutor(max_workers=5))
self.executor = self._stack.enter_context(self.thread_context())
logger.debug("Entered threadpool context")
return self
def __exit__(self, exc_type, exc_value, traceback):
self._stack.__exit__(exc_type, exc_value, traceback)
logger.debug("Exited context")
@contextmanager
def asyncio_context(self):
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
yield loop
finally:
loop.close()
logger.info("Closed the event loop.")
@contextmanager
def thread_context(self):
with ThreadPoolExecutor(max_workers=5) as executor:
yield executor
logger.debug('Shut down the ThreadPoolExecutor')
def handle_signal(self, signum=None, frame=None):
logger.info(f'Handle signal: {signum}, {frame}')
# match signum:
# case signal.SIGINT:
# pass
# case signal.SIGTERM:
# pass
async def run_in_executor(
self,
func: Callable,
*args,
timeout: float | None = None,
**kwargs
) -> Any:
"""Run the sync function using the ThreadPoolExecutor and await the result"""
timeout = timeout or 10.0
coro = self.loop.run_in_executor(
self.executor,
functools.partial(func, **kwargs),
*args,
)
try:
return await asyncio.wait_for(coro, timeout)
except asyncio.TimeoutError:
print('Timed out')
if __name__ == "__main__":
import logging
import random
from uuid import uuid4
logging.basicConfig(level="DEBUG", format="{levelname:<8} {message}", style="{")
def dummy_function(delay: float):
id_ = uuid4().hex[:4]
logger.info(f'{id_} sleeping for {delay:.1f}s')
sleep(delay)
logger.info(f'{id_} Done')
async def async_dummy_function(delay: float):
id_ = uuid4().hex[:4]
logger.info(f'{id_} async sleeping for {delay:.1f}s')
await asyncio.sleep(delay)
logger.info(f'{id_} Done async')
with CustomContextManager() as cm:
for _ in range(3):
logger.info('Submitting random dummy_functions')
cm.executor.submit(dummy_function, random.random() * 3.0)
cm.loop.create_task(async_dummy_function(random.random() * 5.0))
logger.info('Running until complete')
cm.loop.run_until_complete(asyncio.gather(*asyncio.all_tasks(cm.loop)))
if cm.loop.is_closed():
logger.debug('Loop is closed')
if cm.executor._shutdown:
logger.debug('Executor is shut down')