183 lines
5.4 KiB
Python
183 lines
5.4 KiB
Python
"""Tests for checkpoint serialization — round-trip without S3."""
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
|
|
from detect.checkpoint.serializer import (
|
|
serialize_state,
|
|
deserialize_state,
|
|
serialize_frame_meta,
|
|
serialize_text_candidate,
|
|
deserialize_text_candidate,
|
|
)
|
|
|
|
|
|
def _make_frame(seq: int = 0, w: int = 100, h: int = 80) -> Frame:
|
|
image = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
|
|
return Frame(
|
|
sequence=seq,
|
|
chunk_id=0,
|
|
timestamp=float(seq) * 0.5,
|
|
image=image,
|
|
perceptual_hash=f"hash_{seq}",
|
|
)
|
|
|
|
|
|
def _make_box(x=10, y=10, w=30, h=20) -> BoundingBox:
|
|
return BoundingBox(x=x, y=y, w=w, h=h, confidence=0.9, label="text")
|
|
|
|
|
|
def _make_candidate(frame: Frame, text: str = "NIKE") -> TextCandidate:
|
|
box = _make_box()
|
|
return TextCandidate(frame=frame, bbox=box, text=text, ocr_confidence=0.85)
|
|
|
|
|
|
def _make_detection(brand: str = "Nike", timestamp: float = 1.0) -> BrandDetection:
|
|
return BrandDetection(
|
|
brand=brand,
|
|
timestamp=timestamp,
|
|
duration=0.5,
|
|
confidence=0.92,
|
|
source="ocr",
|
|
content_type="soccer_broadcast",
|
|
frame_ref=0,
|
|
)
|
|
|
|
|
|
# --- Frame metadata ---
|
|
|
|
def test_serialize_frame_meta():
|
|
frame = _make_frame(seq=5)
|
|
meta = serialize_frame_meta(frame)
|
|
|
|
assert meta["sequence"] == 5
|
|
assert meta["timestamp"] == 2.5
|
|
assert meta["perceptual_hash"] == "hash_5"
|
|
assert "image" not in meta
|
|
|
|
|
|
# --- TextCandidate ---
|
|
|
|
def test_serialize_text_candidate():
|
|
frame = _make_frame()
|
|
candidate = _make_candidate(frame, text="EMIRATES")
|
|
data = serialize_text_candidate(candidate)
|
|
|
|
assert data["frame_sequence"] == 0
|
|
assert data["text"] == "EMIRATES"
|
|
assert data["ocr_confidence"] == 0.85
|
|
assert "bbox" in data
|
|
|
|
|
|
def test_deserialize_text_candidate():
|
|
frame = _make_frame()
|
|
candidate = _make_candidate(frame, text="ADIDAS")
|
|
|
|
serialized = serialize_text_candidate(candidate)
|
|
frame_map = {frame.sequence: frame}
|
|
|
|
restored = deserialize_text_candidate(serialized, frame_map)
|
|
|
|
assert restored.text == "ADIDAS"
|
|
assert restored.ocr_confidence == 0.85
|
|
assert restored.frame is frame # same object reference
|
|
assert restored.bbox.x == 10
|
|
|
|
|
|
# --- Full state round-trip ---
|
|
|
|
def test_state_round_trip():
|
|
frames = [_make_frame(seq=i) for i in range(3)]
|
|
filtered = frames[:2]
|
|
|
|
box = _make_box()
|
|
boxes_by_frame = {0: [box], 1: [box]}
|
|
|
|
candidates = [_make_candidate(frames[0], "NIKE"), _make_candidate(frames[1], "EMIRATES")]
|
|
unresolved = [_make_candidate(frames[2], "unknown")]
|
|
detections = [_make_detection("Nike", 0.5), _make_detection("Emirates", 1.0)]
|
|
|
|
stats = PipelineStats(
|
|
frames_extracted=3,
|
|
frames_after_scene_filter=2,
|
|
regions_detected=2,
|
|
regions_resolved_by_ocr=2,
|
|
cloud_llm_calls=1,
|
|
estimated_cloud_cost_usd=0.003,
|
|
)
|
|
|
|
state = {
|
|
"job_id": "test-123",
|
|
"video_path": "/tmp/test.mp4",
|
|
"profile_name": "soccer_broadcast",
|
|
"config_overrides": {"ocr": {"min_confidence": 0.3}},
|
|
"frames": frames,
|
|
"filtered_frames": filtered,
|
|
"boxes_by_frame": boxes_by_frame,
|
|
"text_candidates": candidates,
|
|
"unresolved_candidates": unresolved,
|
|
"detections": detections,
|
|
"stats": stats,
|
|
}
|
|
|
|
manifest = {f.sequence: f"s3://fake/frames/{f.sequence}.jpg" for f in frames}
|
|
|
|
# Serialize
|
|
serialized = serialize_state(state, manifest)
|
|
|
|
# Verify JSON-compatible (no numpy, no Frame objects)
|
|
import json
|
|
json_str = json.dumps(serialized, default=str)
|
|
assert len(json_str) > 0
|
|
|
|
# Deserialize with the original frames (simulating frame load from S3)
|
|
restored = deserialize_state(serialized, frames)
|
|
|
|
# Verify round-trip
|
|
assert restored["job_id"] == "test-123"
|
|
assert restored["video_path"] == "/tmp/test.mp4"
|
|
assert restored["profile_name"] == "soccer_broadcast"
|
|
assert restored["config_overrides"] == {"ocr": {"min_confidence": 0.3}}
|
|
|
|
assert len(restored["frames"]) == 3
|
|
assert len(restored["filtered_frames"]) == 2
|
|
assert len(restored["boxes_by_frame"]) == 2
|
|
assert len(restored["text_candidates"]) == 2
|
|
assert len(restored["unresolved_candidates"]) == 1
|
|
assert len(restored["detections"]) == 2
|
|
|
|
restored_stats = restored["stats"]
|
|
assert restored_stats.frames_extracted == 3
|
|
assert restored_stats.cloud_llm_calls == 1
|
|
assert restored_stats.estimated_cloud_cost_usd == 0.003
|
|
|
|
# TextCandidate frame references should point to actual Frame objects
|
|
tc = restored["text_candidates"][0]
|
|
assert tc.frame is frames[0]
|
|
assert tc.text == "NIKE"
|
|
|
|
|
|
def test_state_round_trip_empty():
|
|
"""Empty state should serialize/deserialize cleanly."""
|
|
state = {
|
|
"job_id": "empty-job",
|
|
"video_path": "",
|
|
"profile_name": "soccer_broadcast",
|
|
"frames": [],
|
|
"filtered_frames": [],
|
|
"boxes_by_frame": {},
|
|
"text_candidates": [],
|
|
"unresolved_candidates": [],
|
|
"detections": [],
|
|
"stats": PipelineStats(),
|
|
}
|
|
|
|
serialized = serialize_state(state, {})
|
|
restored = deserialize_state(serialized, [])
|
|
|
|
assert restored["job_id"] == "empty-job"
|
|
assert len(restored["frames"]) == 0
|
|
assert len(restored["detections"]) == 0
|
|
assert restored["stats"].frames_extracted == 0
|