150 lines
4.8 KiB
Python
150 lines
4.8 KiB
Python
import asyncio
|
|
import logging
|
|
import os
|
|
import signal
|
|
import traceback
|
|
from contextlib import ExitStack
|
|
from dataclasses import dataclass, field
|
|
from logging import Logger, getLogger
|
|
from random import random
|
|
from threading import Event, RLock
|
|
from time import perf_counter
|
|
from typing import Coroutine
|
|
|
|
from appdaemon.models import AppDaemonConfig
|
|
from context_manager import AppDaemonRunContext
|
|
|
|
|
|
@dataclass
|
|
class ADSubsystem:
|
|
AD: "AppDaemon"
|
|
stop: Event
|
|
"""An thread event for the subsystem to use to shutdown gracefully"""
|
|
lock: RLock = field(default_factory=RLock)
|
|
"""A threadsafe re-entrant lock to protect any internal data while it's being modified"""
|
|
logger: Logger = field(init=False)
|
|
tasks: list[asyncio.Task] = field(default_factory=list)
|
|
|
|
def __post_init__(self) -> None:
|
|
name = f'_{self.__class__.__name__.lower()}'
|
|
self.logger = getLogger(f'AppDaemon.{name}')
|
|
self.create_task = self.AD.create_task
|
|
|
|
def __enter__(self):
|
|
self.logger.debug(f'Starting {self.__class__.__name__}')
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self.logger.debug(f'Exiting {self.__class__.__name__}')
|
|
|
|
@property
|
|
def stopping(self) -> bool:
|
|
return self.stop.is_set()
|
|
|
|
|
|
@dataclass
|
|
class Utility(ADSubsystem):
|
|
loop_rate: float = 0.25
|
|
|
|
def __enter__(self):
|
|
super().__enter__()
|
|
self.create_task(self.loop(), 'Utility loop', critical=True)
|
|
return self
|
|
|
|
async def loop(self):
|
|
while not self.stopping:
|
|
self.logger.debug('Looping...')
|
|
try:
|
|
await asyncio.sleep(self.loop_rate)
|
|
except asyncio.CancelledError:
|
|
self.logger.debug('Cancelled during sleep')
|
|
self.logger.debug('Stopped utility loop gracefully')
|
|
|
|
|
|
@dataclass
|
|
class Plugin(ADSubsystem):
|
|
state: dict[str, int] = field(default_factory=dict)
|
|
update_rate: float = 2.0
|
|
|
|
def __post_init__(self) -> None:
|
|
super().__post_init__()
|
|
self.state['update_count'] = 0
|
|
|
|
def __enter__(self):
|
|
super().__enter__()
|
|
self.create_task(
|
|
self.periodic_self_udpate(),
|
|
name='plugin periodic update',
|
|
critical=True
|
|
)
|
|
return self
|
|
|
|
async def periodic_self_udpate(self):
|
|
loop_time = perf_counter()
|
|
while not self.stopping:
|
|
with self.lock:
|
|
self.state['update_count'] += 1
|
|
# self.logger.debug('Long plugin update...')
|
|
# await asyncio.sleep(random())
|
|
self.logger.debug(
|
|
'Plugin self update: %s %s',
|
|
self.state["update_count"],
|
|
f'{perf_counter()-loop_time:.3f}s'
|
|
)
|
|
loop_time = perf_counter()
|
|
|
|
# if self.state['update_count'] == 2:
|
|
# raise ValueError('fake error')
|
|
try:
|
|
await asyncio.sleep(self.update_rate)
|
|
except asyncio.CancelledError:
|
|
self.logger.debug('Cancelled during sleep')
|
|
self.logger.debug('Stopped plugin updates gracefully')
|
|
|
|
|
|
@dataclass
|
|
class AppDaemon:
|
|
cfg: AppDaemonConfig
|
|
context: AppDaemonRunContext
|
|
_stack: ExitStack = field(default_factory=ExitStack)
|
|
utility: Utility = field(init=False)
|
|
plugins: dict[str, Plugin] = field(default_factory=dict)
|
|
|
|
def __post_init__(self) -> None:
|
|
self.logger = logging.getLogger('AppDaemon')
|
|
self.utility = Utility(self, self.context.stop_event)
|
|
self.plugins['dummy'] = Plugin(self, self.context.stop_event)
|
|
|
|
def __enter__(self):
|
|
self.logger.info('Starting AppDaemon')
|
|
self._stack.enter_context(self.utility)
|
|
for plugin in self.plugins.values():
|
|
self._stack.enter_context(plugin)
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self._stack.__exit__(exc_type, exc_value, traceback)
|
|
|
|
def create_task(self, coro: Coroutine, name: str | None = None, critical: bool = False):
|
|
"""Creates an async task and adds exception callbacks"""
|
|
task = self.context.loop.create_task(coro, name=name)
|
|
task.add_done_callback(self.check_task_exception)
|
|
if critical:
|
|
task.add_done_callback(self.critical_exception)
|
|
return task
|
|
|
|
def check_task_exception(self, task: asyncio.Task):
|
|
if (exc := task.exception()) and not isinstance(exc, asyncio.CancelledError):
|
|
self.logger.error('\n'.join(traceback.format_exception(exc)))
|
|
|
|
def critical_exception(self, task: asyncio.Task):
|
|
if task.exception():
|
|
self.logger.critical(
|
|
'Critical error in %s, forcing shutdown',
|
|
task.get_name()
|
|
)
|
|
self.shutdown()
|
|
|
|
def shutdown(self):
|
|
self.logger.debug('Shutting down by sending SIGTERM')
|
|
os.kill(os.getpid(), signal.SIGTERM)
|