279 lines
9.2 KiB
Python
279 lines
9.2 KiB
Python
"""
|
|
Pipeline run endpoints.
|
|
|
|
POST /detect/run — launch pipeline on a timeline
|
|
POST /detect/stop/{job_id} — cancel a running pipeline
|
|
POST /detect/pause/{job_id} — pause after current stage
|
|
POST /detect/resume/{job_id} — resume a paused pipeline
|
|
POST /detect/step/{job_id} — run one stage then pause
|
|
POST /detect/clear/{job_id} — clear events from Redis
|
|
GET /detect/status/{job_id} — pipeline run status
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
import threading
|
|
import uuid
|
|
|
|
from fastapi import APIRouter, HTTPException
|
|
from pydantic import BaseModel
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/detect", tags=["detect"])
|
|
|
|
# In-process pipeline tracking
|
|
_running_jobs: dict[str, threading.Thread] = {}
|
|
_cancelled_jobs: set[str] = set()
|
|
|
|
|
|
class RunRequest(BaseModel):
|
|
timeline_id: str
|
|
profile_name: str = "soccer_broadcast"
|
|
checkpoint: bool = True
|
|
skip_vlm: bool = False
|
|
skip_cloud: bool = False
|
|
log_level: str = "INFO" # INFO | DEBUG
|
|
pause_after_stage: bool = False
|
|
config_overrides: dict | None = None
|
|
|
|
|
|
class RunResponse(BaseModel):
|
|
status: str
|
|
job_id: str
|
|
timeline_id: str
|
|
|
|
|
|
def _resolve_video_path(video_path: str) -> str:
|
|
"""Download a chunk from blob storage to a temp file."""
|
|
from core.storage.blob import get_store
|
|
|
|
store = get_store("out")
|
|
try:
|
|
return store.download_to_temp(video_path)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Failed to download chunk: {e}")
|
|
|
|
|
|
@router.post("/run", response_model=RunResponse)
|
|
def run_pipeline(req: RunRequest):
|
|
"""Launch a detection pipeline run on a timeline."""
|
|
from core.detect import emit
|
|
from core.detect.graph import get_pipeline
|
|
from core.detect.state import DetectState
|
|
from core.detect.checkpoint.storage import get_timeline
|
|
from core.db.connection import get_session
|
|
from core.db.job import create_job, update_job_status
|
|
|
|
# Load timeline
|
|
try:
|
|
timeline = get_timeline(req.timeline_id)
|
|
except ValueError:
|
|
raise HTTPException(status_code=404, detail=f"Timeline not found: {req.timeline_id}")
|
|
|
|
chunk_paths = timeline["chunk_paths"]
|
|
if not chunk_paths:
|
|
raise HTTPException(status_code=400, detail="Timeline has no chunk paths")
|
|
|
|
# Resolve first chunk to local path for the pipeline
|
|
local_path = _resolve_video_path(chunk_paths[0])
|
|
|
|
# Create job in DB
|
|
source_asset_id_str = timeline.get("source_asset_id", "")
|
|
with get_session() as session:
|
|
from uuid import UUID as _UUID
|
|
source_asset_id = _UUID(source_asset_id_str) if source_asset_id_str else uuid.uuid4()
|
|
job = create_job(
|
|
session,
|
|
source_asset_id=source_asset_id,
|
|
video_path=chunk_paths[0],
|
|
timeline_id=_UUID(req.timeline_id),
|
|
profile_name=req.profile_name,
|
|
config_overrides=req.config_overrides,
|
|
)
|
|
job_id = str(job.id)
|
|
|
|
if req.skip_vlm:
|
|
os.environ["SKIP_VLM"] = "1"
|
|
elif "SKIP_VLM" in os.environ:
|
|
del os.environ["SKIP_VLM"]
|
|
|
|
if req.skip_cloud:
|
|
os.environ["SKIP_CLOUD"] = "1"
|
|
elif "SKIP_CLOUD" in os.environ:
|
|
del os.environ["SKIP_CLOUD"]
|
|
|
|
# Clear any stale events
|
|
from core.events import _get_redis
|
|
from core.detect.events import DETECT_EVENTS_PREFIX
|
|
r = _get_redis()
|
|
r.delete(f"{DETECT_EVENTS_PREFIX}:{job_id}")
|
|
|
|
emit.set_run_context(
|
|
run_id=job_id, parent_job_id=job_id, run_type="initial",
|
|
log_level=req.log_level,
|
|
)
|
|
|
|
pipeline = get_pipeline(checkpoint=req.checkpoint, profile_name=req.profile_name)
|
|
|
|
initial_state = DetectState(
|
|
video_path=local_path,
|
|
job_id=job_id,
|
|
profile_name=req.profile_name,
|
|
source_asset_id=source_asset_id_str or str(source_asset_id),
|
|
timeline_id=req.timeline_id,
|
|
config_overrides=req.config_overrides or {},
|
|
)
|
|
|
|
from core.detect.graph import (
|
|
PipelineCancelled, set_cancel_check, clear_cancel_check,
|
|
init_pause, clear_pause,
|
|
)
|
|
|
|
set_cancel_check(job_id, lambda: job_id in _cancelled_jobs)
|
|
init_pause(job_id, pause_after_stage=req.pause_after_stage)
|
|
|
|
def _update_job(status, stage=None, error=None):
|
|
from core.db.connection import get_session
|
|
from core.db.job import update_job_status
|
|
with get_session() as session:
|
|
update_job_status(session, _UUID(job_id), status,
|
|
current_stage=stage, error_message=error)
|
|
|
|
def _run():
|
|
try:
|
|
_update_job("running")
|
|
emit.log(job_id, "Pipeline", "INFO",
|
|
f"Starting pipeline: {chunk_paths[0]} (profile={req.profile_name})")
|
|
pipeline.invoke(initial_state)
|
|
_update_job("completed")
|
|
emit.log(job_id, "Pipeline", "INFO", "Pipeline completed successfully")
|
|
emit.job_complete(job_id, {"status": "completed"})
|
|
except PipelineCancelled:
|
|
_update_job("cancelled")
|
|
emit.log(job_id, "Pipeline", "INFO", "Pipeline cancelled")
|
|
emit.job_complete(job_id, {"status": "cancelled"})
|
|
except Exception as e:
|
|
logger.exception("Pipeline run %s failed: %s", job_id, e)
|
|
_update_job("failed", error=str(e))
|
|
from core.detect.graph import _node_states, NODES
|
|
if job_id in _node_states:
|
|
states = _node_states[job_id]
|
|
for node in reversed(NODES):
|
|
if states.get(node) in ("running", "done"):
|
|
states[node] = "error"
|
|
break
|
|
nodes = [{"id": n, "status": states[n]} for n in NODES]
|
|
emit.graph_update(job_id, nodes)
|
|
emit.log(job_id, "Pipeline", "ERROR", str(e))
|
|
emit.job_complete(job_id, {"status": "failed", "error": str(e)})
|
|
finally:
|
|
_running_jobs.pop(job_id, None)
|
|
_cancelled_jobs.discard(job_id)
|
|
clear_cancel_check(job_id)
|
|
clear_pause(job_id)
|
|
emit.clear_run_context()
|
|
from core.detect.checkpoint.runner_bridge import reset_checkpoint_state
|
|
reset_checkpoint_state(job_id)
|
|
|
|
thread = threading.Thread(target=_run, daemon=True, name=f"pipeline-{job_id}")
|
|
_running_jobs[job_id] = thread
|
|
thread.start()
|
|
|
|
return RunResponse(status="started", job_id=job_id, timeline_id=req.timeline_id)
|
|
|
|
|
|
@router.post("/stop/{job_id}")
|
|
def stop_pipeline(job_id: str):
|
|
"""Stop a running pipeline. Signals cancellation; the thread checks on next stage."""
|
|
from core.detect import emit
|
|
|
|
if job_id not in _running_jobs:
|
|
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
|
|
|
_cancelled_jobs.add(job_id)
|
|
emit.log(job_id, "Pipeline", "INFO", "Stop requested — cancelling after current stage")
|
|
return {"status": "stopping", "job_id": job_id}
|
|
|
|
|
|
@router.post("/pause/{job_id}")
|
|
def pause(job_id: str):
|
|
"""Pause a running pipeline after the current stage completes."""
|
|
from core.detect.graph import pause_pipeline
|
|
|
|
if job_id not in _running_jobs:
|
|
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
|
|
|
pause_pipeline(job_id)
|
|
return {"status": "pausing", "job_id": job_id}
|
|
|
|
|
|
@router.post("/resume/{job_id}")
|
|
def resume(job_id: str):
|
|
"""Resume a paused pipeline."""
|
|
from core.detect.graph import resume_pipeline
|
|
|
|
if job_id not in _running_jobs:
|
|
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
|
|
|
resume_pipeline(job_id)
|
|
return {"status": "running", "job_id": job_id}
|
|
|
|
|
|
@router.post("/step/{job_id}")
|
|
def step(job_id: str):
|
|
"""Run one stage then pause again."""
|
|
from core.detect.graph import step_pipeline
|
|
|
|
if job_id not in _running_jobs:
|
|
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
|
|
|
step_pipeline(job_id)
|
|
return {"status": "stepping", "job_id": job_id}
|
|
|
|
|
|
@router.post("/pause-after-stage/{job_id}")
|
|
def toggle_pause_after_stage(job_id: str, enabled: bool = True):
|
|
"""Toggle pause-after-each-stage mode."""
|
|
from core.detect.graph import set_pause_after_stage
|
|
|
|
if job_id not in _running_jobs:
|
|
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
|
|
|
set_pause_after_stage(job_id, enabled)
|
|
return {"status": "ok", "pause_after_stage": enabled, "job_id": job_id}
|
|
|
|
|
|
@router.get("/status/{job_id}")
|
|
def pipeline_status(job_id: str):
|
|
"""Get pipeline run status."""
|
|
from core.detect.graph import is_paused
|
|
|
|
running = job_id in _running_jobs
|
|
paused = is_paused(job_id)
|
|
cancelling = job_id in _cancelled_jobs
|
|
|
|
if cancelling:
|
|
status = "cancelling"
|
|
elif paused:
|
|
status = "paused"
|
|
elif running:
|
|
status = "running"
|
|
else:
|
|
status = "idle"
|
|
|
|
return {"status": status, "job_id": job_id}
|
|
|
|
|
|
@router.post("/clear/{job_id}")
|
|
def clear_pipeline(job_id: str):
|
|
"""Clear events for a job from Redis."""
|
|
from core.events import _get_redis
|
|
from core.detect.events import DETECT_EVENTS_PREFIX
|
|
|
|
r = _get_redis()
|
|
r.delete(f"{DETECT_EVENTS_PREFIX}:{job_id}")
|
|
return {"status": "cleared", "job_id": job_id}
|