93 lines
3.4 KiB
Python
93 lines
3.4 KiB
Python
"""Tests for cloud LLM escalation stage."""
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from core.detect.models import BoundingBox, Frame, PipelineStats, TextCandidate
|
|
from core.detect.stages.vlm_cloud import escalate_cloud, _parse_response
|
|
|
|
|
|
def _make_candidate(text: str = "unknown", confidence: float = 0.4) -> TextCandidate:
|
|
frame = Frame(sequence=0, chunk_id=0, timestamp=1.0,
|
|
image=np.zeros((50, 100, 3), dtype=np.uint8))
|
|
box = BoundingBox(x=0, y=0, w=100, h=50, confidence=0.5, label="text")
|
|
return TextCandidate(frame=frame, bbox=box, text=text, ocr_confidence=confidence)
|
|
|
|
|
|
def test_parse_response_clean():
|
|
result = _parse_response("Nike, 0.92, swoosh logo visible", 200)
|
|
assert result["brand"] == "Nike"
|
|
assert result["confidence"] == 0.92
|
|
assert "swoosh" in result["reasoning"]
|
|
assert result["tokens"] == 200
|
|
|
|
|
|
def test_parse_response_no_confidence():
|
|
result = _parse_response("Adidas", 0)
|
|
assert result["brand"] == "Adidas"
|
|
assert result["confidence"] == 0.5 # default
|
|
|
|
|
|
def test_escalate_skips_without_api_key(monkeypatch):
|
|
events = []
|
|
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
|
lambda job_id, etype, data: events.append((etype, data)))
|
|
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
|
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
|
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
|
monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq")
|
|
# Reset cached provider
|
|
import core.detect.providers as prov
|
|
monkeypatch.setattr(prov, "_cached", None)
|
|
|
|
candidates = [_make_candidate()]
|
|
stats = PipelineStats()
|
|
prompt_fn = lambda ctx: "what brand?"
|
|
|
|
matched = escalate_cloud(candidates, prompt_fn, stats, job_id="test")
|
|
|
|
assert len(matched) == 0
|
|
assert stats.cloud_llm_calls == 0
|
|
log_events = [e for e in events if e[0] == "log"]
|
|
assert any("No API key" in e[1].get("msg", "") for e in log_events)
|
|
|
|
|
|
def test_escalate_empty_candidates(monkeypatch):
|
|
events = []
|
|
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
|
lambda job_id, etype, data: events.append((etype, data)))
|
|
|
|
stats = PipelineStats()
|
|
matched = escalate_cloud([], lambda ctx: "", stats, job_id="test")
|
|
|
|
assert len(matched) == 0
|
|
assert stats.cloud_llm_calls == 0
|
|
|
|
|
|
def test_escalate_with_mock_api(monkeypatch):
|
|
events = []
|
|
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
|
lambda job_id, etype, data: events.append((etype, data)))
|
|
monkeypatch.setenv("GROQ_API_KEY", "test-key")
|
|
monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq")
|
|
# Reset cached provider
|
|
import core.detect.providers as prov
|
|
monkeypatch.setattr(prov, "_cached", None)
|
|
|
|
def mock_call(image_b64, prompt):
|
|
return {"brand": "Heineken", "confidence": 0.75, "reasoning": "green logo", "tokens": 300}
|
|
|
|
monkeypatch.setattr("core.detect.stages.vlm_cloud._call_cloud_api", mock_call)
|
|
|
|
candidates = [_make_candidate("unknown logo")]
|
|
stats = PipelineStats()
|
|
prompt_fn = lambda ctx: "what brand?"
|
|
|
|
matched = escalate_cloud(candidates, prompt_fn, stats, job_id="test")
|
|
|
|
assert len(matched) == 1
|
|
assert matched[0].brand == "Heineken"
|
|
assert matched[0].source == "cloud_llm"
|
|
assert stats.cloud_llm_calls == 1
|
|
assert stats.estimated_cloud_cost_usd >= 0 # exact cost depends on provider model index
|