import asyncio import functools import logging import signal from concurrent.futures import ThreadPoolExecutor from contextlib import ExitStack, contextmanager from time import sleep from typing import Any, Callable logger = logging.getLogger() class CustomContextManager: _stack: ExitStack loop: asyncio.AbstractEventLoop executor: ThreadPoolExecutor def __init__(self): self._stack = ExitStack() def __enter__(self): self.loop = self._stack.enter_context(self.asyncio_context()) logger.debug("Entered asyncio context") self.loop.add_signal_handler(signal.SIGINT, self.handle_signal) self.loop.add_signal_handler(signal.SIGTERM, self.handle_signal) # self.executor = self._stack.enter_context(ThreadPoolExecutor(max_workers=5)) self.executor = self._stack.enter_context(self.thread_context()) logger.debug("Entered threadpool context") return self def __exit__(self, exc_type, exc_value, traceback): self._stack.__exit__(exc_type, exc_value, traceback) logger.debug("Exited context") @contextmanager def asyncio_context(self): try: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) yield loop finally: loop.close() logger.info("Closed the event loop.") @contextmanager def thread_context(self): with ThreadPoolExecutor(max_workers=5) as executor: yield executor logger.debug('Shut down the ThreadPoolExecutor') def handle_signal(self, signum=None, frame=None): logger.info(f'Handle signal: {signum}, {frame}') # match signum: # case signal.SIGINT: # pass # case signal.SIGTERM: # pass async def run_in_executor( self, func: Callable, *args, timeout: float | None = None, **kwargs ) -> Any: """Run the sync function using the ThreadPoolExecutor and await the result""" timeout = timeout or 10.0 coro = self.loop.run_in_executor( self.executor, functools.partial(func, **kwargs), *args, ) try: return await asyncio.wait_for(coro, timeout) except asyncio.TimeoutError: print('Timed out') if __name__ == "__main__": import logging import random from uuid import uuid4 logging.basicConfig(level="DEBUG", format="{levelname:<8} {message}", style="{") def dummy_function(delay: float): id_ = uuid4().hex[:4] logger.info(f'{id_} sleeping for {delay:.1f}s') sleep(delay) logger.info(f'{id_} Done') async def async_dummy_function(delay: float): id_ = uuid4().hex[:4] logger.info(f'{id_} async sleeping for {delay:.1f}s') await asyncio.sleep(delay) logger.info(f'{id_} Done async') with CustomContextManager() as cm: for _ in range(3): logger.info('Submitting random dummy_functions') cm.executor.submit(dummy_function, random.random() * 3.0) cm.loop.create_task(async_dummy_function(random.random() * 5.0)) logger.info('Running until complete') cm.loop.run_until_complete(asyncio.gather(*asyncio.all_tasks(cm.loop))) if cm.loop.is_closed(): logger.debug('Loop is closed') if cm.executor._shutdown: logger.debug('Executor is shut down')