213 lines
6.5 KiB
Python
213 lines
6.5 KiB
Python
"""
|
|
Checkpoint storage — save/load stage state.
|
|
|
|
Binary data (frame images) → S3/MinIO via frames.py
|
|
Structured data (boxes, detections, stats, config) → Postgres via Django ORM
|
|
|
|
Until the Django model is generated by modelgen, checkpoint data is stored
|
|
as JSON in S3 as a fallback. Once DetectJob/StageCheckpoint models exist,
|
|
this module switches to Postgres.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
from .frames import save_frames, load_frames, BUCKET, CHECKPOINT_PREFIX
|
|
from .serializer import serialize_state, deserialize_state
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _has_db() -> bool:
|
|
"""Check if Postgres is reachable."""
|
|
try:
|
|
from core.db.connection import get_session
|
|
from sqlmodel import text
|
|
with get_session() as session:
|
|
session.exec(text("SELECT 1"))
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Save
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def save_checkpoint(
|
|
job_id: str,
|
|
stage: str,
|
|
stage_index: int,
|
|
state: dict,
|
|
frames_manifest: dict[int, str] | None = None,
|
|
) -> str:
|
|
"""
|
|
Save a stage checkpoint.
|
|
|
|
Saves frame images to S3 (if not already saved), then persists
|
|
structured state to Postgres (or S3 JSON fallback).
|
|
|
|
Returns the checkpoint identifier (DB id or S3 key).
|
|
"""
|
|
# Save frames to S3 if no manifest provided
|
|
if frames_manifest is None:
|
|
all_frames = state.get("frames", [])
|
|
frames_manifest = save_frames(job_id, all_frames)
|
|
|
|
checkpoint_data = serialize_state(state, frames_manifest)
|
|
|
|
if _has_db():
|
|
checkpoint_id = _save_to_db(job_id, stage, stage_index, checkpoint_data)
|
|
else:
|
|
checkpoint_id = _save_to_s3(job_id, stage, checkpoint_data)
|
|
|
|
return checkpoint_id
|
|
|
|
|
|
def _save_to_db(job_id: str, stage: str, stage_index: int, data: dict) -> str:
|
|
"""Save checkpoint structured data to Postgres."""
|
|
from core.db.detect import save_stage_checkpoint
|
|
|
|
frames_prefix = f"{CHECKPOINT_PREFIX}/{job_id}/frames/"
|
|
|
|
checkpoint = save_stage_checkpoint(
|
|
job_id=job_id,
|
|
stage=stage,
|
|
stage_index=stage_index,
|
|
frames_prefix=frames_prefix,
|
|
frames_manifest=data.get("frames_manifest", {}),
|
|
frames_meta=data.get("frames_meta", []),
|
|
filtered_frame_sequences=data.get("filtered_frame_sequences", []),
|
|
boxes_by_frame=data.get("boxes_by_frame", {}),
|
|
text_candidates=data.get("text_candidates", []),
|
|
unresolved_candidates=data.get("unresolved_candidates", []),
|
|
detections=data.get("detections", []),
|
|
stats=data.get("stats", {}),
|
|
config_snapshot=data.get("config_overrides", {}),
|
|
config_overrides=data.get("config_overrides", {}),
|
|
video_path=data.get("video_path", ""),
|
|
profile_name=data.get("profile_name", ""),
|
|
)
|
|
|
|
logger.info("Checkpoint saved to DB: %s/%s (id=%s)", job_id, stage, checkpoint.id)
|
|
return str(checkpoint.id)
|
|
|
|
|
|
def _save_to_s3(job_id: str, stage: str, data: dict) -> str:
|
|
"""Fallback: save checkpoint as JSON to S3 (before modelgen generates DB models)."""
|
|
from core.storage.s3 import upload_file
|
|
|
|
checkpoint_json = json.dumps(data, default=str)
|
|
key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json"
|
|
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp:
|
|
tmp.write(checkpoint_json)
|
|
tmp_path = tmp.name
|
|
|
|
try:
|
|
upload_file(tmp_path, BUCKET, key)
|
|
finally:
|
|
os.unlink(tmp_path)
|
|
|
|
logger.info("Checkpoint saved to S3: s3://%s/%s", BUCKET, key)
|
|
return key
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Load
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def load_checkpoint(job_id: str, stage: str) -> dict:
|
|
"""
|
|
Load a stage checkpoint and reconstitute full DetectState.
|
|
|
|
Tries Postgres first, falls back to S3 JSON.
|
|
"""
|
|
if _has_db():
|
|
data = _load_from_db(job_id, stage)
|
|
else:
|
|
data = _load_from_s3(job_id, stage)
|
|
|
|
raw_manifest = data.get("frames_manifest", {})
|
|
manifest = {int(k): v for k, v in raw_manifest.items()}
|
|
frame_metadata = data.get("frames_meta", [])
|
|
frames = load_frames(manifest, frame_metadata)
|
|
|
|
state = deserialize_state(data, frames)
|
|
|
|
logger.info("Checkpoint loaded: %s/%s (%d frames)", job_id, stage, len(frames))
|
|
return state
|
|
|
|
|
|
def _load_from_db(job_id: str, stage: str) -> dict:
|
|
"""Load checkpoint data from Postgres via core/db."""
|
|
from core.db.detect import get_stage_checkpoint
|
|
|
|
checkpoint = get_stage_checkpoint(job_id, stage)
|
|
|
|
data = {
|
|
"job_id": str(checkpoint.job_id),
|
|
"video_path": checkpoint.video_path,
|
|
"profile_name": checkpoint.profile_name,
|
|
"config_overrides": checkpoint.config_overrides,
|
|
"frames_manifest": checkpoint.frames_manifest,
|
|
"frames_meta": checkpoint.frames_meta,
|
|
"filtered_frame_sequences": checkpoint.filtered_frame_sequences,
|
|
"boxes_by_frame": checkpoint.boxes_by_frame,
|
|
"text_candidates": checkpoint.text_candidates,
|
|
"unresolved_candidates": checkpoint.unresolved_candidates,
|
|
"detections": checkpoint.detections,
|
|
"stats": checkpoint.stats,
|
|
}
|
|
return data
|
|
|
|
|
|
def _load_from_s3(job_id: str, stage: str) -> dict:
|
|
"""Fallback: load checkpoint JSON from S3."""
|
|
from core.storage.s3 import download_to_temp
|
|
|
|
key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json"
|
|
tmp_path = download_to_temp(BUCKET, key)
|
|
try:
|
|
with open(tmp_path) as f:
|
|
data = json.load(f)
|
|
finally:
|
|
os.unlink(tmp_path)
|
|
|
|
return data
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# List
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def list_checkpoints(job_id: str) -> list[str]:
|
|
"""List available checkpoint stages for a job."""
|
|
if _has_db():
|
|
return _list_from_db(job_id)
|
|
return _list_from_s3(job_id)
|
|
|
|
|
|
def _list_from_db(job_id: str) -> list[str]:
|
|
from core.db.detect import list_stage_checkpoints
|
|
return list_stage_checkpoints(job_id)
|
|
|
|
|
|
def _list_from_s3(job_id: str) -> list[str]:
|
|
from core.storage.s3 import list_objects
|
|
|
|
prefix = f"{CHECKPOINT_PREFIX}/{job_id}/stages/"
|
|
objects = list_objects(BUCKET, prefix)
|
|
|
|
stages = []
|
|
for obj in objects:
|
|
name = Path(obj["key"]).stem
|
|
stages.append(name)
|
|
|
|
return stages
|