"""Aggregator service - gRPC server that receives metrics and stores them.""" import asyncio import signal import sys from pathlib import Path import grpc from grpc_health.v1 import health, health_pb2, health_pb2_grpc # Add project root to path for imports sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from services.aggregator.storage import RedisStorage, TimescaleStorage from shared import metrics_pb2, metrics_pb2_grpc from shared.config import get_aggregator_config from shared.events import get_publisher from shared.logging import setup_logging class MetricsServicer(metrics_pb2_grpc.MetricsServiceServicer): """gRPC servicer for metrics ingestion.""" def __init__( self, redis_storage: RedisStorage, timescale_storage: TimescaleStorage, event_publisher, logger, ): self.redis = redis_storage self.timescale = timescale_storage self.publisher = event_publisher self.logger = logger async def StreamMetrics(self, request_iterator, context): """Receive streaming metrics from a collector.""" metrics_received = 0 current_machine = None current_batch: list[tuple[str, float, dict]] = [] batch_timestamp = 0 batch_hostname = "" try: async for metric in request_iterator: metrics_received += 1 # Track current machine if current_machine != metric.machine_id: # Flush previous batch if switching machines if current_machine and current_batch: await self._flush_batch( current_machine, batch_hostname, batch_timestamp, current_batch, ) current_batch = [] current_machine = metric.machine_id self.logger.info( "collector_connected", machine_id=metric.machine_id, hostname=metric.hostname, ) # Get metric type name metric_type = metrics_pb2.MetricType.Name(metric.type) # Add to batch current_batch.append( ( metric_type, metric.value, dict(metric.labels), ) ) batch_timestamp = metric.timestamp_ms batch_hostname = metric.hostname # Flush batch every 20 metrics or if timestamp changes significantly if len(current_batch) >= 20: await self._flush_batch( current_machine, batch_hostname, batch_timestamp, current_batch ) current_batch = [] # Flush remaining if current_machine and current_batch: await self._flush_batch( current_machine, batch_hostname, batch_timestamp, current_batch ) self.logger.info( "stream_completed", machine_id=current_machine, metrics_received=metrics_received, ) return metrics_pb2.StreamAck( success=True, metrics_received=metrics_received, message="OK", ) except Exception as e: self.logger.error( "stream_error", error=str(e), machine_id=current_machine, metrics_received=metrics_received, ) return metrics_pb2.StreamAck( success=False, metrics_received=metrics_received, message=str(e), ) async def _flush_batch( self, machine_id: str, hostname: str, timestamp_ms: int, batch: list[tuple[str, float, dict]], ) -> None: """Flush a batch of metrics to storage and events.""" # Aggregate metrics for Redis state metrics_dict = {} for metric_type, value, labels in batch: key = metric_type if labels: key = f"{metric_type}:{','.join(f'{k}={v}' for k, v in labels.items())}" metrics_dict[key] = value # Update Redis (current state) await self.redis.update_machine_state( machine_id=machine_id, hostname=hostname, metrics=metrics_dict, timestamp_ms=timestamp_ms, ) # Insert into TimescaleDB (historical) try: await self.timescale.insert_metrics( machine_id=machine_id, hostname=hostname, timestamp_ms=timestamp_ms, metrics=batch, ) except Exception as e: self.logger.warning("timescale_insert_failed", error=str(e)) # Update machine registry try: await self.timescale.update_machine_registry( machine_id=machine_id, hostname=hostname, ) except Exception as e: self.logger.warning("machine_registry_update_failed", error=str(e)) # Publish event for subscribers (alerts, gateway) await self.publisher.publish( topic="metrics.raw", payload={ "machine_id": machine_id, "hostname": hostname, "timestamp_ms": timestamp_ms, "metrics": metrics_dict, }, ) self.logger.debug( "batch_flushed", machine_id=machine_id, count=len(batch), ) async def GetCurrentState(self, request, context): """Get current state for a single machine.""" state = await self.redis.get_machine_state(request.machine_id) if not state: context.set_code(grpc.StatusCode.NOT_FOUND) context.set_details(f"Machine {request.machine_id} not found") return metrics_pb2.MachineState() # Convert state to proto metrics = [] for key, value in state.get("metrics", {}).items(): parts = key.split(":") metric_type_str = parts[0] labels = {} if len(parts) > 1: for pair in parts[1].split(","): k, v = pair.split("=") labels[k] = v metric_type = getattr(metrics_pb2, metric_type_str, 0) metrics.append( metrics_pb2.Metric( machine_id=state["machine_id"], hostname=state["hostname"], timestamp_ms=state["last_seen_ms"], type=metric_type, value=value, labels=labels, ) ) return metrics_pb2.MachineState( machine_id=state["machine_id"], hostname=state["hostname"], last_seen_ms=state["last_seen_ms"], current_metrics=metrics, health=metrics_pb2.HEALTHY, ) async def GetAllStates(self, request, context): """Get current state for all machines.""" states = await self.redis.get_all_machines() machine_states = [] for state in states: metrics = [] for key, value in state.get("metrics", {}).items(): parts = key.split(":") metric_type_str = parts[0] metric_type = getattr(metrics_pb2, metric_type_str, 0) metrics.append( metrics_pb2.Metric( machine_id=state["machine_id"], hostname=state["hostname"], timestamp_ms=state["last_seen_ms"], type=metric_type, value=value, ) ) machine_states.append( metrics_pb2.MachineState( machine_id=state["machine_id"], hostname=state["hostname"], last_seen_ms=state["last_seen_ms"], current_metrics=metrics, health=metrics_pb2.HEALTHY, ) ) return metrics_pb2.AllMachinesState(machines=machine_states) class AggregatorService: """Main aggregator service.""" def __init__(self): self.config = get_aggregator_config() self.logger = setup_logging( service_name=self.config.service_name, log_level=self.config.log_level, log_format=self.config.log_format, ) self.redis = RedisStorage(self.config.redis_url) self.timescale = TimescaleStorage(self.config.timescale_url) self.publisher = get_publisher(source="aggregator") self.server: grpc.aio.Server | None = None self.running = False async def start(self) -> None: """Start the gRPC server.""" self.running = True # Connect to storage await self.redis.connect() try: await self.timescale.connect() except Exception as e: self.logger.warning( "timescale_connection_failed", error=str(e), message="Continuing without TimescaleDB - metrics won't be persisted", ) # Connect to event publisher await self.publisher.connect() # Create gRPC server self.server = grpc.aio.server() # Add metrics servicer servicer = MetricsServicer( redis_storage=self.redis, timescale_storage=self.timescale, event_publisher=self.publisher, logger=self.logger, ) metrics_pb2_grpc.add_MetricsServiceServicer_to_server(servicer, self.server) # Add health check servicer health_servicer = health.HealthServicer() health_servicer.set("", health_pb2.HealthCheckResponse.SERVING) health_servicer.set("MetricsService", health_pb2.HealthCheckResponse.SERVING) health_pb2_grpc.add_HealthServicer_to_server(health_servicer, self.server) # Start server listen_addr = f"[::]:{self.config.grpc_port}" self.server.add_insecure_port(listen_addr) await self.server.start() self.logger.info( "aggregator_started", port=self.config.grpc_port, listen_addr=listen_addr, ) async def stop(self) -> None: """Stop the gRPC server.""" self.running = False if self.server: await self.server.stop(grace=5) self.server = None await self.publisher.disconnect() await self.timescale.disconnect() await self.redis.disconnect() self.logger.info("aggregator_stopped") async def wait(self) -> None: """Wait for the server to terminate.""" if self.server: await self.server.wait_for_termination() async def main(): """Main entry point.""" service = AggregatorService() # Handle shutdown signals loop = asyncio.get_event_loop() async def shutdown(): service.logger.info("shutdown_signal_received") await service.stop() for sig in (signal.SIGTERM, signal.SIGINT): loop.add_signal_handler(sig, lambda: asyncio.create_task(shutdown())) await service.start() await service.wait() if __name__ == "__main__": asyncio.run(main())