""" Checkpoint storage — Timeline, Checkpoint, StageOutput persistence. Timeline: user-created source selection (chunk paths) Checkpoint: lightweight tree node (parent_id, stage_name, config, stats) StageOutput: per-stage result (flat table, one row per job+stage) """ from __future__ import annotations import logging from uuid import UUID logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Timeline # --------------------------------------------------------------------------- def create_timeline( chunk_paths: list[str], profile_name: str = "", name: str = "", source_asset_id: UUID | None = None, fps: float = 2.0, ) -> str: """ Create a timeline from a chunk selection. Called by the user (via API) before any pipeline runs. Returns timeline_id. """ from core.db.models import Timeline from core.db.connection import get_session with get_session() as session: timeline = Timeline( name=name, chunk_paths=chunk_paths, profile_name=profile_name, source_asset_id=source_asset_id, fps=fps, status="created", ) session.add(timeline) session.commit() session.refresh(timeline) tid = str(timeline.id) logger.info("Timeline created: %s (%d chunks)", tid, len(chunk_paths)) return tid def get_timeline(timeline_id: str) -> dict: """Load a timeline as a dict.""" from core.db.models import Timeline from core.db.connection import get_session with get_session() as session: timeline = session.get(Timeline, UUID(timeline_id)) if not timeline: raise ValueError(f"Timeline not found: {timeline_id}") return { "id": str(timeline.id), "name": timeline.name, "chunk_paths": timeline.chunk_paths, "profile_name": timeline.profile_name, "status": timeline.status, "fps": timeline.fps, "source_asset_id": str(timeline.source_asset_id) if timeline.source_asset_id else None, "created_at": str(timeline.created_at) if timeline.created_at else None, } def update_timeline_status(timeline_id: str, status: str, frame_count: int | None = None): """Update timeline status and optionally frame count.""" from core.db.models import Timeline from core.db.connection import get_session with get_session() as session: timeline = session.get(Timeline, UUID(timeline_id)) if timeline: timeline.status = status if frame_count is not None: timeline.frame_count = frame_count session.commit() # --------------------------------------------------------------------------- # Checkpoint # --------------------------------------------------------------------------- def save_checkpoint( timeline_id: str, stage_name: str, parent_checkpoint_id: str | None = None, config_overrides: dict | None = None, stats: dict | None = None, is_scenario: bool = False, scenario_label: str = "", job_id: str | None = None, ) -> str: """ Save a checkpoint (lightweight tree node). No stage outputs — those go in StageOutput table separately. Returns the new checkpoint ID. """ from core.db.models import Checkpoint from core.db.connection import get_session with get_session() as session: checkpoint = Checkpoint( timeline_id=UUID(timeline_id), job_id=UUID(job_id) if job_id else None, parent_id=UUID(parent_checkpoint_id) if parent_checkpoint_id else None, stage_name=stage_name, config_overrides=config_overrides or {}, stats=stats or {}, is_scenario=is_scenario, scenario_label=scenario_label, ) session.add(checkpoint) session.commit() session.refresh(checkpoint) cid = str(checkpoint.id) logger.info("Checkpoint saved: %s (timeline %s, stage %s, parent %s)", cid, timeline_id, stage_name, parent_checkpoint_id) return cid def get_checkpoints_for_job(job_id: str) -> list[dict]: """List checkpoints for a job, ordered by creation time.""" from sqlmodel import select from core.db.models import Checkpoint from core.db.connection import get_session with get_session() as session: stmt = ( select(Checkpoint) .where(Checkpoint.job_id == UUID(job_id)) .order_by(Checkpoint.created_at) ) checkpoints = session.exec(stmt).all() return [ { "id": str(c.id), "timeline_id": str(c.timeline_id), "job_id": str(c.job_id) if c.job_id else None, "parent_id": str(c.parent_id) if c.parent_id else None, "stage_name": c.stage_name, "config_overrides": c.config_overrides or {}, "stats": c.stats or {}, "is_scenario": c.is_scenario, "scenario_label": c.scenario_label, "created_at": str(c.created_at) if c.created_at else None, } for c in checkpoints ] def get_checkpoints_for_timeline(timeline_id: str) -> list[dict]: """List all checkpoints on a timeline, ordered by creation time.""" from sqlmodel import select from core.db.models import Checkpoint from core.db.connection import get_session with get_session() as session: stmt = ( select(Checkpoint) .where(Checkpoint.timeline_id == UUID(timeline_id)) .order_by(Checkpoint.created_at) ) checkpoints = session.exec(stmt).all() return [ { "id": str(c.id), "timeline_id": str(c.timeline_id), "job_id": str(c.job_id) if c.job_id else None, "parent_id": str(c.parent_id) if c.parent_id else None, "stage_name": c.stage_name, "config_overrides": c.config_overrides or {}, "stats": c.stats or {}, "is_scenario": c.is_scenario, "scenario_label": c.scenario_label, "created_at": str(c.created_at) if c.created_at else None, } for c in checkpoints ] # --------------------------------------------------------------------------- # StageOutput # --------------------------------------------------------------------------- def save_stage_output( job_id: str, timeline_id: str, stage_name: str, output: dict, checkpoint_id: str | None = None, ) -> str: """ Save (upsert) a stage output. One row per (job_id, stage_name). Returns the stage_output ID. """ from sqlmodel import select from core.db.models import StageOutput from core.db.connection import get_session with get_session() as session: # Upsert: check if exists stmt = ( select(StageOutput) .where(StageOutput.job_id == UUID(job_id)) .where(StageOutput.stage_name == stage_name) ) existing = session.exec(stmt).first() if existing: existing.output = output existing.checkpoint_id = UUID(checkpoint_id) if checkpoint_id else None session.commit() session.refresh(existing) return str(existing.id) stage_output = StageOutput( job_id=UUID(job_id), timeline_id=UUID(timeline_id), stage_name=stage_name, checkpoint_id=UUID(checkpoint_id) if checkpoint_id else None, output=output, ) session.add(stage_output) session.commit() session.refresh(stage_output) return str(stage_output.id) def load_stage_output(job_id: str, stage_name: str) -> dict | None: """Load a stage's output by job + stage name.""" from sqlmodel import select from core.db.models import StageOutput from core.db.connection import get_session with get_session() as session: stmt = ( select(StageOutput) .where(StageOutput.job_id == UUID(job_id)) .where(StageOutput.stage_name == stage_name) ) row = session.exec(stmt).first() if not row: return None return row.output def load_stage_outputs_for_job(job_id: str) -> dict[str, dict]: """Load all stage outputs for a job. Returns {stage_name: output}.""" from sqlmodel import select from core.db.models import StageOutput from core.db.connection import get_session with get_session() as session: stmt = ( select(StageOutput) .where(StageOutput.job_id == UUID(job_id)) ) rows = session.exec(stmt).all() return {row.stage_name: row.output for row in rows} def load_stage_outputs_for_timeline(timeline_id: str, stage_name: str | None = None) -> list[dict]: """Load stage outputs for a timeline, optionally filtered by stage.""" from sqlmodel import select from core.db.models import StageOutput from core.db.connection import get_session with get_session() as session: stmt = select(StageOutput).where(StageOutput.timeline_id == UUID(timeline_id)) if stage_name: stmt = stmt.where(StageOutput.stage_name == stage_name) rows = session.exec(stmt).all() return [ { "id": str(r.id), "job_id": str(r.job_id), "stage_name": r.stage_name, "checkpoint_id": str(r.checkpoint_id) if r.checkpoint_id else None, "output": r.output, "created_at": str(r.created_at) if r.created_at else None, } for r in rows ]