schema clean up and refactor
This commit is contained in:
@@ -1,26 +1,30 @@
|
||||
"""Tests for checkpoint serialization — round-trip without S3."""
|
||||
"""Tests for checkpoint serialization — unit tests without S3."""
|
||||
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
|
||||
from detect.checkpoint.serializer import (
|
||||
serialize_state,
|
||||
deserialize_state,
|
||||
from core.schema.serializers._common import safe_construct
|
||||
from core.schema.serializers.detect_pipeline import (
|
||||
serialize_frame_meta,
|
||||
serialize_text_candidate,
|
||||
serialize_text_candidates,
|
||||
serialize_dataclass_list,
|
||||
deserialize_text_candidate,
|
||||
deserialize_text_candidates,
|
||||
deserialize_bounding_box,
|
||||
deserialize_brand_detection,
|
||||
deserialize_pipeline_stats,
|
||||
)
|
||||
|
||||
|
||||
def _make_frame(seq: int = 0, w: int = 100, h: int = 80) -> Frame:
|
||||
image = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
|
||||
return Frame(
|
||||
sequence=seq,
|
||||
chunk_id=0,
|
||||
timestamp=float(seq) * 0.5,
|
||||
image=image,
|
||||
perceptual_hash=f"hash_{seq}",
|
||||
sequence=seq, chunk_id=0, timestamp=float(seq) * 0.5,
|
||||
image=image, perceptual_hash=f"hash_{seq}",
|
||||
)
|
||||
|
||||
|
||||
@@ -35,13 +39,8 @@ def _make_candidate(frame: Frame, text: str = "NIKE") -> TextCandidate:
|
||||
|
||||
def _make_detection(brand: str = "Nike", timestamp: float = 1.0) -> BrandDetection:
|
||||
return BrandDetection(
|
||||
brand=brand,
|
||||
timestamp=timestamp,
|
||||
duration=0.5,
|
||||
confidence=0.92,
|
||||
source="ocr",
|
||||
content_type="soccer_broadcast",
|
||||
frame_ref=0,
|
||||
brand=brand, timestamp=timestamp, duration=0.5,
|
||||
confidence=0.92, source="ocr", content_type="soccer_broadcast", frame_ref=0,
|
||||
)
|
||||
|
||||
|
||||
@@ -81,102 +80,84 @@ def test_deserialize_text_candidate():
|
||||
|
||||
assert restored.text == "ADIDAS"
|
||||
assert restored.ocr_confidence == 0.85
|
||||
assert restored.frame is frame # same object reference
|
||||
assert restored.frame is frame
|
||||
assert restored.bbox.x == 10
|
||||
|
||||
|
||||
# --- Full state round-trip ---
|
||||
# --- BoundingBox ---
|
||||
|
||||
def test_state_round_trip():
|
||||
frames = [_make_frame(seq=i) for i in range(3)]
|
||||
filtered = frames[:2]
|
||||
def test_bounding_box_round_trip():
|
||||
box = _make_box(x=15, y=25, w=40, h=30)
|
||||
serialized = serialize_dataclass_list([box])[0]
|
||||
restored = deserialize_bounding_box(serialized)
|
||||
|
||||
box = _make_box()
|
||||
boxes_by_frame = {0: [box], 1: [box]}
|
||||
assert restored.x == 15
|
||||
assert restored.w == 40
|
||||
assert restored.confidence == 0.9
|
||||
|
||||
candidates = [_make_candidate(frames[0], "NIKE"), _make_candidate(frames[1], "EMIRATES")]
|
||||
unresolved = [_make_candidate(frames[2], "unknown")]
|
||||
detections = [_make_detection("Nike", 0.5), _make_detection("Emirates", 1.0)]
|
||||
|
||||
# --- BrandDetection ---
|
||||
|
||||
def test_brand_detection_round_trip():
|
||||
det = _make_detection("Emirates", 3.5)
|
||||
serialized = serialize_dataclass_list([det])[0]
|
||||
restored = deserialize_brand_detection(serialized)
|
||||
|
||||
assert restored.brand == "Emirates"
|
||||
assert restored.timestamp == 3.5
|
||||
assert restored.source == "ocr"
|
||||
|
||||
|
||||
# --- PipelineStats ---
|
||||
|
||||
def test_pipeline_stats_round_trip():
|
||||
stats = PipelineStats(
|
||||
frames_extracted=3,
|
||||
frames_after_scene_filter=2,
|
||||
regions_detected=2,
|
||||
regions_resolved_by_ocr=2,
|
||||
cloud_llm_calls=1,
|
||||
estimated_cloud_cost_usd=0.003,
|
||||
frames_extracted=120, cloud_llm_calls=3,
|
||||
estimated_cloud_cost_usd=0.005,
|
||||
)
|
||||
serialized = serialize_dataclass_list([stats])[0]
|
||||
restored = deserialize_pipeline_stats(serialized)
|
||||
|
||||
state = {
|
||||
"job_id": "test-123",
|
||||
"video_path": "/tmp/test.mp4",
|
||||
"profile_name": "soccer_broadcast",
|
||||
"config_overrides": {"ocr": {"min_confidence": 0.3}},
|
||||
"frames": frames,
|
||||
"filtered_frames": filtered,
|
||||
"boxes_by_frame": boxes_by_frame,
|
||||
"text_candidates": candidates,
|
||||
"unresolved_candidates": unresolved,
|
||||
"detections": detections,
|
||||
"stats": stats,
|
||||
assert restored.frames_extracted == 120
|
||||
assert restored.cloud_llm_calls == 3
|
||||
assert restored.estimated_cloud_cost_usd == 0.005
|
||||
|
||||
|
||||
# --- safe_construct tolerates schema changes ---
|
||||
|
||||
def test_safe_construct_ignores_unknown_fields():
|
||||
data = {"x": 1, "y": 2, "w": 3, "h": 4, "confidence": 0.5, "label": "t", "extra_field": 99}
|
||||
box = safe_construct(BoundingBox, data)
|
||||
|
||||
assert box.x == 1
|
||||
assert not hasattr(box, "extra_field")
|
||||
|
||||
|
||||
def test_safe_construct_uses_defaults():
|
||||
data = {"frames_extracted": 50}
|
||||
stats = safe_construct(PipelineStats, data)
|
||||
|
||||
assert stats.frames_extracted == 50
|
||||
assert stats.cloud_llm_calls == 0 # default
|
||||
|
||||
|
||||
# --- JSON compatibility ---
|
||||
|
||||
def test_all_serialized_is_json_compatible():
|
||||
frame = _make_frame()
|
||||
candidate = _make_candidate(frame)
|
||||
detection = _make_detection()
|
||||
stats = PipelineStats(frames_extracted=10)
|
||||
|
||||
all_data = {
|
||||
"frame_meta": serialize_frame_meta(frame),
|
||||
"candidate": serialize_text_candidate(candidate),
|
||||
"detection": serialize_dataclass_list([detection])[0],
|
||||
"stats": serialize_dataclass_list([stats])[0],
|
||||
}
|
||||
|
||||
manifest = {f.sequence: f"s3://fake/frames/{f.sequence}.jpg" for f in frames}
|
||||
|
||||
# Serialize
|
||||
serialized = serialize_state(state, manifest)
|
||||
|
||||
# Verify JSON-compatible (no numpy, no Frame objects)
|
||||
import json
|
||||
json_str = json.dumps(serialized, default=str)
|
||||
json_str = json.dumps(all_data, default=str)
|
||||
assert len(json_str) > 0
|
||||
|
||||
# Deserialize with the original frames (simulating frame load from S3)
|
||||
restored = deserialize_state(serialized, frames)
|
||||
|
||||
# Verify round-trip
|
||||
assert restored["job_id"] == "test-123"
|
||||
assert restored["video_path"] == "/tmp/test.mp4"
|
||||
assert restored["profile_name"] == "soccer_broadcast"
|
||||
assert restored["config_overrides"] == {"ocr": {"min_confidence": 0.3}}
|
||||
|
||||
assert len(restored["frames"]) == 3
|
||||
assert len(restored["filtered_frames"]) == 2
|
||||
assert len(restored["boxes_by_frame"]) == 2
|
||||
assert len(restored["text_candidates"]) == 2
|
||||
assert len(restored["unresolved_candidates"]) == 1
|
||||
assert len(restored["detections"]) == 2
|
||||
|
||||
restored_stats = restored["stats"]
|
||||
assert restored_stats.frames_extracted == 3
|
||||
assert restored_stats.cloud_llm_calls == 1
|
||||
assert restored_stats.estimated_cloud_cost_usd == 0.003
|
||||
|
||||
# TextCandidate frame references should point to actual Frame objects
|
||||
tc = restored["text_candidates"][0]
|
||||
assert tc.frame is frames[0]
|
||||
assert tc.text == "NIKE"
|
||||
|
||||
|
||||
def test_state_round_trip_empty():
|
||||
"""Empty state should serialize/deserialize cleanly."""
|
||||
state = {
|
||||
"job_id": "empty-job",
|
||||
"video_path": "",
|
||||
"profile_name": "soccer_broadcast",
|
||||
"frames": [],
|
||||
"filtered_frames": [],
|
||||
"boxes_by_frame": {},
|
||||
"text_candidates": [],
|
||||
"unresolved_candidates": [],
|
||||
"detections": [],
|
||||
"stats": PipelineStats(),
|
||||
}
|
||||
|
||||
serialized = serialize_state(state, {})
|
||||
restored = deserialize_state(serialized, [])
|
||||
|
||||
assert restored["job_id"] == "empty-job"
|
||||
assert len(restored["frames"]) == 0
|
||||
assert len(restored["detections"]) == 0
|
||||
assert restored["stats"].frames_extracted == 0
|
||||
roundtrip = json.loads(json_str)
|
||||
assert roundtrip["frame_meta"]["sequence"] == frame.sequence
|
||||
|
||||
Reference in New Issue
Block a user