"""State serialization — DetectState ↔ JSON-compatible dict.""" from __future__ import annotations import dataclasses from detect.models import ( BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate, ) # --------------------------------------------------------------------------- # Serialize helpers # --------------------------------------------------------------------------- def serialize_frame_meta(frame: Frame) -> dict: meta = { "sequence": frame.sequence, "chunk_id": frame.chunk_id, "timestamp": frame.timestamp, "perceptual_hash": frame.perceptual_hash, } return meta def serialize_text_candidate(tc: TextCandidate) -> dict: bbox_dict = dataclasses.asdict(tc.bbox) candidate = { "frame_sequence": tc.frame.sequence, "bbox": bbox_dict, "text": tc.text, "ocr_confidence": tc.ocr_confidence, } return candidate def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict: """ Serialize DetectState to a JSON-compatible dict. Frame images are replaced with S3 key references. TextCandidate.frame references become frame_sequence integers. """ frames = state.get("frames", []) filtered = state.get("filtered_frames", []) manifest_strs = {str(k): v for k, v in frames_manifest.items()} frames_meta = [serialize_frame_meta(f) for f in frames] filtered_seqs = [f.sequence for f in filtered] boxes_serialized = {} for seq, boxes in state.get("boxes_by_frame", {}).items(): boxes_serialized[str(seq)] = [dataclasses.asdict(b) for b in boxes] text_candidates = [serialize_text_candidate(tc) for tc in state.get("text_candidates", [])] unresolved = [serialize_text_candidate(tc) for tc in state.get("unresolved_candidates", [])] detections = [dataclasses.asdict(d) for d in state.get("detections", [])] stats = dataclasses.asdict(state.get("stats", PipelineStats())) checkpoint = { "job_id": state.get("job_id", ""), "video_path": state.get("video_path", ""), "profile_name": state.get("profile_name", ""), "config_overrides": state.get("config_overrides", {}), "frames_manifest": manifest_strs, "frames_meta": frames_meta, "filtered_frame_sequences": filtered_seqs, "boxes_by_frame": boxes_serialized, "text_candidates": text_candidates, "unresolved_candidates": unresolved, "detections": detections, "stats": stats, } return checkpoint # --------------------------------------------------------------------------- # Deserialize helpers # --------------------------------------------------------------------------- def deserialize_text_candidate(d: dict, frame_map: dict[int, Frame]) -> TextCandidate: frame = frame_map[d["frame_sequence"]] bbox = BoundingBox(**d["bbox"]) candidate = TextCandidate( frame=frame, bbox=bbox, text=d["text"], ocr_confidence=d["ocr_confidence"], ) return candidate def deserialize_state(checkpoint: dict, frames: list[Frame]) -> dict: """Reconstitute DetectState from a checkpoint dict + loaded frames.""" frame_map = {f.sequence: f for f in frames} filtered_seqs = set(checkpoint.get("filtered_frame_sequences", [])) filtered_frames = [f for f in frames if f.sequence in filtered_seqs] boxes_by_frame = {} for seq_str, box_dicts in checkpoint.get("boxes_by_frame", {}).items(): seq = int(seq_str) boxes_by_frame[seq] = [BoundingBox(**b) for b in box_dicts] text_candidates = [ deserialize_text_candidate(d, frame_map) for d in checkpoint.get("text_candidates", []) ] unresolved_candidates = [ deserialize_text_candidate(d, frame_map) for d in checkpoint.get("unresolved_candidates", []) ] detections = [BrandDetection(**d) for d in checkpoint.get("detections", [])] stats = PipelineStats(**checkpoint.get("stats", {})) state = { "job_id": checkpoint.get("job_id", ""), "video_path": checkpoint.get("video_path", ""), "profile_name": checkpoint.get("profile_name", ""), "config_overrides": checkpoint.get("config_overrides", {}), "frames": frames, "filtered_frames": filtered_frames, "boxes_by_frame": boxes_by_frame, "text_candidates": text_candidates, "unresolved_candidates": unresolved_candidates, "detections": detections, "stats": stats, } return state