This commit is contained in:
2026-03-30 09:53:10 -03:00
parent 4220b0418e
commit aac27b8504
32 changed files with 1068 additions and 329 deletions

View File

@@ -1,19 +1,31 @@
"""
Checkpoint system — Timeline + Checkpoint tree.
Checkpoint system — Timeline + Checkpoint tree + StageOutput.
detect/checkpoint/
frames.py — frame image S3 upload/download
storage.py — Timeline + Checkpoint (Postgres + MinIO)
replay.py — replay (TODO: migrate to new model)
frames.py — per-timeline frame cache (local filesystem)
storage.py — Timeline, Checkpoint, StageOutput persistence
replay.py — replay from checkpoint (TODO: rework in 5d)
runner_bridge.py — checkpoint hook for PipelineRunner
"""
from .storage import (
create_timeline,
get_timeline_frames,
get_timeline_frames_b64,
get_timeline,
update_timeline_status,
save_checkpoint,
get_checkpoints_for_job,
get_checkpoints_for_timeline,
save_stage_output,
load_stage_output,
load_stage_outputs_for_job,
load_stage_outputs_for_timeline,
)
from .frames import save_frames, load_frames
from .runner_bridge import checkpoint_after_stage, reset_checkpoint_state, get_timeline_id
from .frames import (
cache_exists,
cache_frames,
load_cached_frames,
load_cached_frames_b64,
clear_cache,
frames_to_b64,
)
from .runner_bridge import checkpoint_after_stage, reset_checkpoint_state, get_latest_checkpoint

View File

@@ -1,7 +1,19 @@
"""Frame image storage — save/load to S3/MinIO as JPEGs."""
"""
Frame cache — per-timeline frame storage in blob storage (S3/MinIO).
Frames are extracted from chunks once, cached as JPEGs at
cache/timelines/{timeline_id}/frames/{seq}.jpg in the app's
blob storage. Any job on the timeline reads from the cache.
Cache is clearable and rebuildable from chunks.
Uses the same storage backend as the rest of the app, so it
works across lambdas, GPU boxes, and local dev.
"""
from __future__ import annotations
import base64
import io
import logging
import os
import tempfile
@@ -14,25 +26,39 @@ from core.detect.models import Frame
logger = logging.getLogger(__name__)
BUCKET = os.environ.get("S3_BUCKET", "mpr")
CHECKPOINT_PREFIX = "checkpoints"
CACHE_PREFIX = "cache/timelines"
def save_frames(job_id: str, frames: list[Frame]) -> dict[int, str]:
def _frame_key(timeline_id: str, seq: int) -> str:
return f"{CACHE_PREFIX}/{timeline_id}/frames/{seq}.jpg"
def _list_prefix(timeline_id: str) -> str:
return f"{CACHE_PREFIX}/{timeline_id}/frames/"
def cache_exists(timeline_id: str) -> bool:
"""Check if frame cache exists for a timeline."""
from core.storage.s3 import list_objects
objects = list_objects(BUCKET, _list_prefix(timeline_id))
return len(objects) > 0
def cache_frames(timeline_id: str, frames: list[Frame], quality: int = 85) -> int:
"""
Save frame images to S3 as JPEGs.
Write frames to blob storage as JPEGs.
Returns manifest: {sequence: s3_key}
Returns number of frames cached.
"""
from core.storage.s3 import upload_file
manifest = {}
for frame in frames:
key = f"{CHECKPOINT_PREFIX}/{job_id}/frames/{frame.sequence}.jpg"
key = _frame_key(timeline_id, frame.sequence)
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
img = Image.fromarray(frame.image)
img.save(tmp, format="JPEG", quality=85)
img.save(tmp, format="JPEG", quality=quality)
tmp_path = tmp.name
try:
@@ -40,25 +66,30 @@ def save_frames(job_id: str, frames: list[Frame]) -> dict[int, str]:
finally:
os.unlink(tmp_path)
manifest[frame.sequence] = key
logger.info("Saved %d frames to s3://%s/%s/%s/frames/",
len(frames), BUCKET, CHECKPOINT_PREFIX, job_id)
return manifest
logger.info("Cached %d frames for timeline %s", len(frames), timeline_id)
return len(frames)
def load_frames(manifest: dict[int, str], frame_metadata: list[dict]) -> list[Frame]:
def load_cached_frames(timeline_id: str) -> list[Frame]:
"""
Load frame images from S3 and reconstitute Frame objects.
Load all cached frames as Frame objects with numpy arrays.
frame_metadata: list of dicts with sequence, chunk_id, timestamp, perceptual_hash.
Returns empty list if cache doesn't exist.
"""
from core.storage.s3 import download_to_temp
from core.storage.s3 import list_objects, download_to_temp
objects = list_objects(BUCKET, _list_prefix(timeline_id))
if not objects:
return []
meta_map = {m["sequence"]: m for m in frame_metadata}
frames = []
for obj in objects:
key = obj["key"]
filename = key.rsplit("/", 1)[-1]
if not filename.endswith(".jpg"):
continue
seq = int(filename.replace(".jpg", ""))
for seq, key in manifest.items():
tmp_path = download_to_temp(BUCKET, key)
try:
img = Image.open(tmp_path).convert("RGB")
@@ -66,13 +97,12 @@ def load_frames(manifest: dict[int, str], frame_metadata: list[dict]) -> list[Fr
finally:
os.unlink(tmp_path)
meta = meta_map.get(seq, {})
frame = Frame(
sequence=seq,
chunk_id=meta.get("chunk_id", 0),
timestamp=meta.get("timestamp", 0.0),
chunk_id=0,
timestamp=0.0,
image=image_array,
perceptual_hash=meta.get("perceptual_hash", ""),
perceptual_hash="",
)
frames.append(frame)
@@ -80,32 +110,70 @@ def load_frames(manifest: dict[int, str], frame_metadata: list[dict]) -> list[Fr
return frames
def load_frames_b64(manifest: dict[int, str], frame_metadata: list[dict]) -> list[dict]:
def load_cached_frames_b64(timeline_id: str) -> list[dict]:
"""
Load frame images from S3 as base64 JPEG — lightweight, no numpy.
Load cached frames as base64 JPEGs for the UI.
Returns list of dicts: {seq, timestamp, jpeg_b64}
Returns list of {seq, timestamp, jpeg_b64}.
"""
import base64
from core.storage.s3 import download_to_temp
from core.storage.s3 import list_objects, download_to_temp
meta_map = {m["sequence"]: m for m in frame_metadata}
frames = []
objects = list_objects(BUCKET, _list_prefix(timeline_id))
if not objects:
return []
result = []
for obj in objects:
key = obj["key"]
filename = key.rsplit("/", 1)[-1]
if not filename.endswith(".jpg"):
continue
seq = int(filename.replace(".jpg", ""))
for seq, key in manifest.items():
tmp_path = download_to_temp(BUCKET, key)
try:
with open(tmp_path, "rb") as f:
jpeg_bytes = f.read()
jpeg_b64 = base64.b64encode(f.read()).decode()
finally:
os.unlink(tmp_path)
meta = meta_map.get(seq, {})
frames.append({
result.append({
"seq": seq,
"timestamp": meta.get("timestamp", 0.0),
"jpeg_b64": base64.b64encode(jpeg_bytes).decode(),
"timestamp": 0.0,
"jpeg_b64": jpeg_b64,
})
frames.sort(key=lambda f: f["seq"])
return frames
result.sort(key=lambda f: f["seq"])
return result
def clear_cache(timeline_id: str):
"""Delete the frame cache for a timeline."""
from core.storage.s3 import delete_objects
prefix = _list_prefix(timeline_id)
delete_objects(BUCKET, prefix)
logger.info("Cleared frame cache for timeline %s", timeline_id)
def frames_to_b64(frames: list[Frame], quality: int = 75) -> list[dict]:
"""
Convert in-memory Frame objects to base64 JPEG dicts.
For API responses when frames are already in memory.
"""
result = []
for frame in frames:
buf = io.BytesIO()
img = Image.fromarray(frame.image)
img.save(buf, format="JPEG", quality=quality)
jpeg_b64 = base64.b64encode(buf.getvalue()).decode()
result.append({
"seq": frame.sequence,
"timestamp": frame.timestamp,
"jpeg_b64": jpeg_b64,
})
result.sort(key=lambda f: f["seq"])
return result

View File

@@ -1,13 +1,9 @@
"""
Runner bridge — checkpoint hook called by PipelineRunner after each stage.
Owns the per-job state (timeline, frame manifest, checkpoint chain) that
the runner shouldn't know about.
Timeline and Job are independent entities:
- One Timeline can serve multiple Jobs (re-run with different params)
- One Job operates on one Timeline (set after frame extraction)
- Checkpoints belong to Timeline, tagged with the Job that created them
Saves a checkpoint + stage output after each stage completes.
Timeline and Job are independent: timeline_id and job_id come from
the pipeline state (set at job creation time).
"""
from __future__ import annotations
@@ -16,63 +12,37 @@ import logging
logger = logging.getLogger(__name__)
# Per-job state
_timeline_id: dict[str, str] = {}
_frames_manifest: dict[str, dict[int, str]] = {}
# Per-job state: tracks the latest checkpoint so we can chain parent → child
_latest_checkpoint: dict[str, str] = {}
def reset_checkpoint_state(job_id: str):
"""Clean up per-job checkpoint state. Called when pipeline finishes."""
_timeline_id.pop(job_id, None)
_frames_manifest.pop(job_id, None)
_latest_checkpoint.pop(job_id, None)
def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict):
"""
Save a checkpoint after a stage completes.
Save a checkpoint + stage output after a stage completes.
Called by the runner. Handles:
- Timeline creation (once, on extract_frames)
- Frame upload (via create_timeline)
- Stage output serialization (via stage registry)
- Checkpoint chain (parent → child)
- Stage output as separate row in StageOutput table
"""
if not job_id:
return
from .storage import create_timeline, save_stage_output
timeline_id = state.get("timeline_id", "")
if not timeline_id:
logger.warning("No timeline_id in state for job %s, skipping checkpoint", job_id)
return
from .storage import save_checkpoint, save_stage_output
from core.detect.stages.base import _REGISTRY
merged = {**state, **result}
# On extract_frames: create Timeline + upload frames + root checkpoint
if stage_name == "extract_frames" and job_id not in _timeline_id:
frames = merged.get("frames", [])
video_path = merged.get("video_path", "")
profile_name = merged.get("profile_name", "")
tid, cid = create_timeline(
source_video=video_path,
profile_name=profile_name,
frames=frames,
)
_timeline_id[job_id] = tid
_latest_checkpoint[job_id] = cid
logger.info("Job %s → Timeline %s (root checkpoint %s)", job_id, tid, cid)
# Emit timeline_id via SSE so the UI can use it for checkpoint loads
from core.detect import emit
emit.log(job_id, "Checkpoint", "INFO", f"timeline_id={tid}")
return
# For subsequent stages: save checkpoint on the timeline
tid = _timeline_id.get(job_id)
if not tid:
logger.warning("No timeline for job %s, skipping checkpoint", job_id)
return
# Serialize stage output using the stage's serialize_fn if available
stage_cls = _REGISTRY.get(stage_name)
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
@@ -81,17 +51,41 @@ def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: di
else:
output_json = {}
# Convert stats dataclass to dict for JSONB storage
import dataclasses
raw_stats = state.get("stats", {})
if dataclasses.is_dataclass(raw_stats):
stats_dict = dataclasses.asdict(raw_stats)
elif isinstance(raw_stats, dict):
stats_dict = raw_stats
else:
stats_dict = {}
# Save checkpoint (lightweight tree node)
parent_id = _latest_checkpoint.get(job_id)
new_checkpoint_id = save_stage_output(
timeline_id=tid,
parent_checkpoint_id=parent_id,
checkpoint_id = save_checkpoint(
timeline_id=timeline_id,
stage_name=stage_name,
output_json=output_json,
parent_checkpoint_id=parent_id,
config_overrides=state.get("config_overrides"),
stats=stats_dict,
job_id=job_id,
)
_latest_checkpoint[job_id] = new_checkpoint_id
_latest_checkpoint[job_id] = checkpoint_id
# Save stage output (separate row, upsert by job+stage)
if output_json:
save_stage_output(
job_id=job_id,
timeline_id=timeline_id,
stage_name=stage_name,
output=output_json,
checkpoint_id=checkpoint_id,
)
logger.info("Checkpoint %s + output for stage %s (job %s)", checkpoint_id, stage_name, job_id)
def get_timeline_id(job_id: str) -> str | None:
"""Get the timeline_id for a running job. Used by the UI to load checkpoints."""
return _timeline_id.get(job_id)
def get_latest_checkpoint(job_id: str) -> str | None:
"""Get the latest checkpoint_id for a running job."""
return _latest_checkpoint.get(job_id)

View File

@@ -6,6 +6,9 @@ This file has no model-specific knowledge — stages own their data format.
The only things serialized here are the "envelope" fields (job_id, video_path, etc.)
that don't belong to any stage.
Frames are ephemeral (in-memory during a run). Serialization stores
metadata only; frames are re-extracted from chunks when needed.
"""
from __future__ import annotations
@@ -18,10 +21,10 @@ from core.schema.serializers.pipeline import (
# Envelope fields — not owned by any stage, always present
ENVELOPE_KEYS = ["job_id", "video_path", "profile_name", "config_overrides"]
ENVELOPE_KEYS = ["job_id", "video_path", "profile_name", "timeline_id", "config_overrides"]
def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
def serialize_state(state: dict) -> dict:
"""
Serialize DetectState to a JSON-compatible dict.
@@ -37,9 +40,6 @@ def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
default = {} if key == "config_overrides" else ""
checkpoint[key] = state.get(key, default)
# Frames manifest (needed by frame-loading stages)
checkpoint["frames_manifest"] = {str(k): v for k, v in frames_manifest.items()}
# Stats (shared across stages, not owned by one)
stats = state.get("stats")
if stats is not None:
@@ -60,8 +60,9 @@ def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
def deserialize_state(checkpoint: dict, frames: list) -> dict:
"""
Reconstitute DetectState from a checkpoint dict + loaded frames.
Reconstitute DetectState from a checkpoint dict + frames.
Frames are provided by the caller (re-extracted from chunks).
Calls each stage's deserialize_fn to restore stage-owned data.
"""
from core.detect.stages.base import _REGISTRY
@@ -75,7 +76,7 @@ def deserialize_state(checkpoint: dict, frames: list) -> dict:
default = {} if key == "config_overrides" else ""
state[key] = checkpoint.get(key, default)
# Frames (always present, loaded externally)
# Frames (provided externally, ephemeral)
state["frames"] = frames
# Stats

View File

@@ -1,9 +1,9 @@
"""
Checkpoint storage — Timeline + Checkpoint (tree of snapshots).
Checkpoint storage — Timeline, Checkpoint, StageOutput persistence.
Timeline: frame sequence from source video (frames in MinIO)
Checkpoint: snapshot of pipeline state (stage outputs as JSONB in Postgres)
parent_id forms a tree — multiple children = different config tries
Timeline: user-created source selection (chunk paths)
Checkpoint: lightweight tree node (parent_id, stage_name, config, stats)
StageOutput: per-stage result (flat table, one row per job+stage)
"""
from __future__ import annotations
@@ -11,8 +11,6 @@ from __future__ import annotations
import logging
from uuid import UUID
from .frames import save_frames, load_frames, CHECKPOINT_PREFIX
logger = logging.getLogger(__name__)
@@ -21,66 +19,41 @@ logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
def create_timeline(
source_video: str,
profile_name: str,
frames: list,
fps: float = 2.0,
chunk_paths: list[str],
profile_name: str = "",
name: str = "",
source_asset_id: UUID | None = None,
) -> tuple[str, str]:
fps: float = 2.0,
) -> str:
"""
Create a timeline from frames. Uploads frame images to MinIO,
creates Timeline + root Checkpoint in Postgres.
Create a timeline from a chunk selection.
Returns (timeline_id, checkpoint_id).
Called by the user (via API) before any pipeline runs.
Returns timeline_id.
"""
from core.db.models import Timeline, Checkpoint
from core.db.models import Timeline
from core.db.connection import get_session
with get_session() as session:
timeline = Timeline(
source_video=source_video,
name=name,
chunk_paths=chunk_paths,
profile_name=profile_name,
source_asset_id=source_asset_id,
fps=fps,
status="created",
)
session.add(timeline)
session.flush()
session.commit()
session.refresh(timeline)
tid = str(timeline.id)
# Upload frames to MinIO
manifest = save_frames(tid, frames)
frames_meta = [
{
"sequence": f.sequence,
"chunk_id": getattr(f, "chunk_id", 0),
"timestamp": f.timestamp,
"perceptual_hash": getattr(f, "perceptual_hash", ""),
}
for f in frames
]
timeline.frames_prefix = f"{CHECKPOINT_PREFIX}/{tid}/frames/"
timeline.frames_manifest = {str(k): v for k, v in manifest.items()}
timeline.frames_meta = frames_meta
checkpoint = Checkpoint(
timeline_id=timeline.id,
parent_id=None,
stage_outputs={},
stats={"frames_extracted": len(frames)},
)
session.add(checkpoint)
session.commit()
session.refresh(checkpoint)
cid = str(checkpoint.id)
logger.info("Timeline created: %s (%d frames, root checkpoint %s)", tid, len(frames), cid)
return tid, cid
logger.info("Timeline created: %s (%d chunks)", tid, len(chunk_paths))
return tid
def get_timeline_frames(timeline_id: str) -> list:
"""Load frames from a timeline (from MinIO) as Frame objects."""
def get_timeline(timeline_id: str) -> dict:
"""Load a timeline as a dict."""
from core.db.models import Timeline
from core.db.connection import get_session
@@ -89,36 +62,40 @@ def get_timeline_frames(timeline_id: str) -> list:
if not timeline:
raise ValueError(f"Timeline not found: {timeline_id}")
raw_manifest = timeline.frames_manifest or {}
manifest = {int(k): v for k, v in raw_manifest.items()}
return load_frames(manifest, timeline.frames_meta or [])
return {
"id": str(timeline.id),
"name": timeline.name,
"chunk_paths": timeline.chunk_paths,
"profile_name": timeline.profile_name,
"status": timeline.status,
"fps": timeline.fps,
"source_asset_id": str(timeline.source_asset_id) if timeline.source_asset_id else None,
"created_at": str(timeline.created_at) if timeline.created_at else None,
}
def get_timeline_frames_b64(timeline_id: str) -> list[dict]:
"""Load frames as base64 JPEG (lightweight, no numpy)."""
def update_timeline_status(timeline_id: str, status: str, frame_count: int | None = None):
"""Update timeline status and optionally frame count."""
from core.db.models import Timeline
from core.db.connection import get_session
from .frames import load_frames_b64
with get_session() as session:
timeline = session.get(Timeline, UUID(timeline_id))
if not timeline:
raise ValueError(f"Timeline not found: {timeline_id}")
raw_manifest = timeline.frames_manifest or {}
manifest = {int(k): v for k, v in raw_manifest.items()}
return load_frames_b64(manifest, timeline.frames_meta or [])
if timeline:
timeline.status = status
if frame_count is not None:
timeline.frame_count = frame_count
session.commit()
# ---------------------------------------------------------------------------
# Checkpoint
# ---------------------------------------------------------------------------
def save_stage_output(
def save_checkpoint(
timeline_id: str,
parent_checkpoint_id: str | None,
stage_name: str,
output_json: dict,
parent_checkpoint_id: str | None = None,
config_overrides: dict | None = None,
stats: dict | None = None,
is_scenario: bool = False,
@@ -126,32 +103,22 @@ def save_stage_output(
job_id: str | None = None,
) -> str:
"""
Save a stage's output as a new checkpoint (child of parent).
Save a checkpoint (lightweight tree node).
Carries forward stage outputs from parent + adds the new one.
No stage outputs — those go in StageOutput table separately.
Returns the new checkpoint ID.
"""
from core.db.models import Checkpoint
from core.db.connection import get_session
with get_session() as session:
parent_outputs = {}
parent_stats = {}
parent_config = {}
if parent_checkpoint_id:
parent = session.get(Checkpoint, UUID(parent_checkpoint_id))
if parent:
parent_outputs = dict(parent.stage_outputs or {})
parent_stats = dict(parent.stats or {})
parent_config = dict(parent.config_overrides or {})
checkpoint = Checkpoint(
timeline_id=UUID(timeline_id),
job_id=UUID(job_id) if job_id else None,
parent_id=UUID(parent_checkpoint_id) if parent_checkpoint_id else None,
stage_outputs={**parent_outputs, stage_name: output_json},
config_overrides={**parent_config, **(config_overrides or {})},
stats={**parent_stats, **(stats or {})},
stage_name=stage_name,
config_overrides=config_overrides or {},
stats=stats or {},
is_scenario=is_scenario,
scenario_label=scenario_label,
)
@@ -165,13 +132,172 @@ def save_stage_output(
return cid
def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None:
"""Load a stage's output from a checkpoint."""
def get_checkpoints_for_job(job_id: str) -> list[dict]:
"""List checkpoints for a job, ordered by creation time."""
from sqlmodel import select
from core.db.models import Checkpoint
from core.db.connection import get_session
with get_session() as session:
checkpoint = session.get(Checkpoint, UUID(checkpoint_id))
if not checkpoint:
stmt = (
select(Checkpoint)
.where(Checkpoint.job_id == UUID(job_id))
.order_by(Checkpoint.created_at)
)
checkpoints = session.exec(stmt).all()
return [
{
"id": str(c.id),
"timeline_id": str(c.timeline_id),
"job_id": str(c.job_id) if c.job_id else None,
"parent_id": str(c.parent_id) if c.parent_id else None,
"stage_name": c.stage_name,
"config_overrides": c.config_overrides or {},
"stats": c.stats or {},
"is_scenario": c.is_scenario,
"scenario_label": c.scenario_label,
"created_at": str(c.created_at) if c.created_at else None,
}
for c in checkpoints
]
def get_checkpoints_for_timeline(timeline_id: str) -> list[dict]:
"""List all checkpoints on a timeline, ordered by creation time."""
from sqlmodel import select
from core.db.models import Checkpoint
from core.db.connection import get_session
with get_session() as session:
stmt = (
select(Checkpoint)
.where(Checkpoint.timeline_id == UUID(timeline_id))
.order_by(Checkpoint.created_at)
)
checkpoints = session.exec(stmt).all()
return [
{
"id": str(c.id),
"timeline_id": str(c.timeline_id),
"job_id": str(c.job_id) if c.job_id else None,
"parent_id": str(c.parent_id) if c.parent_id else None,
"stage_name": c.stage_name,
"config_overrides": c.config_overrides or {},
"stats": c.stats or {},
"is_scenario": c.is_scenario,
"scenario_label": c.scenario_label,
"created_at": str(c.created_at) if c.created_at else None,
}
for c in checkpoints
]
# ---------------------------------------------------------------------------
# StageOutput
# ---------------------------------------------------------------------------
def save_stage_output(
job_id: str,
timeline_id: str,
stage_name: str,
output: dict,
checkpoint_id: str | None = None,
) -> str:
"""
Save (upsert) a stage output. One row per (job_id, stage_name).
Returns the stage_output ID.
"""
from sqlmodel import select
from core.db.models import StageOutput
from core.db.connection import get_session
with get_session() as session:
# Upsert: check if exists
stmt = (
select(StageOutput)
.where(StageOutput.job_id == UUID(job_id))
.where(StageOutput.stage_name == stage_name)
)
existing = session.exec(stmt).first()
if existing:
existing.output = output
existing.checkpoint_id = UUID(checkpoint_id) if checkpoint_id else None
session.commit()
session.refresh(existing)
return str(existing.id)
stage_output = StageOutput(
job_id=UUID(job_id),
timeline_id=UUID(timeline_id),
stage_name=stage_name,
checkpoint_id=UUID(checkpoint_id) if checkpoint_id else None,
output=output,
)
session.add(stage_output)
session.commit()
session.refresh(stage_output)
return str(stage_output.id)
def load_stage_output(job_id: str, stage_name: str) -> dict | None:
"""Load a stage's output by job + stage name."""
from sqlmodel import select
from core.db.models import StageOutput
from core.db.connection import get_session
with get_session() as session:
stmt = (
select(StageOutput)
.where(StageOutput.job_id == UUID(job_id))
.where(StageOutput.stage_name == stage_name)
)
row = session.exec(stmt).first()
if not row:
return None
return (checkpoint.stage_outputs or {}).get(stage_name)
return row.output
def load_stage_outputs_for_job(job_id: str) -> dict[str, dict]:
"""Load all stage outputs for a job. Returns {stage_name: output}."""
from sqlmodel import select
from core.db.models import StageOutput
from core.db.connection import get_session
with get_session() as session:
stmt = (
select(StageOutput)
.where(StageOutput.job_id == UUID(job_id))
)
rows = session.exec(stmt).all()
return {row.stage_name: row.output for row in rows}
def load_stage_outputs_for_timeline(timeline_id: str, stage_name: str | None = None) -> list[dict]:
"""Load stage outputs for a timeline, optionally filtered by stage."""
from sqlmodel import select
from core.db.models import StageOutput
from core.db.connection import get_session
with get_session() as session:
stmt = select(StageOutput).where(StageOutput.timeline_id == UUID(timeline_id))
if stage_name:
stmt = stmt.where(StageOutput.stage_name == stage_name)
rows = session.exec(stmt).all()
return [
{
"id": str(r.id),
"job_id": str(r.job_id),
"stage_name": r.stage_name,
"checkpoint_id": str(r.checkpoint_id) if r.checkpoint_id else None,
"output": r.output,
"created_at": str(r.created_at) if r.created_at else None,
}
for r in rows
]