From 51ce14a812f68bbe186fe191c994f1cca0b2ccf3 Mon Sep 17 00:00:00 2001 From: buenosairesam Date: Fri, 27 Mar 2026 06:02:58 -0300 Subject: [PATCH] major refactor --- core/api/detect_replay.py | 118 ++++++------ core/schema/models/__init__.py | 9 +- core/schema/models/brand.py | 38 ++++ core/schema/models/checkpoint.py | 38 ++++ core/schema/models/detect_jobs.py | 177 ------------------ core/schema/models/job.py | 97 +--------- .../{detect_pipeline.py => pipeline.py} | 17 +- core/schema/models/timeline.py | 29 +++ .../{detect_pipeline.py => pipeline.py} | 10 +- detect/checkpoint/serializer.py | 2 +- detect/checkpoint/storage.py | 163 ++++++++-------- detect/graph.py | 23 ++- detect/models.py | 9 +- detect/stages/brand_resolver.py | 121 ++++++------ detect/stages/registry/_serializers.py | 2 +- tests/detect/manual/list_scenarios.py | 6 +- tests/detect/manual/seed_scenario.py | 13 +- tests/detect/test_checkpoint.py | 2 +- 18 files changed, 351 insertions(+), 523 deletions(-) create mode 100644 core/schema/models/brand.py create mode 100644 core/schema/models/checkpoint.py delete mode 100644 core/schema/models/detect_jobs.py rename core/schema/models/{detect_pipeline.py => pipeline.py} (76%) create mode 100644 core/schema/models/timeline.py rename core/schema/serializers/{detect_pipeline.py => pipeline.py} (94%) diff --git a/core/api/detect_replay.py b/core/api/detect_replay.py index 28aea0a..2ab36f6 100644 --- a/core/api/detect_replay.py +++ b/core/api/detect_replay.py @@ -1,7 +1,7 @@ """ API endpoints for checkpoint inspection, replay, retry, and GPU proxy. -GET /detect/checkpoints/{job_id} — list available checkpoints +GET /detect/checkpoints/{timeline_id} — list available checkpoints POST /detect/replay — replay from a stage with config overrides POST /detect/retry — queue async retry with different provider POST /detect/replay-stage — replay single stage (fast path) @@ -31,7 +31,7 @@ class CheckpointInfo(BaseModel): class ScenarioInfo(BaseModel): - job_id: str + timeline_id: str stage: str scenario_label: str profile_name: str @@ -41,21 +41,21 @@ class ScenarioInfo(BaseModel): class ReplayRequest(BaseModel): - job_id: str + timeline_id: str start_stage: str config_overrides: dict | None = None class ReplayResponse(BaseModel): status: str - job_id: str + timeline_id: str start_stage: str detections: int = 0 brands_found: int = 0 class RetryRequest(BaseModel): - job_id: str + timeline_id: str config_overrides: dict | None = None start_stage: str = "escalate_vlm" schedule_seconds: float | None = None # delay before execution (off-peak) @@ -64,11 +64,11 @@ class RetryRequest(BaseModel): class RetryResponse(BaseModel): status: str task_id: str - job_id: str + timeline_id: str class ReplaySingleStageRequest(BaseModel): - job_id: str + timeline_id: str stage: str frame_refs: list[int] | None = None config_overrides: dict | None = None @@ -102,15 +102,15 @@ class ReplaySingleStageResponse(BaseModel): # --- Endpoints --- -@router.get("/checkpoints/{job_id}") -def list_checkpoints(job_id: str) -> list[CheckpointInfo]: +@router.get("/checkpoints/{timeline_id}") +def list_checkpoints(timeline_id: str) -> list[CheckpointInfo]: """List available checkpoint stages for a job.""" from detect.checkpoint import list_checkpoints as _list try: - stages = _list(job_id) + stages = _list(timeline_id) except Exception as e: - raise HTTPException(status_code=404, detail=f"No checkpoints for job {job_id}: {e}") + raise HTTPException(status_code=404, detail=f"No checkpoints for job {timeline_id}: {e}") result = [CheckpointInfo(stage=s) for s in stages] return result @@ -123,7 +123,7 @@ class CheckpointFrameInfo(BaseModel): class CheckpointData(BaseModel): - job_id: str + timeline_id: str stage: str profile_name: str video_path: str @@ -135,26 +135,32 @@ class CheckpointData(BaseModel): stage_output_key: str = "" -@router.get("/checkpoints/{job_id}/{stage}", response_model=CheckpointData) -def get_checkpoint_data(job_id: str, stage: str): +@router.get("/checkpoints/{timeline_id}/{stage}", response_model=CheckpointData) +def get_checkpoint_data(timeline_id: str, stage: str): """Load checkpoint frames + metadata for the editor UI.""" - from core.db.detect import get_stage_checkpoint + from uuid import UUID + from core.db.tables import Timeline, Checkpoint + from core.db.connection import get_session + from core.db.checkpoint import list_checkpoints from detect.checkpoint.frames import load_frames_b64 - checkpoint = get_stage_checkpoint(job_id, stage) - if not checkpoint: - raise HTTPException(status_code=404, detail=f"No checkpoint for {job_id}/{stage}") + with get_session() as session: + timeline = session.get(Timeline, UUID(timeline_id)) + if not timeline: + raise HTTPException(status_code=404, detail=f"Timeline not found: {timeline_id}") - raw_manifest = checkpoint.frames_manifest or {} - manifest = {int(k): v for k, v in raw_manifest.items()} - frame_metadata = checkpoint.frames_meta or [] + checkpoints = list_checkpoints(session, UUID(timeline_id)) + if not checkpoints: + raise HTTPException(status_code=404, detail=f"No checkpoints for timeline {timeline_id}") + # Prefer a checkpoint that has this stage's output; fall back to latest + checkpoint = next( + (c for c in reversed(checkpoints) if stage in (c.stage_outputs or {})), + checkpoints[-1], + ) - # Only load filtered frames if available, otherwise all - filtered = set(checkpoint.filtered_frame_sequences or []) - if filtered: - manifest = {k: v for k, v in manifest.items() if k in filtered} - - frames_b64 = load_frames_b64(manifest, frame_metadata) + raw_manifest = timeline.frames_manifest or {} + manifest = {int(k): v for k, v in raw_manifest.items()} + frames_b64 = load_frames_b64(manifest, timeline.frames_meta or []) frame_list = [ CheckpointFrameInfo(seq=f["seq"], timestamp=f["timestamp"], jpeg_b64=f["jpeg_b64"]) @@ -162,38 +168,44 @@ def get_checkpoint_data(job_id: str, stage: str): ] return CheckpointData( - job_id=str(checkpoint.job_id), - stage=checkpoint.stage, - profile_name=checkpoint.profile_name, - video_path=checkpoint.video_path, + timeline_id=timeline_id, + stage=stage, + profile_name=timeline.profile_name, + video_path=timeline.source_video, is_scenario=checkpoint.is_scenario, scenario_label=checkpoint.scenario_label, frames=frame_list, stats=checkpoint.stats or {}, - config_snapshot=checkpoint.config_snapshot or {}, - stage_output_key=checkpoint.stage_output_key or "", + config_snapshot=checkpoint.config_overrides or {}, + stage_output_key=stage, ) @router.get("/scenarios", response_model=list[ScenarioInfo]) def list_scenarios_endpoint(): """List all available scenarios (bookmarked checkpoints).""" - from core.db.detect import list_scenarios + from core.db.tables import Timeline + from core.db.connection import get_session + from core.db.checkpoint import list_scenarios - scenarios = list_scenarios() - result = [] - for s in scenarios: - manifest = s.frames_manifest or {} - info = ScenarioInfo( - job_id=str(s.job_id), - stage=s.stage, - scenario_label=s.scenario_label, - profile_name=s.profile_name, - video_path=s.video_path, - frame_count=len(manifest), - created_at=str(s.created_at) if s.created_at else "", - ) - result.append(info) + with get_session() as session: + scenarios = list_scenarios(session) + result = [] + for s in scenarios: + timeline = session.get(Timeline, s.timeline_id) + if not timeline: + continue + last_stage = next(reversed(s.stage_outputs), "") if s.stage_outputs else "" + info = ScenarioInfo( + timeline_id=str(s.timeline_id), + stage=last_stage, + scenario_label=s.scenario_label, + profile_name=timeline.profile_name, + video_path=timeline.source_video, + frame_count=len(timeline.frames_manifest or {}), + created_at=str(s.created_at) if s.created_at else "", + ) + result.append(info) return result @@ -204,7 +216,7 @@ def replay(req: ReplayRequest): try: result = replay_from( - job_id=req.job_id, + timeline_id=req.timeline_id, start_stage=req.start_stage, config_overrides=req.config_overrides, ) @@ -219,7 +231,7 @@ def replay(req: ReplayRequest): response = ReplayResponse( status="completed", - job_id=req.job_id, + timeline_id=req.timeline_id, start_stage=req.start_stage, detections=len(detections), brands_found=brands_found, @@ -233,7 +245,7 @@ def retry(req: RetryRequest): from detect.checkpoint.tasks import retry_candidates kwargs = { - "job_id": req.job_id, + "timeline_id": req.timeline_id, "config_overrides": req.config_overrides, "start_stage": req.start_stage, } @@ -246,7 +258,7 @@ def retry(req: RetryRequest): response = RetryResponse( status="queued", task_id=task.id, - job_id=req.job_id, + timeline_id=req.timeline_id, ) return response @@ -258,7 +270,7 @@ def replay_single_stage(req: ReplaySingleStageRequest): try: result = _replay( - job_id=req.job_id, + timeline_id=req.timeline_id, stage=req.stage, frame_refs=req.frame_refs, config_overrides=req.config_overrides, diff --git a/core/schema/models/__init__.py b/core/schema/models/__init__.py index 2d96d1f..9a076fb 100644 --- a/core/schema/models/__init__.py +++ b/core/schema/models/__init__.py @@ -25,11 +25,10 @@ from .grpc import ( ProgressUpdate, WorkerStatus, ) -from .job import ( - Job, JobStatus, RunType, - Timeline, Checkpoint, - BrandSource, Brand, -) +from .job import Job, JobStatus, RunType +from .timeline import Timeline +from .checkpoint import Checkpoint +from .brand import BrandSource, Brand from .media import AssetStatus, MediaAsset from .presets import BUILTIN_PRESETS, TranscodePreset from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generic loader diff --git a/core/schema/models/brand.py b/core/schema/models/brand.py new file mode 100644 index 0000000..bc5dbeb --- /dev/null +++ b/core/schema/models/brand.py @@ -0,0 +1,38 @@ +"""Brand schema — source of truth for brand discovery.""" + +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional +from uuid import UUID + + +class BrandSource(str, Enum): + OCR = "ocr" + VLM = "local_vlm" + CLOUD = "cloud_llm" + MANUAL = "manual" + + +@dataclass +class Brand: + """ + A brand discovered or registered in the system. + + Airings track where/when the brand appeared — each airing + references a timeline and a frame range. + """ + + id: UUID + canonical_name: str + aliases: List[str] = field(default_factory=list) + source: BrandSource = BrandSource.OCR # how first discovered + confirmed: bool = False + + # Airings — JSONB array of appearances + # [{timeline_id, frame_start, frame_end, confidence, source, timestamp}] + airings: List[Dict[str, Any]] = field(default_factory=list) + total_airings: int = 0 + + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None diff --git a/core/schema/models/checkpoint.py b/core/schema/models/checkpoint.py new file mode 100644 index 0000000..3e9b594 --- /dev/null +++ b/core/schema/models/checkpoint.py @@ -0,0 +1,38 @@ +"""Checkpoint schema — source of truth for pipeline state snapshots.""" + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Dict, Optional +from uuid import UUID + + +@dataclass +class Checkpoint: + """ + A snapshot of pipeline state on a timeline. + + Stage outputs stored as JSONB — each stage serializes to JSON, + the checkpoint stores it without knowing the shape. + + parent_id forms a tree: multiple children from the same parent + = different config tries from the same starting point. + """ + + id: UUID + timeline_id: UUID + parent_id: Optional[UUID] = None # null = root checkpoint + + # Stage outputs — JSONB per stage, opaque to the checkpoint layer + stage_outputs: Dict[str, Any] = field(default_factory=dict) + + # Config that produced this checkpoint + config_overrides: Dict[str, Any] = field(default_factory=dict) + + # Pipeline state + stats: Dict[str, Any] = field(default_factory=dict) + + # Scenario bookmark + is_scenario: bool = False + scenario_label: str = "" + + created_at: Optional[datetime] = None diff --git a/core/schema/models/detect_jobs.py b/core/schema/models/detect_jobs.py deleted file mode 100644 index e14e730..0000000 --- a/core/schema/models/detect_jobs.py +++ /dev/null @@ -1,177 +0,0 @@ -""" -Detection Job and Checkpoint Schema Definitions - -Source of truth for detection pipeline job tracking and stage checkpoints. -Follows the TranscodeJob/ChunkJob pattern. -""" - -from dataclasses import dataclass, field -from datetime import datetime -from enum import Enum -from typing import Any, Dict, List, Optional -from uuid import UUID - - -class DetectJobStatus(str, Enum): - PENDING = "pending" - RUNNING = "running" - PAUSED = "paused" - COMPLETED = "completed" - FAILED = "failed" - CANCELLED = "cancelled" - - -class RunType(str, Enum): - INITIAL = "initial" - REPLAY = "replay" - RETRY = "retry" - - -@dataclass -class DetectJob: - """ - A detection pipeline job. - - Each invocation of the pipeline (initial run, replay, retry) creates a DetectJob. - Jobs for the same source video are linked via parent_job_id. - """ - - id: UUID - - # Input - source_asset_id: UUID - video_path: str - profile_name: str = "soccer_broadcast" - - # Run lineage - parent_job_id: Optional[UUID] = None # links all runs for the same source - run_type: RunType = RunType.INITIAL - replay_from_stage: Optional[str] = None # null for initial runs - config_overrides: Dict[str, Any] = field(default_factory=dict) - - # Status - status: DetectJobStatus = DetectJobStatus.PENDING - current_stage: Optional[str] = None - progress: float = 0.0 - error_message: Optional[str] = None - - # Results summary - total_detections: int = 0 - brands_found: int = 0 - cloud_llm_calls: int = 0 - estimated_cost_usd: float = 0.0 - - # Worker tracking - celery_task_id: Optional[str] = None - priority: int = 0 - - # Timestamps - created_at: Optional[datetime] = None - started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None - - -@dataclass -class Timeline: - """ - The frame sequence from a source video. - - Independent of stages — exists before any stage runs. - Stages annotate the timeline, they don't own it. - Frames are stored in MinIO as JPEGs. - """ - - id: UUID - source_asset_id: Optional[UUID] = None - source_video: str = "" - profile_name: str = "" - fps: float = 2.0 - - # Frame metadata (images in MinIO, metadata here) - frames_prefix: str = "" # s3: timelines/{id}/frames/ - frames_manifest: Dict[int, str] = field(default_factory=dict) # seq → s3 key - frames_meta: List[Dict[str, Any]] = field(default_factory=list) - - created_at: Optional[datetime] = None - - -@dataclass -class Checkpoint: - """ - A snapshot of pipeline state on a timeline. - - Stage outputs stored as JSONB — each stage serializes to JSON, - the checkpoint stores it without knowing the shape. - - parent_id forms a tree: multiple children from the same parent - = different config tries from the same starting point. - """ - - id: UUID - timeline_id: UUID - parent_id: Optional[UUID] = None # null = root checkpoint - - # Stage outputs — JSONB per stage, opaque to the checkpoint layer - stage_outputs: Dict[str, Any] = field(default_factory=dict) - - # Config that produced this checkpoint - config_overrides: Dict[str, Any] = field(default_factory=dict) - - # Pipeline state - stats: Dict[str, Any] = field(default_factory=dict) - - # Scenario bookmark - is_scenario: bool = False - scenario_label: str = "" - - created_at: Optional[datetime] = None - - -class BrandSource(str, Enum): - """How a brand was first identified.""" - OCR = "ocr" - VLM = "local_vlm" - CLOUD = "cloud_llm" - MANUAL = "manual" # user-added via UI - - -@dataclass -class KnownBrand: - """ - A brand discovered or registered in the system. - - Global — not per-source. Accumulates across all pipeline runs. - Aliases enable fuzzy matching without re-escalating to VLM. - """ - - id: UUID - canonical_name: str # normalized display name - aliases: List[str] = field(default_factory=list) # known spellings/variants - first_source: BrandSource = BrandSource.OCR - total_occurrences: int = 0 - confirmed: bool = False # manually confirmed by user - - created_at: Optional[datetime] = None - updated_at: Optional[datetime] = None - - -@dataclass -class SourceBrandSighting: - """ - A brand seen in a specific source (video/asset). - - Per-source session cache — avoids re-escalating the same brand - on subsequent frames or re-runs of the same source. - """ - - id: UUID - source_asset_id: UUID # the video this sighting belongs to - brand_id: UUID # FK to KnownBrand - brand_name: str # denormalized for fast lookup - first_seen_timestamp: float = 0.0 - last_seen_timestamp: float = 0.0 - occurrences: int = 0 - detection_source: BrandSource = BrandSource.OCR - avg_confidence: float = 0.0 - - created_at: Optional[datetime] = None diff --git a/core/schema/models/job.py b/core/schema/models/job.py index a3b5391..013fa56 100644 --- a/core/schema/models/job.py +++ b/core/schema/models/job.py @@ -1,14 +1,9 @@ -""" -Job, Timeline, and Checkpoint Schema Definitions - -Source of truth for pipeline jobs, timelines, and checkpoints. -Generates: SQLModel (core/db/models.py), TypeScript via modelgen. -""" +"""Job schema — source of truth for pipeline jobs.""" from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional from uuid import UUID @@ -68,91 +63,3 @@ class Job: created_at: Optional[datetime] = None started_at: Optional[datetime] = None completed_at: Optional[datetime] = None - - -@dataclass -class Timeline: - """ - The frame sequence from a source video. - - Independent of stages — exists before any stage runs. - Frames stored in MinIO as JPEGs, metadata here. - One timeline per job. - """ - - id: UUID - source_asset_id: Optional[UUID] = None - source_video: str = "" - profile_name: str = "" - fps: float = 2.0 - - frames_prefix: str = "" # s3: timeline/{id}/frames/ - frames_manifest: Dict[int, str] = field(default_factory=dict) # seq → s3 key - frames_meta: List[Dict[str, Any]] = field(default_factory=list) - - created_at: Optional[datetime] = None - - -@dataclass -class Checkpoint: - """ - A snapshot of pipeline state on a timeline. - - Stage outputs stored as JSONB — each stage serializes to JSON, - the checkpoint stores it without knowing the shape. - - parent_id forms a tree: multiple children from the same parent - = different config tries from the same starting point. - """ - - id: UUID - timeline_id: UUID - parent_id: Optional[UUID] = None # null = root checkpoint - - # Stage outputs — JSONB per stage, opaque to the checkpoint layer - stage_outputs: Dict[str, Any] = field(default_factory=dict) - - # Config that produced this checkpoint - config_overrides: Dict[str, Any] = field(default_factory=dict) - - # Pipeline state - stats: Dict[str, Any] = field(default_factory=dict) - - # Scenario bookmark - is_scenario: bool = False - scenario_label: str = "" - - created_at: Optional[datetime] = None - - -# --- Brands --- - -class BrandSource(str, Enum): - OCR = "ocr" - VLM = "local_vlm" - CLOUD = "cloud_llm" - MANUAL = "manual" - - -@dataclass -class Brand: - """ - A brand discovered or registered in the system. - - Airings track where/when the brand appeared — each airing - references a timeline and a frame range. - """ - - id: UUID - canonical_name: str - aliases: List[str] = field(default_factory=list) - source: BrandSource = BrandSource.OCR # how first discovered - confirmed: bool = False - - # Airings — JSONB array of appearances - # [{timeline_id, frame_start, frame_end, confidence, source, timestamp}] - airings: List[Dict[str, Any]] = field(default_factory=list) - total_airings: int = 0 - - created_at: Optional[datetime] = None - updated_at: Optional[datetime] = None diff --git a/core/schema/models/detect_pipeline.py b/core/schema/models/pipeline.py similarity index 76% rename from core/schema/models/detect_pipeline.py rename to core/schema/models/pipeline.py index 84cdd57..8a8c763 100644 --- a/core/schema/models/detect_pipeline.py +++ b/core/schema/models/pipeline.py @@ -1,13 +1,9 @@ """ Detection pipeline runtime models. -These are the data structures that flow between LangGraph nodes. -They contain runtime types (np.ndarray) so they are NOT generated -by modelgen — they live here for the schema to be the complete -map of the application, but modelgen skips them. - -Wire-format models (SSE events) are in detect.py. -DB models (jobs, checkpoints) are in detect_jobs.py. +These are the data structures that flow between pipeline stages. +They contain runtime types (np.ndarray) so modelgen skips them — +not generated to SQLModel or TypeScript. """ from __future__ import annotations @@ -89,10 +85,3 @@ class DetectionReport: brands: dict[str, BrandStats] = field(default_factory=dict) timeline: list[BrandDetection] = field(default_factory=list) pipeline_stats: PipelineStats = field(default_factory=PipelineStats) - - -# Not in DATACLASSES — modelgen skips these (they contain np.ndarray) -RUNTIME_MODELS = [ - Frame, BoundingBox, TextCandidate, BrandDetection, - BrandStats, PipelineStats, DetectionReport, -] diff --git a/core/schema/models/timeline.py b/core/schema/models/timeline.py new file mode 100644 index 0000000..dcca8b3 --- /dev/null +++ b/core/schema/models/timeline.py @@ -0,0 +1,29 @@ +"""Timeline schema — source of truth for frame sequences.""" + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Dict, List, Optional +from uuid import UUID + + +@dataclass +class Timeline: + """ + The frame sequence from a source video. + + Independent of stages — exists before any stage runs. + Frames stored in MinIO as JPEGs, metadata here. + One timeline per job. + """ + + id: UUID + source_asset_id: Optional[UUID] = None + source_video: str = "" + profile_name: str = "" + fps: float = 2.0 + + frames_prefix: str = "" # s3: timeline/{id}/frames/ + frames_manifest: Dict[int, str] = field(default_factory=dict) # seq → s3 key + frames_meta: List[Dict[str, Any]] = field(default_factory=list) + + created_at: Optional[datetime] = None diff --git a/core/schema/serializers/detect_pipeline.py b/core/schema/serializers/pipeline.py similarity index 94% rename from core/schema/serializers/detect_pipeline.py rename to core/schema/serializers/pipeline.py index 9738cb4..e8b1440 100644 --- a/core/schema/serializers/detect_pipeline.py +++ b/core/schema/serializers/pipeline.py @@ -1,8 +1,6 @@ """ Serializers for detection pipeline runtime models. -Mirrors core/schema/models/detect_pipeline.py. - Special handling: - Frame.image (np.ndarray → S3, excluded from JSON) - TextCandidate.frame (object ref → frame_sequence integer) @@ -13,7 +11,7 @@ from __future__ import annotations import dataclasses -from core.schema.models.detect_pipeline import ( +from core.schema.models.pipeline import ( BoundingBox, BrandDetection, BrandStats, @@ -59,13 +57,12 @@ def deserialize_frames_with_download(meta: list[dict], manifest: dict, job_id: s def serialize_text_candidate(tc: TextCandidate) -> dict: bbox_dict = dataclasses.asdict(tc.bbox) - result = { + return { "frame_sequence": tc.frame.sequence, "bbox": bbox_dict, "text": tc.text, "ocr_confidence": tc.ocr_confidence, } - return result def serialize_text_candidates(candidates: list[TextCandidate]) -> list[dict]: @@ -75,13 +72,12 @@ def serialize_text_candidates(candidates: list[TextCandidate]) -> list[dict]: def deserialize_text_candidate(data: dict, frame_map: dict[int, Frame]) -> TextCandidate: frame = frame_map[data["frame_sequence"]] bbox = safe_construct(BoundingBox, data["bbox"]) - candidate = TextCandidate( + return TextCandidate( frame=frame, bbox=bbox, text=data["text"], ocr_confidence=data["ocr_confidence"], ) - return candidate def deserialize_text_candidates(data: list[dict], frame_map: dict[int, Frame]) -> list[TextCandidate]: diff --git a/detect/checkpoint/serializer.py b/detect/checkpoint/serializer.py index 087ca9a..466b1eb 100644 --- a/detect/checkpoint/serializer.py +++ b/detect/checkpoint/serializer.py @@ -11,7 +11,7 @@ that don't belong to any stage. from __future__ import annotations from core.schema.serializers._common import serialize_dataclass -from core.schema.serializers.detect_pipeline import ( +from core.schema.serializers.pipeline import ( deserialize_pipeline_stats, deserialize_text_candidates, ) diff --git a/detect/checkpoint/storage.py b/detect/checkpoint/storage.py index f098480..f0b6aad 100644 --- a/detect/checkpoint/storage.py +++ b/detect/checkpoint/storage.py @@ -33,83 +33,81 @@ def create_timeline( Returns (timeline_id, checkpoint_id). """ - from core.db.detect import create_timeline as db_create_timeline - from core.db.detect import save_checkpoint - - # Create timeline - timeline = db_create_timeline( - source_video=source_video, - profile_name=profile_name, - source_asset_id=source_asset_id, - fps=fps, - ) - tid = str(timeline.id) - - # Upload frames to MinIO - manifest = save_frames(tid, frames) - - # Store frame metadata on the timeline - frames_meta = [ - { - "sequence": f.sequence, - "chunk_id": getattr(f, "chunk_id", 0), - "timestamp": f.timestamp, - "perceptual_hash": getattr(f, "perceptual_hash", ""), - } - for f in frames - ] - - timeline.frames_prefix = f"{CHECKPOINT_PREFIX}/{tid}/frames/" - timeline.frames_manifest = {str(k): v for k, v in manifest.items()} - timeline.frames_meta = frames_meta - + from core.db.tables import Timeline, Checkpoint from core.db.connection import get_session + with get_session() as session: + timeline = Timeline( + source_video=source_video, + profile_name=profile_name, + source_asset_id=source_asset_id, + fps=fps, + ) session.add(timeline) + session.flush() + tid = str(timeline.id) + + # Upload frames to MinIO + manifest = save_frames(tid, frames) + + frames_meta = [ + { + "sequence": f.sequence, + "chunk_id": getattr(f, "chunk_id", 0), + "timestamp": f.timestamp, + "perceptual_hash": getattr(f, "perceptual_hash", ""), + } + for f in frames + ] + + timeline.frames_prefix = f"{CHECKPOINT_PREFIX}/{tid}/frames/" + timeline.frames_manifest = {str(k): v for k, v in manifest.items()} + timeline.frames_meta = frames_meta + + checkpoint = Checkpoint( + timeline_id=timeline.id, + parent_id=None, + stage_outputs={}, + stats={"frames_extracted": len(frames)}, + ) + session.add(checkpoint) session.commit() + session.refresh(checkpoint) + cid = str(checkpoint.id) - # Create root checkpoint (no parent, no stage outputs yet) - checkpoint = save_checkpoint( - timeline_id=timeline.id, - parent_id=None, - stage_outputs={}, - stats={"frames_extracted": len(frames)}, - ) - - logger.info("Timeline created: %s (%d frames, root checkpoint %s)", - tid, len(frames), checkpoint.id) - return tid, str(checkpoint.id) + logger.info("Timeline created: %s (%d frames, root checkpoint %s)", tid, len(frames), cid) + return tid, cid def get_timeline_frames(timeline_id: str) -> list: """Load frames from a timeline (from MinIO) as Frame objects.""" - from core.db.detect import get_timeline + from core.db.tables import Timeline + from core.db.connection import get_session - timeline = get_timeline(timeline_id) + with get_session() as session: + timeline = session.get(Timeline, UUID(timeline_id)) if not timeline: raise ValueError(f"Timeline not found: {timeline_id}") raw_manifest = timeline.frames_manifest or {} manifest = {int(k): v for k, v in raw_manifest.items()} - frame_metadata = timeline.frames_meta or [] - - return load_frames(manifest, frame_metadata) + return load_frames(manifest, timeline.frames_meta or []) def get_timeline_frames_b64(timeline_id: str) -> list[dict]: """Load frames as base64 JPEG (lightweight, no numpy).""" - from core.db.detect import get_timeline + from core.db.tables import Timeline + from core.db.connection import get_session from .frames import load_frames_b64 - timeline = get_timeline(timeline_id) + with get_session() as session: + timeline = session.get(Timeline, UUID(timeline_id)) if not timeline: raise ValueError(f"Timeline not found: {timeline_id}") raw_manifest = timeline.frames_manifest or {} manifest = {int(k): v for k, v in raw_manifest.items()} - frame_metadata = timeline.frames_meta or [] - - return load_frames_b64(manifest, frame_metadata) + return load_frames_b64(manifest, timeline.frames_meta or []) # --------------------------------------------------------------------------- @@ -132,47 +130,46 @@ def save_stage_output( Carries forward stage outputs from parent + adds the new one. Returns the new checkpoint ID. """ - from core.db.detect import get_checkpoint, save_checkpoint + from core.db.tables import Checkpoint + from core.db.connection import get_session - # Carry forward from parent - parent_outputs = {} - parent_stats = {} - parent_config = {} - if parent_checkpoint_id: - parent = get_checkpoint(parent_checkpoint_id) - if parent: - parent_outputs = dict(parent.stage_outputs or {}) - parent_stats = dict(parent.stats or {}) - parent_config = dict(parent.config_overrides or {}) + with get_session() as session: + parent_outputs = {} + parent_stats = {} + parent_config = {} + if parent_checkpoint_id: + parent = session.get(Checkpoint, UUID(parent_checkpoint_id)) + if parent: + parent_outputs = dict(parent.stage_outputs or {}) + parent_stats = dict(parent.stats or {}) + parent_config = dict(parent.config_overrides or {}) - # Add new stage output - stage_outputs = {**parent_outputs, stage_name: output_json} - - # Merge stats and config - merged_stats = {**parent_stats, **(stats or {})} - merged_config = {**parent_config, **(config_overrides or {})} - - checkpoint = save_checkpoint( - timeline_id=timeline_id, - parent_id=parent_checkpoint_id, - stage_outputs=stage_outputs, - config_overrides=merged_config, - stats=merged_stats, - is_scenario=is_scenario, - scenario_label=scenario_label, - ) + checkpoint = Checkpoint( + timeline_id=UUID(timeline_id), + parent_id=UUID(parent_checkpoint_id) if parent_checkpoint_id else None, + stage_outputs={**parent_outputs, stage_name: output_json}, + config_overrides={**parent_config, **(config_overrides or {})}, + stats={**parent_stats, **(stats or {})}, + is_scenario=is_scenario, + scenario_label=scenario_label, + ) + session.add(checkpoint) + session.commit() + session.refresh(checkpoint) + cid = str(checkpoint.id) logger.info("Checkpoint saved: %s (timeline %s, stage %s, parent %s)", - checkpoint.id, timeline_id, stage_name, parent_checkpoint_id) - return str(checkpoint.id) + cid, timeline_id, stage_name, parent_checkpoint_id) + return cid def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None: """Load a stage's output from a checkpoint.""" - from core.db.detect import get_checkpoint + from core.db.tables import Checkpoint + from core.db.connection import get_session - checkpoint = get_checkpoint(checkpoint_id) + with get_session() as session: + checkpoint = session.get(Checkpoint, UUID(checkpoint_id)) if not checkpoint: return None - return (checkpoint.stage_outputs or {}).get(stage_name) diff --git a/detect/graph.py b/detect/graph.py index 5bfb603..4f0434d 100644 --- a/detect/graph.py +++ b/detect/graph.py @@ -326,6 +326,7 @@ def node_compile_report(state: DetectState) -> dict: _CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1" _frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job) +_latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id class PipelineCancelled(Exception): @@ -361,17 +362,33 @@ def _checkpointing_node(node_name: str, node_fn): if not job_id: return result - from detect.checkpoint import save_checkpoint, save_frames + from detect.checkpoint import save_stage_output, save_frames + from detect.stages.base import _REGISTRY merged = {**state, **result} - # Save frames once (first checkpoint), reuse manifest after + # Save frames once (first node), reuse manifest after manifest = _frames_manifest.get(job_id) if manifest is None and node_name == "extract_frames": manifest = save_frames(job_id, merged.get("frames", [])) _frames_manifest[job_id] = manifest - save_checkpoint(job_id, node_name, stage_index, merged, frames_manifest=manifest) + # Serialize stage output using the stage's serialize_fn if available + stage_cls = _REGISTRY.get(node_name) + serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None) + if serialize_fn: + output_json = serialize_fn(merged, job_id) + else: + output_json = {} + + parent_id = _latest_checkpoint.get(job_id) + new_checkpoint_id = save_stage_output( + timeline_id=job_id, + parent_checkpoint_id=parent_id, + stage_name=node_name, + output_json=output_json, + ) + _latest_checkpoint[job_id] = new_checkpoint_id return result wrapper.__name__ = node_fn.__name__ diff --git a/detect/models.py b/detect/models.py index c1b6889..2e2aeca 100644 --- a/detect/models.py +++ b/detect/models.py @@ -1,11 +1,6 @@ -""" -Re-export pipeline runtime models from core/schema/models/detect_pipeline.py. +"""Re-export pipeline runtime models from core/schema/models/pipeline.py.""" -All models are defined in core/schema/ — this module exists for backward -compatibility so existing imports (from detect.models import Frame) keep working. -""" - -from core.schema.models.detect_pipeline import ( +from core.schema.models.pipeline import ( BoundingBox, BrandDetection, BrandStats, diff --git a/detect/stages/brand_resolver.py b/detect/stages/brand_resolver.py index a9c1e7b..7bc4590 100644 --- a/detect/stages/brand_resolver.py +++ b/detect/stages/brand_resolver.py @@ -4,14 +4,10 @@ Stage 5 — Brand Resolver (discovery mode) Discovery-first brand matching. No static dictionary — all brands live in the DB. Flow: - 1. Check session sightings first (brands already seen in this source) + 1. Check session brands first (brands already seen in this run, in-memory) 2. Check global known brands (accumulated across all runs) 3. Unresolved candidates → escalate to VLM/cloud 4. Confirmed brands get added to DB for future runs - -The resolver is an enricher, not a gatekeeper. Every OCR text candidate -passes through — the question is whether we can resolve it cheaply (DB lookup) -or need to escalate (VLM/cloud). """ from __future__ import annotations @@ -33,41 +29,30 @@ def _normalize(text: str) -> str: def _has_db() -> bool: try: - from core.db.detect import find_brand_by_text as _ - from admin.mpr.media_assets.models import KnownBrand as _ + from core.db import find_brand_by_text as _ return True except (ImportError, Exception): return False def _match_session(text: str, session_brands: dict[str, str]) -> str | None: - """ - Check against session brands (already seen in this source). - - session_brands: {normalized_name: canonical_name, ...} - Includes aliases. - """ - normalized = _normalize(text) - return session_brands.get(normalized) + return session_brands.get(_normalize(text)) def _match_known(text: str, threshold: int) -> tuple[str | None, str | None]: - """ - Check against global known brands in DB. - - Returns (canonical_name, brand_id) or (None, None). - """ + """Check against global known brands in DB. Returns (canonical_name, brand_id) or (None, None).""" if not _has_db(): return None, None - from core.db.detect import find_brand_by_text - brand = find_brand_by_text(text) - if brand: - return brand.canonical_name, str(brand.id) + from core.db import find_brand_by_text, list_brands + from core.db.connection import get_session - # Fuzzy match against all known brands - from core.db.detect import list_all_brands - all_brands = list_all_brands() + with get_session() as session: + brand = find_brand_by_text(session, text) + if brand: + return brand.canonical_name, str(brand.id) + + all_brands = list_brands(session) normalized = _normalize(text) best_brand = None @@ -92,58 +77,62 @@ def _register_brand(canonical_name: str, source: str) -> str | None: if not _has_db(): return None - from core.db.detect import get_or_create_brand - brand, created = get_or_create_brand(canonical_name, source=source) + from core.db import get_or_create_brand + from core.db.connection import get_session + + with get_session() as session: + brand, created = get_or_create_brand(session, canonical_name, source=source) + session.commit() if created: logger.info("New brand discovered: %s (source=%s)", canonical_name, source) return str(brand.id) -def _record_sighting(source_asset_id: str | None, brand_id: str, - brand_name: str, timestamp: float, - confidence: float, source: str): - """Record a brand sighting for this source.""" - if not _has_db() or not source_asset_id: +def _record_airing(timeline_id: str | None, brand_id: str, + frame_seq: int, confidence: float, source: str): + """Record a brand airing on a timeline.""" + if not _has_db() or not timeline_id: return - from core.db.detect import record_sighting - import uuid - asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id - brand_uuid = uuid.UUID(brand_id) if isinstance(brand_id, str) else brand_id - record_sighting(asset_id, brand_uuid, brand_name, timestamp, confidence, source) + from core.db import record_airing + from core.db.connection import get_session + from uuid import UUID + + with get_session() as session: + record_airing( + session, + brand_id=UUID(brand_id), + timeline_id=UUID(timeline_id), + frame_start=frame_seq, + frame_end=frame_seq, + confidence=confidence, + source=source, + ) + session.commit() -def build_session_dict(source_asset_id: str | None) -> dict[str, str]: +def build_session_dict(source_asset_id: str | None = None) -> dict[str, str]: """ - Load session brands from DB for this source. + Load known brands from DB as a session lookup dict. Returns {normalized_name: canonical_name, ...} including aliases. """ - if not _has_db() or not source_asset_id: + if not _has_db(): return {} - from core.db.detect import get_source_sightings - import uuid + from core.db import list_brands + from core.db.connection import get_session - asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id - sightings = get_source_sightings(asset_id) + with get_session() as session: + all_brands = list_brands(session) - session = {} - for s in sightings: - canonical = s.brand_name - session[_normalize(canonical)] = canonical + session_dict = {} + for brand in all_brands: + session_dict[_normalize(brand.canonical_name)] = brand.canonical_name + for alias in (brand.aliases or []): + session_dict[_normalize(alias)] = brand.canonical_name - # Also load aliases from KnownBrand for each sighted brand - if _has_db(): - from core.db.detect import list_all_brands - all_brands = list_all_brands() - sighted_names = {s.brand_name for s in sightings} - for brand in all_brands: - if brand.canonical_name in sighted_names: - for alias in (brand.aliases or []): - session[_normalize(alias)] = brand.canonical_name - - return session + return session_dict def resolve_brands( @@ -158,7 +147,7 @@ def resolve_brands( Match text candidates against known brands (session → global → unresolved). session_brands: pre-loaded session dict (from build_session_dict) - source_asset_id: for recording new sightings in DB + job_id: timeline_id — used to record airings """ if session_brands is None: session_brands = {} @@ -187,7 +176,6 @@ def resolve_brands( brand_name, brand_id = _match_known(text, config.fuzzy_threshold) if brand_name: known_hits += 1 - # Add to session for subsequent candidates in this run session_brands[_normalize(brand_name)] = brand_name if brand_name: @@ -203,11 +191,10 @@ def resolve_brands( ) matched.append(detection) - # Record sighting in DB if brand_id: - _record_sighting( - source_asset_id, brand_id, brand_name, - candidate.frame.timestamp, candidate.ocr_confidence, match_source, + _record_airing( + job_id, brand_id, + candidate.frame.sequence, candidate.ocr_confidence, match_source, ) emit.detection( diff --git a/detect/stages/registry/_serializers.py b/detect/stages/registry/_serializers.py index 1375e6c..c4bf038 100644 --- a/detect/stages/registry/_serializers.py +++ b/detect/stages/registry/_serializers.py @@ -10,7 +10,7 @@ from core.schema.serializers._common import ( serialize_dataclass, serialize_dataclass_list, ) -from core.schema.serializers.detect_pipeline import ( +from core.schema.serializers.pipeline import ( serialize_frame_meta, serialize_frames_with_upload as serialize_frames, deserialize_frames_with_download as deserialize_frames, diff --git a/tests/detect/manual/list_scenarios.py b/tests/detect/manual/list_scenarios.py index e7d81e6..929e8b4 100644 --- a/tests/detect/manual/list_scenarios.py +++ b/tests/detect/manual/list_scenarios.py @@ -35,9 +35,11 @@ logger = logging.getLogger(__name__) def main(): - from core.db.detect import list_scenarios + from core.db import list_scenarios + from core.db.connection import get_session - scenarios = list_scenarios() + with get_session() as session: + scenarios = list_scenarios(session) if not scenarios: logger.info("No scenarios found. Create one with:") diff --git a/tests/detect/manual/seed_scenario.py b/tests/detect/manual/seed_scenario.py index 95abd14..81bf80d 100644 --- a/tests/detect/manual/seed_scenario.py +++ b/tests/detect/manual/seed_scenario.py @@ -121,15 +121,14 @@ def main(): ) # Mark as scenario - from core.db.detect import get_latest_checkpoint + from core.db import get_latest_checkpoint from core.db.connection import get_session - checkpoint = get_latest_checkpoint(branch_id) - if checkpoint: - checkpoint.is_scenario = True - checkpoint.scenario_label = args.label - with get_session() as session: - session.add(checkpoint) + with get_session() as session: + checkpoint = get_latest_checkpoint(session, branch_id) + if checkpoint: + checkpoint.is_scenario = True + checkpoint.scenario_label = args.label session.commit() logger.info("") diff --git a/tests/detect/test_checkpoint.py b/tests/detect/test_checkpoint.py index 620f10d..c530cb3 100644 --- a/tests/detect/test_checkpoint.py +++ b/tests/detect/test_checkpoint.py @@ -7,7 +7,7 @@ import pytest from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate from core.schema.serializers._common import safe_construct -from core.schema.serializers.detect_pipeline import ( +from core.schema.serializers.pipeline import ( serialize_frame_meta, serialize_text_candidate, serialize_text_candidates,