From 08b67f2bb7cc9b85946638d1ed5d214f078b1a70 Mon Sep 17 00:00:00 2001 From: buenosairesam Date: Thu, 26 Mar 2026 02:54:56 -0300 Subject: [PATCH] phase 9 --- .../memory/project_agent_sdk.md | 11 + ctrl/.env.template | 24 ++ detect/graph.py | 69 ++++-- detect/providers/__init__.py | 58 +++++ detect/providers/base.py | 36 +++ detect/providers/claude.py | 73 ++++++ detect/providers/gemini.py | 75 ++++++ detect/providers/groq.py | 66 +++++ detect/providers/openai_compat.py | 73 ++++++ detect/stages/aggregator.py | 116 +++++++++ detect/stages/vlm_cloud.py | 168 +++++++++++++ detect/stages/vlm_local.py | 124 ++++++++++ gpu/models/vlm.py | 100 ++++++++ gpu/requirements.txt | 6 + gpu/server.py | 28 +++ requirements.txt | 3 + tests/detect/manual/test_cloud_provider.py | 107 ++++++++ tests/detect/manual/test_escalation_e2e.py | 230 ++++++++++++++++++ tests/detect/manual/test_vlm_e2e.py | 100 ++++++++ tests/detect/test_aggregator.py | 79 ++++++ tests/detect/test_vlm_cloud.py | 92 +++++++ 21 files changed, 1622 insertions(+), 16 deletions(-) create mode 100644 .claude/projects/-home-mariano-wdir-mpr/memory/project_agent_sdk.md create mode 100644 detect/providers/__init__.py create mode 100644 detect/providers/base.py create mode 100644 detect/providers/claude.py create mode 100644 detect/providers/gemini.py create mode 100644 detect/providers/groq.py create mode 100644 detect/providers/openai_compat.py create mode 100644 detect/stages/aggregator.py create mode 100644 detect/stages/vlm_cloud.py create mode 100644 detect/stages/vlm_local.py create mode 100644 gpu/models/vlm.py create mode 100644 tests/detect/manual/test_cloud_provider.py create mode 100644 tests/detect/manual/test_escalation_e2e.py create mode 100644 tests/detect/manual/test_vlm_e2e.py create mode 100644 tests/detect/test_aggregator.py create mode 100644 tests/detect/test_vlm_cloud.py diff --git a/.claude/projects/-home-mariano-wdir-mpr/memory/project_agent_sdk.md b/.claude/projects/-home-mariano-wdir-mpr/memory/project_agent_sdk.md new file mode 100644 index 0000000..5218be0 --- /dev/null +++ b/.claude/projects/-home-mariano-wdir-mpr/memory/project_agent_sdk.md @@ -0,0 +1,11 @@ +--- +name: agent_sdk_future +description: Claude Agent SDK for general mpr tasks (not vision provider), uses OAuth not API keys +type: project +--- + +Claude Agent SDK (`claude-agent-sdk`) is for future general-purpose tasks in mpr, NOT for the cloud vision provider. + +**Why:** The Agent SDK uses Claude Code CLI's OAuth (browser login, no API keys) and is designed for agentic tasks (file read/edit, bash, web search). The vision provider needs raw API calls with image payloads — use the `anthropic` SDK with `ANTHROPIC_API_KEY` for that. + +**How to apply:** When adding Claude-powered automation to mpr (e.g., log analysis, config suggestions, code review on pipeline changes), use the Agent SDK. For the cloud LLM escalation stage (image crops → brand ID), keep using the `anthropic` SDK with API key auth. diff --git a/ctrl/.env.template b/ctrl/.env.template index e7f860d..9066bef 100644 --- a/ctrl/.env.template +++ b/ctrl/.env.template @@ -35,5 +35,29 @@ AWS_REGION=us-east-1 AWS_ACCESS_KEY_ID=minioadmin AWS_SECRET_ACCESS_KEY=minioadmin +# Inference +INFERENCE_URL=http://mcrndeb:8000 + +# Cloud LLM (detection pipeline escalation) +# Set CLOUD_LLM_PROVIDER to: groq, gemini, claude, openai +CLOUD_LLM_PROVIDER=groq + +# Groq (default, free tier) +GROQ_API_KEY= +GROQ_MODEL=llama-3.2-90b-vision-preview + +# Gemini +#GEMINI_API_KEY= +#GEMINI_MODEL=gemini-2.0-flash + +# Claude (uses anthropic SDK) +#ANTHROPIC_API_KEY= +#CLAUDE_MODEL=claude-sonnet-4-20250514 + +# OpenAI-compatible +#OPENAI_API_KEY= +#OPENAI_MODEL=gpt-4o-mini +#OPENAI_BASE_URL=https://api.openai.com/v1 + # Vite VITE_ALLOWED_HOSTS=your-domain.local diff --git a/detect/graph.py b/detect/graph.py index a60a91d..62c5647 100644 --- a/detect/graph.py +++ b/detect/graph.py @@ -20,6 +20,9 @@ from detect.stages.scene_filter import scene_filter from detect.stages.yolo_detector import detect_objects from detect.stages.ocr_stage import run_ocr from detect.stages.brand_resolver import resolve_brands +from detect.stages.vlm_local import escalate_vlm +from detect.stages.vlm_cloud import escalate_cloud +from detect.stages.aggregator import compile_report from detect.tracing import trace_node, flush as flush_traces INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode @@ -158,43 +161,77 @@ def node_escalate_vlm(state: DetectState) -> dict: _emit_transition(state, "escalate_vlm", "running") with trace_node(state, "escalate_vlm") as span: + profile = _get_profile(state) + candidates = state.get("unresolved_candidates", []) job_id = state.get("job_id") - emit.log(job_id, "VLMLocal", "INFO", "Stub: VLM escalation not yet implemented") - span.set_output({"stub": True}) + + vlm_matched, still_unresolved = escalate_vlm( + candidates, + vlm_prompt_fn=profile.vlm_prompt, + inference_url=INFERENCE_URL, + content_type=profile.name, + job_id=job_id, + ) + + stats = state.get("stats", PipelineStats()) + stats.regions_escalated_to_local_vlm = len(candidates) + span.set_output({"candidates": len(candidates), "matched": len(vlm_matched), + "still_unresolved": len(still_unresolved)}) + + existing = state.get("detections", []) _emit_transition(state, "escalate_vlm", "done") - return {} + return { + "detections": existing + vlm_matched, + "unresolved_candidates": still_unresolved, + "stats": stats, + } def node_escalate_cloud(state: DetectState) -> dict: _emit_transition(state, "escalate_cloud", "running") with trace_node(state, "escalate_cloud") as span: + profile = _get_profile(state) + candidates = state.get("unresolved_candidates", []) job_id = state.get("job_id") - emit.log(job_id, "CloudLLM", "INFO", "Stub: cloud LLM escalation not yet implemented") - span.set_output({"stub": True}) + stats = state.get("stats", PipelineStats()) + + cloud_matched = escalate_cloud( + candidates, + vlm_prompt_fn=profile.vlm_prompt, + stats=stats, + content_type=profile.name, + job_id=job_id, + ) + + span.set_output({"candidates": len(candidates), "matched": len(cloud_matched), + "cloud_calls": stats.cloud_llm_calls, + "cost_usd": stats.estimated_cloud_cost_usd}) + + existing = state.get("detections", []) _emit_transition(state, "escalate_cloud", "done") - return {} + return {"detections": existing + cloud_matched, "stats": stats} def node_compile_report(state: DetectState) -> dict: _emit_transition(state, "compile_report", "running") with trace_node(state, "compile_report") as span: - job_id = state.get("job_id") profile = _get_profile(state) detections = state.get("detections", []) - report = profile.aggregate(detections) - report.video_source = state.get("video_path", "") + stats = state.get("stats", PipelineStats()) + job_id = state.get("job_id") + + report = compile_report( + detections=detections, + stats=stats, + video_source=state.get("video_path", ""), + content_type=profile.name, + job_id=job_id, + ) - emit.log(job_id, "Aggregator", "INFO", - f"Report: {len(report.brands)} brands, {len(report.timeline)} detections") - emit.job_complete(job_id, { - "video_source": report.video_source, - "content_type": report.content_type, - "brands": {k: {"total_appearances": v.total_appearances} for k, v in report.brands.items()}, - }) span.set_output({"brands": len(report.brands), "detections": len(report.timeline)}) flush_traces() diff --git a/detect/providers/__init__.py b/detect/providers/__init__.py new file mode 100644 index 0000000..0380833 --- /dev/null +++ b/detect/providers/__init__.py @@ -0,0 +1,58 @@ +""" +Cloud LLM provider registry. + +Select provider via CLOUD_LLM_PROVIDER env var. +Each provider reads its own env vars for auth/config. + + CLOUD_LLM_PROVIDER=groq → GROQ_API_KEY, GROQ_MODEL, GROQ_BASE_URL + CLOUD_LLM_PROVIDER=gemini → GEMINI_API_KEY, GEMINI_MODEL + CLOUD_LLM_PROVIDER=openai → OPENAI_API_KEY, OPENAI_MODEL, OPENAI_BASE_URL + CLOUD_LLM_PROVIDER=claude → ANTHROPIC_API_KEY, CLAUDE_MODEL +""" + +from __future__ import annotations + +import os + +from .base import CloudProvider, ProviderResponse +from .groq import GroqProvider +from .gemini import GeminiProvider +from .openai_compat import OpenAICompatProvider +from .claude import ClaudeProvider + +PROVIDERS: dict[str, type] = { + "groq": GroqProvider, + "gemini": GeminiProvider, + "openai": OpenAICompatProvider, + "claude": ClaudeProvider, +} + +_cached: CloudProvider | None = None + + +def get_provider() -> CloudProvider: + """Get the configured cloud provider (cached after first call).""" + global _cached + if _cached is not None: + return _cached + + name = os.environ.get("CLOUD_LLM_PROVIDER", "groq") + cls = PROVIDERS.get(name) + if cls is None: + raise ValueError(f"Unknown provider: {name!r}. Options: {list(PROVIDERS)}") + + _cached = cls() + return _cached + + +def has_api_key() -> bool: + """Check if the configured provider has an API key set.""" + name = os.environ.get("CLOUD_LLM_PROVIDER", "groq") + key_map = { + "groq": "GROQ_API_KEY", + "gemini": "GEMINI_API_KEY", + "openai": "OPENAI_API_KEY", + "claude": "ANTHROPIC_API_KEY", + } + env_var = key_map.get(name, "") + return bool(os.environ.get(env_var, "")) diff --git a/detect/providers/base.py b/detect/providers/base.py new file mode 100644 index 0000000..c6e809d --- /dev/null +++ b/detect/providers/base.py @@ -0,0 +1,36 @@ +"""Cloud LLM provider protocol and model metadata.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Protocol + + +@dataclass +class ModelInfo: + """Metadata for a cloud LLM model.""" + id: str + vision: bool = True + cost_per_input_token: float = 0.0 + cost_per_output_token: float = 0.0 + max_output_tokens: int = 4096 + notes: str = "" + + +@dataclass +class ProviderResponse: + answer: str + total_tokens: int = 0 + + +class CloudProvider(Protocol): + """ + Interface for cloud LLM providers. + + Each provider handles its own auth, payload format, and response parsing. + The pipeline only calls call() and reads the response. + """ + name: str + models: dict[str, ModelInfo] + + def call(self, image_b64: str, prompt: str) -> ProviderResponse: ... diff --git a/detect/providers/claude.py b/detect/providers/claude.py new file mode 100644 index 0000000..c88ee60 --- /dev/null +++ b/detect/providers/claude.py @@ -0,0 +1,73 @@ +"""Anthropic Claude provider — uses the official SDK.""" + +from __future__ import annotations + +import logging +import os + +from .base import ModelInfo, ProviderResponse + +logger = logging.getLogger(__name__) + +# Claude-specific env vars +# ANTHROPIC_API_KEY is read by the SDK automatically +CLAUDE_MODEL = os.environ.get("CLAUDE_MODEL", "claude-sonnet-4-20250514") + +MODELS = { + "claude-sonnet-4-20250514": ModelInfo( + id="claude-sonnet-4-20250514", + vision=True, + cost_per_input_token=0.000003, + cost_per_output_token=0.000015, + notes="Best balance of quality/cost with vision", + ), + "claude-haiku-4-5-20251001": ModelInfo( + id="claude-haiku-4-5-20251001", + vision=True, + cost_per_input_token=0.0000008, + cost_per_output_token=0.000004, + notes="Fastest, cheapest, good for simple brand ID", + ), + "claude-opus-4-6": ModelInfo( + id="claude-opus-4-6", + vision=True, + cost_per_input_token=0.000015, + cost_per_output_token=0.000075, + notes="Highest quality, use for ambiguous cases", + ), +} + + +class ClaudeProvider: + name = "claude" + models = MODELS + + def __init__(self): + from anthropic import Anthropic + self.client = Anthropic() + self.model = CLAUDE_MODEL + + def call(self, image_b64: str, prompt: str) -> ProviderResponse: + message = self.client.messages.create( + model=self.model, + max_tokens=150, + messages=[{ + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": image_b64, + }, + }, + {"type": "text", "text": prompt}, + ], + }], + ) + + answer = message.content[0].text.strip() + total_tokens = message.usage.input_tokens + message.usage.output_tokens + + return ProviderResponse(answer=answer, total_tokens=total_tokens) diff --git a/detect/providers/gemini.py b/detect/providers/gemini.py new file mode 100644 index 0000000..fc3e49b --- /dev/null +++ b/detect/providers/gemini.py @@ -0,0 +1,75 @@ +"""Google Gemini provider — native REST API, not OpenAI-compatible.""" + +from __future__ import annotations + +import logging +import os + +import requests + +from .base import ModelInfo, ProviderResponse + +logger = logging.getLogger(__name__) + +# Gemini-specific env vars +GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "") +GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-2.0-flash") + +MODELS = { + "gemini-2.0-flash": ModelInfo( + id="gemini-2.0-flash", + vision=True, + cost_per_input_token=0.0000001, + cost_per_output_token=0.0000004, + notes="Fast, cheap, good vision", + ), + "gemini-2.0-pro": ModelInfo( + id="gemini-2.0-pro", + vision=True, + cost_per_input_token=0.00000125, + cost_per_output_token=0.000005, + notes="Higher quality, slower", + ), + "gemini-1.5-flash": ModelInfo( + id="gemini-1.5-flash", + vision=True, + cost_per_input_token=0.000000075, + cost_per_output_token=0.0000003, + notes="Cheapest option", + ), +} + + +class GeminiProvider: + name = "gemini" + models = MODELS + + def __init__(self): + self.api_key = GEMINI_API_KEY + self.model = GEMINI_MODEL + self.endpoint = ( + f"https://generativelanguage.googleapis.com/v1beta/models/" + f"{self.model}:generateContent" + ) + + def call(self, image_b64: str, prompt: str) -> ProviderResponse: + payload = { + "contents": [{ + "parts": [ + {"text": prompt}, + {"inline_data": {"mime_type": "image/jpeg", "data": image_b64}}, + ], + }], + "generationConfig": {"maxOutputTokens": 150}, + } + + url = f"{self.endpoint}?key={self.api_key}" + resp = requests.post(url, json=payload, timeout=30) + resp.raise_for_status() + data = resp.json() + + answer = data["candidates"][0]["content"]["parts"][0]["text"].strip() + usage = data.get("usageMetadata", {}) + total_tokens = usage.get("totalTokenCount", 0) + + return ProviderResponse(answer=answer, total_tokens=total_tokens) diff --git a/detect/providers/groq.py b/detect/providers/groq.py new file mode 100644 index 0000000..9223255 --- /dev/null +++ b/detect/providers/groq.py @@ -0,0 +1,66 @@ +"""Groq cloud provider — OpenAI-compatible API with vision.""" + +from __future__ import annotations + +import logging +import os + +import requests + +from .base import ModelInfo, ProviderResponse + +logger = logging.getLogger(__name__) + +# Groq-specific env vars +GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "") +GROQ_BASE_URL = os.environ.get("GROQ_BASE_URL", "https://api.groq.com/openai/v1") +GROQ_MODEL = os.environ.get("GROQ_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct") + +MODELS = { + "meta-llama/llama-4-scout-17b-16e-instruct": ModelInfo( + id="meta-llama/llama-4-scout-17b-16e-instruct", + vision=True, + cost_per_input_token=0.0, + cost_per_output_token=0.0, + notes="Llama 4 Scout, only vision model on Groq free tier", + ), +} + + +class GroqProvider: + name = "groq" + models = MODELS + + def __init__(self): + self.api_key = GROQ_API_KEY + self.base_url = GROQ_BASE_URL + self.model = GROQ_MODEL + self.endpoint = f"{self.base_url.rstrip('/')}/chat/completions" + self.headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + def call(self, image_b64: str, prompt: str) -> ProviderResponse: + payload = { + "model": self.model, + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": { + "url": f"data:image/jpeg;base64,{image_b64}", + }}, + ], + }], + "max_tokens": 150, + } + + resp = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=30) + resp.raise_for_status() + data = resp.json() + + answer = data["choices"][0]["message"]["content"].strip() + total_tokens = data.get("usage", {}).get("total_tokens", 0) + + return ProviderResponse(answer=answer, total_tokens=total_tokens) diff --git a/detect/providers/openai_compat.py b/detect/providers/openai_compat.py new file mode 100644 index 0000000..4d11fa3 --- /dev/null +++ b/detect/providers/openai_compat.py @@ -0,0 +1,73 @@ +"""Generic OpenAI-compatible provider (OpenAI, Together, etc.).""" + +from __future__ import annotations + +import logging +import os + +import requests + +from .base import ModelInfo, ProviderResponse + +logger = logging.getLogger(__name__) + +# OpenAI-compat specific env vars +OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") +OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1") +OPENAI_MODEL = os.environ.get("OPENAI_MODEL", "gpt-4o-mini") + +MODELS = { + "gpt-4o-mini": ModelInfo( + id="gpt-4o-mini", + vision=True, + cost_per_input_token=0.00000015, + cost_per_output_token=0.0000006, + notes="Cheap, fast, decent vision", + ), + "gpt-4o": ModelInfo( + id="gpt-4o", + vision=True, + cost_per_input_token=0.0000025, + cost_per_output_token=0.00001, + notes="Best OpenAI vision model", + ), +} + + +class OpenAICompatProvider: + name = "openai" + models = MODELS + + def __init__(self): + self.api_key = OPENAI_API_KEY + self.base_url = OPENAI_BASE_URL + self.model = OPENAI_MODEL + self.endpoint = f"{self.base_url.rstrip('/')}/chat/completions" + self.headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + def call(self, image_b64: str, prompt: str) -> ProviderResponse: + payload = { + "model": self.model, + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": { + "url": f"data:image/jpeg;base64,{image_b64}", + }}, + ], + }], + "max_tokens": 150, + } + + resp = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=30) + resp.raise_for_status() + data = resp.json() + + answer = data["choices"][0]["message"]["content"].strip() + total_tokens = data.get("usage", {}).get("total_tokens", 0) + + return ProviderResponse(answer=answer, total_tokens=total_tokens) diff --git a/detect/stages/aggregator.py b/detect/stages/aggregator.py new file mode 100644 index 0000000..433f7f5 --- /dev/null +++ b/detect/stages/aggregator.py @@ -0,0 +1,116 @@ +""" +Stage 8 — Report compilation + +Groups all detections by brand, merges contiguous appearances, +and builds the final DetectionReport. +""" + +from __future__ import annotations + +import logging + +from detect import emit +from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats + +logger = logging.getLogger(__name__) + + +def _merge_contiguous(detections: list[BrandDetection], gap_threshold: float = 2.0) -> list[BrandDetection]: + """ + Merge detections of the same brand that are close in time. + + If two detections of the same brand are within gap_threshold seconds, + they're merged into one detection spanning the full range. + """ + if not detections: + return [] + + sorted_dets = sorted(detections, key=lambda d: (d.brand, d.timestamp)) + merged: list[BrandDetection] = [] + current = sorted_dets[0] + + for det in sorted_dets[1:]: + if (det.brand == current.brand + and det.timestamp <= current.timestamp + current.duration + gap_threshold): + end = max(current.timestamp + current.duration, + det.timestamp + det.duration) + current = BrandDetection( + brand=current.brand, + timestamp=current.timestamp, + duration=end - current.timestamp, + confidence=max(current.confidence, det.confidence), + source=current.source, + bbox=current.bbox, + frame_ref=current.frame_ref, + content_type=current.content_type, + ) + else: + merged.append(current) + current = det + + merged.append(current) + return merged + + +def compile_report( + detections: list[BrandDetection], + stats: PipelineStats, + video_source: str = "", + content_type: str = "", + duration_seconds: float = 0.0, + job_id: str | None = None, +) -> DetectionReport: + """ + Build the final detection report from all accumulated detections. + + Merges contiguous detections, computes per-brand stats, + and emits the job_complete event. + """ + merged = _merge_contiguous(detections) + + brands: dict[str, BrandStats] = {} + for d in merged: + if d.brand not in brands: + brands[d.brand] = BrandStats() + s = brands[d.brand] + s.total_appearances += 1 + s.total_screen_time += d.duration + s.avg_confidence = ( + (s.avg_confidence * (s.total_appearances - 1) + d.confidence) + / s.total_appearances + ) + if s.first_seen == 0.0 or d.timestamp < s.first_seen: + s.first_seen = d.timestamp + if d.timestamp > s.last_seen: + s.last_seen = d.timestamp + + report = DetectionReport( + video_source=video_source, + content_type=content_type, + duration_seconds=duration_seconds, + brands=brands, + timeline=sorted(merged, key=lambda d: d.timestamp), + pipeline_stats=stats, + ) + + emit.log(job_id, "Aggregator", "INFO", + f"Report: {len(brands)} brands, {len(merged)} detections " + f"(merged from {len(detections)} raw)") + + emit.job_complete(job_id, { + "video_source": report.video_source, + "content_type": report.content_type, + "duration_seconds": report.duration_seconds, + "brands": { + k: { + "total_appearances": v.total_appearances, + "total_screen_time": v.total_screen_time, + "avg_confidence": round(v.avg_confidence, 3), + "first_seen": v.first_seen, + "last_seen": v.last_seen, + } + for k, v in brands.items() + }, + }) + + return report diff --git a/detect/stages/vlm_cloud.py b/detect/stages/vlm_cloud.py new file mode 100644 index 0000000..c0e2ab6 --- /dev/null +++ b/detect/stages/vlm_cloud.py @@ -0,0 +1,168 @@ +""" +Stage 7 — Cloud LLM escalation + +Last resort for crops the local VLM couldn't resolve. +Provider-agnostic — switch via CLOUD_LLM_PROVIDER env var. +Each provider has its own file under detect/providers/. + +Tracks token usage and cost. +""" + +from __future__ import annotations + +import base64 +import io +import logging + +import numpy as np +from PIL import Image + +from detect import emit +from detect.models import BrandDetection, PipelineStats, TextCandidate +from detect.profiles.base import CropContext +from detect.providers import get_provider, has_api_key + +logger = logging.getLogger(__name__) + +ESTIMATED_TOKENS_PER_CROP = 500 + + +def _encode_crop(crop: np.ndarray) -> str: + img = Image.fromarray(crop) + buf = io.BytesIO() + img.save(buf, format="JPEG", quality=85) + return base64.b64encode(buf.getvalue()).decode() + + +def _crop_image(candidate: TextCandidate) -> np.ndarray: + frame = candidate.frame + box = candidate.bbox + h, w = frame.image.shape[:2] + x1 = max(0, box.x) + y1 = max(0, box.y) + x2 = min(w, box.x + box.w) + y2 = min(h, box.y + box.h) + return frame.image[y1:y2, x1:x2] + + +def _parse_response(answer: str, total_tokens: int) -> dict: + """Parse LLM free-text response into structured output.""" + parts = [p.strip() for p in answer.split(",", 2)] + + brand = parts[0] if parts else "" + confidence = 0.5 + reasoning = answer + + if len(parts) >= 2: + try: + confidence = float(parts[1]) + confidence = max(0.0, min(1.0, confidence)) + except ValueError: + pass + + if len(parts) >= 3: + reasoning = parts[2] + + return { + "brand": brand, + "confidence": confidence, + "reasoning": reasoning, + "tokens": total_tokens or ESTIMATED_TOKENS_PER_CROP, + } + + +def _call_cloud_api(image_b64: str, prompt: str) -> dict: + """Route to the configured provider and parse the response.""" + provider = get_provider() + result = provider.call(image_b64, prompt) + return _parse_response(result.answer, result.total_tokens) + + +def escalate_cloud( + candidates: list[TextCandidate], + vlm_prompt_fn, + stats: PipelineStats, + min_confidence: float = 0.4, + content_type: str = "", + job_id: str | None = None, +) -> list[BrandDetection]: + """ + Send remaining unresolved crops to cloud LLM. + + Provider is selected via CLOUD_LLM_PROVIDER env var (groq, gemini, openai). + Updates stats with call count and cost. + """ + if not candidates: + return [] + + if not has_api_key(): + emit.log(job_id, "CloudLLM", "WARNING", + f"No API key set for cloud provider, skipping {len(candidates)} crops") + return [] + + provider = get_provider() + emit.log(job_id, "CloudLLM", "INFO", + f"Escalating {len(candidates)} crops to {provider.name}") + + matched: list[BrandDetection] = [] + total_cost = 0.0 + + for candidate in candidates: + crop = _crop_image(candidate) + if crop.size == 0: + continue + + crop_context = CropContext( + image=b"", + surrounding_text=candidate.text, + position_hint=f"frame {candidate.frame.sequence}", + ) + prompt = vlm_prompt_fn(crop_context) + image_b64 = _encode_crop(crop) + + try: + result = _call_cloud_api(image_b64, prompt) + except Exception as e: + logger.warning("Cloud LLM failed for '%s': %s", candidate.text, e) + continue + + stats.cloud_llm_calls += 1 + model_info = provider.models.get(provider.model) + cost_per_token = model_info.cost_per_input_token if model_info else 0.00001 + call_cost = result["tokens"] * cost_per_token + total_cost += call_cost + + brand = result["brand"] + confidence = result["confidence"] + + if brand and confidence >= min_confidence: + detection = BrandDetection( + brand=brand, + timestamp=candidate.frame.timestamp, + duration=0.5, + confidence=confidence, + source="cloud_llm", + bbox=candidate.bbox, + frame_ref=candidate.frame.sequence, + content_type=content_type, + ) + matched.append(detection) + + emit.detection( + job_id, + brand=brand, + confidence=confidence, + source="cloud_llm", + timestamp=candidate.frame.timestamp, + content_type=content_type, + frame_ref=candidate.frame.sequence, + ) + + stats.estimated_cloud_cost_usd += total_cost + stats.regions_escalated_to_cloud_llm = len(candidates) + + emit.log(job_id, "CloudLLM", "INFO", + f"Cloud resolved {len(matched)}/{len(candidates)} — " + f"cost ${total_cost:.4f} ({stats.cloud_llm_calls} calls total)") + + return matched diff --git a/detect/stages/vlm_local.py b/detect/stages/vlm_local.py new file mode 100644 index 0000000..1f4987e --- /dev/null +++ b/detect/stages/vlm_local.py @@ -0,0 +1,124 @@ +""" +Stage 6 — Local VLM escalation (moondream2) + +Processes unresolved text candidates by sending crop images + prompt +to the local VLM on the inference server. Produces BrandDetection +objects for crops the VLM can identify. +""" + +from __future__ import annotations + +import logging + +import numpy as np + +from detect import emit +from detect.models import BrandDetection, TextCandidate +from detect.profiles.base import CropContext + +logger = logging.getLogger(__name__) + + +def _crop_image(candidate: TextCandidate) -> np.ndarray: + frame = candidate.frame + box = candidate.bbox + h, w = frame.image.shape[:2] + x1 = max(0, box.x) + y1 = max(0, box.y) + x2 = min(w, box.x + box.w) + y2 = min(h, box.y + box.h) + return frame.image[y1:y2, x1:x2] + + +def escalate_vlm( + candidates: list[TextCandidate], + vlm_prompt_fn, + inference_url: str | None = None, + min_confidence: float = 0.5, + content_type: str = "", + job_id: str | None = None, +) -> tuple[list[BrandDetection], list[TextCandidate]]: + """ + Send unresolved crops to local VLM for brand identification. + + Returns: + - matched: BrandDetections the VLM confirmed + - still_unresolved: candidates the VLM couldn't resolve (→ cloud escalation) + """ + if not candidates: + return [], [] + + emit.log(job_id, "VLMLocal", "INFO", + f"Processing {len(candidates)} unresolved crops with moondream2") + + matched: list[BrandDetection] = [] + still_unresolved: list[TextCandidate] = [] + + if inference_url: + from detect.inference import InferenceClient + client = InferenceClient(base_url=inference_url) + + for candidate in candidates: + crop = _crop_image(candidate) + if crop.size == 0: + still_unresolved.append(candidate) + continue + + crop_context = CropContext( + image=b"", # not used for prompt generation + surrounding_text=candidate.text, + position_hint=f"frame {candidate.frame.sequence}", + ) + prompt = vlm_prompt_fn(crop_context) + + try: + if inference_url: + result = client.vlm(image=crop, prompt=prompt) + brand = result.brand + confidence = result.confidence + reasoning = result.reasoning + else: + brand, confidence, reasoning = _vlm_local(crop, prompt) + except Exception as e: + logger.warning("VLM failed for candidate '%s': %s", candidate.text, e) + still_unresolved.append(candidate) + continue + + if brand and confidence >= min_confidence: + detection = BrandDetection( + brand=brand, + timestamp=candidate.frame.timestamp, + duration=0.5, + confidence=confidence, + source="local_vlm", + bbox=candidate.bbox, + frame_ref=candidate.frame.sequence, + content_type=content_type, + ) + matched.append(detection) + + emit.detection( + job_id, + brand=brand, + confidence=confidence, + source="local_vlm", + timestamp=candidate.frame.timestamp, + content_type=content_type, + frame_ref=candidate.frame.sequence, + ) + + logger.debug("VLM matched: %s (%.2f) — %s", brand, confidence, reasoning) + else: + still_unresolved.append(candidate) + + emit.log(job_id, "VLMLocal", "INFO", + f"VLM resolved {len(matched)}, unresolved {len(still_unresolved)} → cloud") + + return matched, still_unresolved + + +def _vlm_local(crop: np.ndarray, prompt: str) -> tuple[str, float, str]: + """Run moondream2 in-process (single-box mode).""" + from gpu.models.vlm import query + result = query(crop, prompt) + return result["brand"], result["confidence"], result["reasoning"] diff --git a/gpu/models/vlm.py b/gpu/models/vlm.py new file mode 100644 index 0000000..5e1e3c3 --- /dev/null +++ b/gpu/models/vlm.py @@ -0,0 +1,100 @@ +"""moondream2 visual language model wrapper.""" + +from __future__ import annotations + +import logging + +from models import registry +from config import get_config + +logger = logging.getLogger(__name__) + +_MODEL_KEY = "vlm_moondream2" + + +def _load(): + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer + + device = get_config().get("device", "auto") + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + + logger.info("Loading moondream2 (device=%s)...", device) + + model_id = "vikhyatk/moondream2" + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + dtype = torch.float16 if "cuda" in device else torch.float32 + model = AutoModelForCausalLM.from_pretrained( + model_id, + trust_remote_code=True, + dtype=dtype, + device_map=device, + ) + + wrapper = {"model": model, "tokenizer": tokenizer} + registry.put(_MODEL_KEY, wrapper) + logger.info("moondream2 loaded") + return wrapper + + +def _get(): + wrapper = registry.get(_MODEL_KEY) + if wrapper is None: + wrapper = _load() + return wrapper + + +def query(image, prompt: str) -> dict: + """ + Query moondream2 with an image crop and prompt. + + Returns {"brand": str, "confidence": float, "reasoning": str} + """ + from PIL import Image as PILImage + + wrapper = _get() + model = wrapper["model"] + tokenizer = wrapper["tokenizer"] + + # Convert numpy array to PIL if needed + if not isinstance(image, PILImage.Image): + image = PILImage.fromarray(image) + + enc_image = model.encode_image(image) + answer = model.answer_question(enc_image, prompt, tokenizer) + + # Parse response — moondream2 returns free text, extract brand + confidence + result = _parse_vlm_response(answer) + return result + + +def _parse_vlm_response(answer: str) -> dict: + """ + Parse moondream2 free-text response into structured output. + + Expected format from prompt: "brand, confidence (0-1), reasoning" + Falls back gracefully if format doesn't match. + """ + answer = answer.strip() + parts = [p.strip() for p in answer.split(",", 2)] + + brand = parts[0] if parts else "" + confidence = 0.5 + reasoning = answer + + if len(parts) >= 2: + try: + confidence = float(parts[1]) + confidence = max(0.0, min(1.0, confidence)) + except ValueError: + pass + + if len(parts) >= 3: + reasoning = parts[2] + + return { + "brand": brand, + "confidence": confidence, + "reasoning": reasoning, + } diff --git a/gpu/requirements.txt b/gpu/requirements.txt index 6b1e8a0..771aadb 100644 --- a/gpu/requirements.txt +++ b/gpu/requirements.txt @@ -19,3 +19,9 @@ ultralytics>=8.0.0 # Install with: # uv pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ paddleocr>=3.0.0 + +# VLM (moondream2) — uses torch (already installed above) +# Pinned <5: transformers 5.x broke moondream2's custom model code +# (all_tied_weights_keys API change). Also needs accelerate for device_map. +transformers>=4.40.0,<5 +accelerate>=0.27.0 diff --git a/gpu/server.py b/gpu/server.py index 0ba7f6e..e9a46fa 100644 --- a/gpu/server.py +++ b/gpu/server.py @@ -25,6 +25,7 @@ from config import get_config, get_device, update_config from models import registry from models.yolo import detect as yolo_detect from models.ocr import ocr as ocr_run +from models.vlm import query as vlm_query logger = logging.getLogger(__name__) @@ -72,6 +73,18 @@ class OCRResponse(BaseModel): results: list[OCRTextResult] +class VLMRequest(BaseModel): + image: str + prompt: str + model: str | None = None + + +class VLMResponse(BaseModel): + brand: str + confidence: float + reasoning: str + + class ConfigUpdate(BaseModel): device: str | None = None yolo_model: str | None = None @@ -170,6 +183,21 @@ def ocr(req: OCRRequest): return OCRResponse(results=[OCRTextResult(**r) for r in results]) +@app.post("/vlm", response_model=VLMResponse) +def vlm(req: VLMRequest): + try: + image = _decode_image(req.image) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Bad image: {e}") + + try: + result = vlm_query(image, req.prompt) + except Exception as e: + raise HTTPException(status_code=500, detail=f"VLM failed: {e}") + + return VLMResponse(**result) + + if __name__ == "__main__": import uvicorn diff --git a/requirements.txt b/requirements.txt index f19800a..0acae65 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,6 +29,9 @@ strawberry-graphql[fastapi]>=0.311.0 # Observability langfuse>=2.0.0 +# Cloud LLM providers (only needed for cloud escalation stage) +anthropic>=0.40.0 + # Testing pytest>=7.4.0 pytest-django>=4.7.0 diff --git a/tests/detect/manual/test_cloud_provider.py b/tests/detect/manual/test_cloud_provider.py new file mode 100644 index 0000000..1a7936a --- /dev/null +++ b/tests/detect/manual/test_cloud_provider.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +""" +Test cloud LLM provider with a real API call. + +Sends a test image to the configured cloud provider and verifies +the response. Set your provider env vars before running. + +Usage: + # Groq (default) + CLOUD_LLM_PROVIDER=groq GROQ_API_KEY=gsk_... python tests/detect/manual/test_cloud_provider.py + + # Gemini + CLOUD_LLM_PROVIDER=gemini GEMINI_API_KEY=AIza... python tests/detect/manual/test_cloud_provider.py + + # Claude + CLOUD_LLM_PROVIDER=claude ANTHROPIC_API_KEY=sk-ant-... python tests/detect/manual/test_cloud_provider.py + + # OpenAI-compatible + CLOUD_LLM_PROVIDER=openai OPENAI_API_KEY=sk-... python tests/detect/manual/test_cloud_provider.py +""" + +import base64 +import io +import logging +import os +import sys +from pathlib import Path + +# Load .env from ctrl/ (same as docker-compose uses) +env_file = Path(__file__).resolve().parents[3] / "ctrl" / ".env" +if env_file.exists(): + for line in env_file.read_text().splitlines(): + line = line.strip() + if line and not line.startswith("#") and "=" in line: + key, _, value = line.partition("=") + os.environ.setdefault(key.strip(), value.strip()) + +import numpy as np +from PIL import Image, ImageDraw, ImageFont + +sys.path.insert(0, ".") + +logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s") +logger = logging.getLogger(__name__) + + +def make_brand_image(text: str, width: int = 300, height: int = 100) -> str: + img = Image.new("RGB", (width, height), "white") + draw = ImageDraw.Draw(img) + try: + font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 42) + except (OSError, IOError): + font = ImageFont.load_default() + draw.text((10, 20), text, fill="black", font=font) + buf = io.BytesIO() + img.save(buf, "JPEG") + return base64.b64encode(buf.getvalue()).decode() + + +def main(): + from detect.providers import get_provider, has_api_key, PROVIDERS + + provider_name = os.environ.get("CLOUD_LLM_PROVIDER", "groq") + logger.info("Provider: %s", provider_name) + logger.info("Available providers: %s", list(PROVIDERS.keys())) + + if not has_api_key(): + logger.error("No API key set for provider '%s'", provider_name) + logger.error("Set the appropriate env var (see usage in docstring)") + sys.exit(1) + + provider = get_provider() + logger.info("Model: %s", provider.model) + logger.info("Available models: %s", list(provider.models.keys())) + input("\nPress Enter to start...") + + prompt = ( + "Identify the brand or sponsor visible in this image from a soccer broadcast. " + "Respond with: brand, confidence (0-1), reasoning." + ) + + test_cases = ["NIKE", "EMIRATES", "Coca-Cola", "adidas"] + + for text in test_cases: + logger.info("--- Testing: '%s' ---", text) + image_b64 = make_brand_image(text) + + try: + result = provider.call(image_b64, prompt) + logger.info(" answer: %s", result.answer) + logger.info(" tokens: %d", result.total_tokens) + + if text.lower() in result.answer.lower(): + logger.info(" PASS — found '%s' in response", text) + else: + logger.warning(" MISS — '%s' not in response (may be correct, check answer)", text) + + except Exception as e: + logger.error(" FAIL — %s: %s", type(e).__name__, e) + if hasattr(e, 'response') and e.response is not None: + logger.error(" Response: %s", e.response.text[:500]) + + logger.info("All provider tests complete.") + + +if __name__ == "__main__": + main() diff --git a/tests/detect/manual/test_escalation_e2e.py b/tests/detect/manual/test_escalation_e2e.py new file mode 100644 index 0000000..745eaaa --- /dev/null +++ b/tests/detect/manual/test_escalation_e2e.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +""" +Push a full pipeline simulation with escalation events. + +Exercises all stages including VLM and cloud escalation, with progressive +stats showing cost accumulating. Tests all panels: pipeline graph, funnel, +timeline, cost stats, brand table, and log. + +Usage: + python tests/detect/manual/test_escalation_e2e.py [--job JOB_ID] [--port PORT] [--delay SECS] + +Opens: http://mpr.local.ar/detection/?job= +""" + +import argparse +import json +import logging +import time +from datetime import datetime, timezone + +import redis + +logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s") +logger = logging.getLogger(__name__) + +NODES = ["extract_frames", "filter_scenes", "detect_objects", "run_ocr", + "match_brands", "escalate_vlm", "escalate_cloud", "compile_report"] + + +def ts(): + return datetime.now(timezone.utc).isoformat() + + +def push(r, key, event): + event["ts"] = event.get("ts", ts()) + r.rpush(key, json.dumps(event)) + return event + + +def push_graph(r, key, active_node, status, delay): + nodes = [] + for n in NODES: + if n == active_node: + nodes.append({"id": n, "status": status}) + elif NODES.index(n) < NODES.index(active_node): + nodes.append({"id": n, "status": "done"}) + else: + nodes.append({"id": n, "status": "pending"}) + push(r, key, {"event": "graph_update", "nodes": nodes}) + time.sleep(delay) + + +def push_stats(r, key, **fields): + base = { + "event": "stats_update", + "frames_extracted": 0, "frames_after_scene_filter": 0, + "regions_detected": 0, "regions_resolved_by_ocr": 0, + "regions_escalated_to_local_vlm": 0, "regions_escalated_to_cloud_llm": 0, + "cloud_llm_calls": 0, "processing_time_seconds": 0, "estimated_cloud_cost_usd": 0, + } + base.update(fields) + push(r, key, base) + + +def push_detection(r, key, brand, conf, source, timestamp, frame_ref, delay): + push(r, key, { + "event": "detection", + "brand": brand, "confidence": conf, "source": source, + "timestamp": timestamp, "duration": 0.5, + "content_type": "soccer_broadcast", "frame_ref": frame_ref, + }) + logger.info(" [%s] %s %.2f t=%.1fs", source, brand, conf, timestamp) + time.sleep(delay * 0.3) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--job", default="escalation-test") + parser.add_argument("--port", type=int, default=6382) + parser.add_argument("--delay", type=float, default=0.5) + args = parser.parse_args() + + r = redis.Redis(port=args.port, decode_responses=True) + key = f"detect_events:{args.job}" + r.delete(key) + delay = args.delay + + logger.info("Full escalation pipeline simulation → %s", key) + logger.info("Open: http://mpr.local.ar/detection/?job=%s", args.job) + input("\nPress Enter to start...") + + # --- Extract frames --- + push_graph(r, key, "extract_frames", "running", delay) + push(r, key, {"event": "log", "level": "INFO", "stage": "FrameExtractor", + "msg": "Extracting frames: match_clip.mp4 (90.0s, 1920x1080, fps=2)"}) + time.sleep(delay) + push_stats(r, key, frames_extracted=180, processing_time_seconds=4.5) + push_graph(r, key, "extract_frames", "done", delay) + + # --- Scene filter --- + push_graph(r, key, "filter_scenes", "running", delay) + push_stats(r, key, frames_extracted=180, frames_after_scene_filter=52, processing_time_seconds=6.8) + push(r, key, {"event": "log", "level": "INFO", "stage": "SceneFilter", + "msg": "Kept 52 frames (71% reduction)"}) + push_graph(r, key, "filter_scenes", "done", delay) + + # --- YOLO detect --- + push_graph(r, key, "detect_objects", "running", delay) + push(r, key, {"event": "log", "level": "INFO", "stage": "YOLODetector", + "msg": "Running yolov8n on 52 frames"}) + time.sleep(delay) + push_stats(r, key, frames_extracted=180, frames_after_scene_filter=52, + regions_detected=41, processing_time_seconds=14.2) + push_graph(r, key, "detect_objects", "done", delay) + + # --- OCR --- + push_graph(r, key, "run_ocr", "running", delay) + push(r, key, {"event": "log", "level": "INFO", "stage": "OCRStage", + "msg": "Running OCR on 41 regions (mode=remote)"}) + time.sleep(delay) + push_stats(r, key, frames_extracted=180, frames_after_scene_filter=52, + regions_detected=41, regions_resolved_by_ocr=30, processing_time_seconds=21.5) + push_graph(r, key, "run_ocr", "done", delay) + + # --- Brand matching --- + push_graph(r, key, "match_brands", "running", delay) + push(r, key, {"event": "log", "level": "INFO", "stage": "BrandResolver", + "msg": "Matching 30 candidates against 12 brands (fuzzy_threshold=75)"}) + time.sleep(delay) + + # OCR detections + ocr_brands = [ + ("Nike", 0.97, 2.0, 4), ("Nike", 0.95, 5.5, 11), ("Emirates", 0.92, 8.0, 16), + ("Adidas", 0.89, 12.0, 24), ("Coca-Cola", 0.85, 18.0, 36), + ("Nike", 0.94, 22.0, 44), ("Emirates", 0.88, 28.0, 56), + ("Adidas", 0.91, 32.0, 64), ("Nike", 0.96, 38.0, 76), + ("Emirates", 0.90, 42.0, 84), ("Coca-Cola", 0.87, 48.0, 96), + ("Nike", 0.93, 52.0, 104), ("Adidas", 0.90, 58.0, 116), + ] + for brand, conf, ts_val, fref in ocr_brands: + push_detection(r, key, brand, conf, "ocr", ts_val, fref, delay) + + push(r, key, {"event": "log", "level": "INFO", "stage": "BrandResolver", + "msg": "Exact: 10, Fuzzy: 3, Unresolved: 11 → VLM"}) + push_graph(r, key, "match_brands", "done", delay) + + # --- VLM escalation --- + push_graph(r, key, "escalate_vlm", "running", delay) + push(r, key, {"event": "log", "level": "INFO", "stage": "VLMLocal", + "msg": "Processing 11 unresolved crops with moondream2"}) + time.sleep(delay) + + vlm_brands = [ + ("Mastercard", 0.78, 15.0, 30), ("Santander", 0.74, 25.0, 50), + ("Qatar Airways", 0.81, 35.0, 70), ("Heineken", 0.76, 45.0, 90), + ("Lay's", 0.72, 55.0, 110), + ] + for brand, conf, ts_val, fref in vlm_brands: + push_detection(r, key, brand, conf, "local_vlm", ts_val, fref, delay) + + push_stats(r, key, frames_extracted=180, frames_after_scene_filter=52, + regions_detected=41, regions_resolved_by_ocr=30, + regions_escalated_to_local_vlm=11, processing_time_seconds=38.7, + estimated_cloud_cost_usd=0) + push(r, key, {"event": "log", "level": "INFO", "stage": "VLMLocal", + "msg": "VLM resolved 5, unresolved 6 → cloud"}) + push_graph(r, key, "escalate_vlm", "done", delay) + + # --- Cloud escalation --- + push_graph(r, key, "escalate_cloud", "running", delay) + push(r, key, {"event": "log", "level": "INFO", "stage": "CloudLLM", + "msg": "Escalating 6 crops to groq (llama-3.2-90b-vision)"}) + time.sleep(delay) + + cloud_brands = [ + ("Pepsi", 0.68, 10.0, 20), + ("Gazprom", 0.65, 40.0, 80), + ] + for brand, conf, ts_val, fref in cloud_brands: + push_detection(r, key, brand, conf, "cloud_llm", ts_val, fref, delay) + + push_stats(r, key, frames_extracted=180, frames_after_scene_filter=52, + regions_detected=41, regions_resolved_by_ocr=30, + regions_escalated_to_local_vlm=11, regions_escalated_to_cloud_llm=6, + cloud_llm_calls=6, processing_time_seconds=45.2, + estimated_cloud_cost_usd=0.0) # groq free tier + + push(r, key, {"event": "log", "level": "WARNING", "stage": "CloudLLM", + "msg": "4 crops unresolved after cloud — likely not brands"}) + push(r, key, {"event": "log", "level": "INFO", "stage": "CloudLLM", + "msg": "Cloud resolved 2/6 — cost $0.0000 (groq free tier)"}) + push_graph(r, key, "escalate_cloud", "done", delay) + + # --- Compile report --- + push_graph(r, key, "compile_report", "running", delay) + + total_brands = len(set(b[0] for b in ocr_brands + vlm_brands + cloud_brands)) + total_dets = len(ocr_brands) + len(vlm_brands) + len(cloud_brands) + + push(r, key, {"event": "log", "level": "INFO", "stage": "Aggregator", + "msg": f"Report: {total_brands} brands, {total_dets} detections (merged from {total_dets} raw)"}) + + push(r, key, {"event": "job_complete", "job_id": args.job, "report": { + "video_source": "match_clip.mp4", + "content_type": "soccer_broadcast", + "duration_seconds": 90.0, + "brands": { + "Nike": {"total_appearances": 5, "avg_confidence": 0.95}, + "Emirates": {"total_appearances": 3, "avg_confidence": 0.90}, + "Adidas": {"total_appearances": 3, "avg_confidence": 0.90}, + "Coca-Cola": {"total_appearances": 2, "avg_confidence": 0.86}, + "Mastercard": {"total_appearances": 1, "avg_confidence": 0.78}, + "Santander": {"total_appearances": 1, "avg_confidence": 0.74}, + "Qatar Airways": {"total_appearances": 1, "avg_confidence": 0.81}, + "Heineken": {"total_appearances": 1, "avg_confidence": 0.76}, + "Lay's": {"total_appearances": 1, "avg_confidence": 0.72}, + "Pepsi": {"total_appearances": 1, "avg_confidence": 0.68}, + "Gazprom": {"total_appearances": 1, "avg_confidence": 0.65}, + }, + }}) + + push_graph(r, key, "compile_report", "done", delay) + + logger.info("Done. %d brands, %d detections across ocr/vlm/cloud.", total_brands, total_dets) + logger.info("Check: pipeline graph (all green), timeline (3 source colors),") + logger.info(" cost panel (escalation ratio), brand table (source column).") + + +if __name__ == "__main__": + main() diff --git a/tests/detect/manual/test_vlm_e2e.py b/tests/detect/manual/test_vlm_e2e.py new file mode 100644 index 0000000..fde9b96 --- /dev/null +++ b/tests/detect/manual/test_vlm_e2e.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +""" +Test local VLM (moondream2) via the inference server. + +Creates test images with brand text/logos, sends them to the /vlm endpoint, +verifies moondream2 can identify the brand. + +Usage: + python tests/detect/manual/test_vlm_e2e.py [--url URL] + +Requires: inference server running with moondream2 loaded (gpu/server.py) +""" + +import argparse +import base64 +import io +import logging +import sys + +import numpy as np +import requests +from PIL import Image, ImageDraw, ImageFont + +logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s") +logger = logging.getLogger(__name__) + + +def make_brand_image(text: str, width: int = 300, height: int = 100) -> np.ndarray: + img = Image.new("RGB", (width, height), "white") + draw = ImageDraw.Draw(img) + try: + font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 42) + except (OSError, IOError): + font = ImageFont.load_default() + draw.text((10, 20), text, fill="black", font=font) + return np.array(img) + + +def image_to_b64(image: np.ndarray) -> str: + img = Image.fromarray(image) + buf = io.BytesIO() + img.save(buf, "JPEG") + return base64.b64encode(buf.getvalue()).decode() + + +def test_health(url: str): + logger.info("--- Health check ---") + resp = requests.get(f"{url}/health") + resp.raise_for_status() + data = resp.json() + logger.info("Status: %s, device: %s, models: %s", data["status"], data["device"], data.get("loaded_models", [])) + + +def test_vlm(url: str, text: str, prompt: str): + logger.info("--- VLM: image='%s' ---", text) + image = make_brand_image(text) + b64 = image_to_b64(image) + + resp = requests.post(f"{url}/vlm", json={"image": b64, "prompt": prompt}) + resp.raise_for_status() + data = resp.json() + + logger.info(" brand: %s", data["brand"]) + logger.info(" confidence: %.2f", data["confidence"]) + logger.info(" reasoning: %s", data["reasoning"]) + + if text.lower() in data["brand"].lower(): + logger.info(" PASS — matched") + else: + logger.warning(" MISS — expected '%s' in response", text) + + return data + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--url", default="http://mcrndeb:8000") + args = parser.parse_args() + + url = args.url.rstrip("/") + logger.info("Inference server: %s", url) + input("\nPress Enter to start...") + + test_health(url) + + prompt = ( + "Identify the brand or sponsor visible in this image from a soccer broadcast. " + "Respond with: brand, confidence (0-1), reasoning." + ) + + test_vlm(url, "NIKE", prompt) + test_vlm(url, "EMIRATES", prompt) + test_vlm(url, "Coca-Cola", prompt) + test_vlm(url, "adidas", prompt) + + logger.info("All VLM tests complete.") + + +if __name__ == "__main__": + main() diff --git a/tests/detect/test_aggregator.py b/tests/detect/test_aggregator.py new file mode 100644 index 0000000..7d268fa --- /dev/null +++ b/tests/detect/test_aggregator.py @@ -0,0 +1,79 @@ +"""Tests for the report aggregator stage.""" + +import pytest + +from detect.models import BoundingBox, BrandDetection, PipelineStats +from detect.stages.aggregator import compile_report, _merge_contiguous + + +def _make_detection(brand: str, timestamp: float, duration: float = 0.5, + source: str = "ocr", confidence: float = 0.9) -> BrandDetection: + return BrandDetection( + brand=brand, timestamp=timestamp, duration=duration, + confidence=confidence, source=source, content_type="soccer_broadcast", + ) + + +def test_merge_contiguous_same_brand(): + dets = [ + _make_detection("Nike", 1.0, 0.5), + _make_detection("Nike", 1.3, 0.5), # within gap + _make_detection("Nike", 5.0, 0.5), # separate + ] + merged = _merge_contiguous(dets, gap_threshold=2.0) + assert len(merged) == 2 + assert merged[0].brand == "Nike" + assert merged[0].timestamp == 1.0 + assert merged[0].duration == pytest.approx(0.8) # 1.0 to 1.8 + assert merged[1].timestamp == 5.0 + + +def test_merge_different_brands(): + dets = [ + _make_detection("Nike", 1.0), + _make_detection("Adidas", 1.5), + ] + merged = _merge_contiguous(dets) + assert len(merged) == 2 + + +def test_merge_empty(): + assert _merge_contiguous([]) == [] + + +def test_compile_report(monkeypatch): + events = [] + monkeypatch.setattr("detect.emit.push_detect_event", + lambda job_id, etype, data: events.append((etype, data))) + + dets = [ + _make_detection("Nike", 1.0, 0.5, confidence=0.95), + _make_detection("Nike", 5.0, 1.0, confidence=0.90), + _make_detection("Adidas", 3.0, 0.5, confidence=0.85), + _make_detection("Heineken", 10.0, 0.5, source="cloud_llm", confidence=0.70), + ] + stats = PipelineStats( + frames_extracted=120, + regions_detected=32, + cloud_llm_calls=1, + estimated_cloud_cost_usd=0.003, + ) + + report = compile_report( + detections=dets, + stats=stats, + video_source="test.mp4", + content_type="soccer_broadcast", + job_id="test-report", + ) + + assert len(report.brands) == 3 + assert report.brands["Nike"].total_appearances == 2 + assert report.brands["Adidas"].total_appearances == 1 + assert report.brands["Heineken"].total_appearances == 1 + assert report.pipeline_stats.cloud_llm_calls == 1 + assert report.video_source == "test.mp4" + + # job_complete event should have been emitted + complete = [e for e in events if e[0] == "job_complete"] + assert len(complete) == 1 diff --git a/tests/detect/test_vlm_cloud.py b/tests/detect/test_vlm_cloud.py new file mode 100644 index 0000000..1538a0b --- /dev/null +++ b/tests/detect/test_vlm_cloud.py @@ -0,0 +1,92 @@ +"""Tests for cloud LLM escalation stage.""" + +import numpy as np +import pytest + +from detect.models import BoundingBox, Frame, PipelineStats, TextCandidate +from detect.stages.vlm_cloud import escalate_cloud, _parse_response + + +def _make_candidate(text: str = "unknown", confidence: float = 0.4) -> TextCandidate: + frame = Frame(sequence=0, chunk_id=0, timestamp=1.0, + image=np.zeros((50, 100, 3), dtype=np.uint8)) + box = BoundingBox(x=0, y=0, w=100, h=50, confidence=0.5, label="text") + return TextCandidate(frame=frame, bbox=box, text=text, ocr_confidence=confidence) + + +def test_parse_response_clean(): + result = _parse_response("Nike, 0.92, swoosh logo visible", 200) + assert result["brand"] == "Nike" + assert result["confidence"] == 0.92 + assert "swoosh" in result["reasoning"] + assert result["tokens"] == 200 + + +def test_parse_response_no_confidence(): + result = _parse_response("Adidas", 0) + assert result["brand"] == "Adidas" + assert result["confidence"] == 0.5 # default + + +def test_escalate_skips_without_api_key(monkeypatch): + events = [] + monkeypatch.setattr("detect.emit.push_detect_event", + lambda job_id, etype, data: events.append((etype, data))) + monkeypatch.delenv("GROQ_API_KEY", raising=False) + monkeypatch.delenv("GEMINI_API_KEY", raising=False) + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq") + # Reset cached provider + import detect.providers as prov + monkeypatch.setattr(prov, "_cached", None) + + candidates = [_make_candidate()] + stats = PipelineStats() + prompt_fn = lambda ctx: "what brand?" + + matched = escalate_cloud(candidates, prompt_fn, stats, job_id="test") + + assert len(matched) == 0 + assert stats.cloud_llm_calls == 0 + log_events = [e for e in events if e[0] == "log"] + assert any("No API key" in e[1].get("msg", "") for e in log_events) + + +def test_escalate_empty_candidates(monkeypatch): + events = [] + monkeypatch.setattr("detect.emit.push_detect_event", + lambda job_id, etype, data: events.append((etype, data))) + + stats = PipelineStats() + matched = escalate_cloud([], lambda ctx: "", stats, job_id="test") + + assert len(matched) == 0 + assert stats.cloud_llm_calls == 0 + + +def test_escalate_with_mock_api(monkeypatch): + events = [] + monkeypatch.setattr("detect.emit.push_detect_event", + lambda job_id, etype, data: events.append((etype, data))) + monkeypatch.setenv("GROQ_API_KEY", "test-key") + monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq") + # Reset cached provider + import detect.providers as prov + monkeypatch.setattr(prov, "_cached", None) + + def mock_call(image_b64, prompt): + return {"brand": "Heineken", "confidence": 0.75, "reasoning": "green logo", "tokens": 300} + + monkeypatch.setattr("detect.stages.vlm_cloud._call_cloud_api", mock_call) + + candidates = [_make_candidate("unknown logo")] + stats = PipelineStats() + prompt_fn = lambda ctx: "what brand?" + + matched = escalate_cloud(candidates, prompt_fn, stats, job_id="test") + + assert len(matched) == 1 + assert matched[0].brand == "Heineken" + assert matched[0].source == "cloud_llm" + assert stats.cloud_llm_calls == 1 + assert stats.estimated_cloud_cost_usd >= 0 # exact cost depends on provider model index