This commit is contained in:
2026-03-30 07:22:14 -03:00
parent d0707333fd
commit 4220b0418e
182 changed files with 3668 additions and 5231 deletions

View File

@@ -24,9 +24,9 @@ logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(m
sys.path.insert(0, ".")
from detect.profiles.soccer import SoccerBroadcastProfile
from detect.stages.frame_extractor import extract_frames
from detect.stages.scene_filter import scene_filter
from core.detect.profile import get_profile
from core.detect.stages.frame_extractor import extract_frames
from core.detect.stages.scene_filter import scene_filter
logger = logging.getLogger(__name__)

View File

@@ -24,8 +24,8 @@ logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(m
sys.path.insert(0, ".")
from detect.graph import get_pipeline
from detect.state import DetectState
from core.detect.graph import get_pipeline
from core.detect.state import DetectState
logger = logging.getLogger(__name__)

View File

@@ -39,13 +39,13 @@ sys.path.insert(0, ".")
from langgraph.graph import END, StateGraph
from detect import emit
from detect.models import PipelineStats
from detect.profiles.soccer import SoccerBroadcastProfile
from detect.stages.frame_extractor import extract_frames
from detect.stages.scene_filter import scene_filter
from detect.stages.edge_detector import detect_edge_regions
from detect.state import DetectState
from core.detect import emit
from core.detect.models import PipelineStats
from core.detect.profile import get_profile
from core.detect.stages.frame_extractor import extract_frames
from core.detect.stages.scene_filter import scene_filter
from core.detect.stages.edge_detector import detect_edge_regions
from core.detect.state import DetectState
logger = logging.getLogger(__name__)
@@ -166,7 +166,7 @@ def main():
# --- Parameter sensitivity ---
logger.info("=== Parameter sensitivity (local debug) ===")
from detect.stages.edge_detector import _load_cv_edges
from core.detect.stages.edge_detector import _load_cv_edges
edges_mod = _load_cv_edges()
filtered = result.get("filtered_frames", [])

View File

@@ -58,7 +58,7 @@ def extract_frames_ffmpeg(video_path: str, fps: float, max_frames: int):
import numpy as np
from PIL import Image
from detect.models import Frame
from core.detect.models import Frame
tmpdir = tempfile.mkdtemp(prefix="scenario_")
pattern = os.path.join(tmpdir, "frame_%04d.jpg")
@@ -111,7 +111,7 @@ def main():
logger.info("Extracted %d frames", len(frames))
# Create timeline + branch + checkpoint
from detect.checkpoint.storage import create_timeline, save_stage_output
from core.detect.checkpoint.storage import create_timeline, save_stage_output
timeline_id, branch_id = create_timeline(
source_video=video_path,

View File

@@ -58,7 +58,7 @@ def make_brand_image(text: str, width: int = 300, height: int = 100) -> str:
def main():
from detect.providers import get_provider, has_api_key, PROVIDERS
from core.detect.providers import get_provider, has_api_key, PROVIDERS
provider_name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
logger.info("Provider: %s", provider_name)

View File

@@ -13,8 +13,8 @@ import sys
sys.path.insert(0, ".")
from detect.profiles.soccer import SoccerBroadcastProfile
from detect.stages.frame_extractor import extract_frames
from core.detect.profile import get_profile
from core.detect.stages.frame_extractor import extract_frames
logger = logging.getLogger(__name__)

View File

@@ -86,9 +86,9 @@ def test_ocr_stage_remote(url: str):
logger.info("--- OCR stage (remote mode) ---")
sys.path.insert(0, ".")
from detect.models import BoundingBox, Frame
from detect.profiles.base import OCRConfig
from detect.stages.ocr_stage import run_ocr
from core.detect.models import BoundingBox, Frame
from core.detect.stages.models import OCRConfig
from core.detect.stages.ocr_stage import run_ocr
# Create a frame with text baked in
image = make_text_image("EMIRATES")

View File

@@ -48,10 +48,10 @@ def main():
# Override Redis to localhost (ctrl/.env has k8s hostname)
os.environ["REDIS_URL"] = f"redis://localhost:{args.port}/0"
from detect.graph import get_pipeline, NODES
from detect.checkpoint import list_checkpoints
from detect.checkpoint import replay_from
from detect.state import DetectState
from core.detect.graph import get_pipeline, NODES
from core.detect.checkpoint import list_checkpoints
from core.detect.checkpoint import replay_from
from core.detect.state import DetectState
VIDEO = "media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4"

View File

@@ -2,8 +2,8 @@
import pytest
from detect.models import BoundingBox, BrandDetection, PipelineStats
from detect.stages.aggregator import compile_report, _merge_contiguous
from core.detect.models import BoundingBox, BrandDetection, PipelineStats
from core.detect.stages.aggregator import compile_report, _merge_contiguous
def _make_detection(brand: str, timestamp: float, duration: float = 0.5,
@@ -43,7 +43,7 @@ def test_merge_empty():
def test_compile_report(monkeypatch):
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
monkeypatch.setattr("core.detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
dets = [

View File

@@ -3,9 +3,9 @@
import numpy as np
import pytest
from detect.models import BoundingBox, Frame, TextCandidate
from detect.profiles.base import ResolverConfig
from detect.stages.brand_resolver import resolve_brands, _normalize, _match_session
from core.detect.models import BoundingBox, Frame, TextCandidate
from core.detect.stages.models import ResolverConfig
from core.detect.stages.brand_resolver import resolve_brands, _normalize, _match_session
CONFIG = ResolverConfig(fuzzy_threshold=75)
@@ -28,7 +28,7 @@ def test_session_match():
def test_resolve_with_session(monkeypatch):
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
monkeypatch.setattr("core.detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
session = {"nike": "Nike", "emirates": "Emirates"}
@@ -46,7 +46,7 @@ def test_resolve_with_session(monkeypatch):
def test_resolve_unresolved_without_db(monkeypatch):
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
monkeypatch.setattr("core.detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
candidates = [_make_candidate("random garbage text")]
@@ -61,7 +61,7 @@ def test_resolve_unresolved_without_db(monkeypatch):
def test_resolve_empty(monkeypatch):
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
monkeypatch.setattr("core.detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
matched, unresolved = resolve_brands([], CONFIG, session_brands={})
@@ -73,7 +73,7 @@ def test_resolve_empty(monkeypatch):
def test_resolve_builds_session_during_run(monkeypatch):
"""Session brands accumulate during a single run — second candidate benefits."""
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
monkeypatch.setattr("core.detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
session = {"nike": "Nike"}
@@ -93,7 +93,7 @@ def test_resolve_builds_session_during_run(monkeypatch):
def test_events_emitted(monkeypatch):
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
monkeypatch.setattr("core.detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
session = {"nike": "Nike"}

View File

@@ -5,7 +5,7 @@ import json
import numpy as np
import pytest
from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
from core.detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
from core.schema.serializers._common import safe_construct
from core.schema.serializers.pipeline import (
serialize_frame_meta,
@@ -163,34 +163,39 @@ def test_all_serialized_is_json_compatible():
assert roundtrip["frame_meta"]["sequence"] == frame.sequence
# --- OverrideProfile ---
# --- Config overrides (dict merging, replaces OverrideProfile) ---
def test_override_profile_region_analysis():
"""OverrideProfile must patch region_analysis_config with overrides."""
from detect.checkpoint.replay import OverrideProfile
from detect.profiles.soccer import SoccerBroadcastProfile
from detect.profiles.base import RegionAnalysisConfig
def test_config_override_region_analysis():
"""Config overrides must patch stage config values."""
from core.detect.profile import get_profile, get_stage_config
from core.detect.stages.models import RegionAnalysisConfig
base = SoccerBroadcastProfile()
original = base.region_analysis_config()
profile = get_profile("soccer_broadcast")
original = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
overrides = {"region_analysis": {"edge_canny_low": 25, "edge_canny_high": 200}}
wrapped = OverrideProfile(base, overrides)
patched = wrapped.region_analysis_config()
overrides = {"detect_edges": {"edge_canny_low": 25, "edge_canny_high": 200}}
merged_configs = {**profile["configs"]}
merged_configs["detect_edges"] = {**merged_configs["detect_edges"], **overrides["detect_edges"]}
patched_profile = {**profile, "configs": merged_configs}
patched = RegionAnalysisConfig(**get_stage_config(patched_profile, "detect_edges"))
assert isinstance(patched, RegionAnalysisConfig)
assert patched.edge_canny_low == 25
assert patched.edge_canny_high == 200
# Unmodified fields keep their defaults
assert patched.edge_hough_threshold == original.edge_hough_threshold
def test_override_profile_passthrough():
"""OverrideProfile without region_analysis key passes through unchanged."""
from detect.checkpoint.replay import OverrideProfile
from detect.profiles.soccer import SoccerBroadcastProfile
def test_config_override_passthrough():
"""Overrides for other stages don't affect unrelated stages."""
from core.detect.profile import get_profile, get_stage_config
from core.detect.stages.models import RegionAnalysisConfig
base = SoccerBroadcastProfile()
wrapped = OverrideProfile(base, {"ocr": {"min_confidence": 0.1}})
config = wrapped.region_analysis_config()
assert config.edge_canny_low == base.region_analysis_config().edge_canny_low
profile = get_profile("soccer_broadcast")
original = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
overrides = {"run_ocr": {"min_confidence": 0.1}}
merged_configs = {**profile["configs"], **overrides}
patched_profile = {**profile, "configs": merged_configs}
patched = RegionAnalysisConfig(**get_stage_config(patched_profile, "detect_edges"))
assert patched.edge_canny_low == original.edge_canny_low

View File

@@ -1,6 +1,6 @@
"""Tests for the config endpoint and stage palette."""
from detect.stages import list_stages, get_palette
from core.detect.stages import list_stages, get_palette
def test_stage_palette_has_config_fields():

View File

@@ -15,7 +15,7 @@ import pytest
# Load edges module directly
_spec = importlib.util.spec_from_file_location(
"cv_edges", Path("gpu/models/cv/edges.py"),
"cv_edges", Path("core/gpu/models/cv/edges.py"),
)
_edges_mod = importlib.util.module_from_spec(_spec)
_spec.loader.exec_module(_edges_mod)

View File

@@ -5,8 +5,8 @@ from pathlib import Path
import pytest
from detect.profiles.base import FrameExtractionConfig
from detect.stages.frame_extractor import extract_frames
from core.detect.stages.models import FrameExtractionConfig
from core.detect.stages.frame_extractor import extract_frames
SAMPLE_DIR = Path("media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e")
@@ -61,7 +61,7 @@ def test_extract_frames_with_events(monkeypatch):
def mock_push(job_id, event_type, data):
events.append((job_id, event_type, data))
monkeypatch.setattr("detect.emit.push_detect_event", mock_push)
monkeypatch.setattr("core.detect.emit.push_detect_event", mock_push)
video = _get_sample_video()
config = FrameExtractionConfig(fps=1, max_frames=5)

View File

@@ -4,9 +4,9 @@ import os
import pytest
from detect.graph import NODES, build_graph, get_pipeline
from detect.models import PipelineStats
from detect.state import DetectState
from core.detect.graph import NODES, build_graph, get_pipeline
from core.detect.models import PipelineStats
from core.detect.state import DetectState
VIDEO = "media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4"
@@ -42,7 +42,7 @@ def test_graph_has_all_nodes():
def test_graph_runs_end_to_end(monkeypatch):
"""Run the full graph with mocked event emission."""
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
monkeypatch.setattr("core.detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
pipeline = get_pipeline()
@@ -75,7 +75,7 @@ def test_graph_runs_end_to_end(monkeypatch):
def test_graph_node_transitions(monkeypatch):
"""Verify each node emits running → done transitions."""
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
monkeypatch.setattr("core.detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
pipeline = get_pipeline()

View File

@@ -3,9 +3,9 @@
import numpy as np
import pytest
from detect.models import BoundingBox, Frame
from detect.profiles.base import OCRConfig
from detect.stages.ocr_stage import _crop_region, _parse_ocr_raw, run_ocr
from core.detect.models import BoundingBox, Frame
from core.detect.stages.models import OCRConfig
from core.detect.stages.ocr_stage import _crop_region, _parse_ocr_raw, run_ocr
def _has_paddleocr() -> bool:
@@ -80,7 +80,7 @@ def test_parse_empty():
def test_run_ocr_remote(monkeypatch):
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
monkeypatch.setattr("core.detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
class FakeResult:
@@ -94,11 +94,11 @@ def test_run_ocr_remote(monkeypatch):
def ocr(self, image, languages):
return [FakeResult("NIKE", 0.92)]
monkeypatch.setattr("detect.stages.ocr_stage.InferenceClient", FakeClient,
monkeypatch.setattr("core.detect.stages.ocr_stage.InferenceClient", FakeClient,
raising=False)
# Patch the import path used in the function
import detect.stages.ocr_stage as mod
monkeypatch.setattr("detect.inference.InferenceClient", FakeClient)
import core.detect.stages.ocr_stage as mod
monkeypatch.setattr("core.detect.inference.InferenceClient", FakeClient)
frame = _make_frame()
box = _make_box()
@@ -123,7 +123,7 @@ def test_run_ocr_remote(monkeypatch):
)
def test_run_ocr_skips_empty_crop(monkeypatch):
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
monkeypatch.setattr("core.detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
frame = _make_frame(w=10, h=10)

View File

@@ -26,7 +26,7 @@ def _make_image(w: int = 200, h: int = 60) -> np.ndarray:
@requires_cv2
def test_binarize():
from gpu.models.preprocess import binarize
from core.gpu.models.preprocess import binarize
img = _make_image()
result = binarize(img)
@@ -40,7 +40,7 @@ def test_binarize():
@requires_cv2
def test_enhance_contrast():
from gpu.models.preprocess import enhance_contrast
from core.gpu.models.preprocess import enhance_contrast
img = _make_image()
result = enhance_contrast(img)
@@ -51,7 +51,7 @@ def test_enhance_contrast():
@requires_cv2
def test_deskew_no_rotation():
from gpu.models.preprocess import deskew
from core.gpu.models.preprocess import deskew
img = _make_image()
result = deskew(img)
@@ -63,7 +63,7 @@ def test_deskew_no_rotation():
@requires_cv2
def test_preprocess_pipeline():
from gpu.models.preprocess import preprocess
from core.gpu.models.preprocess import preprocess
img = _make_image()
@@ -76,7 +76,7 @@ def test_preprocess_pipeline():
@requires_cv2
def test_preprocess_all_disabled():
from gpu.models.preprocess import preprocess
from core.gpu.models.preprocess import preprocess
img = _make_image()
result = preprocess(img, do_binarize=False, do_deskew=False, do_contrast=False)

View File

@@ -1,55 +1,70 @@
"""Tests for ContentTypeProfile implementations."""
"""Tests for profile data and helper functions."""
import pytest
from detect.models import BrandDetection
from detect.profiles.base import ContentTypeProfile, CropContext
from detect.profiles.soccer import SoccerBroadcastProfile
from detect.profiles.stubs import AdvertisingProfile, NewsBroadcastProfile, TranscriptProfile
from core.detect.models import BrandDetection, CropContext
from core.detect.profile import get_profile, get_stage_config, build_vlm_prompt, aggregate_detections, pipeline_config_from_dict
from core.detect.stages.models import FrameExtractionConfig, DetectionConfig, ResolverConfig
def test_soccer_satisfies_protocol():
profile: ContentTypeProfile = SoccerBroadcastProfile()
assert profile.name == "soccer_broadcast"
def test_soccer_profile_exists():
profile = get_profile("soccer_broadcast")
assert profile["name"] == "soccer_broadcast"
def test_soccer_has_pipeline():
profile = get_profile("soccer_broadcast")
assert "stages" in profile["pipeline"]
assert "edges" in profile["pipeline"]
def test_soccer_has_configs():
profile = get_profile("soccer_broadcast")
configs = profile["configs"]
assert "extract_frames" in configs
assert "filter_scenes" in configs
assert "detect_edges" in configs
def test_soccer_frame_extraction_config():
cfg = SoccerBroadcastProfile().frame_extraction_config()
profile = get_profile("soccer_broadcast")
cfg = FrameExtractionConfig(**get_stage_config(profile, "extract_frames"))
assert cfg.fps > 0
assert cfg.max_frames > 0
def test_soccer_detection_config():
cfg = SoccerBroadcastProfile().detection_config()
profile = get_profile("soccer_broadcast")
cfg = DetectionConfig(**get_stage_config(profile, "detect_objects"))
assert 0 < cfg.confidence_threshold < 1
assert isinstance(cfg.target_classes, list)
def test_soccer_resolver_config():
cfg = SoccerBroadcastProfile().resolver_config()
profile = get_profile("soccer_broadcast")
cfg = ResolverConfig(**get_stage_config(profile, "match_brands"))
assert cfg.fuzzy_threshold > 0
def test_soccer_vlm_prompt():
def test_vlm_prompt():
ctx = CropContext(image=b"fake", surrounding_text="Emirates", position_hint="top-center")
prompt = SoccerBroadcastProfile().vlm_prompt(ctx)
template = get_profile("soccer_broadcast")["configs"]["escalate_vlm"]["vlm_prompt_template"]
prompt = build_vlm_prompt(ctx, template)
assert "brand" in prompt.lower()
assert "Emirates" in prompt
def test_soccer_aggregate_empty():
report = SoccerBroadcastProfile().aggregate([])
def test_aggregate_empty():
report = aggregate_detections([], "soccer_broadcast")
assert len(report.brands) == 0
assert len(report.timeline) == 0
def test_soccer_aggregate_groups():
def test_aggregate_groups():
detections = [
BrandDetection(brand="Nike", timestamp=1.0, duration=0.5, confidence=0.9, source="ocr"),
BrandDetection(brand="Nike", timestamp=2.0, duration=0.5, confidence=0.8, source="ocr"),
BrandDetection(brand="Adidas", timestamp=3.0, duration=0.5, confidence=0.7, source="logo_match"),
]
report = SoccerBroadcastProfile().aggregate(detections)
report = aggregate_detections(detections, "soccer_broadcast")
assert "Nike" in report.brands
assert "Adidas" in report.brands
assert report.brands["Nike"].total_appearances == 2
@@ -57,15 +72,9 @@ def test_soccer_aggregate_groups():
assert report.timeline == sorted(report.timeline, key=lambda d: d.timestamp)
def test_soccer_auxiliary_returns_empty():
assert SoccerBroadcastProfile().auxiliary_detections("test.mp4") == []
@pytest.mark.parametrize("stub_cls", [NewsBroadcastProfile, AdvertisingProfile, TranscriptProfile])
def test_stubs_raise(stub_cls):
stub = stub_cls()
assert isinstance(stub.name, str)
with pytest.raises(NotImplementedError):
stub.frame_extraction_config()
with pytest.raises(NotImplementedError):
stub.resolver_config()
def test_pipeline_config():
profile = get_profile("soccer_broadcast")
config = pipeline_config_from_dict(profile["pipeline"])
assert config.name == "soccer_broadcast"
assert len(config.stages) > 0
assert len(config.edges) > 0

View File

@@ -6,14 +6,14 @@ from pathlib import Path
import numpy as np
import pytest
from detect.models import BoundingBox, Frame
from detect.profiles.base import RegionAnalysisConfig
from detect.profiles.soccer import SoccerBroadcastProfile
from core.detect.models import BoundingBox, Frame
from core.detect.stages.models import RegionAnalysisConfig
from core.detect.profile import get_profile, get_stage_config
# Load edges module directly — gpu/models/__init__.py has GPU-only imports
_spec = importlib.util.spec_from_file_location(
"cv_edges", Path("gpu/models/cv/edges.py"),
"cv_edges", Path("core/gpu/models/cv/edges.py"),
)
_edges_mod = importlib.util.module_from_spec(_spec)
_spec.loader.exec_module(_edges_mod)
@@ -40,8 +40,8 @@ def _make_frame_with_lines(seq: int = 0) -> Frame:
# --- Config ---
def test_soccer_profile_has_region_analysis_config():
profile = SoccerBroadcastProfile()
config = profile.region_analysis_config()
profile = get_profile("soccer_broadcast")
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
assert isinstance(config, RegionAnalysisConfig)
assert config.enabled is True
@@ -133,9 +133,9 @@ def test_detect_edges_debug_blank_frame():
def test_stage_disabled(monkeypatch):
"""When disabled, returns empty dict."""
monkeypatch.setattr("detect.emit.push_detect_event", lambda *a, **kw: None)
monkeypatch.setattr("core.detect.emit.push_detect_event", lambda *a, **kw: None)
from detect.stages.edge_detector import detect_edge_regions
from core.detect.stages.edge_detector import detect_edge_regions
config = RegionAnalysisConfig(enabled=False)
result = detect_edge_regions([_make_frame()], config, job_id="test")
@@ -144,9 +144,9 @@ def test_stage_disabled(monkeypatch):
def test_stage_local_blank(monkeypatch):
"""Local mode on blank frames returns empty boxes."""
monkeypatch.setattr("detect.emit.push_detect_event", lambda *a, **kw: None)
monkeypatch.setattr("core.detect.emit.push_detect_event", lambda *a, **kw: None)
from detect.stages.edge_detector import detect_edge_regions
from core.detect.stages.edge_detector import detect_edge_regions
config = RegionAnalysisConfig()
result = detect_edge_regions([_make_frame()], config, job_id="test")
@@ -156,9 +156,9 @@ def test_stage_local_blank(monkeypatch):
def test_stage_local_with_lines(monkeypatch):
"""Local mode on frame with lines should find regions."""
monkeypatch.setattr("detect.emit.push_detect_event", lambda *a, **kw: None)
monkeypatch.setattr("core.detect.emit.push_detect_event", lambda *a, **kw: None)
from detect.stages.edge_detector import detect_edge_regions
from core.detect.stages.edge_detector import detect_edge_regions
config = RegionAnalysisConfig()
frame = _make_frame_with_lines()
@@ -174,22 +174,22 @@ def test_stage_local_with_lines(monkeypatch):
def test_detect_edges_in_nodes():
"""detect_edges must be in the pipeline node list."""
from detect.graph import NODES, NODE_FUNCTIONS
from core.detect.graph import NODES, NODE_FUNCTIONS
assert "detect_edges" in NODES
node_names = [name for name, _ in NODE_FUNCTIONS]
assert "detect_edges" in node_names
# Must be after filter_scenes, before detect_objects
# Must be after field_segmentation, before detect_objects
idx = NODES.index("detect_edges")
assert NODES[idx - 1] == "filter_scenes"
assert NODES[idx - 1] == "field_segmentation"
assert NODES[idx + 1] == "detect_objects"
# --- State ---
def test_state_has_edge_regions_field():
from detect.state import DetectState
from core.detect.state import DetectState
hints = DetectState.__annotations__
assert "edge_regions_by_frame" in hints

View File

@@ -1,87 +1,67 @@
"""Tests for replay and OverrideProfile."""
"""Tests for config overrides and replay."""
import pytest
from detect.profiles.soccer import SoccerBroadcastProfile
from detect.profiles.base import RegionAnalysisConfig
from detect.checkpoint.replay import OverrideProfile, replay_single_stage
from core.detect.profile import get_profile, get_stage_config
from core.detect.stages.models import RegionAnalysisConfig, OCRConfig, ResolverConfig
from core.detect.checkpoint.replay import replay_single_stage
def test_override_profile_patches_ocr():
base = SoccerBroadcastProfile()
overrides = {"ocr": {"min_confidence": 0.3, "languages": ["en", "es", "pt"]}}
profile = OverrideProfile(base, overrides)
def _apply_overrides(profile, overrides):
"""Apply config overrides to a profile dict (same logic as nodes._load_profile)."""
merged_configs = dict(profile.get("configs", {}))
for stage_name, stage_overrides in overrides.items():
if stage_name in merged_configs:
merged_configs[stage_name] = {**merged_configs[stage_name], **stage_overrides}
else:
merged_configs[stage_name] = stage_overrides
return {**profile, "configs": merged_configs}
config = profile.ocr_config()
def test_override_patches_ocr():
profile = get_profile("soccer_broadcast")
overrides = {"run_ocr": {"min_confidence": 0.3, "languages": ["en", "es", "pt"]}}
patched = _apply_overrides(profile, overrides)
config = OCRConfig(**get_stage_config(patched, "run_ocr"))
assert config.min_confidence == 0.3
assert config.languages == ["en", "es", "pt"]
def test_override_profile_patches_resolver():
base = SoccerBroadcastProfile()
overrides = {"resolver": {"fuzzy_threshold": 60}}
profile = OverrideProfile(base, overrides)
def test_override_patches_resolver():
profile = get_profile("soccer_broadcast")
overrides = {"match_brands": {"fuzzy_threshold": 60}}
patched = _apply_overrides(profile, overrides)
config = profile.resolver_config()
config = ResolverConfig(**get_stage_config(patched, "match_brands"))
assert config.fuzzy_threshold == 60
def test_override_profile_patches_detection():
base = SoccerBroadcastProfile()
overrides = {"detection": {"confidence_threshold": 0.5}}
profile = OverrideProfile(base, overrides)
def test_override_no_overrides():
profile = get_profile("soccer_broadcast")
patched = _apply_overrides(profile, {})
config = profile.detection_config()
assert config.confidence_threshold == 0.5
def test_override_profile_no_overrides():
base = SoccerBroadcastProfile()
profile = OverrideProfile(base, {})
ocr = profile.ocr_config()
base_ocr = base.ocr_config()
ocr = OCRConfig(**get_stage_config(patched, "run_ocr"))
base_ocr = OCRConfig(**get_stage_config(profile, "run_ocr"))
assert ocr.min_confidence == base_ocr.min_confidence
assert ocr.languages == base_ocr.languages
def test_override_profile_delegates_non_config():
base = SoccerBroadcastProfile()
profile = OverrideProfile(base, {"ocr": {"min_confidence": 0.1}})
def test_override_patches_region_analysis():
profile = get_profile("soccer_broadcast")
overrides = {"detect_edges": {"edge_canny_low": 25, "edge_canny_high": 200}}
patched = _apply_overrides(profile, overrides)
assert profile.name == "soccer_broadcast"
assert profile.resolver_config().fuzzy_threshold > 0
config = RegionAnalysisConfig(**get_stage_config(patched, "detect_edges"))
def test_override_profile_ignores_unknown_fields():
base = SoccerBroadcastProfile()
overrides = {"ocr": {"nonexistent_field": 42}}
profile = OverrideProfile(base, overrides)
config = profile.ocr_config()
assert not hasattr(config, "nonexistent_field")
assert config.min_confidence == base.ocr_config().min_confidence
# --- OverrideProfile for region_analysis ---
def test_override_profile_patches_region_analysis():
base = SoccerBroadcastProfile()
overrides = {"region_analysis": {"edge_canny_low": 25, "edge_canny_high": 200}}
profile = OverrideProfile(base, overrides)
config = profile.region_analysis_config()
assert isinstance(config, RegionAnalysisConfig)
assert config.edge_canny_low == 25
assert config.edge_canny_high == 200
# Unchanged fields keep defaults
assert config.edge_hough_threshold == base.region_analysis_config().edge_hough_threshold
# Unchanged fields keep defaults from profile
base_config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
assert config.edge_hough_threshold == base_config.edge_hough_threshold
# --- replay_single_stage ---

View File

@@ -3,9 +3,9 @@
import numpy as np
import pytest
from detect.models import Frame
from detect.profiles.base import SceneFilterConfig
from detect.stages.scene_filter import scene_filter
from core.detect.models import Frame
from core.detect.stages.models import SceneFilterConfig
from core.detect.stages.scene_filter import scene_filter
def _make_frame(seq: int, color: tuple[int, int, int] = (128, 128, 128)) -> Frame:
@@ -72,7 +72,7 @@ def test_hashes_populated():
def test_events_emitted(monkeypatch):
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
monkeypatch.setattr("core.detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
frames = [_make_frame(i) for i in range(5)]

View File

@@ -1,6 +1,6 @@
"""Round-trip serialization tests for SSE contract models."""
from detect.sse_contract import (
from core.detect.sse import (
BoundingBoxEvent,
BrandSummary,
Detection,

View File

@@ -1,7 +1,7 @@
"""Tests for the stage registry."""
from detect.stages import list_stages, get_stage, get_palette
from detect.stages.base import get_stage_class
from core.detect.stages import list_stages, get_stage, get_palette
from core.detect.stages.base import get_stage_class
EXPECTED_STAGES = [

View File

@@ -2,7 +2,7 @@
import pytest
from detect.tracing import trace_node, SpanContext, flush
from core.detect.tracing import trace_node, SpanContext, flush
def test_trace_node_noop():

View File

@@ -3,8 +3,8 @@
import numpy as np
import pytest
from detect.models import BoundingBox, Frame, PipelineStats, TextCandidate
from detect.stages.vlm_cloud import escalate_cloud, _parse_response
from core.detect.models import BoundingBox, Frame, PipelineStats, TextCandidate
from core.detect.stages.vlm_cloud import escalate_cloud, _parse_response
def _make_candidate(text: str = "unknown", confidence: float = 0.4) -> TextCandidate:
@@ -30,14 +30,14 @@ def test_parse_response_no_confidence():
def test_escalate_skips_without_api_key(monkeypatch):
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
monkeypatch.setattr("core.detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
monkeypatch.delenv("GROQ_API_KEY", raising=False)
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq")
# Reset cached provider
import detect.providers as prov
import core.detect.providers as prov
monkeypatch.setattr(prov, "_cached", None)
candidates = [_make_candidate()]
@@ -54,7 +54,7 @@ def test_escalate_skips_without_api_key(monkeypatch):
def test_escalate_empty_candidates(monkeypatch):
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
monkeypatch.setattr("core.detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
stats = PipelineStats()
@@ -66,18 +66,18 @@ def test_escalate_empty_candidates(monkeypatch):
def test_escalate_with_mock_api(monkeypatch):
events = []
monkeypatch.setattr("detect.emit.push_detect_event",
monkeypatch.setattr("core.detect.emit.push_detect_event",
lambda job_id, etype, data: events.append((etype, data)))
monkeypatch.setenv("GROQ_API_KEY", "test-key")
monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq")
# Reset cached provider
import detect.providers as prov
import core.detect.providers as prov
monkeypatch.setattr(prov, "_cached", None)
def mock_call(image_b64, prompt):
return {"brand": "Heineken", "confidence": 0.75, "reasoning": "green logo", "tokens": 300}
monkeypatch.setattr("detect.stages.vlm_cloud._call_cloud_api", mock_call)
monkeypatch.setattr("core.detect.stages.vlm_cloud._call_cloud_api", mock_call)
candidates = [_make_candidate("unknown logo")]
stats = PipelineStats()