Files
mediaproc/detect/checkpoint/serializer.py
2026-03-26 04:40:00 -03:00

134 lines
4.5 KiB
Python

"""State serialization — DetectState ↔ JSON-compatible dict."""
from __future__ import annotations
import dataclasses
from detect.models import (
BoundingBox,
BrandDetection,
Frame,
PipelineStats,
TextCandidate,
)
# ---------------------------------------------------------------------------
# Serialize helpers
# ---------------------------------------------------------------------------
def serialize_frame_meta(frame: Frame) -> dict:
meta = {
"sequence": frame.sequence,
"chunk_id": frame.chunk_id,
"timestamp": frame.timestamp,
"perceptual_hash": frame.perceptual_hash,
}
return meta
def serialize_text_candidate(tc: TextCandidate) -> dict:
bbox_dict = dataclasses.asdict(tc.bbox)
candidate = {
"frame_sequence": tc.frame.sequence,
"bbox": bbox_dict,
"text": tc.text,
"ocr_confidence": tc.ocr_confidence,
}
return candidate
def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
"""
Serialize DetectState to a JSON-compatible dict.
Frame images are replaced with S3 key references.
TextCandidate.frame references become frame_sequence integers.
"""
frames = state.get("frames", [])
filtered = state.get("filtered_frames", [])
manifest_strs = {str(k): v for k, v in frames_manifest.items()}
frames_meta = [serialize_frame_meta(f) for f in frames]
filtered_seqs = [f.sequence for f in filtered]
boxes_serialized = {}
for seq, boxes in state.get("boxes_by_frame", {}).items():
boxes_serialized[str(seq)] = [dataclasses.asdict(b) for b in boxes]
text_candidates = [serialize_text_candidate(tc) for tc in state.get("text_candidates", [])]
unresolved = [serialize_text_candidate(tc) for tc in state.get("unresolved_candidates", [])]
detections = [dataclasses.asdict(d) for d in state.get("detections", [])]
stats = dataclasses.asdict(state.get("stats", PipelineStats()))
checkpoint = {
"job_id": state.get("job_id", ""),
"video_path": state.get("video_path", ""),
"profile_name": state.get("profile_name", ""),
"config_overrides": state.get("config_overrides", {}),
"frames_manifest": manifest_strs,
"frames_meta": frames_meta,
"filtered_frame_sequences": filtered_seqs,
"boxes_by_frame": boxes_serialized,
"text_candidates": text_candidates,
"unresolved_candidates": unresolved,
"detections": detections,
"stats": stats,
}
return checkpoint
# ---------------------------------------------------------------------------
# Deserialize helpers
# ---------------------------------------------------------------------------
def deserialize_text_candidate(d: dict, frame_map: dict[int, Frame]) -> TextCandidate:
frame = frame_map[d["frame_sequence"]]
bbox = BoundingBox(**d["bbox"])
candidate = TextCandidate(
frame=frame,
bbox=bbox,
text=d["text"],
ocr_confidence=d["ocr_confidence"],
)
return candidate
def deserialize_state(checkpoint: dict, frames: list[Frame]) -> dict:
"""Reconstitute DetectState from a checkpoint dict + loaded frames."""
frame_map = {f.sequence: f for f in frames}
filtered_seqs = set(checkpoint.get("filtered_frame_sequences", []))
filtered_frames = [f for f in frames if f.sequence in filtered_seqs]
boxes_by_frame = {}
for seq_str, box_dicts in checkpoint.get("boxes_by_frame", {}).items():
seq = int(seq_str)
boxes_by_frame[seq] = [BoundingBox(**b) for b in box_dicts]
text_candidates = [
deserialize_text_candidate(d, frame_map)
for d in checkpoint.get("text_candidates", [])
]
unresolved_candidates = [
deserialize_text_candidate(d, frame_map)
for d in checkpoint.get("unresolved_candidates", [])
]
detections = [BrandDetection(**d) for d in checkpoint.get("detections", [])]
stats = PipelineStats(**checkpoint.get("stats", {}))
state = {
"job_id": checkpoint.get("job_id", ""),
"video_path": checkpoint.get("video_path", ""),
"profile_name": checkpoint.get("profile_name", ""),
"config_overrides": checkpoint.get("config_overrides", {}),
"frames": frames,
"filtered_frames": filtered_frames,
"boxes_by_frame": boxes_by_frame,
"text_candidates": text_candidates,
"unresolved_candidates": unresolved_candidates,
"detections": detections,
"stats": stats,
}
return state