""" Runner bridge — checkpoint hook called by PipelineRunner after each stage. Saves a checkpoint + stage output after each stage completes. Timeline and Job are independent: timeline_id and job_id come from the pipeline state (set at job creation time). """ from __future__ import annotations import logging logger = logging.getLogger(__name__) # Per-job state: tracks the latest checkpoint so we can chain parent → child _latest_checkpoint: dict[str, str] = {} def reset_checkpoint_state(job_id: str): """Clean up per-job checkpoint state. Called when pipeline finishes.""" _latest_checkpoint.pop(job_id, None) def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict): """ Save a checkpoint + stage output after a stage completes. Called by the runner. Handles: - Stage output serialization (via stage registry) - Checkpoint chain (parent → child) - Stage output as separate row in StageOutput table """ if not job_id: return timeline_id = state.get("timeline_id", "") if not timeline_id: logger.warning("No timeline_id in state for job %s, skipping checkpoint", job_id) return from .storage import save_checkpoint, save_stage_output from core.detect.stages.base import _REGISTRY, _LEGACY_REGISTRY merged = {**state, **result} # Serialize stage output using the stage's serialize_fn if available # Check new-style registry first, then legacy (some stages are in both) serialize_fn = None stage_cls = _REGISTRY.get(stage_name) if stage_cls: serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None) if not serialize_fn: legacy = _LEGACY_REGISTRY.get(stage_name) if legacy: serialize_fn = legacy.serialize_fn if serialize_fn: output_json = serialize_fn(merged, job_id) else: output_json = {} # Convert stats dataclass to dict for JSONB storage import dataclasses raw_stats = state.get("stats", {}) if dataclasses.is_dataclass(raw_stats): stats_dict = dataclasses.asdict(raw_stats) elif isinstance(raw_stats, dict): stats_dict = raw_stats else: stats_dict = {} # Save checkpoint (lightweight tree node) parent_id = _latest_checkpoint.get(job_id) checkpoint_id = save_checkpoint( timeline_id=timeline_id, stage_name=stage_name, parent_checkpoint_id=parent_id, config_overrides=state.get("config_overrides"), stats=stats_dict, job_id=job_id, ) _latest_checkpoint[job_id] = checkpoint_id # Save stage output (separate row, upsert by job+stage) if output_json: save_stage_output( job_id=job_id, timeline_id=timeline_id, stage_name=stage_name, output=output_json, checkpoint_id=checkpoint_id, ) logger.info("Checkpoint %s + output for stage %s (job %s)", checkpoint_id, stage_name, job_id) def get_latest_checkpoint(job_id: str) -> str | None: """Get the latest checkpoint_id for a running job.""" return _latest_checkpoint.get(job_id)