Files
mediaproc/detect/checkpoint/storage.py
2026-03-26 22:22:35 -03:00

117 lines
3.9 KiB
Python

"""
Checkpoint storage — save/load stage state.
Binary data (frame images) → S3/MinIO via frames.py
Structured data (stage output, stats, config) → Postgres
"""
from __future__ import annotations
import logging
from .frames import save_frames, load_frames, CHECKPOINT_PREFIX
from .serializer import serialize_state, deserialize_state
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Save
# ---------------------------------------------------------------------------
def save_checkpoint(
job_id: str,
stage: str,
stage_index: int,
state: dict,
frames_manifest: dict[int, str] | None = None,
is_scenario: bool = False,
scenario_label: str = "",
) -> str:
"""
Save a stage checkpoint.
Saves frame images to S3 (if not already saved), then persists
structured state to Postgres.
Returns the checkpoint DB id.
"""
from core.db.detect import save_stage_checkpoint
if frames_manifest is None:
all_frames = state.get("frames", [])
frames_manifest = save_frames(job_id, all_frames)
checkpoint_data = serialize_state(state, frames_manifest)
frames_prefix = f"{CHECKPOINT_PREFIX}/{job_id}/frames/"
checkpoint = save_stage_checkpoint(
job_id=job_id,
stage=stage,
stage_index=stage_index,
frames_prefix=frames_prefix,
frames_manifest=checkpoint_data.get("frames_manifest", {}),
frames_meta=checkpoint_data.get("frames_meta", []),
filtered_frame_sequences=checkpoint_data.get("filtered_frame_sequences", []),
stage_output_key=checkpoint_data.get("stage_output_key", ""),
stats=checkpoint_data.get("stats", {}),
config_snapshot=checkpoint_data.get("config_overrides", {}),
config_overrides=checkpoint_data.get("config_overrides", {}),
video_path=checkpoint_data.get("video_path", ""),
profile_name=checkpoint_data.get("profile_name", ""),
is_scenario=is_scenario,
scenario_label=scenario_label,
)
logger.info("Checkpoint saved: %s/%s (id=%s, scenario=%s)",
job_id, stage, checkpoint.id, is_scenario)
return str(checkpoint.id)
# ---------------------------------------------------------------------------
# Load
# ---------------------------------------------------------------------------
def load_checkpoint(job_id: str, stage: str) -> dict:
"""
Load a stage checkpoint and reconstitute full DetectState.
"""
from core.db.detect import get_stage_checkpoint
checkpoint = get_stage_checkpoint(job_id, stage)
if not checkpoint:
raise ValueError(f"No checkpoint for {job_id}/{stage}")
data = {
"job_id": str(checkpoint.job_id),
"video_path": checkpoint.video_path,
"profile_name": checkpoint.profile_name,
"config_overrides": checkpoint.config_overrides,
"frames_manifest": checkpoint.frames_manifest,
"frames_meta": checkpoint.frames_meta,
"filtered_frame_sequences": checkpoint.filtered_frame_sequences,
"stage_output_key": checkpoint.stage_output_key,
"stats": checkpoint.stats,
}
raw_manifest = data.get("frames_manifest", {})
manifest = {int(k): v for k, v in raw_manifest.items()}
frame_metadata = data.get("frames_meta", [])
frames = load_frames(manifest, frame_metadata)
state = deserialize_state(data, frames)
logger.info("Checkpoint loaded: %s/%s (%d frames, scenario=%s)",
job_id, stage, len(frames), checkpoint.is_scenario)
return state
# ---------------------------------------------------------------------------
# List
# ---------------------------------------------------------------------------
def list_checkpoints(job_id: str) -> list[str]:
"""List available checkpoint stages for a job."""
from core.db.detect import list_stage_checkpoints
return list_stage_checkpoints(job_id)