refactor stage 1
This commit is contained in:
@@ -1,14 +1,18 @@
|
||||
"""
|
||||
Stage checkpoint, replay, and retry.
|
||||
Checkpoint system — Timeline + Checkpoint tree.
|
||||
|
||||
detect/checkpoint/
|
||||
frames.py — frame image S3 upload/download
|
||||
serializer.py — state ↔ JSON conversion
|
||||
storage.py — checkpoint save/load/list (Postgres + S3)
|
||||
replay.py — replay_from, OverrideProfile
|
||||
storage.py — Timeline + Checkpoint (Postgres + MinIO)
|
||||
replay.py — replay (TODO: migrate to new model)
|
||||
tasks.py — retry_candidates Celery task
|
||||
"""
|
||||
|
||||
from .storage import save_checkpoint, load_checkpoint, list_checkpoints
|
||||
from .storage import (
|
||||
create_timeline,
|
||||
get_timeline_frames,
|
||||
get_timeline_frames_b64,
|
||||
save_stage_output,
|
||||
load_stage_output,
|
||||
)
|
||||
from .frames import save_frames, load_frames
|
||||
from .replay import replay_from, OverrideProfile
|
||||
|
||||
@@ -12,7 +12,13 @@ import logging
|
||||
import uuid
|
||||
|
||||
from detect import emit
|
||||
from detect.checkpoint import load_checkpoint, list_checkpoints
|
||||
# TODO: migrate to Timeline/Branch/Checkpoint model
|
||||
# These old functions no longer exist — replay needs rework
|
||||
def _not_migrated(*args, **kwargs):
|
||||
raise NotImplementedError("Replay not yet migrated to Timeline/Branch/Checkpoint model")
|
||||
|
||||
load_checkpoint = _not_migrated
|
||||
list_checkpoints = _not_migrated
|
||||
from detect.graph import NODES, build_graph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,116 +1,178 @@
|
||||
"""
|
||||
Checkpoint storage — save/load stage state.
|
||||
Checkpoint storage — Timeline + Checkpoint (tree of snapshots).
|
||||
|
||||
Binary data (frame images) → S3/MinIO via frames.py
|
||||
Structured data (stage output, stats, config) → Postgres
|
||||
Timeline: frame sequence from source video (frames in MinIO)
|
||||
Checkpoint: snapshot of pipeline state (stage outputs as JSONB in Postgres)
|
||||
parent_id forms a tree — multiple children = different config tries
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from .frames import save_frames, load_frames, CHECKPOINT_PREFIX
|
||||
from .serializer import serialize_state, deserialize_state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Save
|
||||
# Timeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def save_checkpoint(
|
||||
job_id: str,
|
||||
stage: str,
|
||||
stage_index: int,
|
||||
state: dict,
|
||||
frames_manifest: dict[int, str] | None = None,
|
||||
def create_timeline(
|
||||
source_video: str,
|
||||
profile_name: str,
|
||||
frames: list,
|
||||
fps: float = 2.0,
|
||||
source_asset_id: UUID | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Create a timeline from frames. Uploads frame images to MinIO,
|
||||
creates Timeline + root Checkpoint in Postgres.
|
||||
|
||||
Returns (timeline_id, checkpoint_id).
|
||||
"""
|
||||
from core.db.detect import create_timeline as db_create_timeline
|
||||
from core.db.detect import save_checkpoint
|
||||
|
||||
# Create timeline
|
||||
timeline = db_create_timeline(
|
||||
source_video=source_video,
|
||||
profile_name=profile_name,
|
||||
source_asset_id=source_asset_id,
|
||||
fps=fps,
|
||||
)
|
||||
tid = str(timeline.id)
|
||||
|
||||
# Upload frames to MinIO
|
||||
manifest = save_frames(tid, frames)
|
||||
|
||||
# Store frame metadata on the timeline
|
||||
frames_meta = [
|
||||
{
|
||||
"sequence": f.sequence,
|
||||
"chunk_id": getattr(f, "chunk_id", 0),
|
||||
"timestamp": f.timestamp,
|
||||
"perceptual_hash": getattr(f, "perceptual_hash", ""),
|
||||
}
|
||||
for f in frames
|
||||
]
|
||||
|
||||
timeline.frames_prefix = f"{CHECKPOINT_PREFIX}/{tid}/frames/"
|
||||
timeline.frames_manifest = {str(k): v for k, v in manifest.items()}
|
||||
timeline.frames_meta = frames_meta
|
||||
|
||||
from core.db.connection import get_session
|
||||
with get_session() as session:
|
||||
session.add(timeline)
|
||||
session.commit()
|
||||
|
||||
# Create root checkpoint (no parent, no stage outputs yet)
|
||||
checkpoint = save_checkpoint(
|
||||
timeline_id=timeline.id,
|
||||
parent_id=None,
|
||||
stage_outputs={},
|
||||
stats={"frames_extracted": len(frames)},
|
||||
)
|
||||
|
||||
logger.info("Timeline created: %s (%d frames, root checkpoint %s)",
|
||||
tid, len(frames), checkpoint.id)
|
||||
return tid, str(checkpoint.id)
|
||||
|
||||
|
||||
def get_timeline_frames(timeline_id: str) -> list:
|
||||
"""Load frames from a timeline (from MinIO) as Frame objects."""
|
||||
from core.db.detect import get_timeline
|
||||
|
||||
timeline = get_timeline(timeline_id)
|
||||
if not timeline:
|
||||
raise ValueError(f"Timeline not found: {timeline_id}")
|
||||
|
||||
raw_manifest = timeline.frames_manifest or {}
|
||||
manifest = {int(k): v for k, v in raw_manifest.items()}
|
||||
frame_metadata = timeline.frames_meta or []
|
||||
|
||||
return load_frames(manifest, frame_metadata)
|
||||
|
||||
|
||||
def get_timeline_frames_b64(timeline_id: str) -> list[dict]:
|
||||
"""Load frames as base64 JPEG (lightweight, no numpy)."""
|
||||
from core.db.detect import get_timeline
|
||||
from .frames import load_frames_b64
|
||||
|
||||
timeline = get_timeline(timeline_id)
|
||||
if not timeline:
|
||||
raise ValueError(f"Timeline not found: {timeline_id}")
|
||||
|
||||
raw_manifest = timeline.frames_manifest or {}
|
||||
manifest = {int(k): v for k, v in raw_manifest.items()}
|
||||
frame_metadata = timeline.frames_meta or []
|
||||
|
||||
return load_frames_b64(manifest, frame_metadata)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Checkpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def save_stage_output(
|
||||
timeline_id: str,
|
||||
parent_checkpoint_id: str | None,
|
||||
stage_name: str,
|
||||
output_json: dict,
|
||||
config_overrides: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
is_scenario: bool = False,
|
||||
scenario_label: str = "",
|
||||
) -> str:
|
||||
"""
|
||||
Save a stage checkpoint.
|
||||
Save a stage's output as a new checkpoint (child of parent).
|
||||
|
||||
Saves frame images to S3 (if not already saved), then persists
|
||||
structured state to Postgres.
|
||||
|
||||
Returns the checkpoint DB id.
|
||||
Carries forward stage outputs from parent + adds the new one.
|
||||
Returns the new checkpoint ID.
|
||||
"""
|
||||
from core.db.detect import save_stage_checkpoint
|
||||
from core.db.detect import get_checkpoint, save_checkpoint
|
||||
|
||||
if frames_manifest is None:
|
||||
all_frames = state.get("frames", [])
|
||||
frames_manifest = save_frames(job_id, all_frames)
|
||||
# Carry forward from parent
|
||||
parent_outputs = {}
|
||||
parent_stats = {}
|
||||
parent_config = {}
|
||||
if parent_checkpoint_id:
|
||||
parent = get_checkpoint(parent_checkpoint_id)
|
||||
if parent:
|
||||
parent_outputs = dict(parent.stage_outputs or {})
|
||||
parent_stats = dict(parent.stats or {})
|
||||
parent_config = dict(parent.config_overrides or {})
|
||||
|
||||
checkpoint_data = serialize_state(state, frames_manifest)
|
||||
frames_prefix = f"{CHECKPOINT_PREFIX}/{job_id}/frames/"
|
||||
# Add new stage output
|
||||
stage_outputs = {**parent_outputs, stage_name: output_json}
|
||||
|
||||
checkpoint = save_stage_checkpoint(
|
||||
job_id=job_id,
|
||||
stage=stage,
|
||||
stage_index=stage_index,
|
||||
frames_prefix=frames_prefix,
|
||||
frames_manifest=checkpoint_data.get("frames_manifest", {}),
|
||||
frames_meta=checkpoint_data.get("frames_meta", []),
|
||||
filtered_frame_sequences=checkpoint_data.get("filtered_frame_sequences", []),
|
||||
stage_output_key=checkpoint_data.get("stage_output_key", ""),
|
||||
stats=checkpoint_data.get("stats", {}),
|
||||
config_snapshot=checkpoint_data.get("config_overrides", {}),
|
||||
config_overrides=checkpoint_data.get("config_overrides", {}),
|
||||
video_path=checkpoint_data.get("video_path", ""),
|
||||
profile_name=checkpoint_data.get("profile_name", ""),
|
||||
# Merge stats and config
|
||||
merged_stats = {**parent_stats, **(stats or {})}
|
||||
merged_config = {**parent_config, **(config_overrides or {})}
|
||||
|
||||
checkpoint = save_checkpoint(
|
||||
timeline_id=timeline_id,
|
||||
parent_id=parent_checkpoint_id,
|
||||
stage_outputs=stage_outputs,
|
||||
config_overrides=merged_config,
|
||||
stats=merged_stats,
|
||||
is_scenario=is_scenario,
|
||||
scenario_label=scenario_label,
|
||||
)
|
||||
|
||||
logger.info("Checkpoint saved: %s/%s (id=%s, scenario=%s)",
|
||||
job_id, stage, checkpoint.id, is_scenario)
|
||||
logger.info("Checkpoint saved: %s (timeline %s, stage %s, parent %s)",
|
||||
checkpoint.id, timeline_id, stage_name, parent_checkpoint_id)
|
||||
return str(checkpoint.id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Load
|
||||
# ---------------------------------------------------------------------------
|
||||
def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None:
|
||||
"""Load a stage's output from a checkpoint."""
|
||||
from core.db.detect import get_checkpoint
|
||||
|
||||
def load_checkpoint(job_id: str, stage: str) -> dict:
|
||||
"""
|
||||
Load a stage checkpoint and reconstitute full DetectState.
|
||||
"""
|
||||
from core.db.detect import get_stage_checkpoint
|
||||
|
||||
checkpoint = get_stage_checkpoint(job_id, stage)
|
||||
checkpoint = get_checkpoint(checkpoint_id)
|
||||
if not checkpoint:
|
||||
raise ValueError(f"No checkpoint for {job_id}/{stage}")
|
||||
return None
|
||||
|
||||
data = {
|
||||
"job_id": str(checkpoint.job_id),
|
||||
"video_path": checkpoint.video_path,
|
||||
"profile_name": checkpoint.profile_name,
|
||||
"config_overrides": checkpoint.config_overrides,
|
||||
"frames_manifest": checkpoint.frames_manifest,
|
||||
"frames_meta": checkpoint.frames_meta,
|
||||
"filtered_frame_sequences": checkpoint.filtered_frame_sequences,
|
||||
"stage_output_key": checkpoint.stage_output_key,
|
||||
"stats": checkpoint.stats,
|
||||
}
|
||||
|
||||
raw_manifest = data.get("frames_manifest", {})
|
||||
manifest = {int(k): v for k, v in raw_manifest.items()}
|
||||
frame_metadata = data.get("frames_meta", [])
|
||||
frames = load_frames(manifest, frame_metadata)
|
||||
|
||||
state = deserialize_state(data, frames)
|
||||
|
||||
logger.info("Checkpoint loaded: %s/%s (%d frames, scenario=%s)",
|
||||
job_id, stage, len(frames), checkpoint.is_scenario)
|
||||
return state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# List
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def list_checkpoints(job_id: str) -> list[str]:
|
||||
"""List available checkpoint stages for a job."""
|
||||
from core.db.detect import list_stage_checkpoints
|
||||
return list_stage_checkpoints(job_id)
|
||||
return (checkpoint.stage_outputs or {}).get(stage_name)
|
||||
|
||||
Reference in New Issue
Block a user