schema clean up and refactor

This commit is contained in:
2026-03-26 05:14:33 -03:00
parent 08c58a6a9d
commit d58a90157a
17 changed files with 930 additions and 287 deletions

View File

@@ -1,26 +1,30 @@
"""Tests for checkpoint serialization — round-trip without S3."""
"""Tests for checkpoint serialization — unit tests without S3."""
import json
import numpy as np
import pytest
from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate
from detect.checkpoint.serializer import (
serialize_state,
deserialize_state,
from core.schema.serializers._common import safe_construct
from core.schema.serializers.detect_pipeline import (
serialize_frame_meta,
serialize_text_candidate,
serialize_text_candidates,
serialize_dataclass_list,
deserialize_text_candidate,
deserialize_text_candidates,
deserialize_bounding_box,
deserialize_brand_detection,
deserialize_pipeline_stats,
)
def _make_frame(seq: int = 0, w: int = 100, h: int = 80) -> Frame:
image = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
return Frame(
sequence=seq,
chunk_id=0,
timestamp=float(seq) * 0.5,
image=image,
perceptual_hash=f"hash_{seq}",
sequence=seq, chunk_id=0, timestamp=float(seq) * 0.5,
image=image, perceptual_hash=f"hash_{seq}",
)
@@ -35,13 +39,8 @@ def _make_candidate(frame: Frame, text: str = "NIKE") -> TextCandidate:
def _make_detection(brand: str = "Nike", timestamp: float = 1.0) -> BrandDetection:
return BrandDetection(
brand=brand,
timestamp=timestamp,
duration=0.5,
confidence=0.92,
source="ocr",
content_type="soccer_broadcast",
frame_ref=0,
brand=brand, timestamp=timestamp, duration=0.5,
confidence=0.92, source="ocr", content_type="soccer_broadcast", frame_ref=0,
)
@@ -81,102 +80,84 @@ def test_deserialize_text_candidate():
assert restored.text == "ADIDAS"
assert restored.ocr_confidence == 0.85
assert restored.frame is frame # same object reference
assert restored.frame is frame
assert restored.bbox.x == 10
# --- Full state round-trip ---
# --- BoundingBox ---
def test_state_round_trip():
frames = [_make_frame(seq=i) for i in range(3)]
filtered = frames[:2]
def test_bounding_box_round_trip():
box = _make_box(x=15, y=25, w=40, h=30)
serialized = serialize_dataclass_list([box])[0]
restored = deserialize_bounding_box(serialized)
box = _make_box()
boxes_by_frame = {0: [box], 1: [box]}
assert restored.x == 15
assert restored.w == 40
assert restored.confidence == 0.9
candidates = [_make_candidate(frames[0], "NIKE"), _make_candidate(frames[1], "EMIRATES")]
unresolved = [_make_candidate(frames[2], "unknown")]
detections = [_make_detection("Nike", 0.5), _make_detection("Emirates", 1.0)]
# --- BrandDetection ---
def test_brand_detection_round_trip():
det = _make_detection("Emirates", 3.5)
serialized = serialize_dataclass_list([det])[0]
restored = deserialize_brand_detection(serialized)
assert restored.brand == "Emirates"
assert restored.timestamp == 3.5
assert restored.source == "ocr"
# --- PipelineStats ---
def test_pipeline_stats_round_trip():
stats = PipelineStats(
frames_extracted=3,
frames_after_scene_filter=2,
regions_detected=2,
regions_resolved_by_ocr=2,
cloud_llm_calls=1,
estimated_cloud_cost_usd=0.003,
frames_extracted=120, cloud_llm_calls=3,
estimated_cloud_cost_usd=0.005,
)
serialized = serialize_dataclass_list([stats])[0]
restored = deserialize_pipeline_stats(serialized)
state = {
"job_id": "test-123",
"video_path": "/tmp/test.mp4",
"profile_name": "soccer_broadcast",
"config_overrides": {"ocr": {"min_confidence": 0.3}},
"frames": frames,
"filtered_frames": filtered,
"boxes_by_frame": boxes_by_frame,
"text_candidates": candidates,
"unresolved_candidates": unresolved,
"detections": detections,
"stats": stats,
assert restored.frames_extracted == 120
assert restored.cloud_llm_calls == 3
assert restored.estimated_cloud_cost_usd == 0.005
# --- safe_construct tolerates schema changes ---
def test_safe_construct_ignores_unknown_fields():
data = {"x": 1, "y": 2, "w": 3, "h": 4, "confidence": 0.5, "label": "t", "extra_field": 99}
box = safe_construct(BoundingBox, data)
assert box.x == 1
assert not hasattr(box, "extra_field")
def test_safe_construct_uses_defaults():
data = {"frames_extracted": 50}
stats = safe_construct(PipelineStats, data)
assert stats.frames_extracted == 50
assert stats.cloud_llm_calls == 0 # default
# --- JSON compatibility ---
def test_all_serialized_is_json_compatible():
frame = _make_frame()
candidate = _make_candidate(frame)
detection = _make_detection()
stats = PipelineStats(frames_extracted=10)
all_data = {
"frame_meta": serialize_frame_meta(frame),
"candidate": serialize_text_candidate(candidate),
"detection": serialize_dataclass_list([detection])[0],
"stats": serialize_dataclass_list([stats])[0],
}
manifest = {f.sequence: f"s3://fake/frames/{f.sequence}.jpg" for f in frames}
# Serialize
serialized = serialize_state(state, manifest)
# Verify JSON-compatible (no numpy, no Frame objects)
import json
json_str = json.dumps(serialized, default=str)
json_str = json.dumps(all_data, default=str)
assert len(json_str) > 0
# Deserialize with the original frames (simulating frame load from S3)
restored = deserialize_state(serialized, frames)
# Verify round-trip
assert restored["job_id"] == "test-123"
assert restored["video_path"] == "/tmp/test.mp4"
assert restored["profile_name"] == "soccer_broadcast"
assert restored["config_overrides"] == {"ocr": {"min_confidence": 0.3}}
assert len(restored["frames"]) == 3
assert len(restored["filtered_frames"]) == 2
assert len(restored["boxes_by_frame"]) == 2
assert len(restored["text_candidates"]) == 2
assert len(restored["unresolved_candidates"]) == 1
assert len(restored["detections"]) == 2
restored_stats = restored["stats"]
assert restored_stats.frames_extracted == 3
assert restored_stats.cloud_llm_calls == 1
assert restored_stats.estimated_cloud_cost_usd == 0.003
# TextCandidate frame references should point to actual Frame objects
tc = restored["text_candidates"][0]
assert tc.frame is frames[0]
assert tc.text == "NIKE"
def test_state_round_trip_empty():
"""Empty state should serialize/deserialize cleanly."""
state = {
"job_id": "empty-job",
"video_path": "",
"profile_name": "soccer_broadcast",
"frames": [],
"filtered_frames": [],
"boxes_by_frame": {},
"text_candidates": [],
"unresolved_candidates": [],
"detections": [],
"stats": PipelineStats(),
}
serialized = serialize_state(state, {})
restored = deserialize_state(serialized, [])
assert restored["job_id"] == "empty-job"
assert len(restored["frames"]) == 0
assert len(restored["detections"]) == 0
assert restored["stats"].frames_extracted == 0
roundtrip = json.loads(json_str)
assert roundtrip["frame_meta"]["sequence"] == frame.sequence

View File

@@ -0,0 +1,58 @@
"""Tests for the stage registry."""
from detect.stages import list_stages, get_stage, get_palette
EXPECTED_STAGES = [
"extract_frames", "filter_scenes", "detect_objects", "run_ocr",
"match_brands", "escalate_vlm", "escalate_cloud", "compile_report",
]
def test_all_stages_registered():
stages = list_stages()
names = [s.name for s in stages]
for expected in EXPECTED_STAGES:
assert expected in names, f"Missing stage: {expected}"
def test_stage_has_io():
for name in EXPECTED_STAGES:
stage = get_stage(name)
assert len(stage.io.writes) > 0, f"{name} has no writes"
assert stage.label, f"{name} has no label"
assert stage.description, f"{name} has no description"
def test_stage_has_serialization():
for name in EXPECTED_STAGES:
stage = get_stage(name)
assert stage.serialize_fn is not None, f"{name} has no serialize_fn"
assert stage.deserialize_fn is not None, f"{name} has no deserialize_fn"
def test_palette_groups():
palette = get_palette()
assert len(palette) > 0
all_stages = []
for category, stages in palette.items():
assert isinstance(category, str)
all_stages.extend(stages)
assert len(all_stages) == len(EXPECTED_STAGES)
def test_io_chain_valid():
"""Each stage's reads should be satisfied by previous stages' writes."""
stages = [get_stage(name) for name in EXPECTED_STAGES]
available_keys = {"video_path", "job_id", "profile_name",
"source_asset_id", "session_brands", "stats",
"config_overrides"}
for stage in stages:
for read_key in stage.io.reads:
assert read_key in available_keys, (
f"Stage {stage.name} reads '{read_key}' but no previous stage writes it. "
f"Available: {available_keys}"
)
for write_key in stage.io.writes:
available_keys.add(write_key)