Files
home-automation/apps/rules/main.py
2025-11-11 19:58:06 +01:00

376 lines
13 KiB
Python

"""
Rules Engine
Loads rules configuration, subscribes to MQTT events, and dispatches events
to registered rule implementations.
"""
import asyncio
import logging
import os
import signal
import sys
from datetime import datetime
from typing import Any
from apps.rules.rules_config import load_rules_config
from apps.rules.rule_interface import (
RuleDescriptor,
RuleContext,
MQTTClient,
RedisState,
load_rule
)
# Configure logging
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
class RuleEngine:
"""
Rule engine that loads rules, subscribes to MQTT events,
and dispatches them to registered rule implementations.
"""
def __init__(
self,
rules_config_path: str,
mqtt_broker: str,
mqtt_port: int,
redis_url: str
):
"""
Initialize rule engine.
Args:
rules_config_path: Path to rules.yaml
mqtt_broker: MQTT broker hostname/IP
mqtt_port: MQTT broker port
redis_url: Redis connection URL
"""
self.rules_config_path = rules_config_path
self.mqtt_broker = mqtt_broker
self.mqtt_port = mqtt_port
self.redis_url = redis_url
# Will be initialized in setup()
self.rule_descriptors: list[RuleDescriptor] = []
self.rules: dict[str, Any] = {} # rule_id -> Rule instance
self.mqtt_client: MQTTClient | None = None
self.redis_state: RedisState | None = None
self.context: RuleContext | None = None
self._mqtt_topics: list[str] = [] # Topics to subscribe to
# For graceful shutdown
self._shutdown_event = asyncio.Event()
async def setup(self) -> None:
"""
Load configuration and instantiate rules.
Raises:
ImportError: If rule implementation not found
ValueError: If configuration is invalid
"""
logger.info(f"Loading rules configuration from {self.rules_config_path}")
# Load rules configuration
config = load_rules_config(self.rules_config_path)
self.rule_descriptors = config.rules
logger.info(f"Loaded {len(self.rule_descriptors)} rule(s) from configuration")
# Instantiate each rule
for desc in self.rule_descriptors:
if not desc.enabled:
logger.info(f" - {desc.id} (type: {desc.type}) [DISABLED]")
continue
try:
rule_instance = load_rule(desc)
self.rules[desc.id] = rule_instance
logger.info(f" - {desc.id} (type: {desc.type})")
except Exception as e:
logger.error(f"Failed to load rule {desc.id} (type: {desc.type}): {e}")
raise
enabled_count = len(self.rules)
total_count = len(self.rule_descriptors)
disabled_count = total_count - enabled_count
logger.info(f"Successfully loaded {enabled_count} rule implementation(s) ({disabled_count} disabled)")
# Call setup on each rule for validation
for rule_id, rule_instance in self.rules.items():
desc = next((d for d in self.rule_descriptors if d.id == rule_id), None)
if desc:
try:
ctx = RuleContext(
logger=logger,
mqtt_publisher=self.mqtt_client,
redis_state=self.redis_state
)
await rule_instance.setup(desc, ctx)
except Exception as e:
logger.error(f"Failed to setup rule {rule_id}: {e}")
raise
# Collect MQTT subscriptions from all enabled rules
all_topics = set()
for rule_id, rule_instance in self.rules.items():
desc = next((d for d in self.rule_descriptors if d.id == rule_id), None)
if desc:
try:
topics = rule_instance.get_subscriptions(desc)
all_topics.update(topics)
logger.debug(f"Rule {rule_id} subscribes to {len(topics)} topic(s)")
except Exception as e:
logger.error(f"Failed to get subscriptions for rule {rule_id}: {e}")
raise
logger.info(f"Total MQTT subscriptions needed: {len(all_topics)}")
# Create unique client ID to avoid conflicts
import uuid
import os
client_id_base = "rule_engine"
client_suffix = os.environ.get("MQTT_CLIENT_ID_SUFFIX") or uuid.uuid4().hex[:6]
unique_client_id = f"{client_id_base}-{client_suffix}"
# Initialize MQTT client
self.mqtt_client = MQTTClient(
broker=self.mqtt_broker,
port=self.mqtt_port,
client_id=unique_client_id
)
self.mqtt_client.set_logger(logger)
# Store topics for connection
self._mqtt_topics = list(all_topics)
# Initialize Redis state
self.redis_state = RedisState(self.redis_url)
# Create MQTT publisher wrapper for RuleContext
from apps.rules.rule_interface import MQTTPublisher
mqtt_publisher = MQTTPublisher(mqtt_client=self.mqtt_client)
# Create rule context
self.context = RuleContext(
logger=logger,
mqtt_publisher=mqtt_publisher,
redis_state=self.redis_state,
now_fn=datetime.now
)
def _filter_rules_for_event(self, event: dict[str, Any]) -> list[tuple[str, RuleDescriptor]]:
"""
Filter rules that should receive this event.
Rules match if the event's device_id is in the rule's objects.
Args:
event: Normalized MQTT event
Returns:
List of (rule_id, descriptor) tuples that should process this event
"""
matching_rules = []
device_id = event.get('device_id')
cap = event.get('cap')
if not device_id or not cap:
return matching_rules
logger.debug(f"Filtering for cap={cap}, device_id={device_id}")
# Only check enabled rules (rules in self.rules dict)
for rule_id, rule_instance in self.rules.items():
desc = next((d for d in self.rule_descriptors if d.id == rule_id), None)
if not desc:
continue
objects = desc.objects
# Check if this device is in the rule's objects
matched = False
if cap == 'contact' and objects.get('contacts'):
logger.debug(f"Rule {rule_id}: checking contacts {objects.get('contacts')}")
if device_id in objects.get('contacts', []):
matched = True
elif cap == 'thermostat' and objects.get('thermostats'):
logger.debug(f"Rule {rule_id}: checking thermostats {objects.get('thermostats')}")
if device_id in objects.get('thermostats', []):
matched = True
elif cap == 'light' and objects.get('lights'):
logger.debug(f"Rule {rule_id}: checking lights {objects.get('lights')}")
if device_id in objects.get('lights', []):
matched = True
elif cap == 'relay' and objects.get('relays'):
logger.debug(f"Rule {rule_id}: checking relays {objects.get('relays')}")
if device_id in objects.get('relays', []):
matched = True
if matched:
matching_rules.append((rule_id, desc))
return matching_rules
async def _dispatch_event(self, event: dict[str, Any]) -> None:
"""
Dispatch event to matching rules.
Calls rule.on_event() for each matching rule sequentially
to preserve order and avoid race conditions.
Args:
event: Normalized MQTT event
"""
# Debug logging
logger.debug(f"Received event: {event}")
matching_rules = self._filter_rules_for_event(event)
if not matching_rules:
# No rules interested in this event
logger.debug(f"No matching rules for {event.get('cap')}/{event.get('device_id')}")
return
logger.info(
f"Event {event['cap']}/{event['device_id']}: "
f"{len(matching_rules)} matching rule(s)"
)
# Process rules sequentially to preserve order
for rule_id, desc in matching_rules:
rule = self.rules.get(rule_id)
if not rule:
logger.warning(f"Rule instance not found for {rule_id}")
continue
try:
await rule.on_event(event, desc, self.context)
except Exception as e:
logger.error(
f"Error in rule {rule_id} processing event "
f"{event['cap']}/{event['device_id']}: {e}",
exc_info=True
)
# Continue with other rules
async def run(self) -> None:
"""
Main event loop - subscribe to MQTT and process events.
Runs until shutdown signal received.
"""
logger.info("Starting event processing loop")
try:
async for event in self.mqtt_client.connect(topics=self._mqtt_topics):
# Check for shutdown
if self._shutdown_event.is_set():
logger.info("Shutdown signal received, stopping event loop")
break
# Dispatch event to matching rules
await self._dispatch_event(event)
except asyncio.CancelledError:
logger.info("Event loop cancelled")
raise
except Exception as e:
logger.error(f"Fatal error in event loop: {e}", exc_info=True)
raise
async def shutdown(self) -> None:
"""Graceful shutdown - close connections."""
logger.info("Shutting down rule engine...")
self._shutdown_event.set()
if self.redis_state:
await self.redis_state.close()
logger.info("Redis connection closed")
logger.info("Shutdown complete")
async def main_async() -> None:
"""Async main function."""
# Read configuration from environment
rules_config = os.getenv('RULES_CONFIG', 'config/rules.yaml')
mqtt_broker = os.getenv('MQTT_BROKER', '172.16.2.16')
mqtt_port = int(os.getenv('MQTT_PORT', '1883'))
redis_host = os.getenv('REDIS_HOST', '172.23.1.116')
redis_port = int(os.getenv('REDIS_PORT', '6379'))
redis_db = int(os.getenv('REDIS_DB', '8'))
redis_url = f'redis://{redis_host}:{redis_port}/{redis_db}'
logger.info("=" * 60)
logger.info("Rules Engine Starting")
logger.info("=" * 60)
logger.info(f"Config: {rules_config}")
logger.info(f"MQTT: {mqtt_broker}:{mqtt_port}")
logger.info(f"Redis: {redis_url}")
logger.info("=" * 60)
# Initialize engine
engine = RuleEngine(
rules_config_path=rules_config,
mqtt_broker=mqtt_broker,
mqtt_port=mqtt_port,
redis_url=redis_url
)
# Load rules
try:
await engine.setup()
except Exception as e:
logger.error(f"Failed to setup engine: {e}", exc_info=True)
sys.exit(1)
# Setup signal handlers for graceful shutdown
loop = asyncio.get_running_loop()
main_task = None
def signal_handler():
logger.info("Received shutdown signal")
engine._shutdown_event.set()
if main_task and not main_task.done():
main_task.cancel()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
# Run engine
try:
main_task = asyncio.create_task(engine.run())
await main_task
except asyncio.CancelledError:
logger.info("Main task cancelled")
finally:
await engine.shutdown()
def main() -> None:
"""Entry point for rule engine."""
try:
asyncio.run(main_async())
except KeyboardInterrupt:
logger.info("Keyboard interrupt received")
except Exception as e:
logger.error(f"Fatal error: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()