246 lines
7.2 KiB
Python
246 lines
7.2 KiB
Python
"""Storage layer for metrics - Redis (current state) and TimescaleDB (historical)."""
|
|
|
|
import json
|
|
import time
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
import asyncpg
|
|
import redis.asyncio as redis
|
|
|
|
from shared.logging import get_logger
|
|
|
|
logger = get_logger("storage")
|
|
|
|
|
|
class RedisStorage:
|
|
"""Redis storage for current machine state."""
|
|
|
|
def __init__(self, redis_url: str):
|
|
self.redis_url = redis_url
|
|
self._client: redis.Redis | None = None
|
|
|
|
async def connect(self) -> None:
|
|
self._client = redis.from_url(self.redis_url, decode_responses=True)
|
|
await self._client.ping()
|
|
logger.info("redis_connected", url=self.redis_url)
|
|
|
|
async def disconnect(self) -> None:
|
|
if self._client:
|
|
await self._client.close()
|
|
self._client = None
|
|
logger.info("redis_disconnected")
|
|
|
|
async def update_machine_state(
|
|
self,
|
|
machine_id: str,
|
|
hostname: str,
|
|
metrics: dict[str, float],
|
|
timestamp_ms: int,
|
|
) -> None:
|
|
"""Update the current state for a machine."""
|
|
if not self._client:
|
|
raise RuntimeError("Not connected to Redis")
|
|
|
|
state = {
|
|
"machine_id": machine_id,
|
|
"hostname": hostname,
|
|
"last_seen_ms": timestamp_ms,
|
|
"metrics": metrics,
|
|
"updated_at": datetime.utcnow().isoformat(),
|
|
}
|
|
|
|
# Store as hash for efficient partial reads
|
|
key = f"machine:{machine_id}"
|
|
await self._client.hset(
|
|
key,
|
|
mapping={
|
|
"state": json.dumps(state),
|
|
"last_seen": str(timestamp_ms),
|
|
},
|
|
)
|
|
|
|
# Set expiry - if no updates for 5 minutes, consider stale
|
|
await self._client.expire(key, 300)
|
|
|
|
# Add to active machines set
|
|
await self._client.sadd("machines:active", machine_id)
|
|
|
|
async def get_machine_state(self, machine_id: str) -> dict[str, Any] | None:
|
|
"""Get current state for a machine."""
|
|
if not self._client:
|
|
raise RuntimeError("Not connected to Redis")
|
|
|
|
key = f"machine:{machine_id}"
|
|
data = await self._client.hget(key, "state")
|
|
|
|
if data:
|
|
return json.loads(data)
|
|
return None
|
|
|
|
async def get_all_machines(self) -> list[dict[str, Any]]:
|
|
"""Get current state for all active machines."""
|
|
if not self._client:
|
|
raise RuntimeError("Not connected to Redis")
|
|
|
|
machine_ids = await self._client.smembers("machines:active")
|
|
states = []
|
|
|
|
for machine_id in machine_ids:
|
|
state = await self.get_machine_state(machine_id)
|
|
if state:
|
|
states.append(state)
|
|
else:
|
|
# Remove stale machine from active set
|
|
await self._client.srem("machines:active", machine_id)
|
|
|
|
return states
|
|
|
|
|
|
class TimescaleStorage:
|
|
"""TimescaleDB storage for historical metrics."""
|
|
|
|
def __init__(self, connection_url: str):
|
|
self.connection_url = connection_url
|
|
self._pool: asyncpg.Pool | None = None
|
|
|
|
async def connect(self) -> None:
|
|
self._pool = await asyncpg.create_pool(
|
|
self.connection_url,
|
|
min_size=2,
|
|
max_size=10,
|
|
)
|
|
logger.info("timescaledb_connected")
|
|
|
|
async def disconnect(self) -> None:
|
|
if self._pool:
|
|
await self._pool.close()
|
|
self._pool = None
|
|
logger.info("timescaledb_disconnected")
|
|
|
|
async def insert_metrics(
|
|
self,
|
|
machine_id: str,
|
|
hostname: str,
|
|
timestamp_ms: int,
|
|
metrics: list[tuple[str, float, dict[str, str]]],
|
|
) -> int:
|
|
"""
|
|
Insert a batch of metrics.
|
|
|
|
Args:
|
|
machine_id: Machine identifier
|
|
hostname: Machine hostname
|
|
timestamp_ms: Timestamp in milliseconds
|
|
metrics: List of (metric_type, value, labels) tuples
|
|
|
|
Returns:
|
|
Number of rows inserted
|
|
"""
|
|
if not self._pool:
|
|
raise RuntimeError("Not connected to TimescaleDB")
|
|
|
|
timestamp = datetime.utcfromtimestamp(timestamp_ms / 1000)
|
|
|
|
# Prepare batch insert
|
|
rows = [
|
|
(timestamp, machine_id, hostname, metric_type, value, json.dumps(labels))
|
|
for metric_type, value, labels in metrics
|
|
]
|
|
|
|
async with self._pool.acquire() as conn:
|
|
await conn.executemany(
|
|
"""
|
|
INSERT INTO metrics_raw (time, machine_id, hostname, metric_type, value, labels)
|
|
VALUES ($1, $2, $3, $4, $5, $6)
|
|
""",
|
|
rows,
|
|
)
|
|
|
|
return len(rows)
|
|
|
|
async def update_machine_registry(
|
|
self,
|
|
machine_id: str,
|
|
hostname: str,
|
|
health: str = "HEALTHY",
|
|
) -> None:
|
|
"""Update the machines registry with last seen time."""
|
|
if not self._pool:
|
|
raise RuntimeError("Not connected to TimescaleDB")
|
|
|
|
async with self._pool.acquire() as conn:
|
|
await conn.execute(
|
|
"""
|
|
INSERT INTO machines (machine_id, hostname, last_seen, health)
|
|
VALUES ($1, $2, NOW(), $3)
|
|
ON CONFLICT (machine_id) DO UPDATE
|
|
SET hostname = $2, last_seen = NOW(), health = $3
|
|
""",
|
|
machine_id,
|
|
hostname,
|
|
health,
|
|
)
|
|
|
|
async def get_metrics(
|
|
self,
|
|
machine_id: str | None = None,
|
|
metric_type: str | None = None,
|
|
start_time: datetime | None = None,
|
|
end_time: datetime | None = None,
|
|
limit: int = 1000,
|
|
) -> list[dict[str, Any]]:
|
|
"""Query historical metrics."""
|
|
if not self._pool:
|
|
raise RuntimeError("Not connected to TimescaleDB")
|
|
|
|
conditions = []
|
|
params = []
|
|
param_idx = 1
|
|
|
|
if machine_id:
|
|
conditions.append(f"machine_id = ${param_idx}")
|
|
params.append(machine_id)
|
|
param_idx += 1
|
|
|
|
if metric_type:
|
|
conditions.append(f"metric_type = ${param_idx}")
|
|
params.append(metric_type)
|
|
param_idx += 1
|
|
|
|
if start_time:
|
|
conditions.append(f"time >= ${param_idx}")
|
|
params.append(start_time)
|
|
param_idx += 1
|
|
|
|
if end_time:
|
|
conditions.append(f"time <= ${param_idx}")
|
|
params.append(end_time)
|
|
param_idx += 1
|
|
|
|
where_clause = " AND ".join(conditions) if conditions else "TRUE"
|
|
|
|
query = f"""
|
|
SELECT time, machine_id, hostname, metric_type, value, labels
|
|
FROM metrics_raw
|
|
WHERE {where_clause}
|
|
ORDER BY time DESC
|
|
LIMIT ${param_idx}
|
|
"""
|
|
params.append(limit)
|
|
|
|
async with self._pool.acquire() as conn:
|
|
rows = await conn.fetch(query, *params)
|
|
|
|
return [
|
|
{
|
|
"time": row["time"].isoformat(),
|
|
"machine_id": row["machine_id"],
|
|
"hostname": row["hostname"],
|
|
"metric_type": row["metric_type"],
|
|
"value": row["value"],
|
|
"labels": json.loads(row["labels"]) if row["labels"] else {},
|
|
}
|
|
for row in rows
|
|
]
|