97 lines
2.7 KiB
Python
97 lines
2.7 KiB
Python
import ssl
|
|
import sys
|
|
from ipaddress import IPv4Address
|
|
from pathlib import Path
|
|
from typing import Any, List, Literal, Optional, Union
|
|
|
|
from pydantic import BaseModel, SecretStr, validator
|
|
|
|
|
|
class PluginConfig(BaseModel):
|
|
""""""
|
|
|
|
type: str
|
|
|
|
|
|
class HASSPluginConfig(PluginConfig):
|
|
type: Literal['hass'] = 'hass'
|
|
token: SecretStr
|
|
ha_url: str
|
|
|
|
|
|
class MQTTPluginConfig(PluginConfig):
|
|
version: str = '1.0'
|
|
type: Literal['mqtt'] = 'mqtt'
|
|
name: Optional[str] = None
|
|
namespace: str = 'default'
|
|
client_host: Union[str, IPv4Address] = '127.0.0.1'
|
|
client_port: int = 1883
|
|
client_timeout: int = 60
|
|
client_transport: Literal['tcp', 'websockets'] = 'tcp'
|
|
client_clean_session: bool = True
|
|
client_id: Optional[str] = None
|
|
client_user: Optional[str] = None
|
|
client_password: Optional[SecretStr] = None
|
|
client_cert: Optional[Path] = None
|
|
client_key: Optional[Path] = None
|
|
verify_cert: bool = True
|
|
tls_version: ssl._SSLMethod = 'auto'
|
|
ca_cert: Optional[str] = None
|
|
event_name: str = 'MQTT_MESSAGE'
|
|
client_topics: Union[List[str] | Literal['NONE']] = ['#']
|
|
client_qos: Any = 0
|
|
|
|
status_topic: str = None
|
|
birth_topic: Optional[str] = None
|
|
birth_payload: str = 'online'
|
|
birth_retain: bool = True
|
|
will_topic: Optional[str] = None
|
|
will_payload: str = 'offline'
|
|
will_retain: bool = True
|
|
shutdown_payload: Optional[str] = None
|
|
force_start: bool = False
|
|
|
|
def model_post_init(self, context: Any):
|
|
if self.client_topics == 'NONE':
|
|
self.client_topics = []
|
|
|
|
if self.will_topic is None:
|
|
self.will_topic = self.status_topic
|
|
|
|
if self.birth_topic is None:
|
|
self.birth_topic = self.status_topic
|
|
|
|
@validator('status_topic', pre=True, always=True)
|
|
@classmethod
|
|
def set_status_topic(cls, v, values):
|
|
if v is None:
|
|
client_id = values['client_id'] if values['client_id'] else f'{values["name"]}-client'
|
|
status_topic = f'{client_id}/status'
|
|
return status_topic
|
|
else:
|
|
return v
|
|
|
|
@validator('tls_version', pre=True, always=True)
|
|
@classmethod
|
|
def convert_tls_version(cls, v, values):
|
|
if v.lower() == 'auto':
|
|
if sys.hexversion >= 0x03060000:
|
|
return ssl.PROTOCOL_TLS
|
|
else:
|
|
return ssl.PROTOCOL_TLSv1
|
|
|
|
val_map = {
|
|
'1.0': ssl.PROTOCOL_TLSv1,
|
|
'1.1': ssl.PROTOCOL_TLSv1_1,
|
|
'1.2': ssl.PROTOCOL_TLSv1_2,
|
|
}
|
|
return val_map[v]
|
|
|
|
@validator('client_id', pre=True, always=True)
|
|
@classmethod
|
|
def validate_client_id(cls, v, values):
|
|
if v is None:
|
|
return f"appdaemon_{values['name']}_client"
|
|
else:
|
|
return v
|