"""Tests for cloud LLM escalation stage.""" import numpy as np import pytest from detect.models import BoundingBox, Frame, PipelineStats, TextCandidate from detect.stages.vlm_cloud import escalate_cloud, _parse_response def _make_candidate(text: str = "unknown", confidence: float = 0.4) -> TextCandidate: frame = Frame(sequence=0, chunk_id=0, timestamp=1.0, image=np.zeros((50, 100, 3), dtype=np.uint8)) box = BoundingBox(x=0, y=0, w=100, h=50, confidence=0.5, label="text") return TextCandidate(frame=frame, bbox=box, text=text, ocr_confidence=confidence) def test_parse_response_clean(): result = _parse_response("Nike, 0.92, swoosh logo visible", 200) assert result["brand"] == "Nike" assert result["confidence"] == 0.92 assert "swoosh" in result["reasoning"] assert result["tokens"] == 200 def test_parse_response_no_confidence(): result = _parse_response("Adidas", 0) assert result["brand"] == "Adidas" assert result["confidence"] == 0.5 # default def test_escalate_skips_without_api_key(monkeypatch): events = [] monkeypatch.setattr("detect.emit.push_detect_event", lambda job_id, etype, data: events.append((etype, data))) monkeypatch.delenv("GROQ_API_KEY", raising=False) monkeypatch.delenv("GEMINI_API_KEY", raising=False) monkeypatch.delenv("OPENAI_API_KEY", raising=False) monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq") # Reset cached provider import detect.providers as prov monkeypatch.setattr(prov, "_cached", None) candidates = [_make_candidate()] stats = PipelineStats() prompt_fn = lambda ctx: "what brand?" matched = escalate_cloud(candidates, prompt_fn, stats, job_id="test") assert len(matched) == 0 assert stats.cloud_llm_calls == 0 log_events = [e for e in events if e[0] == "log"] assert any("No API key" in e[1].get("msg", "") for e in log_events) def test_escalate_empty_candidates(monkeypatch): events = [] monkeypatch.setattr("detect.emit.push_detect_event", lambda job_id, etype, data: events.append((etype, data))) stats = PipelineStats() matched = escalate_cloud([], lambda ctx: "", stats, job_id="test") assert len(matched) == 0 assert stats.cloud_llm_calls == 0 def test_escalate_with_mock_api(monkeypatch): events = [] monkeypatch.setattr("detect.emit.push_detect_event", lambda job_id, etype, data: events.append((etype, data))) monkeypatch.setenv("GROQ_API_KEY", "test-key") monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq") # Reset cached provider import detect.providers as prov monkeypatch.setattr(prov, "_cached", None) def mock_call(image_b64, prompt): return {"brand": "Heineken", "confidence": 0.75, "reasoning": "green logo", "tokens": 300} monkeypatch.setattr("detect.stages.vlm_cloud._call_cloud_api", mock_call) candidates = [_make_candidate("unknown logo")] stats = PipelineStats() prompt_fn = lambda ctx: "what brand?" matched = escalate_cloud(candidates, prompt_fn, stats, job_id="test") assert len(matched) == 1 assert matched[0].brand == "Heineken" assert matched[0].source == "cloud_llm" assert stats.cloud_llm_calls == 1 assert stats.estimated_cloud_cost_usd >= 0 # exact cost depends on provider model index