""" gRPC Server - Worker Service Implementation Runs in the worker process to handle job submissions and progress streaming. """ import json import logging import os import time from concurrent import futures from typing import Iterator import grpc # Configuration from environment GRPC_PORT = int(os.environ.get("GRPC_PORT", "50051")) GRPC_MAX_WORKERS = int(os.environ.get("GRPC_MAX_WORKERS", "10")) # Generated stubs - run `python schema/generate.py --proto` if missing from . import worker_pb2, worker_pb2_grpc logger = logging.getLogger(__name__) # Active jobs progress tracking (shared state for streaming) _active_jobs: dict[str, dict] = {} class WorkerServicer(worker_pb2_grpc.WorkerServiceServicer): """gRPC service implementation for worker operations.""" def __init__(self): """Initialize the servicer.""" pass def SubmitJob(self, request, context): """Submit a transcode/trim job to the worker.""" job_id = request.job_id logger.info(f"SubmitJob: {job_id}") try: # Parse preset preset = json.loads(request.preset_json) if request.preset_json else None # Initialize progress tracking _active_jobs[job_id] = { "status": "pending", "progress": 0, "current_frame": 0, "current_time": 0.0, "speed": 0.0, "error": None, } # TODO: dispatch via executor (local/lambda/gcp/grpc) return worker_pb2.JobResponse( job_id=job_id, accepted=True, message="Job submitted", ) except Exception as e: logger.exception(f"SubmitJob failed: {e}") return worker_pb2.JobResponse( job_id=job_id, accepted=False, message=str(e), ) def StreamProgress(self, request, context) -> Iterator[worker_pb2.ProgressUpdate]: """Stream progress updates for a job.""" job_id = request.job_id logger.info(f"StreamProgress: {job_id}") # Check if job exists if job_id not in _active_jobs: yield worker_pb2.ProgressUpdate( job_id=job_id, progress=0, status="not_found", error="Job not found", ) return # Stream updates until job completes last_progress = -1 while True: if context.cancelled(): logger.info(f"StreamProgress cancelled: {job_id}") break job_state = _active_jobs.get(job_id) if not job_state: break # Only yield if progress changed if job_state["progress"] != last_progress: last_progress = job_state["progress"] yield worker_pb2.ProgressUpdate( job_id=job_id, progress=job_state["progress"], current_frame=job_state.get("current_frame", 0), current_time=job_state.get("current_time", 0.0), speed=job_state.get("speed", 0.0), status=job_state["status"], error=job_state.get("error"), ) # Exit if job is done if job_state["status"] in ("completed", "failed", "cancelled"): break # Small delay to avoid busy loop time.sleep(0.1) # Cleanup completed jobs if job_id in _active_jobs: status = _active_jobs[job_id].get("status") if status in ("completed", "failed", "cancelled"): _active_jobs.pop(job_id, None) def CancelJob(self, request, context): """Cancel a running job.""" job_id = request.job_id logger.info(f"CancelJob: {job_id}") if job_id in _active_jobs: _active_jobs[job_id]["status"] = "cancelled" return worker_pb2.CancelResponse( job_id=job_id, cancelled=True, message="Job cancelled", ) return worker_pb2.CancelResponse( job_id=job_id, cancelled=False, message="Job not found", ) def StreamChunkPipeline(self, request, context) -> Iterator[worker_pb2.ChunkPipelineEvent]: """Stream chunk pipeline events for a job.""" from core.events import poll_events job_id = request.job_id logger.info(f"StreamChunkPipeline: {job_id}") cursor = 0 timeout = time.monotonic() + 600 # 10 min max while context.is_active() and time.monotonic() < timeout: events, cursor = poll_events(job_id, cursor) for data in events: event_type = data.pop("event", "") yield worker_pb2.ChunkPipelineEvent( job_id=job_id, event_type=event_type, sequence=data.get("sequence", 0), worker_id=data.get("worker_id", ""), state=data.get("state", ""), queue_size=data.get("queue_size", 0), elapsed=data.get("elapsed", 0.0), throughput_mbps=data.get("throughput_mbps", 0.0), total_chunks=data.get("total_chunks", 0), processed_chunks=data.get("processed_chunks", 0), failed_chunks=data.get("failed_chunks", 0), error=data.get("error", ""), processing_time=data.get("processing_time", 0.0), retries=data.get("retries", 0), ) if event_type in ("pipeline_complete", "pipeline_error"): return time.sleep(0.05) def GetWorkerStatus(self, request, context): """Get worker health and capabilities.""" return worker_pb2.WorkerStatus( available=True, active_jobs=len(_active_jobs), supported_codecs=[], gpu_available=False, ) def update_job_progress( job_id: str, progress: int, current_frame: int = 0, current_time: float = 0.0, speed: float = 0.0, status: str = "processing", error: str = None, **extra, ) -> None: """ Update job progress (called from worker tasks). Updates both the in-memory gRPC state and the Django database. Extra kwargs are stored for chunker-specific fields (total_chunks, processed_chunks, failed_chunks, throughput_mbps, etc.). """ if job_id in _active_jobs: _active_jobs[job_id].update( { "progress": progress, "current_frame": current_frame, "current_time": current_time, "speed": speed, "status": status, "error": error, **extra, } ) # Update Django database try: from django.utils import timezone from core.db import update_job_fields updates = { "progress": progress, "current_frame": current_frame, "current_time": current_time, "speed": str(speed), "status": status, } if error: updates["error_message"] = error if status == "processing": updates["started_at"] = timezone.now() elif status in ("completed", "failed"): updates["completed_at"] = timezone.now() update_job_fields(job_id, **updates) except Exception as e: logger.warning(f"Failed to update job {job_id} in DB: {e}") def serve(port: int = None) -> grpc.Server: """ Start the gRPC server. Args: port: Port to listen on (defaults to GRPC_PORT env var) Returns: The running gRPC server """ if port is None: port = GRPC_PORT server = grpc.server(futures.ThreadPoolExecutor(max_workers=GRPC_MAX_WORKERS)) worker_pb2_grpc.add_WorkerServiceServicer_to_server( WorkerServicer(), server, ) server.add_insecure_port(f"[::]:{port}") server.start() logger.info(f"gRPC server started on port {port}") return server if __name__ == "__main__": logging.basicConfig(level=logging.INFO) server = serve() server.wait_for_termination()