""" Pipeline runner — executes stages sequentially with checkpointing and cancellation. Currently wraps LangGraph for execution. Will be replaced with a lean custom runner in Phase 3, with an executor socket for distributed dispatch. """ from __future__ import annotations import logging import os import threading from langgraph.graph import END, StateGraph from detect.state import DetectState from .nodes import NODES, NODE_FUNCTIONS logger = logging.getLogger(__name__) # --- Checkpoint wrapper --- _CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1" _frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job) _latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id class PipelineCancelled(Exception): """Raised when a pipeline run is cancelled.""" pass class PipelinePaused(Exception): """Raised when a pipeline is paused (internally, for flow control).""" pass # --------------------------------------------------------------------------- # Cancellation — checked before each node # --------------------------------------------------------------------------- _cancel_check: dict[str, callable] = {} def set_cancel_check(job_id: str, fn): _cancel_check[job_id] = fn def clear_cancel_check(job_id: str): _cancel_check.pop(job_id, None) # --------------------------------------------------------------------------- # Pause / Resume / Step — checked after each node completes # # _pause_gate: threading.Event per job. When cleared, the runner blocks. # When set, the runner proceeds to the next node. # _pause_after_stage: if True, automatically clear the gate after each node. # --------------------------------------------------------------------------- _pause_gate: dict[str, threading.Event] = {} _pause_after_stage: dict[str, bool] = {} def init_pause(job_id: str, pause_after_stage: bool = False): """Initialize pause state for a job. Called when pipeline starts.""" gate = threading.Event() gate.set() # start unpaused _pause_gate[job_id] = gate _pause_after_stage[job_id] = pause_after_stage def clear_pause(job_id: str): """Clean up pause state. Called when pipeline finishes.""" _pause_gate.pop(job_id, None) _pause_after_stage.pop(job_id, None) def pause_pipeline(job_id: str): """Pause a running pipeline. It will block after the current stage completes.""" gate = _pause_gate.get(job_id) if gate: gate.clear() logger.info("Pipeline %s paused", job_id) def resume_pipeline(job_id: str): """Resume a paused pipeline.""" gate = _pause_gate.get(job_id) if gate: gate.set() logger.info("Pipeline %s resumed", job_id) def step_pipeline(job_id: str): """Run one stage then pause again.""" _pause_after_stage[job_id] = True gate = _pause_gate.get(job_id) if gate: gate.set() # unblock for one stage, _pause_after_stage re-pauses after logger.info("Pipeline %s stepping", job_id) def set_pause_after_stage(job_id: str, enabled: bool): """Toggle pause-after-each-stage mode.""" _pause_after_stage[job_id] = enabled if not enabled: # If disabling, also resume in case we're currently paused gate = _pause_gate.get(job_id) if gate: gate.set() def is_paused(job_id: str) -> bool: """Check if a pipeline is currently paused.""" gate = _pause_gate.get(job_id) return gate is not None and not gate.is_set() def _wait_if_paused(job_id: str, node_name: str): """Block until resumed. Called after each node completes.""" gate = _pause_gate.get(job_id) if gate is None: return # If pause-after-stage is on, pause now if _pause_after_stage.get(job_id, False): gate.clear() from detect import emit emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}") # Block until gate is set (resume/step) or cancelled while not gate.wait(timeout=0.5): check = _cancel_check.get(job_id) if check and check(): raise PipelineCancelled(f"Cancelled while paused before next stage") def _checkpointing_node(node_name: str, node_fn): """Wrap a node function to auto-checkpoint after completion.""" def wrapper(state: DetectState) -> dict: job_id = state.get("job_id", "") check = _cancel_check.get(job_id) if check and check(): raise PipelineCancelled(f"Cancelled before {node_name}") result = node_fn(state) job_id = state.get("job_id", "") if not job_id: return result from detect.checkpoint import save_stage_output, save_frames from detect.stages.base import _REGISTRY merged = {**state, **result} # Save frames once (first node), reuse manifest after manifest = _frames_manifest.get(job_id) if manifest is None and node_name == "extract_frames": manifest = save_frames(job_id, merged.get("frames", [])) _frames_manifest[job_id] = manifest # Serialize stage output using the stage's serialize_fn if available stage_cls = _REGISTRY.get(node_name) serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None) if serialize_fn: output_json = serialize_fn(merged, job_id) else: output_json = {} parent_id = _latest_checkpoint.get(job_id) new_checkpoint_id = save_stage_output( timeline_id=job_id, parent_checkpoint_id=parent_id, stage_name=node_name, output_json=output_json, ) _latest_checkpoint[job_id] = new_checkpoint_id # Pause check — blocks if paused, respects cancel while waiting _wait_if_paused(job_id, node_name) return result wrapper.__name__ = node_fn.__name__ return wrapper # --- Graph construction --- def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph: """ Build the pipeline graph. checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var) start_from: skip nodes before this stage (for replay) """ do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED graph = StateGraph(DetectState) # Filter to start_from if replaying node_pairs = NODE_FUNCTIONS if start_from: start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from) node_pairs = NODE_FUNCTIONS[start_idx:] for name, fn in node_pairs: wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn graph.add_node(name, wrapped) # Wire edges entry = node_pairs[0][0] graph.set_entry_point(entry) for i in range(len(node_pairs) - 1): graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0]) graph.add_edge(node_pairs[-1][0], END) return graph def get_pipeline(checkpoint: bool | None = None): """Return a compiled, runnable pipeline.""" return build_graph(checkpoint=checkpoint).compile()