230 lines
7.0 KiB
Python
230 lines
7.0 KiB
Python
"""
|
|
Pipeline runner — executes stages sequentially with checkpointing and cancellation.
|
|
|
|
Currently wraps LangGraph for execution. Will be replaced with a lean
|
|
custom runner in Phase 3, with an executor socket for distributed dispatch.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
import threading
|
|
|
|
from langgraph.graph import END, StateGraph
|
|
|
|
from detect.state import DetectState
|
|
from .nodes import NODES, NODE_FUNCTIONS
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# --- Checkpoint wrapper ---
|
|
|
|
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
|
|
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
|
|
_latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id
|
|
|
|
|
|
class PipelineCancelled(Exception):
|
|
"""Raised when a pipeline run is cancelled."""
|
|
pass
|
|
|
|
|
|
class PipelinePaused(Exception):
|
|
"""Raised when a pipeline is paused (internally, for flow control)."""
|
|
pass
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Cancellation — checked before each node
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_cancel_check: dict[str, callable] = {}
|
|
|
|
|
|
def set_cancel_check(job_id: str, fn):
|
|
_cancel_check[job_id] = fn
|
|
|
|
|
|
def clear_cancel_check(job_id: str):
|
|
_cancel_check.pop(job_id, None)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Pause / Resume / Step — checked after each node completes
|
|
#
|
|
# _pause_gate: threading.Event per job. When cleared, the runner blocks.
|
|
# When set, the runner proceeds to the next node.
|
|
# _pause_after_stage: if True, automatically clear the gate after each node.
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_pause_gate: dict[str, threading.Event] = {}
|
|
_pause_after_stage: dict[str, bool] = {}
|
|
|
|
|
|
def init_pause(job_id: str, pause_after_stage: bool = False):
|
|
"""Initialize pause state for a job. Called when pipeline starts."""
|
|
gate = threading.Event()
|
|
gate.set() # start unpaused
|
|
_pause_gate[job_id] = gate
|
|
_pause_after_stage[job_id] = pause_after_stage
|
|
|
|
|
|
def clear_pause(job_id: str):
|
|
"""Clean up pause state. Called when pipeline finishes."""
|
|
_pause_gate.pop(job_id, None)
|
|
_pause_after_stage.pop(job_id, None)
|
|
|
|
|
|
def pause_pipeline(job_id: str):
|
|
"""Pause a running pipeline. It will block after the current stage completes."""
|
|
gate = _pause_gate.get(job_id)
|
|
if gate:
|
|
gate.clear()
|
|
logger.info("Pipeline %s paused", job_id)
|
|
|
|
|
|
def resume_pipeline(job_id: str):
|
|
"""Resume a paused pipeline."""
|
|
gate = _pause_gate.get(job_id)
|
|
if gate:
|
|
gate.set()
|
|
logger.info("Pipeline %s resumed", job_id)
|
|
|
|
|
|
def step_pipeline(job_id: str):
|
|
"""Run one stage then pause again."""
|
|
_pause_after_stage[job_id] = True
|
|
gate = _pause_gate.get(job_id)
|
|
if gate:
|
|
gate.set() # unblock for one stage, _pause_after_stage re-pauses after
|
|
logger.info("Pipeline %s stepping", job_id)
|
|
|
|
|
|
def set_pause_after_stage(job_id: str, enabled: bool):
|
|
"""Toggle pause-after-each-stage mode."""
|
|
_pause_after_stage[job_id] = enabled
|
|
if not enabled:
|
|
# If disabling, also resume in case we're currently paused
|
|
gate = _pause_gate.get(job_id)
|
|
if gate:
|
|
gate.set()
|
|
|
|
|
|
def is_paused(job_id: str) -> bool:
|
|
"""Check if a pipeline is currently paused."""
|
|
gate = _pause_gate.get(job_id)
|
|
return gate is not None and not gate.is_set()
|
|
|
|
|
|
def _wait_if_paused(job_id: str, node_name: str):
|
|
"""Block until resumed. Called after each node completes."""
|
|
gate = _pause_gate.get(job_id)
|
|
if gate is None:
|
|
return
|
|
|
|
# If pause-after-stage is on, pause now
|
|
if _pause_after_stage.get(job_id, False):
|
|
gate.clear()
|
|
from detect import emit
|
|
emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}")
|
|
|
|
# Block until gate is set (resume/step) or cancelled
|
|
while not gate.wait(timeout=0.5):
|
|
check = _cancel_check.get(job_id)
|
|
if check and check():
|
|
raise PipelineCancelled(f"Cancelled while paused before next stage")
|
|
|
|
|
|
def _checkpointing_node(node_name: str, node_fn):
|
|
"""Wrap a node function to auto-checkpoint after completion."""
|
|
|
|
def wrapper(state: DetectState) -> dict:
|
|
job_id = state.get("job_id", "")
|
|
check = _cancel_check.get(job_id)
|
|
if check and check():
|
|
raise PipelineCancelled(f"Cancelled before {node_name}")
|
|
|
|
result = node_fn(state)
|
|
|
|
job_id = state.get("job_id", "")
|
|
if not job_id:
|
|
return result
|
|
|
|
from detect.checkpoint import save_stage_output, save_frames
|
|
from detect.stages.base import _REGISTRY
|
|
|
|
merged = {**state, **result}
|
|
|
|
# Save frames once (first node), reuse manifest after
|
|
manifest = _frames_manifest.get(job_id)
|
|
if manifest is None and node_name == "extract_frames":
|
|
manifest = save_frames(job_id, merged.get("frames", []))
|
|
_frames_manifest[job_id] = manifest
|
|
|
|
# Serialize stage output using the stage's serialize_fn if available
|
|
stage_cls = _REGISTRY.get(node_name)
|
|
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
|
if serialize_fn:
|
|
output_json = serialize_fn(merged, job_id)
|
|
else:
|
|
output_json = {}
|
|
|
|
parent_id = _latest_checkpoint.get(job_id)
|
|
new_checkpoint_id = save_stage_output(
|
|
timeline_id=job_id,
|
|
parent_checkpoint_id=parent_id,
|
|
stage_name=node_name,
|
|
output_json=output_json,
|
|
)
|
|
_latest_checkpoint[job_id] = new_checkpoint_id
|
|
|
|
# Pause check — blocks if paused, respects cancel while waiting
|
|
_wait_if_paused(job_id, node_name)
|
|
|
|
return result
|
|
|
|
wrapper.__name__ = node_fn.__name__
|
|
return wrapper
|
|
|
|
|
|
# --- Graph construction ---
|
|
|
|
def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph:
|
|
"""
|
|
Build the pipeline graph.
|
|
|
|
checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var)
|
|
start_from: skip nodes before this stage (for replay)
|
|
"""
|
|
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
|
|
|
|
graph = StateGraph(DetectState)
|
|
|
|
# Filter to start_from if replaying
|
|
node_pairs = NODE_FUNCTIONS
|
|
if start_from:
|
|
start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from)
|
|
node_pairs = NODE_FUNCTIONS[start_idx:]
|
|
|
|
for name, fn in node_pairs:
|
|
wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn
|
|
graph.add_node(name, wrapped)
|
|
|
|
# Wire edges
|
|
entry = node_pairs[0][0]
|
|
graph.set_entry_point(entry)
|
|
|
|
for i in range(len(node_pairs) - 1):
|
|
graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0])
|
|
|
|
graph.add_edge(node_pairs[-1][0], END)
|
|
|
|
return graph
|
|
|
|
|
|
def get_pipeline(checkpoint: bool | None = None):
|
|
"""Return a compiled, runnable pipeline."""
|
|
return build_graph(checkpoint=checkpoint).compile()
|