Files
mediaproc/detect/checkpoint/runner_bridge.py
2026-03-28 10:05:59 -03:00

65 lines
1.9 KiB
Python

"""
Runner bridge — checkpoint hook called by PipelineRunner after each stage.
Owns the per-job state (frame manifest cache, checkpoint chain) that
the runner shouldn't know about.
"""
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
# Per-job state
_frames_manifest: dict[str, dict[int, str]] = {}
_latest_checkpoint: dict[str, str] = {}
def reset_checkpoint_state(job_id: str):
"""Clean up per-job checkpoint state. Called when pipeline finishes."""
_frames_manifest.pop(job_id, None)
_latest_checkpoint.pop(job_id, None)
def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict):
"""
Save a checkpoint after a stage completes.
Called by the runner. Handles:
- Frame upload (once, on first stage)
- Stage output serialization (via stage registry)
- Checkpoint chain (parent → child)
"""
if not job_id:
return
from .storage import save_stage_output
from .frames import save_frames
from detect.stages.base import _REGISTRY
merged = {**state, **result}
# Save frames once (first stage that produces them)
manifest = _frames_manifest.get(job_id)
if manifest is None and stage_name == "extract_frames":
manifest = save_frames(job_id, merged.get("frames", []))
_frames_manifest[job_id] = manifest
# Serialize stage output using the stage's serialize_fn if available
stage_cls = _REGISTRY.get(stage_name)
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
if serialize_fn:
output_json = serialize_fn(merged, job_id)
else:
output_json = {}
parent_id = _latest_checkpoint.get(job_id)
new_checkpoint_id = save_stage_output(
timeline_id=job_id,
parent_checkpoint_id=parent_id,
stage_name=stage_name,
output_json=output_json,
)
_latest_checkpoint[job_id] = new_checkpoint_id