109 lines
3.3 KiB
Python
109 lines
3.3 KiB
Python
"""
|
|
State serialization — DetectState ↔ JSON-compatible dict.
|
|
|
|
Delegates to each stage's serialize_fn/deserialize_fn via the registry.
|
|
This file has no model-specific knowledge — stages own their data format.
|
|
|
|
The only things serialized here are the "envelope" fields (job_id, video_path, etc.)
|
|
that don't belong to any stage.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from core.schema.serializers._common import serialize_dataclass
|
|
from core.schema.serializers.pipeline import (
|
|
deserialize_pipeline_stats,
|
|
deserialize_text_candidates,
|
|
)
|
|
|
|
|
|
# Envelope fields — not owned by any stage, always present
|
|
ENVELOPE_KEYS = ["job_id", "video_path", "profile_name", "config_overrides"]
|
|
|
|
|
|
def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
|
|
"""
|
|
Serialize DetectState to a JSON-compatible dict.
|
|
|
|
Calls each registered stage's serialize_fn for stage-owned data.
|
|
Envelope fields (job_id, etc.) are copied directly.
|
|
"""
|
|
from detect.stages.base import _REGISTRY
|
|
|
|
checkpoint = {}
|
|
|
|
# Envelope
|
|
for key in ENVELOPE_KEYS:
|
|
default = {} if key == "config_overrides" else ""
|
|
checkpoint[key] = state.get(key, default)
|
|
|
|
# Frames manifest (needed by frame-loading stages)
|
|
checkpoint["frames_manifest"] = {str(k): v for k, v in frames_manifest.items()}
|
|
|
|
# Stats (shared across stages, not owned by one)
|
|
stats = state.get("stats")
|
|
if stats is not None:
|
|
checkpoint["stats"] = serialize_dataclass(stats)
|
|
else:
|
|
checkpoint["stats"] = {}
|
|
|
|
# Per-stage data
|
|
for name, stage_def in _REGISTRY.items():
|
|
if stage_def.serialize_fn is None:
|
|
continue
|
|
job_id = state.get("job_id", "")
|
|
stage_data = stage_def.serialize_fn(state, job_id)
|
|
checkpoint[f"stage_{name}"] = stage_data
|
|
|
|
return checkpoint
|
|
|
|
|
|
def deserialize_state(checkpoint: dict, frames: list) -> dict:
|
|
"""
|
|
Reconstitute DetectState from a checkpoint dict + loaded frames.
|
|
|
|
Calls each stage's deserialize_fn to restore stage-owned data.
|
|
"""
|
|
from detect.stages.base import _REGISTRY
|
|
|
|
frame_map = {f.sequence: f for f in frames}
|
|
|
|
state = {}
|
|
|
|
# Envelope
|
|
for key in ENVELOPE_KEYS:
|
|
default = {} if key == "config_overrides" else ""
|
|
state[key] = checkpoint.get(key, default)
|
|
|
|
# Frames (always present, loaded externally)
|
|
state["frames"] = frames
|
|
|
|
# Stats
|
|
state["stats"] = deserialize_pipeline_stats(checkpoint.get("stats", {}))
|
|
|
|
# Per-stage data
|
|
for name, stage_def in _REGISTRY.items():
|
|
if stage_def.deserialize_fn is None:
|
|
continue
|
|
|
|
stage_key = f"stage_{name}"
|
|
if stage_key not in checkpoint:
|
|
continue
|
|
|
|
job_id = state.get("job_id", "")
|
|
stage_data = stage_def.deserialize_fn(checkpoint[stage_key], job_id)
|
|
|
|
for k, v in stage_data.items():
|
|
if k == "_filtered_sequences":
|
|
# Reconnect filtered frames from sequence list
|
|
seq_set = set(v)
|
|
state["filtered_frames"] = [f for f in frames if f.sequence in seq_set]
|
|
elif k.endswith("_raw"):
|
|
# Raw text candidates need frame reference reconnection
|
|
real_key = k.removeprefix("_").removesuffix("_raw")
|
|
state[real_key] = deserialize_text_candidates(v, frame_map)
|
|
else:
|
|
state[k] = v
|
|
|
|
return state
|