phase 9
This commit is contained in:
@@ -0,0 +1,11 @@
|
|||||||
|
---
|
||||||
|
name: agent_sdk_future
|
||||||
|
description: Claude Agent SDK for general mpr tasks (not vision provider), uses OAuth not API keys
|
||||||
|
type: project
|
||||||
|
---
|
||||||
|
|
||||||
|
Claude Agent SDK (`claude-agent-sdk`) is for future general-purpose tasks in mpr, NOT for the cloud vision provider.
|
||||||
|
|
||||||
|
**Why:** The Agent SDK uses Claude Code CLI's OAuth (browser login, no API keys) and is designed for agentic tasks (file read/edit, bash, web search). The vision provider needs raw API calls with image payloads — use the `anthropic` SDK with `ANTHROPIC_API_KEY` for that.
|
||||||
|
|
||||||
|
**How to apply:** When adding Claude-powered automation to mpr (e.g., log analysis, config suggestions, code review on pipeline changes), use the Agent SDK. For the cloud LLM escalation stage (image crops → brand ID), keep using the `anthropic` SDK with API key auth.
|
||||||
@@ -35,5 +35,29 @@ AWS_REGION=us-east-1
|
|||||||
AWS_ACCESS_KEY_ID=minioadmin
|
AWS_ACCESS_KEY_ID=minioadmin
|
||||||
AWS_SECRET_ACCESS_KEY=minioadmin
|
AWS_SECRET_ACCESS_KEY=minioadmin
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
INFERENCE_URL=http://mcrndeb:8000
|
||||||
|
|
||||||
|
# Cloud LLM (detection pipeline escalation)
|
||||||
|
# Set CLOUD_LLM_PROVIDER to: groq, gemini, claude, openai
|
||||||
|
CLOUD_LLM_PROVIDER=groq
|
||||||
|
|
||||||
|
# Groq (default, free tier)
|
||||||
|
GROQ_API_KEY=
|
||||||
|
GROQ_MODEL=llama-3.2-90b-vision-preview
|
||||||
|
|
||||||
|
# Gemini
|
||||||
|
#GEMINI_API_KEY=
|
||||||
|
#GEMINI_MODEL=gemini-2.0-flash
|
||||||
|
|
||||||
|
# Claude (uses anthropic SDK)
|
||||||
|
#ANTHROPIC_API_KEY=
|
||||||
|
#CLAUDE_MODEL=claude-sonnet-4-20250514
|
||||||
|
|
||||||
|
# OpenAI-compatible
|
||||||
|
#OPENAI_API_KEY=
|
||||||
|
#OPENAI_MODEL=gpt-4o-mini
|
||||||
|
#OPENAI_BASE_URL=https://api.openai.com/v1
|
||||||
|
|
||||||
# Vite
|
# Vite
|
||||||
VITE_ALLOWED_HOSTS=your-domain.local
|
VITE_ALLOWED_HOSTS=your-domain.local
|
||||||
|
|||||||
@@ -20,6 +20,9 @@ from detect.stages.scene_filter import scene_filter
|
|||||||
from detect.stages.yolo_detector import detect_objects
|
from detect.stages.yolo_detector import detect_objects
|
||||||
from detect.stages.ocr_stage import run_ocr
|
from detect.stages.ocr_stage import run_ocr
|
||||||
from detect.stages.brand_resolver import resolve_brands
|
from detect.stages.brand_resolver import resolve_brands
|
||||||
|
from detect.stages.vlm_local import escalate_vlm
|
||||||
|
from detect.stages.vlm_cloud import escalate_cloud
|
||||||
|
from detect.stages.aggregator import compile_report
|
||||||
from detect.tracing import trace_node, flush as flush_traces
|
from detect.tracing import trace_node, flush as flush_traces
|
||||||
|
|
||||||
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
|
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
|
||||||
@@ -158,43 +161,77 @@ def node_escalate_vlm(state: DetectState) -> dict:
|
|||||||
_emit_transition(state, "escalate_vlm", "running")
|
_emit_transition(state, "escalate_vlm", "running")
|
||||||
|
|
||||||
with trace_node(state, "escalate_vlm") as span:
|
with trace_node(state, "escalate_vlm") as span:
|
||||||
|
profile = _get_profile(state)
|
||||||
|
candidates = state.get("unresolved_candidates", [])
|
||||||
job_id = state.get("job_id")
|
job_id = state.get("job_id")
|
||||||
emit.log(job_id, "VLMLocal", "INFO", "Stub: VLM escalation not yet implemented")
|
|
||||||
span.set_output({"stub": True})
|
vlm_matched, still_unresolved = escalate_vlm(
|
||||||
|
candidates,
|
||||||
|
vlm_prompt_fn=profile.vlm_prompt,
|
||||||
|
inference_url=INFERENCE_URL,
|
||||||
|
content_type=profile.name,
|
||||||
|
job_id=job_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
stats = state.get("stats", PipelineStats())
|
||||||
|
stats.regions_escalated_to_local_vlm = len(candidates)
|
||||||
|
span.set_output({"candidates": len(candidates), "matched": len(vlm_matched),
|
||||||
|
"still_unresolved": len(still_unresolved)})
|
||||||
|
|
||||||
|
existing = state.get("detections", [])
|
||||||
|
|
||||||
_emit_transition(state, "escalate_vlm", "done")
|
_emit_transition(state, "escalate_vlm", "done")
|
||||||
return {}
|
return {
|
||||||
|
"detections": existing + vlm_matched,
|
||||||
|
"unresolved_candidates": still_unresolved,
|
||||||
|
"stats": stats,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def node_escalate_cloud(state: DetectState) -> dict:
|
def node_escalate_cloud(state: DetectState) -> dict:
|
||||||
_emit_transition(state, "escalate_cloud", "running")
|
_emit_transition(state, "escalate_cloud", "running")
|
||||||
|
|
||||||
with trace_node(state, "escalate_cloud") as span:
|
with trace_node(state, "escalate_cloud") as span:
|
||||||
|
profile = _get_profile(state)
|
||||||
|
candidates = state.get("unresolved_candidates", [])
|
||||||
job_id = state.get("job_id")
|
job_id = state.get("job_id")
|
||||||
emit.log(job_id, "CloudLLM", "INFO", "Stub: cloud LLM escalation not yet implemented")
|
stats = state.get("stats", PipelineStats())
|
||||||
span.set_output({"stub": True})
|
|
||||||
|
cloud_matched = escalate_cloud(
|
||||||
|
candidates,
|
||||||
|
vlm_prompt_fn=profile.vlm_prompt,
|
||||||
|
stats=stats,
|
||||||
|
content_type=profile.name,
|
||||||
|
job_id=job_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
span.set_output({"candidates": len(candidates), "matched": len(cloud_matched),
|
||||||
|
"cloud_calls": stats.cloud_llm_calls,
|
||||||
|
"cost_usd": stats.estimated_cloud_cost_usd})
|
||||||
|
|
||||||
|
existing = state.get("detections", [])
|
||||||
|
|
||||||
_emit_transition(state, "escalate_cloud", "done")
|
_emit_transition(state, "escalate_cloud", "done")
|
||||||
return {}
|
return {"detections": existing + cloud_matched, "stats": stats}
|
||||||
|
|
||||||
|
|
||||||
def node_compile_report(state: DetectState) -> dict:
|
def node_compile_report(state: DetectState) -> dict:
|
||||||
_emit_transition(state, "compile_report", "running")
|
_emit_transition(state, "compile_report", "running")
|
||||||
|
|
||||||
with trace_node(state, "compile_report") as span:
|
with trace_node(state, "compile_report") as span:
|
||||||
job_id = state.get("job_id")
|
|
||||||
profile = _get_profile(state)
|
profile = _get_profile(state)
|
||||||
detections = state.get("detections", [])
|
detections = state.get("detections", [])
|
||||||
report = profile.aggregate(detections)
|
stats = state.get("stats", PipelineStats())
|
||||||
report.video_source = state.get("video_path", "")
|
job_id = state.get("job_id")
|
||||||
|
|
||||||
|
report = compile_report(
|
||||||
|
detections=detections,
|
||||||
|
stats=stats,
|
||||||
|
video_source=state.get("video_path", ""),
|
||||||
|
content_type=profile.name,
|
||||||
|
job_id=job_id,
|
||||||
|
)
|
||||||
|
|
||||||
emit.log(job_id, "Aggregator", "INFO",
|
|
||||||
f"Report: {len(report.brands)} brands, {len(report.timeline)} detections")
|
|
||||||
emit.job_complete(job_id, {
|
|
||||||
"video_source": report.video_source,
|
|
||||||
"content_type": report.content_type,
|
|
||||||
"brands": {k: {"total_appearances": v.total_appearances} for k, v in report.brands.items()},
|
|
||||||
})
|
|
||||||
span.set_output({"brands": len(report.brands), "detections": len(report.timeline)})
|
span.set_output({"brands": len(report.brands), "detections": len(report.timeline)})
|
||||||
|
|
||||||
flush_traces()
|
flush_traces()
|
||||||
|
|||||||
58
detect/providers/__init__.py
Normal file
58
detect/providers/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""
|
||||||
|
Cloud LLM provider registry.
|
||||||
|
|
||||||
|
Select provider via CLOUD_LLM_PROVIDER env var.
|
||||||
|
Each provider reads its own env vars for auth/config.
|
||||||
|
|
||||||
|
CLOUD_LLM_PROVIDER=groq → GROQ_API_KEY, GROQ_MODEL, GROQ_BASE_URL
|
||||||
|
CLOUD_LLM_PROVIDER=gemini → GEMINI_API_KEY, GEMINI_MODEL
|
||||||
|
CLOUD_LLM_PROVIDER=openai → OPENAI_API_KEY, OPENAI_MODEL, OPENAI_BASE_URL
|
||||||
|
CLOUD_LLM_PROVIDER=claude → ANTHROPIC_API_KEY, CLAUDE_MODEL
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from .base import CloudProvider, ProviderResponse
|
||||||
|
from .groq import GroqProvider
|
||||||
|
from .gemini import GeminiProvider
|
||||||
|
from .openai_compat import OpenAICompatProvider
|
||||||
|
from .claude import ClaudeProvider
|
||||||
|
|
||||||
|
PROVIDERS: dict[str, type] = {
|
||||||
|
"groq": GroqProvider,
|
||||||
|
"gemini": GeminiProvider,
|
||||||
|
"openai": OpenAICompatProvider,
|
||||||
|
"claude": ClaudeProvider,
|
||||||
|
}
|
||||||
|
|
||||||
|
_cached: CloudProvider | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider() -> CloudProvider:
|
||||||
|
"""Get the configured cloud provider (cached after first call)."""
|
||||||
|
global _cached
|
||||||
|
if _cached is not None:
|
||||||
|
return _cached
|
||||||
|
|
||||||
|
name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
|
||||||
|
cls = PROVIDERS.get(name)
|
||||||
|
if cls is None:
|
||||||
|
raise ValueError(f"Unknown provider: {name!r}. Options: {list(PROVIDERS)}")
|
||||||
|
|
||||||
|
_cached = cls()
|
||||||
|
return _cached
|
||||||
|
|
||||||
|
|
||||||
|
def has_api_key() -> bool:
|
||||||
|
"""Check if the configured provider has an API key set."""
|
||||||
|
name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
|
||||||
|
key_map = {
|
||||||
|
"groq": "GROQ_API_KEY",
|
||||||
|
"gemini": "GEMINI_API_KEY",
|
||||||
|
"openai": "OPENAI_API_KEY",
|
||||||
|
"claude": "ANTHROPIC_API_KEY",
|
||||||
|
}
|
||||||
|
env_var = key_map.get(name, "")
|
||||||
|
return bool(os.environ.get(env_var, ""))
|
||||||
36
detect/providers/base.py
Normal file
36
detect/providers/base.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""Cloud LLM provider protocol and model metadata."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelInfo:
|
||||||
|
"""Metadata for a cloud LLM model."""
|
||||||
|
id: str
|
||||||
|
vision: bool = True
|
||||||
|
cost_per_input_token: float = 0.0
|
||||||
|
cost_per_output_token: float = 0.0
|
||||||
|
max_output_tokens: int = 4096
|
||||||
|
notes: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProviderResponse:
|
||||||
|
answer: str
|
||||||
|
total_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class CloudProvider(Protocol):
|
||||||
|
"""
|
||||||
|
Interface for cloud LLM providers.
|
||||||
|
|
||||||
|
Each provider handles its own auth, payload format, and response parsing.
|
||||||
|
The pipeline only calls call() and reads the response.
|
||||||
|
"""
|
||||||
|
name: str
|
||||||
|
models: dict[str, ModelInfo]
|
||||||
|
|
||||||
|
def call(self, image_b64: str, prompt: str) -> ProviderResponse: ...
|
||||||
73
detect/providers/claude.py
Normal file
73
detect/providers/claude.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""Anthropic Claude provider — uses the official SDK."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from .base import ModelInfo, ProviderResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Claude-specific env vars
|
||||||
|
# ANTHROPIC_API_KEY is read by the SDK automatically
|
||||||
|
CLAUDE_MODEL = os.environ.get("CLAUDE_MODEL", "claude-sonnet-4-20250514")
|
||||||
|
|
||||||
|
MODELS = {
|
||||||
|
"claude-sonnet-4-20250514": ModelInfo(
|
||||||
|
id="claude-sonnet-4-20250514",
|
||||||
|
vision=True,
|
||||||
|
cost_per_input_token=0.000003,
|
||||||
|
cost_per_output_token=0.000015,
|
||||||
|
notes="Best balance of quality/cost with vision",
|
||||||
|
),
|
||||||
|
"claude-haiku-4-5-20251001": ModelInfo(
|
||||||
|
id="claude-haiku-4-5-20251001",
|
||||||
|
vision=True,
|
||||||
|
cost_per_input_token=0.0000008,
|
||||||
|
cost_per_output_token=0.000004,
|
||||||
|
notes="Fastest, cheapest, good for simple brand ID",
|
||||||
|
),
|
||||||
|
"claude-opus-4-6": ModelInfo(
|
||||||
|
id="claude-opus-4-6",
|
||||||
|
vision=True,
|
||||||
|
cost_per_input_token=0.000015,
|
||||||
|
cost_per_output_token=0.000075,
|
||||||
|
notes="Highest quality, use for ambiguous cases",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeProvider:
|
||||||
|
name = "claude"
|
||||||
|
models = MODELS
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
from anthropic import Anthropic
|
||||||
|
self.client = Anthropic()
|
||||||
|
self.model = CLAUDE_MODEL
|
||||||
|
|
||||||
|
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
|
||||||
|
message = self.client.messages.create(
|
||||||
|
model=self.model,
|
||||||
|
max_tokens=150,
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/jpeg",
|
||||||
|
"data": image_b64,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"type": "text", "text": prompt},
|
||||||
|
],
|
||||||
|
}],
|
||||||
|
)
|
||||||
|
|
||||||
|
answer = message.content[0].text.strip()
|
||||||
|
total_tokens = message.usage.input_tokens + message.usage.output_tokens
|
||||||
|
|
||||||
|
return ProviderResponse(answer=answer, total_tokens=total_tokens)
|
||||||
75
detect/providers/gemini.py
Normal file
75
detect/providers/gemini.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
"""Google Gemini provider — native REST API, not OpenAI-compatible."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from .base import ModelInfo, ProviderResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Gemini-specific env vars
|
||||||
|
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
|
||||||
|
GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-2.0-flash")
|
||||||
|
|
||||||
|
MODELS = {
|
||||||
|
"gemini-2.0-flash": ModelInfo(
|
||||||
|
id="gemini-2.0-flash",
|
||||||
|
vision=True,
|
||||||
|
cost_per_input_token=0.0000001,
|
||||||
|
cost_per_output_token=0.0000004,
|
||||||
|
notes="Fast, cheap, good vision",
|
||||||
|
),
|
||||||
|
"gemini-2.0-pro": ModelInfo(
|
||||||
|
id="gemini-2.0-pro",
|
||||||
|
vision=True,
|
||||||
|
cost_per_input_token=0.00000125,
|
||||||
|
cost_per_output_token=0.000005,
|
||||||
|
notes="Higher quality, slower",
|
||||||
|
),
|
||||||
|
"gemini-1.5-flash": ModelInfo(
|
||||||
|
id="gemini-1.5-flash",
|
||||||
|
vision=True,
|
||||||
|
cost_per_input_token=0.000000075,
|
||||||
|
cost_per_output_token=0.0000003,
|
||||||
|
notes="Cheapest option",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiProvider:
|
||||||
|
name = "gemini"
|
||||||
|
models = MODELS
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.api_key = GEMINI_API_KEY
|
||||||
|
self.model = GEMINI_MODEL
|
||||||
|
self.endpoint = (
|
||||||
|
f"https://generativelanguage.googleapis.com/v1beta/models/"
|
||||||
|
f"{self.model}:generateContent"
|
||||||
|
)
|
||||||
|
|
||||||
|
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
|
||||||
|
payload = {
|
||||||
|
"contents": [{
|
||||||
|
"parts": [
|
||||||
|
{"text": prompt},
|
||||||
|
{"inline_data": {"mime_type": "image/jpeg", "data": image_b64}},
|
||||||
|
],
|
||||||
|
}],
|
||||||
|
"generationConfig": {"maxOutputTokens": 150},
|
||||||
|
}
|
||||||
|
|
||||||
|
url = f"{self.endpoint}?key={self.api_key}"
|
||||||
|
resp = requests.post(url, json=payload, timeout=30)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
answer = data["candidates"][0]["content"]["parts"][0]["text"].strip()
|
||||||
|
usage = data.get("usageMetadata", {})
|
||||||
|
total_tokens = usage.get("totalTokenCount", 0)
|
||||||
|
|
||||||
|
return ProviderResponse(answer=answer, total_tokens=total_tokens)
|
||||||
66
detect/providers/groq.py
Normal file
66
detect/providers/groq.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""Groq cloud provider — OpenAI-compatible API with vision."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from .base import ModelInfo, ProviderResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Groq-specific env vars
|
||||||
|
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
|
||||||
|
GROQ_BASE_URL = os.environ.get("GROQ_BASE_URL", "https://api.groq.com/openai/v1")
|
||||||
|
GROQ_MODEL = os.environ.get("GROQ_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct")
|
||||||
|
|
||||||
|
MODELS = {
|
||||||
|
"meta-llama/llama-4-scout-17b-16e-instruct": ModelInfo(
|
||||||
|
id="meta-llama/llama-4-scout-17b-16e-instruct",
|
||||||
|
vision=True,
|
||||||
|
cost_per_input_token=0.0,
|
||||||
|
cost_per_output_token=0.0,
|
||||||
|
notes="Llama 4 Scout, only vision model on Groq free tier",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class GroqProvider:
|
||||||
|
name = "groq"
|
||||||
|
models = MODELS
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.api_key = GROQ_API_KEY
|
||||||
|
self.base_url = GROQ_BASE_URL
|
||||||
|
self.model = GROQ_MODEL
|
||||||
|
self.endpoint = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||||
|
self.headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": prompt},
|
||||||
|
{"type": "image_url", "image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{image_b64}",
|
||||||
|
}},
|
||||||
|
],
|
||||||
|
}],
|
||||||
|
"max_tokens": 150,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=30)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
answer = data["choices"][0]["message"]["content"].strip()
|
||||||
|
total_tokens = data.get("usage", {}).get("total_tokens", 0)
|
||||||
|
|
||||||
|
return ProviderResponse(answer=answer, total_tokens=total_tokens)
|
||||||
73
detect/providers/openai_compat.py
Normal file
73
detect/providers/openai_compat.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""Generic OpenAI-compatible provider (OpenAI, Together, etc.)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from .base import ModelInfo, ProviderResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# OpenAI-compat specific env vars
|
||||||
|
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
||||||
|
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||||
|
OPENAI_MODEL = os.environ.get("OPENAI_MODEL", "gpt-4o-mini")
|
||||||
|
|
||||||
|
MODELS = {
|
||||||
|
"gpt-4o-mini": ModelInfo(
|
||||||
|
id="gpt-4o-mini",
|
||||||
|
vision=True,
|
||||||
|
cost_per_input_token=0.00000015,
|
||||||
|
cost_per_output_token=0.0000006,
|
||||||
|
notes="Cheap, fast, decent vision",
|
||||||
|
),
|
||||||
|
"gpt-4o": ModelInfo(
|
||||||
|
id="gpt-4o",
|
||||||
|
vision=True,
|
||||||
|
cost_per_input_token=0.0000025,
|
||||||
|
cost_per_output_token=0.00001,
|
||||||
|
notes="Best OpenAI vision model",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompatProvider:
|
||||||
|
name = "openai"
|
||||||
|
models = MODELS
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.api_key = OPENAI_API_KEY
|
||||||
|
self.base_url = OPENAI_BASE_URL
|
||||||
|
self.model = OPENAI_MODEL
|
||||||
|
self.endpoint = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||||
|
self.headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": prompt},
|
||||||
|
{"type": "image_url", "image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{image_b64}",
|
||||||
|
}},
|
||||||
|
],
|
||||||
|
}],
|
||||||
|
"max_tokens": 150,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=30)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
answer = data["choices"][0]["message"]["content"].strip()
|
||||||
|
total_tokens = data.get("usage", {}).get("total_tokens", 0)
|
||||||
|
|
||||||
|
return ProviderResponse(answer=answer, total_tokens=total_tokens)
|
||||||
116
detect/stages/aggregator.py
Normal file
116
detect/stages/aggregator.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
"""
|
||||||
|
Stage 8 — Report compilation
|
||||||
|
|
||||||
|
Groups all detections by brand, merges contiguous appearances,
|
||||||
|
and builds the final DetectionReport.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from detect import emit
|
||||||
|
from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_contiguous(detections: list[BrandDetection], gap_threshold: float = 2.0) -> list[BrandDetection]:
|
||||||
|
"""
|
||||||
|
Merge detections of the same brand that are close in time.
|
||||||
|
|
||||||
|
If two detections of the same brand are within gap_threshold seconds,
|
||||||
|
they're merged into one detection spanning the full range.
|
||||||
|
"""
|
||||||
|
if not detections:
|
||||||
|
return []
|
||||||
|
|
||||||
|
sorted_dets = sorted(detections, key=lambda d: (d.brand, d.timestamp))
|
||||||
|
merged: list[BrandDetection] = []
|
||||||
|
current = sorted_dets[0]
|
||||||
|
|
||||||
|
for det in sorted_dets[1:]:
|
||||||
|
if (det.brand == current.brand
|
||||||
|
and det.timestamp <= current.timestamp + current.duration + gap_threshold):
|
||||||
|
end = max(current.timestamp + current.duration,
|
||||||
|
det.timestamp + det.duration)
|
||||||
|
current = BrandDetection(
|
||||||
|
brand=current.brand,
|
||||||
|
timestamp=current.timestamp,
|
||||||
|
duration=end - current.timestamp,
|
||||||
|
confidence=max(current.confidence, det.confidence),
|
||||||
|
source=current.source,
|
||||||
|
bbox=current.bbox,
|
||||||
|
frame_ref=current.frame_ref,
|
||||||
|
content_type=current.content_type,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
merged.append(current)
|
||||||
|
current = det
|
||||||
|
|
||||||
|
merged.append(current)
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def compile_report(
|
||||||
|
detections: list[BrandDetection],
|
||||||
|
stats: PipelineStats,
|
||||||
|
video_source: str = "",
|
||||||
|
content_type: str = "",
|
||||||
|
duration_seconds: float = 0.0,
|
||||||
|
job_id: str | None = None,
|
||||||
|
) -> DetectionReport:
|
||||||
|
"""
|
||||||
|
Build the final detection report from all accumulated detections.
|
||||||
|
|
||||||
|
Merges contiguous detections, computes per-brand stats,
|
||||||
|
and emits the job_complete event.
|
||||||
|
"""
|
||||||
|
merged = _merge_contiguous(detections)
|
||||||
|
|
||||||
|
brands: dict[str, BrandStats] = {}
|
||||||
|
for d in merged:
|
||||||
|
if d.brand not in brands:
|
||||||
|
brands[d.brand] = BrandStats()
|
||||||
|
s = brands[d.brand]
|
||||||
|
s.total_appearances += 1
|
||||||
|
s.total_screen_time += d.duration
|
||||||
|
s.avg_confidence = (
|
||||||
|
(s.avg_confidence * (s.total_appearances - 1) + d.confidence)
|
||||||
|
/ s.total_appearances
|
||||||
|
)
|
||||||
|
if s.first_seen == 0.0 or d.timestamp < s.first_seen:
|
||||||
|
s.first_seen = d.timestamp
|
||||||
|
if d.timestamp > s.last_seen:
|
||||||
|
s.last_seen = d.timestamp
|
||||||
|
|
||||||
|
report = DetectionReport(
|
||||||
|
video_source=video_source,
|
||||||
|
content_type=content_type,
|
||||||
|
duration_seconds=duration_seconds,
|
||||||
|
brands=brands,
|
||||||
|
timeline=sorted(merged, key=lambda d: d.timestamp),
|
||||||
|
pipeline_stats=stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
emit.log(job_id, "Aggregator", "INFO",
|
||||||
|
f"Report: {len(brands)} brands, {len(merged)} detections "
|
||||||
|
f"(merged from {len(detections)} raw)")
|
||||||
|
|
||||||
|
emit.job_complete(job_id, {
|
||||||
|
"video_source": report.video_source,
|
||||||
|
"content_type": report.content_type,
|
||||||
|
"duration_seconds": report.duration_seconds,
|
||||||
|
"brands": {
|
||||||
|
k: {
|
||||||
|
"total_appearances": v.total_appearances,
|
||||||
|
"total_screen_time": v.total_screen_time,
|
||||||
|
"avg_confidence": round(v.avg_confidence, 3),
|
||||||
|
"first_seen": v.first_seen,
|
||||||
|
"last_seen": v.last_seen,
|
||||||
|
}
|
||||||
|
for k, v in brands.items()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
return report
|
||||||
168
detect/stages/vlm_cloud.py
Normal file
168
detect/stages/vlm_cloud.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""
|
||||||
|
Stage 7 — Cloud LLM escalation
|
||||||
|
|
||||||
|
Last resort for crops the local VLM couldn't resolve.
|
||||||
|
Provider-agnostic — switch via CLOUD_LLM_PROVIDER env var.
|
||||||
|
Each provider has its own file under detect/providers/.
|
||||||
|
|
||||||
|
Tracks token usage and cost.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from detect import emit
|
||||||
|
from detect.models import BrandDetection, PipelineStats, TextCandidate
|
||||||
|
from detect.profiles.base import CropContext
|
||||||
|
from detect.providers import get_provider, has_api_key
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ESTIMATED_TOKENS_PER_CROP = 500
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_crop(crop: np.ndarray) -> str:
|
||||||
|
img = Image.fromarray(crop)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
img.save(buf, format="JPEG", quality=85)
|
||||||
|
return base64.b64encode(buf.getvalue()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _crop_image(candidate: TextCandidate) -> np.ndarray:
|
||||||
|
frame = candidate.frame
|
||||||
|
box = candidate.bbox
|
||||||
|
h, w = frame.image.shape[:2]
|
||||||
|
x1 = max(0, box.x)
|
||||||
|
y1 = max(0, box.y)
|
||||||
|
x2 = min(w, box.x + box.w)
|
||||||
|
y2 = min(h, box.y + box.h)
|
||||||
|
return frame.image[y1:y2, x1:x2]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_response(answer: str, total_tokens: int) -> dict:
|
||||||
|
"""Parse LLM free-text response into structured output."""
|
||||||
|
parts = [p.strip() for p in answer.split(",", 2)]
|
||||||
|
|
||||||
|
brand = parts[0] if parts else ""
|
||||||
|
confidence = 0.5
|
||||||
|
reasoning = answer
|
||||||
|
|
||||||
|
if len(parts) >= 2:
|
||||||
|
try:
|
||||||
|
confidence = float(parts[1])
|
||||||
|
confidence = max(0.0, min(1.0, confidence))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if len(parts) >= 3:
|
||||||
|
reasoning = parts[2]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"brand": brand,
|
||||||
|
"confidence": confidence,
|
||||||
|
"reasoning": reasoning,
|
||||||
|
"tokens": total_tokens or ESTIMATED_TOKENS_PER_CROP,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _call_cloud_api(image_b64: str, prompt: str) -> dict:
|
||||||
|
"""Route to the configured provider and parse the response."""
|
||||||
|
provider = get_provider()
|
||||||
|
result = provider.call(image_b64, prompt)
|
||||||
|
return _parse_response(result.answer, result.total_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def escalate_cloud(
|
||||||
|
candidates: list[TextCandidate],
|
||||||
|
vlm_prompt_fn,
|
||||||
|
stats: PipelineStats,
|
||||||
|
min_confidence: float = 0.4,
|
||||||
|
content_type: str = "",
|
||||||
|
job_id: str | None = None,
|
||||||
|
) -> list[BrandDetection]:
|
||||||
|
"""
|
||||||
|
Send remaining unresolved crops to cloud LLM.
|
||||||
|
|
||||||
|
Provider is selected via CLOUD_LLM_PROVIDER env var (groq, gemini, openai).
|
||||||
|
Updates stats with call count and cost.
|
||||||
|
"""
|
||||||
|
if not candidates:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not has_api_key():
|
||||||
|
emit.log(job_id, "CloudLLM", "WARNING",
|
||||||
|
f"No API key set for cloud provider, skipping {len(candidates)} crops")
|
||||||
|
return []
|
||||||
|
|
||||||
|
provider = get_provider()
|
||||||
|
emit.log(job_id, "CloudLLM", "INFO",
|
||||||
|
f"Escalating {len(candidates)} crops to {provider.name}")
|
||||||
|
|
||||||
|
matched: list[BrandDetection] = []
|
||||||
|
total_cost = 0.0
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
crop = _crop_image(candidate)
|
||||||
|
if crop.size == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
crop_context = CropContext(
|
||||||
|
image=b"",
|
||||||
|
surrounding_text=candidate.text,
|
||||||
|
position_hint=f"frame {candidate.frame.sequence}",
|
||||||
|
)
|
||||||
|
prompt = vlm_prompt_fn(crop_context)
|
||||||
|
image_b64 = _encode_crop(crop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = _call_cloud_api(image_b64, prompt)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Cloud LLM failed for '%s': %s", candidate.text, e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
stats.cloud_llm_calls += 1
|
||||||
|
model_info = provider.models.get(provider.model)
|
||||||
|
cost_per_token = model_info.cost_per_input_token if model_info else 0.00001
|
||||||
|
call_cost = result["tokens"] * cost_per_token
|
||||||
|
total_cost += call_cost
|
||||||
|
|
||||||
|
brand = result["brand"]
|
||||||
|
confidence = result["confidence"]
|
||||||
|
|
||||||
|
if brand and confidence >= min_confidence:
|
||||||
|
detection = BrandDetection(
|
||||||
|
brand=brand,
|
||||||
|
timestamp=candidate.frame.timestamp,
|
||||||
|
duration=0.5,
|
||||||
|
confidence=confidence,
|
||||||
|
source="cloud_llm",
|
||||||
|
bbox=candidate.bbox,
|
||||||
|
frame_ref=candidate.frame.sequence,
|
||||||
|
content_type=content_type,
|
||||||
|
)
|
||||||
|
matched.append(detection)
|
||||||
|
|
||||||
|
emit.detection(
|
||||||
|
job_id,
|
||||||
|
brand=brand,
|
||||||
|
confidence=confidence,
|
||||||
|
source="cloud_llm",
|
||||||
|
timestamp=candidate.frame.timestamp,
|
||||||
|
content_type=content_type,
|
||||||
|
frame_ref=candidate.frame.sequence,
|
||||||
|
)
|
||||||
|
|
||||||
|
stats.estimated_cloud_cost_usd += total_cost
|
||||||
|
stats.regions_escalated_to_cloud_llm = len(candidates)
|
||||||
|
|
||||||
|
emit.log(job_id, "CloudLLM", "INFO",
|
||||||
|
f"Cloud resolved {len(matched)}/{len(candidates)} — "
|
||||||
|
f"cost ${total_cost:.4f} ({stats.cloud_llm_calls} calls total)")
|
||||||
|
|
||||||
|
return matched
|
||||||
124
detect/stages/vlm_local.py
Normal file
124
detect/stages/vlm_local.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
"""
|
||||||
|
Stage 6 — Local VLM escalation (moondream2)
|
||||||
|
|
||||||
|
Processes unresolved text candidates by sending crop images + prompt
|
||||||
|
to the local VLM on the inference server. Produces BrandDetection
|
||||||
|
objects for crops the VLM can identify.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from detect import emit
|
||||||
|
from detect.models import BrandDetection, TextCandidate
|
||||||
|
from detect.profiles.base import CropContext
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _crop_image(candidate: TextCandidate) -> np.ndarray:
|
||||||
|
frame = candidate.frame
|
||||||
|
box = candidate.bbox
|
||||||
|
h, w = frame.image.shape[:2]
|
||||||
|
x1 = max(0, box.x)
|
||||||
|
y1 = max(0, box.y)
|
||||||
|
x2 = min(w, box.x + box.w)
|
||||||
|
y2 = min(h, box.y + box.h)
|
||||||
|
return frame.image[y1:y2, x1:x2]
|
||||||
|
|
||||||
|
|
||||||
|
def escalate_vlm(
|
||||||
|
candidates: list[TextCandidate],
|
||||||
|
vlm_prompt_fn,
|
||||||
|
inference_url: str | None = None,
|
||||||
|
min_confidence: float = 0.5,
|
||||||
|
content_type: str = "",
|
||||||
|
job_id: str | None = None,
|
||||||
|
) -> tuple[list[BrandDetection], list[TextCandidate]]:
|
||||||
|
"""
|
||||||
|
Send unresolved crops to local VLM for brand identification.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- matched: BrandDetections the VLM confirmed
|
||||||
|
- still_unresolved: candidates the VLM couldn't resolve (→ cloud escalation)
|
||||||
|
"""
|
||||||
|
if not candidates:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
emit.log(job_id, "VLMLocal", "INFO",
|
||||||
|
f"Processing {len(candidates)} unresolved crops with moondream2")
|
||||||
|
|
||||||
|
matched: list[BrandDetection] = []
|
||||||
|
still_unresolved: list[TextCandidate] = []
|
||||||
|
|
||||||
|
if inference_url:
|
||||||
|
from detect.inference import InferenceClient
|
||||||
|
client = InferenceClient(base_url=inference_url)
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
crop = _crop_image(candidate)
|
||||||
|
if crop.size == 0:
|
||||||
|
still_unresolved.append(candidate)
|
||||||
|
continue
|
||||||
|
|
||||||
|
crop_context = CropContext(
|
||||||
|
image=b"", # not used for prompt generation
|
||||||
|
surrounding_text=candidate.text,
|
||||||
|
position_hint=f"frame {candidate.frame.sequence}",
|
||||||
|
)
|
||||||
|
prompt = vlm_prompt_fn(crop_context)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if inference_url:
|
||||||
|
result = client.vlm(image=crop, prompt=prompt)
|
||||||
|
brand = result.brand
|
||||||
|
confidence = result.confidence
|
||||||
|
reasoning = result.reasoning
|
||||||
|
else:
|
||||||
|
brand, confidence, reasoning = _vlm_local(crop, prompt)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("VLM failed for candidate '%s': %s", candidate.text, e)
|
||||||
|
still_unresolved.append(candidate)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if brand and confidence >= min_confidence:
|
||||||
|
detection = BrandDetection(
|
||||||
|
brand=brand,
|
||||||
|
timestamp=candidate.frame.timestamp,
|
||||||
|
duration=0.5,
|
||||||
|
confidence=confidence,
|
||||||
|
source="local_vlm",
|
||||||
|
bbox=candidate.bbox,
|
||||||
|
frame_ref=candidate.frame.sequence,
|
||||||
|
content_type=content_type,
|
||||||
|
)
|
||||||
|
matched.append(detection)
|
||||||
|
|
||||||
|
emit.detection(
|
||||||
|
job_id,
|
||||||
|
brand=brand,
|
||||||
|
confidence=confidence,
|
||||||
|
source="local_vlm",
|
||||||
|
timestamp=candidate.frame.timestamp,
|
||||||
|
content_type=content_type,
|
||||||
|
frame_ref=candidate.frame.sequence,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("VLM matched: %s (%.2f) — %s", brand, confidence, reasoning)
|
||||||
|
else:
|
||||||
|
still_unresolved.append(candidate)
|
||||||
|
|
||||||
|
emit.log(job_id, "VLMLocal", "INFO",
|
||||||
|
f"VLM resolved {len(matched)}, unresolved {len(still_unresolved)} → cloud")
|
||||||
|
|
||||||
|
return matched, still_unresolved
|
||||||
|
|
||||||
|
|
||||||
|
def _vlm_local(crop: np.ndarray, prompt: str) -> tuple[str, float, str]:
|
||||||
|
"""Run moondream2 in-process (single-box mode)."""
|
||||||
|
from gpu.models.vlm import query
|
||||||
|
result = query(crop, prompt)
|
||||||
|
return result["brand"], result["confidence"], result["reasoning"]
|
||||||
100
gpu/models/vlm.py
Normal file
100
gpu/models/vlm.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""moondream2 visual language model wrapper."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from models import registry
|
||||||
|
from config import get_config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_MODEL_KEY = "vlm_moondream2"
|
||||||
|
|
||||||
|
|
||||||
|
def _load():
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
device = get_config().get("device", "auto")
|
||||||
|
if device == "auto":
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
logger.info("Loading moondream2 (device=%s)...", device)
|
||||||
|
|
||||||
|
model_id = "vikhyatk/moondream2"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||||
|
dtype = torch.float16 if "cuda" in device else torch.float32
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
trust_remote_code=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device_map=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper = {"model": model, "tokenizer": tokenizer}
|
||||||
|
registry.put(_MODEL_KEY, wrapper)
|
||||||
|
logger.info("moondream2 loaded")
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def _get():
|
||||||
|
wrapper = registry.get(_MODEL_KEY)
|
||||||
|
if wrapper is None:
|
||||||
|
wrapper = _load()
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def query(image, prompt: str) -> dict:
|
||||||
|
"""
|
||||||
|
Query moondream2 with an image crop and prompt.
|
||||||
|
|
||||||
|
Returns {"brand": str, "confidence": float, "reasoning": str}
|
||||||
|
"""
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
wrapper = _get()
|
||||||
|
model = wrapper["model"]
|
||||||
|
tokenizer = wrapper["tokenizer"]
|
||||||
|
|
||||||
|
# Convert numpy array to PIL if needed
|
||||||
|
if not isinstance(image, PILImage.Image):
|
||||||
|
image = PILImage.fromarray(image)
|
||||||
|
|
||||||
|
enc_image = model.encode_image(image)
|
||||||
|
answer = model.answer_question(enc_image, prompt, tokenizer)
|
||||||
|
|
||||||
|
# Parse response — moondream2 returns free text, extract brand + confidence
|
||||||
|
result = _parse_vlm_response(answer)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_vlm_response(answer: str) -> dict:
|
||||||
|
"""
|
||||||
|
Parse moondream2 free-text response into structured output.
|
||||||
|
|
||||||
|
Expected format from prompt: "brand, confidence (0-1), reasoning"
|
||||||
|
Falls back gracefully if format doesn't match.
|
||||||
|
"""
|
||||||
|
answer = answer.strip()
|
||||||
|
parts = [p.strip() for p in answer.split(",", 2)]
|
||||||
|
|
||||||
|
brand = parts[0] if parts else ""
|
||||||
|
confidence = 0.5
|
||||||
|
reasoning = answer
|
||||||
|
|
||||||
|
if len(parts) >= 2:
|
||||||
|
try:
|
||||||
|
confidence = float(parts[1])
|
||||||
|
confidence = max(0.0, min(1.0, confidence))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if len(parts) >= 3:
|
||||||
|
reasoning = parts[2]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"brand": brand,
|
||||||
|
"confidence": confidence,
|
||||||
|
"reasoning": reasoning,
|
||||||
|
}
|
||||||
@@ -19,3 +19,9 @@ ultralytics>=8.0.0
|
|||||||
# Install with:
|
# Install with:
|
||||||
# uv pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
# uv pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||||
paddleocr>=3.0.0
|
paddleocr>=3.0.0
|
||||||
|
|
||||||
|
# VLM (moondream2) — uses torch (already installed above)
|
||||||
|
# Pinned <5: transformers 5.x broke moondream2's custom model code
|
||||||
|
# (all_tied_weights_keys API change). Also needs accelerate for device_map.
|
||||||
|
transformers>=4.40.0,<5
|
||||||
|
accelerate>=0.27.0
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from config import get_config, get_device, update_config
|
|||||||
from models import registry
|
from models import registry
|
||||||
from models.yolo import detect as yolo_detect
|
from models.yolo import detect as yolo_detect
|
||||||
from models.ocr import ocr as ocr_run
|
from models.ocr import ocr as ocr_run
|
||||||
|
from models.vlm import query as vlm_query
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -72,6 +73,18 @@ class OCRResponse(BaseModel):
|
|||||||
results: list[OCRTextResult]
|
results: list[OCRTextResult]
|
||||||
|
|
||||||
|
|
||||||
|
class VLMRequest(BaseModel):
|
||||||
|
image: str
|
||||||
|
prompt: str
|
||||||
|
model: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class VLMResponse(BaseModel):
|
||||||
|
brand: str
|
||||||
|
confidence: float
|
||||||
|
reasoning: str
|
||||||
|
|
||||||
|
|
||||||
class ConfigUpdate(BaseModel):
|
class ConfigUpdate(BaseModel):
|
||||||
device: str | None = None
|
device: str | None = None
|
||||||
yolo_model: str | None = None
|
yolo_model: str | None = None
|
||||||
@@ -170,6 +183,21 @@ def ocr(req: OCRRequest):
|
|||||||
return OCRResponse(results=[OCRTextResult(**r) for r in results])
|
return OCRResponse(results=[OCRTextResult(**r) for r in results])
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/vlm", response_model=VLMResponse)
|
||||||
|
def vlm(req: VLMRequest):
|
||||||
|
try:
|
||||||
|
image = _decode_image(req.image)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = vlm_query(image, req.prompt)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"VLM failed: {e}")
|
||||||
|
|
||||||
|
return VLMResponse(**result)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,9 @@ strawberry-graphql[fastapi]>=0.311.0
|
|||||||
# Observability
|
# Observability
|
||||||
langfuse>=2.0.0
|
langfuse>=2.0.0
|
||||||
|
|
||||||
|
# Cloud LLM providers (only needed for cloud escalation stage)
|
||||||
|
anthropic>=0.40.0
|
||||||
|
|
||||||
# Testing
|
# Testing
|
||||||
pytest>=7.4.0
|
pytest>=7.4.0
|
||||||
pytest-django>=4.7.0
|
pytest-django>=4.7.0
|
||||||
|
|||||||
107
tests/detect/manual/test_cloud_provider.py
Normal file
107
tests/detect/manual/test_cloud_provider.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test cloud LLM provider with a real API call.
|
||||||
|
|
||||||
|
Sends a test image to the configured cloud provider and verifies
|
||||||
|
the response. Set your provider env vars before running.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Groq (default)
|
||||||
|
CLOUD_LLM_PROVIDER=groq GROQ_API_KEY=gsk_... python tests/detect/manual/test_cloud_provider.py
|
||||||
|
|
||||||
|
# Gemini
|
||||||
|
CLOUD_LLM_PROVIDER=gemini GEMINI_API_KEY=AIza... python tests/detect/manual/test_cloud_provider.py
|
||||||
|
|
||||||
|
# Claude
|
||||||
|
CLOUD_LLM_PROVIDER=claude ANTHROPIC_API_KEY=sk-ant-... python tests/detect/manual/test_cloud_provider.py
|
||||||
|
|
||||||
|
# OpenAI-compatible
|
||||||
|
CLOUD_LLM_PROVIDER=openai OPENAI_API_KEY=sk-... python tests/detect/manual/test_cloud_provider.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Load .env from ctrl/ (same as docker-compose uses)
|
||||||
|
env_file = Path(__file__).resolve().parents[3] / "ctrl" / ".env"
|
||||||
|
if env_file.exists():
|
||||||
|
for line in env_file.read_text().splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if line and not line.startswith("#") and "=" in line:
|
||||||
|
key, _, value = line.partition("=")
|
||||||
|
os.environ.setdefault(key.strip(), value.strip())
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
sys.path.insert(0, ".")
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def make_brand_image(text: str, width: int = 300, height: int = 100) -> str:
|
||||||
|
img = Image.new("RGB", (width, height), "white")
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
try:
|
||||||
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 42)
|
||||||
|
except (OSError, IOError):
|
||||||
|
font = ImageFont.load_default()
|
||||||
|
draw.text((10, 20), text, fill="black", font=font)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
img.save(buf, "JPEG")
|
||||||
|
return base64.b64encode(buf.getvalue()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
from detect.providers import get_provider, has_api_key, PROVIDERS
|
||||||
|
|
||||||
|
provider_name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
|
||||||
|
logger.info("Provider: %s", provider_name)
|
||||||
|
logger.info("Available providers: %s", list(PROVIDERS.keys()))
|
||||||
|
|
||||||
|
if not has_api_key():
|
||||||
|
logger.error("No API key set for provider '%s'", provider_name)
|
||||||
|
logger.error("Set the appropriate env var (see usage in docstring)")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
provider = get_provider()
|
||||||
|
logger.info("Model: %s", provider.model)
|
||||||
|
logger.info("Available models: %s", list(provider.models.keys()))
|
||||||
|
input("\nPress Enter to start...")
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
"Identify the brand or sponsor visible in this image from a soccer broadcast. "
|
||||||
|
"Respond with: brand, confidence (0-1), reasoning."
|
||||||
|
)
|
||||||
|
|
||||||
|
test_cases = ["NIKE", "EMIRATES", "Coca-Cola", "adidas"]
|
||||||
|
|
||||||
|
for text in test_cases:
|
||||||
|
logger.info("--- Testing: '%s' ---", text)
|
||||||
|
image_b64 = make_brand_image(text)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = provider.call(image_b64, prompt)
|
||||||
|
logger.info(" answer: %s", result.answer)
|
||||||
|
logger.info(" tokens: %d", result.total_tokens)
|
||||||
|
|
||||||
|
if text.lower() in result.answer.lower():
|
||||||
|
logger.info(" PASS — found '%s' in response", text)
|
||||||
|
else:
|
||||||
|
logger.warning(" MISS — '%s' not in response (may be correct, check answer)", text)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(" FAIL — %s: %s", type(e).__name__, e)
|
||||||
|
if hasattr(e, 'response') and e.response is not None:
|
||||||
|
logger.error(" Response: %s", e.response.text[:500])
|
||||||
|
|
||||||
|
logger.info("All provider tests complete.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
230
tests/detect/manual/test_escalation_e2e.py
Normal file
230
tests/detect/manual/test_escalation_e2e.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Push a full pipeline simulation with escalation events.
|
||||||
|
|
||||||
|
Exercises all stages including VLM and cloud escalation, with progressive
|
||||||
|
stats showing cost accumulating. Tests all panels: pipeline graph, funnel,
|
||||||
|
timeline, cost stats, brand table, and log.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python tests/detect/manual/test_escalation_e2e.py [--job JOB_ID] [--port PORT] [--delay SECS]
|
||||||
|
|
||||||
|
Opens: http://mpr.local.ar/detection/?job=<JOB_ID>
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import redis
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
NODES = ["extract_frames", "filter_scenes", "detect_objects", "run_ocr",
|
||||||
|
"match_brands", "escalate_vlm", "escalate_cloud", "compile_report"]
|
||||||
|
|
||||||
|
|
||||||
|
def ts():
|
||||||
|
return datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
def push(r, key, event):
|
||||||
|
event["ts"] = event.get("ts", ts())
|
||||||
|
r.rpush(key, json.dumps(event))
|
||||||
|
return event
|
||||||
|
|
||||||
|
|
||||||
|
def push_graph(r, key, active_node, status, delay):
|
||||||
|
nodes = []
|
||||||
|
for n in NODES:
|
||||||
|
if n == active_node:
|
||||||
|
nodes.append({"id": n, "status": status})
|
||||||
|
elif NODES.index(n) < NODES.index(active_node):
|
||||||
|
nodes.append({"id": n, "status": "done"})
|
||||||
|
else:
|
||||||
|
nodes.append({"id": n, "status": "pending"})
|
||||||
|
push(r, key, {"event": "graph_update", "nodes": nodes})
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
|
||||||
|
def push_stats(r, key, **fields):
|
||||||
|
base = {
|
||||||
|
"event": "stats_update",
|
||||||
|
"frames_extracted": 0, "frames_after_scene_filter": 0,
|
||||||
|
"regions_detected": 0, "regions_resolved_by_ocr": 0,
|
||||||
|
"regions_escalated_to_local_vlm": 0, "regions_escalated_to_cloud_llm": 0,
|
||||||
|
"cloud_llm_calls": 0, "processing_time_seconds": 0, "estimated_cloud_cost_usd": 0,
|
||||||
|
}
|
||||||
|
base.update(fields)
|
||||||
|
push(r, key, base)
|
||||||
|
|
||||||
|
|
||||||
|
def push_detection(r, key, brand, conf, source, timestamp, frame_ref, delay):
|
||||||
|
push(r, key, {
|
||||||
|
"event": "detection",
|
||||||
|
"brand": brand, "confidence": conf, "source": source,
|
||||||
|
"timestamp": timestamp, "duration": 0.5,
|
||||||
|
"content_type": "soccer_broadcast", "frame_ref": frame_ref,
|
||||||
|
})
|
||||||
|
logger.info(" [%s] %s %.2f t=%.1fs", source, brand, conf, timestamp)
|
||||||
|
time.sleep(delay * 0.3)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--job", default="escalation-test")
|
||||||
|
parser.add_argument("--port", type=int, default=6382)
|
||||||
|
parser.add_argument("--delay", type=float, default=0.5)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
r = redis.Redis(port=args.port, decode_responses=True)
|
||||||
|
key = f"detect_events:{args.job}"
|
||||||
|
r.delete(key)
|
||||||
|
delay = args.delay
|
||||||
|
|
||||||
|
logger.info("Full escalation pipeline simulation → %s", key)
|
||||||
|
logger.info("Open: http://mpr.local.ar/detection/?job=%s", args.job)
|
||||||
|
input("\nPress Enter to start...")
|
||||||
|
|
||||||
|
# --- Extract frames ---
|
||||||
|
push_graph(r, key, "extract_frames", "running", delay)
|
||||||
|
push(r, key, {"event": "log", "level": "INFO", "stage": "FrameExtractor",
|
||||||
|
"msg": "Extracting frames: match_clip.mp4 (90.0s, 1920x1080, fps=2)"})
|
||||||
|
time.sleep(delay)
|
||||||
|
push_stats(r, key, frames_extracted=180, processing_time_seconds=4.5)
|
||||||
|
push_graph(r, key, "extract_frames", "done", delay)
|
||||||
|
|
||||||
|
# --- Scene filter ---
|
||||||
|
push_graph(r, key, "filter_scenes", "running", delay)
|
||||||
|
push_stats(r, key, frames_extracted=180, frames_after_scene_filter=52, processing_time_seconds=6.8)
|
||||||
|
push(r, key, {"event": "log", "level": "INFO", "stage": "SceneFilter",
|
||||||
|
"msg": "Kept 52 frames (71% reduction)"})
|
||||||
|
push_graph(r, key, "filter_scenes", "done", delay)
|
||||||
|
|
||||||
|
# --- YOLO detect ---
|
||||||
|
push_graph(r, key, "detect_objects", "running", delay)
|
||||||
|
push(r, key, {"event": "log", "level": "INFO", "stage": "YOLODetector",
|
||||||
|
"msg": "Running yolov8n on 52 frames"})
|
||||||
|
time.sleep(delay)
|
||||||
|
push_stats(r, key, frames_extracted=180, frames_after_scene_filter=52,
|
||||||
|
regions_detected=41, processing_time_seconds=14.2)
|
||||||
|
push_graph(r, key, "detect_objects", "done", delay)
|
||||||
|
|
||||||
|
# --- OCR ---
|
||||||
|
push_graph(r, key, "run_ocr", "running", delay)
|
||||||
|
push(r, key, {"event": "log", "level": "INFO", "stage": "OCRStage",
|
||||||
|
"msg": "Running OCR on 41 regions (mode=remote)"})
|
||||||
|
time.sleep(delay)
|
||||||
|
push_stats(r, key, frames_extracted=180, frames_after_scene_filter=52,
|
||||||
|
regions_detected=41, regions_resolved_by_ocr=30, processing_time_seconds=21.5)
|
||||||
|
push_graph(r, key, "run_ocr", "done", delay)
|
||||||
|
|
||||||
|
# --- Brand matching ---
|
||||||
|
push_graph(r, key, "match_brands", "running", delay)
|
||||||
|
push(r, key, {"event": "log", "level": "INFO", "stage": "BrandResolver",
|
||||||
|
"msg": "Matching 30 candidates against 12 brands (fuzzy_threshold=75)"})
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
# OCR detections
|
||||||
|
ocr_brands = [
|
||||||
|
("Nike", 0.97, 2.0, 4), ("Nike", 0.95, 5.5, 11), ("Emirates", 0.92, 8.0, 16),
|
||||||
|
("Adidas", 0.89, 12.0, 24), ("Coca-Cola", 0.85, 18.0, 36),
|
||||||
|
("Nike", 0.94, 22.0, 44), ("Emirates", 0.88, 28.0, 56),
|
||||||
|
("Adidas", 0.91, 32.0, 64), ("Nike", 0.96, 38.0, 76),
|
||||||
|
("Emirates", 0.90, 42.0, 84), ("Coca-Cola", 0.87, 48.0, 96),
|
||||||
|
("Nike", 0.93, 52.0, 104), ("Adidas", 0.90, 58.0, 116),
|
||||||
|
]
|
||||||
|
for brand, conf, ts_val, fref in ocr_brands:
|
||||||
|
push_detection(r, key, brand, conf, "ocr", ts_val, fref, delay)
|
||||||
|
|
||||||
|
push(r, key, {"event": "log", "level": "INFO", "stage": "BrandResolver",
|
||||||
|
"msg": "Exact: 10, Fuzzy: 3, Unresolved: 11 → VLM"})
|
||||||
|
push_graph(r, key, "match_brands", "done", delay)
|
||||||
|
|
||||||
|
# --- VLM escalation ---
|
||||||
|
push_graph(r, key, "escalate_vlm", "running", delay)
|
||||||
|
push(r, key, {"event": "log", "level": "INFO", "stage": "VLMLocal",
|
||||||
|
"msg": "Processing 11 unresolved crops with moondream2"})
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
vlm_brands = [
|
||||||
|
("Mastercard", 0.78, 15.0, 30), ("Santander", 0.74, 25.0, 50),
|
||||||
|
("Qatar Airways", 0.81, 35.0, 70), ("Heineken", 0.76, 45.0, 90),
|
||||||
|
("Lay's", 0.72, 55.0, 110),
|
||||||
|
]
|
||||||
|
for brand, conf, ts_val, fref in vlm_brands:
|
||||||
|
push_detection(r, key, brand, conf, "local_vlm", ts_val, fref, delay)
|
||||||
|
|
||||||
|
push_stats(r, key, frames_extracted=180, frames_after_scene_filter=52,
|
||||||
|
regions_detected=41, regions_resolved_by_ocr=30,
|
||||||
|
regions_escalated_to_local_vlm=11, processing_time_seconds=38.7,
|
||||||
|
estimated_cloud_cost_usd=0)
|
||||||
|
push(r, key, {"event": "log", "level": "INFO", "stage": "VLMLocal",
|
||||||
|
"msg": "VLM resolved 5, unresolved 6 → cloud"})
|
||||||
|
push_graph(r, key, "escalate_vlm", "done", delay)
|
||||||
|
|
||||||
|
# --- Cloud escalation ---
|
||||||
|
push_graph(r, key, "escalate_cloud", "running", delay)
|
||||||
|
push(r, key, {"event": "log", "level": "INFO", "stage": "CloudLLM",
|
||||||
|
"msg": "Escalating 6 crops to groq (llama-3.2-90b-vision)"})
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
cloud_brands = [
|
||||||
|
("Pepsi", 0.68, 10.0, 20),
|
||||||
|
("Gazprom", 0.65, 40.0, 80),
|
||||||
|
]
|
||||||
|
for brand, conf, ts_val, fref in cloud_brands:
|
||||||
|
push_detection(r, key, brand, conf, "cloud_llm", ts_val, fref, delay)
|
||||||
|
|
||||||
|
push_stats(r, key, frames_extracted=180, frames_after_scene_filter=52,
|
||||||
|
regions_detected=41, regions_resolved_by_ocr=30,
|
||||||
|
regions_escalated_to_local_vlm=11, regions_escalated_to_cloud_llm=6,
|
||||||
|
cloud_llm_calls=6, processing_time_seconds=45.2,
|
||||||
|
estimated_cloud_cost_usd=0.0) # groq free tier
|
||||||
|
|
||||||
|
push(r, key, {"event": "log", "level": "WARNING", "stage": "CloudLLM",
|
||||||
|
"msg": "4 crops unresolved after cloud — likely not brands"})
|
||||||
|
push(r, key, {"event": "log", "level": "INFO", "stage": "CloudLLM",
|
||||||
|
"msg": "Cloud resolved 2/6 — cost $0.0000 (groq free tier)"})
|
||||||
|
push_graph(r, key, "escalate_cloud", "done", delay)
|
||||||
|
|
||||||
|
# --- Compile report ---
|
||||||
|
push_graph(r, key, "compile_report", "running", delay)
|
||||||
|
|
||||||
|
total_brands = len(set(b[0] for b in ocr_brands + vlm_brands + cloud_brands))
|
||||||
|
total_dets = len(ocr_brands) + len(vlm_brands) + len(cloud_brands)
|
||||||
|
|
||||||
|
push(r, key, {"event": "log", "level": "INFO", "stage": "Aggregator",
|
||||||
|
"msg": f"Report: {total_brands} brands, {total_dets} detections (merged from {total_dets} raw)"})
|
||||||
|
|
||||||
|
push(r, key, {"event": "job_complete", "job_id": args.job, "report": {
|
||||||
|
"video_source": "match_clip.mp4",
|
||||||
|
"content_type": "soccer_broadcast",
|
||||||
|
"duration_seconds": 90.0,
|
||||||
|
"brands": {
|
||||||
|
"Nike": {"total_appearances": 5, "avg_confidence": 0.95},
|
||||||
|
"Emirates": {"total_appearances": 3, "avg_confidence": 0.90},
|
||||||
|
"Adidas": {"total_appearances": 3, "avg_confidence": 0.90},
|
||||||
|
"Coca-Cola": {"total_appearances": 2, "avg_confidence": 0.86},
|
||||||
|
"Mastercard": {"total_appearances": 1, "avg_confidence": 0.78},
|
||||||
|
"Santander": {"total_appearances": 1, "avg_confidence": 0.74},
|
||||||
|
"Qatar Airways": {"total_appearances": 1, "avg_confidence": 0.81},
|
||||||
|
"Heineken": {"total_appearances": 1, "avg_confidence": 0.76},
|
||||||
|
"Lay's": {"total_appearances": 1, "avg_confidence": 0.72},
|
||||||
|
"Pepsi": {"total_appearances": 1, "avg_confidence": 0.68},
|
||||||
|
"Gazprom": {"total_appearances": 1, "avg_confidence": 0.65},
|
||||||
|
},
|
||||||
|
}})
|
||||||
|
|
||||||
|
push_graph(r, key, "compile_report", "done", delay)
|
||||||
|
|
||||||
|
logger.info("Done. %d brands, %d detections across ocr/vlm/cloud.", total_brands, total_dets)
|
||||||
|
logger.info("Check: pipeline graph (all green), timeline (3 source colors),")
|
||||||
|
logger.info(" cost panel (escalation ratio), brand table (source column).")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
100
tests/detect/manual/test_vlm_e2e.py
Normal file
100
tests/detect/manual/test_vlm_e2e.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test local VLM (moondream2) via the inference server.
|
||||||
|
|
||||||
|
Creates test images with brand text/logos, sends them to the /vlm endpoint,
|
||||||
|
verifies moondream2 can identify the brand.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python tests/detect/manual/test_vlm_e2e.py [--url URL]
|
||||||
|
|
||||||
|
Requires: inference server running with moondream2 loaded (gpu/server.py)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def make_brand_image(text: str, width: int = 300, height: int = 100) -> np.ndarray:
|
||||||
|
img = Image.new("RGB", (width, height), "white")
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
try:
|
||||||
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 42)
|
||||||
|
except (OSError, IOError):
|
||||||
|
font = ImageFont.load_default()
|
||||||
|
draw.text((10, 20), text, fill="black", font=font)
|
||||||
|
return np.array(img)
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_b64(image: np.ndarray) -> str:
|
||||||
|
img = Image.fromarray(image)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
img.save(buf, "JPEG")
|
||||||
|
return base64.b64encode(buf.getvalue()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def test_health(url: str):
|
||||||
|
logger.info("--- Health check ---")
|
||||||
|
resp = requests.get(f"{url}/health")
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
logger.info("Status: %s, device: %s, models: %s", data["status"], data["device"], data.get("loaded_models", []))
|
||||||
|
|
||||||
|
|
||||||
|
def test_vlm(url: str, text: str, prompt: str):
|
||||||
|
logger.info("--- VLM: image='%s' ---", text)
|
||||||
|
image = make_brand_image(text)
|
||||||
|
b64 = image_to_b64(image)
|
||||||
|
|
||||||
|
resp = requests.post(f"{url}/vlm", json={"image": b64, "prompt": prompt})
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
logger.info(" brand: %s", data["brand"])
|
||||||
|
logger.info(" confidence: %.2f", data["confidence"])
|
||||||
|
logger.info(" reasoning: %s", data["reasoning"])
|
||||||
|
|
||||||
|
if text.lower() in data["brand"].lower():
|
||||||
|
logger.info(" PASS — matched")
|
||||||
|
else:
|
||||||
|
logger.warning(" MISS — expected '%s' in response", text)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--url", default="http://mcrndeb:8000")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
url = args.url.rstrip("/")
|
||||||
|
logger.info("Inference server: %s", url)
|
||||||
|
input("\nPress Enter to start...")
|
||||||
|
|
||||||
|
test_health(url)
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
"Identify the brand or sponsor visible in this image from a soccer broadcast. "
|
||||||
|
"Respond with: brand, confidence (0-1), reasoning."
|
||||||
|
)
|
||||||
|
|
||||||
|
test_vlm(url, "NIKE", prompt)
|
||||||
|
test_vlm(url, "EMIRATES", prompt)
|
||||||
|
test_vlm(url, "Coca-Cola", prompt)
|
||||||
|
test_vlm(url, "adidas", prompt)
|
||||||
|
|
||||||
|
logger.info("All VLM tests complete.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
79
tests/detect/test_aggregator.py
Normal file
79
tests/detect/test_aggregator.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""Tests for the report aggregator stage."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from detect.models import BoundingBox, BrandDetection, PipelineStats
|
||||||
|
from detect.stages.aggregator import compile_report, _merge_contiguous
|
||||||
|
|
||||||
|
|
||||||
|
def _make_detection(brand: str, timestamp: float, duration: float = 0.5,
|
||||||
|
source: str = "ocr", confidence: float = 0.9) -> BrandDetection:
|
||||||
|
return BrandDetection(
|
||||||
|
brand=brand, timestamp=timestamp, duration=duration,
|
||||||
|
confidence=confidence, source=source, content_type="soccer_broadcast",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_contiguous_same_brand():
|
||||||
|
dets = [
|
||||||
|
_make_detection("Nike", 1.0, 0.5),
|
||||||
|
_make_detection("Nike", 1.3, 0.5), # within gap
|
||||||
|
_make_detection("Nike", 5.0, 0.5), # separate
|
||||||
|
]
|
||||||
|
merged = _merge_contiguous(dets, gap_threshold=2.0)
|
||||||
|
assert len(merged) == 2
|
||||||
|
assert merged[0].brand == "Nike"
|
||||||
|
assert merged[0].timestamp == 1.0
|
||||||
|
assert merged[0].duration == pytest.approx(0.8) # 1.0 to 1.8
|
||||||
|
assert merged[1].timestamp == 5.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_different_brands():
|
||||||
|
dets = [
|
||||||
|
_make_detection("Nike", 1.0),
|
||||||
|
_make_detection("Adidas", 1.5),
|
||||||
|
]
|
||||||
|
merged = _merge_contiguous(dets)
|
||||||
|
assert len(merged) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_empty():
|
||||||
|
assert _merge_contiguous([]) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_compile_report(monkeypatch):
|
||||||
|
events = []
|
||||||
|
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||||
|
lambda job_id, etype, data: events.append((etype, data)))
|
||||||
|
|
||||||
|
dets = [
|
||||||
|
_make_detection("Nike", 1.0, 0.5, confidence=0.95),
|
||||||
|
_make_detection("Nike", 5.0, 1.0, confidence=0.90),
|
||||||
|
_make_detection("Adidas", 3.0, 0.5, confidence=0.85),
|
||||||
|
_make_detection("Heineken", 10.0, 0.5, source="cloud_llm", confidence=0.70),
|
||||||
|
]
|
||||||
|
stats = PipelineStats(
|
||||||
|
frames_extracted=120,
|
||||||
|
regions_detected=32,
|
||||||
|
cloud_llm_calls=1,
|
||||||
|
estimated_cloud_cost_usd=0.003,
|
||||||
|
)
|
||||||
|
|
||||||
|
report = compile_report(
|
||||||
|
detections=dets,
|
||||||
|
stats=stats,
|
||||||
|
video_source="test.mp4",
|
||||||
|
content_type="soccer_broadcast",
|
||||||
|
job_id="test-report",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(report.brands) == 3
|
||||||
|
assert report.brands["Nike"].total_appearances == 2
|
||||||
|
assert report.brands["Adidas"].total_appearances == 1
|
||||||
|
assert report.brands["Heineken"].total_appearances == 1
|
||||||
|
assert report.pipeline_stats.cloud_llm_calls == 1
|
||||||
|
assert report.video_source == "test.mp4"
|
||||||
|
|
||||||
|
# job_complete event should have been emitted
|
||||||
|
complete = [e for e in events if e[0] == "job_complete"]
|
||||||
|
assert len(complete) == 1
|
||||||
92
tests/detect/test_vlm_cloud.py
Normal file
92
tests/detect/test_vlm_cloud.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""Tests for cloud LLM escalation stage."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from detect.models import BoundingBox, Frame, PipelineStats, TextCandidate
|
||||||
|
from detect.stages.vlm_cloud import escalate_cloud, _parse_response
|
||||||
|
|
||||||
|
|
||||||
|
def _make_candidate(text: str = "unknown", confidence: float = 0.4) -> TextCandidate:
|
||||||
|
frame = Frame(sequence=0, chunk_id=0, timestamp=1.0,
|
||||||
|
image=np.zeros((50, 100, 3), dtype=np.uint8))
|
||||||
|
box = BoundingBox(x=0, y=0, w=100, h=50, confidence=0.5, label="text")
|
||||||
|
return TextCandidate(frame=frame, bbox=box, text=text, ocr_confidence=confidence)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_clean():
|
||||||
|
result = _parse_response("Nike, 0.92, swoosh logo visible", 200)
|
||||||
|
assert result["brand"] == "Nike"
|
||||||
|
assert result["confidence"] == 0.92
|
||||||
|
assert "swoosh" in result["reasoning"]
|
||||||
|
assert result["tokens"] == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_no_confidence():
|
||||||
|
result = _parse_response("Adidas", 0)
|
||||||
|
assert result["brand"] == "Adidas"
|
||||||
|
assert result["confidence"] == 0.5 # default
|
||||||
|
|
||||||
|
|
||||||
|
def test_escalate_skips_without_api_key(monkeypatch):
|
||||||
|
events = []
|
||||||
|
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||||
|
lambda job_id, etype, data: events.append((etype, data)))
|
||||||
|
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||||
|
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
||||||
|
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||||
|
monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq")
|
||||||
|
# Reset cached provider
|
||||||
|
import detect.providers as prov
|
||||||
|
monkeypatch.setattr(prov, "_cached", None)
|
||||||
|
|
||||||
|
candidates = [_make_candidate()]
|
||||||
|
stats = PipelineStats()
|
||||||
|
prompt_fn = lambda ctx: "what brand?"
|
||||||
|
|
||||||
|
matched = escalate_cloud(candidates, prompt_fn, stats, job_id="test")
|
||||||
|
|
||||||
|
assert len(matched) == 0
|
||||||
|
assert stats.cloud_llm_calls == 0
|
||||||
|
log_events = [e for e in events if e[0] == "log"]
|
||||||
|
assert any("No API key" in e[1].get("msg", "") for e in log_events)
|
||||||
|
|
||||||
|
|
||||||
|
def test_escalate_empty_candidates(monkeypatch):
|
||||||
|
events = []
|
||||||
|
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||||
|
lambda job_id, etype, data: events.append((etype, data)))
|
||||||
|
|
||||||
|
stats = PipelineStats()
|
||||||
|
matched = escalate_cloud([], lambda ctx: "", stats, job_id="test")
|
||||||
|
|
||||||
|
assert len(matched) == 0
|
||||||
|
assert stats.cloud_llm_calls == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_escalate_with_mock_api(monkeypatch):
|
||||||
|
events = []
|
||||||
|
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||||
|
lambda job_id, etype, data: events.append((etype, data)))
|
||||||
|
monkeypatch.setenv("GROQ_API_KEY", "test-key")
|
||||||
|
monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq")
|
||||||
|
# Reset cached provider
|
||||||
|
import detect.providers as prov
|
||||||
|
monkeypatch.setattr(prov, "_cached", None)
|
||||||
|
|
||||||
|
def mock_call(image_b64, prompt):
|
||||||
|
return {"brand": "Heineken", "confidence": 0.75, "reasoning": "green logo", "tokens": 300}
|
||||||
|
|
||||||
|
monkeypatch.setattr("detect.stages.vlm_cloud._call_cloud_api", mock_call)
|
||||||
|
|
||||||
|
candidates = [_make_candidate("unknown logo")]
|
||||||
|
stats = PipelineStats()
|
||||||
|
prompt_fn = lambda ctx: "what brand?"
|
||||||
|
|
||||||
|
matched = escalate_cloud(candidates, prompt_fn, stats, job_id="test")
|
||||||
|
|
||||||
|
assert len(matched) == 1
|
||||||
|
assert matched[0].brand == "Heineken"
|
||||||
|
assert matched[0].source == "cloud_llm"
|
||||||
|
assert stats.cloud_llm_calls == 1
|
||||||
|
assert stats.estimated_cloud_cost_usd >= 0 # exact cost depends on provider model index
|
||||||
Reference in New Issue
Block a user