a
This commit is contained in:
@@ -1,19 +1,31 @@
|
||||
"""
|
||||
Checkpoint system — Timeline + Checkpoint tree.
|
||||
Checkpoint system — Timeline + Checkpoint tree + StageOutput.
|
||||
|
||||
detect/checkpoint/
|
||||
frames.py — frame image S3 upload/download
|
||||
storage.py — Timeline + Checkpoint (Postgres + MinIO)
|
||||
replay.py — replay (TODO: migrate to new model)
|
||||
frames.py — per-timeline frame cache (local filesystem)
|
||||
storage.py — Timeline, Checkpoint, StageOutput persistence
|
||||
replay.py — replay from checkpoint (TODO: rework in 5d)
|
||||
runner_bridge.py — checkpoint hook for PipelineRunner
|
||||
"""
|
||||
|
||||
from .storage import (
|
||||
create_timeline,
|
||||
get_timeline_frames,
|
||||
get_timeline_frames_b64,
|
||||
get_timeline,
|
||||
update_timeline_status,
|
||||
save_checkpoint,
|
||||
get_checkpoints_for_job,
|
||||
get_checkpoints_for_timeline,
|
||||
save_stage_output,
|
||||
load_stage_output,
|
||||
load_stage_outputs_for_job,
|
||||
load_stage_outputs_for_timeline,
|
||||
)
|
||||
from .frames import save_frames, load_frames
|
||||
from .runner_bridge import checkpoint_after_stage, reset_checkpoint_state, get_timeline_id
|
||||
from .frames import (
|
||||
cache_exists,
|
||||
cache_frames,
|
||||
load_cached_frames,
|
||||
load_cached_frames_b64,
|
||||
clear_cache,
|
||||
frames_to_b64,
|
||||
)
|
||||
from .runner_bridge import checkpoint_after_stage, reset_checkpoint_state, get_latest_checkpoint
|
||||
|
||||
@@ -1,7 +1,19 @@
|
||||
"""Frame image storage — save/load to S3/MinIO as JPEGs."""
|
||||
"""
|
||||
Frame cache — per-timeline frame storage in blob storage (S3/MinIO).
|
||||
|
||||
Frames are extracted from chunks once, cached as JPEGs at
|
||||
cache/timelines/{timeline_id}/frames/{seq}.jpg in the app's
|
||||
blob storage. Any job on the timeline reads from the cache.
|
||||
Cache is clearable and rebuildable from chunks.
|
||||
|
||||
Uses the same storage backend as the rest of the app, so it
|
||||
works across lambdas, GPU boxes, and local dev.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
@@ -14,25 +26,39 @@ from core.detect.models import Frame
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BUCKET = os.environ.get("S3_BUCKET", "mpr")
|
||||
CHECKPOINT_PREFIX = "checkpoints"
|
||||
CACHE_PREFIX = "cache/timelines"
|
||||
|
||||
|
||||
def save_frames(job_id: str, frames: list[Frame]) -> dict[int, str]:
|
||||
def _frame_key(timeline_id: str, seq: int) -> str:
|
||||
return f"{CACHE_PREFIX}/{timeline_id}/frames/{seq}.jpg"
|
||||
|
||||
|
||||
def _list_prefix(timeline_id: str) -> str:
|
||||
return f"{CACHE_PREFIX}/{timeline_id}/frames/"
|
||||
|
||||
|
||||
def cache_exists(timeline_id: str) -> bool:
|
||||
"""Check if frame cache exists for a timeline."""
|
||||
from core.storage.s3 import list_objects
|
||||
|
||||
objects = list_objects(BUCKET, _list_prefix(timeline_id))
|
||||
return len(objects) > 0
|
||||
|
||||
|
||||
def cache_frames(timeline_id: str, frames: list[Frame], quality: int = 85) -> int:
|
||||
"""
|
||||
Save frame images to S3 as JPEGs.
|
||||
Write frames to blob storage as JPEGs.
|
||||
|
||||
Returns manifest: {sequence: s3_key}
|
||||
Returns number of frames cached.
|
||||
"""
|
||||
from core.storage.s3 import upload_file
|
||||
|
||||
manifest = {}
|
||||
|
||||
for frame in frames:
|
||||
key = f"{CHECKPOINT_PREFIX}/{job_id}/frames/{frame.sequence}.jpg"
|
||||
key = _frame_key(timeline_id, frame.sequence)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
|
||||
img = Image.fromarray(frame.image)
|
||||
img.save(tmp, format="JPEG", quality=85)
|
||||
img.save(tmp, format="JPEG", quality=quality)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
@@ -40,25 +66,30 @@ def save_frames(job_id: str, frames: list[Frame]) -> dict[int, str]:
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
manifest[frame.sequence] = key
|
||||
|
||||
logger.info("Saved %d frames to s3://%s/%s/%s/frames/",
|
||||
len(frames), BUCKET, CHECKPOINT_PREFIX, job_id)
|
||||
return manifest
|
||||
logger.info("Cached %d frames for timeline %s", len(frames), timeline_id)
|
||||
return len(frames)
|
||||
|
||||
|
||||
def load_frames(manifest: dict[int, str], frame_metadata: list[dict]) -> list[Frame]:
|
||||
def load_cached_frames(timeline_id: str) -> list[Frame]:
|
||||
"""
|
||||
Load frame images from S3 and reconstitute Frame objects.
|
||||
Load all cached frames as Frame objects with numpy arrays.
|
||||
|
||||
frame_metadata: list of dicts with sequence, chunk_id, timestamp, perceptual_hash.
|
||||
Returns empty list if cache doesn't exist.
|
||||
"""
|
||||
from core.storage.s3 import download_to_temp
|
||||
from core.storage.s3 import list_objects, download_to_temp
|
||||
|
||||
objects = list_objects(BUCKET, _list_prefix(timeline_id))
|
||||
if not objects:
|
||||
return []
|
||||
|
||||
meta_map = {m["sequence"]: m for m in frame_metadata}
|
||||
frames = []
|
||||
for obj in objects:
|
||||
key = obj["key"]
|
||||
filename = key.rsplit("/", 1)[-1]
|
||||
if not filename.endswith(".jpg"):
|
||||
continue
|
||||
seq = int(filename.replace(".jpg", ""))
|
||||
|
||||
for seq, key in manifest.items():
|
||||
tmp_path = download_to_temp(BUCKET, key)
|
||||
try:
|
||||
img = Image.open(tmp_path).convert("RGB")
|
||||
@@ -66,13 +97,12 @@ def load_frames(manifest: dict[int, str], frame_metadata: list[dict]) -> list[Fr
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
meta = meta_map.get(seq, {})
|
||||
frame = Frame(
|
||||
sequence=seq,
|
||||
chunk_id=meta.get("chunk_id", 0),
|
||||
timestamp=meta.get("timestamp", 0.0),
|
||||
chunk_id=0,
|
||||
timestamp=0.0,
|
||||
image=image_array,
|
||||
perceptual_hash=meta.get("perceptual_hash", ""),
|
||||
perceptual_hash="",
|
||||
)
|
||||
frames.append(frame)
|
||||
|
||||
@@ -80,32 +110,70 @@ def load_frames(manifest: dict[int, str], frame_metadata: list[dict]) -> list[Fr
|
||||
return frames
|
||||
|
||||
|
||||
def load_frames_b64(manifest: dict[int, str], frame_metadata: list[dict]) -> list[dict]:
|
||||
def load_cached_frames_b64(timeline_id: str) -> list[dict]:
|
||||
"""
|
||||
Load frame images from S3 as base64 JPEG — lightweight, no numpy.
|
||||
Load cached frames as base64 JPEGs for the UI.
|
||||
|
||||
Returns list of dicts: {seq, timestamp, jpeg_b64}
|
||||
Returns list of {seq, timestamp, jpeg_b64}.
|
||||
"""
|
||||
import base64
|
||||
from core.storage.s3 import download_to_temp
|
||||
from core.storage.s3 import list_objects, download_to_temp
|
||||
|
||||
meta_map = {m["sequence"]: m for m in frame_metadata}
|
||||
frames = []
|
||||
objects = list_objects(BUCKET, _list_prefix(timeline_id))
|
||||
if not objects:
|
||||
return []
|
||||
|
||||
result = []
|
||||
for obj in objects:
|
||||
key = obj["key"]
|
||||
filename = key.rsplit("/", 1)[-1]
|
||||
if not filename.endswith(".jpg"):
|
||||
continue
|
||||
seq = int(filename.replace(".jpg", ""))
|
||||
|
||||
for seq, key in manifest.items():
|
||||
tmp_path = download_to_temp(BUCKET, key)
|
||||
try:
|
||||
with open(tmp_path, "rb") as f:
|
||||
jpeg_bytes = f.read()
|
||||
jpeg_b64 = base64.b64encode(f.read()).decode()
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
meta = meta_map.get(seq, {})
|
||||
frames.append({
|
||||
result.append({
|
||||
"seq": seq,
|
||||
"timestamp": meta.get("timestamp", 0.0),
|
||||
"jpeg_b64": base64.b64encode(jpeg_bytes).decode(),
|
||||
"timestamp": 0.0,
|
||||
"jpeg_b64": jpeg_b64,
|
||||
})
|
||||
|
||||
frames.sort(key=lambda f: f["seq"])
|
||||
return frames
|
||||
result.sort(key=lambda f: f["seq"])
|
||||
return result
|
||||
|
||||
|
||||
def clear_cache(timeline_id: str):
|
||||
"""Delete the frame cache for a timeline."""
|
||||
from core.storage.s3 import delete_objects
|
||||
|
||||
prefix = _list_prefix(timeline_id)
|
||||
delete_objects(BUCKET, prefix)
|
||||
logger.info("Cleared frame cache for timeline %s", timeline_id)
|
||||
|
||||
|
||||
def frames_to_b64(frames: list[Frame], quality: int = 75) -> list[dict]:
|
||||
"""
|
||||
Convert in-memory Frame objects to base64 JPEG dicts.
|
||||
|
||||
For API responses when frames are already in memory.
|
||||
"""
|
||||
result = []
|
||||
for frame in frames:
|
||||
buf = io.BytesIO()
|
||||
img = Image.fromarray(frame.image)
|
||||
img.save(buf, format="JPEG", quality=quality)
|
||||
jpeg_b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
result.append({
|
||||
"seq": frame.sequence,
|
||||
"timestamp": frame.timestamp,
|
||||
"jpeg_b64": jpeg_b64,
|
||||
})
|
||||
|
||||
result.sort(key=lambda f: f["seq"])
|
||||
return result
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
"""
|
||||
Runner bridge — checkpoint hook called by PipelineRunner after each stage.
|
||||
|
||||
Owns the per-job state (timeline, frame manifest, checkpoint chain) that
|
||||
the runner shouldn't know about.
|
||||
|
||||
Timeline and Job are independent entities:
|
||||
- One Timeline can serve multiple Jobs (re-run with different params)
|
||||
- One Job operates on one Timeline (set after frame extraction)
|
||||
- Checkpoints belong to Timeline, tagged with the Job that created them
|
||||
Saves a checkpoint + stage output after each stage completes.
|
||||
Timeline and Job are independent: timeline_id and job_id come from
|
||||
the pipeline state (set at job creation time).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -16,63 +12,37 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Per-job state
|
||||
_timeline_id: dict[str, str] = {}
|
||||
_frames_manifest: dict[str, dict[int, str]] = {}
|
||||
# Per-job state: tracks the latest checkpoint so we can chain parent → child
|
||||
_latest_checkpoint: dict[str, str] = {}
|
||||
|
||||
|
||||
def reset_checkpoint_state(job_id: str):
|
||||
"""Clean up per-job checkpoint state. Called when pipeline finishes."""
|
||||
_timeline_id.pop(job_id, None)
|
||||
_frames_manifest.pop(job_id, None)
|
||||
_latest_checkpoint.pop(job_id, None)
|
||||
|
||||
|
||||
def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict):
|
||||
"""
|
||||
Save a checkpoint after a stage completes.
|
||||
Save a checkpoint + stage output after a stage completes.
|
||||
|
||||
Called by the runner. Handles:
|
||||
- Timeline creation (once, on extract_frames)
|
||||
- Frame upload (via create_timeline)
|
||||
- Stage output serialization (via stage registry)
|
||||
- Checkpoint chain (parent → child)
|
||||
- Stage output as separate row in StageOutput table
|
||||
"""
|
||||
if not job_id:
|
||||
return
|
||||
|
||||
from .storage import create_timeline, save_stage_output
|
||||
timeline_id = state.get("timeline_id", "")
|
||||
if not timeline_id:
|
||||
logger.warning("No timeline_id in state for job %s, skipping checkpoint", job_id)
|
||||
return
|
||||
|
||||
from .storage import save_checkpoint, save_stage_output
|
||||
from core.detect.stages.base import _REGISTRY
|
||||
|
||||
merged = {**state, **result}
|
||||
|
||||
# On extract_frames: create Timeline + upload frames + root checkpoint
|
||||
if stage_name == "extract_frames" and job_id not in _timeline_id:
|
||||
frames = merged.get("frames", [])
|
||||
video_path = merged.get("video_path", "")
|
||||
profile_name = merged.get("profile_name", "")
|
||||
|
||||
tid, cid = create_timeline(
|
||||
source_video=video_path,
|
||||
profile_name=profile_name,
|
||||
frames=frames,
|
||||
)
|
||||
_timeline_id[job_id] = tid
|
||||
_latest_checkpoint[job_id] = cid
|
||||
logger.info("Job %s → Timeline %s (root checkpoint %s)", job_id, tid, cid)
|
||||
|
||||
# Emit timeline_id via SSE so the UI can use it for checkpoint loads
|
||||
from core.detect import emit
|
||||
emit.log(job_id, "Checkpoint", "INFO", f"timeline_id={tid}")
|
||||
return
|
||||
|
||||
# For subsequent stages: save checkpoint on the timeline
|
||||
tid = _timeline_id.get(job_id)
|
||||
if not tid:
|
||||
logger.warning("No timeline for job %s, skipping checkpoint", job_id)
|
||||
return
|
||||
|
||||
# Serialize stage output using the stage's serialize_fn if available
|
||||
stage_cls = _REGISTRY.get(stage_name)
|
||||
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
||||
@@ -81,17 +51,41 @@ def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: di
|
||||
else:
|
||||
output_json = {}
|
||||
|
||||
# Convert stats dataclass to dict for JSONB storage
|
||||
import dataclasses
|
||||
raw_stats = state.get("stats", {})
|
||||
if dataclasses.is_dataclass(raw_stats):
|
||||
stats_dict = dataclasses.asdict(raw_stats)
|
||||
elif isinstance(raw_stats, dict):
|
||||
stats_dict = raw_stats
|
||||
else:
|
||||
stats_dict = {}
|
||||
|
||||
# Save checkpoint (lightweight tree node)
|
||||
parent_id = _latest_checkpoint.get(job_id)
|
||||
new_checkpoint_id = save_stage_output(
|
||||
timeline_id=tid,
|
||||
parent_checkpoint_id=parent_id,
|
||||
checkpoint_id = save_checkpoint(
|
||||
timeline_id=timeline_id,
|
||||
stage_name=stage_name,
|
||||
output_json=output_json,
|
||||
parent_checkpoint_id=parent_id,
|
||||
config_overrides=state.get("config_overrides"),
|
||||
stats=stats_dict,
|
||||
job_id=job_id,
|
||||
)
|
||||
_latest_checkpoint[job_id] = new_checkpoint_id
|
||||
_latest_checkpoint[job_id] = checkpoint_id
|
||||
|
||||
# Save stage output (separate row, upsert by job+stage)
|
||||
if output_json:
|
||||
save_stage_output(
|
||||
job_id=job_id,
|
||||
timeline_id=timeline_id,
|
||||
stage_name=stage_name,
|
||||
output=output_json,
|
||||
checkpoint_id=checkpoint_id,
|
||||
)
|
||||
|
||||
logger.info("Checkpoint %s + output for stage %s (job %s)", checkpoint_id, stage_name, job_id)
|
||||
|
||||
|
||||
def get_timeline_id(job_id: str) -> str | None:
|
||||
"""Get the timeline_id for a running job. Used by the UI to load checkpoints."""
|
||||
return _timeline_id.get(job_id)
|
||||
def get_latest_checkpoint(job_id: str) -> str | None:
|
||||
"""Get the latest checkpoint_id for a running job."""
|
||||
return _latest_checkpoint.get(job_id)
|
||||
|
||||
@@ -6,6 +6,9 @@ This file has no model-specific knowledge — stages own their data format.
|
||||
|
||||
The only things serialized here are the "envelope" fields (job_id, video_path, etc.)
|
||||
that don't belong to any stage.
|
||||
|
||||
Frames are ephemeral (in-memory during a run). Serialization stores
|
||||
metadata only; frames are re-extracted from chunks when needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -18,10 +21,10 @@ from core.schema.serializers.pipeline import (
|
||||
|
||||
|
||||
# Envelope fields — not owned by any stage, always present
|
||||
ENVELOPE_KEYS = ["job_id", "video_path", "profile_name", "config_overrides"]
|
||||
ENVELOPE_KEYS = ["job_id", "video_path", "profile_name", "timeline_id", "config_overrides"]
|
||||
|
||||
|
||||
def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
|
||||
def serialize_state(state: dict) -> dict:
|
||||
"""
|
||||
Serialize DetectState to a JSON-compatible dict.
|
||||
|
||||
@@ -37,9 +40,6 @@ def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
|
||||
default = {} if key == "config_overrides" else ""
|
||||
checkpoint[key] = state.get(key, default)
|
||||
|
||||
# Frames manifest (needed by frame-loading stages)
|
||||
checkpoint["frames_manifest"] = {str(k): v for k, v in frames_manifest.items()}
|
||||
|
||||
# Stats (shared across stages, not owned by one)
|
||||
stats = state.get("stats")
|
||||
if stats is not None:
|
||||
@@ -60,8 +60,9 @@ def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
|
||||
|
||||
def deserialize_state(checkpoint: dict, frames: list) -> dict:
|
||||
"""
|
||||
Reconstitute DetectState from a checkpoint dict + loaded frames.
|
||||
Reconstitute DetectState from a checkpoint dict + frames.
|
||||
|
||||
Frames are provided by the caller (re-extracted from chunks).
|
||||
Calls each stage's deserialize_fn to restore stage-owned data.
|
||||
"""
|
||||
from core.detect.stages.base import _REGISTRY
|
||||
@@ -75,7 +76,7 @@ def deserialize_state(checkpoint: dict, frames: list) -> dict:
|
||||
default = {} if key == "config_overrides" else ""
|
||||
state[key] = checkpoint.get(key, default)
|
||||
|
||||
# Frames (always present, loaded externally)
|
||||
# Frames (provided externally, ephemeral)
|
||||
state["frames"] = frames
|
||||
|
||||
# Stats
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""
|
||||
Checkpoint storage — Timeline + Checkpoint (tree of snapshots).
|
||||
Checkpoint storage — Timeline, Checkpoint, StageOutput persistence.
|
||||
|
||||
Timeline: frame sequence from source video (frames in MinIO)
|
||||
Checkpoint: snapshot of pipeline state (stage outputs as JSONB in Postgres)
|
||||
parent_id forms a tree — multiple children = different config tries
|
||||
Timeline: user-created source selection (chunk paths)
|
||||
Checkpoint: lightweight tree node (parent_id, stage_name, config, stats)
|
||||
StageOutput: per-stage result (flat table, one row per job+stage)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -11,8 +11,6 @@ from __future__ import annotations
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from .frames import save_frames, load_frames, CHECKPOINT_PREFIX
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -21,66 +19,41 @@ logger = logging.getLogger(__name__)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_timeline(
|
||||
source_video: str,
|
||||
profile_name: str,
|
||||
frames: list,
|
||||
fps: float = 2.0,
|
||||
chunk_paths: list[str],
|
||||
profile_name: str = "",
|
||||
name: str = "",
|
||||
source_asset_id: UUID | None = None,
|
||||
) -> tuple[str, str]:
|
||||
fps: float = 2.0,
|
||||
) -> str:
|
||||
"""
|
||||
Create a timeline from frames. Uploads frame images to MinIO,
|
||||
creates Timeline + root Checkpoint in Postgres.
|
||||
Create a timeline from a chunk selection.
|
||||
|
||||
Returns (timeline_id, checkpoint_id).
|
||||
Called by the user (via API) before any pipeline runs.
|
||||
Returns timeline_id.
|
||||
"""
|
||||
from core.db.models import Timeline, Checkpoint
|
||||
from core.db.models import Timeline
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
timeline = Timeline(
|
||||
source_video=source_video,
|
||||
name=name,
|
||||
chunk_paths=chunk_paths,
|
||||
profile_name=profile_name,
|
||||
source_asset_id=source_asset_id,
|
||||
fps=fps,
|
||||
status="created",
|
||||
)
|
||||
session.add(timeline)
|
||||
session.flush()
|
||||
session.commit()
|
||||
session.refresh(timeline)
|
||||
tid = str(timeline.id)
|
||||
|
||||
# Upload frames to MinIO
|
||||
manifest = save_frames(tid, frames)
|
||||
|
||||
frames_meta = [
|
||||
{
|
||||
"sequence": f.sequence,
|
||||
"chunk_id": getattr(f, "chunk_id", 0),
|
||||
"timestamp": f.timestamp,
|
||||
"perceptual_hash": getattr(f, "perceptual_hash", ""),
|
||||
}
|
||||
for f in frames
|
||||
]
|
||||
|
||||
timeline.frames_prefix = f"{CHECKPOINT_PREFIX}/{tid}/frames/"
|
||||
timeline.frames_manifest = {str(k): v for k, v in manifest.items()}
|
||||
timeline.frames_meta = frames_meta
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
timeline_id=timeline.id,
|
||||
parent_id=None,
|
||||
stage_outputs={},
|
||||
stats={"frames_extracted": len(frames)},
|
||||
)
|
||||
session.add(checkpoint)
|
||||
session.commit()
|
||||
session.refresh(checkpoint)
|
||||
cid = str(checkpoint.id)
|
||||
|
||||
logger.info("Timeline created: %s (%d frames, root checkpoint %s)", tid, len(frames), cid)
|
||||
return tid, cid
|
||||
logger.info("Timeline created: %s (%d chunks)", tid, len(chunk_paths))
|
||||
return tid
|
||||
|
||||
|
||||
def get_timeline_frames(timeline_id: str) -> list:
|
||||
"""Load frames from a timeline (from MinIO) as Frame objects."""
|
||||
def get_timeline(timeline_id: str) -> dict:
|
||||
"""Load a timeline as a dict."""
|
||||
from core.db.models import Timeline
|
||||
from core.db.connection import get_session
|
||||
|
||||
@@ -89,36 +62,40 @@ def get_timeline_frames(timeline_id: str) -> list:
|
||||
if not timeline:
|
||||
raise ValueError(f"Timeline not found: {timeline_id}")
|
||||
|
||||
raw_manifest = timeline.frames_manifest or {}
|
||||
manifest = {int(k): v for k, v in raw_manifest.items()}
|
||||
return load_frames(manifest, timeline.frames_meta or [])
|
||||
return {
|
||||
"id": str(timeline.id),
|
||||
"name": timeline.name,
|
||||
"chunk_paths": timeline.chunk_paths,
|
||||
"profile_name": timeline.profile_name,
|
||||
"status": timeline.status,
|
||||
"fps": timeline.fps,
|
||||
"source_asset_id": str(timeline.source_asset_id) if timeline.source_asset_id else None,
|
||||
"created_at": str(timeline.created_at) if timeline.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def get_timeline_frames_b64(timeline_id: str) -> list[dict]:
|
||||
"""Load frames as base64 JPEG (lightweight, no numpy)."""
|
||||
def update_timeline_status(timeline_id: str, status: str, frame_count: int | None = None):
|
||||
"""Update timeline status and optionally frame count."""
|
||||
from core.db.models import Timeline
|
||||
from core.db.connection import get_session
|
||||
from .frames import load_frames_b64
|
||||
|
||||
with get_session() as session:
|
||||
timeline = session.get(Timeline, UUID(timeline_id))
|
||||
if not timeline:
|
||||
raise ValueError(f"Timeline not found: {timeline_id}")
|
||||
|
||||
raw_manifest = timeline.frames_manifest or {}
|
||||
manifest = {int(k): v for k, v in raw_manifest.items()}
|
||||
return load_frames_b64(manifest, timeline.frames_meta or [])
|
||||
if timeline:
|
||||
timeline.status = status
|
||||
if frame_count is not None:
|
||||
timeline.frame_count = frame_count
|
||||
session.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Checkpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def save_stage_output(
|
||||
def save_checkpoint(
|
||||
timeline_id: str,
|
||||
parent_checkpoint_id: str | None,
|
||||
stage_name: str,
|
||||
output_json: dict,
|
||||
parent_checkpoint_id: str | None = None,
|
||||
config_overrides: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
is_scenario: bool = False,
|
||||
@@ -126,32 +103,22 @@ def save_stage_output(
|
||||
job_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Save a stage's output as a new checkpoint (child of parent).
|
||||
Save a checkpoint (lightweight tree node).
|
||||
|
||||
Carries forward stage outputs from parent + adds the new one.
|
||||
No stage outputs — those go in StageOutput table separately.
|
||||
Returns the new checkpoint ID.
|
||||
"""
|
||||
from core.db.models import Checkpoint
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
parent_outputs = {}
|
||||
parent_stats = {}
|
||||
parent_config = {}
|
||||
if parent_checkpoint_id:
|
||||
parent = session.get(Checkpoint, UUID(parent_checkpoint_id))
|
||||
if parent:
|
||||
parent_outputs = dict(parent.stage_outputs or {})
|
||||
parent_stats = dict(parent.stats or {})
|
||||
parent_config = dict(parent.config_overrides or {})
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
timeline_id=UUID(timeline_id),
|
||||
job_id=UUID(job_id) if job_id else None,
|
||||
parent_id=UUID(parent_checkpoint_id) if parent_checkpoint_id else None,
|
||||
stage_outputs={**parent_outputs, stage_name: output_json},
|
||||
config_overrides={**parent_config, **(config_overrides or {})},
|
||||
stats={**parent_stats, **(stats or {})},
|
||||
stage_name=stage_name,
|
||||
config_overrides=config_overrides or {},
|
||||
stats=stats or {},
|
||||
is_scenario=is_scenario,
|
||||
scenario_label=scenario_label,
|
||||
)
|
||||
@@ -165,13 +132,172 @@ def save_stage_output(
|
||||
return cid
|
||||
|
||||
|
||||
def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None:
|
||||
"""Load a stage's output from a checkpoint."""
|
||||
def get_checkpoints_for_job(job_id: str) -> list[dict]:
|
||||
"""List checkpoints for a job, ordered by creation time."""
|
||||
from sqlmodel import select
|
||||
from core.db.models import Checkpoint
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
checkpoint = session.get(Checkpoint, UUID(checkpoint_id))
|
||||
if not checkpoint:
|
||||
stmt = (
|
||||
select(Checkpoint)
|
||||
.where(Checkpoint.job_id == UUID(job_id))
|
||||
.order_by(Checkpoint.created_at)
|
||||
)
|
||||
checkpoints = session.exec(stmt).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(c.id),
|
||||
"timeline_id": str(c.timeline_id),
|
||||
"job_id": str(c.job_id) if c.job_id else None,
|
||||
"parent_id": str(c.parent_id) if c.parent_id else None,
|
||||
"stage_name": c.stage_name,
|
||||
"config_overrides": c.config_overrides or {},
|
||||
"stats": c.stats or {},
|
||||
"is_scenario": c.is_scenario,
|
||||
"scenario_label": c.scenario_label,
|
||||
"created_at": str(c.created_at) if c.created_at else None,
|
||||
}
|
||||
for c in checkpoints
|
||||
]
|
||||
|
||||
|
||||
def get_checkpoints_for_timeline(timeline_id: str) -> list[dict]:
|
||||
"""List all checkpoints on a timeline, ordered by creation time."""
|
||||
from sqlmodel import select
|
||||
from core.db.models import Checkpoint
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
stmt = (
|
||||
select(Checkpoint)
|
||||
.where(Checkpoint.timeline_id == UUID(timeline_id))
|
||||
.order_by(Checkpoint.created_at)
|
||||
)
|
||||
checkpoints = session.exec(stmt).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(c.id),
|
||||
"timeline_id": str(c.timeline_id),
|
||||
"job_id": str(c.job_id) if c.job_id else None,
|
||||
"parent_id": str(c.parent_id) if c.parent_id else None,
|
||||
"stage_name": c.stage_name,
|
||||
"config_overrides": c.config_overrides or {},
|
||||
"stats": c.stats or {},
|
||||
"is_scenario": c.is_scenario,
|
||||
"scenario_label": c.scenario_label,
|
||||
"created_at": str(c.created_at) if c.created_at else None,
|
||||
}
|
||||
for c in checkpoints
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StageOutput
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def save_stage_output(
|
||||
job_id: str,
|
||||
timeline_id: str,
|
||||
stage_name: str,
|
||||
output: dict,
|
||||
checkpoint_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Save (upsert) a stage output. One row per (job_id, stage_name).
|
||||
|
||||
Returns the stage_output ID.
|
||||
"""
|
||||
from sqlmodel import select
|
||||
from core.db.models import StageOutput
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
# Upsert: check if exists
|
||||
stmt = (
|
||||
select(StageOutput)
|
||||
.where(StageOutput.job_id == UUID(job_id))
|
||||
.where(StageOutput.stage_name == stage_name)
|
||||
)
|
||||
existing = session.exec(stmt).first()
|
||||
|
||||
if existing:
|
||||
existing.output = output
|
||||
existing.checkpoint_id = UUID(checkpoint_id) if checkpoint_id else None
|
||||
session.commit()
|
||||
session.refresh(existing)
|
||||
return str(existing.id)
|
||||
|
||||
stage_output = StageOutput(
|
||||
job_id=UUID(job_id),
|
||||
timeline_id=UUID(timeline_id),
|
||||
stage_name=stage_name,
|
||||
checkpoint_id=UUID(checkpoint_id) if checkpoint_id else None,
|
||||
output=output,
|
||||
)
|
||||
session.add(stage_output)
|
||||
session.commit()
|
||||
session.refresh(stage_output)
|
||||
return str(stage_output.id)
|
||||
|
||||
|
||||
def load_stage_output(job_id: str, stage_name: str) -> dict | None:
|
||||
"""Load a stage's output by job + stage name."""
|
||||
from sqlmodel import select
|
||||
from core.db.models import StageOutput
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
stmt = (
|
||||
select(StageOutput)
|
||||
.where(StageOutput.job_id == UUID(job_id))
|
||||
.where(StageOutput.stage_name == stage_name)
|
||||
)
|
||||
row = session.exec(stmt).first()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
return (checkpoint.stage_outputs or {}).get(stage_name)
|
||||
return row.output
|
||||
|
||||
|
||||
def load_stage_outputs_for_job(job_id: str) -> dict[str, dict]:
|
||||
"""Load all stage outputs for a job. Returns {stage_name: output}."""
|
||||
from sqlmodel import select
|
||||
from core.db.models import StageOutput
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
stmt = (
|
||||
select(StageOutput)
|
||||
.where(StageOutput.job_id == UUID(job_id))
|
||||
)
|
||||
rows = session.exec(stmt).all()
|
||||
|
||||
return {row.stage_name: row.output for row in rows}
|
||||
|
||||
|
||||
def load_stage_outputs_for_timeline(timeline_id: str, stage_name: str | None = None) -> list[dict]:
|
||||
"""Load stage outputs for a timeline, optionally filtered by stage."""
|
||||
from sqlmodel import select
|
||||
from core.db.models import StageOutput
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
stmt = select(StageOutput).where(StageOutput.timeline_id == UUID(timeline_id))
|
||||
if stage_name:
|
||||
stmt = stmt.where(StageOutput.stage_name == stage_name)
|
||||
rows = session.exec(stmt).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(r.id),
|
||||
"job_id": str(r.job_id),
|
||||
"stage_name": r.stage_name,
|
||||
"checkpoint_id": str(r.checkpoint_id) if r.checkpoint_id else None,
|
||||
"output": r.output,
|
||||
"created_at": str(r.created_at) if r.created_at else None,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
@@ -98,6 +98,15 @@ def node_extract_frames(state: DetectState) -> dict:
|
||||
frames = extract_frames(state["video_path"], config, job_id=job_id)
|
||||
span.set_output({"frames_extracted": len(frames)})
|
||||
|
||||
# Cache frames on the timeline for reuse across jobs and UI
|
||||
timeline_id = state.get("timeline_id")
|
||||
if timeline_id:
|
||||
from core.detect.checkpoint.frames import cache_frames, cache_exists
|
||||
if not cache_exists(timeline_id):
|
||||
cache_frames(timeline_id, frames)
|
||||
from core.detect.checkpoint.storage import update_timeline_status
|
||||
update_timeline_status(timeline_id, "cached", frame_count=len(frames))
|
||||
|
||||
_emit(state, "extract_frames", "done")
|
||||
return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))}
|
||||
|
||||
|
||||
@@ -12,8 +12,7 @@ from core.schema.serializers._common import (
|
||||
)
|
||||
from core.schema.serializers.pipeline import (
|
||||
serialize_frame_meta,
|
||||
serialize_frames_with_upload as serialize_frames,
|
||||
deserialize_frames_with_download as deserialize_frames,
|
||||
serialize_frames_meta,
|
||||
serialize_text_candidate,
|
||||
serialize_text_candidates,
|
||||
deserialize_text_candidate,
|
||||
|
||||
@@ -2,18 +2,19 @@
|
||||
|
||||
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||
from core.detect.stages.base import register_stage
|
||||
from ._serializers import serialize_frames, deserialize_frames
|
||||
from ._serializers import serialize_frame_meta
|
||||
|
||||
|
||||
def _ser_extract(state: dict, job_id: str) -> dict:
|
||||
frames = state.get("frames", [])
|
||||
meta, manifest = serialize_frames(frames, job_id)
|
||||
return {"frames_meta": meta, "frames_manifest": manifest}
|
||||
meta = [serialize_frame_meta(f) for f in frames]
|
||||
return {"frames_meta": meta, "frame_count": len(frames)}
|
||||
|
||||
|
||||
def _deser_extract(data: dict, job_id: str) -> dict:
|
||||
frames = deserialize_frames(data["frames_meta"], data["frames_manifest"], job_id)
|
||||
return {"frames": frames}
|
||||
# Frames are ephemeral — re-extract from chunks on demand.
|
||||
# Store metadata so we know what was extracted.
|
||||
return {"_frames_meta": data.get("frames_meta", [])}
|
||||
|
||||
|
||||
def _ser_filter(state: dict, job_id: str) -> dict:
|
||||
|
||||
@@ -16,6 +16,7 @@ class DetectState(TypedDict, total=False):
|
||||
# Input
|
||||
video_path: str
|
||||
job_id: str
|
||||
timeline_id: str
|
||||
profile_name: str
|
||||
source_asset_id: str # UUID of the source MediaAsset
|
||||
|
||||
|
||||
Reference in New Issue
Block a user