diff --git a/context_manager.py b/context_manager.py new file mode 100644 index 0000000..363691e --- /dev/null +++ b/context_manager.py @@ -0,0 +1,112 @@ +import asyncio +import functools +import logging +import signal +from concurrent.futures import ThreadPoolExecutor +from contextlib import ExitStack, contextmanager +from time import sleep +from typing import Any, Callable + +logger = logging.getLogger() + + +class CustomContextManager: + _stack: ExitStack + loop: asyncio.AbstractEventLoop + executor: ThreadPoolExecutor + + def __init__(self): + self._stack = ExitStack() + + def __enter__(self): + self.loop = self._stack.enter_context(self.asyncio_context()) + logger.debug("Entered asyncio context") + self.loop.add_signal_handler(signal.SIGINT, self.handle_signal) + self.loop.add_signal_handler(signal.SIGTERM, self.handle_signal) + # self.executor = self._stack.enter_context(ThreadPoolExecutor(max_workers=5)) + self.executor = self._stack.enter_context(self.thread_context()) + logger.debug("Entered threadpool context") + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._stack.__exit__(exc_type, exc_value, traceback) + logger.debug("Exited context") + + @contextmanager + def asyncio_context(self): + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + yield loop + finally: + loop.close() + logger.info("Closed the event loop.") + + @contextmanager + def thread_context(self): + with ThreadPoolExecutor(max_workers=5) as executor: + yield executor + logger.debug('Shut down the ThreadPoolExecutor') + + def handle_signal(self, signum=None, frame=None): + logger.info(f'Handle signal: {signum}, {frame}') + # match signum: + # case signal.SIGINT: + # pass + # case signal.SIGTERM: + # pass + + async def run_in_executor( + self, + func: Callable, + *args, + timeout: float | None = None, + **kwargs + ) -> Any: + """Run the sync function using the ThreadPoolExecutor and await the result""" + timeout = timeout or 10.0 + + coro = self.loop.run_in_executor( + self.executor, + functools.partial(func, **kwargs), + *args, + ) + + try: + return await asyncio.wait_for(coro, timeout) + except asyncio.TimeoutError: + print('Timed out') + + +if __name__ == "__main__": + import logging + import random + from uuid import uuid4 + + logging.basicConfig(level="DEBUG", format="{levelname:<8} {message}", style="{") + + def dummy_function(delay: float): + id_ = uuid4().hex[:4] + logger.info(f'{id_} sleeping for {delay:.1f}s') + sleep(delay) + logger.info(f'{id_} Done') + + async def async_dummy_function(delay: float): + id_ = uuid4().hex[:4] + logger.info(f'{id_} async sleeping for {delay:.1f}s') + await asyncio.sleep(delay) + logger.info(f'{id_} Done async') + + with CustomContextManager() as cm: + for _ in range(3): + logger.info('Submitting random dummy_functions') + cm.executor.submit(dummy_function, random.random() * 3.0) + cm.loop.create_task(async_dummy_function(random.random() * 5.0)) + + logger.info('Running until complete') + cm.loop.run_until_complete(asyncio.gather(*asyncio.all_tasks(cm.loop))) + + if cm.loop.is_closed(): + logger.debug('Loop is closed') + if cm.executor._shutdown: + logger.debug('Executor is shut down')