phase 9
This commit is contained in:
92
tests/detect/test_vlm_cloud.py
Normal file
92
tests/detect/test_vlm_cloud.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Tests for cloud LLM escalation stage."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, Frame, PipelineStats, TextCandidate
|
||||
from detect.stages.vlm_cloud import escalate_cloud, _parse_response
|
||||
|
||||
|
||||
def _make_candidate(text: str = "unknown", confidence: float = 0.4) -> TextCandidate:
|
||||
frame = Frame(sequence=0, chunk_id=0, timestamp=1.0,
|
||||
image=np.zeros((50, 100, 3), dtype=np.uint8))
|
||||
box = BoundingBox(x=0, y=0, w=100, h=50, confidence=0.5, label="text")
|
||||
return TextCandidate(frame=frame, bbox=box, text=text, ocr_confidence=confidence)
|
||||
|
||||
|
||||
def test_parse_response_clean():
|
||||
result = _parse_response("Nike, 0.92, swoosh logo visible", 200)
|
||||
assert result["brand"] == "Nike"
|
||||
assert result["confidence"] == 0.92
|
||||
assert "swoosh" in result["reasoning"]
|
||||
assert result["tokens"] == 200
|
||||
|
||||
|
||||
def test_parse_response_no_confidence():
|
||||
result = _parse_response("Adidas", 0)
|
||||
assert result["brand"] == "Adidas"
|
||||
assert result["confidence"] == 0.5 # default
|
||||
|
||||
|
||||
def test_escalate_skips_without_api_key(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq")
|
||||
# Reset cached provider
|
||||
import detect.providers as prov
|
||||
monkeypatch.setattr(prov, "_cached", None)
|
||||
|
||||
candidates = [_make_candidate()]
|
||||
stats = PipelineStats()
|
||||
prompt_fn = lambda ctx: "what brand?"
|
||||
|
||||
matched = escalate_cloud(candidates, prompt_fn, stats, job_id="test")
|
||||
|
||||
assert len(matched) == 0
|
||||
assert stats.cloud_llm_calls == 0
|
||||
log_events = [e for e in events if e[0] == "log"]
|
||||
assert any("No API key" in e[1].get("msg", "") for e in log_events)
|
||||
|
||||
|
||||
def test_escalate_empty_candidates(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
stats = PipelineStats()
|
||||
matched = escalate_cloud([], lambda ctx: "", stats, job_id="test")
|
||||
|
||||
assert len(matched) == 0
|
||||
assert stats.cloud_llm_calls == 0
|
||||
|
||||
|
||||
def test_escalate_with_mock_api(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
monkeypatch.setenv("GROQ_API_KEY", "test-key")
|
||||
monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq")
|
||||
# Reset cached provider
|
||||
import detect.providers as prov
|
||||
monkeypatch.setattr(prov, "_cached", None)
|
||||
|
||||
def mock_call(image_b64, prompt):
|
||||
return {"brand": "Heineken", "confidence": 0.75, "reasoning": "green logo", "tokens": 300}
|
||||
|
||||
monkeypatch.setattr("detect.stages.vlm_cloud._call_cloud_api", mock_call)
|
||||
|
||||
candidates = [_make_candidate("unknown logo")]
|
||||
stats = PipelineStats()
|
||||
prompt_fn = lambda ctx: "what brand?"
|
||||
|
||||
matched = escalate_cloud(candidates, prompt_fn, stats, job_id="test")
|
||||
|
||||
assert len(matched) == 1
|
||||
assert matched[0].brand == "Heineken"
|
||||
assert matched[0].source == "cloud_llm"
|
||||
assert stats.cloud_llm_calls == 1
|
||||
assert stats.estimated_cloud_cost_usd >= 0 # exact cost depends on provider model index
|
||||
Reference in New Issue
Block a user