This commit is contained in:
2025-11-11 19:58:06 +01:00
parent d3d96ed3e9
commit b6b441c0ca
5 changed files with 245 additions and 89 deletions

View File

@@ -24,7 +24,7 @@ from apps.rules.rule_interface import (
# Configure logging
logging.basicConfig(
level=logging.INFO,
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
@@ -63,11 +63,12 @@ class RuleEngine:
self.mqtt_client: MQTTClient | None = None
self.redis_state: RedisState | None = None
self.context: RuleContext | None = None
self._mqtt_topics: list[str] = [] # Topics to subscribe to
# For graceful shutdown
self._shutdown_event = asyncio.Event()
def setup(self) -> None:
async def setup(self) -> None:
"""
Load configuration and instantiate rules.
@@ -102,14 +103,55 @@ class RuleEngine:
disabled_count = total_count - enabled_count
logger.info(f"Successfully loaded {enabled_count} rule implementation(s) ({disabled_count} disabled)")
# Call setup on each rule for validation
for rule_id, rule_instance in self.rules.items():
desc = next((d for d in self.rule_descriptors if d.id == rule_id), None)
if desc:
try:
ctx = RuleContext(
logger=logger,
mqtt_publisher=self.mqtt_client,
redis_state=self.redis_state
)
await rule_instance.setup(desc, ctx)
except Exception as e:
logger.error(f"Failed to setup rule {rule_id}: {e}")
raise
# Collect MQTT subscriptions from all enabled rules
all_topics = set()
for rule_id, rule_instance in self.rules.items():
desc = next((d for d in self.rule_descriptors if d.id == rule_id), None)
if desc:
try:
topics = rule_instance.get_subscriptions(desc)
all_topics.update(topics)
logger.debug(f"Rule {rule_id} subscribes to {len(topics)} topic(s)")
except Exception as e:
logger.error(f"Failed to get subscriptions for rule {rule_id}: {e}")
raise
logger.info(f"Total MQTT subscriptions needed: {len(all_topics)}")
# Create unique client ID to avoid conflicts
import uuid
import os
client_id_base = "rule_engine"
client_suffix = os.environ.get("MQTT_CLIENT_ID_SUFFIX") or uuid.uuid4().hex[:6]
unique_client_id = f"{client_id_base}-{client_suffix}"
# Initialize MQTT client
self.mqtt_client = MQTTClient(
broker=self.mqtt_broker,
port=self.mqtt_port,
client_id="rule_engine"
client_id=unique_client_id
)
self.mqtt_client.set_logger(logger)
# Store topics for connection
self._mqtt_topics = list(all_topics)
# Initialize Redis state
self.redis_state = RedisState(self.redis_url)
@@ -129,10 +171,7 @@ class RuleEngine:
"""
Filter rules that should receive this event.
Rules match if:
- For contact events: device_id in targets.contacts
- For thermostat events: device_id in targets.thermostats
- (Room-based filtering could be added here)
Rules match if the event's device_id is in the rule's objects.
Args:
event: Normalized MQTT event
@@ -149,27 +188,36 @@ class RuleEngine:
logger.debug(f"Filtering for cap={cap}, device_id={device_id}")
for rule_id, desc in [(r.id, r) for r in self.rule_descriptors]:
targets = desc.targets
# Only check enabled rules (rules in self.rules dict)
for rule_id, rule_instance in self.rules.items():
desc = next((d for d in self.rule_descriptors if d.id == rule_id), None)
if not desc:
continue
objects = desc.objects
# Check if this device is in the rule's targets
# Check if this device is in the rule's objects
matched = False
if cap == 'contact' and targets.contacts:
logger.debug(f"Rule {rule_id}: checking contacts {targets.contacts}")
if device_id in targets.contacts:
if cap == 'contact' and objects.get('contacts'):
logger.debug(f"Rule {rule_id}: checking contacts {objects.get('contacts')}")
if device_id in objects.get('contacts', []):
matched = True
elif cap == 'thermostat' and targets.thermostats:
logger.debug(f"Rule {rule_id}: checking thermostats {targets.thermostats}")
if device_id in targets.thermostats:
elif cap == 'thermostat' and objects.get('thermostats'):
logger.debug(f"Rule {rule_id}: checking thermostats {objects.get('thermostats')}")
if device_id in objects.get('thermostats', []):
matched = True
# Could add room-based filtering here:
# elif 'rooms' in targets:
# device_room = get_device_room(device_id)
# if device_room in targets['rooms']:
# matched = True
elif cap == 'light' and objects.get('lights'):
logger.debug(f"Rule {rule_id}: checking lights {objects.get('lights')}")
if device_id in objects.get('lights', []):
matched = True
elif cap == 'relay' and objects.get('relays'):
logger.debug(f"Rule {rule_id}: checking relays {objects.get('relays')}")
if device_id in objects.get('relays', []):
matched = True
if matched:
matching_rules.append((rule_id, desc))
@@ -227,7 +275,7 @@ class RuleEngine:
logger.info("Starting event processing loop")
try:
async for event in self.mqtt_client.connect():
async for event in self.mqtt_client.connect(topics=self._mqtt_topics):
# Check for shutdown
if self._shutdown_event.is_set():
logger.info("Shutdown signal received, stopping event loop")
@@ -284,24 +332,28 @@ async def main_async() -> None:
# Load rules
try:
engine.setup()
await engine.setup()
except Exception as e:
logger.error(f"Failed to setup engine: {e}", exc_info=True)
sys.exit(1)
# Setup signal handlers for graceful shutdown
loop = asyncio.get_running_loop()
main_task = None
def signal_handler():
logger.info("Received shutdown signal")
asyncio.create_task(engine.shutdown())
engine._shutdown_event.set()
if main_task and not main_task.done():
main_task.cancel()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
# Run engine
try:
await engine.run()
main_task = asyncio.create_task(engine.run())
await main_task
except asyncio.CancelledError:
logger.info("Main task cancelled")
finally: