143 lines
4.4 KiB
Python
143 lines
4.4 KiB
Python
"""Redis Pub/Sub implementation of event publishing/subscribing."""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Any, AsyncIterator
|
|
|
|
import redis.asyncio as redis
|
|
|
|
from .base import Event, EventPublisher, EventSubscriber
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RedisPubSubPublisher(EventPublisher):
|
|
"""Redis Pub/Sub based event publisher."""
|
|
|
|
def __init__(
|
|
self,
|
|
redis_url: str = "redis://localhost:6379",
|
|
source: str = "",
|
|
):
|
|
self.redis_url = redis_url
|
|
self.source = source
|
|
self._client: redis.Redis | None = None
|
|
|
|
async def connect(self) -> None:
|
|
self._client = redis.from_url(self.redis_url, decode_responses=True)
|
|
await self._client.ping()
|
|
logger.info(f"Connected to Redis at {self.redis_url}")
|
|
|
|
async def disconnect(self) -> None:
|
|
if self._client:
|
|
await self._client.close()
|
|
self._client = None
|
|
logger.info("Disconnected from Redis")
|
|
|
|
async def publish(self, topic: str, payload: dict[str, Any], **kwargs) -> str:
|
|
if not self._client:
|
|
raise RuntimeError("Publisher not connected")
|
|
|
|
event = Event(
|
|
topic=topic,
|
|
payload=payload,
|
|
event_id=kwargs.get("event_id", None)
|
|
or Event(topic="", payload={}).event_id,
|
|
source=self.source,
|
|
)
|
|
|
|
message = json.dumps(event.to_dict())
|
|
await self._client.publish(topic, message)
|
|
|
|
logger.debug(f"Published event {event.event_id} to {topic}")
|
|
return event.event_id
|
|
|
|
|
|
class RedisPubSubSubscriber(EventSubscriber):
|
|
"""Redis Pub/Sub based event subscriber."""
|
|
|
|
def __init__(
|
|
self,
|
|
redis_url: str = "redis://localhost:6379",
|
|
topics: list[str] | None = None,
|
|
):
|
|
self.redis_url = redis_url
|
|
self._topics = topics or []
|
|
self._client: redis.Redis | None = None
|
|
self._pubsub: redis.client.PubSub | None = None
|
|
self._running = False
|
|
|
|
async def connect(self) -> None:
|
|
self._client = redis.from_url(self.redis_url, decode_responses=True)
|
|
await self._client.ping()
|
|
self._pubsub = self._client.pubsub()
|
|
logger.info(f"Connected to Redis at {self.redis_url}")
|
|
|
|
if self._topics:
|
|
await self.subscribe(self._topics)
|
|
|
|
async def disconnect(self) -> None:
|
|
self._running = False
|
|
if self._pubsub:
|
|
await self._pubsub.unsubscribe()
|
|
await self._pubsub.close()
|
|
self._pubsub = None
|
|
if self._client:
|
|
await self._client.close()
|
|
self._client = None
|
|
logger.info("Disconnected from Redis")
|
|
|
|
async def subscribe(self, topics: list[str]) -> None:
|
|
if not self._pubsub:
|
|
raise RuntimeError("Subscriber not connected")
|
|
|
|
# Separate pattern subscriptions from regular ones
|
|
patterns = [t for t in topics if "*" in t]
|
|
channels = [t for t in topics if "*" not in t]
|
|
|
|
if channels:
|
|
await self._pubsub.subscribe(*channels)
|
|
logger.info(f"Subscribed to channels: {channels}")
|
|
|
|
if patterns:
|
|
await self._pubsub.psubscribe(*patterns)
|
|
logger.info(f"Subscribed to patterns: {patterns}")
|
|
|
|
self._topics.extend(topics)
|
|
|
|
async def consume(self) -> AsyncIterator[Event]:
|
|
if not self._pubsub:
|
|
raise RuntimeError("Subscriber not connected")
|
|
|
|
self._running = True
|
|
|
|
while self._running:
|
|
try:
|
|
message = await self._pubsub.get_message(
|
|
ignore_subscribe_messages=True,
|
|
timeout=1.0,
|
|
)
|
|
|
|
if message is None:
|
|
await asyncio.sleep(0.01)
|
|
continue
|
|
|
|
if message["type"] not in ("message", "pmessage"):
|
|
continue
|
|
|
|
try:
|
|
data = json.loads(message["data"])
|
|
event = Event.from_dict(data)
|
|
yield event
|
|
except (json.JSONDecodeError, KeyError) as e:
|
|
logger.warning(f"Failed to parse event: {e}")
|
|
continue
|
|
|
|
except asyncio.CancelledError:
|
|
self._running = False
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error consuming events: {e}")
|
|
await asyncio.sleep(1.0)
|