phase 4
This commit is contained in:
19
core/detect/checkpoint/__init__.py
Normal file
19
core/detect/checkpoint/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Checkpoint system — Timeline + Checkpoint tree.
|
||||
|
||||
detect/checkpoint/
|
||||
frames.py — frame image S3 upload/download
|
||||
storage.py — Timeline + Checkpoint (Postgres + MinIO)
|
||||
replay.py — replay (TODO: migrate to new model)
|
||||
runner_bridge.py — checkpoint hook for PipelineRunner
|
||||
"""
|
||||
|
||||
from .storage import (
|
||||
create_timeline,
|
||||
get_timeline_frames,
|
||||
get_timeline_frames_b64,
|
||||
save_stage_output,
|
||||
load_stage_output,
|
||||
)
|
||||
from .frames import save_frames, load_frames
|
||||
from .runner_bridge import checkpoint_after_stage, reset_checkpoint_state, get_timeline_id
|
||||
111
core/detect/checkpoint/frames.py
Normal file
111
core/detect/checkpoint/frames.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Frame image storage — save/load to S3/MinIO as JPEGs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from core.detect.models import Frame
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BUCKET = os.environ.get("S3_BUCKET", "mpr")
|
||||
CHECKPOINT_PREFIX = "checkpoints"
|
||||
|
||||
|
||||
def save_frames(job_id: str, frames: list[Frame]) -> dict[int, str]:
|
||||
"""
|
||||
Save frame images to S3 as JPEGs.
|
||||
|
||||
Returns manifest: {sequence: s3_key}
|
||||
"""
|
||||
from core.storage.s3 import upload_file
|
||||
|
||||
manifest = {}
|
||||
|
||||
for frame in frames:
|
||||
key = f"{CHECKPOINT_PREFIX}/{job_id}/frames/{frame.sequence}.jpg"
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
|
||||
img = Image.fromarray(frame.image)
|
||||
img.save(tmp, format="JPEG", quality=85)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
upload_file(tmp_path, BUCKET, key)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
manifest[frame.sequence] = key
|
||||
|
||||
logger.info("Saved %d frames to s3://%s/%s/%s/frames/",
|
||||
len(frames), BUCKET, CHECKPOINT_PREFIX, job_id)
|
||||
return manifest
|
||||
|
||||
|
||||
def load_frames(manifest: dict[int, str], frame_metadata: list[dict]) -> list[Frame]:
|
||||
"""
|
||||
Load frame images from S3 and reconstitute Frame objects.
|
||||
|
||||
frame_metadata: list of dicts with sequence, chunk_id, timestamp, perceptual_hash.
|
||||
"""
|
||||
from core.storage.s3 import download_to_temp
|
||||
|
||||
meta_map = {m["sequence"]: m for m in frame_metadata}
|
||||
frames = []
|
||||
|
||||
for seq, key in manifest.items():
|
||||
tmp_path = download_to_temp(BUCKET, key)
|
||||
try:
|
||||
img = Image.open(tmp_path).convert("RGB")
|
||||
image_array = np.array(img)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
meta = meta_map.get(seq, {})
|
||||
frame = Frame(
|
||||
sequence=seq,
|
||||
chunk_id=meta.get("chunk_id", 0),
|
||||
timestamp=meta.get("timestamp", 0.0),
|
||||
image=image_array,
|
||||
perceptual_hash=meta.get("perceptual_hash", ""),
|
||||
)
|
||||
frames.append(frame)
|
||||
|
||||
frames.sort(key=lambda f: f.sequence)
|
||||
return frames
|
||||
|
||||
|
||||
def load_frames_b64(manifest: dict[int, str], frame_metadata: list[dict]) -> list[dict]:
|
||||
"""
|
||||
Load frame images from S3 as base64 JPEG — lightweight, no numpy.
|
||||
|
||||
Returns list of dicts: {seq, timestamp, jpeg_b64}
|
||||
"""
|
||||
import base64
|
||||
from core.storage.s3 import download_to_temp
|
||||
|
||||
meta_map = {m["sequence"]: m for m in frame_metadata}
|
||||
frames = []
|
||||
|
||||
for seq, key in manifest.items():
|
||||
tmp_path = download_to_temp(BUCKET, key)
|
||||
try:
|
||||
with open(tmp_path, "rb") as f:
|
||||
jpeg_bytes = f.read()
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
meta = meta_map.get(seq, {})
|
||||
frames.append({
|
||||
"seq": seq,
|
||||
"timestamp": meta.get("timestamp", 0.0),
|
||||
"jpeg_b64": base64.b64encode(jpeg_bytes).decode(),
|
||||
})
|
||||
|
||||
frames.sort(key=lambda f: f["seq"])
|
||||
return frames
|
||||
232
core/detect/checkpoint/replay.py
Normal file
232
core/detect/checkpoint/replay.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
Pipeline replay — re-run from any stage with different config.
|
||||
|
||||
Loads a checkpoint, applies config overrides, builds a subgraph
|
||||
starting from the target stage, and invokes it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import uuid
|
||||
|
||||
from core.detect import emit
|
||||
# TODO: migrate to Timeline/Branch/Checkpoint model
|
||||
# These old functions no longer exist — replay needs rework
|
||||
def _not_migrated(*args, **kwargs):
|
||||
raise NotImplementedError("Replay not yet migrated to Timeline/Branch/Checkpoint model")
|
||||
|
||||
load_checkpoint = _not_migrated
|
||||
list_checkpoints = _not_migrated
|
||||
from core.detect.graph import NODES, build_graph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
# OverrideProfile removed — config overrides are now handled by dict merging
|
||||
# in _load_profile() (nodes.py) and replay_single_stage (below).
|
||||
|
||||
|
||||
def replay_from(
|
||||
job_id: str,
|
||||
start_stage: str,
|
||||
config_overrides: dict | None = None,
|
||||
checkpoint: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Replay the pipeline from a specific stage.
|
||||
|
||||
Loads the checkpoint from the stage immediately before start_stage,
|
||||
applies config overrides, and runs the subgraph from start_stage onward.
|
||||
|
||||
Returns the final state dict.
|
||||
"""
|
||||
if start_stage not in NODES:
|
||||
raise ValueError(f"Unknown stage: {start_stage!r}. Options: {NODES}")
|
||||
|
||||
start_idx = NODES.index(start_stage)
|
||||
|
||||
# Load checkpoint from the stage before start_stage
|
||||
if start_idx == 0:
|
||||
raise ValueError("Cannot replay from the first stage — just run the full pipeline")
|
||||
|
||||
previous_stage = NODES[start_idx - 1]
|
||||
|
||||
available = list_checkpoints(job_id)
|
||||
if previous_stage not in available:
|
||||
raise ValueError(
|
||||
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
|
||||
f"Available: {available}"
|
||||
)
|
||||
|
||||
logger.info("Replaying job %s from %s (loading checkpoint: %s)",
|
||||
job_id, start_stage, previous_stage)
|
||||
|
||||
state = load_checkpoint(job_id, previous_stage)
|
||||
|
||||
# Apply config overrides
|
||||
if config_overrides:
|
||||
state["config_overrides"] = config_overrides
|
||||
|
||||
# Set run context for SSE events
|
||||
run_id = str(uuid.uuid4())[:8]
|
||||
emit.set_run_context(
|
||||
run_id=run_id,
|
||||
parent_job_id=job_id,
|
||||
run_type="replay",
|
||||
)
|
||||
|
||||
# Build subgraph starting from start_stage
|
||||
graph = build_graph(checkpoint=checkpoint, start_from=start_stage)
|
||||
pipeline = graph.compile()
|
||||
|
||||
try:
|
||||
result = pipeline.invoke(state)
|
||||
finally:
|
||||
emit.clear_run_context()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def replay_single_stage(
|
||||
job_id: str,
|
||||
stage: str,
|
||||
frame_refs: list[int] | None = None,
|
||||
config_overrides: dict | None = None,
|
||||
debug: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Replay a single stage on specific frames (or all frames from checkpoint).
|
||||
|
||||
Fast path for interactive parameter tuning — runs only the target stage
|
||||
function, not the full pipeline tail. Returns the stage output directly.
|
||||
|
||||
When debug=True and stage is detect_edges, returns additional overlay
|
||||
data (Canny edges, Hough lines) for visual feedback in the editor.
|
||||
|
||||
For detect_edges: returns {"edge_regions_by_frame": {seq: [box, ...]}}
|
||||
With debug=True, also returns {"debug": {seq: {edge_overlay_b64, lines_overlay_b64, ...}}}
|
||||
"""
|
||||
if stage not in NODES:
|
||||
raise ValueError(f"Unknown stage: {stage!r}. Options: {NODES}")
|
||||
|
||||
stage_idx = NODES.index(stage)
|
||||
if stage_idx == 0:
|
||||
raise ValueError("Cannot replay the first stage — just run the full pipeline")
|
||||
|
||||
previous_stage = NODES[stage_idx - 1]
|
||||
|
||||
available = list_checkpoints(job_id)
|
||||
if previous_stage not in available:
|
||||
raise ValueError(
|
||||
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
|
||||
f"Available: {available}"
|
||||
)
|
||||
|
||||
logger.info("Single-stage replay: job %s, stage %s (loading checkpoint: %s, debug=%s)",
|
||||
job_id, stage, previous_stage, debug)
|
||||
|
||||
state = load_checkpoint(job_id, previous_stage)
|
||||
|
||||
# Build profile with overrides
|
||||
from core.detect.profile import get_profile, get_stage_config
|
||||
profile = get_profile(state.get("profile_name", "soccer_broadcast"))
|
||||
if config_overrides:
|
||||
merged_configs = dict(profile.get("configs", {}))
|
||||
for sname, soverrides in config_overrides.items():
|
||||
if sname in merged_configs:
|
||||
merged_configs[sname] = {**merged_configs[sname], **soverrides}
|
||||
else:
|
||||
merged_configs[sname] = soverrides
|
||||
profile = {**profile, "configs": merged_configs}
|
||||
|
||||
# Run the stage function directly (not through the graph)
|
||||
if stage == "detect_edges":
|
||||
return _replay_detect_edges(state, profile, frame_refs, job_id, debug)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Single-stage replay not yet implemented for {stage!r}. "
|
||||
f"Use replay_from() for full pipeline replay."
|
||||
)
|
||||
|
||||
|
||||
def _replay_detect_edges(
|
||||
state: dict,
|
||||
profile,
|
||||
frame_refs: list[int] | None,
|
||||
job_id: str,
|
||||
debug: bool,
|
||||
) -> dict:
|
||||
"""Run edge detection on checkpoint frames, optionally with debug overlays."""
|
||||
import os
|
||||
from core.detect.stages.edge_detector import detect_edge_regions
|
||||
|
||||
from core.detect.profile import get_stage_config
|
||||
from core.detect.stages.models import RegionAnalysisConfig
|
||||
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
||||
frames = state.get("filtered_frames", [])
|
||||
|
||||
if frame_refs:
|
||||
ref_set = set(frame_refs)
|
||||
frames = [f for f in frames if f.sequence in ref_set]
|
||||
|
||||
inference_url = os.environ.get("INFERENCE_URL")
|
||||
|
||||
# Normal run — always needed for the boxes
|
||||
result = detect_edge_regions(
|
||||
frames=frames,
|
||||
config=config,
|
||||
inference_url=inference_url,
|
||||
job_id=job_id,
|
||||
)
|
||||
output = {"edge_regions_by_frame": result}
|
||||
|
||||
# Debug overlays — call debug endpoint (remote) or local debug function
|
||||
if debug and frames:
|
||||
debug_data = {}
|
||||
if inference_url:
|
||||
from core.detect.inference import InferenceClient
|
||||
client = InferenceClient(base_url=inference_url, job_id=job_id)
|
||||
for frame in frames:
|
||||
dr = client.detect_edges_debug(
|
||||
image=frame.image,
|
||||
edge_canny_low=config.edge_canny_low,
|
||||
edge_canny_high=config.edge_canny_high,
|
||||
edge_hough_threshold=config.edge_hough_threshold,
|
||||
edge_hough_min_length=config.edge_hough_min_length,
|
||||
edge_hough_max_gap=config.edge_hough_max_gap,
|
||||
edge_pair_max_distance=config.edge_pair_max_distance,
|
||||
edge_pair_min_distance=config.edge_pair_min_distance,
|
||||
)
|
||||
debug_data[frame.sequence] = {
|
||||
"edge_overlay_b64": dr.edge_overlay_b64,
|
||||
"lines_overlay_b64": dr.lines_overlay_b64,
|
||||
"horizontal_count": dr.horizontal_count,
|
||||
"pair_count": dr.pair_count,
|
||||
}
|
||||
else:
|
||||
# Local mode — import GPU module directly
|
||||
from core.detect.stages.edge_detector import _load_cv_edges
|
||||
edges_mod = _load_cv_edges()
|
||||
for frame in frames:
|
||||
dr = edges_mod.detect_edges_debug(
|
||||
frame.image,
|
||||
canny_low=config.edge_canny_low,
|
||||
canny_high=config.edge_canny_high,
|
||||
hough_threshold=config.edge_hough_threshold,
|
||||
hough_min_length=config.edge_hough_min_length,
|
||||
hough_max_gap=config.edge_hough_max_gap,
|
||||
pair_max_distance=config.edge_pair_max_distance,
|
||||
pair_min_distance=config.edge_pair_min_distance,
|
||||
)
|
||||
debug_data[frame.sequence] = {
|
||||
"edge_overlay_b64": dr["edge_overlay_b64"],
|
||||
"lines_overlay_b64": dr["lines_overlay_b64"],
|
||||
"horizontal_count": dr["horizontal_count"],
|
||||
"pair_count": dr["pair_count"],
|
||||
}
|
||||
output["debug"] = debug_data
|
||||
|
||||
return output
|
||||
97
core/detect/checkpoint/runner_bridge.py
Normal file
97
core/detect/checkpoint/runner_bridge.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Runner bridge — checkpoint hook called by PipelineRunner after each stage.
|
||||
|
||||
Owns the per-job state (timeline, frame manifest, checkpoint chain) that
|
||||
the runner shouldn't know about.
|
||||
|
||||
Timeline and Job are independent entities:
|
||||
- One Timeline can serve multiple Jobs (re-run with different params)
|
||||
- One Job operates on one Timeline (set after frame extraction)
|
||||
- Checkpoints belong to Timeline, tagged with the Job that created them
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Per-job state
|
||||
_timeline_id: dict[str, str] = {}
|
||||
_frames_manifest: dict[str, dict[int, str]] = {}
|
||||
_latest_checkpoint: dict[str, str] = {}
|
||||
|
||||
|
||||
def reset_checkpoint_state(job_id: str):
|
||||
"""Clean up per-job checkpoint state. Called when pipeline finishes."""
|
||||
_timeline_id.pop(job_id, None)
|
||||
_frames_manifest.pop(job_id, None)
|
||||
_latest_checkpoint.pop(job_id, None)
|
||||
|
||||
|
||||
def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict):
|
||||
"""
|
||||
Save a checkpoint after a stage completes.
|
||||
|
||||
Called by the runner. Handles:
|
||||
- Timeline creation (once, on extract_frames)
|
||||
- Frame upload (via create_timeline)
|
||||
- Stage output serialization (via stage registry)
|
||||
- Checkpoint chain (parent → child)
|
||||
"""
|
||||
if not job_id:
|
||||
return
|
||||
|
||||
from .storage import create_timeline, save_stage_output
|
||||
from core.detect.stages.base import _REGISTRY
|
||||
|
||||
merged = {**state, **result}
|
||||
|
||||
# On extract_frames: create Timeline + upload frames + root checkpoint
|
||||
if stage_name == "extract_frames" and job_id not in _timeline_id:
|
||||
frames = merged.get("frames", [])
|
||||
video_path = merged.get("video_path", "")
|
||||
profile_name = merged.get("profile_name", "")
|
||||
|
||||
tid, cid = create_timeline(
|
||||
source_video=video_path,
|
||||
profile_name=profile_name,
|
||||
frames=frames,
|
||||
)
|
||||
_timeline_id[job_id] = tid
|
||||
_latest_checkpoint[job_id] = cid
|
||||
logger.info("Job %s → Timeline %s (root checkpoint %s)", job_id, tid, cid)
|
||||
|
||||
# Emit timeline_id via SSE so the UI can use it for checkpoint loads
|
||||
from core.detect import emit
|
||||
emit.log(job_id, "Checkpoint", "INFO", f"timeline_id={tid}")
|
||||
return
|
||||
|
||||
# For subsequent stages: save checkpoint on the timeline
|
||||
tid = _timeline_id.get(job_id)
|
||||
if not tid:
|
||||
logger.warning("No timeline for job %s, skipping checkpoint", job_id)
|
||||
return
|
||||
|
||||
# Serialize stage output using the stage's serialize_fn if available
|
||||
stage_cls = _REGISTRY.get(stage_name)
|
||||
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
||||
if serialize_fn:
|
||||
output_json = serialize_fn(merged, job_id)
|
||||
else:
|
||||
output_json = {}
|
||||
|
||||
parent_id = _latest_checkpoint.get(job_id)
|
||||
new_checkpoint_id = save_stage_output(
|
||||
timeline_id=tid,
|
||||
parent_checkpoint_id=parent_id,
|
||||
stage_name=stage_name,
|
||||
output_json=output_json,
|
||||
job_id=job_id,
|
||||
)
|
||||
_latest_checkpoint[job_id] = new_checkpoint_id
|
||||
|
||||
|
||||
def get_timeline_id(job_id: str) -> str | None:
|
||||
"""Get the timeline_id for a running job. Used by the UI to load checkpoints."""
|
||||
return _timeline_id.get(job_id)
|
||||
108
core/detect/checkpoint/serializer.py
Normal file
108
core/detect/checkpoint/serializer.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
State serialization — DetectState ↔ JSON-compatible dict.
|
||||
|
||||
Delegates to each stage's serialize_fn/deserialize_fn via the registry.
|
||||
This file has no model-specific knowledge — stages own their data format.
|
||||
|
||||
The only things serialized here are the "envelope" fields (job_id, video_path, etc.)
|
||||
that don't belong to any stage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from core.schema.serializers._common import serialize_dataclass
|
||||
from core.schema.serializers.pipeline import (
|
||||
deserialize_pipeline_stats,
|
||||
deserialize_text_candidates,
|
||||
)
|
||||
|
||||
|
||||
# Envelope fields — not owned by any stage, always present
|
||||
ENVELOPE_KEYS = ["job_id", "video_path", "profile_name", "config_overrides"]
|
||||
|
||||
|
||||
def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
|
||||
"""
|
||||
Serialize DetectState to a JSON-compatible dict.
|
||||
|
||||
Calls each registered stage's serialize_fn for stage-owned data.
|
||||
Envelope fields (job_id, etc.) are copied directly.
|
||||
"""
|
||||
from core.detect.stages.base import _REGISTRY
|
||||
|
||||
checkpoint = {}
|
||||
|
||||
# Envelope
|
||||
for key in ENVELOPE_KEYS:
|
||||
default = {} if key == "config_overrides" else ""
|
||||
checkpoint[key] = state.get(key, default)
|
||||
|
||||
# Frames manifest (needed by frame-loading stages)
|
||||
checkpoint["frames_manifest"] = {str(k): v for k, v in frames_manifest.items()}
|
||||
|
||||
# Stats (shared across stages, not owned by one)
|
||||
stats = state.get("stats")
|
||||
if stats is not None:
|
||||
checkpoint["stats"] = serialize_dataclass(stats)
|
||||
else:
|
||||
checkpoint["stats"] = {}
|
||||
|
||||
# Per-stage data
|
||||
for name, stage_def in _REGISTRY.items():
|
||||
if stage_def.serialize_fn is None:
|
||||
continue
|
||||
job_id = state.get("job_id", "")
|
||||
stage_data = stage_def.serialize_fn(state, job_id)
|
||||
checkpoint[f"stage_{name}"] = stage_data
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def deserialize_state(checkpoint: dict, frames: list) -> dict:
|
||||
"""
|
||||
Reconstitute DetectState from a checkpoint dict + loaded frames.
|
||||
|
||||
Calls each stage's deserialize_fn to restore stage-owned data.
|
||||
"""
|
||||
from core.detect.stages.base import _REGISTRY
|
||||
|
||||
frame_map = {f.sequence: f for f in frames}
|
||||
|
||||
state = {}
|
||||
|
||||
# Envelope
|
||||
for key in ENVELOPE_KEYS:
|
||||
default = {} if key == "config_overrides" else ""
|
||||
state[key] = checkpoint.get(key, default)
|
||||
|
||||
# Frames (always present, loaded externally)
|
||||
state["frames"] = frames
|
||||
|
||||
# Stats
|
||||
state["stats"] = deserialize_pipeline_stats(checkpoint.get("stats", {}))
|
||||
|
||||
# Per-stage data
|
||||
for name, stage_def in _REGISTRY.items():
|
||||
if stage_def.deserialize_fn is None:
|
||||
continue
|
||||
|
||||
stage_key = f"stage_{name}"
|
||||
if stage_key not in checkpoint:
|
||||
continue
|
||||
|
||||
job_id = state.get("job_id", "")
|
||||
stage_data = stage_def.deserialize_fn(checkpoint[stage_key], job_id)
|
||||
|
||||
for k, v in stage_data.items():
|
||||
if k == "_filtered_sequences":
|
||||
# Reconnect filtered frames from sequence list
|
||||
seq_set = set(v)
|
||||
state["filtered_frames"] = [f for f in frames if f.sequence in seq_set]
|
||||
elif k.endswith("_raw"):
|
||||
# Raw text candidates need frame reference reconnection
|
||||
real_key = k.removeprefix("_").removesuffix("_raw")
|
||||
state[real_key] = deserialize_text_candidates(v, frame_map)
|
||||
else:
|
||||
state[k] = v
|
||||
|
||||
return state
|
||||
177
core/detect/checkpoint/storage.py
Normal file
177
core/detect/checkpoint/storage.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""
|
||||
Checkpoint storage — Timeline + Checkpoint (tree of snapshots).
|
||||
|
||||
Timeline: frame sequence from source video (frames in MinIO)
|
||||
Checkpoint: snapshot of pipeline state (stage outputs as JSONB in Postgres)
|
||||
parent_id forms a tree — multiple children = different config tries
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from .frames import save_frames, load_frames, CHECKPOINT_PREFIX
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Timeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_timeline(
|
||||
source_video: str,
|
||||
profile_name: str,
|
||||
frames: list,
|
||||
fps: float = 2.0,
|
||||
source_asset_id: UUID | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Create a timeline from frames. Uploads frame images to MinIO,
|
||||
creates Timeline + root Checkpoint in Postgres.
|
||||
|
||||
Returns (timeline_id, checkpoint_id).
|
||||
"""
|
||||
from core.db.models import Timeline, Checkpoint
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
timeline = Timeline(
|
||||
source_video=source_video,
|
||||
profile_name=profile_name,
|
||||
source_asset_id=source_asset_id,
|
||||
fps=fps,
|
||||
)
|
||||
session.add(timeline)
|
||||
session.flush()
|
||||
tid = str(timeline.id)
|
||||
|
||||
# Upload frames to MinIO
|
||||
manifest = save_frames(tid, frames)
|
||||
|
||||
frames_meta = [
|
||||
{
|
||||
"sequence": f.sequence,
|
||||
"chunk_id": getattr(f, "chunk_id", 0),
|
||||
"timestamp": f.timestamp,
|
||||
"perceptual_hash": getattr(f, "perceptual_hash", ""),
|
||||
}
|
||||
for f in frames
|
||||
]
|
||||
|
||||
timeline.frames_prefix = f"{CHECKPOINT_PREFIX}/{tid}/frames/"
|
||||
timeline.frames_manifest = {str(k): v for k, v in manifest.items()}
|
||||
timeline.frames_meta = frames_meta
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
timeline_id=timeline.id,
|
||||
parent_id=None,
|
||||
stage_outputs={},
|
||||
stats={"frames_extracted": len(frames)},
|
||||
)
|
||||
session.add(checkpoint)
|
||||
session.commit()
|
||||
session.refresh(checkpoint)
|
||||
cid = str(checkpoint.id)
|
||||
|
||||
logger.info("Timeline created: %s (%d frames, root checkpoint %s)", tid, len(frames), cid)
|
||||
return tid, cid
|
||||
|
||||
|
||||
def get_timeline_frames(timeline_id: str) -> list:
|
||||
"""Load frames from a timeline (from MinIO) as Frame objects."""
|
||||
from core.db.models import Timeline
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
timeline = session.get(Timeline, UUID(timeline_id))
|
||||
if not timeline:
|
||||
raise ValueError(f"Timeline not found: {timeline_id}")
|
||||
|
||||
raw_manifest = timeline.frames_manifest or {}
|
||||
manifest = {int(k): v for k, v in raw_manifest.items()}
|
||||
return load_frames(manifest, timeline.frames_meta or [])
|
||||
|
||||
|
||||
def get_timeline_frames_b64(timeline_id: str) -> list[dict]:
|
||||
"""Load frames as base64 JPEG (lightweight, no numpy)."""
|
||||
from core.db.models import Timeline
|
||||
from core.db.connection import get_session
|
||||
from .frames import load_frames_b64
|
||||
|
||||
with get_session() as session:
|
||||
timeline = session.get(Timeline, UUID(timeline_id))
|
||||
if not timeline:
|
||||
raise ValueError(f"Timeline not found: {timeline_id}")
|
||||
|
||||
raw_manifest = timeline.frames_manifest or {}
|
||||
manifest = {int(k): v for k, v in raw_manifest.items()}
|
||||
return load_frames_b64(manifest, timeline.frames_meta or [])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Checkpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def save_stage_output(
|
||||
timeline_id: str,
|
||||
parent_checkpoint_id: str | None,
|
||||
stage_name: str,
|
||||
output_json: dict,
|
||||
config_overrides: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
is_scenario: bool = False,
|
||||
scenario_label: str = "",
|
||||
job_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Save a stage's output as a new checkpoint (child of parent).
|
||||
|
||||
Carries forward stage outputs from parent + adds the new one.
|
||||
Returns the new checkpoint ID.
|
||||
"""
|
||||
from core.db.models import Checkpoint
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
parent_outputs = {}
|
||||
parent_stats = {}
|
||||
parent_config = {}
|
||||
if parent_checkpoint_id:
|
||||
parent = session.get(Checkpoint, UUID(parent_checkpoint_id))
|
||||
if parent:
|
||||
parent_outputs = dict(parent.stage_outputs or {})
|
||||
parent_stats = dict(parent.stats or {})
|
||||
parent_config = dict(parent.config_overrides or {})
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
timeline_id=UUID(timeline_id),
|
||||
job_id=UUID(job_id) if job_id else None,
|
||||
parent_id=UUID(parent_checkpoint_id) if parent_checkpoint_id else None,
|
||||
stage_outputs={**parent_outputs, stage_name: output_json},
|
||||
config_overrides={**parent_config, **(config_overrides or {})},
|
||||
stats={**parent_stats, **(stats or {})},
|
||||
is_scenario=is_scenario,
|
||||
scenario_label=scenario_label,
|
||||
)
|
||||
session.add(checkpoint)
|
||||
session.commit()
|
||||
session.refresh(checkpoint)
|
||||
cid = str(checkpoint.id)
|
||||
|
||||
logger.info("Checkpoint saved: %s (timeline %s, stage %s, parent %s)",
|
||||
cid, timeline_id, stage_name, parent_checkpoint_id)
|
||||
return cid
|
||||
|
||||
|
||||
def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None:
|
||||
"""Load a stage's output from a checkpoint."""
|
||||
from core.db.models import Checkpoint
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
checkpoint = session.get(Checkpoint, UUID(checkpoint_id))
|
||||
if not checkpoint:
|
||||
return None
|
||||
return (checkpoint.stage_outputs or {}).get(stage_name)
|
||||
Reference in New Issue
Block a user