131 lines
4.0 KiB
Python
131 lines
4.0 KiB
Python
"""
|
|
DetectHandler — runs the detection pipeline as a Celery job.
|
|
|
|
Supports three modes via payload:
|
|
- Initial run: {"video_path": "...", "profile_name": "..."}
|
|
- Replay: {"replay_from": "run_ocr", "source_job_id": "...", "config_overrides": {...}}
|
|
- Retry: {"retry_from": "escalate_vlm", "source_job_id": "...", "config_overrides": {...}}
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import uuid
|
|
from typing import Any, Callable, Dict, Optional
|
|
|
|
from .base import Handler
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DetectHandler(Handler):
|
|
|
|
def process(
|
|
self,
|
|
job_id: str,
|
|
payload: Dict[str, Any],
|
|
progress_callback: Optional[Callable[[int, Dict[str, Any]], None]] = None,
|
|
) -> Dict[str, Any]:
|
|
|
|
replay_from = payload.get("replay_from")
|
|
source_job_id = payload.get("source_job_id")
|
|
|
|
if replay_from and source_job_id:
|
|
return self._run_replay(job_id, source_job_id, replay_from, payload, progress_callback)
|
|
|
|
return self._run_initial(job_id, payload, progress_callback)
|
|
|
|
def _run_initial(
|
|
self,
|
|
job_id: str,
|
|
payload: Dict[str, Any],
|
|
progress_callback: Optional[Callable],
|
|
) -> Dict[str, Any]:
|
|
from detect import emit
|
|
from detect.graph import get_pipeline
|
|
from detect.state import DetectState
|
|
|
|
video_path = payload["video_path"]
|
|
profile_name = payload.get("profile_name", "soccer_broadcast")
|
|
source_asset_id = payload.get("source_asset_id", "")
|
|
checkpoint_enabled = payload.get("checkpoint", os.environ.get("MPR_CHECKPOINT") == "1")
|
|
|
|
emit.set_run_context(
|
|
run_id=job_id,
|
|
parent_job_id=payload.get("parent_job_id", job_id),
|
|
run_type="initial",
|
|
)
|
|
|
|
logger.info("DetectHandler: initial run job=%s video=%s profile=%s checkpoint=%s",
|
|
job_id, video_path, profile_name, checkpoint_enabled)
|
|
|
|
if progress_callback:
|
|
progress_callback(0, {"stage": "starting"})
|
|
|
|
pipeline = get_pipeline(checkpoint=checkpoint_enabled)
|
|
|
|
initial_state = DetectState(
|
|
video_path=video_path,
|
|
job_id=job_id,
|
|
profile_name=profile_name,
|
|
source_asset_id=source_asset_id,
|
|
)
|
|
|
|
try:
|
|
result = pipeline.invoke(initial_state)
|
|
finally:
|
|
emit.clear_run_context()
|
|
|
|
detections = result.get("detections", [])
|
|
report = result.get("report")
|
|
brands_found = len(report.brands) if report else 0
|
|
|
|
if progress_callback:
|
|
progress_callback(100, {"stage": "completed"})
|
|
|
|
return {
|
|
"status": "completed",
|
|
"job_id": job_id,
|
|
"detections": len(detections),
|
|
"brands_found": brands_found,
|
|
}
|
|
|
|
def _run_replay(
|
|
self,
|
|
job_id: str,
|
|
source_job_id: str,
|
|
start_stage: str,
|
|
payload: Dict[str, Any],
|
|
progress_callback: Optional[Callable],
|
|
) -> Dict[str, Any]:
|
|
from detect.checkpoint import replay_from
|
|
|
|
config_overrides = payload.get("config_overrides", {})
|
|
|
|
logger.info("DetectHandler: replay job=%s from=%s source=%s overrides=%s",
|
|
job_id, start_stage, source_job_id, config_overrides)
|
|
|
|
if progress_callback:
|
|
progress_callback(0, {"stage": f"replaying from {start_stage}"})
|
|
|
|
result = replay_from(
|
|
job_id=source_job_id,
|
|
start_stage=start_stage,
|
|
config_overrides=config_overrides,
|
|
)
|
|
|
|
detections = result.get("detections", [])
|
|
report = result.get("report")
|
|
brands_found = len(report.brands) if report else 0
|
|
|
|
if progress_callback:
|
|
progress_callback(100, {"stage": "completed"})
|
|
|
|
return {
|
|
"status": "completed",
|
|
"job_id": job_id,
|
|
"source_job_id": source_job_id,
|
|
"replay_from": start_stage,
|
|
"detections": len(detections),
|
|
"brands_found": brands_found,
|
|
}
|