phase 9
This commit is contained in:
107
tests/detect/manual/test_cloud_provider.py
Normal file
107
tests/detect/manual/test_cloud_provider.py
Normal file
@@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test cloud LLM provider with a real API call.
|
||||
|
||||
Sends a test image to the configured cloud provider and verifies
|
||||
the response. Set your provider env vars before running.
|
||||
|
||||
Usage:
|
||||
# Groq (default)
|
||||
CLOUD_LLM_PROVIDER=groq GROQ_API_KEY=gsk_... python tests/detect/manual/test_cloud_provider.py
|
||||
|
||||
# Gemini
|
||||
CLOUD_LLM_PROVIDER=gemini GEMINI_API_KEY=AIza... python tests/detect/manual/test_cloud_provider.py
|
||||
|
||||
# Claude
|
||||
CLOUD_LLM_PROVIDER=claude ANTHROPIC_API_KEY=sk-ant-... python tests/detect/manual/test_cloud_provider.py
|
||||
|
||||
# OpenAI-compatible
|
||||
CLOUD_LLM_PROVIDER=openai OPENAI_API_KEY=sk-... python tests/detect/manual/test_cloud_provider.py
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Load .env from ctrl/ (same as docker-compose uses)
|
||||
env_file = Path(__file__).resolve().parents[3] / "ctrl" / ".env"
|
||||
if env_file.exists():
|
||||
for line in env_file.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
key, _, value = line.partition("=")
|
||||
os.environ.setdefault(key.strip(), value.strip())
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
sys.path.insert(0, ".")
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_brand_image(text: str, width: int = 300, height: int = 100) -> str:
|
||||
img = Image.new("RGB", (width, height), "white")
|
||||
draw = ImageDraw.Draw(img)
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 42)
|
||||
except (OSError, IOError):
|
||||
font = ImageFont.load_default()
|
||||
draw.text((10, 20), text, fill="black", font=font)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, "JPEG")
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
def main():
|
||||
from detect.providers import get_provider, has_api_key, PROVIDERS
|
||||
|
||||
provider_name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
|
||||
logger.info("Provider: %s", provider_name)
|
||||
logger.info("Available providers: %s", list(PROVIDERS.keys()))
|
||||
|
||||
if not has_api_key():
|
||||
logger.error("No API key set for provider '%s'", provider_name)
|
||||
logger.error("Set the appropriate env var (see usage in docstring)")
|
||||
sys.exit(1)
|
||||
|
||||
provider = get_provider()
|
||||
logger.info("Model: %s", provider.model)
|
||||
logger.info("Available models: %s", list(provider.models.keys()))
|
||||
input("\nPress Enter to start...")
|
||||
|
||||
prompt = (
|
||||
"Identify the brand or sponsor visible in this image from a soccer broadcast. "
|
||||
"Respond with: brand, confidence (0-1), reasoning."
|
||||
)
|
||||
|
||||
test_cases = ["NIKE", "EMIRATES", "Coca-Cola", "adidas"]
|
||||
|
||||
for text in test_cases:
|
||||
logger.info("--- Testing: '%s' ---", text)
|
||||
image_b64 = make_brand_image(text)
|
||||
|
||||
try:
|
||||
result = provider.call(image_b64, prompt)
|
||||
logger.info(" answer: %s", result.answer)
|
||||
logger.info(" tokens: %d", result.total_tokens)
|
||||
|
||||
if text.lower() in result.answer.lower():
|
||||
logger.info(" PASS — found '%s' in response", text)
|
||||
else:
|
||||
logger.warning(" MISS — '%s' not in response (may be correct, check answer)", text)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(" FAIL — %s: %s", type(e).__name__, e)
|
||||
if hasattr(e, 'response') and e.response is not None:
|
||||
logger.error(" Response: %s", e.response.text[:500])
|
||||
|
||||
logger.info("All provider tests complete.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
230
tests/detect/manual/test_escalation_e2e.py
Normal file
230
tests/detect/manual/test_escalation_e2e.py
Normal file
@@ -0,0 +1,230 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Push a full pipeline simulation with escalation events.
|
||||
|
||||
Exercises all stages including VLM and cloud escalation, with progressive
|
||||
stats showing cost accumulating. Tests all panels: pipeline graph, funnel,
|
||||
timeline, cost stats, brand table, and log.
|
||||
|
||||
Usage:
|
||||
python tests/detect/manual/test_escalation_e2e.py [--job JOB_ID] [--port PORT] [--delay SECS]
|
||||
|
||||
Opens: http://mpr.local.ar/detection/?job=<JOB_ID>
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import redis
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NODES = ["extract_frames", "filter_scenes", "detect_objects", "run_ocr",
|
||||
"match_brands", "escalate_vlm", "escalate_cloud", "compile_report"]
|
||||
|
||||
|
||||
def ts():
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def push(r, key, event):
|
||||
event["ts"] = event.get("ts", ts())
|
||||
r.rpush(key, json.dumps(event))
|
||||
return event
|
||||
|
||||
|
||||
def push_graph(r, key, active_node, status, delay):
|
||||
nodes = []
|
||||
for n in NODES:
|
||||
if n == active_node:
|
||||
nodes.append({"id": n, "status": status})
|
||||
elif NODES.index(n) < NODES.index(active_node):
|
||||
nodes.append({"id": n, "status": "done"})
|
||||
else:
|
||||
nodes.append({"id": n, "status": "pending"})
|
||||
push(r, key, {"event": "graph_update", "nodes": nodes})
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
def push_stats(r, key, **fields):
|
||||
base = {
|
||||
"event": "stats_update",
|
||||
"frames_extracted": 0, "frames_after_scene_filter": 0,
|
||||
"regions_detected": 0, "regions_resolved_by_ocr": 0,
|
||||
"regions_escalated_to_local_vlm": 0, "regions_escalated_to_cloud_llm": 0,
|
||||
"cloud_llm_calls": 0, "processing_time_seconds": 0, "estimated_cloud_cost_usd": 0,
|
||||
}
|
||||
base.update(fields)
|
||||
push(r, key, base)
|
||||
|
||||
|
||||
def push_detection(r, key, brand, conf, source, timestamp, frame_ref, delay):
|
||||
push(r, key, {
|
||||
"event": "detection",
|
||||
"brand": brand, "confidence": conf, "source": source,
|
||||
"timestamp": timestamp, "duration": 0.5,
|
||||
"content_type": "soccer_broadcast", "frame_ref": frame_ref,
|
||||
})
|
||||
logger.info(" [%s] %s %.2f t=%.1fs", source, brand, conf, timestamp)
|
||||
time.sleep(delay * 0.3)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--job", default="escalation-test")
|
||||
parser.add_argument("--port", type=int, default=6382)
|
||||
parser.add_argument("--delay", type=float, default=0.5)
|
||||
args = parser.parse_args()
|
||||
|
||||
r = redis.Redis(port=args.port, decode_responses=True)
|
||||
key = f"detect_events:{args.job}"
|
||||
r.delete(key)
|
||||
delay = args.delay
|
||||
|
||||
logger.info("Full escalation pipeline simulation → %s", key)
|
||||
logger.info("Open: http://mpr.local.ar/detection/?job=%s", args.job)
|
||||
input("\nPress Enter to start...")
|
||||
|
||||
# --- Extract frames ---
|
||||
push_graph(r, key, "extract_frames", "running", delay)
|
||||
push(r, key, {"event": "log", "level": "INFO", "stage": "FrameExtractor",
|
||||
"msg": "Extracting frames: match_clip.mp4 (90.0s, 1920x1080, fps=2)"})
|
||||
time.sleep(delay)
|
||||
push_stats(r, key, frames_extracted=180, processing_time_seconds=4.5)
|
||||
push_graph(r, key, "extract_frames", "done", delay)
|
||||
|
||||
# --- Scene filter ---
|
||||
push_graph(r, key, "filter_scenes", "running", delay)
|
||||
push_stats(r, key, frames_extracted=180, frames_after_scene_filter=52, processing_time_seconds=6.8)
|
||||
push(r, key, {"event": "log", "level": "INFO", "stage": "SceneFilter",
|
||||
"msg": "Kept 52 frames (71% reduction)"})
|
||||
push_graph(r, key, "filter_scenes", "done", delay)
|
||||
|
||||
# --- YOLO detect ---
|
||||
push_graph(r, key, "detect_objects", "running", delay)
|
||||
push(r, key, {"event": "log", "level": "INFO", "stage": "YOLODetector",
|
||||
"msg": "Running yolov8n on 52 frames"})
|
||||
time.sleep(delay)
|
||||
push_stats(r, key, frames_extracted=180, frames_after_scene_filter=52,
|
||||
regions_detected=41, processing_time_seconds=14.2)
|
||||
push_graph(r, key, "detect_objects", "done", delay)
|
||||
|
||||
# --- OCR ---
|
||||
push_graph(r, key, "run_ocr", "running", delay)
|
||||
push(r, key, {"event": "log", "level": "INFO", "stage": "OCRStage",
|
||||
"msg": "Running OCR on 41 regions (mode=remote)"})
|
||||
time.sleep(delay)
|
||||
push_stats(r, key, frames_extracted=180, frames_after_scene_filter=52,
|
||||
regions_detected=41, regions_resolved_by_ocr=30, processing_time_seconds=21.5)
|
||||
push_graph(r, key, "run_ocr", "done", delay)
|
||||
|
||||
# --- Brand matching ---
|
||||
push_graph(r, key, "match_brands", "running", delay)
|
||||
push(r, key, {"event": "log", "level": "INFO", "stage": "BrandResolver",
|
||||
"msg": "Matching 30 candidates against 12 brands (fuzzy_threshold=75)"})
|
||||
time.sleep(delay)
|
||||
|
||||
# OCR detections
|
||||
ocr_brands = [
|
||||
("Nike", 0.97, 2.0, 4), ("Nike", 0.95, 5.5, 11), ("Emirates", 0.92, 8.0, 16),
|
||||
("Adidas", 0.89, 12.0, 24), ("Coca-Cola", 0.85, 18.0, 36),
|
||||
("Nike", 0.94, 22.0, 44), ("Emirates", 0.88, 28.0, 56),
|
||||
("Adidas", 0.91, 32.0, 64), ("Nike", 0.96, 38.0, 76),
|
||||
("Emirates", 0.90, 42.0, 84), ("Coca-Cola", 0.87, 48.0, 96),
|
||||
("Nike", 0.93, 52.0, 104), ("Adidas", 0.90, 58.0, 116),
|
||||
]
|
||||
for brand, conf, ts_val, fref in ocr_brands:
|
||||
push_detection(r, key, brand, conf, "ocr", ts_val, fref, delay)
|
||||
|
||||
push(r, key, {"event": "log", "level": "INFO", "stage": "BrandResolver",
|
||||
"msg": "Exact: 10, Fuzzy: 3, Unresolved: 11 → VLM"})
|
||||
push_graph(r, key, "match_brands", "done", delay)
|
||||
|
||||
# --- VLM escalation ---
|
||||
push_graph(r, key, "escalate_vlm", "running", delay)
|
||||
push(r, key, {"event": "log", "level": "INFO", "stage": "VLMLocal",
|
||||
"msg": "Processing 11 unresolved crops with moondream2"})
|
||||
time.sleep(delay)
|
||||
|
||||
vlm_brands = [
|
||||
("Mastercard", 0.78, 15.0, 30), ("Santander", 0.74, 25.0, 50),
|
||||
("Qatar Airways", 0.81, 35.0, 70), ("Heineken", 0.76, 45.0, 90),
|
||||
("Lay's", 0.72, 55.0, 110),
|
||||
]
|
||||
for brand, conf, ts_val, fref in vlm_brands:
|
||||
push_detection(r, key, brand, conf, "local_vlm", ts_val, fref, delay)
|
||||
|
||||
push_stats(r, key, frames_extracted=180, frames_after_scene_filter=52,
|
||||
regions_detected=41, regions_resolved_by_ocr=30,
|
||||
regions_escalated_to_local_vlm=11, processing_time_seconds=38.7,
|
||||
estimated_cloud_cost_usd=0)
|
||||
push(r, key, {"event": "log", "level": "INFO", "stage": "VLMLocal",
|
||||
"msg": "VLM resolved 5, unresolved 6 → cloud"})
|
||||
push_graph(r, key, "escalate_vlm", "done", delay)
|
||||
|
||||
# --- Cloud escalation ---
|
||||
push_graph(r, key, "escalate_cloud", "running", delay)
|
||||
push(r, key, {"event": "log", "level": "INFO", "stage": "CloudLLM",
|
||||
"msg": "Escalating 6 crops to groq (llama-3.2-90b-vision)"})
|
||||
time.sleep(delay)
|
||||
|
||||
cloud_brands = [
|
||||
("Pepsi", 0.68, 10.0, 20),
|
||||
("Gazprom", 0.65, 40.0, 80),
|
||||
]
|
||||
for brand, conf, ts_val, fref in cloud_brands:
|
||||
push_detection(r, key, brand, conf, "cloud_llm", ts_val, fref, delay)
|
||||
|
||||
push_stats(r, key, frames_extracted=180, frames_after_scene_filter=52,
|
||||
regions_detected=41, regions_resolved_by_ocr=30,
|
||||
regions_escalated_to_local_vlm=11, regions_escalated_to_cloud_llm=6,
|
||||
cloud_llm_calls=6, processing_time_seconds=45.2,
|
||||
estimated_cloud_cost_usd=0.0) # groq free tier
|
||||
|
||||
push(r, key, {"event": "log", "level": "WARNING", "stage": "CloudLLM",
|
||||
"msg": "4 crops unresolved after cloud — likely not brands"})
|
||||
push(r, key, {"event": "log", "level": "INFO", "stage": "CloudLLM",
|
||||
"msg": "Cloud resolved 2/6 — cost $0.0000 (groq free tier)"})
|
||||
push_graph(r, key, "escalate_cloud", "done", delay)
|
||||
|
||||
# --- Compile report ---
|
||||
push_graph(r, key, "compile_report", "running", delay)
|
||||
|
||||
total_brands = len(set(b[0] for b in ocr_brands + vlm_brands + cloud_brands))
|
||||
total_dets = len(ocr_brands) + len(vlm_brands) + len(cloud_brands)
|
||||
|
||||
push(r, key, {"event": "log", "level": "INFO", "stage": "Aggregator",
|
||||
"msg": f"Report: {total_brands} brands, {total_dets} detections (merged from {total_dets} raw)"})
|
||||
|
||||
push(r, key, {"event": "job_complete", "job_id": args.job, "report": {
|
||||
"video_source": "match_clip.mp4",
|
||||
"content_type": "soccer_broadcast",
|
||||
"duration_seconds": 90.0,
|
||||
"brands": {
|
||||
"Nike": {"total_appearances": 5, "avg_confidence": 0.95},
|
||||
"Emirates": {"total_appearances": 3, "avg_confidence": 0.90},
|
||||
"Adidas": {"total_appearances": 3, "avg_confidence": 0.90},
|
||||
"Coca-Cola": {"total_appearances": 2, "avg_confidence": 0.86},
|
||||
"Mastercard": {"total_appearances": 1, "avg_confidence": 0.78},
|
||||
"Santander": {"total_appearances": 1, "avg_confidence": 0.74},
|
||||
"Qatar Airways": {"total_appearances": 1, "avg_confidence": 0.81},
|
||||
"Heineken": {"total_appearances": 1, "avg_confidence": 0.76},
|
||||
"Lay's": {"total_appearances": 1, "avg_confidence": 0.72},
|
||||
"Pepsi": {"total_appearances": 1, "avg_confidence": 0.68},
|
||||
"Gazprom": {"total_appearances": 1, "avg_confidence": 0.65},
|
||||
},
|
||||
}})
|
||||
|
||||
push_graph(r, key, "compile_report", "done", delay)
|
||||
|
||||
logger.info("Done. %d brands, %d detections across ocr/vlm/cloud.", total_brands, total_dets)
|
||||
logger.info("Check: pipeline graph (all green), timeline (3 source colors),")
|
||||
logger.info(" cost panel (escalation ratio), brand table (source column).")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
100
tests/detect/manual/test_vlm_e2e.py
Normal file
100
tests/detect/manual/test_vlm_e2e.py
Normal file
@@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test local VLM (moondream2) via the inference server.
|
||||
|
||||
Creates test images with brand text/logos, sends them to the /vlm endpoint,
|
||||
verifies moondream2 can identify the brand.
|
||||
|
||||
Usage:
|
||||
python tests/detect/manual/test_vlm_e2e.py [--url URL]
|
||||
|
||||
Requires: inference server running with moondream2 loaded (gpu/server.py)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_brand_image(text: str, width: int = 300, height: int = 100) -> np.ndarray:
|
||||
img = Image.new("RGB", (width, height), "white")
|
||||
draw = ImageDraw.Draw(img)
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 42)
|
||||
except (OSError, IOError):
|
||||
font = ImageFont.load_default()
|
||||
draw.text((10, 20), text, fill="black", font=font)
|
||||
return np.array(img)
|
||||
|
||||
|
||||
def image_to_b64(image: np.ndarray) -> str:
|
||||
img = Image.fromarray(image)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, "JPEG")
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
def test_health(url: str):
|
||||
logger.info("--- Health check ---")
|
||||
resp = requests.get(f"{url}/health")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
logger.info("Status: %s, device: %s, models: %s", data["status"], data["device"], data.get("loaded_models", []))
|
||||
|
||||
|
||||
def test_vlm(url: str, text: str, prompt: str):
|
||||
logger.info("--- VLM: image='%s' ---", text)
|
||||
image = make_brand_image(text)
|
||||
b64 = image_to_b64(image)
|
||||
|
||||
resp = requests.post(f"{url}/vlm", json={"image": b64, "prompt": prompt})
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
logger.info(" brand: %s", data["brand"])
|
||||
logger.info(" confidence: %.2f", data["confidence"])
|
||||
logger.info(" reasoning: %s", data["reasoning"])
|
||||
|
||||
if text.lower() in data["brand"].lower():
|
||||
logger.info(" PASS — matched")
|
||||
else:
|
||||
logger.warning(" MISS — expected '%s' in response", text)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--url", default="http://mcrndeb:8000")
|
||||
args = parser.parse_args()
|
||||
|
||||
url = args.url.rstrip("/")
|
||||
logger.info("Inference server: %s", url)
|
||||
input("\nPress Enter to start...")
|
||||
|
||||
test_health(url)
|
||||
|
||||
prompt = (
|
||||
"Identify the brand or sponsor visible in this image from a soccer broadcast. "
|
||||
"Respond with: brand, confidence (0-1), reasoning."
|
||||
)
|
||||
|
||||
test_vlm(url, "NIKE", prompt)
|
||||
test_vlm(url, "EMIRATES", prompt)
|
||||
test_vlm(url, "Coca-Cola", prompt)
|
||||
test_vlm(url, "adidas", prompt)
|
||||
|
||||
logger.info("All VLM tests complete.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
79
tests/detect/test_aggregator.py
Normal file
79
tests/detect/test_aggregator.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Tests for the report aggregator stage."""
|
||||
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, BrandDetection, PipelineStats
|
||||
from detect.stages.aggregator import compile_report, _merge_contiguous
|
||||
|
||||
|
||||
def _make_detection(brand: str, timestamp: float, duration: float = 0.5,
|
||||
source: str = "ocr", confidence: float = 0.9) -> BrandDetection:
|
||||
return BrandDetection(
|
||||
brand=brand, timestamp=timestamp, duration=duration,
|
||||
confidence=confidence, source=source, content_type="soccer_broadcast",
|
||||
)
|
||||
|
||||
|
||||
def test_merge_contiguous_same_brand():
|
||||
dets = [
|
||||
_make_detection("Nike", 1.0, 0.5),
|
||||
_make_detection("Nike", 1.3, 0.5), # within gap
|
||||
_make_detection("Nike", 5.0, 0.5), # separate
|
||||
]
|
||||
merged = _merge_contiguous(dets, gap_threshold=2.0)
|
||||
assert len(merged) == 2
|
||||
assert merged[0].brand == "Nike"
|
||||
assert merged[0].timestamp == 1.0
|
||||
assert merged[0].duration == pytest.approx(0.8) # 1.0 to 1.8
|
||||
assert merged[1].timestamp == 5.0
|
||||
|
||||
|
||||
def test_merge_different_brands():
|
||||
dets = [
|
||||
_make_detection("Nike", 1.0),
|
||||
_make_detection("Adidas", 1.5),
|
||||
]
|
||||
merged = _merge_contiguous(dets)
|
||||
assert len(merged) == 2
|
||||
|
||||
|
||||
def test_merge_empty():
|
||||
assert _merge_contiguous([]) == []
|
||||
|
||||
|
||||
def test_compile_report(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
dets = [
|
||||
_make_detection("Nike", 1.0, 0.5, confidence=0.95),
|
||||
_make_detection("Nike", 5.0, 1.0, confidence=0.90),
|
||||
_make_detection("Adidas", 3.0, 0.5, confidence=0.85),
|
||||
_make_detection("Heineken", 10.0, 0.5, source="cloud_llm", confidence=0.70),
|
||||
]
|
||||
stats = PipelineStats(
|
||||
frames_extracted=120,
|
||||
regions_detected=32,
|
||||
cloud_llm_calls=1,
|
||||
estimated_cloud_cost_usd=0.003,
|
||||
)
|
||||
|
||||
report = compile_report(
|
||||
detections=dets,
|
||||
stats=stats,
|
||||
video_source="test.mp4",
|
||||
content_type="soccer_broadcast",
|
||||
job_id="test-report",
|
||||
)
|
||||
|
||||
assert len(report.brands) == 3
|
||||
assert report.brands["Nike"].total_appearances == 2
|
||||
assert report.brands["Adidas"].total_appearances == 1
|
||||
assert report.brands["Heineken"].total_appearances == 1
|
||||
assert report.pipeline_stats.cloud_llm_calls == 1
|
||||
assert report.video_source == "test.mp4"
|
||||
|
||||
# job_complete event should have been emitted
|
||||
complete = [e for e in events if e[0] == "job_complete"]
|
||||
assert len(complete) == 1
|
||||
92
tests/detect/test_vlm_cloud.py
Normal file
92
tests/detect/test_vlm_cloud.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Tests for cloud LLM escalation stage."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, Frame, PipelineStats, TextCandidate
|
||||
from detect.stages.vlm_cloud import escalate_cloud, _parse_response
|
||||
|
||||
|
||||
def _make_candidate(text: str = "unknown", confidence: float = 0.4) -> TextCandidate:
|
||||
frame = Frame(sequence=0, chunk_id=0, timestamp=1.0,
|
||||
image=np.zeros((50, 100, 3), dtype=np.uint8))
|
||||
box = BoundingBox(x=0, y=0, w=100, h=50, confidence=0.5, label="text")
|
||||
return TextCandidate(frame=frame, bbox=box, text=text, ocr_confidence=confidence)
|
||||
|
||||
|
||||
def test_parse_response_clean():
|
||||
result = _parse_response("Nike, 0.92, swoosh logo visible", 200)
|
||||
assert result["brand"] == "Nike"
|
||||
assert result["confidence"] == 0.92
|
||||
assert "swoosh" in result["reasoning"]
|
||||
assert result["tokens"] == 200
|
||||
|
||||
|
||||
def test_parse_response_no_confidence():
|
||||
result = _parse_response("Adidas", 0)
|
||||
assert result["brand"] == "Adidas"
|
||||
assert result["confidence"] == 0.5 # default
|
||||
|
||||
|
||||
def test_escalate_skips_without_api_key(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq")
|
||||
# Reset cached provider
|
||||
import detect.providers as prov
|
||||
monkeypatch.setattr(prov, "_cached", None)
|
||||
|
||||
candidates = [_make_candidate()]
|
||||
stats = PipelineStats()
|
||||
prompt_fn = lambda ctx: "what brand?"
|
||||
|
||||
matched = escalate_cloud(candidates, prompt_fn, stats, job_id="test")
|
||||
|
||||
assert len(matched) == 0
|
||||
assert stats.cloud_llm_calls == 0
|
||||
log_events = [e for e in events if e[0] == "log"]
|
||||
assert any("No API key" in e[1].get("msg", "") for e in log_events)
|
||||
|
||||
|
||||
def test_escalate_empty_candidates(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
stats = PipelineStats()
|
||||
matched = escalate_cloud([], lambda ctx: "", stats, job_id="test")
|
||||
|
||||
assert len(matched) == 0
|
||||
assert stats.cloud_llm_calls == 0
|
||||
|
||||
|
||||
def test_escalate_with_mock_api(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
monkeypatch.setenv("GROQ_API_KEY", "test-key")
|
||||
monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq")
|
||||
# Reset cached provider
|
||||
import detect.providers as prov
|
||||
monkeypatch.setattr(prov, "_cached", None)
|
||||
|
||||
def mock_call(image_b64, prompt):
|
||||
return {"brand": "Heineken", "confidence": 0.75, "reasoning": "green logo", "tokens": 300}
|
||||
|
||||
monkeypatch.setattr("detect.stages.vlm_cloud._call_cloud_api", mock_call)
|
||||
|
||||
candidates = [_make_candidate("unknown logo")]
|
||||
stats = PipelineStats()
|
||||
prompt_fn = lambda ctx: "what brand?"
|
||||
|
||||
matched = escalate_cloud(candidates, prompt_fn, stats, job_id="test")
|
||||
|
||||
assert len(matched) == 1
|
||||
assert matched[0].brand == "Heineken"
|
||||
assert matched[0].source == "cloud_llm"
|
||||
assert stats.cloud_llm_calls == 1
|
||||
assert stats.estimated_cloud_cost_usd >= 0 # exact cost depends on provider model index
|
||||
Reference in New Issue
Block a user