phase 4
This commit is contained in:
@@ -24,9 +24,9 @@ logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(m
|
||||
|
||||
sys.path.insert(0, ".")
|
||||
|
||||
from detect.profiles.soccer import SoccerBroadcastProfile
|
||||
from detect.stages.frame_extractor import extract_frames
|
||||
from detect.stages.scene_filter import scene_filter
|
||||
from core.detect.profile import get_profile
|
||||
from core.detect.stages.frame_extractor import extract_frames
|
||||
from core.detect.stages.scene_filter import scene_filter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -24,8 +24,8 @@ logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(m
|
||||
|
||||
sys.path.insert(0, ".")
|
||||
|
||||
from detect.graph import get_pipeline
|
||||
from detect.state import DetectState
|
||||
from core.detect.graph import get_pipeline
|
||||
from core.detect.state import DetectState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -39,13 +39,13 @@ sys.path.insert(0, ".")
|
||||
|
||||
from langgraph.graph import END, StateGraph
|
||||
|
||||
from detect import emit
|
||||
from detect.models import PipelineStats
|
||||
from detect.profiles.soccer import SoccerBroadcastProfile
|
||||
from detect.stages.frame_extractor import extract_frames
|
||||
from detect.stages.scene_filter import scene_filter
|
||||
from detect.stages.edge_detector import detect_edge_regions
|
||||
from detect.state import DetectState
|
||||
from core.detect import emit
|
||||
from core.detect.models import PipelineStats
|
||||
from core.detect.profile import get_profile
|
||||
from core.detect.stages.frame_extractor import extract_frames
|
||||
from core.detect.stages.scene_filter import scene_filter
|
||||
from core.detect.stages.edge_detector import detect_edge_regions
|
||||
from core.detect.state import DetectState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -166,7 +166,7 @@ def main():
|
||||
# --- Parameter sensitivity ---
|
||||
logger.info("=== Parameter sensitivity (local debug) ===")
|
||||
|
||||
from detect.stages.edge_detector import _load_cv_edges
|
||||
from core.detect.stages.edge_detector import _load_cv_edges
|
||||
edges_mod = _load_cv_edges()
|
||||
|
||||
filtered = result.get("filtered_frames", [])
|
||||
|
||||
@@ -58,7 +58,7 @@ def extract_frames_ffmpeg(video_path: str, fps: float, max_frames: int):
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from detect.models import Frame
|
||||
from core.detect.models import Frame
|
||||
|
||||
tmpdir = tempfile.mkdtemp(prefix="scenario_")
|
||||
pattern = os.path.join(tmpdir, "frame_%04d.jpg")
|
||||
@@ -111,7 +111,7 @@ def main():
|
||||
logger.info("Extracted %d frames", len(frames))
|
||||
|
||||
# Create timeline + branch + checkpoint
|
||||
from detect.checkpoint.storage import create_timeline, save_stage_output
|
||||
from core.detect.checkpoint.storage import create_timeline, save_stage_output
|
||||
|
||||
timeline_id, branch_id = create_timeline(
|
||||
source_video=video_path,
|
||||
|
||||
@@ -58,7 +58,7 @@ def make_brand_image(text: str, width: int = 300, height: int = 100) -> str:
|
||||
|
||||
|
||||
def main():
|
||||
from detect.providers import get_provider, has_api_key, PROVIDERS
|
||||
from core.detect.providers import get_provider, has_api_key, PROVIDERS
|
||||
|
||||
provider_name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
|
||||
logger.info("Provider: %s", provider_name)
|
||||
|
||||
@@ -13,8 +13,8 @@ import sys
|
||||
|
||||
sys.path.insert(0, ".")
|
||||
|
||||
from detect.profiles.soccer import SoccerBroadcastProfile
|
||||
from detect.stages.frame_extractor import extract_frames
|
||||
from core.detect.profile import get_profile
|
||||
from core.detect.stages.frame_extractor import extract_frames
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -86,9 +86,9 @@ def test_ocr_stage_remote(url: str):
|
||||
logger.info("--- OCR stage (remote mode) ---")
|
||||
|
||||
sys.path.insert(0, ".")
|
||||
from detect.models import BoundingBox, Frame
|
||||
from detect.profiles.base import OCRConfig
|
||||
from detect.stages.ocr_stage import run_ocr
|
||||
from core.detect.models import BoundingBox, Frame
|
||||
from core.detect.stages.models import OCRConfig
|
||||
from core.detect.stages.ocr_stage import run_ocr
|
||||
|
||||
# Create a frame with text baked in
|
||||
image = make_text_image("EMIRATES")
|
||||
|
||||
@@ -48,10 +48,10 @@ def main():
|
||||
# Override Redis to localhost (ctrl/.env has k8s hostname)
|
||||
os.environ["REDIS_URL"] = f"redis://localhost:{args.port}/0"
|
||||
|
||||
from detect.graph import get_pipeline, NODES
|
||||
from detect.checkpoint import list_checkpoints
|
||||
from detect.checkpoint import replay_from
|
||||
from detect.state import DetectState
|
||||
from core.detect.graph import get_pipeline, NODES
|
||||
from core.detect.checkpoint import list_checkpoints
|
||||
from core.detect.checkpoint import replay_from
|
||||
from core.detect.state import DetectState
|
||||
|
||||
VIDEO = "media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4"
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, BrandDetection, PipelineStats
|
||||
from detect.stages.aggregator import compile_report, _merge_contiguous
|
||||
from core.detect.models import BoundingBox, BrandDetection, PipelineStats
|
||||
from core.detect.stages.aggregator import compile_report, _merge_contiguous
|
||||
|
||||
|
||||
def _make_detection(brand: str, timestamp: float, duration: float = 0.5,
|
||||
@@ -43,7 +43,7 @@ def test_merge_empty():
|
||||
|
||||
def test_compile_report(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
dets = [
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, Frame, TextCandidate
|
||||
from detect.profiles.base import ResolverConfig
|
||||
from detect.stages.brand_resolver import resolve_brands, _normalize, _match_session
|
||||
from core.detect.models import BoundingBox, Frame, TextCandidate
|
||||
from core.detect.stages.models import ResolverConfig
|
||||
from core.detect.stages.brand_resolver import resolve_brands, _normalize, _match_session
|
||||
|
||||
|
||||
CONFIG = ResolverConfig(fuzzy_threshold=75)
|
||||
@@ -28,7 +28,7 @@ def test_session_match():
|
||||
|
||||
def test_resolve_with_session(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
session = {"nike": "Nike", "emirates": "Emirates"}
|
||||
@@ -46,7 +46,7 @@ def test_resolve_with_session(monkeypatch):
|
||||
|
||||
def test_resolve_unresolved_without_db(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
candidates = [_make_candidate("random garbage text")]
|
||||
@@ -61,7 +61,7 @@ def test_resolve_unresolved_without_db(monkeypatch):
|
||||
|
||||
def test_resolve_empty(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
matched, unresolved = resolve_brands([], CONFIG, session_brands={})
|
||||
@@ -73,7 +73,7 @@ def test_resolve_empty(monkeypatch):
|
||||
def test_resolve_builds_session_during_run(monkeypatch):
|
||||
"""Session brands accumulate during a single run — second candidate benefits."""
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
session = {"nike": "Nike"}
|
||||
@@ -93,7 +93,7 @@ def test_resolve_builds_session_during_run(monkeypatch):
|
||||
|
||||
def test_events_emitted(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
session = {"nike": "Nike"}
|
||||
|
||||
@@ -5,7 +5,7 @@ import json
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
|
||||
from core.detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
|
||||
from core.schema.serializers._common import safe_construct
|
||||
from core.schema.serializers.pipeline import (
|
||||
serialize_frame_meta,
|
||||
@@ -163,34 +163,39 @@ def test_all_serialized_is_json_compatible():
|
||||
assert roundtrip["frame_meta"]["sequence"] == frame.sequence
|
||||
|
||||
|
||||
# --- OverrideProfile ---
|
||||
# --- Config overrides (dict merging, replaces OverrideProfile) ---
|
||||
|
||||
def test_override_profile_region_analysis():
|
||||
"""OverrideProfile must patch region_analysis_config with overrides."""
|
||||
from detect.checkpoint.replay import OverrideProfile
|
||||
from detect.profiles.soccer import SoccerBroadcastProfile
|
||||
from detect.profiles.base import RegionAnalysisConfig
|
||||
def test_config_override_region_analysis():
|
||||
"""Config overrides must patch stage config values."""
|
||||
from core.detect.profile import get_profile, get_stage_config
|
||||
from core.detect.stages.models import RegionAnalysisConfig
|
||||
|
||||
base = SoccerBroadcastProfile()
|
||||
original = base.region_analysis_config()
|
||||
profile = get_profile("soccer_broadcast")
|
||||
original = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
||||
|
||||
overrides = {"region_analysis": {"edge_canny_low": 25, "edge_canny_high": 200}}
|
||||
wrapped = OverrideProfile(base, overrides)
|
||||
patched = wrapped.region_analysis_config()
|
||||
overrides = {"detect_edges": {"edge_canny_low": 25, "edge_canny_high": 200}}
|
||||
merged_configs = {**profile["configs"]}
|
||||
merged_configs["detect_edges"] = {**merged_configs["detect_edges"], **overrides["detect_edges"]}
|
||||
patched_profile = {**profile, "configs": merged_configs}
|
||||
|
||||
patched = RegionAnalysisConfig(**get_stage_config(patched_profile, "detect_edges"))
|
||||
|
||||
assert isinstance(patched, RegionAnalysisConfig)
|
||||
assert patched.edge_canny_low == 25
|
||||
assert patched.edge_canny_high == 200
|
||||
# Unmodified fields keep their defaults
|
||||
assert patched.edge_hough_threshold == original.edge_hough_threshold
|
||||
|
||||
|
||||
def test_override_profile_passthrough():
|
||||
"""OverrideProfile without region_analysis key passes through unchanged."""
|
||||
from detect.checkpoint.replay import OverrideProfile
|
||||
from detect.profiles.soccer import SoccerBroadcastProfile
|
||||
def test_config_override_passthrough():
|
||||
"""Overrides for other stages don't affect unrelated stages."""
|
||||
from core.detect.profile import get_profile, get_stage_config
|
||||
from core.detect.stages.models import RegionAnalysisConfig
|
||||
|
||||
base = SoccerBroadcastProfile()
|
||||
wrapped = OverrideProfile(base, {"ocr": {"min_confidence": 0.1}})
|
||||
config = wrapped.region_analysis_config()
|
||||
assert config.edge_canny_low == base.region_analysis_config().edge_canny_low
|
||||
profile = get_profile("soccer_broadcast")
|
||||
original = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
||||
|
||||
overrides = {"run_ocr": {"min_confidence": 0.1}}
|
||||
merged_configs = {**profile["configs"], **overrides}
|
||||
patched_profile = {**profile, "configs": merged_configs}
|
||||
|
||||
patched = RegionAnalysisConfig(**get_stage_config(patched_profile, "detect_edges"))
|
||||
assert patched.edge_canny_low == original.edge_canny_low
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for the config endpoint and stage palette."""
|
||||
|
||||
from detect.stages import list_stages, get_palette
|
||||
from core.detect.stages import list_stages, get_palette
|
||||
|
||||
|
||||
def test_stage_palette_has_config_fields():
|
||||
|
||||
@@ -15,7 +15,7 @@ import pytest
|
||||
|
||||
# Load edges module directly
|
||||
_spec = importlib.util.spec_from_file_location(
|
||||
"cv_edges", Path("gpu/models/cv/edges.py"),
|
||||
"cv_edges", Path("core/gpu/models/cv/edges.py"),
|
||||
)
|
||||
_edges_mod = importlib.util.module_from_spec(_spec)
|
||||
_spec.loader.exec_module(_edges_mod)
|
||||
|
||||
@@ -5,8 +5,8 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from detect.profiles.base import FrameExtractionConfig
|
||||
from detect.stages.frame_extractor import extract_frames
|
||||
from core.detect.stages.models import FrameExtractionConfig
|
||||
from core.detect.stages.frame_extractor import extract_frames
|
||||
|
||||
SAMPLE_DIR = Path("media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e")
|
||||
|
||||
@@ -61,7 +61,7 @@ def test_extract_frames_with_events(monkeypatch):
|
||||
def mock_push(job_id, event_type, data):
|
||||
events.append((job_id, event_type, data))
|
||||
|
||||
monkeypatch.setattr("detect.emit.push_detect_event", mock_push)
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event", mock_push)
|
||||
|
||||
video = _get_sample_video()
|
||||
config = FrameExtractionConfig(fps=1, max_frames=5)
|
||||
|
||||
@@ -4,9 +4,9 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from detect.graph import NODES, build_graph, get_pipeline
|
||||
from detect.models import PipelineStats
|
||||
from detect.state import DetectState
|
||||
from core.detect.graph import NODES, build_graph, get_pipeline
|
||||
from core.detect.models import PipelineStats
|
||||
from core.detect.state import DetectState
|
||||
|
||||
VIDEO = "media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4"
|
||||
|
||||
@@ -42,7 +42,7 @@ def test_graph_has_all_nodes():
|
||||
def test_graph_runs_end_to_end(monkeypatch):
|
||||
"""Run the full graph with mocked event emission."""
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
pipeline = get_pipeline()
|
||||
@@ -75,7 +75,7 @@ def test_graph_runs_end_to_end(monkeypatch):
|
||||
def test_graph_node_transitions(monkeypatch):
|
||||
"""Verify each node emits running → done transitions."""
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
pipeline = get_pipeline()
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, Frame
|
||||
from detect.profiles.base import OCRConfig
|
||||
from detect.stages.ocr_stage import _crop_region, _parse_ocr_raw, run_ocr
|
||||
from core.detect.models import BoundingBox, Frame
|
||||
from core.detect.stages.models import OCRConfig
|
||||
from core.detect.stages.ocr_stage import _crop_region, _parse_ocr_raw, run_ocr
|
||||
|
||||
|
||||
def _has_paddleocr() -> bool:
|
||||
@@ -80,7 +80,7 @@ def test_parse_empty():
|
||||
|
||||
def test_run_ocr_remote(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
class FakeResult:
|
||||
@@ -94,11 +94,11 @@ def test_run_ocr_remote(monkeypatch):
|
||||
def ocr(self, image, languages):
|
||||
return [FakeResult("NIKE", 0.92)]
|
||||
|
||||
monkeypatch.setattr("detect.stages.ocr_stage.InferenceClient", FakeClient,
|
||||
monkeypatch.setattr("core.detect.stages.ocr_stage.InferenceClient", FakeClient,
|
||||
raising=False)
|
||||
# Patch the import path used in the function
|
||||
import detect.stages.ocr_stage as mod
|
||||
monkeypatch.setattr("detect.inference.InferenceClient", FakeClient)
|
||||
import core.detect.stages.ocr_stage as mod
|
||||
monkeypatch.setattr("core.detect.inference.InferenceClient", FakeClient)
|
||||
|
||||
frame = _make_frame()
|
||||
box = _make_box()
|
||||
@@ -123,7 +123,7 @@ def test_run_ocr_remote(monkeypatch):
|
||||
)
|
||||
def test_run_ocr_skips_empty_crop(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
frame = _make_frame(w=10, h=10)
|
||||
|
||||
@@ -26,7 +26,7 @@ def _make_image(w: int = 200, h: int = 60) -> np.ndarray:
|
||||
|
||||
@requires_cv2
|
||||
def test_binarize():
|
||||
from gpu.models.preprocess import binarize
|
||||
from core.gpu.models.preprocess import binarize
|
||||
|
||||
img = _make_image()
|
||||
result = binarize(img)
|
||||
@@ -40,7 +40,7 @@ def test_binarize():
|
||||
|
||||
@requires_cv2
|
||||
def test_enhance_contrast():
|
||||
from gpu.models.preprocess import enhance_contrast
|
||||
from core.gpu.models.preprocess import enhance_contrast
|
||||
|
||||
img = _make_image()
|
||||
result = enhance_contrast(img)
|
||||
@@ -51,7 +51,7 @@ def test_enhance_contrast():
|
||||
|
||||
@requires_cv2
|
||||
def test_deskew_no_rotation():
|
||||
from gpu.models.preprocess import deskew
|
||||
from core.gpu.models.preprocess import deskew
|
||||
|
||||
img = _make_image()
|
||||
result = deskew(img)
|
||||
@@ -63,7 +63,7 @@ def test_deskew_no_rotation():
|
||||
|
||||
@requires_cv2
|
||||
def test_preprocess_pipeline():
|
||||
from gpu.models.preprocess import preprocess
|
||||
from core.gpu.models.preprocess import preprocess
|
||||
|
||||
img = _make_image()
|
||||
|
||||
@@ -76,7 +76,7 @@ def test_preprocess_pipeline():
|
||||
|
||||
@requires_cv2
|
||||
def test_preprocess_all_disabled():
|
||||
from gpu.models.preprocess import preprocess
|
||||
from core.gpu.models.preprocess import preprocess
|
||||
|
||||
img = _make_image()
|
||||
result = preprocess(img, do_binarize=False, do_deskew=False, do_contrast=False)
|
||||
|
||||
@@ -1,55 +1,70 @@
|
||||
"""Tests for ContentTypeProfile implementations."""
|
||||
"""Tests for profile data and helper functions."""
|
||||
|
||||
import pytest
|
||||
|
||||
from detect.models import BrandDetection
|
||||
from detect.profiles.base import ContentTypeProfile, CropContext
|
||||
from detect.profiles.soccer import SoccerBroadcastProfile
|
||||
from detect.profiles.stubs import AdvertisingProfile, NewsBroadcastProfile, TranscriptProfile
|
||||
from core.detect.models import BrandDetection, CropContext
|
||||
from core.detect.profile import get_profile, get_stage_config, build_vlm_prompt, aggregate_detections, pipeline_config_from_dict
|
||||
from core.detect.stages.models import FrameExtractionConfig, DetectionConfig, ResolverConfig
|
||||
|
||||
|
||||
def test_soccer_satisfies_protocol():
|
||||
profile: ContentTypeProfile = SoccerBroadcastProfile()
|
||||
assert profile.name == "soccer_broadcast"
|
||||
def test_soccer_profile_exists():
|
||||
profile = get_profile("soccer_broadcast")
|
||||
assert profile["name"] == "soccer_broadcast"
|
||||
|
||||
|
||||
def test_soccer_has_pipeline():
|
||||
profile = get_profile("soccer_broadcast")
|
||||
assert "stages" in profile["pipeline"]
|
||||
assert "edges" in profile["pipeline"]
|
||||
|
||||
|
||||
def test_soccer_has_configs():
|
||||
profile = get_profile("soccer_broadcast")
|
||||
configs = profile["configs"]
|
||||
assert "extract_frames" in configs
|
||||
assert "filter_scenes" in configs
|
||||
assert "detect_edges" in configs
|
||||
|
||||
|
||||
def test_soccer_frame_extraction_config():
|
||||
cfg = SoccerBroadcastProfile().frame_extraction_config()
|
||||
profile = get_profile("soccer_broadcast")
|
||||
cfg = FrameExtractionConfig(**get_stage_config(profile, "extract_frames"))
|
||||
assert cfg.fps > 0
|
||||
assert cfg.max_frames > 0
|
||||
|
||||
|
||||
def test_soccer_detection_config():
|
||||
cfg = SoccerBroadcastProfile().detection_config()
|
||||
profile = get_profile("soccer_broadcast")
|
||||
cfg = DetectionConfig(**get_stage_config(profile, "detect_objects"))
|
||||
assert 0 < cfg.confidence_threshold < 1
|
||||
assert isinstance(cfg.target_classes, list)
|
||||
|
||||
|
||||
def test_soccer_resolver_config():
|
||||
cfg = SoccerBroadcastProfile().resolver_config()
|
||||
profile = get_profile("soccer_broadcast")
|
||||
cfg = ResolverConfig(**get_stage_config(profile, "match_brands"))
|
||||
assert cfg.fuzzy_threshold > 0
|
||||
|
||||
|
||||
def test_soccer_vlm_prompt():
|
||||
def test_vlm_prompt():
|
||||
ctx = CropContext(image=b"fake", surrounding_text="Emirates", position_hint="top-center")
|
||||
prompt = SoccerBroadcastProfile().vlm_prompt(ctx)
|
||||
template = get_profile("soccer_broadcast")["configs"]["escalate_vlm"]["vlm_prompt_template"]
|
||||
prompt = build_vlm_prompt(ctx, template)
|
||||
assert "brand" in prompt.lower()
|
||||
assert "Emirates" in prompt
|
||||
|
||||
|
||||
def test_soccer_aggregate_empty():
|
||||
report = SoccerBroadcastProfile().aggregate([])
|
||||
def test_aggregate_empty():
|
||||
report = aggregate_detections([], "soccer_broadcast")
|
||||
assert len(report.brands) == 0
|
||||
assert len(report.timeline) == 0
|
||||
|
||||
|
||||
def test_soccer_aggregate_groups():
|
||||
def test_aggregate_groups():
|
||||
detections = [
|
||||
BrandDetection(brand="Nike", timestamp=1.0, duration=0.5, confidence=0.9, source="ocr"),
|
||||
BrandDetection(brand="Nike", timestamp=2.0, duration=0.5, confidence=0.8, source="ocr"),
|
||||
BrandDetection(brand="Adidas", timestamp=3.0, duration=0.5, confidence=0.7, source="logo_match"),
|
||||
]
|
||||
report = SoccerBroadcastProfile().aggregate(detections)
|
||||
report = aggregate_detections(detections, "soccer_broadcast")
|
||||
assert "Nike" in report.brands
|
||||
assert "Adidas" in report.brands
|
||||
assert report.brands["Nike"].total_appearances == 2
|
||||
@@ -57,15 +72,9 @@ def test_soccer_aggregate_groups():
|
||||
assert report.timeline == sorted(report.timeline, key=lambda d: d.timestamp)
|
||||
|
||||
|
||||
def test_soccer_auxiliary_returns_empty():
|
||||
assert SoccerBroadcastProfile().auxiliary_detections("test.mp4") == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stub_cls", [NewsBroadcastProfile, AdvertisingProfile, TranscriptProfile])
|
||||
def test_stubs_raise(stub_cls):
|
||||
stub = stub_cls()
|
||||
assert isinstance(stub.name, str)
|
||||
with pytest.raises(NotImplementedError):
|
||||
stub.frame_extraction_config()
|
||||
with pytest.raises(NotImplementedError):
|
||||
stub.resolver_config()
|
||||
def test_pipeline_config():
|
||||
profile = get_profile("soccer_broadcast")
|
||||
config = pipeline_config_from_dict(profile["pipeline"])
|
||||
assert config.name == "soccer_broadcast"
|
||||
assert len(config.stages) > 0
|
||||
assert len(config.edges) > 0
|
||||
|
||||
@@ -6,14 +6,14 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, Frame
|
||||
from detect.profiles.base import RegionAnalysisConfig
|
||||
from detect.profiles.soccer import SoccerBroadcastProfile
|
||||
from core.detect.models import BoundingBox, Frame
|
||||
from core.detect.stages.models import RegionAnalysisConfig
|
||||
from core.detect.profile import get_profile, get_stage_config
|
||||
|
||||
|
||||
# Load edges module directly — gpu/models/__init__.py has GPU-only imports
|
||||
_spec = importlib.util.spec_from_file_location(
|
||||
"cv_edges", Path("gpu/models/cv/edges.py"),
|
||||
"cv_edges", Path("core/gpu/models/cv/edges.py"),
|
||||
)
|
||||
_edges_mod = importlib.util.module_from_spec(_spec)
|
||||
_spec.loader.exec_module(_edges_mod)
|
||||
@@ -40,8 +40,8 @@ def _make_frame_with_lines(seq: int = 0) -> Frame:
|
||||
# --- Config ---
|
||||
|
||||
def test_soccer_profile_has_region_analysis_config():
|
||||
profile = SoccerBroadcastProfile()
|
||||
config = profile.region_analysis_config()
|
||||
profile = get_profile("soccer_broadcast")
|
||||
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
||||
assert isinstance(config, RegionAnalysisConfig)
|
||||
assert config.enabled is True
|
||||
|
||||
@@ -133,9 +133,9 @@ def test_detect_edges_debug_blank_frame():
|
||||
|
||||
def test_stage_disabled(monkeypatch):
|
||||
"""When disabled, returns empty dict."""
|
||||
monkeypatch.setattr("detect.emit.push_detect_event", lambda *a, **kw: None)
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event", lambda *a, **kw: None)
|
||||
|
||||
from detect.stages.edge_detector import detect_edge_regions
|
||||
from core.detect.stages.edge_detector import detect_edge_regions
|
||||
|
||||
config = RegionAnalysisConfig(enabled=False)
|
||||
result = detect_edge_regions([_make_frame()], config, job_id="test")
|
||||
@@ -144,9 +144,9 @@ def test_stage_disabled(monkeypatch):
|
||||
|
||||
def test_stage_local_blank(monkeypatch):
|
||||
"""Local mode on blank frames returns empty boxes."""
|
||||
monkeypatch.setattr("detect.emit.push_detect_event", lambda *a, **kw: None)
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event", lambda *a, **kw: None)
|
||||
|
||||
from detect.stages.edge_detector import detect_edge_regions
|
||||
from core.detect.stages.edge_detector import detect_edge_regions
|
||||
|
||||
config = RegionAnalysisConfig()
|
||||
result = detect_edge_regions([_make_frame()], config, job_id="test")
|
||||
@@ -156,9 +156,9 @@ def test_stage_local_blank(monkeypatch):
|
||||
|
||||
def test_stage_local_with_lines(monkeypatch):
|
||||
"""Local mode on frame with lines should find regions."""
|
||||
monkeypatch.setattr("detect.emit.push_detect_event", lambda *a, **kw: None)
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event", lambda *a, **kw: None)
|
||||
|
||||
from detect.stages.edge_detector import detect_edge_regions
|
||||
from core.detect.stages.edge_detector import detect_edge_regions
|
||||
|
||||
config = RegionAnalysisConfig()
|
||||
frame = _make_frame_with_lines()
|
||||
@@ -174,22 +174,22 @@ def test_stage_local_with_lines(monkeypatch):
|
||||
|
||||
def test_detect_edges_in_nodes():
|
||||
"""detect_edges must be in the pipeline node list."""
|
||||
from detect.graph import NODES, NODE_FUNCTIONS
|
||||
from core.detect.graph import NODES, NODE_FUNCTIONS
|
||||
|
||||
assert "detect_edges" in NODES
|
||||
node_names = [name for name, _ in NODE_FUNCTIONS]
|
||||
assert "detect_edges" in node_names
|
||||
|
||||
# Must be after filter_scenes, before detect_objects
|
||||
# Must be after field_segmentation, before detect_objects
|
||||
idx = NODES.index("detect_edges")
|
||||
assert NODES[idx - 1] == "filter_scenes"
|
||||
assert NODES[idx - 1] == "field_segmentation"
|
||||
assert NODES[idx + 1] == "detect_objects"
|
||||
|
||||
|
||||
# --- State ---
|
||||
|
||||
def test_state_has_edge_regions_field():
|
||||
from detect.state import DetectState
|
||||
from core.detect.state import DetectState
|
||||
|
||||
hints = DetectState.__annotations__
|
||||
assert "edge_regions_by_frame" in hints
|
||||
|
||||
@@ -1,87 +1,67 @@
|
||||
"""Tests for replay and OverrideProfile."""
|
||||
"""Tests for config overrides and replay."""
|
||||
|
||||
import pytest
|
||||
|
||||
from detect.profiles.soccer import SoccerBroadcastProfile
|
||||
from detect.profiles.base import RegionAnalysisConfig
|
||||
from detect.checkpoint.replay import OverrideProfile, replay_single_stage
|
||||
from core.detect.profile import get_profile, get_stage_config
|
||||
from core.detect.stages.models import RegionAnalysisConfig, OCRConfig, ResolverConfig
|
||||
from core.detect.checkpoint.replay import replay_single_stage
|
||||
|
||||
|
||||
def test_override_profile_patches_ocr():
|
||||
base = SoccerBroadcastProfile()
|
||||
overrides = {"ocr": {"min_confidence": 0.3, "languages": ["en", "es", "pt"]}}
|
||||
profile = OverrideProfile(base, overrides)
|
||||
def _apply_overrides(profile, overrides):
|
||||
"""Apply config overrides to a profile dict (same logic as nodes._load_profile)."""
|
||||
merged_configs = dict(profile.get("configs", {}))
|
||||
for stage_name, stage_overrides in overrides.items():
|
||||
if stage_name in merged_configs:
|
||||
merged_configs[stage_name] = {**merged_configs[stage_name], **stage_overrides}
|
||||
else:
|
||||
merged_configs[stage_name] = stage_overrides
|
||||
return {**profile, "configs": merged_configs}
|
||||
|
||||
config = profile.ocr_config()
|
||||
|
||||
def test_override_patches_ocr():
|
||||
profile = get_profile("soccer_broadcast")
|
||||
overrides = {"run_ocr": {"min_confidence": 0.3, "languages": ["en", "es", "pt"]}}
|
||||
patched = _apply_overrides(profile, overrides)
|
||||
|
||||
config = OCRConfig(**get_stage_config(patched, "run_ocr"))
|
||||
|
||||
assert config.min_confidence == 0.3
|
||||
assert config.languages == ["en", "es", "pt"]
|
||||
|
||||
|
||||
def test_override_profile_patches_resolver():
|
||||
base = SoccerBroadcastProfile()
|
||||
overrides = {"resolver": {"fuzzy_threshold": 60}}
|
||||
profile = OverrideProfile(base, overrides)
|
||||
def test_override_patches_resolver():
|
||||
profile = get_profile("soccer_broadcast")
|
||||
overrides = {"match_brands": {"fuzzy_threshold": 60}}
|
||||
patched = _apply_overrides(profile, overrides)
|
||||
|
||||
config = profile.resolver_config()
|
||||
config = ResolverConfig(**get_stage_config(patched, "match_brands"))
|
||||
|
||||
assert config.fuzzy_threshold == 60
|
||||
|
||||
|
||||
def test_override_profile_patches_detection():
|
||||
base = SoccerBroadcastProfile()
|
||||
overrides = {"detection": {"confidence_threshold": 0.5}}
|
||||
profile = OverrideProfile(base, overrides)
|
||||
def test_override_no_overrides():
|
||||
profile = get_profile("soccer_broadcast")
|
||||
patched = _apply_overrides(profile, {})
|
||||
|
||||
config = profile.detection_config()
|
||||
|
||||
assert config.confidence_threshold == 0.5
|
||||
|
||||
|
||||
def test_override_profile_no_overrides():
|
||||
base = SoccerBroadcastProfile()
|
||||
profile = OverrideProfile(base, {})
|
||||
|
||||
ocr = profile.ocr_config()
|
||||
base_ocr = base.ocr_config()
|
||||
ocr = OCRConfig(**get_stage_config(patched, "run_ocr"))
|
||||
base_ocr = OCRConfig(**get_stage_config(profile, "run_ocr"))
|
||||
|
||||
assert ocr.min_confidence == base_ocr.min_confidence
|
||||
assert ocr.languages == base_ocr.languages
|
||||
|
||||
|
||||
def test_override_profile_delegates_non_config():
|
||||
base = SoccerBroadcastProfile()
|
||||
profile = OverrideProfile(base, {"ocr": {"min_confidence": 0.1}})
|
||||
def test_override_patches_region_analysis():
|
||||
profile = get_profile("soccer_broadcast")
|
||||
overrides = {"detect_edges": {"edge_canny_low": 25, "edge_canny_high": 200}}
|
||||
patched = _apply_overrides(profile, overrides)
|
||||
|
||||
assert profile.name == "soccer_broadcast"
|
||||
assert profile.resolver_config().fuzzy_threshold > 0
|
||||
config = RegionAnalysisConfig(**get_stage_config(patched, "detect_edges"))
|
||||
|
||||
|
||||
def test_override_profile_ignores_unknown_fields():
|
||||
base = SoccerBroadcastProfile()
|
||||
overrides = {"ocr": {"nonexistent_field": 42}}
|
||||
profile = OverrideProfile(base, overrides)
|
||||
|
||||
config = profile.ocr_config()
|
||||
|
||||
assert not hasattr(config, "nonexistent_field")
|
||||
assert config.min_confidence == base.ocr_config().min_confidence
|
||||
|
||||
|
||||
# --- OverrideProfile for region_analysis ---
|
||||
|
||||
def test_override_profile_patches_region_analysis():
|
||||
base = SoccerBroadcastProfile()
|
||||
overrides = {"region_analysis": {"edge_canny_low": 25, "edge_canny_high": 200}}
|
||||
profile = OverrideProfile(base, overrides)
|
||||
|
||||
config = profile.region_analysis_config()
|
||||
|
||||
assert isinstance(config, RegionAnalysisConfig)
|
||||
assert config.edge_canny_low == 25
|
||||
assert config.edge_canny_high == 200
|
||||
# Unchanged fields keep defaults
|
||||
assert config.edge_hough_threshold == base.region_analysis_config().edge_hough_threshold
|
||||
# Unchanged fields keep defaults from profile
|
||||
base_config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
||||
assert config.edge_hough_threshold == base_config.edge_hough_threshold
|
||||
|
||||
|
||||
# --- replay_single_stage ---
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from detect.models import Frame
|
||||
from detect.profiles.base import SceneFilterConfig
|
||||
from detect.stages.scene_filter import scene_filter
|
||||
from core.detect.models import Frame
|
||||
from core.detect.stages.models import SceneFilterConfig
|
||||
from core.detect.stages.scene_filter import scene_filter
|
||||
|
||||
|
||||
def _make_frame(seq: int, color: tuple[int, int, int] = (128, 128, 128)) -> Frame:
|
||||
@@ -72,7 +72,7 @@ def test_hashes_populated():
|
||||
|
||||
def test_events_emitted(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
frames = [_make_frame(i) for i in range(5)]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Round-trip serialization tests for SSE contract models."""
|
||||
|
||||
from detect.sse_contract import (
|
||||
from core.detect.sse import (
|
||||
BoundingBoxEvent,
|
||||
BrandSummary,
|
||||
Detection,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Tests for the stage registry."""
|
||||
|
||||
from detect.stages import list_stages, get_stage, get_palette
|
||||
from detect.stages.base import get_stage_class
|
||||
from core.detect.stages import list_stages, get_stage, get_palette
|
||||
from core.detect.stages.base import get_stage_class
|
||||
|
||||
|
||||
EXPECTED_STAGES = [
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from detect.tracing import trace_node, SpanContext, flush
|
||||
from core.detect.tracing import trace_node, SpanContext, flush
|
||||
|
||||
|
||||
def test_trace_node_noop():
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, Frame, PipelineStats, TextCandidate
|
||||
from detect.stages.vlm_cloud import escalate_cloud, _parse_response
|
||||
from core.detect.models import BoundingBox, Frame, PipelineStats, TextCandidate
|
||||
from core.detect.stages.vlm_cloud import escalate_cloud, _parse_response
|
||||
|
||||
|
||||
def _make_candidate(text: str = "unknown", confidence: float = 0.4) -> TextCandidate:
|
||||
@@ -30,14 +30,14 @@ def test_parse_response_no_confidence():
|
||||
|
||||
def test_escalate_skips_without_api_key(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq")
|
||||
# Reset cached provider
|
||||
import detect.providers as prov
|
||||
import core.detect.providers as prov
|
||||
monkeypatch.setattr(prov, "_cached", None)
|
||||
|
||||
candidates = [_make_candidate()]
|
||||
@@ -54,7 +54,7 @@ def test_escalate_skips_without_api_key(monkeypatch):
|
||||
|
||||
def test_escalate_empty_candidates(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
stats = PipelineStats()
|
||||
@@ -66,18 +66,18 @@ def test_escalate_empty_candidates(monkeypatch):
|
||||
|
||||
def test_escalate_with_mock_api(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
monkeypatch.setattr("core.detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
monkeypatch.setenv("GROQ_API_KEY", "test-key")
|
||||
monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq")
|
||||
# Reset cached provider
|
||||
import detect.providers as prov
|
||||
import core.detect.providers as prov
|
||||
monkeypatch.setattr(prov, "_cached", None)
|
||||
|
||||
def mock_call(image_b64, prompt):
|
||||
return {"brand": "Heineken", "confidence": 0.75, "reasoning": "green logo", "tokens": 300}
|
||||
|
||||
monkeypatch.setattr("detect.stages.vlm_cloud._call_cloud_api", mock_call)
|
||||
monkeypatch.setattr("core.detect.stages.vlm_cloud._call_cloud_api", mock_call)
|
||||
|
||||
candidates = [_make_candidate("unknown logo")]
|
||||
stats = PipelineStats()
|
||||
|
||||
Reference in New Issue
Block a user