phase 10
This commit is contained in:
121
core/api/detect_replay.py
Normal file
121
core/api/detect_replay.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""
|
||||||
|
API endpoints for checkpoint inspection, replay, and retry.
|
||||||
|
|
||||||
|
GET /detect/checkpoints/{job_id} — list available checkpoints
|
||||||
|
POST /detect/replay — replay from a stage with config overrides
|
||||||
|
POST /detect/retry — queue async retry with different provider
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/detect", tags=["detect"])
|
||||||
|
|
||||||
|
|
||||||
|
# --- Request/Response models ---
|
||||||
|
|
||||||
|
class CheckpointInfo(BaseModel):
|
||||||
|
stage: str
|
||||||
|
|
||||||
|
|
||||||
|
class ReplayRequest(BaseModel):
|
||||||
|
job_id: str
|
||||||
|
start_stage: str
|
||||||
|
config_overrides: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ReplayResponse(BaseModel):
|
||||||
|
status: str
|
||||||
|
job_id: str
|
||||||
|
start_stage: str
|
||||||
|
detections: int = 0
|
||||||
|
brands_found: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class RetryRequest(BaseModel):
|
||||||
|
job_id: str
|
||||||
|
config_overrides: dict | None = None
|
||||||
|
start_stage: str = "escalate_vlm"
|
||||||
|
schedule_seconds: float | None = None # delay before execution (off-peak)
|
||||||
|
|
||||||
|
|
||||||
|
class RetryResponse(BaseModel):
|
||||||
|
status: str
|
||||||
|
task_id: str
|
||||||
|
job_id: str
|
||||||
|
|
||||||
|
|
||||||
|
# --- Endpoints ---
|
||||||
|
|
||||||
|
@router.get("/checkpoints/{job_id}")
|
||||||
|
def list_checkpoints(job_id: str) -> list[CheckpointInfo]:
|
||||||
|
"""List available checkpoint stages for a job."""
|
||||||
|
from detect.checkpoint import list_checkpoints as _list
|
||||||
|
|
||||||
|
try:
|
||||||
|
stages = _list(job_id)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=404, detail=f"No checkpoints for job {job_id}: {e}")
|
||||||
|
|
||||||
|
result = [CheckpointInfo(stage=s) for s in stages]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/replay", response_model=ReplayResponse)
|
||||||
|
def replay(req: ReplayRequest):
|
||||||
|
"""Replay pipeline from a specific stage with optional config overrides."""
|
||||||
|
from detect.checkpoint import replay_from
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = replay_from(
|
||||||
|
job_id=req.job_id,
|
||||||
|
start_stage=req.start_stage,
|
||||||
|
config_overrides=req.config_overrides,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Replay failed: {e}")
|
||||||
|
|
||||||
|
detections = result.get("detections", [])
|
||||||
|
report = result.get("report")
|
||||||
|
brands_found = len(report.brands) if report else 0
|
||||||
|
|
||||||
|
response = ReplayResponse(
|
||||||
|
status="completed",
|
||||||
|
job_id=req.job_id,
|
||||||
|
start_stage=req.start_stage,
|
||||||
|
detections=len(detections),
|
||||||
|
brands_found=brands_found,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/retry", response_model=RetryResponse)
|
||||||
|
def retry(req: RetryRequest):
|
||||||
|
"""Queue an async retry of unresolved candidates with different config."""
|
||||||
|
from detect.checkpoint.tasks import retry_candidates
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"job_id": req.job_id,
|
||||||
|
"config_overrides": req.config_overrides,
|
||||||
|
"start_stage": req.start_stage,
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.schedule_seconds:
|
||||||
|
task = retry_candidates.apply_async(kwargs=kwargs, countdown=req.schedule_seconds)
|
||||||
|
else:
|
||||||
|
task = retry_candidates.delay(**kwargs)
|
||||||
|
|
||||||
|
response = RetryResponse(
|
||||||
|
status="queued",
|
||||||
|
task_id=task.id,
|
||||||
|
job_id=req.job_id,
|
||||||
|
)
|
||||||
|
return response
|
||||||
@@ -25,6 +25,7 @@ from strawberry.fastapi import GraphQLRouter
|
|||||||
|
|
||||||
from core.api.chunker_sse import router as chunker_router
|
from core.api.chunker_sse import router as chunker_router
|
||||||
from core.api.detect_sse import router as detect_router
|
from core.api.detect_sse import router as detect_router
|
||||||
|
from core.api.detect_replay import router as detect_replay_router
|
||||||
from core.api.graphql import schema as graphql_schema
|
from core.api.graphql import schema as graphql_schema
|
||||||
|
|
||||||
CALLBACK_API_KEY = os.environ.get("CALLBACK_API_KEY", "")
|
CALLBACK_API_KEY = os.environ.get("CALLBACK_API_KEY", "")
|
||||||
@@ -56,6 +57,9 @@ app.include_router(chunker_router)
|
|||||||
# Detection SSE
|
# Detection SSE
|
||||||
app.include_router(detect_router)
|
app.include_router(detect_router)
|
||||||
|
|
||||||
|
# Detection replay/retry
|
||||||
|
app.include_router(detect_replay_router)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
def health():
|
def health():
|
||||||
|
|||||||
175
core/db/detect.py
Normal file
175
core/db/detect.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
"""Database operations for DetectJob and StageCheckpoint."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# DetectJob
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def create_detect_job(**fields):
|
||||||
|
from admin.mpr.media_assets.models import DetectJob
|
||||||
|
return DetectJob.objects.create(**fields)
|
||||||
|
|
||||||
|
|
||||||
|
def get_detect_job(id: UUID):
|
||||||
|
from admin.mpr.media_assets.models import DetectJob
|
||||||
|
return DetectJob.objects.get(id=id)
|
||||||
|
|
||||||
|
|
||||||
|
def update_detect_job(job_id: UUID, **fields):
|
||||||
|
from admin.mpr.media_assets.models import DetectJob
|
||||||
|
DetectJob.objects.filter(id=job_id).update(**fields)
|
||||||
|
|
||||||
|
|
||||||
|
def list_detect_jobs(
|
||||||
|
parent_job_id: Optional[UUID] = None,
|
||||||
|
status: Optional[str] = None,
|
||||||
|
):
|
||||||
|
from admin.mpr.media_assets.models import DetectJob
|
||||||
|
|
||||||
|
qs = DetectJob.objects.all()
|
||||||
|
if parent_job_id:
|
||||||
|
qs = qs.filter(parent_job_id=parent_job_id)
|
||||||
|
if status:
|
||||||
|
qs = qs.filter(status=status)
|
||||||
|
return list(qs)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# StageCheckpoint
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def save_stage_checkpoint(**fields):
|
||||||
|
from admin.mpr.media_assets.models import StageCheckpoint
|
||||||
|
return StageCheckpoint.objects.create(**fields)
|
||||||
|
|
||||||
|
|
||||||
|
def get_stage_checkpoint(job_id: UUID, stage: str):
|
||||||
|
from admin.mpr.media_assets.models import StageCheckpoint
|
||||||
|
return StageCheckpoint.objects.get(job_id=job_id, stage=stage)
|
||||||
|
|
||||||
|
|
||||||
|
def list_stage_checkpoints(job_id: UUID) -> list[str]:
|
||||||
|
from admin.mpr.media_assets.models import StageCheckpoint
|
||||||
|
|
||||||
|
stages = (
|
||||||
|
StageCheckpoint.objects
|
||||||
|
.filter(job_id=job_id)
|
||||||
|
.order_by("stage_index")
|
||||||
|
.values_list("stage", flat=True)
|
||||||
|
)
|
||||||
|
return list(stages)
|
||||||
|
|
||||||
|
|
||||||
|
def delete_stage_checkpoints(job_id: UUID):
|
||||||
|
from admin.mpr.media_assets.models import StageCheckpoint
|
||||||
|
StageCheckpoint.objects.filter(job_id=job_id).delete()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# KnownBrand
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_or_create_brand(canonical_name: str, aliases: list[str] | None = None,
|
||||||
|
source: str = "ocr") -> tuple:
|
||||||
|
"""Get existing brand or create new one. Returns (brand, created)."""
|
||||||
|
from admin.mpr.media_assets.models import KnownBrand
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
normalized = canonical_name.strip()
|
||||||
|
brand = KnownBrand.objects.filter(canonical_name__iexact=normalized).first()
|
||||||
|
if brand:
|
||||||
|
return brand, False
|
||||||
|
|
||||||
|
# Check aliases of existing brands
|
||||||
|
for existing in KnownBrand.objects.all():
|
||||||
|
existing_aliases = [a.lower() for a in (existing.aliases or [])]
|
||||||
|
if normalized.lower() in existing_aliases:
|
||||||
|
return existing, False
|
||||||
|
|
||||||
|
brand = KnownBrand.objects.create(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
canonical_name=normalized,
|
||||||
|
aliases=aliases or [],
|
||||||
|
first_source=source,
|
||||||
|
)
|
||||||
|
return brand, True
|
||||||
|
|
||||||
|
|
||||||
|
def find_brand_by_text(text: str) -> Optional[object]:
|
||||||
|
"""Find a known brand by canonical name or alias (case-insensitive)."""
|
||||||
|
from admin.mpr.media_assets.models import KnownBrand
|
||||||
|
|
||||||
|
normalized = text.strip().lower()
|
||||||
|
|
||||||
|
# Exact canonical match
|
||||||
|
brand = KnownBrand.objects.filter(canonical_name__iexact=normalized).first()
|
||||||
|
if brand:
|
||||||
|
return brand
|
||||||
|
|
||||||
|
# Search aliases (jsonb contains)
|
||||||
|
for brand in KnownBrand.objects.all():
|
||||||
|
brand_aliases = [a.lower() for a in (brand.aliases or [])]
|
||||||
|
if normalized in brand_aliases:
|
||||||
|
return brand
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def list_all_brands() -> list:
|
||||||
|
from admin.mpr.media_assets.models import KnownBrand
|
||||||
|
return list(KnownBrand.objects.all().order_by("canonical_name"))
|
||||||
|
|
||||||
|
|
||||||
|
def update_brand(brand_id: UUID, **fields):
|
||||||
|
from admin.mpr.media_assets.models import KnownBrand
|
||||||
|
KnownBrand.objects.filter(id=brand_id).update(**fields)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SourceBrandSighting
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_source_sightings(source_asset_id: UUID) -> list:
|
||||||
|
"""Get all brand sightings for a specific source video."""
|
||||||
|
from admin.mpr.media_assets.models import SourceBrandSighting
|
||||||
|
return list(
|
||||||
|
SourceBrandSighting.objects
|
||||||
|
.filter(source_asset_id=source_asset_id)
|
||||||
|
.order_by("-occurrences")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def record_sighting(source_asset_id: UUID, brand_id: UUID, brand_name: str,
|
||||||
|
timestamp: float, confidence: float, source: str = "ocr"):
|
||||||
|
"""Record or update a brand sighting for a source."""
|
||||||
|
from admin.mpr.media_assets.models import SourceBrandSighting
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
sighting = SourceBrandSighting.objects.filter(
|
||||||
|
source_asset_id=source_asset_id,
|
||||||
|
brand_id=brand_id,
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if sighting:
|
||||||
|
sighting.occurrences += 1
|
||||||
|
sighting.last_seen_timestamp = timestamp
|
||||||
|
total_conf = sighting.avg_confidence * (sighting.occurrences - 1) + confidence
|
||||||
|
sighting.avg_confidence = total_conf / sighting.occurrences
|
||||||
|
sighting.save()
|
||||||
|
return sighting
|
||||||
|
|
||||||
|
sighting = SourceBrandSighting.objects.create(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
source_asset_id=source_asset_id,
|
||||||
|
brand_id=brand_id,
|
||||||
|
brand_name=brand_name,
|
||||||
|
first_seen_timestamp=timestamp,
|
||||||
|
last_seen_timestamp=timestamp,
|
||||||
|
occurrences=1,
|
||||||
|
detection_source=source,
|
||||||
|
avg_confidence=confidence,
|
||||||
|
)
|
||||||
|
return sighting
|
||||||
@@ -26,13 +26,18 @@ from .grpc import (
|
|||||||
WorkerStatus,
|
WorkerStatus,
|
||||||
)
|
)
|
||||||
from .jobs import ChunkJob, ChunkJobStatus, JobStatus, TranscodeJob
|
from .jobs import ChunkJob, ChunkJobStatus, JobStatus, TranscodeJob
|
||||||
|
from .detect_jobs import (
|
||||||
|
DetectJob, DetectJobStatus, RunType, StageCheckpoint,
|
||||||
|
BrandSource, KnownBrand, SourceBrandSighting,
|
||||||
|
)
|
||||||
from .media import AssetStatus, MediaAsset
|
from .media import AssetStatus, MediaAsset
|
||||||
from .presets import BUILTIN_PRESETS, TranscodePreset
|
from .presets import BUILTIN_PRESETS, TranscodePreset
|
||||||
from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generic loader
|
from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generic loader
|
||||||
from .views import ChunkEvent, ChunkOutputFile, PipelineStats, WorkerEvent
|
from .views import ChunkEvent, ChunkOutputFile, PipelineStats, WorkerEvent
|
||||||
|
|
||||||
# Core domain models - generates Django, Pydantic, TypeScript
|
# Core domain models - generates Django, Pydantic, TypeScript
|
||||||
DATACLASSES = [MediaAsset, TranscodePreset, TranscodeJob, ChunkJob]
|
DATACLASSES = [MediaAsset, TranscodePreset, TranscodeJob, ChunkJob,
|
||||||
|
DetectJob, StageCheckpoint, KnownBrand, SourceBrandSighting]
|
||||||
|
|
||||||
# API request/response models - generates TypeScript only (no Django)
|
# API request/response models - generates TypeScript only (no Django)
|
||||||
# WorkerStatus from grpc.py is reused here
|
# WorkerStatus from grpc.py is reused here
|
||||||
@@ -46,7 +51,7 @@ API_MODELS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Status enums - included in generated code
|
# Status enums - included in generated code
|
||||||
ENUMS = [AssetStatus, JobStatus, ChunkJobStatus]
|
ENUMS = [AssetStatus, JobStatus, ChunkJobStatus, DetectJobStatus, RunType, BrandSource]
|
||||||
|
|
||||||
# View/event models - generates TypeScript for UI consumption
|
# View/event models - generates TypeScript for UI consumption
|
||||||
VIEWS = [ChunkEvent, WorkerEvent, PipelineStats, ChunkOutputFile]
|
VIEWS = [ChunkEvent, WorkerEvent, PipelineStats, ChunkOutputFile]
|
||||||
|
|||||||
@@ -149,6 +149,64 @@ class JobComplete:
|
|||||||
report: Optional[DetectionReportSummary] = None
|
report: Optional[DetectionReportSummary] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RunContext:
|
||||||
|
"""Run context injected into all SSE events for grouping."""
|
||||||
|
|
||||||
|
run_id: str
|
||||||
|
parent_job_id: str
|
||||||
|
run_type: str = "initial" # initial | replay | retry
|
||||||
|
|
||||||
|
|
||||||
|
# --- Checkpoint API types ---
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CheckpointInfo:
|
||||||
|
"""Available checkpoint for a stage."""
|
||||||
|
|
||||||
|
stage: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReplayRequest:
|
||||||
|
"""Request to replay pipeline from a specific stage."""
|
||||||
|
|
||||||
|
job_id: str
|
||||||
|
start_stage: str
|
||||||
|
config_overrides: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReplayResponse:
|
||||||
|
"""Result of a replay invocation."""
|
||||||
|
|
||||||
|
status: str
|
||||||
|
job_id: str
|
||||||
|
start_stage: str
|
||||||
|
detections: int = 0
|
||||||
|
brands_found: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetryRequest:
|
||||||
|
"""Request to queue async retry with different config."""
|
||||||
|
|
||||||
|
job_id: str
|
||||||
|
config_overrides: Optional[dict] = None
|
||||||
|
start_stage: str = "escalate_vlm"
|
||||||
|
schedule_seconds: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetryResponse:
|
||||||
|
"""Result of queueing a retry task."""
|
||||||
|
|
||||||
|
status: str
|
||||||
|
task_id: str
|
||||||
|
job_id: str
|
||||||
|
|
||||||
|
|
||||||
# --- Export lists for modelgen ---
|
# --- Export lists for modelgen ---
|
||||||
|
|
||||||
DETECT_VIEWS = [
|
DETECT_VIEWS = [
|
||||||
@@ -163,4 +221,10 @@ DETECT_VIEWS = [
|
|||||||
LogEvent,
|
LogEvent,
|
||||||
DetectionReportSummary,
|
DetectionReportSummary,
|
||||||
JobComplete,
|
JobComplete,
|
||||||
|
RunContext,
|
||||||
|
CheckpointInfo,
|
||||||
|
ReplayRequest,
|
||||||
|
ReplayResponse,
|
||||||
|
RetryRequest,
|
||||||
|
RetryResponse,
|
||||||
]
|
]
|
||||||
|
|||||||
162
core/schema/models/detect_jobs.py
Normal file
162
core/schema/models/detect_jobs.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
"""
|
||||||
|
Detection Job and Checkpoint Schema Definitions
|
||||||
|
|
||||||
|
Source of truth for detection pipeline job tracking and stage checkpoints.
|
||||||
|
Follows the TranscodeJob/ChunkJob pattern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
|
class DetectJobStatus(str, Enum):
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
PAUSED = "paused"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
|
class RunType(str, Enum):
|
||||||
|
INITIAL = "initial"
|
||||||
|
REPLAY = "replay"
|
||||||
|
RETRY = "retry"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DetectJob:
|
||||||
|
"""
|
||||||
|
A detection pipeline job.
|
||||||
|
|
||||||
|
Each invocation of the pipeline (initial run, replay, retry) creates a DetectJob.
|
||||||
|
Jobs for the same source video are linked via parent_job_id.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
|
||||||
|
# Input
|
||||||
|
source_asset_id: UUID
|
||||||
|
video_path: str
|
||||||
|
profile_name: str = "soccer_broadcast"
|
||||||
|
|
||||||
|
# Run lineage
|
||||||
|
parent_job_id: Optional[UUID] = None # links all runs for the same source
|
||||||
|
run_type: RunType = RunType.INITIAL
|
||||||
|
replay_from_stage: Optional[str] = None # null for initial runs
|
||||||
|
config_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Status
|
||||||
|
status: DetectJobStatus = DetectJobStatus.PENDING
|
||||||
|
current_stage: Optional[str] = None
|
||||||
|
progress: float = 0.0
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
|
||||||
|
# Results summary
|
||||||
|
total_detections: int = 0
|
||||||
|
brands_found: int = 0
|
||||||
|
cloud_llm_calls: int = 0
|
||||||
|
estimated_cost_usd: float = 0.0
|
||||||
|
|
||||||
|
# Worker tracking
|
||||||
|
celery_task_id: Optional[str] = None
|
||||||
|
priority: int = 0
|
||||||
|
|
||||||
|
# Timestamps
|
||||||
|
created_at: Optional[datetime] = None
|
||||||
|
started_at: Optional[datetime] = None
|
||||||
|
completed_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StageCheckpoint:
|
||||||
|
"""
|
||||||
|
A checkpoint saved after a pipeline stage completes.
|
||||||
|
|
||||||
|
Binary data (frame images, crops) goes to S3/MinIO.
|
||||||
|
Everything else (structured state) lives here in Postgres.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
job_id: UUID
|
||||||
|
stage: str
|
||||||
|
stage_index: int # position in NODES list (0-7)
|
||||||
|
|
||||||
|
# S3 reference for binary data only
|
||||||
|
frames_prefix: str = "" # s3 prefix: checkpoints/{job_id}/frames/
|
||||||
|
|
||||||
|
# Frame metadata (non-image fields)
|
||||||
|
frames_manifest: Dict[int, str] = field(default_factory=dict) # seq → s3 key
|
||||||
|
frames_meta: List[Dict[str, Any]] = field(default_factory=list) # sequence, chunk_id, timestamp, hash
|
||||||
|
filtered_frame_sequences: List[int] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Detection state (full structured data, not just summaries)
|
||||||
|
boxes_by_frame: Dict[str, List[Dict[str, Any]]] = field(default_factory=dict)
|
||||||
|
text_candidates: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
unresolved_candidates: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
detections: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Pipeline state
|
||||||
|
stats: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
config_snapshot: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
config_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Input refs (for replay)
|
||||||
|
video_path: str = ""
|
||||||
|
profile_name: str = ""
|
||||||
|
|
||||||
|
# Timestamps
|
||||||
|
created_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
|
class BrandSource(str, Enum):
|
||||||
|
"""How a brand was first identified."""
|
||||||
|
OCR = "ocr"
|
||||||
|
VLM = "local_vlm"
|
||||||
|
CLOUD = "cloud_llm"
|
||||||
|
MANUAL = "manual" # user-added via UI
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KnownBrand:
|
||||||
|
"""
|
||||||
|
A brand discovered or registered in the system.
|
||||||
|
|
||||||
|
Global — not per-source. Accumulates across all pipeline runs.
|
||||||
|
Aliases enable fuzzy matching without re-escalating to VLM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
canonical_name: str # normalized display name
|
||||||
|
aliases: List[str] = field(default_factory=list) # known spellings/variants
|
||||||
|
first_source: BrandSource = BrandSource.OCR
|
||||||
|
total_occurrences: int = 0
|
||||||
|
confirmed: bool = False # manually confirmed by user
|
||||||
|
|
||||||
|
created_at: Optional[datetime] = None
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SourceBrandSighting:
|
||||||
|
"""
|
||||||
|
A brand seen in a specific source (video/asset).
|
||||||
|
|
||||||
|
Per-source session cache — avoids re-escalating the same brand
|
||||||
|
on subsequent frames or re-runs of the same source.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
source_asset_id: UUID # the video this sighting belongs to
|
||||||
|
brand_id: UUID # FK to KnownBrand
|
||||||
|
brand_name: str # denormalized for fast lookup
|
||||||
|
first_seen_timestamp: float = 0.0
|
||||||
|
last_seen_timestamp: float = 0.0
|
||||||
|
occurrences: int = 0
|
||||||
|
detection_source: BrandSource = BrandSource.OCR
|
||||||
|
avg_confidence: float = 0.0
|
||||||
|
|
||||||
|
created_at: Optional[datetime] = None
|
||||||
@@ -28,7 +28,10 @@ GRPC_PORT=50051
|
|||||||
GRPC_MAX_WORKERS=10
|
GRPC_MAX_WORKERS=10
|
||||||
|
|
||||||
# S3 Storage (MinIO locally, real S3 on AWS)
|
# S3 Storage (MinIO locally, real S3 on AWS)
|
||||||
S3_ENDPOINT_URL=http://minio:9000
|
# In k8s/docker: http://minio:9000
|
||||||
|
# On dev machine (port-forward): http://localhost:9000
|
||||||
|
# On AWS: omit S3_ENDPOINT_URL entirely
|
||||||
|
S3_ENDPOINT_URL=http://localhost:9000
|
||||||
S3_BUCKET_IN=mpr-media-in
|
S3_BUCKET_IN=mpr-media-in
|
||||||
S3_BUCKET_OUT=mpr-media-out
|
S3_BUCKET_OUT=mpr-media-out
|
||||||
AWS_REGION=us-east-1
|
AWS_REGION=us-east-1
|
||||||
@@ -44,7 +47,7 @@ CLOUD_LLM_PROVIDER=groq
|
|||||||
|
|
||||||
# Groq (default, free tier)
|
# Groq (default, free tier)
|
||||||
GROQ_API_KEY=
|
GROQ_API_KEY=
|
||||||
GROQ_MODEL=llama-3.2-90b-vision-preview
|
GROQ_MODEL=meta-llama/llama-4-scout-17b-16e-instruct
|
||||||
|
|
||||||
# Gemini
|
# Gemini
|
||||||
#GEMINI_API_KEY=
|
#GEMINI_API_KEY=
|
||||||
|
|||||||
@@ -35,7 +35,15 @@ docker_build(
|
|||||||
# --- Resources ---
|
# --- Resources ---
|
||||||
|
|
||||||
k8s_resource('redis')
|
k8s_resource('redis')
|
||||||
k8s_resource('fastapi', resource_deps=['redis'])
|
k8s_resource('minio', port_forwards=['9000:9000', '9001:9001'])
|
||||||
|
k8s_resource('fastapi', resource_deps=['redis', 'minio'])
|
||||||
k8s_resource('detection-ui')
|
k8s_resource('detection-ui')
|
||||||
k8s_resource('gateway', resource_deps=['fastapi', 'detection-ui'],
|
k8s_resource('gateway', resource_deps=['fastapi', 'detection-ui'],
|
||||||
port_forwards=['8080:8080'])
|
port_forwards=['8080:8080'])
|
||||||
|
|
||||||
|
# Group uncategorized resources (configmaps, namespace) under infra
|
||||||
|
k8s_resource(
|
||||||
|
objects=['mpr:namespace', 'mpr-config:configmap', 'minio-config:configmap',
|
||||||
|
'envoy-gateway-config:configmap'],
|
||||||
|
new_name='infra',
|
||||||
|
)
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ spec:
|
|||||||
envFrom:
|
envFrom:
|
||||||
- configMapRef:
|
- configMapRef:
|
||||||
name: mpr-config
|
name: mpr-config
|
||||||
|
- configMapRef:
|
||||||
|
name: minio-config
|
||||||
readinessProbe:
|
readinessProbe:
|
||||||
httpGet:
|
httpGet:
|
||||||
path: /health
|
path: /health
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ resources:
|
|||||||
- namespace.yaml
|
- namespace.yaml
|
||||||
- configmap.yaml
|
- configmap.yaml
|
||||||
- redis.yaml
|
- redis.yaml
|
||||||
|
- minio.yaml
|
||||||
- fastapi.yaml
|
- fastapi.yaml
|
||||||
- detection-ui.yaml
|
- detection-ui.yaml
|
||||||
- gateway.yaml
|
- gateway.yaml
|
||||||
|
|||||||
87
ctrl/k8s/base/minio.yaml
Normal file
87
ctrl/k8s/base/minio.yaml
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: ConfigMap
|
||||||
|
metadata:
|
||||||
|
name: minio-config
|
||||||
|
namespace: mpr
|
||||||
|
data:
|
||||||
|
S3_ENDPOINT_URL: http://minio:9000
|
||||||
|
S3_BUCKET_IN: mpr-media-in
|
||||||
|
S3_BUCKET_OUT: mpr-media-out
|
||||||
|
AWS_ACCESS_KEY_ID: minioadmin
|
||||||
|
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||||
|
AWS_REGION: us-east-1
|
||||||
|
---
|
||||||
|
apiVersion: apps/v1
|
||||||
|
kind: Deployment
|
||||||
|
metadata:
|
||||||
|
name: minio
|
||||||
|
namespace: mpr
|
||||||
|
spec:
|
||||||
|
replicas: 1
|
||||||
|
selector:
|
||||||
|
matchLabels:
|
||||||
|
app: minio
|
||||||
|
template:
|
||||||
|
metadata:
|
||||||
|
labels:
|
||||||
|
app: minio
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- name: minio
|
||||||
|
image: minio/minio:latest
|
||||||
|
args: ["server", "/data", "--console-address", ":9001"]
|
||||||
|
ports:
|
||||||
|
- containerPort: 9000
|
||||||
|
name: api
|
||||||
|
- containerPort: 9001
|
||||||
|
name: console
|
||||||
|
env:
|
||||||
|
- name: MINIO_ROOT_USER
|
||||||
|
value: minioadmin
|
||||||
|
- name: MINIO_ROOT_PASSWORD
|
||||||
|
value: minioadmin
|
||||||
|
readinessProbe:
|
||||||
|
httpGet:
|
||||||
|
path: /minio/health/ready
|
||||||
|
port: 9000
|
||||||
|
initialDelaySeconds: 5
|
||||||
|
periodSeconds: 10
|
||||||
|
lifecycle:
|
||||||
|
postStart:
|
||||||
|
exec:
|
||||||
|
command:
|
||||||
|
- /bin/sh
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
sleep 3
|
||||||
|
for bucket in mpr-media-in mpr-media-out; do
|
||||||
|
mkdir -p /data/$bucket
|
||||||
|
done
|
||||||
|
volumeMounts:
|
||||||
|
- name: data
|
||||||
|
mountPath: /data
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: 128Mi
|
||||||
|
cpu: 100m
|
||||||
|
limits:
|
||||||
|
memory: 512Mi
|
||||||
|
volumes:
|
||||||
|
- name: data
|
||||||
|
emptyDir: {}
|
||||||
|
---
|
||||||
|
apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
name: minio
|
||||||
|
namespace: mpr
|
||||||
|
spec:
|
||||||
|
selector:
|
||||||
|
app: minio
|
||||||
|
ports:
|
||||||
|
- port: 9000
|
||||||
|
targetPort: 9000
|
||||||
|
name: api
|
||||||
|
- port: 9001
|
||||||
|
targetPort: 9001
|
||||||
|
name: console
|
||||||
14
detect/checkpoint/__init__.py
Normal file
14
detect/checkpoint/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
Stage checkpoint, replay, and retry.
|
||||||
|
|
||||||
|
detect/checkpoint/
|
||||||
|
frames.py — frame image S3 upload/download
|
||||||
|
serializer.py — state ↔ JSON conversion
|
||||||
|
storage.py — checkpoint save/load/list (Postgres + S3)
|
||||||
|
replay.py — replay_from, OverrideProfile
|
||||||
|
tasks.py — retry_candidates Celery task
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .storage import save_checkpoint, load_checkpoint, list_checkpoints
|
||||||
|
from .frames import save_frames, load_frames
|
||||||
|
from .replay import replay_from, OverrideProfile
|
||||||
80
detect/checkpoint/frames.py
Normal file
80
detect/checkpoint/frames.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
"""Frame image storage — save/load to S3/MinIO as JPEGs."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from detect.models import Frame
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
BUCKET = os.environ.get("S3_BUCKET_OUT", "mpr-media-out")
|
||||||
|
CHECKPOINT_PREFIX = "checkpoints"
|
||||||
|
|
||||||
|
|
||||||
|
def save_frames(job_id: str, frames: list[Frame]) -> dict[int, str]:
|
||||||
|
"""
|
||||||
|
Save frame images to S3 as JPEGs.
|
||||||
|
|
||||||
|
Returns manifest: {sequence: s3_key}
|
||||||
|
"""
|
||||||
|
from core.storage.s3 import upload_file
|
||||||
|
|
||||||
|
manifest = {}
|
||||||
|
|
||||||
|
for frame in frames:
|
||||||
|
key = f"{CHECKPOINT_PREFIX}/{job_id}/frames/{frame.sequence}.jpg"
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
|
||||||
|
img = Image.fromarray(frame.image)
|
||||||
|
img.save(tmp, format="JPEG", quality=85)
|
||||||
|
tmp_path = tmp.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
upload_file(tmp_path, BUCKET, key)
|
||||||
|
finally:
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
|
||||||
|
manifest[frame.sequence] = key
|
||||||
|
|
||||||
|
logger.info("Saved %d frames to s3://%s/%s/%s/frames/",
|
||||||
|
len(frames), BUCKET, CHECKPOINT_PREFIX, job_id)
|
||||||
|
return manifest
|
||||||
|
|
||||||
|
|
||||||
|
def load_frames(manifest: dict[int, str], frame_metadata: list[dict]) -> list[Frame]:
|
||||||
|
"""
|
||||||
|
Load frame images from S3 and reconstitute Frame objects.
|
||||||
|
|
||||||
|
frame_metadata: list of dicts with sequence, chunk_id, timestamp, perceptual_hash.
|
||||||
|
"""
|
||||||
|
from core.storage.s3 import download_to_temp
|
||||||
|
|
||||||
|
meta_map = {m["sequence"]: m for m in frame_metadata}
|
||||||
|
frames = []
|
||||||
|
|
||||||
|
for seq, key in manifest.items():
|
||||||
|
tmp_path = download_to_temp(BUCKET, key)
|
||||||
|
try:
|
||||||
|
img = Image.open(tmp_path).convert("RGB")
|
||||||
|
image_array = np.array(img)
|
||||||
|
finally:
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
|
||||||
|
meta = meta_map.get(seq, {})
|
||||||
|
frame = Frame(
|
||||||
|
sequence=seq,
|
||||||
|
chunk_id=meta.get("chunk_id", 0),
|
||||||
|
timestamp=meta.get("timestamp", 0.0),
|
||||||
|
image=image_array,
|
||||||
|
perceptual_hash=meta.get("perceptual_hash", ""),
|
||||||
|
)
|
||||||
|
frames.append(frame)
|
||||||
|
|
||||||
|
frames.sort(key=lambda f: f.sequence)
|
||||||
|
return frames
|
||||||
132
detect/checkpoint/replay.py
Normal file
132
detect/checkpoint/replay.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
"""
|
||||||
|
Pipeline replay — re-run from any stage with different config.
|
||||||
|
|
||||||
|
Loads a checkpoint, applies config overrides, builds a subgraph
|
||||||
|
starting from the target stage, and invokes it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from detect import emit
|
||||||
|
from detect.checkpoint import load_checkpoint, list_checkpoints
|
||||||
|
from detect.graph import NODES, build_graph
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OverrideProfile:
|
||||||
|
"""
|
||||||
|
Wraps a ContentTypeProfile and patches config methods with overrides.
|
||||||
|
|
||||||
|
Override dict structure:
|
||||||
|
{
|
||||||
|
"frame_extraction": {"fps": 1.0},
|
||||||
|
"scene_filter": {"hamming_threshold": 12},
|
||||||
|
"detection": {"confidence_threshold": 0.5},
|
||||||
|
"ocr": {"languages": ["en", "es"], "min_confidence": 0.3},
|
||||||
|
"resolver": {"fuzzy_threshold": 60},
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base, overrides: dict):
|
||||||
|
self._base = base
|
||||||
|
self._overrides = overrides
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self._base, name)
|
||||||
|
|
||||||
|
def _patch(self, config, key: str):
|
||||||
|
patches = self._overrides.get(key, {})
|
||||||
|
for k, v in patches.items():
|
||||||
|
if hasattr(config, k):
|
||||||
|
setattr(config, k, v)
|
||||||
|
return config
|
||||||
|
|
||||||
|
def frame_extraction_config(self):
|
||||||
|
return self._patch(self._base.frame_extraction_config(), "frame_extraction")
|
||||||
|
|
||||||
|
def scene_filter_config(self):
|
||||||
|
return self._patch(self._base.scene_filter_config(), "scene_filter")
|
||||||
|
|
||||||
|
def detection_config(self):
|
||||||
|
return self._patch(self._base.detection_config(), "detection")
|
||||||
|
|
||||||
|
def ocr_config(self):
|
||||||
|
return self._patch(self._base.ocr_config(), "ocr")
|
||||||
|
|
||||||
|
def resolver_config(self):
|
||||||
|
return self._patch(self._base.resolver_config(), "resolver")
|
||||||
|
|
||||||
|
def vlm_prompt(self, crop_context):
|
||||||
|
return self._base.vlm_prompt(crop_context)
|
||||||
|
|
||||||
|
def aggregate(self, detections):
|
||||||
|
return self._base.aggregate(detections)
|
||||||
|
|
||||||
|
def auxiliary_detections(self, source):
|
||||||
|
return self._base.auxiliary_detections(source)
|
||||||
|
|
||||||
|
|
||||||
|
def replay_from(
|
||||||
|
job_id: str,
|
||||||
|
start_stage: str,
|
||||||
|
config_overrides: dict | None = None,
|
||||||
|
checkpoint: bool = True,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Replay the pipeline from a specific stage.
|
||||||
|
|
||||||
|
Loads the checkpoint from the stage immediately before start_stage,
|
||||||
|
applies config overrides, and runs the subgraph from start_stage onward.
|
||||||
|
|
||||||
|
Returns the final state dict.
|
||||||
|
"""
|
||||||
|
if start_stage not in NODES:
|
||||||
|
raise ValueError(f"Unknown stage: {start_stage!r}. Options: {NODES}")
|
||||||
|
|
||||||
|
start_idx = NODES.index(start_stage)
|
||||||
|
|
||||||
|
# Load checkpoint from the stage before start_stage
|
||||||
|
if start_idx == 0:
|
||||||
|
raise ValueError("Cannot replay from the first stage — just run the full pipeline")
|
||||||
|
|
||||||
|
previous_stage = NODES[start_idx - 1]
|
||||||
|
|
||||||
|
available = list_checkpoints(job_id)
|
||||||
|
if previous_stage not in available:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
|
||||||
|
f"Available: {available}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Replaying job %s from %s (loading checkpoint: %s)",
|
||||||
|
job_id, start_stage, previous_stage)
|
||||||
|
|
||||||
|
state = load_checkpoint(job_id, previous_stage)
|
||||||
|
|
||||||
|
# Apply config overrides
|
||||||
|
if config_overrides:
|
||||||
|
state["config_overrides"] = config_overrides
|
||||||
|
|
||||||
|
# Set run context for SSE events
|
||||||
|
run_id = str(uuid.uuid4())[:8]
|
||||||
|
emit.set_run_context(
|
||||||
|
run_id=run_id,
|
||||||
|
parent_job_id=job_id,
|
||||||
|
run_type="replay",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build subgraph starting from start_stage
|
||||||
|
graph = build_graph(checkpoint=checkpoint, start_from=start_stage)
|
||||||
|
pipeline = graph.compile()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = pipeline.invoke(state)
|
||||||
|
finally:
|
||||||
|
emit.clear_run_context()
|
||||||
|
|
||||||
|
return result
|
||||||
133
detect/checkpoint/serializer.py
Normal file
133
detect/checkpoint/serializer.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
"""State serialization — DetectState ↔ JSON-compatible dict."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
|
from detect.models import (
|
||||||
|
BoundingBox,
|
||||||
|
BrandDetection,
|
||||||
|
Frame,
|
||||||
|
PipelineStats,
|
||||||
|
TextCandidate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Serialize helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def serialize_frame_meta(frame: Frame) -> dict:
|
||||||
|
meta = {
|
||||||
|
"sequence": frame.sequence,
|
||||||
|
"chunk_id": frame.chunk_id,
|
||||||
|
"timestamp": frame.timestamp,
|
||||||
|
"perceptual_hash": frame.perceptual_hash,
|
||||||
|
}
|
||||||
|
return meta
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_text_candidate(tc: TextCandidate) -> dict:
|
||||||
|
bbox_dict = dataclasses.asdict(tc.bbox)
|
||||||
|
candidate = {
|
||||||
|
"frame_sequence": tc.frame.sequence,
|
||||||
|
"bbox": bbox_dict,
|
||||||
|
"text": tc.text,
|
||||||
|
"ocr_confidence": tc.ocr_confidence,
|
||||||
|
}
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
|
||||||
|
"""
|
||||||
|
Serialize DetectState to a JSON-compatible dict.
|
||||||
|
|
||||||
|
Frame images are replaced with S3 key references.
|
||||||
|
TextCandidate.frame references become frame_sequence integers.
|
||||||
|
"""
|
||||||
|
frames = state.get("frames", [])
|
||||||
|
filtered = state.get("filtered_frames", [])
|
||||||
|
|
||||||
|
manifest_strs = {str(k): v for k, v in frames_manifest.items()}
|
||||||
|
frames_meta = [serialize_frame_meta(f) for f in frames]
|
||||||
|
filtered_seqs = [f.sequence for f in filtered]
|
||||||
|
|
||||||
|
boxes_serialized = {}
|
||||||
|
for seq, boxes in state.get("boxes_by_frame", {}).items():
|
||||||
|
boxes_serialized[str(seq)] = [dataclasses.asdict(b) for b in boxes]
|
||||||
|
|
||||||
|
text_candidates = [serialize_text_candidate(tc) for tc in state.get("text_candidates", [])]
|
||||||
|
unresolved = [serialize_text_candidate(tc) for tc in state.get("unresolved_candidates", [])]
|
||||||
|
detections = [dataclasses.asdict(d) for d in state.get("detections", [])]
|
||||||
|
stats = dataclasses.asdict(state.get("stats", PipelineStats()))
|
||||||
|
|
||||||
|
checkpoint = {
|
||||||
|
"job_id": state.get("job_id", ""),
|
||||||
|
"video_path": state.get("video_path", ""),
|
||||||
|
"profile_name": state.get("profile_name", ""),
|
||||||
|
"config_overrides": state.get("config_overrides", {}),
|
||||||
|
"frames_manifest": manifest_strs,
|
||||||
|
"frames_meta": frames_meta,
|
||||||
|
"filtered_frame_sequences": filtered_seqs,
|
||||||
|
"boxes_by_frame": boxes_serialized,
|
||||||
|
"text_candidates": text_candidates,
|
||||||
|
"unresolved_candidates": unresolved,
|
||||||
|
"detections": detections,
|
||||||
|
"stats": stats,
|
||||||
|
}
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Deserialize helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def deserialize_text_candidate(d: dict, frame_map: dict[int, Frame]) -> TextCandidate:
|
||||||
|
frame = frame_map[d["frame_sequence"]]
|
||||||
|
bbox = BoundingBox(**d["bbox"])
|
||||||
|
candidate = TextCandidate(
|
||||||
|
frame=frame,
|
||||||
|
bbox=bbox,
|
||||||
|
text=d["text"],
|
||||||
|
ocr_confidence=d["ocr_confidence"],
|
||||||
|
)
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_state(checkpoint: dict, frames: list[Frame]) -> dict:
|
||||||
|
"""Reconstitute DetectState from a checkpoint dict + loaded frames."""
|
||||||
|
frame_map = {f.sequence: f for f in frames}
|
||||||
|
|
||||||
|
filtered_seqs = set(checkpoint.get("filtered_frame_sequences", []))
|
||||||
|
filtered_frames = [f for f in frames if f.sequence in filtered_seqs]
|
||||||
|
|
||||||
|
boxes_by_frame = {}
|
||||||
|
for seq_str, box_dicts in checkpoint.get("boxes_by_frame", {}).items():
|
||||||
|
seq = int(seq_str)
|
||||||
|
boxes_by_frame[seq] = [BoundingBox(**b) for b in box_dicts]
|
||||||
|
|
||||||
|
text_candidates = [
|
||||||
|
deserialize_text_candidate(d, frame_map)
|
||||||
|
for d in checkpoint.get("text_candidates", [])
|
||||||
|
]
|
||||||
|
unresolved_candidates = [
|
||||||
|
deserialize_text_candidate(d, frame_map)
|
||||||
|
for d in checkpoint.get("unresolved_candidates", [])
|
||||||
|
]
|
||||||
|
detections = [BrandDetection(**d) for d in checkpoint.get("detections", [])]
|
||||||
|
stats = PipelineStats(**checkpoint.get("stats", {}))
|
||||||
|
|
||||||
|
state = {
|
||||||
|
"job_id": checkpoint.get("job_id", ""),
|
||||||
|
"video_path": checkpoint.get("video_path", ""),
|
||||||
|
"profile_name": checkpoint.get("profile_name", ""),
|
||||||
|
"config_overrides": checkpoint.get("config_overrides", {}),
|
||||||
|
"frames": frames,
|
||||||
|
"filtered_frames": filtered_frames,
|
||||||
|
"boxes_by_frame": boxes_by_frame,
|
||||||
|
"text_candidates": text_candidates,
|
||||||
|
"unresolved_candidates": unresolved_candidates,
|
||||||
|
"detections": detections,
|
||||||
|
"stats": stats,
|
||||||
|
}
|
||||||
|
return state
|
||||||
215
detect/checkpoint/storage.py
Normal file
215
detect/checkpoint/storage.py
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
"""
|
||||||
|
Checkpoint storage — save/load stage state.
|
||||||
|
|
||||||
|
Binary data (frame images) → S3/MinIO via frames.py
|
||||||
|
Structured data (boxes, detections, stats, config) → Postgres via Django ORM
|
||||||
|
|
||||||
|
Until the Django model is generated by modelgen, checkpoint data is stored
|
||||||
|
as JSON in S3 as a fallback. Once DetectJob/StageCheckpoint models exist,
|
||||||
|
this module switches to Postgres.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .frames import save_frames, load_frames, BUCKET, CHECKPOINT_PREFIX
|
||||||
|
from .serializer import serialize_state, deserialize_state
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _has_db() -> bool:
|
||||||
|
"""Check if the DB layer is available (Django + models generated by modelgen)."""
|
||||||
|
try:
|
||||||
|
from core.db.detect import get_stage_checkpoint as _
|
||||||
|
# Quick check that the model exists (modelgen may not have run yet)
|
||||||
|
from admin.mpr.media_assets.models import StageCheckpoint as _
|
||||||
|
return True
|
||||||
|
except (ImportError, Exception):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Save
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def save_checkpoint(
|
||||||
|
job_id: str,
|
||||||
|
stage: str,
|
||||||
|
stage_index: int,
|
||||||
|
state: dict,
|
||||||
|
frames_manifest: dict[int, str] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Save a stage checkpoint.
|
||||||
|
|
||||||
|
Saves frame images to S3 (if not already saved), then persists
|
||||||
|
structured state to Postgres (or S3 JSON fallback).
|
||||||
|
|
||||||
|
Returns the checkpoint identifier (DB id or S3 key).
|
||||||
|
"""
|
||||||
|
# Save frames to S3 if no manifest provided
|
||||||
|
if frames_manifest is None:
|
||||||
|
all_frames = state.get("frames", [])
|
||||||
|
frames_manifest = save_frames(job_id, all_frames)
|
||||||
|
|
||||||
|
checkpoint_data = serialize_state(state, frames_manifest)
|
||||||
|
|
||||||
|
if _has_db():
|
||||||
|
checkpoint_id = _save_to_db(job_id, stage, stage_index, checkpoint_data)
|
||||||
|
else:
|
||||||
|
checkpoint_id = _save_to_s3(job_id, stage, checkpoint_data)
|
||||||
|
|
||||||
|
return checkpoint_id
|
||||||
|
|
||||||
|
|
||||||
|
def _save_to_db(job_id: str, stage: str, stage_index: int, data: dict) -> str:
|
||||||
|
"""Save checkpoint structured data to Postgres via core/db."""
|
||||||
|
import uuid
|
||||||
|
from core.db.detect import save_stage_checkpoint
|
||||||
|
|
||||||
|
job_uuid = uuid.UUID(job_id) if isinstance(job_id, str) else job_id
|
||||||
|
checkpoint_id = uuid.uuid4()
|
||||||
|
frames_prefix = f"{CHECKPOINT_PREFIX}/{job_id}/frames/"
|
||||||
|
|
||||||
|
checkpoint = save_stage_checkpoint(
|
||||||
|
id=checkpoint_id,
|
||||||
|
job_id=job_uuid,
|
||||||
|
stage=stage,
|
||||||
|
stage_index=stage_index,
|
||||||
|
frames_prefix=frames_prefix,
|
||||||
|
frames_manifest=data.get("frames_manifest", {}),
|
||||||
|
frames_meta=data.get("frames_meta", []),
|
||||||
|
filtered_frame_sequences=data.get("filtered_frame_sequences", []),
|
||||||
|
boxes_by_frame=data.get("boxes_by_frame", {}),
|
||||||
|
text_candidates=data.get("text_candidates", []),
|
||||||
|
unresolved_candidates=data.get("unresolved_candidates", []),
|
||||||
|
detections=data.get("detections", []),
|
||||||
|
stats=data.get("stats", {}),
|
||||||
|
config_snapshot=data.get("config_overrides", {}),
|
||||||
|
config_overrides=data.get("config_overrides", {}),
|
||||||
|
video_path=data.get("video_path", ""),
|
||||||
|
profile_name=data.get("profile_name", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Checkpoint saved to DB: %s/%s (id=%s)", job_id, stage, checkpoint.id)
|
||||||
|
return str(checkpoint.id)
|
||||||
|
|
||||||
|
|
||||||
|
def _save_to_s3(job_id: str, stage: str, data: dict) -> str:
|
||||||
|
"""Fallback: save checkpoint as JSON to S3 (before modelgen generates DB models)."""
|
||||||
|
from core.storage.s3 import upload_file
|
||||||
|
|
||||||
|
checkpoint_json = json.dumps(data, default=str)
|
||||||
|
key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json"
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp:
|
||||||
|
tmp.write(checkpoint_json)
|
||||||
|
tmp_path = tmp.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
upload_file(tmp_path, BUCKET, key)
|
||||||
|
finally:
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
|
||||||
|
logger.info("Checkpoint saved to S3: s3://%s/%s", BUCKET, key)
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Load
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def load_checkpoint(job_id: str, stage: str) -> dict:
|
||||||
|
"""
|
||||||
|
Load a stage checkpoint and reconstitute full DetectState.
|
||||||
|
|
||||||
|
Tries Postgres first, falls back to S3 JSON.
|
||||||
|
"""
|
||||||
|
if _has_db():
|
||||||
|
data = _load_from_db(job_id, stage)
|
||||||
|
else:
|
||||||
|
data = _load_from_s3(job_id, stage)
|
||||||
|
|
||||||
|
raw_manifest = data.get("frames_manifest", {})
|
||||||
|
manifest = {int(k): v for k, v in raw_manifest.items()}
|
||||||
|
frame_metadata = data.get("frames_meta", [])
|
||||||
|
frames = load_frames(manifest, frame_metadata)
|
||||||
|
|
||||||
|
state = deserialize_state(data, frames)
|
||||||
|
|
||||||
|
logger.info("Checkpoint loaded: %s/%s (%d frames)", job_id, stage, len(frames))
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def _load_from_db(job_id: str, stage: str) -> dict:
|
||||||
|
"""Load checkpoint data from Postgres via core/db."""
|
||||||
|
from core.db.detect import get_stage_checkpoint
|
||||||
|
|
||||||
|
checkpoint = get_stage_checkpoint(job_id, stage)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"job_id": str(checkpoint.job_id),
|
||||||
|
"video_path": checkpoint.video_path,
|
||||||
|
"profile_name": checkpoint.profile_name,
|
||||||
|
"config_overrides": checkpoint.config_overrides,
|
||||||
|
"frames_manifest": checkpoint.frames_manifest,
|
||||||
|
"frames_meta": checkpoint.frames_meta,
|
||||||
|
"filtered_frame_sequences": checkpoint.filtered_frame_sequences,
|
||||||
|
"boxes_by_frame": checkpoint.boxes_by_frame,
|
||||||
|
"text_candidates": checkpoint.text_candidates,
|
||||||
|
"unresolved_candidates": checkpoint.unresolved_candidates,
|
||||||
|
"detections": checkpoint.detections,
|
||||||
|
"stats": checkpoint.stats,
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _load_from_s3(job_id: str, stage: str) -> dict:
|
||||||
|
"""Fallback: load checkpoint JSON from S3."""
|
||||||
|
from core.storage.s3 import download_to_temp
|
||||||
|
|
||||||
|
key = f"{CHECKPOINT_PREFIX}/{job_id}/stages/{stage}.json"
|
||||||
|
tmp_path = download_to_temp(BUCKET, key)
|
||||||
|
try:
|
||||||
|
with open(tmp_path) as f:
|
||||||
|
data = json.load(f)
|
||||||
|
finally:
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# List
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def list_checkpoints(job_id: str) -> list[str]:
|
||||||
|
"""List available checkpoint stages for a job."""
|
||||||
|
if _has_db():
|
||||||
|
return _list_from_db(job_id)
|
||||||
|
return _list_from_s3(job_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _list_from_db(job_id: str) -> list[str]:
|
||||||
|
from core.db.detect import list_stage_checkpoints
|
||||||
|
return list_stage_checkpoints(job_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _list_from_s3(job_id: str) -> list[str]:
|
||||||
|
from core.storage.s3 import list_objects
|
||||||
|
|
||||||
|
prefix = f"{CHECKPOINT_PREFIX}/{job_id}/stages/"
|
||||||
|
objects = list_objects(BUCKET, prefix)
|
||||||
|
|
||||||
|
stages = []
|
||||||
|
for obj in objects:
|
||||||
|
name = Path(obj["key"]).stem
|
||||||
|
stages.append(name)
|
||||||
|
|
||||||
|
return stages
|
||||||
71
detect/checkpoint/tasks.py
Normal file
71
detect/checkpoint/tasks.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""
|
||||||
|
Celery tasks for detection pipeline async operations.
|
||||||
|
|
||||||
|
retry_candidates: re-run VLM/cloud escalation with different config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from celery import shared_task
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task(bind=True, max_retries=1, default_retry_delay=30)
|
||||||
|
def retry_candidates(
|
||||||
|
self,
|
||||||
|
job_id: str,
|
||||||
|
config_overrides: dict | None = None,
|
||||||
|
start_stage: str = "escalate_vlm",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retry unresolved candidates with different config.
|
||||||
|
|
||||||
|
Loads the checkpoint from the stage before start_stage,
|
||||||
|
applies config overrides (e.g. different cloud provider),
|
||||||
|
and runs from start_stage onward.
|
||||||
|
"""
|
||||||
|
from detect.checkpoint.replay import replay_from
|
||||||
|
|
||||||
|
run_id = str(uuid.uuid4())[:8]
|
||||||
|
logger.info("Retry task %s: job=%s, from=%s, overrides=%s",
|
||||||
|
run_id, job_id, start_stage, config_overrides)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = replay_from(
|
||||||
|
job_id=job_id,
|
||||||
|
start_stage=start_stage,
|
||||||
|
config_overrides=config_overrides,
|
||||||
|
)
|
||||||
|
|
||||||
|
detections = result.get("detections", [])
|
||||||
|
report = result.get("report")
|
||||||
|
brands_found = len(report.brands) if report else 0
|
||||||
|
|
||||||
|
logger.info("Retry %s complete: %d detections, %d brands",
|
||||||
|
run_id, len(detections), brands_found)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"run_id": run_id,
|
||||||
|
"job_id": job_id,
|
||||||
|
"detections": len(detections),
|
||||||
|
"brands_found": brands_found,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Retry %s failed: %s", run_id, e)
|
||||||
|
|
||||||
|
if self.request.retries < self.max_retries:
|
||||||
|
raise self.retry(exc=e)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"run_id": run_id,
|
||||||
|
"job_id": job_id,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
@@ -3,6 +3,9 @@ Event emission helpers for detection pipeline stages.
|
|||||||
|
|
||||||
Single place that knows how to build event payloads.
|
Single place that knows how to build event payloads.
|
||||||
Stages call these instead of constructing dicts or dataclasses directly.
|
Stages call these instead of constructing dicts or dataclasses directly.
|
||||||
|
|
||||||
|
Run context (run_id, parent_job_id) is set once at pipeline start via
|
||||||
|
set_run_context() and automatically injected into all events.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -13,9 +16,33 @@ from datetime import datetime, timezone
|
|||||||
from detect.events import push_detect_event
|
from detect.events import push_detect_event
|
||||||
from detect.models import PipelineStats
|
from detect.models import PipelineStats
|
||||||
|
|
||||||
|
# Module-level run context — set once per pipeline invocation
|
||||||
|
_run_context: dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
def set_run_context(run_id: str = "", parent_job_id: str = "", run_type: str = "initial"):
|
||||||
|
"""Set the run context for all subsequent events in this pipeline invocation."""
|
||||||
|
global _run_context
|
||||||
|
_run_context = {
|
||||||
|
"run_id": run_id,
|
||||||
|
"parent_job_id": parent_job_id,
|
||||||
|
"run_type": run_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def clear_run_context():
|
||||||
|
global _run_context
|
||||||
|
_run_context = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _inject_context(payload: dict) -> dict:
|
||||||
|
"""Add run context fields to an event payload."""
|
||||||
|
if _run_context:
|
||||||
|
payload.update(_run_context)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
def log(job_id: str | None, stage: str, level: str, msg: str) -> None:
|
def log(job_id: str | None, stage: str, level: str, msg: str) -> None:
|
||||||
"""Emit a log event."""
|
|
||||||
if not job_id:
|
if not job_id:
|
||||||
return
|
return
|
||||||
payload = {
|
payload = {
|
||||||
@@ -24,15 +51,17 @@ def log(job_id: str | None, stage: str, level: str, msg: str) -> None:
|
|||||||
"msg": msg,
|
"msg": msg,
|
||||||
"ts": datetime.now(timezone.utc).isoformat(),
|
"ts": datetime.now(timezone.utc).isoformat(),
|
||||||
}
|
}
|
||||||
|
_inject_context(payload)
|
||||||
push_detect_event(job_id, "log", payload)
|
push_detect_event(job_id, "log", payload)
|
||||||
|
|
||||||
|
|
||||||
def stats(job_id: str | None, **kwargs) -> None:
|
def stats(job_id: str | None, **kwargs) -> None:
|
||||||
"""Emit a stats_update event. Pass only the fields that changed."""
|
|
||||||
if not job_id:
|
if not job_id:
|
||||||
return
|
return
|
||||||
s = PipelineStats(**kwargs)
|
s = PipelineStats(**kwargs)
|
||||||
push_detect_event(job_id, "stats_update", dataclasses.asdict(s))
|
payload = dataclasses.asdict(s)
|
||||||
|
_inject_context(payload)
|
||||||
|
push_detect_event(job_id, "stats_update", payload)
|
||||||
|
|
||||||
|
|
||||||
def frame_update(
|
def frame_update(
|
||||||
@@ -42,7 +71,6 @@ def frame_update(
|
|||||||
jpeg_b64: str,
|
jpeg_b64: str,
|
||||||
boxes: list[dict],
|
boxes: list[dict],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emit a frame_update event with the image and bounding boxes."""
|
|
||||||
if not job_id:
|
if not job_id:
|
||||||
return
|
return
|
||||||
payload = {
|
payload = {
|
||||||
@@ -51,14 +79,15 @@ def frame_update(
|
|||||||
"jpeg_b64": jpeg_b64,
|
"jpeg_b64": jpeg_b64,
|
||||||
"boxes": boxes,
|
"boxes": boxes,
|
||||||
}
|
}
|
||||||
|
_inject_context(payload)
|
||||||
push_detect_event(job_id, "frame_update", payload)
|
push_detect_event(job_id, "frame_update", payload)
|
||||||
|
|
||||||
|
|
||||||
def graph_update(job_id: str | None, nodes: list[dict]) -> None:
|
def graph_update(job_id: str | None, nodes: list[dict]) -> None:
|
||||||
"""Emit a graph_update event with node states."""
|
|
||||||
if not job_id:
|
if not job_id:
|
||||||
return
|
return
|
||||||
payload = {"nodes": nodes}
|
payload = {"nodes": nodes}
|
||||||
|
_inject_context(payload)
|
||||||
push_detect_event(job_id, "graph_update", payload)
|
push_detect_event(job_id, "graph_update", payload)
|
||||||
|
|
||||||
|
|
||||||
@@ -72,7 +101,6 @@ def detection(
|
|||||||
content_type: str = "",
|
content_type: str = "",
|
||||||
frame_ref: int | None = None,
|
frame_ref: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emit a brand detection event."""
|
|
||||||
if not job_id:
|
if not job_id:
|
||||||
return
|
return
|
||||||
payload = {
|
payload = {
|
||||||
@@ -84,12 +112,13 @@ def detection(
|
|||||||
"content_type": content_type,
|
"content_type": content_type,
|
||||||
"frame_ref": frame_ref,
|
"frame_ref": frame_ref,
|
||||||
}
|
}
|
||||||
|
_inject_context(payload)
|
||||||
push_detect_event(job_id, "detection", payload)
|
push_detect_event(job_id, "detection", payload)
|
||||||
|
|
||||||
|
|
||||||
def job_complete(job_id: str | None, report: dict) -> None:
|
def job_complete(job_id: str | None, report: dict) -> None:
|
||||||
"""Emit a job_complete event with the final report."""
|
|
||||||
if not job_id:
|
if not job_id:
|
||||||
return
|
return
|
||||||
payload = {"job_id": job_id, "report": report}
|
payload = {"job_id": job_id, "report": report}
|
||||||
|
_inject_context(payload)
|
||||||
push_detect_event(job_id, "job_complete", payload)
|
push_detect_event(job_id, "job_complete", payload)
|
||||||
|
|||||||
125
detect/graph.py
125
detect/graph.py
@@ -42,9 +42,17 @@ NODES = [
|
|||||||
def _get_profile(state: DetectState):
|
def _get_profile(state: DetectState):
|
||||||
name = state.get("profile_name", "soccer_broadcast")
|
name = state.get("profile_name", "soccer_broadcast")
|
||||||
if name == "soccer_broadcast":
|
if name == "soccer_broadcast":
|
||||||
return SoccerBroadcastProfile()
|
profile = SoccerBroadcastProfile()
|
||||||
|
else:
|
||||||
raise ValueError(f"Unknown profile: {name}")
|
raise ValueError(f"Unknown profile: {name}")
|
||||||
|
|
||||||
|
overrides = state.get("config_overrides")
|
||||||
|
if overrides:
|
||||||
|
from detect.checkpoint.replay import OverrideProfile
|
||||||
|
profile = OverrideProfile(profile, overrides)
|
||||||
|
|
||||||
|
return profile
|
||||||
|
|
||||||
|
|
||||||
# Track node states across the pipeline run
|
# Track node states across the pipeline run
|
||||||
_node_states: dict[str, dict[str, str]] = {}
|
_node_states: dict[str, dict[str, str]] = {}
|
||||||
@@ -68,6 +76,18 @@ def _emit_transition(state: DetectState, node: str, status: str):
|
|||||||
# --- Node functions ---
|
# --- Node functions ---
|
||||||
|
|
||||||
def node_extract_frames(state: DetectState) -> dict:
|
def node_extract_frames(state: DetectState) -> dict:
|
||||||
|
# Set run context for initial runs (replays set it in replay_from)
|
||||||
|
job_id = state.get("job_id", "")
|
||||||
|
if job_id and not emit._run_context:
|
||||||
|
emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial")
|
||||||
|
|
||||||
|
# Load session brands from DB for this source
|
||||||
|
source_asset_id = state.get("source_asset_id")
|
||||||
|
if source_asset_id and not state.get("session_brands"):
|
||||||
|
from detect.stages.brand_resolver import build_session_dict
|
||||||
|
session_brands = build_session_dict(source_asset_id)
|
||||||
|
state["session_brands"] = session_brands
|
||||||
|
|
||||||
_emit_transition(state, "extract_frames", "running")
|
_emit_transition(state, "extract_frames", "running")
|
||||||
|
|
||||||
with trace_node(state, "extract_frames") as span:
|
with trace_node(state, "extract_frames") as span:
|
||||||
@@ -142,13 +162,16 @@ def node_match_brands(state: DetectState) -> dict:
|
|||||||
|
|
||||||
with trace_node(state, "match_brands") as span:
|
with trace_node(state, "match_brands") as span:
|
||||||
profile = _get_profile(state)
|
profile = _get_profile(state)
|
||||||
dictionary = profile.brand_dictionary()
|
|
||||||
resolver_config = profile.resolver_config()
|
resolver_config = profile.resolver_config()
|
||||||
candidates = state.get("text_candidates", [])
|
candidates = state.get("text_candidates", [])
|
||||||
|
session_brands = state.get("session_brands", {})
|
||||||
job_id = state.get("job_id")
|
job_id = state.get("job_id")
|
||||||
|
source_asset_id = state.get("source_asset_id")
|
||||||
|
|
||||||
matched, unresolved = resolve_brands(
|
matched, unresolved = resolve_brands(
|
||||||
candidates, dictionary, resolver_config,
|
candidates, resolver_config,
|
||||||
|
session_brands=session_brands,
|
||||||
|
source_asset_id=source_asset_id,
|
||||||
content_type=profile.name, job_id=job_id,
|
content_type=profile.name, job_id=job_id,
|
||||||
)
|
)
|
||||||
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
|
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
|
||||||
@@ -170,6 +193,7 @@ def node_escalate_vlm(state: DetectState) -> dict:
|
|||||||
vlm_prompt_fn=profile.vlm_prompt,
|
vlm_prompt_fn=profile.vlm_prompt,
|
||||||
inference_url=INFERENCE_URL,
|
inference_url=INFERENCE_URL,
|
||||||
content_type=profile.name,
|
content_type=profile.name,
|
||||||
|
source_asset_id=state.get("source_asset_id"),
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -202,6 +226,7 @@ def node_escalate_cloud(state: DetectState) -> dict:
|
|||||||
vlm_prompt_fn=profile.vlm_prompt,
|
vlm_prompt_fn=profile.vlm_prompt,
|
||||||
stats=stats,
|
stats=stats,
|
||||||
content_type=profile.name,
|
content_type=profile.name,
|
||||||
|
source_asset_id=state.get("source_asset_id"),
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -239,33 +264,87 @@ def node_compile_report(state: DetectState) -> dict:
|
|||||||
return {"report": report}
|
return {"report": report}
|
||||||
|
|
||||||
|
|
||||||
|
# --- Checkpoint wrapper ---
|
||||||
|
|
||||||
|
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
|
||||||
|
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
|
||||||
|
|
||||||
|
|
||||||
|
def _checkpointing_node(node_name: str, node_fn):
|
||||||
|
"""Wrap a node function to auto-checkpoint after completion."""
|
||||||
|
stage_index = NODES.index(node_name)
|
||||||
|
|
||||||
|
def wrapper(state: DetectState) -> dict:
|
||||||
|
result = node_fn(state)
|
||||||
|
|
||||||
|
job_id = state.get("job_id", "")
|
||||||
|
if not job_id:
|
||||||
|
return result
|
||||||
|
|
||||||
|
from detect.checkpoint import save_checkpoint, save_frames
|
||||||
|
|
||||||
|
merged = {**state, **result}
|
||||||
|
|
||||||
|
# Save frames once (first checkpoint), reuse manifest after
|
||||||
|
manifest = _frames_manifest.get(job_id)
|
||||||
|
if manifest is None and node_name == "extract_frames":
|
||||||
|
manifest = save_frames(job_id, merged.get("frames", []))
|
||||||
|
_frames_manifest[job_id] = manifest
|
||||||
|
|
||||||
|
save_checkpoint(job_id, node_name, stage_index, merged, frames_manifest=manifest)
|
||||||
|
return result
|
||||||
|
|
||||||
|
wrapper.__name__ = node_fn.__name__
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
# --- Graph construction ---
|
# --- Graph construction ---
|
||||||
|
|
||||||
def build_graph() -> StateGraph:
|
NODE_FUNCTIONS = [
|
||||||
|
("extract_frames", node_extract_frames),
|
||||||
|
("filter_scenes", node_filter_scenes),
|
||||||
|
("detect_objects", node_detect_objects),
|
||||||
|
("run_ocr", node_run_ocr),
|
||||||
|
("match_brands", node_match_brands),
|
||||||
|
("escalate_vlm", node_escalate_vlm),
|
||||||
|
("escalate_cloud", node_escalate_cloud),
|
||||||
|
("compile_report", node_compile_report),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph:
|
||||||
|
"""
|
||||||
|
Build the pipeline graph.
|
||||||
|
|
||||||
|
checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var)
|
||||||
|
start_from: skip nodes before this stage (for replay)
|
||||||
|
"""
|
||||||
|
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
|
||||||
|
|
||||||
graph = StateGraph(DetectState)
|
graph = StateGraph(DetectState)
|
||||||
|
|
||||||
graph.add_node("extract_frames", node_extract_frames)
|
# Filter to start_from if replaying
|
||||||
graph.add_node("filter_scenes", node_filter_scenes)
|
node_pairs = NODE_FUNCTIONS
|
||||||
graph.add_node("detect_objects", node_detect_objects)
|
if start_from:
|
||||||
graph.add_node("run_ocr", node_run_ocr)
|
start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from)
|
||||||
graph.add_node("match_brands", node_match_brands)
|
node_pairs = NODE_FUNCTIONS[start_idx:]
|
||||||
graph.add_node("escalate_vlm", node_escalate_vlm)
|
|
||||||
graph.add_node("escalate_cloud", node_escalate_cloud)
|
|
||||||
graph.add_node("compile_report", node_compile_report)
|
|
||||||
|
|
||||||
graph.set_entry_point("extract_frames")
|
for name, fn in node_pairs:
|
||||||
graph.add_edge("extract_frames", "filter_scenes")
|
wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn
|
||||||
graph.add_edge("filter_scenes", "detect_objects")
|
graph.add_node(name, wrapped)
|
||||||
graph.add_edge("detect_objects", "run_ocr")
|
|
||||||
graph.add_edge("run_ocr", "match_brands")
|
# Wire edges
|
||||||
graph.add_edge("match_brands", "escalate_vlm")
|
entry = node_pairs[0][0]
|
||||||
graph.add_edge("escalate_vlm", "escalate_cloud")
|
graph.set_entry_point(entry)
|
||||||
graph.add_edge("escalate_cloud", "compile_report")
|
|
||||||
graph.add_edge("compile_report", END)
|
for i in range(len(node_pairs) - 1):
|
||||||
|
graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0])
|
||||||
|
|
||||||
|
graph.add_edge(node_pairs[-1][0], END)
|
||||||
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline():
|
def get_pipeline(checkpoint: bool | None = None):
|
||||||
"""Return a compiled, runnable pipeline."""
|
"""Return a compiled, runnable pipeline."""
|
||||||
return build_graph().compile()
|
return build_graph(checkpoint=checkpoint).compile()
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from .base import (
|
from .base import (
|
||||||
ContentTypeProfile,
|
ContentTypeProfile,
|
||||||
BrandDictionary,
|
|
||||||
CropContext,
|
CropContext,
|
||||||
DetectionConfig,
|
DetectionConfig,
|
||||||
FrameExtractionConfig,
|
FrameExtractionConfig,
|
||||||
@@ -12,7 +11,6 @@ from .soccer import SoccerBroadcastProfile
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ContentTypeProfile",
|
"ContentTypeProfile",
|
||||||
"BrandDictionary",
|
|
||||||
"CropContext",
|
"CropContext",
|
||||||
"DetectionConfig",
|
"DetectionConfig",
|
||||||
"FrameExtractionConfig",
|
"FrameExtractionConfig",
|
||||||
|
|||||||
@@ -44,12 +44,6 @@ class ResolverConfig:
|
|||||||
fuzzy_threshold: int = 75
|
fuzzy_threshold: int = 75
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BrandDictionary:
|
|
||||||
"""Maps canonical brand name → list of known aliases/spellings."""
|
|
||||||
brands: dict[str, list[str]] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CropContext:
|
class CropContext:
|
||||||
image: bytes
|
image: bytes
|
||||||
@@ -64,7 +58,6 @@ class ContentTypeProfile(Protocol):
|
|||||||
def scene_filter_config(self) -> SceneFilterConfig: ...
|
def scene_filter_config(self) -> SceneFilterConfig: ...
|
||||||
def detection_config(self) -> DetectionConfig: ...
|
def detection_config(self) -> DetectionConfig: ...
|
||||||
def ocr_config(self) -> OCRConfig: ...
|
def ocr_config(self) -> OCRConfig: ...
|
||||||
def brand_dictionary(self) -> BrandDictionary: ...
|
|
||||||
def resolver_config(self) -> ResolverConfig: ...
|
def resolver_config(self) -> ResolverConfig: ...
|
||||||
def vlm_prompt(self, crop_context: CropContext) -> str: ...
|
def vlm_prompt(self, crop_context: CropContext) -> str: ...
|
||||||
def aggregate(self, detections: list[BrandDetection]) -> DetectionReport: ...
|
def aggregate(self, detections: list[BrandDetection]) -> DetectionReport: ...
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from __future__ import annotations
|
|||||||
from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
|
from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
BrandDictionary,
|
|
||||||
CropContext,
|
CropContext,
|
||||||
DetectionConfig,
|
DetectionConfig,
|
||||||
FrameExtractionConfig,
|
FrameExtractionConfig,
|
||||||
@@ -34,22 +33,6 @@ class SoccerBroadcastProfile:
|
|||||||
def ocr_config(self) -> OCRConfig:
|
def ocr_config(self) -> OCRConfig:
|
||||||
return OCRConfig(languages=["en", "es"], min_confidence=0.5)
|
return OCRConfig(languages=["en", "es"], min_confidence=0.5)
|
||||||
|
|
||||||
def brand_dictionary(self) -> BrandDictionary:
|
|
||||||
return BrandDictionary(brands={
|
|
||||||
"Nike": ["nike", "NIKE", "swoosh"],
|
|
||||||
"Adidas": ["adidas", "ADIDAS", "adi"],
|
|
||||||
"Puma": ["puma", "PUMA"],
|
|
||||||
"Emirates": ["emirates", "fly emirates", "EMIRATES"],
|
|
||||||
"Coca-Cola": ["coca-cola", "coca cola", "coke", "COCA-COLA"],
|
|
||||||
"Pepsi": ["pepsi", "PEPSI"],
|
|
||||||
"Mastercard": ["mastercard", "MASTERCARD"],
|
|
||||||
"Heineken": ["heineken", "HEINEKEN"],
|
|
||||||
"Santander": ["santander", "SANTANDER"],
|
|
||||||
"Gazprom": ["gazprom", "GAZPROM"],
|
|
||||||
"Qatar Airways": ["qatar airways", "QATAR AIRWAYS"],
|
|
||||||
"Lay's": ["lays", "lay's", "LAYS", "LAY'S"],
|
|
||||||
})
|
|
||||||
|
|
||||||
def resolver_config(self) -> ResolverConfig:
|
def resolver_config(self) -> ResolverConfig:
|
||||||
return ResolverConfig(fuzzy_threshold=75)
|
return ResolverConfig(fuzzy_threshold=75)
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from __future__ import annotations
|
|||||||
from detect.models import BrandDetection, DetectionReport
|
from detect.models import BrandDetection, DetectionReport
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
BrandDictionary,
|
|
||||||
CropContext,
|
CropContext,
|
||||||
DetectionConfig,
|
DetectionConfig,
|
||||||
FrameExtractionConfig,
|
FrameExtractionConfig,
|
||||||
@@ -30,9 +29,6 @@ class NewsBroadcastProfile:
|
|||||||
def ocr_config(self) -> OCRConfig:
|
def ocr_config(self) -> OCRConfig:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def brand_dictionary(self) -> BrandDictionary:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def resolver_config(self) -> ResolverConfig:
|
def resolver_config(self) -> ResolverConfig:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -61,9 +57,6 @@ class AdvertisingProfile:
|
|||||||
def ocr_config(self) -> OCRConfig:
|
def ocr_config(self) -> OCRConfig:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def brand_dictionary(self) -> BrandDictionary:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def resolver_config(self) -> ResolverConfig:
|
def resolver_config(self) -> ResolverConfig:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -92,9 +85,6 @@ class TranscriptProfile:
|
|||||||
def ocr_config(self) -> OCRConfig:
|
def ocr_config(self) -> OCRConfig:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def brand_dictionary(self) -> BrandDictionary:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def resolver_config(self) -> ResolverConfig:
|
def resolver_config(self) -> ResolverConfig:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,17 @@
|
|||||||
"""
|
"""
|
||||||
Stage 5 — Brand Resolver
|
Stage 5 — Brand Resolver (discovery mode)
|
||||||
|
|
||||||
Matches OCR text against the profile's brand dictionary.
|
Discovery-first brand matching. No static dictionary — all brands live in the DB.
|
||||||
Uses exact matching first, then fuzzy matching (rapidfuzz) as fallback.
|
|
||||||
Emits detection events for confirmed brands.
|
Flow:
|
||||||
|
1. Check session sightings first (brands already seen in this source)
|
||||||
|
2. Check global known brands (accumulated across all runs)
|
||||||
|
3. Unresolved candidates → escalate to VLM/cloud
|
||||||
|
4. Confirmed brands get added to DB for future runs
|
||||||
|
|
||||||
|
The resolver is an enricher, not a gatekeeper. Every OCR text candidate
|
||||||
|
passes through — the question is whether we can resolve it cheaply (DB lookup)
|
||||||
|
or need to escalate (VLM/cloud).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -14,99 +22,199 @@ from rapidfuzz import fuzz
|
|||||||
|
|
||||||
from detect import emit
|
from detect import emit
|
||||||
from detect.models import BrandDetection, TextCandidate
|
from detect.models import BrandDetection, TextCandidate
|
||||||
from detect.profiles.base import BrandDictionary, ResolverConfig
|
from detect.profiles.base import ResolverConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _normalize(text: str) -> str:
|
def _normalize(text: str) -> str:
|
||||||
"""Normalize text for matching."""
|
|
||||||
return text.strip().lower()
|
return text.strip().lower()
|
||||||
|
|
||||||
|
|
||||||
def _exact_match(text: str, dictionary: BrandDictionary) -> str | None:
|
def _has_db() -> bool:
|
||||||
"""Try exact match against all aliases."""
|
try:
|
||||||
|
from core.db.detect import find_brand_by_text as _
|
||||||
|
from admin.mpr.media_assets.models import KnownBrand as _
|
||||||
|
return True
|
||||||
|
except (ImportError, Exception):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _match_session(text: str, session_brands: dict[str, str]) -> str | None:
|
||||||
|
"""
|
||||||
|
Check against session brands (already seen in this source).
|
||||||
|
|
||||||
|
session_brands: {normalized_name: canonical_name, ...}
|
||||||
|
Includes aliases.
|
||||||
|
"""
|
||||||
normalized = _normalize(text)
|
normalized = _normalize(text)
|
||||||
for canonical, aliases in dictionary.brands.items():
|
return session_brands.get(normalized)
|
||||||
if normalized == _normalize(canonical):
|
|
||||||
return canonical
|
|
||||||
for alias in aliases:
|
|
||||||
if normalized == _normalize(alias):
|
|
||||||
return canonical
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _fuzzy_match(text: str, dictionary: BrandDictionary, threshold: int) -> tuple[str | None, int]:
|
def _match_known(text: str, threshold: int) -> tuple[str | None, str | None]:
|
||||||
"""Try fuzzy match, return (brand, score) or (None, 0)."""
|
"""
|
||||||
|
Check against global known brands in DB.
|
||||||
|
|
||||||
|
Returns (canonical_name, brand_id) or (None, None).
|
||||||
|
"""
|
||||||
|
if not _has_db():
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
from core.db.detect import find_brand_by_text
|
||||||
|
brand = find_brand_by_text(text)
|
||||||
|
if brand:
|
||||||
|
return brand.canonical_name, str(brand.id)
|
||||||
|
|
||||||
|
# Fuzzy match against all known brands
|
||||||
|
from core.db.detect import list_all_brands
|
||||||
|
all_brands = list_all_brands()
|
||||||
|
|
||||||
normalized = _normalize(text)
|
normalized = _normalize(text)
|
||||||
best_brand = None
|
best_brand = None
|
||||||
best_score = 0
|
best_score = 0
|
||||||
|
|
||||||
for canonical, aliases in dictionary.brands.items():
|
for known in all_brands:
|
||||||
all_names = [canonical] + aliases
|
names = [known.canonical_name] + (known.aliases or [])
|
||||||
for name in all_names:
|
for name in names:
|
||||||
score = fuzz.ratio(normalized, _normalize(name))
|
score = fuzz.ratio(normalized, _normalize(name))
|
||||||
if score > best_score and score >= threshold:
|
if score > best_score and score >= threshold:
|
||||||
best_score = score
|
best_score = score
|
||||||
best_brand = canonical
|
best_brand = known
|
||||||
|
|
||||||
return best_brand, best_score
|
if best_brand:
|
||||||
|
return best_brand.canonical_name, str(best_brand.id)
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def _register_brand(canonical_name: str, source: str) -> str | None:
|
||||||
|
"""Register a newly discovered brand in the DB. Returns brand_id."""
|
||||||
|
if not _has_db():
|
||||||
|
return None
|
||||||
|
|
||||||
|
from core.db.detect import get_or_create_brand
|
||||||
|
brand, created = get_or_create_brand(canonical_name, source=source)
|
||||||
|
if created:
|
||||||
|
logger.info("New brand discovered: %s (source=%s)", canonical_name, source)
|
||||||
|
return str(brand.id)
|
||||||
|
|
||||||
|
|
||||||
|
def _record_sighting(source_asset_id: str | None, brand_id: str,
|
||||||
|
brand_name: str, timestamp: float,
|
||||||
|
confidence: float, source: str):
|
||||||
|
"""Record a brand sighting for this source."""
|
||||||
|
if not _has_db() or not source_asset_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
from core.db.detect import record_sighting
|
||||||
|
import uuid
|
||||||
|
asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id
|
||||||
|
brand_uuid = uuid.UUID(brand_id) if isinstance(brand_id, str) else brand_id
|
||||||
|
record_sighting(asset_id, brand_uuid, brand_name, timestamp, confidence, source)
|
||||||
|
|
||||||
|
|
||||||
|
def build_session_dict(source_asset_id: str | None) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Load session brands from DB for this source.
|
||||||
|
|
||||||
|
Returns {normalized_name: canonical_name, ...} including aliases.
|
||||||
|
"""
|
||||||
|
if not _has_db() or not source_asset_id:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
from core.db.detect import get_source_sightings
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id
|
||||||
|
sightings = get_source_sightings(asset_id)
|
||||||
|
|
||||||
|
session = {}
|
||||||
|
for s in sightings:
|
||||||
|
canonical = s.brand_name
|
||||||
|
session[_normalize(canonical)] = canonical
|
||||||
|
|
||||||
|
# Also load aliases from KnownBrand for each sighted brand
|
||||||
|
if _has_db():
|
||||||
|
from core.db.detect import list_all_brands
|
||||||
|
all_brands = list_all_brands()
|
||||||
|
sighted_names = {s.brand_name for s in sightings}
|
||||||
|
for brand in all_brands:
|
||||||
|
if brand.canonical_name in sighted_names:
|
||||||
|
for alias in (brand.aliases or []):
|
||||||
|
session[_normalize(alias)] = brand.canonical_name
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
def resolve_brands(
|
def resolve_brands(
|
||||||
candidates: list[TextCandidate],
|
candidates: list[TextCandidate],
|
||||||
dictionary: BrandDictionary,
|
|
||||||
config: ResolverConfig,
|
config: ResolverConfig,
|
||||||
|
session_brands: dict[str, str] | None = None,
|
||||||
|
source_asset_id: str | None = None,
|
||||||
content_type: str = "",
|
content_type: str = "",
|
||||||
job_id: str | None = None,
|
job_id: str | None = None,
|
||||||
) -> tuple[list[BrandDetection], list[TextCandidate]]:
|
) -> tuple[list[BrandDetection], list[TextCandidate]]:
|
||||||
"""
|
"""
|
||||||
Match text candidates against the brand dictionary.
|
Match text candidates against known brands (session → global → unresolved).
|
||||||
|
|
||||||
Returns:
|
session_brands: pre-loaded session dict (from build_session_dict)
|
||||||
- matched: list of BrandDetection for confirmed brands
|
source_asset_id: for recording new sightings in DB
|
||||||
- unresolved: list of TextCandidate that couldn't be matched
|
|
||||||
"""
|
"""
|
||||||
|
if session_brands is None:
|
||||||
|
session_brands = {}
|
||||||
|
|
||||||
emit.log(job_id, "BrandResolver", "INFO",
|
emit.log(job_id, "BrandResolver", "INFO",
|
||||||
f"Matching {len(candidates)} candidates against "
|
f"Resolving {len(candidates)} candidates "
|
||||||
f"{len(dictionary.brands)} brands (fuzzy_threshold={config.fuzzy_threshold})")
|
f"(session={len(session_brands)} brands, fuzzy={config.fuzzy_threshold})")
|
||||||
|
|
||||||
matched: list[BrandDetection] = []
|
matched: list[BrandDetection] = []
|
||||||
unresolved: list[TextCandidate] = []
|
unresolved: list[TextCandidate] = []
|
||||||
exact_count = 0
|
session_hits = 0
|
||||||
fuzzy_count = 0
|
known_hits = 0
|
||||||
|
|
||||||
for candidate in candidates:
|
for candidate in candidates:
|
||||||
# Try exact match first
|
text = candidate.text
|
||||||
brand = _exact_match(candidate.text, dictionary)
|
brand_name = None
|
||||||
source = "ocr"
|
brand_id = None
|
||||||
|
match_source = "ocr"
|
||||||
|
|
||||||
if brand:
|
# 1. Check session (cheapest — in-memory dict)
|
||||||
exact_count += 1
|
brand_name = _match_session(text, session_brands)
|
||||||
|
if brand_name:
|
||||||
|
session_hits += 1
|
||||||
else:
|
else:
|
||||||
# Try fuzzy match
|
# 2. Check global known brands (DB query + fuzzy)
|
||||||
brand, score = _fuzzy_match(candidate.text, dictionary, config.fuzzy_threshold)
|
brand_name, brand_id = _match_known(text, config.fuzzy_threshold)
|
||||||
if brand:
|
if brand_name:
|
||||||
fuzzy_count += 1
|
known_hits += 1
|
||||||
|
# Add to session for subsequent candidates in this run
|
||||||
|
session_brands[_normalize(brand_name)] = brand_name
|
||||||
|
|
||||||
if brand:
|
if brand_name:
|
||||||
detection = BrandDetection(
|
detection = BrandDetection(
|
||||||
brand=brand,
|
brand=brand_name,
|
||||||
timestamp=candidate.frame.timestamp,
|
timestamp=candidate.frame.timestamp,
|
||||||
duration=0.5,
|
duration=0.5,
|
||||||
confidence=candidate.ocr_confidence,
|
confidence=candidate.ocr_confidence,
|
||||||
source=source,
|
source=match_source,
|
||||||
bbox=candidate.bbox,
|
bbox=candidate.bbox,
|
||||||
frame_ref=candidate.frame.sequence,
|
frame_ref=candidate.frame.sequence,
|
||||||
content_type=content_type,
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
matched.append(detection)
|
matched.append(detection)
|
||||||
|
|
||||||
|
# Record sighting in DB
|
||||||
|
if brand_id:
|
||||||
|
_record_sighting(
|
||||||
|
source_asset_id, brand_id, brand_name,
|
||||||
|
candidate.frame.timestamp, candidate.ocr_confidence, match_source,
|
||||||
|
)
|
||||||
|
|
||||||
emit.detection(
|
emit.detection(
|
||||||
job_id,
|
job_id,
|
||||||
brand=brand,
|
brand=brand_name,
|
||||||
confidence=candidate.ocr_confidence,
|
confidence=candidate.ocr_confidence,
|
||||||
source=source,
|
source=match_source,
|
||||||
timestamp=candidate.frame.timestamp,
|
timestamp=candidate.frame.timestamp,
|
||||||
content_type=content_type,
|
content_type=content_type,
|
||||||
frame_ref=candidate.frame.sequence,
|
frame_ref=candidate.frame.sequence,
|
||||||
@@ -115,7 +223,7 @@ def resolve_brands(
|
|||||||
unresolved.append(candidate)
|
unresolved.append(candidate)
|
||||||
|
|
||||||
emit.log(job_id, "BrandResolver", "INFO",
|
emit.log(job_id, "BrandResolver", "INFO",
|
||||||
f"Exact: {exact_count}, Fuzzy: {fuzzy_count}, "
|
f"Session: {session_hits}, Known: {known_hits}, "
|
||||||
f"Unresolved: {len(unresolved)} → escalating to VLM")
|
f"Unresolved: {len(unresolved)} → escalating")
|
||||||
|
|
||||||
return matched, unresolved
|
return matched, unresolved
|
||||||
|
|||||||
@@ -27,6 +27,18 @@ logger = logging.getLogger(__name__)
|
|||||||
ESTIMATED_TOKENS_PER_CROP = 500
|
ESTIMATED_TOKENS_PER_CROP = 500
|
||||||
|
|
||||||
|
|
||||||
|
def _register_discovered_brand(brand: str, source_asset_id: str | None,
|
||||||
|
timestamp: float, confidence: float):
|
||||||
|
"""Register a cloud-confirmed brand in the DB."""
|
||||||
|
try:
|
||||||
|
from detect.stages.brand_resolver import _register_brand, _record_sighting
|
||||||
|
brand_id = _register_brand(brand, "cloud_llm")
|
||||||
|
if brand_id and source_asset_id:
|
||||||
|
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, "cloud_llm")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Failed to register brand %s: %s", brand, e)
|
||||||
|
|
||||||
|
|
||||||
def _encode_crop(crop: np.ndarray) -> str:
|
def _encode_crop(crop: np.ndarray) -> str:
|
||||||
img = Image.fromarray(crop)
|
img = Image.fromarray(crop)
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
@@ -84,6 +96,7 @@ def escalate_cloud(
|
|||||||
stats: PipelineStats,
|
stats: PipelineStats,
|
||||||
min_confidence: float = 0.4,
|
min_confidence: float = 0.4,
|
||||||
content_type: str = "",
|
content_type: str = "",
|
||||||
|
source_asset_id: str | None = None,
|
||||||
job_id: str | None = None,
|
job_id: str | None = None,
|
||||||
) -> list[BrandDetection]:
|
) -> list[BrandDetection]:
|
||||||
"""
|
"""
|
||||||
@@ -158,6 +171,10 @@ def escalate_cloud(
|
|||||||
frame_ref=candidate.frame.sequence,
|
frame_ref=candidate.frame.sequence,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Register newly discovered brand in DB
|
||||||
|
_register_discovered_brand(brand, source_asset_id,
|
||||||
|
candidate.frame.timestamp, confidence)
|
||||||
|
|
||||||
stats.estimated_cloud_cost_usd += total_cost
|
stats.estimated_cloud_cost_usd += total_cost
|
||||||
stats.regions_escalated_to_cloud_llm = len(candidates)
|
stats.regions_escalated_to_cloud_llm = len(candidates)
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,18 @@ from detect.profiles.base import CropContext
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _register_discovered_brand(brand: str, source_asset_id: str | None,
|
||||||
|
timestamp: float, confidence: float, source: str):
|
||||||
|
"""Register a VLM-confirmed brand in the DB."""
|
||||||
|
try:
|
||||||
|
from detect.stages.brand_resolver import _register_brand, _record_sighting
|
||||||
|
brand_id = _register_brand(brand, source)
|
||||||
|
if brand_id and source_asset_id:
|
||||||
|
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, source)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Failed to register brand %s: %s", brand, e)
|
||||||
|
|
||||||
|
|
||||||
def _crop_image(candidate: TextCandidate) -> np.ndarray:
|
def _crop_image(candidate: TextCandidate) -> np.ndarray:
|
||||||
frame = candidate.frame
|
frame = candidate.frame
|
||||||
box = candidate.bbox
|
box = candidate.bbox
|
||||||
@@ -36,6 +48,7 @@ def escalate_vlm(
|
|||||||
inference_url: str | None = None,
|
inference_url: str | None = None,
|
||||||
min_confidence: float = 0.5,
|
min_confidence: float = 0.5,
|
||||||
content_type: str = "",
|
content_type: str = "",
|
||||||
|
source_asset_id: str | None = None,
|
||||||
job_id: str | None = None,
|
job_id: str | None = None,
|
||||||
) -> tuple[list[BrandDetection], list[TextCandidate]]:
|
) -> tuple[list[BrandDetection], list[TextCandidate]]:
|
||||||
"""
|
"""
|
||||||
@@ -107,6 +120,10 @@ def escalate_vlm(
|
|||||||
frame_ref=candidate.frame.sequence,
|
frame_ref=candidate.frame.sequence,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Register newly discovered brand in DB
|
||||||
|
_register_discovered_brand(brand, source_asset_id,
|
||||||
|
candidate.frame.timestamp, confidence, "local_vlm")
|
||||||
|
|
||||||
logger.debug("VLM matched: %s (%.2f) — %s", brand, confidence, reasoning)
|
logger.debug("VLM matched: %s (%.2f) — %s", brand, confidence, reasoning)
|
||||||
else:
|
else:
|
||||||
still_unresolved.append(candidate)
|
still_unresolved.append(candidate)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ class DetectState(TypedDict, total=False):
|
|||||||
video_path: str
|
video_path: str
|
||||||
job_id: str
|
job_id: str
|
||||||
profile_name: str
|
profile_name: str
|
||||||
|
source_asset_id: str # UUID of the source MediaAsset
|
||||||
|
|
||||||
# Stage outputs
|
# Stage outputs
|
||||||
frames: list[Frame]
|
frames: list[Frame]
|
||||||
@@ -27,5 +28,11 @@ class DetectState(TypedDict, total=False):
|
|||||||
detections: list[BrandDetection]
|
detections: list[BrandDetection]
|
||||||
report: DetectionReport
|
report: DetectionReport
|
||||||
|
|
||||||
|
# Session brands (accumulated during the run, persisted to DB)
|
||||||
|
session_brands: dict # {normalized_name: canonical_name}
|
||||||
|
|
||||||
# Running stats (updated by each stage)
|
# Running stats (updated by each stage)
|
||||||
stats: PipelineStats
|
stats: PipelineStats
|
||||||
|
|
||||||
|
# Config overrides for replay (applied via OverrideProfile)
|
||||||
|
config_overrides: dict
|
||||||
|
|||||||
@@ -31,8 +31,16 @@ def ts():
|
|||||||
return datetime.now(timezone.utc).isoformat()
|
return datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
RUN_CONTEXT = {}
|
||||||
|
|
||||||
|
|
||||||
|
def set_run_context(run_id: str, parent_job_id: str, run_type: str = "initial"):
|
||||||
|
RUN_CONTEXT.update({"run_id": run_id, "parent_job_id": parent_job_id, "run_type": run_type})
|
||||||
|
|
||||||
|
|
||||||
def push(r, key, event):
|
def push(r, key, event):
|
||||||
event["ts"] = event.get("ts", ts())
|
event["ts"] = event.get("ts", ts())
|
||||||
|
event.update(RUN_CONTEXT)
|
||||||
r.rpush(key, json.dumps(event))
|
r.rpush(key, json.dumps(event))
|
||||||
return event
|
return event
|
||||||
|
|
||||||
@@ -85,7 +93,11 @@ def main():
|
|||||||
r.delete(key)
|
r.delete(key)
|
||||||
delay = args.delay
|
delay = args.delay
|
||||||
|
|
||||||
|
run_id = f"{args.job[:8]}-r1"
|
||||||
|
set_run_context(run_id=run_id, parent_job_id=args.job, run_type="initial")
|
||||||
|
|
||||||
logger.info("Full escalation pipeline simulation → %s", key)
|
logger.info("Full escalation pipeline simulation → %s", key)
|
||||||
|
logger.info("Run: %s (parent: %s)", run_id, args.job)
|
||||||
logger.info("Open: http://mpr.local.ar/detection/?job=%s", args.job)
|
logger.info("Open: http://mpr.local.ar/detection/?job=%s", args.job)
|
||||||
input("\nPress Enter to start...")
|
input("\nPress Enter to start...")
|
||||||
|
|
||||||
|
|||||||
123
tests/detect/manual/test_replay.py
Normal file
123
tests/detect/manual/test_replay.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test checkpoint + replay flow end-to-end.
|
||||||
|
|
||||||
|
1. Runs the pipeline with checkpointing enabled on a test video
|
||||||
|
2. Lists available checkpoints
|
||||||
|
3. Replays from run_ocr with different config
|
||||||
|
4. Compares detection counts
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
MPR_CHECKPOINT=1 INFERENCE_URL=http://mcrndeb:8000 python tests/detect/manual/test_replay.py [--job JOB_ID]
|
||||||
|
|
||||||
|
Requires: inference server running, MinIO/S3 running, test video available
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Load ctrl/.env
|
||||||
|
env_file = Path(__file__).resolve().parents[3] / "ctrl" / ".env"
|
||||||
|
if env_file.exists():
|
||||||
|
for line in env_file.read_text().splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if line and not line.startswith("#") and "=" in line:
|
||||||
|
key, _, value = line.partition("=")
|
||||||
|
os.environ.setdefault(key.strip(), value.strip())
|
||||||
|
|
||||||
|
sys.path.insert(0, ".")
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Force checkpointing on
|
||||||
|
os.environ["MPR_CHECKPOINT"] = "1"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
import time
|
||||||
|
default_job = f"replay-{int(time.time()) % 100000}"
|
||||||
|
parser.add_argument("--job", default=default_job)
|
||||||
|
parser.add_argument("--port", type=int, default=6382)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Override Redis to localhost (ctrl/.env has k8s hostname)
|
||||||
|
os.environ["REDIS_URL"] = f"redis://localhost:{args.port}/0"
|
||||||
|
|
||||||
|
from detect.graph import get_pipeline, NODES
|
||||||
|
from detect.checkpoint import list_checkpoints
|
||||||
|
from detect.checkpoint import replay_from
|
||||||
|
from detect.state import DetectState
|
||||||
|
|
||||||
|
VIDEO = "media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4"
|
||||||
|
|
||||||
|
logger.info("Job: %s", args.job)
|
||||||
|
logger.info("Checkpoint: enabled")
|
||||||
|
logger.info("Video: %s", VIDEO)
|
||||||
|
logger.info("Open: http://mpr.local.ar/detection/?job=%s", args.job)
|
||||||
|
input("\nPress Enter to run initial pipeline...")
|
||||||
|
|
||||||
|
# --- Initial run ---
|
||||||
|
pipeline = get_pipeline(checkpoint=True)
|
||||||
|
initial_state = DetectState(
|
||||||
|
video_path=VIDEO,
|
||||||
|
job_id=args.job,
|
||||||
|
profile_name="soccer_broadcast",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Running initial pipeline...")
|
||||||
|
result = pipeline.invoke(initial_state)
|
||||||
|
|
||||||
|
detections = result.get("detections", [])
|
||||||
|
report = result.get("report")
|
||||||
|
logger.info("Initial run: %d detections, %d brands",
|
||||||
|
len(detections), len(report.brands) if report else 0)
|
||||||
|
|
||||||
|
# --- List checkpoints ---
|
||||||
|
stages = list_checkpoints(args.job)
|
||||||
|
logger.info("Available checkpoints: %s", stages)
|
||||||
|
|
||||||
|
if "detect_objects" not in stages:
|
||||||
|
logger.error("Expected checkpoint for detect_objects — aborting replay test")
|
||||||
|
return
|
||||||
|
|
||||||
|
input("\nPress Enter to replay from run_ocr with different config...")
|
||||||
|
|
||||||
|
# --- Replay with different OCR config ---
|
||||||
|
overrides = {"ocr": {"min_confidence": 0.3, "languages": ["en", "es"]}}
|
||||||
|
logger.info("Replaying from run_ocr with overrides: %s", overrides)
|
||||||
|
|
||||||
|
replay_result = replay_from(
|
||||||
|
job_id=args.job,
|
||||||
|
start_stage="run_ocr",
|
||||||
|
config_overrides=overrides,
|
||||||
|
)
|
||||||
|
|
||||||
|
replay_detections = replay_result.get("detections", [])
|
||||||
|
replay_report = replay_result.get("report")
|
||||||
|
logger.info("Replay run: %d detections, %d brands",
|
||||||
|
len(replay_detections),
|
||||||
|
len(replay_report.brands) if replay_report else 0)
|
||||||
|
|
||||||
|
# --- Compare ---
|
||||||
|
logger.info("--- Comparison ---")
|
||||||
|
logger.info("Initial: %d detections", len(detections))
|
||||||
|
logger.info("Replay: %d detections (min_confidence 0.5 → 0.3)", len(replay_detections))
|
||||||
|
|
||||||
|
diff = len(replay_detections) - len(detections)
|
||||||
|
if diff > 0:
|
||||||
|
logger.info("Replay found %d more detections with lower threshold", diff)
|
||||||
|
elif diff == 0:
|
||||||
|
logger.info("Same count — threshold change didn't affect this video")
|
||||||
|
else:
|
||||||
|
logger.warning("Replay found fewer detections — unexpected")
|
||||||
|
|
||||||
|
logger.info("Done.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,20 +1,13 @@
|
|||||||
"""Tests for BrandResolver stage."""
|
"""Tests for BrandResolver stage (discovery mode)."""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from detect.models import BoundingBox, Frame, TextCandidate
|
from detect.models import BoundingBox, Frame, TextCandidate
|
||||||
from detect.profiles.base import BrandDictionary, ResolverConfig
|
from detect.profiles.base import ResolverConfig
|
||||||
from detect.stages.brand_resolver import resolve_brands, _exact_match, _fuzzy_match
|
from detect.stages.brand_resolver import resolve_brands, _normalize, _match_session
|
||||||
|
|
||||||
|
|
||||||
DICTIONARY = BrandDictionary(brands={
|
|
||||||
"Nike": ["nike", "NIKE", "swoosh"],
|
|
||||||
"Adidas": ["adidas", "ADIDAS"],
|
|
||||||
"Coca-Cola": ["coca-cola", "coca cola", "coke", "COCA-COLA"],
|
|
||||||
"Emirates": ["emirates", "fly emirates", "EMIRATES"],
|
|
||||||
})
|
|
||||||
|
|
||||||
CONFIG = ResolverConfig(fuzzy_threshold=75)
|
CONFIG = ResolverConfig(fuzzy_threshold=75)
|
||||||
|
|
||||||
|
|
||||||
@@ -25,57 +18,76 @@ def _make_candidate(text: str, confidence: float = 0.9) -> TextCandidate:
|
|||||||
return TextCandidate(frame=dummy_frame, bbox=dummy_box, text=text, ocr_confidence=confidence)
|
return TextCandidate(frame=dummy_frame, bbox=dummy_box, text=text, ocr_confidence=confidence)
|
||||||
|
|
||||||
|
|
||||||
def test_exact_match():
|
def test_session_match():
|
||||||
assert _exact_match("Nike", DICTIONARY) == "Nike"
|
session = {"nike": "Nike", "fly emirates": "Emirates"}
|
||||||
assert _exact_match("nike", DICTIONARY) == "Nike"
|
assert _match_session("Nike", session) == "Nike"
|
||||||
assert _exact_match("COCA-COLA", DICTIONARY) == "Coca-Cola"
|
assert _match_session("nike", session) == "Nike"
|
||||||
assert _exact_match("fly emirates", DICTIONARY) == "Emirates"
|
assert _match_session("FLY EMIRATES", session) == "Emirates"
|
||||||
assert _exact_match("unknown brand", DICTIONARY) is None
|
assert _match_session("unknown", session) is None
|
||||||
|
|
||||||
|
|
||||||
def test_fuzzy_match():
|
def test_resolve_with_session(monkeypatch):
|
||||||
brand, score = _fuzzy_match("Nik3", DICTIONARY, threshold=75)
|
events = []
|
||||||
assert brand == "Nike"
|
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||||
assert score >= 75
|
lambda job_id, etype, data: events.append((etype, data)))
|
||||||
|
|
||||||
brand, score = _fuzzy_match("adldas", DICTIONARY, threshold=75)
|
session = {"nike": "Nike", "emirates": "Emirates"}
|
||||||
assert brand == "Adidas"
|
|
||||||
|
|
||||||
brand, score = _fuzzy_match("xyzxyzxyz", DICTIONARY, threshold=75)
|
|
||||||
assert brand is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_exact():
|
|
||||||
candidates = [_make_candidate("Nike"), _make_candidate("EMIRATES")]
|
candidates = [_make_candidate("Nike"), _make_candidate("EMIRATES")]
|
||||||
matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG)
|
|
||||||
|
matched, unresolved = resolve_brands(
|
||||||
|
candidates, CONFIG, session_brands=session,
|
||||||
|
)
|
||||||
|
|
||||||
assert len(matched) == 2
|
assert len(matched) == 2
|
||||||
assert len(unresolved) == 0
|
assert len(unresolved) == 0
|
||||||
assert matched[0].brand == "Nike"
|
assert matched[0].brand == "Nike"
|
||||||
assert matched[1].brand == "Emirates"
|
assert matched[1].brand == "Emirates"
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_fuzzy():
|
def test_resolve_unresolved_without_db(monkeypatch):
|
||||||
candidates = [_make_candidate("coca coIa")] # OCR misread
|
events = []
|
||||||
matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG)
|
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||||
assert len(matched) == 1
|
lambda job_id, etype, data: events.append((etype, data)))
|
||||||
assert matched[0].brand == "Coca-Cola"
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_unresolved():
|
|
||||||
candidates = [_make_candidate("random garbage text")]
|
candidates = [_make_candidate("random garbage text")]
|
||||||
matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG)
|
|
||||||
|
matched, unresolved = resolve_brands(
|
||||||
|
candidates, CONFIG, session_brands={},
|
||||||
|
)
|
||||||
|
|
||||||
assert len(matched) == 0
|
assert len(matched) == 0
|
||||||
assert len(unresolved) == 1
|
assert len(unresolved) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_mixed():
|
def test_resolve_empty(monkeypatch):
|
||||||
|
events = []
|
||||||
|
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||||
|
lambda job_id, etype, data: events.append((etype, data)))
|
||||||
|
|
||||||
|
matched, unresolved = resolve_brands([], CONFIG, session_brands={})
|
||||||
|
|
||||||
|
assert len(matched) == 0
|
||||||
|
assert len(unresolved) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_builds_session_during_run(monkeypatch):
|
||||||
|
"""Session brands accumulate during a single run — second candidate benefits."""
|
||||||
|
events = []
|
||||||
|
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||||
|
lambda job_id, etype, data: events.append((etype, data)))
|
||||||
|
|
||||||
|
session = {"nike": "Nike"}
|
||||||
candidates = [
|
candidates = [
|
||||||
_make_candidate("Nike"),
|
_make_candidate("Nike"), # hits session
|
||||||
_make_candidate("unknown"),
|
_make_candidate("unknown"), # misses everything
|
||||||
_make_candidate("adldas"),
|
|
||||||
]
|
]
|
||||||
matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG)
|
|
||||||
assert len(matched) == 2 # Nike exact + Adidas fuzzy
|
matched, unresolved = resolve_brands(
|
||||||
|
candidates, CONFIG, session_brands=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(matched) == 1
|
||||||
|
assert matched[0].brand == "Nike"
|
||||||
assert len(unresolved) == 1
|
assert len(unresolved) == 1
|
||||||
|
|
||||||
|
|
||||||
@@ -84,8 +96,10 @@ def test_events_emitted(monkeypatch):
|
|||||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||||
lambda job_id, etype, data: events.append((etype, data)))
|
lambda job_id, etype, data: events.append((etype, data)))
|
||||||
|
|
||||||
|
session = {"nike": "Nike"}
|
||||||
candidates = [_make_candidate("Nike")]
|
candidates = [_make_candidate("Nike")]
|
||||||
resolve_brands(candidates, DICTIONARY, CONFIG, job_id="test-job")
|
|
||||||
|
resolve_brands(candidates, CONFIG, session_brands=session, job_id="test-job")
|
||||||
|
|
||||||
event_types = [e[0] for e in events]
|
event_types = [e[0] for e in events]
|
||||||
assert "log" in event_types
|
assert "log" in event_types
|
||||||
|
|||||||
182
tests/detect/test_checkpoint.py
Normal file
182
tests/detect/test_checkpoint.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
"""Tests for checkpoint serialization — round-trip without S3."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
|
||||||
|
from detect.checkpoint.serializer import (
|
||||||
|
serialize_state,
|
||||||
|
deserialize_state,
|
||||||
|
serialize_frame_meta,
|
||||||
|
serialize_text_candidate,
|
||||||
|
deserialize_text_candidate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_frame(seq: int = 0, w: int = 100, h: int = 80) -> Frame:
|
||||||
|
image = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
|
||||||
|
return Frame(
|
||||||
|
sequence=seq,
|
||||||
|
chunk_id=0,
|
||||||
|
timestamp=float(seq) * 0.5,
|
||||||
|
image=image,
|
||||||
|
perceptual_hash=f"hash_{seq}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_box(x=10, y=10, w=30, h=20) -> BoundingBox:
|
||||||
|
return BoundingBox(x=x, y=y, w=w, h=h, confidence=0.9, label="text")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_candidate(frame: Frame, text: str = "NIKE") -> TextCandidate:
|
||||||
|
box = _make_box()
|
||||||
|
return TextCandidate(frame=frame, bbox=box, text=text, ocr_confidence=0.85)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_detection(brand: str = "Nike", timestamp: float = 1.0) -> BrandDetection:
|
||||||
|
return BrandDetection(
|
||||||
|
brand=brand,
|
||||||
|
timestamp=timestamp,
|
||||||
|
duration=0.5,
|
||||||
|
confidence=0.92,
|
||||||
|
source="ocr",
|
||||||
|
content_type="soccer_broadcast",
|
||||||
|
frame_ref=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Frame metadata ---
|
||||||
|
|
||||||
|
def test_serialize_frame_meta():
|
||||||
|
frame = _make_frame(seq=5)
|
||||||
|
meta = serialize_frame_meta(frame)
|
||||||
|
|
||||||
|
assert meta["sequence"] == 5
|
||||||
|
assert meta["timestamp"] == 2.5
|
||||||
|
assert meta["perceptual_hash"] == "hash_5"
|
||||||
|
assert "image" not in meta
|
||||||
|
|
||||||
|
|
||||||
|
# --- TextCandidate ---
|
||||||
|
|
||||||
|
def test_serialize_text_candidate():
|
||||||
|
frame = _make_frame()
|
||||||
|
candidate = _make_candidate(frame, text="EMIRATES")
|
||||||
|
data = serialize_text_candidate(candidate)
|
||||||
|
|
||||||
|
assert data["frame_sequence"] == 0
|
||||||
|
assert data["text"] == "EMIRATES"
|
||||||
|
assert data["ocr_confidence"] == 0.85
|
||||||
|
assert "bbox" in data
|
||||||
|
|
||||||
|
|
||||||
|
def test_deserialize_text_candidate():
|
||||||
|
frame = _make_frame()
|
||||||
|
candidate = _make_candidate(frame, text="ADIDAS")
|
||||||
|
|
||||||
|
serialized = serialize_text_candidate(candidate)
|
||||||
|
frame_map = {frame.sequence: frame}
|
||||||
|
|
||||||
|
restored = deserialize_text_candidate(serialized, frame_map)
|
||||||
|
|
||||||
|
assert restored.text == "ADIDAS"
|
||||||
|
assert restored.ocr_confidence == 0.85
|
||||||
|
assert restored.frame is frame # same object reference
|
||||||
|
assert restored.bbox.x == 10
|
||||||
|
|
||||||
|
|
||||||
|
# --- Full state round-trip ---
|
||||||
|
|
||||||
|
def test_state_round_trip():
|
||||||
|
frames = [_make_frame(seq=i) for i in range(3)]
|
||||||
|
filtered = frames[:2]
|
||||||
|
|
||||||
|
box = _make_box()
|
||||||
|
boxes_by_frame = {0: [box], 1: [box]}
|
||||||
|
|
||||||
|
candidates = [_make_candidate(frames[0], "NIKE"), _make_candidate(frames[1], "EMIRATES")]
|
||||||
|
unresolved = [_make_candidate(frames[2], "unknown")]
|
||||||
|
detections = [_make_detection("Nike", 0.5), _make_detection("Emirates", 1.0)]
|
||||||
|
|
||||||
|
stats = PipelineStats(
|
||||||
|
frames_extracted=3,
|
||||||
|
frames_after_scene_filter=2,
|
||||||
|
regions_detected=2,
|
||||||
|
regions_resolved_by_ocr=2,
|
||||||
|
cloud_llm_calls=1,
|
||||||
|
estimated_cloud_cost_usd=0.003,
|
||||||
|
)
|
||||||
|
|
||||||
|
state = {
|
||||||
|
"job_id": "test-123",
|
||||||
|
"video_path": "/tmp/test.mp4",
|
||||||
|
"profile_name": "soccer_broadcast",
|
||||||
|
"config_overrides": {"ocr": {"min_confidence": 0.3}},
|
||||||
|
"frames": frames,
|
||||||
|
"filtered_frames": filtered,
|
||||||
|
"boxes_by_frame": boxes_by_frame,
|
||||||
|
"text_candidates": candidates,
|
||||||
|
"unresolved_candidates": unresolved,
|
||||||
|
"detections": detections,
|
||||||
|
"stats": stats,
|
||||||
|
}
|
||||||
|
|
||||||
|
manifest = {f.sequence: f"s3://fake/frames/{f.sequence}.jpg" for f in frames}
|
||||||
|
|
||||||
|
# Serialize
|
||||||
|
serialized = serialize_state(state, manifest)
|
||||||
|
|
||||||
|
# Verify JSON-compatible (no numpy, no Frame objects)
|
||||||
|
import json
|
||||||
|
json_str = json.dumps(serialized, default=str)
|
||||||
|
assert len(json_str) > 0
|
||||||
|
|
||||||
|
# Deserialize with the original frames (simulating frame load from S3)
|
||||||
|
restored = deserialize_state(serialized, frames)
|
||||||
|
|
||||||
|
# Verify round-trip
|
||||||
|
assert restored["job_id"] == "test-123"
|
||||||
|
assert restored["video_path"] == "/tmp/test.mp4"
|
||||||
|
assert restored["profile_name"] == "soccer_broadcast"
|
||||||
|
assert restored["config_overrides"] == {"ocr": {"min_confidence": 0.3}}
|
||||||
|
|
||||||
|
assert len(restored["frames"]) == 3
|
||||||
|
assert len(restored["filtered_frames"]) == 2
|
||||||
|
assert len(restored["boxes_by_frame"]) == 2
|
||||||
|
assert len(restored["text_candidates"]) == 2
|
||||||
|
assert len(restored["unresolved_candidates"]) == 1
|
||||||
|
assert len(restored["detections"]) == 2
|
||||||
|
|
||||||
|
restored_stats = restored["stats"]
|
||||||
|
assert restored_stats.frames_extracted == 3
|
||||||
|
assert restored_stats.cloud_llm_calls == 1
|
||||||
|
assert restored_stats.estimated_cloud_cost_usd == 0.003
|
||||||
|
|
||||||
|
# TextCandidate frame references should point to actual Frame objects
|
||||||
|
tc = restored["text_candidates"][0]
|
||||||
|
assert tc.frame is frames[0]
|
||||||
|
assert tc.text == "NIKE"
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_round_trip_empty():
|
||||||
|
"""Empty state should serialize/deserialize cleanly."""
|
||||||
|
state = {
|
||||||
|
"job_id": "empty-job",
|
||||||
|
"video_path": "",
|
||||||
|
"profile_name": "soccer_broadcast",
|
||||||
|
"frames": [],
|
||||||
|
"filtered_frames": [],
|
||||||
|
"boxes_by_frame": {},
|
||||||
|
"text_candidates": [],
|
||||||
|
"unresolved_candidates": [],
|
||||||
|
"detections": [],
|
||||||
|
"stats": PipelineStats(),
|
||||||
|
}
|
||||||
|
|
||||||
|
serialized = serialize_state(state, {})
|
||||||
|
restored = deserialize_state(serialized, [])
|
||||||
|
|
||||||
|
assert restored["job_id"] == "empty-job"
|
||||||
|
assert len(restored["frames"]) == 0
|
||||||
|
assert len(restored["detections"]) == 0
|
||||||
|
assert restored["stats"].frames_extracted == 0
|
||||||
@@ -25,11 +25,9 @@ def test_soccer_detection_config():
|
|||||||
assert isinstance(cfg.target_classes, list)
|
assert isinstance(cfg.target_classes, list)
|
||||||
|
|
||||||
|
|
||||||
def test_soccer_brand_dictionary_non_empty():
|
def test_soccer_resolver_config():
|
||||||
bd = SoccerBroadcastProfile().brand_dictionary()
|
cfg = SoccerBroadcastProfile().resolver_config()
|
||||||
assert len(bd.brands) > 0
|
assert cfg.fuzzy_threshold > 0
|
||||||
for canonical, aliases in bd.brands.items():
|
|
||||||
assert len(aliases) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_soccer_vlm_prompt():
|
def test_soccer_vlm_prompt():
|
||||||
@@ -70,4 +68,4 @@ def test_stubs_raise(stub_cls):
|
|||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
stub.frame_extraction_config()
|
stub.frame_extraction_config()
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
stub.brand_dictionary()
|
stub.resolver_config()
|
||||||
|
|||||||
67
tests/detect/test_replay.py
Normal file
67
tests/detect/test_replay.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""Tests for replay and OverrideProfile."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from detect.profiles.soccer import SoccerBroadcastProfile
|
||||||
|
from detect.checkpoint.replay import OverrideProfile
|
||||||
|
|
||||||
|
|
||||||
|
def test_override_profile_patches_ocr():
|
||||||
|
base = SoccerBroadcastProfile()
|
||||||
|
overrides = {"ocr": {"min_confidence": 0.3, "languages": ["en", "es", "pt"]}}
|
||||||
|
profile = OverrideProfile(base, overrides)
|
||||||
|
|
||||||
|
config = profile.ocr_config()
|
||||||
|
|
||||||
|
assert config.min_confidence == 0.3
|
||||||
|
assert config.languages == ["en", "es", "pt"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_override_profile_patches_resolver():
|
||||||
|
base = SoccerBroadcastProfile()
|
||||||
|
overrides = {"resolver": {"fuzzy_threshold": 60}}
|
||||||
|
profile = OverrideProfile(base, overrides)
|
||||||
|
|
||||||
|
config = profile.resolver_config()
|
||||||
|
|
||||||
|
assert config.fuzzy_threshold == 60
|
||||||
|
|
||||||
|
|
||||||
|
def test_override_profile_patches_detection():
|
||||||
|
base = SoccerBroadcastProfile()
|
||||||
|
overrides = {"detection": {"confidence_threshold": 0.5}}
|
||||||
|
profile = OverrideProfile(base, overrides)
|
||||||
|
|
||||||
|
config = profile.detection_config()
|
||||||
|
|
||||||
|
assert config.confidence_threshold == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def test_override_profile_no_overrides():
|
||||||
|
base = SoccerBroadcastProfile()
|
||||||
|
profile = OverrideProfile(base, {})
|
||||||
|
|
||||||
|
ocr = profile.ocr_config()
|
||||||
|
base_ocr = base.ocr_config()
|
||||||
|
|
||||||
|
assert ocr.min_confidence == base_ocr.min_confidence
|
||||||
|
assert ocr.languages == base_ocr.languages
|
||||||
|
|
||||||
|
|
||||||
|
def test_override_profile_delegates_non_config():
|
||||||
|
base = SoccerBroadcastProfile()
|
||||||
|
profile = OverrideProfile(base, {"ocr": {"min_confidence": 0.1}})
|
||||||
|
|
||||||
|
assert profile.name == "soccer_broadcast"
|
||||||
|
assert profile.resolver_config().fuzzy_threshold > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_override_profile_ignores_unknown_fields():
|
||||||
|
base = SoccerBroadcastProfile()
|
||||||
|
overrides = {"ocr": {"nonexistent_field": 42}}
|
||||||
|
profile = OverrideProfile(base, overrides)
|
||||||
|
|
||||||
|
config = profile.ocr_config()
|
||||||
|
|
||||||
|
assert not hasattr(config, "nonexistent_field")
|
||||||
|
assert config.min_confidence == base.ocr_config().min_confidence
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref } from 'vue'
|
import { ref } from 'vue'
|
||||||
import { SSEDataSource, Panel, LayoutGrid } from 'mpr-ui-framework'
|
import { SSEDataSource, Panel, ResizeHandle } from 'mpr-ui-framework'
|
||||||
import 'mpr-ui-framework/src/tokens.css'
|
import 'mpr-ui-framework/src/tokens.css'
|
||||||
import LogPanel from './panels/LogPanel.vue'
|
import LogPanel from './panels/LogPanel.vue'
|
||||||
import FunnelPanel from './panels/FunnelPanel.vue'
|
import FunnelPanel from './panels/FunnelPanel.vue'
|
||||||
@@ -9,10 +9,11 @@ import FramePanel from './panels/FramePanel.vue'
|
|||||||
import BrandTablePanel from './panels/BrandTablePanel.vue'
|
import BrandTablePanel from './panels/BrandTablePanel.vue'
|
||||||
import TimelinePanel from './panels/TimelinePanel.vue'
|
import TimelinePanel from './panels/TimelinePanel.vue'
|
||||||
import CostStatsPanel from './panels/CostStatsPanel.vue'
|
import CostStatsPanel from './panels/CostStatsPanel.vue'
|
||||||
import type { StatsUpdate } from './types/sse-contract'
|
import type { StatsUpdate, RunContext } from './types/sse-contract'
|
||||||
|
|
||||||
const jobId = ref(new URLSearchParams(window.location.search).get('job') || 'test-job')
|
const jobId = ref(new URLSearchParams(window.location.search).get('job') || 'test-job')
|
||||||
const stats = ref<StatsUpdate | null>(null)
|
const stats = ref<StatsUpdate | null>(null)
|
||||||
|
const runContext = ref<RunContext | null>(null)
|
||||||
const status = ref<'idle' | 'live' | 'processing' | 'error'>('idle')
|
const status = ref<'idle' | 'live' | 'processing' | 'error'>('idle')
|
||||||
|
|
||||||
const source = new SSEDataSource({
|
const source = new SSEDataSource({
|
||||||
@@ -23,8 +24,41 @@ const source = new SSEDataSource({
|
|||||||
|
|
||||||
source.on<StatsUpdate>('stats_update', (e) => {
|
source.on<StatsUpdate>('stats_update', (e) => {
|
||||||
stats.value = e
|
stats.value = e
|
||||||
|
// Capture run context from first event that carries it
|
||||||
|
if (!runContext.value && e.run_id) {
|
||||||
|
runContext.value = {
|
||||||
|
run_id: (e as any).run_id,
|
||||||
|
parent_job_id: (e as any).parent_job_id,
|
||||||
|
run_type: (e as any).run_type ?? 'initial',
|
||||||
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Resizable splits
|
||||||
|
const pipelineWidth = ref(320)
|
||||||
|
const detectionsFlex = ref(3) // ratio for detections vs stats
|
||||||
|
const viewerHeight = ref(240)
|
||||||
|
const timelineFlex = ref(1)
|
||||||
|
const tableFlex = ref(1)
|
||||||
|
|
||||||
|
function onPipelineResize(delta: number) {
|
||||||
|
pipelineWidth.value = Math.max(200, Math.min(500, pipelineWidth.value + delta))
|
||||||
|
}
|
||||||
|
|
||||||
|
function onViewerResize(delta: number) {
|
||||||
|
viewerHeight.value = Math.max(120, Math.min(400, viewerHeight.value + delta))
|
||||||
|
}
|
||||||
|
|
||||||
|
function onDetectionsResize(delta: number) {
|
||||||
|
detectionsFlex.value = Math.max(1, Math.min(6, detectionsFlex.value + delta * 0.01))
|
||||||
|
}
|
||||||
|
|
||||||
|
function onTimelineResize(delta: number) {
|
||||||
|
const shift = delta * 0.02
|
||||||
|
timelineFlex.value = Math.max(0.3, Math.min(3, timelineFlex.value + shift))
|
||||||
|
tableFlex.value = Math.max(0.3, Math.min(3, tableFlex.value - shift))
|
||||||
|
}
|
||||||
|
|
||||||
const statusMap: Record<string, 'idle' | 'live' | 'processing' | 'error'> = {
|
const statusMap: Record<string, 'idle' | 'live' | 'processing' | 'error'> = {
|
||||||
idle: 'idle',
|
idle: 'idle',
|
||||||
connecting: 'processing',
|
connecting: 'processing',
|
||||||
@@ -42,41 +76,70 @@ source.connect()
|
|||||||
<header>
|
<header>
|
||||||
<h1>Detection Pipeline</h1>
|
<h1>Detection Pipeline</h1>
|
||||||
<span class="status-badge" :class="status">{{ status }}</span>
|
<span class="status-badge" :class="status">{{ status }}</span>
|
||||||
|
<span v-if="runContext" class="run-info">
|
||||||
|
{{ runContext.run_type }} · run: {{ runContext.run_id }}
|
||||||
|
</span>
|
||||||
<span class="job-id">job: {{ jobId }}</span>
|
<span class="job-id">job: {{ jobId }}</span>
|
||||||
</header>
|
</header>
|
||||||
|
|
||||||
<LayoutGrid :columns="3" :rows="3" gap="var(--space-2)">
|
<div class="main-layout">
|
||||||
<Panel title="Stats" :status="status">
|
<!-- Left column: Pipeline control (full height) -->
|
||||||
<div class="stats" v-if="stats">
|
<div class="pipeline-col" :style="{ width: pipelineWidth + 'px' }">
|
||||||
|
<PipelineGraphPanel :source="source" :status="status" />
|
||||||
|
</div>
|
||||||
|
<ResizeHandle direction="horizontal" @resize="onPipelineResize" />
|
||||||
|
|
||||||
|
<!-- Right area: interactive panels -->
|
||||||
|
<div class="content-col">
|
||||||
|
<!-- Row 1: Frame viewer + Funnel -->
|
||||||
|
<div class="viewer-row" :style="{ height: viewerHeight + 'px' }">
|
||||||
|
<FramePanel :source="source" :status="status" />
|
||||||
|
<FunnelPanel :source="source" :status="status" />
|
||||||
|
</div>
|
||||||
|
<ResizeHandle direction="vertical" @resize="onViewerResize" />
|
||||||
|
|
||||||
|
<!-- Row 2: Detections + Stats side by side -->
|
||||||
|
<div class="detections-stats-row">
|
||||||
|
<div class="detections-col" :style="{ flex: detectionsFlex }">
|
||||||
|
<Panel title="Detections" :status="status">
|
||||||
|
<div class="detections-stack">
|
||||||
|
<div class="timeline-section" :style="{ flex: timelineFlex }">
|
||||||
|
<TimelinePanel :source="source" :status="status" :embedded="true" />
|
||||||
|
</div>
|
||||||
|
<ResizeHandle direction="vertical" @resize="onTimelineResize" />
|
||||||
|
<div class="table-section" :style="{ flex: tableFlex }">
|
||||||
|
<BrandTablePanel :source="source" :status="status" :embedded="true" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Panel>
|
||||||
|
</div>
|
||||||
|
<ResizeHandle direction="horizontal" @resize="onDetectionsResize" />
|
||||||
|
<div class="stats-col">
|
||||||
|
<Panel title="Pipeline" :status="status">
|
||||||
|
<div class="pipeline-stats">
|
||||||
<div class="stat" v-for="s in [
|
<div class="stat" v-for="s in [
|
||||||
{ label: 'Frames', value: stats.frames_extracted },
|
{ label: 'Frames', value: stats?.frames_extracted ?? '—' },
|
||||||
{ label: 'After filter', value: stats.frames_after_scene_filter },
|
{ label: 'After filter', value: stats?.frames_after_scene_filter ?? '—' },
|
||||||
{ label: 'Regions', value: stats.regions_detected },
|
{ label: 'Regions', value: stats?.regions_detected ?? '—' },
|
||||||
{ label: 'OCR resolved', value: stats.regions_resolved_by_ocr },
|
{ label: 'OCR resolved', value: stats?.regions_resolved_by_ocr ?? '—' },
|
||||||
{ label: 'Cloud calls', value: stats.cloud_llm_calls },
|
{ label: 'VLM escalated', value: stats?.regions_escalated_to_local_vlm ?? '—' },
|
||||||
{ label: 'Cost', value: `$${stats.estimated_cloud_cost_usd.toFixed(4)}` },
|
{ label: 'Cloud escalated', value: stats?.regions_escalated_to_cloud_llm ?? '—' },
|
||||||
]" :key="s.label">
|
]" :key="s.label">
|
||||||
<span class="label">{{ s.label }}</span>
|
<span class="label">{{ s.label }}</span>
|
||||||
<span class="value">{{ s.value }}</span>
|
<span class="value">{{ s.value }}</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div v-else class="empty">Waiting for stats...</div>
|
|
||||||
</Panel>
|
</Panel>
|
||||||
|
|
||||||
<FunnelPanel :source="source" :status="status" />
|
|
||||||
|
|
||||||
<FramePanel :source="source" :status="status" />
|
|
||||||
|
|
||||||
<PipelineGraphPanel :source="source" :status="status" />
|
|
||||||
|
|
||||||
<BrandTablePanel :source="source" :status="status" />
|
|
||||||
|
|
||||||
<TimelinePanel :source="source" :status="status" />
|
|
||||||
|
|
||||||
<CostStatsPanel :source="source" :status="status" />
|
<CostStatsPanel :source="source" :status="status" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Bottom: Log (full width) -->
|
||||||
|
<div class="log-row">
|
||||||
<LogPanel :source="source" :status="status" />
|
<LogPanel :source="source" :status="status" />
|
||||||
</LayoutGrid>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
@@ -92,10 +155,11 @@ body {
|
|||||||
|
|
||||||
.app {
|
.app {
|
||||||
height: 100vh;
|
height: 100vh;
|
||||||
display: flex;
|
display: grid;
|
||||||
flex-direction: column;
|
grid-template-rows: auto 1fr auto;
|
||||||
padding: var(--space-4);
|
padding: var(--space-4);
|
||||||
gap: var(--space-2);
|
gap: var(--space-2);
|
||||||
|
overflow: hidden;
|
||||||
}
|
}
|
||||||
|
|
||||||
header {
|
header {
|
||||||
@@ -120,21 +184,110 @@ header h1 { font-size: var(--font-size-lg); font-weight: 600; }
|
|||||||
.status-badge.live { background: var(--status-live); color: #000; }
|
.status-badge.live { background: var(--status-live); color: #000; }
|
||||||
.status-badge.error { background: var(--status-error); color: #000; }
|
.status-badge.error { background: var(--status-error); color: #000; }
|
||||||
|
|
||||||
|
.run-info {
|
||||||
|
color: var(--text-secondary);
|
||||||
|
font-size: var(--font-size-sm);
|
||||||
|
}
|
||||||
|
|
||||||
.job-id { color: var(--text-dim); font-size: var(--font-size-sm); margin-left: auto; }
|
.job-id { color: var(--text-dim); font-size: var(--font-size-sm); margin-left: auto; }
|
||||||
|
|
||||||
.stats {
|
/* Main layout: pipeline left, content right — both same height */
|
||||||
display: grid;
|
.main-layout {
|
||||||
grid-template-columns: repeat(3, 1fr);
|
display: flex;
|
||||||
gap: var(--space-2);
|
gap: var(--space-2);
|
||||||
|
min-height: 0;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.pipeline-col {
|
||||||
|
flex-shrink: 0;
|
||||||
|
display: flex;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.pipeline-col > * { flex: 1; }
|
||||||
|
|
||||||
|
.content-col {
|
||||||
|
flex: 1;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: var(--space-2);
|
||||||
|
min-height: 0;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Content rows */
|
||||||
|
.viewer-row {
|
||||||
|
display: flex;
|
||||||
|
gap: var(--space-2);
|
||||||
|
flex-shrink: 0;
|
||||||
|
}
|
||||||
|
.viewer-row > * { flex: 1; overflow: hidden; }
|
||||||
|
|
||||||
|
/* Detections (75%) + Stats (25%) side by side, bottom-aligned with pipeline */
|
||||||
|
.detections-stats-row {
|
||||||
|
display: flex;
|
||||||
|
gap: var(--space-2);
|
||||||
|
flex: 1;
|
||||||
|
min-height: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.detections-col {
|
||||||
|
flex: 3;
|
||||||
|
min-width: 0;
|
||||||
|
display: flex;
|
||||||
|
}
|
||||||
|
.detections-col > * { flex: 1; display: flex; flex-direction: column; }
|
||||||
|
|
||||||
|
.stats-col {
|
||||||
|
flex: 1;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: var(--space-2);
|
||||||
|
min-width: 200px;
|
||||||
|
}
|
||||||
|
.stats-col > * { flex: 1; }
|
||||||
|
|
||||||
|
.detections-stack {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
height: 100%;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.timeline-section {
|
||||||
|
min-height: 60px;
|
||||||
|
overflow: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.table-section {
|
||||||
|
min-height: 60px;
|
||||||
|
overflow-y: auto;
|
||||||
|
overflow-x: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Pipeline stats list */
|
||||||
|
.pipeline-stats {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: var(--space-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
.stat {
|
.stat {
|
||||||
background: var(--surface-2);
|
display: flex;
|
||||||
border-radius: var(--panel-radius);
|
justify-content: space-between;
|
||||||
padding: var(--space-3);
|
padding: var(--space-1) 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.stat .label { color: var(--text-dim); font-size: var(--font-size-sm); }
|
||||||
|
.stat .value { font-weight: 600; }
|
||||||
|
|
||||||
|
/* Log: full width bottom, fixed height */
|
||||||
|
.log-row {
|
||||||
|
flex-shrink: 0;
|
||||||
|
height: 160px;
|
||||||
|
overflow: hidden;
|
||||||
}
|
}
|
||||||
.stat .label { display: block; color: var(--text-dim); font-size: var(--font-size-sm); margin-bottom: var(--space-1); }
|
|
||||||
.stat .value { font-size: 20px; font-weight: 600; }
|
|
||||||
|
|
||||||
.empty { color: var(--text-dim); padding: var(--space-6); text-align: center; }
|
.empty { color: var(--text-dim); padding: var(--space-6); text-align: center; }
|
||||||
</style>
|
</style>
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import type { DataSource } from 'mpr-ui-framework'
|
|||||||
const props = defineProps<{
|
const props = defineProps<{
|
||||||
source: DataSource
|
source: DataSource
|
||||||
status?: 'idle' | 'live' | 'processing' | 'error'
|
status?: 'idle' | 'live' | 'processing' | 'error'
|
||||||
|
embedded?: boolean
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
const columns: TableColumn[] = [
|
const columns: TableColumn[] = [
|
||||||
@@ -45,7 +46,7 @@ function onSort(key: string) {
|
|||||||
</script>
|
</script>
|
||||||
|
|
||||||
<template>
|
<template>
|
||||||
<Panel title="Detections" :status="status">
|
<Panel v-if="!embedded" title="Detections" :status="status">
|
||||||
<TableRenderer
|
<TableRenderer
|
||||||
:columns="columns"
|
:columns="columns"
|
||||||
:rows="rows"
|
:rows="rows"
|
||||||
@@ -54,4 +55,12 @@ function onSort(key: string) {
|
|||||||
@sort="onSort"
|
@sort="onSort"
|
||||||
/>
|
/>
|
||||||
</Panel>
|
</Panel>
|
||||||
|
<TableRenderer
|
||||||
|
v-else
|
||||||
|
:columns="columns"
|
||||||
|
:rows="rows"
|
||||||
|
:sort-key="sortKey"
|
||||||
|
:sort-dir="sortDir"
|
||||||
|
@sort="onSort"
|
||||||
|
/>
|
||||||
</template>
|
</template>
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import type { StatsUpdate, Detection } from '../types/sse-contract'
|
|||||||
const props = defineProps<{
|
const props = defineProps<{
|
||||||
source: DataSource
|
source: DataSource
|
||||||
status?: 'idle' | 'live' | 'processing' | 'error'
|
status?: 'idle' | 'live' | 'processing' | 'error'
|
||||||
|
embedded?: boolean
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
const stats = ref<StatsUpdate | null>(null)
|
const stats = ref<StatsUpdate | null>(null)
|
||||||
@@ -71,7 +72,7 @@ const metrics = computed<Metric[]>(() => {
|
|||||||
</script>
|
</script>
|
||||||
|
|
||||||
<template>
|
<template>
|
||||||
<Panel title="Cost & Stats" :status="status">
|
<Panel v-if="!embedded" title="Cost & Stats" :status="status">
|
||||||
<div class="cost-stats" v-if="stats">
|
<div class="cost-stats" v-if="stats">
|
||||||
<div class="metric" v-for="m in metrics" :key="m.label">
|
<div class="metric" v-for="m in metrics" :key="m.label">
|
||||||
<span class="label">{{ m.label }}</span>
|
<span class="label">{{ m.label }}</span>
|
||||||
@@ -81,38 +82,59 @@ const metrics = computed<Metric[]>(() => {
|
|||||||
</div>
|
</div>
|
||||||
<div v-else class="empty">Waiting for stats...</div>
|
<div v-else class="empty">Waiting for stats...</div>
|
||||||
</Panel>
|
</Panel>
|
||||||
|
<div v-else class="cost-stats" v-show="stats">
|
||||||
|
<div class="metric" v-for="m in metrics" :key="m.label">
|
||||||
|
<span class="label">{{ m.label }}</span>
|
||||||
|
<span class="value" :style="m.color ? { color: m.color } : {}">{{ m.value }}</span>
|
||||||
|
<span class="sub" v-if="m.sub">{{ m.sub }}</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<style scoped>
|
<style scoped>
|
||||||
.cost-stats {
|
.cost-stats {
|
||||||
display: grid;
|
display: grid;
|
||||||
grid-template-columns: 1fr 1fr;
|
grid-template-columns: 1fr 1fr;
|
||||||
gap: var(--space-3);
|
gap: var(--space-2);
|
||||||
padding: var(--space-3);
|
padding: var(--space-2);
|
||||||
|
overflow: hidden;
|
||||||
|
box-sizing: border-box;
|
||||||
|
width: 100%;
|
||||||
}
|
}
|
||||||
|
|
||||||
.metric {
|
.metric {
|
||||||
background: var(--surface-2);
|
background: var(--surface-2);
|
||||||
border-radius: var(--panel-radius);
|
border-radius: var(--panel-radius);
|
||||||
padding: var(--space-3);
|
padding: var(--space-2);
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
gap: var(--space-1);
|
gap: 2px;
|
||||||
|
overflow: hidden;
|
||||||
|
min-width: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
.label {
|
.label {
|
||||||
font-size: var(--font-size-sm);
|
font-size: var(--font-size-sm);
|
||||||
color: var(--text-dim);
|
color: var(--text-dim);
|
||||||
|
white-space: nowrap;
|
||||||
|
overflow: hidden;
|
||||||
|
text-overflow: ellipsis;
|
||||||
}
|
}
|
||||||
|
|
||||||
.value {
|
.value {
|
||||||
font-size: 22px;
|
font-size: 18px;
|
||||||
font-weight: 600;
|
font-weight: 600;
|
||||||
|
white-space: nowrap;
|
||||||
|
overflow: hidden;
|
||||||
|
text-overflow: ellipsis;
|
||||||
}
|
}
|
||||||
|
|
||||||
.sub {
|
.sub {
|
||||||
font-size: 11px;
|
font-size: 10px;
|
||||||
color: var(--text-dim);
|
color: var(--text-dim);
|
||||||
|
white-space: nowrap;
|
||||||
|
overflow: hidden;
|
||||||
|
text-overflow: ellipsis;
|
||||||
}
|
}
|
||||||
|
|
||||||
.empty {
|
.empty {
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import type { Detection } from '../types/sse-contract'
|
|||||||
const props = defineProps<{
|
const props = defineProps<{
|
||||||
source: DataSource
|
source: DataSource
|
||||||
status?: 'idle' | 'live' | 'processing' | 'error'
|
status?: 'idle' | 'live' | 'processing' | 'error'
|
||||||
|
embedded?: boolean
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
interface TimelineEntry {
|
interface TimelineEntry {
|
||||||
@@ -45,6 +46,20 @@ const maxTime = computed(() => {
|
|||||||
return Math.max(...entries.value.map((e) => e.timestamp + e.duration)) * 1.1
|
return Math.max(...entries.value.map((e) => e.timestamp + e.duration)) * 1.1
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Ruler ticks — nice intervals based on duration
|
||||||
|
const rulerTicks = computed(() => {
|
||||||
|
const max = maxTime.value
|
||||||
|
const intervals = [1, 2, 5, 10, 15, 30, 60, 120, 300]
|
||||||
|
const targetTicks = 8
|
||||||
|
const rawStep = max / targetTicks
|
||||||
|
const step = intervals.find((i) => i >= rawStep) || Math.ceil(rawStep / 60) * 60
|
||||||
|
const ticks: number[] = []
|
||||||
|
for (let t = 0; t <= max; t += step) {
|
||||||
|
ticks.push(t)
|
||||||
|
}
|
||||||
|
return ticks
|
||||||
|
})
|
||||||
|
|
||||||
const sourceColor: Record<string, string> = {
|
const sourceColor: Record<string, string> = {
|
||||||
ocr: '#3ecf8e',
|
ocr: '#3ecf8e',
|
||||||
local_vlm: '#f5a623',
|
local_vlm: '#f5a623',
|
||||||
@@ -67,8 +82,19 @@ function barStyle(entry: TimelineEntry) {
|
|||||||
</script>
|
</script>
|
||||||
|
|
||||||
<template>
|
<template>
|
||||||
<Panel title="Detection Timeline" :status="status">
|
<Panel v-if="!embedded" title="Detection Timeline" :status="status">
|
||||||
<div class="timeline" v-if="brands.length > 0">
|
<div class="timeline" v-if="brands.length > 0">
|
||||||
|
<div class="ruler">
|
||||||
|
<span class="ruler-label" />
|
||||||
|
<div class="ruler-track">
|
||||||
|
<span
|
||||||
|
v-for="t in rulerTicks" :key="t"
|
||||||
|
class="ruler-tick"
|
||||||
|
:style="{ left: (t / maxTime) * 100 + '%' }"
|
||||||
|
>{{ t }}s</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="rows">
|
||||||
<div class="row" v-for="brand in brands" :key="brand">
|
<div class="row" v-for="brand in brands" :key="brand">
|
||||||
<span class="brand-label">{{ brand }}</span>
|
<span class="brand-label">{{ brand }}</span>
|
||||||
<div class="track">
|
<div class="track">
|
||||||
@@ -81,10 +107,6 @@ function barStyle(entry: TimelineEntry) {
|
|||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="time-axis">
|
|
||||||
<span>0s</span>
|
|
||||||
<span>{{ (maxTime / 2).toFixed(0) }}s</span>
|
|
||||||
<span>{{ maxTime.toFixed(0) }}s</span>
|
|
||||||
</div>
|
</div>
|
||||||
<div class="legend">
|
<div class="legend">
|
||||||
<span v-for="(color, source) in sourceColor" :key="source" class="legend-item">
|
<span v-for="(color, source) in sourceColor" :key="source" class="legend-item">
|
||||||
@@ -95,6 +117,32 @@ function barStyle(entry: TimelineEntry) {
|
|||||||
</div>
|
</div>
|
||||||
<div v-else class="empty">Waiting for detections...</div>
|
<div v-else class="empty">Waiting for detections...</div>
|
||||||
</Panel>
|
</Panel>
|
||||||
|
<div v-else class="timeline" v-show="brands.length > 0">
|
||||||
|
<div class="ruler">
|
||||||
|
<span class="ruler-label" />
|
||||||
|
<div class="ruler-track">
|
||||||
|
<span
|
||||||
|
v-for="t in rulerTicks" :key="t"
|
||||||
|
class="ruler-tick"
|
||||||
|
:style="{ left: (t / maxTime) * 100 + '%' }"
|
||||||
|
>{{ t }}s</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="rows">
|
||||||
|
<div class="row" v-for="brand in brands" :key="brand">
|
||||||
|
<span class="brand-label">{{ brand }}</span>
|
||||||
|
<div class="track">
|
||||||
|
<div
|
||||||
|
v-for="(entry, i) in entries.filter((e) => e.brand === brand)"
|
||||||
|
:key="i"
|
||||||
|
class="bar"
|
||||||
|
:style="barStyle(entry)"
|
||||||
|
:title="`${entry.brand} — ${entry.source} (${(entry.confidence * 100).toFixed(0)}%) @ ${entry.timestamp.toFixed(1)}s`"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<style scoped>
|
<style scoped>
|
||||||
@@ -102,9 +150,54 @@ function barStyle(entry: TimelineEntry) {
|
|||||||
padding: var(--space-2);
|
padding: var(--space-2);
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
gap: var(--space-1);
|
|
||||||
height: 100%;
|
height: 100%;
|
||||||
overflow-y: auto;
|
overflow: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.ruler {
|
||||||
|
display: flex;
|
||||||
|
align-items: flex-end;
|
||||||
|
gap: var(--space-2);
|
||||||
|
flex-shrink: 0;
|
||||||
|
height: 20px;
|
||||||
|
margin-bottom: var(--space-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
.ruler-label {
|
||||||
|
width: 100px;
|
||||||
|
flex-shrink: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.ruler-track {
|
||||||
|
flex: 1;
|
||||||
|
position: relative;
|
||||||
|
height: 100%;
|
||||||
|
border-bottom: 1px solid var(--surface-3);
|
||||||
|
}
|
||||||
|
|
||||||
|
.ruler-tick {
|
||||||
|
position: absolute;
|
||||||
|
bottom: 0;
|
||||||
|
transform: translateX(-50%);
|
||||||
|
font-size: 9px;
|
||||||
|
color: var(--text-dim);
|
||||||
|
padding-bottom: 2px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.ruler-tick::after {
|
||||||
|
content: '';
|
||||||
|
position: absolute;
|
||||||
|
bottom: -1px;
|
||||||
|
left: 50%;
|
||||||
|
width: 1px;
|
||||||
|
height: 6px;
|
||||||
|
background: var(--surface-3);
|
||||||
|
}
|
||||||
|
|
||||||
|
.rows {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: var(--space-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
.row {
|
.row {
|
||||||
@@ -141,15 +234,6 @@ function barStyle(entry: TimelineEntry) {
|
|||||||
min-width: 4px;
|
min-width: 4px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.time-axis {
|
|
||||||
display: flex;
|
|
||||||
justify-content: space-between;
|
|
||||||
padding-left: 108px;
|
|
||||||
font-size: 10px;
|
|
||||||
color: var(--text-dim);
|
|
||||||
margin-top: var(--space-1);
|
|
||||||
}
|
|
||||||
|
|
||||||
.legend {
|
.legend {
|
||||||
display: flex;
|
display: flex;
|
||||||
gap: var(--space-3);
|
gap: var(--space-3);
|
||||||
|
|||||||
@@ -95,3 +95,44 @@ export interface JobComplete {
|
|||||||
job_id: string;
|
job_id: string;
|
||||||
report: DetectionReportSummary | null;
|
report: DetectionReportSummary | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Run context (injected into all SSE events) ---
|
||||||
|
|
||||||
|
export interface RunContext {
|
||||||
|
run_id: string;
|
||||||
|
parent_job_id: string;
|
||||||
|
run_type: 'initial' | 'replay' | 'retry';
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Checkpoint API types ---
|
||||||
|
|
||||||
|
export interface CheckpointInfo {
|
||||||
|
stage: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ReplayRequest {
|
||||||
|
job_id: string;
|
||||||
|
start_stage: string;
|
||||||
|
config_overrides?: Record<string, unknown>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ReplayResponse {
|
||||||
|
status: string;
|
||||||
|
job_id: string;
|
||||||
|
start_stage: string;
|
||||||
|
detections: number;
|
||||||
|
brands_found: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface RetryRequest {
|
||||||
|
job_id: string;
|
||||||
|
config_overrides?: Record<string, unknown>;
|
||||||
|
start_stage?: string;
|
||||||
|
schedule_seconds?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface RetryResponse {
|
||||||
|
status: string;
|
||||||
|
task_id: string;
|
||||||
|
job_id: string;
|
||||||
|
}
|
||||||
|
|||||||
@@ -64,8 +64,9 @@ defineProps<{
|
|||||||
|
|
||||||
.panel-body {
|
.panel-body {
|
||||||
flex: 1;
|
flex: 1;
|
||||||
overflow: auto;
|
overflow: hidden;
|
||||||
padding: var(--space-2);
|
padding: var(--space-2);
|
||||||
|
min-height: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
.panel-overlay {
|
.panel-overlay {
|
||||||
|
|||||||
70
ui/framework/src/components/ResizeHandle.vue
Normal file
70
ui/framework/src/components/ResizeHandle.vue
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
<script setup lang="ts">
|
||||||
|
import { ref } from 'vue'
|
||||||
|
|
||||||
|
const props = defineProps<{
|
||||||
|
direction: 'horizontal' | 'vertical'
|
||||||
|
}>()
|
||||||
|
|
||||||
|
const emit = defineEmits<{
|
||||||
|
resize: [delta: number]
|
||||||
|
}>()
|
||||||
|
|
||||||
|
const dragging = ref(false)
|
||||||
|
let startPos = 0
|
||||||
|
|
||||||
|
function onPointerDown(e: PointerEvent) {
|
||||||
|
dragging.value = true
|
||||||
|
startPos = props.direction === 'horizontal' ? e.clientX : e.clientY
|
||||||
|
const el = e.target as HTMLElement
|
||||||
|
el.setPointerCapture(e.pointerId)
|
||||||
|
}
|
||||||
|
|
||||||
|
function onPointerMove(e: PointerEvent) {
|
||||||
|
if (!dragging.value) return
|
||||||
|
const currentPos = props.direction === 'horizontal' ? e.clientX : e.clientY
|
||||||
|
const delta = currentPos - startPos
|
||||||
|
startPos = currentPos
|
||||||
|
emit('resize', delta)
|
||||||
|
}
|
||||||
|
|
||||||
|
function onPointerUp() {
|
||||||
|
dragging.value = false
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<template>
|
||||||
|
<div
|
||||||
|
class="resize-handle"
|
||||||
|
:class="[direction, { dragging }]"
|
||||||
|
@pointerdown="onPointerDown"
|
||||||
|
@pointermove="onPointerMove"
|
||||||
|
@pointerup="onPointerUp"
|
||||||
|
/>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<style scoped>
|
||||||
|
.resize-handle {
|
||||||
|
flex-shrink: 0;
|
||||||
|
background: transparent;
|
||||||
|
transition: background 0.15s;
|
||||||
|
touch-action: none;
|
||||||
|
z-index: 10;
|
||||||
|
}
|
||||||
|
|
||||||
|
.resize-handle:hover,
|
||||||
|
.resize-handle.dragging {
|
||||||
|
background: var(--text-dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
.resize-handle.horizontal {
|
||||||
|
width: 4px;
|
||||||
|
cursor: col-resize;
|
||||||
|
margin: 0 -2px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.resize-handle.vertical {
|
||||||
|
height: 4px;
|
||||||
|
cursor: row-resize;
|
||||||
|
margin: -2px 0;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
@@ -7,6 +7,7 @@ export { useDataSource } from './composables/useDataSource'
|
|||||||
// Components
|
// Components
|
||||||
export { default as Panel } from './components/Panel.vue'
|
export { default as Panel } from './components/Panel.vue'
|
||||||
export { default as LayoutGrid } from './components/LayoutGrid.vue'
|
export { default as LayoutGrid } from './components/LayoutGrid.vue'
|
||||||
|
export { default as ResizeHandle } from './components/ResizeHandle.vue'
|
||||||
|
|
||||||
// Renderers
|
// Renderers
|
||||||
export { default as LogRenderer } from './renderers/LogRenderer.vue'
|
export { default as LogRenderer } from './renderers/LogRenderer.vue'
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ const sorted = computed(() => {
|
|||||||
table {
|
table {
|
||||||
width: 100%;
|
width: 100%;
|
||||||
border-collapse: collapse;
|
border-collapse: collapse;
|
||||||
|
table-layout: fixed;
|
||||||
}
|
}
|
||||||
|
|
||||||
th {
|
th {
|
||||||
@@ -89,7 +90,6 @@ th {
|
|||||||
border-bottom: var(--panel-border);
|
border-bottom: var(--panel-border);
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
user-select: none;
|
user-select: none;
|
||||||
white-space: nowrap;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
th:hover {
|
th:hover {
|
||||||
@@ -104,7 +104,10 @@ th:hover {
|
|||||||
td {
|
td {
|
||||||
padding: var(--space-1) var(--space-3);
|
padding: var(--space-1) var(--space-3);
|
||||||
border-bottom: 1px solid var(--surface-3);
|
border-bottom: 1px solid var(--surface-3);
|
||||||
white-space: nowrap;
|
white-space: normal;
|
||||||
|
word-break: break-word;
|
||||||
|
overflow: hidden;
|
||||||
|
text-overflow: ellipsis;
|
||||||
}
|
}
|
||||||
|
|
||||||
tr:hover td {
|
tr:hover td {
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ const props = withDefaults(defineProps<{
|
|||||||
})
|
})
|
||||||
|
|
||||||
const container = ref<HTMLElement | null>(null)
|
const container = ref<HTMLElement | null>(null)
|
||||||
|
const zoomed = ref(false)
|
||||||
let chart: uPlot | null = null
|
let chart: uPlot | null = null
|
||||||
|
|
||||||
function buildOpts(): uPlot.Options {
|
function buildOpts(): uPlot.Options {
|
||||||
@@ -40,25 +41,66 @@ function buildOpts(): uPlot.Options {
|
|||||||
height: container.value?.clientHeight ?? 200,
|
height: container.value?.clientHeight ?? 200,
|
||||||
series: seriesOpts,
|
series: seriesOpts,
|
||||||
axes: [
|
axes: [
|
||||||
{ stroke: '#555568', grid: { stroke: '#2e2e3822' } },
|
{
|
||||||
{ stroke: '#555568', grid: { stroke: '#2e2e3822' } },
|
stroke: '#555568',
|
||||||
|
grid: { stroke: '#2e2e3822' },
|
||||||
|
size: 40,
|
||||||
|
font: '10px monospace',
|
||||||
|
ticks: { size: 3 },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
stroke: '#555568',
|
||||||
|
grid: { stroke: '#2e2e3822' },
|
||||||
|
size: 35,
|
||||||
|
font: '10px monospace',
|
||||||
|
ticks: { size: 3 },
|
||||||
|
},
|
||||||
],
|
],
|
||||||
cursor: { show: true },
|
cursor: { show: true },
|
||||||
legend: { show: true },
|
legend: { show: true, live: false },
|
||||||
|
padding: [8, 8, 0, 0],
|
||||||
|
hooks: {
|
||||||
|
setScale: [(_self: uPlot, scaleKey: string) => {
|
||||||
|
if (scaleKey === 'x') zoomed.value = true
|
||||||
|
}],
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function resetZoom() {
|
||||||
|
if (!chart) return
|
||||||
|
const data = chart.data
|
||||||
|
if (data && data[0] && data[0].length > 0) {
|
||||||
|
const min = data[0][0]
|
||||||
|
const max = data[0][data[0].length - 1]
|
||||||
|
chart.setScale('x', { min, max })
|
||||||
|
}
|
||||||
|
zoomed.value = false
|
||||||
|
}
|
||||||
|
|
||||||
|
function getLegendHeight(): number {
|
||||||
|
if (!container.value) return 0
|
||||||
|
const legend = container.value.querySelector('.u-legend') as HTMLElement | null
|
||||||
|
return legend ? legend.offsetHeight : 0
|
||||||
|
}
|
||||||
|
|
||||||
function createChart() {
|
function createChart() {
|
||||||
if (!container.value) return
|
if (!container.value) return
|
||||||
if (chart) chart.destroy()
|
if (chart) chart.destroy()
|
||||||
chart = new uPlot(buildOpts(), props.data, container.value)
|
chart = new uPlot(buildOpts(), props.data, container.value)
|
||||||
|
// Refit after legend renders
|
||||||
|
nextTick(() => resize())
|
||||||
}
|
}
|
||||||
|
|
||||||
function resize() {
|
function resize() {
|
||||||
if (!chart || !container.value) return
|
if (!chart || !container.value) return
|
||||||
|
const legendH = getLegendHeight()
|
||||||
|
const availableH = container.value.clientHeight
|
||||||
|
// uPlot height = canvas height (chart sets total = canvas + legend)
|
||||||
|
const chartH = Math.max(60, availableH - legendH)
|
||||||
chart.setSize({
|
chart.setSize({
|
||||||
width: container.value.clientWidth,
|
width: container.value.clientWidth,
|
||||||
height: container.value.clientHeight,
|
height: chartH,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,19 +125,74 @@ onMounted(() => {
|
|||||||
</script>
|
</script>
|
||||||
|
|
||||||
<template>
|
<template>
|
||||||
|
<div class="timeseries-wrapper">
|
||||||
|
<button v-if="zoomed" class="reset-zoom" @click="resetZoom" title="Reset zoom">⟲</button>
|
||||||
<div ref="container" class="timeseries-renderer" />
|
<div ref="container" class="timeseries-renderer" />
|
||||||
|
</div>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<style scoped>
|
<style scoped>
|
||||||
|
.timeseries-wrapper {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
.reset-zoom {
|
||||||
|
position: absolute;
|
||||||
|
top: 4px;
|
||||||
|
right: 4px;
|
||||||
|
z-index: 20;
|
||||||
|
background: var(--surface-2);
|
||||||
|
border: 1px solid var(--surface-3);
|
||||||
|
border-radius: 4px;
|
||||||
|
color: var(--text-secondary);
|
||||||
|
font-size: 14px;
|
||||||
|
width: 24px;
|
||||||
|
height: 24px;
|
||||||
|
cursor: pointer;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
opacity: 0.7;
|
||||||
|
transition: opacity 0.15s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.reset-zoom:hover {
|
||||||
|
opacity: 1;
|
||||||
|
color: var(--text-primary);
|
||||||
|
}
|
||||||
|
|
||||||
.timeseries-renderer {
|
.timeseries-renderer {
|
||||||
width: 100%;
|
width: 100%;
|
||||||
height: 100%;
|
height: 100%;
|
||||||
min-height: 150px;
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* uPlot creates a .u-wrap for canvas + a .u-legend below it */
|
||||||
|
.timeseries-renderer :deep(.u-wrap) {
|
||||||
|
flex: 1;
|
||||||
|
min-height: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.timeseries-renderer :deep(.u-legend) {
|
||||||
|
flex-shrink: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
.timeseries-renderer :deep(.u-legend) {
|
.timeseries-renderer :deep(.u-legend) {
|
||||||
font-family: var(--font-mono);
|
font-family: var(--font-mono);
|
||||||
font-size: var(--font-size-sm);
|
font-size: 10px;
|
||||||
color: var(--text-secondary);
|
color: var(--text-secondary);
|
||||||
|
padding: 2px 0;
|
||||||
|
display: flex;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
gap: 0 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.timeseries-renderer :deep(.u-legend .u-series) {
|
||||||
|
display: inline-flex;
|
||||||
|
padding: 0;
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
|
|||||||
Reference in New Issue
Block a user