""" State serialization — DetectState ↔ JSON-compatible dict. Delegates to each stage's serialize_fn/deserialize_fn via the registry. This file has no model-specific knowledge — stages own their data format. The only things serialized here are the "envelope" fields (job_id, video_path, etc.) that don't belong to any stage. """ from __future__ import annotations from core.schema.serializers._common import serialize_dataclass from core.schema.serializers.pipeline import ( deserialize_pipeline_stats, deserialize_text_candidates, ) # Envelope fields — not owned by any stage, always present ENVELOPE_KEYS = ["job_id", "video_path", "profile_name", "config_overrides"] def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict: """ Serialize DetectState to a JSON-compatible dict. Calls each registered stage's serialize_fn for stage-owned data. Envelope fields (job_id, etc.) are copied directly. """ from detect.stages.base import _REGISTRY checkpoint = {} # Envelope for key in ENVELOPE_KEYS: default = {} if key == "config_overrides" else "" checkpoint[key] = state.get(key, default) # Frames manifest (needed by frame-loading stages) checkpoint["frames_manifest"] = {str(k): v for k, v in frames_manifest.items()} # Stats (shared across stages, not owned by one) stats = state.get("stats") if stats is not None: checkpoint["stats"] = serialize_dataclass(stats) else: checkpoint["stats"] = {} # Per-stage data for name, stage_def in _REGISTRY.items(): if stage_def.serialize_fn is None: continue job_id = state.get("job_id", "") stage_data = stage_def.serialize_fn(state, job_id) checkpoint[f"stage_{name}"] = stage_data return checkpoint def deserialize_state(checkpoint: dict, frames: list) -> dict: """ Reconstitute DetectState from a checkpoint dict + loaded frames. Calls each stage's deserialize_fn to restore stage-owned data. """ from detect.stages.base import _REGISTRY frame_map = {f.sequence: f for f in frames} state = {} # Envelope for key in ENVELOPE_KEYS: default = {} if key == "config_overrides" else "" state[key] = checkpoint.get(key, default) # Frames (always present, loaded externally) state["frames"] = frames # Stats state["stats"] = deserialize_pipeline_stats(checkpoint.get("stats", {})) # Per-stage data for name, stage_def in _REGISTRY.items(): if stage_def.deserialize_fn is None: continue stage_key = f"stage_{name}" if stage_key not in checkpoint: continue job_id = state.get("job_id", "") stage_data = stage_def.deserialize_fn(checkpoint[stage_key], job_id) for k, v in stage_data.items(): if k == "_filtered_sequences": # Reconnect filtered frames from sequence list seq_set = set(v) state["filtered_frames"] = [f for f in frames if f.sequence in seq_set] elif k.endswith("_raw"): # Raw text candidates need frame reference reconnection real_key = k.removeprefix("_").removesuffix("_raw") state[real_key] = deserialize_text_candidates(v, frame_map) else: state[k] = v return state