324 lines
10 KiB
Python
324 lines
10 KiB
Python
"""
|
|
Rules Engine
|
|
|
|
Loads rules configuration, subscribes to MQTT events, and dispatches events
|
|
to registered rule implementations.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import signal
|
|
import sys
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
from apps.rules.rules_config import load_rules_config
|
|
from apps.rules.rule_interface import (
|
|
RuleDescriptor,
|
|
RuleContext,
|
|
MQTTClient,
|
|
RedisState,
|
|
load_rule
|
|
)
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RuleEngine:
|
|
"""
|
|
Rule engine that loads rules, subscribes to MQTT events,
|
|
and dispatches them to registered rule implementations.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
rules_config_path: str,
|
|
mqtt_broker: str,
|
|
mqtt_port: int,
|
|
redis_url: str
|
|
):
|
|
"""
|
|
Initialize rule engine.
|
|
|
|
Args:
|
|
rules_config_path: Path to rules.yaml
|
|
mqtt_broker: MQTT broker hostname/IP
|
|
mqtt_port: MQTT broker port
|
|
redis_url: Redis connection URL
|
|
"""
|
|
self.rules_config_path = rules_config_path
|
|
self.mqtt_broker = mqtt_broker
|
|
self.mqtt_port = mqtt_port
|
|
self.redis_url = redis_url
|
|
|
|
# Will be initialized in setup()
|
|
self.rule_descriptors: list[RuleDescriptor] = []
|
|
self.rules: dict[str, Any] = {} # rule_id -> Rule instance
|
|
self.mqtt_client: MQTTClient | None = None
|
|
self.redis_state: RedisState | None = None
|
|
self.context: RuleContext | None = None
|
|
|
|
# For graceful shutdown
|
|
self._shutdown_event = asyncio.Event()
|
|
|
|
def setup(self) -> None:
|
|
"""
|
|
Load configuration and instantiate rules.
|
|
|
|
Raises:
|
|
ImportError: If rule implementation not found
|
|
ValueError: If configuration is invalid
|
|
"""
|
|
logger.info(f"Loading rules configuration from {self.rules_config_path}")
|
|
|
|
# Load rules configuration
|
|
config = load_rules_config(self.rules_config_path)
|
|
self.rule_descriptors = config.rules
|
|
|
|
logger.info(f"Loaded {len(self.rule_descriptors)} rule(s) from configuration")
|
|
|
|
# Instantiate each rule
|
|
for desc in self.rule_descriptors:
|
|
if not desc.enabled:
|
|
logger.info(f" - {desc.id} (type: {desc.type}) [DISABLED]")
|
|
continue
|
|
|
|
try:
|
|
rule_instance = load_rule(desc)
|
|
self.rules[desc.id] = rule_instance
|
|
logger.info(f" - {desc.id} (type: {desc.type})")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load rule {desc.id} (type: {desc.type}): {e}")
|
|
raise
|
|
|
|
enabled_count = len(self.rules)
|
|
total_count = len(self.rule_descriptors)
|
|
disabled_count = total_count - enabled_count
|
|
logger.info(f"Successfully loaded {enabled_count} rule implementation(s) ({disabled_count} disabled)")
|
|
|
|
# Initialize MQTT client
|
|
self.mqtt_client = MQTTClient(
|
|
broker=self.mqtt_broker,
|
|
port=self.mqtt_port,
|
|
client_id="rule_engine"
|
|
)
|
|
self.mqtt_client.set_logger(logger)
|
|
|
|
# Initialize Redis state
|
|
self.redis_state = RedisState(self.redis_url)
|
|
|
|
# Create MQTT publisher wrapper for RuleContext
|
|
from apps.rules.rule_interface import MQTTPublisher
|
|
mqtt_publisher = MQTTPublisher(mqtt_client=self.mqtt_client)
|
|
|
|
# Create rule context
|
|
self.context = RuleContext(
|
|
logger=logger,
|
|
mqtt_publisher=mqtt_publisher,
|
|
redis_state=self.redis_state,
|
|
now_fn=datetime.now
|
|
)
|
|
|
|
def _filter_rules_for_event(self, event: dict[str, Any]) -> list[tuple[str, RuleDescriptor]]:
|
|
"""
|
|
Filter rules that should receive this event.
|
|
|
|
Rules match if:
|
|
- For contact events: device_id in targets.contacts
|
|
- For thermostat events: device_id in targets.thermostats
|
|
- (Room-based filtering could be added here)
|
|
|
|
Args:
|
|
event: Normalized MQTT event
|
|
|
|
Returns:
|
|
List of (rule_id, descriptor) tuples that should process this event
|
|
"""
|
|
matching_rules = []
|
|
device_id = event.get('device_id')
|
|
cap = event.get('cap')
|
|
|
|
if not device_id or not cap:
|
|
return matching_rules
|
|
|
|
logger.debug(f"Filtering for cap={cap}, device_id={device_id}")
|
|
|
|
for rule_id, desc in [(r.id, r) for r in self.rule_descriptors]:
|
|
targets = desc.targets
|
|
|
|
# Check if this device is in the rule's targets
|
|
matched = False
|
|
|
|
if cap == 'contact' and targets.contacts:
|
|
logger.debug(f"Rule {rule_id}: checking contacts {targets.contacts}")
|
|
if device_id in targets.contacts:
|
|
matched = True
|
|
|
|
elif cap == 'thermostat' and targets.thermostats:
|
|
logger.debug(f"Rule {rule_id}: checking thermostats {targets.thermostats}")
|
|
if device_id in targets.thermostats:
|
|
matched = True
|
|
|
|
# Could add room-based filtering here:
|
|
# elif 'rooms' in targets:
|
|
# device_room = get_device_room(device_id)
|
|
# if device_room in targets['rooms']:
|
|
# matched = True
|
|
|
|
if matched:
|
|
matching_rules.append((rule_id, desc))
|
|
|
|
return matching_rules
|
|
|
|
async def _dispatch_event(self, event: dict[str, Any]) -> None:
|
|
"""
|
|
Dispatch event to matching rules.
|
|
|
|
Calls rule.on_event() for each matching rule sequentially
|
|
to preserve order and avoid race conditions.
|
|
|
|
Args:
|
|
event: Normalized MQTT event
|
|
"""
|
|
# Debug logging
|
|
logger.debug(f"Received event: {event}")
|
|
|
|
matching_rules = self._filter_rules_for_event(event)
|
|
|
|
if not matching_rules:
|
|
# No rules interested in this event
|
|
logger.debug(f"No matching rules for {event.get('cap')}/{event.get('device_id')}")
|
|
return
|
|
|
|
logger.info(
|
|
f"Event {event['cap']}/{event['device_id']}: "
|
|
f"{len(matching_rules)} matching rule(s)"
|
|
)
|
|
|
|
# Process rules sequentially to preserve order
|
|
for rule_id, desc in matching_rules:
|
|
rule = self.rules.get(rule_id)
|
|
if not rule:
|
|
logger.warning(f"Rule instance not found for {rule_id}")
|
|
continue
|
|
|
|
try:
|
|
await rule.on_event(event, desc, self.context)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error in rule {rule_id} processing event "
|
|
f"{event['cap']}/{event['device_id']}: {e}",
|
|
exc_info=True
|
|
)
|
|
# Continue with other rules
|
|
|
|
async def run(self) -> None:
|
|
"""
|
|
Main event loop - subscribe to MQTT and process events.
|
|
|
|
Runs until shutdown signal received.
|
|
"""
|
|
logger.info("Starting event processing loop")
|
|
|
|
try:
|
|
async for event in self.mqtt_client.connect():
|
|
# Check for shutdown
|
|
if self._shutdown_event.is_set():
|
|
logger.info("Shutdown signal received, stopping event loop")
|
|
break
|
|
|
|
# Dispatch event to matching rules
|
|
await self._dispatch_event(event)
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info("Event loop cancelled")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Fatal error in event loop: {e}", exc_info=True)
|
|
raise
|
|
|
|
async def shutdown(self) -> None:
|
|
"""Graceful shutdown - close connections."""
|
|
logger.info("Shutting down rule engine...")
|
|
self._shutdown_event.set()
|
|
|
|
if self.redis_state:
|
|
await self.redis_state.close()
|
|
logger.info("Redis connection closed")
|
|
|
|
logger.info("Shutdown complete")
|
|
|
|
|
|
async def main_async() -> None:
|
|
"""Async main function."""
|
|
# Read configuration from environment
|
|
rules_config = os.getenv('RULES_CONFIG', 'config/rules.yaml')
|
|
mqtt_broker = os.getenv('MQTT_BROKER', '172.16.2.16')
|
|
mqtt_port = int(os.getenv('MQTT_PORT', '1883'))
|
|
redis_host = os.getenv('REDIS_HOST', '172.23.1.116')
|
|
redis_port = int(os.getenv('REDIS_PORT', '6379'))
|
|
redis_db = int(os.getenv('REDIS_DB', '8'))
|
|
redis_url = f'redis://{redis_host}:{redis_port}/{redis_db}'
|
|
|
|
logger.info("=" * 60)
|
|
logger.info("Rules Engine Starting")
|
|
logger.info("=" * 60)
|
|
logger.info(f"Config: {rules_config}")
|
|
logger.info(f"MQTT: {mqtt_broker}:{mqtt_port}")
|
|
logger.info(f"Redis: {redis_url}")
|
|
logger.info("=" * 60)
|
|
|
|
# Initialize engine
|
|
engine = RuleEngine(
|
|
rules_config_path=rules_config,
|
|
mqtt_broker=mqtt_broker,
|
|
mqtt_port=mqtt_port,
|
|
redis_url=redis_url
|
|
)
|
|
|
|
# Load rules
|
|
try:
|
|
engine.setup()
|
|
except Exception as e:
|
|
logger.error(f"Failed to setup engine: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
# Setup signal handlers for graceful shutdown
|
|
loop = asyncio.get_running_loop()
|
|
|
|
def signal_handler():
|
|
logger.info("Received shutdown signal")
|
|
asyncio.create_task(engine.shutdown())
|
|
|
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
|
loop.add_signal_handler(sig, signal_handler)
|
|
|
|
# Run engine
|
|
try:
|
|
await engine.run()
|
|
except asyncio.CancelledError:
|
|
logger.info("Main task cancelled")
|
|
finally:
|
|
await engine.shutdown()
|
|
|
|
|
|
def main() -> None:
|
|
"""Entry point for rule engine."""
|
|
try:
|
|
asyncio.run(main_async())
|
|
except KeyboardInterrupt:
|
|
logger.info("Keyboard interrupt received")
|
|
except Exception as e:
|
|
logger.error(f"Fatal error: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|