major refactor

This commit is contained in:
2026-03-27 06:02:58 -03:00
parent bcf6f3dc71
commit 51ce14a812
18 changed files with 351 additions and 523 deletions

View File

@@ -1,7 +1,7 @@
""" """
API endpoints for checkpoint inspection, replay, retry, and GPU proxy. API endpoints for checkpoint inspection, replay, retry, and GPU proxy.
GET /detect/checkpoints/{job_id} — list available checkpoints GET /detect/checkpoints/{timeline_id} — list available checkpoints
POST /detect/replay — replay from a stage with config overrides POST /detect/replay — replay from a stage with config overrides
POST /detect/retry — queue async retry with different provider POST /detect/retry — queue async retry with different provider
POST /detect/replay-stage — replay single stage (fast path) POST /detect/replay-stage — replay single stage (fast path)
@@ -31,7 +31,7 @@ class CheckpointInfo(BaseModel):
class ScenarioInfo(BaseModel): class ScenarioInfo(BaseModel):
job_id: str timeline_id: str
stage: str stage: str
scenario_label: str scenario_label: str
profile_name: str profile_name: str
@@ -41,21 +41,21 @@ class ScenarioInfo(BaseModel):
class ReplayRequest(BaseModel): class ReplayRequest(BaseModel):
job_id: str timeline_id: str
start_stage: str start_stage: str
config_overrides: dict | None = None config_overrides: dict | None = None
class ReplayResponse(BaseModel): class ReplayResponse(BaseModel):
status: str status: str
job_id: str timeline_id: str
start_stage: str start_stage: str
detections: int = 0 detections: int = 0
brands_found: int = 0 brands_found: int = 0
class RetryRequest(BaseModel): class RetryRequest(BaseModel):
job_id: str timeline_id: str
config_overrides: dict | None = None config_overrides: dict | None = None
start_stage: str = "escalate_vlm" start_stage: str = "escalate_vlm"
schedule_seconds: float | None = None # delay before execution (off-peak) schedule_seconds: float | None = None # delay before execution (off-peak)
@@ -64,11 +64,11 @@ class RetryRequest(BaseModel):
class RetryResponse(BaseModel): class RetryResponse(BaseModel):
status: str status: str
task_id: str task_id: str
job_id: str timeline_id: str
class ReplaySingleStageRequest(BaseModel): class ReplaySingleStageRequest(BaseModel):
job_id: str timeline_id: str
stage: str stage: str
frame_refs: list[int] | None = None frame_refs: list[int] | None = None
config_overrides: dict | None = None config_overrides: dict | None = None
@@ -102,15 +102,15 @@ class ReplaySingleStageResponse(BaseModel):
# --- Endpoints --- # --- Endpoints ---
@router.get("/checkpoints/{job_id}") @router.get("/checkpoints/{timeline_id}")
def list_checkpoints(job_id: str) -> list[CheckpointInfo]: def list_checkpoints(timeline_id: str) -> list[CheckpointInfo]:
"""List available checkpoint stages for a job.""" """List available checkpoint stages for a job."""
from detect.checkpoint import list_checkpoints as _list from detect.checkpoint import list_checkpoints as _list
try: try:
stages = _list(job_id) stages = _list(timeline_id)
except Exception as e: except Exception as e:
raise HTTPException(status_code=404, detail=f"No checkpoints for job {job_id}: {e}") raise HTTPException(status_code=404, detail=f"No checkpoints for job {timeline_id}: {e}")
result = [CheckpointInfo(stage=s) for s in stages] result = [CheckpointInfo(stage=s) for s in stages]
return result return result
@@ -123,7 +123,7 @@ class CheckpointFrameInfo(BaseModel):
class CheckpointData(BaseModel): class CheckpointData(BaseModel):
job_id: str timeline_id: str
stage: str stage: str
profile_name: str profile_name: str
video_path: str video_path: str
@@ -135,26 +135,32 @@ class CheckpointData(BaseModel):
stage_output_key: str = "" stage_output_key: str = ""
@router.get("/checkpoints/{job_id}/{stage}", response_model=CheckpointData) @router.get("/checkpoints/{timeline_id}/{stage}", response_model=CheckpointData)
def get_checkpoint_data(job_id: str, stage: str): def get_checkpoint_data(timeline_id: str, stage: str):
"""Load checkpoint frames + metadata for the editor UI.""" """Load checkpoint frames + metadata for the editor UI."""
from core.db.detect import get_stage_checkpoint from uuid import UUID
from core.db.tables import Timeline, Checkpoint
from core.db.connection import get_session
from core.db.checkpoint import list_checkpoints
from detect.checkpoint.frames import load_frames_b64 from detect.checkpoint.frames import load_frames_b64
checkpoint = get_stage_checkpoint(job_id, stage) with get_session() as session:
if not checkpoint: timeline = session.get(Timeline, UUID(timeline_id))
raise HTTPException(status_code=404, detail=f"No checkpoint for {job_id}/{stage}") if not timeline:
raise HTTPException(status_code=404, detail=f"Timeline not found: {timeline_id}")
raw_manifest = checkpoint.frames_manifest or {} checkpoints = list_checkpoints(session, UUID(timeline_id))
if not checkpoints:
raise HTTPException(status_code=404, detail=f"No checkpoints for timeline {timeline_id}")
# Prefer a checkpoint that has this stage's output; fall back to latest
checkpoint = next(
(c for c in reversed(checkpoints) if stage in (c.stage_outputs or {})),
checkpoints[-1],
)
raw_manifest = timeline.frames_manifest or {}
manifest = {int(k): v for k, v in raw_manifest.items()} manifest = {int(k): v for k, v in raw_manifest.items()}
frame_metadata = checkpoint.frames_meta or [] frames_b64 = load_frames_b64(manifest, timeline.frames_meta or [])
# Only load filtered frames if available, otherwise all
filtered = set(checkpoint.filtered_frame_sequences or [])
if filtered:
manifest = {k: v for k, v in manifest.items() if k in filtered}
frames_b64 = load_frames_b64(manifest, frame_metadata)
frame_list = [ frame_list = [
CheckpointFrameInfo(seq=f["seq"], timestamp=f["timestamp"], jpeg_b64=f["jpeg_b64"]) CheckpointFrameInfo(seq=f["seq"], timestamp=f["timestamp"], jpeg_b64=f["jpeg_b64"])
@@ -162,35 +168,41 @@ def get_checkpoint_data(job_id: str, stage: str):
] ]
return CheckpointData( return CheckpointData(
job_id=str(checkpoint.job_id), timeline_id=timeline_id,
stage=checkpoint.stage, stage=stage,
profile_name=checkpoint.profile_name, profile_name=timeline.profile_name,
video_path=checkpoint.video_path, video_path=timeline.source_video,
is_scenario=checkpoint.is_scenario, is_scenario=checkpoint.is_scenario,
scenario_label=checkpoint.scenario_label, scenario_label=checkpoint.scenario_label,
frames=frame_list, frames=frame_list,
stats=checkpoint.stats or {}, stats=checkpoint.stats or {},
config_snapshot=checkpoint.config_snapshot or {}, config_snapshot=checkpoint.config_overrides or {},
stage_output_key=checkpoint.stage_output_key or "", stage_output_key=stage,
) )
@router.get("/scenarios", response_model=list[ScenarioInfo]) @router.get("/scenarios", response_model=list[ScenarioInfo])
def list_scenarios_endpoint(): def list_scenarios_endpoint():
"""List all available scenarios (bookmarked checkpoints).""" """List all available scenarios (bookmarked checkpoints)."""
from core.db.detect import list_scenarios from core.db.tables import Timeline
from core.db.connection import get_session
from core.db.checkpoint import list_scenarios
scenarios = list_scenarios() with get_session() as session:
scenarios = list_scenarios(session)
result = [] result = []
for s in scenarios: for s in scenarios:
manifest = s.frames_manifest or {} timeline = session.get(Timeline, s.timeline_id)
if not timeline:
continue
last_stage = next(reversed(s.stage_outputs), "") if s.stage_outputs else ""
info = ScenarioInfo( info = ScenarioInfo(
job_id=str(s.job_id), timeline_id=str(s.timeline_id),
stage=s.stage, stage=last_stage,
scenario_label=s.scenario_label, scenario_label=s.scenario_label,
profile_name=s.profile_name, profile_name=timeline.profile_name,
video_path=s.video_path, video_path=timeline.source_video,
frame_count=len(manifest), frame_count=len(timeline.frames_manifest or {}),
created_at=str(s.created_at) if s.created_at else "", created_at=str(s.created_at) if s.created_at else "",
) )
result.append(info) result.append(info)
@@ -204,7 +216,7 @@ def replay(req: ReplayRequest):
try: try:
result = replay_from( result = replay_from(
job_id=req.job_id, timeline_id=req.timeline_id,
start_stage=req.start_stage, start_stage=req.start_stage,
config_overrides=req.config_overrides, config_overrides=req.config_overrides,
) )
@@ -219,7 +231,7 @@ def replay(req: ReplayRequest):
response = ReplayResponse( response = ReplayResponse(
status="completed", status="completed",
job_id=req.job_id, timeline_id=req.timeline_id,
start_stage=req.start_stage, start_stage=req.start_stage,
detections=len(detections), detections=len(detections),
brands_found=brands_found, brands_found=brands_found,
@@ -233,7 +245,7 @@ def retry(req: RetryRequest):
from detect.checkpoint.tasks import retry_candidates from detect.checkpoint.tasks import retry_candidates
kwargs = { kwargs = {
"job_id": req.job_id, "timeline_id": req.timeline_id,
"config_overrides": req.config_overrides, "config_overrides": req.config_overrides,
"start_stage": req.start_stage, "start_stage": req.start_stage,
} }
@@ -246,7 +258,7 @@ def retry(req: RetryRequest):
response = RetryResponse( response = RetryResponse(
status="queued", status="queued",
task_id=task.id, task_id=task.id,
job_id=req.job_id, timeline_id=req.timeline_id,
) )
return response return response
@@ -258,7 +270,7 @@ def replay_single_stage(req: ReplaySingleStageRequest):
try: try:
result = _replay( result = _replay(
job_id=req.job_id, timeline_id=req.timeline_id,
stage=req.stage, stage=req.stage,
frame_refs=req.frame_refs, frame_refs=req.frame_refs,
config_overrides=req.config_overrides, config_overrides=req.config_overrides,

View File

@@ -25,11 +25,10 @@ from .grpc import (
ProgressUpdate, ProgressUpdate,
WorkerStatus, WorkerStatus,
) )
from .job import ( from .job import Job, JobStatus, RunType
Job, JobStatus, RunType, from .timeline import Timeline
Timeline, Checkpoint, from .checkpoint import Checkpoint
BrandSource, Brand, from .brand import BrandSource, Brand
)
from .media import AssetStatus, MediaAsset from .media import AssetStatus, MediaAsset
from .presets import BUILTIN_PRESETS, TranscodePreset from .presets import BUILTIN_PRESETS, TranscodePreset
from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generic loader from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generic loader

View File

@@ -0,0 +1,38 @@
"""Brand schema — source of truth for brand discovery."""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
class BrandSource(str, Enum):
OCR = "ocr"
VLM = "local_vlm"
CLOUD = "cloud_llm"
MANUAL = "manual"
@dataclass
class Brand:
"""
A brand discovered or registered in the system.
Airings track where/when the brand appeared — each airing
references a timeline and a frame range.
"""
id: UUID
canonical_name: str
aliases: List[str] = field(default_factory=list)
source: BrandSource = BrandSource.OCR # how first discovered
confirmed: bool = False
# Airings — JSONB array of appearances
# [{timeline_id, frame_start, frame_end, confidence, source, timestamp}]
airings: List[Dict[str, Any]] = field(default_factory=list)
total_airings: int = 0
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None

View File

@@ -0,0 +1,38 @@
"""Checkpoint schema — source of truth for pipeline state snapshots."""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, Optional
from uuid import UUID
@dataclass
class Checkpoint:
"""
A snapshot of pipeline state on a timeline.
Stage outputs stored as JSONB — each stage serializes to JSON,
the checkpoint stores it without knowing the shape.
parent_id forms a tree: multiple children from the same parent
= different config tries from the same starting point.
"""
id: UUID
timeline_id: UUID
parent_id: Optional[UUID] = None # null = root checkpoint
# Stage outputs — JSONB per stage, opaque to the checkpoint layer
stage_outputs: Dict[str, Any] = field(default_factory=dict)
# Config that produced this checkpoint
config_overrides: Dict[str, Any] = field(default_factory=dict)
# Pipeline state
stats: Dict[str, Any] = field(default_factory=dict)
# Scenario bookmark
is_scenario: bool = False
scenario_label: str = ""
created_at: Optional[datetime] = None

View File

@@ -1,177 +0,0 @@
"""
Detection Job and Checkpoint Schema Definitions
Source of truth for detection pipeline job tracking and stage checkpoints.
Follows the TranscodeJob/ChunkJob pattern.
"""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
class DetectJobStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
PAUSED = "paused"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class RunType(str, Enum):
INITIAL = "initial"
REPLAY = "replay"
RETRY = "retry"
@dataclass
class DetectJob:
"""
A detection pipeline job.
Each invocation of the pipeline (initial run, replay, retry) creates a DetectJob.
Jobs for the same source video are linked via parent_job_id.
"""
id: UUID
# Input
source_asset_id: UUID
video_path: str
profile_name: str = "soccer_broadcast"
# Run lineage
parent_job_id: Optional[UUID] = None # links all runs for the same source
run_type: RunType = RunType.INITIAL
replay_from_stage: Optional[str] = None # null for initial runs
config_overrides: Dict[str, Any] = field(default_factory=dict)
# Status
status: DetectJobStatus = DetectJobStatus.PENDING
current_stage: Optional[str] = None
progress: float = 0.0
error_message: Optional[str] = None
# Results summary
total_detections: int = 0
brands_found: int = 0
cloud_llm_calls: int = 0
estimated_cost_usd: float = 0.0
# Worker tracking
celery_task_id: Optional[str] = None
priority: int = 0
# Timestamps
created_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
@dataclass
class Timeline:
"""
The frame sequence from a source video.
Independent of stages — exists before any stage runs.
Stages annotate the timeline, they don't own it.
Frames are stored in MinIO as JPEGs.
"""
id: UUID
source_asset_id: Optional[UUID] = None
source_video: str = ""
profile_name: str = ""
fps: float = 2.0
# Frame metadata (images in MinIO, metadata here)
frames_prefix: str = "" # s3: timelines/{id}/frames/
frames_manifest: Dict[int, str] = field(default_factory=dict) # seq → s3 key
frames_meta: List[Dict[str, Any]] = field(default_factory=list)
created_at: Optional[datetime] = None
@dataclass
class Checkpoint:
"""
A snapshot of pipeline state on a timeline.
Stage outputs stored as JSONB — each stage serializes to JSON,
the checkpoint stores it without knowing the shape.
parent_id forms a tree: multiple children from the same parent
= different config tries from the same starting point.
"""
id: UUID
timeline_id: UUID
parent_id: Optional[UUID] = None # null = root checkpoint
# Stage outputs — JSONB per stage, opaque to the checkpoint layer
stage_outputs: Dict[str, Any] = field(default_factory=dict)
# Config that produced this checkpoint
config_overrides: Dict[str, Any] = field(default_factory=dict)
# Pipeline state
stats: Dict[str, Any] = field(default_factory=dict)
# Scenario bookmark
is_scenario: bool = False
scenario_label: str = ""
created_at: Optional[datetime] = None
class BrandSource(str, Enum):
"""How a brand was first identified."""
OCR = "ocr"
VLM = "local_vlm"
CLOUD = "cloud_llm"
MANUAL = "manual" # user-added via UI
@dataclass
class KnownBrand:
"""
A brand discovered or registered in the system.
Global — not per-source. Accumulates across all pipeline runs.
Aliases enable fuzzy matching without re-escalating to VLM.
"""
id: UUID
canonical_name: str # normalized display name
aliases: List[str] = field(default_factory=list) # known spellings/variants
first_source: BrandSource = BrandSource.OCR
total_occurrences: int = 0
confirmed: bool = False # manually confirmed by user
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
@dataclass
class SourceBrandSighting:
"""
A brand seen in a specific source (video/asset).
Per-source session cache — avoids re-escalating the same brand
on subsequent frames or re-runs of the same source.
"""
id: UUID
source_asset_id: UUID # the video this sighting belongs to
brand_id: UUID # FK to KnownBrand
brand_name: str # denormalized for fast lookup
first_seen_timestamp: float = 0.0
last_seen_timestamp: float = 0.0
occurrences: int = 0
detection_source: BrandSource = BrandSource.OCR
avg_confidence: float = 0.0
created_at: Optional[datetime] = None

View File

@@ -1,14 +1,9 @@
""" """Job schema — source of truth for pipeline jobs."""
Job, Timeline, and Checkpoint Schema Definitions
Source of truth for pipeline jobs, timelines, and checkpoints.
Generates: SQLModel (core/db/models.py), TypeScript via modelgen.
"""
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, Optional
from uuid import UUID from uuid import UUID
@@ -68,91 +63,3 @@ class Job:
created_at: Optional[datetime] = None created_at: Optional[datetime] = None
started_at: Optional[datetime] = None started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None completed_at: Optional[datetime] = None
@dataclass
class Timeline:
"""
The frame sequence from a source video.
Independent of stages — exists before any stage runs.
Frames stored in MinIO as JPEGs, metadata here.
One timeline per job.
"""
id: UUID
source_asset_id: Optional[UUID] = None
source_video: str = ""
profile_name: str = ""
fps: float = 2.0
frames_prefix: str = "" # s3: timeline/{id}/frames/
frames_manifest: Dict[int, str] = field(default_factory=dict) # seq → s3 key
frames_meta: List[Dict[str, Any]] = field(default_factory=list)
created_at: Optional[datetime] = None
@dataclass
class Checkpoint:
"""
A snapshot of pipeline state on a timeline.
Stage outputs stored as JSONB — each stage serializes to JSON,
the checkpoint stores it without knowing the shape.
parent_id forms a tree: multiple children from the same parent
= different config tries from the same starting point.
"""
id: UUID
timeline_id: UUID
parent_id: Optional[UUID] = None # null = root checkpoint
# Stage outputs — JSONB per stage, opaque to the checkpoint layer
stage_outputs: Dict[str, Any] = field(default_factory=dict)
# Config that produced this checkpoint
config_overrides: Dict[str, Any] = field(default_factory=dict)
# Pipeline state
stats: Dict[str, Any] = field(default_factory=dict)
# Scenario bookmark
is_scenario: bool = False
scenario_label: str = ""
created_at: Optional[datetime] = None
# --- Brands ---
class BrandSource(str, Enum):
OCR = "ocr"
VLM = "local_vlm"
CLOUD = "cloud_llm"
MANUAL = "manual"
@dataclass
class Brand:
"""
A brand discovered or registered in the system.
Airings track where/when the brand appeared — each airing
references a timeline and a frame range.
"""
id: UUID
canonical_name: str
aliases: List[str] = field(default_factory=list)
source: BrandSource = BrandSource.OCR # how first discovered
confirmed: bool = False
# Airings — JSONB array of appearances
# [{timeline_id, frame_start, frame_end, confidence, source, timestamp}]
airings: List[Dict[str, Any]] = field(default_factory=list)
total_airings: int = 0
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None

View File

@@ -1,13 +1,9 @@
""" """
Detection pipeline runtime models. Detection pipeline runtime models.
These are the data structures that flow between LangGraph nodes. These are the data structures that flow between pipeline stages.
They contain runtime types (np.ndarray) so they are NOT generated They contain runtime types (np.ndarray) so modelgen skips them
by modelgen they live here for the schema to be the complete not generated to SQLModel or TypeScript.
map of the application, but modelgen skips them.
Wire-format models (SSE events) are in detect.py.
DB models (jobs, checkpoints) are in detect_jobs.py.
""" """
from __future__ import annotations from __future__ import annotations
@@ -89,10 +85,3 @@ class DetectionReport:
brands: dict[str, BrandStats] = field(default_factory=dict) brands: dict[str, BrandStats] = field(default_factory=dict)
timeline: list[BrandDetection] = field(default_factory=list) timeline: list[BrandDetection] = field(default_factory=list)
pipeline_stats: PipelineStats = field(default_factory=PipelineStats) pipeline_stats: PipelineStats = field(default_factory=PipelineStats)
# Not in DATACLASSES — modelgen skips these (they contain np.ndarray)
RUNTIME_MODELS = [
Frame, BoundingBox, TextCandidate, BrandDetection,
BrandStats, PipelineStats, DetectionReport,
]

View File

@@ -0,0 +1,29 @@
"""Timeline schema — source of truth for frame sequences."""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
from uuid import UUID
@dataclass
class Timeline:
"""
The frame sequence from a source video.
Independent of stages — exists before any stage runs.
Frames stored in MinIO as JPEGs, metadata here.
One timeline per job.
"""
id: UUID
source_asset_id: Optional[UUID] = None
source_video: str = ""
profile_name: str = ""
fps: float = 2.0
frames_prefix: str = "" # s3: timeline/{id}/frames/
frames_manifest: Dict[int, str] = field(default_factory=dict) # seq → s3 key
frames_meta: List[Dict[str, Any]] = field(default_factory=list)
created_at: Optional[datetime] = None

View File

@@ -1,8 +1,6 @@
""" """
Serializers for detection pipeline runtime models. Serializers for detection pipeline runtime models.
Mirrors core/schema/models/detect_pipeline.py.
Special handling: Special handling:
- Frame.image (np.ndarray S3, excluded from JSON) - Frame.image (np.ndarray S3, excluded from JSON)
- TextCandidate.frame (object ref frame_sequence integer) - TextCandidate.frame (object ref frame_sequence integer)
@@ -13,7 +11,7 @@ from __future__ import annotations
import dataclasses import dataclasses
from core.schema.models.detect_pipeline import ( from core.schema.models.pipeline import (
BoundingBox, BoundingBox,
BrandDetection, BrandDetection,
BrandStats, BrandStats,
@@ -59,13 +57,12 @@ def deserialize_frames_with_download(meta: list[dict], manifest: dict, job_id: s
def serialize_text_candidate(tc: TextCandidate) -> dict: def serialize_text_candidate(tc: TextCandidate) -> dict:
bbox_dict = dataclasses.asdict(tc.bbox) bbox_dict = dataclasses.asdict(tc.bbox)
result = { return {
"frame_sequence": tc.frame.sequence, "frame_sequence": tc.frame.sequence,
"bbox": bbox_dict, "bbox": bbox_dict,
"text": tc.text, "text": tc.text,
"ocr_confidence": tc.ocr_confidence, "ocr_confidence": tc.ocr_confidence,
} }
return result
def serialize_text_candidates(candidates: list[TextCandidate]) -> list[dict]: def serialize_text_candidates(candidates: list[TextCandidate]) -> list[dict]:
@@ -75,13 +72,12 @@ def serialize_text_candidates(candidates: list[TextCandidate]) -> list[dict]:
def deserialize_text_candidate(data: dict, frame_map: dict[int, Frame]) -> TextCandidate: def deserialize_text_candidate(data: dict, frame_map: dict[int, Frame]) -> TextCandidate:
frame = frame_map[data["frame_sequence"]] frame = frame_map[data["frame_sequence"]]
bbox = safe_construct(BoundingBox, data["bbox"]) bbox = safe_construct(BoundingBox, data["bbox"])
candidate = TextCandidate( return TextCandidate(
frame=frame, frame=frame,
bbox=bbox, bbox=bbox,
text=data["text"], text=data["text"],
ocr_confidence=data["ocr_confidence"], ocr_confidence=data["ocr_confidence"],
) )
return candidate
def deserialize_text_candidates(data: list[dict], frame_map: dict[int, Frame]) -> list[TextCandidate]: def deserialize_text_candidates(data: list[dict], frame_map: dict[int, Frame]) -> list[TextCandidate]:

View File

@@ -11,7 +11,7 @@ that don't belong to any stage.
from __future__ import annotations from __future__ import annotations
from core.schema.serializers._common import serialize_dataclass from core.schema.serializers._common import serialize_dataclass
from core.schema.serializers.detect_pipeline import ( from core.schema.serializers.pipeline import (
deserialize_pipeline_stats, deserialize_pipeline_stats,
deserialize_text_candidates, deserialize_text_candidates,
) )

View File

@@ -33,22 +33,23 @@ def create_timeline(
Returns (timeline_id, checkpoint_id). Returns (timeline_id, checkpoint_id).
""" """
from core.db.detect import create_timeline as db_create_timeline from core.db.tables import Timeline, Checkpoint
from core.db.detect import save_checkpoint from core.db.connection import get_session
# Create timeline with get_session() as session:
timeline = db_create_timeline( timeline = Timeline(
source_video=source_video, source_video=source_video,
profile_name=profile_name, profile_name=profile_name,
source_asset_id=source_asset_id, source_asset_id=source_asset_id,
fps=fps, fps=fps,
) )
session.add(timeline)
session.flush()
tid = str(timeline.id) tid = str(timeline.id)
# Upload frames to MinIO # Upload frames to MinIO
manifest = save_frames(tid, frames) manifest = save_frames(tid, frames)
# Store frame metadata on the timeline
frames_meta = [ frames_meta = [
{ {
"sequence": f.sequence, "sequence": f.sequence,
@@ -63,53 +64,50 @@ def create_timeline(
timeline.frames_manifest = {str(k): v for k, v in manifest.items()} timeline.frames_manifest = {str(k): v for k, v in manifest.items()}
timeline.frames_meta = frames_meta timeline.frames_meta = frames_meta
from core.db.connection import get_session checkpoint = Checkpoint(
with get_session() as session:
session.add(timeline)
session.commit()
# Create root checkpoint (no parent, no stage outputs yet)
checkpoint = save_checkpoint(
timeline_id=timeline.id, timeline_id=timeline.id,
parent_id=None, parent_id=None,
stage_outputs={}, stage_outputs={},
stats={"frames_extracted": len(frames)}, stats={"frames_extracted": len(frames)},
) )
session.add(checkpoint)
session.commit()
session.refresh(checkpoint)
cid = str(checkpoint.id)
logger.info("Timeline created: %s (%d frames, root checkpoint %s)", logger.info("Timeline created: %s (%d frames, root checkpoint %s)", tid, len(frames), cid)
tid, len(frames), checkpoint.id) return tid, cid
return tid, str(checkpoint.id)
def get_timeline_frames(timeline_id: str) -> list: def get_timeline_frames(timeline_id: str) -> list:
"""Load frames from a timeline (from MinIO) as Frame objects.""" """Load frames from a timeline (from MinIO) as Frame objects."""
from core.db.detect import get_timeline from core.db.tables import Timeline
from core.db.connection import get_session
timeline = get_timeline(timeline_id) with get_session() as session:
timeline = session.get(Timeline, UUID(timeline_id))
if not timeline: if not timeline:
raise ValueError(f"Timeline not found: {timeline_id}") raise ValueError(f"Timeline not found: {timeline_id}")
raw_manifest = timeline.frames_manifest or {} raw_manifest = timeline.frames_manifest or {}
manifest = {int(k): v for k, v in raw_manifest.items()} manifest = {int(k): v for k, v in raw_manifest.items()}
frame_metadata = timeline.frames_meta or [] return load_frames(manifest, timeline.frames_meta or [])
return load_frames(manifest, frame_metadata)
def get_timeline_frames_b64(timeline_id: str) -> list[dict]: def get_timeline_frames_b64(timeline_id: str) -> list[dict]:
"""Load frames as base64 JPEG (lightweight, no numpy).""" """Load frames as base64 JPEG (lightweight, no numpy)."""
from core.db.detect import get_timeline from core.db.tables import Timeline
from core.db.connection import get_session
from .frames import load_frames_b64 from .frames import load_frames_b64
timeline = get_timeline(timeline_id) with get_session() as session:
timeline = session.get(Timeline, UUID(timeline_id))
if not timeline: if not timeline:
raise ValueError(f"Timeline not found: {timeline_id}") raise ValueError(f"Timeline not found: {timeline_id}")
raw_manifest = timeline.frames_manifest or {} raw_manifest = timeline.frames_manifest or {}
manifest = {int(k): v for k, v in raw_manifest.items()} manifest = {int(k): v for k, v in raw_manifest.items()}
frame_metadata = timeline.frames_meta or [] return load_frames_b64(manifest, timeline.frames_meta or [])
return load_frames_b64(manifest, frame_metadata)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -132,47 +130,46 @@ def save_stage_output(
Carries forward stage outputs from parent + adds the new one. Carries forward stage outputs from parent + adds the new one.
Returns the new checkpoint ID. Returns the new checkpoint ID.
""" """
from core.db.detect import get_checkpoint, save_checkpoint from core.db.tables import Checkpoint
from core.db.connection import get_session
# Carry forward from parent with get_session() as session:
parent_outputs = {} parent_outputs = {}
parent_stats = {} parent_stats = {}
parent_config = {} parent_config = {}
if parent_checkpoint_id: if parent_checkpoint_id:
parent = get_checkpoint(parent_checkpoint_id) parent = session.get(Checkpoint, UUID(parent_checkpoint_id))
if parent: if parent:
parent_outputs = dict(parent.stage_outputs or {}) parent_outputs = dict(parent.stage_outputs or {})
parent_stats = dict(parent.stats or {}) parent_stats = dict(parent.stats or {})
parent_config = dict(parent.config_overrides or {}) parent_config = dict(parent.config_overrides or {})
# Add new stage output checkpoint = Checkpoint(
stage_outputs = {**parent_outputs, stage_name: output_json} timeline_id=UUID(timeline_id),
parent_id=UUID(parent_checkpoint_id) if parent_checkpoint_id else None,
# Merge stats and config stage_outputs={**parent_outputs, stage_name: output_json},
merged_stats = {**parent_stats, **(stats or {})} config_overrides={**parent_config, **(config_overrides or {})},
merged_config = {**parent_config, **(config_overrides or {})} stats={**parent_stats, **(stats or {})},
checkpoint = save_checkpoint(
timeline_id=timeline_id,
parent_id=parent_checkpoint_id,
stage_outputs=stage_outputs,
config_overrides=merged_config,
stats=merged_stats,
is_scenario=is_scenario, is_scenario=is_scenario,
scenario_label=scenario_label, scenario_label=scenario_label,
) )
session.add(checkpoint)
session.commit()
session.refresh(checkpoint)
cid = str(checkpoint.id)
logger.info("Checkpoint saved: %s (timeline %s, stage %s, parent %s)", logger.info("Checkpoint saved: %s (timeline %s, stage %s, parent %s)",
checkpoint.id, timeline_id, stage_name, parent_checkpoint_id) cid, timeline_id, stage_name, parent_checkpoint_id)
return str(checkpoint.id) return cid
def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None: def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None:
"""Load a stage's output from a checkpoint.""" """Load a stage's output from a checkpoint."""
from core.db.detect import get_checkpoint from core.db.tables import Checkpoint
from core.db.connection import get_session
checkpoint = get_checkpoint(checkpoint_id) with get_session() as session:
checkpoint = session.get(Checkpoint, UUID(checkpoint_id))
if not checkpoint: if not checkpoint:
return None return None
return (checkpoint.stage_outputs or {}).get(stage_name) return (checkpoint.stage_outputs or {}).get(stage_name)

View File

@@ -326,6 +326,7 @@ def node_compile_report(state: DetectState) -> dict:
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1" _CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job) _frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
_latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id
class PipelineCancelled(Exception): class PipelineCancelled(Exception):
@@ -361,17 +362,33 @@ def _checkpointing_node(node_name: str, node_fn):
if not job_id: if not job_id:
return result return result
from detect.checkpoint import save_checkpoint, save_frames from detect.checkpoint import save_stage_output, save_frames
from detect.stages.base import _REGISTRY
merged = {**state, **result} merged = {**state, **result}
# Save frames once (first checkpoint), reuse manifest after # Save frames once (first node), reuse manifest after
manifest = _frames_manifest.get(job_id) manifest = _frames_manifest.get(job_id)
if manifest is None and node_name == "extract_frames": if manifest is None and node_name == "extract_frames":
manifest = save_frames(job_id, merged.get("frames", [])) manifest = save_frames(job_id, merged.get("frames", []))
_frames_manifest[job_id] = manifest _frames_manifest[job_id] = manifest
save_checkpoint(job_id, node_name, stage_index, merged, frames_manifest=manifest) # Serialize stage output using the stage's serialize_fn if available
stage_cls = _REGISTRY.get(node_name)
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
if serialize_fn:
output_json = serialize_fn(merged, job_id)
else:
output_json = {}
parent_id = _latest_checkpoint.get(job_id)
new_checkpoint_id = save_stage_output(
timeline_id=job_id,
parent_checkpoint_id=parent_id,
stage_name=node_name,
output_json=output_json,
)
_latest_checkpoint[job_id] = new_checkpoint_id
return result return result
wrapper.__name__ = node_fn.__name__ wrapper.__name__ = node_fn.__name__

View File

@@ -1,11 +1,6 @@
""" """Re-export pipeline runtime models from core/schema/models/pipeline.py."""
Re-export pipeline runtime models from core/schema/models/detect_pipeline.py.
All models are defined in core/schema/ — this module exists for backward from core.schema.models.pipeline import (
compatibility so existing imports (from detect.models import Frame) keep working.
"""
from core.schema.models.detect_pipeline import (
BoundingBox, BoundingBox,
BrandDetection, BrandDetection,
BrandStats, BrandStats,

View File

@@ -4,14 +4,10 @@ Stage 5 — Brand Resolver (discovery mode)
Discovery-first brand matching. No static dictionary — all brands live in the DB. Discovery-first brand matching. No static dictionary — all brands live in the DB.
Flow: Flow:
1. Check session sightings first (brands already seen in this source) 1. Check session brands first (brands already seen in this run, in-memory)
2. Check global known brands (accumulated across all runs) 2. Check global known brands (accumulated across all runs)
3. Unresolved candidates → escalate to VLM/cloud 3. Unresolved candidates → escalate to VLM/cloud
4. Confirmed brands get added to DB for future runs 4. Confirmed brands get added to DB for future runs
The resolver is an enricher, not a gatekeeper. Every OCR text candidate
passes through — the question is whether we can resolve it cheaply (DB lookup)
or need to escalate (VLM/cloud).
""" """
from __future__ import annotations from __future__ import annotations
@@ -33,41 +29,30 @@ def _normalize(text: str) -> str:
def _has_db() -> bool: def _has_db() -> bool:
try: try:
from core.db.detect import find_brand_by_text as _ from core.db import find_brand_by_text as _
from admin.mpr.media_assets.models import KnownBrand as _
return True return True
except (ImportError, Exception): except (ImportError, Exception):
return False return False
def _match_session(text: str, session_brands: dict[str, str]) -> str | None: def _match_session(text: str, session_brands: dict[str, str]) -> str | None:
""" return session_brands.get(_normalize(text))
Check against session brands (already seen in this source).
session_brands: {normalized_name: canonical_name, ...}
Includes aliases.
"""
normalized = _normalize(text)
return session_brands.get(normalized)
def _match_known(text: str, threshold: int) -> tuple[str | None, str | None]: def _match_known(text: str, threshold: int) -> tuple[str | None, str | None]:
""" """Check against global known brands in DB. Returns (canonical_name, brand_id) or (None, None)."""
Check against global known brands in DB.
Returns (canonical_name, brand_id) or (None, None).
"""
if not _has_db(): if not _has_db():
return None, None return None, None
from core.db.detect import find_brand_by_text from core.db import find_brand_by_text, list_brands
brand = find_brand_by_text(text) from core.db.connection import get_session
with get_session() as session:
brand = find_brand_by_text(session, text)
if brand: if brand:
return brand.canonical_name, str(brand.id) return brand.canonical_name, str(brand.id)
# Fuzzy match against all known brands all_brands = list_brands(session)
from core.db.detect import list_all_brands
all_brands = list_all_brands()
normalized = _normalize(text) normalized = _normalize(text)
best_brand = None best_brand = None
@@ -92,58 +77,62 @@ def _register_brand(canonical_name: str, source: str) -> str | None:
if not _has_db(): if not _has_db():
return None return None
from core.db.detect import get_or_create_brand from core.db import get_or_create_brand
brand, created = get_or_create_brand(canonical_name, source=source) from core.db.connection import get_session
with get_session() as session:
brand, created = get_or_create_brand(session, canonical_name, source=source)
session.commit()
if created: if created:
logger.info("New brand discovered: %s (source=%s)", canonical_name, source) logger.info("New brand discovered: %s (source=%s)", canonical_name, source)
return str(brand.id) return str(brand.id)
def _record_sighting(source_asset_id: str | None, brand_id: str, def _record_airing(timeline_id: str | None, brand_id: str,
brand_name: str, timestamp: float, frame_seq: int, confidence: float, source: str):
confidence: float, source: str): """Record a brand airing on a timeline."""
"""Record a brand sighting for this source.""" if not _has_db() or not timeline_id:
if not _has_db() or not source_asset_id:
return return
from core.db.detect import record_sighting from core.db import record_airing
import uuid from core.db.connection import get_session
asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id from uuid import UUID
brand_uuid = uuid.UUID(brand_id) if isinstance(brand_id, str) else brand_id
record_sighting(asset_id, brand_uuid, brand_name, timestamp, confidence, source) with get_session() as session:
record_airing(
session,
brand_id=UUID(brand_id),
timeline_id=UUID(timeline_id),
frame_start=frame_seq,
frame_end=frame_seq,
confidence=confidence,
source=source,
)
session.commit()
def build_session_dict(source_asset_id: str | None) -> dict[str, str]: def build_session_dict(source_asset_id: str | None = None) -> dict[str, str]:
""" """
Load session brands from DB for this source. Load known brands from DB as a session lookup dict.
Returns {normalized_name: canonical_name, ...} including aliases. Returns {normalized_name: canonical_name, ...} including aliases.
""" """
if not _has_db() or not source_asset_id: if not _has_db():
return {} return {}
from core.db.detect import get_source_sightings from core.db import list_brands
import uuid from core.db.connection import get_session
asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id with get_session() as session:
sightings = get_source_sightings(asset_id) all_brands = list_brands(session)
session = {} session_dict = {}
for s in sightings:
canonical = s.brand_name
session[_normalize(canonical)] = canonical
# Also load aliases from KnownBrand for each sighted brand
if _has_db():
from core.db.detect import list_all_brands
all_brands = list_all_brands()
sighted_names = {s.brand_name for s in sightings}
for brand in all_brands: for brand in all_brands:
if brand.canonical_name in sighted_names: session_dict[_normalize(brand.canonical_name)] = brand.canonical_name
for alias in (brand.aliases or []): for alias in (brand.aliases or []):
session[_normalize(alias)] = brand.canonical_name session_dict[_normalize(alias)] = brand.canonical_name
return session return session_dict
def resolve_brands( def resolve_brands(
@@ -158,7 +147,7 @@ def resolve_brands(
Match text candidates against known brands (session → global → unresolved). Match text candidates against known brands (session → global → unresolved).
session_brands: pre-loaded session dict (from build_session_dict) session_brands: pre-loaded session dict (from build_session_dict)
source_asset_id: for recording new sightings in DB job_id: timeline_id — used to record airings
""" """
if session_brands is None: if session_brands is None:
session_brands = {} session_brands = {}
@@ -187,7 +176,6 @@ def resolve_brands(
brand_name, brand_id = _match_known(text, config.fuzzy_threshold) brand_name, brand_id = _match_known(text, config.fuzzy_threshold)
if brand_name: if brand_name:
known_hits += 1 known_hits += 1
# Add to session for subsequent candidates in this run
session_brands[_normalize(brand_name)] = brand_name session_brands[_normalize(brand_name)] = brand_name
if brand_name: if brand_name:
@@ -203,11 +191,10 @@ def resolve_brands(
) )
matched.append(detection) matched.append(detection)
# Record sighting in DB
if brand_id: if brand_id:
_record_sighting( _record_airing(
source_asset_id, brand_id, brand_name, job_id, brand_id,
candidate.frame.timestamp, candidate.ocr_confidence, match_source, candidate.frame.sequence, candidate.ocr_confidence, match_source,
) )
emit.detection( emit.detection(

View File

@@ -10,7 +10,7 @@ from core.schema.serializers._common import (
serialize_dataclass, serialize_dataclass,
serialize_dataclass_list, serialize_dataclass_list,
) )
from core.schema.serializers.detect_pipeline import ( from core.schema.serializers.pipeline import (
serialize_frame_meta, serialize_frame_meta,
serialize_frames_with_upload as serialize_frames, serialize_frames_with_upload as serialize_frames,
deserialize_frames_with_download as deserialize_frames, deserialize_frames_with_download as deserialize_frames,

View File

@@ -35,9 +35,11 @@ logger = logging.getLogger(__name__)
def main(): def main():
from core.db.detect import list_scenarios from core.db import list_scenarios
from core.db.connection import get_session
scenarios = list_scenarios() with get_session() as session:
scenarios = list_scenarios(session)
if not scenarios: if not scenarios:
logger.info("No scenarios found. Create one with:") logger.info("No scenarios found. Create one with:")

View File

@@ -121,15 +121,14 @@ def main():
) )
# Mark as scenario # Mark as scenario
from core.db.detect import get_latest_checkpoint from core.db import get_latest_checkpoint
from core.db.connection import get_session from core.db.connection import get_session
checkpoint = get_latest_checkpoint(branch_id) with get_session() as session:
checkpoint = get_latest_checkpoint(session, branch_id)
if checkpoint: if checkpoint:
checkpoint.is_scenario = True checkpoint.is_scenario = True
checkpoint.scenario_label = args.label checkpoint.scenario_label = args.label
with get_session() as session:
session.add(checkpoint)
session.commit() session.commit()
logger.info("") logger.info("")

View File

@@ -7,7 +7,7 @@ import pytest
from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
from core.schema.serializers._common import safe_construct from core.schema.serializers._common import safe_construct
from core.schema.serializers.detect_pipeline import ( from core.schema.serializers.pipeline import (
serialize_frame_meta, serialize_frame_meta,
serialize_text_candidate, serialize_text_candidate,
serialize_text_candidates, serialize_text_candidates,