""" Pipeline runner — executes stages sequentially with checkpointing, cancellation, and pause/resume. Reads PipelineConfig from the profile to determine what stages to run. Flattens the graph into a linear sequence for now (serial execution). Executor socket: all stages run via LocalExecutor (call function directly). """ from __future__ import annotations import logging import os import threading from core.schema.models.pipeline_config import PipelineConfig from detect.state import DetectState from .nodes import NODES, NODE_FUNCTIONS logger = logging.getLogger(__name__) _CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1" class PipelineCancelled(Exception): """Raised when a pipeline run is cancelled.""" pass class PipelinePaused(Exception): """Raised when a pipeline is paused (internally, for flow control).""" pass # --------------------------------------------------------------------------- # Cancellation — checked before each node # --------------------------------------------------------------------------- _cancel_check: dict[str, callable] = {} def set_cancel_check(job_id: str, fn): _cancel_check[job_id] = fn def clear_cancel_check(job_id: str): _cancel_check.pop(job_id, None) # --------------------------------------------------------------------------- # Pause / Resume / Step — checked after each node completes # --------------------------------------------------------------------------- _pause_gate: dict[str, threading.Event] = {} _pause_after_stage: dict[str, bool] = {} def init_pause(job_id: str, pause_after_stage: bool = False): """Initialize pause state for a job. Called when pipeline starts.""" gate = threading.Event() gate.set() # start unpaused _pause_gate[job_id] = gate _pause_after_stage[job_id] = pause_after_stage def clear_pause(job_id: str): """Clean up pause state. Called when pipeline finishes.""" _pause_gate.pop(job_id, None) _pause_after_stage.pop(job_id, None) def pause_pipeline(job_id: str): """Pause a running pipeline. It will block after the current stage completes.""" gate = _pause_gate.get(job_id) if gate: gate.clear() logger.info("Pipeline %s paused", job_id) def resume_pipeline(job_id: str): """Resume a paused pipeline.""" gate = _pause_gate.get(job_id) if gate: gate.set() logger.info("Pipeline %s resumed", job_id) def step_pipeline(job_id: str): """Run one stage then pause again.""" _pause_after_stage[job_id] = True gate = _pause_gate.get(job_id) if gate: gate.set() logger.info("Pipeline %s stepping", job_id) def set_pause_after_stage(job_id: str, enabled: bool): """Toggle pause-after-each-stage mode.""" _pause_after_stage[job_id] = enabled if not enabled: gate = _pause_gate.get(job_id) if gate: gate.set() def is_paused(job_id: str) -> bool: """Check if a pipeline is currently paused.""" gate = _pause_gate.get(job_id) return gate is not None and not gate.is_set() def _wait_if_paused(job_id: str, node_name: str): """Block until resumed. Called after each node completes.""" gate = _pause_gate.get(job_id) if gate is None: return if _pause_after_stage.get(job_id, False): gate.clear() from detect import emit emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}") while not gate.wait(timeout=0.5): check = _cancel_check.get(job_id) if check and check(): raise PipelineCancelled(f"Cancelled while paused before next stage") # --------------------------------------------------------------------------- # Pipeline Runner # --------------------------------------------------------------------------- # Node function lookup — maps stage name to callable _NODE_FN_MAP: dict[str, callable] = {name: fn for name, fn in NODE_FUNCTIONS} def _flatten_config(config: PipelineConfig, start_from: str | None = None) -> list[str]: """ Flatten a PipelineConfig into a linear stage sequence. For now: topological sort via edges. Falls back to stage order if no edges. Respects start_from for replay (skip stages before it). """ if not config.edges: # No edges defined — use stage order as-is names = [s.name for s in config.stages] else: # Topological sort from edges graph: dict[str, list[str]] = {} in_degree: dict[str, int] = {} stage_names = {s.name for s in config.stages} for name in stage_names: graph[name] = [] in_degree[name] = 0 for edge in config.edges: if edge.source in stage_names and edge.target in stage_names: graph[edge.source].append(edge.target) in_degree[edge.target] = in_degree.get(edge.target, 0) + 1 # Kahn's algorithm queue = [n for n in stage_names if in_degree.get(n, 0) == 0] # Stable sort: prefer order from config.stages stage_order = {s.name: i for i, s in enumerate(config.stages)} queue.sort(key=lambda n: stage_order.get(n, 999)) names = [] while queue: node = queue.pop(0) names.append(node) for neighbor in graph.get(node, []): in_degree[neighbor] -= 1 if in_degree[neighbor] == 0: queue.append(neighbor) queue.sort(key=lambda n: stage_order.get(n, 999)) if start_from: try: idx = names.index(start_from) names = names[idx:] except ValueError: raise ValueError(f"Stage {start_from!r} not in pipeline config") return names class PipelineRunner: """ Executes a pipeline defined by PipelineConfig. Runs stages sequentially (flattened). Each stage: 1. Check cancel 2. Run node function (via executor — local for now) 3. Merge result into state 4. Checkpoint (if enabled) 5. Check pause Executor socket: currently calls node functions directly. Future: dispatch to LocalExecutor / GrpcExecutor / LambdaExecutor based on StageRef.execution_target. """ def __init__( self, config: PipelineConfig, checkpoint: bool = False, start_from: str | None = None, ): self.config = config self.do_checkpoint = checkpoint self.stage_sequence = _flatten_config(config, start_from) def invoke(self, state: DetectState) -> DetectState: """Run the pipeline on the given state. Returns final state.""" for stage_name in self.stage_sequence: job_id = state.get("job_id", "") # 1. Cancel check check = _cancel_check.get(job_id) if check and check(): raise PipelineCancelled(f"Cancelled before {stage_name}") # 2. Run node function node_fn = _NODE_FN_MAP.get(stage_name) if node_fn is None: logger.warning("No node function for stage %s, skipping", stage_name) continue result = node_fn(state) # 3. Merge result into state state.update(result) # 4. Checkpoint if self.do_checkpoint: from detect.checkpoint import checkpoint_after_stage checkpoint_after_stage(job_id, stage_name, state, result) # 5. Pause check _wait_if_paused(job_id, stage_name) return state # --------------------------------------------------------------------------- # Public API — backwards compatible with old get_pipeline/build_graph # --------------------------------------------------------------------------- def get_pipeline( checkpoint: bool | None = None, profile_name: str = "soccer_broadcast", start_from: str | None = None, ) -> PipelineRunner: """Return a PipelineRunner for the given profile.""" from detect.profiles import get_profile do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED profile = get_profile(profile_name) config = profile.pipeline_config() return PipelineRunner( config=config, checkpoint=do_checkpoint, start_from=start_from, ) def build_graph(checkpoint: bool | None = None, start_from: str | None = None): """Backwards-compatible wrapper. Returns a PipelineRunner.""" return get_pipeline(checkpoint=checkpoint, start_from=start_from)