major refactor
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
API endpoints for checkpoint inspection, replay, retry, and GPU proxy.
|
API endpoints for checkpoint inspection, replay, retry, and GPU proxy.
|
||||||
|
|
||||||
GET /detect/checkpoints/{job_id} — list available checkpoints
|
GET /detect/checkpoints/{timeline_id} — list available checkpoints
|
||||||
POST /detect/replay — replay from a stage with config overrides
|
POST /detect/replay — replay from a stage with config overrides
|
||||||
POST /detect/retry — queue async retry with different provider
|
POST /detect/retry — queue async retry with different provider
|
||||||
POST /detect/replay-stage — replay single stage (fast path)
|
POST /detect/replay-stage — replay single stage (fast path)
|
||||||
@@ -31,7 +31,7 @@ class CheckpointInfo(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ScenarioInfo(BaseModel):
|
class ScenarioInfo(BaseModel):
|
||||||
job_id: str
|
timeline_id: str
|
||||||
stage: str
|
stage: str
|
||||||
scenario_label: str
|
scenario_label: str
|
||||||
profile_name: str
|
profile_name: str
|
||||||
@@ -41,21 +41,21 @@ class ScenarioInfo(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ReplayRequest(BaseModel):
|
class ReplayRequest(BaseModel):
|
||||||
job_id: str
|
timeline_id: str
|
||||||
start_stage: str
|
start_stage: str
|
||||||
config_overrides: dict | None = None
|
config_overrides: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
class ReplayResponse(BaseModel):
|
class ReplayResponse(BaseModel):
|
||||||
status: str
|
status: str
|
||||||
job_id: str
|
timeline_id: str
|
||||||
start_stage: str
|
start_stage: str
|
||||||
detections: int = 0
|
detections: int = 0
|
||||||
brands_found: int = 0
|
brands_found: int = 0
|
||||||
|
|
||||||
|
|
||||||
class RetryRequest(BaseModel):
|
class RetryRequest(BaseModel):
|
||||||
job_id: str
|
timeline_id: str
|
||||||
config_overrides: dict | None = None
|
config_overrides: dict | None = None
|
||||||
start_stage: str = "escalate_vlm"
|
start_stage: str = "escalate_vlm"
|
||||||
schedule_seconds: float | None = None # delay before execution (off-peak)
|
schedule_seconds: float | None = None # delay before execution (off-peak)
|
||||||
@@ -64,11 +64,11 @@ class RetryRequest(BaseModel):
|
|||||||
class RetryResponse(BaseModel):
|
class RetryResponse(BaseModel):
|
||||||
status: str
|
status: str
|
||||||
task_id: str
|
task_id: str
|
||||||
job_id: str
|
timeline_id: str
|
||||||
|
|
||||||
|
|
||||||
class ReplaySingleStageRequest(BaseModel):
|
class ReplaySingleStageRequest(BaseModel):
|
||||||
job_id: str
|
timeline_id: str
|
||||||
stage: str
|
stage: str
|
||||||
frame_refs: list[int] | None = None
|
frame_refs: list[int] | None = None
|
||||||
config_overrides: dict | None = None
|
config_overrides: dict | None = None
|
||||||
@@ -102,15 +102,15 @@ class ReplaySingleStageResponse(BaseModel):
|
|||||||
|
|
||||||
# --- Endpoints ---
|
# --- Endpoints ---
|
||||||
|
|
||||||
@router.get("/checkpoints/{job_id}")
|
@router.get("/checkpoints/{timeline_id}")
|
||||||
def list_checkpoints(job_id: str) -> list[CheckpointInfo]:
|
def list_checkpoints(timeline_id: str) -> list[CheckpointInfo]:
|
||||||
"""List available checkpoint stages for a job."""
|
"""List available checkpoint stages for a job."""
|
||||||
from detect.checkpoint import list_checkpoints as _list
|
from detect.checkpoint import list_checkpoints as _list
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stages = _list(job_id)
|
stages = _list(timeline_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=404, detail=f"No checkpoints for job {job_id}: {e}")
|
raise HTTPException(status_code=404, detail=f"No checkpoints for job {timeline_id}: {e}")
|
||||||
|
|
||||||
result = [CheckpointInfo(stage=s) for s in stages]
|
result = [CheckpointInfo(stage=s) for s in stages]
|
||||||
return result
|
return result
|
||||||
@@ -123,7 +123,7 @@ class CheckpointFrameInfo(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class CheckpointData(BaseModel):
|
class CheckpointData(BaseModel):
|
||||||
job_id: str
|
timeline_id: str
|
||||||
stage: str
|
stage: str
|
||||||
profile_name: str
|
profile_name: str
|
||||||
video_path: str
|
video_path: str
|
||||||
@@ -135,26 +135,32 @@ class CheckpointData(BaseModel):
|
|||||||
stage_output_key: str = ""
|
stage_output_key: str = ""
|
||||||
|
|
||||||
|
|
||||||
@router.get("/checkpoints/{job_id}/{stage}", response_model=CheckpointData)
|
@router.get("/checkpoints/{timeline_id}/{stage}", response_model=CheckpointData)
|
||||||
def get_checkpoint_data(job_id: str, stage: str):
|
def get_checkpoint_data(timeline_id: str, stage: str):
|
||||||
"""Load checkpoint frames + metadata for the editor UI."""
|
"""Load checkpoint frames + metadata for the editor UI."""
|
||||||
from core.db.detect import get_stage_checkpoint
|
from uuid import UUID
|
||||||
|
from core.db.tables import Timeline, Checkpoint
|
||||||
|
from core.db.connection import get_session
|
||||||
|
from core.db.checkpoint import list_checkpoints
|
||||||
from detect.checkpoint.frames import load_frames_b64
|
from detect.checkpoint.frames import load_frames_b64
|
||||||
|
|
||||||
checkpoint = get_stage_checkpoint(job_id, stage)
|
with get_session() as session:
|
||||||
if not checkpoint:
|
timeline = session.get(Timeline, UUID(timeline_id))
|
||||||
raise HTTPException(status_code=404, detail=f"No checkpoint for {job_id}/{stage}")
|
if not timeline:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Timeline not found: {timeline_id}")
|
||||||
|
|
||||||
raw_manifest = checkpoint.frames_manifest or {}
|
checkpoints = list_checkpoints(session, UUID(timeline_id))
|
||||||
manifest = {int(k): v for k, v in raw_manifest.items()}
|
if not checkpoints:
|
||||||
frame_metadata = checkpoint.frames_meta or []
|
raise HTTPException(status_code=404, detail=f"No checkpoints for timeline {timeline_id}")
|
||||||
|
# Prefer a checkpoint that has this stage's output; fall back to latest
|
||||||
|
checkpoint = next(
|
||||||
|
(c for c in reversed(checkpoints) if stage in (c.stage_outputs or {})),
|
||||||
|
checkpoints[-1],
|
||||||
|
)
|
||||||
|
|
||||||
# Only load filtered frames if available, otherwise all
|
raw_manifest = timeline.frames_manifest or {}
|
||||||
filtered = set(checkpoint.filtered_frame_sequences or [])
|
manifest = {int(k): v for k, v in raw_manifest.items()}
|
||||||
if filtered:
|
frames_b64 = load_frames_b64(manifest, timeline.frames_meta or [])
|
||||||
manifest = {k: v for k, v in manifest.items() if k in filtered}
|
|
||||||
|
|
||||||
frames_b64 = load_frames_b64(manifest, frame_metadata)
|
|
||||||
|
|
||||||
frame_list = [
|
frame_list = [
|
||||||
CheckpointFrameInfo(seq=f["seq"], timestamp=f["timestamp"], jpeg_b64=f["jpeg_b64"])
|
CheckpointFrameInfo(seq=f["seq"], timestamp=f["timestamp"], jpeg_b64=f["jpeg_b64"])
|
||||||
@@ -162,38 +168,44 @@ def get_checkpoint_data(job_id: str, stage: str):
|
|||||||
]
|
]
|
||||||
|
|
||||||
return CheckpointData(
|
return CheckpointData(
|
||||||
job_id=str(checkpoint.job_id),
|
timeline_id=timeline_id,
|
||||||
stage=checkpoint.stage,
|
stage=stage,
|
||||||
profile_name=checkpoint.profile_name,
|
profile_name=timeline.profile_name,
|
||||||
video_path=checkpoint.video_path,
|
video_path=timeline.source_video,
|
||||||
is_scenario=checkpoint.is_scenario,
|
is_scenario=checkpoint.is_scenario,
|
||||||
scenario_label=checkpoint.scenario_label,
|
scenario_label=checkpoint.scenario_label,
|
||||||
frames=frame_list,
|
frames=frame_list,
|
||||||
stats=checkpoint.stats or {},
|
stats=checkpoint.stats or {},
|
||||||
config_snapshot=checkpoint.config_snapshot or {},
|
config_snapshot=checkpoint.config_overrides or {},
|
||||||
stage_output_key=checkpoint.stage_output_key or "",
|
stage_output_key=stage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/scenarios", response_model=list[ScenarioInfo])
|
@router.get("/scenarios", response_model=list[ScenarioInfo])
|
||||||
def list_scenarios_endpoint():
|
def list_scenarios_endpoint():
|
||||||
"""List all available scenarios (bookmarked checkpoints)."""
|
"""List all available scenarios (bookmarked checkpoints)."""
|
||||||
from core.db.detect import list_scenarios
|
from core.db.tables import Timeline
|
||||||
|
from core.db.connection import get_session
|
||||||
|
from core.db.checkpoint import list_scenarios
|
||||||
|
|
||||||
scenarios = list_scenarios()
|
with get_session() as session:
|
||||||
result = []
|
scenarios = list_scenarios(session)
|
||||||
for s in scenarios:
|
result = []
|
||||||
manifest = s.frames_manifest or {}
|
for s in scenarios:
|
||||||
info = ScenarioInfo(
|
timeline = session.get(Timeline, s.timeline_id)
|
||||||
job_id=str(s.job_id),
|
if not timeline:
|
||||||
stage=s.stage,
|
continue
|
||||||
scenario_label=s.scenario_label,
|
last_stage = next(reversed(s.stage_outputs), "") if s.stage_outputs else ""
|
||||||
profile_name=s.profile_name,
|
info = ScenarioInfo(
|
||||||
video_path=s.video_path,
|
timeline_id=str(s.timeline_id),
|
||||||
frame_count=len(manifest),
|
stage=last_stage,
|
||||||
created_at=str(s.created_at) if s.created_at else "",
|
scenario_label=s.scenario_label,
|
||||||
)
|
profile_name=timeline.profile_name,
|
||||||
result.append(info)
|
video_path=timeline.source_video,
|
||||||
|
frame_count=len(timeline.frames_manifest or {}),
|
||||||
|
created_at=str(s.created_at) if s.created_at else "",
|
||||||
|
)
|
||||||
|
result.append(info)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -204,7 +216,7 @@ def replay(req: ReplayRequest):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result = replay_from(
|
result = replay_from(
|
||||||
job_id=req.job_id,
|
timeline_id=req.timeline_id,
|
||||||
start_stage=req.start_stage,
|
start_stage=req.start_stage,
|
||||||
config_overrides=req.config_overrides,
|
config_overrides=req.config_overrides,
|
||||||
)
|
)
|
||||||
@@ -219,7 +231,7 @@ def replay(req: ReplayRequest):
|
|||||||
|
|
||||||
response = ReplayResponse(
|
response = ReplayResponse(
|
||||||
status="completed",
|
status="completed",
|
||||||
job_id=req.job_id,
|
timeline_id=req.timeline_id,
|
||||||
start_stage=req.start_stage,
|
start_stage=req.start_stage,
|
||||||
detections=len(detections),
|
detections=len(detections),
|
||||||
brands_found=brands_found,
|
brands_found=brands_found,
|
||||||
@@ -233,7 +245,7 @@ def retry(req: RetryRequest):
|
|||||||
from detect.checkpoint.tasks import retry_candidates
|
from detect.checkpoint.tasks import retry_candidates
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"job_id": req.job_id,
|
"timeline_id": req.timeline_id,
|
||||||
"config_overrides": req.config_overrides,
|
"config_overrides": req.config_overrides,
|
||||||
"start_stage": req.start_stage,
|
"start_stage": req.start_stage,
|
||||||
}
|
}
|
||||||
@@ -246,7 +258,7 @@ def retry(req: RetryRequest):
|
|||||||
response = RetryResponse(
|
response = RetryResponse(
|
||||||
status="queued",
|
status="queued",
|
||||||
task_id=task.id,
|
task_id=task.id,
|
||||||
job_id=req.job_id,
|
timeline_id=req.timeline_id,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@@ -258,7 +270,7 @@ def replay_single_stage(req: ReplaySingleStageRequest):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result = _replay(
|
result = _replay(
|
||||||
job_id=req.job_id,
|
timeline_id=req.timeline_id,
|
||||||
stage=req.stage,
|
stage=req.stage,
|
||||||
frame_refs=req.frame_refs,
|
frame_refs=req.frame_refs,
|
||||||
config_overrides=req.config_overrides,
|
config_overrides=req.config_overrides,
|
||||||
|
|||||||
@@ -25,11 +25,10 @@ from .grpc import (
|
|||||||
ProgressUpdate,
|
ProgressUpdate,
|
||||||
WorkerStatus,
|
WorkerStatus,
|
||||||
)
|
)
|
||||||
from .job import (
|
from .job import Job, JobStatus, RunType
|
||||||
Job, JobStatus, RunType,
|
from .timeline import Timeline
|
||||||
Timeline, Checkpoint,
|
from .checkpoint import Checkpoint
|
||||||
BrandSource, Brand,
|
from .brand import BrandSource, Brand
|
||||||
)
|
|
||||||
from .media import AssetStatus, MediaAsset
|
from .media import AssetStatus, MediaAsset
|
||||||
from .presets import BUILTIN_PRESETS, TranscodePreset
|
from .presets import BUILTIN_PRESETS, TranscodePreset
|
||||||
from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generic loader
|
from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generic loader
|
||||||
|
|||||||
38
core/schema/models/brand.py
Normal file
38
core/schema/models/brand.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""Brand schema — source of truth for brand discovery."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
|
class BrandSource(str, Enum):
|
||||||
|
OCR = "ocr"
|
||||||
|
VLM = "local_vlm"
|
||||||
|
CLOUD = "cloud_llm"
|
||||||
|
MANUAL = "manual"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Brand:
|
||||||
|
"""
|
||||||
|
A brand discovered or registered in the system.
|
||||||
|
|
||||||
|
Airings track where/when the brand appeared — each airing
|
||||||
|
references a timeline and a frame range.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
canonical_name: str
|
||||||
|
aliases: List[str] = field(default_factory=list)
|
||||||
|
source: BrandSource = BrandSource.OCR # how first discovered
|
||||||
|
confirmed: bool = False
|
||||||
|
|
||||||
|
# Airings — JSONB array of appearances
|
||||||
|
# [{timeline_id, frame_start, frame_end, confidence, source, timestamp}]
|
||||||
|
airings: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
total_airings: int = 0
|
||||||
|
|
||||||
|
created_at: Optional[datetime] = None
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
38
core/schema/models/checkpoint.py
Normal file
38
core/schema/models/checkpoint.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""Checkpoint schema — source of truth for pipeline state snapshots."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Checkpoint:
|
||||||
|
"""
|
||||||
|
A snapshot of pipeline state on a timeline.
|
||||||
|
|
||||||
|
Stage outputs stored as JSONB — each stage serializes to JSON,
|
||||||
|
the checkpoint stores it without knowing the shape.
|
||||||
|
|
||||||
|
parent_id forms a tree: multiple children from the same parent
|
||||||
|
= different config tries from the same starting point.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
timeline_id: UUID
|
||||||
|
parent_id: Optional[UUID] = None # null = root checkpoint
|
||||||
|
|
||||||
|
# Stage outputs — JSONB per stage, opaque to the checkpoint layer
|
||||||
|
stage_outputs: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Config that produced this checkpoint
|
||||||
|
config_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Pipeline state
|
||||||
|
stats: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Scenario bookmark
|
||||||
|
is_scenario: bool = False
|
||||||
|
scenario_label: str = ""
|
||||||
|
|
||||||
|
created_at: Optional[datetime] = None
|
||||||
@@ -1,177 +0,0 @@
|
|||||||
"""
|
|
||||||
Detection Job and Checkpoint Schema Definitions
|
|
||||||
|
|
||||||
Source of truth for detection pipeline job tracking and stage checkpoints.
|
|
||||||
Follows the TranscodeJob/ChunkJob pattern.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
|
|
||||||
class DetectJobStatus(str, Enum):
|
|
||||||
PENDING = "pending"
|
|
||||||
RUNNING = "running"
|
|
||||||
PAUSED = "paused"
|
|
||||||
COMPLETED = "completed"
|
|
||||||
FAILED = "failed"
|
|
||||||
CANCELLED = "cancelled"
|
|
||||||
|
|
||||||
|
|
||||||
class RunType(str, Enum):
|
|
||||||
INITIAL = "initial"
|
|
||||||
REPLAY = "replay"
|
|
||||||
RETRY = "retry"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DetectJob:
|
|
||||||
"""
|
|
||||||
A detection pipeline job.
|
|
||||||
|
|
||||||
Each invocation of the pipeline (initial run, replay, retry) creates a DetectJob.
|
|
||||||
Jobs for the same source video are linked via parent_job_id.
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: UUID
|
|
||||||
|
|
||||||
# Input
|
|
||||||
source_asset_id: UUID
|
|
||||||
video_path: str
|
|
||||||
profile_name: str = "soccer_broadcast"
|
|
||||||
|
|
||||||
# Run lineage
|
|
||||||
parent_job_id: Optional[UUID] = None # links all runs for the same source
|
|
||||||
run_type: RunType = RunType.INITIAL
|
|
||||||
replay_from_stage: Optional[str] = None # null for initial runs
|
|
||||||
config_overrides: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
# Status
|
|
||||||
status: DetectJobStatus = DetectJobStatus.PENDING
|
|
||||||
current_stage: Optional[str] = None
|
|
||||||
progress: float = 0.0
|
|
||||||
error_message: Optional[str] = None
|
|
||||||
|
|
||||||
# Results summary
|
|
||||||
total_detections: int = 0
|
|
||||||
brands_found: int = 0
|
|
||||||
cloud_llm_calls: int = 0
|
|
||||||
estimated_cost_usd: float = 0.0
|
|
||||||
|
|
||||||
# Worker tracking
|
|
||||||
celery_task_id: Optional[str] = None
|
|
||||||
priority: int = 0
|
|
||||||
|
|
||||||
# Timestamps
|
|
||||||
created_at: Optional[datetime] = None
|
|
||||||
started_at: Optional[datetime] = None
|
|
||||||
completed_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Timeline:
|
|
||||||
"""
|
|
||||||
The frame sequence from a source video.
|
|
||||||
|
|
||||||
Independent of stages — exists before any stage runs.
|
|
||||||
Stages annotate the timeline, they don't own it.
|
|
||||||
Frames are stored in MinIO as JPEGs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: UUID
|
|
||||||
source_asset_id: Optional[UUID] = None
|
|
||||||
source_video: str = ""
|
|
||||||
profile_name: str = ""
|
|
||||||
fps: float = 2.0
|
|
||||||
|
|
||||||
# Frame metadata (images in MinIO, metadata here)
|
|
||||||
frames_prefix: str = "" # s3: timelines/{id}/frames/
|
|
||||||
frames_manifest: Dict[int, str] = field(default_factory=dict) # seq → s3 key
|
|
||||||
frames_meta: List[Dict[str, Any]] = field(default_factory=list)
|
|
||||||
|
|
||||||
created_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Checkpoint:
|
|
||||||
"""
|
|
||||||
A snapshot of pipeline state on a timeline.
|
|
||||||
|
|
||||||
Stage outputs stored as JSONB — each stage serializes to JSON,
|
|
||||||
the checkpoint stores it without knowing the shape.
|
|
||||||
|
|
||||||
parent_id forms a tree: multiple children from the same parent
|
|
||||||
= different config tries from the same starting point.
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: UUID
|
|
||||||
timeline_id: UUID
|
|
||||||
parent_id: Optional[UUID] = None # null = root checkpoint
|
|
||||||
|
|
||||||
# Stage outputs — JSONB per stage, opaque to the checkpoint layer
|
|
||||||
stage_outputs: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
# Config that produced this checkpoint
|
|
||||||
config_overrides: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
# Pipeline state
|
|
||||||
stats: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
# Scenario bookmark
|
|
||||||
is_scenario: bool = False
|
|
||||||
scenario_label: str = ""
|
|
||||||
|
|
||||||
created_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
|
|
||||||
class BrandSource(str, Enum):
|
|
||||||
"""How a brand was first identified."""
|
|
||||||
OCR = "ocr"
|
|
||||||
VLM = "local_vlm"
|
|
||||||
CLOUD = "cloud_llm"
|
|
||||||
MANUAL = "manual" # user-added via UI
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class KnownBrand:
|
|
||||||
"""
|
|
||||||
A brand discovered or registered in the system.
|
|
||||||
|
|
||||||
Global — not per-source. Accumulates across all pipeline runs.
|
|
||||||
Aliases enable fuzzy matching without re-escalating to VLM.
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: UUID
|
|
||||||
canonical_name: str # normalized display name
|
|
||||||
aliases: List[str] = field(default_factory=list) # known spellings/variants
|
|
||||||
first_source: BrandSource = BrandSource.OCR
|
|
||||||
total_occurrences: int = 0
|
|
||||||
confirmed: bool = False # manually confirmed by user
|
|
||||||
|
|
||||||
created_at: Optional[datetime] = None
|
|
||||||
updated_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SourceBrandSighting:
|
|
||||||
"""
|
|
||||||
A brand seen in a specific source (video/asset).
|
|
||||||
|
|
||||||
Per-source session cache — avoids re-escalating the same brand
|
|
||||||
on subsequent frames or re-runs of the same source.
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: UUID
|
|
||||||
source_asset_id: UUID # the video this sighting belongs to
|
|
||||||
brand_id: UUID # FK to KnownBrand
|
|
||||||
brand_name: str # denormalized for fast lookup
|
|
||||||
first_seen_timestamp: float = 0.0
|
|
||||||
last_seen_timestamp: float = 0.0
|
|
||||||
occurrences: int = 0
|
|
||||||
detection_source: BrandSource = BrandSource.OCR
|
|
||||||
avg_confidence: float = 0.0
|
|
||||||
|
|
||||||
created_at: Optional[datetime] = None
|
|
||||||
@@ -1,14 +1,9 @@
|
|||||||
"""
|
"""Job schema — source of truth for pipeline jobs."""
|
||||||
Job, Timeline, and Checkpoint Schema Definitions
|
|
||||||
|
|
||||||
Source of truth for pipeline jobs, timelines, and checkpoints.
|
|
||||||
Generates: SQLModel (core/db/models.py), TypeScript via modelgen.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
@@ -68,91 +63,3 @@ class Job:
|
|||||||
created_at: Optional[datetime] = None
|
created_at: Optional[datetime] = None
|
||||||
started_at: Optional[datetime] = None
|
started_at: Optional[datetime] = None
|
||||||
completed_at: Optional[datetime] = None
|
completed_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Timeline:
|
|
||||||
"""
|
|
||||||
The frame sequence from a source video.
|
|
||||||
|
|
||||||
Independent of stages — exists before any stage runs.
|
|
||||||
Frames stored in MinIO as JPEGs, metadata here.
|
|
||||||
One timeline per job.
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: UUID
|
|
||||||
source_asset_id: Optional[UUID] = None
|
|
||||||
source_video: str = ""
|
|
||||||
profile_name: str = ""
|
|
||||||
fps: float = 2.0
|
|
||||||
|
|
||||||
frames_prefix: str = "" # s3: timeline/{id}/frames/
|
|
||||||
frames_manifest: Dict[int, str] = field(default_factory=dict) # seq → s3 key
|
|
||||||
frames_meta: List[Dict[str, Any]] = field(default_factory=list)
|
|
||||||
|
|
||||||
created_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Checkpoint:
|
|
||||||
"""
|
|
||||||
A snapshot of pipeline state on a timeline.
|
|
||||||
|
|
||||||
Stage outputs stored as JSONB — each stage serializes to JSON,
|
|
||||||
the checkpoint stores it without knowing the shape.
|
|
||||||
|
|
||||||
parent_id forms a tree: multiple children from the same parent
|
|
||||||
= different config tries from the same starting point.
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: UUID
|
|
||||||
timeline_id: UUID
|
|
||||||
parent_id: Optional[UUID] = None # null = root checkpoint
|
|
||||||
|
|
||||||
# Stage outputs — JSONB per stage, opaque to the checkpoint layer
|
|
||||||
stage_outputs: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
# Config that produced this checkpoint
|
|
||||||
config_overrides: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
# Pipeline state
|
|
||||||
stats: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
# Scenario bookmark
|
|
||||||
is_scenario: bool = False
|
|
||||||
scenario_label: str = ""
|
|
||||||
|
|
||||||
created_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
|
|
||||||
# --- Brands ---
|
|
||||||
|
|
||||||
class BrandSource(str, Enum):
|
|
||||||
OCR = "ocr"
|
|
||||||
VLM = "local_vlm"
|
|
||||||
CLOUD = "cloud_llm"
|
|
||||||
MANUAL = "manual"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Brand:
|
|
||||||
"""
|
|
||||||
A brand discovered or registered in the system.
|
|
||||||
|
|
||||||
Airings track where/when the brand appeared — each airing
|
|
||||||
references a timeline and a frame range.
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: UUID
|
|
||||||
canonical_name: str
|
|
||||||
aliases: List[str] = field(default_factory=list)
|
|
||||||
source: BrandSource = BrandSource.OCR # how first discovered
|
|
||||||
confirmed: bool = False
|
|
||||||
|
|
||||||
# Airings — JSONB array of appearances
|
|
||||||
# [{timeline_id, frame_start, frame_end, confidence, source, timestamp}]
|
|
||||||
airings: List[Dict[str, Any]] = field(default_factory=list)
|
|
||||||
total_airings: int = 0
|
|
||||||
|
|
||||||
created_at: Optional[datetime] = None
|
|
||||||
updated_at: Optional[datetime] = None
|
|
||||||
|
|||||||
@@ -1,13 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
Detection pipeline runtime models.
|
Detection pipeline runtime models.
|
||||||
|
|
||||||
These are the data structures that flow between LangGraph nodes.
|
These are the data structures that flow between pipeline stages.
|
||||||
They contain runtime types (np.ndarray) so they are NOT generated
|
They contain runtime types (np.ndarray) so modelgen skips them —
|
||||||
by modelgen — they live here for the schema to be the complete
|
not generated to SQLModel or TypeScript.
|
||||||
map of the application, but modelgen skips them.
|
|
||||||
|
|
||||||
Wire-format models (SSE events) are in detect.py.
|
|
||||||
DB models (jobs, checkpoints) are in detect_jobs.py.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -89,10 +85,3 @@ class DetectionReport:
|
|||||||
brands: dict[str, BrandStats] = field(default_factory=dict)
|
brands: dict[str, BrandStats] = field(default_factory=dict)
|
||||||
timeline: list[BrandDetection] = field(default_factory=list)
|
timeline: list[BrandDetection] = field(default_factory=list)
|
||||||
pipeline_stats: PipelineStats = field(default_factory=PipelineStats)
|
pipeline_stats: PipelineStats = field(default_factory=PipelineStats)
|
||||||
|
|
||||||
|
|
||||||
# Not in DATACLASSES — modelgen skips these (they contain np.ndarray)
|
|
||||||
RUNTIME_MODELS = [
|
|
||||||
Frame, BoundingBox, TextCandidate, BrandDetection,
|
|
||||||
BrandStats, PipelineStats, DetectionReport,
|
|
||||||
]
|
|
||||||
29
core/schema/models/timeline.py
Normal file
29
core/schema/models/timeline.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""Timeline schema — source of truth for frame sequences."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Timeline:
|
||||||
|
"""
|
||||||
|
The frame sequence from a source video.
|
||||||
|
|
||||||
|
Independent of stages — exists before any stage runs.
|
||||||
|
Frames stored in MinIO as JPEGs, metadata here.
|
||||||
|
One timeline per job.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
source_asset_id: Optional[UUID] = None
|
||||||
|
source_video: str = ""
|
||||||
|
profile_name: str = ""
|
||||||
|
fps: float = 2.0
|
||||||
|
|
||||||
|
frames_prefix: str = "" # s3: timeline/{id}/frames/
|
||||||
|
frames_manifest: Dict[int, str] = field(default_factory=dict) # seq → s3 key
|
||||||
|
frames_meta: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
created_at: Optional[datetime] = None
|
||||||
@@ -1,8 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Serializers for detection pipeline runtime models.
|
Serializers for detection pipeline runtime models.
|
||||||
|
|
||||||
Mirrors core/schema/models/detect_pipeline.py.
|
|
||||||
|
|
||||||
Special handling:
|
Special handling:
|
||||||
- Frame.image (np.ndarray → S3, excluded from JSON)
|
- Frame.image (np.ndarray → S3, excluded from JSON)
|
||||||
- TextCandidate.frame (object ref → frame_sequence integer)
|
- TextCandidate.frame (object ref → frame_sequence integer)
|
||||||
@@ -13,7 +11,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
|
||||||
from core.schema.models.detect_pipeline import (
|
from core.schema.models.pipeline import (
|
||||||
BoundingBox,
|
BoundingBox,
|
||||||
BrandDetection,
|
BrandDetection,
|
||||||
BrandStats,
|
BrandStats,
|
||||||
@@ -59,13 +57,12 @@ def deserialize_frames_with_download(meta: list[dict], manifest: dict, job_id: s
|
|||||||
|
|
||||||
def serialize_text_candidate(tc: TextCandidate) -> dict:
|
def serialize_text_candidate(tc: TextCandidate) -> dict:
|
||||||
bbox_dict = dataclasses.asdict(tc.bbox)
|
bbox_dict = dataclasses.asdict(tc.bbox)
|
||||||
result = {
|
return {
|
||||||
"frame_sequence": tc.frame.sequence,
|
"frame_sequence": tc.frame.sequence,
|
||||||
"bbox": bbox_dict,
|
"bbox": bbox_dict,
|
||||||
"text": tc.text,
|
"text": tc.text,
|
||||||
"ocr_confidence": tc.ocr_confidence,
|
"ocr_confidence": tc.ocr_confidence,
|
||||||
}
|
}
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_text_candidates(candidates: list[TextCandidate]) -> list[dict]:
|
def serialize_text_candidates(candidates: list[TextCandidate]) -> list[dict]:
|
||||||
@@ -75,13 +72,12 @@ def serialize_text_candidates(candidates: list[TextCandidate]) -> list[dict]:
|
|||||||
def deserialize_text_candidate(data: dict, frame_map: dict[int, Frame]) -> TextCandidate:
|
def deserialize_text_candidate(data: dict, frame_map: dict[int, Frame]) -> TextCandidate:
|
||||||
frame = frame_map[data["frame_sequence"]]
|
frame = frame_map[data["frame_sequence"]]
|
||||||
bbox = safe_construct(BoundingBox, data["bbox"])
|
bbox = safe_construct(BoundingBox, data["bbox"])
|
||||||
candidate = TextCandidate(
|
return TextCandidate(
|
||||||
frame=frame,
|
frame=frame,
|
||||||
bbox=bbox,
|
bbox=bbox,
|
||||||
text=data["text"],
|
text=data["text"],
|
||||||
ocr_confidence=data["ocr_confidence"],
|
ocr_confidence=data["ocr_confidence"],
|
||||||
)
|
)
|
||||||
return candidate
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize_text_candidates(data: list[dict], frame_map: dict[int, Frame]) -> list[TextCandidate]:
|
def deserialize_text_candidates(data: list[dict], frame_map: dict[int, Frame]) -> list[TextCandidate]:
|
||||||
@@ -11,7 +11,7 @@ that don't belong to any stage.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from core.schema.serializers._common import serialize_dataclass
|
from core.schema.serializers._common import serialize_dataclass
|
||||||
from core.schema.serializers.detect_pipeline import (
|
from core.schema.serializers.pipeline import (
|
||||||
deserialize_pipeline_stats,
|
deserialize_pipeline_stats,
|
||||||
deserialize_text_candidates,
|
deserialize_text_candidates,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -33,83 +33,81 @@ def create_timeline(
|
|||||||
|
|
||||||
Returns (timeline_id, checkpoint_id).
|
Returns (timeline_id, checkpoint_id).
|
||||||
"""
|
"""
|
||||||
from core.db.detect import create_timeline as db_create_timeline
|
from core.db.tables import Timeline, Checkpoint
|
||||||
from core.db.detect import save_checkpoint
|
|
||||||
|
|
||||||
# Create timeline
|
|
||||||
timeline = db_create_timeline(
|
|
||||||
source_video=source_video,
|
|
||||||
profile_name=profile_name,
|
|
||||||
source_asset_id=source_asset_id,
|
|
||||||
fps=fps,
|
|
||||||
)
|
|
||||||
tid = str(timeline.id)
|
|
||||||
|
|
||||||
# Upload frames to MinIO
|
|
||||||
manifest = save_frames(tid, frames)
|
|
||||||
|
|
||||||
# Store frame metadata on the timeline
|
|
||||||
frames_meta = [
|
|
||||||
{
|
|
||||||
"sequence": f.sequence,
|
|
||||||
"chunk_id": getattr(f, "chunk_id", 0),
|
|
||||||
"timestamp": f.timestamp,
|
|
||||||
"perceptual_hash": getattr(f, "perceptual_hash", ""),
|
|
||||||
}
|
|
||||||
for f in frames
|
|
||||||
]
|
|
||||||
|
|
||||||
timeline.frames_prefix = f"{CHECKPOINT_PREFIX}/{tid}/frames/"
|
|
||||||
timeline.frames_manifest = {str(k): v for k, v in manifest.items()}
|
|
||||||
timeline.frames_meta = frames_meta
|
|
||||||
|
|
||||||
from core.db.connection import get_session
|
from core.db.connection import get_session
|
||||||
|
|
||||||
with get_session() as session:
|
with get_session() as session:
|
||||||
|
timeline = Timeline(
|
||||||
|
source_video=source_video,
|
||||||
|
profile_name=profile_name,
|
||||||
|
source_asset_id=source_asset_id,
|
||||||
|
fps=fps,
|
||||||
|
)
|
||||||
session.add(timeline)
|
session.add(timeline)
|
||||||
|
session.flush()
|
||||||
|
tid = str(timeline.id)
|
||||||
|
|
||||||
|
# Upload frames to MinIO
|
||||||
|
manifest = save_frames(tid, frames)
|
||||||
|
|
||||||
|
frames_meta = [
|
||||||
|
{
|
||||||
|
"sequence": f.sequence,
|
||||||
|
"chunk_id": getattr(f, "chunk_id", 0),
|
||||||
|
"timestamp": f.timestamp,
|
||||||
|
"perceptual_hash": getattr(f, "perceptual_hash", ""),
|
||||||
|
}
|
||||||
|
for f in frames
|
||||||
|
]
|
||||||
|
|
||||||
|
timeline.frames_prefix = f"{CHECKPOINT_PREFIX}/{tid}/frames/"
|
||||||
|
timeline.frames_manifest = {str(k): v for k, v in manifest.items()}
|
||||||
|
timeline.frames_meta = frames_meta
|
||||||
|
|
||||||
|
checkpoint = Checkpoint(
|
||||||
|
timeline_id=timeline.id,
|
||||||
|
parent_id=None,
|
||||||
|
stage_outputs={},
|
||||||
|
stats={"frames_extracted": len(frames)},
|
||||||
|
)
|
||||||
|
session.add(checkpoint)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
session.refresh(checkpoint)
|
||||||
|
cid = str(checkpoint.id)
|
||||||
|
|
||||||
# Create root checkpoint (no parent, no stage outputs yet)
|
logger.info("Timeline created: %s (%d frames, root checkpoint %s)", tid, len(frames), cid)
|
||||||
checkpoint = save_checkpoint(
|
return tid, cid
|
||||||
timeline_id=timeline.id,
|
|
||||||
parent_id=None,
|
|
||||||
stage_outputs={},
|
|
||||||
stats={"frames_extracted": len(frames)},
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Timeline created: %s (%d frames, root checkpoint %s)",
|
|
||||||
tid, len(frames), checkpoint.id)
|
|
||||||
return tid, str(checkpoint.id)
|
|
||||||
|
|
||||||
|
|
||||||
def get_timeline_frames(timeline_id: str) -> list:
|
def get_timeline_frames(timeline_id: str) -> list:
|
||||||
"""Load frames from a timeline (from MinIO) as Frame objects."""
|
"""Load frames from a timeline (from MinIO) as Frame objects."""
|
||||||
from core.db.detect import get_timeline
|
from core.db.tables import Timeline
|
||||||
|
from core.db.connection import get_session
|
||||||
|
|
||||||
timeline = get_timeline(timeline_id)
|
with get_session() as session:
|
||||||
|
timeline = session.get(Timeline, UUID(timeline_id))
|
||||||
if not timeline:
|
if not timeline:
|
||||||
raise ValueError(f"Timeline not found: {timeline_id}")
|
raise ValueError(f"Timeline not found: {timeline_id}")
|
||||||
|
|
||||||
raw_manifest = timeline.frames_manifest or {}
|
raw_manifest = timeline.frames_manifest or {}
|
||||||
manifest = {int(k): v for k, v in raw_manifest.items()}
|
manifest = {int(k): v for k, v in raw_manifest.items()}
|
||||||
frame_metadata = timeline.frames_meta or []
|
return load_frames(manifest, timeline.frames_meta or [])
|
||||||
|
|
||||||
return load_frames(manifest, frame_metadata)
|
|
||||||
|
|
||||||
|
|
||||||
def get_timeline_frames_b64(timeline_id: str) -> list[dict]:
|
def get_timeline_frames_b64(timeline_id: str) -> list[dict]:
|
||||||
"""Load frames as base64 JPEG (lightweight, no numpy)."""
|
"""Load frames as base64 JPEG (lightweight, no numpy)."""
|
||||||
from core.db.detect import get_timeline
|
from core.db.tables import Timeline
|
||||||
|
from core.db.connection import get_session
|
||||||
from .frames import load_frames_b64
|
from .frames import load_frames_b64
|
||||||
|
|
||||||
timeline = get_timeline(timeline_id)
|
with get_session() as session:
|
||||||
|
timeline = session.get(Timeline, UUID(timeline_id))
|
||||||
if not timeline:
|
if not timeline:
|
||||||
raise ValueError(f"Timeline not found: {timeline_id}")
|
raise ValueError(f"Timeline not found: {timeline_id}")
|
||||||
|
|
||||||
raw_manifest = timeline.frames_manifest or {}
|
raw_manifest = timeline.frames_manifest or {}
|
||||||
manifest = {int(k): v for k, v in raw_manifest.items()}
|
manifest = {int(k): v for k, v in raw_manifest.items()}
|
||||||
frame_metadata = timeline.frames_meta or []
|
return load_frames_b64(manifest, timeline.frames_meta or [])
|
||||||
|
|
||||||
return load_frames_b64(manifest, frame_metadata)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -132,47 +130,46 @@ def save_stage_output(
|
|||||||
Carries forward stage outputs from parent + adds the new one.
|
Carries forward stage outputs from parent + adds the new one.
|
||||||
Returns the new checkpoint ID.
|
Returns the new checkpoint ID.
|
||||||
"""
|
"""
|
||||||
from core.db.detect import get_checkpoint, save_checkpoint
|
from core.db.tables import Checkpoint
|
||||||
|
from core.db.connection import get_session
|
||||||
|
|
||||||
# Carry forward from parent
|
with get_session() as session:
|
||||||
parent_outputs = {}
|
parent_outputs = {}
|
||||||
parent_stats = {}
|
parent_stats = {}
|
||||||
parent_config = {}
|
parent_config = {}
|
||||||
if parent_checkpoint_id:
|
if parent_checkpoint_id:
|
||||||
parent = get_checkpoint(parent_checkpoint_id)
|
parent = session.get(Checkpoint, UUID(parent_checkpoint_id))
|
||||||
if parent:
|
if parent:
|
||||||
parent_outputs = dict(parent.stage_outputs or {})
|
parent_outputs = dict(parent.stage_outputs or {})
|
||||||
parent_stats = dict(parent.stats or {})
|
parent_stats = dict(parent.stats or {})
|
||||||
parent_config = dict(parent.config_overrides or {})
|
parent_config = dict(parent.config_overrides or {})
|
||||||
|
|
||||||
# Add new stage output
|
checkpoint = Checkpoint(
|
||||||
stage_outputs = {**parent_outputs, stage_name: output_json}
|
timeline_id=UUID(timeline_id),
|
||||||
|
parent_id=UUID(parent_checkpoint_id) if parent_checkpoint_id else None,
|
||||||
# Merge stats and config
|
stage_outputs={**parent_outputs, stage_name: output_json},
|
||||||
merged_stats = {**parent_stats, **(stats or {})}
|
config_overrides={**parent_config, **(config_overrides or {})},
|
||||||
merged_config = {**parent_config, **(config_overrides or {})}
|
stats={**parent_stats, **(stats or {})},
|
||||||
|
is_scenario=is_scenario,
|
||||||
checkpoint = save_checkpoint(
|
scenario_label=scenario_label,
|
||||||
timeline_id=timeline_id,
|
)
|
||||||
parent_id=parent_checkpoint_id,
|
session.add(checkpoint)
|
||||||
stage_outputs=stage_outputs,
|
session.commit()
|
||||||
config_overrides=merged_config,
|
session.refresh(checkpoint)
|
||||||
stats=merged_stats,
|
cid = str(checkpoint.id)
|
||||||
is_scenario=is_scenario,
|
|
||||||
scenario_label=scenario_label,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Checkpoint saved: %s (timeline %s, stage %s, parent %s)",
|
logger.info("Checkpoint saved: %s (timeline %s, stage %s, parent %s)",
|
||||||
checkpoint.id, timeline_id, stage_name, parent_checkpoint_id)
|
cid, timeline_id, stage_name, parent_checkpoint_id)
|
||||||
return str(checkpoint.id)
|
return cid
|
||||||
|
|
||||||
|
|
||||||
def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None:
|
def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None:
|
||||||
"""Load a stage's output from a checkpoint."""
|
"""Load a stage's output from a checkpoint."""
|
||||||
from core.db.detect import get_checkpoint
|
from core.db.tables import Checkpoint
|
||||||
|
from core.db.connection import get_session
|
||||||
|
|
||||||
checkpoint = get_checkpoint(checkpoint_id)
|
with get_session() as session:
|
||||||
|
checkpoint = session.get(Checkpoint, UUID(checkpoint_id))
|
||||||
if not checkpoint:
|
if not checkpoint:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return (checkpoint.stage_outputs or {}).get(stage_name)
|
return (checkpoint.stage_outputs or {}).get(stage_name)
|
||||||
|
|||||||
@@ -326,6 +326,7 @@ def node_compile_report(state: DetectState) -> dict:
|
|||||||
|
|
||||||
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
|
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
|
||||||
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
|
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
|
||||||
|
_latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id
|
||||||
|
|
||||||
|
|
||||||
class PipelineCancelled(Exception):
|
class PipelineCancelled(Exception):
|
||||||
@@ -361,17 +362,33 @@ def _checkpointing_node(node_name: str, node_fn):
|
|||||||
if not job_id:
|
if not job_id:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
from detect.checkpoint import save_checkpoint, save_frames
|
from detect.checkpoint import save_stage_output, save_frames
|
||||||
|
from detect.stages.base import _REGISTRY
|
||||||
|
|
||||||
merged = {**state, **result}
|
merged = {**state, **result}
|
||||||
|
|
||||||
# Save frames once (first checkpoint), reuse manifest after
|
# Save frames once (first node), reuse manifest after
|
||||||
manifest = _frames_manifest.get(job_id)
|
manifest = _frames_manifest.get(job_id)
|
||||||
if manifest is None and node_name == "extract_frames":
|
if manifest is None and node_name == "extract_frames":
|
||||||
manifest = save_frames(job_id, merged.get("frames", []))
|
manifest = save_frames(job_id, merged.get("frames", []))
|
||||||
_frames_manifest[job_id] = manifest
|
_frames_manifest[job_id] = manifest
|
||||||
|
|
||||||
save_checkpoint(job_id, node_name, stage_index, merged, frames_manifest=manifest)
|
# Serialize stage output using the stage's serialize_fn if available
|
||||||
|
stage_cls = _REGISTRY.get(node_name)
|
||||||
|
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
||||||
|
if serialize_fn:
|
||||||
|
output_json = serialize_fn(merged, job_id)
|
||||||
|
else:
|
||||||
|
output_json = {}
|
||||||
|
|
||||||
|
parent_id = _latest_checkpoint.get(job_id)
|
||||||
|
new_checkpoint_id = save_stage_output(
|
||||||
|
timeline_id=job_id,
|
||||||
|
parent_checkpoint_id=parent_id,
|
||||||
|
stage_name=node_name,
|
||||||
|
output_json=output_json,
|
||||||
|
)
|
||||||
|
_latest_checkpoint[job_id] = new_checkpoint_id
|
||||||
return result
|
return result
|
||||||
|
|
||||||
wrapper.__name__ = node_fn.__name__
|
wrapper.__name__ = node_fn.__name__
|
||||||
|
|||||||
@@ -1,11 +1,6 @@
|
|||||||
"""
|
"""Re-export pipeline runtime models from core/schema/models/pipeline.py."""
|
||||||
Re-export pipeline runtime models from core/schema/models/detect_pipeline.py.
|
|
||||||
|
|
||||||
All models are defined in core/schema/ — this module exists for backward
|
from core.schema.models.pipeline import (
|
||||||
compatibility so existing imports (from detect.models import Frame) keep working.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from core.schema.models.detect_pipeline import (
|
|
||||||
BoundingBox,
|
BoundingBox,
|
||||||
BrandDetection,
|
BrandDetection,
|
||||||
BrandStats,
|
BrandStats,
|
||||||
|
|||||||
@@ -4,14 +4,10 @@ Stage 5 — Brand Resolver (discovery mode)
|
|||||||
Discovery-first brand matching. No static dictionary — all brands live in the DB.
|
Discovery-first brand matching. No static dictionary — all brands live in the DB.
|
||||||
|
|
||||||
Flow:
|
Flow:
|
||||||
1. Check session sightings first (brands already seen in this source)
|
1. Check session brands first (brands already seen in this run, in-memory)
|
||||||
2. Check global known brands (accumulated across all runs)
|
2. Check global known brands (accumulated across all runs)
|
||||||
3. Unresolved candidates → escalate to VLM/cloud
|
3. Unresolved candidates → escalate to VLM/cloud
|
||||||
4. Confirmed brands get added to DB for future runs
|
4. Confirmed brands get added to DB for future runs
|
||||||
|
|
||||||
The resolver is an enricher, not a gatekeeper. Every OCR text candidate
|
|
||||||
passes through — the question is whether we can resolve it cheaply (DB lookup)
|
|
||||||
or need to escalate (VLM/cloud).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -33,41 +29,30 @@ def _normalize(text: str) -> str:
|
|||||||
|
|
||||||
def _has_db() -> bool:
|
def _has_db() -> bool:
|
||||||
try:
|
try:
|
||||||
from core.db.detect import find_brand_by_text as _
|
from core.db import find_brand_by_text as _
|
||||||
from admin.mpr.media_assets.models import KnownBrand as _
|
|
||||||
return True
|
return True
|
||||||
except (ImportError, Exception):
|
except (ImportError, Exception):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _match_session(text: str, session_brands: dict[str, str]) -> str | None:
|
def _match_session(text: str, session_brands: dict[str, str]) -> str | None:
|
||||||
"""
|
return session_brands.get(_normalize(text))
|
||||||
Check against session brands (already seen in this source).
|
|
||||||
|
|
||||||
session_brands: {normalized_name: canonical_name, ...}
|
|
||||||
Includes aliases.
|
|
||||||
"""
|
|
||||||
normalized = _normalize(text)
|
|
||||||
return session_brands.get(normalized)
|
|
||||||
|
|
||||||
|
|
||||||
def _match_known(text: str, threshold: int) -> tuple[str | None, str | None]:
|
def _match_known(text: str, threshold: int) -> tuple[str | None, str | None]:
|
||||||
"""
|
"""Check against global known brands in DB. Returns (canonical_name, brand_id) or (None, None)."""
|
||||||
Check against global known brands in DB.
|
|
||||||
|
|
||||||
Returns (canonical_name, brand_id) or (None, None).
|
|
||||||
"""
|
|
||||||
if not _has_db():
|
if not _has_db():
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
from core.db.detect import find_brand_by_text
|
from core.db import find_brand_by_text, list_brands
|
||||||
brand = find_brand_by_text(text)
|
from core.db.connection import get_session
|
||||||
if brand:
|
|
||||||
return brand.canonical_name, str(brand.id)
|
|
||||||
|
|
||||||
# Fuzzy match against all known brands
|
with get_session() as session:
|
||||||
from core.db.detect import list_all_brands
|
brand = find_brand_by_text(session, text)
|
||||||
all_brands = list_all_brands()
|
if brand:
|
||||||
|
return brand.canonical_name, str(brand.id)
|
||||||
|
|
||||||
|
all_brands = list_brands(session)
|
||||||
|
|
||||||
normalized = _normalize(text)
|
normalized = _normalize(text)
|
||||||
best_brand = None
|
best_brand = None
|
||||||
@@ -92,58 +77,62 @@ def _register_brand(canonical_name: str, source: str) -> str | None:
|
|||||||
if not _has_db():
|
if not _has_db():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
from core.db.detect import get_or_create_brand
|
from core.db import get_or_create_brand
|
||||||
brand, created = get_or_create_brand(canonical_name, source=source)
|
from core.db.connection import get_session
|
||||||
|
|
||||||
|
with get_session() as session:
|
||||||
|
brand, created = get_or_create_brand(session, canonical_name, source=source)
|
||||||
|
session.commit()
|
||||||
if created:
|
if created:
|
||||||
logger.info("New brand discovered: %s (source=%s)", canonical_name, source)
|
logger.info("New brand discovered: %s (source=%s)", canonical_name, source)
|
||||||
return str(brand.id)
|
return str(brand.id)
|
||||||
|
|
||||||
|
|
||||||
def _record_sighting(source_asset_id: str | None, brand_id: str,
|
def _record_airing(timeline_id: str | None, brand_id: str,
|
||||||
brand_name: str, timestamp: float,
|
frame_seq: int, confidence: float, source: str):
|
||||||
confidence: float, source: str):
|
"""Record a brand airing on a timeline."""
|
||||||
"""Record a brand sighting for this source."""
|
if not _has_db() or not timeline_id:
|
||||||
if not _has_db() or not source_asset_id:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
from core.db.detect import record_sighting
|
from core.db import record_airing
|
||||||
import uuid
|
from core.db.connection import get_session
|
||||||
asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id
|
from uuid import UUID
|
||||||
brand_uuid = uuid.UUID(brand_id) if isinstance(brand_id, str) else brand_id
|
|
||||||
record_sighting(asset_id, brand_uuid, brand_name, timestamp, confidence, source)
|
with get_session() as session:
|
||||||
|
record_airing(
|
||||||
|
session,
|
||||||
|
brand_id=UUID(brand_id),
|
||||||
|
timeline_id=UUID(timeline_id),
|
||||||
|
frame_start=frame_seq,
|
||||||
|
frame_end=frame_seq,
|
||||||
|
confidence=confidence,
|
||||||
|
source=source,
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
def build_session_dict(source_asset_id: str | None) -> dict[str, str]:
|
def build_session_dict(source_asset_id: str | None = None) -> dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Load session brands from DB for this source.
|
Load known brands from DB as a session lookup dict.
|
||||||
|
|
||||||
Returns {normalized_name: canonical_name, ...} including aliases.
|
Returns {normalized_name: canonical_name, ...} including aliases.
|
||||||
"""
|
"""
|
||||||
if not _has_db() or not source_asset_id:
|
if not _has_db():
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
from core.db.detect import get_source_sightings
|
from core.db import list_brands
|
||||||
import uuid
|
from core.db.connection import get_session
|
||||||
|
|
||||||
asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id
|
with get_session() as session:
|
||||||
sightings = get_source_sightings(asset_id)
|
all_brands = list_brands(session)
|
||||||
|
|
||||||
session = {}
|
session_dict = {}
|
||||||
for s in sightings:
|
for brand in all_brands:
|
||||||
canonical = s.brand_name
|
session_dict[_normalize(brand.canonical_name)] = brand.canonical_name
|
||||||
session[_normalize(canonical)] = canonical
|
for alias in (brand.aliases or []):
|
||||||
|
session_dict[_normalize(alias)] = brand.canonical_name
|
||||||
|
|
||||||
# Also load aliases from KnownBrand for each sighted brand
|
return session_dict
|
||||||
if _has_db():
|
|
||||||
from core.db.detect import list_all_brands
|
|
||||||
all_brands = list_all_brands()
|
|
||||||
sighted_names = {s.brand_name for s in sightings}
|
|
||||||
for brand in all_brands:
|
|
||||||
if brand.canonical_name in sighted_names:
|
|
||||||
for alias in (brand.aliases or []):
|
|
||||||
session[_normalize(alias)] = brand.canonical_name
|
|
||||||
|
|
||||||
return session
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_brands(
|
def resolve_brands(
|
||||||
@@ -158,7 +147,7 @@ def resolve_brands(
|
|||||||
Match text candidates against known brands (session → global → unresolved).
|
Match text candidates against known brands (session → global → unresolved).
|
||||||
|
|
||||||
session_brands: pre-loaded session dict (from build_session_dict)
|
session_brands: pre-loaded session dict (from build_session_dict)
|
||||||
source_asset_id: for recording new sightings in DB
|
job_id: timeline_id — used to record airings
|
||||||
"""
|
"""
|
||||||
if session_brands is None:
|
if session_brands is None:
|
||||||
session_brands = {}
|
session_brands = {}
|
||||||
@@ -187,7 +176,6 @@ def resolve_brands(
|
|||||||
brand_name, brand_id = _match_known(text, config.fuzzy_threshold)
|
brand_name, brand_id = _match_known(text, config.fuzzy_threshold)
|
||||||
if brand_name:
|
if brand_name:
|
||||||
known_hits += 1
|
known_hits += 1
|
||||||
# Add to session for subsequent candidates in this run
|
|
||||||
session_brands[_normalize(brand_name)] = brand_name
|
session_brands[_normalize(brand_name)] = brand_name
|
||||||
|
|
||||||
if brand_name:
|
if brand_name:
|
||||||
@@ -203,11 +191,10 @@ def resolve_brands(
|
|||||||
)
|
)
|
||||||
matched.append(detection)
|
matched.append(detection)
|
||||||
|
|
||||||
# Record sighting in DB
|
|
||||||
if brand_id:
|
if brand_id:
|
||||||
_record_sighting(
|
_record_airing(
|
||||||
source_asset_id, brand_id, brand_name,
|
job_id, brand_id,
|
||||||
candidate.frame.timestamp, candidate.ocr_confidence, match_source,
|
candidate.frame.sequence, candidate.ocr_confidence, match_source,
|
||||||
)
|
)
|
||||||
|
|
||||||
emit.detection(
|
emit.detection(
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from core.schema.serializers._common import (
|
|||||||
serialize_dataclass,
|
serialize_dataclass,
|
||||||
serialize_dataclass_list,
|
serialize_dataclass_list,
|
||||||
)
|
)
|
||||||
from core.schema.serializers.detect_pipeline import (
|
from core.schema.serializers.pipeline import (
|
||||||
serialize_frame_meta,
|
serialize_frame_meta,
|
||||||
serialize_frames_with_upload as serialize_frames,
|
serialize_frames_with_upload as serialize_frames,
|
||||||
deserialize_frames_with_download as deserialize_frames,
|
deserialize_frames_with_download as deserialize_frames,
|
||||||
|
|||||||
@@ -35,9 +35,11 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
from core.db.detect import list_scenarios
|
from core.db import list_scenarios
|
||||||
|
from core.db.connection import get_session
|
||||||
|
|
||||||
scenarios = list_scenarios()
|
with get_session() as session:
|
||||||
|
scenarios = list_scenarios(session)
|
||||||
|
|
||||||
if not scenarios:
|
if not scenarios:
|
||||||
logger.info("No scenarios found. Create one with:")
|
logger.info("No scenarios found. Create one with:")
|
||||||
|
|||||||
@@ -121,15 +121,14 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Mark as scenario
|
# Mark as scenario
|
||||||
from core.db.detect import get_latest_checkpoint
|
from core.db import get_latest_checkpoint
|
||||||
from core.db.connection import get_session
|
from core.db.connection import get_session
|
||||||
|
|
||||||
checkpoint = get_latest_checkpoint(branch_id)
|
with get_session() as session:
|
||||||
if checkpoint:
|
checkpoint = get_latest_checkpoint(session, branch_id)
|
||||||
checkpoint.is_scenario = True
|
if checkpoint:
|
||||||
checkpoint.scenario_label = args.label
|
checkpoint.is_scenario = True
|
||||||
with get_session() as session:
|
checkpoint.scenario_label = args.label
|
||||||
session.add(checkpoint)
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
logger.info("")
|
logger.info("")
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import pytest
|
|||||||
|
|
||||||
from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
|
from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
|
||||||
from core.schema.serializers._common import safe_construct
|
from core.schema.serializers._common import safe_construct
|
||||||
from core.schema.serializers.detect_pipeline import (
|
from core.schema.serializers.pipeline import (
|
||||||
serialize_frame_meta,
|
serialize_frame_meta,
|
||||||
serialize_text_candidate,
|
serialize_text_candidate,
|
||||||
serialize_text_candidates,
|
serialize_text_candidates,
|
||||||
|
|||||||
Reference in New Issue
Block a user