"""Tests for config overrides and replay.""" import pytest from core.detect.profile import get_profile, get_stage_config from core.detect.stages.models import RegionAnalysisConfig, OCRConfig, ResolverConfig from core.detect.checkpoint.replay import replay_single_stage def _apply_overrides(profile, overrides): """Apply config overrides to a profile dict (same logic as nodes._load_profile).""" merged_configs = dict(profile.get("configs", {})) for stage_name, stage_overrides in overrides.items(): if stage_name in merged_configs: merged_configs[stage_name] = {**merged_configs[stage_name], **stage_overrides} else: merged_configs[stage_name] = stage_overrides return {**profile, "configs": merged_configs} def test_override_patches_ocr(): profile = get_profile("soccer_broadcast") overrides = {"run_ocr": {"min_confidence": 0.3, "languages": ["en", "es", "pt"]}} patched = _apply_overrides(profile, overrides) config = OCRConfig(**get_stage_config(patched, "run_ocr")) assert config.min_confidence == 0.3 assert config.languages == ["en", "es", "pt"] def test_override_patches_resolver(): profile = get_profile("soccer_broadcast") overrides = {"match_brands": {"fuzzy_threshold": 60}} patched = _apply_overrides(profile, overrides) config = ResolverConfig(**get_stage_config(patched, "match_brands")) assert config.fuzzy_threshold == 60 def test_override_no_overrides(): profile = get_profile("soccer_broadcast") patched = _apply_overrides(profile, {}) ocr = OCRConfig(**get_stage_config(patched, "run_ocr")) base_ocr = OCRConfig(**get_stage_config(profile, "run_ocr")) assert ocr.min_confidence == base_ocr.min_confidence assert ocr.languages == base_ocr.languages def test_override_patches_region_analysis(): profile = get_profile("soccer_broadcast") overrides = {"detect_edges": {"edge_canny_low": 25, "edge_canny_high": 200}} patched = _apply_overrides(profile, overrides) config = RegionAnalysisConfig(**get_stage_config(patched, "detect_edges")) assert config.edge_canny_low == 25 assert config.edge_canny_high == 200 # Unchanged fields keep defaults from profile base_config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges")) assert config.edge_hough_threshold == base_config.edge_hough_threshold # --- replay_single_stage --- def test_replay_single_stage_unknown_stage(): with pytest.raises(ValueError, match="Unknown stage"): replay_single_stage("fake-job", "nonexistent_stage") def test_replay_single_stage_first_stage(): with pytest.raises(ValueError, match="Cannot replay the first stage"): replay_single_stage("fake-job", "extract_frames")