401 lines
12 KiB
Python
401 lines
12 KiB
Python
"""
|
|
API endpoints for checkpoint inspection, replay, retry, and GPU proxy.
|
|
|
|
GET /detect/checkpoints/{timeline_id} — list available checkpoints
|
|
POST /detect/replay — replay from a stage with config overrides
|
|
POST /detect/retry — queue async retry with different provider
|
|
POST /detect/replay-stage — replay single stage (fast path)
|
|
POST /detect/gpu/detect_edges — proxy to GPU inference server
|
|
POST /detect/gpu/detect_edges/debug — proxy with debug overlays
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
|
|
from fastapi import APIRouter, HTTPException, Request, Response
|
|
from pydantic import BaseModel
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/detect", tags=["detect"])
|
|
|
|
|
|
# --- Request/Response models ---
|
|
|
|
class CheckpointInfo(BaseModel):
|
|
stage: str
|
|
is_scenario: bool = False
|
|
scenario_label: str = ""
|
|
|
|
|
|
class ScenarioInfo(BaseModel):
|
|
timeline_id: str
|
|
stage: str
|
|
scenario_label: str
|
|
profile_name: str
|
|
video_path: str
|
|
frame_count: int = 0
|
|
created_at: str = ""
|
|
|
|
|
|
class ReplayRequest(BaseModel):
|
|
timeline_id: str
|
|
start_stage: str
|
|
config_overrides: dict | None = None
|
|
|
|
|
|
class ReplayResponse(BaseModel):
|
|
status: str
|
|
timeline_id: str
|
|
start_stage: str
|
|
detections: int = 0
|
|
brands_found: int = 0
|
|
|
|
|
|
class RetryRequest(BaseModel):
|
|
timeline_id: str
|
|
config_overrides: dict | None = None
|
|
start_stage: str = "escalate_vlm"
|
|
schedule_seconds: float | None = None # delay before execution (off-peak)
|
|
|
|
|
|
class RetryResponse(BaseModel):
|
|
status: str
|
|
task_id: str
|
|
timeline_id: str
|
|
|
|
|
|
class ReplaySingleStageRequest(BaseModel):
|
|
timeline_id: str
|
|
stage: str
|
|
frame_refs: list[int] | None = None
|
|
config_overrides: dict | None = None
|
|
debug: bool = False
|
|
|
|
|
|
class ReplaySingleStageBox(BaseModel):
|
|
x: int
|
|
y: int
|
|
w: int
|
|
h: int
|
|
confidence: float
|
|
label: str
|
|
|
|
|
|
class FrameDebugOverlays(BaseModel):
|
|
edge_overlay_b64: str = ""
|
|
lines_overlay_b64: str = ""
|
|
horizontal_count: int = 0
|
|
pair_count: int = 0
|
|
|
|
|
|
class ReplaySingleStageResponse(BaseModel):
|
|
status: str
|
|
stage: str
|
|
frame_count: int = 0
|
|
region_count: int = 0
|
|
regions_by_frame: dict[str, list[ReplaySingleStageBox]] = {}
|
|
debug: dict[str, FrameDebugOverlays] = {} # keyed by frame seq
|
|
|
|
|
|
# --- Endpoints ---
|
|
|
|
@router.get("/checkpoints/{timeline_id}")
|
|
def list_checkpoints(timeline_id: str) -> list[CheckpointInfo]:
|
|
"""List available checkpoint stages for a job."""
|
|
from core.detect.checkpoint import list_checkpoints as _list
|
|
|
|
try:
|
|
stages = _list(timeline_id)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=404, detail=f"No checkpoints for job {timeline_id}: {e}")
|
|
|
|
result = [CheckpointInfo(stage=s) for s in stages]
|
|
return result
|
|
|
|
|
|
class CheckpointFrameInfo(BaseModel):
|
|
seq: int
|
|
timestamp: float
|
|
jpeg_b64: str
|
|
|
|
|
|
class CheckpointData(BaseModel):
|
|
timeline_id: str
|
|
stage: str
|
|
profile_name: str
|
|
video_path: str
|
|
is_scenario: bool
|
|
scenario_label: str
|
|
frames: list[CheckpointFrameInfo]
|
|
stats: dict = {}
|
|
config_snapshot: dict = {}
|
|
stage_output_key: str = ""
|
|
|
|
|
|
@router.get("/checkpoints/{timeline_id}/{stage}", response_model=CheckpointData)
|
|
def get_checkpoint_data(timeline_id: str, stage: str):
|
|
"""Load checkpoint frames + metadata for the editor UI.
|
|
|
|
Reads from the timeline's frame cache (local filesystem).
|
|
"""
|
|
from uuid import UUID
|
|
from core.db.models import Timeline, Checkpoint
|
|
from core.db.connection import get_session
|
|
from core.db.checkpoint import list_checkpoints
|
|
from core.detect.checkpoint.frames import load_cached_frames_b64
|
|
|
|
with get_session() as session:
|
|
timeline = session.get(Timeline, UUID(timeline_id))
|
|
if not timeline:
|
|
raise HTTPException(status_code=404, detail=f"Timeline not found: {timeline_id}")
|
|
|
|
checkpoints = list_checkpoints(session, UUID(timeline_id))
|
|
if not checkpoints:
|
|
raise HTTPException(status_code=404, detail=f"No checkpoints for timeline {timeline_id}")
|
|
# Prefer a checkpoint for this stage; fall back to latest
|
|
checkpoint = next(
|
|
(c for c in reversed(checkpoints) if c.stage_name == stage),
|
|
checkpoints[-1],
|
|
)
|
|
|
|
# Read from timeline's frame cache
|
|
frames_b64 = load_cached_frames_b64(timeline_id)
|
|
frame_list = [
|
|
CheckpointFrameInfo(seq=f["seq"], timestamp=f["timestamp"], jpeg_b64=f["jpeg_b64"])
|
|
for f in frames_b64
|
|
]
|
|
|
|
return CheckpointData(
|
|
timeline_id=timeline_id,
|
|
stage=stage,
|
|
profile_name=timeline.profile_name,
|
|
video_path=timeline.chunk_paths[0] if timeline.chunk_paths else "",
|
|
is_scenario=checkpoint.is_scenario,
|
|
scenario_label=checkpoint.scenario_label,
|
|
frames=frame_list,
|
|
stats=checkpoint.stats or {},
|
|
config_snapshot=checkpoint.config_overrides or {},
|
|
stage_output_key=stage,
|
|
)
|
|
|
|
|
|
@router.get("/scenarios", response_model=list[ScenarioInfo])
|
|
def list_scenarios_endpoint():
|
|
"""List all available scenarios (bookmarked checkpoints)."""
|
|
from core.db.models import Timeline
|
|
from core.db.connection import get_session
|
|
from core.db.checkpoint import list_scenarios
|
|
|
|
with get_session() as session:
|
|
scenarios = list_scenarios(session)
|
|
result = []
|
|
for s in scenarios:
|
|
timeline = session.get(Timeline, s.timeline_id)
|
|
if not timeline:
|
|
continue
|
|
info = ScenarioInfo(
|
|
timeline_id=str(s.timeline_id),
|
|
stage=s.stage_name,
|
|
scenario_label=s.scenario_label,
|
|
profile_name=timeline.profile_name,
|
|
video_path=timeline.chunk_paths[0] if timeline.chunk_paths else "",
|
|
created_at=str(s.created_at) if s.created_at else "",
|
|
)
|
|
result.append(info)
|
|
return result
|
|
|
|
|
|
@router.post("/replay", response_model=ReplayResponse)
|
|
def replay(req: ReplayRequest):
|
|
"""Replay pipeline from a specific stage with optional config overrides."""
|
|
from core.detect.checkpoint import replay_from
|
|
|
|
try:
|
|
result = replay_from(
|
|
timeline_id=req.timeline_id,
|
|
start_stage=req.start_stage,
|
|
config_overrides=req.config_overrides,
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Replay failed: {e}")
|
|
|
|
detections = result.get("detections", [])
|
|
report = result.get("report")
|
|
brands_found = len(report.brands) if report else 0
|
|
|
|
response = ReplayResponse(
|
|
status="completed",
|
|
timeline_id=req.timeline_id,
|
|
start_stage=req.start_stage,
|
|
detections=len(detections),
|
|
brands_found=brands_found,
|
|
)
|
|
return response
|
|
|
|
|
|
@router.post("/retry", response_model=RetryResponse)
|
|
def retry(req: RetryRequest):
|
|
"""Queue an async retry of unresolved candidates with different config."""
|
|
from core.detect.checkpoint.tasks import retry_candidates
|
|
|
|
kwargs = {
|
|
"timeline_id": req.timeline_id,
|
|
"config_overrides": req.config_overrides,
|
|
"start_stage": req.start_stage,
|
|
}
|
|
|
|
if req.schedule_seconds:
|
|
task = retry_candidates.apply_async(kwargs=kwargs, countdown=req.schedule_seconds)
|
|
else:
|
|
task = retry_candidates.delay(**kwargs)
|
|
|
|
response = RetryResponse(
|
|
status="queued",
|
|
task_id=task.id,
|
|
timeline_id=req.timeline_id,
|
|
)
|
|
return response
|
|
|
|
|
|
@router.post("/replay-stage", response_model=ReplaySingleStageResponse)
|
|
def replay_single_stage(req: ReplaySingleStageRequest):
|
|
"""Replay a single stage on specific frames — fast path for interactive tuning."""
|
|
from core.detect.checkpoint.replay import replay_single_stage as _replay
|
|
|
|
try:
|
|
result = _replay(
|
|
timeline_id=req.timeline_id,
|
|
stage=req.stage,
|
|
frame_refs=req.frame_refs,
|
|
config_overrides=req.config_overrides,
|
|
debug=req.debug,
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Single-stage replay failed: {e}")
|
|
|
|
# Convert result to response format
|
|
regions_by_frame = result.get("edge_regions_by_frame", {})
|
|
total_regions = 0
|
|
serialized = {}
|
|
for seq, boxes in regions_by_frame.items():
|
|
box_list = []
|
|
for b in boxes:
|
|
box = ReplaySingleStageBox(
|
|
x=b.x, y=b.y, w=b.w, h=b.h,
|
|
confidence=b.confidence, label=b.label,
|
|
)
|
|
box_list.append(box)
|
|
serialized[str(seq)] = box_list
|
|
total_regions += len(box_list)
|
|
|
|
# Serialize debug overlays if present
|
|
debug_out = {}
|
|
raw_debug = result.get("debug", {})
|
|
for seq, d in raw_debug.items():
|
|
debug_out[str(seq)] = FrameDebugOverlays(
|
|
edge_overlay_b64=d.get("edge_overlay_b64", ""),
|
|
lines_overlay_b64=d.get("lines_overlay_b64", ""),
|
|
horizontal_count=d.get("horizontal_count", 0),
|
|
pair_count=d.get("pair_count", 0),
|
|
)
|
|
|
|
return ReplaySingleStageResponse(
|
|
status="completed",
|
|
stage=req.stage,
|
|
frame_count=len(regions_by_frame),
|
|
region_count=total_regions,
|
|
regions_by_frame=serialized,
|
|
debug=debug_out,
|
|
)
|
|
|
|
|
|
# --- GPU proxy — thin passthrough to inference server for interactive editor ---
|
|
|
|
|
|
def _gpu_url() -> str:
|
|
url = os.environ.get("INFERENCE_URL", "http://localhost:8000")
|
|
return url.rstrip("/")
|
|
|
|
|
|
@router.post("/gpu/detect_edges")
|
|
async def gpu_detect_edges(request: Request):
|
|
"""Proxy to GPU inference server — browser can't reach it directly."""
|
|
import httpx
|
|
|
|
body = await request.body()
|
|
try:
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
resp = await client.post(
|
|
f"{_gpu_url()}/detect_edges",
|
|
content=body,
|
|
headers={"Content-Type": "application/json"},
|
|
)
|
|
return Response(content=resp.content, status_code=resp.status_code,
|
|
media_type="application/json")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
|
|
|
|
|
|
@router.post("/gpu/detect_edges/debug")
|
|
async def gpu_detect_edges_debug(request: Request):
|
|
"""Proxy to GPU inference server debug endpoint."""
|
|
import httpx
|
|
|
|
body = await request.body()
|
|
try:
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
resp = await client.post(
|
|
f"{_gpu_url()}/detect_edges/debug",
|
|
content=body,
|
|
headers={"Content-Type": "application/json"},
|
|
)
|
|
return Response(content=resp.content, status_code=resp.status_code,
|
|
media_type="application/json")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
|
|
|
|
|
|
@router.post("/gpu/segment_field")
|
|
async def gpu_segment_field(request: Request):
|
|
"""Proxy to GPU inference server — field segmentation."""
|
|
import httpx
|
|
|
|
body = await request.body()
|
|
try:
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
resp = await client.post(
|
|
f"{_gpu_url()}/segment_field",
|
|
content=body,
|
|
headers={"Content-Type": "application/json"},
|
|
)
|
|
return Response(content=resp.content, status_code=resp.status_code,
|
|
media_type="application/json")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
|
|
|
|
|
|
@router.post("/gpu/segment_field/debug")
|
|
async def gpu_segment_field_debug(request: Request):
|
|
"""Proxy to GPU inference server — field segmentation with debug overlay."""
|
|
import httpx
|
|
|
|
body = await request.body()
|
|
try:
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
resp = await client.post(
|
|
f"{_gpu_url()}/segment_field/debug",
|
|
content=body,
|
|
headers={"Content-Type": "application/json"},
|
|
)
|
|
return Response(content=resp.content, status_code=resp.status_code,
|
|
media_type="application/json")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
|