117 lines
3.9 KiB
Python
117 lines
3.9 KiB
Python
"""
|
|
Checkpoint storage — save/load stage state.
|
|
|
|
Binary data (frame images) → S3/MinIO via frames.py
|
|
Structured data (stage output, stats, config) → Postgres
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
from .frames import save_frames, load_frames, CHECKPOINT_PREFIX
|
|
from .serializer import serialize_state, deserialize_state
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Save
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def save_checkpoint(
|
|
job_id: str,
|
|
stage: str,
|
|
stage_index: int,
|
|
state: dict,
|
|
frames_manifest: dict[int, str] | None = None,
|
|
is_scenario: bool = False,
|
|
scenario_label: str = "",
|
|
) -> str:
|
|
"""
|
|
Save a stage checkpoint.
|
|
|
|
Saves frame images to S3 (if not already saved), then persists
|
|
structured state to Postgres.
|
|
|
|
Returns the checkpoint DB id.
|
|
"""
|
|
from core.db.detect import save_stage_checkpoint
|
|
|
|
if frames_manifest is None:
|
|
all_frames = state.get("frames", [])
|
|
frames_manifest = save_frames(job_id, all_frames)
|
|
|
|
checkpoint_data = serialize_state(state, frames_manifest)
|
|
frames_prefix = f"{CHECKPOINT_PREFIX}/{job_id}/frames/"
|
|
|
|
checkpoint = save_stage_checkpoint(
|
|
job_id=job_id,
|
|
stage=stage,
|
|
stage_index=stage_index,
|
|
frames_prefix=frames_prefix,
|
|
frames_manifest=checkpoint_data.get("frames_manifest", {}),
|
|
frames_meta=checkpoint_data.get("frames_meta", []),
|
|
filtered_frame_sequences=checkpoint_data.get("filtered_frame_sequences", []),
|
|
stage_output_key=checkpoint_data.get("stage_output_key", ""),
|
|
stats=checkpoint_data.get("stats", {}),
|
|
config_snapshot=checkpoint_data.get("config_overrides", {}),
|
|
config_overrides=checkpoint_data.get("config_overrides", {}),
|
|
video_path=checkpoint_data.get("video_path", ""),
|
|
profile_name=checkpoint_data.get("profile_name", ""),
|
|
is_scenario=is_scenario,
|
|
scenario_label=scenario_label,
|
|
)
|
|
|
|
logger.info("Checkpoint saved: %s/%s (id=%s, scenario=%s)",
|
|
job_id, stage, checkpoint.id, is_scenario)
|
|
return str(checkpoint.id)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Load
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def load_checkpoint(job_id: str, stage: str) -> dict:
|
|
"""
|
|
Load a stage checkpoint and reconstitute full DetectState.
|
|
"""
|
|
from core.db.detect import get_stage_checkpoint
|
|
|
|
checkpoint = get_stage_checkpoint(job_id, stage)
|
|
if not checkpoint:
|
|
raise ValueError(f"No checkpoint for {job_id}/{stage}")
|
|
|
|
data = {
|
|
"job_id": str(checkpoint.job_id),
|
|
"video_path": checkpoint.video_path,
|
|
"profile_name": checkpoint.profile_name,
|
|
"config_overrides": checkpoint.config_overrides,
|
|
"frames_manifest": checkpoint.frames_manifest,
|
|
"frames_meta": checkpoint.frames_meta,
|
|
"filtered_frame_sequences": checkpoint.filtered_frame_sequences,
|
|
"stage_output_key": checkpoint.stage_output_key,
|
|
"stats": checkpoint.stats,
|
|
}
|
|
|
|
raw_manifest = data.get("frames_manifest", {})
|
|
manifest = {int(k): v for k, v in raw_manifest.items()}
|
|
frame_metadata = data.get("frames_meta", [])
|
|
frames = load_frames(manifest, frame_metadata)
|
|
|
|
state = deserialize_state(data, frames)
|
|
|
|
logger.info("Checkpoint loaded: %s/%s (%d frames, scenario=%s)",
|
|
job_id, stage, len(frames), checkpoint.is_scenario)
|
|
return state
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# List
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def list_checkpoints(job_id: str) -> list[str]:
|
|
"""List available checkpoint stages for a job."""
|
|
from core.db.detect import list_stage_checkpoints
|
|
return list_stage_checkpoints(job_id)
|