"""Storage layer for metrics - Redis (current state) and TimescaleDB (historical).""" import json import time from datetime import datetime from typing import Any import asyncpg import redis.asyncio as redis from shared.logging import get_logger logger = get_logger("storage") class RedisStorage: """Redis storage for current machine state.""" def __init__(self, redis_url: str): self.redis_url = redis_url self._client: redis.Redis | None = None async def connect(self) -> None: self._client = redis.from_url(self.redis_url, decode_responses=True) await self._client.ping() logger.info("redis_connected", url=self.redis_url) async def disconnect(self) -> None: if self._client: await self._client.close() self._client = None logger.info("redis_disconnected") async def update_machine_state( self, machine_id: str, hostname: str, metrics: dict[str, float], timestamp_ms: int, ) -> None: """Update the current state for a machine (merges metrics, doesn't replace).""" if not self._client: raise RuntimeError("Not connected to Redis") key = f"machine:{machine_id}" # Get existing state to merge metrics existing_data = await self._client.hget(key, "state") if existing_data: existing_state = json.loads(existing_data) existing_metrics = existing_state.get("metrics", {}) # Merge new metrics into existing (new values override old) existing_metrics.update(metrics) metrics = existing_metrics state = { "machine_id": machine_id, "hostname": hostname, "last_seen_ms": timestamp_ms, "metrics": metrics, "updated_at": datetime.utcnow().isoformat(), } # Store as hash for efficient partial reads await self._client.hset( key, mapping={ "state": json.dumps(state), "last_seen": str(timestamp_ms), }, ) # Set expiry - if no updates for 5 minutes, consider stale await self._client.expire(key, 300) # Add to active machines set await self._client.sadd("machines:active", machine_id) async def get_machine_state(self, machine_id: str) -> dict[str, Any] | None: """Get current state for a machine.""" if not self._client: raise RuntimeError("Not connected to Redis") key = f"machine:{machine_id}" data = await self._client.hget(key, "state") if data: return json.loads(data) return None async def get_all_machines(self) -> list[dict[str, Any]]: """Get current state for all active machines.""" if not self._client: raise RuntimeError("Not connected to Redis") machine_ids = await self._client.smembers("machines:active") states = [] for machine_id in machine_ids: state = await self.get_machine_state(machine_id) if state: states.append(state) else: # Remove stale machine from active set await self._client.srem("machines:active", machine_id) return states class TimescaleStorage: """TimescaleDB storage for historical metrics.""" def __init__(self, connection_url: str): self.connection_url = connection_url self._pool: asyncpg.Pool | None = None async def connect(self) -> None: self._pool = await asyncpg.create_pool( self.connection_url, min_size=2, max_size=10, ) logger.info("timescaledb_connected") async def disconnect(self) -> None: if self._pool: await self._pool.close() self._pool = None logger.info("timescaledb_disconnected") async def insert_metrics( self, machine_id: str, hostname: str, timestamp_ms: int, metrics: list[tuple[str, float, dict[str, str]]], ) -> int: """ Insert a batch of metrics. Args: machine_id: Machine identifier hostname: Machine hostname timestamp_ms: Timestamp in milliseconds metrics: List of (metric_type, value, labels) tuples Returns: Number of rows inserted """ if not self._pool: raise RuntimeError("Not connected to TimescaleDB") timestamp = datetime.utcfromtimestamp(timestamp_ms / 1000) # Prepare batch insert rows = [ (timestamp, machine_id, hostname, metric_type, value, json.dumps(labels)) for metric_type, value, labels in metrics ] async with self._pool.acquire() as conn: await conn.executemany( """ INSERT INTO metrics_raw (time, machine_id, hostname, metric_type, value, labels) VALUES ($1, $2, $3, $4, $5, $6) """, rows, ) return len(rows) async def update_machine_registry( self, machine_id: str, hostname: str, health: str = "HEALTHY", ) -> None: """Update the machines registry with last seen time.""" if not self._pool: raise RuntimeError("Not connected to TimescaleDB") async with self._pool.acquire() as conn: await conn.execute( """ INSERT INTO machines (machine_id, hostname, last_seen, health) VALUES ($1, $2, NOW(), $3) ON CONFLICT (machine_id) DO UPDATE SET hostname = $2, last_seen = NOW(), health = $3 """, machine_id, hostname, health, ) async def get_metrics( self, machine_id: str | None = None, metric_type: str | None = None, start_time: datetime | None = None, end_time: datetime | None = None, limit: int = 1000, ) -> list[dict[str, Any]]: """Query historical metrics.""" if not self._pool: raise RuntimeError("Not connected to TimescaleDB") conditions = [] params = [] param_idx = 1 if machine_id: conditions.append(f"machine_id = ${param_idx}") params.append(machine_id) param_idx += 1 if metric_type: conditions.append(f"metric_type = ${param_idx}") params.append(metric_type) param_idx += 1 if start_time: conditions.append(f"time >= ${param_idx}") params.append(start_time) param_idx += 1 if end_time: conditions.append(f"time <= ${param_idx}") params.append(end_time) param_idx += 1 where_clause = " AND ".join(conditions) if conditions else "TRUE" query = f""" SELECT time, machine_id, hostname, metric_type, value, labels FROM metrics_raw WHERE {where_clause} ORDER BY time DESC LIMIT ${param_idx} """ params.append(limit) async with self._pool.acquire() as conn: rows = await conn.fetch(query, *params) return [ { "time": row["time"].isoformat(), "machine_id": row["machine_id"], "hostname": row["hostname"], "metric_type": row["metric_type"], "value": row["value"], "labels": json.loads(row["labels"]) if row["labels"] else {}, } for row in rows ]