376 lines
13 KiB
Python
376 lines
13 KiB
Python
"""
|
|
Rules Engine
|
|
|
|
Loads rules configuration, subscribes to MQTT events, and dispatches events
|
|
to registered rule implementations.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import signal
|
|
import sys
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
from apps.rules.rules_config import load_rules_config
|
|
from apps.rules.rule_interface import (
|
|
RuleDescriptor,
|
|
RuleContext,
|
|
MQTTClient,
|
|
RedisState,
|
|
load_rule
|
|
)
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.DEBUG,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RuleEngine:
|
|
"""
|
|
Rule engine that loads rules, subscribes to MQTT events,
|
|
and dispatches them to registered rule implementations.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
rules_config_path: str,
|
|
mqtt_broker: str,
|
|
mqtt_port: int,
|
|
redis_url: str
|
|
):
|
|
"""
|
|
Initialize rule engine.
|
|
|
|
Args:
|
|
rules_config_path: Path to rules.yaml
|
|
mqtt_broker: MQTT broker hostname/IP
|
|
mqtt_port: MQTT broker port
|
|
redis_url: Redis connection URL
|
|
"""
|
|
self.rules_config_path = rules_config_path
|
|
self.mqtt_broker = mqtt_broker
|
|
self.mqtt_port = mqtt_port
|
|
self.redis_url = redis_url
|
|
|
|
# Will be initialized in setup()
|
|
self.rule_descriptors: list[RuleDescriptor] = []
|
|
self.rules: dict[str, Any] = {} # rule_id -> Rule instance
|
|
self.mqtt_client: MQTTClient | None = None
|
|
self.redis_state: RedisState | None = None
|
|
self.context: RuleContext | None = None
|
|
self._mqtt_topics: list[str] = [] # Topics to subscribe to
|
|
|
|
# For graceful shutdown
|
|
self._shutdown_event = asyncio.Event()
|
|
|
|
async def setup(self) -> None:
|
|
"""
|
|
Load configuration and instantiate rules.
|
|
|
|
Raises:
|
|
ImportError: If rule implementation not found
|
|
ValueError: If configuration is invalid
|
|
"""
|
|
logger.info(f"Loading rules configuration from {self.rules_config_path}")
|
|
|
|
# Load rules configuration
|
|
config = load_rules_config(self.rules_config_path)
|
|
self.rule_descriptors = config.rules
|
|
|
|
logger.info(f"Loaded {len(self.rule_descriptors)} rule(s) from configuration")
|
|
|
|
# Instantiate each rule
|
|
for desc in self.rule_descriptors:
|
|
if not desc.enabled:
|
|
logger.info(f" - {desc.id} (type: {desc.type}) [DISABLED]")
|
|
continue
|
|
|
|
try:
|
|
rule_instance = load_rule(desc)
|
|
self.rules[desc.id] = rule_instance
|
|
logger.info(f" - {desc.id} (type: {desc.type})")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load rule {desc.id} (type: {desc.type}): {e}")
|
|
raise
|
|
|
|
enabled_count = len(self.rules)
|
|
total_count = len(self.rule_descriptors)
|
|
disabled_count = total_count - enabled_count
|
|
logger.info(f"Successfully loaded {enabled_count} rule implementation(s) ({disabled_count} disabled)")
|
|
|
|
# Call setup on each rule for validation
|
|
for rule_id, rule_instance in self.rules.items():
|
|
desc = next((d for d in self.rule_descriptors if d.id == rule_id), None)
|
|
if desc:
|
|
try:
|
|
ctx = RuleContext(
|
|
logger=logger,
|
|
mqtt_publisher=self.mqtt_client,
|
|
redis_state=self.redis_state
|
|
)
|
|
await rule_instance.setup(desc, ctx)
|
|
except Exception as e:
|
|
logger.error(f"Failed to setup rule {rule_id}: {e}")
|
|
raise
|
|
|
|
# Collect MQTT subscriptions from all enabled rules
|
|
all_topics = set()
|
|
for rule_id, rule_instance in self.rules.items():
|
|
desc = next((d for d in self.rule_descriptors if d.id == rule_id), None)
|
|
if desc:
|
|
try:
|
|
topics = rule_instance.get_subscriptions(desc)
|
|
all_topics.update(topics)
|
|
logger.debug(f"Rule {rule_id} subscribes to {len(topics)} topic(s)")
|
|
except Exception as e:
|
|
logger.error(f"Failed to get subscriptions for rule {rule_id}: {e}")
|
|
raise
|
|
|
|
logger.info(f"Total MQTT subscriptions needed: {len(all_topics)}")
|
|
|
|
# Create unique client ID to avoid conflicts
|
|
import uuid
|
|
import os
|
|
|
|
client_id_base = "rule_engine"
|
|
client_suffix = os.environ.get("MQTT_CLIENT_ID_SUFFIX") or uuid.uuid4().hex[:6]
|
|
unique_client_id = f"{client_id_base}-{client_suffix}"
|
|
|
|
# Initialize MQTT client
|
|
self.mqtt_client = MQTTClient(
|
|
broker=self.mqtt_broker,
|
|
port=self.mqtt_port,
|
|
client_id=unique_client_id
|
|
)
|
|
self.mqtt_client.set_logger(logger)
|
|
|
|
# Store topics for connection
|
|
self._mqtt_topics = list(all_topics)
|
|
|
|
# Initialize Redis state
|
|
self.redis_state = RedisState(self.redis_url)
|
|
|
|
# Create MQTT publisher wrapper for RuleContext
|
|
from apps.rules.rule_interface import MQTTPublisher
|
|
mqtt_publisher = MQTTPublisher(mqtt_client=self.mqtt_client)
|
|
|
|
# Create rule context
|
|
self.context = RuleContext(
|
|
logger=logger,
|
|
mqtt_publisher=mqtt_publisher,
|
|
redis_state=self.redis_state,
|
|
now_fn=datetime.now
|
|
)
|
|
|
|
def _filter_rules_for_event(self, event: dict[str, Any]) -> list[tuple[str, RuleDescriptor]]:
|
|
"""
|
|
Filter rules that should receive this event.
|
|
|
|
Rules match if the event's device_id is in the rule's objects.
|
|
|
|
Args:
|
|
event: Normalized MQTT event
|
|
|
|
Returns:
|
|
List of (rule_id, descriptor) tuples that should process this event
|
|
"""
|
|
matching_rules = []
|
|
device_id = event.get('device_id')
|
|
cap = event.get('cap')
|
|
|
|
if not device_id or not cap:
|
|
return matching_rules
|
|
|
|
logger.debug(f"Filtering for cap={cap}, device_id={device_id}")
|
|
|
|
# Only check enabled rules (rules in self.rules dict)
|
|
for rule_id, rule_instance in self.rules.items():
|
|
desc = next((d for d in self.rule_descriptors if d.id == rule_id), None)
|
|
if not desc:
|
|
continue
|
|
|
|
objects = desc.objects
|
|
|
|
# Check if this device is in the rule's objects
|
|
matched = False
|
|
|
|
if cap == 'contact' and objects.get('contacts'):
|
|
logger.debug(f"Rule {rule_id}: checking contacts {objects.get('contacts')}")
|
|
if device_id in objects.get('contacts', []):
|
|
matched = True
|
|
|
|
elif cap == 'thermostat' and objects.get('thermostats'):
|
|
logger.debug(f"Rule {rule_id}: checking thermostats {objects.get('thermostats')}")
|
|
if device_id in objects.get('thermostats', []):
|
|
matched = True
|
|
|
|
elif cap == 'light' and objects.get('lights'):
|
|
logger.debug(f"Rule {rule_id}: checking lights {objects.get('lights')}")
|
|
if device_id in objects.get('lights', []):
|
|
matched = True
|
|
|
|
elif cap == 'relay' and objects.get('relays'):
|
|
logger.debug(f"Rule {rule_id}: checking relays {objects.get('relays')}")
|
|
if device_id in objects.get('relays', []):
|
|
matched = True
|
|
|
|
if matched:
|
|
matching_rules.append((rule_id, desc))
|
|
|
|
return matching_rules
|
|
|
|
async def _dispatch_event(self, event: dict[str, Any]) -> None:
|
|
"""
|
|
Dispatch event to matching rules.
|
|
|
|
Calls rule.on_event() for each matching rule sequentially
|
|
to preserve order and avoid race conditions.
|
|
|
|
Args:
|
|
event: Normalized MQTT event
|
|
"""
|
|
# Debug logging
|
|
logger.debug(f"Received event: {event}")
|
|
|
|
matching_rules = self._filter_rules_for_event(event)
|
|
|
|
if not matching_rules:
|
|
# No rules interested in this event
|
|
logger.debug(f"No matching rules for {event.get('cap')}/{event.get('device_id')}")
|
|
return
|
|
|
|
logger.info(
|
|
f"Event {event['cap']}/{event['device_id']}: "
|
|
f"{len(matching_rules)} matching rule(s)"
|
|
)
|
|
|
|
# Process rules sequentially to preserve order
|
|
for rule_id, desc in matching_rules:
|
|
rule = self.rules.get(rule_id)
|
|
if not rule:
|
|
logger.warning(f"Rule instance not found for {rule_id}")
|
|
continue
|
|
|
|
try:
|
|
await rule.on_event(event, desc, self.context)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error in rule {rule_id} processing event "
|
|
f"{event['cap']}/{event['device_id']}: {e}",
|
|
exc_info=True
|
|
)
|
|
# Continue with other rules
|
|
|
|
async def run(self) -> None:
|
|
"""
|
|
Main event loop - subscribe to MQTT and process events.
|
|
|
|
Runs until shutdown signal received.
|
|
"""
|
|
logger.info("Starting event processing loop")
|
|
|
|
try:
|
|
async for event in self.mqtt_client.connect(topics=self._mqtt_topics):
|
|
# Check for shutdown
|
|
if self._shutdown_event.is_set():
|
|
logger.info("Shutdown signal received, stopping event loop")
|
|
break
|
|
|
|
# Dispatch event to matching rules
|
|
await self._dispatch_event(event)
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info("Event loop cancelled")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Fatal error in event loop: {e}", exc_info=True)
|
|
raise
|
|
|
|
async def shutdown(self) -> None:
|
|
"""Graceful shutdown - close connections."""
|
|
logger.info("Shutting down rule engine...")
|
|
self._shutdown_event.set()
|
|
|
|
if self.redis_state:
|
|
await self.redis_state.close()
|
|
logger.info("Redis connection closed")
|
|
|
|
logger.info("Shutdown complete")
|
|
|
|
|
|
async def main_async() -> None:
|
|
"""Async main function."""
|
|
# Read configuration from environment
|
|
rules_config = os.getenv('RULES_CONFIG', 'config/rules.yaml')
|
|
mqtt_broker = os.getenv('MQTT_BROKER', '172.16.2.16')
|
|
mqtt_port = int(os.getenv('MQTT_PORT', '1883'))
|
|
redis_host = os.getenv('REDIS_HOST', '172.23.1.116')
|
|
redis_port = int(os.getenv('REDIS_PORT', '6379'))
|
|
redis_db = int(os.getenv('REDIS_DB', '8'))
|
|
redis_url = f'redis://{redis_host}:{redis_port}/{redis_db}'
|
|
|
|
logger.info("=" * 60)
|
|
logger.info("Rules Engine Starting")
|
|
logger.info("=" * 60)
|
|
logger.info(f"Config: {rules_config}")
|
|
logger.info(f"MQTT: {mqtt_broker}:{mqtt_port}")
|
|
logger.info(f"Redis: {redis_url}")
|
|
logger.info("=" * 60)
|
|
|
|
# Initialize engine
|
|
engine = RuleEngine(
|
|
rules_config_path=rules_config,
|
|
mqtt_broker=mqtt_broker,
|
|
mqtt_port=mqtt_port,
|
|
redis_url=redis_url
|
|
)
|
|
|
|
# Load rules
|
|
try:
|
|
await engine.setup()
|
|
except Exception as e:
|
|
logger.error(f"Failed to setup engine: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
# Setup signal handlers for graceful shutdown
|
|
loop = asyncio.get_running_loop()
|
|
main_task = None
|
|
|
|
def signal_handler():
|
|
logger.info("Received shutdown signal")
|
|
engine._shutdown_event.set()
|
|
if main_task and not main_task.done():
|
|
main_task.cancel()
|
|
|
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
|
loop.add_signal_handler(sig, signal_handler)
|
|
|
|
# Run engine
|
|
try:
|
|
main_task = asyncio.create_task(engine.run())
|
|
await main_task
|
|
except asyncio.CancelledError:
|
|
logger.info("Main task cancelled")
|
|
finally:
|
|
await engine.shutdown()
|
|
|
|
|
|
def main() -> None:
|
|
"""Entry point for rule engine."""
|
|
try:
|
|
asyncio.run(main_async())
|
|
except KeyboardInterrupt:
|
|
logger.info("Keyboard interrupt received")
|
|
except Exception as e:
|
|
logger.error(f"Fatal error: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|