Compare commits

..

5 Commits

Author SHA1 Message Date
John Lancaster
5e458aca41 reorg 2024-10-16 03:56:07 +00:00
John Lancaster
5af940f077 added some threading stuff to appdaemon object 2024-10-16 03:55:32 +00:00
John Lancaster
4dddc50c82 reformat 2024-10-16 03:51:16 +00:00
John Lancaster
9ccbad58cf added get_state to state 2024-10-16 03:50:57 +00:00
John Lancaster
45e435554a started context manager 2024-10-16 03:50:25 +00:00
4 changed files with 252 additions and 17 deletions

55
appdaemon/appdaemon.py Normal file
View File

@@ -0,0 +1,55 @@
import asyncio
import concurrent
import concurrent.futures
import functools
import inspect
from typing import Any, Callable, Coroutine
class AppDaemon:
async def run_async_sync_func(self, method, *args, timeout: float | None = None, **kwargs):
if inspect.iscoroutinefunction(method):
result = await method(*args, timeout=timeout, **kwargs)
else:
result = await self.run_in_executor(self, method, *args, timeout=timeout, **kwargs)
return result
async def run_in_executor(
self,
func: Callable,
*args,
timeout: float | None = None,
**kwargs
) -> Any:
"""Run the sync function using the ThreadPoolExecutor and await the result"""
timeout = timeout or self.AD.internal_function_timeout
coro = self.AD.loop.run_in_executor(
self.AD.executor,
functools.partial(func, **kwargs),
*args,
)
try:
return await asyncio.wait_for(coro, timeout)
except asyncio.TimeoutError:
self.logger.warning(
"Function (%s) took too long (%s seconds), cancelling the task...",
func.__name__, timeout,
)
def run_coroutine_threadsafe(self, coro: Coroutine, timeout: float | None = None) -> Any:
timeout = timeout or self.AD.internal_function_timeout
if self.AD.loop.is_running():
try:
future = asyncio.run_coroutine_threadsafe(coro, self.AD.loop)
return future.result(timeout)
except (asyncio.TimeoutError, concurrent.futures.TimeoutError):
self.logger.warning(
"Coroutine (%s) took too long (%s seconds), cancelling the task...",
coro, timeout,
)
future.cancel()
else:
self.logger.warning("LOOP NOT RUNNING. Returning NONE.")

View File

@@ -1,11 +1,8 @@
import asyncio import asyncio
from cgitb import handler
from copy import deepcopy from copy import deepcopy
from logging import Logger from logging import Logger
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal
import appdaemon.utils as utils
if TYPE_CHECKING: if TYPE_CHECKING:
from appdaemon.appdaemon import AppDaemon from appdaemon.appdaemon import AppDaemon
@@ -50,12 +47,14 @@ class Callbacks:
await self.AD.state.remove_entity("admin", f"{callback['type']}_callback.{handle}") await self.AD.state.remove_entity("admin", f"{callback['type']}_callback.{handle}")
return True return True
elif not silent: elif not silent:
self.logger.warning(f"Invalid callback handle '{handle}' in cancel_callback()") self.logger.warning(f"Invalid callback handle '{
handle}' in cancel_callback()")
async def cancel_all_callbacks(self, name: str, silent: bool = False): async def cancel_all_callbacks(self, name: str, silent: bool = False):
async with self.callbacks_lock: async with self.callbacks_lock:
if callbacks := self.callbacks.pop(name, False): if callbacks := self.callbacks.pop(name, False):
self.logger.debug("Clearing %s callbacks for %s", len(callbacks), name) self.logger.debug(
"Clearing %s callbacks for %s", len(callbacks), name)
for handle, cb_info in callbacks.items(): for handle, cb_info in callbacks.items():
cb_type: Literal["event", "state", "log"] = cb_info['type'] cb_type: Literal["event", "state", "log"] = cb_info['type']
await self.AD.state.remove_entity("admin", f"{cb_type}_callback.{handle}") await self.AD.state.remove_entity("admin", f"{cb_type}_callback.{handle}")
@@ -82,16 +81,36 @@ class Callbacks:
for app_name, app_callbacks in self.callbacks.items() for app_name, app_callbacks in self.callbacks.items()
} }
async def get_callback_handles(self, app: str = 'all', type: str = 'all', entity_id: str = 'all'): async def get_callbacks(
self,
namespace: str = 'all',
app: str = 'all',
type: str = 'all',
entity_id: str = 'all',
copy: bool = True,
) -> dict[str, dict[str, Any]]:
async with self.callbacks_lock: async with self.callbacks_lock:
handles = set( return {
handle handle: deepcopy(cb_info) if copy else cb_info
for app_name, app_callbacks in self.callbacks.items() for app_name, app_callbacks in self.callbacks.items()
if app == 'all' or app == app_name if app == 'all' or app == app_name
for handle, cb_info in app_callbacks.items() for handle, cb_info in app_callbacks.items()
if (type == 'all' or type == cb_info["type"]) if (type == 'all' or type == cb_info["type"])
and (entity_id == 'all' or entity_id == cb_info["entity"]) and (entity_id == 'all' or entity_id == cb_info["entity"])
and (
namespace == 'all'
or namespace == 'global'
or cb_info["namespace"] == 'global'
or namespace == cb_info["namespace"]
) )
self.logger.debug(f"Got {len(handles)} callbacks for app={app}, type={type}, entity_id={entity_id}") }
return handles
async def get_callback_handles(
self,
namespace: str = 'all',
app: str = 'all',
type: str = 'all',
entity_id: str = 'all'
) -> set[str]:
callbacks = await self.get_callbacks(namespace, app, type, entity_id, copy=False)
return set(callbacks.keys())

View File

@@ -1,4 +1,4 @@
from copy import deepcopy
from logging import Logger from logging import Logger
from typing import Any from typing import Any
@@ -56,3 +56,52 @@ class States:
async for remove in _send_dispatches(): async for remove in _send_dispatches():
await self.cancel_state_callback(**remove) await self.cancel_state_callback(**remove)
async def get_state(
self,
name: str,
namespace: str,
entity_id: str | None = None,
attribute: str | None = None,
default: Any = None,
copy: bool = True,
):
self.logger.debug("get_state: %s.%s %s %s", entity_id, attribute, default, copy)
result = default
if ns := self.state.get(namespace):
# Process entity_id input
if entity_id is None:
result = ns
# TODO: filter by attribute?
elif "." not in entity_id:
domain = entity_id
result = {
eid: state
for eid, state in ns.items()
if eid.startswith(domain)
}
elif full_state := ns.get(entity_id):
result = full_state
else:
self.logger.warning(f"Entity {entity_id} does not exist in namespace {namespace}")
return
# Process attribute input
if attribute == "all":
result = result
elif attr := full_state.get(attribute):
result = attr
elif attr := full_state.get('attributes', {}).get(attribute):
result = attr
elif state := full_state.get("state"):
result = state
return deepcopy(result) if copy else result
else:
self.logger.warning(f"Namespace does not exist: {namespace}")
async def cancel_state_callback(self, handle: str, name: str, silent: bool = False) -> bool:
return await self.AD.callbacks.cancel_callback(handle, name, silent)

112
context_manager.py Normal file
View File

@@ -0,0 +1,112 @@
import asyncio
import functools
import logging
import signal
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack, contextmanager
from time import sleep
from typing import Any, Callable
logger = logging.getLogger()
class CustomContextManager:
_stack: ExitStack
loop: asyncio.AbstractEventLoop
executor: ThreadPoolExecutor
def __init__(self):
self._stack = ExitStack()
def __enter__(self):
self.loop = self._stack.enter_context(self.asyncio_context())
logger.debug("Entered asyncio context")
self.loop.add_signal_handler(signal.SIGINT, self.handle_signal)
self.loop.add_signal_handler(signal.SIGTERM, self.handle_signal)
# self.executor = self._stack.enter_context(ThreadPoolExecutor(max_workers=5))
self.executor = self._stack.enter_context(self.thread_context())
logger.debug("Entered threadpool context")
return self
def __exit__(self, exc_type, exc_value, traceback):
self._stack.__exit__(exc_type, exc_value, traceback)
logger.debug("Exited context")
@contextmanager
def asyncio_context(self):
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
yield loop
finally:
loop.close()
logger.info("Closed the event loop.")
@contextmanager
def thread_context(self):
with ThreadPoolExecutor(max_workers=5) as executor:
yield executor
logger.debug('Shut down the ThreadPoolExecutor')
def handle_signal(self, signum=None, frame=None):
logger.info(f'Handle signal: {signum}, {frame}')
# match signum:
# case signal.SIGINT:
# pass
# case signal.SIGTERM:
# pass
async def run_in_executor(
self,
func: Callable,
*args,
timeout: float | None = None,
**kwargs
) -> Any:
"""Run the sync function using the ThreadPoolExecutor and await the result"""
timeout = timeout or 10.0
coro = self.loop.run_in_executor(
self.executor,
functools.partial(func, **kwargs),
*args,
)
try:
return await asyncio.wait_for(coro, timeout)
except asyncio.TimeoutError:
print('Timed out')
if __name__ == "__main__":
import logging
import random
from uuid import uuid4
logging.basicConfig(level="DEBUG", format="{levelname:<8} {message}", style="{")
def dummy_function(delay: float):
id_ = uuid4().hex[:4]
logger.info(f'{id_} sleeping for {delay:.1f}s')
sleep(delay)
logger.info(f'{id_} Done')
async def async_dummy_function(delay: float):
id_ = uuid4().hex[:4]
logger.info(f'{id_} async sleeping for {delay:.1f}s')
await asyncio.sleep(delay)
logger.info(f'{id_} Done async')
with CustomContextManager() as cm:
for _ in range(3):
logger.info('Submitting random dummy_functions')
cm.executor.submit(dummy_function, random.random() * 3.0)
cm.loop.create_task(async_dummy_function(random.random() * 5.0))
logger.info('Running until complete')
cm.loop.run_until_complete(asyncio.gather(*asyncio.all_tasks(cm.loop)))
if cm.loop.is_closed():
logger.debug('Loop is closed')
if cm.executor._shutdown:
logger.debug('Executor is shut down')