This commit is contained in:
2026-03-26 02:54:56 -03:00
parent dfa3c12514
commit 08b67f2bb7
21 changed files with 1622 additions and 16 deletions

View File

@@ -0,0 +1,11 @@
---
name: agent_sdk_future
description: Claude Agent SDK for general mpr tasks (not vision provider), uses OAuth not API keys
type: project
---
Claude Agent SDK (`claude-agent-sdk`) is for future general-purpose tasks in mpr, NOT for the cloud vision provider.
**Why:** The Agent SDK uses Claude Code CLI's OAuth (browser login, no API keys) and is designed for agentic tasks (file read/edit, bash, web search). The vision provider needs raw API calls with image payloads — use the `anthropic` SDK with `ANTHROPIC_API_KEY` for that.
**How to apply:** When adding Claude-powered automation to mpr (e.g., log analysis, config suggestions, code review on pipeline changes), use the Agent SDK. For the cloud LLM escalation stage (image crops → brand ID), keep using the `anthropic` SDK with API key auth.

View File

@@ -35,5 +35,29 @@ AWS_REGION=us-east-1
AWS_ACCESS_KEY_ID=minioadmin AWS_ACCESS_KEY_ID=minioadmin
AWS_SECRET_ACCESS_KEY=minioadmin AWS_SECRET_ACCESS_KEY=minioadmin
# Inference
INFERENCE_URL=http://mcrndeb:8000
# Cloud LLM (detection pipeline escalation)
# Set CLOUD_LLM_PROVIDER to: groq, gemini, claude, openai
CLOUD_LLM_PROVIDER=groq
# Groq (default, free tier)
GROQ_API_KEY=
GROQ_MODEL=llama-3.2-90b-vision-preview
# Gemini
#GEMINI_API_KEY=
#GEMINI_MODEL=gemini-2.0-flash
# Claude (uses anthropic SDK)
#ANTHROPIC_API_KEY=
#CLAUDE_MODEL=claude-sonnet-4-20250514
# OpenAI-compatible
#OPENAI_API_KEY=
#OPENAI_MODEL=gpt-4o-mini
#OPENAI_BASE_URL=https://api.openai.com/v1
# Vite # Vite
VITE_ALLOWED_HOSTS=your-domain.local VITE_ALLOWED_HOSTS=your-domain.local

View File

@@ -20,6 +20,9 @@ from detect.stages.scene_filter import scene_filter
from detect.stages.yolo_detector import detect_objects from detect.stages.yolo_detector import detect_objects
from detect.stages.ocr_stage import run_ocr from detect.stages.ocr_stage import run_ocr
from detect.stages.brand_resolver import resolve_brands from detect.stages.brand_resolver import resolve_brands
from detect.stages.vlm_local import escalate_vlm
from detect.stages.vlm_cloud import escalate_cloud
from detect.stages.aggregator import compile_report
from detect.tracing import trace_node, flush as flush_traces from detect.tracing import trace_node, flush as flush_traces
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
@@ -158,43 +161,77 @@ def node_escalate_vlm(state: DetectState) -> dict:
_emit_transition(state, "escalate_vlm", "running") _emit_transition(state, "escalate_vlm", "running")
with trace_node(state, "escalate_vlm") as span: with trace_node(state, "escalate_vlm") as span:
profile = _get_profile(state)
candidates = state.get("unresolved_candidates", [])
job_id = state.get("job_id") job_id = state.get("job_id")
emit.log(job_id, "VLMLocal", "INFO", "Stub: VLM escalation not yet implemented")
span.set_output({"stub": True}) vlm_matched, still_unresolved = escalate_vlm(
candidates,
vlm_prompt_fn=profile.vlm_prompt,
inference_url=INFERENCE_URL,
content_type=profile.name,
job_id=job_id,
)
stats = state.get("stats", PipelineStats())
stats.regions_escalated_to_local_vlm = len(candidates)
span.set_output({"candidates": len(candidates), "matched": len(vlm_matched),
"still_unresolved": len(still_unresolved)})
existing = state.get("detections", [])
_emit_transition(state, "escalate_vlm", "done") _emit_transition(state, "escalate_vlm", "done")
return {} return {
"detections": existing + vlm_matched,
"unresolved_candidates": still_unresolved,
"stats": stats,
}
def node_escalate_cloud(state: DetectState) -> dict: def node_escalate_cloud(state: DetectState) -> dict:
_emit_transition(state, "escalate_cloud", "running") _emit_transition(state, "escalate_cloud", "running")
with trace_node(state, "escalate_cloud") as span: with trace_node(state, "escalate_cloud") as span:
profile = _get_profile(state)
candidates = state.get("unresolved_candidates", [])
job_id = state.get("job_id") job_id = state.get("job_id")
emit.log(job_id, "CloudLLM", "INFO", "Stub: cloud LLM escalation not yet implemented") stats = state.get("stats", PipelineStats())
span.set_output({"stub": True})
cloud_matched = escalate_cloud(
candidates,
vlm_prompt_fn=profile.vlm_prompt,
stats=stats,
content_type=profile.name,
job_id=job_id,
)
span.set_output({"candidates": len(candidates), "matched": len(cloud_matched),
"cloud_calls": stats.cloud_llm_calls,
"cost_usd": stats.estimated_cloud_cost_usd})
existing = state.get("detections", [])
_emit_transition(state, "escalate_cloud", "done") _emit_transition(state, "escalate_cloud", "done")
return {} return {"detections": existing + cloud_matched, "stats": stats}
def node_compile_report(state: DetectState) -> dict: def node_compile_report(state: DetectState) -> dict:
_emit_transition(state, "compile_report", "running") _emit_transition(state, "compile_report", "running")
with trace_node(state, "compile_report") as span: with trace_node(state, "compile_report") as span:
job_id = state.get("job_id")
profile = _get_profile(state) profile = _get_profile(state)
detections = state.get("detections", []) detections = state.get("detections", [])
report = profile.aggregate(detections) stats = state.get("stats", PipelineStats())
report.video_source = state.get("video_path", "") job_id = state.get("job_id")
report = compile_report(
detections=detections,
stats=stats,
video_source=state.get("video_path", ""),
content_type=profile.name,
job_id=job_id,
)
emit.log(job_id, "Aggregator", "INFO",
f"Report: {len(report.brands)} brands, {len(report.timeline)} detections")
emit.job_complete(job_id, {
"video_source": report.video_source,
"content_type": report.content_type,
"brands": {k: {"total_appearances": v.total_appearances} for k, v in report.brands.items()},
})
span.set_output({"brands": len(report.brands), "detections": len(report.timeline)}) span.set_output({"brands": len(report.brands), "detections": len(report.timeline)})
flush_traces() flush_traces()

View File

@@ -0,0 +1,58 @@
"""
Cloud LLM provider registry.
Select provider via CLOUD_LLM_PROVIDER env var.
Each provider reads its own env vars for auth/config.
CLOUD_LLM_PROVIDER=groq → GROQ_API_KEY, GROQ_MODEL, GROQ_BASE_URL
CLOUD_LLM_PROVIDER=gemini → GEMINI_API_KEY, GEMINI_MODEL
CLOUD_LLM_PROVIDER=openai → OPENAI_API_KEY, OPENAI_MODEL, OPENAI_BASE_URL
CLOUD_LLM_PROVIDER=claude → ANTHROPIC_API_KEY, CLAUDE_MODEL
"""
from __future__ import annotations
import os
from .base import CloudProvider, ProviderResponse
from .groq import GroqProvider
from .gemini import GeminiProvider
from .openai_compat import OpenAICompatProvider
from .claude import ClaudeProvider
PROVIDERS: dict[str, type] = {
"groq": GroqProvider,
"gemini": GeminiProvider,
"openai": OpenAICompatProvider,
"claude": ClaudeProvider,
}
_cached: CloudProvider | None = None
def get_provider() -> CloudProvider:
"""Get the configured cloud provider (cached after first call)."""
global _cached
if _cached is not None:
return _cached
name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
cls = PROVIDERS.get(name)
if cls is None:
raise ValueError(f"Unknown provider: {name!r}. Options: {list(PROVIDERS)}")
_cached = cls()
return _cached
def has_api_key() -> bool:
"""Check if the configured provider has an API key set."""
name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
key_map = {
"groq": "GROQ_API_KEY",
"gemini": "GEMINI_API_KEY",
"openai": "OPENAI_API_KEY",
"claude": "ANTHROPIC_API_KEY",
}
env_var = key_map.get(name, "")
return bool(os.environ.get(env_var, ""))

36
detect/providers/base.py Normal file
View File

@@ -0,0 +1,36 @@
"""Cloud LLM provider protocol and model metadata."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol
@dataclass
class ModelInfo:
"""Metadata for a cloud LLM model."""
id: str
vision: bool = True
cost_per_input_token: float = 0.0
cost_per_output_token: float = 0.0
max_output_tokens: int = 4096
notes: str = ""
@dataclass
class ProviderResponse:
answer: str
total_tokens: int = 0
class CloudProvider(Protocol):
"""
Interface for cloud LLM providers.
Each provider handles its own auth, payload format, and response parsing.
The pipeline only calls call() and reads the response.
"""
name: str
models: dict[str, ModelInfo]
def call(self, image_b64: str, prompt: str) -> ProviderResponse: ...

View File

@@ -0,0 +1,73 @@
"""Anthropic Claude provider — uses the official SDK."""
from __future__ import annotations
import logging
import os
from .base import ModelInfo, ProviderResponse
logger = logging.getLogger(__name__)
# Claude-specific env vars
# ANTHROPIC_API_KEY is read by the SDK automatically
CLAUDE_MODEL = os.environ.get("CLAUDE_MODEL", "claude-sonnet-4-20250514")
MODELS = {
"claude-sonnet-4-20250514": ModelInfo(
id="claude-sonnet-4-20250514",
vision=True,
cost_per_input_token=0.000003,
cost_per_output_token=0.000015,
notes="Best balance of quality/cost with vision",
),
"claude-haiku-4-5-20251001": ModelInfo(
id="claude-haiku-4-5-20251001",
vision=True,
cost_per_input_token=0.0000008,
cost_per_output_token=0.000004,
notes="Fastest, cheapest, good for simple brand ID",
),
"claude-opus-4-6": ModelInfo(
id="claude-opus-4-6",
vision=True,
cost_per_input_token=0.000015,
cost_per_output_token=0.000075,
notes="Highest quality, use for ambiguous cases",
),
}
class ClaudeProvider:
name = "claude"
models = MODELS
def __init__(self):
from anthropic import Anthropic
self.client = Anthropic()
self.model = CLAUDE_MODEL
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
message = self.client.messages.create(
model=self.model,
max_tokens=150,
messages=[{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": image_b64,
},
},
{"type": "text", "text": prompt},
],
}],
)
answer = message.content[0].text.strip()
total_tokens = message.usage.input_tokens + message.usage.output_tokens
return ProviderResponse(answer=answer, total_tokens=total_tokens)

View File

@@ -0,0 +1,75 @@
"""Google Gemini provider — native REST API, not OpenAI-compatible."""
from __future__ import annotations
import logging
import os
import requests
from .base import ModelInfo, ProviderResponse
logger = logging.getLogger(__name__)
# Gemini-specific env vars
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-2.0-flash")
MODELS = {
"gemini-2.0-flash": ModelInfo(
id="gemini-2.0-flash",
vision=True,
cost_per_input_token=0.0000001,
cost_per_output_token=0.0000004,
notes="Fast, cheap, good vision",
),
"gemini-2.0-pro": ModelInfo(
id="gemini-2.0-pro",
vision=True,
cost_per_input_token=0.00000125,
cost_per_output_token=0.000005,
notes="Higher quality, slower",
),
"gemini-1.5-flash": ModelInfo(
id="gemini-1.5-flash",
vision=True,
cost_per_input_token=0.000000075,
cost_per_output_token=0.0000003,
notes="Cheapest option",
),
}
class GeminiProvider:
name = "gemini"
models = MODELS
def __init__(self):
self.api_key = GEMINI_API_KEY
self.model = GEMINI_MODEL
self.endpoint = (
f"https://generativelanguage.googleapis.com/v1beta/models/"
f"{self.model}:generateContent"
)
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
payload = {
"contents": [{
"parts": [
{"text": prompt},
{"inline_data": {"mime_type": "image/jpeg", "data": image_b64}},
],
}],
"generationConfig": {"maxOutputTokens": 150},
}
url = f"{self.endpoint}?key={self.api_key}"
resp = requests.post(url, json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
answer = data["candidates"][0]["content"]["parts"][0]["text"].strip()
usage = data.get("usageMetadata", {})
total_tokens = usage.get("totalTokenCount", 0)
return ProviderResponse(answer=answer, total_tokens=total_tokens)

66
detect/providers/groq.py Normal file
View File

@@ -0,0 +1,66 @@
"""Groq cloud provider — OpenAI-compatible API with vision."""
from __future__ import annotations
import logging
import os
import requests
from .base import ModelInfo, ProviderResponse
logger = logging.getLogger(__name__)
# Groq-specific env vars
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
GROQ_BASE_URL = os.environ.get("GROQ_BASE_URL", "https://api.groq.com/openai/v1")
GROQ_MODEL = os.environ.get("GROQ_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct")
MODELS = {
"meta-llama/llama-4-scout-17b-16e-instruct": ModelInfo(
id="meta-llama/llama-4-scout-17b-16e-instruct",
vision=True,
cost_per_input_token=0.0,
cost_per_output_token=0.0,
notes="Llama 4 Scout, only vision model on Groq free tier",
),
}
class GroqProvider:
name = "groq"
models = MODELS
def __init__(self):
self.api_key = GROQ_API_KEY
self.base_url = GROQ_BASE_URL
self.model = GROQ_MODEL
self.endpoint = f"{self.base_url.rstrip('/')}/chat/completions"
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
payload = {
"model": self.model,
"messages": [{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {
"url": f"data:image/jpeg;base64,{image_b64}",
}},
],
}],
"max_tokens": 150,
}
resp = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
answer = data["choices"][0]["message"]["content"].strip()
total_tokens = data.get("usage", {}).get("total_tokens", 0)
return ProviderResponse(answer=answer, total_tokens=total_tokens)

View File

@@ -0,0 +1,73 @@
"""Generic OpenAI-compatible provider (OpenAI, Together, etc.)."""
from __future__ import annotations
import logging
import os
import requests
from .base import ModelInfo, ProviderResponse
logger = logging.getLogger(__name__)
# OpenAI-compat specific env vars
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
OPENAI_MODEL = os.environ.get("OPENAI_MODEL", "gpt-4o-mini")
MODELS = {
"gpt-4o-mini": ModelInfo(
id="gpt-4o-mini",
vision=True,
cost_per_input_token=0.00000015,
cost_per_output_token=0.0000006,
notes="Cheap, fast, decent vision",
),
"gpt-4o": ModelInfo(
id="gpt-4o",
vision=True,
cost_per_input_token=0.0000025,
cost_per_output_token=0.00001,
notes="Best OpenAI vision model",
),
}
class OpenAICompatProvider:
name = "openai"
models = MODELS
def __init__(self):
self.api_key = OPENAI_API_KEY
self.base_url = OPENAI_BASE_URL
self.model = OPENAI_MODEL
self.endpoint = f"{self.base_url.rstrip('/')}/chat/completions"
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
payload = {
"model": self.model,
"messages": [{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {
"url": f"data:image/jpeg;base64,{image_b64}",
}},
],
}],
"max_tokens": 150,
}
resp = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
answer = data["choices"][0]["message"]["content"].strip()
total_tokens = data.get("usage", {}).get("total_tokens", 0)
return ProviderResponse(answer=answer, total_tokens=total_tokens)

116
detect/stages/aggregator.py Normal file
View File

@@ -0,0 +1,116 @@
"""
Stage 8 — Report compilation
Groups all detections by brand, merges contiguous appearances,
and builds the final DetectionReport.
"""
from __future__ import annotations
import logging
from detect import emit
from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
logger = logging.getLogger(__name__)
def _merge_contiguous(detections: list[BrandDetection], gap_threshold: float = 2.0) -> list[BrandDetection]:
"""
Merge detections of the same brand that are close in time.
If two detections of the same brand are within gap_threshold seconds,
they're merged into one detection spanning the full range.
"""
if not detections:
return []
sorted_dets = sorted(detections, key=lambda d: (d.brand, d.timestamp))
merged: list[BrandDetection] = []
current = sorted_dets[0]
for det in sorted_dets[1:]:
if (det.brand == current.brand
and det.timestamp <= current.timestamp + current.duration + gap_threshold):
end = max(current.timestamp + current.duration,
det.timestamp + det.duration)
current = BrandDetection(
brand=current.brand,
timestamp=current.timestamp,
duration=end - current.timestamp,
confidence=max(current.confidence, det.confidence),
source=current.source,
bbox=current.bbox,
frame_ref=current.frame_ref,
content_type=current.content_type,
)
else:
merged.append(current)
current = det
merged.append(current)
return merged
def compile_report(
detections: list[BrandDetection],
stats: PipelineStats,
video_source: str = "",
content_type: str = "",
duration_seconds: float = 0.0,
job_id: str | None = None,
) -> DetectionReport:
"""
Build the final detection report from all accumulated detections.
Merges contiguous detections, computes per-brand stats,
and emits the job_complete event.
"""
merged = _merge_contiguous(detections)
brands: dict[str, BrandStats] = {}
for d in merged:
if d.brand not in brands:
brands[d.brand] = BrandStats()
s = brands[d.brand]
s.total_appearances += 1
s.total_screen_time += d.duration
s.avg_confidence = (
(s.avg_confidence * (s.total_appearances - 1) + d.confidence)
/ s.total_appearances
)
if s.first_seen == 0.0 or d.timestamp < s.first_seen:
s.first_seen = d.timestamp
if d.timestamp > s.last_seen:
s.last_seen = d.timestamp
report = DetectionReport(
video_source=video_source,
content_type=content_type,
duration_seconds=duration_seconds,
brands=brands,
timeline=sorted(merged, key=lambda d: d.timestamp),
pipeline_stats=stats,
)
emit.log(job_id, "Aggregator", "INFO",
f"Report: {len(brands)} brands, {len(merged)} detections "
f"(merged from {len(detections)} raw)")
emit.job_complete(job_id, {
"video_source": report.video_source,
"content_type": report.content_type,
"duration_seconds": report.duration_seconds,
"brands": {
k: {
"total_appearances": v.total_appearances,
"total_screen_time": v.total_screen_time,
"avg_confidence": round(v.avg_confidence, 3),
"first_seen": v.first_seen,
"last_seen": v.last_seen,
}
for k, v in brands.items()
},
})
return report

168
detect/stages/vlm_cloud.py Normal file
View File

@@ -0,0 +1,168 @@
"""
Stage 7 — Cloud LLM escalation
Last resort for crops the local VLM couldn't resolve.
Provider-agnostic — switch via CLOUD_LLM_PROVIDER env var.
Each provider has its own file under detect/providers/.
Tracks token usage and cost.
"""
from __future__ import annotations
import base64
import io
import logging
import numpy as np
from PIL import Image
from detect import emit
from detect.models import BrandDetection, PipelineStats, TextCandidate
from detect.profiles.base import CropContext
from detect.providers import get_provider, has_api_key
logger = logging.getLogger(__name__)
ESTIMATED_TOKENS_PER_CROP = 500
def _encode_crop(crop: np.ndarray) -> str:
img = Image.fromarray(crop)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
return base64.b64encode(buf.getvalue()).decode()
def _crop_image(candidate: TextCandidate) -> np.ndarray:
frame = candidate.frame
box = candidate.bbox
h, w = frame.image.shape[:2]
x1 = max(0, box.x)
y1 = max(0, box.y)
x2 = min(w, box.x + box.w)
y2 = min(h, box.y + box.h)
return frame.image[y1:y2, x1:x2]
def _parse_response(answer: str, total_tokens: int) -> dict:
"""Parse LLM free-text response into structured output."""
parts = [p.strip() for p in answer.split(",", 2)]
brand = parts[0] if parts else ""
confidence = 0.5
reasoning = answer
if len(parts) >= 2:
try:
confidence = float(parts[1])
confidence = max(0.0, min(1.0, confidence))
except ValueError:
pass
if len(parts) >= 3:
reasoning = parts[2]
return {
"brand": brand,
"confidence": confidence,
"reasoning": reasoning,
"tokens": total_tokens or ESTIMATED_TOKENS_PER_CROP,
}
def _call_cloud_api(image_b64: str, prompt: str) -> dict:
"""Route to the configured provider and parse the response."""
provider = get_provider()
result = provider.call(image_b64, prompt)
return _parse_response(result.answer, result.total_tokens)
def escalate_cloud(
candidates: list[TextCandidate],
vlm_prompt_fn,
stats: PipelineStats,
min_confidence: float = 0.4,
content_type: str = "",
job_id: str | None = None,
) -> list[BrandDetection]:
"""
Send remaining unresolved crops to cloud LLM.
Provider is selected via CLOUD_LLM_PROVIDER env var (groq, gemini, openai).
Updates stats with call count and cost.
"""
if not candidates:
return []
if not has_api_key():
emit.log(job_id, "CloudLLM", "WARNING",
f"No API key set for cloud provider, skipping {len(candidates)} crops")
return []
provider = get_provider()
emit.log(job_id, "CloudLLM", "INFO",
f"Escalating {len(candidates)} crops to {provider.name}")
matched: list[BrandDetection] = []
total_cost = 0.0
for candidate in candidates:
crop = _crop_image(candidate)
if crop.size == 0:
continue
crop_context = CropContext(
image=b"",
surrounding_text=candidate.text,
position_hint=f"frame {candidate.frame.sequence}",
)
prompt = vlm_prompt_fn(crop_context)
image_b64 = _encode_crop(crop)
try:
result = _call_cloud_api(image_b64, prompt)
except Exception as e:
logger.warning("Cloud LLM failed for '%s': %s", candidate.text, e)
continue
stats.cloud_llm_calls += 1
model_info = provider.models.get(provider.model)
cost_per_token = model_info.cost_per_input_token if model_info else 0.00001
call_cost = result["tokens"] * cost_per_token
total_cost += call_cost
brand = result["brand"]
confidence = result["confidence"]
if brand and confidence >= min_confidence:
detection = BrandDetection(
brand=brand,
timestamp=candidate.frame.timestamp,
duration=0.5,
confidence=confidence,
source="cloud_llm",
bbox=candidate.bbox,
frame_ref=candidate.frame.sequence,
content_type=content_type,
)
matched.append(detection)
emit.detection(
job_id,
brand=brand,
confidence=confidence,
source="cloud_llm",
timestamp=candidate.frame.timestamp,
content_type=content_type,
frame_ref=candidate.frame.sequence,
)
stats.estimated_cloud_cost_usd += total_cost
stats.regions_escalated_to_cloud_llm = len(candidates)
emit.log(job_id, "CloudLLM", "INFO",
f"Cloud resolved {len(matched)}/{len(candidates)}"
f"cost ${total_cost:.4f} ({stats.cloud_llm_calls} calls total)")
return matched

124
detect/stages/vlm_local.py Normal file
View File

@@ -0,0 +1,124 @@
"""
Stage 6 — Local VLM escalation (moondream2)
Processes unresolved text candidates by sending crop images + prompt
to the local VLM on the inference server. Produces BrandDetection
objects for crops the VLM can identify.
"""
from __future__ import annotations
import logging
import numpy as np
from detect import emit
from detect.models import BrandDetection, TextCandidate
from detect.profiles.base import CropContext
logger = logging.getLogger(__name__)
def _crop_image(candidate: TextCandidate) -> np.ndarray:
frame = candidate.frame
box = candidate.bbox
h, w = frame.image.shape[:2]
x1 = max(0, box.x)
y1 = max(0, box.y)
x2 = min(w, box.x + box.w)
y2 = min(h, box.y + box.h)
return frame.image[y1:y2, x1:x2]
def escalate_vlm(
candidates: list[TextCandidate],
vlm_prompt_fn,
inference_url: str | None = None,
min_confidence: float = 0.5,
content_type: str = "",
job_id: str | None = None,
) -> tuple[list[BrandDetection], list[TextCandidate]]:
"""
Send unresolved crops to local VLM for brand identification.
Returns:
- matched: BrandDetections the VLM confirmed
- still_unresolved: candidates the VLM couldn't resolve (→ cloud escalation)
"""
if not candidates:
return [], []
emit.log(job_id, "VLMLocal", "INFO",
f"Processing {len(candidates)} unresolved crops with moondream2")
matched: list[BrandDetection] = []
still_unresolved: list[TextCandidate] = []
if inference_url:
from detect.inference import InferenceClient
client = InferenceClient(base_url=inference_url)
for candidate in candidates:
crop = _crop_image(candidate)
if crop.size == 0:
still_unresolved.append(candidate)
continue
crop_context = CropContext(
image=b"", # not used for prompt generation
surrounding_text=candidate.text,
position_hint=f"frame {candidate.frame.sequence}",
)
prompt = vlm_prompt_fn(crop_context)
try:
if inference_url:
result = client.vlm(image=crop, prompt=prompt)
brand = result.brand
confidence = result.confidence
reasoning = result.reasoning
else:
brand, confidence, reasoning = _vlm_local(crop, prompt)
except Exception as e:
logger.warning("VLM failed for candidate '%s': %s", candidate.text, e)
still_unresolved.append(candidate)
continue
if brand and confidence >= min_confidence:
detection = BrandDetection(
brand=brand,
timestamp=candidate.frame.timestamp,
duration=0.5,
confidence=confidence,
source="local_vlm",
bbox=candidate.bbox,
frame_ref=candidate.frame.sequence,
content_type=content_type,
)
matched.append(detection)
emit.detection(
job_id,
brand=brand,
confidence=confidence,
source="local_vlm",
timestamp=candidate.frame.timestamp,
content_type=content_type,
frame_ref=candidate.frame.sequence,
)
logger.debug("VLM matched: %s (%.2f) — %s", brand, confidence, reasoning)
else:
still_unresolved.append(candidate)
emit.log(job_id, "VLMLocal", "INFO",
f"VLM resolved {len(matched)}, unresolved {len(still_unresolved)} → cloud")
return matched, still_unresolved
def _vlm_local(crop: np.ndarray, prompt: str) -> tuple[str, float, str]:
"""Run moondream2 in-process (single-box mode)."""
from gpu.models.vlm import query
result = query(crop, prompt)
return result["brand"], result["confidence"], result["reasoning"]

100
gpu/models/vlm.py Normal file
View File

@@ -0,0 +1,100 @@
"""moondream2 visual language model wrapper."""
from __future__ import annotations
import logging
from models import registry
from config import get_config
logger = logging.getLogger(__name__)
_MODEL_KEY = "vlm_moondream2"
def _load():
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = get_config().get("device", "auto")
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Loading moondream2 (device=%s)...", device)
model_id = "vikhyatk/moondream2"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
dtype = torch.float16 if "cuda" in device else torch.float32
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
dtype=dtype,
device_map=device,
)
wrapper = {"model": model, "tokenizer": tokenizer}
registry.put(_MODEL_KEY, wrapper)
logger.info("moondream2 loaded")
return wrapper
def _get():
wrapper = registry.get(_MODEL_KEY)
if wrapper is None:
wrapper = _load()
return wrapper
def query(image, prompt: str) -> dict:
"""
Query moondream2 with an image crop and prompt.
Returns {"brand": str, "confidence": float, "reasoning": str}
"""
from PIL import Image as PILImage
wrapper = _get()
model = wrapper["model"]
tokenizer = wrapper["tokenizer"]
# Convert numpy array to PIL if needed
if not isinstance(image, PILImage.Image):
image = PILImage.fromarray(image)
enc_image = model.encode_image(image)
answer = model.answer_question(enc_image, prompt, tokenizer)
# Parse response — moondream2 returns free text, extract brand + confidence
result = _parse_vlm_response(answer)
return result
def _parse_vlm_response(answer: str) -> dict:
"""
Parse moondream2 free-text response into structured output.
Expected format from prompt: "brand, confidence (0-1), reasoning"
Falls back gracefully if format doesn't match.
"""
answer = answer.strip()
parts = [p.strip() for p in answer.split(",", 2)]
brand = parts[0] if parts else ""
confidence = 0.5
reasoning = answer
if len(parts) >= 2:
try:
confidence = float(parts[1])
confidence = max(0.0, min(1.0, confidence))
except ValueError:
pass
if len(parts) >= 3:
reasoning = parts[2]
return {
"brand": brand,
"confidence": confidence,
"reasoning": reasoning,
}

View File

@@ -19,3 +19,9 @@ ultralytics>=8.0.0
# Install with: # Install with:
# uv pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ # uv pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
paddleocr>=3.0.0 paddleocr>=3.0.0
# VLM (moondream2) — uses torch (already installed above)
# Pinned <5: transformers 5.x broke moondream2's custom model code
# (all_tied_weights_keys API change). Also needs accelerate for device_map.
transformers>=4.40.0,<5
accelerate>=0.27.0

View File

@@ -25,6 +25,7 @@ from config import get_config, get_device, update_config
from models import registry from models import registry
from models.yolo import detect as yolo_detect from models.yolo import detect as yolo_detect
from models.ocr import ocr as ocr_run from models.ocr import ocr as ocr_run
from models.vlm import query as vlm_query
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -72,6 +73,18 @@ class OCRResponse(BaseModel):
results: list[OCRTextResult] results: list[OCRTextResult]
class VLMRequest(BaseModel):
image: str
prompt: str
model: str | None = None
class VLMResponse(BaseModel):
brand: str
confidence: float
reasoning: str
class ConfigUpdate(BaseModel): class ConfigUpdate(BaseModel):
device: str | None = None device: str | None = None
yolo_model: str | None = None yolo_model: str | None = None
@@ -170,6 +183,21 @@ def ocr(req: OCRRequest):
return OCRResponse(results=[OCRTextResult(**r) for r in results]) return OCRResponse(results=[OCRTextResult(**r) for r in results])
@app.post("/vlm", response_model=VLMResponse)
def vlm(req: VLMRequest):
try:
image = _decode_image(req.image)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
result = vlm_query(image, req.prompt)
except Exception as e:
raise HTTPException(status_code=500, detail=f"VLM failed: {e}")
return VLMResponse(**result)
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn

View File

@@ -29,6 +29,9 @@ strawberry-graphql[fastapi]>=0.311.0
# Observability # Observability
langfuse>=2.0.0 langfuse>=2.0.0
# Cloud LLM providers (only needed for cloud escalation stage)
anthropic>=0.40.0
# Testing # Testing
pytest>=7.4.0 pytest>=7.4.0
pytest-django>=4.7.0 pytest-django>=4.7.0

View File

@@ -0,0 +1,107 @@
#!/usr/bin/env python3
"""
Test cloud LLM provider with a real API call.
Sends a test image to the configured cloud provider and verifies
the response. Set your provider env vars before running.
Usage:
# Groq (default)
CLOUD_LLM_PROVIDER=groq GROQ_API_KEY=gsk_... python tests/detect/manual/test_cloud_provider.py
# Gemini
CLOUD_LLM_PROVIDER=gemini GEMINI_API_KEY=AIza... python tests/detect/manual/test_cloud_provider.py
# Claude
CLOUD_LLM_PROVIDER=claude ANTHROPIC_API_KEY=sk-ant-... python tests/detect/manual/test_cloud_provider.py
# OpenAI-compatible
CLOUD_LLM_PROVIDER=openai OPENAI_API_KEY=sk-... python tests/detect/manual/test_cloud_provider.py
"""
import base64
import io
import logging
import os
import sys
from pathlib import Path
# Load .env from ctrl/ (same as docker-compose uses)
env_file = Path(__file__).resolve().parents[3] / "ctrl" / ".env"
if env_file.exists():
for line in env_file.read_text().splitlines():
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, _, value = line.partition("=")
os.environ.setdefault(key.strip(), value.strip())
import numpy as np
from PIL import Image, ImageDraw, ImageFont
sys.path.insert(0, ".")
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s%(message)s")
logger = logging.getLogger(__name__)
def make_brand_image(text: str, width: int = 300, height: int = 100) -> str:
img = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 42)
except (OSError, IOError):
font = ImageFont.load_default()
draw.text((10, 20), text, fill="black", font=font)
buf = io.BytesIO()
img.save(buf, "JPEG")
return base64.b64encode(buf.getvalue()).decode()
def main():
from detect.providers import get_provider, has_api_key, PROVIDERS
provider_name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
logger.info("Provider: %s", provider_name)
logger.info("Available providers: %s", list(PROVIDERS.keys()))
if not has_api_key():
logger.error("No API key set for provider '%s'", provider_name)
logger.error("Set the appropriate env var (see usage in docstring)")
sys.exit(1)
provider = get_provider()
logger.info("Model: %s", provider.model)
logger.info("Available models: %s", list(provider.models.keys()))
input("\nPress Enter to start...")
prompt = (
"Identify the brand or sponsor visible in this image from a soccer broadcast. "
"Respond with: brand, confidence (0-1), reasoning."
)
test_cases = ["NIKE", "EMIRATES", "Coca-Cola", "adidas"]
for text in test_cases:
logger.info("--- Testing: '%s' ---", text)
image_b64 = make_brand_image(text)
try:
result = provider.call(image_b64, prompt)
logger.info(" answer: %s", result.answer)
logger.info(" tokens: %d", result.total_tokens)
if text.lower() in result.answer.lower():
logger.info(" PASS — found '%s' in response", text)
else:
logger.warning(" MISS — '%s' not in response (may be correct, check answer)", text)
except Exception as e:
logger.error(" FAIL — %s: %s", type(e).__name__, e)
if hasattr(e, 'response') and e.response is not None:
logger.error(" Response: %s", e.response.text[:500])
logger.info("All provider tests complete.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,230 @@
#!/usr/bin/env python3
"""
Push a full pipeline simulation with escalation events.
Exercises all stages including VLM and cloud escalation, with progressive
stats showing cost accumulating. Tests all panels: pipeline graph, funnel,
timeline, cost stats, brand table, and log.
Usage:
python tests/detect/manual/test_escalation_e2e.py [--job JOB_ID] [--port PORT] [--delay SECS]
Opens: http://mpr.local.ar/detection/?job=<JOB_ID>
"""
import argparse
import json
import logging
import time
from datetime import datetime, timezone
import redis
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s%(message)s")
logger = logging.getLogger(__name__)
NODES = ["extract_frames", "filter_scenes", "detect_objects", "run_ocr",
"match_brands", "escalate_vlm", "escalate_cloud", "compile_report"]
def ts():
return datetime.now(timezone.utc).isoformat()
def push(r, key, event):
event["ts"] = event.get("ts", ts())
r.rpush(key, json.dumps(event))
return event
def push_graph(r, key, active_node, status, delay):
nodes = []
for n in NODES:
if n == active_node:
nodes.append({"id": n, "status": status})
elif NODES.index(n) < NODES.index(active_node):
nodes.append({"id": n, "status": "done"})
else:
nodes.append({"id": n, "status": "pending"})
push(r, key, {"event": "graph_update", "nodes": nodes})
time.sleep(delay)
def push_stats(r, key, **fields):
base = {
"event": "stats_update",
"frames_extracted": 0, "frames_after_scene_filter": 0,
"regions_detected": 0, "regions_resolved_by_ocr": 0,
"regions_escalated_to_local_vlm": 0, "regions_escalated_to_cloud_llm": 0,
"cloud_llm_calls": 0, "processing_time_seconds": 0, "estimated_cloud_cost_usd": 0,
}
base.update(fields)
push(r, key, base)
def push_detection(r, key, brand, conf, source, timestamp, frame_ref, delay):
push(r, key, {
"event": "detection",
"brand": brand, "confidence": conf, "source": source,
"timestamp": timestamp, "duration": 0.5,
"content_type": "soccer_broadcast", "frame_ref": frame_ref,
})
logger.info(" [%s] %s %.2f t=%.1fs", source, brand, conf, timestamp)
time.sleep(delay * 0.3)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--job", default="escalation-test")
parser.add_argument("--port", type=int, default=6382)
parser.add_argument("--delay", type=float, default=0.5)
args = parser.parse_args()
r = redis.Redis(port=args.port, decode_responses=True)
key = f"detect_events:{args.job}"
r.delete(key)
delay = args.delay
logger.info("Full escalation pipeline simulation → %s", key)
logger.info("Open: http://mpr.local.ar/detection/?job=%s", args.job)
input("\nPress Enter to start...")
# --- Extract frames ---
push_graph(r, key, "extract_frames", "running", delay)
push(r, key, {"event": "log", "level": "INFO", "stage": "FrameExtractor",
"msg": "Extracting frames: match_clip.mp4 (90.0s, 1920x1080, fps=2)"})
time.sleep(delay)
push_stats(r, key, frames_extracted=180, processing_time_seconds=4.5)
push_graph(r, key, "extract_frames", "done", delay)
# --- Scene filter ---
push_graph(r, key, "filter_scenes", "running", delay)
push_stats(r, key, frames_extracted=180, frames_after_scene_filter=52, processing_time_seconds=6.8)
push(r, key, {"event": "log", "level": "INFO", "stage": "SceneFilter",
"msg": "Kept 52 frames (71% reduction)"})
push_graph(r, key, "filter_scenes", "done", delay)
# --- YOLO detect ---
push_graph(r, key, "detect_objects", "running", delay)
push(r, key, {"event": "log", "level": "INFO", "stage": "YOLODetector",
"msg": "Running yolov8n on 52 frames"})
time.sleep(delay)
push_stats(r, key, frames_extracted=180, frames_after_scene_filter=52,
regions_detected=41, processing_time_seconds=14.2)
push_graph(r, key, "detect_objects", "done", delay)
# --- OCR ---
push_graph(r, key, "run_ocr", "running", delay)
push(r, key, {"event": "log", "level": "INFO", "stage": "OCRStage",
"msg": "Running OCR on 41 regions (mode=remote)"})
time.sleep(delay)
push_stats(r, key, frames_extracted=180, frames_after_scene_filter=52,
regions_detected=41, regions_resolved_by_ocr=30, processing_time_seconds=21.5)
push_graph(r, key, "run_ocr", "done", delay)
# --- Brand matching ---
push_graph(r, key, "match_brands", "running", delay)
push(r, key, {"event": "log", "level": "INFO", "stage": "BrandResolver",
"msg": "Matching 30 candidates against 12 brands (fuzzy_threshold=75)"})
time.sleep(delay)
# OCR detections
ocr_brands = [
("Nike", 0.97, 2.0, 4), ("Nike", 0.95, 5.5, 11), ("Emirates", 0.92, 8.0, 16),
("Adidas", 0.89, 12.0, 24), ("Coca-Cola", 0.85, 18.0, 36),
("Nike", 0.94, 22.0, 44), ("Emirates", 0.88, 28.0, 56),
("Adidas", 0.91, 32.0, 64), ("Nike", 0.96, 38.0, 76),
("Emirates", 0.90, 42.0, 84), ("Coca-Cola", 0.87, 48.0, 96),
("Nike", 0.93, 52.0, 104), ("Adidas", 0.90, 58.0, 116),
]
for brand, conf, ts_val, fref in ocr_brands:
push_detection(r, key, brand, conf, "ocr", ts_val, fref, delay)
push(r, key, {"event": "log", "level": "INFO", "stage": "BrandResolver",
"msg": "Exact: 10, Fuzzy: 3, Unresolved: 11 → VLM"})
push_graph(r, key, "match_brands", "done", delay)
# --- VLM escalation ---
push_graph(r, key, "escalate_vlm", "running", delay)
push(r, key, {"event": "log", "level": "INFO", "stage": "VLMLocal",
"msg": "Processing 11 unresolved crops with moondream2"})
time.sleep(delay)
vlm_brands = [
("Mastercard", 0.78, 15.0, 30), ("Santander", 0.74, 25.0, 50),
("Qatar Airways", 0.81, 35.0, 70), ("Heineken", 0.76, 45.0, 90),
("Lay's", 0.72, 55.0, 110),
]
for brand, conf, ts_val, fref in vlm_brands:
push_detection(r, key, brand, conf, "local_vlm", ts_val, fref, delay)
push_stats(r, key, frames_extracted=180, frames_after_scene_filter=52,
regions_detected=41, regions_resolved_by_ocr=30,
regions_escalated_to_local_vlm=11, processing_time_seconds=38.7,
estimated_cloud_cost_usd=0)
push(r, key, {"event": "log", "level": "INFO", "stage": "VLMLocal",
"msg": "VLM resolved 5, unresolved 6 → cloud"})
push_graph(r, key, "escalate_vlm", "done", delay)
# --- Cloud escalation ---
push_graph(r, key, "escalate_cloud", "running", delay)
push(r, key, {"event": "log", "level": "INFO", "stage": "CloudLLM",
"msg": "Escalating 6 crops to groq (llama-3.2-90b-vision)"})
time.sleep(delay)
cloud_brands = [
("Pepsi", 0.68, 10.0, 20),
("Gazprom", 0.65, 40.0, 80),
]
for brand, conf, ts_val, fref in cloud_brands:
push_detection(r, key, brand, conf, "cloud_llm", ts_val, fref, delay)
push_stats(r, key, frames_extracted=180, frames_after_scene_filter=52,
regions_detected=41, regions_resolved_by_ocr=30,
regions_escalated_to_local_vlm=11, regions_escalated_to_cloud_llm=6,
cloud_llm_calls=6, processing_time_seconds=45.2,
estimated_cloud_cost_usd=0.0) # groq free tier
push(r, key, {"event": "log", "level": "WARNING", "stage": "CloudLLM",
"msg": "4 crops unresolved after cloud — likely not brands"})
push(r, key, {"event": "log", "level": "INFO", "stage": "CloudLLM",
"msg": "Cloud resolved 2/6 — cost $0.0000 (groq free tier)"})
push_graph(r, key, "escalate_cloud", "done", delay)
# --- Compile report ---
push_graph(r, key, "compile_report", "running", delay)
total_brands = len(set(b[0] for b in ocr_brands + vlm_brands + cloud_brands))
total_dets = len(ocr_brands) + len(vlm_brands) + len(cloud_brands)
push(r, key, {"event": "log", "level": "INFO", "stage": "Aggregator",
"msg": f"Report: {total_brands} brands, {total_dets} detections (merged from {total_dets} raw)"})
push(r, key, {"event": "job_complete", "job_id": args.job, "report": {
"video_source": "match_clip.mp4",
"content_type": "soccer_broadcast",
"duration_seconds": 90.0,
"brands": {
"Nike": {"total_appearances": 5, "avg_confidence": 0.95},
"Emirates": {"total_appearances": 3, "avg_confidence": 0.90},
"Adidas": {"total_appearances": 3, "avg_confidence": 0.90},
"Coca-Cola": {"total_appearances": 2, "avg_confidence": 0.86},
"Mastercard": {"total_appearances": 1, "avg_confidence": 0.78},
"Santander": {"total_appearances": 1, "avg_confidence": 0.74},
"Qatar Airways": {"total_appearances": 1, "avg_confidence": 0.81},
"Heineken": {"total_appearances": 1, "avg_confidence": 0.76},
"Lay's": {"total_appearances": 1, "avg_confidence": 0.72},
"Pepsi": {"total_appearances": 1, "avg_confidence": 0.68},
"Gazprom": {"total_appearances": 1, "avg_confidence": 0.65},
},
}})
push_graph(r, key, "compile_report", "done", delay)
logger.info("Done. %d brands, %d detections across ocr/vlm/cloud.", total_brands, total_dets)
logger.info("Check: pipeline graph (all green), timeline (3 source colors),")
logger.info(" cost panel (escalation ratio), brand table (source column).")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,100 @@
#!/usr/bin/env python3
"""
Test local VLM (moondream2) via the inference server.
Creates test images with brand text/logos, sends them to the /vlm endpoint,
verifies moondream2 can identify the brand.
Usage:
python tests/detect/manual/test_vlm_e2e.py [--url URL]
Requires: inference server running with moondream2 loaded (gpu/server.py)
"""
import argparse
import base64
import io
import logging
import sys
import numpy as np
import requests
from PIL import Image, ImageDraw, ImageFont
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s%(message)s")
logger = logging.getLogger(__name__)
def make_brand_image(text: str, width: int = 300, height: int = 100) -> np.ndarray:
img = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 42)
except (OSError, IOError):
font = ImageFont.load_default()
draw.text((10, 20), text, fill="black", font=font)
return np.array(img)
def image_to_b64(image: np.ndarray) -> str:
img = Image.fromarray(image)
buf = io.BytesIO()
img.save(buf, "JPEG")
return base64.b64encode(buf.getvalue()).decode()
def test_health(url: str):
logger.info("--- Health check ---")
resp = requests.get(f"{url}/health")
resp.raise_for_status()
data = resp.json()
logger.info("Status: %s, device: %s, models: %s", data["status"], data["device"], data.get("loaded_models", []))
def test_vlm(url: str, text: str, prompt: str):
logger.info("--- VLM: image='%s' ---", text)
image = make_brand_image(text)
b64 = image_to_b64(image)
resp = requests.post(f"{url}/vlm", json={"image": b64, "prompt": prompt})
resp.raise_for_status()
data = resp.json()
logger.info(" brand: %s", data["brand"])
logger.info(" confidence: %.2f", data["confidence"])
logger.info(" reasoning: %s", data["reasoning"])
if text.lower() in data["brand"].lower():
logger.info(" PASS — matched")
else:
logger.warning(" MISS — expected '%s' in response", text)
return data
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--url", default="http://mcrndeb:8000")
args = parser.parse_args()
url = args.url.rstrip("/")
logger.info("Inference server: %s", url)
input("\nPress Enter to start...")
test_health(url)
prompt = (
"Identify the brand or sponsor visible in this image from a soccer broadcast. "
"Respond with: brand, confidence (0-1), reasoning."
)
test_vlm(url, "NIKE", prompt)
test_vlm(url, "EMIRATES", prompt)
test_vlm(url, "Coca-Cola", prompt)
test_vlm(url, "adidas", prompt)
logger.info("All VLM tests complete.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,79 @@
"""Tests for the report aggregator stage."""
import pytest
from detect.models import BoundingBox, BrandDetection, PipelineStats
from detect.stages.aggregator import compile_report, _merge_contiguous
def _make_detection(brand: str, timestamp: float, duration: float = 0.5,
source: str = "ocr", confidence: float = 0.9) -> BrandDetection:
return BrandDetection(
brand=brand, timestamp=timestamp, duration=duration,
confidence=confidence, source=source, content_type="soccer_broadcast",
)
def test_merge_contiguous_same_brand():
dets = [
_make_detection("Nike", 1.0, 0.5),
_make_detection("Nike", 1.3, 0.5), # within gap
_make_detection("Nike", 5.0, 0.5), # separate
]
merged = _merge_contiguous(dets, gap_threshold=2.0)
assert len(merged) == 2
assert merged[0].brand == "Nike"
assert merged[0].timestamp == 1.0
assert merged[0].duration == pytest.approx(0.8) # 1.0 to 1.8
assert merged[1].timestamp == 5.0
def test_merge_different_brands():
dets = [
_make_detection("Nike", 1.0),
_make_detection("Adidas", 1.5),
]
merged = _merge_contiguous(dets)
assert len(merged) == 2
def test_merge_empty():
assert _merge_contiguous([]) == []
def test_compile_report(monkeypatch):
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
dets = [
_make_detection("Nike", 1.0, 0.5, confidence=0.95),
_make_detection("Nike", 5.0, 1.0, confidence=0.90),
_make_detection("Adidas", 3.0, 0.5, confidence=0.85),
_make_detection("Heineken", 10.0, 0.5, source="cloud_llm", confidence=0.70),
]
stats = PipelineStats(
frames_extracted=120,
regions_detected=32,
cloud_llm_calls=1,
estimated_cloud_cost_usd=0.003,
)
report = compile_report(
detections=dets,
stats=stats,
video_source="test.mp4",
content_type="soccer_broadcast",
job_id="test-report",
)
assert len(report.brands) == 3
assert report.brands["Nike"].total_appearances == 2
assert report.brands["Adidas"].total_appearances == 1
assert report.brands["Heineken"].total_appearances == 1
assert report.pipeline_stats.cloud_llm_calls == 1
assert report.video_source == "test.mp4"
# job_complete event should have been emitted
complete = [e for e in events if e[0] == "job_complete"]
assert len(complete) == 1

View File

@@ -0,0 +1,92 @@
"""Tests for cloud LLM escalation stage."""
import numpy as np
import pytest
from detect.models import BoundingBox, Frame, PipelineStats, TextCandidate
from detect.stages.vlm_cloud import escalate_cloud, _parse_response
def _make_candidate(text: str = "unknown", confidence: float = 0.4) -> TextCandidate:
frame = Frame(sequence=0, chunk_id=0, timestamp=1.0,
image=np.zeros((50, 100, 3), dtype=np.uint8))
box = BoundingBox(x=0, y=0, w=100, h=50, confidence=0.5, label="text")
return TextCandidate(frame=frame, bbox=box, text=text, ocr_confidence=confidence)
def test_parse_response_clean():
result = _parse_response("Nike, 0.92, swoosh logo visible", 200)
assert result["brand"] == "Nike"
assert result["confidence"] == 0.92
assert "swoosh" in result["reasoning"]
assert result["tokens"] == 200
def test_parse_response_no_confidence():
result = _parse_response("Adidas", 0)
assert result["brand"] == "Adidas"
assert result["confidence"] == 0.5 # default
def test_escalate_skips_without_api_key(monkeypatch):
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
monkeypatch.delenv("GROQ_API_KEY", raising=False)
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq")
# Reset cached provider
import detect.providers as prov
monkeypatch.setattr(prov, "_cached", None)
candidates = [_make_candidate()]
stats = PipelineStats()
prompt_fn = lambda ctx: "what brand?"
matched = escalate_cloud(candidates, prompt_fn, stats, job_id="test")
assert len(matched) == 0
assert stats.cloud_llm_calls == 0
log_events = [e for e in events if e[0] == "log"]
assert any("No API key" in e[1].get("msg", "") for e in log_events)
def test_escalate_empty_candidates(monkeypatch):
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
stats = PipelineStats()
matched = escalate_cloud([], lambda ctx: "", stats, job_id="test")
assert len(matched) == 0
assert stats.cloud_llm_calls == 0
def test_escalate_with_mock_api(monkeypatch):
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
monkeypatch.setenv("GROQ_API_KEY", "test-key")
monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq")
# Reset cached provider
import detect.providers as prov
monkeypatch.setattr(prov, "_cached", None)
def mock_call(image_b64, prompt):
return {"brand": "Heineken", "confidence": 0.75, "reasoning": "green logo", "tokens": 300}
monkeypatch.setattr("detect.stages.vlm_cloud._call_cloud_api", mock_call)
candidates = [_make_candidate("unknown logo")]
stats = PipelineStats()
prompt_fn = lambda ctx: "what brand?"
matched = escalate_cloud(candidates, prompt_fn, stats, job_id="test")
assert len(matched) == 1
assert matched[0].brand == "Heineken"
assert matched[0].source == "cloud_llm"
assert stats.cloud_llm_calls == 1
assert stats.estimated_cloud_cost_usd >= 0 # exact cost depends on provider model index