This commit is contained in:
2026-03-26 04:24:32 -03:00
parent 08b67f2bb7
commit 08c58a6a9d
43 changed files with 2627 additions and 252 deletions

View File

@@ -0,0 +1,14 @@
"""
Stage checkpoint, replay, and retry.
detect/checkpoint/
frames.py — frame image S3 upload/download
serializer.py — state ↔ JSON conversion
storage.py — checkpoint save/load/list (Postgres + S3)
replay.py — replay_from, OverrideProfile
tasks.py — retry_candidates Celery task
"""
from .storage import save_checkpoint, load_checkpoint, list_checkpoints
from .frames import save_frames, load_frames
from .replay import replay_from, OverrideProfile

View File

@@ -0,0 +1,80 @@
"""Frame image storage — save/load to S3/MinIO as JPEGs."""
from __future__ import annotations
import logging
import os
import tempfile
import numpy as np
from PIL import Image
from detect.models import Frame
logger = logging.getLogger(__name__)
BUCKET = os.environ.get("S3_BUCKET_OUT", "mpr-media-out")
CHECKPOINT_PREFIX = "checkpoints"
def save_frames(job_id: str, frames: list[Frame]) -> dict[int, str]:
"""
Save frame images to S3 as JPEGs.
Returns manifest: {sequence: s3_key}
"""
from core.storage.s3 import upload_file
manifest = {}
for frame in frames:
key = f"{CHECKPOINT_PREFIX}/{job_id}/frames/{frame.sequence}.jpg"
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
img = Image.fromarray(frame.image)
img.save(tmp, format="JPEG", quality=85)
tmp_path = tmp.name
try:
upload_file(tmp_path, BUCKET, key)
finally:
os.unlink(tmp_path)
manifest[frame.sequence] = key
logger.info("Saved %d frames to s3://%s/%s/%s/frames/",
len(frames), BUCKET, CHECKPOINT_PREFIX, job_id)
return manifest
def load_frames(manifest: dict[int, str], frame_metadata: list[dict]) -> list[Frame]:
"""
Load frame images from S3 and reconstitute Frame objects.
frame_metadata: list of dicts with sequence, chunk_id, timestamp, perceptual_hash.
"""
from core.storage.s3 import download_to_temp
meta_map = {m["sequence"]: m for m in frame_metadata}
frames = []
for seq, key in manifest.items():
tmp_path = download_to_temp(BUCKET, key)
try:
img = Image.open(tmp_path).convert("RGB")
image_array = np.array(img)
finally:
os.unlink(tmp_path)
meta = meta_map.get(seq, {})
frame = Frame(
sequence=seq,
chunk_id=meta.get("chunk_id", 0),
timestamp=meta.get("timestamp", 0.0),
image=image_array,
perceptual_hash=meta.get("perceptual_hash", ""),
)
frames.append(frame)
frames.sort(key=lambda f: f.sequence)
return frames

132
detect/checkpoint/replay.py Normal file
View File

@@ -0,0 +1,132 @@
"""
Pipeline replay — re-run from any stage with different config.
Loads a checkpoint, applies config overrides, builds a subgraph
starting from the target stage, and invokes it.
"""
from __future__ import annotations
import logging
import uuid
from detect import emit
from detect.checkpoint import load_checkpoint, list_checkpoints
from detect.graph import NODES, build_graph
logger = logging.getLogger(__name__)
class OverrideProfile:
"""
Wraps a ContentTypeProfile and patches config methods with overrides.
Override dict structure:
{
"frame_extraction": {"fps": 1.0},
"scene_filter": {"hamming_threshold": 12},
"detection": {"confidence_threshold": 0.5},
"ocr": {"languages": ["en", "es"], "min_confidence": 0.3},
"resolver": {"fuzzy_threshold": 60},
}
"""
def __init__(self, base, overrides: dict):
self._base = base
self._overrides = overrides
def __getattr__(self, name):
return getattr(self._base, name)
def _patch(self, config, key: str):
patches = self._overrides.get(key, {})
for k, v in patches.items():
if hasattr(config, k):
setattr(config, k, v)
return config
def frame_extraction_config(self):
return self._patch(self._base.frame_extraction_config(), "frame_extraction")
def scene_filter_config(self):
return self._patch(self._base.scene_filter_config(), "scene_filter")
def detection_config(self):
return self._patch(self._base.detection_config(), "detection")
def ocr_config(self):
return self._patch(self._base.ocr_config(), "ocr")
def resolver_config(self):
return self._patch(self._base.resolver_config(), "resolver")
def vlm_prompt(self, crop_context):
return self._base.vlm_prompt(crop_context)
def aggregate(self, detections):
return self._base.aggregate(detections)
def auxiliary_detections(self, source):
return self._base.auxiliary_detections(source)
def replay_from(
job_id: str,
start_stage: str,
config_overrides: dict | None = None,
checkpoint: bool = True,
) -> dict:
"""
Replay the pipeline from a specific stage.
Loads the checkpoint from the stage immediately before start_stage,
applies config overrides, and runs the subgraph from start_stage onward.
Returns the final state dict.
"""
if start_stage not in NODES:
raise ValueError(f"Unknown stage: {start_stage!r}. Options: {NODES}")
start_idx = NODES.index(start_stage)
# Load checkpoint from the stage before start_stage
if start_idx == 0:
raise ValueError("Cannot replay from the first stage — just run the full pipeline")
previous_stage = NODES[start_idx - 1]
available = list_checkpoints(job_id)
if previous_stage not in available:
raise ValueError(
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
f"Available: {available}"
)
logger.info("Replaying job %s from %s (loading checkpoint: %s)",
job_id, start_stage, previous_stage)
state = load_checkpoint(job_id, previous_stage)
# Apply config overrides
if config_overrides:
state["config_overrides"] = config_overrides
# Set run context for SSE events
run_id = str(uuid.uuid4())[:8]
emit.set_run_context(
run_id=run_id,
parent_job_id=job_id,
run_type="replay",
)
# Build subgraph starting from start_stage
graph = build_graph(checkpoint=checkpoint, start_from=start_stage)
pipeline = graph.compile()
try:
result = pipeline.invoke(state)
finally:
emit.clear_run_context()
return result

View File

@@ -0,0 +1,133 @@
"""State serialization — DetectState ↔ JSON-compatible dict."""
from __future__ import annotations
import dataclasses
from detect.models import (
BoundingBox,
BrandDetection,
Frame,
PipelineStats,
TextCandidate,
)
# ---------------------------------------------------------------------------
# Serialize helpers
# ---------------------------------------------------------------------------
def serialize_frame_meta(frame: Frame) -> dict:
meta = {
"sequence": frame.sequence,
"chunk_id": frame.chunk_id,
"timestamp": frame.timestamp,
"perceptual_hash": frame.perceptual_hash,
}
return meta
def serialize_text_candidate(tc: TextCandidate) -> dict:
bbox_dict = dataclasses.asdict(tc.bbox)
candidate = {
"frame_sequence": tc.frame.sequence,
"bbox": bbox_dict,
"text": tc.text,
"ocr_confidence": tc.ocr_confidence,
}
return candidate
def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
"""
Serialize DetectState to a JSON-compatible dict.
Frame images are replaced with S3 key references.
TextCandidate.frame references become frame_sequence integers.
"""
frames = state.get("frames", [])
filtered = state.get("filtered_frames", [])
manifest_strs = {str(k): v for k, v in frames_manifest.items()}
frames_meta = [serialize_frame_meta(f) for f in frames]
filtered_seqs = [f.sequence for f in filtered]
boxes_serialized = {}
for seq, boxes in state.get("boxes_by_frame", {}).items():
boxes_serialized[str(seq)] = [dataclasses.asdict(b) for b in boxes]
text_candidates = [serialize_text_candidate(tc) for tc in state.get("text_candidates", [])]
unresolved = [serialize_text_candidate(tc) for tc in state.get("unresolved_candidates", [])]
detections = [dataclasses.asdict(d) for d in state.get("detections", [])]
stats = dataclasses.asdict(state.get("stats", PipelineStats()))
checkpoint = {
"job_id": state.get("job_id", ""),
"video_path": state.get("video_path", ""),
"profile_name": state.get("profile_name", ""),
"config_overrides": state.get("config_overrides", {}),
"frames_manifest": manifest_strs,
"frames_meta": frames_meta,
"filtered_frame_sequences": filtered_seqs,
"boxes_by_frame": boxes_serialized,
"text_candidates": text_candidates,
"unresolved_candidates": unresolved,
"detections": detections,
"stats": stats,
}
return checkpoint
# ---------------------------------------------------------------------------
# Deserialize helpers
# ---------------------------------------------------------------------------
def deserialize_text_candidate(d: dict, frame_map: dict[int, Frame]) -> TextCandidate:
frame = frame_map[d["frame_sequence"]]
bbox = BoundingBox(**d["bbox"])
candidate = TextCandidate(
frame=frame,
bbox=bbox,
text=d["text"],
ocr_confidence=d["ocr_confidence"],
)
return candidate
def deserialize_state(checkpoint: dict, frames: list[Frame]) -> dict:
"""Reconstitute DetectState from a checkpoint dict + loaded frames."""
frame_map = {f.sequence: f for f in frames}
filtered_seqs = set(checkpoint.get("filtered_frame_sequences", []))
filtered_frames = [f for f in frames if f.sequence in filtered_seqs]
boxes_by_frame = {}
for seq_str, box_dicts in checkpoint.get("boxes_by_frame", {}).items():
seq = int(seq_str)
boxes_by_frame[seq] = [BoundingBox(**b) for b in box_dicts]
text_candidates = [
deserialize_text_candidate(d, frame_map)
for d in checkpoint.get("text_candidates", [])
]
unresolved_candidates = [
deserialize_text_candidate(d, frame_map)
for d in checkpoint.get("unresolved_candidates", [])
]
detections = [BrandDetection(**d) for d in checkpoint.get("detections", [])]
stats = PipelineStats(**checkpoint.get("stats", {}))
state = {
"job_id": checkpoint.get("job_id", ""),
"video_path": checkpoint.get("video_path", ""),
"profile_name": checkpoint.get("profile_name", ""),
"config_overrides": checkpoint.get("config_overrides", {}),
"frames": frames,
"filtered_frames": filtered_frames,
"boxes_by_frame": boxes_by_frame,
"text_candidates": text_candidates,
"unresolved_candidates": unresolved_candidates,
"detections": detections,
"stats": stats,
}
return state

View File

@@ -0,0 +1,215 @@
"""
Checkpoint storage — save/load stage state.
Binary data (frame images) → S3/MinIO via frames.py
Structured data (boxes, detections, stats, config) → Postgres via Django ORM
Until the Django model is generated by modelgen, checkpoint data is stored
as JSON in S3 as a fallback. Once DetectJob/StageCheckpoint models exist,
this module switches to Postgres.
"""
from __future__ import annotations
import json
import logging
import os
import tempfile
from pathlib import Path
from .frames import save_frames, load_frames, BUCKET, CHECKPOINT_PREFIX
from .serializer import serialize_state, deserialize_state
logger = logging.getLogger(__name__)
def _has_db() -> bool:
"""Check if the DB layer is available (Django + models generated by modelgen)."""
try:
from core.db.detect import get_stage_checkpoint as _
# Quick check that the model exists (modelgen may not have run yet)
from admin.mpr.media_assets.models import StageCheckpoint as _
return True
except (ImportError, Exception):
return False
# ---------------------------------------------------------------------------
# Save
# ---------------------------------------------------------------------------
def save_checkpoint(
job_id: str,
stage: str,
stage_index: int,
state: dict,
frames_manifest: dict[int, str] | None = None,
) -> str:
"""
Save a stage checkpoint.
Saves frame images to S3 (if not already saved), then persists
structured state to Postgres (or S3 JSON fallback).
Returns the checkpoint identifier (DB id or S3 key).
"""
# Save frames to S3 if no manifest provided
if frames_manifest is None:
all_frames = state.get("frames", [])
frames_manifest = save_frames(job_id, all_frames)
checkpoint_data = serialize_state(state, frames_manifest)
if _has_db():
checkpoint_id = _save_to_db(job_id, stage, stage_index, checkpoint_data)
else:
checkpoint_id = _save_to_s3(job_id, stage, checkpoint_data)
return checkpoint_id
def _save_to_db(job_id: str, stage: str, stage_index: int, data: dict) -> str:
"""Save checkpoint structured data to Postgres via core/db."""
import uuid
from core.db.detect import save_stage_checkpoint
job_uuid = uuid.UUID(job_id) if isinstance(job_id, str) else job_id
checkpoint_id = uuid.uuid4()
frames_prefix = f"{CHECKPOINT_PREFIX}/{job_id}/frames/"
checkpoint = save_stage_checkpoint(
id=checkpoint_id,
job_id=job_uuid,
stage=stage,
stage_index=stage_index,
frames_prefix=frames_prefix,
frames_manifest=data.get("frames_manifest", {}),
frames_meta=data.get("frames_meta", []),
filtered_frame_sequences=data.get("filtered_frame_sequences", []),
boxes_by_frame=data.get("boxes_by_frame", {}),
text_candidates=data.get("text_candidates", []),
unresolved_candidates=data.get("unresolved_candidates", []),
detections=data.get("detections", []),
stats=data.get("stats", {}),
config_snapshot=data.get("config_overrides", {}),
config_overrides=data.get("config_overrides", {}),
video_path=data.get("video_path", ""),
profile_name=data.get("profile_name", ""),
)
logger.info("Checkpoint saved to DB: %s/%s (id=%s)", job_id, stage, checkpoint.id)
return str(checkpoint.id)
def _save_to_s3(job_id: str, stage: str, data: dict) -> str:
"""Fallback: save checkpoint as JSON to S3 (before modelgen generates DB models)."""
from core.storage.s3 import upload_file
checkpoint_json = json.dumps(data, default=str)
key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json"
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp:
tmp.write(checkpoint_json)
tmp_path = tmp.name
try:
upload_file(tmp_path, BUCKET, key)
finally:
os.unlink(tmp_path)
logger.info("Checkpoint saved to S3: s3://%s/%s", BUCKET, key)
return key
# ---------------------------------------------------------------------------
# Load
# ---------------------------------------------------------------------------
def load_checkpoint(job_id: str, stage: str) -> dict:
"""
Load a stage checkpoint and reconstitute full DetectState.
Tries Postgres first, falls back to S3 JSON.
"""
if _has_db():
data = _load_from_db(job_id, stage)
else:
data = _load_from_s3(job_id, stage)
raw_manifest = data.get("frames_manifest", {})
manifest = {int(k): v for k, v in raw_manifest.items()}
frame_metadata = data.get("frames_meta", [])
frames = load_frames(manifest, frame_metadata)
state = deserialize_state(data, frames)
logger.info("Checkpoint loaded: %s/%s (%d frames)", job_id, stage, len(frames))
return state
def _load_from_db(job_id: str, stage: str) -> dict:
"""Load checkpoint data from Postgres via core/db."""
from core.db.detect import get_stage_checkpoint
checkpoint = get_stage_checkpoint(job_id, stage)
data = {
"job_id": str(checkpoint.job_id),
"video_path": checkpoint.video_path,
"profile_name": checkpoint.profile_name,
"config_overrides": checkpoint.config_overrides,
"frames_manifest": checkpoint.frames_manifest,
"frames_meta": checkpoint.frames_meta,
"filtered_frame_sequences": checkpoint.filtered_frame_sequences,
"boxes_by_frame": checkpoint.boxes_by_frame,
"text_candidates": checkpoint.text_candidates,
"unresolved_candidates": checkpoint.unresolved_candidates,
"detections": checkpoint.detections,
"stats": checkpoint.stats,
}
return data
def _load_from_s3(job_id: str, stage: str) -> dict:
"""Fallback: load checkpoint JSON from S3."""
from core.storage.s3 import download_to_temp
key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json"
tmp_path = download_to_temp(BUCKET, key)
try:
with open(tmp_path) as f:
data = json.load(f)
finally:
os.unlink(tmp_path)
return data
# ---------------------------------------------------------------------------
# List
# ---------------------------------------------------------------------------
def list_checkpoints(job_id: str) -> list[str]:
"""List available checkpoint stages for a job."""
if _has_db():
return _list_from_db(job_id)
return _list_from_s3(job_id)
def _list_from_db(job_id: str) -> list[str]:
from core.db.detect import list_stage_checkpoints
return list_stage_checkpoints(job_id)
def _list_from_s3(job_id: str) -> list[str]:
from core.storage.s3 import list_objects
prefix = f"{CHECKPOINT_PREFIX}/{job_id}/stages/"
objects = list_objects(BUCKET, prefix)
stages = []
for obj in objects:
name = Path(obj["key"]).stem
stages.append(name)
return stages

View File

@@ -0,0 +1,71 @@
"""
Celery tasks for detection pipeline async operations.
retry_candidates: re-run VLM/cloud escalation with different config.
"""
from __future__ import annotations
import logging
import uuid
from datetime import datetime, timezone
from celery import shared_task
logger = logging.getLogger(__name__)
@shared_task(bind=True, max_retries=1, default_retry_delay=30)
def retry_candidates(
self,
job_id: str,
config_overrides: dict | None = None,
start_stage: str = "escalate_vlm",
):
"""
Retry unresolved candidates with different config.
Loads the checkpoint from the stage before start_stage,
applies config overrides (e.g. different cloud provider),
and runs from start_stage onward.
"""
from detect.checkpoint.replay import replay_from
run_id = str(uuid.uuid4())[:8]
logger.info("Retry task %s: job=%s, from=%s, overrides=%s",
run_id, job_id, start_stage, config_overrides)
try:
result = replay_from(
job_id=job_id,
start_stage=start_stage,
config_overrides=config_overrides,
)
detections = result.get("detections", [])
report = result.get("report")
brands_found = len(report.brands) if report else 0
logger.info("Retry %s complete: %d detections, %d brands",
run_id, len(detections), brands_found)
return {
"status": "completed",
"run_id": run_id,
"job_id": job_id,
"detections": len(detections),
"brands_found": brands_found,
}
except Exception as e:
logger.exception("Retry %s failed: %s", run_id, e)
if self.request.retries < self.max_retries:
raise self.retry(exc=e)
return {
"status": "failed",
"run_id": run_id,
"job_id": job_id,
"error": str(e),
}

View File

@@ -3,6 +3,9 @@ Event emission helpers for detection pipeline stages.
Single place that knows how to build event payloads.
Stages call these instead of constructing dicts or dataclasses directly.
Run context (run_id, parent_job_id) is set once at pipeline start via
set_run_context() and automatically injected into all events.
"""
from __future__ import annotations
@@ -13,9 +16,33 @@ from datetime import datetime, timezone
from detect.events import push_detect_event
from detect.models import PipelineStats
# Module-level run context — set once per pipeline invocation
_run_context: dict = {}
def set_run_context(run_id: str = "", parent_job_id: str = "", run_type: str = "initial"):
"""Set the run context for all subsequent events in this pipeline invocation."""
global _run_context
_run_context = {
"run_id": run_id,
"parent_job_id": parent_job_id,
"run_type": run_type,
}
def clear_run_context():
global _run_context
_run_context = {}
def _inject_context(payload: dict) -> dict:
"""Add run context fields to an event payload."""
if _run_context:
payload.update(_run_context)
return payload
def log(job_id: str | None, stage: str, level: str, msg: str) -> None:
"""Emit a log event."""
if not job_id:
return
payload = {
@@ -24,15 +51,17 @@ def log(job_id: str | None, stage: str, level: str, msg: str) -> None:
"msg": msg,
"ts": datetime.now(timezone.utc).isoformat(),
}
_inject_context(payload)
push_detect_event(job_id, "log", payload)
def stats(job_id: str | None, **kwargs) -> None:
"""Emit a stats_update event. Pass only the fields that changed."""
if not job_id:
return
s = PipelineStats(**kwargs)
push_detect_event(job_id, "stats_update", dataclasses.asdict(s))
payload = dataclasses.asdict(s)
_inject_context(payload)
push_detect_event(job_id, "stats_update", payload)
def frame_update(
@@ -42,7 +71,6 @@ def frame_update(
jpeg_b64: str,
boxes: list[dict],
) -> None:
"""Emit a frame_update event with the image and bounding boxes."""
if not job_id:
return
payload = {
@@ -51,14 +79,15 @@ def frame_update(
"jpeg_b64": jpeg_b64,
"boxes": boxes,
}
_inject_context(payload)
push_detect_event(job_id, "frame_update", payload)
def graph_update(job_id: str | None, nodes: list[dict]) -> None:
"""Emit a graph_update event with node states."""
if not job_id:
return
payload = {"nodes": nodes}
_inject_context(payload)
push_detect_event(job_id, "graph_update", payload)
@@ -72,7 +101,6 @@ def detection(
content_type: str = "",
frame_ref: int | None = None,
) -> None:
"""Emit a brand detection event."""
if not job_id:
return
payload = {
@@ -84,12 +112,13 @@ def detection(
"content_type": content_type,
"frame_ref": frame_ref,
}
_inject_context(payload)
push_detect_event(job_id, "detection", payload)
def job_complete(job_id: str | None, report: dict) -> None:
"""Emit a job_complete event with the final report."""
if not job_id:
return
payload = {"job_id": job_id, "report": report}
_inject_context(payload)
push_detect_event(job_id, "job_complete", payload)

View File

@@ -42,8 +42,16 @@ NODES = [
def _get_profile(state: DetectState):
name = state.get("profile_name", "soccer_broadcast")
if name == "soccer_broadcast":
return SoccerBroadcastProfile()
raise ValueError(f"Unknown profile: {name}")
profile = SoccerBroadcastProfile()
else:
raise ValueError(f"Unknown profile: {name}")
overrides = state.get("config_overrides")
if overrides:
from detect.checkpoint.replay import OverrideProfile
profile = OverrideProfile(profile, overrides)
return profile
# Track node states across the pipeline run
@@ -68,6 +76,18 @@ def _emit_transition(state: DetectState, node: str, status: str):
# --- Node functions ---
def node_extract_frames(state: DetectState) -> dict:
# Set run context for initial runs (replays set it in replay_from)
job_id = state.get("job_id", "")
if job_id and not emit._run_context:
emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial")
# Load session brands from DB for this source
source_asset_id = state.get("source_asset_id")
if source_asset_id and not state.get("session_brands"):
from detect.stages.brand_resolver import build_session_dict
session_brands = build_session_dict(source_asset_id)
state["session_brands"] = session_brands
_emit_transition(state, "extract_frames", "running")
with trace_node(state, "extract_frames") as span:
@@ -142,13 +162,16 @@ def node_match_brands(state: DetectState) -> dict:
with trace_node(state, "match_brands") as span:
profile = _get_profile(state)
dictionary = profile.brand_dictionary()
resolver_config = profile.resolver_config()
candidates = state.get("text_candidates", [])
session_brands = state.get("session_brands", {})
job_id = state.get("job_id")
source_asset_id = state.get("source_asset_id")
matched, unresolved = resolve_brands(
candidates, dictionary, resolver_config,
candidates, resolver_config,
session_brands=session_brands,
source_asset_id=source_asset_id,
content_type=profile.name, job_id=job_id,
)
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
@@ -170,6 +193,7 @@ def node_escalate_vlm(state: DetectState) -> dict:
vlm_prompt_fn=profile.vlm_prompt,
inference_url=INFERENCE_URL,
content_type=profile.name,
source_asset_id=state.get("source_asset_id"),
job_id=job_id,
)
@@ -202,6 +226,7 @@ def node_escalate_cloud(state: DetectState) -> dict:
vlm_prompt_fn=profile.vlm_prompt,
stats=stats,
content_type=profile.name,
source_asset_id=state.get("source_asset_id"),
job_id=job_id,
)
@@ -239,33 +264,87 @@ def node_compile_report(state: DetectState) -> dict:
return {"report": report}
# --- Checkpoint wrapper ---
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
def _checkpointing_node(node_name: str, node_fn):
"""Wrap a node function to auto-checkpoint after completion."""
stage_index = NODES.index(node_name)
def wrapper(state: DetectState) -> dict:
result = node_fn(state)
job_id = state.get("job_id", "")
if not job_id:
return result
from detect.checkpoint import save_checkpoint, save_frames
merged = {**state, **result}
# Save frames once (first checkpoint), reuse manifest after
manifest = _frames_manifest.get(job_id)
if manifest is None and node_name == "extract_frames":
manifest = save_frames(job_id, merged.get("frames", []))
_frames_manifest[job_id] = manifest
save_checkpoint(job_id, node_name, stage_index, merged, frames_manifest=manifest)
return result
wrapper.__name__ = node_fn.__name__
return wrapper
# --- Graph construction ---
def build_graph() -> StateGraph:
NODE_FUNCTIONS = [
("extract_frames", node_extract_frames),
("filter_scenes", node_filter_scenes),
("detect_objects", node_detect_objects),
("run_ocr", node_run_ocr),
("match_brands", node_match_brands),
("escalate_vlm", node_escalate_vlm),
("escalate_cloud", node_escalate_cloud),
("compile_report", node_compile_report),
]
def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph:
"""
Build the pipeline graph.
checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var)
start_from: skip nodes before this stage (for replay)
"""
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
graph = StateGraph(DetectState)
graph.add_node("extract_frames", node_extract_frames)
graph.add_node("filter_scenes", node_filter_scenes)
graph.add_node("detect_objects", node_detect_objects)
graph.add_node("run_ocr", node_run_ocr)
graph.add_node("match_brands", node_match_brands)
graph.add_node("escalate_vlm", node_escalate_vlm)
graph.add_node("escalate_cloud", node_escalate_cloud)
graph.add_node("compile_report", node_compile_report)
# Filter to start_from if replaying
node_pairs = NODE_FUNCTIONS
if start_from:
start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from)
node_pairs = NODE_FUNCTIONS[start_idx:]
graph.set_entry_point("extract_frames")
graph.add_edge("extract_frames", "filter_scenes")
graph.add_edge("filter_scenes", "detect_objects")
graph.add_edge("detect_objects", "run_ocr")
graph.add_edge("run_ocr", "match_brands")
graph.add_edge("match_brands", "escalate_vlm")
graph.add_edge("escalate_vlm", "escalate_cloud")
graph.add_edge("escalate_cloud", "compile_report")
graph.add_edge("compile_report", END)
for name, fn in node_pairs:
wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn
graph.add_node(name, wrapped)
# Wire edges
entry = node_pairs[0][0]
graph.set_entry_point(entry)
for i in range(len(node_pairs) - 1):
graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0])
graph.add_edge(node_pairs[-1][0], END)
return graph
def get_pipeline():
def get_pipeline(checkpoint: bool | None = None):
"""Return a compiled, runnable pipeline."""
return build_graph().compile()
return build_graph(checkpoint=checkpoint).compile()

View File

@@ -1,6 +1,5 @@
from .base import (
ContentTypeProfile,
BrandDictionary,
CropContext,
DetectionConfig,
FrameExtractionConfig,
@@ -12,7 +11,6 @@ from .soccer import SoccerBroadcastProfile
__all__ = [
"ContentTypeProfile",
"BrandDictionary",
"CropContext",
"DetectionConfig",
"FrameExtractionConfig",

View File

@@ -44,12 +44,6 @@ class ResolverConfig:
fuzzy_threshold: int = 75
@dataclass
class BrandDictionary:
"""Maps canonical brand name → list of known aliases/spellings."""
brands: dict[str, list[str]] = field(default_factory=dict)
@dataclass
class CropContext:
image: bytes
@@ -64,7 +58,6 @@ class ContentTypeProfile(Protocol):
def scene_filter_config(self) -> SceneFilterConfig: ...
def detection_config(self) -> DetectionConfig: ...
def ocr_config(self) -> OCRConfig: ...
def brand_dictionary(self) -> BrandDictionary: ...
def resolver_config(self) -> ResolverConfig: ...
def vlm_prompt(self, crop_context: CropContext) -> str: ...
def aggregate(self, detections: list[BrandDetection]) -> DetectionReport: ...

View File

@@ -5,7 +5,6 @@ from __future__ import annotations
from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
from .base import (
BrandDictionary,
CropContext,
DetectionConfig,
FrameExtractionConfig,
@@ -34,22 +33,6 @@ class SoccerBroadcastProfile:
def ocr_config(self) -> OCRConfig:
return OCRConfig(languages=["en", "es"], min_confidence=0.5)
def brand_dictionary(self) -> BrandDictionary:
return BrandDictionary(brands={
"Nike": ["nike", "NIKE", "swoosh"],
"Adidas": ["adidas", "ADIDAS", "adi"],
"Puma": ["puma", "PUMA"],
"Emirates": ["emirates", "fly emirates", "EMIRATES"],
"Coca-Cola": ["coca-cola", "coca cola", "coke", "COCA-COLA"],
"Pepsi": ["pepsi", "PEPSI"],
"Mastercard": ["mastercard", "MASTERCARD"],
"Heineken": ["heineken", "HEINEKEN"],
"Santander": ["santander", "SANTANDER"],
"Gazprom": ["gazprom", "GAZPROM"],
"Qatar Airways": ["qatar airways", "QATAR AIRWAYS"],
"Lay's": ["lays", "lay's", "LAYS", "LAY'S"],
})
def resolver_config(self) -> ResolverConfig:
return ResolverConfig(fuzzy_threshold=75)

View File

@@ -5,7 +5,6 @@ from __future__ import annotations
from detect.models import BrandDetection, DetectionReport
from .base import (
BrandDictionary,
CropContext,
DetectionConfig,
FrameExtractionConfig,
@@ -30,9 +29,6 @@ class NewsBroadcastProfile:
def ocr_config(self) -> OCRConfig:
raise NotImplementedError
def brand_dictionary(self) -> BrandDictionary:
raise NotImplementedError
def resolver_config(self) -> ResolverConfig:
raise NotImplementedError
@@ -61,9 +57,6 @@ class AdvertisingProfile:
def ocr_config(self) -> OCRConfig:
raise NotImplementedError
def brand_dictionary(self) -> BrandDictionary:
raise NotImplementedError
def resolver_config(self) -> ResolverConfig:
raise NotImplementedError
@@ -92,9 +85,6 @@ class TranscriptProfile:
def ocr_config(self) -> OCRConfig:
raise NotImplementedError
def brand_dictionary(self) -> BrandDictionary:
raise NotImplementedError
def resolver_config(self) -> ResolverConfig:
raise NotImplementedError

View File

@@ -1,9 +1,17 @@
"""
Stage 5 — Brand Resolver
Stage 5 — Brand Resolver (discovery mode)
Matches OCR text against the profile's brand dictionary.
Uses exact matching first, then fuzzy matching (rapidfuzz) as fallback.
Emits detection events for confirmed brands.
Discovery-first brand matching. No static dictionary — all brands live in the DB.
Flow:
1. Check session sightings first (brands already seen in this source)
2. Check global known brands (accumulated across all runs)
3. Unresolved candidates → escalate to VLM/cloud
4. Confirmed brands get added to DB for future runs
The resolver is an enricher, not a gatekeeper. Every OCR text candidate
passes through — the question is whether we can resolve it cheaply (DB lookup)
or need to escalate (VLM/cloud).
"""
from __future__ import annotations
@@ -14,99 +22,199 @@ from rapidfuzz import fuzz
from detect import emit
from detect.models import BrandDetection, TextCandidate
from detect.profiles.base import BrandDictionary, ResolverConfig
from detect.profiles.base import ResolverConfig
logger = logging.getLogger(__name__)
def _normalize(text: str) -> str:
"""Normalize text for matching."""
return text.strip().lower()
def _exact_match(text: str, dictionary: BrandDictionary) -> str | None:
"""Try exact match against all aliases."""
def _has_db() -> bool:
try:
from core.db.detect import find_brand_by_text as _
from admin.mpr.media_assets.models import KnownBrand as _
return True
except (ImportError, Exception):
return False
def _match_session(text: str, session_brands: dict[str, str]) -> str | None:
"""
Check against session brands (already seen in this source).
session_brands: {normalized_name: canonical_name, ...}
Includes aliases.
"""
normalized = _normalize(text)
for canonical, aliases in dictionary.brands.items():
if normalized == _normalize(canonical):
return canonical
for alias in aliases:
if normalized == _normalize(alias):
return canonical
return None
return session_brands.get(normalized)
def _fuzzy_match(text: str, dictionary: BrandDictionary, threshold: int) -> tuple[str | None, int]:
"""Try fuzzy match, return (brand, score) or (None, 0)."""
def _match_known(text: str, threshold: int) -> tuple[str | None, str | None]:
"""
Check against global known brands in DB.
Returns (canonical_name, brand_id) or (None, None).
"""
if not _has_db():
return None, None
from core.db.detect import find_brand_by_text
brand = find_brand_by_text(text)
if brand:
return brand.canonical_name, str(brand.id)
# Fuzzy match against all known brands
from core.db.detect import list_all_brands
all_brands = list_all_brands()
normalized = _normalize(text)
best_brand = None
best_score = 0
for canonical, aliases in dictionary.brands.items():
all_names = [canonical] + aliases
for name in all_names:
for known in all_brands:
names = [known.canonical_name] + (known.aliases or [])
for name in names:
score = fuzz.ratio(normalized, _normalize(name))
if score > best_score and score >= threshold:
best_score = score
best_brand = canonical
best_brand = known
return best_brand, best_score
if best_brand:
return best_brand.canonical_name, str(best_brand.id)
return None, None
def _register_brand(canonical_name: str, source: str) -> str | None:
"""Register a newly discovered brand in the DB. Returns brand_id."""
if not _has_db():
return None
from core.db.detect import get_or_create_brand
brand, created = get_or_create_brand(canonical_name, source=source)
if created:
logger.info("New brand discovered: %s (source=%s)", canonical_name, source)
return str(brand.id)
def _record_sighting(source_asset_id: str | None, brand_id: str,
brand_name: str, timestamp: float,
confidence: float, source: str):
"""Record a brand sighting for this source."""
if not _has_db() or not source_asset_id:
return
from core.db.detect import record_sighting
import uuid
asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id
brand_uuid = uuid.UUID(brand_id) if isinstance(brand_id, str) else brand_id
record_sighting(asset_id, brand_uuid, brand_name, timestamp, confidence, source)
def build_session_dict(source_asset_id: str | None) -> dict[str, str]:
"""
Load session brands from DB for this source.
Returns {normalized_name: canonical_name, ...} including aliases.
"""
if not _has_db() or not source_asset_id:
return {}
from core.db.detect import get_source_sightings
import uuid
asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id
sightings = get_source_sightings(asset_id)
session = {}
for s in sightings:
canonical = s.brand_name
session[_normalize(canonical)] = canonical
# Also load aliases from KnownBrand for each sighted brand
if _has_db():
from core.db.detect import list_all_brands
all_brands = list_all_brands()
sighted_names = {s.brand_name for s in sightings}
for brand in all_brands:
if brand.canonical_name in sighted_names:
for alias in (brand.aliases or []):
session[_normalize(alias)] = brand.canonical_name
return session
def resolve_brands(
candidates: list[TextCandidate],
dictionary: BrandDictionary,
config: ResolverConfig,
session_brands: dict[str, str] | None = None,
source_asset_id: str | None = None,
content_type: str = "",
job_id: str | None = None,
) -> tuple[list[BrandDetection], list[TextCandidate]]:
"""
Match text candidates against the brand dictionary.
Match text candidates against known brands (session → global → unresolved).
Returns:
- matched: list of BrandDetection for confirmed brands
- unresolved: list of TextCandidate that couldn't be matched
session_brands: pre-loaded session dict (from build_session_dict)
source_asset_id: for recording new sightings in DB
"""
if session_brands is None:
session_brands = {}
emit.log(job_id, "BrandResolver", "INFO",
f"Matching {len(candidates)} candidates against "
f"{len(dictionary.brands)} brands (fuzzy_threshold={config.fuzzy_threshold})")
f"Resolving {len(candidates)} candidates "
f"(session={len(session_brands)} brands, fuzzy={config.fuzzy_threshold})")
matched: list[BrandDetection] = []
unresolved: list[TextCandidate] = []
exact_count = 0
fuzzy_count = 0
session_hits = 0
known_hits = 0
for candidate in candidates:
# Try exact match first
brand = _exact_match(candidate.text, dictionary)
source = "ocr"
text = candidate.text
brand_name = None
brand_id = None
match_source = "ocr"
if brand:
exact_count += 1
# 1. Check session (cheapest — in-memory dict)
brand_name = _match_session(text, session_brands)
if brand_name:
session_hits += 1
else:
# Try fuzzy match
brand, score = _fuzzy_match(candidate.text, dictionary, config.fuzzy_threshold)
if brand:
fuzzy_count += 1
# 2. Check global known brands (DB query + fuzzy)
brand_name, brand_id = _match_known(text, config.fuzzy_threshold)
if brand_name:
known_hits += 1
# Add to session for subsequent candidates in this run
session_brands[_normalize(brand_name)] = brand_name
if brand:
if brand_name:
detection = BrandDetection(
brand=brand,
brand=brand_name,
timestamp=candidate.frame.timestamp,
duration=0.5,
confidence=candidate.ocr_confidence,
source=source,
source=match_source,
bbox=candidate.bbox,
frame_ref=candidate.frame.sequence,
content_type=content_type,
)
matched.append(detection)
# Record sighting in DB
if brand_id:
_record_sighting(
source_asset_id, brand_id, brand_name,
candidate.frame.timestamp, candidate.ocr_confidence, match_source,
)
emit.detection(
job_id,
brand=brand,
brand=brand_name,
confidence=candidate.ocr_confidence,
source=source,
source=match_source,
timestamp=candidate.frame.timestamp,
content_type=content_type,
frame_ref=candidate.frame.sequence,
@@ -115,7 +223,7 @@ def resolve_brands(
unresolved.append(candidate)
emit.log(job_id, "BrandResolver", "INFO",
f"Exact: {exact_count}, Fuzzy: {fuzzy_count}, "
f"Unresolved: {len(unresolved)} → escalating to VLM")
f"Session: {session_hits}, Known: {known_hits}, "
f"Unresolved: {len(unresolved)} → escalating")
return matched, unresolved

View File

@@ -27,6 +27,18 @@ logger = logging.getLogger(__name__)
ESTIMATED_TOKENS_PER_CROP = 500
def _register_discovered_brand(brand: str, source_asset_id: str | None,
timestamp: float, confidence: float):
"""Register a cloud-confirmed brand in the DB."""
try:
from detect.stages.brand_resolver import _register_brand, _record_sighting
brand_id = _register_brand(brand, "cloud_llm")
if brand_id and source_asset_id:
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, "cloud_llm")
except Exception as e:
logger.debug("Failed to register brand %s: %s", brand, e)
def _encode_crop(crop: np.ndarray) -> str:
img = Image.fromarray(crop)
buf = io.BytesIO()
@@ -84,6 +96,7 @@ def escalate_cloud(
stats: PipelineStats,
min_confidence: float = 0.4,
content_type: str = "",
source_asset_id: str | None = None,
job_id: str | None = None,
) -> list[BrandDetection]:
"""
@@ -158,6 +171,10 @@ def escalate_cloud(
frame_ref=candidate.frame.sequence,
)
# Register newly discovered brand in DB
_register_discovered_brand(brand, source_asset_id,
candidate.frame.timestamp, confidence)
stats.estimated_cloud_cost_usd += total_cost
stats.regions_escalated_to_cloud_llm = len(candidates)

View File

@@ -19,6 +19,18 @@ from detect.profiles.base import CropContext
logger = logging.getLogger(__name__)
def _register_discovered_brand(brand: str, source_asset_id: str | None,
timestamp: float, confidence: float, source: str):
"""Register a VLM-confirmed brand in the DB."""
try:
from detect.stages.brand_resolver import _register_brand, _record_sighting
brand_id = _register_brand(brand, source)
if brand_id and source_asset_id:
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, source)
except Exception as e:
logger.debug("Failed to register brand %s: %s", brand, e)
def _crop_image(candidate: TextCandidate) -> np.ndarray:
frame = candidate.frame
box = candidate.bbox
@@ -36,6 +48,7 @@ def escalate_vlm(
inference_url: str | None = None,
min_confidence: float = 0.5,
content_type: str = "",
source_asset_id: str | None = None,
job_id: str | None = None,
) -> tuple[list[BrandDetection], list[TextCandidate]]:
"""
@@ -107,6 +120,10 @@ def escalate_vlm(
frame_ref=candidate.frame.sequence,
)
# Register newly discovered brand in DB
_register_discovered_brand(brand, source_asset_id,
candidate.frame.timestamp, confidence, "local_vlm")
logger.debug("VLM matched: %s (%.2f) — %s", brand, confidence, reasoning)
else:
still_unresolved.append(candidate)

View File

@@ -17,6 +17,7 @@ class DetectState(TypedDict, total=False):
video_path: str
job_id: str
profile_name: str
source_asset_id: str # UUID of the source MediaAsset
# Stage outputs
frames: list[Frame]
@@ -27,5 +28,11 @@ class DetectState(TypedDict, total=False):
detections: list[BrandDetection]
report: DetectionReport
# Session brands (accumulated during the run, persisted to DB)
session_brands: dict # {normalized_name: canonical_name}
# Running stats (updated by each stage)
stats: PipelineStats
# Config overrides for replay (applied via OverrideProfile)
config_overrides: dict