Files
mediaproc/core/rpc/server.py
2026-03-30 07:22:14 -03:00

276 lines
8.3 KiB
Python

"""
gRPC Server - Worker Service Implementation
Runs in the worker process to handle job submissions and progress streaming.
"""
import json
import logging
import os
import time
from concurrent import futures
from typing import Iterator
import grpc
# Configuration from environment
GRPC_PORT = int(os.environ.get("GRPC_PORT", "50051"))
GRPC_MAX_WORKERS = int(os.environ.get("GRPC_MAX_WORKERS", "10"))
# Generated stubs - run `python schema/generate.py --proto` if missing
from . import worker_pb2, worker_pb2_grpc
logger = logging.getLogger(__name__)
# Active jobs progress tracking (shared state for streaming)
_active_jobs: dict[str, dict] = {}
class WorkerServicer(worker_pb2_grpc.WorkerServiceServicer):
"""gRPC service implementation for worker operations."""
def __init__(self):
"""Initialize the servicer."""
pass
def SubmitJob(self, request, context):
"""Submit a transcode/trim job to the worker."""
job_id = request.job_id
logger.info(f"SubmitJob: {job_id}")
try:
# Parse preset
preset = json.loads(request.preset_json) if request.preset_json else None
# Initialize progress tracking
_active_jobs[job_id] = {
"status": "pending",
"progress": 0,
"current_frame": 0,
"current_time": 0.0,
"speed": 0.0,
"error": None,
}
# TODO: dispatch via executor (local/lambda/gcp/grpc)
return worker_pb2.JobResponse(
job_id=job_id,
accepted=True,
message="Job submitted",
)
except Exception as e:
logger.exception(f"SubmitJob failed: {e}")
return worker_pb2.JobResponse(
job_id=job_id,
accepted=False,
message=str(e),
)
def StreamProgress(self, request, context) -> Iterator[worker_pb2.ProgressUpdate]:
"""Stream progress updates for a job."""
job_id = request.job_id
logger.info(f"StreamProgress: {job_id}")
# Check if job exists
if job_id not in _active_jobs:
yield worker_pb2.ProgressUpdate(
job_id=job_id,
progress=0,
status="not_found",
error="Job not found",
)
return
# Stream updates until job completes
last_progress = -1
while True:
if context.cancelled():
logger.info(f"StreamProgress cancelled: {job_id}")
break
job_state = _active_jobs.get(job_id)
if not job_state:
break
# Only yield if progress changed
if job_state["progress"] != last_progress:
last_progress = job_state["progress"]
yield worker_pb2.ProgressUpdate(
job_id=job_id,
progress=job_state["progress"],
current_frame=job_state.get("current_frame", 0),
current_time=job_state.get("current_time", 0.0),
speed=job_state.get("speed", 0.0),
status=job_state["status"],
error=job_state.get("error"),
)
# Exit if job is done
if job_state["status"] in ("completed", "failed", "cancelled"):
break
# Small delay to avoid busy loop
time.sleep(0.1)
# Cleanup completed jobs
if job_id in _active_jobs:
status = _active_jobs[job_id].get("status")
if status in ("completed", "failed", "cancelled"):
_active_jobs.pop(job_id, None)
def CancelJob(self, request, context):
"""Cancel a running job."""
job_id = request.job_id
logger.info(f"CancelJob: {job_id}")
if job_id in _active_jobs:
_active_jobs[job_id]["status"] = "cancelled"
return worker_pb2.CancelResponse(
job_id=job_id,
cancelled=True,
message="Job cancelled",
)
return worker_pb2.CancelResponse(
job_id=job_id,
cancelled=False,
message="Job not found",
)
def StreamChunkPipeline(self, request, context) -> Iterator[worker_pb2.ChunkPipelineEvent]:
"""Stream chunk pipeline events for a job."""
from core.events import poll_events
job_id = request.job_id
logger.info(f"StreamChunkPipeline: {job_id}")
cursor = 0
timeout = time.monotonic() + 600 # 10 min max
while context.is_active() and time.monotonic() < timeout:
events, cursor = poll_events(job_id, cursor)
for data in events:
event_type = data.pop("event", "")
yield worker_pb2.ChunkPipelineEvent(
job_id=job_id,
event_type=event_type,
sequence=data.get("sequence", 0),
worker_id=data.get("worker_id", ""),
state=data.get("state", ""),
queue_size=data.get("queue_size", 0),
elapsed=data.get("elapsed", 0.0),
throughput_mbps=data.get("throughput_mbps", 0.0),
total_chunks=data.get("total_chunks", 0),
processed_chunks=data.get("processed_chunks", 0),
failed_chunks=data.get("failed_chunks", 0),
error=data.get("error", ""),
processing_time=data.get("processing_time", 0.0),
retries=data.get("retries", 0),
)
if event_type in ("pipeline_complete", "pipeline_error"):
return
time.sleep(0.05)
def GetWorkerStatus(self, request, context):
"""Get worker health and capabilities."""
return worker_pb2.WorkerStatus(
available=True,
active_jobs=len(_active_jobs),
supported_codecs=[],
gpu_available=False,
)
def update_job_progress(
job_id: str,
progress: int,
current_frame: int = 0,
current_time: float = 0.0,
speed: float = 0.0,
status: str = "processing",
error: str = None,
**extra,
) -> None:
"""
Update job progress (called from worker tasks).
Updates both the in-memory gRPC state and the Django database.
Extra kwargs are stored for chunker-specific fields (total_chunks,
processed_chunks, failed_chunks, throughput_mbps, etc.).
"""
if job_id in _active_jobs:
_active_jobs[job_id].update(
{
"progress": progress,
"current_frame": current_frame,
"current_time": current_time,
"speed": speed,
"status": status,
"error": error,
**extra,
}
)
# Update Django database
try:
from django.utils import timezone
from core.db import update_job_fields
updates = {
"progress": progress,
"current_frame": current_frame,
"current_time": current_time,
"speed": str(speed),
"status": status,
}
if error:
updates["error_message"] = error
if status == "processing":
updates["started_at"] = timezone.now()
elif status in ("completed", "failed"):
updates["completed_at"] = timezone.now()
update_job_fields(job_id, **updates)
except Exception as e:
logger.warning(f"Failed to update job {job_id} in DB: {e}")
def serve(port: int = None) -> grpc.Server:
"""
Start the gRPC server.
Args:
port: Port to listen on (defaults to GRPC_PORT env var)
Returns:
The running gRPC server
"""
if port is None:
port = GRPC_PORT
server = grpc.server(futures.ThreadPoolExecutor(max_workers=GRPC_MAX_WORKERS))
worker_pb2_grpc.add_WorkerServiceServicer_to_server(
WorkerServicer(),
server,
)
server.add_insecure_port(f"[::]:{port}")
server.start()
logger.info(f"gRPC server started on port {port}")
return server
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
server = serve()
server.wait_for_termination()