128 lines
3.8 KiB
Python
128 lines
3.8 KiB
Python
"""
|
|
Pipeline runner — executes stages sequentially with checkpointing and cancellation.
|
|
|
|
Currently wraps LangGraph for execution. Will be replaced with a lean
|
|
custom runner in Phase 3, with an executor socket for distributed dispatch.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
|
|
from langgraph.graph import END, StateGraph
|
|
|
|
from detect.state import DetectState
|
|
from .nodes import NODES, NODE_FUNCTIONS
|
|
|
|
|
|
# --- Checkpoint wrapper ---
|
|
|
|
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
|
|
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
|
|
_latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id
|
|
|
|
|
|
class PipelineCancelled(Exception):
|
|
"""Raised when a pipeline run is cancelled."""
|
|
pass
|
|
|
|
|
|
# Cancellation hook — set by the run endpoint, checked before each node
|
|
_cancel_check: dict[str, callable] = {}
|
|
|
|
|
|
def set_cancel_check(job_id: str, fn):
|
|
_cancel_check[job_id] = fn
|
|
|
|
|
|
def clear_cancel_check(job_id: str):
|
|
_cancel_check.pop(job_id, None)
|
|
|
|
|
|
def _checkpointing_node(node_name: str, node_fn):
|
|
"""Wrap a node function to auto-checkpoint after completion."""
|
|
|
|
def wrapper(state: DetectState) -> dict:
|
|
job_id = state.get("job_id", "")
|
|
check = _cancel_check.get(job_id)
|
|
if check and check():
|
|
raise PipelineCancelled(f"Cancelled before {node_name}")
|
|
|
|
result = node_fn(state)
|
|
|
|
job_id = state.get("job_id", "")
|
|
if not job_id:
|
|
return result
|
|
|
|
from detect.checkpoint import save_stage_output, save_frames
|
|
from detect.stages.base import _REGISTRY
|
|
|
|
merged = {**state, **result}
|
|
|
|
# Save frames once (first node), reuse manifest after
|
|
manifest = _frames_manifest.get(job_id)
|
|
if manifest is None and node_name == "extract_frames":
|
|
manifest = save_frames(job_id, merged.get("frames", []))
|
|
_frames_manifest[job_id] = manifest
|
|
|
|
# Serialize stage output using the stage's serialize_fn if available
|
|
stage_cls = _REGISTRY.get(node_name)
|
|
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
|
if serialize_fn:
|
|
output_json = serialize_fn(merged, job_id)
|
|
else:
|
|
output_json = {}
|
|
|
|
parent_id = _latest_checkpoint.get(job_id)
|
|
new_checkpoint_id = save_stage_output(
|
|
timeline_id=job_id,
|
|
parent_checkpoint_id=parent_id,
|
|
stage_name=node_name,
|
|
output_json=output_json,
|
|
)
|
|
_latest_checkpoint[job_id] = new_checkpoint_id
|
|
return result
|
|
|
|
wrapper.__name__ = node_fn.__name__
|
|
return wrapper
|
|
|
|
|
|
# --- Graph construction ---
|
|
|
|
def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph:
|
|
"""
|
|
Build the pipeline graph.
|
|
|
|
checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var)
|
|
start_from: skip nodes before this stage (for replay)
|
|
"""
|
|
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
|
|
|
|
graph = StateGraph(DetectState)
|
|
|
|
# Filter to start_from if replaying
|
|
node_pairs = NODE_FUNCTIONS
|
|
if start_from:
|
|
start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from)
|
|
node_pairs = NODE_FUNCTIONS[start_idx:]
|
|
|
|
for name, fn in node_pairs:
|
|
wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn
|
|
graph.add_node(name, wrapped)
|
|
|
|
# Wire edges
|
|
entry = node_pairs[0][0]
|
|
graph.set_entry_point(entry)
|
|
|
|
for i in range(len(node_pairs) - 1):
|
|
graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0])
|
|
|
|
graph.add_edge(node_pairs[-1][0], END)
|
|
|
|
return graph
|
|
|
|
|
|
def get_pipeline(checkpoint: bool | None = None):
|
|
"""Return a compiled, runnable pipeline."""
|
|
return build_graph(checkpoint=checkpoint).compile()
|