phase 10
This commit is contained in:
14
detect/checkpoint/__init__.py
Normal file
14
detect/checkpoint/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Stage checkpoint, replay, and retry.
|
||||
|
||||
detect/checkpoint/
|
||||
frames.py — frame image S3 upload/download
|
||||
serializer.py — state ↔ JSON conversion
|
||||
storage.py — checkpoint save/load/list (Postgres + S3)
|
||||
replay.py — replay_from, OverrideProfile
|
||||
tasks.py — retry_candidates Celery task
|
||||
"""
|
||||
|
||||
from .storage import save_checkpoint, load_checkpoint, list_checkpoints
|
||||
from .frames import save_frames, load_frames
|
||||
from .replay import replay_from, OverrideProfile
|
||||
80
detect/checkpoint/frames.py
Normal file
80
detect/checkpoint/frames.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Frame image storage — save/load to S3/MinIO as JPEGs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from detect.models import Frame
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BUCKET = os.environ.get("S3_BUCKET_OUT", "mpr-media-out")
|
||||
CHECKPOINT_PREFIX = "checkpoints"
|
||||
|
||||
|
||||
def save_frames(job_id: str, frames: list[Frame]) -> dict[int, str]:
|
||||
"""
|
||||
Save frame images to S3 as JPEGs.
|
||||
|
||||
Returns manifest: {sequence: s3_key}
|
||||
"""
|
||||
from core.storage.s3 import upload_file
|
||||
|
||||
manifest = {}
|
||||
|
||||
for frame in frames:
|
||||
key = f"{CHECKPOINT_PREFIX}/{job_id}/frames/{frame.sequence}.jpg"
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
|
||||
img = Image.fromarray(frame.image)
|
||||
img.save(tmp, format="JPEG", quality=85)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
upload_file(tmp_path, BUCKET, key)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
manifest[frame.sequence] = key
|
||||
|
||||
logger.info("Saved %d frames to s3://%s/%s/%s/frames/",
|
||||
len(frames), BUCKET, CHECKPOINT_PREFIX, job_id)
|
||||
return manifest
|
||||
|
||||
|
||||
def load_frames(manifest: dict[int, str], frame_metadata: list[dict]) -> list[Frame]:
|
||||
"""
|
||||
Load frame images from S3 and reconstitute Frame objects.
|
||||
|
||||
frame_metadata: list of dicts with sequence, chunk_id, timestamp, perceptual_hash.
|
||||
"""
|
||||
from core.storage.s3 import download_to_temp
|
||||
|
||||
meta_map = {m["sequence"]: m for m in frame_metadata}
|
||||
frames = []
|
||||
|
||||
for seq, key in manifest.items():
|
||||
tmp_path = download_to_temp(BUCKET, key)
|
||||
try:
|
||||
img = Image.open(tmp_path).convert("RGB")
|
||||
image_array = np.array(img)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
meta = meta_map.get(seq, {})
|
||||
frame = Frame(
|
||||
sequence=seq,
|
||||
chunk_id=meta.get("chunk_id", 0),
|
||||
timestamp=meta.get("timestamp", 0.0),
|
||||
image=image_array,
|
||||
perceptual_hash=meta.get("perceptual_hash", ""),
|
||||
)
|
||||
frames.append(frame)
|
||||
|
||||
frames.sort(key=lambda f: f.sequence)
|
||||
return frames
|
||||
132
detect/checkpoint/replay.py
Normal file
132
detect/checkpoint/replay.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
Pipeline replay — re-run from any stage with different config.
|
||||
|
||||
Loads a checkpoint, applies config overrides, builds a subgraph
|
||||
starting from the target stage, and invokes it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import uuid
|
||||
|
||||
from detect import emit
|
||||
from detect.checkpoint import load_checkpoint, list_checkpoints
|
||||
from detect.graph import NODES, build_graph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OverrideProfile:
|
||||
"""
|
||||
Wraps a ContentTypeProfile and patches config methods with overrides.
|
||||
|
||||
Override dict structure:
|
||||
{
|
||||
"frame_extraction": {"fps": 1.0},
|
||||
"scene_filter": {"hamming_threshold": 12},
|
||||
"detection": {"confidence_threshold": 0.5},
|
||||
"ocr": {"languages": ["en", "es"], "min_confidence": 0.3},
|
||||
"resolver": {"fuzzy_threshold": 60},
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, base, overrides: dict):
|
||||
self._base = base
|
||||
self._overrides = overrides
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._base, name)
|
||||
|
||||
def _patch(self, config, key: str):
|
||||
patches = self._overrides.get(key, {})
|
||||
for k, v in patches.items():
|
||||
if hasattr(config, k):
|
||||
setattr(config, k, v)
|
||||
return config
|
||||
|
||||
def frame_extraction_config(self):
|
||||
return self._patch(self._base.frame_extraction_config(), "frame_extraction")
|
||||
|
||||
def scene_filter_config(self):
|
||||
return self._patch(self._base.scene_filter_config(), "scene_filter")
|
||||
|
||||
def detection_config(self):
|
||||
return self._patch(self._base.detection_config(), "detection")
|
||||
|
||||
def ocr_config(self):
|
||||
return self._patch(self._base.ocr_config(), "ocr")
|
||||
|
||||
def resolver_config(self):
|
||||
return self._patch(self._base.resolver_config(), "resolver")
|
||||
|
||||
def vlm_prompt(self, crop_context):
|
||||
return self._base.vlm_prompt(crop_context)
|
||||
|
||||
def aggregate(self, detections):
|
||||
return self._base.aggregate(detections)
|
||||
|
||||
def auxiliary_detections(self, source):
|
||||
return self._base.auxiliary_detections(source)
|
||||
|
||||
|
||||
def replay_from(
|
||||
job_id: str,
|
||||
start_stage: str,
|
||||
config_overrides: dict | None = None,
|
||||
checkpoint: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Replay the pipeline from a specific stage.
|
||||
|
||||
Loads the checkpoint from the stage immediately before start_stage,
|
||||
applies config overrides, and runs the subgraph from start_stage onward.
|
||||
|
||||
Returns the final state dict.
|
||||
"""
|
||||
if start_stage not in NODES:
|
||||
raise ValueError(f"Unknown stage: {start_stage!r}. Options: {NODES}")
|
||||
|
||||
start_idx = NODES.index(start_stage)
|
||||
|
||||
# Load checkpoint from the stage before start_stage
|
||||
if start_idx == 0:
|
||||
raise ValueError("Cannot replay from the first stage — just run the full pipeline")
|
||||
|
||||
previous_stage = NODES[start_idx - 1]
|
||||
|
||||
available = list_checkpoints(job_id)
|
||||
if previous_stage not in available:
|
||||
raise ValueError(
|
||||
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
|
||||
f"Available: {available}"
|
||||
)
|
||||
|
||||
logger.info("Replaying job %s from %s (loading checkpoint: %s)",
|
||||
job_id, start_stage, previous_stage)
|
||||
|
||||
state = load_checkpoint(job_id, previous_stage)
|
||||
|
||||
# Apply config overrides
|
||||
if config_overrides:
|
||||
state["config_overrides"] = config_overrides
|
||||
|
||||
# Set run context for SSE events
|
||||
run_id = str(uuid.uuid4())[:8]
|
||||
emit.set_run_context(
|
||||
run_id=run_id,
|
||||
parent_job_id=job_id,
|
||||
run_type="replay",
|
||||
)
|
||||
|
||||
# Build subgraph starting from start_stage
|
||||
graph = build_graph(checkpoint=checkpoint, start_from=start_stage)
|
||||
pipeline = graph.compile()
|
||||
|
||||
try:
|
||||
result = pipeline.invoke(state)
|
||||
finally:
|
||||
emit.clear_run_context()
|
||||
|
||||
return result
|
||||
133
detect/checkpoint/serializer.py
Normal file
133
detect/checkpoint/serializer.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""State serialization — DetectState ↔ JSON-compatible dict."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
|
||||
from detect.models import (
|
||||
BoundingBox,
|
||||
BrandDetection,
|
||||
Frame,
|
||||
PipelineStats,
|
||||
TextCandidate,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialize helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def serialize_frame_meta(frame: Frame) -> dict:
|
||||
meta = {
|
||||
"sequence": frame.sequence,
|
||||
"chunk_id": frame.chunk_id,
|
||||
"timestamp": frame.timestamp,
|
||||
"perceptual_hash": frame.perceptual_hash,
|
||||
}
|
||||
return meta
|
||||
|
||||
|
||||
def serialize_text_candidate(tc: TextCandidate) -> dict:
|
||||
bbox_dict = dataclasses.asdict(tc.bbox)
|
||||
candidate = {
|
||||
"frame_sequence": tc.frame.sequence,
|
||||
"bbox": bbox_dict,
|
||||
"text": tc.text,
|
||||
"ocr_confidence": tc.ocr_confidence,
|
||||
}
|
||||
return candidate
|
||||
|
||||
|
||||
def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
|
||||
"""
|
||||
Serialize DetectState to a JSON-compatible dict.
|
||||
|
||||
Frame images are replaced with S3 key references.
|
||||
TextCandidate.frame references become frame_sequence integers.
|
||||
"""
|
||||
frames = state.get("frames", [])
|
||||
filtered = state.get("filtered_frames", [])
|
||||
|
||||
manifest_strs = {str(k): v for k, v in frames_manifest.items()}
|
||||
frames_meta = [serialize_frame_meta(f) for f in frames]
|
||||
filtered_seqs = [f.sequence for f in filtered]
|
||||
|
||||
boxes_serialized = {}
|
||||
for seq, boxes in state.get("boxes_by_frame", {}).items():
|
||||
boxes_serialized[str(seq)] = [dataclasses.asdict(b) for b in boxes]
|
||||
|
||||
text_candidates = [serialize_text_candidate(tc) for tc in state.get("text_candidates", [])]
|
||||
unresolved = [serialize_text_candidate(tc) for tc in state.get("unresolved_candidates", [])]
|
||||
detections = [dataclasses.asdict(d) for d in state.get("detections", [])]
|
||||
stats = dataclasses.asdict(state.get("stats", PipelineStats()))
|
||||
|
||||
checkpoint = {
|
||||
"job_id": state.get("job_id", ""),
|
||||
"video_path": state.get("video_path", ""),
|
||||
"profile_name": state.get("profile_name", ""),
|
||||
"config_overrides": state.get("config_overrides", {}),
|
||||
"frames_manifest": manifest_strs,
|
||||
"frames_meta": frames_meta,
|
||||
"filtered_frame_sequences": filtered_seqs,
|
||||
"boxes_by_frame": boxes_serialized,
|
||||
"text_candidates": text_candidates,
|
||||
"unresolved_candidates": unresolved,
|
||||
"detections": detections,
|
||||
"stats": stats,
|
||||
}
|
||||
return checkpoint
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deserialize helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def deserialize_text_candidate(d: dict, frame_map: dict[int, Frame]) -> TextCandidate:
|
||||
frame = frame_map[d["frame_sequence"]]
|
||||
bbox = BoundingBox(**d["bbox"])
|
||||
candidate = TextCandidate(
|
||||
frame=frame,
|
||||
bbox=bbox,
|
||||
text=d["text"],
|
||||
ocr_confidence=d["ocr_confidence"],
|
||||
)
|
||||
return candidate
|
||||
|
||||
|
||||
def deserialize_state(checkpoint: dict, frames: list[Frame]) -> dict:
|
||||
"""Reconstitute DetectState from a checkpoint dict + loaded frames."""
|
||||
frame_map = {f.sequence: f for f in frames}
|
||||
|
||||
filtered_seqs = set(checkpoint.get("filtered_frame_sequences", []))
|
||||
filtered_frames = [f for f in frames if f.sequence in filtered_seqs]
|
||||
|
||||
boxes_by_frame = {}
|
||||
for seq_str, box_dicts in checkpoint.get("boxes_by_frame", {}).items():
|
||||
seq = int(seq_str)
|
||||
boxes_by_frame[seq] = [BoundingBox(**b) for b in box_dicts]
|
||||
|
||||
text_candidates = [
|
||||
deserialize_text_candidate(d, frame_map)
|
||||
for d in checkpoint.get("text_candidates", [])
|
||||
]
|
||||
unresolved_candidates = [
|
||||
deserialize_text_candidate(d, frame_map)
|
||||
for d in checkpoint.get("unresolved_candidates", [])
|
||||
]
|
||||
detections = [BrandDetection(**d) for d in checkpoint.get("detections", [])]
|
||||
stats = PipelineStats(**checkpoint.get("stats", {}))
|
||||
|
||||
state = {
|
||||
"job_id": checkpoint.get("job_id", ""),
|
||||
"video_path": checkpoint.get("video_path", ""),
|
||||
"profile_name": checkpoint.get("profile_name", ""),
|
||||
"config_overrides": checkpoint.get("config_overrides", {}),
|
||||
"frames": frames,
|
||||
"filtered_frames": filtered_frames,
|
||||
"boxes_by_frame": boxes_by_frame,
|
||||
"text_candidates": text_candidates,
|
||||
"unresolved_candidates": unresolved_candidates,
|
||||
"detections": detections,
|
||||
"stats": stats,
|
||||
}
|
||||
return state
|
||||
215
detect/checkpoint/storage.py
Normal file
215
detect/checkpoint/storage.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
Checkpoint storage — save/load stage state.
|
||||
|
||||
Binary data (frame images) → S3/MinIO via frames.py
|
||||
Structured data (boxes, detections, stats, config) → Postgres via Django ORM
|
||||
|
||||
Until the Django model is generated by modelgen, checkpoint data is stored
|
||||
as JSON in S3 as a fallback. Once DetectJob/StageCheckpoint models exist,
|
||||
this module switches to Postgres.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from .frames import save_frames, load_frames, BUCKET, CHECKPOINT_PREFIX
|
||||
from .serializer import serialize_state, deserialize_state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _has_db() -> bool:
|
||||
"""Check if the DB layer is available (Django + models generated by modelgen)."""
|
||||
try:
|
||||
from core.db.detect import get_stage_checkpoint as _
|
||||
# Quick check that the model exists (modelgen may not have run yet)
|
||||
from admin.mpr.media_assets.models import StageCheckpoint as _
|
||||
return True
|
||||
except (ImportError, Exception):
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Save
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def save_checkpoint(
|
||||
job_id: str,
|
||||
stage: str,
|
||||
stage_index: int,
|
||||
state: dict,
|
||||
frames_manifest: dict[int, str] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Save a stage checkpoint.
|
||||
|
||||
Saves frame images to S3 (if not already saved), then persists
|
||||
structured state to Postgres (or S3 JSON fallback).
|
||||
|
||||
Returns the checkpoint identifier (DB id or S3 key).
|
||||
"""
|
||||
# Save frames to S3 if no manifest provided
|
||||
if frames_manifest is None:
|
||||
all_frames = state.get("frames", [])
|
||||
frames_manifest = save_frames(job_id, all_frames)
|
||||
|
||||
checkpoint_data = serialize_state(state, frames_manifest)
|
||||
|
||||
if _has_db():
|
||||
checkpoint_id = _save_to_db(job_id, stage, stage_index, checkpoint_data)
|
||||
else:
|
||||
checkpoint_id = _save_to_s3(job_id, stage, checkpoint_data)
|
||||
|
||||
return checkpoint_id
|
||||
|
||||
|
||||
def _save_to_db(job_id: str, stage: str, stage_index: int, data: dict) -> str:
|
||||
"""Save checkpoint structured data to Postgres via core/db."""
|
||||
import uuid
|
||||
from core.db.detect import save_stage_checkpoint
|
||||
|
||||
job_uuid = uuid.UUID(job_id) if isinstance(job_id, str) else job_id
|
||||
checkpoint_id = uuid.uuid4()
|
||||
frames_prefix = f"{CHECKPOINT_PREFIX}/{job_id}/frames/"
|
||||
|
||||
checkpoint = save_stage_checkpoint(
|
||||
id=checkpoint_id,
|
||||
job_id=job_uuid,
|
||||
stage=stage,
|
||||
stage_index=stage_index,
|
||||
frames_prefix=frames_prefix,
|
||||
frames_manifest=data.get("frames_manifest", {}),
|
||||
frames_meta=data.get("frames_meta", []),
|
||||
filtered_frame_sequences=data.get("filtered_frame_sequences", []),
|
||||
boxes_by_frame=data.get("boxes_by_frame", {}),
|
||||
text_candidates=data.get("text_candidates", []),
|
||||
unresolved_candidates=data.get("unresolved_candidates", []),
|
||||
detections=data.get("detections", []),
|
||||
stats=data.get("stats", {}),
|
||||
config_snapshot=data.get("config_overrides", {}),
|
||||
config_overrides=data.get("config_overrides", {}),
|
||||
video_path=data.get("video_path", ""),
|
||||
profile_name=data.get("profile_name", ""),
|
||||
)
|
||||
|
||||
logger.info("Checkpoint saved to DB: %s/%s (id=%s)", job_id, stage, checkpoint.id)
|
||||
return str(checkpoint.id)
|
||||
|
||||
|
||||
def _save_to_s3(job_id: str, stage: str, data: dict) -> str:
|
||||
"""Fallback: save checkpoint as JSON to S3 (before modelgen generates DB models)."""
|
||||
from core.storage.s3 import upload_file
|
||||
|
||||
checkpoint_json = json.dumps(data, default=str)
|
||||
key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json"
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp:
|
||||
tmp.write(checkpoint_json)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
upload_file(tmp_path, BUCKET, key)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
logger.info("Checkpoint saved to S3: s3://%s/%s", BUCKET, key)
|
||||
return key
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Load
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_checkpoint(job_id: str, stage: str) -> dict:
|
||||
"""
|
||||
Load a stage checkpoint and reconstitute full DetectState.
|
||||
|
||||
Tries Postgres first, falls back to S3 JSON.
|
||||
"""
|
||||
if _has_db():
|
||||
data = _load_from_db(job_id, stage)
|
||||
else:
|
||||
data = _load_from_s3(job_id, stage)
|
||||
|
||||
raw_manifest = data.get("frames_manifest", {})
|
||||
manifest = {int(k): v for k, v in raw_manifest.items()}
|
||||
frame_metadata = data.get("frames_meta", [])
|
||||
frames = load_frames(manifest, frame_metadata)
|
||||
|
||||
state = deserialize_state(data, frames)
|
||||
|
||||
logger.info("Checkpoint loaded: %s/%s (%d frames)", job_id, stage, len(frames))
|
||||
return state
|
||||
|
||||
|
||||
def _load_from_db(job_id: str, stage: str) -> dict:
|
||||
"""Load checkpoint data from Postgres via core/db."""
|
||||
from core.db.detect import get_stage_checkpoint
|
||||
|
||||
checkpoint = get_stage_checkpoint(job_id, stage)
|
||||
|
||||
data = {
|
||||
"job_id": str(checkpoint.job_id),
|
||||
"video_path": checkpoint.video_path,
|
||||
"profile_name": checkpoint.profile_name,
|
||||
"config_overrides": checkpoint.config_overrides,
|
||||
"frames_manifest": checkpoint.frames_manifest,
|
||||
"frames_meta": checkpoint.frames_meta,
|
||||
"filtered_frame_sequences": checkpoint.filtered_frame_sequences,
|
||||
"boxes_by_frame": checkpoint.boxes_by_frame,
|
||||
"text_candidates": checkpoint.text_candidates,
|
||||
"unresolved_candidates": checkpoint.unresolved_candidates,
|
||||
"detections": checkpoint.detections,
|
||||
"stats": checkpoint.stats,
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
def _load_from_s3(job_id: str, stage: str) -> dict:
|
||||
"""Fallback: load checkpoint JSON from S3."""
|
||||
from core.storage.s3 import download_to_temp
|
||||
|
||||
key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json"
|
||||
tmp_path = download_to_temp(BUCKET, key)
|
||||
try:
|
||||
with open(tmp_path) as f:
|
||||
data = json.load(f)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# List
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def list_checkpoints(job_id: str) -> list[str]:
|
||||
"""List available checkpoint stages for a job."""
|
||||
if _has_db():
|
||||
return _list_from_db(job_id)
|
||||
return _list_from_s3(job_id)
|
||||
|
||||
|
||||
def _list_from_db(job_id: str) -> list[str]:
|
||||
from core.db.detect import list_stage_checkpoints
|
||||
return list_stage_checkpoints(job_id)
|
||||
|
||||
|
||||
def _list_from_s3(job_id: str) -> list[str]:
|
||||
from core.storage.s3 import list_objects
|
||||
|
||||
prefix = f"{CHECKPOINT_PREFIX}/{job_id}/stages/"
|
||||
objects = list_objects(BUCKET, prefix)
|
||||
|
||||
stages = []
|
||||
for obj in objects:
|
||||
name = Path(obj["key"]).stem
|
||||
stages.append(name)
|
||||
|
||||
return stages
|
||||
71
detect/checkpoint/tasks.py
Normal file
71
detect/checkpoint/tasks.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
Celery tasks for detection pipeline async operations.
|
||||
|
||||
retry_candidates: re-run VLM/cloud escalation with different config.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from celery import shared_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(bind=True, max_retries=1, default_retry_delay=30)
|
||||
def retry_candidates(
|
||||
self,
|
||||
job_id: str,
|
||||
config_overrides: dict | None = None,
|
||||
start_stage: str = "escalate_vlm",
|
||||
):
|
||||
"""
|
||||
Retry unresolved candidates with different config.
|
||||
|
||||
Loads the checkpoint from the stage before start_stage,
|
||||
applies config overrides (e.g. different cloud provider),
|
||||
and runs from start_stage onward.
|
||||
"""
|
||||
from detect.checkpoint.replay import replay_from
|
||||
|
||||
run_id = str(uuid.uuid4())[:8]
|
||||
logger.info("Retry task %s: job=%s, from=%s, overrides=%s",
|
||||
run_id, job_id, start_stage, config_overrides)
|
||||
|
||||
try:
|
||||
result = replay_from(
|
||||
job_id=job_id,
|
||||
start_stage=start_stage,
|
||||
config_overrides=config_overrides,
|
||||
)
|
||||
|
||||
detections = result.get("detections", [])
|
||||
report = result.get("report")
|
||||
brands_found = len(report.brands) if report else 0
|
||||
|
||||
logger.info("Retry %s complete: %d detections, %d brands",
|
||||
run_id, len(detections), brands_found)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"run_id": run_id,
|
||||
"job_id": job_id,
|
||||
"detections": len(detections),
|
||||
"brands_found": brands_found,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Retry %s failed: %s", run_id, e)
|
||||
|
||||
if self.request.retries < self.max_retries:
|
||||
raise self.retry(exc=e)
|
||||
|
||||
return {
|
||||
"status": "failed",
|
||||
"run_id": run_id,
|
||||
"job_id": job_id,
|
||||
"error": str(e),
|
||||
}
|
||||
Reference in New Issue
Block a user