Files
mediaproc/tests/detect/test_vlm_cloud.py
2026-03-26 02:54:56 -03:00

93 lines
3.3 KiB
Python

"""Tests for cloud LLM escalation stage."""
import numpy as np
import pytest
from detect.models import BoundingBox, Frame, PipelineStats, TextCandidate
from detect.stages.vlm_cloud import escalate_cloud, _parse_response
def _make_candidate(text: str = "unknown", confidence: float = 0.4) -> TextCandidate:
frame = Frame(sequence=0, chunk_id=0, timestamp=1.0,
image=np.zeros((50, 100, 3), dtype=np.uint8))
box = BoundingBox(x=0, y=0, w=100, h=50, confidence=0.5, label="text")
return TextCandidate(frame=frame, bbox=box, text=text, ocr_confidence=confidence)
def test_parse_response_clean():
result = _parse_response("Nike, 0.92, swoosh logo visible", 200)
assert result["brand"] == "Nike"
assert result["confidence"] == 0.92
assert "swoosh" in result["reasoning"]
assert result["tokens"] == 200
def test_parse_response_no_confidence():
result = _parse_response("Adidas", 0)
assert result["brand"] == "Adidas"
assert result["confidence"] == 0.5 # default
def test_escalate_skips_without_api_key(monkeypatch):
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
monkeypatch.delenv("GROQ_API_KEY", raising=False)
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq")
# Reset cached provider
import detect.providers as prov
monkeypatch.setattr(prov, "_cached", None)
candidates = [_make_candidate()]
stats = PipelineStats()
prompt_fn = lambda ctx: "what brand?"
matched = escalate_cloud(candidates, prompt_fn, stats, job_id="test")
assert len(matched) == 0
assert stats.cloud_llm_calls == 0
log_events = [e for e in events if e[0] == "log"]
assert any("No API key" in e[1].get("msg", "") for e in log_events)
def test_escalate_empty_candidates(monkeypatch):
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
stats = PipelineStats()
matched = escalate_cloud([], lambda ctx: "", stats, job_id="test")
assert len(matched) == 0
assert stats.cloud_llm_calls == 0
def test_escalate_with_mock_api(monkeypatch):
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
monkeypatch.setenv("GROQ_API_KEY", "test-key")
monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq")
# Reset cached provider
import detect.providers as prov
monkeypatch.setattr(prov, "_cached", None)
def mock_call(image_b64, prompt):
return {"brand": "Heineken", "confidence": 0.75, "reasoning": "green logo", "tokens": 300}
monkeypatch.setattr("detect.stages.vlm_cloud._call_cloud_api", mock_call)
candidates = [_make_candidate("unknown logo")]
stats = PipelineStats()
prompt_fn = lambda ctx: "what brand?"
matched = escalate_cloud(candidates, prompt_fn, stats, job_id="test")
assert len(matched) == 1
assert matched[0].brand == "Heineken"
assert matched[0].source == "cloud_llm"
assert stats.cloud_llm_calls == 1
assert stats.estimated_cloud_cost_usd >= 0 # exact cost depends on provider model index