From 08c58a6a9ddb7ea4999093ce45d2d18a57e7b19b Mon Sep 17 00:00:00 2001 From: buenosairesam Date: Thu, 26 Mar 2026 04:24:32 -0300 Subject: [PATCH] phase 10 --- core/api/detect_replay.py | 121 +++++++++ core/api/main.py | 4 + core/db/detect.py | 175 +++++++++++++ core/schema/models/__init__.py | 9 +- core/schema/models/detect.py | 64 +++++ core/schema/models/detect_jobs.py | 162 ++++++++++++ ctrl/.env.template | 7 +- ctrl/Tiltfile | 10 +- ctrl/k8s/base/fastapi.yaml | 2 + ctrl/k8s/base/kustomization.yaml | 1 + ctrl/k8s/base/minio.yaml | 87 +++++++ detect/checkpoint/__init__.py | 14 ++ detect/checkpoint/frames.py | 80 ++++++ detect/checkpoint/replay.py | 132 ++++++++++ detect/checkpoint/serializer.py | 133 ++++++++++ detect/checkpoint/storage.py | 215 ++++++++++++++++ detect/checkpoint/tasks.py | 71 ++++++ detect/emit.py | 43 +++- detect/graph.py | 127 ++++++++-- detect/profiles/__init__.py | 2 - detect/profiles/base.py | 7 - detect/profiles/soccer.py | 17 -- detect/profiles/stubs.py | 10 - detect/stages/brand_resolver.py | 202 +++++++++++---- detect/stages/vlm_cloud.py | 17 ++ detect/stages/vlm_local.py | 17 ++ detect/state.py | 7 + tests/detect/manual/test_escalation_e2e.py | 12 + tests/detect/manual/test_replay.py | 123 +++++++++ tests/detect/test_brand_resolver.py | 102 ++++---- tests/detect/test_checkpoint.py | 182 ++++++++++++++ tests/detect/test_profiles.py | 10 +- tests/detect/test_replay.py | 67 +++++ ui/detection-app/src/App.vue | 233 +++++++++++++++--- .../src/panels/BrandTablePanel.vue | 11 +- .../src/panels/CostStatsPanel.vue | 36 ++- ui/detection-app/src/panels/TimelinePanel.vue | 134 ++++++++-- ui/detection-app/src/types/sse-contract.ts | 41 +++ ui/framework/src/components/Panel.vue | 3 +- ui/framework/src/components/ResizeHandle.vue | 70 ++++++ ui/framework/src/index.ts | 1 + ui/framework/src/renderers/TableRenderer.vue | 7 +- .../src/renderers/TimeSeriesRenderer.vue | 111 ++++++++- 43 files changed, 2627 insertions(+), 252 deletions(-) create mode 100644 core/api/detect_replay.py create mode 100644 core/db/detect.py create mode 100644 core/schema/models/detect_jobs.py create mode 100644 ctrl/k8s/base/minio.yaml create mode 100644 detect/checkpoint/__init__.py create mode 100644 detect/checkpoint/frames.py create mode 100644 detect/checkpoint/replay.py create mode 100644 detect/checkpoint/serializer.py create mode 100644 detect/checkpoint/storage.py create mode 100644 detect/checkpoint/tasks.py create mode 100644 tests/detect/manual/test_replay.py create mode 100644 tests/detect/test_checkpoint.py create mode 100644 tests/detect/test_replay.py create mode 100644 ui/framework/src/components/ResizeHandle.vue diff --git a/core/api/detect_replay.py b/core/api/detect_replay.py new file mode 100644 index 0000000..89f0cb8 --- /dev/null +++ b/core/api/detect_replay.py @@ -0,0 +1,121 @@ +""" +API endpoints for checkpoint inspection, replay, and retry. + +GET /detect/checkpoints/{job_id} — list available checkpoints +POST /detect/replay — replay from a stage with config overrides +POST /detect/retry — queue async retry with different provider +""" + +from __future__ import annotations + +import logging + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/detect", tags=["detect"]) + + +# --- Request/Response models --- + +class CheckpointInfo(BaseModel): + stage: str + + +class ReplayRequest(BaseModel): + job_id: str + start_stage: str + config_overrides: dict | None = None + + +class ReplayResponse(BaseModel): + status: str + job_id: str + start_stage: str + detections: int = 0 + brands_found: int = 0 + + +class RetryRequest(BaseModel): + job_id: str + config_overrides: dict | None = None + start_stage: str = "escalate_vlm" + schedule_seconds: float | None = None # delay before execution (off-peak) + + +class RetryResponse(BaseModel): + status: str + task_id: str + job_id: str + + +# --- Endpoints --- + +@router.get("/checkpoints/{job_id}") +def list_checkpoints(job_id: str) -> list[CheckpointInfo]: + """List available checkpoint stages for a job.""" + from detect.checkpoint import list_checkpoints as _list + + try: + stages = _list(job_id) + except Exception as e: + raise HTTPException(status_code=404, detail=f"No checkpoints for job {job_id}: {e}") + + result = [CheckpointInfo(stage=s) for s in stages] + return result + + +@router.post("/replay", response_model=ReplayResponse) +def replay(req: ReplayRequest): + """Replay pipeline from a specific stage with optional config overrides.""" + from detect.checkpoint import replay_from + + try: + result = replay_from( + job_id=req.job_id, + start_stage=req.start_stage, + config_overrides=req.config_overrides, + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Replay failed: {e}") + + detections = result.get("detections", []) + report = result.get("report") + brands_found = len(report.brands) if report else 0 + + response = ReplayResponse( + status="completed", + job_id=req.job_id, + start_stage=req.start_stage, + detections=len(detections), + brands_found=brands_found, + ) + return response + + +@router.post("/retry", response_model=RetryResponse) +def retry(req: RetryRequest): + """Queue an async retry of unresolved candidates with different config.""" + from detect.checkpoint.tasks import retry_candidates + + kwargs = { + "job_id": req.job_id, + "config_overrides": req.config_overrides, + "start_stage": req.start_stage, + } + + if req.schedule_seconds: + task = retry_candidates.apply_async(kwargs=kwargs, countdown=req.schedule_seconds) + else: + task = retry_candidates.delay(**kwargs) + + response = RetryResponse( + status="queued", + task_id=task.id, + job_id=req.job_id, + ) + return response diff --git a/core/api/main.py b/core/api/main.py index 4177145..9f02e55 100644 --- a/core/api/main.py +++ b/core/api/main.py @@ -25,6 +25,7 @@ from strawberry.fastapi import GraphQLRouter from core.api.chunker_sse import router as chunker_router from core.api.detect_sse import router as detect_router +from core.api.detect_replay import router as detect_replay_router from core.api.graphql import schema as graphql_schema CALLBACK_API_KEY = os.environ.get("CALLBACK_API_KEY", "") @@ -56,6 +57,9 @@ app.include_router(chunker_router) # Detection SSE app.include_router(detect_router) +# Detection replay/retry +app.include_router(detect_replay_router) + @app.get("/health") def health(): diff --git a/core/db/detect.py b/core/db/detect.py new file mode 100644 index 0000000..93872a1 --- /dev/null +++ b/core/db/detect.py @@ -0,0 +1,175 @@ +"""Database operations for DetectJob and StageCheckpoint.""" + +from typing import Optional +from uuid import UUID + + +# --------------------------------------------------------------------------- +# DetectJob +# --------------------------------------------------------------------------- + +def create_detect_job(**fields): + from admin.mpr.media_assets.models import DetectJob + return DetectJob.objects.create(**fields) + + +def get_detect_job(id: UUID): + from admin.mpr.media_assets.models import DetectJob + return DetectJob.objects.get(id=id) + + +def update_detect_job(job_id: UUID, **fields): + from admin.mpr.media_assets.models import DetectJob + DetectJob.objects.filter(id=job_id).update(**fields) + + +def list_detect_jobs( + parent_job_id: Optional[UUID] = None, + status: Optional[str] = None, +): + from admin.mpr.media_assets.models import DetectJob + + qs = DetectJob.objects.all() + if parent_job_id: + qs = qs.filter(parent_job_id=parent_job_id) + if status: + qs = qs.filter(status=status) + return list(qs) + + +# --------------------------------------------------------------------------- +# StageCheckpoint +# --------------------------------------------------------------------------- + +def save_stage_checkpoint(**fields): + from admin.mpr.media_assets.models import StageCheckpoint + return StageCheckpoint.objects.create(**fields) + + +def get_stage_checkpoint(job_id: UUID, stage: str): + from admin.mpr.media_assets.models import StageCheckpoint + return StageCheckpoint.objects.get(job_id=job_id, stage=stage) + + +def list_stage_checkpoints(job_id: UUID) -> list[str]: + from admin.mpr.media_assets.models import StageCheckpoint + + stages = ( + StageCheckpoint.objects + .filter(job_id=job_id) + .order_by("stage_index") + .values_list("stage", flat=True) + ) + return list(stages) + + +def delete_stage_checkpoints(job_id: UUID): + from admin.mpr.media_assets.models import StageCheckpoint + StageCheckpoint.objects.filter(job_id=job_id).delete() + + +# --------------------------------------------------------------------------- +# KnownBrand +# --------------------------------------------------------------------------- + +def get_or_create_brand(canonical_name: str, aliases: list[str] | None = None, + source: str = "ocr") -> tuple: + """Get existing brand or create new one. Returns (brand, created).""" + from admin.mpr.media_assets.models import KnownBrand + import uuid + + normalized = canonical_name.strip() + brand = KnownBrand.objects.filter(canonical_name__iexact=normalized).first() + if brand: + return brand, False + + # Check aliases of existing brands + for existing in KnownBrand.objects.all(): + existing_aliases = [a.lower() for a in (existing.aliases or [])] + if normalized.lower() in existing_aliases: + return existing, False + + brand = KnownBrand.objects.create( + id=uuid.uuid4(), + canonical_name=normalized, + aliases=aliases or [], + first_source=source, + ) + return brand, True + + +def find_brand_by_text(text: str) -> Optional[object]: + """Find a known brand by canonical name or alias (case-insensitive).""" + from admin.mpr.media_assets.models import KnownBrand + + normalized = text.strip().lower() + + # Exact canonical match + brand = KnownBrand.objects.filter(canonical_name__iexact=normalized).first() + if brand: + return brand + + # Search aliases (jsonb contains) + for brand in KnownBrand.objects.all(): + brand_aliases = [a.lower() for a in (brand.aliases or [])] + if normalized in brand_aliases: + return brand + + return None + + +def list_all_brands() -> list: + from admin.mpr.media_assets.models import KnownBrand + return list(KnownBrand.objects.all().order_by("canonical_name")) + + +def update_brand(brand_id: UUID, **fields): + from admin.mpr.media_assets.models import KnownBrand + KnownBrand.objects.filter(id=brand_id).update(**fields) + + +# --------------------------------------------------------------------------- +# SourceBrandSighting +# --------------------------------------------------------------------------- + +def get_source_sightings(source_asset_id: UUID) -> list: + """Get all brand sightings for a specific source video.""" + from admin.mpr.media_assets.models import SourceBrandSighting + return list( + SourceBrandSighting.objects + .filter(source_asset_id=source_asset_id) + .order_by("-occurrences") + ) + + +def record_sighting(source_asset_id: UUID, brand_id: UUID, brand_name: str, + timestamp: float, confidence: float, source: str = "ocr"): + """Record or update a brand sighting for a source.""" + from admin.mpr.media_assets.models import SourceBrandSighting + import uuid + + sighting = SourceBrandSighting.objects.filter( + source_asset_id=source_asset_id, + brand_id=brand_id, + ).first() + + if sighting: + sighting.occurrences += 1 + sighting.last_seen_timestamp = timestamp + total_conf = sighting.avg_confidence * (sighting.occurrences - 1) + confidence + sighting.avg_confidence = total_conf / sighting.occurrences + sighting.save() + return sighting + + sighting = SourceBrandSighting.objects.create( + id=uuid.uuid4(), + source_asset_id=source_asset_id, + brand_id=brand_id, + brand_name=brand_name, + first_seen_timestamp=timestamp, + last_seen_timestamp=timestamp, + occurrences=1, + detection_source=source, + avg_confidence=confidence, + ) + return sighting diff --git a/core/schema/models/__init__.py b/core/schema/models/__init__.py index 8b6266d..d66d450 100644 --- a/core/schema/models/__init__.py +++ b/core/schema/models/__init__.py @@ -26,13 +26,18 @@ from .grpc import ( WorkerStatus, ) from .jobs import ChunkJob, ChunkJobStatus, JobStatus, TranscodeJob +from .detect_jobs import ( + DetectJob, DetectJobStatus, RunType, StageCheckpoint, + BrandSource, KnownBrand, SourceBrandSighting, +) from .media import AssetStatus, MediaAsset from .presets import BUILTIN_PRESETS, TranscodePreset from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generic loader from .views import ChunkEvent, ChunkOutputFile, PipelineStats, WorkerEvent # Core domain models - generates Django, Pydantic, TypeScript -DATACLASSES = [MediaAsset, TranscodePreset, TranscodeJob, ChunkJob] +DATACLASSES = [MediaAsset, TranscodePreset, TranscodeJob, ChunkJob, + DetectJob, StageCheckpoint, KnownBrand, SourceBrandSighting] # API request/response models - generates TypeScript only (no Django) # WorkerStatus from grpc.py is reused here @@ -46,7 +51,7 @@ API_MODELS = [ ] # Status enums - included in generated code -ENUMS = [AssetStatus, JobStatus, ChunkJobStatus] +ENUMS = [AssetStatus, JobStatus, ChunkJobStatus, DetectJobStatus, RunType, BrandSource] # View/event models - generates TypeScript for UI consumption VIEWS = [ChunkEvent, WorkerEvent, PipelineStats, ChunkOutputFile] diff --git a/core/schema/models/detect.py b/core/schema/models/detect.py index 4e373c6..33c5325 100644 --- a/core/schema/models/detect.py +++ b/core/schema/models/detect.py @@ -149,6 +149,64 @@ class JobComplete: report: Optional[DetectionReportSummary] = None +@dataclass +class RunContext: + """Run context injected into all SSE events for grouping.""" + + run_id: str + parent_job_id: str + run_type: str = "initial" # initial | replay | retry + + +# --- Checkpoint API types --- + + +@dataclass +class CheckpointInfo: + """Available checkpoint for a stage.""" + + stage: str + + +@dataclass +class ReplayRequest: + """Request to replay pipeline from a specific stage.""" + + job_id: str + start_stage: str + config_overrides: Optional[dict] = None + + +@dataclass +class ReplayResponse: + """Result of a replay invocation.""" + + status: str + job_id: str + start_stage: str + detections: int = 0 + brands_found: int = 0 + + +@dataclass +class RetryRequest: + """Request to queue async retry with different config.""" + + job_id: str + config_overrides: Optional[dict] = None + start_stage: str = "escalate_vlm" + schedule_seconds: Optional[float] = None + + +@dataclass +class RetryResponse: + """Result of queueing a retry task.""" + + status: str + task_id: str + job_id: str + + # --- Export lists for modelgen --- DETECT_VIEWS = [ @@ -163,4 +221,10 @@ DETECT_VIEWS = [ LogEvent, DetectionReportSummary, JobComplete, + RunContext, + CheckpointInfo, + ReplayRequest, + ReplayResponse, + RetryRequest, + RetryResponse, ] diff --git a/core/schema/models/detect_jobs.py b/core/schema/models/detect_jobs.py new file mode 100644 index 0000000..72a8232 --- /dev/null +++ b/core/schema/models/detect_jobs.py @@ -0,0 +1,162 @@ +""" +Detection Job and Checkpoint Schema Definitions + +Source of truth for detection pipeline job tracking and stage checkpoints. +Follows the TranscodeJob/ChunkJob pattern. +""" + +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional +from uuid import UUID + + +class DetectJobStatus(str, Enum): + PENDING = "pending" + RUNNING = "running" + PAUSED = "paused" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class RunType(str, Enum): + INITIAL = "initial" + REPLAY = "replay" + RETRY = "retry" + + +@dataclass +class DetectJob: + """ + A detection pipeline job. + + Each invocation of the pipeline (initial run, replay, retry) creates a DetectJob. + Jobs for the same source video are linked via parent_job_id. + """ + + id: UUID + + # Input + source_asset_id: UUID + video_path: str + profile_name: str = "soccer_broadcast" + + # Run lineage + parent_job_id: Optional[UUID] = None # links all runs for the same source + run_type: RunType = RunType.INITIAL + replay_from_stage: Optional[str] = None # null for initial runs + config_overrides: Dict[str, Any] = field(default_factory=dict) + + # Status + status: DetectJobStatus = DetectJobStatus.PENDING + current_stage: Optional[str] = None + progress: float = 0.0 + error_message: Optional[str] = None + + # Results summary + total_detections: int = 0 + brands_found: int = 0 + cloud_llm_calls: int = 0 + estimated_cost_usd: float = 0.0 + + # Worker tracking + celery_task_id: Optional[str] = None + priority: int = 0 + + # Timestamps + created_at: Optional[datetime] = None + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + + +@dataclass +class StageCheckpoint: + """ + A checkpoint saved after a pipeline stage completes. + + Binary data (frame images, crops) goes to S3/MinIO. + Everything else (structured state) lives here in Postgres. + """ + + id: UUID + job_id: UUID + stage: str + stage_index: int # position in NODES list (0-7) + + # S3 reference for binary data only + frames_prefix: str = "" # s3 prefix: checkpoints/{job_id}/frames/ + + # Frame metadata (non-image fields) + frames_manifest: Dict[int, str] = field(default_factory=dict) # seq → s3 key + frames_meta: List[Dict[str, Any]] = field(default_factory=list) # sequence, chunk_id, timestamp, hash + filtered_frame_sequences: List[int] = field(default_factory=list) + + # Detection state (full structured data, not just summaries) + boxes_by_frame: Dict[str, List[Dict[str, Any]]] = field(default_factory=dict) + text_candidates: List[Dict[str, Any]] = field(default_factory=list) + unresolved_candidates: List[Dict[str, Any]] = field(default_factory=list) + detections: List[Dict[str, Any]] = field(default_factory=list) + + # Pipeline state + stats: Dict[str, Any] = field(default_factory=dict) + config_snapshot: Dict[str, Any] = field(default_factory=dict) + config_overrides: Dict[str, Any] = field(default_factory=dict) + + # Input refs (for replay) + video_path: str = "" + profile_name: str = "" + + # Timestamps + created_at: Optional[datetime] = None + + +class BrandSource(str, Enum): + """How a brand was first identified.""" + OCR = "ocr" + VLM = "local_vlm" + CLOUD = "cloud_llm" + MANUAL = "manual" # user-added via UI + + +@dataclass +class KnownBrand: + """ + A brand discovered or registered in the system. + + Global — not per-source. Accumulates across all pipeline runs. + Aliases enable fuzzy matching without re-escalating to VLM. + """ + + id: UUID + canonical_name: str # normalized display name + aliases: List[str] = field(default_factory=list) # known spellings/variants + first_source: BrandSource = BrandSource.OCR + total_occurrences: int = 0 + confirmed: bool = False # manually confirmed by user + + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + + +@dataclass +class SourceBrandSighting: + """ + A brand seen in a specific source (video/asset). + + Per-source session cache — avoids re-escalating the same brand + on subsequent frames or re-runs of the same source. + """ + + id: UUID + source_asset_id: UUID # the video this sighting belongs to + brand_id: UUID # FK to KnownBrand + brand_name: str # denormalized for fast lookup + first_seen_timestamp: float = 0.0 + last_seen_timestamp: float = 0.0 + occurrences: int = 0 + detection_source: BrandSource = BrandSource.OCR + avg_confidence: float = 0.0 + + created_at: Optional[datetime] = None diff --git a/ctrl/.env.template b/ctrl/.env.template index 9066bef..8492437 100644 --- a/ctrl/.env.template +++ b/ctrl/.env.template @@ -28,7 +28,10 @@ GRPC_PORT=50051 GRPC_MAX_WORKERS=10 # S3 Storage (MinIO locally, real S3 on AWS) -S3_ENDPOINT_URL=http://minio:9000 +# In k8s/docker: http://minio:9000 +# On dev machine (port-forward): http://localhost:9000 +# On AWS: omit S3_ENDPOINT_URL entirely +S3_ENDPOINT_URL=http://localhost:9000 S3_BUCKET_IN=mpr-media-in S3_BUCKET_OUT=mpr-media-out AWS_REGION=us-east-1 @@ -44,7 +47,7 @@ CLOUD_LLM_PROVIDER=groq # Groq (default, free tier) GROQ_API_KEY= -GROQ_MODEL=llama-3.2-90b-vision-preview +GROQ_MODEL=meta-llama/llama-4-scout-17b-16e-instruct # Gemini #GEMINI_API_KEY= diff --git a/ctrl/Tiltfile b/ctrl/Tiltfile index aa2e6e3..9aef23d 100644 --- a/ctrl/Tiltfile +++ b/ctrl/Tiltfile @@ -35,7 +35,15 @@ docker_build( # --- Resources --- k8s_resource('redis') -k8s_resource('fastapi', resource_deps=['redis']) +k8s_resource('minio', port_forwards=['9000:9000', '9001:9001']) +k8s_resource('fastapi', resource_deps=['redis', 'minio']) k8s_resource('detection-ui') k8s_resource('gateway', resource_deps=['fastapi', 'detection-ui'], port_forwards=['8080:8080']) + +# Group uncategorized resources (configmaps, namespace) under infra +k8s_resource( + objects=['mpr:namespace', 'mpr-config:configmap', 'minio-config:configmap', + 'envoy-gateway-config:configmap'], + new_name='infra', +) diff --git a/ctrl/k8s/base/fastapi.yaml b/ctrl/k8s/base/fastapi.yaml index 90d1833..ad42020 100644 --- a/ctrl/k8s/base/fastapi.yaml +++ b/ctrl/k8s/base/fastapi.yaml @@ -22,6 +22,8 @@ spec: envFrom: - configMapRef: name: mpr-config + - configMapRef: + name: minio-config readinessProbe: httpGet: path: /health diff --git a/ctrl/k8s/base/kustomization.yaml b/ctrl/k8s/base/kustomization.yaml index b9f453b..982abb3 100644 --- a/ctrl/k8s/base/kustomization.yaml +++ b/ctrl/k8s/base/kustomization.yaml @@ -7,6 +7,7 @@ resources: - namespace.yaml - configmap.yaml - redis.yaml + - minio.yaml - fastapi.yaml - detection-ui.yaml - gateway.yaml diff --git a/ctrl/k8s/base/minio.yaml b/ctrl/k8s/base/minio.yaml new file mode 100644 index 0000000..157ac85 --- /dev/null +++ b/ctrl/k8s/base/minio.yaml @@ -0,0 +1,87 @@ +apiVersion: v1 +kind: ConfigMap +metadata: + name: minio-config + namespace: mpr +data: + S3_ENDPOINT_URL: http://minio:9000 + S3_BUCKET_IN: mpr-media-in + S3_BUCKET_OUT: mpr-media-out + AWS_ACCESS_KEY_ID: minioadmin + AWS_SECRET_ACCESS_KEY: minioadmin + AWS_REGION: us-east-1 +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: minio + namespace: mpr +spec: + replicas: 1 + selector: + matchLabels: + app: minio + template: + metadata: + labels: + app: minio + spec: + containers: + - name: minio + image: minio/minio:latest + args: ["server", "/data", "--console-address", ":9001"] + ports: + - containerPort: 9000 + name: api + - containerPort: 9001 + name: console + env: + - name: MINIO_ROOT_USER + value: minioadmin + - name: MINIO_ROOT_PASSWORD + value: minioadmin + readinessProbe: + httpGet: + path: /minio/health/ready + port: 9000 + initialDelaySeconds: 5 + periodSeconds: 10 + lifecycle: + postStart: + exec: + command: + - /bin/sh + - -c + - | + sleep 3 + for bucket in mpr-media-in mpr-media-out; do + mkdir -p /data/$bucket + done + volumeMounts: + - name: data + mountPath: /data + resources: + requests: + memory: 128Mi + cpu: 100m + limits: + memory: 512Mi + volumes: + - name: data + emptyDir: {} +--- +apiVersion: v1 +kind: Service +metadata: + name: minio + namespace: mpr +spec: + selector: + app: minio + ports: + - port: 9000 + targetPort: 9000 + name: api + - port: 9001 + targetPort: 9001 + name: console diff --git a/detect/checkpoint/__init__.py b/detect/checkpoint/__init__.py new file mode 100644 index 0000000..062b5bd --- /dev/null +++ b/detect/checkpoint/__init__.py @@ -0,0 +1,14 @@ +""" +Stage checkpoint, replay, and retry. + + detect/checkpoint/ + frames.py — frame image S3 upload/download + serializer.py — state ↔ JSON conversion + storage.py — checkpoint save/load/list (Postgres + S3) + replay.py — replay_from, OverrideProfile + tasks.py — retry_candidates Celery task +""" + +from .storage import save_checkpoint, load_checkpoint, list_checkpoints +from .frames import save_frames, load_frames +from .replay import replay_from, OverrideProfile diff --git a/detect/checkpoint/frames.py b/detect/checkpoint/frames.py new file mode 100644 index 0000000..08e6425 --- /dev/null +++ b/detect/checkpoint/frames.py @@ -0,0 +1,80 @@ +"""Frame image storage — save/load to S3/MinIO as JPEGs.""" + +from __future__ import annotations + +import logging +import os +import tempfile + +import numpy as np +from PIL import Image + +from detect.models import Frame + +logger = logging.getLogger(__name__) + +BUCKET = os.environ.get("S3_BUCKET_OUT", "mpr-media-out") +CHECKPOINT_PREFIX = "checkpoints" + + +def save_frames(job_id: str, frames: list[Frame]) -> dict[int, str]: + """ + Save frame images to S3 as JPEGs. + + Returns manifest: {sequence: s3_key} + """ + from core.storage.s3 import upload_file + + manifest = {} + + for frame in frames: + key = f"{CHECKPOINT_PREFIX}/{job_id}/frames/{frame.sequence}.jpg" + + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: + img = Image.fromarray(frame.image) + img.save(tmp, format="JPEG", quality=85) + tmp_path = tmp.name + + try: + upload_file(tmp_path, BUCKET, key) + finally: + os.unlink(tmp_path) + + manifest[frame.sequence] = key + + logger.info("Saved %d frames to s3://%s/%s/%s/frames/", + len(frames), BUCKET, CHECKPOINT_PREFIX, job_id) + return manifest + + +def load_frames(manifest: dict[int, str], frame_metadata: list[dict]) -> list[Frame]: + """ + Load frame images from S3 and reconstitute Frame objects. + + frame_metadata: list of dicts with sequence, chunk_id, timestamp, perceptual_hash. + """ + from core.storage.s3 import download_to_temp + + meta_map = {m["sequence"]: m for m in frame_metadata} + frames = [] + + for seq, key in manifest.items(): + tmp_path = download_to_temp(BUCKET, key) + try: + img = Image.open(tmp_path).convert("RGB") + image_array = np.array(img) + finally: + os.unlink(tmp_path) + + meta = meta_map.get(seq, {}) + frame = Frame( + sequence=seq, + chunk_id=meta.get("chunk_id", 0), + timestamp=meta.get("timestamp", 0.0), + image=image_array, + perceptual_hash=meta.get("perceptual_hash", ""), + ) + frames.append(frame) + + frames.sort(key=lambda f: f.sequence) + return frames diff --git a/detect/checkpoint/replay.py b/detect/checkpoint/replay.py new file mode 100644 index 0000000..c40ab04 --- /dev/null +++ b/detect/checkpoint/replay.py @@ -0,0 +1,132 @@ +""" +Pipeline replay — re-run from any stage with different config. + +Loads a checkpoint, applies config overrides, builds a subgraph +starting from the target stage, and invokes it. +""" + +from __future__ import annotations + +import logging + +import uuid + +from detect import emit +from detect.checkpoint import load_checkpoint, list_checkpoints +from detect.graph import NODES, build_graph + +logger = logging.getLogger(__name__) + + +class OverrideProfile: + """ + Wraps a ContentTypeProfile and patches config methods with overrides. + + Override dict structure: + { + "frame_extraction": {"fps": 1.0}, + "scene_filter": {"hamming_threshold": 12}, + "detection": {"confidence_threshold": 0.5}, + "ocr": {"languages": ["en", "es"], "min_confidence": 0.3}, + "resolver": {"fuzzy_threshold": 60}, + } + """ + + def __init__(self, base, overrides: dict): + self._base = base + self._overrides = overrides + + def __getattr__(self, name): + return getattr(self._base, name) + + def _patch(self, config, key: str): + patches = self._overrides.get(key, {}) + for k, v in patches.items(): + if hasattr(config, k): + setattr(config, k, v) + return config + + def frame_extraction_config(self): + return self._patch(self._base.frame_extraction_config(), "frame_extraction") + + def scene_filter_config(self): + return self._patch(self._base.scene_filter_config(), "scene_filter") + + def detection_config(self): + return self._patch(self._base.detection_config(), "detection") + + def ocr_config(self): + return self._patch(self._base.ocr_config(), "ocr") + + def resolver_config(self): + return self._patch(self._base.resolver_config(), "resolver") + + def vlm_prompt(self, crop_context): + return self._base.vlm_prompt(crop_context) + + def aggregate(self, detections): + return self._base.aggregate(detections) + + def auxiliary_detections(self, source): + return self._base.auxiliary_detections(source) + + +def replay_from( + job_id: str, + start_stage: str, + config_overrides: dict | None = None, + checkpoint: bool = True, +) -> dict: + """ + Replay the pipeline from a specific stage. + + Loads the checkpoint from the stage immediately before start_stage, + applies config overrides, and runs the subgraph from start_stage onward. + + Returns the final state dict. + """ + if start_stage not in NODES: + raise ValueError(f"Unknown stage: {start_stage!r}. Options: {NODES}") + + start_idx = NODES.index(start_stage) + + # Load checkpoint from the stage before start_stage + if start_idx == 0: + raise ValueError("Cannot replay from the first stage — just run the full pipeline") + + previous_stage = NODES[start_idx - 1] + + available = list_checkpoints(job_id) + if previous_stage not in available: + raise ValueError( + f"No checkpoint for stage {previous_stage!r} (job {job_id}). " + f"Available: {available}" + ) + + logger.info("Replaying job %s from %s (loading checkpoint: %s)", + job_id, start_stage, previous_stage) + + state = load_checkpoint(job_id, previous_stage) + + # Apply config overrides + if config_overrides: + state["config_overrides"] = config_overrides + + # Set run context for SSE events + run_id = str(uuid.uuid4())[:8] + emit.set_run_context( + run_id=run_id, + parent_job_id=job_id, + run_type="replay", + ) + + # Build subgraph starting from start_stage + graph = build_graph(checkpoint=checkpoint, start_from=start_stage) + pipeline = graph.compile() + + try: + result = pipeline.invoke(state) + finally: + emit.clear_run_context() + + return result diff --git a/detect/checkpoint/serializer.py b/detect/checkpoint/serializer.py new file mode 100644 index 0000000..35a9874 --- /dev/null +++ b/detect/checkpoint/serializer.py @@ -0,0 +1,133 @@ +"""State serialization — DetectState ↔ JSON-compatible dict.""" + +from __future__ import annotations + +import dataclasses + +from detect.models import ( + BoundingBox, + BrandDetection, + Frame, + PipelineStats, + TextCandidate, +) + + +# --------------------------------------------------------------------------- +# Serialize helpers +# --------------------------------------------------------------------------- + +def serialize_frame_meta(frame: Frame) -> dict: + meta = { + "sequence": frame.sequence, + "chunk_id": frame.chunk_id, + "timestamp": frame.timestamp, + "perceptual_hash": frame.perceptual_hash, + } + return meta + + +def serialize_text_candidate(tc: TextCandidate) -> dict: + bbox_dict = dataclasses.asdict(tc.bbox) + candidate = { + "frame_sequence": tc.frame.sequence, + "bbox": bbox_dict, + "text": tc.text, + "ocr_confidence": tc.ocr_confidence, + } + return candidate + + +def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict: + """ + Serialize DetectState to a JSON-compatible dict. + + Frame images are replaced with S3 key references. + TextCandidate.frame references become frame_sequence integers. + """ + frames = state.get("frames", []) + filtered = state.get("filtered_frames", []) + + manifest_strs = {str(k): v for k, v in frames_manifest.items()} + frames_meta = [serialize_frame_meta(f) for f in frames] + filtered_seqs = [f.sequence for f in filtered] + + boxes_serialized = {} + for seq, boxes in state.get("boxes_by_frame", {}).items(): + boxes_serialized[str(seq)] = [dataclasses.asdict(b) for b in boxes] + + text_candidates = [serialize_text_candidate(tc) for tc in state.get("text_candidates", [])] + unresolved = [serialize_text_candidate(tc) for tc in state.get("unresolved_candidates", [])] + detections = [dataclasses.asdict(d) for d in state.get("detections", [])] + stats = dataclasses.asdict(state.get("stats", PipelineStats())) + + checkpoint = { + "job_id": state.get("job_id", ""), + "video_path": state.get("video_path", ""), + "profile_name": state.get("profile_name", ""), + "config_overrides": state.get("config_overrides", {}), + "frames_manifest": manifest_strs, + "frames_meta": frames_meta, + "filtered_frame_sequences": filtered_seqs, + "boxes_by_frame": boxes_serialized, + "text_candidates": text_candidates, + "unresolved_candidates": unresolved, + "detections": detections, + "stats": stats, + } + return checkpoint + + +# --------------------------------------------------------------------------- +# Deserialize helpers +# --------------------------------------------------------------------------- + +def deserialize_text_candidate(d: dict, frame_map: dict[int, Frame]) -> TextCandidate: + frame = frame_map[d["frame_sequence"]] + bbox = BoundingBox(**d["bbox"]) + candidate = TextCandidate( + frame=frame, + bbox=bbox, + text=d["text"], + ocr_confidence=d["ocr_confidence"], + ) + return candidate + + +def deserialize_state(checkpoint: dict, frames: list[Frame]) -> dict: + """Reconstitute DetectState from a checkpoint dict + loaded frames.""" + frame_map = {f.sequence: f for f in frames} + + filtered_seqs = set(checkpoint.get("filtered_frame_sequences", [])) + filtered_frames = [f for f in frames if f.sequence in filtered_seqs] + + boxes_by_frame = {} + for seq_str, box_dicts in checkpoint.get("boxes_by_frame", {}).items(): + seq = int(seq_str) + boxes_by_frame[seq] = [BoundingBox(**b) for b in box_dicts] + + text_candidates = [ + deserialize_text_candidate(d, frame_map) + for d in checkpoint.get("text_candidates", []) + ] + unresolved_candidates = [ + deserialize_text_candidate(d, frame_map) + for d in checkpoint.get("unresolved_candidates", []) + ] + detections = [BrandDetection(**d) for d in checkpoint.get("detections", [])] + stats = PipelineStats(**checkpoint.get("stats", {})) + + state = { + "job_id": checkpoint.get("job_id", ""), + "video_path": checkpoint.get("video_path", ""), + "profile_name": checkpoint.get("profile_name", ""), + "config_overrides": checkpoint.get("config_overrides", {}), + "frames": frames, + "filtered_frames": filtered_frames, + "boxes_by_frame": boxes_by_frame, + "text_candidates": text_candidates, + "unresolved_candidates": unresolved_candidates, + "detections": detections, + "stats": stats, + } + return state diff --git a/detect/checkpoint/storage.py b/detect/checkpoint/storage.py new file mode 100644 index 0000000..155cede --- /dev/null +++ b/detect/checkpoint/storage.py @@ -0,0 +1,215 @@ +""" +Checkpoint storage — save/load stage state. + +Binary data (frame images) → S3/MinIO via frames.py +Structured data (boxes, detections, stats, config) → Postgres via Django ORM + +Until the Django model is generated by modelgen, checkpoint data is stored +as JSON in S3 as a fallback. Once DetectJob/StageCheckpoint models exist, +this module switches to Postgres. +""" + +from __future__ import annotations + +import json +import logging +import os +import tempfile +from pathlib import Path + +from .frames import save_frames, load_frames, BUCKET, CHECKPOINT_PREFIX +from .serializer import serialize_state, deserialize_state + +logger = logging.getLogger(__name__) + + +def _has_db() -> bool: + """Check if the DB layer is available (Django + models generated by modelgen).""" + try: + from core.db.detect import get_stage_checkpoint as _ + # Quick check that the model exists (modelgen may not have run yet) + from admin.mpr.media_assets.models import StageCheckpoint as _ + return True + except (ImportError, Exception): + return False + + +# --------------------------------------------------------------------------- +# Save +# --------------------------------------------------------------------------- + +def save_checkpoint( + job_id: str, + stage: str, + stage_index: int, + state: dict, + frames_manifest: dict[int, str] | None = None, +) -> str: + """ + Save a stage checkpoint. + + Saves frame images to S3 (if not already saved), then persists + structured state to Postgres (or S3 JSON fallback). + + Returns the checkpoint identifier (DB id or S3 key). + """ + # Save frames to S3 if no manifest provided + if frames_manifest is None: + all_frames = state.get("frames", []) + frames_manifest = save_frames(job_id, all_frames) + + checkpoint_data = serialize_state(state, frames_manifest) + + if _has_db(): + checkpoint_id = _save_to_db(job_id, stage, stage_index, checkpoint_data) + else: + checkpoint_id = _save_to_s3(job_id, stage, checkpoint_data) + + return checkpoint_id + + +def _save_to_db(job_id: str, stage: str, stage_index: int, data: dict) -> str: + """Save checkpoint structured data to Postgres via core/db.""" + import uuid + from core.db.detect import save_stage_checkpoint + + job_uuid = uuid.UUID(job_id) if isinstance(job_id, str) else job_id + checkpoint_id = uuid.uuid4() + frames_prefix = f"{CHECKPOINT_PREFIX}/{job_id}/frames/" + + checkpoint = save_stage_checkpoint( + id=checkpoint_id, + job_id=job_uuid, + stage=stage, + stage_index=stage_index, + frames_prefix=frames_prefix, + frames_manifest=data.get("frames_manifest", {}), + frames_meta=data.get("frames_meta", []), + filtered_frame_sequences=data.get("filtered_frame_sequences", []), + boxes_by_frame=data.get("boxes_by_frame", {}), + text_candidates=data.get("text_candidates", []), + unresolved_candidates=data.get("unresolved_candidates", []), + detections=data.get("detections", []), + stats=data.get("stats", {}), + config_snapshot=data.get("config_overrides", {}), + config_overrides=data.get("config_overrides", {}), + video_path=data.get("video_path", ""), + profile_name=data.get("profile_name", ""), + ) + + logger.info("Checkpoint saved to DB: %s/%s (id=%s)", job_id, stage, checkpoint.id) + return str(checkpoint.id) + + +def _save_to_s3(job_id: str, stage: str, data: dict) -> str: + """Fallback: save checkpoint as JSON to S3 (before modelgen generates DB models).""" + from core.storage.s3 import upload_file + + checkpoint_json = json.dumps(data, default=str) + key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json" + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp: + tmp.write(checkpoint_json) + tmp_path = tmp.name + + try: + upload_file(tmp_path, BUCKET, key) + finally: + os.unlink(tmp_path) + + logger.info("Checkpoint saved to S3: s3://%s/%s", BUCKET, key) + return key + + +# --------------------------------------------------------------------------- +# Load +# --------------------------------------------------------------------------- + +def load_checkpoint(job_id: str, stage: str) -> dict: + """ + Load a stage checkpoint and reconstitute full DetectState. + + Tries Postgres first, falls back to S3 JSON. + """ + if _has_db(): + data = _load_from_db(job_id, stage) + else: + data = _load_from_s3(job_id, stage) + + raw_manifest = data.get("frames_manifest", {}) + manifest = {int(k): v for k, v in raw_manifest.items()} + frame_metadata = data.get("frames_meta", []) + frames = load_frames(manifest, frame_metadata) + + state = deserialize_state(data, frames) + + logger.info("Checkpoint loaded: %s/%s (%d frames)", job_id, stage, len(frames)) + return state + + +def _load_from_db(job_id: str, stage: str) -> dict: + """Load checkpoint data from Postgres via core/db.""" + from core.db.detect import get_stage_checkpoint + + checkpoint = get_stage_checkpoint(job_id, stage) + + data = { + "job_id": str(checkpoint.job_id), + "video_path": checkpoint.video_path, + "profile_name": checkpoint.profile_name, + "config_overrides": checkpoint.config_overrides, + "frames_manifest": checkpoint.frames_manifest, + "frames_meta": checkpoint.frames_meta, + "filtered_frame_sequences": checkpoint.filtered_frame_sequences, + "boxes_by_frame": checkpoint.boxes_by_frame, + "text_candidates": checkpoint.text_candidates, + "unresolved_candidates": checkpoint.unresolved_candidates, + "detections": checkpoint.detections, + "stats": checkpoint.stats, + } + return data + + +def _load_from_s3(job_id: str, stage: str) -> dict: + """Fallback: load checkpoint JSON from S3.""" + from core.storage.s3 import download_to_temp + + key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json" + tmp_path = download_to_temp(BUCKET, key) + try: + with open(tmp_path) as f: + data = json.load(f) + finally: + os.unlink(tmp_path) + + return data + + +# --------------------------------------------------------------------------- +# List +# --------------------------------------------------------------------------- + +def list_checkpoints(job_id: str) -> list[str]: + """List available checkpoint stages for a job.""" + if _has_db(): + return _list_from_db(job_id) + return _list_from_s3(job_id) + + +def _list_from_db(job_id: str) -> list[str]: + from core.db.detect import list_stage_checkpoints + return list_stage_checkpoints(job_id) + + +def _list_from_s3(job_id: str) -> list[str]: + from core.storage.s3 import list_objects + + prefix = f"{CHECKPOINT_PREFIX}/{job_id}/stages/" + objects = list_objects(BUCKET, prefix) + + stages = [] + for obj in objects: + name = Path(obj["key"]).stem + stages.append(name) + + return stages diff --git a/detect/checkpoint/tasks.py b/detect/checkpoint/tasks.py new file mode 100644 index 0000000..ea9e6a6 --- /dev/null +++ b/detect/checkpoint/tasks.py @@ -0,0 +1,71 @@ +""" +Celery tasks for detection pipeline async operations. + +retry_candidates: re-run VLM/cloud escalation with different config. +""" + +from __future__ import annotations + +import logging +import uuid +from datetime import datetime, timezone + +from celery import shared_task + +logger = logging.getLogger(__name__) + + +@shared_task(bind=True, max_retries=1, default_retry_delay=30) +def retry_candidates( + self, + job_id: str, + config_overrides: dict | None = None, + start_stage: str = "escalate_vlm", +): + """ + Retry unresolved candidates with different config. + + Loads the checkpoint from the stage before start_stage, + applies config overrides (e.g. different cloud provider), + and runs from start_stage onward. + """ + from detect.checkpoint.replay import replay_from + + run_id = str(uuid.uuid4())[:8] + logger.info("Retry task %s: job=%s, from=%s, overrides=%s", + run_id, job_id, start_stage, config_overrides) + + try: + result = replay_from( + job_id=job_id, + start_stage=start_stage, + config_overrides=config_overrides, + ) + + detections = result.get("detections", []) + report = result.get("report") + brands_found = len(report.brands) if report else 0 + + logger.info("Retry %s complete: %d detections, %d brands", + run_id, len(detections), brands_found) + + return { + "status": "completed", + "run_id": run_id, + "job_id": job_id, + "detections": len(detections), + "brands_found": brands_found, + } + + except Exception as e: + logger.exception("Retry %s failed: %s", run_id, e) + + if self.request.retries < self.max_retries: + raise self.retry(exc=e) + + return { + "status": "failed", + "run_id": run_id, + "job_id": job_id, + "error": str(e), + } diff --git a/detect/emit.py b/detect/emit.py index 0262a38..8008085 100644 --- a/detect/emit.py +++ b/detect/emit.py @@ -3,6 +3,9 @@ Event emission helpers for detection pipeline stages. Single place that knows how to build event payloads. Stages call these instead of constructing dicts or dataclasses directly. + +Run context (run_id, parent_job_id) is set once at pipeline start via +set_run_context() and automatically injected into all events. """ from __future__ import annotations @@ -13,9 +16,33 @@ from datetime import datetime, timezone from detect.events import push_detect_event from detect.models import PipelineStats +# Module-level run context — set once per pipeline invocation +_run_context: dict = {} + + +def set_run_context(run_id: str = "", parent_job_id: str = "", run_type: str = "initial"): + """Set the run context for all subsequent events in this pipeline invocation.""" + global _run_context + _run_context = { + "run_id": run_id, + "parent_job_id": parent_job_id, + "run_type": run_type, + } + + +def clear_run_context(): + global _run_context + _run_context = {} + + +def _inject_context(payload: dict) -> dict: + """Add run context fields to an event payload.""" + if _run_context: + payload.update(_run_context) + return payload + def log(job_id: str | None, stage: str, level: str, msg: str) -> None: - """Emit a log event.""" if not job_id: return payload = { @@ -24,15 +51,17 @@ def log(job_id: str | None, stage: str, level: str, msg: str) -> None: "msg": msg, "ts": datetime.now(timezone.utc).isoformat(), } + _inject_context(payload) push_detect_event(job_id, "log", payload) def stats(job_id: str | None, **kwargs) -> None: - """Emit a stats_update event. Pass only the fields that changed.""" if not job_id: return s = PipelineStats(**kwargs) - push_detect_event(job_id, "stats_update", dataclasses.asdict(s)) + payload = dataclasses.asdict(s) + _inject_context(payload) + push_detect_event(job_id, "stats_update", payload) def frame_update( @@ -42,7 +71,6 @@ def frame_update( jpeg_b64: str, boxes: list[dict], ) -> None: - """Emit a frame_update event with the image and bounding boxes.""" if not job_id: return payload = { @@ -51,14 +79,15 @@ def frame_update( "jpeg_b64": jpeg_b64, "boxes": boxes, } + _inject_context(payload) push_detect_event(job_id, "frame_update", payload) def graph_update(job_id: str | None, nodes: list[dict]) -> None: - """Emit a graph_update event with node states.""" if not job_id: return payload = {"nodes": nodes} + _inject_context(payload) push_detect_event(job_id, "graph_update", payload) @@ -72,7 +101,6 @@ def detection( content_type: str = "", frame_ref: int | None = None, ) -> None: - """Emit a brand detection event.""" if not job_id: return payload = { @@ -84,12 +112,13 @@ def detection( "content_type": content_type, "frame_ref": frame_ref, } + _inject_context(payload) push_detect_event(job_id, "detection", payload) def job_complete(job_id: str | None, report: dict) -> None: - """Emit a job_complete event with the final report.""" if not job_id: return payload = {"job_id": job_id, "report": report} + _inject_context(payload) push_detect_event(job_id, "job_complete", payload) diff --git a/detect/graph.py b/detect/graph.py index 62c5647..b945f3b 100644 --- a/detect/graph.py +++ b/detect/graph.py @@ -42,8 +42,16 @@ NODES = [ def _get_profile(state: DetectState): name = state.get("profile_name", "soccer_broadcast") if name == "soccer_broadcast": - return SoccerBroadcastProfile() - raise ValueError(f"Unknown profile: {name}") + profile = SoccerBroadcastProfile() + else: + raise ValueError(f"Unknown profile: {name}") + + overrides = state.get("config_overrides") + if overrides: + from detect.checkpoint.replay import OverrideProfile + profile = OverrideProfile(profile, overrides) + + return profile # Track node states across the pipeline run @@ -68,6 +76,18 @@ def _emit_transition(state: DetectState, node: str, status: str): # --- Node functions --- def node_extract_frames(state: DetectState) -> dict: + # Set run context for initial runs (replays set it in replay_from) + job_id = state.get("job_id", "") + if job_id and not emit._run_context: + emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial") + + # Load session brands from DB for this source + source_asset_id = state.get("source_asset_id") + if source_asset_id and not state.get("session_brands"): + from detect.stages.brand_resolver import build_session_dict + session_brands = build_session_dict(source_asset_id) + state["session_brands"] = session_brands + _emit_transition(state, "extract_frames", "running") with trace_node(state, "extract_frames") as span: @@ -142,13 +162,16 @@ def node_match_brands(state: DetectState) -> dict: with trace_node(state, "match_brands") as span: profile = _get_profile(state) - dictionary = profile.brand_dictionary() resolver_config = profile.resolver_config() candidates = state.get("text_candidates", []) + session_brands = state.get("session_brands", {}) job_id = state.get("job_id") + source_asset_id = state.get("source_asset_id") matched, unresolved = resolve_brands( - candidates, dictionary, resolver_config, + candidates, resolver_config, + session_brands=session_brands, + source_asset_id=source_asset_id, content_type=profile.name, job_id=job_id, ) span.set_output({"matched": len(matched), "unresolved": len(unresolved)}) @@ -170,6 +193,7 @@ def node_escalate_vlm(state: DetectState) -> dict: vlm_prompt_fn=profile.vlm_prompt, inference_url=INFERENCE_URL, content_type=profile.name, + source_asset_id=state.get("source_asset_id"), job_id=job_id, ) @@ -202,6 +226,7 @@ def node_escalate_cloud(state: DetectState) -> dict: vlm_prompt_fn=profile.vlm_prompt, stats=stats, content_type=profile.name, + source_asset_id=state.get("source_asset_id"), job_id=job_id, ) @@ -239,33 +264,87 @@ def node_compile_report(state: DetectState) -> dict: return {"report": report} +# --- Checkpoint wrapper --- + +_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1" +_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job) + + +def _checkpointing_node(node_name: str, node_fn): + """Wrap a node function to auto-checkpoint after completion.""" + stage_index = NODES.index(node_name) + + def wrapper(state: DetectState) -> dict: + result = node_fn(state) + + job_id = state.get("job_id", "") + if not job_id: + return result + + from detect.checkpoint import save_checkpoint, save_frames + + merged = {**state, **result} + + # Save frames once (first checkpoint), reuse manifest after + manifest = _frames_manifest.get(job_id) + if manifest is None and node_name == "extract_frames": + manifest = save_frames(job_id, merged.get("frames", [])) + _frames_manifest[job_id] = manifest + + save_checkpoint(job_id, node_name, stage_index, merged, frames_manifest=manifest) + return result + + wrapper.__name__ = node_fn.__name__ + return wrapper + + # --- Graph construction --- -def build_graph() -> StateGraph: +NODE_FUNCTIONS = [ + ("extract_frames", node_extract_frames), + ("filter_scenes", node_filter_scenes), + ("detect_objects", node_detect_objects), + ("run_ocr", node_run_ocr), + ("match_brands", node_match_brands), + ("escalate_vlm", node_escalate_vlm), + ("escalate_cloud", node_escalate_cloud), + ("compile_report", node_compile_report), +] + + +def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph: + """ + Build the pipeline graph. + + checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var) + start_from: skip nodes before this stage (for replay) + """ + do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED + graph = StateGraph(DetectState) - graph.add_node("extract_frames", node_extract_frames) - graph.add_node("filter_scenes", node_filter_scenes) - graph.add_node("detect_objects", node_detect_objects) - graph.add_node("run_ocr", node_run_ocr) - graph.add_node("match_brands", node_match_brands) - graph.add_node("escalate_vlm", node_escalate_vlm) - graph.add_node("escalate_cloud", node_escalate_cloud) - graph.add_node("compile_report", node_compile_report) + # Filter to start_from if replaying + node_pairs = NODE_FUNCTIONS + if start_from: + start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from) + node_pairs = NODE_FUNCTIONS[start_idx:] - graph.set_entry_point("extract_frames") - graph.add_edge("extract_frames", "filter_scenes") - graph.add_edge("filter_scenes", "detect_objects") - graph.add_edge("detect_objects", "run_ocr") - graph.add_edge("run_ocr", "match_brands") - graph.add_edge("match_brands", "escalate_vlm") - graph.add_edge("escalate_vlm", "escalate_cloud") - graph.add_edge("escalate_cloud", "compile_report") - graph.add_edge("compile_report", END) + for name, fn in node_pairs: + wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn + graph.add_node(name, wrapped) + + # Wire edges + entry = node_pairs[0][0] + graph.set_entry_point(entry) + + for i in range(len(node_pairs) - 1): + graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0]) + + graph.add_edge(node_pairs[-1][0], END) return graph -def get_pipeline(): +def get_pipeline(checkpoint: bool | None = None): """Return a compiled, runnable pipeline.""" - return build_graph().compile() + return build_graph(checkpoint=checkpoint).compile() diff --git a/detect/profiles/__init__.py b/detect/profiles/__init__.py index e4664dd..c77ed2e 100644 --- a/detect/profiles/__init__.py +++ b/detect/profiles/__init__.py @@ -1,6 +1,5 @@ from .base import ( ContentTypeProfile, - BrandDictionary, CropContext, DetectionConfig, FrameExtractionConfig, @@ -12,7 +11,6 @@ from .soccer import SoccerBroadcastProfile __all__ = [ "ContentTypeProfile", - "BrandDictionary", "CropContext", "DetectionConfig", "FrameExtractionConfig", diff --git a/detect/profiles/base.py b/detect/profiles/base.py index 81c386d..00b6419 100644 --- a/detect/profiles/base.py +++ b/detect/profiles/base.py @@ -44,12 +44,6 @@ class ResolverConfig: fuzzy_threshold: int = 75 -@dataclass -class BrandDictionary: - """Maps canonical brand name → list of known aliases/spellings.""" - brands: dict[str, list[str]] = field(default_factory=dict) - - @dataclass class CropContext: image: bytes @@ -64,7 +58,6 @@ class ContentTypeProfile(Protocol): def scene_filter_config(self) -> SceneFilterConfig: ... def detection_config(self) -> DetectionConfig: ... def ocr_config(self) -> OCRConfig: ... - def brand_dictionary(self) -> BrandDictionary: ... def resolver_config(self) -> ResolverConfig: ... def vlm_prompt(self, crop_context: CropContext) -> str: ... def aggregate(self, detections: list[BrandDetection]) -> DetectionReport: ... diff --git a/detect/profiles/soccer.py b/detect/profiles/soccer.py index 882d72a..916b651 100644 --- a/detect/profiles/soccer.py +++ b/detect/profiles/soccer.py @@ -5,7 +5,6 @@ from __future__ import annotations from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats from .base import ( - BrandDictionary, CropContext, DetectionConfig, FrameExtractionConfig, @@ -34,22 +33,6 @@ class SoccerBroadcastProfile: def ocr_config(self) -> OCRConfig: return OCRConfig(languages=["en", "es"], min_confidence=0.5) - def brand_dictionary(self) -> BrandDictionary: - return BrandDictionary(brands={ - "Nike": ["nike", "NIKE", "swoosh"], - "Adidas": ["adidas", "ADIDAS", "adi"], - "Puma": ["puma", "PUMA"], - "Emirates": ["emirates", "fly emirates", "EMIRATES"], - "Coca-Cola": ["coca-cola", "coca cola", "coke", "COCA-COLA"], - "Pepsi": ["pepsi", "PEPSI"], - "Mastercard": ["mastercard", "MASTERCARD"], - "Heineken": ["heineken", "HEINEKEN"], - "Santander": ["santander", "SANTANDER"], - "Gazprom": ["gazprom", "GAZPROM"], - "Qatar Airways": ["qatar airways", "QATAR AIRWAYS"], - "Lay's": ["lays", "lay's", "LAYS", "LAY'S"], - }) - def resolver_config(self) -> ResolverConfig: return ResolverConfig(fuzzy_threshold=75) diff --git a/detect/profiles/stubs.py b/detect/profiles/stubs.py index 9ccd489..ba43163 100644 --- a/detect/profiles/stubs.py +++ b/detect/profiles/stubs.py @@ -5,7 +5,6 @@ from __future__ import annotations from detect.models import BrandDetection, DetectionReport from .base import ( - BrandDictionary, CropContext, DetectionConfig, FrameExtractionConfig, @@ -30,9 +29,6 @@ class NewsBroadcastProfile: def ocr_config(self) -> OCRConfig: raise NotImplementedError - def brand_dictionary(self) -> BrandDictionary: - raise NotImplementedError - def resolver_config(self) -> ResolverConfig: raise NotImplementedError @@ -61,9 +57,6 @@ class AdvertisingProfile: def ocr_config(self) -> OCRConfig: raise NotImplementedError - def brand_dictionary(self) -> BrandDictionary: - raise NotImplementedError - def resolver_config(self) -> ResolverConfig: raise NotImplementedError @@ -92,9 +85,6 @@ class TranscriptProfile: def ocr_config(self) -> OCRConfig: raise NotImplementedError - def brand_dictionary(self) -> BrandDictionary: - raise NotImplementedError - def resolver_config(self) -> ResolverConfig: raise NotImplementedError diff --git a/detect/stages/brand_resolver.py b/detect/stages/brand_resolver.py index d5647c2..a9c1e7b 100644 --- a/detect/stages/brand_resolver.py +++ b/detect/stages/brand_resolver.py @@ -1,9 +1,17 @@ """ -Stage 5 — Brand Resolver +Stage 5 — Brand Resolver (discovery mode) -Matches OCR text against the profile's brand dictionary. -Uses exact matching first, then fuzzy matching (rapidfuzz) as fallback. -Emits detection events for confirmed brands. +Discovery-first brand matching. No static dictionary — all brands live in the DB. + +Flow: + 1. Check session sightings first (brands already seen in this source) + 2. Check global known brands (accumulated across all runs) + 3. Unresolved candidates → escalate to VLM/cloud + 4. Confirmed brands get added to DB for future runs + +The resolver is an enricher, not a gatekeeper. Every OCR text candidate +passes through — the question is whether we can resolve it cheaply (DB lookup) +or need to escalate (VLM/cloud). """ from __future__ import annotations @@ -14,99 +22,199 @@ from rapidfuzz import fuzz from detect import emit from detect.models import BrandDetection, TextCandidate -from detect.profiles.base import BrandDictionary, ResolverConfig +from detect.profiles.base import ResolverConfig logger = logging.getLogger(__name__) def _normalize(text: str) -> str: - """Normalize text for matching.""" return text.strip().lower() -def _exact_match(text: str, dictionary: BrandDictionary) -> str | None: - """Try exact match against all aliases.""" +def _has_db() -> bool: + try: + from core.db.detect import find_brand_by_text as _ + from admin.mpr.media_assets.models import KnownBrand as _ + return True + except (ImportError, Exception): + return False + + +def _match_session(text: str, session_brands: dict[str, str]) -> str | None: + """ + Check against session brands (already seen in this source). + + session_brands: {normalized_name: canonical_name, ...} + Includes aliases. + """ normalized = _normalize(text) - for canonical, aliases in dictionary.brands.items(): - if normalized == _normalize(canonical): - return canonical - for alias in aliases: - if normalized == _normalize(alias): - return canonical - return None + return session_brands.get(normalized) -def _fuzzy_match(text: str, dictionary: BrandDictionary, threshold: int) -> tuple[str | None, int]: - """Try fuzzy match, return (brand, score) or (None, 0).""" +def _match_known(text: str, threshold: int) -> tuple[str | None, str | None]: + """ + Check against global known brands in DB. + + Returns (canonical_name, brand_id) or (None, None). + """ + if not _has_db(): + return None, None + + from core.db.detect import find_brand_by_text + brand = find_brand_by_text(text) + if brand: + return brand.canonical_name, str(brand.id) + + # Fuzzy match against all known brands + from core.db.detect import list_all_brands + all_brands = list_all_brands() + normalized = _normalize(text) best_brand = None best_score = 0 - for canonical, aliases in dictionary.brands.items(): - all_names = [canonical] + aliases - for name in all_names: + for known in all_brands: + names = [known.canonical_name] + (known.aliases or []) + for name in names: score = fuzz.ratio(normalized, _normalize(name)) if score > best_score and score >= threshold: best_score = score - best_brand = canonical + best_brand = known - return best_brand, best_score + if best_brand: + return best_brand.canonical_name, str(best_brand.id) + + return None, None + + +def _register_brand(canonical_name: str, source: str) -> str | None: + """Register a newly discovered brand in the DB. Returns brand_id.""" + if not _has_db(): + return None + + from core.db.detect import get_or_create_brand + brand, created = get_or_create_brand(canonical_name, source=source) + if created: + logger.info("New brand discovered: %s (source=%s)", canonical_name, source) + return str(brand.id) + + +def _record_sighting(source_asset_id: str | None, brand_id: str, + brand_name: str, timestamp: float, + confidence: float, source: str): + """Record a brand sighting for this source.""" + if not _has_db() or not source_asset_id: + return + + from core.db.detect import record_sighting + import uuid + asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id + brand_uuid = uuid.UUID(brand_id) if isinstance(brand_id, str) else brand_id + record_sighting(asset_id, brand_uuid, brand_name, timestamp, confidence, source) + + +def build_session_dict(source_asset_id: str | None) -> dict[str, str]: + """ + Load session brands from DB for this source. + + Returns {normalized_name: canonical_name, ...} including aliases. + """ + if not _has_db() or not source_asset_id: + return {} + + from core.db.detect import get_source_sightings + import uuid + + asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id + sightings = get_source_sightings(asset_id) + + session = {} + for s in sightings: + canonical = s.brand_name + session[_normalize(canonical)] = canonical + + # Also load aliases from KnownBrand for each sighted brand + if _has_db(): + from core.db.detect import list_all_brands + all_brands = list_all_brands() + sighted_names = {s.brand_name for s in sightings} + for brand in all_brands: + if brand.canonical_name in sighted_names: + for alias in (brand.aliases or []): + session[_normalize(alias)] = brand.canonical_name + + return session def resolve_brands( candidates: list[TextCandidate], - dictionary: BrandDictionary, config: ResolverConfig, + session_brands: dict[str, str] | None = None, + source_asset_id: str | None = None, content_type: str = "", job_id: str | None = None, ) -> tuple[list[BrandDetection], list[TextCandidate]]: """ - Match text candidates against the brand dictionary. + Match text candidates against known brands (session → global → unresolved). - Returns: - - matched: list of BrandDetection for confirmed brands - - unresolved: list of TextCandidate that couldn't be matched + session_brands: pre-loaded session dict (from build_session_dict) + source_asset_id: for recording new sightings in DB """ + if session_brands is None: + session_brands = {} + emit.log(job_id, "BrandResolver", "INFO", - f"Matching {len(candidates)} candidates against " - f"{len(dictionary.brands)} brands (fuzzy_threshold={config.fuzzy_threshold})") + f"Resolving {len(candidates)} candidates " + f"(session={len(session_brands)} brands, fuzzy={config.fuzzy_threshold})") matched: list[BrandDetection] = [] unresolved: list[TextCandidate] = [] - exact_count = 0 - fuzzy_count = 0 + session_hits = 0 + known_hits = 0 for candidate in candidates: - # Try exact match first - brand = _exact_match(candidate.text, dictionary) - source = "ocr" + text = candidate.text + brand_name = None + brand_id = None + match_source = "ocr" - if brand: - exact_count += 1 + # 1. Check session (cheapest — in-memory dict) + brand_name = _match_session(text, session_brands) + if brand_name: + session_hits += 1 else: - # Try fuzzy match - brand, score = _fuzzy_match(candidate.text, dictionary, config.fuzzy_threshold) - if brand: - fuzzy_count += 1 + # 2. Check global known brands (DB query + fuzzy) + brand_name, brand_id = _match_known(text, config.fuzzy_threshold) + if brand_name: + known_hits += 1 + # Add to session for subsequent candidates in this run + session_brands[_normalize(brand_name)] = brand_name - if brand: + if brand_name: detection = BrandDetection( - brand=brand, + brand=brand_name, timestamp=candidate.frame.timestamp, duration=0.5, confidence=candidate.ocr_confidence, - source=source, + source=match_source, bbox=candidate.bbox, frame_ref=candidate.frame.sequence, content_type=content_type, ) matched.append(detection) + # Record sighting in DB + if brand_id: + _record_sighting( + source_asset_id, brand_id, brand_name, + candidate.frame.timestamp, candidate.ocr_confidence, match_source, + ) + emit.detection( job_id, - brand=brand, + brand=brand_name, confidence=candidate.ocr_confidence, - source=source, + source=match_source, timestamp=candidate.frame.timestamp, content_type=content_type, frame_ref=candidate.frame.sequence, @@ -115,7 +223,7 @@ def resolve_brands( unresolved.append(candidate) emit.log(job_id, "BrandResolver", "INFO", - f"Exact: {exact_count}, Fuzzy: {fuzzy_count}, " - f"Unresolved: {len(unresolved)} → escalating to VLM") + f"Session: {session_hits}, Known: {known_hits}, " + f"Unresolved: {len(unresolved)} → escalating") return matched, unresolved diff --git a/detect/stages/vlm_cloud.py b/detect/stages/vlm_cloud.py index c0e2ab6..8fc912b 100644 --- a/detect/stages/vlm_cloud.py +++ b/detect/stages/vlm_cloud.py @@ -27,6 +27,18 @@ logger = logging.getLogger(__name__) ESTIMATED_TOKENS_PER_CROP = 500 +def _register_discovered_brand(brand: str, source_asset_id: str | None, + timestamp: float, confidence: float): + """Register a cloud-confirmed brand in the DB.""" + try: + from detect.stages.brand_resolver import _register_brand, _record_sighting + brand_id = _register_brand(brand, "cloud_llm") + if brand_id and source_asset_id: + _record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, "cloud_llm") + except Exception as e: + logger.debug("Failed to register brand %s: %s", brand, e) + + def _encode_crop(crop: np.ndarray) -> str: img = Image.fromarray(crop) buf = io.BytesIO() @@ -84,6 +96,7 @@ def escalate_cloud( stats: PipelineStats, min_confidence: float = 0.4, content_type: str = "", + source_asset_id: str | None = None, job_id: str | None = None, ) -> list[BrandDetection]: """ @@ -158,6 +171,10 @@ def escalate_cloud( frame_ref=candidate.frame.sequence, ) + # Register newly discovered brand in DB + _register_discovered_brand(brand, source_asset_id, + candidate.frame.timestamp, confidence) + stats.estimated_cloud_cost_usd += total_cost stats.regions_escalated_to_cloud_llm = len(candidates) diff --git a/detect/stages/vlm_local.py b/detect/stages/vlm_local.py index 1f4987e..af06107 100644 --- a/detect/stages/vlm_local.py +++ b/detect/stages/vlm_local.py @@ -19,6 +19,18 @@ from detect.profiles.base import CropContext logger = logging.getLogger(__name__) +def _register_discovered_brand(brand: str, source_asset_id: str | None, + timestamp: float, confidence: float, source: str): + """Register a VLM-confirmed brand in the DB.""" + try: + from detect.stages.brand_resolver import _register_brand, _record_sighting + brand_id = _register_brand(brand, source) + if brand_id and source_asset_id: + _record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, source) + except Exception as e: + logger.debug("Failed to register brand %s: %s", brand, e) + + def _crop_image(candidate: TextCandidate) -> np.ndarray: frame = candidate.frame box = candidate.bbox @@ -36,6 +48,7 @@ def escalate_vlm( inference_url: str | None = None, min_confidence: float = 0.5, content_type: str = "", + source_asset_id: str | None = None, job_id: str | None = None, ) -> tuple[list[BrandDetection], list[TextCandidate]]: """ @@ -107,6 +120,10 @@ def escalate_vlm( frame_ref=candidate.frame.sequence, ) + # Register newly discovered brand in DB + _register_discovered_brand(brand, source_asset_id, + candidate.frame.timestamp, confidence, "local_vlm") + logger.debug("VLM matched: %s (%.2f) — %s", brand, confidence, reasoning) else: still_unresolved.append(candidate) diff --git a/detect/state.py b/detect/state.py index ee1540c..56b4d27 100644 --- a/detect/state.py +++ b/detect/state.py @@ -17,6 +17,7 @@ class DetectState(TypedDict, total=False): video_path: str job_id: str profile_name: str + source_asset_id: str # UUID of the source MediaAsset # Stage outputs frames: list[Frame] @@ -27,5 +28,11 @@ class DetectState(TypedDict, total=False): detections: list[BrandDetection] report: DetectionReport + # Session brands (accumulated during the run, persisted to DB) + session_brands: dict # {normalized_name: canonical_name} + # Running stats (updated by each stage) stats: PipelineStats + + # Config overrides for replay (applied via OverrideProfile) + config_overrides: dict diff --git a/tests/detect/manual/test_escalation_e2e.py b/tests/detect/manual/test_escalation_e2e.py index 745eaaa..b763c58 100644 --- a/tests/detect/manual/test_escalation_e2e.py +++ b/tests/detect/manual/test_escalation_e2e.py @@ -31,8 +31,16 @@ def ts(): return datetime.now(timezone.utc).isoformat() +RUN_CONTEXT = {} + + +def set_run_context(run_id: str, parent_job_id: str, run_type: str = "initial"): + RUN_CONTEXT.update({"run_id": run_id, "parent_job_id": parent_job_id, "run_type": run_type}) + + def push(r, key, event): event["ts"] = event.get("ts", ts()) + event.update(RUN_CONTEXT) r.rpush(key, json.dumps(event)) return event @@ -85,7 +93,11 @@ def main(): r.delete(key) delay = args.delay + run_id = f"{args.job[:8]}-r1" + set_run_context(run_id=run_id, parent_job_id=args.job, run_type="initial") + logger.info("Full escalation pipeline simulation → %s", key) + logger.info("Run: %s (parent: %s)", run_id, args.job) logger.info("Open: http://mpr.local.ar/detection/?job=%s", args.job) input("\nPress Enter to start...") diff --git a/tests/detect/manual/test_replay.py b/tests/detect/manual/test_replay.py new file mode 100644 index 0000000..e319882 --- /dev/null +++ b/tests/detect/manual/test_replay.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +""" +Test checkpoint + replay flow end-to-end. + +1. Runs the pipeline with checkpointing enabled on a test video +2. Lists available checkpoints +3. Replays from run_ocr with different config +4. Compares detection counts + +Usage: + MPR_CHECKPOINT=1 INFERENCE_URL=http://mcrndeb:8000 python tests/detect/manual/test_replay.py [--job JOB_ID] + +Requires: inference server running, MinIO/S3 running, test video available +""" + +import argparse +import logging +import os +import sys +from pathlib import Path + +# Load ctrl/.env +env_file = Path(__file__).resolve().parents[3] / "ctrl" / ".env" +if env_file.exists(): + for line in env_file.read_text().splitlines(): + line = line.strip() + if line and not line.startswith("#") and "=" in line: + key, _, value = line.partition("=") + os.environ.setdefault(key.strip(), value.strip()) + +sys.path.insert(0, ".") + +logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s") +logger = logging.getLogger(__name__) + +# Force checkpointing on +os.environ["MPR_CHECKPOINT"] = "1" + + +def main(): + parser = argparse.ArgumentParser() + import time + default_job = f"replay-{int(time.time()) % 100000}" + parser.add_argument("--job", default=default_job) + parser.add_argument("--port", type=int, default=6382) + args = parser.parse_args() + + # Override Redis to localhost (ctrl/.env has k8s hostname) + os.environ["REDIS_URL"] = f"redis://localhost:{args.port}/0" + + from detect.graph import get_pipeline, NODES + from detect.checkpoint import list_checkpoints + from detect.checkpoint import replay_from + from detect.state import DetectState + + VIDEO = "media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4" + + logger.info("Job: %s", args.job) + logger.info("Checkpoint: enabled") + logger.info("Video: %s", VIDEO) + logger.info("Open: http://mpr.local.ar/detection/?job=%s", args.job) + input("\nPress Enter to run initial pipeline...") + + # --- Initial run --- + pipeline = get_pipeline(checkpoint=True) + initial_state = DetectState( + video_path=VIDEO, + job_id=args.job, + profile_name="soccer_broadcast", + ) + + logger.info("Running initial pipeline...") + result = pipeline.invoke(initial_state) + + detections = result.get("detections", []) + report = result.get("report") + logger.info("Initial run: %d detections, %d brands", + len(detections), len(report.brands) if report else 0) + + # --- List checkpoints --- + stages = list_checkpoints(args.job) + logger.info("Available checkpoints: %s", stages) + + if "detect_objects" not in stages: + logger.error("Expected checkpoint for detect_objects — aborting replay test") + return + + input("\nPress Enter to replay from run_ocr with different config...") + + # --- Replay with different OCR config --- + overrides = {"ocr": {"min_confidence": 0.3, "languages": ["en", "es"]}} + logger.info("Replaying from run_ocr with overrides: %s", overrides) + + replay_result = replay_from( + job_id=args.job, + start_stage="run_ocr", + config_overrides=overrides, + ) + + replay_detections = replay_result.get("detections", []) + replay_report = replay_result.get("report") + logger.info("Replay run: %d detections, %d brands", + len(replay_detections), + len(replay_report.brands) if replay_report else 0) + + # --- Compare --- + logger.info("--- Comparison ---") + logger.info("Initial: %d detections", len(detections)) + logger.info("Replay: %d detections (min_confidence 0.5 → 0.3)", len(replay_detections)) + + diff = len(replay_detections) - len(detections) + if diff > 0: + logger.info("Replay found %d more detections with lower threshold", diff) + elif diff == 0: + logger.info("Same count — threshold change didn't affect this video") + else: + logger.warning("Replay found fewer detections — unexpected") + + logger.info("Done.") + + +if __name__ == "__main__": + main() diff --git a/tests/detect/test_brand_resolver.py b/tests/detect/test_brand_resolver.py index 2077a38..72add92 100644 --- a/tests/detect/test_brand_resolver.py +++ b/tests/detect/test_brand_resolver.py @@ -1,20 +1,13 @@ -"""Tests for BrandResolver stage.""" +"""Tests for BrandResolver stage (discovery mode).""" import numpy as np import pytest from detect.models import BoundingBox, Frame, TextCandidate -from detect.profiles.base import BrandDictionary, ResolverConfig -from detect.stages.brand_resolver import resolve_brands, _exact_match, _fuzzy_match +from detect.profiles.base import ResolverConfig +from detect.stages.brand_resolver import resolve_brands, _normalize, _match_session -DICTIONARY = BrandDictionary(brands={ - "Nike": ["nike", "NIKE", "swoosh"], - "Adidas": ["adidas", "ADIDAS"], - "Coca-Cola": ["coca-cola", "coca cola", "coke", "COCA-COLA"], - "Emirates": ["emirates", "fly emirates", "EMIRATES"], -}) - CONFIG = ResolverConfig(fuzzy_threshold=75) @@ -25,57 +18,76 @@ def _make_candidate(text: str, confidence: float = 0.9) -> TextCandidate: return TextCandidate(frame=dummy_frame, bbox=dummy_box, text=text, ocr_confidence=confidence) -def test_exact_match(): - assert _exact_match("Nike", DICTIONARY) == "Nike" - assert _exact_match("nike", DICTIONARY) == "Nike" - assert _exact_match("COCA-COLA", DICTIONARY) == "Coca-Cola" - assert _exact_match("fly emirates", DICTIONARY) == "Emirates" - assert _exact_match("unknown brand", DICTIONARY) is None +def test_session_match(): + session = {"nike": "Nike", "fly emirates": "Emirates"} + assert _match_session("Nike", session) == "Nike" + assert _match_session("nike", session) == "Nike" + assert _match_session("FLY EMIRATES", session) == "Emirates" + assert _match_session("unknown", session) is None -def test_fuzzy_match(): - brand, score = _fuzzy_match("Nik3", DICTIONARY, threshold=75) - assert brand == "Nike" - assert score >= 75 +def test_resolve_with_session(monkeypatch): + events = [] + monkeypatch.setattr("detect.emit.push_detect_event", + lambda job_id, etype, data: events.append((etype, data))) - brand, score = _fuzzy_match("adldas", DICTIONARY, threshold=75) - assert brand == "Adidas" - - brand, score = _fuzzy_match("xyzxyzxyz", DICTIONARY, threshold=75) - assert brand is None - - -def test_resolve_exact(): + session = {"nike": "Nike", "emirates": "Emirates"} candidates = [_make_candidate("Nike"), _make_candidate("EMIRATES")] - matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG) + + matched, unresolved = resolve_brands( + candidates, CONFIG, session_brands=session, + ) + assert len(matched) == 2 assert len(unresolved) == 0 assert matched[0].brand == "Nike" assert matched[1].brand == "Emirates" -def test_resolve_fuzzy(): - candidates = [_make_candidate("coca coIa")] # OCR misread - matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG) - assert len(matched) == 1 - assert matched[0].brand == "Coca-Cola" +def test_resolve_unresolved_without_db(monkeypatch): + events = [] + monkeypatch.setattr("detect.emit.push_detect_event", + lambda job_id, etype, data: events.append((etype, data))) - -def test_resolve_unresolved(): candidates = [_make_candidate("random garbage text")] - matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG) + + matched, unresolved = resolve_brands( + candidates, CONFIG, session_brands={}, + ) + assert len(matched) == 0 assert len(unresolved) == 1 -def test_resolve_mixed(): +def test_resolve_empty(monkeypatch): + events = [] + monkeypatch.setattr("detect.emit.push_detect_event", + lambda job_id, etype, data: events.append((etype, data))) + + matched, unresolved = resolve_brands([], CONFIG, session_brands={}) + + assert len(matched) == 0 + assert len(unresolved) == 0 + + +def test_resolve_builds_session_during_run(monkeypatch): + """Session brands accumulate during a single run — second candidate benefits.""" + events = [] + monkeypatch.setattr("detect.emit.push_detect_event", + lambda job_id, etype, data: events.append((etype, data))) + + session = {"nike": "Nike"} candidates = [ - _make_candidate("Nike"), - _make_candidate("unknown"), - _make_candidate("adldas"), + _make_candidate("Nike"), # hits session + _make_candidate("unknown"), # misses everything ] - matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG) - assert len(matched) == 2 # Nike exact + Adidas fuzzy + + matched, unresolved = resolve_brands( + candidates, CONFIG, session_brands=session, + ) + + assert len(matched) == 1 + assert matched[0].brand == "Nike" assert len(unresolved) == 1 @@ -84,8 +96,10 @@ def test_events_emitted(monkeypatch): monkeypatch.setattr("detect.emit.push_detect_event", lambda job_id, etype, data: events.append((etype, data))) + session = {"nike": "Nike"} candidates = [_make_candidate("Nike")] - resolve_brands(candidates, DICTIONARY, CONFIG, job_id="test-job") + + resolve_brands(candidates, CONFIG, session_brands=session, job_id="test-job") event_types = [e[0] for e in events] assert "log" in event_types diff --git a/tests/detect/test_checkpoint.py b/tests/detect/test_checkpoint.py new file mode 100644 index 0000000..a147a5e --- /dev/null +++ b/tests/detect/test_checkpoint.py @@ -0,0 +1,182 @@ +"""Tests for checkpoint serialization — round-trip without S3.""" + +import numpy as np +import pytest + +from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate +from detect.checkpoint.serializer import ( + serialize_state, + deserialize_state, + serialize_frame_meta, + serialize_text_candidate, + deserialize_text_candidate, +) + + +def _make_frame(seq: int = 0, w: int = 100, h: int = 80) -> Frame: + image = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8) + return Frame( + sequence=seq, + chunk_id=0, + timestamp=float(seq) * 0.5, + image=image, + perceptual_hash=f"hash_{seq}", + ) + + +def _make_box(x=10, y=10, w=30, h=20) -> BoundingBox: + return BoundingBox(x=x, y=y, w=w, h=h, confidence=0.9, label="text") + + +def _make_candidate(frame: Frame, text: str = "NIKE") -> TextCandidate: + box = _make_box() + return TextCandidate(frame=frame, bbox=box, text=text, ocr_confidence=0.85) + + +def _make_detection(brand: str = "Nike", timestamp: float = 1.0) -> BrandDetection: + return BrandDetection( + brand=brand, + timestamp=timestamp, + duration=0.5, + confidence=0.92, + source="ocr", + content_type="soccer_broadcast", + frame_ref=0, + ) + + +# --- Frame metadata --- + +def test_serialize_frame_meta(): + frame = _make_frame(seq=5) + meta = serialize_frame_meta(frame) + + assert meta["sequence"] == 5 + assert meta["timestamp"] == 2.5 + assert meta["perceptual_hash"] == "hash_5" + assert "image" not in meta + + +# --- TextCandidate --- + +def test_serialize_text_candidate(): + frame = _make_frame() + candidate = _make_candidate(frame, text="EMIRATES") + data = serialize_text_candidate(candidate) + + assert data["frame_sequence"] == 0 + assert data["text"] == "EMIRATES" + assert data["ocr_confidence"] == 0.85 + assert "bbox" in data + + +def test_deserialize_text_candidate(): + frame = _make_frame() + candidate = _make_candidate(frame, text="ADIDAS") + + serialized = serialize_text_candidate(candidate) + frame_map = {frame.sequence: frame} + + restored = deserialize_text_candidate(serialized, frame_map) + + assert restored.text == "ADIDAS" + assert restored.ocr_confidence == 0.85 + assert restored.frame is frame # same object reference + assert restored.bbox.x == 10 + + +# --- Full state round-trip --- + +def test_state_round_trip(): + frames = [_make_frame(seq=i) for i in range(3)] + filtered = frames[:2] + + box = _make_box() + boxes_by_frame = {0: [box], 1: [box]} + + candidates = [_make_candidate(frames[0], "NIKE"), _make_candidate(frames[1], "EMIRATES")] + unresolved = [_make_candidate(frames[2], "unknown")] + detections = [_make_detection("Nike", 0.5), _make_detection("Emirates", 1.0)] + + stats = PipelineStats( + frames_extracted=3, + frames_after_scene_filter=2, + regions_detected=2, + regions_resolved_by_ocr=2, + cloud_llm_calls=1, + estimated_cloud_cost_usd=0.003, + ) + + state = { + "job_id": "test-123", + "video_path": "/tmp/test.mp4", + "profile_name": "soccer_broadcast", + "config_overrides": {"ocr": {"min_confidence": 0.3}}, + "frames": frames, + "filtered_frames": filtered, + "boxes_by_frame": boxes_by_frame, + "text_candidates": candidates, + "unresolved_candidates": unresolved, + "detections": detections, + "stats": stats, + } + + manifest = {f.sequence: f"s3://fake/frames/{f.sequence}.jpg" for f in frames} + + # Serialize + serialized = serialize_state(state, manifest) + + # Verify JSON-compatible (no numpy, no Frame objects) + import json + json_str = json.dumps(serialized, default=str) + assert len(json_str) > 0 + + # Deserialize with the original frames (simulating frame load from S3) + restored = deserialize_state(serialized, frames) + + # Verify round-trip + assert restored["job_id"] == "test-123" + assert restored["video_path"] == "/tmp/test.mp4" + assert restored["profile_name"] == "soccer_broadcast" + assert restored["config_overrides"] == {"ocr": {"min_confidence": 0.3}} + + assert len(restored["frames"]) == 3 + assert len(restored["filtered_frames"]) == 2 + assert len(restored["boxes_by_frame"]) == 2 + assert len(restored["text_candidates"]) == 2 + assert len(restored["unresolved_candidates"]) == 1 + assert len(restored["detections"]) == 2 + + restored_stats = restored["stats"] + assert restored_stats.frames_extracted == 3 + assert restored_stats.cloud_llm_calls == 1 + assert restored_stats.estimated_cloud_cost_usd == 0.003 + + # TextCandidate frame references should point to actual Frame objects + tc = restored["text_candidates"][0] + assert tc.frame is frames[0] + assert tc.text == "NIKE" + + +def test_state_round_trip_empty(): + """Empty state should serialize/deserialize cleanly.""" + state = { + "job_id": "empty-job", + "video_path": "", + "profile_name": "soccer_broadcast", + "frames": [], + "filtered_frames": [], + "boxes_by_frame": {}, + "text_candidates": [], + "unresolved_candidates": [], + "detections": [], + "stats": PipelineStats(), + } + + serialized = serialize_state(state, {}) + restored = deserialize_state(serialized, []) + + assert restored["job_id"] == "empty-job" + assert len(restored["frames"]) == 0 + assert len(restored["detections"]) == 0 + assert restored["stats"].frames_extracted == 0 diff --git a/tests/detect/test_profiles.py b/tests/detect/test_profiles.py index f261eb7..8e95d5b 100644 --- a/tests/detect/test_profiles.py +++ b/tests/detect/test_profiles.py @@ -25,11 +25,9 @@ def test_soccer_detection_config(): assert isinstance(cfg.target_classes, list) -def test_soccer_brand_dictionary_non_empty(): - bd = SoccerBroadcastProfile().brand_dictionary() - assert len(bd.brands) > 0 - for canonical, aliases in bd.brands.items(): - assert len(aliases) > 0 +def test_soccer_resolver_config(): + cfg = SoccerBroadcastProfile().resolver_config() + assert cfg.fuzzy_threshold > 0 def test_soccer_vlm_prompt(): @@ -70,4 +68,4 @@ def test_stubs_raise(stub_cls): with pytest.raises(NotImplementedError): stub.frame_extraction_config() with pytest.raises(NotImplementedError): - stub.brand_dictionary() + stub.resolver_config() diff --git a/tests/detect/test_replay.py b/tests/detect/test_replay.py new file mode 100644 index 0000000..dc3cc64 --- /dev/null +++ b/tests/detect/test_replay.py @@ -0,0 +1,67 @@ +"""Tests for replay and OverrideProfile.""" + +import pytest + +from detect.profiles.soccer import SoccerBroadcastProfile +from detect.checkpoint.replay import OverrideProfile + + +def test_override_profile_patches_ocr(): + base = SoccerBroadcastProfile() + overrides = {"ocr": {"min_confidence": 0.3, "languages": ["en", "es", "pt"]}} + profile = OverrideProfile(base, overrides) + + config = profile.ocr_config() + + assert config.min_confidence == 0.3 + assert config.languages == ["en", "es", "pt"] + + +def test_override_profile_patches_resolver(): + base = SoccerBroadcastProfile() + overrides = {"resolver": {"fuzzy_threshold": 60}} + profile = OverrideProfile(base, overrides) + + config = profile.resolver_config() + + assert config.fuzzy_threshold == 60 + + +def test_override_profile_patches_detection(): + base = SoccerBroadcastProfile() + overrides = {"detection": {"confidence_threshold": 0.5}} + profile = OverrideProfile(base, overrides) + + config = profile.detection_config() + + assert config.confidence_threshold == 0.5 + + +def test_override_profile_no_overrides(): + base = SoccerBroadcastProfile() + profile = OverrideProfile(base, {}) + + ocr = profile.ocr_config() + base_ocr = base.ocr_config() + + assert ocr.min_confidence == base_ocr.min_confidence + assert ocr.languages == base_ocr.languages + + +def test_override_profile_delegates_non_config(): + base = SoccerBroadcastProfile() + profile = OverrideProfile(base, {"ocr": {"min_confidence": 0.1}}) + + assert profile.name == "soccer_broadcast" + assert profile.resolver_config().fuzzy_threshold > 0 + + +def test_override_profile_ignores_unknown_fields(): + base = SoccerBroadcastProfile() + overrides = {"ocr": {"nonexistent_field": 42}} + profile = OverrideProfile(base, overrides) + + config = profile.ocr_config() + + assert not hasattr(config, "nonexistent_field") + assert config.min_confidence == base.ocr_config().min_confidence diff --git a/ui/detection-app/src/App.vue b/ui/detection-app/src/App.vue index a0713ba..5ce2b96 100644 --- a/ui/detection-app/src/App.vue +++ b/ui/detection-app/src/App.vue @@ -1,6 +1,6 @@ diff --git a/ui/detection-app/src/panels/CostStatsPanel.vue b/ui/detection-app/src/panels/CostStatsPanel.vue index ecabc46..57aa0aa 100644 --- a/ui/detection-app/src/panels/CostStatsPanel.vue +++ b/ui/detection-app/src/panels/CostStatsPanel.vue @@ -7,6 +7,7 @@ import type { StatsUpdate, Detection } from '../types/sse-contract' const props = defineProps<{ source: DataSource status?: 'idle' | 'live' | 'processing' | 'error' + embedded?: boolean }>() const stats = ref(null) @@ -71,7 +72,7 @@ const metrics = computed(() => { diff --git a/ui/framework/src/index.ts b/ui/framework/src/index.ts index 3d9274d..a9742c8 100644 --- a/ui/framework/src/index.ts +++ b/ui/framework/src/index.ts @@ -7,6 +7,7 @@ export { useDataSource } from './composables/useDataSource' // Components export { default as Panel } from './components/Panel.vue' export { default as LayoutGrid } from './components/LayoutGrid.vue' +export { default as ResizeHandle } from './components/ResizeHandle.vue' // Renderers export { default as LogRenderer } from './renderers/LogRenderer.vue' diff --git a/ui/framework/src/renderers/TableRenderer.vue b/ui/framework/src/renderers/TableRenderer.vue index 0feb3b2..d4c3d69 100644 --- a/ui/framework/src/renderers/TableRenderer.vue +++ b/ui/framework/src/renderers/TableRenderer.vue @@ -76,6 +76,7 @@ const sorted = computed(() => { table { width: 100%; border-collapse: collapse; + table-layout: fixed; } th { @@ -89,7 +90,6 @@ th { border-bottom: var(--panel-border); cursor: pointer; user-select: none; - white-space: nowrap; } th:hover { @@ -104,7 +104,10 @@ th:hover { td { padding: var(--space-1) var(--space-3); border-bottom: 1px solid var(--surface-3); - white-space: nowrap; + white-space: normal; + word-break: break-word; + overflow: hidden; + text-overflow: ellipsis; } tr:hover td { diff --git a/ui/framework/src/renderers/TimeSeriesRenderer.vue b/ui/framework/src/renderers/TimeSeriesRenderer.vue index b0baf75..c664ffb 100644 --- a/ui/framework/src/renderers/TimeSeriesRenderer.vue +++ b/ui/framework/src/renderers/TimeSeriesRenderer.vue @@ -22,6 +22,7 @@ const props = withDefaults(defineProps<{ }) const container = ref(null) +const zoomed = ref(false) let chart: uPlot | null = null function buildOpts(): uPlot.Options { @@ -40,25 +41,66 @@ function buildOpts(): uPlot.Options { height: container.value?.clientHeight ?? 200, series: seriesOpts, axes: [ - { stroke: '#555568', grid: { stroke: '#2e2e3822' } }, - { stroke: '#555568', grid: { stroke: '#2e2e3822' } }, + { + stroke: '#555568', + grid: { stroke: '#2e2e3822' }, + size: 40, + font: '10px monospace', + ticks: { size: 3 }, + }, + { + stroke: '#555568', + grid: { stroke: '#2e2e3822' }, + size: 35, + font: '10px monospace', + ticks: { size: 3 }, + }, ], cursor: { show: true }, - legend: { show: true }, + legend: { show: true, live: false }, + padding: [8, 8, 0, 0], + hooks: { + setScale: [(_self: uPlot, scaleKey: string) => { + if (scaleKey === 'x') zoomed.value = true + }], + }, } } +function resetZoom() { + if (!chart) return + const data = chart.data + if (data && data[0] && data[0].length > 0) { + const min = data[0][0] + const max = data[0][data[0].length - 1] + chart.setScale('x', { min, max }) + } + zoomed.value = false +} + +function getLegendHeight(): number { + if (!container.value) return 0 + const legend = container.value.querySelector('.u-legend') as HTMLElement | null + return legend ? legend.offsetHeight : 0 +} + function createChart() { if (!container.value) return if (chart) chart.destroy() chart = new uPlot(buildOpts(), props.data, container.value) + // Refit after legend renders + nextTick(() => resize()) } function resize() { if (!chart || !container.value) return + const legendH = getLegendHeight() + const availableH = container.value.clientHeight + // uPlot height = canvas height (chart sets total = canvas + legend) + const chartH = Math.max(60, availableH - legendH) chart.setSize({ width: container.value.clientWidth, - height: container.value.clientHeight, + height: chartH, }) } @@ -83,19 +125,74 @@ onMounted(() => {