322 lines
9.9 KiB
Python
322 lines
9.9 KiB
Python
"""
|
|
gRPC Server - Worker Service Implementation
|
|
|
|
Runs in the worker process to handle job submissions and progress streaming.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
from concurrent import futures
|
|
from typing import Iterator
|
|
|
|
import grpc
|
|
|
|
# Configuration from environment
|
|
GRPC_PORT = int(os.environ.get("GRPC_PORT", "50051"))
|
|
GRPC_MAX_WORKERS = int(os.environ.get("GRPC_MAX_WORKERS", "10"))
|
|
|
|
# Generated stubs - run `python schema/generate.py --proto` if missing
|
|
from . import worker_pb2, worker_pb2_grpc
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Active jobs progress tracking (shared state for streaming)
|
|
_active_jobs: dict[str, dict] = {}
|
|
|
|
|
|
class WorkerServicer(worker_pb2_grpc.WorkerServiceServicer):
|
|
"""gRPC service implementation for worker operations."""
|
|
|
|
def __init__(self, celery_app=None):
|
|
"""
|
|
Initialize the servicer.
|
|
|
|
Args:
|
|
celery_app: Optional Celery app for task dispatch
|
|
"""
|
|
self.celery_app = celery_app
|
|
|
|
def SubmitJob(self, request, context):
|
|
"""Submit a transcode/trim job to the worker."""
|
|
job_id = request.job_id
|
|
logger.info(f"SubmitJob: {job_id}")
|
|
|
|
try:
|
|
# Parse preset
|
|
preset = json.loads(request.preset_json) if request.preset_json else None
|
|
|
|
# Initialize progress tracking
|
|
_active_jobs[job_id] = {
|
|
"status": "pending",
|
|
"progress": 0,
|
|
"current_frame": 0,
|
|
"current_time": 0.0,
|
|
"speed": 0.0,
|
|
"error": None,
|
|
}
|
|
|
|
# Dispatch to Celery if available
|
|
if self.celery_app:
|
|
from core.jobs.task import run_job
|
|
|
|
payload = {
|
|
"source_key": request.source_path,
|
|
"output_key": request.output_path,
|
|
"preset": preset,
|
|
"trim_start": request.trim_start
|
|
if request.HasField("trim_start")
|
|
else None,
|
|
"trim_end": request.trim_end
|
|
if request.HasField("trim_end")
|
|
else None,
|
|
}
|
|
|
|
task = run_job.delay(
|
|
job_type="transcode",
|
|
job_id=job_id,
|
|
payload=payload,
|
|
)
|
|
_active_jobs[job_id]["celery_task_id"] = task.id
|
|
|
|
return worker_pb2.JobResponse(
|
|
job_id=job_id,
|
|
accepted=True,
|
|
message="Job submitted",
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception(f"SubmitJob failed: {e}")
|
|
return worker_pb2.JobResponse(
|
|
job_id=job_id,
|
|
accepted=False,
|
|
message=str(e),
|
|
)
|
|
|
|
def StreamProgress(self, request, context) -> Iterator[worker_pb2.ProgressUpdate]:
|
|
"""Stream progress updates for a job."""
|
|
job_id = request.job_id
|
|
logger.info(f"StreamProgress: {job_id}")
|
|
|
|
# Check if job exists
|
|
if job_id not in _active_jobs:
|
|
yield worker_pb2.ProgressUpdate(
|
|
job_id=job_id,
|
|
progress=0,
|
|
status="not_found",
|
|
error="Job not found",
|
|
)
|
|
return
|
|
|
|
# Stream updates until job completes
|
|
last_progress = -1
|
|
while True:
|
|
if context.cancelled():
|
|
logger.info(f"StreamProgress cancelled: {job_id}")
|
|
break
|
|
|
|
job_state = _active_jobs.get(job_id)
|
|
if not job_state:
|
|
break
|
|
|
|
# Only yield if progress changed
|
|
if job_state["progress"] != last_progress:
|
|
last_progress = job_state["progress"]
|
|
|
|
yield worker_pb2.ProgressUpdate(
|
|
job_id=job_id,
|
|
progress=job_state["progress"],
|
|
current_frame=job_state.get("current_frame", 0),
|
|
current_time=job_state.get("current_time", 0.0),
|
|
speed=job_state.get("speed", 0.0),
|
|
status=job_state["status"],
|
|
error=job_state.get("error"),
|
|
)
|
|
|
|
# Exit if job is done
|
|
if job_state["status"] in ("completed", "failed", "cancelled"):
|
|
break
|
|
|
|
# Small delay to avoid busy loop
|
|
time.sleep(0.1)
|
|
|
|
# Cleanup completed jobs
|
|
if job_id in _active_jobs:
|
|
status = _active_jobs[job_id].get("status")
|
|
if status in ("completed", "failed", "cancelled"):
|
|
_active_jobs.pop(job_id, None)
|
|
|
|
def CancelJob(self, request, context):
|
|
"""Cancel a running job."""
|
|
job_id = request.job_id
|
|
logger.info(f"CancelJob: {job_id}")
|
|
|
|
if job_id in _active_jobs:
|
|
_active_jobs[job_id]["status"] = "cancelled"
|
|
|
|
# Revoke Celery task if available
|
|
if self.celery_app:
|
|
task_id = _active_jobs[job_id].get("celery_task_id")
|
|
if task_id:
|
|
self.celery_app.control.revoke(task_id, terminate=True)
|
|
|
|
return worker_pb2.CancelResponse(
|
|
job_id=job_id,
|
|
cancelled=True,
|
|
message="Job cancelled",
|
|
)
|
|
|
|
return worker_pb2.CancelResponse(
|
|
job_id=job_id,
|
|
cancelled=False,
|
|
message="Job not found",
|
|
)
|
|
|
|
def StreamChunkPipeline(self, request, context) -> Iterator[worker_pb2.ChunkPipelineEvent]:
|
|
"""Stream chunk pipeline events for a job."""
|
|
from core.events import poll_events
|
|
|
|
job_id = request.job_id
|
|
logger.info(f"StreamChunkPipeline: {job_id}")
|
|
|
|
cursor = 0
|
|
timeout = time.monotonic() + 600 # 10 min max
|
|
|
|
while context.is_active() and time.monotonic() < timeout:
|
|
events, cursor = poll_events(job_id, cursor)
|
|
|
|
for data in events:
|
|
event_type = data.pop("event", "")
|
|
yield worker_pb2.ChunkPipelineEvent(
|
|
job_id=job_id,
|
|
event_type=event_type,
|
|
sequence=data.get("sequence", 0),
|
|
worker_id=data.get("worker_id", ""),
|
|
state=data.get("state", ""),
|
|
queue_size=data.get("queue_size", 0),
|
|
elapsed=data.get("elapsed", 0.0),
|
|
throughput_mbps=data.get("throughput_mbps", 0.0),
|
|
total_chunks=data.get("total_chunks", 0),
|
|
processed_chunks=data.get("processed_chunks", 0),
|
|
failed_chunks=data.get("failed_chunks", 0),
|
|
error=data.get("error", ""),
|
|
processing_time=data.get("processing_time", 0.0),
|
|
retries=data.get("retries", 0),
|
|
)
|
|
|
|
if event_type in ("pipeline_complete", "pipeline_error"):
|
|
return
|
|
|
|
time.sleep(0.05)
|
|
|
|
def GetWorkerStatus(self, request, context):
|
|
"""Get worker health and capabilities."""
|
|
try:
|
|
from core.ffmpeg import get_encoders
|
|
|
|
encoders = get_encoders()
|
|
codec_names = [e["name"] for e in encoders.get("video", [])]
|
|
except Exception:
|
|
codec_names = []
|
|
|
|
# Check for GPU encoders
|
|
gpu_available = any(
|
|
"nvenc" in name or "vaapi" in name or "qsv" in name for name in codec_names
|
|
)
|
|
|
|
return worker_pb2.WorkerStatus(
|
|
available=True,
|
|
active_jobs=len(_active_jobs),
|
|
supported_codecs=codec_names[:20], # Limit to 20
|
|
gpu_available=gpu_available,
|
|
)
|
|
|
|
|
|
def update_job_progress(
|
|
job_id: str,
|
|
progress: int,
|
|
current_frame: int = 0,
|
|
current_time: float = 0.0,
|
|
speed: float = 0.0,
|
|
status: str = "processing",
|
|
error: str = None,
|
|
**extra,
|
|
) -> None:
|
|
"""
|
|
Update job progress (called from worker tasks).
|
|
|
|
Updates both the in-memory gRPC state and the Django database.
|
|
Extra kwargs are stored for chunker-specific fields (total_chunks,
|
|
processed_chunks, failed_chunks, throughput_mbps, etc.).
|
|
"""
|
|
if job_id in _active_jobs:
|
|
_active_jobs[job_id].update(
|
|
{
|
|
"progress": progress,
|
|
"current_frame": current_frame,
|
|
"current_time": current_time,
|
|
"speed": speed,
|
|
"status": status,
|
|
"error": error,
|
|
**extra,
|
|
}
|
|
)
|
|
|
|
# Update Django database
|
|
try:
|
|
from django.utils import timezone
|
|
|
|
from core.db import update_job_fields
|
|
|
|
updates = {
|
|
"progress": progress,
|
|
"current_frame": current_frame,
|
|
"current_time": current_time,
|
|
"speed": str(speed),
|
|
"status": status,
|
|
}
|
|
|
|
if error:
|
|
updates["error_message"] = error
|
|
|
|
if status == "processing":
|
|
updates["started_at"] = timezone.now()
|
|
elif status in ("completed", "failed"):
|
|
updates["completed_at"] = timezone.now()
|
|
|
|
update_job_fields(job_id, **updates)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to update job {job_id} in DB: {e}")
|
|
|
|
|
|
def serve(port: int = None, celery_app=None) -> grpc.Server:
|
|
"""
|
|
Start the gRPC server.
|
|
|
|
Args:
|
|
port: Port to listen on (defaults to GRPC_PORT env var)
|
|
celery_app: Optional Celery app for task dispatch
|
|
|
|
Returns:
|
|
The running gRPC server
|
|
"""
|
|
if port is None:
|
|
port = GRPC_PORT
|
|
|
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=GRPC_MAX_WORKERS))
|
|
worker_pb2_grpc.add_WorkerServiceServicer_to_server(
|
|
WorkerServicer(celery_app=celery_app),
|
|
server,
|
|
)
|
|
server.add_insecure_port(f"[::]:{port}")
|
|
server.start()
|
|
logger.info(f"gRPC server started on port {port}")
|
|
return server
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO)
|
|
server = serve()
|
|
server.wait_for_termination()
|