197 lines
5.9 KiB
Python
197 lines
5.9 KiB
Python
"""Tests for checkpoint serialization — unit tests without S3."""
|
|
|
|
import json
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
|
|
from core.schema.serializers._common import safe_construct
|
|
from core.schema.serializers.detect_pipeline import (
|
|
serialize_frame_meta,
|
|
serialize_text_candidate,
|
|
serialize_text_candidates,
|
|
serialize_dataclass_list,
|
|
deserialize_text_candidate,
|
|
deserialize_text_candidates,
|
|
deserialize_bounding_box,
|
|
deserialize_brand_detection,
|
|
deserialize_pipeline_stats,
|
|
)
|
|
|
|
|
|
def _make_frame(seq: int = 0, w: int = 100, h: int = 80) -> Frame:
|
|
image = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
|
|
return Frame(
|
|
sequence=seq, chunk_id=0, timestamp=float(seq) * 0.5,
|
|
image=image, perceptual_hash=f"hash_{seq}",
|
|
)
|
|
|
|
|
|
def _make_box(x=10, y=10, w=30, h=20) -> BoundingBox:
|
|
return BoundingBox(x=x, y=y, w=w, h=h, confidence=0.9, label="text")
|
|
|
|
|
|
def _make_candidate(frame: Frame, text: str = "NIKE") -> TextCandidate:
|
|
box = _make_box()
|
|
return TextCandidate(frame=frame, bbox=box, text=text, ocr_confidence=0.85)
|
|
|
|
|
|
def _make_detection(brand: str = "Nike", timestamp: float = 1.0) -> BrandDetection:
|
|
return BrandDetection(
|
|
brand=brand, timestamp=timestamp, duration=0.5,
|
|
confidence=0.92, source="ocr", content_type="soccer_broadcast", frame_ref=0,
|
|
)
|
|
|
|
|
|
# --- Frame metadata ---
|
|
|
|
def test_serialize_frame_meta():
|
|
frame = _make_frame(seq=5)
|
|
meta = serialize_frame_meta(frame)
|
|
|
|
assert meta["sequence"] == 5
|
|
assert meta["timestamp"] == 2.5
|
|
assert meta["perceptual_hash"] == "hash_5"
|
|
assert "image" not in meta
|
|
|
|
|
|
# --- TextCandidate ---
|
|
|
|
def test_serialize_text_candidate():
|
|
frame = _make_frame()
|
|
candidate = _make_candidate(frame, text="EMIRATES")
|
|
data = serialize_text_candidate(candidate)
|
|
|
|
assert data["frame_sequence"] == 0
|
|
assert data["text"] == "EMIRATES"
|
|
assert data["ocr_confidence"] == 0.85
|
|
assert "bbox" in data
|
|
|
|
|
|
def test_deserialize_text_candidate():
|
|
frame = _make_frame()
|
|
candidate = _make_candidate(frame, text="ADIDAS")
|
|
|
|
serialized = serialize_text_candidate(candidate)
|
|
frame_map = {frame.sequence: frame}
|
|
|
|
restored = deserialize_text_candidate(serialized, frame_map)
|
|
|
|
assert restored.text == "ADIDAS"
|
|
assert restored.ocr_confidence == 0.85
|
|
assert restored.frame is frame
|
|
assert restored.bbox.x == 10
|
|
|
|
|
|
# --- BoundingBox ---
|
|
|
|
def test_bounding_box_round_trip():
|
|
box = _make_box(x=15, y=25, w=40, h=30)
|
|
serialized = serialize_dataclass_list([box])[0]
|
|
restored = deserialize_bounding_box(serialized)
|
|
|
|
assert restored.x == 15
|
|
assert restored.w == 40
|
|
assert restored.confidence == 0.9
|
|
|
|
|
|
# --- BrandDetection ---
|
|
|
|
def test_brand_detection_round_trip():
|
|
det = _make_detection("Emirates", 3.5)
|
|
serialized = serialize_dataclass_list([det])[0]
|
|
restored = deserialize_brand_detection(serialized)
|
|
|
|
assert restored.brand == "Emirates"
|
|
assert restored.timestamp == 3.5
|
|
assert restored.source == "ocr"
|
|
|
|
|
|
# --- PipelineStats ---
|
|
|
|
def test_pipeline_stats_round_trip():
|
|
stats = PipelineStats(
|
|
frames_extracted=120, cloud_llm_calls=3,
|
|
estimated_cloud_cost_usd=0.005,
|
|
)
|
|
serialized = serialize_dataclass_list([stats])[0]
|
|
restored = deserialize_pipeline_stats(serialized)
|
|
|
|
assert restored.frames_extracted == 120
|
|
assert restored.cloud_llm_calls == 3
|
|
assert restored.estimated_cloud_cost_usd == 0.005
|
|
|
|
|
|
# --- safe_construct tolerates schema changes ---
|
|
|
|
def test_safe_construct_ignores_unknown_fields():
|
|
data = {"x": 1, "y": 2, "w": 3, "h": 4, "confidence": 0.5, "label": "t", "extra_field": 99}
|
|
box = safe_construct(BoundingBox, data)
|
|
|
|
assert box.x == 1
|
|
assert not hasattr(box, "extra_field")
|
|
|
|
|
|
def test_safe_construct_uses_defaults():
|
|
data = {"frames_extracted": 50}
|
|
stats = safe_construct(PipelineStats, data)
|
|
|
|
assert stats.frames_extracted == 50
|
|
assert stats.cloud_llm_calls == 0 # default
|
|
|
|
|
|
# --- JSON compatibility ---
|
|
|
|
def test_all_serialized_is_json_compatible():
|
|
frame = _make_frame()
|
|
candidate = _make_candidate(frame)
|
|
detection = _make_detection()
|
|
stats = PipelineStats(frames_extracted=10)
|
|
|
|
all_data = {
|
|
"frame_meta": serialize_frame_meta(frame),
|
|
"candidate": serialize_text_candidate(candidate),
|
|
"detection": serialize_dataclass_list([detection])[0],
|
|
"stats": serialize_dataclass_list([stats])[0],
|
|
}
|
|
|
|
json_str = json.dumps(all_data, default=str)
|
|
assert len(json_str) > 0
|
|
|
|
roundtrip = json.loads(json_str)
|
|
assert roundtrip["frame_meta"]["sequence"] == frame.sequence
|
|
|
|
|
|
# --- OverrideProfile ---
|
|
|
|
def test_override_profile_region_analysis():
|
|
"""OverrideProfile must patch region_analysis_config with overrides."""
|
|
from detect.checkpoint.replay import OverrideProfile
|
|
from detect.profiles.soccer import SoccerBroadcastProfile
|
|
from detect.profiles.base import RegionAnalysisConfig
|
|
|
|
base = SoccerBroadcastProfile()
|
|
original = base.region_analysis_config()
|
|
|
|
overrides = {"region_analysis": {"edge_canny_low": 25, "edge_canny_high": 200}}
|
|
wrapped = OverrideProfile(base, overrides)
|
|
patched = wrapped.region_analysis_config()
|
|
|
|
assert isinstance(patched, RegionAnalysisConfig)
|
|
assert patched.edge_canny_low == 25
|
|
assert patched.edge_canny_high == 200
|
|
# Unmodified fields keep their defaults
|
|
assert patched.edge_hough_threshold == original.edge_hough_threshold
|
|
|
|
|
|
def test_override_profile_passthrough():
|
|
"""OverrideProfile without region_analysis key passes through unchanged."""
|
|
from detect.checkpoint.replay import OverrideProfile
|
|
from detect.profiles.soccer import SoccerBroadcastProfile
|
|
|
|
base = SoccerBroadcastProfile()
|
|
wrapped = OverrideProfile(base, {"ocr": {"min_confidence": 0.1}})
|
|
config = wrapped.region_analysis_config()
|
|
assert config.edge_canny_low == base.region_analysis_config().edge_canny_low
|