178 lines
5.8 KiB
Python
178 lines
5.8 KiB
Python
"""
|
|
Checkpoint storage — Timeline + Checkpoint (tree of snapshots).
|
|
|
|
Timeline: frame sequence from source video (frames in MinIO)
|
|
Checkpoint: snapshot of pipeline state (stage outputs as JSONB in Postgres)
|
|
parent_id forms a tree — multiple children = different config tries
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from uuid import UUID
|
|
|
|
from .frames import save_frames, load_frames, CHECKPOINT_PREFIX
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Timeline
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def create_timeline(
|
|
source_video: str,
|
|
profile_name: str,
|
|
frames: list,
|
|
fps: float = 2.0,
|
|
source_asset_id: UUID | None = None,
|
|
) -> tuple[str, str]:
|
|
"""
|
|
Create a timeline from frames. Uploads frame images to MinIO,
|
|
creates Timeline + root Checkpoint in Postgres.
|
|
|
|
Returns (timeline_id, checkpoint_id).
|
|
"""
|
|
from core.db.models import Timeline, Checkpoint
|
|
from core.db.connection import get_session
|
|
|
|
with get_session() as session:
|
|
timeline = Timeline(
|
|
source_video=source_video,
|
|
profile_name=profile_name,
|
|
source_asset_id=source_asset_id,
|
|
fps=fps,
|
|
)
|
|
session.add(timeline)
|
|
session.flush()
|
|
tid = str(timeline.id)
|
|
|
|
# Upload frames to MinIO
|
|
manifest = save_frames(tid, frames)
|
|
|
|
frames_meta = [
|
|
{
|
|
"sequence": f.sequence,
|
|
"chunk_id": getattr(f, "chunk_id", 0),
|
|
"timestamp": f.timestamp,
|
|
"perceptual_hash": getattr(f, "perceptual_hash", ""),
|
|
}
|
|
for f in frames
|
|
]
|
|
|
|
timeline.frames_prefix = f"{CHECKPOINT_PREFIX}/{tid}/frames/"
|
|
timeline.frames_manifest = {str(k): v for k, v in manifest.items()}
|
|
timeline.frames_meta = frames_meta
|
|
|
|
checkpoint = Checkpoint(
|
|
timeline_id=timeline.id,
|
|
parent_id=None,
|
|
stage_outputs={},
|
|
stats={"frames_extracted": len(frames)},
|
|
)
|
|
session.add(checkpoint)
|
|
session.commit()
|
|
session.refresh(checkpoint)
|
|
cid = str(checkpoint.id)
|
|
|
|
logger.info("Timeline created: %s (%d frames, root checkpoint %s)", tid, len(frames), cid)
|
|
return tid, cid
|
|
|
|
|
|
def get_timeline_frames(timeline_id: str) -> list:
|
|
"""Load frames from a timeline (from MinIO) as Frame objects."""
|
|
from core.db.models import Timeline
|
|
from core.db.connection import get_session
|
|
|
|
with get_session() as session:
|
|
timeline = session.get(Timeline, UUID(timeline_id))
|
|
if not timeline:
|
|
raise ValueError(f"Timeline not found: {timeline_id}")
|
|
|
|
raw_manifest = timeline.frames_manifest or {}
|
|
manifest = {int(k): v for k, v in raw_manifest.items()}
|
|
return load_frames(manifest, timeline.frames_meta or [])
|
|
|
|
|
|
def get_timeline_frames_b64(timeline_id: str) -> list[dict]:
|
|
"""Load frames as base64 JPEG (lightweight, no numpy)."""
|
|
from core.db.models import Timeline
|
|
from core.db.connection import get_session
|
|
from .frames import load_frames_b64
|
|
|
|
with get_session() as session:
|
|
timeline = session.get(Timeline, UUID(timeline_id))
|
|
if not timeline:
|
|
raise ValueError(f"Timeline not found: {timeline_id}")
|
|
|
|
raw_manifest = timeline.frames_manifest or {}
|
|
manifest = {int(k): v for k, v in raw_manifest.items()}
|
|
return load_frames_b64(manifest, timeline.frames_meta or [])
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Checkpoint
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def save_stage_output(
|
|
timeline_id: str,
|
|
parent_checkpoint_id: str | None,
|
|
stage_name: str,
|
|
output_json: dict,
|
|
config_overrides: dict | None = None,
|
|
stats: dict | None = None,
|
|
is_scenario: bool = False,
|
|
scenario_label: str = "",
|
|
job_id: str | None = None,
|
|
) -> str:
|
|
"""
|
|
Save a stage's output as a new checkpoint (child of parent).
|
|
|
|
Carries forward stage outputs from parent + adds the new one.
|
|
Returns the new checkpoint ID.
|
|
"""
|
|
from core.db.models import Checkpoint
|
|
from core.db.connection import get_session
|
|
|
|
with get_session() as session:
|
|
parent_outputs = {}
|
|
parent_stats = {}
|
|
parent_config = {}
|
|
if parent_checkpoint_id:
|
|
parent = session.get(Checkpoint, UUID(parent_checkpoint_id))
|
|
if parent:
|
|
parent_outputs = dict(parent.stage_outputs or {})
|
|
parent_stats = dict(parent.stats or {})
|
|
parent_config = dict(parent.config_overrides or {})
|
|
|
|
checkpoint = Checkpoint(
|
|
timeline_id=UUID(timeline_id),
|
|
job_id=UUID(job_id) if job_id else None,
|
|
parent_id=UUID(parent_checkpoint_id) if parent_checkpoint_id else None,
|
|
stage_outputs={**parent_outputs, stage_name: output_json},
|
|
config_overrides={**parent_config, **(config_overrides or {})},
|
|
stats={**parent_stats, **(stats or {})},
|
|
is_scenario=is_scenario,
|
|
scenario_label=scenario_label,
|
|
)
|
|
session.add(checkpoint)
|
|
session.commit()
|
|
session.refresh(checkpoint)
|
|
cid = str(checkpoint.id)
|
|
|
|
logger.info("Checkpoint saved: %s (timeline %s, stage %s, parent %s)",
|
|
cid, timeline_id, stage_name, parent_checkpoint_id)
|
|
return cid
|
|
|
|
|
|
def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None:
|
|
"""Load a stage's output from a checkpoint."""
|
|
from core.db.models import Checkpoint
|
|
from core.db.connection import get_session
|
|
|
|
with get_session() as session:
|
|
checkpoint = session.get(Checkpoint, UUID(checkpoint_id))
|
|
if not checkpoint:
|
|
return None
|
|
return (checkpoint.stage_outputs or {}).get(stage_name)
|