Files
mediaproc/core/detect/checkpoint/runner_bridge.py
2026-03-30 09:53:10 -03:00

92 lines
2.8 KiB
Python

"""
Runner bridge — checkpoint hook called by PipelineRunner after each stage.
Saves a checkpoint + stage output after each stage completes.
Timeline and Job are independent: timeline_id and job_id come from
the pipeline state (set at job creation time).
"""
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
# Per-job state: tracks the latest checkpoint so we can chain parent → child
_latest_checkpoint: dict[str, str] = {}
def reset_checkpoint_state(job_id: str):
"""Clean up per-job checkpoint state. Called when pipeline finishes."""
_latest_checkpoint.pop(job_id, None)
def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict):
"""
Save a checkpoint + stage output after a stage completes.
Called by the runner. Handles:
- Stage output serialization (via stage registry)
- Checkpoint chain (parent → child)
- Stage output as separate row in StageOutput table
"""
if not job_id:
return
timeline_id = state.get("timeline_id", "")
if not timeline_id:
logger.warning("No timeline_id in state for job %s, skipping checkpoint", job_id)
return
from .storage import save_checkpoint, save_stage_output
from core.detect.stages.base import _REGISTRY
merged = {**state, **result}
# Serialize stage output using the stage's serialize_fn if available
stage_cls = _REGISTRY.get(stage_name)
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
if serialize_fn:
output_json = serialize_fn(merged, job_id)
else:
output_json = {}
# Convert stats dataclass to dict for JSONB storage
import dataclasses
raw_stats = state.get("stats", {})
if dataclasses.is_dataclass(raw_stats):
stats_dict = dataclasses.asdict(raw_stats)
elif isinstance(raw_stats, dict):
stats_dict = raw_stats
else:
stats_dict = {}
# Save checkpoint (lightweight tree node)
parent_id = _latest_checkpoint.get(job_id)
checkpoint_id = save_checkpoint(
timeline_id=timeline_id,
stage_name=stage_name,
parent_checkpoint_id=parent_id,
config_overrides=state.get("config_overrides"),
stats=stats_dict,
job_id=job_id,
)
_latest_checkpoint[job_id] = checkpoint_id
# Save stage output (separate row, upsert by job+stage)
if output_json:
save_stage_output(
job_id=job_id,
timeline_id=timeline_id,
stage_name=stage_name,
output=output_json,
checkpoint_id=checkpoint_id,
)
logger.info("Checkpoint %s + output for stage %s (job %s)", checkpoint_id, stage_name, job_id)
def get_latest_checkpoint(job_id: str) -> str | None:
"""Get the latest checkpoint_id for a running job."""
return _latest_checkpoint.get(job_id)