first claude draft
This commit is contained in:
142
shared/events/redis_pubsub.py
Normal file
142
shared/events/redis_pubsub.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Redis Pub/Sub implementation of event publishing/subscribing."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
import redis.asyncio as redis
|
||||
|
||||
from .base import Event, EventPublisher, EventSubscriber
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisPubSubPublisher(EventPublisher):
|
||||
"""Redis Pub/Sub based event publisher."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_url: str = "redis://localhost:6379",
|
||||
source: str = "",
|
||||
):
|
||||
self.redis_url = redis_url
|
||||
self.source = source
|
||||
self._client: redis.Redis | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
self._client = redis.from_url(self.redis_url, decode_responses=True)
|
||||
await self._client.ping()
|
||||
logger.info(f"Connected to Redis at {self.redis_url}")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
logger.info("Disconnected from Redis")
|
||||
|
||||
async def publish(self, topic: str, payload: dict[str, Any], **kwargs) -> str:
|
||||
if not self._client:
|
||||
raise RuntimeError("Publisher not connected")
|
||||
|
||||
event = Event(
|
||||
topic=topic,
|
||||
payload=payload,
|
||||
event_id=kwargs.get("event_id", None)
|
||||
or Event(topic="", payload={}).event_id,
|
||||
source=self.source,
|
||||
)
|
||||
|
||||
message = json.dumps(event.to_dict())
|
||||
await self._client.publish(topic, message)
|
||||
|
||||
logger.debug(f"Published event {event.event_id} to {topic}")
|
||||
return event.event_id
|
||||
|
||||
|
||||
class RedisPubSubSubscriber(EventSubscriber):
|
||||
"""Redis Pub/Sub based event subscriber."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_url: str = "redis://localhost:6379",
|
||||
topics: list[str] | None = None,
|
||||
):
|
||||
self.redis_url = redis_url
|
||||
self._topics = topics or []
|
||||
self._client: redis.Redis | None = None
|
||||
self._pubsub: redis.client.PubSub | None = None
|
||||
self._running = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
self._client = redis.from_url(self.redis_url, decode_responses=True)
|
||||
await self._client.ping()
|
||||
self._pubsub = self._client.pubsub()
|
||||
logger.info(f"Connected to Redis at {self.redis_url}")
|
||||
|
||||
if self._topics:
|
||||
await self.subscribe(self._topics)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._running = False
|
||||
if self._pubsub:
|
||||
await self._pubsub.unsubscribe()
|
||||
await self._pubsub.close()
|
||||
self._pubsub = None
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
logger.info("Disconnected from Redis")
|
||||
|
||||
async def subscribe(self, topics: list[str]) -> None:
|
||||
if not self._pubsub:
|
||||
raise RuntimeError("Subscriber not connected")
|
||||
|
||||
# Separate pattern subscriptions from regular ones
|
||||
patterns = [t for t in topics if "*" in t]
|
||||
channels = [t for t in topics if "*" not in t]
|
||||
|
||||
if channels:
|
||||
await self._pubsub.subscribe(*channels)
|
||||
logger.info(f"Subscribed to channels: {channels}")
|
||||
|
||||
if patterns:
|
||||
await self._pubsub.psubscribe(*patterns)
|
||||
logger.info(f"Subscribed to patterns: {patterns}")
|
||||
|
||||
self._topics.extend(topics)
|
||||
|
||||
async def consume(self) -> AsyncIterator[Event]:
|
||||
if not self._pubsub:
|
||||
raise RuntimeError("Subscriber not connected")
|
||||
|
||||
self._running = True
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
message = await self._pubsub.get_message(
|
||||
ignore_subscribe_messages=True,
|
||||
timeout=1.0,
|
||||
)
|
||||
|
||||
if message is None:
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
|
||||
if message["type"] not in ("message", "pmessage"):
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(message["data"])
|
||||
event = Event.from_dict(data)
|
||||
yield event
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"Failed to parse event: {e}")
|
||||
continue
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self._running = False
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error consuming events: {e}")
|
||||
await asyncio.sleep(1.0)
|
||||
Reference in New Issue
Block a user