refactor stage 1

This commit is contained in:
2026-03-27 04:23:21 -03:00
parent df6bcb01e8
commit 291ac8dd40
14 changed files with 688 additions and 450 deletions

View File

@@ -9,7 +9,8 @@ from sqlmodel import select
from .connection import get_session
from .models import (
DetectJob, StageCheckpoint, KnownBrand, SourceBrandSighting,
DetectJob, Timeline, Checkpoint,
KnownBrand, SourceBrandSighting,
)
@@ -55,72 +56,86 @@ def list_detect_jobs(
# ---------------------------------------------------------------------------
# StageCheckpoint
# Timeline
# ---------------------------------------------------------------------------
def save_stage_checkpoint(**fields) -> StageCheckpoint:
def create_timeline(**fields) -> Timeline:
timeline = Timeline(**fields)
with get_session() as session:
# Upsert: replace if same job_id + stage
job_id = fields.get("job_id")
stage = fields.get("stage")
if job_id and stage:
stmt = select(StageCheckpoint).where(
StageCheckpoint.job_id == job_id,
StageCheckpoint.stage == stage,
)
existing = session.exec(stmt).first()
if existing:
for k, v in fields.items():
setattr(existing, k, v)
session.add(timeline)
session.commit()
session.refresh(existing)
return existing
session.refresh(timeline)
return timeline
checkpoint = StageCheckpoint(**fields)
def get_timeline(timeline_id: UUID) -> Timeline | None:
with get_session() as session:
return session.get(Timeline, timeline_id)
# ---------------------------------------------------------------------------
# Checkpoint
# ---------------------------------------------------------------------------
def save_checkpoint(**fields) -> Checkpoint:
checkpoint = Checkpoint(**fields)
with get_session() as session:
session.add(checkpoint)
session.commit()
session.refresh(checkpoint)
return checkpoint
def get_stage_checkpoint(job_id: UUID, stage: str) -> StageCheckpoint | None:
def get_checkpoint(checkpoint_id: UUID) -> Checkpoint | None:
with get_session() as session:
stmt = select(StageCheckpoint).where(
StageCheckpoint.job_id == job_id,
StageCheckpoint.stage == stage,
return session.get(Checkpoint, checkpoint_id)
def get_latest_checkpoint(timeline_id: UUID, parent_id: UUID | None = None) -> Checkpoint | None:
"""Get the most recent checkpoint for a timeline, optionally from a specific parent."""
with get_session() as session:
stmt = (
select(Checkpoint)
.where(Checkpoint.timeline_id == timeline_id)
)
if parent_id is not None:
stmt = stmt.where(Checkpoint.parent_id == parent_id)
stmt = stmt.order_by(Checkpoint.created_at.desc())
return session.exec(stmt).first()
def list_checkpoints(timeline_id: UUID) -> list[Checkpoint]:
"""List all checkpoints for a timeline."""
with get_session() as session:
stmt = (
select(Checkpoint)
.where(Checkpoint.timeline_id == timeline_id)
.order_by(Checkpoint.created_at)
)
return list(session.exec(stmt).all())
def get_root_checkpoint(timeline_id: UUID) -> Checkpoint | None:
"""Get the root checkpoint (no parent) for a timeline."""
with get_session() as session:
stmt = select(Checkpoint).where(
Checkpoint.timeline_id == timeline_id,
Checkpoint.parent_id == None,
)
return session.exec(stmt).first()
def list_stage_checkpoints(job_id: UUID) -> list[str]:
with get_session() as session:
stmt = (
select(StageCheckpoint.stage)
.where(StageCheckpoint.job_id == job_id)
.order_by(StageCheckpoint.stage_index)
)
return list(session.exec(stmt).all())
def list_scenarios() -> list[StageCheckpoint]:
def list_scenarios() -> list[Checkpoint]:
"""List all checkpoints marked as scenarios."""
with get_session() as session:
stmt = (
select(StageCheckpoint)
.where(StageCheckpoint.is_scenario == True)
.order_by(StageCheckpoint.created_at.desc())
select(Checkpoint)
.where(Checkpoint.is_scenario == True)
.order_by(Checkpoint.created_at.desc())
)
return list(session.exec(stmt).all())
def delete_stage_checkpoints(job_id: UUID) -> None:
with get_session() as session:
stmt = select(StageCheckpoint).where(StageCheckpoint.job_id == job_id)
for cp in session.exec(stmt).all():
session.delete(cp)
session.commit()
# ---------------------------------------------------------------------------
# KnownBrand
# ---------------------------------------------------------------------------

View File

@@ -181,24 +181,30 @@ class DetectJob(SQLModel, table=True):
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
class StageCheckpoint(SQLModel, table=True):
"""A checkpoint saved after a pipeline stage completes."""
__tablename__ = "stage_checkpoints"
class Timeline(SQLModel, table=True):
"""Frame sequence from a source video. Independent of stages."""
__tablename__ = "timelines"
id: UUID = Field(default_factory=uuid4, primary_key=True)
job_id: UUID = Field(index=True)
stage: str
stage_index: int
source_asset_id: Optional[UUID] = Field(default=None, index=True)
source_video: str = ""
profile_name: str = ""
fps: float = 2.0
frames_prefix: str = ""
frames_manifest: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
frames_meta: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
filtered_frame_sequences: List[int] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
stage_output_key: str = "" # s3 key: checkpoints/{job_id}/stages/{stage}.bson
stats: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
config_snapshot: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
class Checkpoint(SQLModel, table=True):
"""Snapshot of pipeline state. parent_id forms a tree."""
__tablename__ = "checkpoints"
id: UUID = Field(default_factory=uuid4, primary_key=True)
timeline_id: UUID = Field(index=True)
parent_id: Optional[UUID] = Field(default=None, index=True)
stage_outputs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
video_path: str = ""
profile_name: str = ""
stats: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
is_scenario: bool = False
scenario_label: str = ""
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)

View File

@@ -27,9 +27,11 @@ from .grpc import (
)
from .jobs import ChunkJob, ChunkJobStatus, JobStatus, TranscodeJob
from .detect_jobs import (
DetectJob, DetectJobStatus, RunType, StageCheckpoint,
DetectJob, DetectJobStatus, RunType,
Timeline, Checkpoint,
BrandSource, KnownBrand, SourceBrandSighting,
)
from .stages import StageConfigField, StageIO, StageDefinition, STAGE_VIEWS
from .media import AssetStatus, MediaAsset
from .presets import BUILTIN_PRESETS, TranscodePreset
from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generic loader
@@ -40,7 +42,8 @@ from .sources import ChunkInfo, SourceJob, SourceType
# Core domain models - generates Django, SQLModel, TypeScript
DATACLASSES = [MediaAsset, TranscodePreset, TranscodeJob, ChunkJob,
DetectJob, StageCheckpoint, KnownBrand, SourceBrandSighting]
DetectJob, Timeline, Checkpoint,
KnownBrand, SourceBrandSighting]
# API request/response models - generates TypeScript only (no Django)
# WorkerStatus from grpc.py is reused here

View File

@@ -72,49 +72,58 @@ class DetectJob:
@dataclass
class StageCheckpoint:
class Timeline:
"""
A checkpoint saved after a pipeline stage completes.
The frame sequence from a source video.
Binary data (frame images, crops) goes to S3/MinIO.
Everything else (structured state) lives here in Postgres.
Independent of stages — exists before any stage runs.
Stages annotate the timeline, they don't own it.
Frames are stored in MinIO as JPEGs.
"""
id: UUID
job_id: UUID
stage: str
stage_index: int # position in NODES list (0-7)
source_asset_id: Optional[UUID] = None
source_video: str = ""
profile_name: str = ""
fps: float = 2.0
# S3 reference for binary data only
frames_prefix: str = "" # s3 prefix: checkpoints/{job_id}/frames/
# Frame metadata (non-image fields)
# Frame metadata (images in MinIO, metadata here)
frames_prefix: str = "" # s3: timelines/{id}/frames/
frames_manifest: Dict[int, str] = field(default_factory=dict) # seq → s3 key
frames_meta: List[Dict[str, Any]] = field(default_factory=list) # sequence, chunk_id, timestamp, hash
filtered_frame_sequences: List[int] = field(default_factory=list)
frames_meta: List[Dict[str, Any]] = field(default_factory=list)
# Stage output — stored as blob in MinIO: checkpoints/{job_id}/stages/{stage}.bson
# Each stage's serialize_fn/deserialize_fn owns the format.
# Postgres only stores the S3 key, not the data itself.
stage_output_key: str = "" # s3 key to the serialized stage output
created_at: Optional[datetime] = None
# Pipeline state (small, stays in Postgres)
stats: Dict[str, Any] = field(default_factory=dict)
config_snapshot: Dict[str, Any] = field(default_factory=dict)
@dataclass
class Checkpoint:
"""
A snapshot of pipeline state on a timeline.
Stage outputs stored as JSONB — each stage serializes to JSON,
the checkpoint stores it without knowing the shape.
parent_id forms a tree: multiple children from the same parent
= different config tries from the same starting point.
"""
id: UUID
timeline_id: UUID
parent_id: Optional[UUID] = None # null = root checkpoint
# Stage outputs — JSONB per stage, opaque to the checkpoint layer
stage_outputs: Dict[str, Any] = field(default_factory=dict)
# Config that produced this checkpoint
config_overrides: Dict[str, Any] = field(default_factory=dict)
# Input refs (for replay)
video_path: str = ""
profile_name: str = ""
# Pipeline state
stats: Dict[str, Any] = field(default_factory=dict)
# Scenario — a checkpoint bookmarked for the editor workflow.
# Created by seeders (manual scripts that populate state from real footage)
# or captured from a running pipeline. Loaded via URL:
# /detection/?job=<job_id>#/editor/<stage>
# Scenario bookmark
is_scenario: bool = False
scenario_label: str = "" # human-readable name, e.g. "chelsea_edges_lowcanny"
scenario_label: str = ""
# Timestamps
created_at: Optional[datetime] = None

View File

@@ -0,0 +1,64 @@
"""
Stage Schema Definitions
Source of truth for pipeline stage metadata.
Generates: Pydantic, TypeScript via modelgen.
Each stage is defined by its config fields. The implementation
lives in detect/stages/<name>.py as a Stage subclass.
"""
from dataclasses import dataclass, field
from typing import Any, List, Optional
@dataclass
class StageConfigField:
"""A single tunable config parameter for the editor UI."""
name: str
type: str # "float", "int", "str", "bool"
default: Any
description: str = ""
min: Optional[float] = None
max: Optional[float] = None
options: Optional[List[str]] = None
@dataclass
class StageIO:
"""Declares what a stage reads and writes."""
reads: List[str] = field(default_factory=list)
writes: List[str] = field(default_factory=list)
optional_reads: List[str] = field(default_factory=list)
@dataclass
class StageDefinition:
"""
Complete metadata for a pipeline stage.
Lives in schema as the source of truth. Each stage implementation
references a StageDefinition. The editor, graph, and checkpoint
system all consume this.
"""
name: str
label: str
description: str
category: str = "detection"
io: StageIO = field(default_factory=StageIO)
config_fields: List[StageConfigField] = field(default_factory=list)
# Legacy fields — used by old registry pattern during migration.
# New stages use Stage subclass instead.
fn: Any = None
serialize_fn: Any = None
deserialize_fn: Any = None
# --- Export for modelgen ---
STAGE_VIEWS = [
StageConfigField,
StageIO,
StageDefinition,
]

View File

@@ -1,14 +1,18 @@
"""
Stage checkpoint, replay, and retry.
Checkpoint system — Timeline + Checkpoint tree.
detect/checkpoint/
frames.py — frame image S3 upload/download
serializer.py — state ↔ JSON conversion
storage.py — checkpoint save/load/list (Postgres + S3)
replay.py — replay_from, OverrideProfile
storage.py — Timeline + Checkpoint (Postgres + MinIO)
replay.py replay (TODO: migrate to new model)
tasks.py — retry_candidates Celery task
"""
from .storage import save_checkpoint, load_checkpoint, list_checkpoints
from .storage import (
create_timeline,
get_timeline_frames,
get_timeline_frames_b64,
save_stage_output,
load_stage_output,
)
from .frames import save_frames, load_frames
from .replay import replay_from, OverrideProfile

View File

@@ -12,7 +12,13 @@ import logging
import uuid
from detect import emit
from detect.checkpoint import load_checkpoint, list_checkpoints
# TODO: migrate to Timeline/Branch/Checkpoint model
# These old functions no longer exist — replay needs rework
def _not_migrated(*args, **kwargs):
raise NotImplementedError("Replay not yet migrated to Timeline/Branch/Checkpoint model")
load_checkpoint = _not_migrated
list_checkpoints = _not_migrated
from detect.graph import NODES, build_graph
logger = logging.getLogger(__name__)

View File

@@ -1,116 +1,178 @@
"""
Checkpoint storage — save/load stage state.
Checkpoint storage — Timeline + Checkpoint (tree of snapshots).
Binary data (frame images) → S3/MinIO via frames.py
Structured data (stage output, stats, config) → Postgres
Timeline: frame sequence from source video (frames in MinIO)
Checkpoint: snapshot of pipeline state (stage outputs as JSONB in Postgres)
parent_id forms a tree — multiple children = different config tries
"""
from __future__ import annotations
import logging
from uuid import UUID
from .frames import save_frames, load_frames, CHECKPOINT_PREFIX
from .serializer import serialize_state, deserialize_state
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Save
# Timeline
# ---------------------------------------------------------------------------
def save_checkpoint(
job_id: str,
stage: str,
stage_index: int,
state: dict,
frames_manifest: dict[int, str] | None = None,
def create_timeline(
source_video: str,
profile_name: str,
frames: list,
fps: float = 2.0,
source_asset_id: UUID | None = None,
) -> tuple[str, str]:
"""
Create a timeline from frames. Uploads frame images to MinIO,
creates Timeline + root Checkpoint in Postgres.
Returns (timeline_id, checkpoint_id).
"""
from core.db.detect import create_timeline as db_create_timeline
from core.db.detect import save_checkpoint
# Create timeline
timeline = db_create_timeline(
source_video=source_video,
profile_name=profile_name,
source_asset_id=source_asset_id,
fps=fps,
)
tid = str(timeline.id)
# Upload frames to MinIO
manifest = save_frames(tid, frames)
# Store frame metadata on the timeline
frames_meta = [
{
"sequence": f.sequence,
"chunk_id": getattr(f, "chunk_id", 0),
"timestamp": f.timestamp,
"perceptual_hash": getattr(f, "perceptual_hash", ""),
}
for f in frames
]
timeline.frames_prefix = f"{CHECKPOINT_PREFIX}/{tid}/frames/"
timeline.frames_manifest = {str(k): v for k, v in manifest.items()}
timeline.frames_meta = frames_meta
from core.db.connection import get_session
with get_session() as session:
session.add(timeline)
session.commit()
# Create root checkpoint (no parent, no stage outputs yet)
checkpoint = save_checkpoint(
timeline_id=timeline.id,
parent_id=None,
stage_outputs={},
stats={"frames_extracted": len(frames)},
)
logger.info("Timeline created: %s (%d frames, root checkpoint %s)",
tid, len(frames), checkpoint.id)
return tid, str(checkpoint.id)
def get_timeline_frames(timeline_id: str) -> list:
"""Load frames from a timeline (from MinIO) as Frame objects."""
from core.db.detect import get_timeline
timeline = get_timeline(timeline_id)
if not timeline:
raise ValueError(f"Timeline not found: {timeline_id}")
raw_manifest = timeline.frames_manifest or {}
manifest = {int(k): v for k, v in raw_manifest.items()}
frame_metadata = timeline.frames_meta or []
return load_frames(manifest, frame_metadata)
def get_timeline_frames_b64(timeline_id: str) -> list[dict]:
"""Load frames as base64 JPEG (lightweight, no numpy)."""
from core.db.detect import get_timeline
from .frames import load_frames_b64
timeline = get_timeline(timeline_id)
if not timeline:
raise ValueError(f"Timeline not found: {timeline_id}")
raw_manifest = timeline.frames_manifest or {}
manifest = {int(k): v for k, v in raw_manifest.items()}
frame_metadata = timeline.frames_meta or []
return load_frames_b64(manifest, frame_metadata)
# ---------------------------------------------------------------------------
# Checkpoint
# ---------------------------------------------------------------------------
def save_stage_output(
timeline_id: str,
parent_checkpoint_id: str | None,
stage_name: str,
output_json: dict,
config_overrides: dict | None = None,
stats: dict | None = None,
is_scenario: bool = False,
scenario_label: str = "",
) -> str:
"""
Save a stage checkpoint.
Save a stage's output as a new checkpoint (child of parent).
Saves frame images to S3 (if not already saved), then persists
structured state to Postgres.
Returns the checkpoint DB id.
Carries forward stage outputs from parent + adds the new one.
Returns the new checkpoint ID.
"""
from core.db.detect import save_stage_checkpoint
from core.db.detect import get_checkpoint, save_checkpoint
if frames_manifest is None:
all_frames = state.get("frames", [])
frames_manifest = save_frames(job_id, all_frames)
# Carry forward from parent
parent_outputs = {}
parent_stats = {}
parent_config = {}
if parent_checkpoint_id:
parent = get_checkpoint(parent_checkpoint_id)
if parent:
parent_outputs = dict(parent.stage_outputs or {})
parent_stats = dict(parent.stats or {})
parent_config = dict(parent.config_overrides or {})
checkpoint_data = serialize_state(state, frames_manifest)
frames_prefix = f"{CHECKPOINT_PREFIX}/{job_id}/frames/"
# Add new stage output
stage_outputs = {**parent_outputs, stage_name: output_json}
checkpoint = save_stage_checkpoint(
job_id=job_id,
stage=stage,
stage_index=stage_index,
frames_prefix=frames_prefix,
frames_manifest=checkpoint_data.get("frames_manifest", {}),
frames_meta=checkpoint_data.get("frames_meta", []),
filtered_frame_sequences=checkpoint_data.get("filtered_frame_sequences", []),
stage_output_key=checkpoint_data.get("stage_output_key", ""),
stats=checkpoint_data.get("stats", {}),
config_snapshot=checkpoint_data.get("config_overrides", {}),
config_overrides=checkpoint_data.get("config_overrides", {}),
video_path=checkpoint_data.get("video_path", ""),
profile_name=checkpoint_data.get("profile_name", ""),
# Merge stats and config
merged_stats = {**parent_stats, **(stats or {})}
merged_config = {**parent_config, **(config_overrides or {})}
checkpoint = save_checkpoint(
timeline_id=timeline_id,
parent_id=parent_checkpoint_id,
stage_outputs=stage_outputs,
config_overrides=merged_config,
stats=merged_stats,
is_scenario=is_scenario,
scenario_label=scenario_label,
)
logger.info("Checkpoint saved: %s/%s (id=%s, scenario=%s)",
job_id, stage, checkpoint.id, is_scenario)
logger.info("Checkpoint saved: %s (timeline %s, stage %s, parent %s)",
checkpoint.id, timeline_id, stage_name, parent_checkpoint_id)
return str(checkpoint.id)
# ---------------------------------------------------------------------------
# Load
# ---------------------------------------------------------------------------
def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None:
"""Load a stage's output from a checkpoint."""
from core.db.detect import get_checkpoint
def load_checkpoint(job_id: str, stage: str) -> dict:
"""
Load a stage checkpoint and reconstitute full DetectState.
"""
from core.db.detect import get_stage_checkpoint
checkpoint = get_stage_checkpoint(job_id, stage)
checkpoint = get_checkpoint(checkpoint_id)
if not checkpoint:
raise ValueError(f"No checkpoint for {job_id}/{stage}")
return None
data = {
"job_id": str(checkpoint.job_id),
"video_path": checkpoint.video_path,
"profile_name": checkpoint.profile_name,
"config_overrides": checkpoint.config_overrides,
"frames_manifest": checkpoint.frames_manifest,
"frames_meta": checkpoint.frames_meta,
"filtered_frame_sequences": checkpoint.filtered_frame_sequences,
"stage_output_key": checkpoint.stage_output_key,
"stats": checkpoint.stats,
}
raw_manifest = data.get("frames_manifest", {})
manifest = {int(k): v for k, v in raw_manifest.items()}
frame_metadata = data.get("frames_meta", [])
frames = load_frames(manifest, frame_metadata)
state = deserialize_state(data, frames)
logger.info("Checkpoint loaded: %s/%s (%d frames, scenario=%s)",
job_id, stage, len(frames), checkpoint.is_scenario)
return state
# ---------------------------------------------------------------------------
# List
# ---------------------------------------------------------------------------
def list_checkpoints(job_id: str) -> list[str]:
"""List available checkpoint stages for a job."""
from core.db.detect import list_stage_checkpoints
return list_stage_checkpoints(job_id)
return (checkpoint.stage_outputs or {}).get(stage_name)

View File

@@ -1,21 +1,21 @@
"""
Pipeline stages.
Each stage registers its StageDefinition on import,
declaring IO (what it reads/writes from state),
config fields (what's tunable from the editor),
and serialization (how to checkpoint its outputs).
Each stage is a file with a Stage subclass. Auto-discovered via
__init_subclass__ — importing the file registers the stage.
"""
from .base import (
StageDefinition,
StageIO,
StageConfigField,
register_stage,
Stage,
get_stage,
get_stage_instance,
list_stages,
list_stage_classes,
get_palette,
)
# Populate registry with built-in stages
# Import all stage files to trigger auto-registration
from . import edge_detector # noqa: F401
# Import registry for backward compat (other stages still use old pattern)
from . import registry # noqa: F401

View File

@@ -1,101 +1,131 @@
"""
Stage protocol — common interface for all pipeline stages.
Stage base class — common interface for all pipeline stages.
Every stage declares:
- IO: what it reads/writes from DetectState
- Config: tunable parameters for the editor
- Serialization: how to persist/restore its own outputs
Each stage is a file that subclasses Stage. Auto-discovered via
__init_subclass__. No manual registration needed.
The checkpoint layer is a black box — it asks each stage to serialize its
outputs and stores the result. Stages own their data format. Binary data
(frames, crops) goes to S3 via the stage itself. The checkpoint just
stores the JSON envelope.
A stage:
- Has a StageDefinition (from schema) with name, config, IO
- Implements run(frames, config) → output
- Owns its output serialization (opaque blob)
- Optionally has a TypeScript port for browser-side execution
The graph builder uses StageIO to validate that a stage's inputs are
satisfied by previous stages' outputs.
The checkpoint layer stores stage output as blobs without knowing
the format. The stage that wrote it is the only one that can read it.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Callable
from typing import Any
import numpy as np
@dataclass
class StageIO:
"""Declares what a stage reads and writes from/to DetectState."""
reads: list[str]
writes: list[str]
optional_reads: list[str] = field(default_factory=list)
@dataclass
class StageConfigField:
"""A single tunable config parameter for the editor UI."""
name: str
type: str # "float", "int", "str", "bool", "list[str]"
default: Any
description: str = ""
min: float | None = None
max: float | None = None
options: list[str] | None = None
@dataclass
class StageDefinition:
"""
Complete metadata for a pipeline stage.
The profile editor uses this to build the palette, generate config
forms, and validate graph connections. The checkpoint uses serialize_fn
and deserialize_fn to persist stage outputs without knowing the internals.
"""
name: str
label: str
description: str
io: StageIO
config_fields: list[StageConfigField] = field(default_factory=list)
category: str = "detection"
# The actual graph node function: (DetectState) → dict
fn: Callable | None = None
# Stage-owned serialization for checkpointing.
# serialize_fn: (state: dict, job_id: str) → json-compatible dict
# Stage picks its writes from state, serializes them.
# Binary data (frames) → S3 via stage, returns refs.
# deserialize_fn: (data: dict, job_id: str) → state update dict
# Stage restores its writes from the persisted data.
serialize_fn: Callable | None = None
deserialize_fn: Callable | None = None
from core.schema.models.stages import StageConfigField, StageIO, StageDefinition
# ---------------------------------------------------------------------------
# Registry
# Registry — auto-populated by __init_subclass__ (new stages)
# + register_stage() (legacy stages during migration)
# ---------------------------------------------------------------------------
_REGISTRY: dict[str, StageDefinition] = {}
_REGISTRY: dict[str, type['Stage']] = {}
_LEGACY_REGISTRY: dict[str, StageDefinition] = {}
def register_stage(definition: StageDefinition):
_REGISTRY[definition.name] = definition
"""Legacy registration for stages not yet converted to Stage subclass."""
_LEGACY_REGISTRY[definition.name] = definition
class Stage:
"""
Base class for all pipeline stages.
Subclass this in detect/stages/<name>.py. Define `definition` as a
class attribute. Implement `run()`. Optionally override `serialize()`
and `deserialize()` for custom blob formats (default is JSON).
"""
definition: StageDefinition # set by each subclass
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if hasattr(cls, 'definition') and cls.definition is not None:
_REGISTRY[cls.definition.name] = cls
def run(self, frames: list, config: dict) -> Any:
"""
Run the stage on a list of frames with the given config.
Config is a dict of parameter values (from slider UI or profile).
Returns the stage output — whatever shape this stage produces.
Debug overlays are included when config has debug=True.
"""
raise NotImplementedError
def serialize(self, output: Any) -> bytes:
"""Serialize stage output to bytes for checkpoint storage."""
import json
return json.dumps(output, default=str).encode()
def deserialize(self, data: bytes) -> Any:
"""Deserialize stage output from checkpoint blob."""
import json
return json.loads(data)
# ---------------------------------------------------------------------------
# Discovery API
# ---------------------------------------------------------------------------
def _all_definitions() -> dict[str, StageDefinition]:
"""Merge new Stage subclass registry + legacy registry."""
merged = {}
# Legacy first, new overwrites (new takes precedence)
for name, defn in _LEGACY_REGISTRY.items():
merged[name] = defn
for name, cls in _REGISTRY.items():
merged[name] = cls.definition
return merged
def get_stage(name: str) -> StageDefinition:
if name not in _REGISTRY:
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(_REGISTRY)}")
return _REGISTRY[name]
"""Get a stage definition by name (works for both new and legacy)."""
all_defs = _all_definitions()
if name not in all_defs:
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(all_defs)}")
return all_defs[name]
def get_stage_class(name: str) -> type[Stage] | None:
"""Get a Stage subclass by name. Returns None for legacy stages."""
return _REGISTRY.get(name)
def get_stage_instance(name: str) -> Stage:
"""Get an instantiated Stage by name. Only works for new-style stages."""
cls = _REGISTRY.get(name)
if cls is None:
raise KeyError(f"No Stage subclass for {name!r}. Legacy stages don't have instances.")
return cls()
def list_stages() -> list[StageDefinition]:
"""List all registered stage definitions (new + legacy)."""
return list(_all_definitions().values())
def list_stage_classes() -> list[type[Stage]]:
"""List all registered Stage subclasses (new-style only)."""
return list(_REGISTRY.values())
def get_palette() -> dict[str, list[StageDefinition]]:
"""Group stages by category for the editor palette."""
palette: dict[str, list[StageDefinition]] = {}
for stage in _REGISTRY.values():
if stage.category not in palette:
palette[stage.category] = []
palette[stage.category].append(stage)
for defn in _all_definitions().values():
if defn.category not in palette:
palette[defn.category] = []
palette[defn.category].append(defn)
return palette

View File

@@ -7,120 +7,67 @@ advertising hoardings. Pure OpenCV, no ML models.
Two modes:
- Remote: calls GPU inference server over HTTP
- Local: imports cv2 directly (OpenCV on same machine)
Emits frame_update events with bounding boxes for the frame viewer.
"""
from __future__ import annotations
import base64
import io
import json
import logging
import os
import time
from typing import Any
from PIL import Image
from detect import emit
from detect.models import BoundingBox, Frame
from detect.profiles.base import RegionAnalysisConfig
from detect.stages.base import Stage
from core.schema.models.stages import StageDefinition, StageConfigField, StageIO
logger = logging.getLogger(__name__)
def _frame_to_b64(frame: Frame) -> str:
"""Encode frame as base64 JPEG for SSE frame_update events."""
img = Image.fromarray(frame.image)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=70)
return base64.b64encode(buf.getvalue()).decode()
class EdgeDetectionStage(Stage):
def _detect_remote(
frame: Frame,
config: RegionAnalysisConfig,
inference_url: str,
job_id: str = "",
log_level: str = "INFO",
) -> list[BoundingBox]:
"""Call the inference server over HTTP."""
from detect.inference import InferenceClient
client = InferenceClient(
base_url=inference_url, job_id=job_id, log_level=log_level,
)
results = client.detect_edges(
image=frame.image,
edge_canny_low=config.edge_canny_low,
edge_canny_high=config.edge_canny_high,
edge_hough_threshold=config.edge_hough_threshold,
edge_hough_min_length=config.edge_hough_min_length,
edge_hough_max_gap=config.edge_hough_max_gap,
edge_pair_max_distance=config.edge_pair_max_distance,
edge_pair_min_distance=config.edge_pair_min_distance,
)
boxes = []
for r in results:
box = BoundingBox(
x=r.x, y=r.y, w=r.w, h=r.h,
confidence=r.confidence, label=r.label,
)
boxes.append(box)
return boxes
_cv_edges_mod = None
def _load_cv_edges():
"""Load edges module directly — gpu/models/__init__.py has GPU-container-only imports."""
global _cv_edges_mod
if _cv_edges_mod is None:
import importlib.util
from pathlib import Path
spec = importlib.util.spec_from_file_location("cv_edges", Path("gpu/models/cv/edges.py"))
_cv_edges_mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(_cv_edges_mod)
return _cv_edges_mod
def _detect_local(frame: Frame, config: RegionAnalysisConfig) -> list[BoundingBox]:
"""Run edge detection in-process (requires opencv-python)."""
detect_edges_fn = _load_cv_edges().detect_edges
edge_results = detect_edges_fn(
frame.image,
canny_low=config.edge_canny_low,
canny_high=config.edge_canny_high,
hough_threshold=config.edge_hough_threshold,
hough_min_length=config.edge_hough_min_length,
hough_max_gap=config.edge_hough_max_gap,
pair_max_distance=config.edge_pair_max_distance,
pair_min_distance=config.edge_pair_min_distance,
definition = StageDefinition(
name="detect_edges",
label="Edge Detection",
description="Canny + HoughLinesP — find horizontal line pairs (hoarding boundaries)",
category="cv_analysis",
io=StageIO(
reads=["filtered_frames"],
writes=["edge_regions_by_frame"],
),
config_fields=[
StageConfigField("enabled", "bool", True, "Enable edge detection"),
StageConfigField("edge_canny_low", "int", 50, "Canny low threshold", min=0, max=255),
StageConfigField("edge_canny_high", "int", 150, "Canny high threshold", min=0, max=255),
StageConfigField("edge_hough_threshold", "int", 80, "Hough accumulator threshold", min=1, max=500),
StageConfigField("edge_hough_min_length", "int", 100, "Min line length (px)", min=10, max=2000),
StageConfigField("edge_hough_max_gap", "int", 10, "Max line gap (px)", min=1, max=100),
StageConfigField("edge_pair_max_distance", "int", 200, "Max distance between line pair (px)", min=10, max=500),
StageConfigField("edge_pair_min_distance", "int", 15, "Min distance between line pair (px)", min=5, max=200),
],
)
boxes = []
for r in edge_results:
box = BoundingBox(
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
confidence=r["confidence"], label=r["label"],
)
boxes.append(box)
return boxes
def detect_edge_regions(
frames: list[Frame],
config: RegionAnalysisConfig,
inference_url: str | None = None,
job_id: str | None = None,
) -> dict[int, list[BoundingBox]]:
def run(self, frames: list[Frame], config: dict) -> dict[int, list[BoundingBox]]:
"""
Run edge detection on all frames.
Returns a dict mapping frame sequence → list of bounding boxes.
Config keys: enabled, edge_canny_low, edge_canny_high, edge_hough_threshold,
edge_hough_min_length, edge_hough_max_gap, edge_pair_max_distance, edge_pair_min_distance,
debug (bool), inference_url (str|None), job_id (str|None).
Returns dict mapping frame sequence → list of BoundingBox.
"""
if not config.enabled:
enabled = config.get("enabled", True)
job_id = config.get("job_id")
inference_url = config.get("inference_url") or os.environ.get("INFERENCE_URL")
if not enabled:
emit.log(job_id, "EdgeDetection", "INFO", "Edge detection disabled, skipping")
return {}
@@ -131,32 +78,26 @@ def detect_edge_regions(
all_boxes: dict[int, list[BoundingBox]] = {}
total_regions = 0
for i, frame in enumerate(frames):
for frame in frames:
t0 = time.monotonic()
if inference_url:
from detect.emit import _run_log_level
boxes = _detect_remote(
frame, config, inference_url,
job_id=job_id or "", log_level=_run_log_level,
)
boxes = self._run_remote(frame, config, inference_url, job_id or "")
else:
boxes = _detect_local(frame, config)
analysis_ms = (time.monotonic() - t0) * 1000
boxes = self._run_local(frame, config)
ms = (time.monotonic() - t0) * 1000
all_boxes[frame.sequence] = boxes
total_regions += len(boxes)
emit.log(job_id, "EdgeDetection", "DEBUG",
f"Frame {frame.sequence}: {len(boxes)} regions in {analysis_ms:.0f}ms"
f"Frame {frame.sequence}: {len(boxes)} regions in {ms:.0f}ms"
+ (f" [{', '.join(b.label for b in boxes)}]" if boxes else ""))
if boxes and job_id:
box_dicts = [
{
"x": b.x, "y": b.y, "w": b.w, "h": b.h,
{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
"confidence": b.confidence, "label": b.label,
"stage": "detect_edges",
}
"stage": "detect_edges"}
for b in boxes
]
emit.frame_update(
@@ -172,3 +113,121 @@ def detect_edge_regions(
emit.stats(job_id, cv_regions_detected=total_regions)
return all_boxes
def serialize(self, output: Any) -> bytes:
"""Serialize edge regions to JSON blob."""
serialized = {}
for seq, boxes in output.items():
serialized[str(seq)] = [
{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
"confidence": b.confidence, "label": b.label}
for b in boxes
]
return json.dumps(serialized).encode()
def deserialize(self, data: bytes) -> dict[int, list[BoundingBox]]:
"""Deserialize edge regions from JSON blob."""
raw = json.loads(data)
result = {}
for seq_str, box_dicts in raw.items():
boxes = [
BoundingBox(x=b["x"], y=b["y"], w=b["w"], h=b["h"],
confidence=b["confidence"], label=b["label"])
for b in box_dicts
]
result[int(seq_str)] = boxes
return result
# --- Private helpers ---
def _run_remote(self, frame: Frame, config: dict,
inference_url: str, job_id: str) -> list[BoundingBox]:
from detect.inference import InferenceClient
from detect.emit import _run_log_level
client = InferenceClient(
base_url=inference_url, job_id=job_id, log_level=_run_log_level,
)
results = client.detect_edges(
image=frame.image,
edge_canny_low=config.get("edge_canny_low", 50),
edge_canny_high=config.get("edge_canny_high", 150),
edge_hough_threshold=config.get("edge_hough_threshold", 80),
edge_hough_min_length=config.get("edge_hough_min_length", 100),
edge_hough_max_gap=config.get("edge_hough_max_gap", 10),
edge_pair_max_distance=config.get("edge_pair_max_distance", 200),
edge_pair_min_distance=config.get("edge_pair_min_distance", 15),
)
boxes = []
for r in results:
box = BoundingBox(
x=r.x, y=r.y, w=r.w, h=r.h,
confidence=r.confidence, label=r.label,
)
boxes.append(box)
return boxes
def _run_local(self, frame: Frame, config: dict) -> list[BoundingBox]:
detect_edges_fn = _load_cv_edges().detect_edges
edge_results = detect_edges_fn(
frame.image,
canny_low=config.get("edge_canny_low", 50),
canny_high=config.get("edge_canny_high", 150),
hough_threshold=config.get("edge_hough_threshold", 80),
hough_min_length=config.get("edge_hough_min_length", 100),
hough_max_gap=config.get("edge_hough_max_gap", 10),
pair_max_distance=config.get("edge_pair_max_distance", 200),
pair_min_distance=config.get("edge_pair_min_distance", 15),
)
boxes = []
for r in edge_results:
box = BoundingBox(
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
confidence=r["confidence"], label=r["label"],
)
boxes.append(box)
return boxes
# --- Module-level helpers ---
def _frame_to_b64(frame: Frame) -> str:
img = Image.fromarray(frame.image)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=70)
return base64.b64encode(buf.getvalue()).decode()
_cv_edges_mod = None
def _load_cv_edges():
global _cv_edges_mod
if _cv_edges_mod is None:
import importlib.util
from pathlib import Path
spec = importlib.util.spec_from_file_location("cv_edges", Path("gpu/models/cv/edges.py"))
_cv_edges_mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(_cv_edges_mod)
return _cv_edges_mod
# --- Backward compat: standalone function for graph.py ---
def detect_edge_regions(frames, config, inference_url=None, job_id=None):
"""Convenience wrapper — calls EdgeDetectionStage.run()."""
stage = EdgeDetectionStage()
cfg = {
"enabled": config.enabled,
"edge_canny_low": config.edge_canny_low,
"edge_canny_high": config.edge_canny_high,
"edge_hough_threshold": config.edge_hough_threshold,
"edge_hough_min_length": config.edge_hough_min_length,
"edge_hough_max_gap": config.edge_hough_max_gap,
"edge_pair_max_distance": config.edge_pair_max_distance,
"edge_pair_min_distance": config.edge_pair_min_distance,
"inference_url": inference_url,
"job_id": job_id,
}
return stage.run(frames, cfg)

View File

@@ -45,15 +45,15 @@ def main():
return
logger.info("")
logger.info("%3s %-35s %-12s %-18s %6s %s", "#", "Label", "Job ID", "Stage", "Frames", "Created")
logger.info("" * 100)
logger.info("%3s %-35s %-12s %6s %s", "#", "Label", "Timeline", "Stages", "Created")
logger.info("" * 80)
for i, s in enumerate(scenarios, 1):
manifest = s.frames_manifest or {}
created = str(s.created_at)[:19] if s.created_at else ""
job_short = str(s.job_id)[:8]
logger.info("%3d %-35s %-12s %-18s %6d %s",
i, s.scenario_label, job_short, s.stage, len(manifest), created)
tid_short = str(s.timeline_id)[:8]
stage_count = len(s.stage_outputs or {})
logger.info("%3d %-35s %-12s %6d %s",
i, s.scenario_label, tid_short, stage_count, created)
logger.info("")
@@ -73,7 +73,7 @@ def main():
logger.error("Scenario not found: %s", args.open)
return
url = f"{args.base_url}?job={target.job_id}#/editor/detect_edges"
url = f"{args.base_url}?job={target.timeline_id}#/editor/detect_edges"
logger.info("Opening: %s", url)
webbrowser.open(url)
else:

View File

@@ -1,26 +1,20 @@
#!/usr/bin/env python3
"""
Seed a scenario checkpoint from a video chunk.
Seed a scenario from a video chunk.
Extracts frames via ffmpeg, uploads to MinIO, creates a StageCheckpoint
in Postgres marked as a scenario. No pipeline, no Redis, no SSE.
Creates a Timeline (frames in MinIO) + Branch + Checkpoint marked
as a scenario. No pipeline, no Redis, no SSE.
Prerequisites:
- Postgres reachable (port-forward or local)
- MinIO reachable (port-forward or local)
- Postgres reachable (Kind NodePort or local)
- MinIO reachable (Kind NodePort or local)
Usage:
# With K8s port-forwards:
kubectl port-forward svc/postgres 5432:5432 &
kubectl port-forward svc/minio 9000:9000 &
python tests/detect/manual/seed_scenario.py
# Custom video:
python tests/detect/manual/seed_scenario.py --video media/mpr/out/chunks/.../chunk_0001.mp4
Then open:
http://mpr.local.ar/detection/?job=<JOB_ID>#/editor/detect_edges
http://mpr.local.ar/detection/?job=<TIMELINE_ID>#/editor/detect_edges
"""
from __future__ import annotations
@@ -31,7 +25,7 @@ import os
import sys
import uuid
parser = argparse.ArgumentParser(description="Seed a scenario checkpoint")
parser = argparse.ArgumentParser(description="Seed a scenario")
parser.add_argument("--video",
default="media/mpr/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4")
parser.add_argument("--label", default="chelsea_edges_default",
@@ -44,7 +38,6 @@ parser.add_argument("--s3-url",
default=os.environ.get("S3_ENDPOINT_URL", "http://localhost:9000"))
args = parser.parse_args()
# Set env before imports
os.environ["DATABASE_URL"] = args.db_url
os.environ["S3_ENDPOINT_URL"] = args.s3_url
os.environ.setdefault("AWS_ACCESS_KEY_ID", "minioadmin")
@@ -57,7 +50,7 @@ logger = logging.getLogger(__name__)
def extract_frames_ffmpeg(video_path: str, fps: float, max_frames: int):
"""Extract frames using ffmpeg subprocess — no pipeline dependencies."""
"""Extract frames using ffmpeg — no pipeline dependencies."""
import subprocess
import tempfile
from pathlib import Path
@@ -82,7 +75,7 @@ def extract_frames_ffmpeg(video_path: str, fps: float, max_frames: int):
frames = []
for jpg in sorted(Path(tmpdir).glob("frame_*.jpg")):
seq = int(jpg.stem.split("_")[1]) - 1 # 0-indexed
seq = int(jpg.stem.split("_")[1]) - 1
img = Image.open(jpg).convert("RGB")
image_array = np.array(img)
frame = Frame(
@@ -99,7 +92,6 @@ def extract_frames_ffmpeg(video_path: str, fps: float, max_frames: int):
def main():
job_id = str(uuid.uuid4())
video_path = args.video
if not os.path.exists(video_path):
@@ -107,7 +99,6 @@ def main():
sys.exit(1)
logger.info("Video: %s", video_path)
logger.info("Job ID: %s", job_id)
logger.info("Label: %s", args.label)
# Ensure DB tables exist
@@ -119,57 +110,37 @@ def main():
frames = extract_frames_ffmpeg(video_path, args.fps, args.max_frames)
logger.info("Extracted %d frames", len(frames))
# Upload frames to MinIO
from detect.checkpoint.frames import save_frames
logger.info("Uploading frames to MinIO...")
manifest = save_frames(job_id, frames)
logger.info("Uploaded %d frames", len(manifest))
# Create timeline + branch + checkpoint
from detect.checkpoint.storage import create_timeline, save_stage_output
# Build frame metadata
frames_meta = [
{
"sequence": f.sequence,
"chunk_id": f.chunk_id,
"timestamp": f.timestamp,
"perceptual_hash": "",
}
for f in frames
]
# All frames are "filtered" (no scene filter ran)
filtered_sequences = [f.sequence for f in frames]
# Save checkpoint as scenario
from core.db.detect import save_stage_checkpoint
from detect.checkpoint.frames import CHECKPOINT_PREFIX
checkpoint = save_stage_checkpoint(
job_id=job_id,
stage="filter_scenes",
stage_index=1,
frames_prefix=f"{CHECKPOINT_PREFIX}/{job_id}/frames/",
frames_manifest={str(k): v for k, v in manifest.items()},
frames_meta=frames_meta,
filtered_frame_sequences=filtered_sequences,
stage_output_key="",
stats={"frames_extracted": len(frames), "frames_after_scene_filter": len(frames)},
config_snapshot={},
config_overrides={},
video_path=video_path,
timeline_id, branch_id = create_timeline(
source_video=video_path,
profile_name="soccer_broadcast",
is_scenario=True,
scenario_label=args.label,
frames=frames,
fps=args.fps,
)
# Mark as scenario
from core.db.detect import get_latest_checkpoint
from core.db.connection import get_session
checkpoint = get_latest_checkpoint(branch_id)
if checkpoint:
checkpoint.is_scenario = True
checkpoint.scenario_label = args.label
with get_session() as session:
session.add(checkpoint)
session.commit()
logger.info("")
logger.info("Scenario created:")
logger.info(" ID: %s", checkpoint.id)
logger.info(" Job: %s", job_id)
logger.info(" Timeline: %s", timeline_id)
logger.info(" Branch: %s", branch_id)
logger.info(" Label: %s", args.label)
logger.info(" Frames: %d", len(frames))
logger.info("")
logger.info("Open in editor:")
logger.info(" http://mpr.local.ar/detection/?job=%s#/editor/detect_edges", job_id)
logger.info(" http://mpr.local.ar/detection/?job=%s#/editor/detect_edges", timeline_id)
if __name__ == "__main__":

View File

@@ -1,6 +1,7 @@
"""Tests for the stage registry."""
from detect.stages import list_stages, get_stage, get_palette
from detect.stages.base import get_stage_class
EXPECTED_STAGES = [
@@ -26,9 +27,17 @@ def test_stage_has_io():
def test_stage_has_serialization():
for name in EXPECTED_STAGES:
stage = get_stage(name)
assert stage.serialize_fn is not None, f"{name} has no serialize_fn"
assert stage.deserialize_fn is not None, f"{name} has no deserialize_fn"
defn = get_stage(name)
stage_cls = get_stage_class(name)
if stage_cls is not None:
# New-style: serialization on the class
instance = stage_cls()
assert hasattr(instance, 'serialize'), f"{name} has no serialize method"
assert hasattr(instance, 'deserialize'), f"{name} has no deserialize method"
else:
# Legacy: serialization on the definition
assert defn.serialize_fn is not None, f"{name} has no serialize_fn"
assert defn.deserialize_fn is not None, f"{name} has no deserialize_fn"
def test_palette_groups():