claude final draft

This commit is contained in:
buenosairesam
2025-12-29 23:44:30 -03:00
parent 116d4032e2
commit e5aafd5097
22 changed files with 2815 additions and 32 deletions

View File

@@ -0,0 +1 @@
"""Aggregator service."""

361
services/aggregator/main.py Normal file
View File

@@ -0,0 +1,361 @@
"""Aggregator service - gRPC server that receives metrics and stores them."""
import asyncio
import signal
import sys
from pathlib import Path
import grpc
from grpc_health.v1 import health, health_pb2, health_pb2_grpc
# Add project root to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from services.aggregator.storage import RedisStorage, TimescaleStorage
from shared import metrics_pb2, metrics_pb2_grpc
from shared.config import get_aggregator_config
from shared.events import get_publisher
from shared.logging import setup_logging
class MetricsServicer(metrics_pb2_grpc.MetricsServiceServicer):
"""gRPC servicer for metrics ingestion."""
def __init__(
self,
redis_storage: RedisStorage,
timescale_storage: TimescaleStorage,
event_publisher,
logger,
):
self.redis = redis_storage
self.timescale = timescale_storage
self.publisher = event_publisher
self.logger = logger
async def StreamMetrics(self, request_iterator, context):
"""Receive streaming metrics from a collector."""
metrics_received = 0
current_machine = None
current_batch: list[tuple[str, float, dict]] = []
batch_timestamp = 0
batch_hostname = ""
try:
async for metric in request_iterator:
metrics_received += 1
# Track current machine
if current_machine != metric.machine_id:
# Flush previous batch if switching machines
if current_machine and current_batch:
await self._flush_batch(
current_machine,
batch_hostname,
batch_timestamp,
current_batch,
)
current_batch = []
current_machine = metric.machine_id
self.logger.info(
"collector_connected",
machine_id=metric.machine_id,
hostname=metric.hostname,
)
# Get metric type name
metric_type = metrics_pb2.MetricType.Name(metric.type)
# Add to batch
current_batch.append(
(
metric_type,
metric.value,
dict(metric.labels),
)
)
batch_timestamp = metric.timestamp_ms
batch_hostname = metric.hostname
# Flush batch every 20 metrics or if timestamp changes significantly
if len(current_batch) >= 20:
await self._flush_batch(
current_machine, batch_hostname, batch_timestamp, current_batch
)
current_batch = []
# Flush remaining
if current_machine and current_batch:
await self._flush_batch(
current_machine, batch_hostname, batch_timestamp, current_batch
)
self.logger.info(
"stream_completed",
machine_id=current_machine,
metrics_received=metrics_received,
)
return metrics_pb2.StreamAck(
success=True,
metrics_received=metrics_received,
message="OK",
)
except Exception as e:
self.logger.error(
"stream_error",
error=str(e),
machine_id=current_machine,
metrics_received=metrics_received,
)
return metrics_pb2.StreamAck(
success=False,
metrics_received=metrics_received,
message=str(e),
)
async def _flush_batch(
self,
machine_id: str,
hostname: str,
timestamp_ms: int,
batch: list[tuple[str, float, dict]],
) -> None:
"""Flush a batch of metrics to storage and events."""
# Aggregate metrics for Redis state
metrics_dict = {}
for metric_type, value, labels in batch:
key = metric_type
if labels:
key = f"{metric_type}:{','.join(f'{k}={v}' for k, v in labels.items())}"
metrics_dict[key] = value
# Update Redis (current state)
await self.redis.update_machine_state(
machine_id=machine_id,
hostname=hostname,
metrics=metrics_dict,
timestamp_ms=timestamp_ms,
)
# Insert into TimescaleDB (historical)
try:
await self.timescale.insert_metrics(
machine_id=machine_id,
hostname=hostname,
timestamp_ms=timestamp_ms,
metrics=batch,
)
except Exception as e:
self.logger.warning("timescale_insert_failed", error=str(e))
# Update machine registry
try:
await self.timescale.update_machine_registry(
machine_id=machine_id,
hostname=hostname,
)
except Exception as e:
self.logger.warning("machine_registry_update_failed", error=str(e))
# Publish event for subscribers (alerts, gateway)
await self.publisher.publish(
topic="metrics.raw",
payload={
"machine_id": machine_id,
"hostname": hostname,
"timestamp_ms": timestamp_ms,
"metrics": metrics_dict,
},
)
self.logger.debug(
"batch_flushed",
machine_id=machine_id,
count=len(batch),
)
async def GetCurrentState(self, request, context):
"""Get current state for a single machine."""
state = await self.redis.get_machine_state(request.machine_id)
if not state:
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(f"Machine {request.machine_id} not found")
return metrics_pb2.MachineState()
# Convert state to proto
metrics = []
for key, value in state.get("metrics", {}).items():
parts = key.split(":")
metric_type_str = parts[0]
labels = {}
if len(parts) > 1:
for pair in parts[1].split(","):
k, v = pair.split("=")
labels[k] = v
metric_type = getattr(metrics_pb2, metric_type_str, 0)
metrics.append(
metrics_pb2.Metric(
machine_id=state["machine_id"],
hostname=state["hostname"],
timestamp_ms=state["last_seen_ms"],
type=metric_type,
value=value,
labels=labels,
)
)
return metrics_pb2.MachineState(
machine_id=state["machine_id"],
hostname=state["hostname"],
last_seen_ms=state["last_seen_ms"],
current_metrics=metrics,
health=metrics_pb2.HEALTHY,
)
async def GetAllStates(self, request, context):
"""Get current state for all machines."""
states = await self.redis.get_all_machines()
machine_states = []
for state in states:
metrics = []
for key, value in state.get("metrics", {}).items():
parts = key.split(":")
metric_type_str = parts[0]
metric_type = getattr(metrics_pb2, metric_type_str, 0)
metrics.append(
metrics_pb2.Metric(
machine_id=state["machine_id"],
hostname=state["hostname"],
timestamp_ms=state["last_seen_ms"],
type=metric_type,
value=value,
)
)
machine_states.append(
metrics_pb2.MachineState(
machine_id=state["machine_id"],
hostname=state["hostname"],
last_seen_ms=state["last_seen_ms"],
current_metrics=metrics,
health=metrics_pb2.HEALTHY,
)
)
return metrics_pb2.AllMachinesState(machines=machine_states)
class AggregatorService:
"""Main aggregator service."""
def __init__(self):
self.config = get_aggregator_config()
self.logger = setup_logging(
service_name=self.config.service_name,
log_level=self.config.log_level,
log_format=self.config.log_format,
)
self.redis = RedisStorage(self.config.redis_url)
self.timescale = TimescaleStorage(self.config.timescale_url)
self.publisher = get_publisher(source="aggregator")
self.server: grpc.aio.Server | None = None
self.running = False
async def start(self) -> None:
"""Start the gRPC server."""
self.running = True
# Connect to storage
await self.redis.connect()
try:
await self.timescale.connect()
except Exception as e:
self.logger.warning(
"timescale_connection_failed",
error=str(e),
message="Continuing without TimescaleDB - metrics won't be persisted",
)
# Connect to event publisher
await self.publisher.connect()
# Create gRPC server
self.server = grpc.aio.server()
# Add metrics servicer
servicer = MetricsServicer(
redis_storage=self.redis,
timescale_storage=self.timescale,
event_publisher=self.publisher,
logger=self.logger,
)
metrics_pb2_grpc.add_MetricsServiceServicer_to_server(servicer, self.server)
# Add health check servicer
health_servicer = health.HealthServicer()
health_servicer.set("", health_pb2.HealthCheckResponse.SERVING)
health_servicer.set("MetricsService", health_pb2.HealthCheckResponse.SERVING)
health_pb2_grpc.add_HealthServicer_to_server(health_servicer, self.server)
# Start server
listen_addr = f"[::]:{self.config.grpc_port}"
self.server.add_insecure_port(listen_addr)
await self.server.start()
self.logger.info(
"aggregator_started",
port=self.config.grpc_port,
listen_addr=listen_addr,
)
async def stop(self) -> None:
"""Stop the gRPC server."""
self.running = False
if self.server:
await self.server.stop(grace=5)
self.server = None
await self.publisher.disconnect()
await self.timescale.disconnect()
await self.redis.disconnect()
self.logger.info("aggregator_stopped")
async def wait(self) -> None:
"""Wait for the server to terminate."""
if self.server:
await self.server.wait_for_termination()
async def main():
"""Main entry point."""
service = AggregatorService()
# Handle shutdown signals
loop = asyncio.get_event_loop()
async def shutdown():
service.logger.info("shutdown_signal_received")
await service.stop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, lambda: asyncio.create_task(shutdown()))
await service.start()
await service.wait()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,245 @@
"""Storage layer for metrics - Redis (current state) and TimescaleDB (historical)."""
import json
import time
from datetime import datetime
from typing import Any
import asyncpg
import redis.asyncio as redis
from shared.logging import get_logger
logger = get_logger("storage")
class RedisStorage:
"""Redis storage for current machine state."""
def __init__(self, redis_url: str):
self.redis_url = redis_url
self._client: redis.Redis | None = None
async def connect(self) -> None:
self._client = redis.from_url(self.redis_url, decode_responses=True)
await self._client.ping()
logger.info("redis_connected", url=self.redis_url)
async def disconnect(self) -> None:
if self._client:
await self._client.close()
self._client = None
logger.info("redis_disconnected")
async def update_machine_state(
self,
machine_id: str,
hostname: str,
metrics: dict[str, float],
timestamp_ms: int,
) -> None:
"""Update the current state for a machine."""
if not self._client:
raise RuntimeError("Not connected to Redis")
state = {
"machine_id": machine_id,
"hostname": hostname,
"last_seen_ms": timestamp_ms,
"metrics": metrics,
"updated_at": datetime.utcnow().isoformat(),
}
# Store as hash for efficient partial reads
key = f"machine:{machine_id}"
await self._client.hset(
key,
mapping={
"state": json.dumps(state),
"last_seen": str(timestamp_ms),
},
)
# Set expiry - if no updates for 5 minutes, consider stale
await self._client.expire(key, 300)
# Add to active machines set
await self._client.sadd("machines:active", machine_id)
async def get_machine_state(self, machine_id: str) -> dict[str, Any] | None:
"""Get current state for a machine."""
if not self._client:
raise RuntimeError("Not connected to Redis")
key = f"machine:{machine_id}"
data = await self._client.hget(key, "state")
if data:
return json.loads(data)
return None
async def get_all_machines(self) -> list[dict[str, Any]]:
"""Get current state for all active machines."""
if not self._client:
raise RuntimeError("Not connected to Redis")
machine_ids = await self._client.smembers("machines:active")
states = []
for machine_id in machine_ids:
state = await self.get_machine_state(machine_id)
if state:
states.append(state)
else:
# Remove stale machine from active set
await self._client.srem("machines:active", machine_id)
return states
class TimescaleStorage:
"""TimescaleDB storage for historical metrics."""
def __init__(self, connection_url: str):
self.connection_url = connection_url
self._pool: asyncpg.Pool | None = None
async def connect(self) -> None:
self._pool = await asyncpg.create_pool(
self.connection_url,
min_size=2,
max_size=10,
)
logger.info("timescaledb_connected")
async def disconnect(self) -> None:
if self._pool:
await self._pool.close()
self._pool = None
logger.info("timescaledb_disconnected")
async def insert_metrics(
self,
machine_id: str,
hostname: str,
timestamp_ms: int,
metrics: list[tuple[str, float, dict[str, str]]],
) -> int:
"""
Insert a batch of metrics.
Args:
machine_id: Machine identifier
hostname: Machine hostname
timestamp_ms: Timestamp in milliseconds
metrics: List of (metric_type, value, labels) tuples
Returns:
Number of rows inserted
"""
if not self._pool:
raise RuntimeError("Not connected to TimescaleDB")
timestamp = datetime.utcfromtimestamp(timestamp_ms / 1000)
# Prepare batch insert
rows = [
(timestamp, machine_id, hostname, metric_type, value, json.dumps(labels))
for metric_type, value, labels in metrics
]
async with self._pool.acquire() as conn:
await conn.executemany(
"""
INSERT INTO metrics_raw (time, machine_id, hostname, metric_type, value, labels)
VALUES ($1, $2, $3, $4, $5, $6)
""",
rows,
)
return len(rows)
async def update_machine_registry(
self,
machine_id: str,
hostname: str,
health: str = "HEALTHY",
) -> None:
"""Update the machines registry with last seen time."""
if not self._pool:
raise RuntimeError("Not connected to TimescaleDB")
async with self._pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO machines (machine_id, hostname, last_seen, health)
VALUES ($1, $2, NOW(), $3)
ON CONFLICT (machine_id) DO UPDATE
SET hostname = $2, last_seen = NOW(), health = $3
""",
machine_id,
hostname,
health,
)
async def get_metrics(
self,
machine_id: str | None = None,
metric_type: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
limit: int = 1000,
) -> list[dict[str, Any]]:
"""Query historical metrics."""
if not self._pool:
raise RuntimeError("Not connected to TimescaleDB")
conditions = []
params = []
param_idx = 1
if machine_id:
conditions.append(f"machine_id = ${param_idx}")
params.append(machine_id)
param_idx += 1
if metric_type:
conditions.append(f"metric_type = ${param_idx}")
params.append(metric_type)
param_idx += 1
if start_time:
conditions.append(f"time >= ${param_idx}")
params.append(start_time)
param_idx += 1
if end_time:
conditions.append(f"time <= ${param_idx}")
params.append(end_time)
param_idx += 1
where_clause = " AND ".join(conditions) if conditions else "TRUE"
query = f"""
SELECT time, machine_id, hostname, metric_type, value, labels
FROM metrics_raw
WHERE {where_clause}
ORDER BY time DESC
LIMIT ${param_idx}
"""
params.append(limit)
async with self._pool.acquire() as conn:
rows = await conn.fetch(query, *params)
return [
{
"time": row["time"].isoformat(),
"machine_id": row["machine_id"],
"hostname": row["hostname"],
"metric_type": row["metric_type"],
"value": row["value"],
"labels": json.loads(row["labels"]) if row["labels"] else {},
}
for row in rows
]

View File

@@ -14,6 +14,12 @@ RUN pip install --no-cache-dir -r requirements.txt
COPY shared /app/shared
COPY proto /app/proto
RUN python -m grpc_tools.protoc \
-I/app/proto \
--python_out=/app/shared \
--grpc_python_out=/app/shared \
/app/proto/metrics.proto
COPY services/alerts /app/services/alerts
ENV PYTHONPATH=/app

View File

@@ -0,0 +1 @@
"""Alerts service."""

317
services/alerts/main.py Normal file
View File

@@ -0,0 +1,317 @@
"""Alerts service - subscribes to metrics events and evaluates thresholds."""
import asyncio
import signal
import sys
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any
import asyncpg
# Add project root to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from shared.config import get_alerts_config
from shared.events import get_publisher, get_subscriber
from shared.logging import setup_logging
@dataclass
class AlertRule:
"""An alert rule configuration."""
id: int
name: str
metric_type: str
operator: str # gt, lt, gte, lte, eq
threshold: float
severity: str # warning, critical
enabled: bool
@dataclass
class Alert:
"""A triggered alert."""
rule: AlertRule
machine_id: str
value: float
triggered_at: datetime
class AlertEvaluator:
"""Evaluates metrics against alert rules."""
OPERATORS = {
"gt": lambda v, t: v > t,
"lt": lambda v, t: v < t,
"gte": lambda v, t: v >= t,
"lte": lambda v, t: v <= t,
"eq": lambda v, t: v == t,
}
def __init__(self, rules: list[AlertRule]):
self.rules = {r.metric_type: r for r in rules if r.enabled}
# Track active alerts to avoid duplicates
self.active_alerts: dict[str, Alert] = {} # key: f"{machine_id}:{rule_name}"
def evaluate(self, machine_id: str, metrics: dict[str, float]) -> list[Alert]:
"""Evaluate metrics against rules and return new alerts."""
new_alerts = []
for metric_type, value in metrics.items():
rule = self.rules.get(metric_type)
if not rule:
continue
op_func = self.OPERATORS.get(rule.operator)
if not op_func:
continue
alert_key = f"{machine_id}:{rule.name}"
if op_func(value, rule.threshold):
# Threshold exceeded
if alert_key not in self.active_alerts:
alert = Alert(
rule=rule,
machine_id=machine_id,
value=value,
triggered_at=datetime.utcnow(),
)
self.active_alerts[alert_key] = alert
new_alerts.append(alert)
else:
# Threshold no longer exceeded - resolve alert
if alert_key in self.active_alerts:
del self.active_alerts[alert_key]
return new_alerts
def update_rules(self, rules: list[AlertRule]) -> None:
"""Update the rules being evaluated."""
self.rules = {r.metric_type: r for r in rules if r.enabled}
class AlertsService:
"""Main alerts service."""
def __init__(self):
self.config = get_alerts_config()
self.logger = setup_logging(
service_name=self.config.service_name,
log_level=self.config.log_level,
log_format=self.config.log_format,
)
self.running = False
self.db_pool: asyncpg.Pool | None = None
self.evaluator: AlertEvaluator | None = None
self.subscriber = get_subscriber(topics=["metrics.raw"])
self.publisher = get_publisher(source="alerts")
async def connect_db(self) -> None:
"""Connect to TimescaleDB for rules and alert storage."""
try:
self.db_pool = await asyncpg.create_pool(
self.config.timescale_url,
min_size=1,
max_size=5,
)
self.logger.info("database_connected")
except Exception as e:
self.logger.warning("database_connection_failed", error=str(e))
self.db_pool = None
async def load_rules(self) -> list[AlertRule]:
"""Load alert rules from database."""
if not self.db_pool:
# Return default rules if no database
return [
AlertRule(
1, "High CPU Usage", "CPU_PERCENT", "gt", 80.0, "warning", True
),
AlertRule(
2, "Critical CPU Usage", "CPU_PERCENT", "gt", 95.0, "critical", True
),
AlertRule(
3,
"High Memory Usage",
"MEMORY_PERCENT",
"gt",
85.0,
"warning",
True,
),
AlertRule(
4,
"Critical Memory Usage",
"MEMORY_PERCENT",
"gt",
95.0,
"critical",
True,
),
AlertRule(
5, "High Disk Usage", "DISK_PERCENT", "gt", 80.0, "warning", True
),
AlertRule(
6,
"Critical Disk Usage",
"DISK_PERCENT",
"gt",
90.0,
"critical",
True,
),
]
async with self.db_pool.acquire() as conn:
rows = await conn.fetch(
"SELECT id, name, metric_type, operator, threshold, severity, enabled FROM alert_rules"
)
return [
AlertRule(
id=row["id"],
name=row["name"],
metric_type=row["metric_type"],
operator=row["operator"],
threshold=row["threshold"],
severity=row["severity"],
enabled=row["enabled"],
)
for row in rows
]
async def store_alert(self, alert: Alert) -> None:
"""Store triggered alert in database."""
if not self.db_pool:
return
try:
async with self.db_pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO alerts (time, machine_id, rule_id, rule_name, metric_type, value, threshold, severity)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
""",
alert.triggered_at,
alert.machine_id,
alert.rule.id,
alert.rule.name,
alert.rule.metric_type,
alert.value,
alert.rule.threshold,
alert.rule.severity,
)
except Exception as e:
self.logger.warning("alert_storage_failed", error=str(e))
async def publish_alert(self, alert: Alert) -> None:
"""Publish alert event for other services (e.g., notifications)."""
await self.publisher.publish(
topic=f"alerts.{alert.rule.severity}",
payload={
"rule_name": alert.rule.name,
"machine_id": alert.machine_id,
"metric_type": alert.rule.metric_type,
"value": alert.value,
"threshold": alert.rule.threshold,
"severity": alert.rule.severity,
"triggered_at": alert.triggered_at.isoformat(),
},
)
async def process_metrics(self, event_data: dict[str, Any]) -> None:
"""Process incoming metrics and evaluate alerts."""
if not self.evaluator:
return
machine_id = event_data.get("machine_id", "unknown")
metrics = event_data.get("metrics", {})
alerts = self.evaluator.evaluate(machine_id, metrics)
for alert in alerts:
self.logger.warning(
"alert_triggered",
rule=alert.rule.name,
machine_id=alert.machine_id,
value=alert.value,
threshold=alert.rule.threshold,
severity=alert.rule.severity,
)
await self.store_alert(alert)
await self.publish_alert(alert)
async def run(self) -> None:
"""Main service loop."""
self.running = True
self.logger.info("alerts_service_starting")
# Connect to database
await self.connect_db()
# Load rules
rules = await self.load_rules()
self.evaluator = AlertEvaluator(rules)
self.logger.info("rules_loaded", count=len(rules))
# Connect to event bus
await self.subscriber.connect()
await self.publisher.connect()
self.logger.info("alerts_service_started")
try:
# Process events
async for event in self.subscriber.consume():
if not self.running:
break
try:
await self.process_metrics(event.payload)
except Exception as e:
self.logger.error("event_processing_error", error=str(e))
except asyncio.CancelledError:
self.logger.info("alerts_service_cancelled")
finally:
await self.subscriber.disconnect()
await self.publisher.disconnect()
if self.db_pool:
await self.db_pool.close()
self.logger.info("alerts_service_stopped")
def stop(self) -> None:
"""Signal the service to stop."""
self.running = False
async def main():
"""Main entry point."""
service = AlertsService()
# Handle shutdown signals
loop = asyncio.get_event_loop()
def signal_handler():
service.logger.info("shutdown_signal_received")
service.stop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
await service.run()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,3 +1,5 @@
grpcio>=1.60.0
grpcio-tools>=1.60.0
redis>=5.0.0
asyncpg>=0.29.0
structlog>=23.2.0

View File

@@ -0,0 +1 @@
"""Collector service."""

209
services/collector/main.py Normal file
View File

@@ -0,0 +1,209 @@
"""Collector service - streams system metrics to the aggregator via gRPC."""
import asyncio
import signal
import sys
from pathlib import Path
import grpc
# Add project root to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from services.collector.metrics import MetricsCollector
from shared import metrics_pb2, metrics_pb2_grpc
from shared.config import get_collector_config
from shared.logging import setup_logging
class CollectorService:
"""Main collector service that streams metrics to the aggregator."""
def __init__(self):
self.config = get_collector_config()
self.logger = setup_logging(
service_name=self.config.service_name,
log_level=self.config.log_level,
log_format=self.config.log_format,
)
self.running = False
self.channel: grpc.aio.Channel | None = None
self.stub: metrics_pb2_grpc.MetricsServiceStub | None = None
self.collector = MetricsCollector(
machine_id=self.config.machine_id,
collect_cpu=self.config.collect_cpu,
collect_memory=self.config.collect_memory,
collect_disk=self.config.collect_disk,
collect_network=self.config.collect_network,
collect_load=self.config.collect_load,
)
async def connect(self) -> None:
"""Establish connection to the aggregator."""
self.logger.info(
"connecting_to_aggregator",
aggregator_url=self.config.aggregator_url,
)
self.channel = grpc.aio.insecure_channel(
self.config.aggregator_url,
options=[
("grpc.keepalive_time_ms", 10000),
("grpc.keepalive_timeout_ms", 5000),
("grpc.keepalive_permit_without_calls", True),
],
)
self.stub = metrics_pb2_grpc.MetricsServiceStub(self.channel)
# Wait for channel to be ready
try:
await asyncio.wait_for(
self.channel.channel_ready(),
timeout=10.0,
)
self.logger.info("connected_to_aggregator")
except asyncio.TimeoutError:
self.logger.error("connection_timeout")
raise
async def disconnect(self) -> None:
"""Close connection to the aggregator."""
if self.channel:
await self.channel.close()
self.channel = None
self.stub = None
self.logger.info("disconnected_from_aggregator")
def _batch_to_proto(self, batch) -> list[metrics_pb2.Metric]:
"""Convert a MetricsBatch to protobuf messages."""
protos = []
for metric in batch.metrics:
proto = metrics_pb2.Metric(
machine_id=batch.machine_id,
hostname=batch.hostname,
timestamp_ms=batch.timestamp_ms,
type=getattr(metrics_pb2, metric.metric_type, 0),
value=metric.value,
labels=metric.labels,
)
protos.append(proto)
return protos
async def _metric_generator(self):
"""Async generator that yields metrics at the configured interval."""
while self.running:
batch = self.collector.collect()
protos = self._batch_to_proto(batch)
for proto in protos:
yield proto
self.logger.debug(
"collected_metrics",
count=len(protos),
machine_id=batch.machine_id,
)
await asyncio.sleep(self.config.collection_interval)
async def stream_metrics(self) -> None:
"""Stream metrics to the aggregator."""
if not self.stub:
raise RuntimeError("Not connected to aggregator")
retry_count = 0
max_retries = 10
base_delay = 1.0
while self.running:
try:
self.logger.info("starting_metric_stream")
response = await self.stub.StreamMetrics(self._metric_generator())
self.logger.info(
"stream_completed",
success=response.success,
metrics_received=response.metrics_received,
message=response.message,
)
retry_count = 0
except grpc.aio.AioRpcError as e:
retry_count += 1
delay = min(base_delay * (2**retry_count), 60.0)
self.logger.warning(
"stream_error",
code=e.code().name,
details=e.details(),
retry_count=retry_count,
retry_delay=delay,
)
if retry_count >= max_retries:
self.logger.error("max_retries_exceeded")
raise
await asyncio.sleep(delay)
# Reconnect
try:
await self.disconnect()
await self.connect()
except Exception as conn_err:
self.logger.error("reconnect_failed", error=str(conn_err))
except asyncio.CancelledError:
self.logger.info("stream_cancelled")
break
async def run(self) -> None:
"""Main entry point for the collector service."""
self.running = True
self.logger.info(
"collector_starting",
machine_id=self.config.machine_id,
interval=self.config.collection_interval,
)
# Initial CPU percent call to initialize (first call always returns 0)
import psutil
psutil.cpu_percent()
await self.connect()
try:
await self.stream_metrics()
finally:
await self.disconnect()
self.logger.info("collector_stopped")
def stop(self) -> None:
"""Signal the collector to stop."""
self.running = False
async def main():
"""Main entry point."""
service = CollectorService()
# Handle shutdown signals
loop = asyncio.get_event_loop()
def signal_handler():
service.logger.info("shutdown_signal_received")
service.stop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
await service.run()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,233 @@
"""System metrics collection using psutil."""
import socket
import time
from dataclasses import dataclass, field
import psutil
@dataclass
class MetricPoint:
"""A single metric data point."""
metric_type: str
value: float
labels: dict[str, str] = field(default_factory=dict)
@dataclass
class MetricsBatch:
"""A batch of metrics from a single collection cycle."""
machine_id: str
hostname: str
timestamp_ms: int
metrics: list[MetricPoint]
class MetricsCollector:
"""Collects system metrics using psutil."""
def __init__(
self,
machine_id: str,
collect_cpu: bool = True,
collect_memory: bool = True,
collect_disk: bool = True,
collect_network: bool = True,
collect_load: bool = True,
):
self.machine_id = machine_id
self.hostname = socket.gethostname()
self.collect_cpu = collect_cpu
self.collect_memory = collect_memory
self.collect_disk = collect_disk
self.collect_network = collect_network
self.collect_load = collect_load
# Track previous network counters for rate calculation
self._prev_net_io: psutil._common.snetio | None = None
self._prev_net_time: float | None = None
def collect(self) -> MetricsBatch:
"""Collect all enabled metrics and return as a batch."""
metrics: list[MetricPoint] = []
if self.collect_cpu:
metrics.extend(self._collect_cpu())
if self.collect_memory:
metrics.extend(self._collect_memory())
if self.collect_disk:
metrics.extend(self._collect_disk())
if self.collect_network:
metrics.extend(self._collect_network())
if self.collect_load:
metrics.extend(self._collect_load())
return MetricsBatch(
machine_id=self.machine_id,
hostname=self.hostname,
timestamp_ms=int(time.time() * 1000),
metrics=metrics,
)
def _collect_cpu(self) -> list[MetricPoint]:
"""Collect CPU metrics."""
metrics = []
# Overall CPU percent
cpu_percent = psutil.cpu_percent(interval=None)
metrics.append(
MetricPoint(
metric_type="CPU_PERCENT",
value=cpu_percent,
)
)
# Per-core CPU percent
per_cpu = psutil.cpu_percent(interval=None, percpu=True)
for i, pct in enumerate(per_cpu):
metrics.append(
MetricPoint(
metric_type="CPU_PERCENT_PER_CORE",
value=pct,
labels={"core": str(i)},
)
)
return metrics
def _collect_memory(self) -> list[MetricPoint]:
"""Collect memory metrics."""
mem = psutil.virtual_memory()
return [
MetricPoint(metric_type="MEMORY_PERCENT", value=mem.percent),
MetricPoint(metric_type="MEMORY_USED_BYTES", value=float(mem.used)),
MetricPoint(
metric_type="MEMORY_AVAILABLE_BYTES", value=float(mem.available)
),
]
def _collect_disk(self) -> list[MetricPoint]:
"""Collect disk metrics."""
metrics = []
# Disk usage for root partition
try:
disk = psutil.disk_usage("/")
metrics.append(
MetricPoint(
metric_type="DISK_PERCENT",
value=disk.percent,
labels={"mount": "/"},
)
)
metrics.append(
MetricPoint(
metric_type="DISK_USED_BYTES",
value=float(disk.used),
labels={"mount": "/"},
)
)
except (PermissionError, FileNotFoundError):
pass
# Disk I/O rates
try:
io = psutil.disk_io_counters()
if io:
metrics.append(
MetricPoint(
metric_type="DISK_READ_BYTES_SEC",
value=float(
io.read_bytes
), # Will be converted to rate by aggregator
)
)
metrics.append(
MetricPoint(
metric_type="DISK_WRITE_BYTES_SEC",
value=float(io.write_bytes),
)
)
except (PermissionError, AttributeError):
pass
return metrics
def _collect_network(self) -> list[MetricPoint]:
"""Collect network metrics with rate calculation."""
metrics = []
try:
net_io = psutil.net_io_counters()
current_time = time.time()
if self._prev_net_io is not None and self._prev_net_time is not None:
time_delta = current_time - self._prev_net_time
if time_delta > 0:
bytes_sent_rate = (
net_io.bytes_sent - self._prev_net_io.bytes_sent
) / time_delta
bytes_recv_rate = (
net_io.bytes_recv - self._prev_net_io.bytes_recv
) / time_delta
metrics.append(
MetricPoint(
metric_type="NETWORK_SENT_BYTES_SEC",
value=bytes_sent_rate,
)
)
metrics.append(
MetricPoint(
metric_type="NETWORK_RECV_BYTES_SEC",
value=bytes_recv_rate,
)
)
self._prev_net_io = net_io
self._prev_net_time = current_time
# Connection count
connections = len(psutil.net_connections(kind="inet"))
metrics.append(
MetricPoint(
metric_type="NETWORK_CONNECTIONS",
value=float(connections),
)
)
except (PermissionError, psutil.AccessDenied):
pass
return metrics
def _collect_load(self) -> list[MetricPoint]:
"""Collect load average metrics (Unix only)."""
metrics = []
try:
load1, load5, load15 = psutil.getloadavg()
metrics.append(MetricPoint(metric_type="LOAD_AVG_1M", value=load1))
metrics.append(MetricPoint(metric_type="LOAD_AVG_5M", value=load5))
metrics.append(MetricPoint(metric_type="LOAD_AVG_15M", value=load15))
except (AttributeError, OSError):
# Windows doesn't have getloadavg
pass
# Process count
metrics.append(
MetricPoint(
metric_type="PROCESS_COUNT",
value=float(len(psutil.pids())),
)
)
return metrics

View File

@@ -21,6 +21,8 @@ RUN python -m grpc_tools.protoc \
/app/proto/metrics.proto
COPY services/gateway /app/services/gateway
COPY services/aggregator/__init__.py /app/services/aggregator/__init__.py
COPY services/aggregator/storage.py /app/services/aggregator/storage.py
COPY web /app/web
ENV PYTHONPATH=/app

View File

@@ -0,0 +1 @@
"""Gateway service."""

393
services/gateway/main.py Normal file
View File

@@ -0,0 +1,393 @@
"""Gateway service - FastAPI with WebSocket for real-time dashboard."""
import asyncio
import json
import sys
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any
import grpc
from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect
from fastapi.requests import Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
# Add project root to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from services.aggregator.storage import TimescaleStorage
from shared import metrics_pb2, metrics_pb2_grpc
from shared.config import get_gateway_config
from shared.events import get_subscriber
from shared.logging import setup_logging
# Global state
config = get_gateway_config()
logger = setup_logging(
service_name=config.service_name,
log_level=config.log_level,
log_format=config.log_format,
)
# WebSocket connection manager
class ConnectionManager:
"""Manages WebSocket connections for real-time updates."""
def __init__(self):
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket) -> None:
await websocket.accept()
self.active_connections.append(websocket)
logger.info("websocket_connected", total=len(self.active_connections))
def disconnect(self, websocket: WebSocket) -> None:
self.active_connections.remove(websocket)
logger.info("websocket_disconnected", total=len(self.active_connections))
async def broadcast(self, message: dict) -> None:
"""Broadcast message to all connected clients."""
if not self.active_connections:
return
data = json.dumps(message)
disconnected = []
for connection in self.active_connections:
try:
await connection.send_text(data)
except Exception:
disconnected.append(connection)
# Clean up disconnected
for conn in disconnected:
try:
self.active_connections.remove(conn)
except ValueError:
pass
manager = ConnectionManager()
timescale: TimescaleStorage | None = None
grpc_channel: grpc.aio.Channel | None = None
grpc_stub: metrics_pb2_grpc.MetricsServiceStub | None = None
async def event_listener():
"""Background task that listens for metric events and broadcasts to WebSocket clients."""
logger.info("event_listener_starting")
async with get_subscriber(topics=["metrics.raw"]) as subscriber:
async for event in subscriber.consume():
try:
await manager.broadcast(
{
"type": "metrics",
"data": event.payload,
"timestamp": event.timestamp.isoformat(),
}
)
except Exception as e:
logger.warning("broadcast_error", error=str(e))
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager."""
global timescale, grpc_channel, grpc_stub
logger.info("gateway_starting", port=config.http_port)
# Connect to TimescaleDB for historical queries
timescale = TimescaleStorage(config.timescale_url)
try:
await timescale.connect()
except Exception as e:
logger.warning("timescale_connection_failed", error=str(e))
timescale = None
# Connect to aggregator via gRPC
grpc_channel = grpc.aio.insecure_channel(config.aggregator_url)
grpc_stub = metrics_pb2_grpc.MetricsServiceStub(grpc_channel)
# Start event listener in background
listener_task = asyncio.create_task(event_listener())
logger.info("gateway_started")
yield
# Cleanup
listener_task.cancel()
try:
await listener_task
except asyncio.CancelledError:
pass
if grpc_channel:
await grpc_channel.close()
if timescale:
await timescale.disconnect()
logger.info("gateway_stopped")
# Create FastAPI app
app = FastAPI(
title="System Monitor Gateway",
description="Real-time system monitoring dashboard",
version="0.1.0",
lifespan=lifespan,
)
# Mount static files
static_path = Path(__file__).parent.parent.parent / "web" / "static"
if static_path.exists():
app.mount("/static", StaticFiles(directory=str(static_path)), name="static")
# Templates
templates_path = Path(__file__).parent.parent.parent / "web" / "templates"
templates = (
Jinja2Templates(directory=str(templates_path)) if templates_path.exists() else None
)
# ============================================================================
# Health endpoints
# ============================================================================
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy", "service": "gateway"}
@app.get("/ready")
async def readiness_check():
"""Readiness check - verifies dependencies."""
checks = {"gateway": "ok"}
# Check gRPC connection
try:
if grpc_stub:
await grpc_stub.GetAllStates(metrics_pb2.Empty(), timeout=2.0)
checks["aggregator"] = "ok"
except Exception as e:
checks["aggregator"] = f"error: {str(e)}"
# Check TimescaleDB
if timescale and timescale._pool:
checks["timescaledb"] = "ok"
else:
checks["timescaledb"] = "not connected"
return {"status": "ready", "checks": checks}
# ============================================================================
# REST API endpoints
# ============================================================================
@app.get("/api/machines")
async def get_machines():
"""Get current state of all machines."""
if not grpc_stub:
raise HTTPException(status_code=503, detail="Aggregator not connected")
try:
response = await grpc_stub.GetAllStates(metrics_pb2.Empty(), timeout=5.0)
machines = []
for state in response.machines:
metrics = {}
for m in state.current_metrics:
metric_type = metrics_pb2.MetricType.Name(m.type)
metrics[metric_type] = m.value
machines.append(
{
"machine_id": state.machine_id,
"hostname": state.hostname,
"last_seen_ms": state.last_seen_ms,
"health": metrics_pb2.HealthStatus.Name(state.health),
"metrics": metrics,
}
)
return {"machines": machines}
except grpc.aio.AioRpcError as e:
raise HTTPException(status_code=503, detail=f"Aggregator error: {e.details()}")
@app.get("/api/machines/{machine_id}")
async def get_machine(machine_id: str):
"""Get current state of a specific machine."""
if not grpc_stub:
raise HTTPException(status_code=503, detail="Aggregator not connected")
try:
response = await grpc_stub.GetCurrentState(
metrics_pb2.StateRequest(machine_id=machine_id),
timeout=5.0,
)
if not response.machine_id:
raise HTTPException(status_code=404, detail="Machine not found")
metrics = {}
for m in response.current_metrics:
metric_type = metrics_pb2.MetricType.Name(m.type)
metrics[metric_type] = m.value
return {
"machine_id": response.machine_id,
"hostname": response.hostname,
"last_seen_ms": response.last_seen_ms,
"health": metrics_pb2.HealthStatus.Name(response.health),
"metrics": metrics,
}
except grpc.aio.AioRpcError as e:
if e.code() == grpc.StatusCode.NOT_FOUND:
raise HTTPException(status_code=404, detail="Machine not found")
raise HTTPException(status_code=503, detail=f"Aggregator error: {e.details()}")
@app.get("/api/metrics")
async def get_metrics(
machine_id: str | None = Query(None),
metric_type: str | None = Query(None),
minutes: int = Query(60, ge=1, le=1440),
limit: int = Query(1000, ge=1, le=10000),
):
"""Get historical metrics."""
if not timescale:
raise HTTPException(status_code=503, detail="TimescaleDB not connected")
end_time = datetime.utcnow()
start_time = end_time - timedelta(minutes=minutes)
try:
metrics = await timescale.get_metrics(
machine_id=machine_id,
metric_type=metric_type,
start_time=start_time,
end_time=end_time,
limit=limit,
)
return {"metrics": metrics, "count": len(metrics)}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ============================================================================
# WebSocket endpoint
# ============================================================================
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for real-time metric updates."""
await manager.connect(websocket)
try:
# Send initial state
if grpc_stub:
try:
response = await grpc_stub.GetAllStates(
metrics_pb2.Empty(), timeout=5.0
)
for state in response.machines:
metrics = {}
for m in state.current_metrics:
metric_type = metrics_pb2.MetricType.Name(m.type)
metrics[metric_type] = m.value
await websocket.send_json(
{
"type": "initial",
"data": {
"machine_id": state.machine_id,
"hostname": state.hostname,
"metrics": metrics,
},
}
)
except Exception as e:
logger.warning("initial_state_error", error=str(e))
# Keep connection alive and handle incoming messages
while True:
try:
data = await websocket.receive_text()
# Handle ping/pong or commands from client
if data == "ping":
await websocket.send_text("pong")
except WebSocketDisconnect:
break
finally:
manager.disconnect(websocket)
# ============================================================================
# Dashboard (HTML)
# ============================================================================
@app.get("/", response_class=HTMLResponse)
async def dashboard(request: Request):
"""Serve the dashboard HTML."""
if templates:
return templates.TemplateResponse("dashboard.html", {"request": request})
# Fallback if templates not found
return HTMLResponse("""
<!DOCTYPE html>
<html>
<head>
<title>System Monitor</title>
<style>
body { font-family: system-ui; background: #1a1a2e; color: #eee; padding: 2rem; }
h1 { color: #e94560; }
pre { background: #16213e; padding: 1rem; border-radius: 8px; overflow: auto; }
</style>
</head>
<body>
<h1>System Monitor</h1>
<p>Dashboard template not found. API endpoints:</p>
<ul>
<li><a href="/api/machines">/api/machines</a> - Current state of all machines</li>
<li><a href="/api/metrics">/api/metrics</a> - Historical metrics</li>
<li><a href="/docs">/docs</a> - API documentation</li>
</ul>
<h2>Live Metrics</h2>
<pre id="output">Connecting...</pre>
<script>
const ws = new WebSocket(`ws://${location.host}/ws`);
const output = document.getElementById('output');
ws.onmessage = (e) => {
output.textContent = JSON.stringify(JSON.parse(e.data), null, 2);
};
ws.onclose = () => { output.textContent = 'Disconnected'; };
</script>
</body>
</html>
""")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=config.http_port)