122 lines
3.2 KiB
Python
122 lines
3.2 KiB
Python
"""
|
|
API endpoints for checkpoint inspection, replay, and retry.
|
|
|
|
GET /detect/checkpoints/{job_id} — list available checkpoints
|
|
POST /detect/replay — replay from a stage with config overrides
|
|
POST /detect/retry — queue async retry with different provider
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
from fastapi import APIRouter, HTTPException
|
|
from pydantic import BaseModel
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/detect", tags=["detect"])
|
|
|
|
|
|
# --- Request/Response models ---
|
|
|
|
class CheckpointInfo(BaseModel):
|
|
stage: str
|
|
|
|
|
|
class ReplayRequest(BaseModel):
|
|
job_id: str
|
|
start_stage: str
|
|
config_overrides: dict | None = None
|
|
|
|
|
|
class ReplayResponse(BaseModel):
|
|
status: str
|
|
job_id: str
|
|
start_stage: str
|
|
detections: int = 0
|
|
brands_found: int = 0
|
|
|
|
|
|
class RetryRequest(BaseModel):
|
|
job_id: str
|
|
config_overrides: dict | None = None
|
|
start_stage: str = "escalate_vlm"
|
|
schedule_seconds: float | None = None # delay before execution (off-peak)
|
|
|
|
|
|
class RetryResponse(BaseModel):
|
|
status: str
|
|
task_id: str
|
|
job_id: str
|
|
|
|
|
|
# --- Endpoints ---
|
|
|
|
@router.get("/checkpoints/{job_id}")
|
|
def list_checkpoints(job_id: str) -> list[CheckpointInfo]:
|
|
"""List available checkpoint stages for a job."""
|
|
from detect.checkpoint import list_checkpoints as _list
|
|
|
|
try:
|
|
stages = _list(job_id)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=404, detail=f"No checkpoints for job {job_id}: {e}")
|
|
|
|
result = [CheckpointInfo(stage=s) for s in stages]
|
|
return result
|
|
|
|
|
|
@router.post("/replay", response_model=ReplayResponse)
|
|
def replay(req: ReplayRequest):
|
|
"""Replay pipeline from a specific stage with optional config overrides."""
|
|
from detect.checkpoint import replay_from
|
|
|
|
try:
|
|
result = replay_from(
|
|
job_id=req.job_id,
|
|
start_stage=req.start_stage,
|
|
config_overrides=req.config_overrides,
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Replay failed: {e}")
|
|
|
|
detections = result.get("detections", [])
|
|
report = result.get("report")
|
|
brands_found = len(report.brands) if report else 0
|
|
|
|
response = ReplayResponse(
|
|
status="completed",
|
|
job_id=req.job_id,
|
|
start_stage=req.start_stage,
|
|
detections=len(detections),
|
|
brands_found=brands_found,
|
|
)
|
|
return response
|
|
|
|
|
|
@router.post("/retry", response_model=RetryResponse)
|
|
def retry(req: RetryRequest):
|
|
"""Queue an async retry of unresolved candidates with different config."""
|
|
from detect.checkpoint.tasks import retry_candidates
|
|
|
|
kwargs = {
|
|
"job_id": req.job_id,
|
|
"config_overrides": req.config_overrides,
|
|
"start_stage": req.start_stage,
|
|
}
|
|
|
|
if req.schedule_seconds:
|
|
task = retry_candidates.apply_async(kwargs=kwargs, countdown=req.schedule_seconds)
|
|
else:
|
|
task = retry_candidates.delay(**kwargs)
|
|
|
|
response = RetryResponse(
|
|
status="queued",
|
|
task_id=task.id,
|
|
job_id=req.job_id,
|
|
)
|
|
return response
|