Files
mediaproc/tests/detect/test_checkpoint.py
2026-03-26 04:40:00 -03:00

183 lines
5.4 KiB
Python

"""Tests for checkpoint serialization — round-trip without S3."""
import numpy as np
import pytest
from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
from detect.checkpoint.serializer import (
serialize_state,
deserialize_state,
serialize_frame_meta,
serialize_text_candidate,
deserialize_text_candidate,
)
def _make_frame(seq: int = 0, w: int = 100, h: int = 80) -> Frame:
image = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
return Frame(
sequence=seq,
chunk_id=0,
timestamp=float(seq) * 0.5,
image=image,
perceptual_hash=f"hash_{seq}",
)
def _make_box(x=10, y=10, w=30, h=20) -> BoundingBox:
return BoundingBox(x=x, y=y, w=w, h=h, confidence=0.9, label="text")
def _make_candidate(frame: Frame, text: str = "NIKE") -> TextCandidate:
box = _make_box()
return TextCandidate(frame=frame, bbox=box, text=text, ocr_confidence=0.85)
def _make_detection(brand: str = "Nike", timestamp: float = 1.0) -> BrandDetection:
return BrandDetection(
brand=brand,
timestamp=timestamp,
duration=0.5,
confidence=0.92,
source="ocr",
content_type="soccer_broadcast",
frame_ref=0,
)
# --- Frame metadata ---
def test_serialize_frame_meta():
frame = _make_frame(seq=5)
meta = serialize_frame_meta(frame)
assert meta["sequence"] == 5
assert meta["timestamp"] == 2.5
assert meta["perceptual_hash"] == "hash_5"
assert "image" not in meta
# --- TextCandidate ---
def test_serialize_text_candidate():
frame = _make_frame()
candidate = _make_candidate(frame, text="EMIRATES")
data = serialize_text_candidate(candidate)
assert data["frame_sequence"] == 0
assert data["text"] == "EMIRATES"
assert data["ocr_confidence"] == 0.85
assert "bbox" in data
def test_deserialize_text_candidate():
frame = _make_frame()
candidate = _make_candidate(frame, text="ADIDAS")
serialized = serialize_text_candidate(candidate)
frame_map = {frame.sequence: frame}
restored = deserialize_text_candidate(serialized, frame_map)
assert restored.text == "ADIDAS"
assert restored.ocr_confidence == 0.85
assert restored.frame is frame # same object reference
assert restored.bbox.x == 10
# --- Full state round-trip ---
def test_state_round_trip():
frames = [_make_frame(seq=i) for i in range(3)]
filtered = frames[:2]
box = _make_box()
boxes_by_frame = {0: [box], 1: [box]}
candidates = [_make_candidate(frames[0], "NIKE"), _make_candidate(frames[1], "EMIRATES")]
unresolved = [_make_candidate(frames[2], "unknown")]
detections = [_make_detection("Nike", 0.5), _make_detection("Emirates", 1.0)]
stats = PipelineStats(
frames_extracted=3,
frames_after_scene_filter=2,
regions_detected=2,
regions_resolved_by_ocr=2,
cloud_llm_calls=1,
estimated_cloud_cost_usd=0.003,
)
state = {
"job_id": "test-123",
"video_path": "/tmp/test.mp4",
"profile_name": "soccer_broadcast",
"config_overrides": {"ocr": {"min_confidence": 0.3}},
"frames": frames,
"filtered_frames": filtered,
"boxes_by_frame": boxes_by_frame,
"text_candidates": candidates,
"unresolved_candidates": unresolved,
"detections": detections,
"stats": stats,
}
manifest = {f.sequence: f"s3://fake/frames/{f.sequence}.jpg" for f in frames}
# Serialize
serialized = serialize_state(state, manifest)
# Verify JSON-compatible (no numpy, no Frame objects)
import json
json_str = json.dumps(serialized, default=str)
assert len(json_str) > 0
# Deserialize with the original frames (simulating frame load from S3)
restored = deserialize_state(serialized, frames)
# Verify round-trip
assert restored["job_id"] == "test-123"
assert restored["video_path"] == "/tmp/test.mp4"
assert restored["profile_name"] == "soccer_broadcast"
assert restored["config_overrides"] == {"ocr": {"min_confidence": 0.3}}
assert len(restored["frames"]) == 3
assert len(restored["filtered_frames"]) == 2
assert len(restored["boxes_by_frame"]) == 2
assert len(restored["text_candidates"]) == 2
assert len(restored["unresolved_candidates"]) == 1
assert len(restored["detections"]) == 2
restored_stats = restored["stats"]
assert restored_stats.frames_extracted == 3
assert restored_stats.cloud_llm_calls == 1
assert restored_stats.estimated_cloud_cost_usd == 0.003
# TextCandidate frame references should point to actual Frame objects
tc = restored["text_candidates"][0]
assert tc.frame is frames[0]
assert tc.text == "NIKE"
def test_state_round_trip_empty():
"""Empty state should serialize/deserialize cleanly."""
state = {
"job_id": "empty-job",
"video_path": "",
"profile_name": "soccer_broadcast",
"frames": [],
"filtered_frames": [],
"boxes_by_frame": {},
"text_candidates": [],
"unresolved_candidates": [],
"detections": [],
"stats": PipelineStats(),
}
serialized = serialize_state(state, {})
restored = deserialize_state(serialized, [])
assert restored["job_id"] == "empty-job"
assert len(restored["frames"]) == 0
assert len(restored["detections"]) == 0
assert restored["stats"].frames_extracted == 0