started context manager

This commit is contained in:
John Lancaster
2024-10-16 03:50:25 +00:00
parent 6076329aea
commit 45e435554a

112
context_manager.py Normal file
View File

@@ -0,0 +1,112 @@
import asyncio
import functools
import logging
import signal
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack, contextmanager
from time import sleep
from typing import Any, Callable
logger = logging.getLogger()
class CustomContextManager:
_stack: ExitStack
loop: asyncio.AbstractEventLoop
executor: ThreadPoolExecutor
def __init__(self):
self._stack = ExitStack()
def __enter__(self):
self.loop = self._stack.enter_context(self.asyncio_context())
logger.debug("Entered asyncio context")
self.loop.add_signal_handler(signal.SIGINT, self.handle_signal)
self.loop.add_signal_handler(signal.SIGTERM, self.handle_signal)
# self.executor = self._stack.enter_context(ThreadPoolExecutor(max_workers=5))
self.executor = self._stack.enter_context(self.thread_context())
logger.debug("Entered threadpool context")
return self
def __exit__(self, exc_type, exc_value, traceback):
self._stack.__exit__(exc_type, exc_value, traceback)
logger.debug("Exited context")
@contextmanager
def asyncio_context(self):
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
yield loop
finally:
loop.close()
logger.info("Closed the event loop.")
@contextmanager
def thread_context(self):
with ThreadPoolExecutor(max_workers=5) as executor:
yield executor
logger.debug('Shut down the ThreadPoolExecutor')
def handle_signal(self, signum=None, frame=None):
logger.info(f'Handle signal: {signum}, {frame}')
# match signum:
# case signal.SIGINT:
# pass
# case signal.SIGTERM:
# pass
async def run_in_executor(
self,
func: Callable,
*args,
timeout: float | None = None,
**kwargs
) -> Any:
"""Run the sync function using the ThreadPoolExecutor and await the result"""
timeout = timeout or 10.0
coro = self.loop.run_in_executor(
self.executor,
functools.partial(func, **kwargs),
*args,
)
try:
return await asyncio.wait_for(coro, timeout)
except asyncio.TimeoutError:
print('Timed out')
if __name__ == "__main__":
import logging
import random
from uuid import uuid4
logging.basicConfig(level="DEBUG", format="{levelname:<8} {message}", style="{")
def dummy_function(delay: float):
id_ = uuid4().hex[:4]
logger.info(f'{id_} sleeping for {delay:.1f}s')
sleep(delay)
logger.info(f'{id_} Done')
async def async_dummy_function(delay: float):
id_ = uuid4().hex[:4]
logger.info(f'{id_} async sleeping for {delay:.1f}s')
await asyncio.sleep(delay)
logger.info(f'{id_} Done async')
with CustomContextManager() as cm:
for _ in range(3):
logger.info('Submitting random dummy_functions')
cm.executor.submit(dummy_function, random.random() * 3.0)
cm.loop.create_task(async_dummy_function(random.random() * 5.0))
logger.info('Running until complete')
cm.loop.run_until_complete(asyncio.gather(*asyncio.all_tasks(cm.loop)))
if cm.loop.is_closed():
logger.debug('Loop is closed')
if cm.executor._shutdown:
logger.debug('Executor is shut down')