Files
mediaproc/tests/detect/test_checkpoint.py
2026-03-27 06:14:02 -03:00

197 lines
5.9 KiB
Python

"""Tests for checkpoint serialization — unit tests without S3."""
import json
import numpy as np
import pytest
from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
from core.schema.serializers._common import safe_construct
from core.schema.serializers.pipeline import (
serialize_frame_meta,
serialize_text_candidate,
serialize_text_candidates,
serialize_dataclass_list,
deserialize_text_candidate,
deserialize_text_candidates,
deserialize_bounding_box,
deserialize_brand_detection,
deserialize_pipeline_stats,
)
def _make_frame(seq: int = 0, w: int = 100, h: int = 80) -> Frame:
image = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
return Frame(
sequence=seq, chunk_id=0, timestamp=float(seq) * 0.5,
image=image, perceptual_hash=f"hash_{seq}",
)
def _make_box(x=10, y=10, w=30, h=20) -> BoundingBox:
return BoundingBox(x=x, y=y, w=w, h=h, confidence=0.9, label="text")
def _make_candidate(frame: Frame, text: str = "NIKE") -> TextCandidate:
box = _make_box()
return TextCandidate(frame=frame, bbox=box, text=text, ocr_confidence=0.85)
def _make_detection(brand: str = "Nike", timestamp: float = 1.0) -> BrandDetection:
return BrandDetection(
brand=brand, timestamp=timestamp, duration=0.5,
confidence=0.92, source="ocr", content_type="soccer_broadcast", frame_ref=0,
)
# --- Frame metadata ---
def test_serialize_frame_meta():
frame = _make_frame(seq=5)
meta = serialize_frame_meta(frame)
assert meta["sequence"] == 5
assert meta["timestamp"] == 2.5
assert meta["perceptual_hash"] == "hash_5"
assert "image" not in meta
# --- TextCandidate ---
def test_serialize_text_candidate():
frame = _make_frame()
candidate = _make_candidate(frame, text="EMIRATES")
data = serialize_text_candidate(candidate)
assert data["frame_sequence"] == 0
assert data["text"] == "EMIRATES"
assert data["ocr_confidence"] == 0.85
assert "bbox" in data
def test_deserialize_text_candidate():
frame = _make_frame()
candidate = _make_candidate(frame, text="ADIDAS")
serialized = serialize_text_candidate(candidate)
frame_map = {frame.sequence: frame}
restored = deserialize_text_candidate(serialized, frame_map)
assert restored.text == "ADIDAS"
assert restored.ocr_confidence == 0.85
assert restored.frame is frame
assert restored.bbox.x == 10
# --- BoundingBox ---
def test_bounding_box_round_trip():
box = _make_box(x=15, y=25, w=40, h=30)
serialized = serialize_dataclass_list([box])[0]
restored = deserialize_bounding_box(serialized)
assert restored.x == 15
assert restored.w == 40
assert restored.confidence == 0.9
# --- BrandDetection ---
def test_brand_detection_round_trip():
det = _make_detection("Emirates", 3.5)
serialized = serialize_dataclass_list([det])[0]
restored = deserialize_brand_detection(serialized)
assert restored.brand == "Emirates"
assert restored.timestamp == 3.5
assert restored.source == "ocr"
# --- PipelineStats ---
def test_pipeline_stats_round_trip():
stats = PipelineStats(
frames_extracted=120, cloud_llm_calls=3,
estimated_cloud_cost_usd=0.005,
)
serialized = serialize_dataclass_list([stats])[0]
restored = deserialize_pipeline_stats(serialized)
assert restored.frames_extracted == 120
assert restored.cloud_llm_calls == 3
assert restored.estimated_cloud_cost_usd == 0.005
# --- safe_construct tolerates schema changes ---
def test_safe_construct_ignores_unknown_fields():
data = {"x": 1, "y": 2, "w": 3, "h": 4, "confidence": 0.5, "label": "t", "extra_field": 99}
box = safe_construct(BoundingBox, data)
assert box.x == 1
assert not hasattr(box, "extra_field")
def test_safe_construct_uses_defaults():
data = {"frames_extracted": 50}
stats = safe_construct(PipelineStats, data)
assert stats.frames_extracted == 50
assert stats.cloud_llm_calls == 0 # default
# --- JSON compatibility ---
def test_all_serialized_is_json_compatible():
frame = _make_frame()
candidate = _make_candidate(frame)
detection = _make_detection()
stats = PipelineStats(frames_extracted=10)
all_data = {
"frame_meta": serialize_frame_meta(frame),
"candidate": serialize_text_candidate(candidate),
"detection": serialize_dataclass_list([detection])[0],
"stats": serialize_dataclass_list([stats])[0],
}
json_str = json.dumps(all_data, default=str)
assert len(json_str) > 0
roundtrip = json.loads(json_str)
assert roundtrip["frame_meta"]["sequence"] == frame.sequence
# --- OverrideProfile ---
def test_override_profile_region_analysis():
"""OverrideProfile must patch region_analysis_config with overrides."""
from detect.checkpoint.replay import OverrideProfile
from detect.profiles.soccer import SoccerBroadcastProfile
from detect.profiles.base import RegionAnalysisConfig
base = SoccerBroadcastProfile()
original = base.region_analysis_config()
overrides = {"region_analysis": {"edge_canny_low": 25, "edge_canny_high": 200}}
wrapped = OverrideProfile(base, overrides)
patched = wrapped.region_analysis_config()
assert isinstance(patched, RegionAnalysisConfig)
assert patched.edge_canny_low == 25
assert patched.edge_canny_high == 200
# Unmodified fields keep their defaults
assert patched.edge_hough_threshold == original.edge_hough_threshold
def test_override_profile_passthrough():
"""OverrideProfile without region_analysis key passes through unchanged."""
from detect.checkpoint.replay import OverrideProfile
from detect.profiles.soccer import SoccerBroadcastProfile
base = SoccerBroadcastProfile()
wrapped = OverrideProfile(base, {"ocr": {"min_confidence": 0.1}})
config = wrapped.region_analysis_config()
assert config.edge_canny_low == base.region_analysis_config().edge_canny_low