Files
home-automation/apps/api/main.py

430 lines
13 KiB
Python

"""API main entry point."""
import asyncio
import json
import logging
import os
from pathlib import Path
from typing import Any, AsyncGenerator
import redis.asyncio as aioredis
import yaml
from aiomqtt import Client
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ValidationError
from packages.home_capabilities import LIGHT_VERSION, THERMOSTAT_VERSION, LightState, ThermostatState
logger = logging.getLogger(__name__)
app = FastAPI(
title="Home Automation API",
description="API for home automation system",
version="0.1.0"
)
# Configure CORS for localhost (Frontend)
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:8002",
"http://172.19.1.11:8002",
"http://127.0.0.1:8002",
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
async def health() -> dict[str, str]:
"""Health check endpoint.
Returns:
dict: Status indicating the service is healthy
"""
return {"status": "ok"}
@app.get("/spec")
async def spec() -> dict[str, dict[str, str]]:
"""Capability specification endpoint.
Returns:
dict: Dictionary containing supported capabilities and their versions
"""
return {
"capabilities": {
"light": LIGHT_VERSION,
"thermostat": THERMOSTAT_VERSION
}
}
# Pydantic Models
class SetDeviceRequest(BaseModel):
"""Request model for setting device state."""
type: str
payload: dict[str, Any]
class DeviceInfo(BaseModel):
"""Device information model."""
device_id: str
type: str
name: str
features: dict[str, Any] = {}
# Configuration helpers
def load_devices() -> list[dict[str, Any]]:
"""Load devices from configuration file.
Returns:
list: List of device configurations
"""
config_path = Path(__file__).parent.parent.parent / "config" / "devices.yaml"
if not config_path.exists():
return []
with open(config_path, "r") as f:
config = yaml.safe_load(f)
# Normalize device entries: accept both 'id' and 'device_id', use 'device_id' internally
devices = config.get("devices", [])
for device in devices:
device["device_id"] = device.pop("device_id", device.pop("id", None))
return devices
def get_mqtt_settings() -> tuple[str, int]:
"""Get MQTT broker settings from environment.
Supports both MQTT_BROKER and MQTT_HOST for compatibility.
Returns:
tuple: (host, port)
"""
host = os.environ.get("MQTT_BROKER") or os.environ.get("MQTT_HOST", "172.16.2.16")
port = int(os.environ.get("MQTT_PORT", "1883"))
return host, port
def get_redis_settings() -> tuple[str, str]:
"""Get Redis settings from configuration.
Prioritizes environment variables over config file:
- REDIS_HOST, REDIS_PORT, REDIS_DB → redis://host:port/db
- REDIS_CHANNEL → pub/sub channel name
Returns:
tuple: (url, channel)
"""
# Check environment variables first
redis_host = os.getenv("REDIS_HOST")
redis_port = os.getenv("REDIS_PORT", "6379")
redis_db = os.getenv("REDIS_DB", "0")
redis_channel = os.getenv("REDIS_CHANNEL", "ui:updates")
if redis_host:
url = f"redis://{redis_host}:{redis_port}/{redis_db}"
return url, redis_channel
# Fallback to config file
config_path = Path(__file__).parent.parent.parent / "config" / "devices.yaml"
if config_path.exists():
with open(config_path, "r") as f:
config = yaml.safe_load(f)
redis_config = config.get("redis", {})
url = redis_config.get("url", "redis://localhost:6379/0")
channel = redis_config.get("channel", "ui:updates")
return url, channel
return "redis://localhost:6379/0", "ui:updates"
async def publish_mqtt(topic: str, payload: dict[str, Any]) -> None:
"""Publish message to MQTT broker.
Args:
topic: MQTT topic to publish to
payload: Message payload
"""
host, port = get_mqtt_settings()
message = json.dumps(payload)
async with Client(hostname=host, port=port, identifier="home-automation-api") as client:
await client.publish(topic, message, qos=1)
@app.get("/devices")
async def get_devices() -> list[DeviceInfo]:
"""Get list of available devices.
Returns:
list: List of device information including features
"""
devices = load_devices()
return [
DeviceInfo(
device_id=device["device_id"],
type=device["type"],
name=device.get("name", device["device_id"]),
features=device.get("features", {})
)
for device in devices
]
@app.get("/layout")
async def get_layout() -> dict[str, Any]:
"""Get UI layout configuration.
Returns:
dict: Layout configuration with rooms and device tiles
"""
from packages.home_capabilities import load_layout
try:
layout = load_layout()
# Convert Pydantic models to dict
rooms = []
for room in layout.rooms:
devices = []
for tile in room.devices:
devices.append({
"device_id": tile.device_id,
"title": tile.title,
"icon": tile.icon,
"rank": tile.rank
})
rooms.append({
"name": room.name,
"devices": devices
})
return {"rooms": rooms}
except Exception as e:
logger.error(f"Error loading layout: {e}")
# Return empty layout on error
return {"rooms": []}
@app.post("/devices/{device_id}/set", status_code=status.HTTP_202_ACCEPTED)
async def set_device(device_id: str, request: SetDeviceRequest) -> dict[str, str]:
"""Set device state.
Args:
device_id: Device identifier
request: Device state request
Returns:
dict: Confirmation message
Raises:
HTTPException: If device not found or payload invalid
"""
# Load devices and check if device exists
devices = load_devices()
device = next((d for d in devices if d["device_id"] == device_id), None)
if not device:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Device {device_id} not found"
)
# Validate payload based on device type
if request.type == "light":
try:
LightState(**request.payload)
except ValidationError as e:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid payload for light: {e}"
)
elif request.type == "thermostat":
try:
# For thermostat SET: only allow mode and target
allowed_set_fields = {"mode", "target"}
invalid_fields = set(request.payload.keys()) - allowed_set_fields
if invalid_fields:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Thermostat SET only allows {allowed_set_fields}, got invalid fields: {invalid_fields}"
)
ThermostatState(**request.payload)
except ValidationError as e:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid payload for thermostat: {e}"
)
else:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Unsupported device type: {request.type}"
)
# Publish to MQTT
topic = f"home/{request.type}/{device_id}/set"
mqtt_payload = {
"type": request.type,
"payload": request.payload
}
await publish_mqtt(topic, mqtt_payload)
return {"message": f"Command sent to {device_id}"}
async def event_generator(request: Request) -> AsyncGenerator[str, None]:
"""Generate SSE events from Redis Pub/Sub with Safari compatibility.
Safari-compatible features:
- Immediate retry hint on connection
- Regular heartbeats every 15s (comment-only, no data)
- Proper flushing after each yield
- Graceful disconnect handling
Args:
request: FastAPI request object for disconnect detection
Yields:
str: SSE formatted event strings
"""
redis_client = None
pubsub = None
try:
# Send retry hint immediately for EventSource reconnect behavior
yield "retry: 2500\n\n"
# Try to connect to Redis
redis_url, redis_channel = get_redis_settings()
try:
redis_client = await aioredis.from_url(redis_url, decode_responses=True)
pubsub = redis_client.pubsub()
await pubsub.subscribe(redis_channel)
logger.info(f"SSE client connected, subscribed to {redis_channel}")
except Exception as e:
logger.warning(f"Redis unavailable, running in heartbeat-only mode: {e}")
redis_client = None
pubsub = None
# Heartbeat tracking
last_heartbeat = asyncio.get_event_loop().time()
heartbeat_interval = 15 # Safari-friendly: shorter interval
while True:
# Check if client disconnected
if await request.is_disconnected():
logger.info("SSE client disconnected")
break
# Try to get message from Redis (if available)
if pubsub:
try:
message = await asyncio.wait_for(
pubsub.get_message(ignore_subscribe_messages=True),
timeout=0.1
)
if message and message["type"] == "message":
data = message["data"]
logger.debug(f"Sending SSE message: {data[:100]}...")
yield f"event: message\ndata: {data}\n\n"
last_heartbeat = asyncio.get_event_loop().time()
continue # Skip sleep, check for more messages immediately
except asyncio.TimeoutError:
pass # No message, continue to heartbeat check
except Exception as e:
logger.error(f"Redis error: {e}")
# Continue with heartbeats even if Redis fails
# Sleep briefly to avoid busy loop
await asyncio.sleep(0.1)
# Send heartbeat if interval elapsed
current_time = asyncio.get_event_loop().time()
if current_time - last_heartbeat >= heartbeat_interval:
# Comment-style ping (Safari-compatible, no event type)
yield ": ping\n\n"
last_heartbeat = current_time
except asyncio.CancelledError:
logger.info("SSE connection cancelled by client")
raise
except Exception as e:
logger.error(f"SSE error: {e}")
raise
finally:
# Cleanup Redis connection
if pubsub:
try:
await pubsub.unsubscribe(redis_channel)
await pubsub.aclose()
except Exception as e:
logger.error(f"Error closing pubsub: {e}")
if redis_client:
try:
await redis_client.aclose()
except Exception as e:
logger.error(f"Error closing redis: {e}")
logger.info("SSE connection closed")
@app.get("/realtime")
async def realtime_events(request: Request) -> StreamingResponse:
"""Server-Sent Events endpoint for real-time updates.
Safari-compatible SSE implementation:
- Immediate retry hint (2.5s reconnect delay)
- Heartbeat every 15s using comment syntax ": ping"
- Proper Cache-Control headers
- No buffering (nginx compatibility)
- Graceful Redis fallback (heartbeat-only mode)
Args:
request: FastAPI request object
Returns:
StreamingResponse: SSE stream with Redis messages and heartbeats
"""
return StreamingResponse(
event_generator(request),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache, no-transform",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Disable nginx buffering
}
)
def main() -> None:
"""Run the API application with uvicorn."""
import uvicorn
uvicorn.run(
"apps.api.main:app",
host="0.0.0.0",
port=8001,
reload=True
)
if __name__ == "__main__":
main()