From 291ac8dd403175fb4c2a7efbceb627ecc0d3582d Mon Sep 17 00:00:00 2001 From: buenosairesam Date: Fri, 27 Mar 2026 04:23:21 -0300 Subject: [PATCH] refactor stage 1 --- core/db/detect.py | 105 +++++---- core/db/models.py | 30 ++- core/schema/models/__init__.py | 7 +- core/schema/models/detect_jobs.py | 67 +++--- core/schema/models/stages.py | 64 ++++++ detect/checkpoint/__init__.py | 16 +- detect/checkpoint/replay.py | 8 +- detect/checkpoint/storage.py | 224 +++++++++++------- detect/stages/__init__.py | 18 +- detect/stages/base.py | 168 ++++++++------ detect/stages/edge_detector.py | 313 +++++++++++++++----------- tests/detect/manual/list_scenarios.py | 14 +- tests/detect/manual/seed_scenario.py | 89 +++----- tests/detect/test_stage_registry.py | 15 +- 14 files changed, 688 insertions(+), 450 deletions(-) create mode 100644 core/schema/models/stages.py diff --git a/core/db/detect.py b/core/db/detect.py index c3a320f..58f152c 100644 --- a/core/db/detect.py +++ b/core/db/detect.py @@ -9,7 +9,8 @@ from sqlmodel import select from .connection import get_session from .models import ( - DetectJob, StageCheckpoint, KnownBrand, SourceBrandSighting, + DetectJob, Timeline, Checkpoint, + KnownBrand, SourceBrandSighting, ) @@ -55,72 +56,86 @@ def list_detect_jobs( # --------------------------------------------------------------------------- -# StageCheckpoint +# Timeline # --------------------------------------------------------------------------- -def save_stage_checkpoint(**fields) -> StageCheckpoint: +def create_timeline(**fields) -> Timeline: + timeline = Timeline(**fields) with get_session() as session: - # Upsert: replace if same job_id + stage - job_id = fields.get("job_id") - stage = fields.get("stage") - if job_id and stage: - stmt = select(StageCheckpoint).where( - StageCheckpoint.job_id == job_id, - StageCheckpoint.stage == stage, - ) - existing = session.exec(stmt).first() - if existing: - for k, v in fields.items(): - setattr(existing, k, v) - session.commit() - session.refresh(existing) - return existing + session.add(timeline) + session.commit() + session.refresh(timeline) + return timeline - checkpoint = StageCheckpoint(**fields) + +def get_timeline(timeline_id: UUID) -> Timeline | None: + with get_session() as session: + return session.get(Timeline, timeline_id) + + +# --------------------------------------------------------------------------- +# Checkpoint +# --------------------------------------------------------------------------- + +def save_checkpoint(**fields) -> Checkpoint: + checkpoint = Checkpoint(**fields) + with get_session() as session: session.add(checkpoint) session.commit() session.refresh(checkpoint) return checkpoint -def get_stage_checkpoint(job_id: UUID, stage: str) -> StageCheckpoint | None: +def get_checkpoint(checkpoint_id: UUID) -> Checkpoint | None: with get_session() as session: - stmt = select(StageCheckpoint).where( - StageCheckpoint.job_id == job_id, - StageCheckpoint.stage == stage, + return session.get(Checkpoint, checkpoint_id) + + +def get_latest_checkpoint(timeline_id: UUID, parent_id: UUID | None = None) -> Checkpoint | None: + """Get the most recent checkpoint for a timeline, optionally from a specific parent.""" + with get_session() as session: + stmt = ( + select(Checkpoint) + .where(Checkpoint.timeline_id == timeline_id) + ) + if parent_id is not None: + stmt = stmt.where(Checkpoint.parent_id == parent_id) + stmt = stmt.order_by(Checkpoint.created_at.desc()) + return session.exec(stmt).first() + + +def list_checkpoints(timeline_id: UUID) -> list[Checkpoint]: + """List all checkpoints for a timeline.""" + with get_session() as session: + stmt = ( + select(Checkpoint) + .where(Checkpoint.timeline_id == timeline_id) + .order_by(Checkpoint.created_at) + ) + return list(session.exec(stmt).all()) + + +def get_root_checkpoint(timeline_id: UUID) -> Checkpoint | None: + """Get the root checkpoint (no parent) for a timeline.""" + with get_session() as session: + stmt = select(Checkpoint).where( + Checkpoint.timeline_id == timeline_id, + Checkpoint.parent_id == None, ) return session.exec(stmt).first() -def list_stage_checkpoints(job_id: UUID) -> list[str]: - with get_session() as session: - stmt = ( - select(StageCheckpoint.stage) - .where(StageCheckpoint.job_id == job_id) - .order_by(StageCheckpoint.stage_index) - ) - return list(session.exec(stmt).all()) - - -def list_scenarios() -> list[StageCheckpoint]: +def list_scenarios() -> list[Checkpoint]: """List all checkpoints marked as scenarios.""" with get_session() as session: stmt = ( - select(StageCheckpoint) - .where(StageCheckpoint.is_scenario == True) - .order_by(StageCheckpoint.created_at.desc()) + select(Checkpoint) + .where(Checkpoint.is_scenario == True) + .order_by(Checkpoint.created_at.desc()) ) return list(session.exec(stmt).all()) -def delete_stage_checkpoints(job_id: UUID) -> None: - with get_session() as session: - stmt = select(StageCheckpoint).where(StageCheckpoint.job_id == job_id) - for cp in session.exec(stmt).all(): - session.delete(cp) - session.commit() - - # --------------------------------------------------------------------------- # KnownBrand # --------------------------------------------------------------------------- diff --git a/core/db/models.py b/core/db/models.py index cfacda6..452bbfd 100644 --- a/core/db/models.py +++ b/core/db/models.py @@ -181,24 +181,30 @@ class DetectJob(SQLModel, table=True): started_at: Optional[datetime] = None completed_at: Optional[datetime] = None -class StageCheckpoint(SQLModel, table=True): - """A checkpoint saved after a pipeline stage completes.""" - __tablename__ = "stage_checkpoints" +class Timeline(SQLModel, table=True): + """Frame sequence from a source video. Independent of stages.""" + __tablename__ = "timelines" id: UUID = Field(default_factory=uuid4, primary_key=True) - job_id: UUID = Field(index=True) - stage: str - stage_index: int + source_asset_id: Optional[UUID] = Field(default=None, index=True) + source_video: str = "" + profile_name: str = "" + fps: float = 2.0 frames_prefix: str = "" frames_manifest: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) frames_meta: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]')) - filtered_frame_sequences: List[int] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]')) - stage_output_key: str = "" # s3 key: checkpoints/{job_id}/stages/{stage}.bson - stats: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) - config_snapshot: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) + created_at: Optional[datetime] = Field(default_factory=datetime.utcnow) + +class Checkpoint(SQLModel, table=True): + """Snapshot of pipeline state. parent_id forms a tree.""" + __tablename__ = "checkpoints" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + timeline_id: UUID = Field(index=True) + parent_id: Optional[UUID] = Field(default=None, index=True) + stage_outputs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) - video_path: str = "" - profile_name: str = "" + stats: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) is_scenario: bool = False scenario_label: str = "" created_at: Optional[datetime] = Field(default_factory=datetime.utcnow) diff --git a/core/schema/models/__init__.py b/core/schema/models/__init__.py index b8bdd48..4e873eb 100644 --- a/core/schema/models/__init__.py +++ b/core/schema/models/__init__.py @@ -27,9 +27,11 @@ from .grpc import ( ) from .jobs import ChunkJob, ChunkJobStatus, JobStatus, TranscodeJob from .detect_jobs import ( - DetectJob, DetectJobStatus, RunType, StageCheckpoint, + DetectJob, DetectJobStatus, RunType, + Timeline, Checkpoint, BrandSource, KnownBrand, SourceBrandSighting, ) +from .stages import StageConfigField, StageIO, StageDefinition, STAGE_VIEWS from .media import AssetStatus, MediaAsset from .presets import BUILTIN_PRESETS, TranscodePreset from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generic loader @@ -40,7 +42,8 @@ from .sources import ChunkInfo, SourceJob, SourceType # Core domain models - generates Django, SQLModel, TypeScript DATACLASSES = [MediaAsset, TranscodePreset, TranscodeJob, ChunkJob, - DetectJob, StageCheckpoint, KnownBrand, SourceBrandSighting] + DetectJob, Timeline, Checkpoint, + KnownBrand, SourceBrandSighting] # API request/response models - generates TypeScript only (no Django) # WorkerStatus from grpc.py is reused here diff --git a/core/schema/models/detect_jobs.py b/core/schema/models/detect_jobs.py index f1cf459..e14e730 100644 --- a/core/schema/models/detect_jobs.py +++ b/core/schema/models/detect_jobs.py @@ -72,49 +72,58 @@ class DetectJob: @dataclass -class StageCheckpoint: +class Timeline: """ - A checkpoint saved after a pipeline stage completes. + The frame sequence from a source video. - Binary data (frame images, crops) goes to S3/MinIO. - Everything else (structured state) lives here in Postgres. + Independent of stages — exists before any stage runs. + Stages annotate the timeline, they don't own it. + Frames are stored in MinIO as JPEGs. """ id: UUID - job_id: UUID - stage: str - stage_index: int # position in NODES list (0-7) + source_asset_id: Optional[UUID] = None + source_video: str = "" + profile_name: str = "" + fps: float = 2.0 - # S3 reference for binary data only - frames_prefix: str = "" # s3 prefix: checkpoints/{job_id}/frames/ - - # Frame metadata (non-image fields) + # Frame metadata (images in MinIO, metadata here) + frames_prefix: str = "" # s3: timelines/{id}/frames/ frames_manifest: Dict[int, str] = field(default_factory=dict) # seq → s3 key - frames_meta: List[Dict[str, Any]] = field(default_factory=list) # sequence, chunk_id, timestamp, hash - filtered_frame_sequences: List[int] = field(default_factory=list) + frames_meta: List[Dict[str, Any]] = field(default_factory=list) - # Stage output — stored as blob in MinIO: checkpoints/{job_id}/stages/{stage}.bson - # Each stage's serialize_fn/deserialize_fn owns the format. - # Postgres only stores the S3 key, not the data itself. - stage_output_key: str = "" # s3 key to the serialized stage output + created_at: Optional[datetime] = None - # Pipeline state (small, stays in Postgres) - stats: Dict[str, Any] = field(default_factory=dict) - config_snapshot: Dict[str, Any] = field(default_factory=dict) + +@dataclass +class Checkpoint: + """ + A snapshot of pipeline state on a timeline. + + Stage outputs stored as JSONB — each stage serializes to JSON, + the checkpoint stores it without knowing the shape. + + parent_id forms a tree: multiple children from the same parent + = different config tries from the same starting point. + """ + + id: UUID + timeline_id: UUID + parent_id: Optional[UUID] = None # null = root checkpoint + + # Stage outputs — JSONB per stage, opaque to the checkpoint layer + stage_outputs: Dict[str, Any] = field(default_factory=dict) + + # Config that produced this checkpoint config_overrides: Dict[str, Any] = field(default_factory=dict) - # Input refs (for replay) - video_path: str = "" - profile_name: str = "" + # Pipeline state + stats: Dict[str, Any] = field(default_factory=dict) - # Scenario — a checkpoint bookmarked for the editor workflow. - # Created by seeders (manual scripts that populate state from real footage) - # or captured from a running pipeline. Loaded via URL: - # /detection/?job=#/editor/ + # Scenario bookmark is_scenario: bool = False - scenario_label: str = "" # human-readable name, e.g. "chelsea_edges_lowcanny" + scenario_label: str = "" - # Timestamps created_at: Optional[datetime] = None diff --git a/core/schema/models/stages.py b/core/schema/models/stages.py new file mode 100644 index 0000000..d4bb7d9 --- /dev/null +++ b/core/schema/models/stages.py @@ -0,0 +1,64 @@ +""" +Stage Schema Definitions + +Source of truth for pipeline stage metadata. +Generates: Pydantic, TypeScript via modelgen. + +Each stage is defined by its config fields. The implementation +lives in detect/stages/.py as a Stage subclass. +""" + +from dataclasses import dataclass, field +from typing import Any, List, Optional + + +@dataclass +class StageConfigField: + """A single tunable config parameter for the editor UI.""" + name: str + type: str # "float", "int", "str", "bool" + default: Any + description: str = "" + min: Optional[float] = None + max: Optional[float] = None + options: Optional[List[str]] = None + + +@dataclass +class StageIO: + """Declares what a stage reads and writes.""" + reads: List[str] = field(default_factory=list) + writes: List[str] = field(default_factory=list) + optional_reads: List[str] = field(default_factory=list) + + +@dataclass +class StageDefinition: + """ + Complete metadata for a pipeline stage. + + Lives in schema as the source of truth. Each stage implementation + references a StageDefinition. The editor, graph, and checkpoint + system all consume this. + """ + name: str + label: str + description: str + category: str = "detection" + io: StageIO = field(default_factory=StageIO) + config_fields: List[StageConfigField] = field(default_factory=list) + + # Legacy fields — used by old registry pattern during migration. + # New stages use Stage subclass instead. + fn: Any = None + serialize_fn: Any = None + deserialize_fn: Any = None + + +# --- Export for modelgen --- + +STAGE_VIEWS = [ + StageConfigField, + StageIO, + StageDefinition, +] diff --git a/detect/checkpoint/__init__.py b/detect/checkpoint/__init__.py index 062b5bd..7f30b39 100644 --- a/detect/checkpoint/__init__.py +++ b/detect/checkpoint/__init__.py @@ -1,14 +1,18 @@ """ -Stage checkpoint, replay, and retry. +Checkpoint system — Timeline + Checkpoint tree. detect/checkpoint/ frames.py — frame image S3 upload/download - serializer.py — state ↔ JSON conversion - storage.py — checkpoint save/load/list (Postgres + S3) - replay.py — replay_from, OverrideProfile + storage.py — Timeline + Checkpoint (Postgres + MinIO) + replay.py — replay (TODO: migrate to new model) tasks.py — retry_candidates Celery task """ -from .storage import save_checkpoint, load_checkpoint, list_checkpoints +from .storage import ( + create_timeline, + get_timeline_frames, + get_timeline_frames_b64, + save_stage_output, + load_stage_output, +) from .frames import save_frames, load_frames -from .replay import replay_from, OverrideProfile diff --git a/detect/checkpoint/replay.py b/detect/checkpoint/replay.py index 1805cbe..785d422 100644 --- a/detect/checkpoint/replay.py +++ b/detect/checkpoint/replay.py @@ -12,7 +12,13 @@ import logging import uuid from detect import emit -from detect.checkpoint import load_checkpoint, list_checkpoints +# TODO: migrate to Timeline/Branch/Checkpoint model +# These old functions no longer exist — replay needs rework +def _not_migrated(*args, **kwargs): + raise NotImplementedError("Replay not yet migrated to Timeline/Branch/Checkpoint model") + +load_checkpoint = _not_migrated +list_checkpoints = _not_migrated from detect.graph import NODES, build_graph logger = logging.getLogger(__name__) diff --git a/detect/checkpoint/storage.py b/detect/checkpoint/storage.py index 1157e05..f098480 100644 --- a/detect/checkpoint/storage.py +++ b/detect/checkpoint/storage.py @@ -1,116 +1,178 @@ """ -Checkpoint storage — save/load stage state. +Checkpoint storage — Timeline + Checkpoint (tree of snapshots). -Binary data (frame images) → S3/MinIO via frames.py -Structured data (stage output, stats, config) → Postgres +Timeline: frame sequence from source video (frames in MinIO) +Checkpoint: snapshot of pipeline state (stage outputs as JSONB in Postgres) + parent_id forms a tree — multiple children = different config tries """ from __future__ import annotations import logging +from uuid import UUID from .frames import save_frames, load_frames, CHECKPOINT_PREFIX -from .serializer import serialize_state, deserialize_state logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- -# Save +# Timeline # --------------------------------------------------------------------------- -def save_checkpoint( - job_id: str, - stage: str, - stage_index: int, - state: dict, - frames_manifest: dict[int, str] | None = None, +def create_timeline( + source_video: str, + profile_name: str, + frames: list, + fps: float = 2.0, + source_asset_id: UUID | None = None, +) -> tuple[str, str]: + """ + Create a timeline from frames. Uploads frame images to MinIO, + creates Timeline + root Checkpoint in Postgres. + + Returns (timeline_id, checkpoint_id). + """ + from core.db.detect import create_timeline as db_create_timeline + from core.db.detect import save_checkpoint + + # Create timeline + timeline = db_create_timeline( + source_video=source_video, + profile_name=profile_name, + source_asset_id=source_asset_id, + fps=fps, + ) + tid = str(timeline.id) + + # Upload frames to MinIO + manifest = save_frames(tid, frames) + + # Store frame metadata on the timeline + frames_meta = [ + { + "sequence": f.sequence, + "chunk_id": getattr(f, "chunk_id", 0), + "timestamp": f.timestamp, + "perceptual_hash": getattr(f, "perceptual_hash", ""), + } + for f in frames + ] + + timeline.frames_prefix = f"{CHECKPOINT_PREFIX}/{tid}/frames/" + timeline.frames_manifest = {str(k): v for k, v in manifest.items()} + timeline.frames_meta = frames_meta + + from core.db.connection import get_session + with get_session() as session: + session.add(timeline) + session.commit() + + # Create root checkpoint (no parent, no stage outputs yet) + checkpoint = save_checkpoint( + timeline_id=timeline.id, + parent_id=None, + stage_outputs={}, + stats={"frames_extracted": len(frames)}, + ) + + logger.info("Timeline created: %s (%d frames, root checkpoint %s)", + tid, len(frames), checkpoint.id) + return tid, str(checkpoint.id) + + +def get_timeline_frames(timeline_id: str) -> list: + """Load frames from a timeline (from MinIO) as Frame objects.""" + from core.db.detect import get_timeline + + timeline = get_timeline(timeline_id) + if not timeline: + raise ValueError(f"Timeline not found: {timeline_id}") + + raw_manifest = timeline.frames_manifest or {} + manifest = {int(k): v for k, v in raw_manifest.items()} + frame_metadata = timeline.frames_meta or [] + + return load_frames(manifest, frame_metadata) + + +def get_timeline_frames_b64(timeline_id: str) -> list[dict]: + """Load frames as base64 JPEG (lightweight, no numpy).""" + from core.db.detect import get_timeline + from .frames import load_frames_b64 + + timeline = get_timeline(timeline_id) + if not timeline: + raise ValueError(f"Timeline not found: {timeline_id}") + + raw_manifest = timeline.frames_manifest or {} + manifest = {int(k): v for k, v in raw_manifest.items()} + frame_metadata = timeline.frames_meta or [] + + return load_frames_b64(manifest, frame_metadata) + + +# --------------------------------------------------------------------------- +# Checkpoint +# --------------------------------------------------------------------------- + +def save_stage_output( + timeline_id: str, + parent_checkpoint_id: str | None, + stage_name: str, + output_json: dict, + config_overrides: dict | None = None, + stats: dict | None = None, is_scenario: bool = False, scenario_label: str = "", ) -> str: """ - Save a stage checkpoint. + Save a stage's output as a new checkpoint (child of parent). - Saves frame images to S3 (if not already saved), then persists - structured state to Postgres. - - Returns the checkpoint DB id. + Carries forward stage outputs from parent + adds the new one. + Returns the new checkpoint ID. """ - from core.db.detect import save_stage_checkpoint + from core.db.detect import get_checkpoint, save_checkpoint - if frames_manifest is None: - all_frames = state.get("frames", []) - frames_manifest = save_frames(job_id, all_frames) + # Carry forward from parent + parent_outputs = {} + parent_stats = {} + parent_config = {} + if parent_checkpoint_id: + parent = get_checkpoint(parent_checkpoint_id) + if parent: + parent_outputs = dict(parent.stage_outputs or {}) + parent_stats = dict(parent.stats or {}) + parent_config = dict(parent.config_overrides or {}) - checkpoint_data = serialize_state(state, frames_manifest) - frames_prefix = f"{CHECKPOINT_PREFIX}/{job_id}/frames/" + # Add new stage output + stage_outputs = {**parent_outputs, stage_name: output_json} - checkpoint = save_stage_checkpoint( - job_id=job_id, - stage=stage, - stage_index=stage_index, - frames_prefix=frames_prefix, - frames_manifest=checkpoint_data.get("frames_manifest", {}), - frames_meta=checkpoint_data.get("frames_meta", []), - filtered_frame_sequences=checkpoint_data.get("filtered_frame_sequences", []), - stage_output_key=checkpoint_data.get("stage_output_key", ""), - stats=checkpoint_data.get("stats", {}), - config_snapshot=checkpoint_data.get("config_overrides", {}), - config_overrides=checkpoint_data.get("config_overrides", {}), - video_path=checkpoint_data.get("video_path", ""), - profile_name=checkpoint_data.get("profile_name", ""), + # Merge stats and config + merged_stats = {**parent_stats, **(stats or {})} + merged_config = {**parent_config, **(config_overrides or {})} + + checkpoint = save_checkpoint( + timeline_id=timeline_id, + parent_id=parent_checkpoint_id, + stage_outputs=stage_outputs, + config_overrides=merged_config, + stats=merged_stats, is_scenario=is_scenario, scenario_label=scenario_label, ) - logger.info("Checkpoint saved: %s/%s (id=%s, scenario=%s)", - job_id, stage, checkpoint.id, is_scenario) + logger.info("Checkpoint saved: %s (timeline %s, stage %s, parent %s)", + checkpoint.id, timeline_id, stage_name, parent_checkpoint_id) return str(checkpoint.id) -# --------------------------------------------------------------------------- -# Load -# --------------------------------------------------------------------------- +def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None: + """Load a stage's output from a checkpoint.""" + from core.db.detect import get_checkpoint -def load_checkpoint(job_id: str, stage: str) -> dict: - """ - Load a stage checkpoint and reconstitute full DetectState. - """ - from core.db.detect import get_stage_checkpoint - - checkpoint = get_stage_checkpoint(job_id, stage) + checkpoint = get_checkpoint(checkpoint_id) if not checkpoint: - raise ValueError(f"No checkpoint for {job_id}/{stage}") + return None - data = { - "job_id": str(checkpoint.job_id), - "video_path": checkpoint.video_path, - "profile_name": checkpoint.profile_name, - "config_overrides": checkpoint.config_overrides, - "frames_manifest": checkpoint.frames_manifest, - "frames_meta": checkpoint.frames_meta, - "filtered_frame_sequences": checkpoint.filtered_frame_sequences, - "stage_output_key": checkpoint.stage_output_key, - "stats": checkpoint.stats, - } - - raw_manifest = data.get("frames_manifest", {}) - manifest = {int(k): v for k, v in raw_manifest.items()} - frame_metadata = data.get("frames_meta", []) - frames = load_frames(manifest, frame_metadata) - - state = deserialize_state(data, frames) - - logger.info("Checkpoint loaded: %s/%s (%d frames, scenario=%s)", - job_id, stage, len(frames), checkpoint.is_scenario) - return state - - -# --------------------------------------------------------------------------- -# List -# --------------------------------------------------------------------------- - -def list_checkpoints(job_id: str) -> list[str]: - """List available checkpoint stages for a job.""" - from core.db.detect import list_stage_checkpoints - return list_stage_checkpoints(job_id) + return (checkpoint.stage_outputs or {}).get(stage_name) diff --git a/detect/stages/__init__.py b/detect/stages/__init__.py index 7c66a29..13817aa 100644 --- a/detect/stages/__init__.py +++ b/detect/stages/__init__.py @@ -1,21 +1,21 @@ """ Pipeline stages. -Each stage registers its StageDefinition on import, -declaring IO (what it reads/writes from state), -config fields (what's tunable from the editor), -and serialization (how to checkpoint its outputs). +Each stage is a file with a Stage subclass. Auto-discovered via +__init_subclass__ — importing the file registers the stage. """ from .base import ( - StageDefinition, - StageIO, - StageConfigField, - register_stage, + Stage, get_stage, + get_stage_instance, list_stages, + list_stage_classes, get_palette, ) -# Populate registry with built-in stages +# Import all stage files to trigger auto-registration +from . import edge_detector # noqa: F401 + +# Import registry for backward compat (other stages still use old pattern) from . import registry # noqa: F401 diff --git a/detect/stages/base.py b/detect/stages/base.py index 13e109a..03f7679 100644 --- a/detect/stages/base.py +++ b/detect/stages/base.py @@ -1,101 +1,131 @@ """ -Stage protocol — common interface for all pipeline stages. +Stage base class — common interface for all pipeline stages. -Every stage declares: - - IO: what it reads/writes from DetectState - - Config: tunable parameters for the editor - - Serialization: how to persist/restore its own outputs +Each stage is a file that subclasses Stage. Auto-discovered via +__init_subclass__. No manual registration needed. -The checkpoint layer is a black box — it asks each stage to serialize its -outputs and stores the result. Stages own their data format. Binary data -(frames, crops) goes to S3 via the stage itself. The checkpoint just -stores the JSON envelope. +A stage: + - Has a StageDefinition (from schema) with name, config, IO + - Implements run(frames, config) → output + - Owns its output serialization (opaque blob) + - Optionally has a TypeScript port for browser-side execution -The graph builder uses StageIO to validate that a stage's inputs are -satisfied by previous stages' outputs. +The checkpoint layer stores stage output as blobs without knowing +the format. The stage that wrote it is the only one that can read it. """ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Callable +from typing import Any +import numpy as np -@dataclass -class StageIO: - """Declares what a stage reads and writes from/to DetectState.""" - reads: list[str] - writes: list[str] - optional_reads: list[str] = field(default_factory=list) - - -@dataclass -class StageConfigField: - """A single tunable config parameter for the editor UI.""" - name: str - type: str # "float", "int", "str", "bool", "list[str]" - default: Any - description: str = "" - min: float | None = None - max: float | None = None - options: list[str] | None = None - - -@dataclass -class StageDefinition: - """ - Complete metadata for a pipeline stage. - - The profile editor uses this to build the palette, generate config - forms, and validate graph connections. The checkpoint uses serialize_fn - and deserialize_fn to persist stage outputs without knowing the internals. - """ - name: str - label: str - description: str - io: StageIO - config_fields: list[StageConfigField] = field(default_factory=list) - category: str = "detection" - - # The actual graph node function: (DetectState) → dict - fn: Callable | None = None - - # Stage-owned serialization for checkpointing. - # serialize_fn: (state: dict, job_id: str) → json-compatible dict - # Stage picks its writes from state, serializes them. - # Binary data (frames) → S3 via stage, returns refs. - # deserialize_fn: (data: dict, job_id: str) → state update dict - # Stage restores its writes from the persisted data. - serialize_fn: Callable | None = None - deserialize_fn: Callable | None = None +from core.schema.models.stages import StageConfigField, StageIO, StageDefinition # --------------------------------------------------------------------------- -# Registry +# Registry — auto-populated by __init_subclass__ (new stages) +# + register_stage() (legacy stages during migration) # --------------------------------------------------------------------------- -_REGISTRY: dict[str, StageDefinition] = {} +_REGISTRY: dict[str, type['Stage']] = {} +_LEGACY_REGISTRY: dict[str, StageDefinition] = {} def register_stage(definition: StageDefinition): - _REGISTRY[definition.name] = definition + """Legacy registration for stages not yet converted to Stage subclass.""" + _LEGACY_REGISTRY[definition.name] = definition + + +class Stage: + """ + Base class for all pipeline stages. + + Subclass this in detect/stages/.py. Define `definition` as a + class attribute. Implement `run()`. Optionally override `serialize()` + and `deserialize()` for custom blob formats (default is JSON). + """ + + definition: StageDefinition # set by each subclass + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if hasattr(cls, 'definition') and cls.definition is not None: + _REGISTRY[cls.definition.name] = cls + + def run(self, frames: list, config: dict) -> Any: + """ + Run the stage on a list of frames with the given config. + + Config is a dict of parameter values (from slider UI or profile). + Returns the stage output — whatever shape this stage produces. + Debug overlays are included when config has debug=True. + """ + raise NotImplementedError + + def serialize(self, output: Any) -> bytes: + """Serialize stage output to bytes for checkpoint storage.""" + import json + return json.dumps(output, default=str).encode() + + def deserialize(self, data: bytes) -> Any: + """Deserialize stage output from checkpoint blob.""" + import json + return json.loads(data) + + +# --------------------------------------------------------------------------- +# Discovery API +# --------------------------------------------------------------------------- + +def _all_definitions() -> dict[str, StageDefinition]: + """Merge new Stage subclass registry + legacy registry.""" + merged = {} + # Legacy first, new overwrites (new takes precedence) + for name, defn in _LEGACY_REGISTRY.items(): + merged[name] = defn + for name, cls in _REGISTRY.items(): + merged[name] = cls.definition + return merged def get_stage(name: str) -> StageDefinition: - if name not in _REGISTRY: - raise KeyError(f"Unknown stage: {name!r}. Registered: {list(_REGISTRY)}") - return _REGISTRY[name] + """Get a stage definition by name (works for both new and legacy).""" + all_defs = _all_definitions() + if name not in all_defs: + raise KeyError(f"Unknown stage: {name!r}. Registered: {list(all_defs)}") + return all_defs[name] + + +def get_stage_class(name: str) -> type[Stage] | None: + """Get a Stage subclass by name. Returns None for legacy stages.""" + return _REGISTRY.get(name) + + +def get_stage_instance(name: str) -> Stage: + """Get an instantiated Stage by name. Only works for new-style stages.""" + cls = _REGISTRY.get(name) + if cls is None: + raise KeyError(f"No Stage subclass for {name!r}. Legacy stages don't have instances.") + return cls() def list_stages() -> list[StageDefinition]: + """List all registered stage definitions (new + legacy).""" + return list(_all_definitions().values()) + + +def list_stage_classes() -> list[type[Stage]]: + """List all registered Stage subclasses (new-style only).""" return list(_REGISTRY.values()) def get_palette() -> dict[str, list[StageDefinition]]: """Group stages by category for the editor palette.""" palette: dict[str, list[StageDefinition]] = {} - for stage in _REGISTRY.values(): - if stage.category not in palette: - palette[stage.category] = [] - palette[stage.category].append(stage) + for defn in _all_definitions().values(): + if defn.category not in palette: + palette[defn.category] = [] + palette[defn.category].append(defn) return palette diff --git a/detect/stages/edge_detector.py b/detect/stages/edge_detector.py index 405f6cb..063bb1c 100644 --- a/detect/stages/edge_detector.py +++ b/detect/stages/edge_detector.py @@ -7,168 +7,227 @@ advertising hoardings. Pure OpenCV, no ML models. Two modes: - Remote: calls GPU inference server over HTTP - Local: imports cv2 directly (OpenCV on same machine) - -Emits frame_update events with bounding boxes for the frame viewer. """ from __future__ import annotations import base64 import io +import json import logging +import os import time +from typing import Any from PIL import Image from detect import emit from detect.models import BoundingBox, Frame -from detect.profiles.base import RegionAnalysisConfig +from detect.stages.base import Stage +from core.schema.models.stages import StageDefinition, StageConfigField, StageIO + logger = logging.getLogger(__name__) +class EdgeDetectionStage(Stage): + + definition = StageDefinition( + name="detect_edges", + label="Edge Detection", + description="Canny + HoughLinesP — find horizontal line pairs (hoarding boundaries)", + category="cv_analysis", + io=StageIO( + reads=["filtered_frames"], + writes=["edge_regions_by_frame"], + ), + config_fields=[ + StageConfigField("enabled", "bool", True, "Enable edge detection"), + StageConfigField("edge_canny_low", "int", 50, "Canny low threshold", min=0, max=255), + StageConfigField("edge_canny_high", "int", 150, "Canny high threshold", min=0, max=255), + StageConfigField("edge_hough_threshold", "int", 80, "Hough accumulator threshold", min=1, max=500), + StageConfigField("edge_hough_min_length", "int", 100, "Min line length (px)", min=10, max=2000), + StageConfigField("edge_hough_max_gap", "int", 10, "Max line gap (px)", min=1, max=100), + StageConfigField("edge_pair_max_distance", "int", 200, "Max distance between line pair (px)", min=10, max=500), + StageConfigField("edge_pair_min_distance", "int", 15, "Min distance between line pair (px)", min=5, max=200), + ], + ) + + def run(self, frames: list[Frame], config: dict) -> dict[int, list[BoundingBox]]: + """ + Run edge detection on all frames. + + Config keys: enabled, edge_canny_low, edge_canny_high, edge_hough_threshold, + edge_hough_min_length, edge_hough_max_gap, edge_pair_max_distance, edge_pair_min_distance, + debug (bool), inference_url (str|None), job_id (str|None). + + Returns dict mapping frame sequence → list of BoundingBox. + """ + enabled = config.get("enabled", True) + job_id = config.get("job_id") + inference_url = config.get("inference_url") or os.environ.get("INFERENCE_URL") + + if not enabled: + emit.log(job_id, "EdgeDetection", "INFO", "Edge detection disabled, skipping") + return {} + + mode = "remote" if inference_url else "local" + emit.log(job_id, "EdgeDetection", "INFO", + f"Detecting edges in {len(frames)} frames (mode={mode})") + + all_boxes: dict[int, list[BoundingBox]] = {} + total_regions = 0 + + for frame in frames: + t0 = time.monotonic() + if inference_url: + boxes = self._run_remote(frame, config, inference_url, job_id or "") + else: + boxes = self._run_local(frame, config) + ms = (time.monotonic() - t0) * 1000 + + all_boxes[frame.sequence] = boxes + total_regions += len(boxes) + + emit.log(job_id, "EdgeDetection", "DEBUG", + f"Frame {frame.sequence}: {len(boxes)} regions in {ms:.0f}ms" + + (f" [{', '.join(b.label for b in boxes)}]" if boxes else "")) + + if boxes and job_id: + box_dicts = [ + {"x": b.x, "y": b.y, "w": b.w, "h": b.h, + "confidence": b.confidence, "label": b.label, + "stage": "detect_edges"} + for b in boxes + ] + emit.frame_update( + job_id, + frame_ref=frame.sequence, + timestamp=frame.timestamp, + jpeg_b64=_frame_to_b64(frame), + boxes=box_dicts, + ) + + emit.log(job_id, "EdgeDetection", "INFO", + f"Found {total_regions} edge regions across {len(frames)} frames") + emit.stats(job_id, cv_regions_detected=total_regions) + + return all_boxes + + def serialize(self, output: Any) -> bytes: + """Serialize edge regions to JSON blob.""" + serialized = {} + for seq, boxes in output.items(): + serialized[str(seq)] = [ + {"x": b.x, "y": b.y, "w": b.w, "h": b.h, + "confidence": b.confidence, "label": b.label} + for b in boxes + ] + return json.dumps(serialized).encode() + + def deserialize(self, data: bytes) -> dict[int, list[BoundingBox]]: + """Deserialize edge regions from JSON blob.""" + raw = json.loads(data) + result = {} + for seq_str, box_dicts in raw.items(): + boxes = [ + BoundingBox(x=b["x"], y=b["y"], w=b["w"], h=b["h"], + confidence=b["confidence"], label=b["label"]) + for b in box_dicts + ] + result[int(seq_str)] = boxes + return result + + # --- Private helpers --- + + def _run_remote(self, frame: Frame, config: dict, + inference_url: str, job_id: str) -> list[BoundingBox]: + from detect.inference import InferenceClient + from detect.emit import _run_log_level + + client = InferenceClient( + base_url=inference_url, job_id=job_id, log_level=_run_log_level, + ) + results = client.detect_edges( + image=frame.image, + edge_canny_low=config.get("edge_canny_low", 50), + edge_canny_high=config.get("edge_canny_high", 150), + edge_hough_threshold=config.get("edge_hough_threshold", 80), + edge_hough_min_length=config.get("edge_hough_min_length", 100), + edge_hough_max_gap=config.get("edge_hough_max_gap", 10), + edge_pair_max_distance=config.get("edge_pair_max_distance", 200), + edge_pair_min_distance=config.get("edge_pair_min_distance", 15), + ) + boxes = [] + for r in results: + box = BoundingBox( + x=r.x, y=r.y, w=r.w, h=r.h, + confidence=r.confidence, label=r.label, + ) + boxes.append(box) + return boxes + + def _run_local(self, frame: Frame, config: dict) -> list[BoundingBox]: + detect_edges_fn = _load_cv_edges().detect_edges + + edge_results = detect_edges_fn( + frame.image, + canny_low=config.get("edge_canny_low", 50), + canny_high=config.get("edge_canny_high", 150), + hough_threshold=config.get("edge_hough_threshold", 80), + hough_min_length=config.get("edge_hough_min_length", 100), + hough_max_gap=config.get("edge_hough_max_gap", 10), + pair_max_distance=config.get("edge_pair_max_distance", 200), + pair_min_distance=config.get("edge_pair_min_distance", 15), + ) + + boxes = [] + for r in edge_results: + box = BoundingBox( + x=r["x"], y=r["y"], w=r["w"], h=r["h"], + confidence=r["confidence"], label=r["label"], + ) + boxes.append(box) + return boxes + + +# --- Module-level helpers --- + def _frame_to_b64(frame: Frame) -> str: - """Encode frame as base64 JPEG for SSE frame_update events.""" img = Image.fromarray(frame.image) buf = io.BytesIO() img.save(buf, format="JPEG", quality=70) return base64.b64encode(buf.getvalue()).decode() -def _detect_remote( - frame: Frame, - config: RegionAnalysisConfig, - inference_url: str, - job_id: str = "", - log_level: str = "INFO", -) -> list[BoundingBox]: - """Call the inference server over HTTP.""" - from detect.inference import InferenceClient - - client = InferenceClient( - base_url=inference_url, job_id=job_id, log_level=log_level, - ) - results = client.detect_edges( - image=frame.image, - edge_canny_low=config.edge_canny_low, - edge_canny_high=config.edge_canny_high, - edge_hough_threshold=config.edge_hough_threshold, - edge_hough_min_length=config.edge_hough_min_length, - edge_hough_max_gap=config.edge_hough_max_gap, - edge_pair_max_distance=config.edge_pair_max_distance, - edge_pair_min_distance=config.edge_pair_min_distance, - ) - boxes = [] - for r in results: - box = BoundingBox( - x=r.x, y=r.y, w=r.w, h=r.h, - confidence=r.confidence, label=r.label, - ) - boxes.append(box) - return boxes - - _cv_edges_mod = None - def _load_cv_edges(): - """Load edges module directly — gpu/models/__init__.py has GPU-container-only imports.""" global _cv_edges_mod if _cv_edges_mod is None: import importlib.util from pathlib import Path - spec = importlib.util.spec_from_file_location("cv_edges", Path("gpu/models/cv/edges.py")) _cv_edges_mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(_cv_edges_mod) return _cv_edges_mod -def _detect_local(frame: Frame, config: RegionAnalysisConfig) -> list[BoundingBox]: - """Run edge detection in-process (requires opencv-python).""" - detect_edges_fn = _load_cv_edges().detect_edges +# --- Backward compat: standalone function for graph.py --- - edge_results = detect_edges_fn( - frame.image, - canny_low=config.edge_canny_low, - canny_high=config.edge_canny_high, - hough_threshold=config.edge_hough_threshold, - hough_min_length=config.edge_hough_min_length, - hough_max_gap=config.edge_hough_max_gap, - pair_max_distance=config.edge_pair_max_distance, - pair_min_distance=config.edge_pair_min_distance, - ) - - boxes = [] - for r in edge_results: - box = BoundingBox( - x=r["x"], y=r["y"], w=r["w"], h=r["h"], - confidence=r["confidence"], label=r["label"], - ) - boxes.append(box) - return boxes - - -def detect_edge_regions( - frames: list[Frame], - config: RegionAnalysisConfig, - inference_url: str | None = None, - job_id: str | None = None, -) -> dict[int, list[BoundingBox]]: - """ - Run edge detection on all frames. - - Returns a dict mapping frame sequence → list of bounding boxes. - """ - if not config.enabled: - emit.log(job_id, "EdgeDetection", "INFO", "Edge detection disabled, skipping") - return {} - - mode = "remote" if inference_url else "local" - emit.log(job_id, "EdgeDetection", "INFO", - f"Detecting edges in {len(frames)} frames (mode={mode})") - - all_boxes: dict[int, list[BoundingBox]] = {} - total_regions = 0 - - for i, frame in enumerate(frames): - t0 = time.monotonic() - if inference_url: - from detect.emit import _run_log_level - boxes = _detect_remote( - frame, config, inference_url, - job_id=job_id or "", log_level=_run_log_level, - ) - else: - boxes = _detect_local(frame, config) - analysis_ms = (time.monotonic() - t0) * 1000 - - all_boxes[frame.sequence] = boxes - total_regions += len(boxes) - - emit.log(job_id, "EdgeDetection", "DEBUG", - f"Frame {frame.sequence}: {len(boxes)} regions in {analysis_ms:.0f}ms" - + (f" [{', '.join(b.label for b in boxes)}]" if boxes else "")) - - if boxes and job_id: - box_dicts = [ - { - "x": b.x, "y": b.y, "w": b.w, "h": b.h, - "confidence": b.confidence, "label": b.label, - "stage": "detect_edges", - } - for b in boxes - ] - emit.frame_update( - job_id, - frame_ref=frame.sequence, - timestamp=frame.timestamp, - jpeg_b64=_frame_to_b64(frame), - boxes=box_dicts, - ) - - emit.log(job_id, "EdgeDetection", "INFO", - f"Found {total_regions} edge regions across {len(frames)} frames") - emit.stats(job_id, cv_regions_detected=total_regions) - - return all_boxes +def detect_edge_regions(frames, config, inference_url=None, job_id=None): + """Convenience wrapper — calls EdgeDetectionStage.run().""" + stage = EdgeDetectionStage() + cfg = { + "enabled": config.enabled, + "edge_canny_low": config.edge_canny_low, + "edge_canny_high": config.edge_canny_high, + "edge_hough_threshold": config.edge_hough_threshold, + "edge_hough_min_length": config.edge_hough_min_length, + "edge_hough_max_gap": config.edge_hough_max_gap, + "edge_pair_max_distance": config.edge_pair_max_distance, + "edge_pair_min_distance": config.edge_pair_min_distance, + "inference_url": inference_url, + "job_id": job_id, + } + return stage.run(frames, cfg) diff --git a/tests/detect/manual/list_scenarios.py b/tests/detect/manual/list_scenarios.py index cd126c3..e7d81e6 100644 --- a/tests/detect/manual/list_scenarios.py +++ b/tests/detect/manual/list_scenarios.py @@ -45,15 +45,15 @@ def main(): return logger.info("") - logger.info("%3s %-35s %-12s %-18s %6s %s", "#", "Label", "Job ID", "Stage", "Frames", "Created") - logger.info("─" * 100) + logger.info("%3s %-35s %-12s %6s %s", "#", "Label", "Timeline", "Stages", "Created") + logger.info("─" * 80) for i, s in enumerate(scenarios, 1): - manifest = s.frames_manifest or {} created = str(s.created_at)[:19] if s.created_at else "—" - job_short = str(s.job_id)[:8] - logger.info("%3d %-35s %-12s %-18s %6d %s", - i, s.scenario_label, job_short, s.stage, len(manifest), created) + tid_short = str(s.timeline_id)[:8] + stage_count = len(s.stage_outputs or {}) + logger.info("%3d %-35s %-12s %6d %s", + i, s.scenario_label, tid_short, stage_count, created) logger.info("") @@ -73,7 +73,7 @@ def main(): logger.error("Scenario not found: %s", args.open) return - url = f"{args.base_url}?job={target.job_id}#/editor/detect_edges" + url = f"{args.base_url}?job={target.timeline_id}#/editor/detect_edges" logger.info("Opening: %s", url) webbrowser.open(url) else: diff --git a/tests/detect/manual/seed_scenario.py b/tests/detect/manual/seed_scenario.py index 3bef4b1..95abd14 100644 --- a/tests/detect/manual/seed_scenario.py +++ b/tests/detect/manual/seed_scenario.py @@ -1,26 +1,20 @@ #!/usr/bin/env python3 """ -Seed a scenario checkpoint from a video chunk. +Seed a scenario from a video chunk. -Extracts frames via ffmpeg, uploads to MinIO, creates a StageCheckpoint -in Postgres marked as a scenario. No pipeline, no Redis, no SSE. +Creates a Timeline (frames in MinIO) + Branch + Checkpoint marked +as a scenario. No pipeline, no Redis, no SSE. Prerequisites: - - Postgres reachable (port-forward or local) - - MinIO reachable (port-forward or local) + - Postgres reachable (Kind NodePort or local) + - MinIO reachable (Kind NodePort or local) Usage: - # With K8s port-forwards: - kubectl port-forward svc/postgres 5432:5432 & - kubectl port-forward svc/minio 9000:9000 & - python tests/detect/manual/seed_scenario.py - - # Custom video: python tests/detect/manual/seed_scenario.py --video media/mpr/out/chunks/.../chunk_0001.mp4 Then open: - http://mpr.local.ar/detection/?job=#/editor/detect_edges + http://mpr.local.ar/detection/?job=#/editor/detect_edges """ from __future__ import annotations @@ -31,7 +25,7 @@ import os import sys import uuid -parser = argparse.ArgumentParser(description="Seed a scenario checkpoint") +parser = argparse.ArgumentParser(description="Seed a scenario") parser.add_argument("--video", default="media/mpr/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4") parser.add_argument("--label", default="chelsea_edges_default", @@ -44,7 +38,6 @@ parser.add_argument("--s3-url", default=os.environ.get("S3_ENDPOINT_URL", "http://localhost:9000")) args = parser.parse_args() -# Set env before imports os.environ["DATABASE_URL"] = args.db_url os.environ["S3_ENDPOINT_URL"] = args.s3_url os.environ.setdefault("AWS_ACCESS_KEY_ID", "minioadmin") @@ -57,7 +50,7 @@ logger = logging.getLogger(__name__) def extract_frames_ffmpeg(video_path: str, fps: float, max_frames: int): - """Extract frames using ffmpeg subprocess — no pipeline dependencies.""" + """Extract frames using ffmpeg — no pipeline dependencies.""" import subprocess import tempfile from pathlib import Path @@ -82,7 +75,7 @@ def extract_frames_ffmpeg(video_path: str, fps: float, max_frames: int): frames = [] for jpg in sorted(Path(tmpdir).glob("frame_*.jpg")): - seq = int(jpg.stem.split("_")[1]) - 1 # 0-indexed + seq = int(jpg.stem.split("_")[1]) - 1 img = Image.open(jpg).convert("RGB") image_array = np.array(img) frame = Frame( @@ -99,7 +92,6 @@ def extract_frames_ffmpeg(video_path: str, fps: float, max_frames: int): def main(): - job_id = str(uuid.uuid4()) video_path = args.video if not os.path.exists(video_path): @@ -107,7 +99,6 @@ def main(): sys.exit(1) logger.info("Video: %s", video_path) - logger.info("Job ID: %s", job_id) logger.info("Label: %s", args.label) # Ensure DB tables exist @@ -119,57 +110,37 @@ def main(): frames = extract_frames_ffmpeg(video_path, args.fps, args.max_frames) logger.info("Extracted %d frames", len(frames)) - # Upload frames to MinIO - from detect.checkpoint.frames import save_frames - logger.info("Uploading frames to MinIO...") - manifest = save_frames(job_id, frames) - logger.info("Uploaded %d frames", len(manifest)) + # Create timeline + branch + checkpoint + from detect.checkpoint.storage import create_timeline, save_stage_output - # Build frame metadata - frames_meta = [ - { - "sequence": f.sequence, - "chunk_id": f.chunk_id, - "timestamp": f.timestamp, - "perceptual_hash": "", - } - for f in frames - ] - - # All frames are "filtered" (no scene filter ran) - filtered_sequences = [f.sequence for f in frames] - - # Save checkpoint as scenario - from core.db.detect import save_stage_checkpoint - from detect.checkpoint.frames import CHECKPOINT_PREFIX - - checkpoint = save_stage_checkpoint( - job_id=job_id, - stage="filter_scenes", - stage_index=1, - frames_prefix=f"{CHECKPOINT_PREFIX}/{job_id}/frames/", - frames_manifest={str(k): v for k, v in manifest.items()}, - frames_meta=frames_meta, - filtered_frame_sequences=filtered_sequences, - stage_output_key="", - stats={"frames_extracted": len(frames), "frames_after_scene_filter": len(frames)}, - config_snapshot={}, - config_overrides={}, - video_path=video_path, + timeline_id, branch_id = create_timeline( + source_video=video_path, profile_name="soccer_broadcast", - is_scenario=True, - scenario_label=args.label, + frames=frames, + fps=args.fps, ) + # Mark as scenario + from core.db.detect import get_latest_checkpoint + from core.db.connection import get_session + + checkpoint = get_latest_checkpoint(branch_id) + if checkpoint: + checkpoint.is_scenario = True + checkpoint.scenario_label = args.label + with get_session() as session: + session.add(checkpoint) + session.commit() + logger.info("") logger.info("Scenario created:") - logger.info(" ID: %s", checkpoint.id) - logger.info(" Job: %s", job_id) + logger.info(" Timeline: %s", timeline_id) + logger.info(" Branch: %s", branch_id) logger.info(" Label: %s", args.label) logger.info(" Frames: %d", len(frames)) logger.info("") logger.info("Open in editor:") - logger.info(" http://mpr.local.ar/detection/?job=%s#/editor/detect_edges", job_id) + logger.info(" http://mpr.local.ar/detection/?job=%s#/editor/detect_edges", timeline_id) if __name__ == "__main__": diff --git a/tests/detect/test_stage_registry.py b/tests/detect/test_stage_registry.py index 6419989..8a94bee 100644 --- a/tests/detect/test_stage_registry.py +++ b/tests/detect/test_stage_registry.py @@ -1,6 +1,7 @@ """Tests for the stage registry.""" from detect.stages import list_stages, get_stage, get_palette +from detect.stages.base import get_stage_class EXPECTED_STAGES = [ @@ -26,9 +27,17 @@ def test_stage_has_io(): def test_stage_has_serialization(): for name in EXPECTED_STAGES: - stage = get_stage(name) - assert stage.serialize_fn is not None, f"{name} has no serialize_fn" - assert stage.deserialize_fn is not None, f"{name} has no deserialize_fn" + defn = get_stage(name) + stage_cls = get_stage_class(name) + if stage_cls is not None: + # New-style: serialization on the class + instance = stage_cls() + assert hasattr(instance, 'serialize'), f"{name} has no serialize method" + assert hasattr(instance, 'deserialize'), f"{name} has no deserialize method" + else: + # Legacy: serialization on the definition + assert defn.serialize_fn is not None, f"{name} has no serialize_fn" + assert defn.deserialize_fn is not None, f"{name} has no deserialize_fn" def test_palette_groups():