81 lines
2.9 KiB
Python
81 lines
2.9 KiB
Python
"""Tests for profile data and helper functions."""
|
|
|
|
from core.detect.models import BrandDetection, CropContext
|
|
from core.detect.profile import get_profile, get_stage_config, build_vlm_prompt, aggregate_detections, pipeline_config_from_dict
|
|
from core.detect.stages.models import FrameExtractionConfig, DetectionConfig, ResolverConfig
|
|
|
|
|
|
def test_soccer_profile_exists():
|
|
profile = get_profile("soccer_broadcast")
|
|
assert profile["name"] == "soccer_broadcast"
|
|
|
|
|
|
def test_soccer_has_pipeline():
|
|
profile = get_profile("soccer_broadcast")
|
|
assert "stages" in profile["pipeline"]
|
|
assert "edges" in profile["pipeline"]
|
|
|
|
|
|
def test_soccer_has_configs():
|
|
profile = get_profile("soccer_broadcast")
|
|
configs = profile["configs"]
|
|
assert "extract_frames" in configs
|
|
assert "filter_scenes" in configs
|
|
assert "detect_edges" in configs
|
|
|
|
|
|
def test_soccer_frame_extraction_config():
|
|
profile = get_profile("soccer_broadcast")
|
|
cfg = FrameExtractionConfig(**get_stage_config(profile, "extract_frames"))
|
|
assert cfg.fps > 0
|
|
assert cfg.max_frames > 0
|
|
|
|
|
|
def test_soccer_detection_config():
|
|
profile = get_profile("soccer_broadcast")
|
|
cfg = DetectionConfig(**get_stage_config(profile, "detect_objects"))
|
|
assert 0 < cfg.confidence_threshold < 1
|
|
assert isinstance(cfg.target_classes, list)
|
|
|
|
|
|
def test_soccer_resolver_config():
|
|
profile = get_profile("soccer_broadcast")
|
|
cfg = ResolverConfig(**get_stage_config(profile, "match_brands"))
|
|
assert cfg.fuzzy_threshold > 0
|
|
|
|
|
|
def test_vlm_prompt():
|
|
ctx = CropContext(image=b"fake", surrounding_text="Emirates", position_hint="top-center")
|
|
template = get_profile("soccer_broadcast")["configs"]["escalate_vlm"]["vlm_prompt_template"]
|
|
prompt = build_vlm_prompt(ctx, template)
|
|
assert "brand" in prompt.lower()
|
|
assert "Emirates" in prompt
|
|
|
|
|
|
def test_aggregate_empty():
|
|
report = aggregate_detections([], "soccer_broadcast")
|
|
assert len(report.brands) == 0
|
|
assert len(report.timeline) == 0
|
|
|
|
|
|
def test_aggregate_groups():
|
|
detections = [
|
|
BrandDetection(brand="Nike", timestamp=1.0, duration=0.5, confidence=0.9, source="ocr"),
|
|
BrandDetection(brand="Nike", timestamp=2.0, duration=0.5, confidence=0.8, source="ocr"),
|
|
BrandDetection(brand="Adidas", timestamp=3.0, duration=0.5, confidence=0.7, source="logo_match"),
|
|
]
|
|
report = aggregate_detections(detections, "soccer_broadcast")
|
|
assert "Nike" in report.brands
|
|
assert "Adidas" in report.brands
|
|
assert report.brands["Nike"].total_appearances == 2
|
|
assert report.brands["Adidas"].total_appearances == 1
|
|
assert report.timeline == sorted(report.timeline, key=lambda d: d.timestamp)
|
|
|
|
|
|
def test_pipeline_config():
|
|
profile = get_profile("soccer_broadcast")
|
|
config = pipeline_config_from_dict(profile["pipeline"])
|
|
assert config.name == "soccer_broadcast"
|
|
assert len(config.stages) > 0
|
|
assert len(config.edges) > 0
|