phase cv 0
This commit is contained in:
@@ -64,25 +64,42 @@ def list_stage_configs():
|
||||
|
||||
result = []
|
||||
for stage in list_stages():
|
||||
info = StageConfigInfo(
|
||||
name=stage.name,
|
||||
label=stage.label,
|
||||
description=stage.description,
|
||||
category=stage.category,
|
||||
config_fields=[
|
||||
{
|
||||
"name": f.name,
|
||||
"type": f.type,
|
||||
"default": f.default,
|
||||
"description": f.description,
|
||||
"min": f.min,
|
||||
"max": f.max,
|
||||
"options": f.options,
|
||||
}
|
||||
for f in stage.config_fields
|
||||
],
|
||||
reads=stage.io.reads,
|
||||
writes=stage.io.writes,
|
||||
)
|
||||
info = _stage_to_info(stage)
|
||||
result.append(info)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/config/stages/{stage_name}", response_model=StageConfigInfo)
|
||||
def get_stage_config(stage_name: str):
|
||||
"""Return config field metadata for a single stage."""
|
||||
from detect.stages import get_stage
|
||||
|
||||
try:
|
||||
stage = get_stage(stage_name)
|
||||
except KeyError:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=404, detail=f"Unknown stage: {stage_name}")
|
||||
return _stage_to_info(stage)
|
||||
|
||||
|
||||
def _stage_to_info(stage) -> StageConfigInfo:
|
||||
return StageConfigInfo(
|
||||
name=stage.name,
|
||||
label=stage.label,
|
||||
description=stage.description,
|
||||
category=stage.category,
|
||||
config_fields=[
|
||||
{
|
||||
"name": f.name,
|
||||
"type": f.type,
|
||||
"default": f.default,
|
||||
"description": f.description,
|
||||
"min": f.min,
|
||||
"max": f.max,
|
||||
"options": f.options,
|
||||
}
|
||||
for f in stage.config_fields
|
||||
],
|
||||
reads=stage.io.reads,
|
||||
writes=stage.io.writes,
|
||||
)
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
"""
|
||||
API endpoints for checkpoint inspection, replay, and retry.
|
||||
API endpoints for checkpoint inspection, replay, retry, and GPU proxy.
|
||||
|
||||
GET /detect/checkpoints/{job_id} — list available checkpoints
|
||||
POST /detect/replay — replay from a stage with config overrides
|
||||
POST /detect/retry — queue async retry with different provider
|
||||
POST /detect/replay-stage — replay single stage (fast path)
|
||||
POST /detect/gpu/detect_edges — proxy to GPU inference server
|
||||
POST /detect/gpu/detect_edges/debug — proxy with debug overlays
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, Request, Response
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -22,6 +26,18 @@ router = APIRouter(prefix="/detect", tags=["detect"])
|
||||
|
||||
class CheckpointInfo(BaseModel):
|
||||
stage: str
|
||||
is_scenario: bool = False
|
||||
scenario_label: str = ""
|
||||
|
||||
|
||||
class ScenarioInfo(BaseModel):
|
||||
job_id: str
|
||||
stage: str
|
||||
scenario_label: str
|
||||
profile_name: str
|
||||
video_path: str
|
||||
frame_count: int = 0
|
||||
created_at: str = ""
|
||||
|
||||
|
||||
class ReplayRequest(BaseModel):
|
||||
@@ -51,6 +67,39 @@ class RetryResponse(BaseModel):
|
||||
job_id: str
|
||||
|
||||
|
||||
class ReplaySingleStageRequest(BaseModel):
|
||||
job_id: str
|
||||
stage: str
|
||||
frame_refs: list[int] | None = None
|
||||
config_overrides: dict | None = None
|
||||
debug: bool = False
|
||||
|
||||
|
||||
class ReplaySingleStageBox(BaseModel):
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
|
||||
|
||||
class FrameDebugOverlays(BaseModel):
|
||||
edge_overlay_b64: str = ""
|
||||
lines_overlay_b64: str = ""
|
||||
horizontal_count: int = 0
|
||||
pair_count: int = 0
|
||||
|
||||
|
||||
class ReplaySingleStageResponse(BaseModel):
|
||||
status: str
|
||||
stage: str
|
||||
frame_count: int = 0
|
||||
region_count: int = 0
|
||||
regions_by_frame: dict[str, list[ReplaySingleStageBox]] = {}
|
||||
debug: dict[str, FrameDebugOverlays] = {} # keyed by frame seq
|
||||
|
||||
|
||||
# --- Endpoints ---
|
||||
|
||||
@router.get("/checkpoints/{job_id}")
|
||||
@@ -67,6 +116,28 @@ def list_checkpoints(job_id: str) -> list[CheckpointInfo]:
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/scenarios", response_model=list[ScenarioInfo])
|
||||
def list_scenarios_endpoint():
|
||||
"""List all available scenarios (bookmarked checkpoints)."""
|
||||
from core.db.detect import list_scenarios
|
||||
|
||||
scenarios = list_scenarios()
|
||||
result = []
|
||||
for s in scenarios:
|
||||
manifest = s.frames_manifest or {}
|
||||
info = ScenarioInfo(
|
||||
job_id=str(s.job_id),
|
||||
stage=s.stage,
|
||||
scenario_label=s.scenario_label,
|
||||
profile_name=s.profile_name,
|
||||
video_path=s.video_path,
|
||||
frame_count=len(manifest),
|
||||
created_at=str(s.created_at) if s.created_at else "",
|
||||
)
|
||||
result.append(info)
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/replay", response_model=ReplayResponse)
|
||||
def replay(req: ReplayRequest):
|
||||
"""Replay pipeline from a specific stage with optional config overrides."""
|
||||
@@ -119,3 +190,103 @@ def retry(req: RetryRequest):
|
||||
job_id=req.job_id,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/replay-stage", response_model=ReplaySingleStageResponse)
|
||||
def replay_single_stage(req: ReplaySingleStageRequest):
|
||||
"""Replay a single stage on specific frames — fast path for interactive tuning."""
|
||||
from detect.checkpoint.replay import replay_single_stage as _replay
|
||||
|
||||
try:
|
||||
result = _replay(
|
||||
job_id=req.job_id,
|
||||
stage=req.stage,
|
||||
frame_refs=req.frame_refs,
|
||||
config_overrides=req.config_overrides,
|
||||
debug=req.debug,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Single-stage replay failed: {e}")
|
||||
|
||||
# Convert result to response format
|
||||
regions_by_frame = result.get("edge_regions_by_frame", {})
|
||||
total_regions = 0
|
||||
serialized = {}
|
||||
for seq, boxes in regions_by_frame.items():
|
||||
box_list = []
|
||||
for b in boxes:
|
||||
box = ReplaySingleStageBox(
|
||||
x=b.x, y=b.y, w=b.w, h=b.h,
|
||||
confidence=b.confidence, label=b.label,
|
||||
)
|
||||
box_list.append(box)
|
||||
serialized[str(seq)] = box_list
|
||||
total_regions += len(box_list)
|
||||
|
||||
# Serialize debug overlays if present
|
||||
debug_out = {}
|
||||
raw_debug = result.get("debug", {})
|
||||
for seq, d in raw_debug.items():
|
||||
debug_out[str(seq)] = FrameDebugOverlays(
|
||||
edge_overlay_b64=d.get("edge_overlay_b64", ""),
|
||||
lines_overlay_b64=d.get("lines_overlay_b64", ""),
|
||||
horizontal_count=d.get("horizontal_count", 0),
|
||||
pair_count=d.get("pair_count", 0),
|
||||
)
|
||||
|
||||
return ReplaySingleStageResponse(
|
||||
status="completed",
|
||||
stage=req.stage,
|
||||
frame_count=len(regions_by_frame),
|
||||
region_count=total_regions,
|
||||
regions_by_frame=serialized,
|
||||
debug=debug_out,
|
||||
)
|
||||
|
||||
|
||||
# --- GPU proxy — thin passthrough to inference server for interactive editor ---
|
||||
|
||||
|
||||
def _gpu_url() -> str:
|
||||
url = os.environ.get("INFERENCE_URL", "http://localhost:8000")
|
||||
return url.rstrip("/")
|
||||
|
||||
|
||||
@router.post("/gpu/detect_edges")
|
||||
async def gpu_detect_edges(request: Request):
|
||||
"""Proxy to GPU inference server — browser can't reach it directly."""
|
||||
import httpx
|
||||
|
||||
body = await request.body()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
f"{_gpu_url()}/detect_edges",
|
||||
content=body,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
return Response(content=resp.content, status_code=resp.status_code,
|
||||
media_type="application/json")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
|
||||
|
||||
|
||||
@router.post("/gpu/detect_edges/debug")
|
||||
async def gpu_detect_edges_debug(request: Request):
|
||||
"""Proxy to GPU inference server debug endpoint."""
|
||||
import httpx
|
||||
|
||||
body = await request.body()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
f"{_gpu_url()}/detect_edges/debug",
|
||||
content=body,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
return Response(content=resp.content, status_code=resp.status_code,
|
||||
media_type="application/json")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
|
||||
|
||||
@@ -102,6 +102,17 @@ def list_stage_checkpoints(job_id: UUID) -> list[str]:
|
||||
return list(session.exec(stmt).all())
|
||||
|
||||
|
||||
def list_scenarios() -> list[StageCheckpoint]:
|
||||
"""List all checkpoints marked as scenarios."""
|
||||
with get_session() as session:
|
||||
stmt = (
|
||||
select(StageCheckpoint)
|
||||
.where(StageCheckpoint.is_scenario == True)
|
||||
.order_by(StageCheckpoint.created_at.desc())
|
||||
)
|
||||
return list(session.exec(stmt).all())
|
||||
|
||||
|
||||
def delete_stage_checkpoints(job_id: UUID) -> None:
|
||||
with get_session() as session:
|
||||
stmt = select(StageCheckpoint).where(StageCheckpoint.job_id == job_id)
|
||||
|
||||
@@ -193,15 +193,14 @@ class StageCheckpoint(SQLModel, table=True):
|
||||
frames_manifest: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
frames_meta: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
|
||||
filtered_frame_sequences: List[int] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
|
||||
boxes_by_frame: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
text_candidates: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
|
||||
unresolved_candidates: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
|
||||
detections: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
|
||||
stage_output_key: str = "" # s3 key: checkpoints/{job_id}/stages/{stage}.bson
|
||||
stats: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
config_snapshot: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
video_path: str = ""
|
||||
profile_name: str = ""
|
||||
is_scenario: bool = False
|
||||
scenario_label: str = ""
|
||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
|
||||
class KnownBrand(SQLModel, table=True):
|
||||
|
||||
@@ -40,6 +40,11 @@
|
||||
"target": "typescript",
|
||||
"output": "ui/detection-app/src/types/store-state.ts",
|
||||
"include": ["ui_state_views"]
|
||||
},
|
||||
{
|
||||
"target": "pydantic",
|
||||
"output": "gpu/models/inference_contract.py",
|
||||
"include": ["inference_views"]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ from .detect_jobs import (
|
||||
from .media import AssetStatus, MediaAsset
|
||||
from .presets import BUILTIN_PRESETS, TranscodePreset
|
||||
from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generic loader
|
||||
from .inference import INFERENCE_VIEWS # noqa: F401 — GPU inference server API types
|
||||
from .ui_state import UI_STATE_VIEWS # noqa: F401 — UI store state types
|
||||
from .views import ChunkEvent, ChunkOutputFile, PipelineStats, WorkerEvent
|
||||
from .sources import ChunkInfo, SourceJob, SourceType
|
||||
|
||||
@@ -53,6 +53,7 @@ class BoundingBoxEvent:
|
||||
label: str
|
||||
resolved_brand: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
stage: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -85,6 +86,7 @@ class StatsUpdate:
|
||||
|
||||
frames_extracted: int = 0
|
||||
frames_after_scene_filter: int = 0
|
||||
cv_regions_detected: int = 0
|
||||
regions_detected: int = 0
|
||||
regions_resolved_by_ocr: int = 0
|
||||
regions_escalated_to_local_vlm: int = 0
|
||||
@@ -166,6 +168,8 @@ class CheckpointInfo:
|
||||
"""Available checkpoint for a stage."""
|
||||
|
||||
stage: str
|
||||
is_scenario: bool = False
|
||||
scenario_label: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -93,13 +93,12 @@ class StageCheckpoint:
|
||||
frames_meta: List[Dict[str, Any]] = field(default_factory=list) # sequence, chunk_id, timestamp, hash
|
||||
filtered_frame_sequences: List[int] = field(default_factory=list)
|
||||
|
||||
# Detection state (full structured data, not just summaries)
|
||||
boxes_by_frame: Dict[str, List[Dict[str, Any]]] = field(default_factory=dict)
|
||||
text_candidates: List[Dict[str, Any]] = field(default_factory=list)
|
||||
unresolved_candidates: List[Dict[str, Any]] = field(default_factory=list)
|
||||
detections: List[Dict[str, Any]] = field(default_factory=list)
|
||||
# Stage output — stored as blob in MinIO: checkpoints/{job_id}/stages/{stage}.bson
|
||||
# Each stage's serialize_fn/deserialize_fn owns the format.
|
||||
# Postgres only stores the S3 key, not the data itself.
|
||||
stage_output_key: str = "" # s3 key to the serialized stage output
|
||||
|
||||
# Pipeline state
|
||||
# Pipeline state (small, stays in Postgres)
|
||||
stats: Dict[str, Any] = field(default_factory=dict)
|
||||
config_snapshot: Dict[str, Any] = field(default_factory=dict)
|
||||
config_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||
@@ -108,6 +107,13 @@ class StageCheckpoint:
|
||||
video_path: str = ""
|
||||
profile_name: str = ""
|
||||
|
||||
# Scenario — a checkpoint bookmarked for the editor workflow.
|
||||
# Created by seeders (manual scripts that populate state from real footage)
|
||||
# or captured from a running pipeline. Loaded via URL:
|
||||
# /detection/?job=<job_id>&stage=<stage>&editor=true
|
||||
is_scenario: bool = False
|
||||
scenario_label: str = "" # human-readable name, e.g. "chelsea_edges_lowcanny"
|
||||
|
||||
# Timestamps
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
|
||||
@@ -70,6 +70,7 @@ class BrandStats:
|
||||
class PipelineStats:
|
||||
frames_extracted: int = 0
|
||||
frames_after_scene_filter: int = 0
|
||||
cv_regions_detected: int = 0
|
||||
regions_detected: int = 0
|
||||
regions_resolved_by_ocr: int = 0
|
||||
regions_escalated_to_local_vlm: int = 0
|
||||
|
||||
197
core/schema/models/inference.py
Normal file
197
core/schema/models/inference.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Inference Server API Schema Definitions
|
||||
|
||||
Source of truth for GPU inference server request/response types.
|
||||
Generates: Pydantic (gpu/models/inference_contract.py)
|
||||
|
||||
These are the wire-format types for the HTTP API between the
|
||||
pipeline (detect/) and the inference server (gpu/).
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
# --- Object Detection (YOLO) ---
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectRequest:
|
||||
"""Request body for object detection."""
|
||||
|
||||
image: str # base64 JPEG
|
||||
model: Optional[str] = None
|
||||
confidence: Optional[float] = None
|
||||
target_classes: Optional[List[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BBox:
|
||||
"""A detected bounding box."""
|
||||
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectResponse:
|
||||
"""Response from object detection."""
|
||||
|
||||
detections: List[BBox] = field(default_factory=list)
|
||||
|
||||
|
||||
# --- OCR ---
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRRequest:
|
||||
"""Request body for OCR."""
|
||||
|
||||
image: str # base64 JPEG
|
||||
languages: Optional[List[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRTextResult:
|
||||
"""A single OCR text extraction result."""
|
||||
|
||||
text: str
|
||||
confidence: float
|
||||
bbox: List[int] = field(default_factory=list) # [x, y, w, h]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResponse:
|
||||
"""Response from OCR."""
|
||||
|
||||
results: List[OCRTextResult] = field(default_factory=list)
|
||||
|
||||
|
||||
# --- Preprocessing ---
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessRequest:
|
||||
"""Request body for image preprocessing."""
|
||||
|
||||
image: str # base64 JPEG
|
||||
binarize: bool = False
|
||||
deskew: bool = False
|
||||
contrast: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessResponse:
|
||||
"""Response from preprocessing."""
|
||||
|
||||
image: str # base64 JPEG of processed image
|
||||
|
||||
|
||||
# --- VLM ---
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLMRequest:
|
||||
"""Request body for visual language model query."""
|
||||
|
||||
image: str # base64 JPEG
|
||||
prompt: str
|
||||
model: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLMResponse:
|
||||
"""Response from VLM."""
|
||||
|
||||
brand: str
|
||||
confidence: float
|
||||
reasoning: str
|
||||
|
||||
|
||||
# --- CV Region Analysis ---
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnalyzeRegionsRequest:
|
||||
"""Request body for CV region analysis."""
|
||||
|
||||
image: str # base64 JPEG
|
||||
# Edge detection (Canny + HoughLinesP)
|
||||
edge_canny_low: int = 50
|
||||
edge_canny_high: int = 150
|
||||
edge_hough_threshold: int = 80
|
||||
edge_hough_min_length: int = 100
|
||||
edge_hough_max_gap: int = 10
|
||||
edge_pair_max_distance: int = 200
|
||||
edge_pair_min_distance: int = 15
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegionBox:
|
||||
"""A candidate region from CV analysis."""
|
||||
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnalyzeRegionsResponse:
|
||||
"""Response from CV region analysis."""
|
||||
|
||||
regions: List[RegionBox] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnalyzeRegionsDebugResponse:
|
||||
"""Response from CV region analysis with debug overlays."""
|
||||
|
||||
regions: List[RegionBox] = field(default_factory=list)
|
||||
edge_overlay_b64: str = "" # Canny edge image as base64 JPEG
|
||||
lines_overlay_b64: str = "" # frame with Hough lines drawn
|
||||
horizontal_count: int = 0
|
||||
pair_count: int = 0
|
||||
|
||||
|
||||
# --- Server Config ---
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigUpdate:
|
||||
"""Request body for updating server configuration."""
|
||||
|
||||
device: Optional[str] = None
|
||||
yolo_model: Optional[str] = None
|
||||
yolo_confidence: Optional[float] = None
|
||||
vram_budget_mb: Optional[int] = None
|
||||
strategy: Optional[str] = None
|
||||
ocr_languages: Optional[List[str]] = None
|
||||
ocr_min_confidence: Optional[float] = None
|
||||
|
||||
|
||||
# --- Export list for modelgen ---
|
||||
|
||||
INFERENCE_VIEWS = [
|
||||
DetectRequest,
|
||||
BBox,
|
||||
DetectResponse,
|
||||
OCRRequest,
|
||||
OCRTextResult,
|
||||
OCRResponse,
|
||||
PreprocessRequest,
|
||||
PreprocessResponse,
|
||||
VLMRequest,
|
||||
VLMResponse,
|
||||
AnalyzeRegionsRequest,
|
||||
RegionBox,
|
||||
AnalyzeRegionsResponse,
|
||||
AnalyzeRegionsDebugResponse,
|
||||
ConfigUpdate,
|
||||
]
|
||||
Reference in New Issue
Block a user