phase 2
This commit is contained in:
73
tests/detect/test_profiles.py
Normal file
73
tests/detect/test_profiles.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Tests for ContentTypeProfile implementations."""
|
||||
|
||||
import pytest
|
||||
|
||||
from detect.models import BrandDetection
|
||||
from detect.profiles.base import ContentTypeProfile, CropContext
|
||||
from detect.profiles.soccer import SoccerBroadcastProfile
|
||||
from detect.profiles.stubs import AdvertisingProfile, NewsBroadcastProfile, TranscriptProfile
|
||||
|
||||
|
||||
def test_soccer_satisfies_protocol():
|
||||
profile: ContentTypeProfile = SoccerBroadcastProfile()
|
||||
assert profile.name == "soccer_broadcast"
|
||||
|
||||
|
||||
def test_soccer_frame_extraction_config():
|
||||
cfg = SoccerBroadcastProfile().frame_extraction_config()
|
||||
assert cfg.fps > 0
|
||||
assert cfg.max_frames > 0
|
||||
|
||||
|
||||
def test_soccer_detection_config():
|
||||
cfg = SoccerBroadcastProfile().detection_config()
|
||||
assert 0 < cfg.confidence_threshold < 1
|
||||
assert len(cfg.target_classes) > 0
|
||||
|
||||
|
||||
def test_soccer_brand_dictionary_non_empty():
|
||||
bd = SoccerBroadcastProfile().brand_dictionary()
|
||||
assert len(bd.brands) > 0
|
||||
for canonical, aliases in bd.brands.items():
|
||||
assert len(aliases) > 0
|
||||
|
||||
|
||||
def test_soccer_vlm_prompt():
|
||||
ctx = CropContext(image=b"fake", surrounding_text="Emirates", position_hint="top-center")
|
||||
prompt = SoccerBroadcastProfile().vlm_prompt(ctx)
|
||||
assert "brand" in prompt.lower()
|
||||
assert "Emirates" in prompt
|
||||
|
||||
|
||||
def test_soccer_aggregate_empty():
|
||||
report = SoccerBroadcastProfile().aggregate([])
|
||||
assert len(report.brands) == 0
|
||||
assert len(report.timeline) == 0
|
||||
|
||||
|
||||
def test_soccer_aggregate_groups():
|
||||
detections = [
|
||||
BrandDetection(brand="Nike", timestamp=1.0, duration=0.5, confidence=0.9, source="ocr"),
|
||||
BrandDetection(brand="Nike", timestamp=2.0, duration=0.5, confidence=0.8, source="ocr"),
|
||||
BrandDetection(brand="Adidas", timestamp=3.0, duration=0.5, confidence=0.7, source="logo_match"),
|
||||
]
|
||||
report = SoccerBroadcastProfile().aggregate(detections)
|
||||
assert "Nike" in report.brands
|
||||
assert "Adidas" in report.brands
|
||||
assert report.brands["Nike"].total_appearances == 2
|
||||
assert report.brands["Adidas"].total_appearances == 1
|
||||
assert report.timeline == sorted(report.timeline, key=lambda d: d.timestamp)
|
||||
|
||||
|
||||
def test_soccer_auxiliary_returns_empty():
|
||||
assert SoccerBroadcastProfile().auxiliary_detections("test.mp4") == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stub_cls", [NewsBroadcastProfile, AdvertisingProfile, TranscriptProfile])
|
||||
def test_stubs_raise(stub_cls):
|
||||
stub = stub_cls()
|
||||
assert isinstance(stub.name, str)
|
||||
with pytest.raises(NotImplementedError):
|
||||
stub.frame_extraction_config()
|
||||
with pytest.raises(NotImplementedError):
|
||||
stub.brand_dictionary()
|
||||
Reference in New Issue
Block a user