This commit is contained in:
2026-03-30 09:53:10 -03:00
parent 4220b0418e
commit aac27b8504
32 changed files with 1068 additions and 329 deletions

View File

@@ -1,13 +1,9 @@
"""
Runner bridge — checkpoint hook called by PipelineRunner after each stage.
Owns the per-job state (timeline, frame manifest, checkpoint chain) that
the runner shouldn't know about.
Timeline and Job are independent entities:
- One Timeline can serve multiple Jobs (re-run with different params)
- One Job operates on one Timeline (set after frame extraction)
- Checkpoints belong to Timeline, tagged with the Job that created them
Saves a checkpoint + stage output after each stage completes.
Timeline and Job are independent: timeline_id and job_id come from
the pipeline state (set at job creation time).
"""
from __future__ import annotations
@@ -16,63 +12,37 @@ import logging
logger = logging.getLogger(__name__)
# Per-job state
_timeline_id: dict[str, str] = {}
_frames_manifest: dict[str, dict[int, str]] = {}
# Per-job state: tracks the latest checkpoint so we can chain parent → child
_latest_checkpoint: dict[str, str] = {}
def reset_checkpoint_state(job_id: str):
"""Clean up per-job checkpoint state. Called when pipeline finishes."""
_timeline_id.pop(job_id, None)
_frames_manifest.pop(job_id, None)
_latest_checkpoint.pop(job_id, None)
def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict):
"""
Save a checkpoint after a stage completes.
Save a checkpoint + stage output after a stage completes.
Called by the runner. Handles:
- Timeline creation (once, on extract_frames)
- Frame upload (via create_timeline)
- Stage output serialization (via stage registry)
- Checkpoint chain (parent → child)
- Stage output as separate row in StageOutput table
"""
if not job_id:
return
from .storage import create_timeline, save_stage_output
timeline_id = state.get("timeline_id", "")
if not timeline_id:
logger.warning("No timeline_id in state for job %s, skipping checkpoint", job_id)
return
from .storage import save_checkpoint, save_stage_output
from core.detect.stages.base import _REGISTRY
merged = {**state, **result}
# On extract_frames: create Timeline + upload frames + root checkpoint
if stage_name == "extract_frames" and job_id not in _timeline_id:
frames = merged.get("frames", [])
video_path = merged.get("video_path", "")
profile_name = merged.get("profile_name", "")
tid, cid = create_timeline(
source_video=video_path,
profile_name=profile_name,
frames=frames,
)
_timeline_id[job_id] = tid
_latest_checkpoint[job_id] = cid
logger.info("Job %s → Timeline %s (root checkpoint %s)", job_id, tid, cid)
# Emit timeline_id via SSE so the UI can use it for checkpoint loads
from core.detect import emit
emit.log(job_id, "Checkpoint", "INFO", f"timeline_id={tid}")
return
# For subsequent stages: save checkpoint on the timeline
tid = _timeline_id.get(job_id)
if not tid:
logger.warning("No timeline for job %s, skipping checkpoint", job_id)
return
# Serialize stage output using the stage's serialize_fn if available
stage_cls = _REGISTRY.get(stage_name)
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
@@ -81,17 +51,41 @@ def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: di
else:
output_json = {}
# Convert stats dataclass to dict for JSONB storage
import dataclasses
raw_stats = state.get("stats", {})
if dataclasses.is_dataclass(raw_stats):
stats_dict = dataclasses.asdict(raw_stats)
elif isinstance(raw_stats, dict):
stats_dict = raw_stats
else:
stats_dict = {}
# Save checkpoint (lightweight tree node)
parent_id = _latest_checkpoint.get(job_id)
new_checkpoint_id = save_stage_output(
timeline_id=tid,
parent_checkpoint_id=parent_id,
checkpoint_id = save_checkpoint(
timeline_id=timeline_id,
stage_name=stage_name,
output_json=output_json,
parent_checkpoint_id=parent_id,
config_overrides=state.get("config_overrides"),
stats=stats_dict,
job_id=job_id,
)
_latest_checkpoint[job_id] = new_checkpoint_id
_latest_checkpoint[job_id] = checkpoint_id
# Save stage output (separate row, upsert by job+stage)
if output_json:
save_stage_output(
job_id=job_id,
timeline_id=timeline_id,
stage_name=stage_name,
output=output_json,
checkpoint_id=checkpoint_id,
)
logger.info("Checkpoint %s + output for stage %s (job %s)", checkpoint_id, stage_name, job_id)
def get_timeline_id(job_id: str) -> str | None:
"""Get the timeline_id for a running job. Used by the UI to load checkpoints."""
return _timeline_id.get(job_id)
def get_latest_checkpoint(job_id: str) -> str | None:
"""Get the latest checkpoint_id for a running job."""
return _latest_checkpoint.get(job_id)