This commit is contained in:
2026-03-26 04:24:32 -03:00
parent 08b67f2bb7
commit 08c58a6a9d
43 changed files with 2627 additions and 252 deletions

121
core/api/detect_replay.py Normal file
View File

@@ -0,0 +1,121 @@
"""
API endpoints for checkpoint inspection, replay, and retry.
GET /detect/checkpoints/{job_id} — list available checkpoints
POST /detect/replay — replay from a stage with config overrides
POST /detect/retry — queue async retry with different provider
"""
from __future__ import annotations
import logging
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/detect", tags=["detect"])
# --- Request/Response models ---
class CheckpointInfo(BaseModel):
stage: str
class ReplayRequest(BaseModel):
job_id: str
start_stage: str
config_overrides: dict | None = None
class ReplayResponse(BaseModel):
status: str
job_id: str
start_stage: str
detections: int = 0
brands_found: int = 0
class RetryRequest(BaseModel):
job_id: str
config_overrides: dict | None = None
start_stage: str = "escalate_vlm"
schedule_seconds: float | None = None # delay before execution (off-peak)
class RetryResponse(BaseModel):
status: str
task_id: str
job_id: str
# --- Endpoints ---
@router.get("/checkpoints/{job_id}")
def list_checkpoints(job_id: str) -> list[CheckpointInfo]:
"""List available checkpoint stages for a job."""
from detect.checkpoint import list_checkpoints as _list
try:
stages = _list(job_id)
except Exception as e:
raise HTTPException(status_code=404, detail=f"No checkpoints for job {job_id}: {e}")
result = [CheckpointInfo(stage=s) for s in stages]
return result
@router.post("/replay", response_model=ReplayResponse)
def replay(req: ReplayRequest):
"""Replay pipeline from a specific stage with optional config overrides."""
from detect.checkpoint import replay_from
try:
result = replay_from(
job_id=req.job_id,
start_stage=req.start_stage,
config_overrides=req.config_overrides,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Replay failed: {e}")
detections = result.get("detections", [])
report = result.get("report")
brands_found = len(report.brands) if report else 0
response = ReplayResponse(
status="completed",
job_id=req.job_id,
start_stage=req.start_stage,
detections=len(detections),
brands_found=brands_found,
)
return response
@router.post("/retry", response_model=RetryResponse)
def retry(req: RetryRequest):
"""Queue an async retry of unresolved candidates with different config."""
from detect.checkpoint.tasks import retry_candidates
kwargs = {
"job_id": req.job_id,
"config_overrides": req.config_overrides,
"start_stage": req.start_stage,
}
if req.schedule_seconds:
task = retry_candidates.apply_async(kwargs=kwargs, countdown=req.schedule_seconds)
else:
task = retry_candidates.delay(**kwargs)
response = RetryResponse(
status="queued",
task_id=task.id,
job_id=req.job_id,
)
return response

View File

@@ -25,6 +25,7 @@ from strawberry.fastapi import GraphQLRouter
from core.api.chunker_sse import router as chunker_router from core.api.chunker_sse import router as chunker_router
from core.api.detect_sse import router as detect_router from core.api.detect_sse import router as detect_router
from core.api.detect_replay import router as detect_replay_router
from core.api.graphql import schema as graphql_schema from core.api.graphql import schema as graphql_schema
CALLBACK_API_KEY = os.environ.get("CALLBACK_API_KEY", "") CALLBACK_API_KEY = os.environ.get("CALLBACK_API_KEY", "")
@@ -56,6 +57,9 @@ app.include_router(chunker_router)
# Detection SSE # Detection SSE
app.include_router(detect_router) app.include_router(detect_router)
# Detection replay/retry
app.include_router(detect_replay_router)
@app.get("/health") @app.get("/health")
def health(): def health():

175
core/db/detect.py Normal file
View File

@@ -0,0 +1,175 @@
"""Database operations for DetectJob and StageCheckpoint."""
from typing import Optional
from uuid import UUID
# ---------------------------------------------------------------------------
# DetectJob
# ---------------------------------------------------------------------------
def create_detect_job(**fields):
from admin.mpr.media_assets.models import DetectJob
return DetectJob.objects.create(**fields)
def get_detect_job(id: UUID):
from admin.mpr.media_assets.models import DetectJob
return DetectJob.objects.get(id=id)
def update_detect_job(job_id: UUID, **fields):
from admin.mpr.media_assets.models import DetectJob
DetectJob.objects.filter(id=job_id).update(**fields)
def list_detect_jobs(
parent_job_id: Optional[UUID] = None,
status: Optional[str] = None,
):
from admin.mpr.media_assets.models import DetectJob
qs = DetectJob.objects.all()
if parent_job_id:
qs = qs.filter(parent_job_id=parent_job_id)
if status:
qs = qs.filter(status=status)
return list(qs)
# ---------------------------------------------------------------------------
# StageCheckpoint
# ---------------------------------------------------------------------------
def save_stage_checkpoint(**fields):
from admin.mpr.media_assets.models import StageCheckpoint
return StageCheckpoint.objects.create(**fields)
def get_stage_checkpoint(job_id: UUID, stage: str):
from admin.mpr.media_assets.models import StageCheckpoint
return StageCheckpoint.objects.get(job_id=job_id, stage=stage)
def list_stage_checkpoints(job_id: UUID) -> list[str]:
from admin.mpr.media_assets.models import StageCheckpoint
stages = (
StageCheckpoint.objects
.filter(job_id=job_id)
.order_by("stage_index")
.values_list("stage", flat=True)
)
return list(stages)
def delete_stage_checkpoints(job_id: UUID):
from admin.mpr.media_assets.models import StageCheckpoint
StageCheckpoint.objects.filter(job_id=job_id).delete()
# ---------------------------------------------------------------------------
# KnownBrand
# ---------------------------------------------------------------------------
def get_or_create_brand(canonical_name: str, aliases: list[str] | None = None,
source: str = "ocr") -> tuple:
"""Get existing brand or create new one. Returns (brand, created)."""
from admin.mpr.media_assets.models import KnownBrand
import uuid
normalized = canonical_name.strip()
brand = KnownBrand.objects.filter(canonical_name__iexact=normalized).first()
if brand:
return brand, False
# Check aliases of existing brands
for existing in KnownBrand.objects.all():
existing_aliases = [a.lower() for a in (existing.aliases or [])]
if normalized.lower() in existing_aliases:
return existing, False
brand = KnownBrand.objects.create(
id=uuid.uuid4(),
canonical_name=normalized,
aliases=aliases or [],
first_source=source,
)
return brand, True
def find_brand_by_text(text: str) -> Optional[object]:
"""Find a known brand by canonical name or alias (case-insensitive)."""
from admin.mpr.media_assets.models import KnownBrand
normalized = text.strip().lower()
# Exact canonical match
brand = KnownBrand.objects.filter(canonical_name__iexact=normalized).first()
if brand:
return brand
# Search aliases (jsonb contains)
for brand in KnownBrand.objects.all():
brand_aliases = [a.lower() for a in (brand.aliases or [])]
if normalized in brand_aliases:
return brand
return None
def list_all_brands() -> list:
from admin.mpr.media_assets.models import KnownBrand
return list(KnownBrand.objects.all().order_by("canonical_name"))
def update_brand(brand_id: UUID, **fields):
from admin.mpr.media_assets.models import KnownBrand
KnownBrand.objects.filter(id=brand_id).update(**fields)
# ---------------------------------------------------------------------------
# SourceBrandSighting
# ---------------------------------------------------------------------------
def get_source_sightings(source_asset_id: UUID) -> list:
"""Get all brand sightings for a specific source video."""
from admin.mpr.media_assets.models import SourceBrandSighting
return list(
SourceBrandSighting.objects
.filter(source_asset_id=source_asset_id)
.order_by("-occurrences")
)
def record_sighting(source_asset_id: UUID, brand_id: UUID, brand_name: str,
timestamp: float, confidence: float, source: str = "ocr"):
"""Record or update a brand sighting for a source."""
from admin.mpr.media_assets.models import SourceBrandSighting
import uuid
sighting = SourceBrandSighting.objects.filter(
source_asset_id=source_asset_id,
brand_id=brand_id,
).first()
if sighting:
sighting.occurrences += 1
sighting.last_seen_timestamp = timestamp
total_conf = sighting.avg_confidence * (sighting.occurrences - 1) + confidence
sighting.avg_confidence = total_conf / sighting.occurrences
sighting.save()
return sighting
sighting = SourceBrandSighting.objects.create(
id=uuid.uuid4(),
source_asset_id=source_asset_id,
brand_id=brand_id,
brand_name=brand_name,
first_seen_timestamp=timestamp,
last_seen_timestamp=timestamp,
occurrences=1,
detection_source=source,
avg_confidence=confidence,
)
return sighting

View File

@@ -26,13 +26,18 @@ from .grpc import (
WorkerStatus, WorkerStatus,
) )
from .jobs import ChunkJob, ChunkJobStatus, JobStatus, TranscodeJob from .jobs import ChunkJob, ChunkJobStatus, JobStatus, TranscodeJob
from .detect_jobs import (
DetectJob, DetectJobStatus, RunType, StageCheckpoint,
BrandSource, KnownBrand, SourceBrandSighting,
)
from .media import AssetStatus, MediaAsset from .media import AssetStatus, MediaAsset
from .presets import BUILTIN_PRESETS, TranscodePreset from .presets import BUILTIN_PRESETS, TranscodePreset
from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generic loader from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generic loader
from .views import ChunkEvent, ChunkOutputFile, PipelineStats, WorkerEvent from .views import ChunkEvent, ChunkOutputFile, PipelineStats, WorkerEvent
# Core domain models - generates Django, Pydantic, TypeScript # Core domain models - generates Django, Pydantic, TypeScript
DATACLASSES = [MediaAsset, TranscodePreset, TranscodeJob, ChunkJob] DATACLASSES = [MediaAsset, TranscodePreset, TranscodeJob, ChunkJob,
DetectJob, StageCheckpoint, KnownBrand, SourceBrandSighting]
# API request/response models - generates TypeScript only (no Django) # API request/response models - generates TypeScript only (no Django)
# WorkerStatus from grpc.py is reused here # WorkerStatus from grpc.py is reused here
@@ -46,7 +51,7 @@ API_MODELS = [
] ]
# Status enums - included in generated code # Status enums - included in generated code
ENUMS = [AssetStatus, JobStatus, ChunkJobStatus] ENUMS = [AssetStatus, JobStatus, ChunkJobStatus, DetectJobStatus, RunType, BrandSource]
# View/event models - generates TypeScript for UI consumption # View/event models - generates TypeScript for UI consumption
VIEWS = [ChunkEvent, WorkerEvent, PipelineStats, ChunkOutputFile] VIEWS = [ChunkEvent, WorkerEvent, PipelineStats, ChunkOutputFile]

View File

@@ -149,6 +149,64 @@ class JobComplete:
report: Optional[DetectionReportSummary] = None report: Optional[DetectionReportSummary] = None
@dataclass
class RunContext:
"""Run context injected into all SSE events for grouping."""
run_id: str
parent_job_id: str
run_type: str = "initial" # initial | replay | retry
# --- Checkpoint API types ---
@dataclass
class CheckpointInfo:
"""Available checkpoint for a stage."""
stage: str
@dataclass
class ReplayRequest:
"""Request to replay pipeline from a specific stage."""
job_id: str
start_stage: str
config_overrides: Optional[dict] = None
@dataclass
class ReplayResponse:
"""Result of a replay invocation."""
status: str
job_id: str
start_stage: str
detections: int = 0
brands_found: int = 0
@dataclass
class RetryRequest:
"""Request to queue async retry with different config."""
job_id: str
config_overrides: Optional[dict] = None
start_stage: str = "escalate_vlm"
schedule_seconds: Optional[float] = None
@dataclass
class RetryResponse:
"""Result of queueing a retry task."""
status: str
task_id: str
job_id: str
# --- Export lists for modelgen --- # --- Export lists for modelgen ---
DETECT_VIEWS = [ DETECT_VIEWS = [
@@ -163,4 +221,10 @@ DETECT_VIEWS = [
LogEvent, LogEvent,
DetectionReportSummary, DetectionReportSummary,
JobComplete, JobComplete,
RunContext,
CheckpointInfo,
ReplayRequest,
ReplayResponse,
RetryRequest,
RetryResponse,
] ]

View File

@@ -0,0 +1,162 @@
"""
Detection Job and Checkpoint Schema Definitions
Source of truth for detection pipeline job tracking and stage checkpoints.
Follows the TranscodeJob/ChunkJob pattern.
"""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
class DetectJobStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
PAUSED = "paused"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class RunType(str, Enum):
INITIAL = "initial"
REPLAY = "replay"
RETRY = "retry"
@dataclass
class DetectJob:
"""
A detection pipeline job.
Each invocation of the pipeline (initial run, replay, retry) creates a DetectJob.
Jobs for the same source video are linked via parent_job_id.
"""
id: UUID
# Input
source_asset_id: UUID
video_path: str
profile_name: str = "soccer_broadcast"
# Run lineage
parent_job_id: Optional[UUID] = None # links all runs for the same source
run_type: RunType = RunType.INITIAL
replay_from_stage: Optional[str] = None # null for initial runs
config_overrides: Dict[str, Any] = field(default_factory=dict)
# Status
status: DetectJobStatus = DetectJobStatus.PENDING
current_stage: Optional[str] = None
progress: float = 0.0
error_message: Optional[str] = None
# Results summary
total_detections: int = 0
brands_found: int = 0
cloud_llm_calls: int = 0
estimated_cost_usd: float = 0.0
# Worker tracking
celery_task_id: Optional[str] = None
priority: int = 0
# Timestamps
created_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
@dataclass
class StageCheckpoint:
"""
A checkpoint saved after a pipeline stage completes.
Binary data (frame images, crops) goes to S3/MinIO.
Everything else (structured state) lives here in Postgres.
"""
id: UUID
job_id: UUID
stage: str
stage_index: int # position in NODES list (0-7)
# S3 reference for binary data only
frames_prefix: str = "" # s3 prefix: checkpoints/{job_id}/frames/
# Frame metadata (non-image fields)
frames_manifest: Dict[int, str] = field(default_factory=dict) # seq → s3 key
frames_meta: List[Dict[str, Any]] = field(default_factory=list) # sequence, chunk_id, timestamp, hash
filtered_frame_sequences: List[int] = field(default_factory=list)
# Detection state (full structured data, not just summaries)
boxes_by_frame: Dict[str, List[Dict[str, Any]]] = field(default_factory=dict)
text_candidates: List[Dict[str, Any]] = field(default_factory=list)
unresolved_candidates: List[Dict[str, Any]] = field(default_factory=list)
detections: List[Dict[str, Any]] = field(default_factory=list)
# Pipeline state
stats: Dict[str, Any] = field(default_factory=dict)
config_snapshot: Dict[str, Any] = field(default_factory=dict)
config_overrides: Dict[str, Any] = field(default_factory=dict)
# Input refs (for replay)
video_path: str = ""
profile_name: str = ""
# Timestamps
created_at: Optional[datetime] = None
class BrandSource(str, Enum):
"""How a brand was first identified."""
OCR = "ocr"
VLM = "local_vlm"
CLOUD = "cloud_llm"
MANUAL = "manual" # user-added via UI
@dataclass
class KnownBrand:
"""
A brand discovered or registered in the system.
Global — not per-source. Accumulates across all pipeline runs.
Aliases enable fuzzy matching without re-escalating to VLM.
"""
id: UUID
canonical_name: str # normalized display name
aliases: List[str] = field(default_factory=list) # known spellings/variants
first_source: BrandSource = BrandSource.OCR
total_occurrences: int = 0
confirmed: bool = False # manually confirmed by user
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
@dataclass
class SourceBrandSighting:
"""
A brand seen in a specific source (video/asset).
Per-source session cache — avoids re-escalating the same brand
on subsequent frames or re-runs of the same source.
"""
id: UUID
source_asset_id: UUID # the video this sighting belongs to
brand_id: UUID # FK to KnownBrand
brand_name: str # denormalized for fast lookup
first_seen_timestamp: float = 0.0
last_seen_timestamp: float = 0.0
occurrences: int = 0
detection_source: BrandSource = BrandSource.OCR
avg_confidence: float = 0.0
created_at: Optional[datetime] = None

View File

@@ -28,7 +28,10 @@ GRPC_PORT=50051
GRPC_MAX_WORKERS=10 GRPC_MAX_WORKERS=10
# S3 Storage (MinIO locally, real S3 on AWS) # S3 Storage (MinIO locally, real S3 on AWS)
S3_ENDPOINT_URL=http://minio:9000 # In k8s/docker: http://minio:9000
# On dev machine (port-forward): http://localhost:9000
# On AWS: omit S3_ENDPOINT_URL entirely
S3_ENDPOINT_URL=http://localhost:9000
S3_BUCKET_IN=mpr-media-in S3_BUCKET_IN=mpr-media-in
S3_BUCKET_OUT=mpr-media-out S3_BUCKET_OUT=mpr-media-out
AWS_REGION=us-east-1 AWS_REGION=us-east-1
@@ -44,7 +47,7 @@ CLOUD_LLM_PROVIDER=groq
# Groq (default, free tier) # Groq (default, free tier)
GROQ_API_KEY= GROQ_API_KEY=
GROQ_MODEL=llama-3.2-90b-vision-preview GROQ_MODEL=meta-llama/llama-4-scout-17b-16e-instruct
# Gemini # Gemini
#GEMINI_API_KEY= #GEMINI_API_KEY=

View File

@@ -35,7 +35,15 @@ docker_build(
# --- Resources --- # --- Resources ---
k8s_resource('redis') k8s_resource('redis')
k8s_resource('fastapi', resource_deps=['redis']) k8s_resource('minio', port_forwards=['9000:9000', '9001:9001'])
k8s_resource('fastapi', resource_deps=['redis', 'minio'])
k8s_resource('detection-ui') k8s_resource('detection-ui')
k8s_resource('gateway', resource_deps=['fastapi', 'detection-ui'], k8s_resource('gateway', resource_deps=['fastapi', 'detection-ui'],
port_forwards=['8080:8080']) port_forwards=['8080:8080'])
# Group uncategorized resources (configmaps, namespace) under infra
k8s_resource(
objects=['mpr:namespace', 'mpr-config:configmap', 'minio-config:configmap',
'envoy-gateway-config:configmap'],
new_name='infra',
)

View File

@@ -22,6 +22,8 @@ spec:
envFrom: envFrom:
- configMapRef: - configMapRef:
name: mpr-config name: mpr-config
- configMapRef:
name: minio-config
readinessProbe: readinessProbe:
httpGet: httpGet:
path: /health path: /health

View File

@@ -7,6 +7,7 @@ resources:
- namespace.yaml - namespace.yaml
- configmap.yaml - configmap.yaml
- redis.yaml - redis.yaml
- minio.yaml
- fastapi.yaml - fastapi.yaml
- detection-ui.yaml - detection-ui.yaml
- gateway.yaml - gateway.yaml

87
ctrl/k8s/base/minio.yaml Normal file
View File

@@ -0,0 +1,87 @@
apiVersion: v1
kind: ConfigMap
metadata:
name: minio-config
namespace: mpr
data:
S3_ENDPOINT_URL: http://minio:9000
S3_BUCKET_IN: mpr-media-in
S3_BUCKET_OUT: mpr-media-out
AWS_ACCESS_KEY_ID: minioadmin
AWS_SECRET_ACCESS_KEY: minioadmin
AWS_REGION: us-east-1
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: minio
namespace: mpr
spec:
replicas: 1
selector:
matchLabels:
app: minio
template:
metadata:
labels:
app: minio
spec:
containers:
- name: minio
image: minio/minio:latest
args: ["server", "/data", "--console-address", ":9001"]
ports:
- containerPort: 9000
name: api
- containerPort: 9001
name: console
env:
- name: MINIO_ROOT_USER
value: minioadmin
- name: MINIO_ROOT_PASSWORD
value: minioadmin
readinessProbe:
httpGet:
path: /minio/health/ready
port: 9000
initialDelaySeconds: 5
periodSeconds: 10
lifecycle:
postStart:
exec:
command:
- /bin/sh
- -c
- |
sleep 3
for bucket in mpr-media-in mpr-media-out; do
mkdir -p /data/$bucket
done
volumeMounts:
- name: data
mountPath: /data
resources:
requests:
memory: 128Mi
cpu: 100m
limits:
memory: 512Mi
volumes:
- name: data
emptyDir: {}
---
apiVersion: v1
kind: Service
metadata:
name: minio
namespace: mpr
spec:
selector:
app: minio
ports:
- port: 9000
targetPort: 9000
name: api
- port: 9001
targetPort: 9001
name: console

View File

@@ -0,0 +1,14 @@
"""
Stage checkpoint, replay, and retry.
detect/checkpoint/
frames.py — frame image S3 upload/download
serializer.py — state ↔ JSON conversion
storage.py — checkpoint save/load/list (Postgres + S3)
replay.py — replay_from, OverrideProfile
tasks.py — retry_candidates Celery task
"""
from .storage import save_checkpoint, load_checkpoint, list_checkpoints
from .frames import save_frames, load_frames
from .replay import replay_from, OverrideProfile

View File

@@ -0,0 +1,80 @@
"""Frame image storage — save/load to S3/MinIO as JPEGs."""
from __future__ import annotations
import logging
import os
import tempfile
import numpy as np
from PIL import Image
from detect.models import Frame
logger = logging.getLogger(__name__)
BUCKET = os.environ.get("S3_BUCKET_OUT", "mpr-media-out")
CHECKPOINT_PREFIX = "checkpoints"
def save_frames(job_id: str, frames: list[Frame]) -> dict[int, str]:
"""
Save frame images to S3 as JPEGs.
Returns manifest: {sequence: s3_key}
"""
from core.storage.s3 import upload_file
manifest = {}
for frame in frames:
key = f"{CHECKPOINT_PREFIX}/{job_id}/frames/{frame.sequence}.jpg"
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
img = Image.fromarray(frame.image)
img.save(tmp, format="JPEG", quality=85)
tmp_path = tmp.name
try:
upload_file(tmp_path, BUCKET, key)
finally:
os.unlink(tmp_path)
manifest[frame.sequence] = key
logger.info("Saved %d frames to s3://%s/%s/%s/frames/",
len(frames), BUCKET, CHECKPOINT_PREFIX, job_id)
return manifest
def load_frames(manifest: dict[int, str], frame_metadata: list[dict]) -> list[Frame]:
"""
Load frame images from S3 and reconstitute Frame objects.
frame_metadata: list of dicts with sequence, chunk_id, timestamp, perceptual_hash.
"""
from core.storage.s3 import download_to_temp
meta_map = {m["sequence"]: m for m in frame_metadata}
frames = []
for seq, key in manifest.items():
tmp_path = download_to_temp(BUCKET, key)
try:
img = Image.open(tmp_path).convert("RGB")
image_array = np.array(img)
finally:
os.unlink(tmp_path)
meta = meta_map.get(seq, {})
frame = Frame(
sequence=seq,
chunk_id=meta.get("chunk_id", 0),
timestamp=meta.get("timestamp", 0.0),
image=image_array,
perceptual_hash=meta.get("perceptual_hash", ""),
)
frames.append(frame)
frames.sort(key=lambda f: f.sequence)
return frames

132
detect/checkpoint/replay.py Normal file
View File

@@ -0,0 +1,132 @@
"""
Pipeline replay — re-run from any stage with different config.
Loads a checkpoint, applies config overrides, builds a subgraph
starting from the target stage, and invokes it.
"""
from __future__ import annotations
import logging
import uuid
from detect import emit
from detect.checkpoint import load_checkpoint, list_checkpoints
from detect.graph import NODES, build_graph
logger = logging.getLogger(__name__)
class OverrideProfile:
"""
Wraps a ContentTypeProfile and patches config methods with overrides.
Override dict structure:
{
"frame_extraction": {"fps": 1.0},
"scene_filter": {"hamming_threshold": 12},
"detection": {"confidence_threshold": 0.5},
"ocr": {"languages": ["en", "es"], "min_confidence": 0.3},
"resolver": {"fuzzy_threshold": 60},
}
"""
def __init__(self, base, overrides: dict):
self._base = base
self._overrides = overrides
def __getattr__(self, name):
return getattr(self._base, name)
def _patch(self, config, key: str):
patches = self._overrides.get(key, {})
for k, v in patches.items():
if hasattr(config, k):
setattr(config, k, v)
return config
def frame_extraction_config(self):
return self._patch(self._base.frame_extraction_config(), "frame_extraction")
def scene_filter_config(self):
return self._patch(self._base.scene_filter_config(), "scene_filter")
def detection_config(self):
return self._patch(self._base.detection_config(), "detection")
def ocr_config(self):
return self._patch(self._base.ocr_config(), "ocr")
def resolver_config(self):
return self._patch(self._base.resolver_config(), "resolver")
def vlm_prompt(self, crop_context):
return self._base.vlm_prompt(crop_context)
def aggregate(self, detections):
return self._base.aggregate(detections)
def auxiliary_detections(self, source):
return self._base.auxiliary_detections(source)
def replay_from(
job_id: str,
start_stage: str,
config_overrides: dict | None = None,
checkpoint: bool = True,
) -> dict:
"""
Replay the pipeline from a specific stage.
Loads the checkpoint from the stage immediately before start_stage,
applies config overrides, and runs the subgraph from start_stage onward.
Returns the final state dict.
"""
if start_stage not in NODES:
raise ValueError(f"Unknown stage: {start_stage!r}. Options: {NODES}")
start_idx = NODES.index(start_stage)
# Load checkpoint from the stage before start_stage
if start_idx == 0:
raise ValueError("Cannot replay from the first stage — just run the full pipeline")
previous_stage = NODES[start_idx - 1]
available = list_checkpoints(job_id)
if previous_stage not in available:
raise ValueError(
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
f"Available: {available}"
)
logger.info("Replaying job %s from %s (loading checkpoint: %s)",
job_id, start_stage, previous_stage)
state = load_checkpoint(job_id, previous_stage)
# Apply config overrides
if config_overrides:
state["config_overrides"] = config_overrides
# Set run context for SSE events
run_id = str(uuid.uuid4())[:8]
emit.set_run_context(
run_id=run_id,
parent_job_id=job_id,
run_type="replay",
)
# Build subgraph starting from start_stage
graph = build_graph(checkpoint=checkpoint, start_from=start_stage)
pipeline = graph.compile()
try:
result = pipeline.invoke(state)
finally:
emit.clear_run_context()
return result

View File

@@ -0,0 +1,133 @@
"""State serialization — DetectState ↔ JSON-compatible dict."""
from __future__ import annotations
import dataclasses
from detect.models import (
BoundingBox,
BrandDetection,
Frame,
PipelineStats,
TextCandidate,
)
# ---------------------------------------------------------------------------
# Serialize helpers
# ---------------------------------------------------------------------------
def serialize_frame_meta(frame: Frame) -> dict:
meta = {
"sequence": frame.sequence,
"chunk_id": frame.chunk_id,
"timestamp": frame.timestamp,
"perceptual_hash": frame.perceptual_hash,
}
return meta
def serialize_text_candidate(tc: TextCandidate) -> dict:
bbox_dict = dataclasses.asdict(tc.bbox)
candidate = {
"frame_sequence": tc.frame.sequence,
"bbox": bbox_dict,
"text": tc.text,
"ocr_confidence": tc.ocr_confidence,
}
return candidate
def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
"""
Serialize DetectState to a JSON-compatible dict.
Frame images are replaced with S3 key references.
TextCandidate.frame references become frame_sequence integers.
"""
frames = state.get("frames", [])
filtered = state.get("filtered_frames", [])
manifest_strs = {str(k): v for k, v in frames_manifest.items()}
frames_meta = [serialize_frame_meta(f) for f in frames]
filtered_seqs = [f.sequence for f in filtered]
boxes_serialized = {}
for seq, boxes in state.get("boxes_by_frame", {}).items():
boxes_serialized[str(seq)] = [dataclasses.asdict(b) for b in boxes]
text_candidates = [serialize_text_candidate(tc) for tc in state.get("text_candidates", [])]
unresolved = [serialize_text_candidate(tc) for tc in state.get("unresolved_candidates", [])]
detections = [dataclasses.asdict(d) for d in state.get("detections", [])]
stats = dataclasses.asdict(state.get("stats", PipelineStats()))
checkpoint = {
"job_id": state.get("job_id", ""),
"video_path": state.get("video_path", ""),
"profile_name": state.get("profile_name", ""),
"config_overrides": state.get("config_overrides", {}),
"frames_manifest": manifest_strs,
"frames_meta": frames_meta,
"filtered_frame_sequences": filtered_seqs,
"boxes_by_frame": boxes_serialized,
"text_candidates": text_candidates,
"unresolved_candidates": unresolved,
"detections": detections,
"stats": stats,
}
return checkpoint
# ---------------------------------------------------------------------------
# Deserialize helpers
# ---------------------------------------------------------------------------
def deserialize_text_candidate(d: dict, frame_map: dict[int, Frame]) -> TextCandidate:
frame = frame_map[d["frame_sequence"]]
bbox = BoundingBox(**d["bbox"])
candidate = TextCandidate(
frame=frame,
bbox=bbox,
text=d["text"],
ocr_confidence=d["ocr_confidence"],
)
return candidate
def deserialize_state(checkpoint: dict, frames: list[Frame]) -> dict:
"""Reconstitute DetectState from a checkpoint dict + loaded frames."""
frame_map = {f.sequence: f for f in frames}
filtered_seqs = set(checkpoint.get("filtered_frame_sequences", []))
filtered_frames = [f for f in frames if f.sequence in filtered_seqs]
boxes_by_frame = {}
for seq_str, box_dicts in checkpoint.get("boxes_by_frame", {}).items():
seq = int(seq_str)
boxes_by_frame[seq] = [BoundingBox(**b) for b in box_dicts]
text_candidates = [
deserialize_text_candidate(d, frame_map)
for d in checkpoint.get("text_candidates", [])
]
unresolved_candidates = [
deserialize_text_candidate(d, frame_map)
for d in checkpoint.get("unresolved_candidates", [])
]
detections = [BrandDetection(**d) for d in checkpoint.get("detections", [])]
stats = PipelineStats(**checkpoint.get("stats", {}))
state = {
"job_id": checkpoint.get("job_id", ""),
"video_path": checkpoint.get("video_path", ""),
"profile_name": checkpoint.get("profile_name", ""),
"config_overrides": checkpoint.get("config_overrides", {}),
"frames": frames,
"filtered_frames": filtered_frames,
"boxes_by_frame": boxes_by_frame,
"text_candidates": text_candidates,
"unresolved_candidates": unresolved_candidates,
"detections": detections,
"stats": stats,
}
return state

View File

@@ -0,0 +1,215 @@
"""
Checkpoint storage — save/load stage state.
Binary data (frame images) → S3/MinIO via frames.py
Structured data (boxes, detections, stats, config) → Postgres via Django ORM
Until the Django model is generated by modelgen, checkpoint data is stored
as JSON in S3 as a fallback. Once DetectJob/StageCheckpoint models exist,
this module switches to Postgres.
"""
from __future__ import annotations
import json
import logging
import os
import tempfile
from pathlib import Path
from .frames import save_frames, load_frames, BUCKET, CHECKPOINT_PREFIX
from .serializer import serialize_state, deserialize_state
logger = logging.getLogger(__name__)
def _has_db() -> bool:
"""Check if the DB layer is available (Django + models generated by modelgen)."""
try:
from core.db.detect import get_stage_checkpoint as _
# Quick check that the model exists (modelgen may not have run yet)
from admin.mpr.media_assets.models import StageCheckpoint as _
return True
except (ImportError, Exception):
return False
# ---------------------------------------------------------------------------
# Save
# ---------------------------------------------------------------------------
def save_checkpoint(
job_id: str,
stage: str,
stage_index: int,
state: dict,
frames_manifest: dict[int, str] | None = None,
) -> str:
"""
Save a stage checkpoint.
Saves frame images to S3 (if not already saved), then persists
structured state to Postgres (or S3 JSON fallback).
Returns the checkpoint identifier (DB id or S3 key).
"""
# Save frames to S3 if no manifest provided
if frames_manifest is None:
all_frames = state.get("frames", [])
frames_manifest = save_frames(job_id, all_frames)
checkpoint_data = serialize_state(state, frames_manifest)
if _has_db():
checkpoint_id = _save_to_db(job_id, stage, stage_index, checkpoint_data)
else:
checkpoint_id = _save_to_s3(job_id, stage, checkpoint_data)
return checkpoint_id
def _save_to_db(job_id: str, stage: str, stage_index: int, data: dict) -> str:
"""Save checkpoint structured data to Postgres via core/db."""
import uuid
from core.db.detect import save_stage_checkpoint
job_uuid = uuid.UUID(job_id) if isinstance(job_id, str) else job_id
checkpoint_id = uuid.uuid4()
frames_prefix = f"{CHECKPOINT_PREFIX}/{job_id}/frames/"
checkpoint = save_stage_checkpoint(
id=checkpoint_id,
job_id=job_uuid,
stage=stage,
stage_index=stage_index,
frames_prefix=frames_prefix,
frames_manifest=data.get("frames_manifest", {}),
frames_meta=data.get("frames_meta", []),
filtered_frame_sequences=data.get("filtered_frame_sequences", []),
boxes_by_frame=data.get("boxes_by_frame", {}),
text_candidates=data.get("text_candidates", []),
unresolved_candidates=data.get("unresolved_candidates", []),
detections=data.get("detections", []),
stats=data.get("stats", {}),
config_snapshot=data.get("config_overrides", {}),
config_overrides=data.get("config_overrides", {}),
video_path=data.get("video_path", ""),
profile_name=data.get("profile_name", ""),
)
logger.info("Checkpoint saved to DB: %s/%s (id=%s)", job_id, stage, checkpoint.id)
return str(checkpoint.id)
def _save_to_s3(job_id: str, stage: str, data: dict) -> str:
"""Fallback: save checkpoint as JSON to S3 (before modelgen generates DB models)."""
from core.storage.s3 import upload_file
checkpoint_json = json.dumps(data, default=str)
key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json"
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp:
tmp.write(checkpoint_json)
tmp_path = tmp.name
try:
upload_file(tmp_path, BUCKET, key)
finally:
os.unlink(tmp_path)
logger.info("Checkpoint saved to S3: s3://%s/%s", BUCKET, key)
return key
# ---------------------------------------------------------------------------
# Load
# ---------------------------------------------------------------------------
def load_checkpoint(job_id: str, stage: str) -> dict:
"""
Load a stage checkpoint and reconstitute full DetectState.
Tries Postgres first, falls back to S3 JSON.
"""
if _has_db():
data = _load_from_db(job_id, stage)
else:
data = _load_from_s3(job_id, stage)
raw_manifest = data.get("frames_manifest", {})
manifest = {int(k): v for k, v in raw_manifest.items()}
frame_metadata = data.get("frames_meta", [])
frames = load_frames(manifest, frame_metadata)
state = deserialize_state(data, frames)
logger.info("Checkpoint loaded: %s/%s (%d frames)", job_id, stage, len(frames))
return state
def _load_from_db(job_id: str, stage: str) -> dict:
"""Load checkpoint data from Postgres via core/db."""
from core.db.detect import get_stage_checkpoint
checkpoint = get_stage_checkpoint(job_id, stage)
data = {
"job_id": str(checkpoint.job_id),
"video_path": checkpoint.video_path,
"profile_name": checkpoint.profile_name,
"config_overrides": checkpoint.config_overrides,
"frames_manifest": checkpoint.frames_manifest,
"frames_meta": checkpoint.frames_meta,
"filtered_frame_sequences": checkpoint.filtered_frame_sequences,
"boxes_by_frame": checkpoint.boxes_by_frame,
"text_candidates": checkpoint.text_candidates,
"unresolved_candidates": checkpoint.unresolved_candidates,
"detections": checkpoint.detections,
"stats": checkpoint.stats,
}
return data
def _load_from_s3(job_id: str, stage: str) -> dict:
"""Fallback: load checkpoint JSON from S3."""
from core.storage.s3 import download_to_temp
key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json"
tmp_path = download_to_temp(BUCKET, key)
try:
with open(tmp_path) as f:
data = json.load(f)
finally:
os.unlink(tmp_path)
return data
# ---------------------------------------------------------------------------
# List
# ---------------------------------------------------------------------------
def list_checkpoints(job_id: str) -> list[str]:
"""List available checkpoint stages for a job."""
if _has_db():
return _list_from_db(job_id)
return _list_from_s3(job_id)
def _list_from_db(job_id: str) -> list[str]:
from core.db.detect import list_stage_checkpoints
return list_stage_checkpoints(job_id)
def _list_from_s3(job_id: str) -> list[str]:
from core.storage.s3 import list_objects
prefix = f"{CHECKPOINT_PREFIX}/{job_id}/stages/"
objects = list_objects(BUCKET, prefix)
stages = []
for obj in objects:
name = Path(obj["key"]).stem
stages.append(name)
return stages

View File

@@ -0,0 +1,71 @@
"""
Celery tasks for detection pipeline async operations.
retry_candidates: re-run VLM/cloud escalation with different config.
"""
from __future__ import annotations
import logging
import uuid
from datetime import datetime, timezone
from celery import shared_task
logger = logging.getLogger(__name__)
@shared_task(bind=True, max_retries=1, default_retry_delay=30)
def retry_candidates(
self,
job_id: str,
config_overrides: dict | None = None,
start_stage: str = "escalate_vlm",
):
"""
Retry unresolved candidates with different config.
Loads the checkpoint from the stage before start_stage,
applies config overrides (e.g. different cloud provider),
and runs from start_stage onward.
"""
from detect.checkpoint.replay import replay_from
run_id = str(uuid.uuid4())[:8]
logger.info("Retry task %s: job=%s, from=%s, overrides=%s",
run_id, job_id, start_stage, config_overrides)
try:
result = replay_from(
job_id=job_id,
start_stage=start_stage,
config_overrides=config_overrides,
)
detections = result.get("detections", [])
report = result.get("report")
brands_found = len(report.brands) if report else 0
logger.info("Retry %s complete: %d detections, %d brands",
run_id, len(detections), brands_found)
return {
"status": "completed",
"run_id": run_id,
"job_id": job_id,
"detections": len(detections),
"brands_found": brands_found,
}
except Exception as e:
logger.exception("Retry %s failed: %s", run_id, e)
if self.request.retries < self.max_retries:
raise self.retry(exc=e)
return {
"status": "failed",
"run_id": run_id,
"job_id": job_id,
"error": str(e),
}

View File

@@ -3,6 +3,9 @@ Event emission helpers for detection pipeline stages.
Single place that knows how to build event payloads. Single place that knows how to build event payloads.
Stages call these instead of constructing dicts or dataclasses directly. Stages call these instead of constructing dicts or dataclasses directly.
Run context (run_id, parent_job_id) is set once at pipeline start via
set_run_context() and automatically injected into all events.
""" """
from __future__ import annotations from __future__ import annotations
@@ -13,9 +16,33 @@ from datetime import datetime, timezone
from detect.events import push_detect_event from detect.events import push_detect_event
from detect.models import PipelineStats from detect.models import PipelineStats
# Module-level run context — set once per pipeline invocation
_run_context: dict = {}
def set_run_context(run_id: str = "", parent_job_id: str = "", run_type: str = "initial"):
"""Set the run context for all subsequent events in this pipeline invocation."""
global _run_context
_run_context = {
"run_id": run_id,
"parent_job_id": parent_job_id,
"run_type": run_type,
}
def clear_run_context():
global _run_context
_run_context = {}
def _inject_context(payload: dict) -> dict:
"""Add run context fields to an event payload."""
if _run_context:
payload.update(_run_context)
return payload
def log(job_id: str | None, stage: str, level: str, msg: str) -> None: def log(job_id: str | None, stage: str, level: str, msg: str) -> None:
"""Emit a log event."""
if not job_id: if not job_id:
return return
payload = { payload = {
@@ -24,15 +51,17 @@ def log(job_id: str | None, stage: str, level: str, msg: str) -> None:
"msg": msg, "msg": msg,
"ts": datetime.now(timezone.utc).isoformat(), "ts": datetime.now(timezone.utc).isoformat(),
} }
_inject_context(payload)
push_detect_event(job_id, "log", payload) push_detect_event(job_id, "log", payload)
def stats(job_id: str | None, **kwargs) -> None: def stats(job_id: str | None, **kwargs) -> None:
"""Emit a stats_update event. Pass only the fields that changed."""
if not job_id: if not job_id:
return return
s = PipelineStats(**kwargs) s = PipelineStats(**kwargs)
push_detect_event(job_id, "stats_update", dataclasses.asdict(s)) payload = dataclasses.asdict(s)
_inject_context(payload)
push_detect_event(job_id, "stats_update", payload)
def frame_update( def frame_update(
@@ -42,7 +71,6 @@ def frame_update(
jpeg_b64: str, jpeg_b64: str,
boxes: list[dict], boxes: list[dict],
) -> None: ) -> None:
"""Emit a frame_update event with the image and bounding boxes."""
if not job_id: if not job_id:
return return
payload = { payload = {
@@ -51,14 +79,15 @@ def frame_update(
"jpeg_b64": jpeg_b64, "jpeg_b64": jpeg_b64,
"boxes": boxes, "boxes": boxes,
} }
_inject_context(payload)
push_detect_event(job_id, "frame_update", payload) push_detect_event(job_id, "frame_update", payload)
def graph_update(job_id: str | None, nodes: list[dict]) -> None: def graph_update(job_id: str | None, nodes: list[dict]) -> None:
"""Emit a graph_update event with node states."""
if not job_id: if not job_id:
return return
payload = {"nodes": nodes} payload = {"nodes": nodes}
_inject_context(payload)
push_detect_event(job_id, "graph_update", payload) push_detect_event(job_id, "graph_update", payload)
@@ -72,7 +101,6 @@ def detection(
content_type: str = "", content_type: str = "",
frame_ref: int | None = None, frame_ref: int | None = None,
) -> None: ) -> None:
"""Emit a brand detection event."""
if not job_id: if not job_id:
return return
payload = { payload = {
@@ -84,12 +112,13 @@ def detection(
"content_type": content_type, "content_type": content_type,
"frame_ref": frame_ref, "frame_ref": frame_ref,
} }
_inject_context(payload)
push_detect_event(job_id, "detection", payload) push_detect_event(job_id, "detection", payload)
def job_complete(job_id: str | None, report: dict) -> None: def job_complete(job_id: str | None, report: dict) -> None:
"""Emit a job_complete event with the final report."""
if not job_id: if not job_id:
return return
payload = {"job_id": job_id, "report": report} payload = {"job_id": job_id, "report": report}
_inject_context(payload)
push_detect_event(job_id, "job_complete", payload) push_detect_event(job_id, "job_complete", payload)

View File

@@ -42,8 +42,16 @@ NODES = [
def _get_profile(state: DetectState): def _get_profile(state: DetectState):
name = state.get("profile_name", "soccer_broadcast") name = state.get("profile_name", "soccer_broadcast")
if name == "soccer_broadcast": if name == "soccer_broadcast":
return SoccerBroadcastProfile() profile = SoccerBroadcastProfile()
raise ValueError(f"Unknown profile: {name}") else:
raise ValueError(f"Unknown profile: {name}")
overrides = state.get("config_overrides")
if overrides:
from detect.checkpoint.replay import OverrideProfile
profile = OverrideProfile(profile, overrides)
return profile
# Track node states across the pipeline run # Track node states across the pipeline run
@@ -68,6 +76,18 @@ def _emit_transition(state: DetectState, node: str, status: str):
# --- Node functions --- # --- Node functions ---
def node_extract_frames(state: DetectState) -> dict: def node_extract_frames(state: DetectState) -> dict:
# Set run context for initial runs (replays set it in replay_from)
job_id = state.get("job_id", "")
if job_id and not emit._run_context:
emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial")
# Load session brands from DB for this source
source_asset_id = state.get("source_asset_id")
if source_asset_id and not state.get("session_brands"):
from detect.stages.brand_resolver import build_session_dict
session_brands = build_session_dict(source_asset_id)
state["session_brands"] = session_brands
_emit_transition(state, "extract_frames", "running") _emit_transition(state, "extract_frames", "running")
with trace_node(state, "extract_frames") as span: with trace_node(state, "extract_frames") as span:
@@ -142,13 +162,16 @@ def node_match_brands(state: DetectState) -> dict:
with trace_node(state, "match_brands") as span: with trace_node(state, "match_brands") as span:
profile = _get_profile(state) profile = _get_profile(state)
dictionary = profile.brand_dictionary()
resolver_config = profile.resolver_config() resolver_config = profile.resolver_config()
candidates = state.get("text_candidates", []) candidates = state.get("text_candidates", [])
session_brands = state.get("session_brands", {})
job_id = state.get("job_id") job_id = state.get("job_id")
source_asset_id = state.get("source_asset_id")
matched, unresolved = resolve_brands( matched, unresolved = resolve_brands(
candidates, dictionary, resolver_config, candidates, resolver_config,
session_brands=session_brands,
source_asset_id=source_asset_id,
content_type=profile.name, job_id=job_id, content_type=profile.name, job_id=job_id,
) )
span.set_output({"matched": len(matched), "unresolved": len(unresolved)}) span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
@@ -170,6 +193,7 @@ def node_escalate_vlm(state: DetectState) -> dict:
vlm_prompt_fn=profile.vlm_prompt, vlm_prompt_fn=profile.vlm_prompt,
inference_url=INFERENCE_URL, inference_url=INFERENCE_URL,
content_type=profile.name, content_type=profile.name,
source_asset_id=state.get("source_asset_id"),
job_id=job_id, job_id=job_id,
) )
@@ -202,6 +226,7 @@ def node_escalate_cloud(state: DetectState) -> dict:
vlm_prompt_fn=profile.vlm_prompt, vlm_prompt_fn=profile.vlm_prompt,
stats=stats, stats=stats,
content_type=profile.name, content_type=profile.name,
source_asset_id=state.get("source_asset_id"),
job_id=job_id, job_id=job_id,
) )
@@ -239,33 +264,87 @@ def node_compile_report(state: DetectState) -> dict:
return {"report": report} return {"report": report}
# --- Checkpoint wrapper ---
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
def _checkpointing_node(node_name: str, node_fn):
"""Wrap a node function to auto-checkpoint after completion."""
stage_index = NODES.index(node_name)
def wrapper(state: DetectState) -> dict:
result = node_fn(state)
job_id = state.get("job_id", "")
if not job_id:
return result
from detect.checkpoint import save_checkpoint, save_frames
merged = {**state, **result}
# Save frames once (first checkpoint), reuse manifest after
manifest = _frames_manifest.get(job_id)
if manifest is None and node_name == "extract_frames":
manifest = save_frames(job_id, merged.get("frames", []))
_frames_manifest[job_id] = manifest
save_checkpoint(job_id, node_name, stage_index, merged, frames_manifest=manifest)
return result
wrapper.__name__ = node_fn.__name__
return wrapper
# --- Graph construction --- # --- Graph construction ---
def build_graph() -> StateGraph: NODE_FUNCTIONS = [
("extract_frames", node_extract_frames),
("filter_scenes", node_filter_scenes),
("detect_objects", node_detect_objects),
("run_ocr", node_run_ocr),
("match_brands", node_match_brands),
("escalate_vlm", node_escalate_vlm),
("escalate_cloud", node_escalate_cloud),
("compile_report", node_compile_report),
]
def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph:
"""
Build the pipeline graph.
checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var)
start_from: skip nodes before this stage (for replay)
"""
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
graph = StateGraph(DetectState) graph = StateGraph(DetectState)
graph.add_node("extract_frames", node_extract_frames) # Filter to start_from if replaying
graph.add_node("filter_scenes", node_filter_scenes) node_pairs = NODE_FUNCTIONS
graph.add_node("detect_objects", node_detect_objects) if start_from:
graph.add_node("run_ocr", node_run_ocr) start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from)
graph.add_node("match_brands", node_match_brands) node_pairs = NODE_FUNCTIONS[start_idx:]
graph.add_node("escalate_vlm", node_escalate_vlm)
graph.add_node("escalate_cloud", node_escalate_cloud)
graph.add_node("compile_report", node_compile_report)
graph.set_entry_point("extract_frames") for name, fn in node_pairs:
graph.add_edge("extract_frames", "filter_scenes") wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn
graph.add_edge("filter_scenes", "detect_objects") graph.add_node(name, wrapped)
graph.add_edge("detect_objects", "run_ocr")
graph.add_edge("run_ocr", "match_brands") # Wire edges
graph.add_edge("match_brands", "escalate_vlm") entry = node_pairs[0][0]
graph.add_edge("escalate_vlm", "escalate_cloud") graph.set_entry_point(entry)
graph.add_edge("escalate_cloud", "compile_report")
graph.add_edge("compile_report", END) for i in range(len(node_pairs) - 1):
graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0])
graph.add_edge(node_pairs[-1][0], END)
return graph return graph
def get_pipeline(): def get_pipeline(checkpoint: bool | None = None):
"""Return a compiled, runnable pipeline.""" """Return a compiled, runnable pipeline."""
return build_graph().compile() return build_graph(checkpoint=checkpoint).compile()

View File

@@ -1,6 +1,5 @@
from .base import ( from .base import (
ContentTypeProfile, ContentTypeProfile,
BrandDictionary,
CropContext, CropContext,
DetectionConfig, DetectionConfig,
FrameExtractionConfig, FrameExtractionConfig,
@@ -12,7 +11,6 @@ from .soccer import SoccerBroadcastProfile
__all__ = [ __all__ = [
"ContentTypeProfile", "ContentTypeProfile",
"BrandDictionary",
"CropContext", "CropContext",
"DetectionConfig", "DetectionConfig",
"FrameExtractionConfig", "FrameExtractionConfig",

View File

@@ -44,12 +44,6 @@ class ResolverConfig:
fuzzy_threshold: int = 75 fuzzy_threshold: int = 75
@dataclass
class BrandDictionary:
"""Maps canonical brand name → list of known aliases/spellings."""
brands: dict[str, list[str]] = field(default_factory=dict)
@dataclass @dataclass
class CropContext: class CropContext:
image: bytes image: bytes
@@ -64,7 +58,6 @@ class ContentTypeProfile(Protocol):
def scene_filter_config(self) -> SceneFilterConfig: ... def scene_filter_config(self) -> SceneFilterConfig: ...
def detection_config(self) -> DetectionConfig: ... def detection_config(self) -> DetectionConfig: ...
def ocr_config(self) -> OCRConfig: ... def ocr_config(self) -> OCRConfig: ...
def brand_dictionary(self) -> BrandDictionary: ...
def resolver_config(self) -> ResolverConfig: ... def resolver_config(self) -> ResolverConfig: ...
def vlm_prompt(self, crop_context: CropContext) -> str: ... def vlm_prompt(self, crop_context: CropContext) -> str: ...
def aggregate(self, detections: list[BrandDetection]) -> DetectionReport: ... def aggregate(self, detections: list[BrandDetection]) -> DetectionReport: ...

View File

@@ -5,7 +5,6 @@ from __future__ import annotations
from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
from .base import ( from .base import (
BrandDictionary,
CropContext, CropContext,
DetectionConfig, DetectionConfig,
FrameExtractionConfig, FrameExtractionConfig,
@@ -34,22 +33,6 @@ class SoccerBroadcastProfile:
def ocr_config(self) -> OCRConfig: def ocr_config(self) -> OCRConfig:
return OCRConfig(languages=["en", "es"], min_confidence=0.5) return OCRConfig(languages=["en", "es"], min_confidence=0.5)
def brand_dictionary(self) -> BrandDictionary:
return BrandDictionary(brands={
"Nike": ["nike", "NIKE", "swoosh"],
"Adidas": ["adidas", "ADIDAS", "adi"],
"Puma": ["puma", "PUMA"],
"Emirates": ["emirates", "fly emirates", "EMIRATES"],
"Coca-Cola": ["coca-cola", "coca cola", "coke", "COCA-COLA"],
"Pepsi": ["pepsi", "PEPSI"],
"Mastercard": ["mastercard", "MASTERCARD"],
"Heineken": ["heineken", "HEINEKEN"],
"Santander": ["santander", "SANTANDER"],
"Gazprom": ["gazprom", "GAZPROM"],
"Qatar Airways": ["qatar airways", "QATAR AIRWAYS"],
"Lay's": ["lays", "lay's", "LAYS", "LAY'S"],
})
def resolver_config(self) -> ResolverConfig: def resolver_config(self) -> ResolverConfig:
return ResolverConfig(fuzzy_threshold=75) return ResolverConfig(fuzzy_threshold=75)

View File

@@ -5,7 +5,6 @@ from __future__ import annotations
from detect.models import BrandDetection, DetectionReport from detect.models import BrandDetection, DetectionReport
from .base import ( from .base import (
BrandDictionary,
CropContext, CropContext,
DetectionConfig, DetectionConfig,
FrameExtractionConfig, FrameExtractionConfig,
@@ -30,9 +29,6 @@ class NewsBroadcastProfile:
def ocr_config(self) -> OCRConfig: def ocr_config(self) -> OCRConfig:
raise NotImplementedError raise NotImplementedError
def brand_dictionary(self) -> BrandDictionary:
raise NotImplementedError
def resolver_config(self) -> ResolverConfig: def resolver_config(self) -> ResolverConfig:
raise NotImplementedError raise NotImplementedError
@@ -61,9 +57,6 @@ class AdvertisingProfile:
def ocr_config(self) -> OCRConfig: def ocr_config(self) -> OCRConfig:
raise NotImplementedError raise NotImplementedError
def brand_dictionary(self) -> BrandDictionary:
raise NotImplementedError
def resolver_config(self) -> ResolverConfig: def resolver_config(self) -> ResolverConfig:
raise NotImplementedError raise NotImplementedError
@@ -92,9 +85,6 @@ class TranscriptProfile:
def ocr_config(self) -> OCRConfig: def ocr_config(self) -> OCRConfig:
raise NotImplementedError raise NotImplementedError
def brand_dictionary(self) -> BrandDictionary:
raise NotImplementedError
def resolver_config(self) -> ResolverConfig: def resolver_config(self) -> ResolverConfig:
raise NotImplementedError raise NotImplementedError

View File

@@ -1,9 +1,17 @@
""" """
Stage 5 — Brand Resolver Stage 5 — Brand Resolver (discovery mode)
Matches OCR text against the profile's brand dictionary. Discovery-first brand matching. No static dictionary — all brands live in the DB.
Uses exact matching first, then fuzzy matching (rapidfuzz) as fallback.
Emits detection events for confirmed brands. Flow:
1. Check session sightings first (brands already seen in this source)
2. Check global known brands (accumulated across all runs)
3. Unresolved candidates → escalate to VLM/cloud
4. Confirmed brands get added to DB for future runs
The resolver is an enricher, not a gatekeeper. Every OCR text candidate
passes through — the question is whether we can resolve it cheaply (DB lookup)
or need to escalate (VLM/cloud).
""" """
from __future__ import annotations from __future__ import annotations
@@ -14,99 +22,199 @@ from rapidfuzz import fuzz
from detect import emit from detect import emit
from detect.models import BrandDetection, TextCandidate from detect.models import BrandDetection, TextCandidate
from detect.profiles.base import BrandDictionary, ResolverConfig from detect.profiles.base import ResolverConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _normalize(text: str) -> str: def _normalize(text: str) -> str:
"""Normalize text for matching."""
return text.strip().lower() return text.strip().lower()
def _exact_match(text: str, dictionary: BrandDictionary) -> str | None: def _has_db() -> bool:
"""Try exact match against all aliases.""" try:
from core.db.detect import find_brand_by_text as _
from admin.mpr.media_assets.models import KnownBrand as _
return True
except (ImportError, Exception):
return False
def _match_session(text: str, session_brands: dict[str, str]) -> str | None:
"""
Check against session brands (already seen in this source).
session_brands: {normalized_name: canonical_name, ...}
Includes aliases.
"""
normalized = _normalize(text) normalized = _normalize(text)
for canonical, aliases in dictionary.brands.items(): return session_brands.get(normalized)
if normalized == _normalize(canonical):
return canonical
for alias in aliases:
if normalized == _normalize(alias):
return canonical
return None
def _fuzzy_match(text: str, dictionary: BrandDictionary, threshold: int) -> tuple[str | None, int]: def _match_known(text: str, threshold: int) -> tuple[str | None, str | None]:
"""Try fuzzy match, return (brand, score) or (None, 0).""" """
Check against global known brands in DB.
Returns (canonical_name, brand_id) or (None, None).
"""
if not _has_db():
return None, None
from core.db.detect import find_brand_by_text
brand = find_brand_by_text(text)
if brand:
return brand.canonical_name, str(brand.id)
# Fuzzy match against all known brands
from core.db.detect import list_all_brands
all_brands = list_all_brands()
normalized = _normalize(text) normalized = _normalize(text)
best_brand = None best_brand = None
best_score = 0 best_score = 0
for canonical, aliases in dictionary.brands.items(): for known in all_brands:
all_names = [canonical] + aliases names = [known.canonical_name] + (known.aliases or [])
for name in all_names: for name in names:
score = fuzz.ratio(normalized, _normalize(name)) score = fuzz.ratio(normalized, _normalize(name))
if score > best_score and score >= threshold: if score > best_score and score >= threshold:
best_score = score best_score = score
best_brand = canonical best_brand = known
return best_brand, best_score if best_brand:
return best_brand.canonical_name, str(best_brand.id)
return None, None
def _register_brand(canonical_name: str, source: str) -> str | None:
"""Register a newly discovered brand in the DB. Returns brand_id."""
if not _has_db():
return None
from core.db.detect import get_or_create_brand
brand, created = get_or_create_brand(canonical_name, source=source)
if created:
logger.info("New brand discovered: %s (source=%s)", canonical_name, source)
return str(brand.id)
def _record_sighting(source_asset_id: str | None, brand_id: str,
brand_name: str, timestamp: float,
confidence: float, source: str):
"""Record a brand sighting for this source."""
if not _has_db() or not source_asset_id:
return
from core.db.detect import record_sighting
import uuid
asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id
brand_uuid = uuid.UUID(brand_id) if isinstance(brand_id, str) else brand_id
record_sighting(asset_id, brand_uuid, brand_name, timestamp, confidence, source)
def build_session_dict(source_asset_id: str | None) -> dict[str, str]:
"""
Load session brands from DB for this source.
Returns {normalized_name: canonical_name, ...} including aliases.
"""
if not _has_db() or not source_asset_id:
return {}
from core.db.detect import get_source_sightings
import uuid
asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id
sightings = get_source_sightings(asset_id)
session = {}
for s in sightings:
canonical = s.brand_name
session[_normalize(canonical)] = canonical
# Also load aliases from KnownBrand for each sighted brand
if _has_db():
from core.db.detect import list_all_brands
all_brands = list_all_brands()
sighted_names = {s.brand_name for s in sightings}
for brand in all_brands:
if brand.canonical_name in sighted_names:
for alias in (brand.aliases or []):
session[_normalize(alias)] = brand.canonical_name
return session
def resolve_brands( def resolve_brands(
candidates: list[TextCandidate], candidates: list[TextCandidate],
dictionary: BrandDictionary,
config: ResolverConfig, config: ResolverConfig,
session_brands: dict[str, str] | None = None,
source_asset_id: str | None = None,
content_type: str = "", content_type: str = "",
job_id: str | None = None, job_id: str | None = None,
) -> tuple[list[BrandDetection], list[TextCandidate]]: ) -> tuple[list[BrandDetection], list[TextCandidate]]:
""" """
Match text candidates against the brand dictionary. Match text candidates against known brands (session → global → unresolved).
Returns: session_brands: pre-loaded session dict (from build_session_dict)
- matched: list of BrandDetection for confirmed brands source_asset_id: for recording new sightings in DB
- unresolved: list of TextCandidate that couldn't be matched
""" """
if session_brands is None:
session_brands = {}
emit.log(job_id, "BrandResolver", "INFO", emit.log(job_id, "BrandResolver", "INFO",
f"Matching {len(candidates)} candidates against " f"Resolving {len(candidates)} candidates "
f"{len(dictionary.brands)} brands (fuzzy_threshold={config.fuzzy_threshold})") f"(session={len(session_brands)} brands, fuzzy={config.fuzzy_threshold})")
matched: list[BrandDetection] = [] matched: list[BrandDetection] = []
unresolved: list[TextCandidate] = [] unresolved: list[TextCandidate] = []
exact_count = 0 session_hits = 0
fuzzy_count = 0 known_hits = 0
for candidate in candidates: for candidate in candidates:
# Try exact match first text = candidate.text
brand = _exact_match(candidate.text, dictionary) brand_name = None
source = "ocr" brand_id = None
match_source = "ocr"
if brand: # 1. Check session (cheapest — in-memory dict)
exact_count += 1 brand_name = _match_session(text, session_brands)
if brand_name:
session_hits += 1
else: else:
# Try fuzzy match # 2. Check global known brands (DB query + fuzzy)
brand, score = _fuzzy_match(candidate.text, dictionary, config.fuzzy_threshold) brand_name, brand_id = _match_known(text, config.fuzzy_threshold)
if brand: if brand_name:
fuzzy_count += 1 known_hits += 1
# Add to session for subsequent candidates in this run
session_brands[_normalize(brand_name)] = brand_name
if brand: if brand_name:
detection = BrandDetection( detection = BrandDetection(
brand=brand, brand=brand_name,
timestamp=candidate.frame.timestamp, timestamp=candidate.frame.timestamp,
duration=0.5, duration=0.5,
confidence=candidate.ocr_confidence, confidence=candidate.ocr_confidence,
source=source, source=match_source,
bbox=candidate.bbox, bbox=candidate.bbox,
frame_ref=candidate.frame.sequence, frame_ref=candidate.frame.sequence,
content_type=content_type, content_type=content_type,
) )
matched.append(detection) matched.append(detection)
# Record sighting in DB
if brand_id:
_record_sighting(
source_asset_id, brand_id, brand_name,
candidate.frame.timestamp, candidate.ocr_confidence, match_source,
)
emit.detection( emit.detection(
job_id, job_id,
brand=brand, brand=brand_name,
confidence=candidate.ocr_confidence, confidence=candidate.ocr_confidence,
source=source, source=match_source,
timestamp=candidate.frame.timestamp, timestamp=candidate.frame.timestamp,
content_type=content_type, content_type=content_type,
frame_ref=candidate.frame.sequence, frame_ref=candidate.frame.sequence,
@@ -115,7 +223,7 @@ def resolve_brands(
unresolved.append(candidate) unresolved.append(candidate)
emit.log(job_id, "BrandResolver", "INFO", emit.log(job_id, "BrandResolver", "INFO",
f"Exact: {exact_count}, Fuzzy: {fuzzy_count}, " f"Session: {session_hits}, Known: {known_hits}, "
f"Unresolved: {len(unresolved)} → escalating to VLM") f"Unresolved: {len(unresolved)} → escalating")
return matched, unresolved return matched, unresolved

View File

@@ -27,6 +27,18 @@ logger = logging.getLogger(__name__)
ESTIMATED_TOKENS_PER_CROP = 500 ESTIMATED_TOKENS_PER_CROP = 500
def _register_discovered_brand(brand: str, source_asset_id: str | None,
timestamp: float, confidence: float):
"""Register a cloud-confirmed brand in the DB."""
try:
from detect.stages.brand_resolver import _register_brand, _record_sighting
brand_id = _register_brand(brand, "cloud_llm")
if brand_id and source_asset_id:
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, "cloud_llm")
except Exception as e:
logger.debug("Failed to register brand %s: %s", brand, e)
def _encode_crop(crop: np.ndarray) -> str: def _encode_crop(crop: np.ndarray) -> str:
img = Image.fromarray(crop) img = Image.fromarray(crop)
buf = io.BytesIO() buf = io.BytesIO()
@@ -84,6 +96,7 @@ def escalate_cloud(
stats: PipelineStats, stats: PipelineStats,
min_confidence: float = 0.4, min_confidence: float = 0.4,
content_type: str = "", content_type: str = "",
source_asset_id: str | None = None,
job_id: str | None = None, job_id: str | None = None,
) -> list[BrandDetection]: ) -> list[BrandDetection]:
""" """
@@ -158,6 +171,10 @@ def escalate_cloud(
frame_ref=candidate.frame.sequence, frame_ref=candidate.frame.sequence,
) )
# Register newly discovered brand in DB
_register_discovered_brand(brand, source_asset_id,
candidate.frame.timestamp, confidence)
stats.estimated_cloud_cost_usd += total_cost stats.estimated_cloud_cost_usd += total_cost
stats.regions_escalated_to_cloud_llm = len(candidates) stats.regions_escalated_to_cloud_llm = len(candidates)

View File

@@ -19,6 +19,18 @@ from detect.profiles.base import CropContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _register_discovered_brand(brand: str, source_asset_id: str | None,
timestamp: float, confidence: float, source: str):
"""Register a VLM-confirmed brand in the DB."""
try:
from detect.stages.brand_resolver import _register_brand, _record_sighting
brand_id = _register_brand(brand, source)
if brand_id and source_asset_id:
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, source)
except Exception as e:
logger.debug("Failed to register brand %s: %s", brand, e)
def _crop_image(candidate: TextCandidate) -> np.ndarray: def _crop_image(candidate: TextCandidate) -> np.ndarray:
frame = candidate.frame frame = candidate.frame
box = candidate.bbox box = candidate.bbox
@@ -36,6 +48,7 @@ def escalate_vlm(
inference_url: str | None = None, inference_url: str | None = None,
min_confidence: float = 0.5, min_confidence: float = 0.5,
content_type: str = "", content_type: str = "",
source_asset_id: str | None = None,
job_id: str | None = None, job_id: str | None = None,
) -> tuple[list[BrandDetection], list[TextCandidate]]: ) -> tuple[list[BrandDetection], list[TextCandidate]]:
""" """
@@ -107,6 +120,10 @@ def escalate_vlm(
frame_ref=candidate.frame.sequence, frame_ref=candidate.frame.sequence,
) )
# Register newly discovered brand in DB
_register_discovered_brand(brand, source_asset_id,
candidate.frame.timestamp, confidence, "local_vlm")
logger.debug("VLM matched: %s (%.2f) — %s", brand, confidence, reasoning) logger.debug("VLM matched: %s (%.2f) — %s", brand, confidence, reasoning)
else: else:
still_unresolved.append(candidate) still_unresolved.append(candidate)

View File

@@ -17,6 +17,7 @@ class DetectState(TypedDict, total=False):
video_path: str video_path: str
job_id: str job_id: str
profile_name: str profile_name: str
source_asset_id: str # UUID of the source MediaAsset
# Stage outputs # Stage outputs
frames: list[Frame] frames: list[Frame]
@@ -27,5 +28,11 @@ class DetectState(TypedDict, total=False):
detections: list[BrandDetection] detections: list[BrandDetection]
report: DetectionReport report: DetectionReport
# Session brands (accumulated during the run, persisted to DB)
session_brands: dict # {normalized_name: canonical_name}
# Running stats (updated by each stage) # Running stats (updated by each stage)
stats: PipelineStats stats: PipelineStats
# Config overrides for replay (applied via OverrideProfile)
config_overrides: dict

View File

@@ -31,8 +31,16 @@ def ts():
return datetime.now(timezone.utc).isoformat() return datetime.now(timezone.utc).isoformat()
RUN_CONTEXT = {}
def set_run_context(run_id: str, parent_job_id: str, run_type: str = "initial"):
RUN_CONTEXT.update({"run_id": run_id, "parent_job_id": parent_job_id, "run_type": run_type})
def push(r, key, event): def push(r, key, event):
event["ts"] = event.get("ts", ts()) event["ts"] = event.get("ts", ts())
event.update(RUN_CONTEXT)
r.rpush(key, json.dumps(event)) r.rpush(key, json.dumps(event))
return event return event
@@ -85,7 +93,11 @@ def main():
r.delete(key) r.delete(key)
delay = args.delay delay = args.delay
run_id = f"{args.job[:8]}-r1"
set_run_context(run_id=run_id, parent_job_id=args.job, run_type="initial")
logger.info("Full escalation pipeline simulation → %s", key) logger.info("Full escalation pipeline simulation → %s", key)
logger.info("Run: %s (parent: %s)", run_id, args.job)
logger.info("Open: http://mpr.local.ar/detection/?job=%s", args.job) logger.info("Open: http://mpr.local.ar/detection/?job=%s", args.job)
input("\nPress Enter to start...") input("\nPress Enter to start...")

View File

@@ -0,0 +1,123 @@
#!/usr/bin/env python3
"""
Test checkpoint + replay flow end-to-end.
1. Runs the pipeline with checkpointing enabled on a test video
2. Lists available checkpoints
3. Replays from run_ocr with different config
4. Compares detection counts
Usage:
MPR_CHECKPOINT=1 INFERENCE_URL=http://mcrndeb:8000 python tests/detect/manual/test_replay.py [--job JOB_ID]
Requires: inference server running, MinIO/S3 running, test video available
"""
import argparse
import logging
import os
import sys
from pathlib import Path
# Load ctrl/.env
env_file = Path(__file__).resolve().parents[3] / "ctrl" / ".env"
if env_file.exists():
for line in env_file.read_text().splitlines():
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, _, value = line.partition("=")
os.environ.setdefault(key.strip(), value.strip())
sys.path.insert(0, ".")
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s%(message)s")
logger = logging.getLogger(__name__)
# Force checkpointing on
os.environ["MPR_CHECKPOINT"] = "1"
def main():
parser = argparse.ArgumentParser()
import time
default_job = f"replay-{int(time.time()) % 100000}"
parser.add_argument("--job", default=default_job)
parser.add_argument("--port", type=int, default=6382)
args = parser.parse_args()
# Override Redis to localhost (ctrl/.env has k8s hostname)
os.environ["REDIS_URL"] = f"redis://localhost:{args.port}/0"
from detect.graph import get_pipeline, NODES
from detect.checkpoint import list_checkpoints
from detect.checkpoint import replay_from
from detect.state import DetectState
VIDEO = "media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4"
logger.info("Job: %s", args.job)
logger.info("Checkpoint: enabled")
logger.info("Video: %s", VIDEO)
logger.info("Open: http://mpr.local.ar/detection/?job=%s", args.job)
input("\nPress Enter to run initial pipeline...")
# --- Initial run ---
pipeline = get_pipeline(checkpoint=True)
initial_state = DetectState(
video_path=VIDEO,
job_id=args.job,
profile_name="soccer_broadcast",
)
logger.info("Running initial pipeline...")
result = pipeline.invoke(initial_state)
detections = result.get("detections", [])
report = result.get("report")
logger.info("Initial run: %d detections, %d brands",
len(detections), len(report.brands) if report else 0)
# --- List checkpoints ---
stages = list_checkpoints(args.job)
logger.info("Available checkpoints: %s", stages)
if "detect_objects" not in stages:
logger.error("Expected checkpoint for detect_objects — aborting replay test")
return
input("\nPress Enter to replay from run_ocr with different config...")
# --- Replay with different OCR config ---
overrides = {"ocr": {"min_confidence": 0.3, "languages": ["en", "es"]}}
logger.info("Replaying from run_ocr with overrides: %s", overrides)
replay_result = replay_from(
job_id=args.job,
start_stage="run_ocr",
config_overrides=overrides,
)
replay_detections = replay_result.get("detections", [])
replay_report = replay_result.get("report")
logger.info("Replay run: %d detections, %d brands",
len(replay_detections),
len(replay_report.brands) if replay_report else 0)
# --- Compare ---
logger.info("--- Comparison ---")
logger.info("Initial: %d detections", len(detections))
logger.info("Replay: %d detections (min_confidence 0.5 → 0.3)", len(replay_detections))
diff = len(replay_detections) - len(detections)
if diff > 0:
logger.info("Replay found %d more detections with lower threshold", diff)
elif diff == 0:
logger.info("Same count — threshold change didn't affect this video")
else:
logger.warning("Replay found fewer detections — unexpected")
logger.info("Done.")
if __name__ == "__main__":
main()

View File

@@ -1,20 +1,13 @@
"""Tests for BrandResolver stage.""" """Tests for BrandResolver stage (discovery mode)."""
import numpy as np import numpy as np
import pytest import pytest
from detect.models import BoundingBox, Frame, TextCandidate from detect.models import BoundingBox, Frame, TextCandidate
from detect.profiles.base import BrandDictionary, ResolverConfig from detect.profiles.base import ResolverConfig
from detect.stages.brand_resolver import resolve_brands, _exact_match, _fuzzy_match from detect.stages.brand_resolver import resolve_brands, _normalize, _match_session
DICTIONARY = BrandDictionary(brands={
"Nike": ["nike", "NIKE", "swoosh"],
"Adidas": ["adidas", "ADIDAS"],
"Coca-Cola": ["coca-cola", "coca cola", "coke", "COCA-COLA"],
"Emirates": ["emirates", "fly emirates", "EMIRATES"],
})
CONFIG = ResolverConfig(fuzzy_threshold=75) CONFIG = ResolverConfig(fuzzy_threshold=75)
@@ -25,57 +18,76 @@ def _make_candidate(text: str, confidence: float = 0.9) -> TextCandidate:
return TextCandidate(frame=dummy_frame, bbox=dummy_box, text=text, ocr_confidence=confidence) return TextCandidate(frame=dummy_frame, bbox=dummy_box, text=text, ocr_confidence=confidence)
def test_exact_match(): def test_session_match():
assert _exact_match("Nike", DICTIONARY) == "Nike" session = {"nike": "Nike", "fly emirates": "Emirates"}
assert _exact_match("nike", DICTIONARY) == "Nike" assert _match_session("Nike", session) == "Nike"
assert _exact_match("COCA-COLA", DICTIONARY) == "Coca-Cola" assert _match_session("nike", session) == "Nike"
assert _exact_match("fly emirates", DICTIONARY) == "Emirates" assert _match_session("FLY EMIRATES", session) == "Emirates"
assert _exact_match("unknown brand", DICTIONARY) is None assert _match_session("unknown", session) is None
def test_fuzzy_match(): def test_resolve_with_session(monkeypatch):
brand, score = _fuzzy_match("Nik3", DICTIONARY, threshold=75) events = []
assert brand == "Nike" monkeypatch.setattr("detect.emit.push_detect_event",
assert score >= 75 lambda job_id, etype, data: events.append((etype, data)))
brand, score = _fuzzy_match("adldas", DICTIONARY, threshold=75) session = {"nike": "Nike", "emirates": "Emirates"}
assert brand == "Adidas"
brand, score = _fuzzy_match("xyzxyzxyz", DICTIONARY, threshold=75)
assert brand is None
def test_resolve_exact():
candidates = [_make_candidate("Nike"), _make_candidate("EMIRATES")] candidates = [_make_candidate("Nike"), _make_candidate("EMIRATES")]
matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG)
matched, unresolved = resolve_brands(
candidates, CONFIG, session_brands=session,
)
assert len(matched) == 2 assert len(matched) == 2
assert len(unresolved) == 0 assert len(unresolved) == 0
assert matched[0].brand == "Nike" assert matched[0].brand == "Nike"
assert matched[1].brand == "Emirates" assert matched[1].brand == "Emirates"
def test_resolve_fuzzy(): def test_resolve_unresolved_without_db(monkeypatch):
candidates = [_make_candidate("coca coIa")] # OCR misread events = []
matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG) monkeypatch.setattr("detect.emit.push_detect_event",
assert len(matched) == 1 lambda job_id, etype, data: events.append((etype, data)))
assert matched[0].brand == "Coca-Cola"
def test_resolve_unresolved():
candidates = [_make_candidate("random garbage text")] candidates = [_make_candidate("random garbage text")]
matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG)
matched, unresolved = resolve_brands(
candidates, CONFIG, session_brands={},
)
assert len(matched) == 0 assert len(matched) == 0
assert len(unresolved) == 1 assert len(unresolved) == 1
def test_resolve_mixed(): def test_resolve_empty(monkeypatch):
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
matched, unresolved = resolve_brands([], CONFIG, session_brands={})
assert len(matched) == 0
assert len(unresolved) == 0
def test_resolve_builds_session_during_run(monkeypatch):
"""Session brands accumulate during a single run — second candidate benefits."""
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
session = {"nike": "Nike"}
candidates = [ candidates = [
_make_candidate("Nike"), _make_candidate("Nike"), # hits session
_make_candidate("unknown"), _make_candidate("unknown"), # misses everything
_make_candidate("adldas"),
] ]
matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG)
assert len(matched) == 2 # Nike exact + Adidas fuzzy matched, unresolved = resolve_brands(
candidates, CONFIG, session_brands=session,
)
assert len(matched) == 1
assert matched[0].brand == "Nike"
assert len(unresolved) == 1 assert len(unresolved) == 1
@@ -84,8 +96,10 @@ def test_events_emitted(monkeypatch):
monkeypatch.setattr("detect.emit.push_detect_event", monkeypatch.setattr("detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data))) lambda job_id, etype, data: events.append((etype, data)))
session = {"nike": "Nike"}
candidates = [_make_candidate("Nike")] candidates = [_make_candidate("Nike")]
resolve_brands(candidates, DICTIONARY, CONFIG, job_id="test-job")
resolve_brands(candidates, CONFIG, session_brands=session, job_id="test-job")
event_types = [e[0] for e in events] event_types = [e[0] for e in events]
assert "log" in event_types assert "log" in event_types

View File

@@ -0,0 +1,182 @@
"""Tests for checkpoint serialization — round-trip without S3."""
import numpy as np
import pytest
from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
from detect.checkpoint.serializer import (
serialize_state,
deserialize_state,
serialize_frame_meta,
serialize_text_candidate,
deserialize_text_candidate,
)
def _make_frame(seq: int = 0, w: int = 100, h: int = 80) -> Frame:
image = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
return Frame(
sequence=seq,
chunk_id=0,
timestamp=float(seq) * 0.5,
image=image,
perceptual_hash=f"hash_{seq}",
)
def _make_box(x=10, y=10, w=30, h=20) -> BoundingBox:
return BoundingBox(x=x, y=y, w=w, h=h, confidence=0.9, label="text")
def _make_candidate(frame: Frame, text: str = "NIKE") -> TextCandidate:
box = _make_box()
return TextCandidate(frame=frame, bbox=box, text=text, ocr_confidence=0.85)
def _make_detection(brand: str = "Nike", timestamp: float = 1.0) -> BrandDetection:
return BrandDetection(
brand=brand,
timestamp=timestamp,
duration=0.5,
confidence=0.92,
source="ocr",
content_type="soccer_broadcast",
frame_ref=0,
)
# --- Frame metadata ---
def test_serialize_frame_meta():
frame = _make_frame(seq=5)
meta = serialize_frame_meta(frame)
assert meta["sequence"] == 5
assert meta["timestamp"] == 2.5
assert meta["perceptual_hash"] == "hash_5"
assert "image" not in meta
# --- TextCandidate ---
def test_serialize_text_candidate():
frame = _make_frame()
candidate = _make_candidate(frame, text="EMIRATES")
data = serialize_text_candidate(candidate)
assert data["frame_sequence"] == 0
assert data["text"] == "EMIRATES"
assert data["ocr_confidence"] == 0.85
assert "bbox" in data
def test_deserialize_text_candidate():
frame = _make_frame()
candidate = _make_candidate(frame, text="ADIDAS")
serialized = serialize_text_candidate(candidate)
frame_map = {frame.sequence: frame}
restored = deserialize_text_candidate(serialized, frame_map)
assert restored.text == "ADIDAS"
assert restored.ocr_confidence == 0.85
assert restored.frame is frame # same object reference
assert restored.bbox.x == 10
# --- Full state round-trip ---
def test_state_round_trip():
frames = [_make_frame(seq=i) for i in range(3)]
filtered = frames[:2]
box = _make_box()
boxes_by_frame = {0: [box], 1: [box]}
candidates = [_make_candidate(frames[0], "NIKE"), _make_candidate(frames[1], "EMIRATES")]
unresolved = [_make_candidate(frames[2], "unknown")]
detections = [_make_detection("Nike", 0.5), _make_detection("Emirates", 1.0)]
stats = PipelineStats(
frames_extracted=3,
frames_after_scene_filter=2,
regions_detected=2,
regions_resolved_by_ocr=2,
cloud_llm_calls=1,
estimated_cloud_cost_usd=0.003,
)
state = {
"job_id": "test-123",
"video_path": "/tmp/test.mp4",
"profile_name": "soccer_broadcast",
"config_overrides": {"ocr": {"min_confidence": 0.3}},
"frames": frames,
"filtered_frames": filtered,
"boxes_by_frame": boxes_by_frame,
"text_candidates": candidates,
"unresolved_candidates": unresolved,
"detections": detections,
"stats": stats,
}
manifest = {f.sequence: f"s3://fake/frames/{f.sequence}.jpg" for f in frames}
# Serialize
serialized = serialize_state(state, manifest)
# Verify JSON-compatible (no numpy, no Frame objects)
import json
json_str = json.dumps(serialized, default=str)
assert len(json_str) > 0
# Deserialize with the original frames (simulating frame load from S3)
restored = deserialize_state(serialized, frames)
# Verify round-trip
assert restored["job_id"] == "test-123"
assert restored["video_path"] == "/tmp/test.mp4"
assert restored["profile_name"] == "soccer_broadcast"
assert restored["config_overrides"] == {"ocr": {"min_confidence": 0.3}}
assert len(restored["frames"]) == 3
assert len(restored["filtered_frames"]) == 2
assert len(restored["boxes_by_frame"]) == 2
assert len(restored["text_candidates"]) == 2
assert len(restored["unresolved_candidates"]) == 1
assert len(restored["detections"]) == 2
restored_stats = restored["stats"]
assert restored_stats.frames_extracted == 3
assert restored_stats.cloud_llm_calls == 1
assert restored_stats.estimated_cloud_cost_usd == 0.003
# TextCandidate frame references should point to actual Frame objects
tc = restored["text_candidates"][0]
assert tc.frame is frames[0]
assert tc.text == "NIKE"
def test_state_round_trip_empty():
"""Empty state should serialize/deserialize cleanly."""
state = {
"job_id": "empty-job",
"video_path": "",
"profile_name": "soccer_broadcast",
"frames": [],
"filtered_frames": [],
"boxes_by_frame": {},
"text_candidates": [],
"unresolved_candidates": [],
"detections": [],
"stats": PipelineStats(),
}
serialized = serialize_state(state, {})
restored = deserialize_state(serialized, [])
assert restored["job_id"] == "empty-job"
assert len(restored["frames"]) == 0
assert len(restored["detections"]) == 0
assert restored["stats"].frames_extracted == 0

View File

@@ -25,11 +25,9 @@ def test_soccer_detection_config():
assert isinstance(cfg.target_classes, list) assert isinstance(cfg.target_classes, list)
def test_soccer_brand_dictionary_non_empty(): def test_soccer_resolver_config():
bd = SoccerBroadcastProfile().brand_dictionary() cfg = SoccerBroadcastProfile().resolver_config()
assert len(bd.brands) > 0 assert cfg.fuzzy_threshold > 0
for canonical, aliases in bd.brands.items():
assert len(aliases) > 0
def test_soccer_vlm_prompt(): def test_soccer_vlm_prompt():
@@ -70,4 +68,4 @@ def test_stubs_raise(stub_cls):
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
stub.frame_extraction_config() stub.frame_extraction_config()
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
stub.brand_dictionary() stub.resolver_config()

View File

@@ -0,0 +1,67 @@
"""Tests for replay and OverrideProfile."""
import pytest
from detect.profiles.soccer import SoccerBroadcastProfile
from detect.checkpoint.replay import OverrideProfile
def test_override_profile_patches_ocr():
base = SoccerBroadcastProfile()
overrides = {"ocr": {"min_confidence": 0.3, "languages": ["en", "es", "pt"]}}
profile = OverrideProfile(base, overrides)
config = profile.ocr_config()
assert config.min_confidence == 0.3
assert config.languages == ["en", "es", "pt"]
def test_override_profile_patches_resolver():
base = SoccerBroadcastProfile()
overrides = {"resolver": {"fuzzy_threshold": 60}}
profile = OverrideProfile(base, overrides)
config = profile.resolver_config()
assert config.fuzzy_threshold == 60
def test_override_profile_patches_detection():
base = SoccerBroadcastProfile()
overrides = {"detection": {"confidence_threshold": 0.5}}
profile = OverrideProfile(base, overrides)
config = profile.detection_config()
assert config.confidence_threshold == 0.5
def test_override_profile_no_overrides():
base = SoccerBroadcastProfile()
profile = OverrideProfile(base, {})
ocr = profile.ocr_config()
base_ocr = base.ocr_config()
assert ocr.min_confidence == base_ocr.min_confidence
assert ocr.languages == base_ocr.languages
def test_override_profile_delegates_non_config():
base = SoccerBroadcastProfile()
profile = OverrideProfile(base, {"ocr": {"min_confidence": 0.1}})
assert profile.name == "soccer_broadcast"
assert profile.resolver_config().fuzzy_threshold > 0
def test_override_profile_ignores_unknown_fields():
base = SoccerBroadcastProfile()
overrides = {"ocr": {"nonexistent_field": 42}}
profile = OverrideProfile(base, overrides)
config = profile.ocr_config()
assert not hasattr(config, "nonexistent_field")
assert config.min_confidence == base.ocr_config().min_confidence

View File

@@ -1,6 +1,6 @@
<script setup lang="ts"> <script setup lang="ts">
import { ref } from 'vue' import { ref } from 'vue'
import { SSEDataSource, Panel, LayoutGrid } from 'mpr-ui-framework' import { SSEDataSource, Panel, ResizeHandle } from 'mpr-ui-framework'
import 'mpr-ui-framework/src/tokens.css' import 'mpr-ui-framework/src/tokens.css'
import LogPanel from './panels/LogPanel.vue' import LogPanel from './panels/LogPanel.vue'
import FunnelPanel from './panels/FunnelPanel.vue' import FunnelPanel from './panels/FunnelPanel.vue'
@@ -9,10 +9,11 @@ import FramePanel from './panels/FramePanel.vue'
import BrandTablePanel from './panels/BrandTablePanel.vue' import BrandTablePanel from './panels/BrandTablePanel.vue'
import TimelinePanel from './panels/TimelinePanel.vue' import TimelinePanel from './panels/TimelinePanel.vue'
import CostStatsPanel from './panels/CostStatsPanel.vue' import CostStatsPanel from './panels/CostStatsPanel.vue'
import type { StatsUpdate } from './types/sse-contract' import type { StatsUpdate, RunContext } from './types/sse-contract'
const jobId = ref(new URLSearchParams(window.location.search).get('job') || 'test-job') const jobId = ref(new URLSearchParams(window.location.search).get('job') || 'test-job')
const stats = ref<StatsUpdate | null>(null) const stats = ref<StatsUpdate | null>(null)
const runContext = ref<RunContext | null>(null)
const status = ref<'idle' | 'live' | 'processing' | 'error'>('idle') const status = ref<'idle' | 'live' | 'processing' | 'error'>('idle')
const source = new SSEDataSource({ const source = new SSEDataSource({
@@ -23,8 +24,41 @@ const source = new SSEDataSource({
source.on<StatsUpdate>('stats_update', (e) => { source.on<StatsUpdate>('stats_update', (e) => {
stats.value = e stats.value = e
// Capture run context from first event that carries it
if (!runContext.value && e.run_id) {
runContext.value = {
run_id: (e as any).run_id,
parent_job_id: (e as any).parent_job_id,
run_type: (e as any).run_type ?? 'initial',
}
}
}) })
// Resizable splits
const pipelineWidth = ref(320)
const detectionsFlex = ref(3) // ratio for detections vs stats
const viewerHeight = ref(240)
const timelineFlex = ref(1)
const tableFlex = ref(1)
function onPipelineResize(delta: number) {
pipelineWidth.value = Math.max(200, Math.min(500, pipelineWidth.value + delta))
}
function onViewerResize(delta: number) {
viewerHeight.value = Math.max(120, Math.min(400, viewerHeight.value + delta))
}
function onDetectionsResize(delta: number) {
detectionsFlex.value = Math.max(1, Math.min(6, detectionsFlex.value + delta * 0.01))
}
function onTimelineResize(delta: number) {
const shift = delta * 0.02
timelineFlex.value = Math.max(0.3, Math.min(3, timelineFlex.value + shift))
tableFlex.value = Math.max(0.3, Math.min(3, tableFlex.value - shift))
}
const statusMap: Record<string, 'idle' | 'live' | 'processing' | 'error'> = { const statusMap: Record<string, 'idle' | 'live' | 'processing' | 'error'> = {
idle: 'idle', idle: 'idle',
connecting: 'processing', connecting: 'processing',
@@ -42,41 +76,70 @@ source.connect()
<header> <header>
<h1>Detection Pipeline</h1> <h1>Detection Pipeline</h1>
<span class="status-badge" :class="status">{{ status }}</span> <span class="status-badge" :class="status">{{ status }}</span>
<span v-if="runContext" class="run-info">
{{ runContext.run_type }} · run: {{ runContext.run_id }}
</span>
<span class="job-id">job: {{ jobId }}</span> <span class="job-id">job: {{ jobId }}</span>
</header> </header>
<LayoutGrid :columns="3" :rows="3" gap="var(--space-2)"> <div class="main-layout">
<Panel title="Stats" :status="status"> <!-- Left column: Pipeline control (full height) -->
<div class="stats" v-if="stats"> <div class="pipeline-col" :style="{ width: pipelineWidth + 'px' }">
<div class="stat" v-for="s in [ <PipelineGraphPanel :source="source" :status="status" />
{ label: 'Frames', value: stats.frames_extracted }, </div>
{ label: 'After filter', value: stats.frames_after_scene_filter }, <ResizeHandle direction="horizontal" @resize="onPipelineResize" />
{ label: 'Regions', value: stats.regions_detected },
{ label: 'OCR resolved', value: stats.regions_resolved_by_ocr }, <!-- Right area: interactive panels -->
{ label: 'Cloud calls', value: stats.cloud_llm_calls }, <div class="content-col">
{ label: 'Cost', value: `$${stats.estimated_cloud_cost_usd.toFixed(4)}` }, <!-- Row 1: Frame viewer + Funnel -->
]" :key="s.label"> <div class="viewer-row" :style="{ height: viewerHeight + 'px' }">
<span class="label">{{ s.label }}</span> <FramePanel :source="source" :status="status" />
<span class="value">{{ s.value }}</span> <FunnelPanel :source="source" :status="status" />
</div>
<ResizeHandle direction="vertical" @resize="onViewerResize" />
<!-- Row 2: Detections + Stats side by side -->
<div class="detections-stats-row">
<div class="detections-col" :style="{ flex: detectionsFlex }">
<Panel title="Detections" :status="status">
<div class="detections-stack">
<div class="timeline-section" :style="{ flex: timelineFlex }">
<TimelinePanel :source="source" :status="status" :embedded="true" />
</div>
<ResizeHandle direction="vertical" @resize="onTimelineResize" />
<div class="table-section" :style="{ flex: tableFlex }">
<BrandTablePanel :source="source" :status="status" :embedded="true" />
</div>
</div>
</Panel>
</div>
<ResizeHandle direction="horizontal" @resize="onDetectionsResize" />
<div class="stats-col">
<Panel title="Pipeline" :status="status">
<div class="pipeline-stats">
<div class="stat" v-for="s in [
{ label: 'Frames', value: stats?.frames_extracted ?? '—' },
{ label: 'After filter', value: stats?.frames_after_scene_filter ?? '—' },
{ label: 'Regions', value: stats?.regions_detected ?? '—' },
{ label: 'OCR resolved', value: stats?.regions_resolved_by_ocr ?? '—' },
{ label: 'VLM escalated', value: stats?.regions_escalated_to_local_vlm ?? '—' },
{ label: 'Cloud escalated', value: stats?.regions_escalated_to_cloud_llm ?? '—' },
]" :key="s.label">
<span class="label">{{ s.label }}</span>
<span class="value">{{ s.value }}</span>
</div>
</div>
</Panel>
<CostStatsPanel :source="source" :status="status" />
</div> </div>
</div> </div>
<div v-else class="empty">Waiting for stats...</div> </div>
</Panel> </div>
<FunnelPanel :source="source" :status="status" />
<FramePanel :source="source" :status="status" />
<PipelineGraphPanel :source="source" :status="status" />
<BrandTablePanel :source="source" :status="status" />
<TimelinePanel :source="source" :status="status" />
<CostStatsPanel :source="source" :status="status" />
<!-- Bottom: Log (full width) -->
<div class="log-row">
<LogPanel :source="source" :status="status" /> <LogPanel :source="source" :status="status" />
</LayoutGrid> </div>
</div> </div>
</template> </template>
@@ -92,10 +155,11 @@ body {
.app { .app {
height: 100vh; height: 100vh;
display: flex; display: grid;
flex-direction: column; grid-template-rows: auto 1fr auto;
padding: var(--space-4); padding: var(--space-4);
gap: var(--space-2); gap: var(--space-2);
overflow: hidden;
} }
header { header {
@@ -120,21 +184,110 @@ header h1 { font-size: var(--font-size-lg); font-weight: 600; }
.status-badge.live { background: var(--status-live); color: #000; } .status-badge.live { background: var(--status-live); color: #000; }
.status-badge.error { background: var(--status-error); color: #000; } .status-badge.error { background: var(--status-error); color: #000; }
.run-info {
color: var(--text-secondary);
font-size: var(--font-size-sm);
}
.job-id { color: var(--text-dim); font-size: var(--font-size-sm); margin-left: auto; } .job-id { color: var(--text-dim); font-size: var(--font-size-sm); margin-left: auto; }
.stats { /* Main layout: pipeline left, content right — both same height */
display: grid; .main-layout {
grid-template-columns: repeat(3, 1fr); display: flex;
gap: var(--space-2); gap: var(--space-2);
min-height: 0;
overflow: hidden;
}
.pipeline-col {
flex-shrink: 0;
display: flex;
overflow: hidden;
}
.pipeline-col > * { flex: 1; }
.content-col {
flex: 1;
display: flex;
flex-direction: column;
gap: var(--space-2);
min-height: 0;
overflow: hidden;
}
/* Content rows */
.viewer-row {
display: flex;
gap: var(--space-2);
flex-shrink: 0;
}
.viewer-row > * { flex: 1; overflow: hidden; }
/* Detections (75%) + Stats (25%) side by side, bottom-aligned with pipeline */
.detections-stats-row {
display: flex;
gap: var(--space-2);
flex: 1;
min-height: 0;
}
.detections-col {
flex: 3;
min-width: 0;
display: flex;
}
.detections-col > * { flex: 1; display: flex; flex-direction: column; }
.stats-col {
flex: 1;
display: flex;
flex-direction: column;
gap: var(--space-2);
min-width: 200px;
}
.stats-col > * { flex: 1; }
.detections-stack {
display: flex;
flex-direction: column;
height: 100%;
overflow: hidden;
}
.timeline-section {
min-height: 60px;
overflow: auto;
}
.table-section {
min-height: 60px;
overflow-y: auto;
overflow-x: hidden;
}
/* Pipeline stats list */
.pipeline-stats {
display: flex;
flex-direction: column;
gap: var(--space-1);
} }
.stat { .stat {
background: var(--surface-2); display: flex;
border-radius: var(--panel-radius); justify-content: space-between;
padding: var(--space-3); padding: var(--space-1) 0;
}
.stat .label { color: var(--text-dim); font-size: var(--font-size-sm); }
.stat .value { font-weight: 600; }
/* Log: full width bottom, fixed height */
.log-row {
flex-shrink: 0;
height: 160px;
overflow: hidden;
} }
.stat .label { display: block; color: var(--text-dim); font-size: var(--font-size-sm); margin-bottom: var(--space-1); }
.stat .value { font-size: 20px; font-weight: 600; }
.empty { color: var(--text-dim); padding: var(--space-6); text-align: center; } .empty { color: var(--text-dim); padding: var(--space-6); text-align: center; }
</style> </style>

View File

@@ -8,6 +8,7 @@ import type { DataSource } from 'mpr-ui-framework'
const props = defineProps<{ const props = defineProps<{
source: DataSource source: DataSource
status?: 'idle' | 'live' | 'processing' | 'error' status?: 'idle' | 'live' | 'processing' | 'error'
embedded?: boolean
}>() }>()
const columns: TableColumn[] = [ const columns: TableColumn[] = [
@@ -45,7 +46,7 @@ function onSort(key: string) {
</script> </script>
<template> <template>
<Panel title="Detections" :status="status"> <Panel v-if="!embedded" title="Detections" :status="status">
<TableRenderer <TableRenderer
:columns="columns" :columns="columns"
:rows="rows" :rows="rows"
@@ -54,4 +55,12 @@ function onSort(key: string) {
@sort="onSort" @sort="onSort"
/> />
</Panel> </Panel>
<TableRenderer
v-else
:columns="columns"
:rows="rows"
:sort-key="sortKey"
:sort-dir="sortDir"
@sort="onSort"
/>
</template> </template>

View File

@@ -7,6 +7,7 @@ import type { StatsUpdate, Detection } from '../types/sse-contract'
const props = defineProps<{ const props = defineProps<{
source: DataSource source: DataSource
status?: 'idle' | 'live' | 'processing' | 'error' status?: 'idle' | 'live' | 'processing' | 'error'
embedded?: boolean
}>() }>()
const stats = ref<StatsUpdate | null>(null) const stats = ref<StatsUpdate | null>(null)
@@ -71,7 +72,7 @@ const metrics = computed<Metric[]>(() => {
</script> </script>
<template> <template>
<Panel title="Cost & Stats" :status="status"> <Panel v-if="!embedded" title="Cost & Stats" :status="status">
<div class="cost-stats" v-if="stats"> <div class="cost-stats" v-if="stats">
<div class="metric" v-for="m in metrics" :key="m.label"> <div class="metric" v-for="m in metrics" :key="m.label">
<span class="label">{{ m.label }}</span> <span class="label">{{ m.label }}</span>
@@ -81,38 +82,59 @@ const metrics = computed<Metric[]>(() => {
</div> </div>
<div v-else class="empty">Waiting for stats...</div> <div v-else class="empty">Waiting for stats...</div>
</Panel> </Panel>
<div v-else class="cost-stats" v-show="stats">
<div class="metric" v-for="m in metrics" :key="m.label">
<span class="label">{{ m.label }}</span>
<span class="value" :style="m.color ? { color: m.color } : {}">{{ m.value }}</span>
<span class="sub" v-if="m.sub">{{ m.sub }}</span>
</div>
</div>
</template> </template>
<style scoped> <style scoped>
.cost-stats { .cost-stats {
display: grid; display: grid;
grid-template-columns: 1fr 1fr; grid-template-columns: 1fr 1fr;
gap: var(--space-3); gap: var(--space-2);
padding: var(--space-3); padding: var(--space-2);
overflow: hidden;
box-sizing: border-box;
width: 100%;
} }
.metric { .metric {
background: var(--surface-2); background: var(--surface-2);
border-radius: var(--panel-radius); border-radius: var(--panel-radius);
padding: var(--space-3); padding: var(--space-2);
display: flex; display: flex;
flex-direction: column; flex-direction: column;
gap: var(--space-1); gap: 2px;
overflow: hidden;
min-width: 0;
} }
.label { .label {
font-size: var(--font-size-sm); font-size: var(--font-size-sm);
color: var(--text-dim); color: var(--text-dim);
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
} }
.value { .value {
font-size: 22px; font-size: 18px;
font-weight: 600; font-weight: 600;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
} }
.sub { .sub {
font-size: 11px; font-size: 10px;
color: var(--text-dim); color: var(--text-dim);
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
} }
.empty { .empty {

View File

@@ -7,6 +7,7 @@ import type { Detection } from '../types/sse-contract'
const props = defineProps<{ const props = defineProps<{
source: DataSource source: DataSource
status?: 'idle' | 'live' | 'processing' | 'error' status?: 'idle' | 'live' | 'processing' | 'error'
embedded?: boolean
}>() }>()
interface TimelineEntry { interface TimelineEntry {
@@ -45,6 +46,20 @@ const maxTime = computed(() => {
return Math.max(...entries.value.map((e) => e.timestamp + e.duration)) * 1.1 return Math.max(...entries.value.map((e) => e.timestamp + e.duration)) * 1.1
}) })
// Ruler ticks — nice intervals based on duration
const rulerTicks = computed(() => {
const max = maxTime.value
const intervals = [1, 2, 5, 10, 15, 30, 60, 120, 300]
const targetTicks = 8
const rawStep = max / targetTicks
const step = intervals.find((i) => i >= rawStep) || Math.ceil(rawStep / 60) * 60
const ticks: number[] = []
for (let t = 0; t <= max; t += step) {
ticks.push(t)
}
return ticks
})
const sourceColor: Record<string, string> = { const sourceColor: Record<string, string> = {
ocr: '#3ecf8e', ocr: '#3ecf8e',
local_vlm: '#f5a623', local_vlm: '#f5a623',
@@ -67,8 +82,53 @@ function barStyle(entry: TimelineEntry) {
</script> </script>
<template> <template>
<Panel title="Detection Timeline" :status="status"> <Panel v-if="!embedded" title="Detection Timeline" :status="status">
<div class="timeline" v-if="brands.length > 0"> <div class="timeline" v-if="brands.length > 0">
<div class="ruler">
<span class="ruler-label" />
<div class="ruler-track">
<span
v-for="t in rulerTicks" :key="t"
class="ruler-tick"
:style="{ left: (t / maxTime) * 100 + '%' }"
>{{ t }}s</span>
</div>
</div>
<div class="rows">
<div class="row" v-for="brand in brands" :key="brand">
<span class="brand-label">{{ brand }}</span>
<div class="track">
<div
v-for="(entry, i) in entries.filter((e) => e.brand === brand)"
:key="i"
class="bar"
:style="barStyle(entry)"
:title="`${entry.brand} — ${entry.source} (${(entry.confidence * 100).toFixed(0)}%) @ ${entry.timestamp.toFixed(1)}s`"
/>
</div>
</div>
</div>
<div class="legend">
<span v-for="(color, source) in sourceColor" :key="source" class="legend-item">
<span class="legend-dot" :style="{ background: color }" />
{{ source }}
</span>
</div>
</div>
<div v-else class="empty">Waiting for detections...</div>
</Panel>
<div v-else class="timeline" v-show="brands.length > 0">
<div class="ruler">
<span class="ruler-label" />
<div class="ruler-track">
<span
v-for="t in rulerTicks" :key="t"
class="ruler-tick"
:style="{ left: (t / maxTime) * 100 + '%' }"
>{{ t }}s</span>
</div>
</div>
<div class="rows">
<div class="row" v-for="brand in brands" :key="brand"> <div class="row" v-for="brand in brands" :key="brand">
<span class="brand-label">{{ brand }}</span> <span class="brand-label">{{ brand }}</span>
<div class="track"> <div class="track">
@@ -81,20 +141,8 @@ function barStyle(entry: TimelineEntry) {
/> />
</div> </div>
</div> </div>
<div class="time-axis">
<span>0s</span>
<span>{{ (maxTime / 2).toFixed(0) }}s</span>
<span>{{ maxTime.toFixed(0) }}s</span>
</div>
<div class="legend">
<span v-for="(color, source) in sourceColor" :key="source" class="legend-item">
<span class="legend-dot" :style="{ background: color }" />
{{ source }}
</span>
</div>
</div> </div>
<div v-else class="empty">Waiting for detections...</div> </div>
</Panel>
</template> </template>
<style scoped> <style scoped>
@@ -102,9 +150,54 @@ function barStyle(entry: TimelineEntry) {
padding: var(--space-2); padding: var(--space-2);
display: flex; display: flex;
flex-direction: column; flex-direction: column;
gap: var(--space-1);
height: 100%; height: 100%;
overflow-y: auto; overflow: auto;
}
.ruler {
display: flex;
align-items: flex-end;
gap: var(--space-2);
flex-shrink: 0;
height: 20px;
margin-bottom: var(--space-1);
}
.ruler-label {
width: 100px;
flex-shrink: 0;
}
.ruler-track {
flex: 1;
position: relative;
height: 100%;
border-bottom: 1px solid var(--surface-3);
}
.ruler-tick {
position: absolute;
bottom: 0;
transform: translateX(-50%);
font-size: 9px;
color: var(--text-dim);
padding-bottom: 2px;
}
.ruler-tick::after {
content: '';
position: absolute;
bottom: -1px;
left: 50%;
width: 1px;
height: 6px;
background: var(--surface-3);
}
.rows {
display: flex;
flex-direction: column;
gap: var(--space-1);
} }
.row { .row {
@@ -141,15 +234,6 @@ function barStyle(entry: TimelineEntry) {
min-width: 4px; min-width: 4px;
} }
.time-axis {
display: flex;
justify-content: space-between;
padding-left: 108px;
font-size: 10px;
color: var(--text-dim);
margin-top: var(--space-1);
}
.legend { .legend {
display: flex; display: flex;
gap: var(--space-3); gap: var(--space-3);

View File

@@ -95,3 +95,44 @@ export interface JobComplete {
job_id: string; job_id: string;
report: DetectionReportSummary | null; report: DetectionReportSummary | null;
} }
// --- Run context (injected into all SSE events) ---
export interface RunContext {
run_id: string;
parent_job_id: string;
run_type: 'initial' | 'replay' | 'retry';
}
// --- Checkpoint API types ---
export interface CheckpointInfo {
stage: string;
}
export interface ReplayRequest {
job_id: string;
start_stage: string;
config_overrides?: Record<string, unknown>;
}
export interface ReplayResponse {
status: string;
job_id: string;
start_stage: string;
detections: number;
brands_found: number;
}
export interface RetryRequest {
job_id: string;
config_overrides?: Record<string, unknown>;
start_stage?: string;
schedule_seconds?: number;
}
export interface RetryResponse {
status: string;
task_id: string;
job_id: string;
}

View File

@@ -64,8 +64,9 @@ defineProps<{
.panel-body { .panel-body {
flex: 1; flex: 1;
overflow: auto; overflow: hidden;
padding: var(--space-2); padding: var(--space-2);
min-height: 0;
} }
.panel-overlay { .panel-overlay {

View File

@@ -0,0 +1,70 @@
<script setup lang="ts">
import { ref } from 'vue'
const props = defineProps<{
direction: 'horizontal' | 'vertical'
}>()
const emit = defineEmits<{
resize: [delta: number]
}>()
const dragging = ref(false)
let startPos = 0
function onPointerDown(e: PointerEvent) {
dragging.value = true
startPos = props.direction === 'horizontal' ? e.clientX : e.clientY
const el = e.target as HTMLElement
el.setPointerCapture(e.pointerId)
}
function onPointerMove(e: PointerEvent) {
if (!dragging.value) return
const currentPos = props.direction === 'horizontal' ? e.clientX : e.clientY
const delta = currentPos - startPos
startPos = currentPos
emit('resize', delta)
}
function onPointerUp() {
dragging.value = false
}
</script>
<template>
<div
class="resize-handle"
:class="[direction, { dragging }]"
@pointerdown="onPointerDown"
@pointermove="onPointerMove"
@pointerup="onPointerUp"
/>
</template>
<style scoped>
.resize-handle {
flex-shrink: 0;
background: transparent;
transition: background 0.15s;
touch-action: none;
z-index: 10;
}
.resize-handle:hover,
.resize-handle.dragging {
background: var(--text-dim);
}
.resize-handle.horizontal {
width: 4px;
cursor: col-resize;
margin: 0 -2px;
}
.resize-handle.vertical {
height: 4px;
cursor: row-resize;
margin: -2px 0;
}
</style>

View File

@@ -7,6 +7,7 @@ export { useDataSource } from './composables/useDataSource'
// Components // Components
export { default as Panel } from './components/Panel.vue' export { default as Panel } from './components/Panel.vue'
export { default as LayoutGrid } from './components/LayoutGrid.vue' export { default as LayoutGrid } from './components/LayoutGrid.vue'
export { default as ResizeHandle } from './components/ResizeHandle.vue'
// Renderers // Renderers
export { default as LogRenderer } from './renderers/LogRenderer.vue' export { default as LogRenderer } from './renderers/LogRenderer.vue'

View File

@@ -76,6 +76,7 @@ const sorted = computed(() => {
table { table {
width: 100%; width: 100%;
border-collapse: collapse; border-collapse: collapse;
table-layout: fixed;
} }
th { th {
@@ -89,7 +90,6 @@ th {
border-bottom: var(--panel-border); border-bottom: var(--panel-border);
cursor: pointer; cursor: pointer;
user-select: none; user-select: none;
white-space: nowrap;
} }
th:hover { th:hover {
@@ -104,7 +104,10 @@ th:hover {
td { td {
padding: var(--space-1) var(--space-3); padding: var(--space-1) var(--space-3);
border-bottom: 1px solid var(--surface-3); border-bottom: 1px solid var(--surface-3);
white-space: nowrap; white-space: normal;
word-break: break-word;
overflow: hidden;
text-overflow: ellipsis;
} }
tr:hover td { tr:hover td {

View File

@@ -22,6 +22,7 @@ const props = withDefaults(defineProps<{
}) })
const container = ref<HTMLElement | null>(null) const container = ref<HTMLElement | null>(null)
const zoomed = ref(false)
let chart: uPlot | null = null let chart: uPlot | null = null
function buildOpts(): uPlot.Options { function buildOpts(): uPlot.Options {
@@ -40,25 +41,66 @@ function buildOpts(): uPlot.Options {
height: container.value?.clientHeight ?? 200, height: container.value?.clientHeight ?? 200,
series: seriesOpts, series: seriesOpts,
axes: [ axes: [
{ stroke: '#555568', grid: { stroke: '#2e2e3822' } }, {
{ stroke: '#555568', grid: { stroke: '#2e2e3822' } }, stroke: '#555568',
grid: { stroke: '#2e2e3822' },
size: 40,
font: '10px monospace',
ticks: { size: 3 },
},
{
stroke: '#555568',
grid: { stroke: '#2e2e3822' },
size: 35,
font: '10px monospace',
ticks: { size: 3 },
},
], ],
cursor: { show: true }, cursor: { show: true },
legend: { show: true }, legend: { show: true, live: false },
padding: [8, 8, 0, 0],
hooks: {
setScale: [(_self: uPlot, scaleKey: string) => {
if (scaleKey === 'x') zoomed.value = true
}],
},
} }
} }
function resetZoom() {
if (!chart) return
const data = chart.data
if (data && data[0] && data[0].length > 0) {
const min = data[0][0]
const max = data[0][data[0].length - 1]
chart.setScale('x', { min, max })
}
zoomed.value = false
}
function getLegendHeight(): number {
if (!container.value) return 0
const legend = container.value.querySelector('.u-legend') as HTMLElement | null
return legend ? legend.offsetHeight : 0
}
function createChart() { function createChart() {
if (!container.value) return if (!container.value) return
if (chart) chart.destroy() if (chart) chart.destroy()
chart = new uPlot(buildOpts(), props.data, container.value) chart = new uPlot(buildOpts(), props.data, container.value)
// Refit after legend renders
nextTick(() => resize())
} }
function resize() { function resize() {
if (!chart || !container.value) return if (!chart || !container.value) return
const legendH = getLegendHeight()
const availableH = container.value.clientHeight
// uPlot height = canvas height (chart sets total = canvas + legend)
const chartH = Math.max(60, availableH - legendH)
chart.setSize({ chart.setSize({
width: container.value.clientWidth, width: container.value.clientWidth,
height: container.value.clientHeight, height: chartH,
}) })
} }
@@ -83,19 +125,74 @@ onMounted(() => {
</script> </script>
<template> <template>
<div ref="container" class="timeseries-renderer" /> <div class="timeseries-wrapper">
<button v-if="zoomed" class="reset-zoom" @click="resetZoom" title="Reset zoom"></button>
<div ref="container" class="timeseries-renderer" />
</div>
</template> </template>
<style scoped> <style scoped>
.timeseries-wrapper {
width: 100%;
height: 100%;
position: relative;
}
.reset-zoom {
position: absolute;
top: 4px;
right: 4px;
z-index: 20;
background: var(--surface-2);
border: 1px solid var(--surface-3);
border-radius: 4px;
color: var(--text-secondary);
font-size: 14px;
width: 24px;
height: 24px;
cursor: pointer;
display: flex;
align-items: center;
justify-content: center;
opacity: 0.7;
transition: opacity 0.15s;
}
.reset-zoom:hover {
opacity: 1;
color: var(--text-primary);
}
.timeseries-renderer { .timeseries-renderer {
width: 100%; width: 100%;
height: 100%; height: 100%;
min-height: 150px; display: flex;
flex-direction: column;
overflow: hidden;
}
/* uPlot creates a .u-wrap for canvas + a .u-legend below it */
.timeseries-renderer :deep(.u-wrap) {
flex: 1;
min-height: 0;
}
.timeseries-renderer :deep(.u-legend) {
flex-shrink: 0;
} }
.timeseries-renderer :deep(.u-legend) { .timeseries-renderer :deep(.u-legend) {
font-family: var(--font-mono); font-family: var(--font-mono);
font-size: var(--font-size-sm); font-size: 10px;
color: var(--text-secondary); color: var(--text-secondary);
padding: 2px 0;
display: flex;
flex-wrap: wrap;
gap: 0 8px;
}
.timeseries-renderer :deep(.u-legend .u-series) {
display: inline-flex;
padding: 0;
} }
</style> </style>