schema clean up and refactor
This commit is contained in:
@@ -1,133 +1,108 @@
|
||||
"""State serialization — DetectState ↔ JSON-compatible dict."""
|
||||
"""
|
||||
State serialization — DetectState ↔ JSON-compatible dict.
|
||||
|
||||
Delegates to each stage's serialize_fn/deserialize_fn via the registry.
|
||||
This file has no model-specific knowledge — stages own their data format.
|
||||
|
||||
The only things serialized here are the "envelope" fields (job_id, video_path, etc.)
|
||||
that don't belong to any stage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
|
||||
from detect.models import (
|
||||
BoundingBox,
|
||||
BrandDetection,
|
||||
Frame,
|
||||
PipelineStats,
|
||||
TextCandidate,
|
||||
from core.schema.serializers._common import serialize_dataclass
|
||||
from core.schema.serializers.detect_pipeline import (
|
||||
deserialize_pipeline_stats,
|
||||
deserialize_text_candidates,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialize helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def serialize_frame_meta(frame: Frame) -> dict:
|
||||
meta = {
|
||||
"sequence": frame.sequence,
|
||||
"chunk_id": frame.chunk_id,
|
||||
"timestamp": frame.timestamp,
|
||||
"perceptual_hash": frame.perceptual_hash,
|
||||
}
|
||||
return meta
|
||||
|
||||
|
||||
def serialize_text_candidate(tc: TextCandidate) -> dict:
|
||||
bbox_dict = dataclasses.asdict(tc.bbox)
|
||||
candidate = {
|
||||
"frame_sequence": tc.frame.sequence,
|
||||
"bbox": bbox_dict,
|
||||
"text": tc.text,
|
||||
"ocr_confidence": tc.ocr_confidence,
|
||||
}
|
||||
return candidate
|
||||
# Envelope fields — not owned by any stage, always present
|
||||
ENVELOPE_KEYS = ["job_id", "video_path", "profile_name", "config_overrides"]
|
||||
|
||||
|
||||
def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
|
||||
"""
|
||||
Serialize DetectState to a JSON-compatible dict.
|
||||
|
||||
Frame images are replaced with S3 key references.
|
||||
TextCandidate.frame references become frame_sequence integers.
|
||||
Calls each registered stage's serialize_fn for stage-owned data.
|
||||
Envelope fields (job_id, etc.) are copied directly.
|
||||
"""
|
||||
frames = state.get("frames", [])
|
||||
filtered = state.get("filtered_frames", [])
|
||||
from detect.stages.base import _REGISTRY
|
||||
|
||||
manifest_strs = {str(k): v for k, v in frames_manifest.items()}
|
||||
frames_meta = [serialize_frame_meta(f) for f in frames]
|
||||
filtered_seqs = [f.sequence for f in filtered]
|
||||
checkpoint = {}
|
||||
|
||||
boxes_serialized = {}
|
||||
for seq, boxes in state.get("boxes_by_frame", {}).items():
|
||||
boxes_serialized[str(seq)] = [dataclasses.asdict(b) for b in boxes]
|
||||
# Envelope
|
||||
for key in ENVELOPE_KEYS:
|
||||
default = {} if key == "config_overrides" else ""
|
||||
checkpoint[key] = state.get(key, default)
|
||||
|
||||
text_candidates = [serialize_text_candidate(tc) for tc in state.get("text_candidates", [])]
|
||||
unresolved = [serialize_text_candidate(tc) for tc in state.get("unresolved_candidates", [])]
|
||||
detections = [dataclasses.asdict(d) for d in state.get("detections", [])]
|
||||
stats = dataclasses.asdict(state.get("stats", PipelineStats()))
|
||||
# Frames manifest (needed by frame-loading stages)
|
||||
checkpoint["frames_manifest"] = {str(k): v for k, v in frames_manifest.items()}
|
||||
|
||||
# Stats (shared across stages, not owned by one)
|
||||
stats = state.get("stats")
|
||||
if stats is not None:
|
||||
checkpoint["stats"] = serialize_dataclass(stats)
|
||||
else:
|
||||
checkpoint["stats"] = {}
|
||||
|
||||
# Per-stage data
|
||||
for name, stage_def in _REGISTRY.items():
|
||||
if stage_def.serialize_fn is None:
|
||||
continue
|
||||
job_id = state.get("job_id", "")
|
||||
stage_data = stage_def.serialize_fn(state, job_id)
|
||||
checkpoint[f"stage_{name}"] = stage_data
|
||||
|
||||
checkpoint = {
|
||||
"job_id": state.get("job_id", ""),
|
||||
"video_path": state.get("video_path", ""),
|
||||
"profile_name": state.get("profile_name", ""),
|
||||
"config_overrides": state.get("config_overrides", {}),
|
||||
"frames_manifest": manifest_strs,
|
||||
"frames_meta": frames_meta,
|
||||
"filtered_frame_sequences": filtered_seqs,
|
||||
"boxes_by_frame": boxes_serialized,
|
||||
"text_candidates": text_candidates,
|
||||
"unresolved_candidates": unresolved,
|
||||
"detections": detections,
|
||||
"stats": stats,
|
||||
}
|
||||
return checkpoint
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deserialize helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
def deserialize_state(checkpoint: dict, frames: list) -> dict:
|
||||
"""
|
||||
Reconstitute DetectState from a checkpoint dict + loaded frames.
|
||||
|
||||
def deserialize_text_candidate(d: dict, frame_map: dict[int, Frame]) -> TextCandidate:
|
||||
frame = frame_map[d["frame_sequence"]]
|
||||
bbox = BoundingBox(**d["bbox"])
|
||||
candidate = TextCandidate(
|
||||
frame=frame,
|
||||
bbox=bbox,
|
||||
text=d["text"],
|
||||
ocr_confidence=d["ocr_confidence"],
|
||||
)
|
||||
return candidate
|
||||
Calls each stage's deserialize_fn to restore stage-owned data.
|
||||
"""
|
||||
from detect.stages.base import _REGISTRY
|
||||
|
||||
|
||||
def deserialize_state(checkpoint: dict, frames: list[Frame]) -> dict:
|
||||
"""Reconstitute DetectState from a checkpoint dict + loaded frames."""
|
||||
frame_map = {f.sequence: f for f in frames}
|
||||
|
||||
filtered_seqs = set(checkpoint.get("filtered_frame_sequences", []))
|
||||
filtered_frames = [f for f in frames if f.sequence in filtered_seqs]
|
||||
state = {}
|
||||
|
||||
boxes_by_frame = {}
|
||||
for seq_str, box_dicts in checkpoint.get("boxes_by_frame", {}).items():
|
||||
seq = int(seq_str)
|
||||
boxes_by_frame[seq] = [BoundingBox(**b) for b in box_dicts]
|
||||
# Envelope
|
||||
for key in ENVELOPE_KEYS:
|
||||
default = {} if key == "config_overrides" else ""
|
||||
state[key] = checkpoint.get(key, default)
|
||||
|
||||
text_candidates = [
|
||||
deserialize_text_candidate(d, frame_map)
|
||||
for d in checkpoint.get("text_candidates", [])
|
||||
]
|
||||
unresolved_candidates = [
|
||||
deserialize_text_candidate(d, frame_map)
|
||||
for d in checkpoint.get("unresolved_candidates", [])
|
||||
]
|
||||
detections = [BrandDetection(**d) for d in checkpoint.get("detections", [])]
|
||||
stats = PipelineStats(**checkpoint.get("stats", {}))
|
||||
# Frames (always present, loaded externally)
|
||||
state["frames"] = frames
|
||||
|
||||
# Stats
|
||||
state["stats"] = deserialize_pipeline_stats(checkpoint.get("stats", {}))
|
||||
|
||||
# Per-stage data
|
||||
for name, stage_def in _REGISTRY.items():
|
||||
if stage_def.deserialize_fn is None:
|
||||
continue
|
||||
|
||||
stage_key = f"stage_{name}"
|
||||
if stage_key not in checkpoint:
|
||||
continue
|
||||
|
||||
job_id = state.get("job_id", "")
|
||||
stage_data = stage_def.deserialize_fn(checkpoint[stage_key], job_id)
|
||||
|
||||
for k, v in stage_data.items():
|
||||
if k == "_filtered_sequences":
|
||||
# Reconnect filtered frames from sequence list
|
||||
seq_set = set(v)
|
||||
state["filtered_frames"] = [f for f in frames if f.sequence in seq_set]
|
||||
elif k.endswith("_raw"):
|
||||
# Raw text candidates need frame reference reconnection
|
||||
real_key = k.removeprefix("_").removesuffix("_raw")
|
||||
state[real_key] = deserialize_text_candidates(v, frame_map)
|
||||
else:
|
||||
state[k] = v
|
||||
|
||||
state = {
|
||||
"job_id": checkpoint.get("job_id", ""),
|
||||
"video_path": checkpoint.get("video_path", ""),
|
||||
"profile_name": checkpoint.get("profile_name", ""),
|
||||
"config_overrides": checkpoint.get("config_overrides", {}),
|
||||
"frames": frames,
|
||||
"filtered_frames": filtered_frames,
|
||||
"boxes_by_frame": boxes_by_frame,
|
||||
"text_candidates": text_candidates,
|
||||
"unresolved_candidates": unresolved_candidates,
|
||||
"detections": detections,
|
||||
"stats": stats,
|
||||
}
|
||||
return state
|
||||
|
||||
Reference in New Issue
Block a user