This commit is contained in:
2026-03-26 04:24:32 -03:00
parent 08b67f2bb7
commit 08c58a6a9d
43 changed files with 2627 additions and 252 deletions

View File

@@ -0,0 +1,14 @@
"""
Stage checkpoint, replay, and retry.
detect/checkpoint/
frames.py — frame image S3 upload/download
serializer.py — state ↔ JSON conversion
storage.py — checkpoint save/load/list (Postgres + S3)
replay.py — replay_from, OverrideProfile
tasks.py — retry_candidates Celery task
"""
from .storage import save_checkpoint, load_checkpoint, list_checkpoints
from .frames import save_frames, load_frames
from .replay import replay_from, OverrideProfile

View File

@@ -0,0 +1,80 @@
"""Frame image storage — save/load to S3/MinIO as JPEGs."""
from __future__ import annotations
import logging
import os
import tempfile
import numpy as np
from PIL import Image
from detect.models import Frame
logger = logging.getLogger(__name__)
BUCKET = os.environ.get("S3_BUCKET_OUT", "mpr-media-out")
CHECKPOINT_PREFIX = "checkpoints"
def save_frames(job_id: str, frames: list[Frame]) -> dict[int, str]:
"""
Save frame images to S3 as JPEGs.
Returns manifest: {sequence: s3_key}
"""
from core.storage.s3 import upload_file
manifest = {}
for frame in frames:
key = f"{CHECKPOINT_PREFIX}/{job_id}/frames/{frame.sequence}.jpg"
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
img = Image.fromarray(frame.image)
img.save(tmp, format="JPEG", quality=85)
tmp_path = tmp.name
try:
upload_file(tmp_path, BUCKET, key)
finally:
os.unlink(tmp_path)
manifest[frame.sequence] = key
logger.info("Saved %d frames to s3://%s/%s/%s/frames/",
len(frames), BUCKET, CHECKPOINT_PREFIX, job_id)
return manifest
def load_frames(manifest: dict[int, str], frame_metadata: list[dict]) -> list[Frame]:
"""
Load frame images from S3 and reconstitute Frame objects.
frame_metadata: list of dicts with sequence, chunk_id, timestamp, perceptual_hash.
"""
from core.storage.s3 import download_to_temp
meta_map = {m["sequence"]: m for m in frame_metadata}
frames = []
for seq, key in manifest.items():
tmp_path = download_to_temp(BUCKET, key)
try:
img = Image.open(tmp_path).convert("RGB")
image_array = np.array(img)
finally:
os.unlink(tmp_path)
meta = meta_map.get(seq, {})
frame = Frame(
sequence=seq,
chunk_id=meta.get("chunk_id", 0),
timestamp=meta.get("timestamp", 0.0),
image=image_array,
perceptual_hash=meta.get("perceptual_hash", ""),
)
frames.append(frame)
frames.sort(key=lambda f: f.sequence)
return frames

132
detect/checkpoint/replay.py Normal file
View File

@@ -0,0 +1,132 @@
"""
Pipeline replay — re-run from any stage with different config.
Loads a checkpoint, applies config overrides, builds a subgraph
starting from the target stage, and invokes it.
"""
from __future__ import annotations
import logging
import uuid
from detect import emit
from detect.checkpoint import load_checkpoint, list_checkpoints
from detect.graph import NODES, build_graph
logger = logging.getLogger(__name__)
class OverrideProfile:
"""
Wraps a ContentTypeProfile and patches config methods with overrides.
Override dict structure:
{
"frame_extraction": {"fps": 1.0},
"scene_filter": {"hamming_threshold": 12},
"detection": {"confidence_threshold": 0.5},
"ocr": {"languages": ["en", "es"], "min_confidence": 0.3},
"resolver": {"fuzzy_threshold": 60},
}
"""
def __init__(self, base, overrides: dict):
self._base = base
self._overrides = overrides
def __getattr__(self, name):
return getattr(self._base, name)
def _patch(self, config, key: str):
patches = self._overrides.get(key, {})
for k, v in patches.items():
if hasattr(config, k):
setattr(config, k, v)
return config
def frame_extraction_config(self):
return self._patch(self._base.frame_extraction_config(), "frame_extraction")
def scene_filter_config(self):
return self._patch(self._base.scene_filter_config(), "scene_filter")
def detection_config(self):
return self._patch(self._base.detection_config(), "detection")
def ocr_config(self):
return self._patch(self._base.ocr_config(), "ocr")
def resolver_config(self):
return self._patch(self._base.resolver_config(), "resolver")
def vlm_prompt(self, crop_context):
return self._base.vlm_prompt(crop_context)
def aggregate(self, detections):
return self._base.aggregate(detections)
def auxiliary_detections(self, source):
return self._base.auxiliary_detections(source)
def replay_from(
job_id: str,
start_stage: str,
config_overrides: dict | None = None,
checkpoint: bool = True,
) -> dict:
"""
Replay the pipeline from a specific stage.
Loads the checkpoint from the stage immediately before start_stage,
applies config overrides, and runs the subgraph from start_stage onward.
Returns the final state dict.
"""
if start_stage not in NODES:
raise ValueError(f"Unknown stage: {start_stage!r}. Options: {NODES}")
start_idx = NODES.index(start_stage)
# Load checkpoint from the stage before start_stage
if start_idx == 0:
raise ValueError("Cannot replay from the first stage — just run the full pipeline")
previous_stage = NODES[start_idx - 1]
available = list_checkpoints(job_id)
if previous_stage not in available:
raise ValueError(
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
f"Available: {available}"
)
logger.info("Replaying job %s from %s (loading checkpoint: %s)",
job_id, start_stage, previous_stage)
state = load_checkpoint(job_id, previous_stage)
# Apply config overrides
if config_overrides:
state["config_overrides"] = config_overrides
# Set run context for SSE events
run_id = str(uuid.uuid4())[:8]
emit.set_run_context(
run_id=run_id,
parent_job_id=job_id,
run_type="replay",
)
# Build subgraph starting from start_stage
graph = build_graph(checkpoint=checkpoint, start_from=start_stage)
pipeline = graph.compile()
try:
result = pipeline.invoke(state)
finally:
emit.clear_run_context()
return result

View File

@@ -0,0 +1,133 @@
"""State serialization — DetectState ↔ JSON-compatible dict."""
from __future__ import annotations
import dataclasses
from detect.models import (
BoundingBox,
BrandDetection,
Frame,
PipelineStats,
TextCandidate,
)
# ---------------------------------------------------------------------------
# Serialize helpers
# ---------------------------------------------------------------------------
def serialize_frame_meta(frame: Frame) -> dict:
meta = {
"sequence": frame.sequence,
"chunk_id": frame.chunk_id,
"timestamp": frame.timestamp,
"perceptual_hash": frame.perceptual_hash,
}
return meta
def serialize_text_candidate(tc: TextCandidate) -> dict:
bbox_dict = dataclasses.asdict(tc.bbox)
candidate = {
"frame_sequence": tc.frame.sequence,
"bbox": bbox_dict,
"text": tc.text,
"ocr_confidence": tc.ocr_confidence,
}
return candidate
def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
"""
Serialize DetectState to a JSON-compatible dict.
Frame images are replaced with S3 key references.
TextCandidate.frame references become frame_sequence integers.
"""
frames = state.get("frames", [])
filtered = state.get("filtered_frames", [])
manifest_strs = {str(k): v for k, v in frames_manifest.items()}
frames_meta = [serialize_frame_meta(f) for f in frames]
filtered_seqs = [f.sequence for f in filtered]
boxes_serialized = {}
for seq, boxes in state.get("boxes_by_frame", {}).items():
boxes_serialized[str(seq)] = [dataclasses.asdict(b) for b in boxes]
text_candidates = [serialize_text_candidate(tc) for tc in state.get("text_candidates", [])]
unresolved = [serialize_text_candidate(tc) for tc in state.get("unresolved_candidates", [])]
detections = [dataclasses.asdict(d) for d in state.get("detections", [])]
stats = dataclasses.asdict(state.get("stats", PipelineStats()))
checkpoint = {
"job_id": state.get("job_id", ""),
"video_path": state.get("video_path", ""),
"profile_name": state.get("profile_name", ""),
"config_overrides": state.get("config_overrides", {}),
"frames_manifest": manifest_strs,
"frames_meta": frames_meta,
"filtered_frame_sequences": filtered_seqs,
"boxes_by_frame": boxes_serialized,
"text_candidates": text_candidates,
"unresolved_candidates": unresolved,
"detections": detections,
"stats": stats,
}
return checkpoint
# ---------------------------------------------------------------------------
# Deserialize helpers
# ---------------------------------------------------------------------------
def deserialize_text_candidate(d: dict, frame_map: dict[int, Frame]) -> TextCandidate:
frame = frame_map[d["frame_sequence"]]
bbox = BoundingBox(**d["bbox"])
candidate = TextCandidate(
frame=frame,
bbox=bbox,
text=d["text"],
ocr_confidence=d["ocr_confidence"],
)
return candidate
def deserialize_state(checkpoint: dict, frames: list[Frame]) -> dict:
"""Reconstitute DetectState from a checkpoint dict + loaded frames."""
frame_map = {f.sequence: f for f in frames}
filtered_seqs = set(checkpoint.get("filtered_frame_sequences", []))
filtered_frames = [f for f in frames if f.sequence in filtered_seqs]
boxes_by_frame = {}
for seq_str, box_dicts in checkpoint.get("boxes_by_frame", {}).items():
seq = int(seq_str)
boxes_by_frame[seq] = [BoundingBox(**b) for b in box_dicts]
text_candidates = [
deserialize_text_candidate(d, frame_map)
for d in checkpoint.get("text_candidates", [])
]
unresolved_candidates = [
deserialize_text_candidate(d, frame_map)
for d in checkpoint.get("unresolved_candidates", [])
]
detections = [BrandDetection(**d) for d in checkpoint.get("detections", [])]
stats = PipelineStats(**checkpoint.get("stats", {}))
state = {
"job_id": checkpoint.get("job_id", ""),
"video_path": checkpoint.get("video_path", ""),
"profile_name": checkpoint.get("profile_name", ""),
"config_overrides": checkpoint.get("config_overrides", {}),
"frames": frames,
"filtered_frames": filtered_frames,
"boxes_by_frame": boxes_by_frame,
"text_candidates": text_candidates,
"unresolved_candidates": unresolved_candidates,
"detections": detections,
"stats": stats,
}
return state

View File

@@ -0,0 +1,215 @@
"""
Checkpoint storage — save/load stage state.
Binary data (frame images) → S3/MinIO via frames.py
Structured data (boxes, detections, stats, config) → Postgres via Django ORM
Until the Django model is generated by modelgen, checkpoint data is stored
as JSON in S3 as a fallback. Once DetectJob/StageCheckpoint models exist,
this module switches to Postgres.
"""
from __future__ import annotations
import json
import logging
import os
import tempfile
from pathlib import Path
from .frames import save_frames, load_frames, BUCKET, CHECKPOINT_PREFIX
from .serializer import serialize_state, deserialize_state
logger = logging.getLogger(__name__)
def _has_db() -> bool:
"""Check if the DB layer is available (Django + models generated by modelgen)."""
try:
from core.db.detect import get_stage_checkpoint as _
# Quick check that the model exists (modelgen may not have run yet)
from admin.mpr.media_assets.models import StageCheckpoint as _
return True
except (ImportError, Exception):
return False
# ---------------------------------------------------------------------------
# Save
# ---------------------------------------------------------------------------
def save_checkpoint(
job_id: str,
stage: str,
stage_index: int,
state: dict,
frames_manifest: dict[int, str] | None = None,
) -> str:
"""
Save a stage checkpoint.
Saves frame images to S3 (if not already saved), then persists
structured state to Postgres (or S3 JSON fallback).
Returns the checkpoint identifier (DB id or S3 key).
"""
# Save frames to S3 if no manifest provided
if frames_manifest is None:
all_frames = state.get("frames", [])
frames_manifest = save_frames(job_id, all_frames)
checkpoint_data = serialize_state(state, frames_manifest)
if _has_db():
checkpoint_id = _save_to_db(job_id, stage, stage_index, checkpoint_data)
else:
checkpoint_id = _save_to_s3(job_id, stage, checkpoint_data)
return checkpoint_id
def _save_to_db(job_id: str, stage: str, stage_index: int, data: dict) -> str:
"""Save checkpoint structured data to Postgres via core/db."""
import uuid
from core.db.detect import save_stage_checkpoint
job_uuid = uuid.UUID(job_id) if isinstance(job_id, str) else job_id
checkpoint_id = uuid.uuid4()
frames_prefix = f"{CHECKPOINT_PREFIX}/{job_id}/frames/"
checkpoint = save_stage_checkpoint(
id=checkpoint_id,
job_id=job_uuid,
stage=stage,
stage_index=stage_index,
frames_prefix=frames_prefix,
frames_manifest=data.get("frames_manifest", {}),
frames_meta=data.get("frames_meta", []),
filtered_frame_sequences=data.get("filtered_frame_sequences", []),
boxes_by_frame=data.get("boxes_by_frame", {}),
text_candidates=data.get("text_candidates", []),
unresolved_candidates=data.get("unresolved_candidates", []),
detections=data.get("detections", []),
stats=data.get("stats", {}),
config_snapshot=data.get("config_overrides", {}),
config_overrides=data.get("config_overrides", {}),
video_path=data.get("video_path", ""),
profile_name=data.get("profile_name", ""),
)
logger.info("Checkpoint saved to DB: %s/%s (id=%s)", job_id, stage, checkpoint.id)
return str(checkpoint.id)
def _save_to_s3(job_id: str, stage: str, data: dict) -> str:
"""Fallback: save checkpoint as JSON to S3 (before modelgen generates DB models)."""
from core.storage.s3 import upload_file
checkpoint_json = json.dumps(data, default=str)
key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json"
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp:
tmp.write(checkpoint_json)
tmp_path = tmp.name
try:
upload_file(tmp_path, BUCKET, key)
finally:
os.unlink(tmp_path)
logger.info("Checkpoint saved to S3: s3://%s/%s", BUCKET, key)
return key
# ---------------------------------------------------------------------------
# Load
# ---------------------------------------------------------------------------
def load_checkpoint(job_id: str, stage: str) -> dict:
"""
Load a stage checkpoint and reconstitute full DetectState.
Tries Postgres first, falls back to S3 JSON.
"""
if _has_db():
data = _load_from_db(job_id, stage)
else:
data = _load_from_s3(job_id, stage)
raw_manifest = data.get("frames_manifest", {})
manifest = {int(k): v for k, v in raw_manifest.items()}
frame_metadata = data.get("frames_meta", [])
frames = load_frames(manifest, frame_metadata)
state = deserialize_state(data, frames)
logger.info("Checkpoint loaded: %s/%s (%d frames)", job_id, stage, len(frames))
return state
def _load_from_db(job_id: str, stage: str) -> dict:
"""Load checkpoint data from Postgres via core/db."""
from core.db.detect import get_stage_checkpoint
checkpoint = get_stage_checkpoint(job_id, stage)
data = {
"job_id": str(checkpoint.job_id),
"video_path": checkpoint.video_path,
"profile_name": checkpoint.profile_name,
"config_overrides": checkpoint.config_overrides,
"frames_manifest": checkpoint.frames_manifest,
"frames_meta": checkpoint.frames_meta,
"filtered_frame_sequences": checkpoint.filtered_frame_sequences,
"boxes_by_frame": checkpoint.boxes_by_frame,
"text_candidates": checkpoint.text_candidates,
"unresolved_candidates": checkpoint.unresolved_candidates,
"detections": checkpoint.detections,
"stats": checkpoint.stats,
}
return data
def _load_from_s3(job_id: str, stage: str) -> dict:
"""Fallback: load checkpoint JSON from S3."""
from core.storage.s3 import download_to_temp
key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json"
tmp_path = download_to_temp(BUCKET, key)
try:
with open(tmp_path) as f:
data = json.load(f)
finally:
os.unlink(tmp_path)
return data
# ---------------------------------------------------------------------------
# List
# ---------------------------------------------------------------------------
def list_checkpoints(job_id: str) -> list[str]:
"""List available checkpoint stages for a job."""
if _has_db():
return _list_from_db(job_id)
return _list_from_s3(job_id)
def _list_from_db(job_id: str) -> list[str]:
from core.db.detect import list_stage_checkpoints
return list_stage_checkpoints(job_id)
def _list_from_s3(job_id: str) -> list[str]:
from core.storage.s3 import list_objects
prefix = f"{CHECKPOINT_PREFIX}/{job_id}/stages/"
objects = list_objects(BUCKET, prefix)
stages = []
for obj in objects:
name = Path(obj["key"]).stem
stages.append(name)
return stages

View File

@@ -0,0 +1,71 @@
"""
Celery tasks for detection pipeline async operations.
retry_candidates: re-run VLM/cloud escalation with different config.
"""
from __future__ import annotations
import logging
import uuid
from datetime import datetime, timezone
from celery import shared_task
logger = logging.getLogger(__name__)
@shared_task(bind=True, max_retries=1, default_retry_delay=30)
def retry_candidates(
self,
job_id: str,
config_overrides: dict | None = None,
start_stage: str = "escalate_vlm",
):
"""
Retry unresolved candidates with different config.
Loads the checkpoint from the stage before start_stage,
applies config overrides (e.g. different cloud provider),
and runs from start_stage onward.
"""
from detect.checkpoint.replay import replay_from
run_id = str(uuid.uuid4())[:8]
logger.info("Retry task %s: job=%s, from=%s, overrides=%s",
run_id, job_id, start_stage, config_overrides)
try:
result = replay_from(
job_id=job_id,
start_stage=start_stage,
config_overrides=config_overrides,
)
detections = result.get("detections", [])
report = result.get("report")
brands_found = len(report.brands) if report else 0
logger.info("Retry %s complete: %d detections, %d brands",
run_id, len(detections), brands_found)
return {
"status": "completed",
"run_id": run_id,
"job_id": job_id,
"detections": len(detections),
"brands_found": brands_found,
}
except Exception as e:
logger.exception("Retry %s failed: %s", run_id, e)
if self.request.retries < self.max_retries:
raise self.retry(exc=e)
return {
"status": "failed",
"run_id": run_id,
"job_id": job_id,
"error": str(e),
}