a
This commit is contained in:
@@ -1,13 +1,9 @@
|
||||
"""
|
||||
Runner bridge — checkpoint hook called by PipelineRunner after each stage.
|
||||
|
||||
Owns the per-job state (timeline, frame manifest, checkpoint chain) that
|
||||
the runner shouldn't know about.
|
||||
|
||||
Timeline and Job are independent entities:
|
||||
- One Timeline can serve multiple Jobs (re-run with different params)
|
||||
- One Job operates on one Timeline (set after frame extraction)
|
||||
- Checkpoints belong to Timeline, tagged with the Job that created them
|
||||
Saves a checkpoint + stage output after each stage completes.
|
||||
Timeline and Job are independent: timeline_id and job_id come from
|
||||
the pipeline state (set at job creation time).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -16,63 +12,37 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Per-job state
|
||||
_timeline_id: dict[str, str] = {}
|
||||
_frames_manifest: dict[str, dict[int, str]] = {}
|
||||
# Per-job state: tracks the latest checkpoint so we can chain parent → child
|
||||
_latest_checkpoint: dict[str, str] = {}
|
||||
|
||||
|
||||
def reset_checkpoint_state(job_id: str):
|
||||
"""Clean up per-job checkpoint state. Called when pipeline finishes."""
|
||||
_timeline_id.pop(job_id, None)
|
||||
_frames_manifest.pop(job_id, None)
|
||||
_latest_checkpoint.pop(job_id, None)
|
||||
|
||||
|
||||
def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict):
|
||||
"""
|
||||
Save a checkpoint after a stage completes.
|
||||
Save a checkpoint + stage output after a stage completes.
|
||||
|
||||
Called by the runner. Handles:
|
||||
- Timeline creation (once, on extract_frames)
|
||||
- Frame upload (via create_timeline)
|
||||
- Stage output serialization (via stage registry)
|
||||
- Checkpoint chain (parent → child)
|
||||
- Stage output as separate row in StageOutput table
|
||||
"""
|
||||
if not job_id:
|
||||
return
|
||||
|
||||
from .storage import create_timeline, save_stage_output
|
||||
timeline_id = state.get("timeline_id", "")
|
||||
if not timeline_id:
|
||||
logger.warning("No timeline_id in state for job %s, skipping checkpoint", job_id)
|
||||
return
|
||||
|
||||
from .storage import save_checkpoint, save_stage_output
|
||||
from core.detect.stages.base import _REGISTRY
|
||||
|
||||
merged = {**state, **result}
|
||||
|
||||
# On extract_frames: create Timeline + upload frames + root checkpoint
|
||||
if stage_name == "extract_frames" and job_id not in _timeline_id:
|
||||
frames = merged.get("frames", [])
|
||||
video_path = merged.get("video_path", "")
|
||||
profile_name = merged.get("profile_name", "")
|
||||
|
||||
tid, cid = create_timeline(
|
||||
source_video=video_path,
|
||||
profile_name=profile_name,
|
||||
frames=frames,
|
||||
)
|
||||
_timeline_id[job_id] = tid
|
||||
_latest_checkpoint[job_id] = cid
|
||||
logger.info("Job %s → Timeline %s (root checkpoint %s)", job_id, tid, cid)
|
||||
|
||||
# Emit timeline_id via SSE so the UI can use it for checkpoint loads
|
||||
from core.detect import emit
|
||||
emit.log(job_id, "Checkpoint", "INFO", f"timeline_id={tid}")
|
||||
return
|
||||
|
||||
# For subsequent stages: save checkpoint on the timeline
|
||||
tid = _timeline_id.get(job_id)
|
||||
if not tid:
|
||||
logger.warning("No timeline for job %s, skipping checkpoint", job_id)
|
||||
return
|
||||
|
||||
# Serialize stage output using the stage's serialize_fn if available
|
||||
stage_cls = _REGISTRY.get(stage_name)
|
||||
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
||||
@@ -81,17 +51,41 @@ def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: di
|
||||
else:
|
||||
output_json = {}
|
||||
|
||||
# Convert stats dataclass to dict for JSONB storage
|
||||
import dataclasses
|
||||
raw_stats = state.get("stats", {})
|
||||
if dataclasses.is_dataclass(raw_stats):
|
||||
stats_dict = dataclasses.asdict(raw_stats)
|
||||
elif isinstance(raw_stats, dict):
|
||||
stats_dict = raw_stats
|
||||
else:
|
||||
stats_dict = {}
|
||||
|
||||
# Save checkpoint (lightweight tree node)
|
||||
parent_id = _latest_checkpoint.get(job_id)
|
||||
new_checkpoint_id = save_stage_output(
|
||||
timeline_id=tid,
|
||||
parent_checkpoint_id=parent_id,
|
||||
checkpoint_id = save_checkpoint(
|
||||
timeline_id=timeline_id,
|
||||
stage_name=stage_name,
|
||||
output_json=output_json,
|
||||
parent_checkpoint_id=parent_id,
|
||||
config_overrides=state.get("config_overrides"),
|
||||
stats=stats_dict,
|
||||
job_id=job_id,
|
||||
)
|
||||
_latest_checkpoint[job_id] = new_checkpoint_id
|
||||
_latest_checkpoint[job_id] = checkpoint_id
|
||||
|
||||
# Save stage output (separate row, upsert by job+stage)
|
||||
if output_json:
|
||||
save_stage_output(
|
||||
job_id=job_id,
|
||||
timeline_id=timeline_id,
|
||||
stage_name=stage_name,
|
||||
output=output_json,
|
||||
checkpoint_id=checkpoint_id,
|
||||
)
|
||||
|
||||
logger.info("Checkpoint %s + output for stage %s (job %s)", checkpoint_id, stage_name, job_id)
|
||||
|
||||
|
||||
def get_timeline_id(job_id: str) -> str | None:
|
||||
"""Get the timeline_id for a running job. Used by the UI to load checkpoints."""
|
||||
return _timeline_id.get(job_id)
|
||||
def get_latest_checkpoint(job_id: str) -> str | None:
|
||||
"""Get the latest checkpoint_id for a running job."""
|
||||
return _latest_checkpoint.get(job_id)
|
||||
|
||||
Reference in New Issue
Block a user