80 lines
2.5 KiB
Python
80 lines
2.5 KiB
Python
"""Tests for the report aggregator stage."""
|
|
|
|
import pytest
|
|
|
|
from detect.models import BoundingBox, BrandDetection, PipelineStats
|
|
from detect.stages.aggregator import compile_report, _merge_contiguous
|
|
|
|
|
|
def _make_detection(brand: str, timestamp: float, duration: float = 0.5,
|
|
source: str = "ocr", confidence: float = 0.9) -> BrandDetection:
|
|
return BrandDetection(
|
|
brand=brand, timestamp=timestamp, duration=duration,
|
|
confidence=confidence, source=source, content_type="soccer_broadcast",
|
|
)
|
|
|
|
|
|
def test_merge_contiguous_same_brand():
|
|
dets = [
|
|
_make_detection("Nike", 1.0, 0.5),
|
|
_make_detection("Nike", 1.3, 0.5), # within gap
|
|
_make_detection("Nike", 5.0, 0.5), # separate
|
|
]
|
|
merged = _merge_contiguous(dets, gap_threshold=2.0)
|
|
assert len(merged) == 2
|
|
assert merged[0].brand == "Nike"
|
|
assert merged[0].timestamp == 1.0
|
|
assert merged[0].duration == pytest.approx(0.8) # 1.0 to 1.8
|
|
assert merged[1].timestamp == 5.0
|
|
|
|
|
|
def test_merge_different_brands():
|
|
dets = [
|
|
_make_detection("Nike", 1.0),
|
|
_make_detection("Adidas", 1.5),
|
|
]
|
|
merged = _merge_contiguous(dets)
|
|
assert len(merged) == 2
|
|
|
|
|
|
def test_merge_empty():
|
|
assert _merge_contiguous([]) == []
|
|
|
|
|
|
def test_compile_report(monkeypatch):
|
|
events = []
|
|
monkeypatch.setattr("detect.emit.push_detect_event",
|
|
lambda job_id, etype, data: events.append((etype, data)))
|
|
|
|
dets = [
|
|
_make_detection("Nike", 1.0, 0.5, confidence=0.95),
|
|
_make_detection("Nike", 5.0, 1.0, confidence=0.90),
|
|
_make_detection("Adidas", 3.0, 0.5, confidence=0.85),
|
|
_make_detection("Heineken", 10.0, 0.5, source="cloud_llm", confidence=0.70),
|
|
]
|
|
stats = PipelineStats(
|
|
frames_extracted=120,
|
|
regions_detected=32,
|
|
cloud_llm_calls=1,
|
|
estimated_cloud_cost_usd=0.003,
|
|
)
|
|
|
|
report = compile_report(
|
|
detections=dets,
|
|
stats=stats,
|
|
video_source="test.mp4",
|
|
content_type="soccer_broadcast",
|
|
job_id="test-report",
|
|
)
|
|
|
|
assert len(report.brands) == 3
|
|
assert report.brands["Nike"].total_appearances == 2
|
|
assert report.brands["Adidas"].total_appearances == 1
|
|
assert report.brands["Heineken"].total_appearances == 1
|
|
assert report.pipeline_stats.cloud_llm_calls == 1
|
|
assert report.video_source == "test.mp4"
|
|
|
|
# job_complete event should have been emitted
|
|
complete = [e for e in events if e[0] == "job_complete"]
|
|
assert len(complete) == 1
|