Files
mediaproc/tests/detect/test_replay.py
2026-03-30 07:22:14 -03:00

77 lines
2.7 KiB
Python

"""Tests for config overrides and replay."""
import pytest
from core.detect.profile import get_profile, get_stage_config
from core.detect.stages.models import RegionAnalysisConfig, OCRConfig, ResolverConfig
from core.detect.checkpoint.replay import replay_single_stage
def _apply_overrides(profile, overrides):
"""Apply config overrides to a profile dict (same logic as nodes._load_profile)."""
merged_configs = dict(profile.get("configs", {}))
for stage_name, stage_overrides in overrides.items():
if stage_name in merged_configs:
merged_configs[stage_name] = {**merged_configs[stage_name], **stage_overrides}
else:
merged_configs[stage_name] = stage_overrides
return {**profile, "configs": merged_configs}
def test_override_patches_ocr():
profile = get_profile("soccer_broadcast")
overrides = {"run_ocr": {"min_confidence": 0.3, "languages": ["en", "es", "pt"]}}
patched = _apply_overrides(profile, overrides)
config = OCRConfig(**get_stage_config(patched, "run_ocr"))
assert config.min_confidence == 0.3
assert config.languages == ["en", "es", "pt"]
def test_override_patches_resolver():
profile = get_profile("soccer_broadcast")
overrides = {"match_brands": {"fuzzy_threshold": 60}}
patched = _apply_overrides(profile, overrides)
config = ResolverConfig(**get_stage_config(patched, "match_brands"))
assert config.fuzzy_threshold == 60
def test_override_no_overrides():
profile = get_profile("soccer_broadcast")
patched = _apply_overrides(profile, {})
ocr = OCRConfig(**get_stage_config(patched, "run_ocr"))
base_ocr = OCRConfig(**get_stage_config(profile, "run_ocr"))
assert ocr.min_confidence == base_ocr.min_confidence
assert ocr.languages == base_ocr.languages
def test_override_patches_region_analysis():
profile = get_profile("soccer_broadcast")
overrides = {"detect_edges": {"edge_canny_low": 25, "edge_canny_high": 200}}
patched = _apply_overrides(profile, overrides)
config = RegionAnalysisConfig(**get_stage_config(patched, "detect_edges"))
assert config.edge_canny_low == 25
assert config.edge_canny_high == 200
# Unchanged fields keep defaults from profile
base_config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
assert config.edge_hough_threshold == base_config.edge_hough_threshold
# --- replay_single_stage ---
def test_replay_single_stage_unknown_stage():
with pytest.raises(ValueError, match="Unknown stage"):
replay_single_stage("fake-job", "nonexistent_stage")
def test_replay_single_stage_first_stage():
with pytest.raises(ValueError, match="Cannot replay the first stage"):
replay_single_stage("fake-job", "extract_frames")