phase 10
This commit is contained in:
182
tests/detect/test_checkpoint.py
Normal file
182
tests/detect/test_checkpoint.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""Tests for checkpoint serialization — round-trip without S3."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
|
||||
from detect.checkpoint.serializer import (
|
||||
serialize_state,
|
||||
deserialize_state,
|
||||
serialize_frame_meta,
|
||||
serialize_text_candidate,
|
||||
deserialize_text_candidate,
|
||||
)
|
||||
|
||||
|
||||
def _make_frame(seq: int = 0, w: int = 100, h: int = 80) -> Frame:
|
||||
image = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
|
||||
return Frame(
|
||||
sequence=seq,
|
||||
chunk_id=0,
|
||||
timestamp=float(seq) * 0.5,
|
||||
image=image,
|
||||
perceptual_hash=f"hash_{seq}",
|
||||
)
|
||||
|
||||
|
||||
def _make_box(x=10, y=10, w=30, h=20) -> BoundingBox:
|
||||
return BoundingBox(x=x, y=y, w=w, h=h, confidence=0.9, label="text")
|
||||
|
||||
|
||||
def _make_candidate(frame: Frame, text: str = "NIKE") -> TextCandidate:
|
||||
box = _make_box()
|
||||
return TextCandidate(frame=frame, bbox=box, text=text, ocr_confidence=0.85)
|
||||
|
||||
|
||||
def _make_detection(brand: str = "Nike", timestamp: float = 1.0) -> BrandDetection:
|
||||
return BrandDetection(
|
||||
brand=brand,
|
||||
timestamp=timestamp,
|
||||
duration=0.5,
|
||||
confidence=0.92,
|
||||
source="ocr",
|
||||
content_type="soccer_broadcast",
|
||||
frame_ref=0,
|
||||
)
|
||||
|
||||
|
||||
# --- Frame metadata ---
|
||||
|
||||
def test_serialize_frame_meta():
|
||||
frame = _make_frame(seq=5)
|
||||
meta = serialize_frame_meta(frame)
|
||||
|
||||
assert meta["sequence"] == 5
|
||||
assert meta["timestamp"] == 2.5
|
||||
assert meta["perceptual_hash"] == "hash_5"
|
||||
assert "image" not in meta
|
||||
|
||||
|
||||
# --- TextCandidate ---
|
||||
|
||||
def test_serialize_text_candidate():
|
||||
frame = _make_frame()
|
||||
candidate = _make_candidate(frame, text="EMIRATES")
|
||||
data = serialize_text_candidate(candidate)
|
||||
|
||||
assert data["frame_sequence"] == 0
|
||||
assert data["text"] == "EMIRATES"
|
||||
assert data["ocr_confidence"] == 0.85
|
||||
assert "bbox" in data
|
||||
|
||||
|
||||
def test_deserialize_text_candidate():
|
||||
frame = _make_frame()
|
||||
candidate = _make_candidate(frame, text="ADIDAS")
|
||||
|
||||
serialized = serialize_text_candidate(candidate)
|
||||
frame_map = {frame.sequence: frame}
|
||||
|
||||
restored = deserialize_text_candidate(serialized, frame_map)
|
||||
|
||||
assert restored.text == "ADIDAS"
|
||||
assert restored.ocr_confidence == 0.85
|
||||
assert restored.frame is frame # same object reference
|
||||
assert restored.bbox.x == 10
|
||||
|
||||
|
||||
# --- Full state round-trip ---
|
||||
|
||||
def test_state_round_trip():
|
||||
frames = [_make_frame(seq=i) for i in range(3)]
|
||||
filtered = frames[:2]
|
||||
|
||||
box = _make_box()
|
||||
boxes_by_frame = {0: [box], 1: [box]}
|
||||
|
||||
candidates = [_make_candidate(frames[0], "NIKE"), _make_candidate(frames[1], "EMIRATES")]
|
||||
unresolved = [_make_candidate(frames[2], "unknown")]
|
||||
detections = [_make_detection("Nike", 0.5), _make_detection("Emirates", 1.0)]
|
||||
|
||||
stats = PipelineStats(
|
||||
frames_extracted=3,
|
||||
frames_after_scene_filter=2,
|
||||
regions_detected=2,
|
||||
regions_resolved_by_ocr=2,
|
||||
cloud_llm_calls=1,
|
||||
estimated_cloud_cost_usd=0.003,
|
||||
)
|
||||
|
||||
state = {
|
||||
"job_id": "test-123",
|
||||
"video_path": "/tmp/test.mp4",
|
||||
"profile_name": "soccer_broadcast",
|
||||
"config_overrides": {"ocr": {"min_confidence": 0.3}},
|
||||
"frames": frames,
|
||||
"filtered_frames": filtered,
|
||||
"boxes_by_frame": boxes_by_frame,
|
||||
"text_candidates": candidates,
|
||||
"unresolved_candidates": unresolved,
|
||||
"detections": detections,
|
||||
"stats": stats,
|
||||
}
|
||||
|
||||
manifest = {f.sequence: f"s3://fake/frames/{f.sequence}.jpg" for f in frames}
|
||||
|
||||
# Serialize
|
||||
serialized = serialize_state(state, manifest)
|
||||
|
||||
# Verify JSON-compatible (no numpy, no Frame objects)
|
||||
import json
|
||||
json_str = json.dumps(serialized, default=str)
|
||||
assert len(json_str) > 0
|
||||
|
||||
# Deserialize with the original frames (simulating frame load from S3)
|
||||
restored = deserialize_state(serialized, frames)
|
||||
|
||||
# Verify round-trip
|
||||
assert restored["job_id"] == "test-123"
|
||||
assert restored["video_path"] == "/tmp/test.mp4"
|
||||
assert restored["profile_name"] == "soccer_broadcast"
|
||||
assert restored["config_overrides"] == {"ocr": {"min_confidence": 0.3}}
|
||||
|
||||
assert len(restored["frames"]) == 3
|
||||
assert len(restored["filtered_frames"]) == 2
|
||||
assert len(restored["boxes_by_frame"]) == 2
|
||||
assert len(restored["text_candidates"]) == 2
|
||||
assert len(restored["unresolved_candidates"]) == 1
|
||||
assert len(restored["detections"]) == 2
|
||||
|
||||
restored_stats = restored["stats"]
|
||||
assert restored_stats.frames_extracted == 3
|
||||
assert restored_stats.cloud_llm_calls == 1
|
||||
assert restored_stats.estimated_cloud_cost_usd == 0.003
|
||||
|
||||
# TextCandidate frame references should point to actual Frame objects
|
||||
tc = restored["text_candidates"][0]
|
||||
assert tc.frame is frames[0]
|
||||
assert tc.text == "NIKE"
|
||||
|
||||
|
||||
def test_state_round_trip_empty():
|
||||
"""Empty state should serialize/deserialize cleanly."""
|
||||
state = {
|
||||
"job_id": "empty-job",
|
||||
"video_path": "",
|
||||
"profile_name": "soccer_broadcast",
|
||||
"frames": [],
|
||||
"filtered_frames": [],
|
||||
"boxes_by_frame": {},
|
||||
"text_candidates": [],
|
||||
"unresolved_candidates": [],
|
||||
"detections": [],
|
||||
"stats": PipelineStats(),
|
||||
}
|
||||
|
||||
serialized = serialize_state(state, {})
|
||||
restored = deserialize_state(serialized, [])
|
||||
|
||||
assert restored["job_id"] == "empty-job"
|
||||
assert len(restored["frames"]) == 0
|
||||
assert len(restored["detections"]) == 0
|
||||
assert restored["stats"].frames_extracted == 0
|
||||
Reference in New Issue
Block a user