""" Checkpoint storage — save/load stage state. Binary data (frame images) → S3/MinIO via frames.py Structured data (boxes, detections, stats, config) → Postgres via Django ORM Until the Django model is generated by modelgen, checkpoint data is stored as JSON in S3 as a fallback. Once DetectJob/StageCheckpoint models exist, this module switches to Postgres. """ from __future__ import annotations import json import logging import os import tempfile from pathlib import Path from .frames import save_frames, load_frames, BUCKET, CHECKPOINT_PREFIX from .serializer import serialize_state, deserialize_state logger = logging.getLogger(__name__) def _has_db() -> bool: """Check if the DB layer is available (Django + models generated by modelgen).""" try: from core.db.detect import get_stage_checkpoint as _ # Quick check that the model exists (modelgen may not have run yet) from admin.mpr.media_assets.models import StageCheckpoint as _ return True except (ImportError, Exception): return False # --------------------------------------------------------------------------- # Save # --------------------------------------------------------------------------- def save_checkpoint( job_id: str, stage: str, stage_index: int, state: dict, frames_manifest: dict[int, str] | None = None, ) -> str: """ Save a stage checkpoint. Saves frame images to S3 (if not already saved), then persists structured state to Postgres (or S3 JSON fallback). Returns the checkpoint identifier (DB id or S3 key). """ # Save frames to S3 if no manifest provided if frames_manifest is None: all_frames = state.get("frames", []) frames_manifest = save_frames(job_id, all_frames) checkpoint_data = serialize_state(state, frames_manifest) if _has_db(): checkpoint_id = _save_to_db(job_id, stage, stage_index, checkpoint_data) else: checkpoint_id = _save_to_s3(job_id, stage, checkpoint_data) return checkpoint_id def _save_to_db(job_id: str, stage: str, stage_index: int, data: dict) -> str: """Save checkpoint structured data to Postgres via core/db.""" import uuid from core.db.detect import save_stage_checkpoint job_uuid = uuid.UUID(job_id) if isinstance(job_id, str) else job_id checkpoint_id = uuid.uuid4() frames_prefix = f"{CHECKPOINT_PREFIX}/{job_id}/frames/" checkpoint = save_stage_checkpoint( id=checkpoint_id, job_id=job_uuid, stage=stage, stage_index=stage_index, frames_prefix=frames_prefix, frames_manifest=data.get("frames_manifest", {}), frames_meta=data.get("frames_meta", []), filtered_frame_sequences=data.get("filtered_frame_sequences", []), boxes_by_frame=data.get("boxes_by_frame", {}), text_candidates=data.get("text_candidates", []), unresolved_candidates=data.get("unresolved_candidates", []), detections=data.get("detections", []), stats=data.get("stats", {}), config_snapshot=data.get("config_overrides", {}), config_overrides=data.get("config_overrides", {}), video_path=data.get("video_path", ""), profile_name=data.get("profile_name", ""), ) logger.info("Checkpoint saved to DB: %s/%s (id=%s)", job_id, stage, checkpoint.id) return str(checkpoint.id) def _save_to_s3(job_id: str, stage: str, data: dict) -> str: """Fallback: save checkpoint as JSON to S3 (before modelgen generates DB models).""" from core.storage.s3 import upload_file checkpoint_json = json.dumps(data, default=str) key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json" with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp: tmp.write(checkpoint_json) tmp_path = tmp.name try: upload_file(tmp_path, BUCKET, key) finally: os.unlink(tmp_path) logger.info("Checkpoint saved to S3: s3://%s/%s", BUCKET, key) return key # --------------------------------------------------------------------------- # Load # --------------------------------------------------------------------------- def load_checkpoint(job_id: str, stage: str) -> dict: """ Load a stage checkpoint and reconstitute full DetectState. Tries Postgres first, falls back to S3 JSON. """ if _has_db(): data = _load_from_db(job_id, stage) else: data = _load_from_s3(job_id, stage) raw_manifest = data.get("frames_manifest", {}) manifest = {int(k): v for k, v in raw_manifest.items()} frame_metadata = data.get("frames_meta", []) frames = load_frames(manifest, frame_metadata) state = deserialize_state(data, frames) logger.info("Checkpoint loaded: %s/%s (%d frames)", job_id, stage, len(frames)) return state def _load_from_db(job_id: str, stage: str) -> dict: """Load checkpoint data from Postgres via core/db.""" from core.db.detect import get_stage_checkpoint checkpoint = get_stage_checkpoint(job_id, stage) data = { "job_id": str(checkpoint.job_id), "video_path": checkpoint.video_path, "profile_name": checkpoint.profile_name, "config_overrides": checkpoint.config_overrides, "frames_manifest": checkpoint.frames_manifest, "frames_meta": checkpoint.frames_meta, "filtered_frame_sequences": checkpoint.filtered_frame_sequences, "boxes_by_frame": checkpoint.boxes_by_frame, "text_candidates": checkpoint.text_candidates, "unresolved_candidates": checkpoint.unresolved_candidates, "detections": checkpoint.detections, "stats": checkpoint.stats, } return data def _load_from_s3(job_id: str, stage: str) -> dict: """Fallback: load checkpoint JSON from S3.""" from core.storage.s3 import download_to_temp key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json" tmp_path = download_to_temp(BUCKET, key) try: with open(tmp_path) as f: data = json.load(f) finally: os.unlink(tmp_path) return data # --------------------------------------------------------------------------- # List # --------------------------------------------------------------------------- def list_checkpoints(job_id: str) -> list[str]: """List available checkpoint stages for a job.""" if _has_db(): return _list_from_db(job_id) return _list_from_s3(job_id) def _list_from_db(job_id: str) -> list[str]: from core.db.detect import list_stage_checkpoints return list_stage_checkpoints(job_id) def _list_from_s3(job_id: str) -> list[str]: from core.storage.s3 import list_objects prefix = f"{CHECKPOINT_PREFIX}/{job_id}/stages/" objects = list_objects(BUCKET, prefix) stages = [] for obj in objects: name = Path(obj["key"]).stem stages.append(name) return stages