phase 9
This commit is contained in:
79
tests/detect/test_aggregator.py
Normal file
79
tests/detect/test_aggregator.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Tests for the report aggregator stage."""
|
||||
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, BrandDetection, PipelineStats
|
||||
from detect.stages.aggregator import compile_report, _merge_contiguous
|
||||
|
||||
|
||||
def _make_detection(brand: str, timestamp: float, duration: float = 0.5,
|
||||
source: str = "ocr", confidence: float = 0.9) -> BrandDetection:
|
||||
return BrandDetection(
|
||||
brand=brand, timestamp=timestamp, duration=duration,
|
||||
confidence=confidence, source=source, content_type="soccer_broadcast",
|
||||
)
|
||||
|
||||
|
||||
def test_merge_contiguous_same_brand():
|
||||
dets = [
|
||||
_make_detection("Nike", 1.0, 0.5),
|
||||
_make_detection("Nike", 1.3, 0.5), # within gap
|
||||
_make_detection("Nike", 5.0, 0.5), # separate
|
||||
]
|
||||
merged = _merge_contiguous(dets, gap_threshold=2.0)
|
||||
assert len(merged) == 2
|
||||
assert merged[0].brand == "Nike"
|
||||
assert merged[0].timestamp == 1.0
|
||||
assert merged[0].duration == pytest.approx(0.8) # 1.0 to 1.8
|
||||
assert merged[1].timestamp == 5.0
|
||||
|
||||
|
||||
def test_merge_different_brands():
|
||||
dets = [
|
||||
_make_detection("Nike", 1.0),
|
||||
_make_detection("Adidas", 1.5),
|
||||
]
|
||||
merged = _merge_contiguous(dets)
|
||||
assert len(merged) == 2
|
||||
|
||||
|
||||
def test_merge_empty():
|
||||
assert _merge_contiguous([]) == []
|
||||
|
||||
|
||||
def test_compile_report(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
dets = [
|
||||
_make_detection("Nike", 1.0, 0.5, confidence=0.95),
|
||||
_make_detection("Nike", 5.0, 1.0, confidence=0.90),
|
||||
_make_detection("Adidas", 3.0, 0.5, confidence=0.85),
|
||||
_make_detection("Heineken", 10.0, 0.5, source="cloud_llm", confidence=0.70),
|
||||
]
|
||||
stats = PipelineStats(
|
||||
frames_extracted=120,
|
||||
regions_detected=32,
|
||||
cloud_llm_calls=1,
|
||||
estimated_cloud_cost_usd=0.003,
|
||||
)
|
||||
|
||||
report = compile_report(
|
||||
detections=dets,
|
||||
stats=stats,
|
||||
video_source="test.mp4",
|
||||
content_type="soccer_broadcast",
|
||||
job_id="test-report",
|
||||
)
|
||||
|
||||
assert len(report.brands) == 3
|
||||
assert report.brands["Nike"].total_appearances == 2
|
||||
assert report.brands["Adidas"].total_appearances == 1
|
||||
assert report.brands["Heineken"].total_appearances == 1
|
||||
assert report.pipeline_stats.cloud_llm_calls == 1
|
||||
assert report.video_source == "test.mp4"
|
||||
|
||||
# job_complete event should have been emitted
|
||||
complete = [e for e in events if e[0] == "job_complete"]
|
||||
assert len(complete) == 1
|
||||
Reference in New Issue
Block a user