"""Redis Pub/Sub implementation of event publishing/subscribing.""" import asyncio import json import logging from typing import Any, AsyncIterator import redis.asyncio as redis from .base import Event, EventPublisher, EventSubscriber logger = logging.getLogger(__name__) class RedisPubSubPublisher(EventPublisher): """Redis Pub/Sub based event publisher.""" def __init__( self, redis_url: str = "redis://localhost:6379", source: str = "", ): self.redis_url = redis_url self.source = source self._client: redis.Redis | None = None async def connect(self) -> None: self._client = redis.from_url(self.redis_url, decode_responses=True) await self._client.ping() logger.info(f"Connected to Redis at {self.redis_url}") async def disconnect(self) -> None: if self._client: await self._client.close() self._client = None logger.info("Disconnected from Redis") async def publish(self, topic: str, payload: dict[str, Any], **kwargs) -> str: if not self._client: raise RuntimeError("Publisher not connected") event = Event( topic=topic, payload=payload, event_id=kwargs.get("event_id", None) or Event(topic="", payload={}).event_id, source=self.source, ) message = json.dumps(event.to_dict()) await self._client.publish(topic, message) logger.debug(f"Published event {event.event_id} to {topic}") return event.event_id class RedisPubSubSubscriber(EventSubscriber): """Redis Pub/Sub based event subscriber.""" def __init__( self, redis_url: str = "redis://localhost:6379", topics: list[str] | None = None, ): self.redis_url = redis_url self._topics = topics or [] self._client: redis.Redis | None = None self._pubsub: redis.client.PubSub | None = None self._running = False async def connect(self) -> None: self._client = redis.from_url(self.redis_url, decode_responses=True) await self._client.ping() self._pubsub = self._client.pubsub() logger.info(f"Connected to Redis at {self.redis_url}") if self._topics: await self.subscribe(self._topics) async def disconnect(self) -> None: self._running = False if self._pubsub: await self._pubsub.unsubscribe() await self._pubsub.close() self._pubsub = None if self._client: await self._client.close() self._client = None logger.info("Disconnected from Redis") async def subscribe(self, topics: list[str]) -> None: if not self._pubsub: raise RuntimeError("Subscriber not connected") # Separate pattern subscriptions from regular ones patterns = [t for t in topics if "*" in t] channels = [t for t in topics if "*" not in t] if channels: await self._pubsub.subscribe(*channels) logger.info(f"Subscribed to channels: {channels}") if patterns: await self._pubsub.psubscribe(*patterns) logger.info(f"Subscribed to patterns: {patterns}") self._topics.extend(topics) async def _reconnect(self) -> None: """Attempt to reconnect to Redis.""" max_retries = 10 base_delay = 1.0 for attempt in range(max_retries): try: # Clean up old connections if self._pubsub: try: await self._pubsub.close() except Exception: pass if self._client: try: await self._client.close() except Exception: pass # Reconnect self._client = redis.from_url(self.redis_url, decode_responses=True) await self._client.ping() self._pubsub = self._client.pubsub() # Re-subscribe to topics if self._topics: patterns = [t for t in self._topics if "*" in t] channels = [t for t in self._topics if "*" not in t] if channels: await self._pubsub.subscribe(*channels) if patterns: await self._pubsub.psubscribe(*patterns) logger.info(f"Reconnected to Redis at {self.redis_url}") return except Exception as e: delay = min(base_delay * (2**attempt), 30.0) logger.warning( f"Reconnect attempt {attempt + 1} failed: {e}, retrying in {delay}s" ) await asyncio.sleep(delay) raise RuntimeError(f"Failed to reconnect to Redis after {max_retries} attempts") async def consume(self) -> AsyncIterator[Event]: if not self._pubsub: raise RuntimeError("Subscriber not connected") self._running = True while self._running: try: message = await self._pubsub.get_message( ignore_subscribe_messages=True, timeout=1.0, ) if message is None: await asyncio.sleep(0.01) continue if message["type"] not in ("message", "pmessage"): continue try: data = json.loads(message["data"]) event = Event.from_dict(data) yield event except (json.JSONDecodeError, KeyError) as e: logger.warning(f"Failed to parse event: {e}") continue except asyncio.CancelledError: self._running = False break except Exception as e: logger.error(f"Error consuming events: {e}, attempting reconnect...") try: await self._reconnect() except Exception as reconnect_err: logger.error(f"Reconnect failed: {reconnect_err}") await asyncio.sleep(5.0)