claude final draft
This commit is contained in:
1
services/aggregator/__init__.py
Normal file
1
services/aggregator/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Aggregator service."""
|
||||
361
services/aggregator/main.py
Normal file
361
services/aggregator/main.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""Aggregator service - gRPC server that receives metrics and stores them."""
|
||||
|
||||
import asyncio
|
||||
import signal
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import grpc
|
||||
from grpc_health.v1 import health, health_pb2, health_pb2_grpc
|
||||
|
||||
# Add project root to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from services.aggregator.storage import RedisStorage, TimescaleStorage
|
||||
from shared import metrics_pb2, metrics_pb2_grpc
|
||||
from shared.config import get_aggregator_config
|
||||
from shared.events import get_publisher
|
||||
from shared.logging import setup_logging
|
||||
|
||||
|
||||
class MetricsServicer(metrics_pb2_grpc.MetricsServiceServicer):
|
||||
"""gRPC servicer for metrics ingestion."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_storage: RedisStorage,
|
||||
timescale_storage: TimescaleStorage,
|
||||
event_publisher,
|
||||
logger,
|
||||
):
|
||||
self.redis = redis_storage
|
||||
self.timescale = timescale_storage
|
||||
self.publisher = event_publisher
|
||||
self.logger = logger
|
||||
|
||||
async def StreamMetrics(self, request_iterator, context):
|
||||
"""Receive streaming metrics from a collector."""
|
||||
metrics_received = 0
|
||||
current_machine = None
|
||||
current_batch: list[tuple[str, float, dict]] = []
|
||||
batch_timestamp = 0
|
||||
batch_hostname = ""
|
||||
|
||||
try:
|
||||
async for metric in request_iterator:
|
||||
metrics_received += 1
|
||||
|
||||
# Track current machine
|
||||
if current_machine != metric.machine_id:
|
||||
# Flush previous batch if switching machines
|
||||
if current_machine and current_batch:
|
||||
await self._flush_batch(
|
||||
current_machine,
|
||||
batch_hostname,
|
||||
batch_timestamp,
|
||||
current_batch,
|
||||
)
|
||||
current_batch = []
|
||||
|
||||
current_machine = metric.machine_id
|
||||
self.logger.info(
|
||||
"collector_connected",
|
||||
machine_id=metric.machine_id,
|
||||
hostname=metric.hostname,
|
||||
)
|
||||
|
||||
# Get metric type name
|
||||
metric_type = metrics_pb2.MetricType.Name(metric.type)
|
||||
|
||||
# Add to batch
|
||||
current_batch.append(
|
||||
(
|
||||
metric_type,
|
||||
metric.value,
|
||||
dict(metric.labels),
|
||||
)
|
||||
)
|
||||
batch_timestamp = metric.timestamp_ms
|
||||
batch_hostname = metric.hostname
|
||||
|
||||
# Flush batch every 20 metrics or if timestamp changes significantly
|
||||
if len(current_batch) >= 20:
|
||||
await self._flush_batch(
|
||||
current_machine, batch_hostname, batch_timestamp, current_batch
|
||||
)
|
||||
current_batch = []
|
||||
|
||||
# Flush remaining
|
||||
if current_machine and current_batch:
|
||||
await self._flush_batch(
|
||||
current_machine, batch_hostname, batch_timestamp, current_batch
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
"stream_completed",
|
||||
machine_id=current_machine,
|
||||
metrics_received=metrics_received,
|
||||
)
|
||||
|
||||
return metrics_pb2.StreamAck(
|
||||
success=True,
|
||||
metrics_received=metrics_received,
|
||||
message="OK",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
"stream_error",
|
||||
error=str(e),
|
||||
machine_id=current_machine,
|
||||
metrics_received=metrics_received,
|
||||
)
|
||||
return metrics_pb2.StreamAck(
|
||||
success=False,
|
||||
metrics_received=metrics_received,
|
||||
message=str(e),
|
||||
)
|
||||
|
||||
async def _flush_batch(
|
||||
self,
|
||||
machine_id: str,
|
||||
hostname: str,
|
||||
timestamp_ms: int,
|
||||
batch: list[tuple[str, float, dict]],
|
||||
) -> None:
|
||||
"""Flush a batch of metrics to storage and events."""
|
||||
|
||||
# Aggregate metrics for Redis state
|
||||
metrics_dict = {}
|
||||
for metric_type, value, labels in batch:
|
||||
key = metric_type
|
||||
if labels:
|
||||
key = f"{metric_type}:{','.join(f'{k}={v}' for k, v in labels.items())}"
|
||||
metrics_dict[key] = value
|
||||
|
||||
# Update Redis (current state)
|
||||
await self.redis.update_machine_state(
|
||||
machine_id=machine_id,
|
||||
hostname=hostname,
|
||||
metrics=metrics_dict,
|
||||
timestamp_ms=timestamp_ms,
|
||||
)
|
||||
|
||||
# Insert into TimescaleDB (historical)
|
||||
try:
|
||||
await self.timescale.insert_metrics(
|
||||
machine_id=machine_id,
|
||||
hostname=hostname,
|
||||
timestamp_ms=timestamp_ms,
|
||||
metrics=batch,
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.warning("timescale_insert_failed", error=str(e))
|
||||
|
||||
# Update machine registry
|
||||
try:
|
||||
await self.timescale.update_machine_registry(
|
||||
machine_id=machine_id,
|
||||
hostname=hostname,
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.warning("machine_registry_update_failed", error=str(e))
|
||||
|
||||
# Publish event for subscribers (alerts, gateway)
|
||||
await self.publisher.publish(
|
||||
topic="metrics.raw",
|
||||
payload={
|
||||
"machine_id": machine_id,
|
||||
"hostname": hostname,
|
||||
"timestamp_ms": timestamp_ms,
|
||||
"metrics": metrics_dict,
|
||||
},
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
"batch_flushed",
|
||||
machine_id=machine_id,
|
||||
count=len(batch),
|
||||
)
|
||||
|
||||
async def GetCurrentState(self, request, context):
|
||||
"""Get current state for a single machine."""
|
||||
state = await self.redis.get_machine_state(request.machine_id)
|
||||
|
||||
if not state:
|
||||
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||
context.set_details(f"Machine {request.machine_id} not found")
|
||||
return metrics_pb2.MachineState()
|
||||
|
||||
# Convert state to proto
|
||||
metrics = []
|
||||
for key, value in state.get("metrics", {}).items():
|
||||
parts = key.split(":")
|
||||
metric_type_str = parts[0]
|
||||
labels = {}
|
||||
if len(parts) > 1:
|
||||
for pair in parts[1].split(","):
|
||||
k, v = pair.split("=")
|
||||
labels[k] = v
|
||||
|
||||
metric_type = getattr(metrics_pb2, metric_type_str, 0)
|
||||
metrics.append(
|
||||
metrics_pb2.Metric(
|
||||
machine_id=state["machine_id"],
|
||||
hostname=state["hostname"],
|
||||
timestamp_ms=state["last_seen_ms"],
|
||||
type=metric_type,
|
||||
value=value,
|
||||
labels=labels,
|
||||
)
|
||||
)
|
||||
|
||||
return metrics_pb2.MachineState(
|
||||
machine_id=state["machine_id"],
|
||||
hostname=state["hostname"],
|
||||
last_seen_ms=state["last_seen_ms"],
|
||||
current_metrics=metrics,
|
||||
health=metrics_pb2.HEALTHY,
|
||||
)
|
||||
|
||||
async def GetAllStates(self, request, context):
|
||||
"""Get current state for all machines."""
|
||||
states = await self.redis.get_all_machines()
|
||||
|
||||
machine_states = []
|
||||
for state in states:
|
||||
metrics = []
|
||||
for key, value in state.get("metrics", {}).items():
|
||||
parts = key.split(":")
|
||||
metric_type_str = parts[0]
|
||||
metric_type = getattr(metrics_pb2, metric_type_str, 0)
|
||||
metrics.append(
|
||||
metrics_pb2.Metric(
|
||||
machine_id=state["machine_id"],
|
||||
hostname=state["hostname"],
|
||||
timestamp_ms=state["last_seen_ms"],
|
||||
type=metric_type,
|
||||
value=value,
|
||||
)
|
||||
)
|
||||
|
||||
machine_states.append(
|
||||
metrics_pb2.MachineState(
|
||||
machine_id=state["machine_id"],
|
||||
hostname=state["hostname"],
|
||||
last_seen_ms=state["last_seen_ms"],
|
||||
current_metrics=metrics,
|
||||
health=metrics_pb2.HEALTHY,
|
||||
)
|
||||
)
|
||||
|
||||
return metrics_pb2.AllMachinesState(machines=machine_states)
|
||||
|
||||
|
||||
class AggregatorService:
|
||||
"""Main aggregator service."""
|
||||
|
||||
def __init__(self):
|
||||
self.config = get_aggregator_config()
|
||||
self.logger = setup_logging(
|
||||
service_name=self.config.service_name,
|
||||
log_level=self.config.log_level,
|
||||
log_format=self.config.log_format,
|
||||
)
|
||||
|
||||
self.redis = RedisStorage(self.config.redis_url)
|
||||
self.timescale = TimescaleStorage(self.config.timescale_url)
|
||||
self.publisher = get_publisher(source="aggregator")
|
||||
|
||||
self.server: grpc.aio.Server | None = None
|
||||
self.running = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the gRPC server."""
|
||||
self.running = True
|
||||
|
||||
# Connect to storage
|
||||
await self.redis.connect()
|
||||
|
||||
try:
|
||||
await self.timescale.connect()
|
||||
except Exception as e:
|
||||
self.logger.warning(
|
||||
"timescale_connection_failed",
|
||||
error=str(e),
|
||||
message="Continuing without TimescaleDB - metrics won't be persisted",
|
||||
)
|
||||
|
||||
# Connect to event publisher
|
||||
await self.publisher.connect()
|
||||
|
||||
# Create gRPC server
|
||||
self.server = grpc.aio.server()
|
||||
|
||||
# Add metrics servicer
|
||||
servicer = MetricsServicer(
|
||||
redis_storage=self.redis,
|
||||
timescale_storage=self.timescale,
|
||||
event_publisher=self.publisher,
|
||||
logger=self.logger,
|
||||
)
|
||||
metrics_pb2_grpc.add_MetricsServiceServicer_to_server(servicer, self.server)
|
||||
|
||||
# Add health check servicer
|
||||
health_servicer = health.HealthServicer()
|
||||
health_servicer.set("", health_pb2.HealthCheckResponse.SERVING)
|
||||
health_servicer.set("MetricsService", health_pb2.HealthCheckResponse.SERVING)
|
||||
health_pb2_grpc.add_HealthServicer_to_server(health_servicer, self.server)
|
||||
|
||||
# Start server
|
||||
listen_addr = f"[::]:{self.config.grpc_port}"
|
||||
self.server.add_insecure_port(listen_addr)
|
||||
|
||||
await self.server.start()
|
||||
|
||||
self.logger.info(
|
||||
"aggregator_started",
|
||||
port=self.config.grpc_port,
|
||||
listen_addr=listen_addr,
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the gRPC server."""
|
||||
self.running = False
|
||||
|
||||
if self.server:
|
||||
await self.server.stop(grace=5)
|
||||
self.server = None
|
||||
|
||||
await self.publisher.disconnect()
|
||||
await self.timescale.disconnect()
|
||||
await self.redis.disconnect()
|
||||
|
||||
self.logger.info("aggregator_stopped")
|
||||
|
||||
async def wait(self) -> None:
|
||||
"""Wait for the server to terminate."""
|
||||
if self.server:
|
||||
await self.server.wait_for_termination()
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point."""
|
||||
service = AggregatorService()
|
||||
|
||||
# Handle shutdown signals
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
async def shutdown():
|
||||
service.logger.info("shutdown_signal_received")
|
||||
await service.stop()
|
||||
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, lambda: asyncio.create_task(shutdown()))
|
||||
|
||||
await service.start()
|
||||
await service.wait()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
245
services/aggregator/storage.py
Normal file
245
services/aggregator/storage.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Storage layer for metrics - Redis (current state) and TimescaleDB (historical)."""
|
||||
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
import redis.asyncio as redis
|
||||
|
||||
from shared.logging import get_logger
|
||||
|
||||
logger = get_logger("storage")
|
||||
|
||||
|
||||
class RedisStorage:
|
||||
"""Redis storage for current machine state."""
|
||||
|
||||
def __init__(self, redis_url: str):
|
||||
self.redis_url = redis_url
|
||||
self._client: redis.Redis | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
self._client = redis.from_url(self.redis_url, decode_responses=True)
|
||||
await self._client.ping()
|
||||
logger.info("redis_connected", url=self.redis_url)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
logger.info("redis_disconnected")
|
||||
|
||||
async def update_machine_state(
|
||||
self,
|
||||
machine_id: str,
|
||||
hostname: str,
|
||||
metrics: dict[str, float],
|
||||
timestamp_ms: int,
|
||||
) -> None:
|
||||
"""Update the current state for a machine."""
|
||||
if not self._client:
|
||||
raise RuntimeError("Not connected to Redis")
|
||||
|
||||
state = {
|
||||
"machine_id": machine_id,
|
||||
"hostname": hostname,
|
||||
"last_seen_ms": timestamp_ms,
|
||||
"metrics": metrics,
|
||||
"updated_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
# Store as hash for efficient partial reads
|
||||
key = f"machine:{machine_id}"
|
||||
await self._client.hset(
|
||||
key,
|
||||
mapping={
|
||||
"state": json.dumps(state),
|
||||
"last_seen": str(timestamp_ms),
|
||||
},
|
||||
)
|
||||
|
||||
# Set expiry - if no updates for 5 minutes, consider stale
|
||||
await self._client.expire(key, 300)
|
||||
|
||||
# Add to active machines set
|
||||
await self._client.sadd("machines:active", machine_id)
|
||||
|
||||
async def get_machine_state(self, machine_id: str) -> dict[str, Any] | None:
|
||||
"""Get current state for a machine."""
|
||||
if not self._client:
|
||||
raise RuntimeError("Not connected to Redis")
|
||||
|
||||
key = f"machine:{machine_id}"
|
||||
data = await self._client.hget(key, "state")
|
||||
|
||||
if data:
|
||||
return json.loads(data)
|
||||
return None
|
||||
|
||||
async def get_all_machines(self) -> list[dict[str, Any]]:
|
||||
"""Get current state for all active machines."""
|
||||
if not self._client:
|
||||
raise RuntimeError("Not connected to Redis")
|
||||
|
||||
machine_ids = await self._client.smembers("machines:active")
|
||||
states = []
|
||||
|
||||
for machine_id in machine_ids:
|
||||
state = await self.get_machine_state(machine_id)
|
||||
if state:
|
||||
states.append(state)
|
||||
else:
|
||||
# Remove stale machine from active set
|
||||
await self._client.srem("machines:active", machine_id)
|
||||
|
||||
return states
|
||||
|
||||
|
||||
class TimescaleStorage:
|
||||
"""TimescaleDB storage for historical metrics."""
|
||||
|
||||
def __init__(self, connection_url: str):
|
||||
self.connection_url = connection_url
|
||||
self._pool: asyncpg.Pool | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
self._pool = await asyncpg.create_pool(
|
||||
self.connection_url,
|
||||
min_size=2,
|
||||
max_size=10,
|
||||
)
|
||||
logger.info("timescaledb_connected")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self._pool:
|
||||
await self._pool.close()
|
||||
self._pool = None
|
||||
logger.info("timescaledb_disconnected")
|
||||
|
||||
async def insert_metrics(
|
||||
self,
|
||||
machine_id: str,
|
||||
hostname: str,
|
||||
timestamp_ms: int,
|
||||
metrics: list[tuple[str, float, dict[str, str]]],
|
||||
) -> int:
|
||||
"""
|
||||
Insert a batch of metrics.
|
||||
|
||||
Args:
|
||||
machine_id: Machine identifier
|
||||
hostname: Machine hostname
|
||||
timestamp_ms: Timestamp in milliseconds
|
||||
metrics: List of (metric_type, value, labels) tuples
|
||||
|
||||
Returns:
|
||||
Number of rows inserted
|
||||
"""
|
||||
if not self._pool:
|
||||
raise RuntimeError("Not connected to TimescaleDB")
|
||||
|
||||
timestamp = datetime.utcfromtimestamp(timestamp_ms / 1000)
|
||||
|
||||
# Prepare batch insert
|
||||
rows = [
|
||||
(timestamp, machine_id, hostname, metric_type, value, json.dumps(labels))
|
||||
for metric_type, value, labels in metrics
|
||||
]
|
||||
|
||||
async with self._pool.acquire() as conn:
|
||||
await conn.executemany(
|
||||
"""
|
||||
INSERT INTO metrics_raw (time, machine_id, hostname, metric_type, value, labels)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
""",
|
||||
rows,
|
||||
)
|
||||
|
||||
return len(rows)
|
||||
|
||||
async def update_machine_registry(
|
||||
self,
|
||||
machine_id: str,
|
||||
hostname: str,
|
||||
health: str = "HEALTHY",
|
||||
) -> None:
|
||||
"""Update the machines registry with last seen time."""
|
||||
if not self._pool:
|
||||
raise RuntimeError("Not connected to TimescaleDB")
|
||||
|
||||
async with self._pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO machines (machine_id, hostname, last_seen, health)
|
||||
VALUES ($1, $2, NOW(), $3)
|
||||
ON CONFLICT (machine_id) DO UPDATE
|
||||
SET hostname = $2, last_seen = NOW(), health = $3
|
||||
""",
|
||||
machine_id,
|
||||
hostname,
|
||||
health,
|
||||
)
|
||||
|
||||
async def get_metrics(
|
||||
self,
|
||||
machine_id: str | None = None,
|
||||
metric_type: str | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
limit: int = 1000,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Query historical metrics."""
|
||||
if not self._pool:
|
||||
raise RuntimeError("Not connected to TimescaleDB")
|
||||
|
||||
conditions = []
|
||||
params = []
|
||||
param_idx = 1
|
||||
|
||||
if machine_id:
|
||||
conditions.append(f"machine_id = ${param_idx}")
|
||||
params.append(machine_id)
|
||||
param_idx += 1
|
||||
|
||||
if metric_type:
|
||||
conditions.append(f"metric_type = ${param_idx}")
|
||||
params.append(metric_type)
|
||||
param_idx += 1
|
||||
|
||||
if start_time:
|
||||
conditions.append(f"time >= ${param_idx}")
|
||||
params.append(start_time)
|
||||
param_idx += 1
|
||||
|
||||
if end_time:
|
||||
conditions.append(f"time <= ${param_idx}")
|
||||
params.append(end_time)
|
||||
param_idx += 1
|
||||
|
||||
where_clause = " AND ".join(conditions) if conditions else "TRUE"
|
||||
|
||||
query = f"""
|
||||
SELECT time, machine_id, hostname, metric_type, value, labels
|
||||
FROM metrics_raw
|
||||
WHERE {where_clause}
|
||||
ORDER BY time DESC
|
||||
LIMIT ${param_idx}
|
||||
"""
|
||||
params.append(limit)
|
||||
|
||||
async with self._pool.acquire() as conn:
|
||||
rows = await conn.fetch(query, *params)
|
||||
|
||||
return [
|
||||
{
|
||||
"time": row["time"].isoformat(),
|
||||
"machine_id": row["machine_id"],
|
||||
"hostname": row["hostname"],
|
||||
"metric_type": row["metric_type"],
|
||||
"value": row["value"],
|
||||
"labels": json.loads(row["labels"]) if row["labels"] else {},
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
Reference in New Issue
Block a user