"""Tests for checkpoint serialization — unit tests without S3.""" import json import numpy as np import pytest from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate from core.schema.serializers._common import safe_construct from core.schema.serializers.detect_pipeline import ( serialize_frame_meta, serialize_text_candidate, serialize_text_candidates, serialize_dataclass_list, deserialize_text_candidate, deserialize_text_candidates, deserialize_bounding_box, deserialize_brand_detection, deserialize_pipeline_stats, ) def _make_frame(seq: int = 0, w: int = 100, h: int = 80) -> Frame: image = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8) return Frame( sequence=seq, chunk_id=0, timestamp=float(seq) * 0.5, image=image, perceptual_hash=f"hash_{seq}", ) def _make_box(x=10, y=10, w=30, h=20) -> BoundingBox: return BoundingBox(x=x, y=y, w=w, h=h, confidence=0.9, label="text") def _make_candidate(frame: Frame, text: str = "NIKE") -> TextCandidate: box = _make_box() return TextCandidate(frame=frame, bbox=box, text=text, ocr_confidence=0.85) def _make_detection(brand: str = "Nike", timestamp: float = 1.0) -> BrandDetection: return BrandDetection( brand=brand, timestamp=timestamp, duration=0.5, confidence=0.92, source="ocr", content_type="soccer_broadcast", frame_ref=0, ) # --- Frame metadata --- def test_serialize_frame_meta(): frame = _make_frame(seq=5) meta = serialize_frame_meta(frame) assert meta["sequence"] == 5 assert meta["timestamp"] == 2.5 assert meta["perceptual_hash"] == "hash_5" assert "image" not in meta # --- TextCandidate --- def test_serialize_text_candidate(): frame = _make_frame() candidate = _make_candidate(frame, text="EMIRATES") data = serialize_text_candidate(candidate) assert data["frame_sequence"] == 0 assert data["text"] == "EMIRATES" assert data["ocr_confidence"] == 0.85 assert "bbox" in data def test_deserialize_text_candidate(): frame = _make_frame() candidate = _make_candidate(frame, text="ADIDAS") serialized = serialize_text_candidate(candidate) frame_map = {frame.sequence: frame} restored = deserialize_text_candidate(serialized, frame_map) assert restored.text == "ADIDAS" assert restored.ocr_confidence == 0.85 assert restored.frame is frame assert restored.bbox.x == 10 # --- BoundingBox --- def test_bounding_box_round_trip(): box = _make_box(x=15, y=25, w=40, h=30) serialized = serialize_dataclass_list([box])[0] restored = deserialize_bounding_box(serialized) assert restored.x == 15 assert restored.w == 40 assert restored.confidence == 0.9 # --- BrandDetection --- def test_brand_detection_round_trip(): det = _make_detection("Emirates", 3.5) serialized = serialize_dataclass_list([det])[0] restored = deserialize_brand_detection(serialized) assert restored.brand == "Emirates" assert restored.timestamp == 3.5 assert restored.source == "ocr" # --- PipelineStats --- def test_pipeline_stats_round_trip(): stats = PipelineStats( frames_extracted=120, cloud_llm_calls=3, estimated_cloud_cost_usd=0.005, ) serialized = serialize_dataclass_list([stats])[0] restored = deserialize_pipeline_stats(serialized) assert restored.frames_extracted == 120 assert restored.cloud_llm_calls == 3 assert restored.estimated_cloud_cost_usd == 0.005 # --- safe_construct tolerates schema changes --- def test_safe_construct_ignores_unknown_fields(): data = {"x": 1, "y": 2, "w": 3, "h": 4, "confidence": 0.5, "label": "t", "extra_field": 99} box = safe_construct(BoundingBox, data) assert box.x == 1 assert not hasattr(box, "extra_field") def test_safe_construct_uses_defaults(): data = {"frames_extracted": 50} stats = safe_construct(PipelineStats, data) assert stats.frames_extracted == 50 assert stats.cloud_llm_calls == 0 # default # --- JSON compatibility --- def test_all_serialized_is_json_compatible(): frame = _make_frame() candidate = _make_candidate(frame) detection = _make_detection() stats = PipelineStats(frames_extracted=10) all_data = { "frame_meta": serialize_frame_meta(frame), "candidate": serialize_text_candidate(candidate), "detection": serialize_dataclass_list([detection])[0], "stats": serialize_dataclass_list([stats])[0], } json_str = json.dumps(all_data, default=str) assert len(json_str) > 0 roundtrip = json.loads(json_str) assert roundtrip["frame_meta"]["sequence"] == frame.sequence