""" gRPC Server - Worker Service Implementation Runs in the worker process to handle job submissions and progress streaming. """ import json import logging import os import time from concurrent import futures from typing import Iterator import grpc # Configuration from environment GRPC_PORT = int(os.environ.get("GRPC_PORT", "50051")) GRPC_MAX_WORKERS = int(os.environ.get("GRPC_MAX_WORKERS", "10")) # Generated stubs - run `python schema/generate.py --proto` if missing from . import worker_pb2, worker_pb2_grpc logger = logging.getLogger(__name__) # Active jobs progress tracking (shared state for streaming) _active_jobs: dict[str, dict] = {} class WorkerServicer(worker_pb2_grpc.WorkerServiceServicer): """gRPC service implementation for worker operations.""" def __init__(self, celery_app=None): """ Initialize the servicer. Args: celery_app: Optional Celery app for task dispatch """ self.celery_app = celery_app def SubmitJob(self, request, context): """Submit a transcode/trim job to the worker.""" job_id = request.job_id logger.info(f"SubmitJob: {job_id}") try: # Parse preset preset = json.loads(request.preset_json) if request.preset_json else None # Initialize progress tracking _active_jobs[job_id] = { "status": "pending", "progress": 0, "current_frame": 0, "current_time": 0.0, "speed": 0.0, "error": None, } # Dispatch to Celery if available if self.celery_app: from core.task.tasks import run_transcode_job task = run_transcode_job.delay( job_id=job_id, source_path=request.source_path, output_path=request.output_path, preset=preset, trim_start=request.trim_start if request.HasField("trim_start") else None, trim_end=request.trim_end if request.HasField("trim_end") else None, ) _active_jobs[job_id]["celery_task_id"] = task.id return worker_pb2.JobResponse( job_id=job_id, accepted=True, message="Job submitted", ) except Exception as e: logger.exception(f"SubmitJob failed: {e}") return worker_pb2.JobResponse( job_id=job_id, accepted=False, message=str(e), ) def StreamProgress(self, request, context) -> Iterator[worker_pb2.ProgressUpdate]: """Stream progress updates for a job.""" job_id = request.job_id logger.info(f"StreamProgress: {job_id}") # Check if job exists if job_id not in _active_jobs: yield worker_pb2.ProgressUpdate( job_id=job_id, progress=0, status="not_found", error="Job not found", ) return # Stream updates until job completes last_progress = -1 while True: if context.cancelled(): logger.info(f"StreamProgress cancelled: {job_id}") break job_state = _active_jobs.get(job_id) if not job_state: break # Only yield if progress changed if job_state["progress"] != last_progress: last_progress = job_state["progress"] yield worker_pb2.ProgressUpdate( job_id=job_id, progress=job_state["progress"], current_frame=job_state.get("current_frame", 0), current_time=job_state.get("current_time", 0.0), speed=job_state.get("speed", 0.0), status=job_state["status"], error=job_state.get("error"), ) # Exit if job is done if job_state["status"] in ("completed", "failed", "cancelled"): break # Small delay to avoid busy loop time.sleep(0.1) # Cleanup completed jobs if job_id in _active_jobs: status = _active_jobs[job_id].get("status") if status in ("completed", "failed", "cancelled"): _active_jobs.pop(job_id, None) def CancelJob(self, request, context): """Cancel a running job.""" job_id = request.job_id logger.info(f"CancelJob: {job_id}") if job_id in _active_jobs: _active_jobs[job_id]["status"] = "cancelled" # Revoke Celery task if available if self.celery_app: task_id = _active_jobs[job_id].get("celery_task_id") if task_id: self.celery_app.control.revoke(task_id, terminate=True) return worker_pb2.CancelResponse( job_id=job_id, cancelled=True, message="Job cancelled", ) return worker_pb2.CancelResponse( job_id=job_id, cancelled=False, message="Job not found", ) def GetWorkerStatus(self, request, context): """Get worker health and capabilities.""" try: from core.ffmpeg import get_encoders encoders = get_encoders() codec_names = [e["name"] for e in encoders.get("video", [])] except Exception: codec_names = [] # Check for GPU encoders gpu_available = any( "nvenc" in name or "vaapi" in name or "qsv" in name for name in codec_names ) return worker_pb2.WorkerStatus( available=True, active_jobs=len(_active_jobs), supported_codecs=codec_names[:20], # Limit to 20 gpu_available=gpu_available, ) def update_job_progress( job_id: str, progress: int, current_frame: int = 0, current_time: float = 0.0, speed: float = 0.0, status: str = "processing", error: str = None, ) -> None: """ Update job progress (called from worker tasks). Updates both the in-memory gRPC state and the Django database. """ if job_id in _active_jobs: _active_jobs[job_id].update( { "progress": progress, "current_frame": current_frame, "current_time": current_time, "speed": speed, "status": status, "error": error, } ) # Update Django database try: from django.utils import timezone from core.db import update_job_fields updates = { "progress": progress, "current_frame": current_frame, "current_time": current_time, "speed": str(speed), "status": status, } if error: updates["error_message"] = error if status == "processing": updates["started_at"] = timezone.now() elif status in ("completed", "failed"): updates["completed_at"] = timezone.now() update_job_fields(job_id, **updates) except Exception as e: logger.warning(f"Failed to update job {job_id} in DB: {e}") def serve(port: int = None, celery_app=None) -> grpc.Server: """ Start the gRPC server. Args: port: Port to listen on (defaults to GRPC_PORT env var) celery_app: Optional Celery app for task dispatch Returns: The running gRPC server """ if port is None: port = GRPC_PORT server = grpc.server(futures.ThreadPoolExecutor(max_workers=GRPC_MAX_WORKERS)) worker_pb2_grpc.add_WorkerServiceServicer_to_server( WorkerServicer(celery_app=celery_app), server, ) server.add_insecure_port(f"[::]:{port}") server.start() logger.info(f"gRPC server started on port {port}") return server if __name__ == "__main__": logging.basicConfig(level=logging.INFO) server = serve() server.wait_for_termination()