92 lines
2.8 KiB
Python
92 lines
2.8 KiB
Python
"""
|
|
Runner bridge — checkpoint hook called by PipelineRunner after each stage.
|
|
|
|
Saves a checkpoint + stage output after each stage completes.
|
|
Timeline and Job are independent: timeline_id and job_id come from
|
|
the pipeline state (set at job creation time).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Per-job state: tracks the latest checkpoint so we can chain parent → child
|
|
_latest_checkpoint: dict[str, str] = {}
|
|
|
|
|
|
def reset_checkpoint_state(job_id: str):
|
|
"""Clean up per-job checkpoint state. Called when pipeline finishes."""
|
|
_latest_checkpoint.pop(job_id, None)
|
|
|
|
|
|
def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict):
|
|
"""
|
|
Save a checkpoint + stage output after a stage completes.
|
|
|
|
Called by the runner. Handles:
|
|
- Stage output serialization (via stage registry)
|
|
- Checkpoint chain (parent → child)
|
|
- Stage output as separate row in StageOutput table
|
|
"""
|
|
if not job_id:
|
|
return
|
|
|
|
timeline_id = state.get("timeline_id", "")
|
|
if not timeline_id:
|
|
logger.warning("No timeline_id in state for job %s, skipping checkpoint", job_id)
|
|
return
|
|
|
|
from .storage import save_checkpoint, save_stage_output
|
|
from core.detect.stages.base import _REGISTRY
|
|
|
|
merged = {**state, **result}
|
|
|
|
# Serialize stage output using the stage's serialize_fn if available
|
|
stage_cls = _REGISTRY.get(stage_name)
|
|
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
|
if serialize_fn:
|
|
output_json = serialize_fn(merged, job_id)
|
|
else:
|
|
output_json = {}
|
|
|
|
# Convert stats dataclass to dict for JSONB storage
|
|
import dataclasses
|
|
raw_stats = state.get("stats", {})
|
|
if dataclasses.is_dataclass(raw_stats):
|
|
stats_dict = dataclasses.asdict(raw_stats)
|
|
elif isinstance(raw_stats, dict):
|
|
stats_dict = raw_stats
|
|
else:
|
|
stats_dict = {}
|
|
|
|
# Save checkpoint (lightweight tree node)
|
|
parent_id = _latest_checkpoint.get(job_id)
|
|
checkpoint_id = save_checkpoint(
|
|
timeline_id=timeline_id,
|
|
stage_name=stage_name,
|
|
parent_checkpoint_id=parent_id,
|
|
config_overrides=state.get("config_overrides"),
|
|
stats=stats_dict,
|
|
job_id=job_id,
|
|
)
|
|
_latest_checkpoint[job_id] = checkpoint_id
|
|
|
|
# Save stage output (separate row, upsert by job+stage)
|
|
if output_json:
|
|
save_stage_output(
|
|
job_id=job_id,
|
|
timeline_id=timeline_id,
|
|
stage_name=stage_name,
|
|
output=output_json,
|
|
checkpoint_id=checkpoint_id,
|
|
)
|
|
|
|
logger.info("Checkpoint %s + output for stage %s (job %s)", checkpoint_id, stage_name, job_id)
|
|
|
|
|
|
def get_latest_checkpoint(job_id: str) -> str | None:
|
|
"""Get the latest checkpoint_id for a running job."""
|
|
return _latest_checkpoint.get(job_id)
|