phase 10
This commit is contained in:
121
core/api/detect_replay.py
Normal file
121
core/api/detect_replay.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
API endpoints for checkpoint inspection, replay, and retry.
|
||||
|
||||
GET /detect/checkpoints/{job_id} — list available checkpoints
|
||||
POST /detect/replay — replay from a stage with config overrides
|
||||
POST /detect/retry — queue async retry with different provider
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/detect", tags=["detect"])
|
||||
|
||||
|
||||
# --- Request/Response models ---
|
||||
|
||||
class CheckpointInfo(BaseModel):
|
||||
stage: str
|
||||
|
||||
|
||||
class ReplayRequest(BaseModel):
|
||||
job_id: str
|
||||
start_stage: str
|
||||
config_overrides: dict | None = None
|
||||
|
||||
|
||||
class ReplayResponse(BaseModel):
|
||||
status: str
|
||||
job_id: str
|
||||
start_stage: str
|
||||
detections: int = 0
|
||||
brands_found: int = 0
|
||||
|
||||
|
||||
class RetryRequest(BaseModel):
|
||||
job_id: str
|
||||
config_overrides: dict | None = None
|
||||
start_stage: str = "escalate_vlm"
|
||||
schedule_seconds: float | None = None # delay before execution (off-peak)
|
||||
|
||||
|
||||
class RetryResponse(BaseModel):
|
||||
status: str
|
||||
task_id: str
|
||||
job_id: str
|
||||
|
||||
|
||||
# --- Endpoints ---
|
||||
|
||||
@router.get("/checkpoints/{job_id}")
|
||||
def list_checkpoints(job_id: str) -> list[CheckpointInfo]:
|
||||
"""List available checkpoint stages for a job."""
|
||||
from detect.checkpoint import list_checkpoints as _list
|
||||
|
||||
try:
|
||||
stages = _list(job_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404, detail=f"No checkpoints for job {job_id}: {e}")
|
||||
|
||||
result = [CheckpointInfo(stage=s) for s in stages]
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/replay", response_model=ReplayResponse)
|
||||
def replay(req: ReplayRequest):
|
||||
"""Replay pipeline from a specific stage with optional config overrides."""
|
||||
from detect.checkpoint import replay_from
|
||||
|
||||
try:
|
||||
result = replay_from(
|
||||
job_id=req.job_id,
|
||||
start_stage=req.start_stage,
|
||||
config_overrides=req.config_overrides,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Replay failed: {e}")
|
||||
|
||||
detections = result.get("detections", [])
|
||||
report = result.get("report")
|
||||
brands_found = len(report.brands) if report else 0
|
||||
|
||||
response = ReplayResponse(
|
||||
status="completed",
|
||||
job_id=req.job_id,
|
||||
start_stage=req.start_stage,
|
||||
detections=len(detections),
|
||||
brands_found=brands_found,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/retry", response_model=RetryResponse)
|
||||
def retry(req: RetryRequest):
|
||||
"""Queue an async retry of unresolved candidates with different config."""
|
||||
from detect.checkpoint.tasks import retry_candidates
|
||||
|
||||
kwargs = {
|
||||
"job_id": req.job_id,
|
||||
"config_overrides": req.config_overrides,
|
||||
"start_stage": req.start_stage,
|
||||
}
|
||||
|
||||
if req.schedule_seconds:
|
||||
task = retry_candidates.apply_async(kwargs=kwargs, countdown=req.schedule_seconds)
|
||||
else:
|
||||
task = retry_candidates.delay(**kwargs)
|
||||
|
||||
response = RetryResponse(
|
||||
status="queued",
|
||||
task_id=task.id,
|
||||
job_id=req.job_id,
|
||||
)
|
||||
return response
|
||||
@@ -25,6 +25,7 @@ from strawberry.fastapi import GraphQLRouter
|
||||
|
||||
from core.api.chunker_sse import router as chunker_router
|
||||
from core.api.detect_sse import router as detect_router
|
||||
from core.api.detect_replay import router as detect_replay_router
|
||||
from core.api.graphql import schema as graphql_schema
|
||||
|
||||
CALLBACK_API_KEY = os.environ.get("CALLBACK_API_KEY", "")
|
||||
@@ -56,6 +57,9 @@ app.include_router(chunker_router)
|
||||
# Detection SSE
|
||||
app.include_router(detect_router)
|
||||
|
||||
# Detection replay/retry
|
||||
app.include_router(detect_replay_router)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
|
||||
Reference in New Issue
Block a user