This commit is contained in:
2026-03-26 04:24:32 -03:00
parent 08b67f2bb7
commit 08c58a6a9d
43 changed files with 2627 additions and 252 deletions

121
core/api/detect_replay.py Normal file
View File

@@ -0,0 +1,121 @@
"""
API endpoints for checkpoint inspection, replay, and retry.
GET /detect/checkpoints/{job_id} — list available checkpoints
POST /detect/replay — replay from a stage with config overrides
POST /detect/retry — queue async retry with different provider
"""
from __future__ import annotations
import logging
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/detect", tags=["detect"])
# --- Request/Response models ---
class CheckpointInfo(BaseModel):
stage: str
class ReplayRequest(BaseModel):
job_id: str
start_stage: str
config_overrides: dict | None = None
class ReplayResponse(BaseModel):
status: str
job_id: str
start_stage: str
detections: int = 0
brands_found: int = 0
class RetryRequest(BaseModel):
job_id: str
config_overrides: dict | None = None
start_stage: str = "escalate_vlm"
schedule_seconds: float | None = None # delay before execution (off-peak)
class RetryResponse(BaseModel):
status: str
task_id: str
job_id: str
# --- Endpoints ---
@router.get("/checkpoints/{job_id}")
def list_checkpoints(job_id: str) -> list[CheckpointInfo]:
"""List available checkpoint stages for a job."""
from detect.checkpoint import list_checkpoints as _list
try:
stages = _list(job_id)
except Exception as e:
raise HTTPException(status_code=404, detail=f"No checkpoints for job {job_id}: {e}")
result = [CheckpointInfo(stage=s) for s in stages]
return result
@router.post("/replay", response_model=ReplayResponse)
def replay(req: ReplayRequest):
"""Replay pipeline from a specific stage with optional config overrides."""
from detect.checkpoint import replay_from
try:
result = replay_from(
job_id=req.job_id,
start_stage=req.start_stage,
config_overrides=req.config_overrides,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Replay failed: {e}")
detections = result.get("detections", [])
report = result.get("report")
brands_found = len(report.brands) if report else 0
response = ReplayResponse(
status="completed",
job_id=req.job_id,
start_stage=req.start_stage,
detections=len(detections),
brands_found=brands_found,
)
return response
@router.post("/retry", response_model=RetryResponse)
def retry(req: RetryRequest):
"""Queue an async retry of unresolved candidates with different config."""
from detect.checkpoint.tasks import retry_candidates
kwargs = {
"job_id": req.job_id,
"config_overrides": req.config_overrides,
"start_stage": req.start_stage,
}
if req.schedule_seconds:
task = retry_candidates.apply_async(kwargs=kwargs, countdown=req.schedule_seconds)
else:
task = retry_candidates.delay(**kwargs)
response = RetryResponse(
status="queued",
task_id=task.id,
job_id=req.job_id,
)
return response

View File

@@ -25,6 +25,7 @@ from strawberry.fastapi import GraphQLRouter
from core.api.chunker_sse import router as chunker_router
from core.api.detect_sse import router as detect_router
from core.api.detect_replay import router as detect_replay_router
from core.api.graphql import schema as graphql_schema
CALLBACK_API_KEY = os.environ.get("CALLBACK_API_KEY", "")
@@ -56,6 +57,9 @@ app.include_router(chunker_router)
# Detection SSE
app.include_router(detect_router)
# Detection replay/retry
app.include_router(detect_replay_router)
@app.get("/health")
def health():

175
core/db/detect.py Normal file
View File

@@ -0,0 +1,175 @@
"""Database operations for DetectJob and StageCheckpoint."""
from typing import Optional
from uuid import UUID
# ---------------------------------------------------------------------------
# DetectJob
# ---------------------------------------------------------------------------
def create_detect_job(**fields):
from admin.mpr.media_assets.models import DetectJob
return DetectJob.objects.create(**fields)
def get_detect_job(id: UUID):
from admin.mpr.media_assets.models import DetectJob
return DetectJob.objects.get(id=id)
def update_detect_job(job_id: UUID, **fields):
from admin.mpr.media_assets.models import DetectJob
DetectJob.objects.filter(id=job_id).update(**fields)
def list_detect_jobs(
parent_job_id: Optional[UUID] = None,
status: Optional[str] = None,
):
from admin.mpr.media_assets.models import DetectJob
qs = DetectJob.objects.all()
if parent_job_id:
qs = qs.filter(parent_job_id=parent_job_id)
if status:
qs = qs.filter(status=status)
return list(qs)
# ---------------------------------------------------------------------------
# StageCheckpoint
# ---------------------------------------------------------------------------
def save_stage_checkpoint(**fields):
from admin.mpr.media_assets.models import StageCheckpoint
return StageCheckpoint.objects.create(**fields)
def get_stage_checkpoint(job_id: UUID, stage: str):
from admin.mpr.media_assets.models import StageCheckpoint
return StageCheckpoint.objects.get(job_id=job_id, stage=stage)
def list_stage_checkpoints(job_id: UUID) -> list[str]:
from admin.mpr.media_assets.models import StageCheckpoint
stages = (
StageCheckpoint.objects
.filter(job_id=job_id)
.order_by("stage_index")
.values_list("stage", flat=True)
)
return list(stages)
def delete_stage_checkpoints(job_id: UUID):
from admin.mpr.media_assets.models import StageCheckpoint
StageCheckpoint.objects.filter(job_id=job_id).delete()
# ---------------------------------------------------------------------------
# KnownBrand
# ---------------------------------------------------------------------------
def get_or_create_brand(canonical_name: str, aliases: list[str] | None = None,
source: str = "ocr") -> tuple:
"""Get existing brand or create new one. Returns (brand, created)."""
from admin.mpr.media_assets.models import KnownBrand
import uuid
normalized = canonical_name.strip()
brand = KnownBrand.objects.filter(canonical_name__iexact=normalized).first()
if brand:
return brand, False
# Check aliases of existing brands
for existing in KnownBrand.objects.all():
existing_aliases = [a.lower() for a in (existing.aliases or [])]
if normalized.lower() in existing_aliases:
return existing, False
brand = KnownBrand.objects.create(
id=uuid.uuid4(),
canonical_name=normalized,
aliases=aliases or [],
first_source=source,
)
return brand, True
def find_brand_by_text(text: str) -> Optional[object]:
"""Find a known brand by canonical name or alias (case-insensitive)."""
from admin.mpr.media_assets.models import KnownBrand
normalized = text.strip().lower()
# Exact canonical match
brand = KnownBrand.objects.filter(canonical_name__iexact=normalized).first()
if brand:
return brand
# Search aliases (jsonb contains)
for brand in KnownBrand.objects.all():
brand_aliases = [a.lower() for a in (brand.aliases or [])]
if normalized in brand_aliases:
return brand
return None
def list_all_brands() -> list:
from admin.mpr.media_assets.models import KnownBrand
return list(KnownBrand.objects.all().order_by("canonical_name"))
def update_brand(brand_id: UUID, **fields):
from admin.mpr.media_assets.models import KnownBrand
KnownBrand.objects.filter(id=brand_id).update(**fields)
# ---------------------------------------------------------------------------
# SourceBrandSighting
# ---------------------------------------------------------------------------
def get_source_sightings(source_asset_id: UUID) -> list:
"""Get all brand sightings for a specific source video."""
from admin.mpr.media_assets.models import SourceBrandSighting
return list(
SourceBrandSighting.objects
.filter(source_asset_id=source_asset_id)
.order_by("-occurrences")
)
def record_sighting(source_asset_id: UUID, brand_id: UUID, brand_name: str,
timestamp: float, confidence: float, source: str = "ocr"):
"""Record or update a brand sighting for a source."""
from admin.mpr.media_assets.models import SourceBrandSighting
import uuid
sighting = SourceBrandSighting.objects.filter(
source_asset_id=source_asset_id,
brand_id=brand_id,
).first()
if sighting:
sighting.occurrences += 1
sighting.last_seen_timestamp = timestamp
total_conf = sighting.avg_confidence * (sighting.occurrences - 1) + confidence
sighting.avg_confidence = total_conf / sighting.occurrences
sighting.save()
return sighting
sighting = SourceBrandSighting.objects.create(
id=uuid.uuid4(),
source_asset_id=source_asset_id,
brand_id=brand_id,
brand_name=brand_name,
first_seen_timestamp=timestamp,
last_seen_timestamp=timestamp,
occurrences=1,
detection_source=source,
avg_confidence=confidence,
)
return sighting

View File

@@ -26,13 +26,18 @@ from .grpc import (
WorkerStatus,
)
from .jobs import ChunkJob, ChunkJobStatus, JobStatus, TranscodeJob
from .detect_jobs import (
DetectJob, DetectJobStatus, RunType, StageCheckpoint,
BrandSource, KnownBrand, SourceBrandSighting,
)
from .media import AssetStatus, MediaAsset
from .presets import BUILTIN_PRESETS, TranscodePreset
from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generic loader
from .views import ChunkEvent, ChunkOutputFile, PipelineStats, WorkerEvent
# Core domain models - generates Django, Pydantic, TypeScript
DATACLASSES = [MediaAsset, TranscodePreset, TranscodeJob, ChunkJob]
DATACLASSES = [MediaAsset, TranscodePreset, TranscodeJob, ChunkJob,
DetectJob, StageCheckpoint, KnownBrand, SourceBrandSighting]
# API request/response models - generates TypeScript only (no Django)
# WorkerStatus from grpc.py is reused here
@@ -46,7 +51,7 @@ API_MODELS = [
]
# Status enums - included in generated code
ENUMS = [AssetStatus, JobStatus, ChunkJobStatus]
ENUMS = [AssetStatus, JobStatus, ChunkJobStatus, DetectJobStatus, RunType, BrandSource]
# View/event models - generates TypeScript for UI consumption
VIEWS = [ChunkEvent, WorkerEvent, PipelineStats, ChunkOutputFile]

View File

@@ -149,6 +149,64 @@ class JobComplete:
report: Optional[DetectionReportSummary] = None
@dataclass
class RunContext:
"""Run context injected into all SSE events for grouping."""
run_id: str
parent_job_id: str
run_type: str = "initial" # initial | replay | retry
# --- Checkpoint API types ---
@dataclass
class CheckpointInfo:
"""Available checkpoint for a stage."""
stage: str
@dataclass
class ReplayRequest:
"""Request to replay pipeline from a specific stage."""
job_id: str
start_stage: str
config_overrides: Optional[dict] = None
@dataclass
class ReplayResponse:
"""Result of a replay invocation."""
status: str
job_id: str
start_stage: str
detections: int = 0
brands_found: int = 0
@dataclass
class RetryRequest:
"""Request to queue async retry with different config."""
job_id: str
config_overrides: Optional[dict] = None
start_stage: str = "escalate_vlm"
schedule_seconds: Optional[float] = None
@dataclass
class RetryResponse:
"""Result of queueing a retry task."""
status: str
task_id: str
job_id: str
# --- Export lists for modelgen ---
DETECT_VIEWS = [
@@ -163,4 +221,10 @@ DETECT_VIEWS = [
LogEvent,
DetectionReportSummary,
JobComplete,
RunContext,
CheckpointInfo,
ReplayRequest,
ReplayResponse,
RetryRequest,
RetryResponse,
]

View File

@@ -0,0 +1,162 @@
"""
Detection Job and Checkpoint Schema Definitions
Source of truth for detection pipeline job tracking and stage checkpoints.
Follows the TranscodeJob/ChunkJob pattern.
"""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
class DetectJobStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
PAUSED = "paused"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class RunType(str, Enum):
INITIAL = "initial"
REPLAY = "replay"
RETRY = "retry"
@dataclass
class DetectJob:
"""
A detection pipeline job.
Each invocation of the pipeline (initial run, replay, retry) creates a DetectJob.
Jobs for the same source video are linked via parent_job_id.
"""
id: UUID
# Input
source_asset_id: UUID
video_path: str
profile_name: str = "soccer_broadcast"
# Run lineage
parent_job_id: Optional[UUID] = None # links all runs for the same source
run_type: RunType = RunType.INITIAL
replay_from_stage: Optional[str] = None # null for initial runs
config_overrides: Dict[str, Any] = field(default_factory=dict)
# Status
status: DetectJobStatus = DetectJobStatus.PENDING
current_stage: Optional[str] = None
progress: float = 0.0
error_message: Optional[str] = None
# Results summary
total_detections: int = 0
brands_found: int = 0
cloud_llm_calls: int = 0
estimated_cost_usd: float = 0.0
# Worker tracking
celery_task_id: Optional[str] = None
priority: int = 0
# Timestamps
created_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
@dataclass
class StageCheckpoint:
"""
A checkpoint saved after a pipeline stage completes.
Binary data (frame images, crops) goes to S3/MinIO.
Everything else (structured state) lives here in Postgres.
"""
id: UUID
job_id: UUID
stage: str
stage_index: int # position in NODES list (0-7)
# S3 reference for binary data only
frames_prefix: str = "" # s3 prefix: checkpoints/{job_id}/frames/
# Frame metadata (non-image fields)
frames_manifest: Dict[int, str] = field(default_factory=dict) # seq → s3 key
frames_meta: List[Dict[str, Any]] = field(default_factory=list) # sequence, chunk_id, timestamp, hash
filtered_frame_sequences: List[int] = field(default_factory=list)
# Detection state (full structured data, not just summaries)
boxes_by_frame: Dict[str, List[Dict[str, Any]]] = field(default_factory=dict)
text_candidates: List[Dict[str, Any]] = field(default_factory=list)
unresolved_candidates: List[Dict[str, Any]] = field(default_factory=list)
detections: List[Dict[str, Any]] = field(default_factory=list)
# Pipeline state
stats: Dict[str, Any] = field(default_factory=dict)
config_snapshot: Dict[str, Any] = field(default_factory=dict)
config_overrides: Dict[str, Any] = field(default_factory=dict)
# Input refs (for replay)
video_path: str = ""
profile_name: str = ""
# Timestamps
created_at: Optional[datetime] = None
class BrandSource(str, Enum):
"""How a brand was first identified."""
OCR = "ocr"
VLM = "local_vlm"
CLOUD = "cloud_llm"
MANUAL = "manual" # user-added via UI
@dataclass
class KnownBrand:
"""
A brand discovered or registered in the system.
Global — not per-source. Accumulates across all pipeline runs.
Aliases enable fuzzy matching without re-escalating to VLM.
"""
id: UUID
canonical_name: str # normalized display name
aliases: List[str] = field(default_factory=list) # known spellings/variants
first_source: BrandSource = BrandSource.OCR
total_occurrences: int = 0
confirmed: bool = False # manually confirmed by user
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
@dataclass
class SourceBrandSighting:
"""
A brand seen in a specific source (video/asset).
Per-source session cache — avoids re-escalating the same brand
on subsequent frames or re-runs of the same source.
"""
id: UUID
source_asset_id: UUID # the video this sighting belongs to
brand_id: UUID # FK to KnownBrand
brand_name: str # denormalized for fast lookup
first_seen_timestamp: float = 0.0
last_seen_timestamp: float = 0.0
occurrences: int = 0
detection_source: BrandSource = BrandSource.OCR
avg_confidence: float = 0.0
created_at: Optional[datetime] = None