phase 9
This commit is contained in:
@@ -20,6 +20,9 @@ from detect.stages.scene_filter import scene_filter
|
||||
from detect.stages.yolo_detector import detect_objects
|
||||
from detect.stages.ocr_stage import run_ocr
|
||||
from detect.stages.brand_resolver import resolve_brands
|
||||
from detect.stages.vlm_local import escalate_vlm
|
||||
from detect.stages.vlm_cloud import escalate_cloud
|
||||
from detect.stages.aggregator import compile_report
|
||||
from detect.tracing import trace_node, flush as flush_traces
|
||||
|
||||
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
|
||||
@@ -158,43 +161,77 @@ def node_escalate_vlm(state: DetectState) -> dict:
|
||||
_emit_transition(state, "escalate_vlm", "running")
|
||||
|
||||
with trace_node(state, "escalate_vlm") as span:
|
||||
profile = _get_profile(state)
|
||||
candidates = state.get("unresolved_candidates", [])
|
||||
job_id = state.get("job_id")
|
||||
emit.log(job_id, "VLMLocal", "INFO", "Stub: VLM escalation not yet implemented")
|
||||
span.set_output({"stub": True})
|
||||
|
||||
vlm_matched, still_unresolved = escalate_vlm(
|
||||
candidates,
|
||||
vlm_prompt_fn=profile.vlm_prompt,
|
||||
inference_url=INFERENCE_URL,
|
||||
content_type=profile.name,
|
||||
job_id=job_id,
|
||||
)
|
||||
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.regions_escalated_to_local_vlm = len(candidates)
|
||||
span.set_output({"candidates": len(candidates), "matched": len(vlm_matched),
|
||||
"still_unresolved": len(still_unresolved)})
|
||||
|
||||
existing = state.get("detections", [])
|
||||
|
||||
_emit_transition(state, "escalate_vlm", "done")
|
||||
return {}
|
||||
return {
|
||||
"detections": existing + vlm_matched,
|
||||
"unresolved_candidates": still_unresolved,
|
||||
"stats": stats,
|
||||
}
|
||||
|
||||
|
||||
def node_escalate_cloud(state: DetectState) -> dict:
|
||||
_emit_transition(state, "escalate_cloud", "running")
|
||||
|
||||
with trace_node(state, "escalate_cloud") as span:
|
||||
profile = _get_profile(state)
|
||||
candidates = state.get("unresolved_candidates", [])
|
||||
job_id = state.get("job_id")
|
||||
emit.log(job_id, "CloudLLM", "INFO", "Stub: cloud LLM escalation not yet implemented")
|
||||
span.set_output({"stub": True})
|
||||
stats = state.get("stats", PipelineStats())
|
||||
|
||||
cloud_matched = escalate_cloud(
|
||||
candidates,
|
||||
vlm_prompt_fn=profile.vlm_prompt,
|
||||
stats=stats,
|
||||
content_type=profile.name,
|
||||
job_id=job_id,
|
||||
)
|
||||
|
||||
span.set_output({"candidates": len(candidates), "matched": len(cloud_matched),
|
||||
"cloud_calls": stats.cloud_llm_calls,
|
||||
"cost_usd": stats.estimated_cloud_cost_usd})
|
||||
|
||||
existing = state.get("detections", [])
|
||||
|
||||
_emit_transition(state, "escalate_cloud", "done")
|
||||
return {}
|
||||
return {"detections": existing + cloud_matched, "stats": stats}
|
||||
|
||||
|
||||
def node_compile_report(state: DetectState) -> dict:
|
||||
_emit_transition(state, "compile_report", "running")
|
||||
|
||||
with trace_node(state, "compile_report") as span:
|
||||
job_id = state.get("job_id")
|
||||
profile = _get_profile(state)
|
||||
detections = state.get("detections", [])
|
||||
report = profile.aggregate(detections)
|
||||
report.video_source = state.get("video_path", "")
|
||||
stats = state.get("stats", PipelineStats())
|
||||
job_id = state.get("job_id")
|
||||
|
||||
report = compile_report(
|
||||
detections=detections,
|
||||
stats=stats,
|
||||
video_source=state.get("video_path", ""),
|
||||
content_type=profile.name,
|
||||
job_id=job_id,
|
||||
)
|
||||
|
||||
emit.log(job_id, "Aggregator", "INFO",
|
||||
f"Report: {len(report.brands)} brands, {len(report.timeline)} detections")
|
||||
emit.job_complete(job_id, {
|
||||
"video_source": report.video_source,
|
||||
"content_type": report.content_type,
|
||||
"brands": {k: {"total_appearances": v.total_appearances} for k, v in report.brands.items()},
|
||||
})
|
||||
span.set_output({"brands": len(report.brands), "detections": len(report.timeline)})
|
||||
|
||||
flush_traces()
|
||||
|
||||
58
detect/providers/__init__.py
Normal file
58
detect/providers/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Cloud LLM provider registry.
|
||||
|
||||
Select provider via CLOUD_LLM_PROVIDER env var.
|
||||
Each provider reads its own env vars for auth/config.
|
||||
|
||||
CLOUD_LLM_PROVIDER=groq → GROQ_API_KEY, GROQ_MODEL, GROQ_BASE_URL
|
||||
CLOUD_LLM_PROVIDER=gemini → GEMINI_API_KEY, GEMINI_MODEL
|
||||
CLOUD_LLM_PROVIDER=openai → OPENAI_API_KEY, OPENAI_MODEL, OPENAI_BASE_URL
|
||||
CLOUD_LLM_PROVIDER=claude → ANTHROPIC_API_KEY, CLAUDE_MODEL
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from .base import CloudProvider, ProviderResponse
|
||||
from .groq import GroqProvider
|
||||
from .gemini import GeminiProvider
|
||||
from .openai_compat import OpenAICompatProvider
|
||||
from .claude import ClaudeProvider
|
||||
|
||||
PROVIDERS: dict[str, type] = {
|
||||
"groq": GroqProvider,
|
||||
"gemini": GeminiProvider,
|
||||
"openai": OpenAICompatProvider,
|
||||
"claude": ClaudeProvider,
|
||||
}
|
||||
|
||||
_cached: CloudProvider | None = None
|
||||
|
||||
|
||||
def get_provider() -> CloudProvider:
|
||||
"""Get the configured cloud provider (cached after first call)."""
|
||||
global _cached
|
||||
if _cached is not None:
|
||||
return _cached
|
||||
|
||||
name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
|
||||
cls = PROVIDERS.get(name)
|
||||
if cls is None:
|
||||
raise ValueError(f"Unknown provider: {name!r}. Options: {list(PROVIDERS)}")
|
||||
|
||||
_cached = cls()
|
||||
return _cached
|
||||
|
||||
|
||||
def has_api_key() -> bool:
|
||||
"""Check if the configured provider has an API key set."""
|
||||
name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
|
||||
key_map = {
|
||||
"groq": "GROQ_API_KEY",
|
||||
"gemini": "GEMINI_API_KEY",
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"claude": "ANTHROPIC_API_KEY",
|
||||
}
|
||||
env_var = key_map.get(name, "")
|
||||
return bool(os.environ.get(env_var, ""))
|
||||
36
detect/providers/base.py
Normal file
36
detect/providers/base.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Cloud LLM provider protocol and model metadata."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Metadata for a cloud LLM model."""
|
||||
id: str
|
||||
vision: bool = True
|
||||
cost_per_input_token: float = 0.0
|
||||
cost_per_output_token: float = 0.0
|
||||
max_output_tokens: int = 4096
|
||||
notes: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderResponse:
|
||||
answer: str
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class CloudProvider(Protocol):
|
||||
"""
|
||||
Interface for cloud LLM providers.
|
||||
|
||||
Each provider handles its own auth, payload format, and response parsing.
|
||||
The pipeline only calls call() and reads the response.
|
||||
"""
|
||||
name: str
|
||||
models: dict[str, ModelInfo]
|
||||
|
||||
def call(self, image_b64: str, prompt: str) -> ProviderResponse: ...
|
||||
73
detect/providers/claude.py
Normal file
73
detect/providers/claude.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Anthropic Claude provider — uses the official SDK."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from .base import ModelInfo, ProviderResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Claude-specific env vars
|
||||
# ANTHROPIC_API_KEY is read by the SDK automatically
|
||||
CLAUDE_MODEL = os.environ.get("CLAUDE_MODEL", "claude-sonnet-4-20250514")
|
||||
|
||||
MODELS = {
|
||||
"claude-sonnet-4-20250514": ModelInfo(
|
||||
id="claude-sonnet-4-20250514",
|
||||
vision=True,
|
||||
cost_per_input_token=0.000003,
|
||||
cost_per_output_token=0.000015,
|
||||
notes="Best balance of quality/cost with vision",
|
||||
),
|
||||
"claude-haiku-4-5-20251001": ModelInfo(
|
||||
id="claude-haiku-4-5-20251001",
|
||||
vision=True,
|
||||
cost_per_input_token=0.0000008,
|
||||
cost_per_output_token=0.000004,
|
||||
notes="Fastest, cheapest, good for simple brand ID",
|
||||
),
|
||||
"claude-opus-4-6": ModelInfo(
|
||||
id="claude-opus-4-6",
|
||||
vision=True,
|
||||
cost_per_input_token=0.000015,
|
||||
cost_per_output_token=0.000075,
|
||||
notes="Highest quality, use for ambiguous cases",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class ClaudeProvider:
|
||||
name = "claude"
|
||||
models = MODELS
|
||||
|
||||
def __init__(self):
|
||||
from anthropic import Anthropic
|
||||
self.client = Anthropic()
|
||||
self.model = CLAUDE_MODEL
|
||||
|
||||
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
|
||||
message = self.client.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=150,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": image_b64,
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}],
|
||||
)
|
||||
|
||||
answer = message.content[0].text.strip()
|
||||
total_tokens = message.usage.input_tokens + message.usage.output_tokens
|
||||
|
||||
return ProviderResponse(answer=answer, total_tokens=total_tokens)
|
||||
75
detect/providers/gemini.py
Normal file
75
detect/providers/gemini.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Google Gemini provider — native REST API, not OpenAI-compatible."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
from .base import ModelInfo, ProviderResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Gemini-specific env vars
|
||||
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
|
||||
GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-2.0-flash")
|
||||
|
||||
MODELS = {
|
||||
"gemini-2.0-flash": ModelInfo(
|
||||
id="gemini-2.0-flash",
|
||||
vision=True,
|
||||
cost_per_input_token=0.0000001,
|
||||
cost_per_output_token=0.0000004,
|
||||
notes="Fast, cheap, good vision",
|
||||
),
|
||||
"gemini-2.0-pro": ModelInfo(
|
||||
id="gemini-2.0-pro",
|
||||
vision=True,
|
||||
cost_per_input_token=0.00000125,
|
||||
cost_per_output_token=0.000005,
|
||||
notes="Higher quality, slower",
|
||||
),
|
||||
"gemini-1.5-flash": ModelInfo(
|
||||
id="gemini-1.5-flash",
|
||||
vision=True,
|
||||
cost_per_input_token=0.000000075,
|
||||
cost_per_output_token=0.0000003,
|
||||
notes="Cheapest option",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class GeminiProvider:
|
||||
name = "gemini"
|
||||
models = MODELS
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = GEMINI_API_KEY
|
||||
self.model = GEMINI_MODEL
|
||||
self.endpoint = (
|
||||
f"https://generativelanguage.googleapis.com/v1beta/models/"
|
||||
f"{self.model}:generateContent"
|
||||
)
|
||||
|
||||
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
|
||||
payload = {
|
||||
"contents": [{
|
||||
"parts": [
|
||||
{"text": prompt},
|
||||
{"inline_data": {"mime_type": "image/jpeg", "data": image_b64}},
|
||||
],
|
||||
}],
|
||||
"generationConfig": {"maxOutputTokens": 150},
|
||||
}
|
||||
|
||||
url = f"{self.endpoint}?key={self.api_key}"
|
||||
resp = requests.post(url, json=payload, timeout=30)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
answer = data["candidates"][0]["content"]["parts"][0]["text"].strip()
|
||||
usage = data.get("usageMetadata", {})
|
||||
total_tokens = usage.get("totalTokenCount", 0)
|
||||
|
||||
return ProviderResponse(answer=answer, total_tokens=total_tokens)
|
||||
66
detect/providers/groq.py
Normal file
66
detect/providers/groq.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Groq cloud provider — OpenAI-compatible API with vision."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
from .base import ModelInfo, ProviderResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Groq-specific env vars
|
||||
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
|
||||
GROQ_BASE_URL = os.environ.get("GROQ_BASE_URL", "https://api.groq.com/openai/v1")
|
||||
GROQ_MODEL = os.environ.get("GROQ_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct")
|
||||
|
||||
MODELS = {
|
||||
"meta-llama/llama-4-scout-17b-16e-instruct": ModelInfo(
|
||||
id="meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
vision=True,
|
||||
cost_per_input_token=0.0,
|
||||
cost_per_output_token=0.0,
|
||||
notes="Llama 4 Scout, only vision model on Groq free tier",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class GroqProvider:
|
||||
name = "groq"
|
||||
models = MODELS
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = GROQ_API_KEY
|
||||
self.base_url = GROQ_BASE_URL
|
||||
self.model = GROQ_MODEL
|
||||
self.endpoint = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_b64}",
|
||||
}},
|
||||
],
|
||||
}],
|
||||
"max_tokens": 150,
|
||||
}
|
||||
|
||||
resp = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=30)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
answer = data["choices"][0]["message"]["content"].strip()
|
||||
total_tokens = data.get("usage", {}).get("total_tokens", 0)
|
||||
|
||||
return ProviderResponse(answer=answer, total_tokens=total_tokens)
|
||||
73
detect/providers/openai_compat.py
Normal file
73
detect/providers/openai_compat.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Generic OpenAI-compatible provider (OpenAI, Together, etc.)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
from .base import ModelInfo, ProviderResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OpenAI-compat specific env vars
|
||||
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
||||
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
OPENAI_MODEL = os.environ.get("OPENAI_MODEL", "gpt-4o-mini")
|
||||
|
||||
MODELS = {
|
||||
"gpt-4o-mini": ModelInfo(
|
||||
id="gpt-4o-mini",
|
||||
vision=True,
|
||||
cost_per_input_token=0.00000015,
|
||||
cost_per_output_token=0.0000006,
|
||||
notes="Cheap, fast, decent vision",
|
||||
),
|
||||
"gpt-4o": ModelInfo(
|
||||
id="gpt-4o",
|
||||
vision=True,
|
||||
cost_per_input_token=0.0000025,
|
||||
cost_per_output_token=0.00001,
|
||||
notes="Best OpenAI vision model",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class OpenAICompatProvider:
|
||||
name = "openai"
|
||||
models = MODELS
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = OPENAI_API_KEY
|
||||
self.base_url = OPENAI_BASE_URL
|
||||
self.model = OPENAI_MODEL
|
||||
self.endpoint = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_b64}",
|
||||
}},
|
||||
],
|
||||
}],
|
||||
"max_tokens": 150,
|
||||
}
|
||||
|
||||
resp = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=30)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
answer = data["choices"][0]["message"]["content"].strip()
|
||||
total_tokens = data.get("usage", {}).get("total_tokens", 0)
|
||||
|
||||
return ProviderResponse(answer=answer, total_tokens=total_tokens)
|
||||
116
detect/stages/aggregator.py
Normal file
116
detect/stages/aggregator.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Stage 8 — Report compilation
|
||||
|
||||
Groups all detections by brand, merges contiguous appearances,
|
||||
and builds the final DetectionReport.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from detect import emit
|
||||
from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _merge_contiguous(detections: list[BrandDetection], gap_threshold: float = 2.0) -> list[BrandDetection]:
|
||||
"""
|
||||
Merge detections of the same brand that are close in time.
|
||||
|
||||
If two detections of the same brand are within gap_threshold seconds,
|
||||
they're merged into one detection spanning the full range.
|
||||
"""
|
||||
if not detections:
|
||||
return []
|
||||
|
||||
sorted_dets = sorted(detections, key=lambda d: (d.brand, d.timestamp))
|
||||
merged: list[BrandDetection] = []
|
||||
current = sorted_dets[0]
|
||||
|
||||
for det in sorted_dets[1:]:
|
||||
if (det.brand == current.brand
|
||||
and det.timestamp <= current.timestamp + current.duration + gap_threshold):
|
||||
end = max(current.timestamp + current.duration,
|
||||
det.timestamp + det.duration)
|
||||
current = BrandDetection(
|
||||
brand=current.brand,
|
||||
timestamp=current.timestamp,
|
||||
duration=end - current.timestamp,
|
||||
confidence=max(current.confidence, det.confidence),
|
||||
source=current.source,
|
||||
bbox=current.bbox,
|
||||
frame_ref=current.frame_ref,
|
||||
content_type=current.content_type,
|
||||
)
|
||||
else:
|
||||
merged.append(current)
|
||||
current = det
|
||||
|
||||
merged.append(current)
|
||||
return merged
|
||||
|
||||
|
||||
def compile_report(
|
||||
detections: list[BrandDetection],
|
||||
stats: PipelineStats,
|
||||
video_source: str = "",
|
||||
content_type: str = "",
|
||||
duration_seconds: float = 0.0,
|
||||
job_id: str | None = None,
|
||||
) -> DetectionReport:
|
||||
"""
|
||||
Build the final detection report from all accumulated detections.
|
||||
|
||||
Merges contiguous detections, computes per-brand stats,
|
||||
and emits the job_complete event.
|
||||
"""
|
||||
merged = _merge_contiguous(detections)
|
||||
|
||||
brands: dict[str, BrandStats] = {}
|
||||
for d in merged:
|
||||
if d.brand not in brands:
|
||||
brands[d.brand] = BrandStats()
|
||||
s = brands[d.brand]
|
||||
s.total_appearances += 1
|
||||
s.total_screen_time += d.duration
|
||||
s.avg_confidence = (
|
||||
(s.avg_confidence * (s.total_appearances - 1) + d.confidence)
|
||||
/ s.total_appearances
|
||||
)
|
||||
if s.first_seen == 0.0 or d.timestamp < s.first_seen:
|
||||
s.first_seen = d.timestamp
|
||||
if d.timestamp > s.last_seen:
|
||||
s.last_seen = d.timestamp
|
||||
|
||||
report = DetectionReport(
|
||||
video_source=video_source,
|
||||
content_type=content_type,
|
||||
duration_seconds=duration_seconds,
|
||||
brands=brands,
|
||||
timeline=sorted(merged, key=lambda d: d.timestamp),
|
||||
pipeline_stats=stats,
|
||||
)
|
||||
|
||||
emit.log(job_id, "Aggregator", "INFO",
|
||||
f"Report: {len(brands)} brands, {len(merged)} detections "
|
||||
f"(merged from {len(detections)} raw)")
|
||||
|
||||
emit.job_complete(job_id, {
|
||||
"video_source": report.video_source,
|
||||
"content_type": report.content_type,
|
||||
"duration_seconds": report.duration_seconds,
|
||||
"brands": {
|
||||
k: {
|
||||
"total_appearances": v.total_appearances,
|
||||
"total_screen_time": v.total_screen_time,
|
||||
"avg_confidence": round(v.avg_confidence, 3),
|
||||
"first_seen": v.first_seen,
|
||||
"last_seen": v.last_seen,
|
||||
}
|
||||
for k, v in brands.items()
|
||||
},
|
||||
})
|
||||
|
||||
return report
|
||||
168
detect/stages/vlm_cloud.py
Normal file
168
detect/stages/vlm_cloud.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
Stage 7 — Cloud LLM escalation
|
||||
|
||||
Last resort for crops the local VLM couldn't resolve.
|
||||
Provider-agnostic — switch via CLOUD_LLM_PROVIDER env var.
|
||||
Each provider has its own file under detect/providers/.
|
||||
|
||||
Tracks token usage and cost.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from detect import emit
|
||||
from detect.models import BrandDetection, PipelineStats, TextCandidate
|
||||
from detect.profiles.base import CropContext
|
||||
from detect.providers import get_provider, has_api_key
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ESTIMATED_TOKENS_PER_CROP = 500
|
||||
|
||||
|
||||
def _encode_crop(crop: np.ndarray) -> str:
|
||||
img = Image.fromarray(crop)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=85)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
def _crop_image(candidate: TextCandidate) -> np.ndarray:
|
||||
frame = candidate.frame
|
||||
box = candidate.bbox
|
||||
h, w = frame.image.shape[:2]
|
||||
x1 = max(0, box.x)
|
||||
y1 = max(0, box.y)
|
||||
x2 = min(w, box.x + box.w)
|
||||
y2 = min(h, box.y + box.h)
|
||||
return frame.image[y1:y2, x1:x2]
|
||||
|
||||
|
||||
def _parse_response(answer: str, total_tokens: int) -> dict:
|
||||
"""Parse LLM free-text response into structured output."""
|
||||
parts = [p.strip() for p in answer.split(",", 2)]
|
||||
|
||||
brand = parts[0] if parts else ""
|
||||
confidence = 0.5
|
||||
reasoning = answer
|
||||
|
||||
if len(parts) >= 2:
|
||||
try:
|
||||
confidence = float(parts[1])
|
||||
confidence = max(0.0, min(1.0, confidence))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if len(parts) >= 3:
|
||||
reasoning = parts[2]
|
||||
|
||||
return {
|
||||
"brand": brand,
|
||||
"confidence": confidence,
|
||||
"reasoning": reasoning,
|
||||
"tokens": total_tokens or ESTIMATED_TOKENS_PER_CROP,
|
||||
}
|
||||
|
||||
|
||||
def _call_cloud_api(image_b64: str, prompt: str) -> dict:
|
||||
"""Route to the configured provider and parse the response."""
|
||||
provider = get_provider()
|
||||
result = provider.call(image_b64, prompt)
|
||||
return _parse_response(result.answer, result.total_tokens)
|
||||
|
||||
|
||||
def escalate_cloud(
|
||||
candidates: list[TextCandidate],
|
||||
vlm_prompt_fn,
|
||||
stats: PipelineStats,
|
||||
min_confidence: float = 0.4,
|
||||
content_type: str = "",
|
||||
job_id: str | None = None,
|
||||
) -> list[BrandDetection]:
|
||||
"""
|
||||
Send remaining unresolved crops to cloud LLM.
|
||||
|
||||
Provider is selected via CLOUD_LLM_PROVIDER env var (groq, gemini, openai).
|
||||
Updates stats with call count and cost.
|
||||
"""
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
if not has_api_key():
|
||||
emit.log(job_id, "CloudLLM", "WARNING",
|
||||
f"No API key set for cloud provider, skipping {len(candidates)} crops")
|
||||
return []
|
||||
|
||||
provider = get_provider()
|
||||
emit.log(job_id, "CloudLLM", "INFO",
|
||||
f"Escalating {len(candidates)} crops to {provider.name}")
|
||||
|
||||
matched: list[BrandDetection] = []
|
||||
total_cost = 0.0
|
||||
|
||||
for candidate in candidates:
|
||||
crop = _crop_image(candidate)
|
||||
if crop.size == 0:
|
||||
continue
|
||||
|
||||
crop_context = CropContext(
|
||||
image=b"",
|
||||
surrounding_text=candidate.text,
|
||||
position_hint=f"frame {candidate.frame.sequence}",
|
||||
)
|
||||
prompt = vlm_prompt_fn(crop_context)
|
||||
image_b64 = _encode_crop(crop)
|
||||
|
||||
try:
|
||||
result = _call_cloud_api(image_b64, prompt)
|
||||
except Exception as e:
|
||||
logger.warning("Cloud LLM failed for '%s': %s", candidate.text, e)
|
||||
continue
|
||||
|
||||
stats.cloud_llm_calls += 1
|
||||
model_info = provider.models.get(provider.model)
|
||||
cost_per_token = model_info.cost_per_input_token if model_info else 0.00001
|
||||
call_cost = result["tokens"] * cost_per_token
|
||||
total_cost += call_cost
|
||||
|
||||
brand = result["brand"]
|
||||
confidence = result["confidence"]
|
||||
|
||||
if brand and confidence >= min_confidence:
|
||||
detection = BrandDetection(
|
||||
brand=brand,
|
||||
timestamp=candidate.frame.timestamp,
|
||||
duration=0.5,
|
||||
confidence=confidence,
|
||||
source="cloud_llm",
|
||||
bbox=candidate.bbox,
|
||||
frame_ref=candidate.frame.sequence,
|
||||
content_type=content_type,
|
||||
)
|
||||
matched.append(detection)
|
||||
|
||||
emit.detection(
|
||||
job_id,
|
||||
brand=brand,
|
||||
confidence=confidence,
|
||||
source="cloud_llm",
|
||||
timestamp=candidate.frame.timestamp,
|
||||
content_type=content_type,
|
||||
frame_ref=candidate.frame.sequence,
|
||||
)
|
||||
|
||||
stats.estimated_cloud_cost_usd += total_cost
|
||||
stats.regions_escalated_to_cloud_llm = len(candidates)
|
||||
|
||||
emit.log(job_id, "CloudLLM", "INFO",
|
||||
f"Cloud resolved {len(matched)}/{len(candidates)} — "
|
||||
f"cost ${total_cost:.4f} ({stats.cloud_llm_calls} calls total)")
|
||||
|
||||
return matched
|
||||
124
detect/stages/vlm_local.py
Normal file
124
detect/stages/vlm_local.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Stage 6 — Local VLM escalation (moondream2)
|
||||
|
||||
Processes unresolved text candidates by sending crop images + prompt
|
||||
to the local VLM on the inference server. Produces BrandDetection
|
||||
objects for crops the VLM can identify.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from detect import emit
|
||||
from detect.models import BrandDetection, TextCandidate
|
||||
from detect.profiles.base import CropContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _crop_image(candidate: TextCandidate) -> np.ndarray:
|
||||
frame = candidate.frame
|
||||
box = candidate.bbox
|
||||
h, w = frame.image.shape[:2]
|
||||
x1 = max(0, box.x)
|
||||
y1 = max(0, box.y)
|
||||
x2 = min(w, box.x + box.w)
|
||||
y2 = min(h, box.y + box.h)
|
||||
return frame.image[y1:y2, x1:x2]
|
||||
|
||||
|
||||
def escalate_vlm(
|
||||
candidates: list[TextCandidate],
|
||||
vlm_prompt_fn,
|
||||
inference_url: str | None = None,
|
||||
min_confidence: float = 0.5,
|
||||
content_type: str = "",
|
||||
job_id: str | None = None,
|
||||
) -> tuple[list[BrandDetection], list[TextCandidate]]:
|
||||
"""
|
||||
Send unresolved crops to local VLM for brand identification.
|
||||
|
||||
Returns:
|
||||
- matched: BrandDetections the VLM confirmed
|
||||
- still_unresolved: candidates the VLM couldn't resolve (→ cloud escalation)
|
||||
"""
|
||||
if not candidates:
|
||||
return [], []
|
||||
|
||||
emit.log(job_id, "VLMLocal", "INFO",
|
||||
f"Processing {len(candidates)} unresolved crops with moondream2")
|
||||
|
||||
matched: list[BrandDetection] = []
|
||||
still_unresolved: list[TextCandidate] = []
|
||||
|
||||
if inference_url:
|
||||
from detect.inference import InferenceClient
|
||||
client = InferenceClient(base_url=inference_url)
|
||||
|
||||
for candidate in candidates:
|
||||
crop = _crop_image(candidate)
|
||||
if crop.size == 0:
|
||||
still_unresolved.append(candidate)
|
||||
continue
|
||||
|
||||
crop_context = CropContext(
|
||||
image=b"", # not used for prompt generation
|
||||
surrounding_text=candidate.text,
|
||||
position_hint=f"frame {candidate.frame.sequence}",
|
||||
)
|
||||
prompt = vlm_prompt_fn(crop_context)
|
||||
|
||||
try:
|
||||
if inference_url:
|
||||
result = client.vlm(image=crop, prompt=prompt)
|
||||
brand = result.brand
|
||||
confidence = result.confidence
|
||||
reasoning = result.reasoning
|
||||
else:
|
||||
brand, confidence, reasoning = _vlm_local(crop, prompt)
|
||||
except Exception as e:
|
||||
logger.warning("VLM failed for candidate '%s': %s", candidate.text, e)
|
||||
still_unresolved.append(candidate)
|
||||
continue
|
||||
|
||||
if brand and confidence >= min_confidence:
|
||||
detection = BrandDetection(
|
||||
brand=brand,
|
||||
timestamp=candidate.frame.timestamp,
|
||||
duration=0.5,
|
||||
confidence=confidence,
|
||||
source="local_vlm",
|
||||
bbox=candidate.bbox,
|
||||
frame_ref=candidate.frame.sequence,
|
||||
content_type=content_type,
|
||||
)
|
||||
matched.append(detection)
|
||||
|
||||
emit.detection(
|
||||
job_id,
|
||||
brand=brand,
|
||||
confidence=confidence,
|
||||
source="local_vlm",
|
||||
timestamp=candidate.frame.timestamp,
|
||||
content_type=content_type,
|
||||
frame_ref=candidate.frame.sequence,
|
||||
)
|
||||
|
||||
logger.debug("VLM matched: %s (%.2f) — %s", brand, confidence, reasoning)
|
||||
else:
|
||||
still_unresolved.append(candidate)
|
||||
|
||||
emit.log(job_id, "VLMLocal", "INFO",
|
||||
f"VLM resolved {len(matched)}, unresolved {len(still_unresolved)} → cloud")
|
||||
|
||||
return matched, still_unresolved
|
||||
|
||||
|
||||
def _vlm_local(crop: np.ndarray, prompt: str) -> tuple[str, float, str]:
|
||||
"""Run moondream2 in-process (single-box mode)."""
|
||||
from gpu.models.vlm import query
|
||||
result = query(crop, prompt)
|
||||
return result["brand"], result["confidence"], result["reasoning"]
|
||||
Reference in New Issue
Block a user