98 lines
3.2 KiB
Python
98 lines
3.2 KiB
Python
"""
|
|
Runner bridge — checkpoint hook called by PipelineRunner after each stage.
|
|
|
|
Owns the per-job state (timeline, frame manifest, checkpoint chain) that
|
|
the runner shouldn't know about.
|
|
|
|
Timeline and Job are independent entities:
|
|
- One Timeline can serve multiple Jobs (re-run with different params)
|
|
- One Job operates on one Timeline (set after frame extraction)
|
|
- Checkpoints belong to Timeline, tagged with the Job that created them
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Per-job state
|
|
_timeline_id: dict[str, str] = {}
|
|
_frames_manifest: dict[str, dict[int, str]] = {}
|
|
_latest_checkpoint: dict[str, str] = {}
|
|
|
|
|
|
def reset_checkpoint_state(job_id: str):
|
|
"""Clean up per-job checkpoint state. Called when pipeline finishes."""
|
|
_timeline_id.pop(job_id, None)
|
|
_frames_manifest.pop(job_id, None)
|
|
_latest_checkpoint.pop(job_id, None)
|
|
|
|
|
|
def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict):
|
|
"""
|
|
Save a checkpoint after a stage completes.
|
|
|
|
Called by the runner. Handles:
|
|
- Timeline creation (once, on extract_frames)
|
|
- Frame upload (via create_timeline)
|
|
- Stage output serialization (via stage registry)
|
|
- Checkpoint chain (parent → child)
|
|
"""
|
|
if not job_id:
|
|
return
|
|
|
|
from .storage import create_timeline, save_stage_output
|
|
from core.detect.stages.base import _REGISTRY
|
|
|
|
merged = {**state, **result}
|
|
|
|
# On extract_frames: create Timeline + upload frames + root checkpoint
|
|
if stage_name == "extract_frames" and job_id not in _timeline_id:
|
|
frames = merged.get("frames", [])
|
|
video_path = merged.get("video_path", "")
|
|
profile_name = merged.get("profile_name", "")
|
|
|
|
tid, cid = create_timeline(
|
|
source_video=video_path,
|
|
profile_name=profile_name,
|
|
frames=frames,
|
|
)
|
|
_timeline_id[job_id] = tid
|
|
_latest_checkpoint[job_id] = cid
|
|
logger.info("Job %s → Timeline %s (root checkpoint %s)", job_id, tid, cid)
|
|
|
|
# Emit timeline_id via SSE so the UI can use it for checkpoint loads
|
|
from core.detect import emit
|
|
emit.log(job_id, "Checkpoint", "INFO", f"timeline_id={tid}")
|
|
return
|
|
|
|
# For subsequent stages: save checkpoint on the timeline
|
|
tid = _timeline_id.get(job_id)
|
|
if not tid:
|
|
logger.warning("No timeline for job %s, skipping checkpoint", job_id)
|
|
return
|
|
|
|
# Serialize stage output using the stage's serialize_fn if available
|
|
stage_cls = _REGISTRY.get(stage_name)
|
|
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
|
if serialize_fn:
|
|
output_json = serialize_fn(merged, job_id)
|
|
else:
|
|
output_json = {}
|
|
|
|
parent_id = _latest_checkpoint.get(job_id)
|
|
new_checkpoint_id = save_stage_output(
|
|
timeline_id=tid,
|
|
parent_checkpoint_id=parent_id,
|
|
stage_name=stage_name,
|
|
output_json=output_json,
|
|
job_id=job_id,
|
|
)
|
|
_latest_checkpoint[job_id] = new_checkpoint_id
|
|
|
|
|
|
def get_timeline_id(job_id: str) -> str | None:
|
|
"""Get the timeline_id for a running job. Used by the UI to load checkpoints."""
|
|
return _timeline_id.get(job_id)
|