Files
home-automation/apps/rules/rule_interface.py

760 lines
26 KiB
Python

"""
Rule Interface and Context Objects
Provides the core abstractions for implementing automation rules:
- RuleDescriptor: Configuration data for a rule instance
- RedisState: State persistence interface
- RuleContext: Runtime context provided to rules
- Rule: Abstract base class for all rule implementations
"""
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Awaitable, Optional
from pydantic import BaseModel, Field
class RuleDescriptor(BaseModel):
"""
Configuration descriptor for a rule instance.
This is the validated representation of a rule from rules.yaml.
The engine loads these and passes them to rule implementations.
The 'objects' field is intentionally flexible (dict) to allow different
rule types to define their own object structures.
"""
id: str = Field(..., description="Unique identifier for this rule instance")
name: Optional[str] = Field(None, description="Optional human-readable name")
type: str = Field(..., description="Rule type with version (e.g., 'window_setback@1.0')")
enabled: bool = Field(default=True, description="Whether this rule is enabled")
objects: dict[str, Any] = Field(
default_factory=dict,
description="Objects this rule monitors or controls (structure varies by rule type)"
)
params: dict[str, Any] = Field(
default_factory=dict,
description="Rule-specific parameters"
)
class RedisState:
"""
Async Redis-backed state persistence for rules with automatic reconnection.
Provides a simple key-value and hash storage interface for rules to persist
state across restarts. All operations are asynchronous and include retry logic
for robustness against temporary Redis outages.
Key Convention:
- Callers should use keys like: f"rules:{rule_id}:contact:{device_id}"
- This class does NOT enforce key prefixes - caller controls the full key
"""
def __init__(self, url: str, max_retries: int = 3, retry_delay: float = 0.5):
"""
Initialize RedisState with connection URL.
Args:
url: Redis connection URL (e.g., 'redis://172.23.1.116:6379/8')
max_retries: Maximum number of retry attempts for operations (default: 3)
retry_delay: Initial delay between retries in seconds, uses exponential backoff (default: 0.5)
Note:
Connection is lazy - actual connection happens on first operation.
Uses connection pooling with automatic reconnection on failure.
"""
self._url = url
self._max_retries = max_retries
self._retry_delay = retry_delay
self._redis: Optional[Any] = None # redis.asyncio.Redis instance
async def _get_client(self):
"""
Get or create Redis client with connection pool.
Lazy initialization ensures we don't connect until first use.
Uses decode_responses=True for automatic UTF-8 decoding.
"""
if self._redis is None:
import redis.asyncio as aioredis
self._redis = await aioredis.from_url(
self._url,
decode_responses=True, # Automatic UTF-8 decode
encoding='utf-8',
max_connections=10, # Connection pool size
socket_connect_timeout=5,
socket_keepalive=True,
health_check_interval=30 # Auto-check connection health
)
return self._redis
async def _execute_with_retry(self, operation, *args, **kwargs):
"""
Execute Redis operation with exponential backoff retry.
Handles temporary connection failures gracefully by retrying
with exponential backoff. On permanent failure, raises the
original exception.
Args:
operation: Async callable (Redis method)
*args, **kwargs: Arguments to pass to operation
Returns:
Result of the operation
Raises:
Exception: If all retries are exhausted
"""
import asyncio
last_exception = None
for attempt in range(self._max_retries):
try:
client = await self._get_client()
return await operation(client, *args, **kwargs)
except Exception as e:
last_exception = e
if attempt < self._max_retries - 1:
# Exponential backoff: 0.5s, 1s, 2s, ...
delay = self._retry_delay * (2 ** attempt)
await asyncio.sleep(delay)
# Reset client to force reconnection
if self._redis:
try:
await self._redis.close()
except:
pass
self._redis = None
# All retries exhausted
raise last_exception
# JSON helpers for complex data structures
def _dumps(self, obj: Any) -> str:
"""Serialize Python object to JSON string."""
import json
return json.dumps(obj, ensure_ascii=False)
def _loads(self, s: str) -> Any:
"""Deserialize JSON string to Python object."""
import json
return json.loads(s)
async def get(self, key: str) -> Optional[str]:
"""
Get a string value by key.
Args:
key: Redis key (e.g., "rules:my_rule:contact:sensor_1")
Returns:
String value or None if key doesn't exist
Example:
>>> state = RedisState("redis://localhost:6379/0")
>>> await state.set("rules:r1:temp", "22.5")
>>> temp = await state.get("rules:r1:temp")
>>> print(temp) # "22.5"
"""
async def _get(client, k):
return await client.get(k)
return await self._execute_with_retry(_get, key)
async def set(self, key: str, value: str, ttl_secs: Optional[int] = None) -> None:
"""
Set a string value with optional TTL.
Args:
key: Redis key
value: String value to store
ttl_secs: Optional time-to-live in seconds. If None, key persists indefinitely.
Example:
>>> state = RedisState("redis://localhost:6379/0")
>>> # Store with 1 hour TTL
>>> await state.set("rules:r1:previous_temp", "20.0", ttl_secs=3600)
"""
async def _set(client, k, v, ttl):
if ttl is not None:
await client.setex(k, ttl, v)
else:
await client.set(k, v)
await self._execute_with_retry(_set, key, value, ttl_secs)
async def hget(self, key: str, field: str) -> Optional[str]:
"""
Get a hash field value.
Args:
key: Redis hash key
field: Field name within the hash
Returns:
String value or None if field doesn't exist
Example:
>>> state = RedisState("redis://localhost:6379/0")
>>> await state.hset("rules:r1:device_states", "sensor_1", "open")
>>> value = await state.hget("rules:r1:device_states", "sensor_1")
>>> print(value) # "open"
"""
async def _hget(client, k, f):
return await client.hget(k, f)
return await self._execute_with_retry(_hget, key, field)
async def hset(self, key: str, field: str, value: str) -> None:
"""
Set a hash field value.
Args:
key: Redis hash key
field: Field name within the hash
value: String value to store
Example:
>>> state = RedisState("redis://localhost:6379/0")
>>> await state.hset("rules:r1:sensors", "bedroom", "open")
>>> await state.hset("rules:r1:sensors", "kitchen", "closed")
"""
async def _hset(client, k, f, v):
await client.hset(k, f, v)
await self._execute_with_retry(_hset, key, field, value)
async def expire(self, key: str, ttl_secs: int) -> None:
"""
Set or update TTL on an existing key.
Args:
key: Redis key
ttl_secs: Time-to-live in seconds
Example:
>>> state = RedisState("redis://localhost:6379/0")
>>> await state.set("rules:r1:temp", "22.5")
>>> await state.expire("rules:r1:temp", 3600) # Expire in 1 hour
"""
async def _expire(client, k, ttl):
await client.expire(k, ttl)
await self._execute_with_retry(_expire, key, ttl_secs)
async def delete(self, key: str) -> None:
"""
Delete a key from Redis.
Args:
key: Redis key to delete
Example:
>>> state = RedisState("redis://localhost:6379/0")
>>> await state.set("rules:r1:temp", "22.5")
>>> await state.delete("rules:r1:temp")
"""
async def _delete(client, k):
await client.delete(k)
await self._execute_with_retry(_delete, key)
async def close(self) -> None:
"""
Close Redis connection and cleanup resources.
Should be called when shutting down the application.
"""
if self._redis:
await self._redis.close()
self._redis = None
class MQTTClient:
"""
Async MQTT client for rule engine with event normalization and publishing.
Subscribes to device state topics, normalizes events to a consistent format,
and provides high-level publishing methods for device commands.
Event Normalization:
All incoming MQTT messages are parsed into a normalized event structure:
{
"topic": "home/contact/sensor_1/state",
"type": "state",
"cap": "contact", # Capability type (contact, thermostat, light, etc.)
"device_id": "sensor_1",
"payload": {"contact": "open"},
"ts": "2025-11-11T10:30:45.123456"
}
"""
def __init__(
self,
broker: str,
port: int = 1883,
client_id: str = "rule_engine",
reconnect_interval: int = 5,
max_reconnect_delay: int = 300
):
"""
Initialize MQTT client.
Args:
broker: MQTT broker hostname or IP
port: MQTT broker port (default: 1883)
client_id: Unique client ID for this connection
reconnect_interval: Initial reconnect delay in seconds (default: 5)
max_reconnect_delay: Maximum reconnect delay in seconds (default: 300)
"""
self._broker = broker
self._port = port
self._client_id = client_id
self._reconnect_interval = reconnect_interval
self._max_reconnect_delay = max_reconnect_delay
self._client = None
self._logger = None # Set externally
def set_logger(self, logger):
"""Set logger instance for connection status messages."""
self._logger = logger
def _log(self, level: str, msg: str):
"""Internal logging helper."""
if self._logger:
getattr(self._logger, level)(msg)
else:
print(f"[{level.upper()}] {msg}")
async def connect(self, topics: list[str] = None):
"""
Connect to MQTT broker with automatic reconnection.
This method manages the connection and automatically reconnects
with exponential backoff if the connection is lost.
Args:
topics: List of MQTT topics to subscribe to. If None, subscribes to nothing.
"""
import aiomqtt
from aiomqtt import Client
if topics is None:
topics = []
reconnect_delay = self._reconnect_interval
while True:
try:
self._log("info", f"Connecting to MQTT broker {self._broker}:{self._port} (client_id={self._client_id})")
async with Client(
hostname=self._broker,
port=self._port,
identifier=self._client_id,
) as client:
self._client = client
self._log("info", f"Connected to MQTT broker {self._broker}:{self._port}")
# Subscribe to provided topics
if topics:
for topic in topics:
await client.subscribe(topic)
self._log("info", f"Subscribed to {len(topics)} topic(s): {', '.join(topics[:5])}{'...' if len(topics) > 5 else ''}")
# Reset reconnect delay on successful connection
reconnect_delay = self._reconnect_interval
# Process messages - this is a generator that yields messages
async for message in client.messages:
yield self._normalize_event(message)
except aiomqtt.MqttError as e:
self._log("error", f"MQTT connection error: {e}")
self._log("info", f"Reconnecting in {reconnect_delay} seconds...")
import asyncio
await asyncio.sleep(reconnect_delay)
# Exponential backoff
reconnect_delay = min(reconnect_delay * 2, self._max_reconnect_delay)
def _normalize_event(self, message) -> dict[str, Any]:
"""
Normalize MQTT message to standard event format.
Parses topic to extract capability type and device_id,
adds timestamp, and structures payload.
Args:
message: aiomqtt.Message instance
Returns:
Normalized event dictionary
Example:
Topic: home/contact/sensor_bedroom/state
Payload: {"contact": "open"}
Returns:
{
"topic": "home/contact/sensor_bedroom/state",
"type": "state",
"cap": "contact",
"device_id": "sensor_bedroom",
"payload": {"contact": "open"},
"ts": "2025-11-11T10:30:45.123456"
}
"""
from datetime import datetime
import json
topic = str(message.topic)
topic_parts = topic.split('/')
# Parse topic: home/{capability}/{device_id}/state
if len(topic_parts) >= 4 and topic_parts[0] == 'home' and topic_parts[3] == 'state':
cap = topic_parts[1] # contact, thermostat, light, etc.
device_id = topic_parts[2]
else:
# Fallback for unexpected topic format
cap = "unknown"
device_id = topic_parts[-2] if len(topic_parts) >= 2 else "unknown"
# Parse payload
try:
payload = json.loads(message.payload.decode('utf-8'))
except (json.JSONDecodeError, UnicodeDecodeError):
payload = {"raw": message.payload.decode('utf-8', errors='replace')}
# Generate timestamp
ts = datetime.now().isoformat()
return {
"topic": topic,
"type": "state",
"cap": cap,
"device_id": device_id,
"payload": payload,
"ts": ts
}
async def publish_set_thermostat(self, device_id: str, target: float) -> None:
"""
Publish thermostat target temperature command.
Publishes to: home/thermostat/{device_id}/set
QoS: 1 (at least once delivery)
Args:
device_id: Thermostat device identifier
target: Target temperature in degrees Celsius
Example:
>>> mqtt = MQTTClient("172.16.2.16", 1883)
>>> await mqtt.publish_set_thermostat("thermostat_wohnzimmer", 22.5)
Published to: home/thermostat/thermostat_wohnzimmer/set
Payload: {"type":"thermostat","payload":{"target":22.5}}
"""
import json
if self._client is None:
raise RuntimeError("MQTT client not connected. Call connect() first.")
topic = f"home/thermostat/{device_id}/set"
payload = {
"type": "thermostat",
"payload": {
"target": target
}
}
payload_str = json.dumps(payload)
await self._client.publish(
topic,
payload=payload_str.encode('utf-8'),
qos=1 # At least once delivery
)
self._log("debug", f"Published SET to {topic}: {payload_str}")
# Legacy alias for backward compatibility
class MQTTPublisher:
"""
Legacy MQTT publishing interface - DEPRECATED.
Use MQTTClient instead for new code.
This class is kept for backward compatibility with existing documentation.
"""
def __init__(self, mqtt_client):
"""
Initialize MQTT publisher.
Args:
mqtt_client: MQTTClient instance
"""
self._mqtt = mqtt_client
async def publish_set_thermostat(self, device_id: str, target: float) -> None:
"""
Publish a thermostat target temperature command.
Args:
device_id: Thermostat device identifier
target: Target temperature in degrees Celsius
"""
await self._mqtt.publish_set_thermostat(device_id, target)
class RuleContext:
"""
Runtime context provided to rules during event processing.
Contains all external dependencies and utilities a rule needs:
- Logger for diagnostics
- MQTT client for publishing commands
- Redis client for state persistence
- Current timestamp function
"""
def __init__(
self,
logger,
mqtt_publisher: MQTTPublisher,
redis_state: RedisState,
now_fn=None
):
"""
Initialize rule context.
Args:
logger: Logger instance (e.g., logging.Logger)
mqtt_publisher: MQTTPublisher instance for device commands
redis_state: RedisState instance for persistence
now_fn: Optional callable returning current datetime (defaults to datetime.now)
"""
self.logger = logger
self.mqtt = mqtt_publisher
self.redis = redis_state
self._now_fn = now_fn or datetime.now
def now(self) -> datetime:
"""
Get current timestamp.
Returns:
Current datetime (timezone-aware if now_fn provides it)
"""
return self._now_fn()
class Rule(ABC):
"""
Abstract base class for all automation rule implementations.
Rules implement event-driven automation logic. The engine calls on_event()
for each relevant device state change, passing the event data, rule configuration,
and runtime context.
Implementations must be idempotent - processing the same event multiple times
should produce the same result.
Example implementation:
class WindowSetbackRule(Rule):
def get_subscriptions(self, desc: RuleDescriptor) -> list[str]:
# Subscribe to contact sensor state topics
topics = []
for contact_id in desc.objects.contacts or []:
topics.append(f"home/contact/{contact_id}/state")
return topics
async def on_event(self, evt: dict, desc: RuleDescriptor, ctx: RuleContext) -> None:
device_id = evt['device_id']
cap = evt['cap']
if cap == 'contact':
contact_state = evt['payload'].get('contact')
if contact_state == 'open':
# Window opened - set thermostats to eco
for thermo_id in desc.objects.thermostats or []:
eco_temp = desc.params.get('eco_target', 16.0)
await ctx.mqtt.publish_set_thermostat(thermo_id, eco_temp)
"""
@abstractmethod
def get_subscriptions(self, desc: RuleDescriptor) -> list[str]:
"""
Return list of MQTT topics this rule needs to subscribe to.
Called once during rule engine setup. The rule examines its configuration
(desc.objects) and returns the specific state topics it needs to monitor.
Args:
desc: Rule configuration from rules.yaml
Returns:
List of MQTT topic patterns/strings to subscribe to
Example:
For a window setback rule monitoring 2 contacts:
['home/contact/sensor_bedroom/state', 'home/contact/sensor_kitchen/state']
"""
pass
@abstractmethod
async def on_event(
self,
evt: dict[str, Any],
desc: RuleDescriptor,
ctx: RuleContext
) -> None:
"""
Process a device state change event.
This method is called by the rule engine whenever a device state changes
that is relevant to this rule. The implementation should examine the event
and take appropriate actions (e.g., publish MQTT commands, update state).
MUST be idempotent: Processing the same event multiple times should be safe.
Args:
evt: Event dictionary with the following structure:
{
"topic": "home/contact/device_id/state", # MQTT topic
"type": "state", # Message type
"cap": "contact", # Capability type
"device_id": "kontakt_wohnzimmer", # Device identifier
"payload": {"contact": "open"}, # Capability-specific payload
"ts": "2025-11-11T10:30:45.123456" # ISO timestamp
}
desc: Rule configuration from rules.yaml
ctx: Runtime context with logger, MQTT, Redis, and timestamp utilities
Returns:
None
Raises:
Exception: Implementation may raise exceptions for errors.
The engine will log them but continue processing.
"""
pass
# ============================================================================
# Dynamic Rule Loading
# ============================================================================
import importlib
import re
from typing import Type
# Cache for loaded rule classes (per process)
_RULE_CLASS_CACHE: dict[str, Type[Rule]] = {}
def load_rule(desc: RuleDescriptor) -> Rule:
"""
Dynamically load and instantiate a rule based on its type descriptor.
Convention:
- Rule type format: 'name@version' (e.g., 'window_setback@1.0')
- Module path: apps.rules.impl.{name}
- Class name: PascalCase version of name + 'Rule'
Example: 'window_setback''WindowSetbackRule'
Args:
desc: Rule descriptor from rules.yaml
Returns:
Instantiated Rule object
Raises:
ValueError: If type format is invalid
ImportError: If rule module cannot be found
AttributeError: If rule class cannot be found in module
Examples:
>>> desc = RuleDescriptor(
... id="test_rule",
... type="window_setback@1.0",
... targets={},
... params={}
... )
>>> rule = load_rule(desc)
>>> isinstance(rule, Rule)
True
"""
rule_type = desc.type
# Check cache first
if rule_type in _RULE_CLASS_CACHE:
rule_class = _RULE_CLASS_CACHE[rule_type]
return rule_class()
# Parse type: 'name@version'
if '@' not in rule_type:
raise ValueError(
f"Invalid rule type '{rule_type}': must be in format 'name@version' "
f"(e.g., 'window_setback@1.0')"
)
name, version = rule_type.split('@', 1)
# Validate name (alphanumeric and underscores only)
if not re.match(r'^[a-z][a-z0-9_]*$', name):
raise ValueError(
f"Invalid rule name '{name}': must start with lowercase letter "
f"and contain only lowercase letters, numbers, and underscores"
)
# Convert snake_case to PascalCase for class name
# Example: 'window_setback' → 'WindowSetbackRule'
class_name = ''.join(word.capitalize() for word in name.split('_')) + 'Rule'
# Construct module path
module_path = f'apps.rules.impl.{name}'
# Try to import the module
try:
module = importlib.import_module(module_path)
except ImportError as e:
raise ImportError(
f"Cannot load rule type '{rule_type}': module '{module_path}' not found.\n"
f"Hint: Create file 'apps/rules/impl/{name}.py' with class '{class_name}'.\n"
f"Original error: {e}"
) from e
# Try to get the class from the module
try:
rule_class = getattr(module, class_name)
except AttributeError as e:
raise AttributeError(
f"Cannot load rule type '{rule_type}': class '{class_name}' not found in module '{module_path}'.\n"
f"Hint: Define 'class {class_name}(Rule):' in 'apps/rules/impl/{name}.py'.\n"
f"Available classes in module: {[name for name in dir(module) if not name.startswith('_')]}"
) from e
# Validate that it's a Rule subclass
if not issubclass(rule_class, Rule):
raise TypeError(
f"Class '{class_name}' in '{module_path}' is not a subclass of Rule. "
f"Ensure it inherits from apps.rules.rule_interface.Rule"
)
# Cache the class
_RULE_CLASS_CACHE[rule_type] = rule_class
# Instantiate and return
return rule_class()