209 lines
6.0 KiB
Python
209 lines
6.0 KiB
Python
"""
|
|
gRPC Client - Used by FastAPI to communicate with workers.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
from typing import Callable, Iterator, Optional
|
|
|
|
import grpc
|
|
|
|
# Generated stubs - run `python schema/generate.py --proto` if missing
|
|
try:
|
|
from . import worker_pb2, worker_pb2_grpc
|
|
except ImportError:
|
|
import worker_pb2
|
|
import worker_pb2_grpc
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Configuration from environment
|
|
GRPC_HOST = os.environ.get("GRPC_HOST", "grpc")
|
|
GRPC_PORT = int(os.environ.get("GRPC_PORT", "50051"))
|
|
|
|
|
|
class WorkerClient:
|
|
"""gRPC client for worker communication."""
|
|
|
|
def __init__(self, host: str = None, port: int = None):
|
|
"""
|
|
Initialize the client.
|
|
|
|
Args:
|
|
host: gRPC server host (defaults to GRPC_HOST env var)
|
|
port: gRPC server port (defaults to GRPC_PORT env var)
|
|
"""
|
|
self.host = host or GRPC_HOST
|
|
self.port = port or GRPC_PORT
|
|
self.address = f"{self.host}:{self.port}"
|
|
self._channel: Optional[grpc.Channel] = None
|
|
self._stub: Optional[worker_pb2_grpc.WorkerServiceStub] = None
|
|
|
|
def _ensure_connected(self) -> worker_pb2_grpc.WorkerServiceStub:
|
|
"""Ensure channel is connected and return stub."""
|
|
if self._channel is None:
|
|
self._channel = grpc.insecure_channel(self.address)
|
|
self._stub = worker_pb2_grpc.WorkerServiceStub(self._channel)
|
|
return self._stub
|
|
|
|
def close(self) -> None:
|
|
"""Close the channel."""
|
|
if self._channel:
|
|
self._channel.close()
|
|
self._channel = None
|
|
self._stub = None
|
|
|
|
def __enter__(self):
|
|
self._ensure_connected()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self.close()
|
|
|
|
def submit_job(
|
|
self,
|
|
job_id: str,
|
|
source_path: str,
|
|
output_path: str,
|
|
preset: Optional[dict] = None,
|
|
trim_start: Optional[float] = None,
|
|
trim_end: Optional[float] = None,
|
|
) -> tuple[bool, str]:
|
|
"""
|
|
Submit a job to the worker.
|
|
|
|
Args:
|
|
job_id: Unique job identifier
|
|
source_path: Path to source file
|
|
output_path: Path for output file
|
|
preset: Transcode preset dict (optional)
|
|
trim_start: Trim start time in seconds (optional)
|
|
trim_end: Trim end time in seconds (optional)
|
|
|
|
Returns:
|
|
Tuple of (accepted: bool, message: str)
|
|
"""
|
|
stub = self._ensure_connected()
|
|
|
|
request = worker_pb2.JobRequest(
|
|
job_id=job_id,
|
|
source_path=source_path,
|
|
output_path=output_path,
|
|
preset_json=json.dumps(preset) if preset else "",
|
|
)
|
|
|
|
if trim_start is not None:
|
|
request.trim_start = trim_start
|
|
if trim_end is not None:
|
|
request.trim_end = trim_end
|
|
|
|
try:
|
|
response = stub.SubmitJob(request)
|
|
return response.accepted, response.message
|
|
except grpc.RpcError as e:
|
|
logger.error(f"SubmitJob RPC failed: {e}")
|
|
return False, str(e)
|
|
|
|
def stream_progress(
|
|
self,
|
|
job_id: str,
|
|
callback: Optional[Callable[[dict], None]] = None,
|
|
) -> Iterator[dict]:
|
|
"""
|
|
Stream progress updates for a job.
|
|
|
|
Args:
|
|
job_id: Job identifier
|
|
callback: Optional callback for each update
|
|
|
|
Yields:
|
|
Progress update dicts
|
|
"""
|
|
stub = self._ensure_connected()
|
|
|
|
request = worker_pb2.ProgressRequest(job_id=job_id)
|
|
|
|
try:
|
|
for update in stub.StreamProgress(request):
|
|
progress = {
|
|
"job_id": update.job_id,
|
|
"progress": update.progress,
|
|
"current_frame": update.current_frame,
|
|
"current_time": update.current_time,
|
|
"speed": update.speed,
|
|
"status": update.status,
|
|
"error": update.error if update.HasField("error") else None,
|
|
}
|
|
|
|
if callback:
|
|
callback(progress)
|
|
|
|
yield progress
|
|
|
|
if update.status in ("completed", "failed", "cancelled"):
|
|
break
|
|
|
|
except grpc.RpcError as e:
|
|
logger.error(f"StreamProgress RPC failed: {e}")
|
|
yield {
|
|
"job_id": job_id,
|
|
"progress": 0,
|
|
"status": "error",
|
|
"error": str(e),
|
|
}
|
|
|
|
def cancel_job(self, job_id: str) -> tuple[bool, str]:
|
|
"""
|
|
Cancel a running job.
|
|
|
|
Args:
|
|
job_id: Job identifier
|
|
|
|
Returns:
|
|
Tuple of (cancelled: bool, message: str)
|
|
"""
|
|
stub = self._ensure_connected()
|
|
|
|
request = worker_pb2.CancelRequest(job_id=job_id)
|
|
|
|
try:
|
|
response = stub.CancelJob(request)
|
|
return response.cancelled, response.message
|
|
except grpc.RpcError as e:
|
|
logger.error(f"CancelJob RPC failed: {e}")
|
|
return False, str(e)
|
|
|
|
def get_worker_status(self) -> Optional[dict]:
|
|
"""
|
|
Get worker status and capabilities.
|
|
|
|
Returns:
|
|
Status dict or None on error
|
|
"""
|
|
stub = self._ensure_connected()
|
|
|
|
try:
|
|
response = stub.GetWorkerStatus(worker_pb2.Empty())
|
|
return {
|
|
"available": response.available,
|
|
"active_jobs": response.active_jobs,
|
|
"supported_codecs": list(response.supported_codecs),
|
|
"gpu_available": response.gpu_available,
|
|
}
|
|
except grpc.RpcError as e:
|
|
logger.error(f"GetWorkerStatus RPC failed: {e}")
|
|
return None
|
|
|
|
|
|
# Singleton client instance
|
|
_client: Optional[WorkerClient] = None
|
|
|
|
|
|
def get_client() -> WorkerClient:
|
|
"""Get or create the singleton client (uses env vars for config)."""
|
|
global _client
|
|
if _client is None:
|
|
_client = WorkerClient()
|
|
return _client
|