""" Pipeline runner — executes stages sequentially with checkpointing and cancellation. Currently wraps LangGraph for execution. Will be replaced with a lean custom runner in Phase 3, with an executor socket for distributed dispatch. """ from __future__ import annotations import os from langgraph.graph import END, StateGraph from detect.state import DetectState from .nodes import NODES, NODE_FUNCTIONS # --- Checkpoint wrapper --- _CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1" _frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job) _latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id class PipelineCancelled(Exception): """Raised when a pipeline run is cancelled.""" pass # Cancellation hook — set by the run endpoint, checked before each node _cancel_check: dict[str, callable] = {} def set_cancel_check(job_id: str, fn): _cancel_check[job_id] = fn def clear_cancel_check(job_id: str): _cancel_check.pop(job_id, None) def _checkpointing_node(node_name: str, node_fn): """Wrap a node function to auto-checkpoint after completion.""" def wrapper(state: DetectState) -> dict: job_id = state.get("job_id", "") check = _cancel_check.get(job_id) if check and check(): raise PipelineCancelled(f"Cancelled before {node_name}") result = node_fn(state) job_id = state.get("job_id", "") if not job_id: return result from detect.checkpoint import save_stage_output, save_frames from detect.stages.base import _REGISTRY merged = {**state, **result} # Save frames once (first node), reuse manifest after manifest = _frames_manifest.get(job_id) if manifest is None and node_name == "extract_frames": manifest = save_frames(job_id, merged.get("frames", [])) _frames_manifest[job_id] = manifest # Serialize stage output using the stage's serialize_fn if available stage_cls = _REGISTRY.get(node_name) serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None) if serialize_fn: output_json = serialize_fn(merged, job_id) else: output_json = {} parent_id = _latest_checkpoint.get(job_id) new_checkpoint_id = save_stage_output( timeline_id=job_id, parent_checkpoint_id=parent_id, stage_name=node_name, output_json=output_json, ) _latest_checkpoint[job_id] = new_checkpoint_id return result wrapper.__name__ = node_fn.__name__ return wrapper # --- Graph construction --- def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph: """ Build the pipeline graph. checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var) start_from: skip nodes before this stage (for replay) """ do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED graph = StateGraph(DetectState) # Filter to start_from if replaying node_pairs = NODE_FUNCTIONS if start_from: start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from) node_pairs = NODE_FUNCTIONS[start_idx:] for name, fn in node_pairs: wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn graph.add_node(name, wrapped) # Wire edges entry = node_pairs[0][0] graph.set_entry_point(entry) for i in range(len(node_pairs) - 1): graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0]) graph.add_edge(node_pairs[-1][0], END) return graph def get_pipeline(checkpoint: bool | None = None): """Return a compiled, runnable pipeline.""" return build_graph(checkpoint=checkpoint).compile()