major refactor

This commit is contained in:
2026-03-27 06:02:58 -03:00
parent bcf6f3dc71
commit 51ce14a812
18 changed files with 351 additions and 523 deletions

View File

@@ -1,7 +1,7 @@
"""
API endpoints for checkpoint inspection, replay, retry, and GPU proxy.
GET /detect/checkpoints/{job_id} — list available checkpoints
GET /detect/checkpoints/{timeline_id} — list available checkpoints
POST /detect/replay — replay from a stage with config overrides
POST /detect/retry — queue async retry with different provider
POST /detect/replay-stage — replay single stage (fast path)
@@ -31,7 +31,7 @@ class CheckpointInfo(BaseModel):
class ScenarioInfo(BaseModel):
job_id: str
timeline_id: str
stage: str
scenario_label: str
profile_name: str
@@ -41,21 +41,21 @@ class ScenarioInfo(BaseModel):
class ReplayRequest(BaseModel):
job_id: str
timeline_id: str
start_stage: str
config_overrides: dict | None = None
class ReplayResponse(BaseModel):
status: str
job_id: str
timeline_id: str
start_stage: str
detections: int = 0
brands_found: int = 0
class RetryRequest(BaseModel):
job_id: str
timeline_id: str
config_overrides: dict | None = None
start_stage: str = "escalate_vlm"
schedule_seconds: float | None = None # delay before execution (off-peak)
@@ -64,11 +64,11 @@ class RetryRequest(BaseModel):
class RetryResponse(BaseModel):
status: str
task_id: str
job_id: str
timeline_id: str
class ReplaySingleStageRequest(BaseModel):
job_id: str
timeline_id: str
stage: str
frame_refs: list[int] | None = None
config_overrides: dict | None = None
@@ -102,15 +102,15 @@ class ReplaySingleStageResponse(BaseModel):
# --- Endpoints ---
@router.get("/checkpoints/{job_id}")
def list_checkpoints(job_id: str) -> list[CheckpointInfo]:
@router.get("/checkpoints/{timeline_id}")
def list_checkpoints(timeline_id: str) -> list[CheckpointInfo]:
"""List available checkpoint stages for a job."""
from detect.checkpoint import list_checkpoints as _list
try:
stages = _list(job_id)
stages = _list(timeline_id)
except Exception as e:
raise HTTPException(status_code=404, detail=f"No checkpoints for job {job_id}: {e}")
raise HTTPException(status_code=404, detail=f"No checkpoints for job {timeline_id}: {e}")
result = [CheckpointInfo(stage=s) for s in stages]
return result
@@ -123,7 +123,7 @@ class CheckpointFrameInfo(BaseModel):
class CheckpointData(BaseModel):
job_id: str
timeline_id: str
stage: str
profile_name: str
video_path: str
@@ -135,26 +135,32 @@ class CheckpointData(BaseModel):
stage_output_key: str = ""
@router.get("/checkpoints/{job_id}/{stage}", response_model=CheckpointData)
def get_checkpoint_data(job_id: str, stage: str):
@router.get("/checkpoints/{timeline_id}/{stage}", response_model=CheckpointData)
def get_checkpoint_data(timeline_id: str, stage: str):
"""Load checkpoint frames + metadata for the editor UI."""
from core.db.detect import get_stage_checkpoint
from uuid import UUID
from core.db.tables import Timeline, Checkpoint
from core.db.connection import get_session
from core.db.checkpoint import list_checkpoints
from detect.checkpoint.frames import load_frames_b64
checkpoint = get_stage_checkpoint(job_id, stage)
if not checkpoint:
raise HTTPException(status_code=404, detail=f"No checkpoint for {job_id}/{stage}")
with get_session() as session:
timeline = session.get(Timeline, UUID(timeline_id))
if not timeline:
raise HTTPException(status_code=404, detail=f"Timeline not found: {timeline_id}")
raw_manifest = checkpoint.frames_manifest or {}
manifest = {int(k): v for k, v in raw_manifest.items()}
frame_metadata = checkpoint.frames_meta or []
checkpoints = list_checkpoints(session, UUID(timeline_id))
if not checkpoints:
raise HTTPException(status_code=404, detail=f"No checkpoints for timeline {timeline_id}")
# Prefer a checkpoint that has this stage's output; fall back to latest
checkpoint = next(
(c for c in reversed(checkpoints) if stage in (c.stage_outputs or {})),
checkpoints[-1],
)
# Only load filtered frames if available, otherwise all
filtered = set(checkpoint.filtered_frame_sequences or [])
if filtered:
manifest = {k: v for k, v in manifest.items() if k in filtered}
frames_b64 = load_frames_b64(manifest, frame_metadata)
raw_manifest = timeline.frames_manifest or {}
manifest = {int(k): v for k, v in raw_manifest.items()}
frames_b64 = load_frames_b64(manifest, timeline.frames_meta or [])
frame_list = [
CheckpointFrameInfo(seq=f["seq"], timestamp=f["timestamp"], jpeg_b64=f["jpeg_b64"])
@@ -162,38 +168,44 @@ def get_checkpoint_data(job_id: str, stage: str):
]
return CheckpointData(
job_id=str(checkpoint.job_id),
stage=checkpoint.stage,
profile_name=checkpoint.profile_name,
video_path=checkpoint.video_path,
timeline_id=timeline_id,
stage=stage,
profile_name=timeline.profile_name,
video_path=timeline.source_video,
is_scenario=checkpoint.is_scenario,
scenario_label=checkpoint.scenario_label,
frames=frame_list,
stats=checkpoint.stats or {},
config_snapshot=checkpoint.config_snapshot or {},
stage_output_key=checkpoint.stage_output_key or "",
config_snapshot=checkpoint.config_overrides or {},
stage_output_key=stage,
)
@router.get("/scenarios", response_model=list[ScenarioInfo])
def list_scenarios_endpoint():
"""List all available scenarios (bookmarked checkpoints)."""
from core.db.detect import list_scenarios
from core.db.tables import Timeline
from core.db.connection import get_session
from core.db.checkpoint import list_scenarios
scenarios = list_scenarios()
result = []
for s in scenarios:
manifest = s.frames_manifest or {}
info = ScenarioInfo(
job_id=str(s.job_id),
stage=s.stage,
scenario_label=s.scenario_label,
profile_name=s.profile_name,
video_path=s.video_path,
frame_count=len(manifest),
created_at=str(s.created_at) if s.created_at else "",
)
result.append(info)
with get_session() as session:
scenarios = list_scenarios(session)
result = []
for s in scenarios:
timeline = session.get(Timeline, s.timeline_id)
if not timeline:
continue
last_stage = next(reversed(s.stage_outputs), "") if s.stage_outputs else ""
info = ScenarioInfo(
timeline_id=str(s.timeline_id),
stage=last_stage,
scenario_label=s.scenario_label,
profile_name=timeline.profile_name,
video_path=timeline.source_video,
frame_count=len(timeline.frames_manifest or {}),
created_at=str(s.created_at) if s.created_at else "",
)
result.append(info)
return result
@@ -204,7 +216,7 @@ def replay(req: ReplayRequest):
try:
result = replay_from(
job_id=req.job_id,
timeline_id=req.timeline_id,
start_stage=req.start_stage,
config_overrides=req.config_overrides,
)
@@ -219,7 +231,7 @@ def replay(req: ReplayRequest):
response = ReplayResponse(
status="completed",
job_id=req.job_id,
timeline_id=req.timeline_id,
start_stage=req.start_stage,
detections=len(detections),
brands_found=brands_found,
@@ -233,7 +245,7 @@ def retry(req: RetryRequest):
from detect.checkpoint.tasks import retry_candidates
kwargs = {
"job_id": req.job_id,
"timeline_id": req.timeline_id,
"config_overrides": req.config_overrides,
"start_stage": req.start_stage,
}
@@ -246,7 +258,7 @@ def retry(req: RetryRequest):
response = RetryResponse(
status="queued",
task_id=task.id,
job_id=req.job_id,
timeline_id=req.timeline_id,
)
return response
@@ -258,7 +270,7 @@ def replay_single_stage(req: ReplaySingleStageRequest):
try:
result = _replay(
job_id=req.job_id,
timeline_id=req.timeline_id,
stage=req.stage,
frame_refs=req.frame_refs,
config_overrides=req.config_overrides,

View File

@@ -25,11 +25,10 @@ from .grpc import (
ProgressUpdate,
WorkerStatus,
)
from .job import (
Job, JobStatus, RunType,
Timeline, Checkpoint,
BrandSource, Brand,
)
from .job import Job, JobStatus, RunType
from .timeline import Timeline
from .checkpoint import Checkpoint
from .brand import BrandSource, Brand
from .media import AssetStatus, MediaAsset
from .presets import BUILTIN_PRESETS, TranscodePreset
from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generic loader

View File

@@ -0,0 +1,38 @@
"""Brand schema — source of truth for brand discovery."""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
class BrandSource(str, Enum):
OCR = "ocr"
VLM = "local_vlm"
CLOUD = "cloud_llm"
MANUAL = "manual"
@dataclass
class Brand:
"""
A brand discovered or registered in the system.
Airings track where/when the brand appeared — each airing
references a timeline and a frame range.
"""
id: UUID
canonical_name: str
aliases: List[str] = field(default_factory=list)
source: BrandSource = BrandSource.OCR # how first discovered
confirmed: bool = False
# Airings — JSONB array of appearances
# [{timeline_id, frame_start, frame_end, confidence, source, timestamp}]
airings: List[Dict[str, Any]] = field(default_factory=list)
total_airings: int = 0
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None

View File

@@ -0,0 +1,38 @@
"""Checkpoint schema — source of truth for pipeline state snapshots."""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, Optional
from uuid import UUID
@dataclass
class Checkpoint:
"""
A snapshot of pipeline state on a timeline.
Stage outputs stored as JSONB — each stage serializes to JSON,
the checkpoint stores it without knowing the shape.
parent_id forms a tree: multiple children from the same parent
= different config tries from the same starting point.
"""
id: UUID
timeline_id: UUID
parent_id: Optional[UUID] = None # null = root checkpoint
# Stage outputs — JSONB per stage, opaque to the checkpoint layer
stage_outputs: Dict[str, Any] = field(default_factory=dict)
# Config that produced this checkpoint
config_overrides: Dict[str, Any] = field(default_factory=dict)
# Pipeline state
stats: Dict[str, Any] = field(default_factory=dict)
# Scenario bookmark
is_scenario: bool = False
scenario_label: str = ""
created_at: Optional[datetime] = None

View File

@@ -1,177 +0,0 @@
"""
Detection Job and Checkpoint Schema Definitions
Source of truth for detection pipeline job tracking and stage checkpoints.
Follows the TranscodeJob/ChunkJob pattern.
"""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
class DetectJobStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
PAUSED = "paused"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class RunType(str, Enum):
INITIAL = "initial"
REPLAY = "replay"
RETRY = "retry"
@dataclass
class DetectJob:
"""
A detection pipeline job.
Each invocation of the pipeline (initial run, replay, retry) creates a DetectJob.
Jobs for the same source video are linked via parent_job_id.
"""
id: UUID
# Input
source_asset_id: UUID
video_path: str
profile_name: str = "soccer_broadcast"
# Run lineage
parent_job_id: Optional[UUID] = None # links all runs for the same source
run_type: RunType = RunType.INITIAL
replay_from_stage: Optional[str] = None # null for initial runs
config_overrides: Dict[str, Any] = field(default_factory=dict)
# Status
status: DetectJobStatus = DetectJobStatus.PENDING
current_stage: Optional[str] = None
progress: float = 0.0
error_message: Optional[str] = None
# Results summary
total_detections: int = 0
brands_found: int = 0
cloud_llm_calls: int = 0
estimated_cost_usd: float = 0.0
# Worker tracking
celery_task_id: Optional[str] = None
priority: int = 0
# Timestamps
created_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
@dataclass
class Timeline:
"""
The frame sequence from a source video.
Independent of stages — exists before any stage runs.
Stages annotate the timeline, they don't own it.
Frames are stored in MinIO as JPEGs.
"""
id: UUID
source_asset_id: Optional[UUID] = None
source_video: str = ""
profile_name: str = ""
fps: float = 2.0
# Frame metadata (images in MinIO, metadata here)
frames_prefix: str = "" # s3: timelines/{id}/frames/
frames_manifest: Dict[int, str] = field(default_factory=dict) # seq → s3 key
frames_meta: List[Dict[str, Any]] = field(default_factory=list)
created_at: Optional[datetime] = None
@dataclass
class Checkpoint:
"""
A snapshot of pipeline state on a timeline.
Stage outputs stored as JSONB — each stage serializes to JSON,
the checkpoint stores it without knowing the shape.
parent_id forms a tree: multiple children from the same parent
= different config tries from the same starting point.
"""
id: UUID
timeline_id: UUID
parent_id: Optional[UUID] = None # null = root checkpoint
# Stage outputs — JSONB per stage, opaque to the checkpoint layer
stage_outputs: Dict[str, Any] = field(default_factory=dict)
# Config that produced this checkpoint
config_overrides: Dict[str, Any] = field(default_factory=dict)
# Pipeline state
stats: Dict[str, Any] = field(default_factory=dict)
# Scenario bookmark
is_scenario: bool = False
scenario_label: str = ""
created_at: Optional[datetime] = None
class BrandSource(str, Enum):
"""How a brand was first identified."""
OCR = "ocr"
VLM = "local_vlm"
CLOUD = "cloud_llm"
MANUAL = "manual" # user-added via UI
@dataclass
class KnownBrand:
"""
A brand discovered or registered in the system.
Global — not per-source. Accumulates across all pipeline runs.
Aliases enable fuzzy matching without re-escalating to VLM.
"""
id: UUID
canonical_name: str # normalized display name
aliases: List[str] = field(default_factory=list) # known spellings/variants
first_source: BrandSource = BrandSource.OCR
total_occurrences: int = 0
confirmed: bool = False # manually confirmed by user
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
@dataclass
class SourceBrandSighting:
"""
A brand seen in a specific source (video/asset).
Per-source session cache — avoids re-escalating the same brand
on subsequent frames or re-runs of the same source.
"""
id: UUID
source_asset_id: UUID # the video this sighting belongs to
brand_id: UUID # FK to KnownBrand
brand_name: str # denormalized for fast lookup
first_seen_timestamp: float = 0.0
last_seen_timestamp: float = 0.0
occurrences: int = 0
detection_source: BrandSource = BrandSource.OCR
avg_confidence: float = 0.0
created_at: Optional[datetime] = None

View File

@@ -1,14 +1,9 @@
"""
Job, Timeline, and Checkpoint Schema Definitions
Source of truth for pipeline jobs, timelines, and checkpoints.
Generates: SQLModel (core/db/models.py), TypeScript via modelgen.
"""
"""Job schema — source of truth for pipeline jobs."""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
from uuid import UUID
@@ -68,91 +63,3 @@ class Job:
created_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
@dataclass
class Timeline:
"""
The frame sequence from a source video.
Independent of stages — exists before any stage runs.
Frames stored in MinIO as JPEGs, metadata here.
One timeline per job.
"""
id: UUID
source_asset_id: Optional[UUID] = None
source_video: str = ""
profile_name: str = ""
fps: float = 2.0
frames_prefix: str = "" # s3: timeline/{id}/frames/
frames_manifest: Dict[int, str] = field(default_factory=dict) # seq → s3 key
frames_meta: List[Dict[str, Any]] = field(default_factory=list)
created_at: Optional[datetime] = None
@dataclass
class Checkpoint:
"""
A snapshot of pipeline state on a timeline.
Stage outputs stored as JSONB — each stage serializes to JSON,
the checkpoint stores it without knowing the shape.
parent_id forms a tree: multiple children from the same parent
= different config tries from the same starting point.
"""
id: UUID
timeline_id: UUID
parent_id: Optional[UUID] = None # null = root checkpoint
# Stage outputs — JSONB per stage, opaque to the checkpoint layer
stage_outputs: Dict[str, Any] = field(default_factory=dict)
# Config that produced this checkpoint
config_overrides: Dict[str, Any] = field(default_factory=dict)
# Pipeline state
stats: Dict[str, Any] = field(default_factory=dict)
# Scenario bookmark
is_scenario: bool = False
scenario_label: str = ""
created_at: Optional[datetime] = None
# --- Brands ---
class BrandSource(str, Enum):
OCR = "ocr"
VLM = "local_vlm"
CLOUD = "cloud_llm"
MANUAL = "manual"
@dataclass
class Brand:
"""
A brand discovered or registered in the system.
Airings track where/when the brand appeared — each airing
references a timeline and a frame range.
"""
id: UUID
canonical_name: str
aliases: List[str] = field(default_factory=list)
source: BrandSource = BrandSource.OCR # how first discovered
confirmed: bool = False
# Airings — JSONB array of appearances
# [{timeline_id, frame_start, frame_end, confidence, source, timestamp}]
airings: List[Dict[str, Any]] = field(default_factory=list)
total_airings: int = 0
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None

View File

@@ -1,13 +1,9 @@
"""
Detection pipeline runtime models.
These are the data structures that flow between LangGraph nodes.
They contain runtime types (np.ndarray) so they are NOT generated
by modelgen they live here for the schema to be the complete
map of the application, but modelgen skips them.
Wire-format models (SSE events) are in detect.py.
DB models (jobs, checkpoints) are in detect_jobs.py.
These are the data structures that flow between pipeline stages.
They contain runtime types (np.ndarray) so modelgen skips them
not generated to SQLModel or TypeScript.
"""
from __future__ import annotations
@@ -89,10 +85,3 @@ class DetectionReport:
brands: dict[str, BrandStats] = field(default_factory=dict)
timeline: list[BrandDetection] = field(default_factory=list)
pipeline_stats: PipelineStats = field(default_factory=PipelineStats)
# Not in DATACLASSES — modelgen skips these (they contain np.ndarray)
RUNTIME_MODELS = [
Frame, BoundingBox, TextCandidate, BrandDetection,
BrandStats, PipelineStats, DetectionReport,
]

View File

@@ -0,0 +1,29 @@
"""Timeline schema — source of truth for frame sequences."""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
from uuid import UUID
@dataclass
class Timeline:
"""
The frame sequence from a source video.
Independent of stages — exists before any stage runs.
Frames stored in MinIO as JPEGs, metadata here.
One timeline per job.
"""
id: UUID
source_asset_id: Optional[UUID] = None
source_video: str = ""
profile_name: str = ""
fps: float = 2.0
frames_prefix: str = "" # s3: timeline/{id}/frames/
frames_manifest: Dict[int, str] = field(default_factory=dict) # seq → s3 key
frames_meta: List[Dict[str, Any]] = field(default_factory=list)
created_at: Optional[datetime] = None

View File

@@ -1,8 +1,6 @@
"""
Serializers for detection pipeline runtime models.
Mirrors core/schema/models/detect_pipeline.py.
Special handling:
- Frame.image (np.ndarray S3, excluded from JSON)
- TextCandidate.frame (object ref frame_sequence integer)
@@ -13,7 +11,7 @@ from __future__ import annotations
import dataclasses
from core.schema.models.detect_pipeline import (
from core.schema.models.pipeline import (
BoundingBox,
BrandDetection,
BrandStats,
@@ -59,13 +57,12 @@ def deserialize_frames_with_download(meta: list[dict], manifest: dict, job_id: s
def serialize_text_candidate(tc: TextCandidate) -> dict:
bbox_dict = dataclasses.asdict(tc.bbox)
result = {
return {
"frame_sequence": tc.frame.sequence,
"bbox": bbox_dict,
"text": tc.text,
"ocr_confidence": tc.ocr_confidence,
}
return result
def serialize_text_candidates(candidates: list[TextCandidate]) -> list[dict]:
@@ -75,13 +72,12 @@ def serialize_text_candidates(candidates: list[TextCandidate]) -> list[dict]:
def deserialize_text_candidate(data: dict, frame_map: dict[int, Frame]) -> TextCandidate:
frame = frame_map[data["frame_sequence"]]
bbox = safe_construct(BoundingBox, data["bbox"])
candidate = TextCandidate(
return TextCandidate(
frame=frame,
bbox=bbox,
text=data["text"],
ocr_confidence=data["ocr_confidence"],
)
return candidate
def deserialize_text_candidates(data: list[dict], frame_map: dict[int, Frame]) -> list[TextCandidate]:

View File

@@ -11,7 +11,7 @@ that don't belong to any stage.
from __future__ import annotations
from core.schema.serializers._common import serialize_dataclass
from core.schema.serializers.detect_pipeline import (
from core.schema.serializers.pipeline import (
deserialize_pipeline_stats,
deserialize_text_candidates,
)

View File

@@ -33,83 +33,81 @@ def create_timeline(
Returns (timeline_id, checkpoint_id).
"""
from core.db.detect import create_timeline as db_create_timeline
from core.db.detect import save_checkpoint
# Create timeline
timeline = db_create_timeline(
source_video=source_video,
profile_name=profile_name,
source_asset_id=source_asset_id,
fps=fps,
)
tid = str(timeline.id)
# Upload frames to MinIO
manifest = save_frames(tid, frames)
# Store frame metadata on the timeline
frames_meta = [
{
"sequence": f.sequence,
"chunk_id": getattr(f, "chunk_id", 0),
"timestamp": f.timestamp,
"perceptual_hash": getattr(f, "perceptual_hash", ""),
}
for f in frames
]
timeline.frames_prefix = f"{CHECKPOINT_PREFIX}/{tid}/frames/"
timeline.frames_manifest = {str(k): v for k, v in manifest.items()}
timeline.frames_meta = frames_meta
from core.db.tables import Timeline, Checkpoint
from core.db.connection import get_session
with get_session() as session:
timeline = Timeline(
source_video=source_video,
profile_name=profile_name,
source_asset_id=source_asset_id,
fps=fps,
)
session.add(timeline)
session.flush()
tid = str(timeline.id)
# Upload frames to MinIO
manifest = save_frames(tid, frames)
frames_meta = [
{
"sequence": f.sequence,
"chunk_id": getattr(f, "chunk_id", 0),
"timestamp": f.timestamp,
"perceptual_hash": getattr(f, "perceptual_hash", ""),
}
for f in frames
]
timeline.frames_prefix = f"{CHECKPOINT_PREFIX}/{tid}/frames/"
timeline.frames_manifest = {str(k): v for k, v in manifest.items()}
timeline.frames_meta = frames_meta
checkpoint = Checkpoint(
timeline_id=timeline.id,
parent_id=None,
stage_outputs={},
stats={"frames_extracted": len(frames)},
)
session.add(checkpoint)
session.commit()
session.refresh(checkpoint)
cid = str(checkpoint.id)
# Create root checkpoint (no parent, no stage outputs yet)
checkpoint = save_checkpoint(
timeline_id=timeline.id,
parent_id=None,
stage_outputs={},
stats={"frames_extracted": len(frames)},
)
logger.info("Timeline created: %s (%d frames, root checkpoint %s)",
tid, len(frames), checkpoint.id)
return tid, str(checkpoint.id)
logger.info("Timeline created: %s (%d frames, root checkpoint %s)", tid, len(frames), cid)
return tid, cid
def get_timeline_frames(timeline_id: str) -> list:
"""Load frames from a timeline (from MinIO) as Frame objects."""
from core.db.detect import get_timeline
from core.db.tables import Timeline
from core.db.connection import get_session
timeline = get_timeline(timeline_id)
with get_session() as session:
timeline = session.get(Timeline, UUID(timeline_id))
if not timeline:
raise ValueError(f"Timeline not found: {timeline_id}")
raw_manifest = timeline.frames_manifest or {}
manifest = {int(k): v for k, v in raw_manifest.items()}
frame_metadata = timeline.frames_meta or []
return load_frames(manifest, frame_metadata)
return load_frames(manifest, timeline.frames_meta or [])
def get_timeline_frames_b64(timeline_id: str) -> list[dict]:
"""Load frames as base64 JPEG (lightweight, no numpy)."""
from core.db.detect import get_timeline
from core.db.tables import Timeline
from core.db.connection import get_session
from .frames import load_frames_b64
timeline = get_timeline(timeline_id)
with get_session() as session:
timeline = session.get(Timeline, UUID(timeline_id))
if not timeline:
raise ValueError(f"Timeline not found: {timeline_id}")
raw_manifest = timeline.frames_manifest or {}
manifest = {int(k): v for k, v in raw_manifest.items()}
frame_metadata = timeline.frames_meta or []
return load_frames_b64(manifest, frame_metadata)
return load_frames_b64(manifest, timeline.frames_meta or [])
# ---------------------------------------------------------------------------
@@ -132,47 +130,46 @@ def save_stage_output(
Carries forward stage outputs from parent + adds the new one.
Returns the new checkpoint ID.
"""
from core.db.detect import get_checkpoint, save_checkpoint
from core.db.tables import Checkpoint
from core.db.connection import get_session
# Carry forward from parent
parent_outputs = {}
parent_stats = {}
parent_config = {}
if parent_checkpoint_id:
parent = get_checkpoint(parent_checkpoint_id)
if parent:
parent_outputs = dict(parent.stage_outputs or {})
parent_stats = dict(parent.stats or {})
parent_config = dict(parent.config_overrides or {})
with get_session() as session:
parent_outputs = {}
parent_stats = {}
parent_config = {}
if parent_checkpoint_id:
parent = session.get(Checkpoint, UUID(parent_checkpoint_id))
if parent:
parent_outputs = dict(parent.stage_outputs or {})
parent_stats = dict(parent.stats or {})
parent_config = dict(parent.config_overrides or {})
# Add new stage output
stage_outputs = {**parent_outputs, stage_name: output_json}
# Merge stats and config
merged_stats = {**parent_stats, **(stats or {})}
merged_config = {**parent_config, **(config_overrides or {})}
checkpoint = save_checkpoint(
timeline_id=timeline_id,
parent_id=parent_checkpoint_id,
stage_outputs=stage_outputs,
config_overrides=merged_config,
stats=merged_stats,
is_scenario=is_scenario,
scenario_label=scenario_label,
)
checkpoint = Checkpoint(
timeline_id=UUID(timeline_id),
parent_id=UUID(parent_checkpoint_id) if parent_checkpoint_id else None,
stage_outputs={**parent_outputs, stage_name: output_json},
config_overrides={**parent_config, **(config_overrides or {})},
stats={**parent_stats, **(stats or {})},
is_scenario=is_scenario,
scenario_label=scenario_label,
)
session.add(checkpoint)
session.commit()
session.refresh(checkpoint)
cid = str(checkpoint.id)
logger.info("Checkpoint saved: %s (timeline %s, stage %s, parent %s)",
checkpoint.id, timeline_id, stage_name, parent_checkpoint_id)
return str(checkpoint.id)
cid, timeline_id, stage_name, parent_checkpoint_id)
return cid
def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None:
"""Load a stage's output from a checkpoint."""
from core.db.detect import get_checkpoint
from core.db.tables import Checkpoint
from core.db.connection import get_session
checkpoint = get_checkpoint(checkpoint_id)
with get_session() as session:
checkpoint = session.get(Checkpoint, UUID(checkpoint_id))
if not checkpoint:
return None
return (checkpoint.stage_outputs or {}).get(stage_name)

View File

@@ -326,6 +326,7 @@ def node_compile_report(state: DetectState) -> dict:
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
_latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id
class PipelineCancelled(Exception):
@@ -361,17 +362,33 @@ def _checkpointing_node(node_name: str, node_fn):
if not job_id:
return result
from detect.checkpoint import save_checkpoint, save_frames
from detect.checkpoint import save_stage_output, save_frames
from detect.stages.base import _REGISTRY
merged = {**state, **result}
# Save frames once (first checkpoint), reuse manifest after
# Save frames once (first node), reuse manifest after
manifest = _frames_manifest.get(job_id)
if manifest is None and node_name == "extract_frames":
manifest = save_frames(job_id, merged.get("frames", []))
_frames_manifest[job_id] = manifest
save_checkpoint(job_id, node_name, stage_index, merged, frames_manifest=manifest)
# Serialize stage output using the stage's serialize_fn if available
stage_cls = _REGISTRY.get(node_name)
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
if serialize_fn:
output_json = serialize_fn(merged, job_id)
else:
output_json = {}
parent_id = _latest_checkpoint.get(job_id)
new_checkpoint_id = save_stage_output(
timeline_id=job_id,
parent_checkpoint_id=parent_id,
stage_name=node_name,
output_json=output_json,
)
_latest_checkpoint[job_id] = new_checkpoint_id
return result
wrapper.__name__ = node_fn.__name__

View File

@@ -1,11 +1,6 @@
"""
Re-export pipeline runtime models from core/schema/models/detect_pipeline.py.
"""Re-export pipeline runtime models from core/schema/models/pipeline.py."""
All models are defined in core/schema/ — this module exists for backward
compatibility so existing imports (from detect.models import Frame) keep working.
"""
from core.schema.models.detect_pipeline import (
from core.schema.models.pipeline import (
BoundingBox,
BrandDetection,
BrandStats,

View File

@@ -4,14 +4,10 @@ Stage 5 — Brand Resolver (discovery mode)
Discovery-first brand matching. No static dictionary — all brands live in the DB.
Flow:
1. Check session sightings first (brands already seen in this source)
1. Check session brands first (brands already seen in this run, in-memory)
2. Check global known brands (accumulated across all runs)
3. Unresolved candidates → escalate to VLM/cloud
4. Confirmed brands get added to DB for future runs
The resolver is an enricher, not a gatekeeper. Every OCR text candidate
passes through — the question is whether we can resolve it cheaply (DB lookup)
or need to escalate (VLM/cloud).
"""
from __future__ import annotations
@@ -33,41 +29,30 @@ def _normalize(text: str) -> str:
def _has_db() -> bool:
try:
from core.db.detect import find_brand_by_text as _
from admin.mpr.media_assets.models import KnownBrand as _
from core.db import find_brand_by_text as _
return True
except (ImportError, Exception):
return False
def _match_session(text: str, session_brands: dict[str, str]) -> str | None:
"""
Check against session brands (already seen in this source).
session_brands: {normalized_name: canonical_name, ...}
Includes aliases.
"""
normalized = _normalize(text)
return session_brands.get(normalized)
return session_brands.get(_normalize(text))
def _match_known(text: str, threshold: int) -> tuple[str | None, str | None]:
"""
Check against global known brands in DB.
Returns (canonical_name, brand_id) or (None, None).
"""
"""Check against global known brands in DB. Returns (canonical_name, brand_id) or (None, None)."""
if not _has_db():
return None, None
from core.db.detect import find_brand_by_text
brand = find_brand_by_text(text)
if brand:
return brand.canonical_name, str(brand.id)
from core.db import find_brand_by_text, list_brands
from core.db.connection import get_session
# Fuzzy match against all known brands
from core.db.detect import list_all_brands
all_brands = list_all_brands()
with get_session() as session:
brand = find_brand_by_text(session, text)
if brand:
return brand.canonical_name, str(brand.id)
all_brands = list_brands(session)
normalized = _normalize(text)
best_brand = None
@@ -92,58 +77,62 @@ def _register_brand(canonical_name: str, source: str) -> str | None:
if not _has_db():
return None
from core.db.detect import get_or_create_brand
brand, created = get_or_create_brand(canonical_name, source=source)
from core.db import get_or_create_brand
from core.db.connection import get_session
with get_session() as session:
brand, created = get_or_create_brand(session, canonical_name, source=source)
session.commit()
if created:
logger.info("New brand discovered: %s (source=%s)", canonical_name, source)
return str(brand.id)
def _record_sighting(source_asset_id: str | None, brand_id: str,
brand_name: str, timestamp: float,
confidence: float, source: str):
"""Record a brand sighting for this source."""
if not _has_db() or not source_asset_id:
def _record_airing(timeline_id: str | None, brand_id: str,
frame_seq: int, confidence: float, source: str):
"""Record a brand airing on a timeline."""
if not _has_db() or not timeline_id:
return
from core.db.detect import record_sighting
import uuid
asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id
brand_uuid = uuid.UUID(brand_id) if isinstance(brand_id, str) else brand_id
record_sighting(asset_id, brand_uuid, brand_name, timestamp, confidence, source)
from core.db import record_airing
from core.db.connection import get_session
from uuid import UUID
with get_session() as session:
record_airing(
session,
brand_id=UUID(brand_id),
timeline_id=UUID(timeline_id),
frame_start=frame_seq,
frame_end=frame_seq,
confidence=confidence,
source=source,
)
session.commit()
def build_session_dict(source_asset_id: str | None) -> dict[str, str]:
def build_session_dict(source_asset_id: str | None = None) -> dict[str, str]:
"""
Load session brands from DB for this source.
Load known brands from DB as a session lookup dict.
Returns {normalized_name: canonical_name, ...} including aliases.
"""
if not _has_db() or not source_asset_id:
if not _has_db():
return {}
from core.db.detect import get_source_sightings
import uuid
from core.db import list_brands
from core.db.connection import get_session
asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id
sightings = get_source_sightings(asset_id)
with get_session() as session:
all_brands = list_brands(session)
session = {}
for s in sightings:
canonical = s.brand_name
session[_normalize(canonical)] = canonical
session_dict = {}
for brand in all_brands:
session_dict[_normalize(brand.canonical_name)] = brand.canonical_name
for alias in (brand.aliases or []):
session_dict[_normalize(alias)] = brand.canonical_name
# Also load aliases from KnownBrand for each sighted brand
if _has_db():
from core.db.detect import list_all_brands
all_brands = list_all_brands()
sighted_names = {s.brand_name for s in sightings}
for brand in all_brands:
if brand.canonical_name in sighted_names:
for alias in (brand.aliases or []):
session[_normalize(alias)] = brand.canonical_name
return session
return session_dict
def resolve_brands(
@@ -158,7 +147,7 @@ def resolve_brands(
Match text candidates against known brands (session → global → unresolved).
session_brands: pre-loaded session dict (from build_session_dict)
source_asset_id: for recording new sightings in DB
job_id: timeline_id — used to record airings
"""
if session_brands is None:
session_brands = {}
@@ -187,7 +176,6 @@ def resolve_brands(
brand_name, brand_id = _match_known(text, config.fuzzy_threshold)
if brand_name:
known_hits += 1
# Add to session for subsequent candidates in this run
session_brands[_normalize(brand_name)] = brand_name
if brand_name:
@@ -203,11 +191,10 @@ def resolve_brands(
)
matched.append(detection)
# Record sighting in DB
if brand_id:
_record_sighting(
source_asset_id, brand_id, brand_name,
candidate.frame.timestamp, candidate.ocr_confidence, match_source,
_record_airing(
job_id, brand_id,
candidate.frame.sequence, candidate.ocr_confidence, match_source,
)
emit.detection(

View File

@@ -10,7 +10,7 @@ from core.schema.serializers._common import (
serialize_dataclass,
serialize_dataclass_list,
)
from core.schema.serializers.detect_pipeline import (
from core.schema.serializers.pipeline import (
serialize_frame_meta,
serialize_frames_with_upload as serialize_frames,
deserialize_frames_with_download as deserialize_frames,

View File

@@ -35,9 +35,11 @@ logger = logging.getLogger(__name__)
def main():
from core.db.detect import list_scenarios
from core.db import list_scenarios
from core.db.connection import get_session
scenarios = list_scenarios()
with get_session() as session:
scenarios = list_scenarios(session)
if not scenarios:
logger.info("No scenarios found. Create one with:")

View File

@@ -121,15 +121,14 @@ def main():
)
# Mark as scenario
from core.db.detect import get_latest_checkpoint
from core.db import get_latest_checkpoint
from core.db.connection import get_session
checkpoint = get_latest_checkpoint(branch_id)
if checkpoint:
checkpoint.is_scenario = True
checkpoint.scenario_label = args.label
with get_session() as session:
session.add(checkpoint)
with get_session() as session:
checkpoint = get_latest_checkpoint(session, branch_id)
if checkpoint:
checkpoint.is_scenario = True
checkpoint.scenario_label = args.label
session.commit()
logger.info("")

View File

@@ -7,7 +7,7 @@ import pytest
from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
from core.schema.serializers._common import safe_construct
from core.schema.serializers.detect_pipeline import (
from core.schema.serializers.pipeline import (
serialize_frame_meta,
serialize_text_candidate,
serialize_text_candidates,