304 lines
9.5 KiB
Python
304 lines
9.5 KiB
Python
"""
|
|
Checkpoint storage — Timeline, Checkpoint, StageOutput persistence.
|
|
|
|
Timeline: user-created source selection (chunk paths)
|
|
Checkpoint: lightweight tree node (parent_id, stage_name, config, stats)
|
|
StageOutput: per-stage result (flat table, one row per job+stage)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from uuid import UUID
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Timeline
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def create_timeline(
|
|
chunk_paths: list[str],
|
|
profile_name: str = "",
|
|
name: str = "",
|
|
source_asset_id: UUID | None = None,
|
|
fps: float = 2.0,
|
|
) -> str:
|
|
"""
|
|
Create a timeline from a chunk selection.
|
|
|
|
Called by the user (via API) before any pipeline runs.
|
|
Returns timeline_id.
|
|
"""
|
|
from core.db.models import Timeline
|
|
from core.db.connection import get_session
|
|
|
|
with get_session() as session:
|
|
timeline = Timeline(
|
|
name=name,
|
|
chunk_paths=chunk_paths,
|
|
profile_name=profile_name,
|
|
source_asset_id=source_asset_id,
|
|
fps=fps,
|
|
status="created",
|
|
)
|
|
session.add(timeline)
|
|
session.commit()
|
|
session.refresh(timeline)
|
|
tid = str(timeline.id)
|
|
|
|
logger.info("Timeline created: %s (%d chunks)", tid, len(chunk_paths))
|
|
return tid
|
|
|
|
|
|
def get_timeline(timeline_id: str) -> dict:
|
|
"""Load a timeline as a dict."""
|
|
from core.db.models import Timeline
|
|
from core.db.connection import get_session
|
|
|
|
with get_session() as session:
|
|
timeline = session.get(Timeline, UUID(timeline_id))
|
|
if not timeline:
|
|
raise ValueError(f"Timeline not found: {timeline_id}")
|
|
|
|
return {
|
|
"id": str(timeline.id),
|
|
"name": timeline.name,
|
|
"chunk_paths": timeline.chunk_paths,
|
|
"profile_name": timeline.profile_name,
|
|
"status": timeline.status,
|
|
"fps": timeline.fps,
|
|
"source_asset_id": str(timeline.source_asset_id) if timeline.source_asset_id else None,
|
|
"created_at": str(timeline.created_at) if timeline.created_at else None,
|
|
}
|
|
|
|
|
|
def update_timeline_status(timeline_id: str, status: str, frame_count: int | None = None):
|
|
"""Update timeline status and optionally frame count."""
|
|
from core.db.models import Timeline
|
|
from core.db.connection import get_session
|
|
|
|
with get_session() as session:
|
|
timeline = session.get(Timeline, UUID(timeline_id))
|
|
if timeline:
|
|
timeline.status = status
|
|
if frame_count is not None:
|
|
timeline.frame_count = frame_count
|
|
session.commit()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Checkpoint
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def save_checkpoint(
|
|
timeline_id: str,
|
|
stage_name: str,
|
|
parent_checkpoint_id: str | None = None,
|
|
config_overrides: dict | None = None,
|
|
stats: dict | None = None,
|
|
is_scenario: bool = False,
|
|
scenario_label: str = "",
|
|
job_id: str | None = None,
|
|
) -> str:
|
|
"""
|
|
Save a checkpoint (lightweight tree node).
|
|
|
|
No stage outputs — those go in StageOutput table separately.
|
|
Returns the new checkpoint ID.
|
|
"""
|
|
from core.db.models import Checkpoint
|
|
from core.db.connection import get_session
|
|
|
|
with get_session() as session:
|
|
checkpoint = Checkpoint(
|
|
timeline_id=UUID(timeline_id),
|
|
job_id=UUID(job_id) if job_id else None,
|
|
parent_id=UUID(parent_checkpoint_id) if parent_checkpoint_id else None,
|
|
stage_name=stage_name,
|
|
config_overrides=config_overrides or {},
|
|
stats=stats or {},
|
|
is_scenario=is_scenario,
|
|
scenario_label=scenario_label,
|
|
)
|
|
session.add(checkpoint)
|
|
session.commit()
|
|
session.refresh(checkpoint)
|
|
cid = str(checkpoint.id)
|
|
|
|
logger.info("Checkpoint saved: %s (timeline %s, stage %s, parent %s)",
|
|
cid, timeline_id, stage_name, parent_checkpoint_id)
|
|
return cid
|
|
|
|
|
|
def get_checkpoints_for_job(job_id: str) -> list[dict]:
|
|
"""List checkpoints for a job, ordered by creation time."""
|
|
from sqlmodel import select
|
|
from core.db.models import Checkpoint
|
|
from core.db.connection import get_session
|
|
|
|
with get_session() as session:
|
|
stmt = (
|
|
select(Checkpoint)
|
|
.where(Checkpoint.job_id == UUID(job_id))
|
|
.order_by(Checkpoint.created_at)
|
|
)
|
|
checkpoints = session.exec(stmt).all()
|
|
|
|
return [
|
|
{
|
|
"id": str(c.id),
|
|
"timeline_id": str(c.timeline_id),
|
|
"job_id": str(c.job_id) if c.job_id else None,
|
|
"parent_id": str(c.parent_id) if c.parent_id else None,
|
|
"stage_name": c.stage_name,
|
|
"config_overrides": c.config_overrides or {},
|
|
"stats": c.stats or {},
|
|
"is_scenario": c.is_scenario,
|
|
"scenario_label": c.scenario_label,
|
|
"created_at": str(c.created_at) if c.created_at else None,
|
|
}
|
|
for c in checkpoints
|
|
]
|
|
|
|
|
|
def get_checkpoints_for_timeline(timeline_id: str) -> list[dict]:
|
|
"""List all checkpoints on a timeline, ordered by creation time."""
|
|
from sqlmodel import select
|
|
from core.db.models import Checkpoint
|
|
from core.db.connection import get_session
|
|
|
|
with get_session() as session:
|
|
stmt = (
|
|
select(Checkpoint)
|
|
.where(Checkpoint.timeline_id == UUID(timeline_id))
|
|
.order_by(Checkpoint.created_at)
|
|
)
|
|
checkpoints = session.exec(stmt).all()
|
|
|
|
return [
|
|
{
|
|
"id": str(c.id),
|
|
"timeline_id": str(c.timeline_id),
|
|
"job_id": str(c.job_id) if c.job_id else None,
|
|
"parent_id": str(c.parent_id) if c.parent_id else None,
|
|
"stage_name": c.stage_name,
|
|
"config_overrides": c.config_overrides or {},
|
|
"stats": c.stats or {},
|
|
"is_scenario": c.is_scenario,
|
|
"scenario_label": c.scenario_label,
|
|
"created_at": str(c.created_at) if c.created_at else None,
|
|
}
|
|
for c in checkpoints
|
|
]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# StageOutput
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def save_stage_output(
|
|
job_id: str,
|
|
timeline_id: str,
|
|
stage_name: str,
|
|
output: dict,
|
|
checkpoint_id: str | None = None,
|
|
) -> str:
|
|
"""
|
|
Save (upsert) a stage output. One row per (job_id, stage_name).
|
|
|
|
Returns the stage_output ID.
|
|
"""
|
|
from sqlmodel import select
|
|
from core.db.models import StageOutput
|
|
from core.db.connection import get_session
|
|
|
|
with get_session() as session:
|
|
# Upsert: check if exists
|
|
stmt = (
|
|
select(StageOutput)
|
|
.where(StageOutput.job_id == UUID(job_id))
|
|
.where(StageOutput.stage_name == stage_name)
|
|
)
|
|
existing = session.exec(stmt).first()
|
|
|
|
if existing:
|
|
existing.output = output
|
|
existing.checkpoint_id = UUID(checkpoint_id) if checkpoint_id else None
|
|
session.commit()
|
|
session.refresh(existing)
|
|
return str(existing.id)
|
|
|
|
stage_output = StageOutput(
|
|
job_id=UUID(job_id),
|
|
timeline_id=UUID(timeline_id),
|
|
stage_name=stage_name,
|
|
checkpoint_id=UUID(checkpoint_id) if checkpoint_id else None,
|
|
output=output,
|
|
)
|
|
session.add(stage_output)
|
|
session.commit()
|
|
session.refresh(stage_output)
|
|
return str(stage_output.id)
|
|
|
|
|
|
def load_stage_output(job_id: str, stage_name: str) -> dict | None:
|
|
"""Load a stage's output by job + stage name."""
|
|
from sqlmodel import select
|
|
from core.db.models import StageOutput
|
|
from core.db.connection import get_session
|
|
|
|
with get_session() as session:
|
|
stmt = (
|
|
select(StageOutput)
|
|
.where(StageOutput.job_id == UUID(job_id))
|
|
.where(StageOutput.stage_name == stage_name)
|
|
)
|
|
row = session.exec(stmt).first()
|
|
|
|
if not row:
|
|
return None
|
|
return row.output
|
|
|
|
|
|
def load_stage_outputs_for_job(job_id: str) -> dict[str, dict]:
|
|
"""Load all stage outputs for a job. Returns {stage_name: output}."""
|
|
from sqlmodel import select
|
|
from core.db.models import StageOutput
|
|
from core.db.connection import get_session
|
|
|
|
with get_session() as session:
|
|
stmt = (
|
|
select(StageOutput)
|
|
.where(StageOutput.job_id == UUID(job_id))
|
|
)
|
|
rows = session.exec(stmt).all()
|
|
|
|
return {row.stage_name: row.output for row in rows}
|
|
|
|
|
|
def load_stage_outputs_for_timeline(timeline_id: str, stage_name: str | None = None) -> list[dict]:
|
|
"""Load stage outputs for a timeline, optionally filtered by stage."""
|
|
from sqlmodel import select
|
|
from core.db.models import StageOutput
|
|
from core.db.connection import get_session
|
|
|
|
with get_session() as session:
|
|
stmt = select(StageOutput).where(StageOutput.timeline_id == UUID(timeline_id))
|
|
if stage_name:
|
|
stmt = stmt.where(StageOutput.stage_name == stage_name)
|
|
rows = session.exec(stmt).all()
|
|
|
|
return [
|
|
{
|
|
"id": str(r.id),
|
|
"job_id": str(r.job_id),
|
|
"stage_name": r.stage_name,
|
|
"checkpoint_id": str(r.checkpoint_id) if r.checkpoint_id else None,
|
|
"output": r.output,
|
|
"created_at": str(r.created_at) if r.created_at else None,
|
|
}
|
|
for r in rows
|
|
]
|