started context manager
This commit is contained in:
112
context_manager.py
Normal file
112
context_manager.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
import signal
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from contextlib import ExitStack, contextmanager
|
||||||
|
from time import sleep
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class CustomContextManager:
|
||||||
|
_stack: ExitStack
|
||||||
|
loop: asyncio.AbstractEventLoop
|
||||||
|
executor: ThreadPoolExecutor
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._stack = ExitStack()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.loop = self._stack.enter_context(self.asyncio_context())
|
||||||
|
logger.debug("Entered asyncio context")
|
||||||
|
self.loop.add_signal_handler(signal.SIGINT, self.handle_signal)
|
||||||
|
self.loop.add_signal_handler(signal.SIGTERM, self.handle_signal)
|
||||||
|
# self.executor = self._stack.enter_context(ThreadPoolExecutor(max_workers=5))
|
||||||
|
self.executor = self._stack.enter_context(self.thread_context())
|
||||||
|
logger.debug("Entered threadpool context")
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self._stack.__exit__(exc_type, exc_value, traceback)
|
||||||
|
logger.debug("Exited context")
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def asyncio_context(self):
|
||||||
|
try:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
yield loop
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
logger.info("Closed the event loop.")
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def thread_context(self):
|
||||||
|
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||||
|
yield executor
|
||||||
|
logger.debug('Shut down the ThreadPoolExecutor')
|
||||||
|
|
||||||
|
def handle_signal(self, signum=None, frame=None):
|
||||||
|
logger.info(f'Handle signal: {signum}, {frame}')
|
||||||
|
# match signum:
|
||||||
|
# case signal.SIGINT:
|
||||||
|
# pass
|
||||||
|
# case signal.SIGTERM:
|
||||||
|
# pass
|
||||||
|
|
||||||
|
async def run_in_executor(
|
||||||
|
self,
|
||||||
|
func: Callable,
|
||||||
|
*args,
|
||||||
|
timeout: float | None = None,
|
||||||
|
**kwargs
|
||||||
|
) -> Any:
|
||||||
|
"""Run the sync function using the ThreadPoolExecutor and await the result"""
|
||||||
|
timeout = timeout or 10.0
|
||||||
|
|
||||||
|
coro = self.loop.run_in_executor(
|
||||||
|
self.executor,
|
||||||
|
functools.partial(func, **kwargs),
|
||||||
|
*args,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(coro, timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
print('Timed out')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
logging.basicConfig(level="DEBUG", format="{levelname:<8} {message}", style="{")
|
||||||
|
|
||||||
|
def dummy_function(delay: float):
|
||||||
|
id_ = uuid4().hex[:4]
|
||||||
|
logger.info(f'{id_} sleeping for {delay:.1f}s')
|
||||||
|
sleep(delay)
|
||||||
|
logger.info(f'{id_} Done')
|
||||||
|
|
||||||
|
async def async_dummy_function(delay: float):
|
||||||
|
id_ = uuid4().hex[:4]
|
||||||
|
logger.info(f'{id_} async sleeping for {delay:.1f}s')
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
logger.info(f'{id_} Done async')
|
||||||
|
|
||||||
|
with CustomContextManager() as cm:
|
||||||
|
for _ in range(3):
|
||||||
|
logger.info('Submitting random dummy_functions')
|
||||||
|
cm.executor.submit(dummy_function, random.random() * 3.0)
|
||||||
|
cm.loop.create_task(async_dummy_function(random.random() * 5.0))
|
||||||
|
|
||||||
|
logger.info('Running until complete')
|
||||||
|
cm.loop.run_until_complete(asyncio.gather(*asyncio.all_tasks(cm.loop)))
|
||||||
|
|
||||||
|
if cm.loop.is_closed():
|
||||||
|
logger.debug('Loop is closed')
|
||||||
|
if cm.executor._shutdown:
|
||||||
|
logger.debug('Executor is shut down')
|
||||||
Reference in New Issue
Block a user