refactor stage 1
This commit is contained in:
@@ -9,7 +9,8 @@ from sqlmodel import select
|
|||||||
|
|
||||||
from .connection import get_session
|
from .connection import get_session
|
||||||
from .models import (
|
from .models import (
|
||||||
DetectJob, StageCheckpoint, KnownBrand, SourceBrandSighting,
|
DetectJob, Timeline, Checkpoint,
|
||||||
|
KnownBrand, SourceBrandSighting,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -55,72 +56,86 @@ def list_detect_jobs(
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# StageCheckpoint
|
# Timeline
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def save_stage_checkpoint(**fields) -> StageCheckpoint:
|
def create_timeline(**fields) -> Timeline:
|
||||||
|
timeline = Timeline(**fields)
|
||||||
with get_session() as session:
|
with get_session() as session:
|
||||||
# Upsert: replace if same job_id + stage
|
session.add(timeline)
|
||||||
job_id = fields.get("job_id")
|
session.commit()
|
||||||
stage = fields.get("stage")
|
session.refresh(timeline)
|
||||||
if job_id and stage:
|
return timeline
|
||||||
stmt = select(StageCheckpoint).where(
|
|
||||||
StageCheckpoint.job_id == job_id,
|
|
||||||
StageCheckpoint.stage == stage,
|
|
||||||
)
|
|
||||||
existing = session.exec(stmt).first()
|
|
||||||
if existing:
|
|
||||||
for k, v in fields.items():
|
|
||||||
setattr(existing, k, v)
|
|
||||||
session.commit()
|
|
||||||
session.refresh(existing)
|
|
||||||
return existing
|
|
||||||
|
|
||||||
checkpoint = StageCheckpoint(**fields)
|
|
||||||
|
def get_timeline(timeline_id: UUID) -> Timeline | None:
|
||||||
|
with get_session() as session:
|
||||||
|
return session.get(Timeline, timeline_id)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Checkpoint
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def save_checkpoint(**fields) -> Checkpoint:
|
||||||
|
checkpoint = Checkpoint(**fields)
|
||||||
|
with get_session() as session:
|
||||||
session.add(checkpoint)
|
session.add(checkpoint)
|
||||||
session.commit()
|
session.commit()
|
||||||
session.refresh(checkpoint)
|
session.refresh(checkpoint)
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
def get_stage_checkpoint(job_id: UUID, stage: str) -> StageCheckpoint | None:
|
def get_checkpoint(checkpoint_id: UUID) -> Checkpoint | None:
|
||||||
with get_session() as session:
|
with get_session() as session:
|
||||||
stmt = select(StageCheckpoint).where(
|
return session.get(Checkpoint, checkpoint_id)
|
||||||
StageCheckpoint.job_id == job_id,
|
|
||||||
StageCheckpoint.stage == stage,
|
|
||||||
|
def get_latest_checkpoint(timeline_id: UUID, parent_id: UUID | None = None) -> Checkpoint | None:
|
||||||
|
"""Get the most recent checkpoint for a timeline, optionally from a specific parent."""
|
||||||
|
with get_session() as session:
|
||||||
|
stmt = (
|
||||||
|
select(Checkpoint)
|
||||||
|
.where(Checkpoint.timeline_id == timeline_id)
|
||||||
|
)
|
||||||
|
if parent_id is not None:
|
||||||
|
stmt = stmt.where(Checkpoint.parent_id == parent_id)
|
||||||
|
stmt = stmt.order_by(Checkpoint.created_at.desc())
|
||||||
|
return session.exec(stmt).first()
|
||||||
|
|
||||||
|
|
||||||
|
def list_checkpoints(timeline_id: UUID) -> list[Checkpoint]:
|
||||||
|
"""List all checkpoints for a timeline."""
|
||||||
|
with get_session() as session:
|
||||||
|
stmt = (
|
||||||
|
select(Checkpoint)
|
||||||
|
.where(Checkpoint.timeline_id == timeline_id)
|
||||||
|
.order_by(Checkpoint.created_at)
|
||||||
|
)
|
||||||
|
return list(session.exec(stmt).all())
|
||||||
|
|
||||||
|
|
||||||
|
def get_root_checkpoint(timeline_id: UUID) -> Checkpoint | None:
|
||||||
|
"""Get the root checkpoint (no parent) for a timeline."""
|
||||||
|
with get_session() as session:
|
||||||
|
stmt = select(Checkpoint).where(
|
||||||
|
Checkpoint.timeline_id == timeline_id,
|
||||||
|
Checkpoint.parent_id == None,
|
||||||
)
|
)
|
||||||
return session.exec(stmt).first()
|
return session.exec(stmt).first()
|
||||||
|
|
||||||
|
|
||||||
def list_stage_checkpoints(job_id: UUID) -> list[str]:
|
def list_scenarios() -> list[Checkpoint]:
|
||||||
with get_session() as session:
|
|
||||||
stmt = (
|
|
||||||
select(StageCheckpoint.stage)
|
|
||||||
.where(StageCheckpoint.job_id == job_id)
|
|
||||||
.order_by(StageCheckpoint.stage_index)
|
|
||||||
)
|
|
||||||
return list(session.exec(stmt).all())
|
|
||||||
|
|
||||||
|
|
||||||
def list_scenarios() -> list[StageCheckpoint]:
|
|
||||||
"""List all checkpoints marked as scenarios."""
|
"""List all checkpoints marked as scenarios."""
|
||||||
with get_session() as session:
|
with get_session() as session:
|
||||||
stmt = (
|
stmt = (
|
||||||
select(StageCheckpoint)
|
select(Checkpoint)
|
||||||
.where(StageCheckpoint.is_scenario == True)
|
.where(Checkpoint.is_scenario == True)
|
||||||
.order_by(StageCheckpoint.created_at.desc())
|
.order_by(Checkpoint.created_at.desc())
|
||||||
)
|
)
|
||||||
return list(session.exec(stmt).all())
|
return list(session.exec(stmt).all())
|
||||||
|
|
||||||
|
|
||||||
def delete_stage_checkpoints(job_id: UUID) -> None:
|
|
||||||
with get_session() as session:
|
|
||||||
stmt = select(StageCheckpoint).where(StageCheckpoint.job_id == job_id)
|
|
||||||
for cp in session.exec(stmt).all():
|
|
||||||
session.delete(cp)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# KnownBrand
|
# KnownBrand
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -181,24 +181,30 @@ class DetectJob(SQLModel, table=True):
|
|||||||
started_at: Optional[datetime] = None
|
started_at: Optional[datetime] = None
|
||||||
completed_at: Optional[datetime] = None
|
completed_at: Optional[datetime] = None
|
||||||
|
|
||||||
class StageCheckpoint(SQLModel, table=True):
|
class Timeline(SQLModel, table=True):
|
||||||
"""A checkpoint saved after a pipeline stage completes."""
|
"""Frame sequence from a source video. Independent of stages."""
|
||||||
__tablename__ = "stage_checkpoints"
|
__tablename__ = "timelines"
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
job_id: UUID = Field(index=True)
|
source_asset_id: Optional[UUID] = Field(default=None, index=True)
|
||||||
stage: str
|
source_video: str = ""
|
||||||
stage_index: int
|
profile_name: str = ""
|
||||||
|
fps: float = 2.0
|
||||||
frames_prefix: str = ""
|
frames_prefix: str = ""
|
||||||
frames_manifest: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
frames_manifest: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||||
frames_meta: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
|
frames_meta: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
|
||||||
filtered_frame_sequences: List[int] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
|
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||||
stage_output_key: str = "" # s3 key: checkpoints/{job_id}/stages/{stage}.bson
|
|
||||||
stats: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
class Checkpoint(SQLModel, table=True):
|
||||||
config_snapshot: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
"""Snapshot of pipeline state. parent_id forms a tree."""
|
||||||
|
__tablename__ = "checkpoints"
|
||||||
|
|
||||||
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
|
timeline_id: UUID = Field(index=True)
|
||||||
|
parent_id: Optional[UUID] = Field(default=None, index=True)
|
||||||
|
stage_outputs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||||
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||||
video_path: str = ""
|
stats: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||||
profile_name: str = ""
|
|
||||||
is_scenario: bool = False
|
is_scenario: bool = False
|
||||||
scenario_label: str = ""
|
scenario_label: str = ""
|
||||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||||
|
|||||||
@@ -27,9 +27,11 @@ from .grpc import (
|
|||||||
)
|
)
|
||||||
from .jobs import ChunkJob, ChunkJobStatus, JobStatus, TranscodeJob
|
from .jobs import ChunkJob, ChunkJobStatus, JobStatus, TranscodeJob
|
||||||
from .detect_jobs import (
|
from .detect_jobs import (
|
||||||
DetectJob, DetectJobStatus, RunType, StageCheckpoint,
|
DetectJob, DetectJobStatus, RunType,
|
||||||
|
Timeline, Checkpoint,
|
||||||
BrandSource, KnownBrand, SourceBrandSighting,
|
BrandSource, KnownBrand, SourceBrandSighting,
|
||||||
)
|
)
|
||||||
|
from .stages import StageConfigField, StageIO, StageDefinition, STAGE_VIEWS
|
||||||
from .media import AssetStatus, MediaAsset
|
from .media import AssetStatus, MediaAsset
|
||||||
from .presets import BUILTIN_PRESETS, TranscodePreset
|
from .presets import BUILTIN_PRESETS, TranscodePreset
|
||||||
from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generic loader
|
from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generic loader
|
||||||
@@ -40,7 +42,8 @@ from .sources import ChunkInfo, SourceJob, SourceType
|
|||||||
|
|
||||||
# Core domain models - generates Django, SQLModel, TypeScript
|
# Core domain models - generates Django, SQLModel, TypeScript
|
||||||
DATACLASSES = [MediaAsset, TranscodePreset, TranscodeJob, ChunkJob,
|
DATACLASSES = [MediaAsset, TranscodePreset, TranscodeJob, ChunkJob,
|
||||||
DetectJob, StageCheckpoint, KnownBrand, SourceBrandSighting]
|
DetectJob, Timeline, Checkpoint,
|
||||||
|
KnownBrand, SourceBrandSighting]
|
||||||
|
|
||||||
# API request/response models - generates TypeScript only (no Django)
|
# API request/response models - generates TypeScript only (no Django)
|
||||||
# WorkerStatus from grpc.py is reused here
|
# WorkerStatus from grpc.py is reused here
|
||||||
|
|||||||
@@ -72,49 +72,58 @@ class DetectJob:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StageCheckpoint:
|
class Timeline:
|
||||||
"""
|
"""
|
||||||
A checkpoint saved after a pipeline stage completes.
|
The frame sequence from a source video.
|
||||||
|
|
||||||
Binary data (frame images, crops) goes to S3/MinIO.
|
Independent of stages — exists before any stage runs.
|
||||||
Everything else (structured state) lives here in Postgres.
|
Stages annotate the timeline, they don't own it.
|
||||||
|
Frames are stored in MinIO as JPEGs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: UUID
|
id: UUID
|
||||||
job_id: UUID
|
source_asset_id: Optional[UUID] = None
|
||||||
stage: str
|
source_video: str = ""
|
||||||
stage_index: int # position in NODES list (0-7)
|
profile_name: str = ""
|
||||||
|
fps: float = 2.0
|
||||||
|
|
||||||
# S3 reference for binary data only
|
# Frame metadata (images in MinIO, metadata here)
|
||||||
frames_prefix: str = "" # s3 prefix: checkpoints/{job_id}/frames/
|
frames_prefix: str = "" # s3: timelines/{id}/frames/
|
||||||
|
|
||||||
# Frame metadata (non-image fields)
|
|
||||||
frames_manifest: Dict[int, str] = field(default_factory=dict) # seq → s3 key
|
frames_manifest: Dict[int, str] = field(default_factory=dict) # seq → s3 key
|
||||||
frames_meta: List[Dict[str, Any]] = field(default_factory=list) # sequence, chunk_id, timestamp, hash
|
frames_meta: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
filtered_frame_sequences: List[int] = field(default_factory=list)
|
|
||||||
|
|
||||||
# Stage output — stored as blob in MinIO: checkpoints/{job_id}/stages/{stage}.bson
|
created_at: Optional[datetime] = None
|
||||||
# Each stage's serialize_fn/deserialize_fn owns the format.
|
|
||||||
# Postgres only stores the S3 key, not the data itself.
|
|
||||||
stage_output_key: str = "" # s3 key to the serialized stage output
|
|
||||||
|
|
||||||
# Pipeline state (small, stays in Postgres)
|
|
||||||
stats: Dict[str, Any] = field(default_factory=dict)
|
@dataclass
|
||||||
config_snapshot: Dict[str, Any] = field(default_factory=dict)
|
class Checkpoint:
|
||||||
|
"""
|
||||||
|
A snapshot of pipeline state on a timeline.
|
||||||
|
|
||||||
|
Stage outputs stored as JSONB — each stage serializes to JSON,
|
||||||
|
the checkpoint stores it without knowing the shape.
|
||||||
|
|
||||||
|
parent_id forms a tree: multiple children from the same parent
|
||||||
|
= different config tries from the same starting point.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
timeline_id: UUID
|
||||||
|
parent_id: Optional[UUID] = None # null = root checkpoint
|
||||||
|
|
||||||
|
# Stage outputs — JSONB per stage, opaque to the checkpoint layer
|
||||||
|
stage_outputs: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Config that produced this checkpoint
|
||||||
config_overrides: Dict[str, Any] = field(default_factory=dict)
|
config_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
# Input refs (for replay)
|
# Pipeline state
|
||||||
video_path: str = ""
|
stats: Dict[str, Any] = field(default_factory=dict)
|
||||||
profile_name: str = ""
|
|
||||||
|
|
||||||
# Scenario — a checkpoint bookmarked for the editor workflow.
|
# Scenario bookmark
|
||||||
# Created by seeders (manual scripts that populate state from real footage)
|
|
||||||
# or captured from a running pipeline. Loaded via URL:
|
|
||||||
# /detection/?job=<job_id>#/editor/<stage>
|
|
||||||
is_scenario: bool = False
|
is_scenario: bool = False
|
||||||
scenario_label: str = "" # human-readable name, e.g. "chelsea_edges_lowcanny"
|
scenario_label: str = ""
|
||||||
|
|
||||||
# Timestamps
|
|
||||||
created_at: Optional[datetime] = None
|
created_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
64
core/schema/models/stages.py
Normal file
64
core/schema/models/stages.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
"""
|
||||||
|
Stage Schema Definitions
|
||||||
|
|
||||||
|
Source of truth for pipeline stage metadata.
|
||||||
|
Generates: Pydantic, TypeScript via modelgen.
|
||||||
|
|
||||||
|
Each stage is defined by its config fields. The implementation
|
||||||
|
lives in detect/stages/<name>.py as a Stage subclass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StageConfigField:
|
||||||
|
"""A single tunable config parameter for the editor UI."""
|
||||||
|
name: str
|
||||||
|
type: str # "float", "int", "str", "bool"
|
||||||
|
default: Any
|
||||||
|
description: str = ""
|
||||||
|
min: Optional[float] = None
|
||||||
|
max: Optional[float] = None
|
||||||
|
options: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StageIO:
|
||||||
|
"""Declares what a stage reads and writes."""
|
||||||
|
reads: List[str] = field(default_factory=list)
|
||||||
|
writes: List[str] = field(default_factory=list)
|
||||||
|
optional_reads: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StageDefinition:
|
||||||
|
"""
|
||||||
|
Complete metadata for a pipeline stage.
|
||||||
|
|
||||||
|
Lives in schema as the source of truth. Each stage implementation
|
||||||
|
references a StageDefinition. The editor, graph, and checkpoint
|
||||||
|
system all consume this.
|
||||||
|
"""
|
||||||
|
name: str
|
||||||
|
label: str
|
||||||
|
description: str
|
||||||
|
category: str = "detection"
|
||||||
|
io: StageIO = field(default_factory=StageIO)
|
||||||
|
config_fields: List[StageConfigField] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Legacy fields — used by old registry pattern during migration.
|
||||||
|
# New stages use Stage subclass instead.
|
||||||
|
fn: Any = None
|
||||||
|
serialize_fn: Any = None
|
||||||
|
deserialize_fn: Any = None
|
||||||
|
|
||||||
|
|
||||||
|
# --- Export for modelgen ---
|
||||||
|
|
||||||
|
STAGE_VIEWS = [
|
||||||
|
StageConfigField,
|
||||||
|
StageIO,
|
||||||
|
StageDefinition,
|
||||||
|
]
|
||||||
@@ -1,14 +1,18 @@
|
|||||||
"""
|
"""
|
||||||
Stage checkpoint, replay, and retry.
|
Checkpoint system — Timeline + Checkpoint tree.
|
||||||
|
|
||||||
detect/checkpoint/
|
detect/checkpoint/
|
||||||
frames.py — frame image S3 upload/download
|
frames.py — frame image S3 upload/download
|
||||||
serializer.py — state ↔ JSON conversion
|
storage.py — Timeline + Checkpoint (Postgres + MinIO)
|
||||||
storage.py — checkpoint save/load/list (Postgres + S3)
|
replay.py — replay (TODO: migrate to new model)
|
||||||
replay.py — replay_from, OverrideProfile
|
|
||||||
tasks.py — retry_candidates Celery task
|
tasks.py — retry_candidates Celery task
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .storage import save_checkpoint, load_checkpoint, list_checkpoints
|
from .storage import (
|
||||||
|
create_timeline,
|
||||||
|
get_timeline_frames,
|
||||||
|
get_timeline_frames_b64,
|
||||||
|
save_stage_output,
|
||||||
|
load_stage_output,
|
||||||
|
)
|
||||||
from .frames import save_frames, load_frames
|
from .frames import save_frames, load_frames
|
||||||
from .replay import replay_from, OverrideProfile
|
|
||||||
|
|||||||
@@ -12,7 +12,13 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from detect import emit
|
from detect import emit
|
||||||
from detect.checkpoint import load_checkpoint, list_checkpoints
|
# TODO: migrate to Timeline/Branch/Checkpoint model
|
||||||
|
# These old functions no longer exist — replay needs rework
|
||||||
|
def _not_migrated(*args, **kwargs):
|
||||||
|
raise NotImplementedError("Replay not yet migrated to Timeline/Branch/Checkpoint model")
|
||||||
|
|
||||||
|
load_checkpoint = _not_migrated
|
||||||
|
list_checkpoints = _not_migrated
|
||||||
from detect.graph import NODES, build_graph
|
from detect.graph import NODES, build_graph
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -1,116 +1,178 @@
|
|||||||
"""
|
"""
|
||||||
Checkpoint storage — save/load stage state.
|
Checkpoint storage — Timeline + Checkpoint (tree of snapshots).
|
||||||
|
|
||||||
Binary data (frame images) → S3/MinIO via frames.py
|
Timeline: frame sequence from source video (frames in MinIO)
|
||||||
Structured data (stage output, stats, config) → Postgres
|
Checkpoint: snapshot of pipeline state (stage outputs as JSONB in Postgres)
|
||||||
|
parent_id forms a tree — multiple children = different config tries
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from .frames import save_frames, load_frames, CHECKPOINT_PREFIX
|
from .frames import save_frames, load_frames, CHECKPOINT_PREFIX
|
||||||
from .serializer import serialize_state, deserialize_state
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Save
|
# Timeline
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def save_checkpoint(
|
def create_timeline(
|
||||||
job_id: str,
|
source_video: str,
|
||||||
stage: str,
|
profile_name: str,
|
||||||
stage_index: int,
|
frames: list,
|
||||||
state: dict,
|
fps: float = 2.0,
|
||||||
frames_manifest: dict[int, str] | None = None,
|
source_asset_id: UUID | None = None,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Create a timeline from frames. Uploads frame images to MinIO,
|
||||||
|
creates Timeline + root Checkpoint in Postgres.
|
||||||
|
|
||||||
|
Returns (timeline_id, checkpoint_id).
|
||||||
|
"""
|
||||||
|
from core.db.detect import create_timeline as db_create_timeline
|
||||||
|
from core.db.detect import save_checkpoint
|
||||||
|
|
||||||
|
# Create timeline
|
||||||
|
timeline = db_create_timeline(
|
||||||
|
source_video=source_video,
|
||||||
|
profile_name=profile_name,
|
||||||
|
source_asset_id=source_asset_id,
|
||||||
|
fps=fps,
|
||||||
|
)
|
||||||
|
tid = str(timeline.id)
|
||||||
|
|
||||||
|
# Upload frames to MinIO
|
||||||
|
manifest = save_frames(tid, frames)
|
||||||
|
|
||||||
|
# Store frame metadata on the timeline
|
||||||
|
frames_meta = [
|
||||||
|
{
|
||||||
|
"sequence": f.sequence,
|
||||||
|
"chunk_id": getattr(f, "chunk_id", 0),
|
||||||
|
"timestamp": f.timestamp,
|
||||||
|
"perceptual_hash": getattr(f, "perceptual_hash", ""),
|
||||||
|
}
|
||||||
|
for f in frames
|
||||||
|
]
|
||||||
|
|
||||||
|
timeline.frames_prefix = f"{CHECKPOINT_PREFIX}/{tid}/frames/"
|
||||||
|
timeline.frames_manifest = {str(k): v for k, v in manifest.items()}
|
||||||
|
timeline.frames_meta = frames_meta
|
||||||
|
|
||||||
|
from core.db.connection import get_session
|
||||||
|
with get_session() as session:
|
||||||
|
session.add(timeline)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Create root checkpoint (no parent, no stage outputs yet)
|
||||||
|
checkpoint = save_checkpoint(
|
||||||
|
timeline_id=timeline.id,
|
||||||
|
parent_id=None,
|
||||||
|
stage_outputs={},
|
||||||
|
stats={"frames_extracted": len(frames)},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Timeline created: %s (%d frames, root checkpoint %s)",
|
||||||
|
tid, len(frames), checkpoint.id)
|
||||||
|
return tid, str(checkpoint.id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_timeline_frames(timeline_id: str) -> list:
|
||||||
|
"""Load frames from a timeline (from MinIO) as Frame objects."""
|
||||||
|
from core.db.detect import get_timeline
|
||||||
|
|
||||||
|
timeline = get_timeline(timeline_id)
|
||||||
|
if not timeline:
|
||||||
|
raise ValueError(f"Timeline not found: {timeline_id}")
|
||||||
|
|
||||||
|
raw_manifest = timeline.frames_manifest or {}
|
||||||
|
manifest = {int(k): v for k, v in raw_manifest.items()}
|
||||||
|
frame_metadata = timeline.frames_meta or []
|
||||||
|
|
||||||
|
return load_frames(manifest, frame_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def get_timeline_frames_b64(timeline_id: str) -> list[dict]:
|
||||||
|
"""Load frames as base64 JPEG (lightweight, no numpy)."""
|
||||||
|
from core.db.detect import get_timeline
|
||||||
|
from .frames import load_frames_b64
|
||||||
|
|
||||||
|
timeline = get_timeline(timeline_id)
|
||||||
|
if not timeline:
|
||||||
|
raise ValueError(f"Timeline not found: {timeline_id}")
|
||||||
|
|
||||||
|
raw_manifest = timeline.frames_manifest or {}
|
||||||
|
manifest = {int(k): v for k, v in raw_manifest.items()}
|
||||||
|
frame_metadata = timeline.frames_meta or []
|
||||||
|
|
||||||
|
return load_frames_b64(manifest, frame_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Checkpoint
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def save_stage_output(
|
||||||
|
timeline_id: str,
|
||||||
|
parent_checkpoint_id: str | None,
|
||||||
|
stage_name: str,
|
||||||
|
output_json: dict,
|
||||||
|
config_overrides: dict | None = None,
|
||||||
|
stats: dict | None = None,
|
||||||
is_scenario: bool = False,
|
is_scenario: bool = False,
|
||||||
scenario_label: str = "",
|
scenario_label: str = "",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Save a stage checkpoint.
|
Save a stage's output as a new checkpoint (child of parent).
|
||||||
|
|
||||||
Saves frame images to S3 (if not already saved), then persists
|
Carries forward stage outputs from parent + adds the new one.
|
||||||
structured state to Postgres.
|
Returns the new checkpoint ID.
|
||||||
|
|
||||||
Returns the checkpoint DB id.
|
|
||||||
"""
|
"""
|
||||||
from core.db.detect import save_stage_checkpoint
|
from core.db.detect import get_checkpoint, save_checkpoint
|
||||||
|
|
||||||
if frames_manifest is None:
|
# Carry forward from parent
|
||||||
all_frames = state.get("frames", [])
|
parent_outputs = {}
|
||||||
frames_manifest = save_frames(job_id, all_frames)
|
parent_stats = {}
|
||||||
|
parent_config = {}
|
||||||
|
if parent_checkpoint_id:
|
||||||
|
parent = get_checkpoint(parent_checkpoint_id)
|
||||||
|
if parent:
|
||||||
|
parent_outputs = dict(parent.stage_outputs or {})
|
||||||
|
parent_stats = dict(parent.stats or {})
|
||||||
|
parent_config = dict(parent.config_overrides or {})
|
||||||
|
|
||||||
checkpoint_data = serialize_state(state, frames_manifest)
|
# Add new stage output
|
||||||
frames_prefix = f"{CHECKPOINT_PREFIX}/{job_id}/frames/"
|
stage_outputs = {**parent_outputs, stage_name: output_json}
|
||||||
|
|
||||||
checkpoint = save_stage_checkpoint(
|
# Merge stats and config
|
||||||
job_id=job_id,
|
merged_stats = {**parent_stats, **(stats or {})}
|
||||||
stage=stage,
|
merged_config = {**parent_config, **(config_overrides or {})}
|
||||||
stage_index=stage_index,
|
|
||||||
frames_prefix=frames_prefix,
|
checkpoint = save_checkpoint(
|
||||||
frames_manifest=checkpoint_data.get("frames_manifest", {}),
|
timeline_id=timeline_id,
|
||||||
frames_meta=checkpoint_data.get("frames_meta", []),
|
parent_id=parent_checkpoint_id,
|
||||||
filtered_frame_sequences=checkpoint_data.get("filtered_frame_sequences", []),
|
stage_outputs=stage_outputs,
|
||||||
stage_output_key=checkpoint_data.get("stage_output_key", ""),
|
config_overrides=merged_config,
|
||||||
stats=checkpoint_data.get("stats", {}),
|
stats=merged_stats,
|
||||||
config_snapshot=checkpoint_data.get("config_overrides", {}),
|
|
||||||
config_overrides=checkpoint_data.get("config_overrides", {}),
|
|
||||||
video_path=checkpoint_data.get("video_path", ""),
|
|
||||||
profile_name=checkpoint_data.get("profile_name", ""),
|
|
||||||
is_scenario=is_scenario,
|
is_scenario=is_scenario,
|
||||||
scenario_label=scenario_label,
|
scenario_label=scenario_label,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Checkpoint saved: %s/%s (id=%s, scenario=%s)",
|
logger.info("Checkpoint saved: %s (timeline %s, stage %s, parent %s)",
|
||||||
job_id, stage, checkpoint.id, is_scenario)
|
checkpoint.id, timeline_id, stage_name, parent_checkpoint_id)
|
||||||
return str(checkpoint.id)
|
return str(checkpoint.id)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None:
|
||||||
# Load
|
"""Load a stage's output from a checkpoint."""
|
||||||
# ---------------------------------------------------------------------------
|
from core.db.detect import get_checkpoint
|
||||||
|
|
||||||
def load_checkpoint(job_id: str, stage: str) -> dict:
|
checkpoint = get_checkpoint(checkpoint_id)
|
||||||
"""
|
|
||||||
Load a stage checkpoint and reconstitute full DetectState.
|
|
||||||
"""
|
|
||||||
from core.db.detect import get_stage_checkpoint
|
|
||||||
|
|
||||||
checkpoint = get_stage_checkpoint(job_id, stage)
|
|
||||||
if not checkpoint:
|
if not checkpoint:
|
||||||
raise ValueError(f"No checkpoint for {job_id}/{stage}")
|
return None
|
||||||
|
|
||||||
data = {
|
return (checkpoint.stage_outputs or {}).get(stage_name)
|
||||||
"job_id": str(checkpoint.job_id),
|
|
||||||
"video_path": checkpoint.video_path,
|
|
||||||
"profile_name": checkpoint.profile_name,
|
|
||||||
"config_overrides": checkpoint.config_overrides,
|
|
||||||
"frames_manifest": checkpoint.frames_manifest,
|
|
||||||
"frames_meta": checkpoint.frames_meta,
|
|
||||||
"filtered_frame_sequences": checkpoint.filtered_frame_sequences,
|
|
||||||
"stage_output_key": checkpoint.stage_output_key,
|
|
||||||
"stats": checkpoint.stats,
|
|
||||||
}
|
|
||||||
|
|
||||||
raw_manifest = data.get("frames_manifest", {})
|
|
||||||
manifest = {int(k): v for k, v in raw_manifest.items()}
|
|
||||||
frame_metadata = data.get("frames_meta", [])
|
|
||||||
frames = load_frames(manifest, frame_metadata)
|
|
||||||
|
|
||||||
state = deserialize_state(data, frames)
|
|
||||||
|
|
||||||
logger.info("Checkpoint loaded: %s/%s (%d frames, scenario=%s)",
|
|
||||||
job_id, stage, len(frames), checkpoint.is_scenario)
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# List
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def list_checkpoints(job_id: str) -> list[str]:
|
|
||||||
"""List available checkpoint stages for a job."""
|
|
||||||
from core.db.detect import list_stage_checkpoints
|
|
||||||
return list_stage_checkpoints(job_id)
|
|
||||||
|
|||||||
@@ -1,21 +1,21 @@
|
|||||||
"""
|
"""
|
||||||
Pipeline stages.
|
Pipeline stages.
|
||||||
|
|
||||||
Each stage registers its StageDefinition on import,
|
Each stage is a file with a Stage subclass. Auto-discovered via
|
||||||
declaring IO (what it reads/writes from state),
|
__init_subclass__ — importing the file registers the stage.
|
||||||
config fields (what's tunable from the editor),
|
|
||||||
and serialization (how to checkpoint its outputs).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
StageDefinition,
|
Stage,
|
||||||
StageIO,
|
|
||||||
StageConfigField,
|
|
||||||
register_stage,
|
|
||||||
get_stage,
|
get_stage,
|
||||||
|
get_stage_instance,
|
||||||
list_stages,
|
list_stages,
|
||||||
|
list_stage_classes,
|
||||||
get_palette,
|
get_palette,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Populate registry with built-in stages
|
# Import all stage files to trigger auto-registration
|
||||||
|
from . import edge_detector # noqa: F401
|
||||||
|
|
||||||
|
# Import registry for backward compat (other stages still use old pattern)
|
||||||
from . import registry # noqa: F401
|
from . import registry # noqa: F401
|
||||||
|
|||||||
@@ -1,101 +1,131 @@
|
|||||||
"""
|
"""
|
||||||
Stage protocol — common interface for all pipeline stages.
|
Stage base class — common interface for all pipeline stages.
|
||||||
|
|
||||||
Every stage declares:
|
Each stage is a file that subclasses Stage. Auto-discovered via
|
||||||
- IO: what it reads/writes from DetectState
|
__init_subclass__. No manual registration needed.
|
||||||
- Config: tunable parameters for the editor
|
|
||||||
- Serialization: how to persist/restore its own outputs
|
|
||||||
|
|
||||||
The checkpoint layer is a black box — it asks each stage to serialize its
|
A stage:
|
||||||
outputs and stores the result. Stages own their data format. Binary data
|
- Has a StageDefinition (from schema) with name, config, IO
|
||||||
(frames, crops) goes to S3 via the stage itself. The checkpoint just
|
- Implements run(frames, config) → output
|
||||||
stores the JSON envelope.
|
- Owns its output serialization (opaque blob)
|
||||||
|
- Optionally has a TypeScript port for browser-side execution
|
||||||
|
|
||||||
The graph builder uses StageIO to validate that a stage's inputs are
|
The checkpoint layer stores stage output as blobs without knowing
|
||||||
satisfied by previous stages' outputs.
|
the format. The stage that wrote it is the only one that can read it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
@dataclass
|
from core.schema.models.stages import StageConfigField, StageIO, StageDefinition
|
||||||
class StageIO:
|
|
||||||
"""Declares what a stage reads and writes from/to DetectState."""
|
|
||||||
reads: list[str]
|
|
||||||
writes: list[str]
|
|
||||||
optional_reads: list[str] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StageConfigField:
|
|
||||||
"""A single tunable config parameter for the editor UI."""
|
|
||||||
name: str
|
|
||||||
type: str # "float", "int", "str", "bool", "list[str]"
|
|
||||||
default: Any
|
|
||||||
description: str = ""
|
|
||||||
min: float | None = None
|
|
||||||
max: float | None = None
|
|
||||||
options: list[str] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StageDefinition:
|
|
||||||
"""
|
|
||||||
Complete metadata for a pipeline stage.
|
|
||||||
|
|
||||||
The profile editor uses this to build the palette, generate config
|
|
||||||
forms, and validate graph connections. The checkpoint uses serialize_fn
|
|
||||||
and deserialize_fn to persist stage outputs without knowing the internals.
|
|
||||||
"""
|
|
||||||
name: str
|
|
||||||
label: str
|
|
||||||
description: str
|
|
||||||
io: StageIO
|
|
||||||
config_fields: list[StageConfigField] = field(default_factory=list)
|
|
||||||
category: str = "detection"
|
|
||||||
|
|
||||||
# The actual graph node function: (DetectState) → dict
|
|
||||||
fn: Callable | None = None
|
|
||||||
|
|
||||||
# Stage-owned serialization for checkpointing.
|
|
||||||
# serialize_fn: (state: dict, job_id: str) → json-compatible dict
|
|
||||||
# Stage picks its writes from state, serializes them.
|
|
||||||
# Binary data (frames) → S3 via stage, returns refs.
|
|
||||||
# deserialize_fn: (data: dict, job_id: str) → state update dict
|
|
||||||
# Stage restores its writes from the persisted data.
|
|
||||||
serialize_fn: Callable | None = None
|
|
||||||
deserialize_fn: Callable | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Registry
|
# Registry — auto-populated by __init_subclass__ (new stages)
|
||||||
|
# + register_stage() (legacy stages during migration)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
_REGISTRY: dict[str, StageDefinition] = {}
|
_REGISTRY: dict[str, type['Stage']] = {}
|
||||||
|
_LEGACY_REGISTRY: dict[str, StageDefinition] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_stage(definition: StageDefinition):
|
def register_stage(definition: StageDefinition):
|
||||||
_REGISTRY[definition.name] = definition
|
"""Legacy registration for stages not yet converted to Stage subclass."""
|
||||||
|
_LEGACY_REGISTRY[definition.name] = definition
|
||||||
|
|
||||||
|
|
||||||
|
class Stage:
|
||||||
|
"""
|
||||||
|
Base class for all pipeline stages.
|
||||||
|
|
||||||
|
Subclass this in detect/stages/<name>.py. Define `definition` as a
|
||||||
|
class attribute. Implement `run()`. Optionally override `serialize()`
|
||||||
|
and `deserialize()` for custom blob formats (default is JSON).
|
||||||
|
"""
|
||||||
|
|
||||||
|
definition: StageDefinition # set by each subclass
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
if hasattr(cls, 'definition') and cls.definition is not None:
|
||||||
|
_REGISTRY[cls.definition.name] = cls
|
||||||
|
|
||||||
|
def run(self, frames: list, config: dict) -> Any:
|
||||||
|
"""
|
||||||
|
Run the stage on a list of frames with the given config.
|
||||||
|
|
||||||
|
Config is a dict of parameter values (from slider UI or profile).
|
||||||
|
Returns the stage output — whatever shape this stage produces.
|
||||||
|
Debug overlays are included when config has debug=True.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def serialize(self, output: Any) -> bytes:
|
||||||
|
"""Serialize stage output to bytes for checkpoint storage."""
|
||||||
|
import json
|
||||||
|
return json.dumps(output, default=str).encode()
|
||||||
|
|
||||||
|
def deserialize(self, data: bytes) -> Any:
|
||||||
|
"""Deserialize stage output from checkpoint blob."""
|
||||||
|
import json
|
||||||
|
return json.loads(data)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Discovery API
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _all_definitions() -> dict[str, StageDefinition]:
|
||||||
|
"""Merge new Stage subclass registry + legacy registry."""
|
||||||
|
merged = {}
|
||||||
|
# Legacy first, new overwrites (new takes precedence)
|
||||||
|
for name, defn in _LEGACY_REGISTRY.items():
|
||||||
|
merged[name] = defn
|
||||||
|
for name, cls in _REGISTRY.items():
|
||||||
|
merged[name] = cls.definition
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
def get_stage(name: str) -> StageDefinition:
|
def get_stage(name: str) -> StageDefinition:
|
||||||
if name not in _REGISTRY:
|
"""Get a stage definition by name (works for both new and legacy)."""
|
||||||
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(_REGISTRY)}")
|
all_defs = _all_definitions()
|
||||||
return _REGISTRY[name]
|
if name not in all_defs:
|
||||||
|
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(all_defs)}")
|
||||||
|
return all_defs[name]
|
||||||
|
|
||||||
|
|
||||||
|
def get_stage_class(name: str) -> type[Stage] | None:
|
||||||
|
"""Get a Stage subclass by name. Returns None for legacy stages."""
|
||||||
|
return _REGISTRY.get(name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_stage_instance(name: str) -> Stage:
|
||||||
|
"""Get an instantiated Stage by name. Only works for new-style stages."""
|
||||||
|
cls = _REGISTRY.get(name)
|
||||||
|
if cls is None:
|
||||||
|
raise KeyError(f"No Stage subclass for {name!r}. Legacy stages don't have instances.")
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
|
||||||
def list_stages() -> list[StageDefinition]:
|
def list_stages() -> list[StageDefinition]:
|
||||||
|
"""List all registered stage definitions (new + legacy)."""
|
||||||
|
return list(_all_definitions().values())
|
||||||
|
|
||||||
|
|
||||||
|
def list_stage_classes() -> list[type[Stage]]:
|
||||||
|
"""List all registered Stage subclasses (new-style only)."""
|
||||||
return list(_REGISTRY.values())
|
return list(_REGISTRY.values())
|
||||||
|
|
||||||
|
|
||||||
def get_palette() -> dict[str, list[StageDefinition]]:
|
def get_palette() -> dict[str, list[StageDefinition]]:
|
||||||
"""Group stages by category for the editor palette."""
|
"""Group stages by category for the editor palette."""
|
||||||
palette: dict[str, list[StageDefinition]] = {}
|
palette: dict[str, list[StageDefinition]] = {}
|
||||||
for stage in _REGISTRY.values():
|
for defn in _all_definitions().values():
|
||||||
if stage.category not in palette:
|
if defn.category not in palette:
|
||||||
palette[stage.category] = []
|
palette[defn.category] = []
|
||||||
palette[stage.category].append(stage)
|
palette[defn.category].append(defn)
|
||||||
return palette
|
return palette
|
||||||
|
|||||||
@@ -7,168 +7,227 @@ advertising hoardings. Pure OpenCV, no ML models.
|
|||||||
Two modes:
|
Two modes:
|
||||||
- Remote: calls GPU inference server over HTTP
|
- Remote: calls GPU inference server over HTTP
|
||||||
- Local: imports cv2 directly (OpenCV on same machine)
|
- Local: imports cv2 directly (OpenCV on same machine)
|
||||||
|
|
||||||
Emits frame_update events with bounding boxes for the frame viewer.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from detect import emit
|
from detect import emit
|
||||||
from detect.models import BoundingBox, Frame
|
from detect.models import BoundingBox, Frame
|
||||||
from detect.profiles.base import RegionAnalysisConfig
|
from detect.stages.base import Stage
|
||||||
|
from core.schema.models.stages import StageDefinition, StageConfigField, StageIO
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EdgeDetectionStage(Stage):
|
||||||
|
|
||||||
|
definition = StageDefinition(
|
||||||
|
name="detect_edges",
|
||||||
|
label="Edge Detection",
|
||||||
|
description="Canny + HoughLinesP — find horizontal line pairs (hoarding boundaries)",
|
||||||
|
category="cv_analysis",
|
||||||
|
io=StageIO(
|
||||||
|
reads=["filtered_frames"],
|
||||||
|
writes=["edge_regions_by_frame"],
|
||||||
|
),
|
||||||
|
config_fields=[
|
||||||
|
StageConfigField("enabled", "bool", True, "Enable edge detection"),
|
||||||
|
StageConfigField("edge_canny_low", "int", 50, "Canny low threshold", min=0, max=255),
|
||||||
|
StageConfigField("edge_canny_high", "int", 150, "Canny high threshold", min=0, max=255),
|
||||||
|
StageConfigField("edge_hough_threshold", "int", 80, "Hough accumulator threshold", min=1, max=500),
|
||||||
|
StageConfigField("edge_hough_min_length", "int", 100, "Min line length (px)", min=10, max=2000),
|
||||||
|
StageConfigField("edge_hough_max_gap", "int", 10, "Max line gap (px)", min=1, max=100),
|
||||||
|
StageConfigField("edge_pair_max_distance", "int", 200, "Max distance between line pair (px)", min=10, max=500),
|
||||||
|
StageConfigField("edge_pair_min_distance", "int", 15, "Min distance between line pair (px)", min=5, max=200),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(self, frames: list[Frame], config: dict) -> dict[int, list[BoundingBox]]:
|
||||||
|
"""
|
||||||
|
Run edge detection on all frames.
|
||||||
|
|
||||||
|
Config keys: enabled, edge_canny_low, edge_canny_high, edge_hough_threshold,
|
||||||
|
edge_hough_min_length, edge_hough_max_gap, edge_pair_max_distance, edge_pair_min_distance,
|
||||||
|
debug (bool), inference_url (str|None), job_id (str|None).
|
||||||
|
|
||||||
|
Returns dict mapping frame sequence → list of BoundingBox.
|
||||||
|
"""
|
||||||
|
enabled = config.get("enabled", True)
|
||||||
|
job_id = config.get("job_id")
|
||||||
|
inference_url = config.get("inference_url") or os.environ.get("INFERENCE_URL")
|
||||||
|
|
||||||
|
if not enabled:
|
||||||
|
emit.log(job_id, "EdgeDetection", "INFO", "Edge detection disabled, skipping")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
mode = "remote" if inference_url else "local"
|
||||||
|
emit.log(job_id, "EdgeDetection", "INFO",
|
||||||
|
f"Detecting edges in {len(frames)} frames (mode={mode})")
|
||||||
|
|
||||||
|
all_boxes: dict[int, list[BoundingBox]] = {}
|
||||||
|
total_regions = 0
|
||||||
|
|
||||||
|
for frame in frames:
|
||||||
|
t0 = time.monotonic()
|
||||||
|
if inference_url:
|
||||||
|
boxes = self._run_remote(frame, config, inference_url, job_id or "")
|
||||||
|
else:
|
||||||
|
boxes = self._run_local(frame, config)
|
||||||
|
ms = (time.monotonic() - t0) * 1000
|
||||||
|
|
||||||
|
all_boxes[frame.sequence] = boxes
|
||||||
|
total_regions += len(boxes)
|
||||||
|
|
||||||
|
emit.log(job_id, "EdgeDetection", "DEBUG",
|
||||||
|
f"Frame {frame.sequence}: {len(boxes)} regions in {ms:.0f}ms"
|
||||||
|
+ (f" [{', '.join(b.label for b in boxes)}]" if boxes else ""))
|
||||||
|
|
||||||
|
if boxes and job_id:
|
||||||
|
box_dicts = [
|
||||||
|
{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
|
||||||
|
"confidence": b.confidence, "label": b.label,
|
||||||
|
"stage": "detect_edges"}
|
||||||
|
for b in boxes
|
||||||
|
]
|
||||||
|
emit.frame_update(
|
||||||
|
job_id,
|
||||||
|
frame_ref=frame.sequence,
|
||||||
|
timestamp=frame.timestamp,
|
||||||
|
jpeg_b64=_frame_to_b64(frame),
|
||||||
|
boxes=box_dicts,
|
||||||
|
)
|
||||||
|
|
||||||
|
emit.log(job_id, "EdgeDetection", "INFO",
|
||||||
|
f"Found {total_regions} edge regions across {len(frames)} frames")
|
||||||
|
emit.stats(job_id, cv_regions_detected=total_regions)
|
||||||
|
|
||||||
|
return all_boxes
|
||||||
|
|
||||||
|
def serialize(self, output: Any) -> bytes:
|
||||||
|
"""Serialize edge regions to JSON blob."""
|
||||||
|
serialized = {}
|
||||||
|
for seq, boxes in output.items():
|
||||||
|
serialized[str(seq)] = [
|
||||||
|
{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
|
||||||
|
"confidence": b.confidence, "label": b.label}
|
||||||
|
for b in boxes
|
||||||
|
]
|
||||||
|
return json.dumps(serialized).encode()
|
||||||
|
|
||||||
|
def deserialize(self, data: bytes) -> dict[int, list[BoundingBox]]:
|
||||||
|
"""Deserialize edge regions from JSON blob."""
|
||||||
|
raw = json.loads(data)
|
||||||
|
result = {}
|
||||||
|
for seq_str, box_dicts in raw.items():
|
||||||
|
boxes = [
|
||||||
|
BoundingBox(x=b["x"], y=b["y"], w=b["w"], h=b["h"],
|
||||||
|
confidence=b["confidence"], label=b["label"])
|
||||||
|
for b in box_dicts
|
||||||
|
]
|
||||||
|
result[int(seq_str)] = boxes
|
||||||
|
return result
|
||||||
|
|
||||||
|
# --- Private helpers ---
|
||||||
|
|
||||||
|
def _run_remote(self, frame: Frame, config: dict,
|
||||||
|
inference_url: str, job_id: str) -> list[BoundingBox]:
|
||||||
|
from detect.inference import InferenceClient
|
||||||
|
from detect.emit import _run_log_level
|
||||||
|
|
||||||
|
client = InferenceClient(
|
||||||
|
base_url=inference_url, job_id=job_id, log_level=_run_log_level,
|
||||||
|
)
|
||||||
|
results = client.detect_edges(
|
||||||
|
image=frame.image,
|
||||||
|
edge_canny_low=config.get("edge_canny_low", 50),
|
||||||
|
edge_canny_high=config.get("edge_canny_high", 150),
|
||||||
|
edge_hough_threshold=config.get("edge_hough_threshold", 80),
|
||||||
|
edge_hough_min_length=config.get("edge_hough_min_length", 100),
|
||||||
|
edge_hough_max_gap=config.get("edge_hough_max_gap", 10),
|
||||||
|
edge_pair_max_distance=config.get("edge_pair_max_distance", 200),
|
||||||
|
edge_pair_min_distance=config.get("edge_pair_min_distance", 15),
|
||||||
|
)
|
||||||
|
boxes = []
|
||||||
|
for r in results:
|
||||||
|
box = BoundingBox(
|
||||||
|
x=r.x, y=r.y, w=r.w, h=r.h,
|
||||||
|
confidence=r.confidence, label=r.label,
|
||||||
|
)
|
||||||
|
boxes.append(box)
|
||||||
|
return boxes
|
||||||
|
|
||||||
|
def _run_local(self, frame: Frame, config: dict) -> list[BoundingBox]:
|
||||||
|
detect_edges_fn = _load_cv_edges().detect_edges
|
||||||
|
|
||||||
|
edge_results = detect_edges_fn(
|
||||||
|
frame.image,
|
||||||
|
canny_low=config.get("edge_canny_low", 50),
|
||||||
|
canny_high=config.get("edge_canny_high", 150),
|
||||||
|
hough_threshold=config.get("edge_hough_threshold", 80),
|
||||||
|
hough_min_length=config.get("edge_hough_min_length", 100),
|
||||||
|
hough_max_gap=config.get("edge_hough_max_gap", 10),
|
||||||
|
pair_max_distance=config.get("edge_pair_max_distance", 200),
|
||||||
|
pair_min_distance=config.get("edge_pair_min_distance", 15),
|
||||||
|
)
|
||||||
|
|
||||||
|
boxes = []
|
||||||
|
for r in edge_results:
|
||||||
|
box = BoundingBox(
|
||||||
|
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
|
||||||
|
confidence=r["confidence"], label=r["label"],
|
||||||
|
)
|
||||||
|
boxes.append(box)
|
||||||
|
return boxes
|
||||||
|
|
||||||
|
|
||||||
|
# --- Module-level helpers ---
|
||||||
|
|
||||||
def _frame_to_b64(frame: Frame) -> str:
|
def _frame_to_b64(frame: Frame) -> str:
|
||||||
"""Encode frame as base64 JPEG for SSE frame_update events."""
|
|
||||||
img = Image.fromarray(frame.image)
|
img = Image.fromarray(frame.image)
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
img.save(buf, format="JPEG", quality=70)
|
img.save(buf, format="JPEG", quality=70)
|
||||||
return base64.b64encode(buf.getvalue()).decode()
|
return base64.b64encode(buf.getvalue()).decode()
|
||||||
|
|
||||||
|
|
||||||
def _detect_remote(
|
|
||||||
frame: Frame,
|
|
||||||
config: RegionAnalysisConfig,
|
|
||||||
inference_url: str,
|
|
||||||
job_id: str = "",
|
|
||||||
log_level: str = "INFO",
|
|
||||||
) -> list[BoundingBox]:
|
|
||||||
"""Call the inference server over HTTP."""
|
|
||||||
from detect.inference import InferenceClient
|
|
||||||
|
|
||||||
client = InferenceClient(
|
|
||||||
base_url=inference_url, job_id=job_id, log_level=log_level,
|
|
||||||
)
|
|
||||||
results = client.detect_edges(
|
|
||||||
image=frame.image,
|
|
||||||
edge_canny_low=config.edge_canny_low,
|
|
||||||
edge_canny_high=config.edge_canny_high,
|
|
||||||
edge_hough_threshold=config.edge_hough_threshold,
|
|
||||||
edge_hough_min_length=config.edge_hough_min_length,
|
|
||||||
edge_hough_max_gap=config.edge_hough_max_gap,
|
|
||||||
edge_pair_max_distance=config.edge_pair_max_distance,
|
|
||||||
edge_pair_min_distance=config.edge_pair_min_distance,
|
|
||||||
)
|
|
||||||
boxes = []
|
|
||||||
for r in results:
|
|
||||||
box = BoundingBox(
|
|
||||||
x=r.x, y=r.y, w=r.w, h=r.h,
|
|
||||||
confidence=r.confidence, label=r.label,
|
|
||||||
)
|
|
||||||
boxes.append(box)
|
|
||||||
return boxes
|
|
||||||
|
|
||||||
|
|
||||||
_cv_edges_mod = None
|
_cv_edges_mod = None
|
||||||
|
|
||||||
|
|
||||||
def _load_cv_edges():
|
def _load_cv_edges():
|
||||||
"""Load edges module directly — gpu/models/__init__.py has GPU-container-only imports."""
|
|
||||||
global _cv_edges_mod
|
global _cv_edges_mod
|
||||||
if _cv_edges_mod is None:
|
if _cv_edges_mod is None:
|
||||||
import importlib.util
|
import importlib.util
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
spec = importlib.util.spec_from_file_location("cv_edges", Path("gpu/models/cv/edges.py"))
|
spec = importlib.util.spec_from_file_location("cv_edges", Path("gpu/models/cv/edges.py"))
|
||||||
_cv_edges_mod = importlib.util.module_from_spec(spec)
|
_cv_edges_mod = importlib.util.module_from_spec(spec)
|
||||||
spec.loader.exec_module(_cv_edges_mod)
|
spec.loader.exec_module(_cv_edges_mod)
|
||||||
return _cv_edges_mod
|
return _cv_edges_mod
|
||||||
|
|
||||||
|
|
||||||
def _detect_local(frame: Frame, config: RegionAnalysisConfig) -> list[BoundingBox]:
|
# --- Backward compat: standalone function for graph.py ---
|
||||||
"""Run edge detection in-process (requires opencv-python)."""
|
|
||||||
detect_edges_fn = _load_cv_edges().detect_edges
|
|
||||||
|
|
||||||
edge_results = detect_edges_fn(
|
def detect_edge_regions(frames, config, inference_url=None, job_id=None):
|
||||||
frame.image,
|
"""Convenience wrapper — calls EdgeDetectionStage.run()."""
|
||||||
canny_low=config.edge_canny_low,
|
stage = EdgeDetectionStage()
|
||||||
canny_high=config.edge_canny_high,
|
cfg = {
|
||||||
hough_threshold=config.edge_hough_threshold,
|
"enabled": config.enabled,
|
||||||
hough_min_length=config.edge_hough_min_length,
|
"edge_canny_low": config.edge_canny_low,
|
||||||
hough_max_gap=config.edge_hough_max_gap,
|
"edge_canny_high": config.edge_canny_high,
|
||||||
pair_max_distance=config.edge_pair_max_distance,
|
"edge_hough_threshold": config.edge_hough_threshold,
|
||||||
pair_min_distance=config.edge_pair_min_distance,
|
"edge_hough_min_length": config.edge_hough_min_length,
|
||||||
)
|
"edge_hough_max_gap": config.edge_hough_max_gap,
|
||||||
|
"edge_pair_max_distance": config.edge_pair_max_distance,
|
||||||
boxes = []
|
"edge_pair_min_distance": config.edge_pair_min_distance,
|
||||||
for r in edge_results:
|
"inference_url": inference_url,
|
||||||
box = BoundingBox(
|
"job_id": job_id,
|
||||||
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
|
}
|
||||||
confidence=r["confidence"], label=r["label"],
|
return stage.run(frames, cfg)
|
||||||
)
|
|
||||||
boxes.append(box)
|
|
||||||
return boxes
|
|
||||||
|
|
||||||
|
|
||||||
def detect_edge_regions(
|
|
||||||
frames: list[Frame],
|
|
||||||
config: RegionAnalysisConfig,
|
|
||||||
inference_url: str | None = None,
|
|
||||||
job_id: str | None = None,
|
|
||||||
) -> dict[int, list[BoundingBox]]:
|
|
||||||
"""
|
|
||||||
Run edge detection on all frames.
|
|
||||||
|
|
||||||
Returns a dict mapping frame sequence → list of bounding boxes.
|
|
||||||
"""
|
|
||||||
if not config.enabled:
|
|
||||||
emit.log(job_id, "EdgeDetection", "INFO", "Edge detection disabled, skipping")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
mode = "remote" if inference_url else "local"
|
|
||||||
emit.log(job_id, "EdgeDetection", "INFO",
|
|
||||||
f"Detecting edges in {len(frames)} frames (mode={mode})")
|
|
||||||
|
|
||||||
all_boxes: dict[int, list[BoundingBox]] = {}
|
|
||||||
total_regions = 0
|
|
||||||
|
|
||||||
for i, frame in enumerate(frames):
|
|
||||||
t0 = time.monotonic()
|
|
||||||
if inference_url:
|
|
||||||
from detect.emit import _run_log_level
|
|
||||||
boxes = _detect_remote(
|
|
||||||
frame, config, inference_url,
|
|
||||||
job_id=job_id or "", log_level=_run_log_level,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
boxes = _detect_local(frame, config)
|
|
||||||
analysis_ms = (time.monotonic() - t0) * 1000
|
|
||||||
|
|
||||||
all_boxes[frame.sequence] = boxes
|
|
||||||
total_regions += len(boxes)
|
|
||||||
|
|
||||||
emit.log(job_id, "EdgeDetection", "DEBUG",
|
|
||||||
f"Frame {frame.sequence}: {len(boxes)} regions in {analysis_ms:.0f}ms"
|
|
||||||
+ (f" [{', '.join(b.label for b in boxes)}]" if boxes else ""))
|
|
||||||
|
|
||||||
if boxes and job_id:
|
|
||||||
box_dicts = [
|
|
||||||
{
|
|
||||||
"x": b.x, "y": b.y, "w": b.w, "h": b.h,
|
|
||||||
"confidence": b.confidence, "label": b.label,
|
|
||||||
"stage": "detect_edges",
|
|
||||||
}
|
|
||||||
for b in boxes
|
|
||||||
]
|
|
||||||
emit.frame_update(
|
|
||||||
job_id,
|
|
||||||
frame_ref=frame.sequence,
|
|
||||||
timestamp=frame.timestamp,
|
|
||||||
jpeg_b64=_frame_to_b64(frame),
|
|
||||||
boxes=box_dicts,
|
|
||||||
)
|
|
||||||
|
|
||||||
emit.log(job_id, "EdgeDetection", "INFO",
|
|
||||||
f"Found {total_regions} edge regions across {len(frames)} frames")
|
|
||||||
emit.stats(job_id, cv_regions_detected=total_regions)
|
|
||||||
|
|
||||||
return all_boxes
|
|
||||||
|
|||||||
@@ -45,15 +45,15 @@ def main():
|
|||||||
return
|
return
|
||||||
|
|
||||||
logger.info("")
|
logger.info("")
|
||||||
logger.info("%3s %-35s %-12s %-18s %6s %s", "#", "Label", "Job ID", "Stage", "Frames", "Created")
|
logger.info("%3s %-35s %-12s %6s %s", "#", "Label", "Timeline", "Stages", "Created")
|
||||||
logger.info("─" * 100)
|
logger.info("─" * 80)
|
||||||
|
|
||||||
for i, s in enumerate(scenarios, 1):
|
for i, s in enumerate(scenarios, 1):
|
||||||
manifest = s.frames_manifest or {}
|
|
||||||
created = str(s.created_at)[:19] if s.created_at else "—"
|
created = str(s.created_at)[:19] if s.created_at else "—"
|
||||||
job_short = str(s.job_id)[:8]
|
tid_short = str(s.timeline_id)[:8]
|
||||||
logger.info("%3d %-35s %-12s %-18s %6d %s",
|
stage_count = len(s.stage_outputs or {})
|
||||||
i, s.scenario_label, job_short, s.stage, len(manifest), created)
|
logger.info("%3d %-35s %-12s %6d %s",
|
||||||
|
i, s.scenario_label, tid_short, stage_count, created)
|
||||||
|
|
||||||
logger.info("")
|
logger.info("")
|
||||||
|
|
||||||
@@ -73,7 +73,7 @@ def main():
|
|||||||
logger.error("Scenario not found: %s", args.open)
|
logger.error("Scenario not found: %s", args.open)
|
||||||
return
|
return
|
||||||
|
|
||||||
url = f"{args.base_url}?job={target.job_id}#/editor/detect_edges"
|
url = f"{args.base_url}?job={target.timeline_id}#/editor/detect_edges"
|
||||||
logger.info("Opening: %s", url)
|
logger.info("Opening: %s", url)
|
||||||
webbrowser.open(url)
|
webbrowser.open(url)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,26 +1,20 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
Seed a scenario checkpoint from a video chunk.
|
Seed a scenario from a video chunk.
|
||||||
|
|
||||||
Extracts frames via ffmpeg, uploads to MinIO, creates a StageCheckpoint
|
Creates a Timeline (frames in MinIO) + Branch + Checkpoint marked
|
||||||
in Postgres marked as a scenario. No pipeline, no Redis, no SSE.
|
as a scenario. No pipeline, no Redis, no SSE.
|
||||||
|
|
||||||
Prerequisites:
|
Prerequisites:
|
||||||
- Postgres reachable (port-forward or local)
|
- Postgres reachable (Kind NodePort or local)
|
||||||
- MinIO reachable (port-forward or local)
|
- MinIO reachable (Kind NodePort or local)
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
# With K8s port-forwards:
|
|
||||||
kubectl port-forward svc/postgres 5432:5432 &
|
|
||||||
kubectl port-forward svc/minio 9000:9000 &
|
|
||||||
|
|
||||||
python tests/detect/manual/seed_scenario.py
|
python tests/detect/manual/seed_scenario.py
|
||||||
|
|
||||||
# Custom video:
|
|
||||||
python tests/detect/manual/seed_scenario.py --video media/mpr/out/chunks/.../chunk_0001.mp4
|
python tests/detect/manual/seed_scenario.py --video media/mpr/out/chunks/.../chunk_0001.mp4
|
||||||
|
|
||||||
Then open:
|
Then open:
|
||||||
http://mpr.local.ar/detection/?job=<JOB_ID>#/editor/detect_edges
|
http://mpr.local.ar/detection/?job=<TIMELINE_ID>#/editor/detect_edges
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -31,7 +25,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Seed a scenario checkpoint")
|
parser = argparse.ArgumentParser(description="Seed a scenario")
|
||||||
parser.add_argument("--video",
|
parser.add_argument("--video",
|
||||||
default="media/mpr/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4")
|
default="media/mpr/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4")
|
||||||
parser.add_argument("--label", default="chelsea_edges_default",
|
parser.add_argument("--label", default="chelsea_edges_default",
|
||||||
@@ -44,7 +38,6 @@ parser.add_argument("--s3-url",
|
|||||||
default=os.environ.get("S3_ENDPOINT_URL", "http://localhost:9000"))
|
default=os.environ.get("S3_ENDPOINT_URL", "http://localhost:9000"))
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Set env before imports
|
|
||||||
os.environ["DATABASE_URL"] = args.db_url
|
os.environ["DATABASE_URL"] = args.db_url
|
||||||
os.environ["S3_ENDPOINT_URL"] = args.s3_url
|
os.environ["S3_ENDPOINT_URL"] = args.s3_url
|
||||||
os.environ.setdefault("AWS_ACCESS_KEY_ID", "minioadmin")
|
os.environ.setdefault("AWS_ACCESS_KEY_ID", "minioadmin")
|
||||||
@@ -57,7 +50,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def extract_frames_ffmpeg(video_path: str, fps: float, max_frames: int):
|
def extract_frames_ffmpeg(video_path: str, fps: float, max_frames: int):
|
||||||
"""Extract frames using ffmpeg subprocess — no pipeline dependencies."""
|
"""Extract frames using ffmpeg — no pipeline dependencies."""
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -82,7 +75,7 @@ def extract_frames_ffmpeg(video_path: str, fps: float, max_frames: int):
|
|||||||
|
|
||||||
frames = []
|
frames = []
|
||||||
for jpg in sorted(Path(tmpdir).glob("frame_*.jpg")):
|
for jpg in sorted(Path(tmpdir).glob("frame_*.jpg")):
|
||||||
seq = int(jpg.stem.split("_")[1]) - 1 # 0-indexed
|
seq = int(jpg.stem.split("_")[1]) - 1
|
||||||
img = Image.open(jpg).convert("RGB")
|
img = Image.open(jpg).convert("RGB")
|
||||||
image_array = np.array(img)
|
image_array = np.array(img)
|
||||||
frame = Frame(
|
frame = Frame(
|
||||||
@@ -99,7 +92,6 @@ def extract_frames_ffmpeg(video_path: str, fps: float, max_frames: int):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
job_id = str(uuid.uuid4())
|
|
||||||
video_path = args.video
|
video_path = args.video
|
||||||
|
|
||||||
if not os.path.exists(video_path):
|
if not os.path.exists(video_path):
|
||||||
@@ -107,7 +99,6 @@ def main():
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
logger.info("Video: %s", video_path)
|
logger.info("Video: %s", video_path)
|
||||||
logger.info("Job ID: %s", job_id)
|
|
||||||
logger.info("Label: %s", args.label)
|
logger.info("Label: %s", args.label)
|
||||||
|
|
||||||
# Ensure DB tables exist
|
# Ensure DB tables exist
|
||||||
@@ -119,57 +110,37 @@ def main():
|
|||||||
frames = extract_frames_ffmpeg(video_path, args.fps, args.max_frames)
|
frames = extract_frames_ffmpeg(video_path, args.fps, args.max_frames)
|
||||||
logger.info("Extracted %d frames", len(frames))
|
logger.info("Extracted %d frames", len(frames))
|
||||||
|
|
||||||
# Upload frames to MinIO
|
# Create timeline + branch + checkpoint
|
||||||
from detect.checkpoint.frames import save_frames
|
from detect.checkpoint.storage import create_timeline, save_stage_output
|
||||||
logger.info("Uploading frames to MinIO...")
|
|
||||||
manifest = save_frames(job_id, frames)
|
|
||||||
logger.info("Uploaded %d frames", len(manifest))
|
|
||||||
|
|
||||||
# Build frame metadata
|
timeline_id, branch_id = create_timeline(
|
||||||
frames_meta = [
|
source_video=video_path,
|
||||||
{
|
|
||||||
"sequence": f.sequence,
|
|
||||||
"chunk_id": f.chunk_id,
|
|
||||||
"timestamp": f.timestamp,
|
|
||||||
"perceptual_hash": "",
|
|
||||||
}
|
|
||||||
for f in frames
|
|
||||||
]
|
|
||||||
|
|
||||||
# All frames are "filtered" (no scene filter ran)
|
|
||||||
filtered_sequences = [f.sequence for f in frames]
|
|
||||||
|
|
||||||
# Save checkpoint as scenario
|
|
||||||
from core.db.detect import save_stage_checkpoint
|
|
||||||
from detect.checkpoint.frames import CHECKPOINT_PREFIX
|
|
||||||
|
|
||||||
checkpoint = save_stage_checkpoint(
|
|
||||||
job_id=job_id,
|
|
||||||
stage="filter_scenes",
|
|
||||||
stage_index=1,
|
|
||||||
frames_prefix=f"{CHECKPOINT_PREFIX}/{job_id}/frames/",
|
|
||||||
frames_manifest={str(k): v for k, v in manifest.items()},
|
|
||||||
frames_meta=frames_meta,
|
|
||||||
filtered_frame_sequences=filtered_sequences,
|
|
||||||
stage_output_key="",
|
|
||||||
stats={"frames_extracted": len(frames), "frames_after_scene_filter": len(frames)},
|
|
||||||
config_snapshot={},
|
|
||||||
config_overrides={},
|
|
||||||
video_path=video_path,
|
|
||||||
profile_name="soccer_broadcast",
|
profile_name="soccer_broadcast",
|
||||||
is_scenario=True,
|
frames=frames,
|
||||||
scenario_label=args.label,
|
fps=args.fps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Mark as scenario
|
||||||
|
from core.db.detect import get_latest_checkpoint
|
||||||
|
from core.db.connection import get_session
|
||||||
|
|
||||||
|
checkpoint = get_latest_checkpoint(branch_id)
|
||||||
|
if checkpoint:
|
||||||
|
checkpoint.is_scenario = True
|
||||||
|
checkpoint.scenario_label = args.label
|
||||||
|
with get_session() as session:
|
||||||
|
session.add(checkpoint)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
logger.info("")
|
logger.info("")
|
||||||
logger.info("Scenario created:")
|
logger.info("Scenario created:")
|
||||||
logger.info(" ID: %s", checkpoint.id)
|
logger.info(" Timeline: %s", timeline_id)
|
||||||
logger.info(" Job: %s", job_id)
|
logger.info(" Branch: %s", branch_id)
|
||||||
logger.info(" Label: %s", args.label)
|
logger.info(" Label: %s", args.label)
|
||||||
logger.info(" Frames: %d", len(frames))
|
logger.info(" Frames: %d", len(frames))
|
||||||
logger.info("")
|
logger.info("")
|
||||||
logger.info("Open in editor:")
|
logger.info("Open in editor:")
|
||||||
logger.info(" http://mpr.local.ar/detection/?job=%s#/editor/detect_edges", job_id)
|
logger.info(" http://mpr.local.ar/detection/?job=%s#/editor/detect_edges", timeline_id)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Tests for the stage registry."""
|
"""Tests for the stage registry."""
|
||||||
|
|
||||||
from detect.stages import list_stages, get_stage, get_palette
|
from detect.stages import list_stages, get_stage, get_palette
|
||||||
|
from detect.stages.base import get_stage_class
|
||||||
|
|
||||||
|
|
||||||
EXPECTED_STAGES = [
|
EXPECTED_STAGES = [
|
||||||
@@ -26,9 +27,17 @@ def test_stage_has_io():
|
|||||||
|
|
||||||
def test_stage_has_serialization():
|
def test_stage_has_serialization():
|
||||||
for name in EXPECTED_STAGES:
|
for name in EXPECTED_STAGES:
|
||||||
stage = get_stage(name)
|
defn = get_stage(name)
|
||||||
assert stage.serialize_fn is not None, f"{name} has no serialize_fn"
|
stage_cls = get_stage_class(name)
|
||||||
assert stage.deserialize_fn is not None, f"{name} has no deserialize_fn"
|
if stage_cls is not None:
|
||||||
|
# New-style: serialization on the class
|
||||||
|
instance = stage_cls()
|
||||||
|
assert hasattr(instance, 'serialize'), f"{name} has no serialize method"
|
||||||
|
assert hasattr(instance, 'deserialize'), f"{name} has no deserialize method"
|
||||||
|
else:
|
||||||
|
# Legacy: serialization on the definition
|
||||||
|
assert defn.serialize_fn is not None, f"{name} has no serialize_fn"
|
||||||
|
assert defn.deserialize_fn is not None, f"{name} has no deserialize_fn"
|
||||||
|
|
||||||
|
|
||||||
def test_palette_groups():
|
def test_palette_groups():
|
||||||
|
|||||||
Reference in New Issue
Block a user