65 lines
1.9 KiB
Python
65 lines
1.9 KiB
Python
"""
|
|
Runner bridge — checkpoint hook called by PipelineRunner after each stage.
|
|
|
|
Owns the per-job state (frame manifest cache, checkpoint chain) that
|
|
the runner shouldn't know about.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Per-job state
|
|
_frames_manifest: dict[str, dict[int, str]] = {}
|
|
_latest_checkpoint: dict[str, str] = {}
|
|
|
|
|
|
def reset_checkpoint_state(job_id: str):
|
|
"""Clean up per-job checkpoint state. Called when pipeline finishes."""
|
|
_frames_manifest.pop(job_id, None)
|
|
_latest_checkpoint.pop(job_id, None)
|
|
|
|
|
|
def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict):
|
|
"""
|
|
Save a checkpoint after a stage completes.
|
|
|
|
Called by the runner. Handles:
|
|
- Frame upload (once, on first stage)
|
|
- Stage output serialization (via stage registry)
|
|
- Checkpoint chain (parent → child)
|
|
"""
|
|
if not job_id:
|
|
return
|
|
|
|
from .storage import save_stage_output
|
|
from .frames import save_frames
|
|
from detect.stages.base import _REGISTRY
|
|
|
|
merged = {**state, **result}
|
|
|
|
# Save frames once (first stage that produces them)
|
|
manifest = _frames_manifest.get(job_id)
|
|
if manifest is None and stage_name == "extract_frames":
|
|
manifest = save_frames(job_id, merged.get("frames", []))
|
|
_frames_manifest[job_id] = manifest
|
|
|
|
# Serialize stage output using the stage's serialize_fn if available
|
|
stage_cls = _REGISTRY.get(stage_name)
|
|
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
|
if serialize_fn:
|
|
output_json = serialize_fn(merged, job_id)
|
|
else:
|
|
output_json = {}
|
|
|
|
parent_id = _latest_checkpoint.get(job_id)
|
|
new_checkpoint_id = save_stage_output(
|
|
timeline_id=job_id,
|
|
parent_checkpoint_id=parent_id,
|
|
stage_name=stage_name,
|
|
output_json=output_json,
|
|
)
|
|
_latest_checkpoint[job_id] = new_checkpoint_id
|