417 lines
17 KiB
Python
417 lines
17 KiB
Python
"""Abstraction main entry point."""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import redis.asyncio as aioredis
|
|
import yaml
|
|
import socket
|
|
import uuid
|
|
from aiomqtt import Client
|
|
from pydantic import ValidationError
|
|
|
|
from packages.home_capabilities import LightState, ThermostatState, ContactState, TempHumidityState, RelayState
|
|
from apps.abstraction.transformation import (
|
|
transform_abstract_to_vendor,
|
|
transform_vendor_to_abstract
|
|
)
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.DEBUG,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def load_config(config_path: Path) -> dict[str, Any]:
|
|
"""Load configuration from YAML file.
|
|
|
|
Args:
|
|
config_path: Path to the configuration file
|
|
|
|
Returns:
|
|
dict: Configuration dictionary
|
|
"""
|
|
if not config_path.exists():
|
|
logger.warning(f"Config file not found: {config_path}, using defaults")
|
|
return {
|
|
"mqtt": {
|
|
"broker": os.getenv("MQTT_BROKER", "localhost"),
|
|
"port": int(os.getenv("MQTT_PORT", "1883")),
|
|
"client_id": "home-automation-abstraction",
|
|
"keepalive": 60
|
|
},
|
|
"devices": []
|
|
}
|
|
|
|
with open(config_path, "r") as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
# Normalize device entries: accept both 'id' and 'device_id', use 'device_id' internally
|
|
devices = config.get("devices", [])
|
|
for device in devices:
|
|
device["device_id"] = device.pop("device_id", device.pop("id", None))
|
|
|
|
logger.info(f"Loaded configuration from {config_path}")
|
|
return config
|
|
|
|
|
|
def validate_devices(devices: list[dict[str, Any]]) -> None:
|
|
"""Validate device configuration.
|
|
|
|
Args:
|
|
devices: List of device configurations
|
|
|
|
Raises:
|
|
ValueError: If device configuration is invalid
|
|
"""
|
|
required_fields = ["device_id", "type", "cap_version", "technology"]
|
|
|
|
for device in devices:
|
|
# Check for device_id
|
|
if "device_id" not in device or device["device_id"] is None:
|
|
raise ValueError(f"Device entry requires 'id' or 'device_id': {device}")
|
|
|
|
device_id = device["device_id"]
|
|
|
|
# Check required top-level fields
|
|
for field in required_fields:
|
|
if field not in device:
|
|
raise ValueError(f"Device {device_id} missing '{field}'")
|
|
|
|
# Check topics structure
|
|
if "topics" not in device:
|
|
raise ValueError(f"Device {device_id} missing 'topics'")
|
|
|
|
# 'state' topic is required for all devices
|
|
if "state" not in device["topics"]:
|
|
raise ValueError(f"Device {device_id} missing 'topics.state'")
|
|
|
|
# 'set' topic is optional (read-only devices like contact sensors don't have it)
|
|
# No validation needed for topics.set
|
|
|
|
# Log loaded devices
|
|
device_ids = [d["device_id"] for d in devices]
|
|
logger.info(f"Loaded {len(devices)} device(s): {', '.join(device_ids)}")
|
|
|
|
|
|
async def get_redis_client(redis_url: str, max_retries: int = 5) -> aioredis.Redis:
|
|
"""Connect to Redis with exponential backoff.
|
|
|
|
Args:
|
|
redis_url: Redis connection URL
|
|
max_retries: Maximum number of connection attempts
|
|
|
|
Returns:
|
|
Redis client instance
|
|
"""
|
|
retry_delay = 1
|
|
for attempt in range(max_retries):
|
|
try:
|
|
redis_client = await aioredis.from_url(redis_url, decode_responses=True)
|
|
await redis_client.ping()
|
|
logger.info(f"Connected to Redis: {redis_url}")
|
|
return redis_client
|
|
except Exception as e:
|
|
if attempt < max_retries - 1:
|
|
logger.warning(f"Redis connection failed (attempt {attempt + 1}/{max_retries}): {e}")
|
|
await asyncio.sleep(retry_delay)
|
|
retry_delay = min(retry_delay * 2, 30) # Exponential backoff, max 30s
|
|
else:
|
|
logger.error(f"Failed to connect to Redis after {max_retries} attempts")
|
|
raise
|
|
|
|
|
|
async def handle_abstract_set(
|
|
mqtt_client: Client,
|
|
device_id: str,
|
|
device_type: str,
|
|
device_technology: str,
|
|
vendor_topic: str,
|
|
payload: dict[str, Any]
|
|
) -> None:
|
|
"""Handle abstract SET message and publish to vendor topic.
|
|
|
|
Args:
|
|
mqtt_client: MQTT client instance
|
|
device_id: Device identifier
|
|
device_type: Device type (e.g., 'light', 'thermostat')
|
|
device_technology: Technology identifier (e.g., 'zigbee2mqtt')
|
|
vendor_topic: Vendor-specific SET topic
|
|
payload: Message payload
|
|
"""
|
|
# Extract actual payload (remove type wrapper if present)
|
|
abstract_payload = payload.get("payload", payload)
|
|
|
|
# Validate payload based on device type
|
|
try:
|
|
if device_type == "light":
|
|
# Validate light SET payload (power and/or brightness)
|
|
LightState.model_validate(abstract_payload)
|
|
elif device_type == "relay":
|
|
# Validate relay SET payload (power only)
|
|
RelayState.model_validate(abstract_payload)
|
|
elif device_type == "thermostat":
|
|
# For thermostat SET: only allow mode and target fields
|
|
allowed_set_fields = {"mode", "target"}
|
|
invalid_fields = set(abstract_payload.keys()) - allowed_set_fields
|
|
if invalid_fields:
|
|
logger.warning(
|
|
f"Thermostat SET {device_id} contains invalid fields {invalid_fields}, "
|
|
f"only {allowed_set_fields} allowed"
|
|
)
|
|
return
|
|
|
|
# Validate against ThermostatState (current/battery/window_open are optional)
|
|
ThermostatState.model_validate(abstract_payload)
|
|
elif device_type in {"contact", "contact_sensor"}:
|
|
# Contact sensors are read-only - SET commands should not occur
|
|
logger.warning(f"Contact sensor {device_id} received SET command - ignoring (read-only device)")
|
|
return
|
|
except ValidationError as e:
|
|
logger.error(f"Validation failed for {device_type} SET {device_id}: {e}")
|
|
return
|
|
|
|
# Transform abstract payload to vendor-specific format
|
|
vendor_payload = transform_abstract_to_vendor(device_type, device_technology, abstract_payload)
|
|
|
|
# For MAX! thermostats and Shelly relays, vendor_payload is a plain string
|
|
# For other devices, it's a dict that needs JSON encoding
|
|
if (device_technology == "max" and device_type == "thermostat") or \
|
|
(device_technology == "shelly" and device_type == "relay"):
|
|
vendor_message = vendor_payload # Already a string
|
|
else:
|
|
vendor_message = json.dumps(vendor_payload)
|
|
|
|
logger.info(f"→ vendor SET {device_id}: {vendor_topic} ← {vendor_message}")
|
|
await mqtt_client.publish(vendor_topic, vendor_message, qos=1)
|
|
|
|
|
|
async def handle_vendor_state(
|
|
mqtt_client: Client,
|
|
redis_client: aioredis.Redis,
|
|
device_id: str,
|
|
device_type: str,
|
|
device_technology: str,
|
|
payload: str,
|
|
redis_channel: str = "ui:updates"
|
|
) -> None:
|
|
"""Handle vendor STATE message and publish to abstract topic + Redis.
|
|
|
|
Args:
|
|
mqtt_client: MQTT client instance
|
|
redis_client: Redis client instance
|
|
device_id: Device identifier
|
|
device_type: Device type (e.g., 'light', 'thermostat')
|
|
device_technology: Technology identifier (e.g., 'zigbee2mqtt')
|
|
payload: string Message payload
|
|
redis_channel: Redis channel for UI updates
|
|
"""
|
|
# Transform vendor-specific payload to abstract format
|
|
abstract_payload = transform_vendor_to_abstract(device_type, device_technology, payload)
|
|
|
|
# Validate state payload based on device type
|
|
try:
|
|
if device_type == "light":
|
|
LightState.model_validate(abstract_payload)
|
|
elif device_type == "relay":
|
|
RelayState.model_validate(abstract_payload)
|
|
elif device_type == "thermostat":
|
|
# Validate thermostat state: mode, target, current (required), battery, window_open
|
|
ThermostatState.model_validate(abstract_payload)
|
|
elif device_type in {"contact", "contact_sensor"}:
|
|
# Validate contact sensor state
|
|
ContactState.model_validate(abstract_payload)
|
|
elif device_type in {"temp_humidity", "temp_humidity_sensor"}:
|
|
# Validate temperature & humidity sensor state
|
|
TempHumidityState.model_validate(abstract_payload)
|
|
except ValidationError as e:
|
|
logger.error(f"Validation failed for {device_type} STATE {device_id}: {e}")
|
|
return
|
|
|
|
# Normalize device type for topic (use 'contact' for both 'contact' and 'contact_sensor')
|
|
topic_type = "contact" if device_type in {"contact", "contact_sensor"} else device_type
|
|
topic_type = "temp_humidity" if device_type in {"temp_humidity", "temp_humidity_sensor"} else topic_type
|
|
|
|
# Publish to abstract state topic (retained)
|
|
abstract_topic = f"home/{topic_type}/{device_id}/state"
|
|
abstract_message = json.dumps(abstract_payload)
|
|
|
|
logger.info(f"← abstract STATE {device_id}: {abstract_topic} → {abstract_message}")
|
|
await mqtt_client.publish(abstract_topic, abstract_message, qos=1, retain=True)
|
|
|
|
# Publish to Redis for UI updates with timestamp
|
|
ui_update = {
|
|
"type": "state",
|
|
"device_id": device_id,
|
|
"payload": abstract_payload,
|
|
"ts": datetime.now(timezone.utc).isoformat()
|
|
}
|
|
redis_message = json.dumps(ui_update)
|
|
|
|
logger.info(f"← Redis PUBLISH {redis_channel} → {redis_message}")
|
|
await redis_client.publish(redis_channel, redis_message)
|
|
|
|
|
|
async def mqtt_worker(config: dict[str, Any], redis_client: aioredis.Redis) -> None:
|
|
"""MQTT worker that handles device communication.
|
|
|
|
Args:
|
|
config: Configuration dictionary containing MQTT settings
|
|
redis_client: Redis client for UI updates
|
|
"""
|
|
mqtt_config = config.get("mqtt", {})
|
|
broker = os.getenv("MQTT_BROKER") or mqtt_config.get("broker", "localhost")
|
|
port = int(os.getenv("MQTT_PORT", mqtt_config.get("port", 1883)))
|
|
client_id = mqtt_config.get("client_id", "home-automation-abstraction")
|
|
# Append a short suffix (ENV override possible) so multiple processes don't collide
|
|
client_suffix = os.environ.get("MQTT_CLIENT_ID_SUFFIX") or uuid.uuid4().hex[:6]
|
|
unique_client_id = f"{client_id}-{client_suffix}"
|
|
keepalive = mqtt_config.get("keepalive", 60)
|
|
|
|
redis_config = config.get("redis", {})
|
|
redis_channel = redis_config.get("channel", "ui:updates")
|
|
|
|
devices = {d["device_id"]: d for d in config.get("devices", [])}
|
|
|
|
retry_delay = 1
|
|
max_retry_delay = 60
|
|
|
|
while True:
|
|
try:
|
|
logger.info(f"Connecting to MQTT broker: {broker}:{port}")
|
|
|
|
async with Client(
|
|
hostname=broker,
|
|
port=port,
|
|
identifier=unique_client_id,
|
|
keepalive=keepalive,
|
|
timeout=10.0 # Add explicit timeout for operations
|
|
) as client:
|
|
logger.info(f"Connected to MQTT broker as {unique_client_id}")
|
|
|
|
async with client.messages as messages:
|
|
# Subscribe to topics for all devices
|
|
for device in devices.values():
|
|
device_id = device['device_id']
|
|
device_type = device['type']
|
|
|
|
# Subscribe to abstract SET topic only if device has a SET topic (not read-only)
|
|
if "set" in device["topics"]:
|
|
abstract_set_topic = f"home/{device_type}/{device_id}/set"
|
|
await client.subscribe(abstract_set_topic)
|
|
logger.info(f"Subscribed to abstract SET: {abstract_set_topic}")
|
|
else:
|
|
logger.info(f"Skipping SET subscription for read-only device: {device_id}")
|
|
|
|
# Subscribe to vendor STATE topics (all devices have state)
|
|
vendor_state_topic = device["topics"]["state"]
|
|
await client.subscribe(vendor_state_topic)
|
|
logger.info(f"Subscribed to vendor STATE: {vendor_state_topic}")
|
|
|
|
# Reset retry delay on successful connection
|
|
retry_delay = 1
|
|
|
|
# Track last activity for connection health
|
|
last_activity = asyncio.get_event_loop().time()
|
|
connection_timeout = keepalive * 2 # 2x keepalive as timeout
|
|
|
|
# Process messages
|
|
async for message in messages:
|
|
try:
|
|
last_activity = asyncio.get_event_loop().time()
|
|
topic = str(message.topic)
|
|
payload_str = message.payload.decode()
|
|
logger.debug(f"MQTT message received on {topic}: {payload_str}")
|
|
|
|
# Check if this is an abstract SET message
|
|
if topic.startswith("home/") and topic.endswith("/set"):
|
|
|
|
payload = json.loads(payload_str)
|
|
|
|
# Extract device_type and device_id from topic
|
|
parts = topic.split("/")
|
|
if len(parts) == 4: # home/<type>/<id>/set
|
|
device_type = parts[1]
|
|
device_id = parts[2]
|
|
|
|
if device_id in devices:
|
|
device = devices[device_id]
|
|
vendor_topic = device["topics"]["set"]
|
|
device_technology = device.get("technology", "unknown")
|
|
await handle_abstract_set(
|
|
client, device_id, device_type, device_technology, vendor_topic, payload
|
|
)
|
|
|
|
# Check if this is a vendor STATE message
|
|
else:
|
|
# Find device by vendor state topic for other technologies
|
|
for device_id, device in devices.items():
|
|
if topic == device["topics"]["state"]:
|
|
device_technology = device.get("technology", "unknown")
|
|
await handle_vendor_state(
|
|
client, redis_client, device_id, device["type"],
|
|
device_technology, payload_str, redis_channel
|
|
)
|
|
break
|
|
except json.JSONDecodeError:
|
|
logger.error(f"Failed to decode JSON payload on topic {topic}: {payload_str}")
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info("MQTT worker cancelled")
|
|
raise
|
|
except Exception as e:
|
|
import traceback
|
|
logger.error(f"MQTT error: {e}")
|
|
logger.debug(f"Traceback: {traceback.format_exc()}")
|
|
logger.info(f"Reconnecting in {retry_delay}s...")
|
|
await asyncio.sleep(retry_delay)
|
|
retry_delay = min(retry_delay * 2, max_retry_delay)
|
|
|
|
|
|
async def async_main() -> None:
|
|
"""Async main function for the abstraction worker."""
|
|
# Determine config path
|
|
config_path = Path(__file__).parent.parent.parent / "config" / "devices.yaml"
|
|
|
|
# Load configuration
|
|
config = load_config(config_path)
|
|
|
|
# Validate devices
|
|
devices = config.get("devices") or []
|
|
validate_devices(devices)
|
|
logger.info(f"Loaded {len(devices)} device(s) from configuration")
|
|
|
|
# Get Redis URL from config or environment variable or use default
|
|
redis_config = config.get("redis", {})
|
|
redis_url = redis_config.get("url") or os.environ.get("REDIS_URL", "redis://localhost:6379/0")
|
|
|
|
# Connect to Redis with retry
|
|
redis_client = await get_redis_client(redis_url)
|
|
|
|
logger.info("Abstraction worker started")
|
|
|
|
# Start MQTT worker
|
|
await mqtt_worker(config, redis_client)
|
|
|
|
|
|
def main() -> None:
|
|
"""Run the abstraction application."""
|
|
try:
|
|
asyncio.run(async_main())
|
|
except KeyboardInterrupt:
|
|
logger.info("Abstraction worker stopped by user")
|
|
except Exception as e:
|
|
logger.error(f"Fatal error: {e}")
|
|
raise
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|