77 lines
2.7 KiB
Python
77 lines
2.7 KiB
Python
"""Tests for config overrides and replay."""
|
|
|
|
import pytest
|
|
|
|
from core.detect.profile import get_profile, get_stage_config
|
|
from core.detect.stages.models import RegionAnalysisConfig, OCRConfig, ResolverConfig
|
|
from core.detect.checkpoint.replay import replay_single_stage
|
|
|
|
|
|
def _apply_overrides(profile, overrides):
|
|
"""Apply config overrides to a profile dict (same logic as nodes._load_profile)."""
|
|
merged_configs = dict(profile.get("configs", {}))
|
|
for stage_name, stage_overrides in overrides.items():
|
|
if stage_name in merged_configs:
|
|
merged_configs[stage_name] = {**merged_configs[stage_name], **stage_overrides}
|
|
else:
|
|
merged_configs[stage_name] = stage_overrides
|
|
return {**profile, "configs": merged_configs}
|
|
|
|
|
|
def test_override_patches_ocr():
|
|
profile = get_profile("soccer_broadcast")
|
|
overrides = {"run_ocr": {"min_confidence": 0.3, "languages": ["en", "es", "pt"]}}
|
|
patched = _apply_overrides(profile, overrides)
|
|
|
|
config = OCRConfig(**get_stage_config(patched, "run_ocr"))
|
|
|
|
assert config.min_confidence == 0.3
|
|
assert config.languages == ["en", "es", "pt"]
|
|
|
|
|
|
def test_override_patches_resolver():
|
|
profile = get_profile("soccer_broadcast")
|
|
overrides = {"match_brands": {"fuzzy_threshold": 60}}
|
|
patched = _apply_overrides(profile, overrides)
|
|
|
|
config = ResolverConfig(**get_stage_config(patched, "match_brands"))
|
|
|
|
assert config.fuzzy_threshold == 60
|
|
|
|
|
|
def test_override_no_overrides():
|
|
profile = get_profile("soccer_broadcast")
|
|
patched = _apply_overrides(profile, {})
|
|
|
|
ocr = OCRConfig(**get_stage_config(patched, "run_ocr"))
|
|
base_ocr = OCRConfig(**get_stage_config(profile, "run_ocr"))
|
|
|
|
assert ocr.min_confidence == base_ocr.min_confidence
|
|
assert ocr.languages == base_ocr.languages
|
|
|
|
|
|
def test_override_patches_region_analysis():
|
|
profile = get_profile("soccer_broadcast")
|
|
overrides = {"detect_edges": {"edge_canny_low": 25, "edge_canny_high": 200}}
|
|
patched = _apply_overrides(profile, overrides)
|
|
|
|
config = RegionAnalysisConfig(**get_stage_config(patched, "detect_edges"))
|
|
|
|
assert config.edge_canny_low == 25
|
|
assert config.edge_canny_high == 200
|
|
# Unchanged fields keep defaults from profile
|
|
base_config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
|
assert config.edge_hough_threshold == base_config.edge_hough_threshold
|
|
|
|
|
|
# --- replay_single_stage ---
|
|
|
|
def test_replay_single_stage_unknown_stage():
|
|
with pytest.raises(ValueError, match="Unknown stage"):
|
|
replay_single_stage("fake-job", "nonexistent_stage")
|
|
|
|
|
|
def test_replay_single_stage_first_stage():
|
|
with pytest.raises(ValueError, match="Cannot replay the first stage"):
|
|
replay_single_stage("fake-job", "extract_frames")
|