phase 10
This commit is contained in:
121
core/api/detect_replay.py
Normal file
121
core/api/detect_replay.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
API endpoints for checkpoint inspection, replay, and retry.
|
||||
|
||||
GET /detect/checkpoints/{job_id} — list available checkpoints
|
||||
POST /detect/replay — replay from a stage with config overrides
|
||||
POST /detect/retry — queue async retry with different provider
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/detect", tags=["detect"])
|
||||
|
||||
|
||||
# --- Request/Response models ---
|
||||
|
||||
class CheckpointInfo(BaseModel):
|
||||
stage: str
|
||||
|
||||
|
||||
class ReplayRequest(BaseModel):
|
||||
job_id: str
|
||||
start_stage: str
|
||||
config_overrides: dict | None = None
|
||||
|
||||
|
||||
class ReplayResponse(BaseModel):
|
||||
status: str
|
||||
job_id: str
|
||||
start_stage: str
|
||||
detections: int = 0
|
||||
brands_found: int = 0
|
||||
|
||||
|
||||
class RetryRequest(BaseModel):
|
||||
job_id: str
|
||||
config_overrides: dict | None = None
|
||||
start_stage: str = "escalate_vlm"
|
||||
schedule_seconds: float | None = None # delay before execution (off-peak)
|
||||
|
||||
|
||||
class RetryResponse(BaseModel):
|
||||
status: str
|
||||
task_id: str
|
||||
job_id: str
|
||||
|
||||
|
||||
# --- Endpoints ---
|
||||
|
||||
@router.get("/checkpoints/{job_id}")
|
||||
def list_checkpoints(job_id: str) -> list[CheckpointInfo]:
|
||||
"""List available checkpoint stages for a job."""
|
||||
from detect.checkpoint import list_checkpoints as _list
|
||||
|
||||
try:
|
||||
stages = _list(job_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404, detail=f"No checkpoints for job {job_id}: {e}")
|
||||
|
||||
result = [CheckpointInfo(stage=s) for s in stages]
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/replay", response_model=ReplayResponse)
|
||||
def replay(req: ReplayRequest):
|
||||
"""Replay pipeline from a specific stage with optional config overrides."""
|
||||
from detect.checkpoint import replay_from
|
||||
|
||||
try:
|
||||
result = replay_from(
|
||||
job_id=req.job_id,
|
||||
start_stage=req.start_stage,
|
||||
config_overrides=req.config_overrides,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Replay failed: {e}")
|
||||
|
||||
detections = result.get("detections", [])
|
||||
report = result.get("report")
|
||||
brands_found = len(report.brands) if report else 0
|
||||
|
||||
response = ReplayResponse(
|
||||
status="completed",
|
||||
job_id=req.job_id,
|
||||
start_stage=req.start_stage,
|
||||
detections=len(detections),
|
||||
brands_found=brands_found,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/retry", response_model=RetryResponse)
|
||||
def retry(req: RetryRequest):
|
||||
"""Queue an async retry of unresolved candidates with different config."""
|
||||
from detect.checkpoint.tasks import retry_candidates
|
||||
|
||||
kwargs = {
|
||||
"job_id": req.job_id,
|
||||
"config_overrides": req.config_overrides,
|
||||
"start_stage": req.start_stage,
|
||||
}
|
||||
|
||||
if req.schedule_seconds:
|
||||
task = retry_candidates.apply_async(kwargs=kwargs, countdown=req.schedule_seconds)
|
||||
else:
|
||||
task = retry_candidates.delay(**kwargs)
|
||||
|
||||
response = RetryResponse(
|
||||
status="queued",
|
||||
task_id=task.id,
|
||||
job_id=req.job_id,
|
||||
)
|
||||
return response
|
||||
@@ -25,6 +25,7 @@ from strawberry.fastapi import GraphQLRouter
|
||||
|
||||
from core.api.chunker_sse import router as chunker_router
|
||||
from core.api.detect_sse import router as detect_router
|
||||
from core.api.detect_replay import router as detect_replay_router
|
||||
from core.api.graphql import schema as graphql_schema
|
||||
|
||||
CALLBACK_API_KEY = os.environ.get("CALLBACK_API_KEY", "")
|
||||
@@ -56,6 +57,9 @@ app.include_router(chunker_router)
|
||||
# Detection SSE
|
||||
app.include_router(detect_router)
|
||||
|
||||
# Detection replay/retry
|
||||
app.include_router(detect_replay_router)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
|
||||
175
core/db/detect.py
Normal file
175
core/db/detect.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Database operations for DetectJob and StageCheckpoint."""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DetectJob
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_detect_job(**fields):
|
||||
from admin.mpr.media_assets.models import DetectJob
|
||||
return DetectJob.objects.create(**fields)
|
||||
|
||||
|
||||
def get_detect_job(id: UUID):
|
||||
from admin.mpr.media_assets.models import DetectJob
|
||||
return DetectJob.objects.get(id=id)
|
||||
|
||||
|
||||
def update_detect_job(job_id: UUID, **fields):
|
||||
from admin.mpr.media_assets.models import DetectJob
|
||||
DetectJob.objects.filter(id=job_id).update(**fields)
|
||||
|
||||
|
||||
def list_detect_jobs(
|
||||
parent_job_id: Optional[UUID] = None,
|
||||
status: Optional[str] = None,
|
||||
):
|
||||
from admin.mpr.media_assets.models import DetectJob
|
||||
|
||||
qs = DetectJob.objects.all()
|
||||
if parent_job_id:
|
||||
qs = qs.filter(parent_job_id=parent_job_id)
|
||||
if status:
|
||||
qs = qs.filter(status=status)
|
||||
return list(qs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StageCheckpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def save_stage_checkpoint(**fields):
|
||||
from admin.mpr.media_assets.models import StageCheckpoint
|
||||
return StageCheckpoint.objects.create(**fields)
|
||||
|
||||
|
||||
def get_stage_checkpoint(job_id: UUID, stage: str):
|
||||
from admin.mpr.media_assets.models import StageCheckpoint
|
||||
return StageCheckpoint.objects.get(job_id=job_id, stage=stage)
|
||||
|
||||
|
||||
def list_stage_checkpoints(job_id: UUID) -> list[str]:
|
||||
from admin.mpr.media_assets.models import StageCheckpoint
|
||||
|
||||
stages = (
|
||||
StageCheckpoint.objects
|
||||
.filter(job_id=job_id)
|
||||
.order_by("stage_index")
|
||||
.values_list("stage", flat=True)
|
||||
)
|
||||
return list(stages)
|
||||
|
||||
|
||||
def delete_stage_checkpoints(job_id: UUID):
|
||||
from admin.mpr.media_assets.models import StageCheckpoint
|
||||
StageCheckpoint.objects.filter(job_id=job_id).delete()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# KnownBrand
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def get_or_create_brand(canonical_name: str, aliases: list[str] | None = None,
|
||||
source: str = "ocr") -> tuple:
|
||||
"""Get existing brand or create new one. Returns (brand, created)."""
|
||||
from admin.mpr.media_assets.models import KnownBrand
|
||||
import uuid
|
||||
|
||||
normalized = canonical_name.strip()
|
||||
brand = KnownBrand.objects.filter(canonical_name__iexact=normalized).first()
|
||||
if brand:
|
||||
return brand, False
|
||||
|
||||
# Check aliases of existing brands
|
||||
for existing in KnownBrand.objects.all():
|
||||
existing_aliases = [a.lower() for a in (existing.aliases or [])]
|
||||
if normalized.lower() in existing_aliases:
|
||||
return existing, False
|
||||
|
||||
brand = KnownBrand.objects.create(
|
||||
id=uuid.uuid4(),
|
||||
canonical_name=normalized,
|
||||
aliases=aliases or [],
|
||||
first_source=source,
|
||||
)
|
||||
return brand, True
|
||||
|
||||
|
||||
def find_brand_by_text(text: str) -> Optional[object]:
|
||||
"""Find a known brand by canonical name or alias (case-insensitive)."""
|
||||
from admin.mpr.media_assets.models import KnownBrand
|
||||
|
||||
normalized = text.strip().lower()
|
||||
|
||||
# Exact canonical match
|
||||
brand = KnownBrand.objects.filter(canonical_name__iexact=normalized).first()
|
||||
if brand:
|
||||
return brand
|
||||
|
||||
# Search aliases (jsonb contains)
|
||||
for brand in KnownBrand.objects.all():
|
||||
brand_aliases = [a.lower() for a in (brand.aliases or [])]
|
||||
if normalized in brand_aliases:
|
||||
return brand
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def list_all_brands() -> list:
|
||||
from admin.mpr.media_assets.models import KnownBrand
|
||||
return list(KnownBrand.objects.all().order_by("canonical_name"))
|
||||
|
||||
|
||||
def update_brand(brand_id: UUID, **fields):
|
||||
from admin.mpr.media_assets.models import KnownBrand
|
||||
KnownBrand.objects.filter(id=brand_id).update(**fields)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SourceBrandSighting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def get_source_sightings(source_asset_id: UUID) -> list:
|
||||
"""Get all brand sightings for a specific source video."""
|
||||
from admin.mpr.media_assets.models import SourceBrandSighting
|
||||
return list(
|
||||
SourceBrandSighting.objects
|
||||
.filter(source_asset_id=source_asset_id)
|
||||
.order_by("-occurrences")
|
||||
)
|
||||
|
||||
|
||||
def record_sighting(source_asset_id: UUID, brand_id: UUID, brand_name: str,
|
||||
timestamp: float, confidence: float, source: str = "ocr"):
|
||||
"""Record or update a brand sighting for a source."""
|
||||
from admin.mpr.media_assets.models import SourceBrandSighting
|
||||
import uuid
|
||||
|
||||
sighting = SourceBrandSighting.objects.filter(
|
||||
source_asset_id=source_asset_id,
|
||||
brand_id=brand_id,
|
||||
).first()
|
||||
|
||||
if sighting:
|
||||
sighting.occurrences += 1
|
||||
sighting.last_seen_timestamp = timestamp
|
||||
total_conf = sighting.avg_confidence * (sighting.occurrences - 1) + confidence
|
||||
sighting.avg_confidence = total_conf / sighting.occurrences
|
||||
sighting.save()
|
||||
return sighting
|
||||
|
||||
sighting = SourceBrandSighting.objects.create(
|
||||
id=uuid.uuid4(),
|
||||
source_asset_id=source_asset_id,
|
||||
brand_id=brand_id,
|
||||
brand_name=brand_name,
|
||||
first_seen_timestamp=timestamp,
|
||||
last_seen_timestamp=timestamp,
|
||||
occurrences=1,
|
||||
detection_source=source,
|
||||
avg_confidence=confidence,
|
||||
)
|
||||
return sighting
|
||||
@@ -26,13 +26,18 @@ from .grpc import (
|
||||
WorkerStatus,
|
||||
)
|
||||
from .jobs import ChunkJob, ChunkJobStatus, JobStatus, TranscodeJob
|
||||
from .detect_jobs import (
|
||||
DetectJob, DetectJobStatus, RunType, StageCheckpoint,
|
||||
BrandSource, KnownBrand, SourceBrandSighting,
|
||||
)
|
||||
from .media import AssetStatus, MediaAsset
|
||||
from .presets import BUILTIN_PRESETS, TranscodePreset
|
||||
from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generic loader
|
||||
from .views import ChunkEvent, ChunkOutputFile, PipelineStats, WorkerEvent
|
||||
|
||||
# Core domain models - generates Django, Pydantic, TypeScript
|
||||
DATACLASSES = [MediaAsset, TranscodePreset, TranscodeJob, ChunkJob]
|
||||
DATACLASSES = [MediaAsset, TranscodePreset, TranscodeJob, ChunkJob,
|
||||
DetectJob, StageCheckpoint, KnownBrand, SourceBrandSighting]
|
||||
|
||||
# API request/response models - generates TypeScript only (no Django)
|
||||
# WorkerStatus from grpc.py is reused here
|
||||
@@ -46,7 +51,7 @@ API_MODELS = [
|
||||
]
|
||||
|
||||
# Status enums - included in generated code
|
||||
ENUMS = [AssetStatus, JobStatus, ChunkJobStatus]
|
||||
ENUMS = [AssetStatus, JobStatus, ChunkJobStatus, DetectJobStatus, RunType, BrandSource]
|
||||
|
||||
# View/event models - generates TypeScript for UI consumption
|
||||
VIEWS = [ChunkEvent, WorkerEvent, PipelineStats, ChunkOutputFile]
|
||||
|
||||
@@ -149,6 +149,64 @@ class JobComplete:
|
||||
report: Optional[DetectionReportSummary] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunContext:
|
||||
"""Run context injected into all SSE events for grouping."""
|
||||
|
||||
run_id: str
|
||||
parent_job_id: str
|
||||
run_type: str = "initial" # initial | replay | retry
|
||||
|
||||
|
||||
# --- Checkpoint API types ---
|
||||
|
||||
|
||||
@dataclass
|
||||
class CheckpointInfo:
|
||||
"""Available checkpoint for a stage."""
|
||||
|
||||
stage: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReplayRequest:
|
||||
"""Request to replay pipeline from a specific stage."""
|
||||
|
||||
job_id: str
|
||||
start_stage: str
|
||||
config_overrides: Optional[dict] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReplayResponse:
|
||||
"""Result of a replay invocation."""
|
||||
|
||||
status: str
|
||||
job_id: str
|
||||
start_stage: str
|
||||
detections: int = 0
|
||||
brands_found: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryRequest:
|
||||
"""Request to queue async retry with different config."""
|
||||
|
||||
job_id: str
|
||||
config_overrides: Optional[dict] = None
|
||||
start_stage: str = "escalate_vlm"
|
||||
schedule_seconds: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryResponse:
|
||||
"""Result of queueing a retry task."""
|
||||
|
||||
status: str
|
||||
task_id: str
|
||||
job_id: str
|
||||
|
||||
|
||||
# --- Export lists for modelgen ---
|
||||
|
||||
DETECT_VIEWS = [
|
||||
@@ -163,4 +221,10 @@ DETECT_VIEWS = [
|
||||
LogEvent,
|
||||
DetectionReportSummary,
|
||||
JobComplete,
|
||||
RunContext,
|
||||
CheckpointInfo,
|
||||
ReplayRequest,
|
||||
ReplayResponse,
|
||||
RetryRequest,
|
||||
RetryResponse,
|
||||
]
|
||||
|
||||
162
core/schema/models/detect_jobs.py
Normal file
162
core/schema/models/detect_jobs.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Detection Job and Checkpoint Schema Definitions
|
||||
|
||||
Source of truth for detection pipeline job tracking and stage checkpoints.
|
||||
Follows the TranscodeJob/ChunkJob pattern.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class DetectJobStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class RunType(str, Enum):
|
||||
INITIAL = "initial"
|
||||
REPLAY = "replay"
|
||||
RETRY = "retry"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectJob:
|
||||
"""
|
||||
A detection pipeline job.
|
||||
|
||||
Each invocation of the pipeline (initial run, replay, retry) creates a DetectJob.
|
||||
Jobs for the same source video are linked via parent_job_id.
|
||||
"""
|
||||
|
||||
id: UUID
|
||||
|
||||
# Input
|
||||
source_asset_id: UUID
|
||||
video_path: str
|
||||
profile_name: str = "soccer_broadcast"
|
||||
|
||||
# Run lineage
|
||||
parent_job_id: Optional[UUID] = None # links all runs for the same source
|
||||
run_type: RunType = RunType.INITIAL
|
||||
replay_from_stage: Optional[str] = None # null for initial runs
|
||||
config_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Status
|
||||
status: DetectJobStatus = DetectJobStatus.PENDING
|
||||
current_stage: Optional[str] = None
|
||||
progress: float = 0.0
|
||||
error_message: Optional[str] = None
|
||||
|
||||
# Results summary
|
||||
total_detections: int = 0
|
||||
brands_found: int = 0
|
||||
cloud_llm_calls: int = 0
|
||||
estimated_cost_usd: float = 0.0
|
||||
|
||||
# Worker tracking
|
||||
celery_task_id: Optional[str] = None
|
||||
priority: int = 0
|
||||
|
||||
# Timestamps
|
||||
created_at: Optional[datetime] = None
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageCheckpoint:
|
||||
"""
|
||||
A checkpoint saved after a pipeline stage completes.
|
||||
|
||||
Binary data (frame images, crops) goes to S3/MinIO.
|
||||
Everything else (structured state) lives here in Postgres.
|
||||
"""
|
||||
|
||||
id: UUID
|
||||
job_id: UUID
|
||||
stage: str
|
||||
stage_index: int # position in NODES list (0-7)
|
||||
|
||||
# S3 reference for binary data only
|
||||
frames_prefix: str = "" # s3 prefix: checkpoints/{job_id}/frames/
|
||||
|
||||
# Frame metadata (non-image fields)
|
||||
frames_manifest: Dict[int, str] = field(default_factory=dict) # seq → s3 key
|
||||
frames_meta: List[Dict[str, Any]] = field(default_factory=list) # sequence, chunk_id, timestamp, hash
|
||||
filtered_frame_sequences: List[int] = field(default_factory=list)
|
||||
|
||||
# Detection state (full structured data, not just summaries)
|
||||
boxes_by_frame: Dict[str, List[Dict[str, Any]]] = field(default_factory=dict)
|
||||
text_candidates: List[Dict[str, Any]] = field(default_factory=list)
|
||||
unresolved_candidates: List[Dict[str, Any]] = field(default_factory=list)
|
||||
detections: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
# Pipeline state
|
||||
stats: Dict[str, Any] = field(default_factory=dict)
|
||||
config_snapshot: Dict[str, Any] = field(default_factory=dict)
|
||||
config_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Input refs (for replay)
|
||||
video_path: str = ""
|
||||
profile_name: str = ""
|
||||
|
||||
# Timestamps
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class BrandSource(str, Enum):
|
||||
"""How a brand was first identified."""
|
||||
OCR = "ocr"
|
||||
VLM = "local_vlm"
|
||||
CLOUD = "cloud_llm"
|
||||
MANUAL = "manual" # user-added via UI
|
||||
|
||||
|
||||
@dataclass
|
||||
class KnownBrand:
|
||||
"""
|
||||
A brand discovered or registered in the system.
|
||||
|
||||
Global — not per-source. Accumulates across all pipeline runs.
|
||||
Aliases enable fuzzy matching without re-escalating to VLM.
|
||||
"""
|
||||
|
||||
id: UUID
|
||||
canonical_name: str # normalized display name
|
||||
aliases: List[str] = field(default_factory=list) # known spellings/variants
|
||||
first_source: BrandSource = BrandSource.OCR
|
||||
total_occurrences: int = 0
|
||||
confirmed: bool = False # manually confirmed by user
|
||||
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SourceBrandSighting:
|
||||
"""
|
||||
A brand seen in a specific source (video/asset).
|
||||
|
||||
Per-source session cache — avoids re-escalating the same brand
|
||||
on subsequent frames or re-runs of the same source.
|
||||
"""
|
||||
|
||||
id: UUID
|
||||
source_asset_id: UUID # the video this sighting belongs to
|
||||
brand_id: UUID # FK to KnownBrand
|
||||
brand_name: str # denormalized for fast lookup
|
||||
first_seen_timestamp: float = 0.0
|
||||
last_seen_timestamp: float = 0.0
|
||||
occurrences: int = 0
|
||||
detection_source: BrandSource = BrandSource.OCR
|
||||
avg_confidence: float = 0.0
|
||||
|
||||
created_at: Optional[datetime] = None
|
||||
Reference in New Issue
Block a user