phase 10
This commit is contained in:
14
detect/checkpoint/__init__.py
Normal file
14
detect/checkpoint/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Stage checkpoint, replay, and retry.
|
||||
|
||||
detect/checkpoint/
|
||||
frames.py — frame image S3 upload/download
|
||||
serializer.py — state ↔ JSON conversion
|
||||
storage.py — checkpoint save/load/list (Postgres + S3)
|
||||
replay.py — replay_from, OverrideProfile
|
||||
tasks.py — retry_candidates Celery task
|
||||
"""
|
||||
|
||||
from .storage import save_checkpoint, load_checkpoint, list_checkpoints
|
||||
from .frames import save_frames, load_frames
|
||||
from .replay import replay_from, OverrideProfile
|
||||
80
detect/checkpoint/frames.py
Normal file
80
detect/checkpoint/frames.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Frame image storage — save/load to S3/MinIO as JPEGs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from detect.models import Frame
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BUCKET = os.environ.get("S3_BUCKET_OUT", "mpr-media-out")
|
||||
CHECKPOINT_PREFIX = "checkpoints"
|
||||
|
||||
|
||||
def save_frames(job_id: str, frames: list[Frame]) -> dict[int, str]:
|
||||
"""
|
||||
Save frame images to S3 as JPEGs.
|
||||
|
||||
Returns manifest: {sequence: s3_key}
|
||||
"""
|
||||
from core.storage.s3 import upload_file
|
||||
|
||||
manifest = {}
|
||||
|
||||
for frame in frames:
|
||||
key = f"{CHECKPOINT_PREFIX}/{job_id}/frames/{frame.sequence}.jpg"
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
|
||||
img = Image.fromarray(frame.image)
|
||||
img.save(tmp, format="JPEG", quality=85)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
upload_file(tmp_path, BUCKET, key)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
manifest[frame.sequence] = key
|
||||
|
||||
logger.info("Saved %d frames to s3://%s/%s/%s/frames/",
|
||||
len(frames), BUCKET, CHECKPOINT_PREFIX, job_id)
|
||||
return manifest
|
||||
|
||||
|
||||
def load_frames(manifest: dict[int, str], frame_metadata: list[dict]) -> list[Frame]:
|
||||
"""
|
||||
Load frame images from S3 and reconstitute Frame objects.
|
||||
|
||||
frame_metadata: list of dicts with sequence, chunk_id, timestamp, perceptual_hash.
|
||||
"""
|
||||
from core.storage.s3 import download_to_temp
|
||||
|
||||
meta_map = {m["sequence"]: m for m in frame_metadata}
|
||||
frames = []
|
||||
|
||||
for seq, key in manifest.items():
|
||||
tmp_path = download_to_temp(BUCKET, key)
|
||||
try:
|
||||
img = Image.open(tmp_path).convert("RGB")
|
||||
image_array = np.array(img)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
meta = meta_map.get(seq, {})
|
||||
frame = Frame(
|
||||
sequence=seq,
|
||||
chunk_id=meta.get("chunk_id", 0),
|
||||
timestamp=meta.get("timestamp", 0.0),
|
||||
image=image_array,
|
||||
perceptual_hash=meta.get("perceptual_hash", ""),
|
||||
)
|
||||
frames.append(frame)
|
||||
|
||||
frames.sort(key=lambda f: f.sequence)
|
||||
return frames
|
||||
132
detect/checkpoint/replay.py
Normal file
132
detect/checkpoint/replay.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
Pipeline replay — re-run from any stage with different config.
|
||||
|
||||
Loads a checkpoint, applies config overrides, builds a subgraph
|
||||
starting from the target stage, and invokes it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import uuid
|
||||
|
||||
from detect import emit
|
||||
from detect.checkpoint import load_checkpoint, list_checkpoints
|
||||
from detect.graph import NODES, build_graph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OverrideProfile:
|
||||
"""
|
||||
Wraps a ContentTypeProfile and patches config methods with overrides.
|
||||
|
||||
Override dict structure:
|
||||
{
|
||||
"frame_extraction": {"fps": 1.0},
|
||||
"scene_filter": {"hamming_threshold": 12},
|
||||
"detection": {"confidence_threshold": 0.5},
|
||||
"ocr": {"languages": ["en", "es"], "min_confidence": 0.3},
|
||||
"resolver": {"fuzzy_threshold": 60},
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, base, overrides: dict):
|
||||
self._base = base
|
||||
self._overrides = overrides
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._base, name)
|
||||
|
||||
def _patch(self, config, key: str):
|
||||
patches = self._overrides.get(key, {})
|
||||
for k, v in patches.items():
|
||||
if hasattr(config, k):
|
||||
setattr(config, k, v)
|
||||
return config
|
||||
|
||||
def frame_extraction_config(self):
|
||||
return self._patch(self._base.frame_extraction_config(), "frame_extraction")
|
||||
|
||||
def scene_filter_config(self):
|
||||
return self._patch(self._base.scene_filter_config(), "scene_filter")
|
||||
|
||||
def detection_config(self):
|
||||
return self._patch(self._base.detection_config(), "detection")
|
||||
|
||||
def ocr_config(self):
|
||||
return self._patch(self._base.ocr_config(), "ocr")
|
||||
|
||||
def resolver_config(self):
|
||||
return self._patch(self._base.resolver_config(), "resolver")
|
||||
|
||||
def vlm_prompt(self, crop_context):
|
||||
return self._base.vlm_prompt(crop_context)
|
||||
|
||||
def aggregate(self, detections):
|
||||
return self._base.aggregate(detections)
|
||||
|
||||
def auxiliary_detections(self, source):
|
||||
return self._base.auxiliary_detections(source)
|
||||
|
||||
|
||||
def replay_from(
|
||||
job_id: str,
|
||||
start_stage: str,
|
||||
config_overrides: dict | None = None,
|
||||
checkpoint: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Replay the pipeline from a specific stage.
|
||||
|
||||
Loads the checkpoint from the stage immediately before start_stage,
|
||||
applies config overrides, and runs the subgraph from start_stage onward.
|
||||
|
||||
Returns the final state dict.
|
||||
"""
|
||||
if start_stage not in NODES:
|
||||
raise ValueError(f"Unknown stage: {start_stage!r}. Options: {NODES}")
|
||||
|
||||
start_idx = NODES.index(start_stage)
|
||||
|
||||
# Load checkpoint from the stage before start_stage
|
||||
if start_idx == 0:
|
||||
raise ValueError("Cannot replay from the first stage — just run the full pipeline")
|
||||
|
||||
previous_stage = NODES[start_idx - 1]
|
||||
|
||||
available = list_checkpoints(job_id)
|
||||
if previous_stage not in available:
|
||||
raise ValueError(
|
||||
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
|
||||
f"Available: {available}"
|
||||
)
|
||||
|
||||
logger.info("Replaying job %s from %s (loading checkpoint: %s)",
|
||||
job_id, start_stage, previous_stage)
|
||||
|
||||
state = load_checkpoint(job_id, previous_stage)
|
||||
|
||||
# Apply config overrides
|
||||
if config_overrides:
|
||||
state["config_overrides"] = config_overrides
|
||||
|
||||
# Set run context for SSE events
|
||||
run_id = str(uuid.uuid4())[:8]
|
||||
emit.set_run_context(
|
||||
run_id=run_id,
|
||||
parent_job_id=job_id,
|
||||
run_type="replay",
|
||||
)
|
||||
|
||||
# Build subgraph starting from start_stage
|
||||
graph = build_graph(checkpoint=checkpoint, start_from=start_stage)
|
||||
pipeline = graph.compile()
|
||||
|
||||
try:
|
||||
result = pipeline.invoke(state)
|
||||
finally:
|
||||
emit.clear_run_context()
|
||||
|
||||
return result
|
||||
133
detect/checkpoint/serializer.py
Normal file
133
detect/checkpoint/serializer.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""State serialization — DetectState ↔ JSON-compatible dict."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
|
||||
from detect.models import (
|
||||
BoundingBox,
|
||||
BrandDetection,
|
||||
Frame,
|
||||
PipelineStats,
|
||||
TextCandidate,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialize helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def serialize_frame_meta(frame: Frame) -> dict:
|
||||
meta = {
|
||||
"sequence": frame.sequence,
|
||||
"chunk_id": frame.chunk_id,
|
||||
"timestamp": frame.timestamp,
|
||||
"perceptual_hash": frame.perceptual_hash,
|
||||
}
|
||||
return meta
|
||||
|
||||
|
||||
def serialize_text_candidate(tc: TextCandidate) -> dict:
|
||||
bbox_dict = dataclasses.asdict(tc.bbox)
|
||||
candidate = {
|
||||
"frame_sequence": tc.frame.sequence,
|
||||
"bbox": bbox_dict,
|
||||
"text": tc.text,
|
||||
"ocr_confidence": tc.ocr_confidence,
|
||||
}
|
||||
return candidate
|
||||
|
||||
|
||||
def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
|
||||
"""
|
||||
Serialize DetectState to a JSON-compatible dict.
|
||||
|
||||
Frame images are replaced with S3 key references.
|
||||
TextCandidate.frame references become frame_sequence integers.
|
||||
"""
|
||||
frames = state.get("frames", [])
|
||||
filtered = state.get("filtered_frames", [])
|
||||
|
||||
manifest_strs = {str(k): v for k, v in frames_manifest.items()}
|
||||
frames_meta = [serialize_frame_meta(f) for f in frames]
|
||||
filtered_seqs = [f.sequence for f in filtered]
|
||||
|
||||
boxes_serialized = {}
|
||||
for seq, boxes in state.get("boxes_by_frame", {}).items():
|
||||
boxes_serialized[str(seq)] = [dataclasses.asdict(b) for b in boxes]
|
||||
|
||||
text_candidates = [serialize_text_candidate(tc) for tc in state.get("text_candidates", [])]
|
||||
unresolved = [serialize_text_candidate(tc) for tc in state.get("unresolved_candidates", [])]
|
||||
detections = [dataclasses.asdict(d) for d in state.get("detections", [])]
|
||||
stats = dataclasses.asdict(state.get("stats", PipelineStats()))
|
||||
|
||||
checkpoint = {
|
||||
"job_id": state.get("job_id", ""),
|
||||
"video_path": state.get("video_path", ""),
|
||||
"profile_name": state.get("profile_name", ""),
|
||||
"config_overrides": state.get("config_overrides", {}),
|
||||
"frames_manifest": manifest_strs,
|
||||
"frames_meta": frames_meta,
|
||||
"filtered_frame_sequences": filtered_seqs,
|
||||
"boxes_by_frame": boxes_serialized,
|
||||
"text_candidates": text_candidates,
|
||||
"unresolved_candidates": unresolved,
|
||||
"detections": detections,
|
||||
"stats": stats,
|
||||
}
|
||||
return checkpoint
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deserialize helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def deserialize_text_candidate(d: dict, frame_map: dict[int, Frame]) -> TextCandidate:
|
||||
frame = frame_map[d["frame_sequence"]]
|
||||
bbox = BoundingBox(**d["bbox"])
|
||||
candidate = TextCandidate(
|
||||
frame=frame,
|
||||
bbox=bbox,
|
||||
text=d["text"],
|
||||
ocr_confidence=d["ocr_confidence"],
|
||||
)
|
||||
return candidate
|
||||
|
||||
|
||||
def deserialize_state(checkpoint: dict, frames: list[Frame]) -> dict:
|
||||
"""Reconstitute DetectState from a checkpoint dict + loaded frames."""
|
||||
frame_map = {f.sequence: f for f in frames}
|
||||
|
||||
filtered_seqs = set(checkpoint.get("filtered_frame_sequences", []))
|
||||
filtered_frames = [f for f in frames if f.sequence in filtered_seqs]
|
||||
|
||||
boxes_by_frame = {}
|
||||
for seq_str, box_dicts in checkpoint.get("boxes_by_frame", {}).items():
|
||||
seq = int(seq_str)
|
||||
boxes_by_frame[seq] = [BoundingBox(**b) for b in box_dicts]
|
||||
|
||||
text_candidates = [
|
||||
deserialize_text_candidate(d, frame_map)
|
||||
for d in checkpoint.get("text_candidates", [])
|
||||
]
|
||||
unresolved_candidates = [
|
||||
deserialize_text_candidate(d, frame_map)
|
||||
for d in checkpoint.get("unresolved_candidates", [])
|
||||
]
|
||||
detections = [BrandDetection(**d) for d in checkpoint.get("detections", [])]
|
||||
stats = PipelineStats(**checkpoint.get("stats", {}))
|
||||
|
||||
state = {
|
||||
"job_id": checkpoint.get("job_id", ""),
|
||||
"video_path": checkpoint.get("video_path", ""),
|
||||
"profile_name": checkpoint.get("profile_name", ""),
|
||||
"config_overrides": checkpoint.get("config_overrides", {}),
|
||||
"frames": frames,
|
||||
"filtered_frames": filtered_frames,
|
||||
"boxes_by_frame": boxes_by_frame,
|
||||
"text_candidates": text_candidates,
|
||||
"unresolved_candidates": unresolved_candidates,
|
||||
"detections": detections,
|
||||
"stats": stats,
|
||||
}
|
||||
return state
|
||||
215
detect/checkpoint/storage.py
Normal file
215
detect/checkpoint/storage.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
Checkpoint storage — save/load stage state.
|
||||
|
||||
Binary data (frame images) → S3/MinIO via frames.py
|
||||
Structured data (boxes, detections, stats, config) → Postgres via Django ORM
|
||||
|
||||
Until the Django model is generated by modelgen, checkpoint data is stored
|
||||
as JSON in S3 as a fallback. Once DetectJob/StageCheckpoint models exist,
|
||||
this module switches to Postgres.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from .frames import save_frames, load_frames, BUCKET, CHECKPOINT_PREFIX
|
||||
from .serializer import serialize_state, deserialize_state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _has_db() -> bool:
|
||||
"""Check if the DB layer is available (Django + models generated by modelgen)."""
|
||||
try:
|
||||
from core.db.detect import get_stage_checkpoint as _
|
||||
# Quick check that the model exists (modelgen may not have run yet)
|
||||
from admin.mpr.media_assets.models import StageCheckpoint as _
|
||||
return True
|
||||
except (ImportError, Exception):
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Save
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def save_checkpoint(
|
||||
job_id: str,
|
||||
stage: str,
|
||||
stage_index: int,
|
||||
state: dict,
|
||||
frames_manifest: dict[int, str] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Save a stage checkpoint.
|
||||
|
||||
Saves frame images to S3 (if not already saved), then persists
|
||||
structured state to Postgres (or S3 JSON fallback).
|
||||
|
||||
Returns the checkpoint identifier (DB id or S3 key).
|
||||
"""
|
||||
# Save frames to S3 if no manifest provided
|
||||
if frames_manifest is None:
|
||||
all_frames = state.get("frames", [])
|
||||
frames_manifest = save_frames(job_id, all_frames)
|
||||
|
||||
checkpoint_data = serialize_state(state, frames_manifest)
|
||||
|
||||
if _has_db():
|
||||
checkpoint_id = _save_to_db(job_id, stage, stage_index, checkpoint_data)
|
||||
else:
|
||||
checkpoint_id = _save_to_s3(job_id, stage, checkpoint_data)
|
||||
|
||||
return checkpoint_id
|
||||
|
||||
|
||||
def _save_to_db(job_id: str, stage: str, stage_index: int, data: dict) -> str:
|
||||
"""Save checkpoint structured data to Postgres via core/db."""
|
||||
import uuid
|
||||
from core.db.detect import save_stage_checkpoint
|
||||
|
||||
job_uuid = uuid.UUID(job_id) if isinstance(job_id, str) else job_id
|
||||
checkpoint_id = uuid.uuid4()
|
||||
frames_prefix = f"{CHECKPOINT_PREFIX}/{job_id}/frames/"
|
||||
|
||||
checkpoint = save_stage_checkpoint(
|
||||
id=checkpoint_id,
|
||||
job_id=job_uuid,
|
||||
stage=stage,
|
||||
stage_index=stage_index,
|
||||
frames_prefix=frames_prefix,
|
||||
frames_manifest=data.get("frames_manifest", {}),
|
||||
frames_meta=data.get("frames_meta", []),
|
||||
filtered_frame_sequences=data.get("filtered_frame_sequences", []),
|
||||
boxes_by_frame=data.get("boxes_by_frame", {}),
|
||||
text_candidates=data.get("text_candidates", []),
|
||||
unresolved_candidates=data.get("unresolved_candidates", []),
|
||||
detections=data.get("detections", []),
|
||||
stats=data.get("stats", {}),
|
||||
config_snapshot=data.get("config_overrides", {}),
|
||||
config_overrides=data.get("config_overrides", {}),
|
||||
video_path=data.get("video_path", ""),
|
||||
profile_name=data.get("profile_name", ""),
|
||||
)
|
||||
|
||||
logger.info("Checkpoint saved to DB: %s/%s (id=%s)", job_id, stage, checkpoint.id)
|
||||
return str(checkpoint.id)
|
||||
|
||||
|
||||
def _save_to_s3(job_id: str, stage: str, data: dict) -> str:
|
||||
"""Fallback: save checkpoint as JSON to S3 (before modelgen generates DB models)."""
|
||||
from core.storage.s3 import upload_file
|
||||
|
||||
checkpoint_json = json.dumps(data, default=str)
|
||||
key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json"
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp:
|
||||
tmp.write(checkpoint_json)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
upload_file(tmp_path, BUCKET, key)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
logger.info("Checkpoint saved to S3: s3://%s/%s", BUCKET, key)
|
||||
return key
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Load
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_checkpoint(job_id: str, stage: str) -> dict:
|
||||
"""
|
||||
Load a stage checkpoint and reconstitute full DetectState.
|
||||
|
||||
Tries Postgres first, falls back to S3 JSON.
|
||||
"""
|
||||
if _has_db():
|
||||
data = _load_from_db(job_id, stage)
|
||||
else:
|
||||
data = _load_from_s3(job_id, stage)
|
||||
|
||||
raw_manifest = data.get("frames_manifest", {})
|
||||
manifest = {int(k): v for k, v in raw_manifest.items()}
|
||||
frame_metadata = data.get("frames_meta", [])
|
||||
frames = load_frames(manifest, frame_metadata)
|
||||
|
||||
state = deserialize_state(data, frames)
|
||||
|
||||
logger.info("Checkpoint loaded: %s/%s (%d frames)", job_id, stage, len(frames))
|
||||
return state
|
||||
|
||||
|
||||
def _load_from_db(job_id: str, stage: str) -> dict:
|
||||
"""Load checkpoint data from Postgres via core/db."""
|
||||
from core.db.detect import get_stage_checkpoint
|
||||
|
||||
checkpoint = get_stage_checkpoint(job_id, stage)
|
||||
|
||||
data = {
|
||||
"job_id": str(checkpoint.job_id),
|
||||
"video_path": checkpoint.video_path,
|
||||
"profile_name": checkpoint.profile_name,
|
||||
"config_overrides": checkpoint.config_overrides,
|
||||
"frames_manifest": checkpoint.frames_manifest,
|
||||
"frames_meta": checkpoint.frames_meta,
|
||||
"filtered_frame_sequences": checkpoint.filtered_frame_sequences,
|
||||
"boxes_by_frame": checkpoint.boxes_by_frame,
|
||||
"text_candidates": checkpoint.text_candidates,
|
||||
"unresolved_candidates": checkpoint.unresolved_candidates,
|
||||
"detections": checkpoint.detections,
|
||||
"stats": checkpoint.stats,
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
def _load_from_s3(job_id: str, stage: str) -> dict:
|
||||
"""Fallback: load checkpoint JSON from S3."""
|
||||
from core.storage.s3 import download_to_temp
|
||||
|
||||
key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json"
|
||||
tmp_path = download_to_temp(BUCKET, key)
|
||||
try:
|
||||
with open(tmp_path) as f:
|
||||
data = json.load(f)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# List
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def list_checkpoints(job_id: str) -> list[str]:
|
||||
"""List available checkpoint stages for a job."""
|
||||
if _has_db():
|
||||
return _list_from_db(job_id)
|
||||
return _list_from_s3(job_id)
|
||||
|
||||
|
||||
def _list_from_db(job_id: str) -> list[str]:
|
||||
from core.db.detect import list_stage_checkpoints
|
||||
return list_stage_checkpoints(job_id)
|
||||
|
||||
|
||||
def _list_from_s3(job_id: str) -> list[str]:
|
||||
from core.storage.s3 import list_objects
|
||||
|
||||
prefix = f"{CHECKPOINT_PREFIX}/{job_id}/stages/"
|
||||
objects = list_objects(BUCKET, prefix)
|
||||
|
||||
stages = []
|
||||
for obj in objects:
|
||||
name = Path(obj["key"]).stem
|
||||
stages.append(name)
|
||||
|
||||
return stages
|
||||
71
detect/checkpoint/tasks.py
Normal file
71
detect/checkpoint/tasks.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
Celery tasks for detection pipeline async operations.
|
||||
|
||||
retry_candidates: re-run VLM/cloud escalation with different config.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from celery import shared_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(bind=True, max_retries=1, default_retry_delay=30)
|
||||
def retry_candidates(
|
||||
self,
|
||||
job_id: str,
|
||||
config_overrides: dict | None = None,
|
||||
start_stage: str = "escalate_vlm",
|
||||
):
|
||||
"""
|
||||
Retry unresolved candidates with different config.
|
||||
|
||||
Loads the checkpoint from the stage before start_stage,
|
||||
applies config overrides (e.g. different cloud provider),
|
||||
and runs from start_stage onward.
|
||||
"""
|
||||
from detect.checkpoint.replay import replay_from
|
||||
|
||||
run_id = str(uuid.uuid4())[:8]
|
||||
logger.info("Retry task %s: job=%s, from=%s, overrides=%s",
|
||||
run_id, job_id, start_stage, config_overrides)
|
||||
|
||||
try:
|
||||
result = replay_from(
|
||||
job_id=job_id,
|
||||
start_stage=start_stage,
|
||||
config_overrides=config_overrides,
|
||||
)
|
||||
|
||||
detections = result.get("detections", [])
|
||||
report = result.get("report")
|
||||
brands_found = len(report.brands) if report else 0
|
||||
|
||||
logger.info("Retry %s complete: %d detections, %d brands",
|
||||
run_id, len(detections), brands_found)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"run_id": run_id,
|
||||
"job_id": job_id,
|
||||
"detections": len(detections),
|
||||
"brands_found": brands_found,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Retry %s failed: %s", run_id, e)
|
||||
|
||||
if self.request.retries < self.max_retries:
|
||||
raise self.retry(exc=e)
|
||||
|
||||
return {
|
||||
"status": "failed",
|
||||
"run_id": run_id,
|
||||
"job_id": job_id,
|
||||
"error": str(e),
|
||||
}
|
||||
@@ -3,6 +3,9 @@ Event emission helpers for detection pipeline stages.
|
||||
|
||||
Single place that knows how to build event payloads.
|
||||
Stages call these instead of constructing dicts or dataclasses directly.
|
||||
|
||||
Run context (run_id, parent_job_id) is set once at pipeline start via
|
||||
set_run_context() and automatically injected into all events.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -13,9 +16,33 @@ from datetime import datetime, timezone
|
||||
from detect.events import push_detect_event
|
||||
from detect.models import PipelineStats
|
||||
|
||||
# Module-level run context — set once per pipeline invocation
|
||||
_run_context: dict = {}
|
||||
|
||||
|
||||
def set_run_context(run_id: str = "", parent_job_id: str = "", run_type: str = "initial"):
|
||||
"""Set the run context for all subsequent events in this pipeline invocation."""
|
||||
global _run_context
|
||||
_run_context = {
|
||||
"run_id": run_id,
|
||||
"parent_job_id": parent_job_id,
|
||||
"run_type": run_type,
|
||||
}
|
||||
|
||||
|
||||
def clear_run_context():
|
||||
global _run_context
|
||||
_run_context = {}
|
||||
|
||||
|
||||
def _inject_context(payload: dict) -> dict:
|
||||
"""Add run context fields to an event payload."""
|
||||
if _run_context:
|
||||
payload.update(_run_context)
|
||||
return payload
|
||||
|
||||
|
||||
def log(job_id: str | None, stage: str, level: str, msg: str) -> None:
|
||||
"""Emit a log event."""
|
||||
if not job_id:
|
||||
return
|
||||
payload = {
|
||||
@@ -24,15 +51,17 @@ def log(job_id: str | None, stage: str, level: str, msg: str) -> None:
|
||||
"msg": msg,
|
||||
"ts": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
_inject_context(payload)
|
||||
push_detect_event(job_id, "log", payload)
|
||||
|
||||
|
||||
def stats(job_id: str | None, **kwargs) -> None:
|
||||
"""Emit a stats_update event. Pass only the fields that changed."""
|
||||
if not job_id:
|
||||
return
|
||||
s = PipelineStats(**kwargs)
|
||||
push_detect_event(job_id, "stats_update", dataclasses.asdict(s))
|
||||
payload = dataclasses.asdict(s)
|
||||
_inject_context(payload)
|
||||
push_detect_event(job_id, "stats_update", payload)
|
||||
|
||||
|
||||
def frame_update(
|
||||
@@ -42,7 +71,6 @@ def frame_update(
|
||||
jpeg_b64: str,
|
||||
boxes: list[dict],
|
||||
) -> None:
|
||||
"""Emit a frame_update event with the image and bounding boxes."""
|
||||
if not job_id:
|
||||
return
|
||||
payload = {
|
||||
@@ -51,14 +79,15 @@ def frame_update(
|
||||
"jpeg_b64": jpeg_b64,
|
||||
"boxes": boxes,
|
||||
}
|
||||
_inject_context(payload)
|
||||
push_detect_event(job_id, "frame_update", payload)
|
||||
|
||||
|
||||
def graph_update(job_id: str | None, nodes: list[dict]) -> None:
|
||||
"""Emit a graph_update event with node states."""
|
||||
if not job_id:
|
||||
return
|
||||
payload = {"nodes": nodes}
|
||||
_inject_context(payload)
|
||||
push_detect_event(job_id, "graph_update", payload)
|
||||
|
||||
|
||||
@@ -72,7 +101,6 @@ def detection(
|
||||
content_type: str = "",
|
||||
frame_ref: int | None = None,
|
||||
) -> None:
|
||||
"""Emit a brand detection event."""
|
||||
if not job_id:
|
||||
return
|
||||
payload = {
|
||||
@@ -84,12 +112,13 @@ def detection(
|
||||
"content_type": content_type,
|
||||
"frame_ref": frame_ref,
|
||||
}
|
||||
_inject_context(payload)
|
||||
push_detect_event(job_id, "detection", payload)
|
||||
|
||||
|
||||
def job_complete(job_id: str | None, report: dict) -> None:
|
||||
"""Emit a job_complete event with the final report."""
|
||||
if not job_id:
|
||||
return
|
||||
payload = {"job_id": job_id, "report": report}
|
||||
_inject_context(payload)
|
||||
push_detect_event(job_id, "job_complete", payload)
|
||||
|
||||
127
detect/graph.py
127
detect/graph.py
@@ -42,8 +42,16 @@ NODES = [
|
||||
def _get_profile(state: DetectState):
|
||||
name = state.get("profile_name", "soccer_broadcast")
|
||||
if name == "soccer_broadcast":
|
||||
return SoccerBroadcastProfile()
|
||||
raise ValueError(f"Unknown profile: {name}")
|
||||
profile = SoccerBroadcastProfile()
|
||||
else:
|
||||
raise ValueError(f"Unknown profile: {name}")
|
||||
|
||||
overrides = state.get("config_overrides")
|
||||
if overrides:
|
||||
from detect.checkpoint.replay import OverrideProfile
|
||||
profile = OverrideProfile(profile, overrides)
|
||||
|
||||
return profile
|
||||
|
||||
|
||||
# Track node states across the pipeline run
|
||||
@@ -68,6 +76,18 @@ def _emit_transition(state: DetectState, node: str, status: str):
|
||||
# --- Node functions ---
|
||||
|
||||
def node_extract_frames(state: DetectState) -> dict:
|
||||
# Set run context for initial runs (replays set it in replay_from)
|
||||
job_id = state.get("job_id", "")
|
||||
if job_id and not emit._run_context:
|
||||
emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial")
|
||||
|
||||
# Load session brands from DB for this source
|
||||
source_asset_id = state.get("source_asset_id")
|
||||
if source_asset_id and not state.get("session_brands"):
|
||||
from detect.stages.brand_resolver import build_session_dict
|
||||
session_brands = build_session_dict(source_asset_id)
|
||||
state["session_brands"] = session_brands
|
||||
|
||||
_emit_transition(state, "extract_frames", "running")
|
||||
|
||||
with trace_node(state, "extract_frames") as span:
|
||||
@@ -142,13 +162,16 @@ def node_match_brands(state: DetectState) -> dict:
|
||||
|
||||
with trace_node(state, "match_brands") as span:
|
||||
profile = _get_profile(state)
|
||||
dictionary = profile.brand_dictionary()
|
||||
resolver_config = profile.resolver_config()
|
||||
candidates = state.get("text_candidates", [])
|
||||
session_brands = state.get("session_brands", {})
|
||||
job_id = state.get("job_id")
|
||||
source_asset_id = state.get("source_asset_id")
|
||||
|
||||
matched, unresolved = resolve_brands(
|
||||
candidates, dictionary, resolver_config,
|
||||
candidates, resolver_config,
|
||||
session_brands=session_brands,
|
||||
source_asset_id=source_asset_id,
|
||||
content_type=profile.name, job_id=job_id,
|
||||
)
|
||||
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
|
||||
@@ -170,6 +193,7 @@ def node_escalate_vlm(state: DetectState) -> dict:
|
||||
vlm_prompt_fn=profile.vlm_prompt,
|
||||
inference_url=INFERENCE_URL,
|
||||
content_type=profile.name,
|
||||
source_asset_id=state.get("source_asset_id"),
|
||||
job_id=job_id,
|
||||
)
|
||||
|
||||
@@ -202,6 +226,7 @@ def node_escalate_cloud(state: DetectState) -> dict:
|
||||
vlm_prompt_fn=profile.vlm_prompt,
|
||||
stats=stats,
|
||||
content_type=profile.name,
|
||||
source_asset_id=state.get("source_asset_id"),
|
||||
job_id=job_id,
|
||||
)
|
||||
|
||||
@@ -239,33 +264,87 @@ def node_compile_report(state: DetectState) -> dict:
|
||||
return {"report": report}
|
||||
|
||||
|
||||
# --- Checkpoint wrapper ---
|
||||
|
||||
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
|
||||
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
|
||||
|
||||
|
||||
def _checkpointing_node(node_name: str, node_fn):
|
||||
"""Wrap a node function to auto-checkpoint after completion."""
|
||||
stage_index = NODES.index(node_name)
|
||||
|
||||
def wrapper(state: DetectState) -> dict:
|
||||
result = node_fn(state)
|
||||
|
||||
job_id = state.get("job_id", "")
|
||||
if not job_id:
|
||||
return result
|
||||
|
||||
from detect.checkpoint import save_checkpoint, save_frames
|
||||
|
||||
merged = {**state, **result}
|
||||
|
||||
# Save frames once (first checkpoint), reuse manifest after
|
||||
manifest = _frames_manifest.get(job_id)
|
||||
if manifest is None and node_name == "extract_frames":
|
||||
manifest = save_frames(job_id, merged.get("frames", []))
|
||||
_frames_manifest[job_id] = manifest
|
||||
|
||||
save_checkpoint(job_id, node_name, stage_index, merged, frames_manifest=manifest)
|
||||
return result
|
||||
|
||||
wrapper.__name__ = node_fn.__name__
|
||||
return wrapper
|
||||
|
||||
|
||||
# --- Graph construction ---
|
||||
|
||||
def build_graph() -> StateGraph:
|
||||
NODE_FUNCTIONS = [
|
||||
("extract_frames", node_extract_frames),
|
||||
("filter_scenes", node_filter_scenes),
|
||||
("detect_objects", node_detect_objects),
|
||||
("run_ocr", node_run_ocr),
|
||||
("match_brands", node_match_brands),
|
||||
("escalate_vlm", node_escalate_vlm),
|
||||
("escalate_cloud", node_escalate_cloud),
|
||||
("compile_report", node_compile_report),
|
||||
]
|
||||
|
||||
|
||||
def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph:
|
||||
"""
|
||||
Build the pipeline graph.
|
||||
|
||||
checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var)
|
||||
start_from: skip nodes before this stage (for replay)
|
||||
"""
|
||||
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
|
||||
|
||||
graph = StateGraph(DetectState)
|
||||
|
||||
graph.add_node("extract_frames", node_extract_frames)
|
||||
graph.add_node("filter_scenes", node_filter_scenes)
|
||||
graph.add_node("detect_objects", node_detect_objects)
|
||||
graph.add_node("run_ocr", node_run_ocr)
|
||||
graph.add_node("match_brands", node_match_brands)
|
||||
graph.add_node("escalate_vlm", node_escalate_vlm)
|
||||
graph.add_node("escalate_cloud", node_escalate_cloud)
|
||||
graph.add_node("compile_report", node_compile_report)
|
||||
# Filter to start_from if replaying
|
||||
node_pairs = NODE_FUNCTIONS
|
||||
if start_from:
|
||||
start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from)
|
||||
node_pairs = NODE_FUNCTIONS[start_idx:]
|
||||
|
||||
graph.set_entry_point("extract_frames")
|
||||
graph.add_edge("extract_frames", "filter_scenes")
|
||||
graph.add_edge("filter_scenes", "detect_objects")
|
||||
graph.add_edge("detect_objects", "run_ocr")
|
||||
graph.add_edge("run_ocr", "match_brands")
|
||||
graph.add_edge("match_brands", "escalate_vlm")
|
||||
graph.add_edge("escalate_vlm", "escalate_cloud")
|
||||
graph.add_edge("escalate_cloud", "compile_report")
|
||||
graph.add_edge("compile_report", END)
|
||||
for name, fn in node_pairs:
|
||||
wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn
|
||||
graph.add_node(name, wrapped)
|
||||
|
||||
# Wire edges
|
||||
entry = node_pairs[0][0]
|
||||
graph.set_entry_point(entry)
|
||||
|
||||
for i in range(len(node_pairs) - 1):
|
||||
graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0])
|
||||
|
||||
graph.add_edge(node_pairs[-1][0], END)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def get_pipeline():
|
||||
def get_pipeline(checkpoint: bool | None = None):
|
||||
"""Return a compiled, runnable pipeline."""
|
||||
return build_graph().compile()
|
||||
return build_graph(checkpoint=checkpoint).compile()
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from .base import (
|
||||
ContentTypeProfile,
|
||||
BrandDictionary,
|
||||
CropContext,
|
||||
DetectionConfig,
|
||||
FrameExtractionConfig,
|
||||
@@ -12,7 +11,6 @@ from .soccer import SoccerBroadcastProfile
|
||||
|
||||
__all__ = [
|
||||
"ContentTypeProfile",
|
||||
"BrandDictionary",
|
||||
"CropContext",
|
||||
"DetectionConfig",
|
||||
"FrameExtractionConfig",
|
||||
|
||||
@@ -44,12 +44,6 @@ class ResolverConfig:
|
||||
fuzzy_threshold: int = 75
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrandDictionary:
|
||||
"""Maps canonical brand name → list of known aliases/spellings."""
|
||||
brands: dict[str, list[str]] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CropContext:
|
||||
image: bytes
|
||||
@@ -64,7 +58,6 @@ class ContentTypeProfile(Protocol):
|
||||
def scene_filter_config(self) -> SceneFilterConfig: ...
|
||||
def detection_config(self) -> DetectionConfig: ...
|
||||
def ocr_config(self) -> OCRConfig: ...
|
||||
def brand_dictionary(self) -> BrandDictionary: ...
|
||||
def resolver_config(self) -> ResolverConfig: ...
|
||||
def vlm_prompt(self, crop_context: CropContext) -> str: ...
|
||||
def aggregate(self, detections: list[BrandDetection]) -> DetectionReport: ...
|
||||
|
||||
@@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
|
||||
|
||||
from .base import (
|
||||
BrandDictionary,
|
||||
CropContext,
|
||||
DetectionConfig,
|
||||
FrameExtractionConfig,
|
||||
@@ -34,22 +33,6 @@ class SoccerBroadcastProfile:
|
||||
def ocr_config(self) -> OCRConfig:
|
||||
return OCRConfig(languages=["en", "es"], min_confidence=0.5)
|
||||
|
||||
def brand_dictionary(self) -> BrandDictionary:
|
||||
return BrandDictionary(brands={
|
||||
"Nike": ["nike", "NIKE", "swoosh"],
|
||||
"Adidas": ["adidas", "ADIDAS", "adi"],
|
||||
"Puma": ["puma", "PUMA"],
|
||||
"Emirates": ["emirates", "fly emirates", "EMIRATES"],
|
||||
"Coca-Cola": ["coca-cola", "coca cola", "coke", "COCA-COLA"],
|
||||
"Pepsi": ["pepsi", "PEPSI"],
|
||||
"Mastercard": ["mastercard", "MASTERCARD"],
|
||||
"Heineken": ["heineken", "HEINEKEN"],
|
||||
"Santander": ["santander", "SANTANDER"],
|
||||
"Gazprom": ["gazprom", "GAZPROM"],
|
||||
"Qatar Airways": ["qatar airways", "QATAR AIRWAYS"],
|
||||
"Lay's": ["lays", "lay's", "LAYS", "LAY'S"],
|
||||
})
|
||||
|
||||
def resolver_config(self) -> ResolverConfig:
|
||||
return ResolverConfig(fuzzy_threshold=75)
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
from detect.models import BrandDetection, DetectionReport
|
||||
|
||||
from .base import (
|
||||
BrandDictionary,
|
||||
CropContext,
|
||||
DetectionConfig,
|
||||
FrameExtractionConfig,
|
||||
@@ -30,9 +29,6 @@ class NewsBroadcastProfile:
|
||||
def ocr_config(self) -> OCRConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def brand_dictionary(self) -> BrandDictionary:
|
||||
raise NotImplementedError
|
||||
|
||||
def resolver_config(self) -> ResolverConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -61,9 +57,6 @@ class AdvertisingProfile:
|
||||
def ocr_config(self) -> OCRConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def brand_dictionary(self) -> BrandDictionary:
|
||||
raise NotImplementedError
|
||||
|
||||
def resolver_config(self) -> ResolverConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -92,9 +85,6 @@ class TranscriptProfile:
|
||||
def ocr_config(self) -> OCRConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def brand_dictionary(self) -> BrandDictionary:
|
||||
raise NotImplementedError
|
||||
|
||||
def resolver_config(self) -> ResolverConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
"""
|
||||
Stage 5 — Brand Resolver
|
||||
Stage 5 — Brand Resolver (discovery mode)
|
||||
|
||||
Matches OCR text against the profile's brand dictionary.
|
||||
Uses exact matching first, then fuzzy matching (rapidfuzz) as fallback.
|
||||
Emits detection events for confirmed brands.
|
||||
Discovery-first brand matching. No static dictionary — all brands live in the DB.
|
||||
|
||||
Flow:
|
||||
1. Check session sightings first (brands already seen in this source)
|
||||
2. Check global known brands (accumulated across all runs)
|
||||
3. Unresolved candidates → escalate to VLM/cloud
|
||||
4. Confirmed brands get added to DB for future runs
|
||||
|
||||
The resolver is an enricher, not a gatekeeper. Every OCR text candidate
|
||||
passes through — the question is whether we can resolve it cheaply (DB lookup)
|
||||
or need to escalate (VLM/cloud).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -14,99 +22,199 @@ from rapidfuzz import fuzz
|
||||
|
||||
from detect import emit
|
||||
from detect.models import BrandDetection, TextCandidate
|
||||
from detect.profiles.base import BrandDictionary, ResolverConfig
|
||||
from detect.profiles.base import ResolverConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize(text: str) -> str:
|
||||
"""Normalize text for matching."""
|
||||
return text.strip().lower()
|
||||
|
||||
|
||||
def _exact_match(text: str, dictionary: BrandDictionary) -> str | None:
|
||||
"""Try exact match against all aliases."""
|
||||
def _has_db() -> bool:
|
||||
try:
|
||||
from core.db.detect import find_brand_by_text as _
|
||||
from admin.mpr.media_assets.models import KnownBrand as _
|
||||
return True
|
||||
except (ImportError, Exception):
|
||||
return False
|
||||
|
||||
|
||||
def _match_session(text: str, session_brands: dict[str, str]) -> str | None:
|
||||
"""
|
||||
Check against session brands (already seen in this source).
|
||||
|
||||
session_brands: {normalized_name: canonical_name, ...}
|
||||
Includes aliases.
|
||||
"""
|
||||
normalized = _normalize(text)
|
||||
for canonical, aliases in dictionary.brands.items():
|
||||
if normalized == _normalize(canonical):
|
||||
return canonical
|
||||
for alias in aliases:
|
||||
if normalized == _normalize(alias):
|
||||
return canonical
|
||||
return None
|
||||
return session_brands.get(normalized)
|
||||
|
||||
|
||||
def _fuzzy_match(text: str, dictionary: BrandDictionary, threshold: int) -> tuple[str | None, int]:
|
||||
"""Try fuzzy match, return (brand, score) or (None, 0)."""
|
||||
def _match_known(text: str, threshold: int) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
Check against global known brands in DB.
|
||||
|
||||
Returns (canonical_name, brand_id) or (None, None).
|
||||
"""
|
||||
if not _has_db():
|
||||
return None, None
|
||||
|
||||
from core.db.detect import find_brand_by_text
|
||||
brand = find_brand_by_text(text)
|
||||
if brand:
|
||||
return brand.canonical_name, str(brand.id)
|
||||
|
||||
# Fuzzy match against all known brands
|
||||
from core.db.detect import list_all_brands
|
||||
all_brands = list_all_brands()
|
||||
|
||||
normalized = _normalize(text)
|
||||
best_brand = None
|
||||
best_score = 0
|
||||
|
||||
for canonical, aliases in dictionary.brands.items():
|
||||
all_names = [canonical] + aliases
|
||||
for name in all_names:
|
||||
for known in all_brands:
|
||||
names = [known.canonical_name] + (known.aliases or [])
|
||||
for name in names:
|
||||
score = fuzz.ratio(normalized, _normalize(name))
|
||||
if score > best_score and score >= threshold:
|
||||
best_score = score
|
||||
best_brand = canonical
|
||||
best_brand = known
|
||||
|
||||
return best_brand, best_score
|
||||
if best_brand:
|
||||
return best_brand.canonical_name, str(best_brand.id)
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def _register_brand(canonical_name: str, source: str) -> str | None:
|
||||
"""Register a newly discovered brand in the DB. Returns brand_id."""
|
||||
if not _has_db():
|
||||
return None
|
||||
|
||||
from core.db.detect import get_or_create_brand
|
||||
brand, created = get_or_create_brand(canonical_name, source=source)
|
||||
if created:
|
||||
logger.info("New brand discovered: %s (source=%s)", canonical_name, source)
|
||||
return str(brand.id)
|
||||
|
||||
|
||||
def _record_sighting(source_asset_id: str | None, brand_id: str,
|
||||
brand_name: str, timestamp: float,
|
||||
confidence: float, source: str):
|
||||
"""Record a brand sighting for this source."""
|
||||
if not _has_db() or not source_asset_id:
|
||||
return
|
||||
|
||||
from core.db.detect import record_sighting
|
||||
import uuid
|
||||
asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id
|
||||
brand_uuid = uuid.UUID(brand_id) if isinstance(brand_id, str) else brand_id
|
||||
record_sighting(asset_id, brand_uuid, brand_name, timestamp, confidence, source)
|
||||
|
||||
|
||||
def build_session_dict(source_asset_id: str | None) -> dict[str, str]:
|
||||
"""
|
||||
Load session brands from DB for this source.
|
||||
|
||||
Returns {normalized_name: canonical_name, ...} including aliases.
|
||||
"""
|
||||
if not _has_db() or not source_asset_id:
|
||||
return {}
|
||||
|
||||
from core.db.detect import get_source_sightings
|
||||
import uuid
|
||||
|
||||
asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id
|
||||
sightings = get_source_sightings(asset_id)
|
||||
|
||||
session = {}
|
||||
for s in sightings:
|
||||
canonical = s.brand_name
|
||||
session[_normalize(canonical)] = canonical
|
||||
|
||||
# Also load aliases from KnownBrand for each sighted brand
|
||||
if _has_db():
|
||||
from core.db.detect import list_all_brands
|
||||
all_brands = list_all_brands()
|
||||
sighted_names = {s.brand_name for s in sightings}
|
||||
for brand in all_brands:
|
||||
if brand.canonical_name in sighted_names:
|
||||
for alias in (brand.aliases or []):
|
||||
session[_normalize(alias)] = brand.canonical_name
|
||||
|
||||
return session
|
||||
|
||||
|
||||
def resolve_brands(
|
||||
candidates: list[TextCandidate],
|
||||
dictionary: BrandDictionary,
|
||||
config: ResolverConfig,
|
||||
session_brands: dict[str, str] | None = None,
|
||||
source_asset_id: str | None = None,
|
||||
content_type: str = "",
|
||||
job_id: str | None = None,
|
||||
) -> tuple[list[BrandDetection], list[TextCandidate]]:
|
||||
"""
|
||||
Match text candidates against the brand dictionary.
|
||||
Match text candidates against known brands (session → global → unresolved).
|
||||
|
||||
Returns:
|
||||
- matched: list of BrandDetection for confirmed brands
|
||||
- unresolved: list of TextCandidate that couldn't be matched
|
||||
session_brands: pre-loaded session dict (from build_session_dict)
|
||||
source_asset_id: for recording new sightings in DB
|
||||
"""
|
||||
if session_brands is None:
|
||||
session_brands = {}
|
||||
|
||||
emit.log(job_id, "BrandResolver", "INFO",
|
||||
f"Matching {len(candidates)} candidates against "
|
||||
f"{len(dictionary.brands)} brands (fuzzy_threshold={config.fuzzy_threshold})")
|
||||
f"Resolving {len(candidates)} candidates "
|
||||
f"(session={len(session_brands)} brands, fuzzy={config.fuzzy_threshold})")
|
||||
|
||||
matched: list[BrandDetection] = []
|
||||
unresolved: list[TextCandidate] = []
|
||||
exact_count = 0
|
||||
fuzzy_count = 0
|
||||
session_hits = 0
|
||||
known_hits = 0
|
||||
|
||||
for candidate in candidates:
|
||||
# Try exact match first
|
||||
brand = _exact_match(candidate.text, dictionary)
|
||||
source = "ocr"
|
||||
text = candidate.text
|
||||
brand_name = None
|
||||
brand_id = None
|
||||
match_source = "ocr"
|
||||
|
||||
if brand:
|
||||
exact_count += 1
|
||||
# 1. Check session (cheapest — in-memory dict)
|
||||
brand_name = _match_session(text, session_brands)
|
||||
if brand_name:
|
||||
session_hits += 1
|
||||
else:
|
||||
# Try fuzzy match
|
||||
brand, score = _fuzzy_match(candidate.text, dictionary, config.fuzzy_threshold)
|
||||
if brand:
|
||||
fuzzy_count += 1
|
||||
# 2. Check global known brands (DB query + fuzzy)
|
||||
brand_name, brand_id = _match_known(text, config.fuzzy_threshold)
|
||||
if brand_name:
|
||||
known_hits += 1
|
||||
# Add to session for subsequent candidates in this run
|
||||
session_brands[_normalize(brand_name)] = brand_name
|
||||
|
||||
if brand:
|
||||
if brand_name:
|
||||
detection = BrandDetection(
|
||||
brand=brand,
|
||||
brand=brand_name,
|
||||
timestamp=candidate.frame.timestamp,
|
||||
duration=0.5,
|
||||
confidence=candidate.ocr_confidence,
|
||||
source=source,
|
||||
source=match_source,
|
||||
bbox=candidate.bbox,
|
||||
frame_ref=candidate.frame.sequence,
|
||||
content_type=content_type,
|
||||
)
|
||||
matched.append(detection)
|
||||
|
||||
# Record sighting in DB
|
||||
if brand_id:
|
||||
_record_sighting(
|
||||
source_asset_id, brand_id, brand_name,
|
||||
candidate.frame.timestamp, candidate.ocr_confidence, match_source,
|
||||
)
|
||||
|
||||
emit.detection(
|
||||
job_id,
|
||||
brand=brand,
|
||||
brand=brand_name,
|
||||
confidence=candidate.ocr_confidence,
|
||||
source=source,
|
||||
source=match_source,
|
||||
timestamp=candidate.frame.timestamp,
|
||||
content_type=content_type,
|
||||
frame_ref=candidate.frame.sequence,
|
||||
@@ -115,7 +223,7 @@ def resolve_brands(
|
||||
unresolved.append(candidate)
|
||||
|
||||
emit.log(job_id, "BrandResolver", "INFO",
|
||||
f"Exact: {exact_count}, Fuzzy: {fuzzy_count}, "
|
||||
f"Unresolved: {len(unresolved)} → escalating to VLM")
|
||||
f"Session: {session_hits}, Known: {known_hits}, "
|
||||
f"Unresolved: {len(unresolved)} → escalating")
|
||||
|
||||
return matched, unresolved
|
||||
|
||||
@@ -27,6 +27,18 @@ logger = logging.getLogger(__name__)
|
||||
ESTIMATED_TOKENS_PER_CROP = 500
|
||||
|
||||
|
||||
def _register_discovered_brand(brand: str, source_asset_id: str | None,
|
||||
timestamp: float, confidence: float):
|
||||
"""Register a cloud-confirmed brand in the DB."""
|
||||
try:
|
||||
from detect.stages.brand_resolver import _register_brand, _record_sighting
|
||||
brand_id = _register_brand(brand, "cloud_llm")
|
||||
if brand_id and source_asset_id:
|
||||
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, "cloud_llm")
|
||||
except Exception as e:
|
||||
logger.debug("Failed to register brand %s: %s", brand, e)
|
||||
|
||||
|
||||
def _encode_crop(crop: np.ndarray) -> str:
|
||||
img = Image.fromarray(crop)
|
||||
buf = io.BytesIO()
|
||||
@@ -84,6 +96,7 @@ def escalate_cloud(
|
||||
stats: PipelineStats,
|
||||
min_confidence: float = 0.4,
|
||||
content_type: str = "",
|
||||
source_asset_id: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> list[BrandDetection]:
|
||||
"""
|
||||
@@ -158,6 +171,10 @@ def escalate_cloud(
|
||||
frame_ref=candidate.frame.sequence,
|
||||
)
|
||||
|
||||
# Register newly discovered brand in DB
|
||||
_register_discovered_brand(brand, source_asset_id,
|
||||
candidate.frame.timestamp, confidence)
|
||||
|
||||
stats.estimated_cloud_cost_usd += total_cost
|
||||
stats.regions_escalated_to_cloud_llm = len(candidates)
|
||||
|
||||
|
||||
@@ -19,6 +19,18 @@ from detect.profiles.base import CropContext
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _register_discovered_brand(brand: str, source_asset_id: str | None,
|
||||
timestamp: float, confidence: float, source: str):
|
||||
"""Register a VLM-confirmed brand in the DB."""
|
||||
try:
|
||||
from detect.stages.brand_resolver import _register_brand, _record_sighting
|
||||
brand_id = _register_brand(brand, source)
|
||||
if brand_id and source_asset_id:
|
||||
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, source)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to register brand %s: %s", brand, e)
|
||||
|
||||
|
||||
def _crop_image(candidate: TextCandidate) -> np.ndarray:
|
||||
frame = candidate.frame
|
||||
box = candidate.bbox
|
||||
@@ -36,6 +48,7 @@ def escalate_vlm(
|
||||
inference_url: str | None = None,
|
||||
min_confidence: float = 0.5,
|
||||
content_type: str = "",
|
||||
source_asset_id: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> tuple[list[BrandDetection], list[TextCandidate]]:
|
||||
"""
|
||||
@@ -107,6 +120,10 @@ def escalate_vlm(
|
||||
frame_ref=candidate.frame.sequence,
|
||||
)
|
||||
|
||||
# Register newly discovered brand in DB
|
||||
_register_discovered_brand(brand, source_asset_id,
|
||||
candidate.frame.timestamp, confidence, "local_vlm")
|
||||
|
||||
logger.debug("VLM matched: %s (%.2f) — %s", brand, confidence, reasoning)
|
||||
else:
|
||||
still_unresolved.append(candidate)
|
||||
|
||||
@@ -17,6 +17,7 @@ class DetectState(TypedDict, total=False):
|
||||
video_path: str
|
||||
job_id: str
|
||||
profile_name: str
|
||||
source_asset_id: str # UUID of the source MediaAsset
|
||||
|
||||
# Stage outputs
|
||||
frames: list[Frame]
|
||||
@@ -27,5 +28,11 @@ class DetectState(TypedDict, total=False):
|
||||
detections: list[BrandDetection]
|
||||
report: DetectionReport
|
||||
|
||||
# Session brands (accumulated during the run, persisted to DB)
|
||||
session_brands: dict # {normalized_name: canonical_name}
|
||||
|
||||
# Running stats (updated by each stage)
|
||||
stats: PipelineStats
|
||||
|
||||
# Config overrides for replay (applied via OverrideProfile)
|
||||
config_overrides: dict
|
||||
|
||||
Reference in New Issue
Block a user