rules 2
This commit is contained in:
@@ -24,7 +24,7 @@ from apps.rules.rule_interface import (
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -63,11 +63,12 @@ class RuleEngine:
|
||||
self.mqtt_client: MQTTClient | None = None
|
||||
self.redis_state: RedisState | None = None
|
||||
self.context: RuleContext | None = None
|
||||
self._mqtt_topics: list[str] = [] # Topics to subscribe to
|
||||
|
||||
# For graceful shutdown
|
||||
self._shutdown_event = asyncio.Event()
|
||||
|
||||
def setup(self) -> None:
|
||||
async def setup(self) -> None:
|
||||
"""
|
||||
Load configuration and instantiate rules.
|
||||
|
||||
@@ -102,14 +103,55 @@ class RuleEngine:
|
||||
disabled_count = total_count - enabled_count
|
||||
logger.info(f"Successfully loaded {enabled_count} rule implementation(s) ({disabled_count} disabled)")
|
||||
|
||||
# Call setup on each rule for validation
|
||||
for rule_id, rule_instance in self.rules.items():
|
||||
desc = next((d for d in self.rule_descriptors if d.id == rule_id), None)
|
||||
if desc:
|
||||
try:
|
||||
ctx = RuleContext(
|
||||
logger=logger,
|
||||
mqtt_publisher=self.mqtt_client,
|
||||
redis_state=self.redis_state
|
||||
)
|
||||
await rule_instance.setup(desc, ctx)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup rule {rule_id}: {e}")
|
||||
raise
|
||||
|
||||
# Collect MQTT subscriptions from all enabled rules
|
||||
all_topics = set()
|
||||
for rule_id, rule_instance in self.rules.items():
|
||||
desc = next((d for d in self.rule_descriptors if d.id == rule_id), None)
|
||||
if desc:
|
||||
try:
|
||||
topics = rule_instance.get_subscriptions(desc)
|
||||
all_topics.update(topics)
|
||||
logger.debug(f"Rule {rule_id} subscribes to {len(topics)} topic(s)")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get subscriptions for rule {rule_id}: {e}")
|
||||
raise
|
||||
|
||||
logger.info(f"Total MQTT subscriptions needed: {len(all_topics)}")
|
||||
|
||||
# Create unique client ID to avoid conflicts
|
||||
import uuid
|
||||
import os
|
||||
|
||||
client_id_base = "rule_engine"
|
||||
client_suffix = os.environ.get("MQTT_CLIENT_ID_SUFFIX") or uuid.uuid4().hex[:6]
|
||||
unique_client_id = f"{client_id_base}-{client_suffix}"
|
||||
|
||||
# Initialize MQTT client
|
||||
self.mqtt_client = MQTTClient(
|
||||
broker=self.mqtt_broker,
|
||||
port=self.mqtt_port,
|
||||
client_id="rule_engine"
|
||||
client_id=unique_client_id
|
||||
)
|
||||
self.mqtt_client.set_logger(logger)
|
||||
|
||||
# Store topics for connection
|
||||
self._mqtt_topics = list(all_topics)
|
||||
|
||||
# Initialize Redis state
|
||||
self.redis_state = RedisState(self.redis_url)
|
||||
|
||||
@@ -129,10 +171,7 @@ class RuleEngine:
|
||||
"""
|
||||
Filter rules that should receive this event.
|
||||
|
||||
Rules match if:
|
||||
- For contact events: device_id in targets.contacts
|
||||
- For thermostat events: device_id in targets.thermostats
|
||||
- (Room-based filtering could be added here)
|
||||
Rules match if the event's device_id is in the rule's objects.
|
||||
|
||||
Args:
|
||||
event: Normalized MQTT event
|
||||
@@ -149,27 +188,36 @@ class RuleEngine:
|
||||
|
||||
logger.debug(f"Filtering for cap={cap}, device_id={device_id}")
|
||||
|
||||
for rule_id, desc in [(r.id, r) for r in self.rule_descriptors]:
|
||||
targets = desc.targets
|
||||
# Only check enabled rules (rules in self.rules dict)
|
||||
for rule_id, rule_instance in self.rules.items():
|
||||
desc = next((d for d in self.rule_descriptors if d.id == rule_id), None)
|
||||
if not desc:
|
||||
continue
|
||||
|
||||
objects = desc.objects
|
||||
|
||||
# Check if this device is in the rule's targets
|
||||
# Check if this device is in the rule's objects
|
||||
matched = False
|
||||
|
||||
if cap == 'contact' and targets.contacts:
|
||||
logger.debug(f"Rule {rule_id}: checking contacts {targets.contacts}")
|
||||
if device_id in targets.contacts:
|
||||
if cap == 'contact' and objects.get('contacts'):
|
||||
logger.debug(f"Rule {rule_id}: checking contacts {objects.get('contacts')}")
|
||||
if device_id in objects.get('contacts', []):
|
||||
matched = True
|
||||
|
||||
elif cap == 'thermostat' and targets.thermostats:
|
||||
logger.debug(f"Rule {rule_id}: checking thermostats {targets.thermostats}")
|
||||
if device_id in targets.thermostats:
|
||||
elif cap == 'thermostat' and objects.get('thermostats'):
|
||||
logger.debug(f"Rule {rule_id}: checking thermostats {objects.get('thermostats')}")
|
||||
if device_id in objects.get('thermostats', []):
|
||||
matched = True
|
||||
|
||||
# Could add room-based filtering here:
|
||||
# elif 'rooms' in targets:
|
||||
# device_room = get_device_room(device_id)
|
||||
# if device_room in targets['rooms']:
|
||||
# matched = True
|
||||
elif cap == 'light' and objects.get('lights'):
|
||||
logger.debug(f"Rule {rule_id}: checking lights {objects.get('lights')}")
|
||||
if device_id in objects.get('lights', []):
|
||||
matched = True
|
||||
|
||||
elif cap == 'relay' and objects.get('relays'):
|
||||
logger.debug(f"Rule {rule_id}: checking relays {objects.get('relays')}")
|
||||
if device_id in objects.get('relays', []):
|
||||
matched = True
|
||||
|
||||
if matched:
|
||||
matching_rules.append((rule_id, desc))
|
||||
@@ -227,7 +275,7 @@ class RuleEngine:
|
||||
logger.info("Starting event processing loop")
|
||||
|
||||
try:
|
||||
async for event in self.mqtt_client.connect():
|
||||
async for event in self.mqtt_client.connect(topics=self._mqtt_topics):
|
||||
# Check for shutdown
|
||||
if self._shutdown_event.is_set():
|
||||
logger.info("Shutdown signal received, stopping event loop")
|
||||
@@ -284,24 +332,28 @@ async def main_async() -> None:
|
||||
|
||||
# Load rules
|
||||
try:
|
||||
engine.setup()
|
||||
await engine.setup()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup engine: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
# Setup signal handlers for graceful shutdown
|
||||
loop = asyncio.get_running_loop()
|
||||
main_task = None
|
||||
|
||||
def signal_handler():
|
||||
logger.info("Received shutdown signal")
|
||||
asyncio.create_task(engine.shutdown())
|
||||
engine._shutdown_event.set()
|
||||
if main_task and not main_task.done():
|
||||
main_task.cancel()
|
||||
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, signal_handler)
|
||||
|
||||
# Run engine
|
||||
try:
|
||||
await engine.run()
|
||||
main_task = asyncio.create_task(engine.run())
|
||||
await main_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Main task cancelled")
|
||||
finally:
|
||||
|
||||
Reference in New Issue
Block a user