Files
mediaproc/detect/graph/runner.py
2026-03-28 09:40:07 -03:00

230 lines
7.0 KiB
Python

"""
Pipeline runner — executes stages sequentially with checkpointing and cancellation.
Currently wraps LangGraph for execution. Will be replaced with a lean
custom runner in Phase 3, with an executor socket for distributed dispatch.
"""
from __future__ import annotations
import logging
import os
import threading
from langgraph.graph import END, StateGraph
from detect.state import DetectState
from .nodes import NODES, NODE_FUNCTIONS
logger = logging.getLogger(__name__)
# --- Checkpoint wrapper ---
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
_latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id
class PipelineCancelled(Exception):
"""Raised when a pipeline run is cancelled."""
pass
class PipelinePaused(Exception):
"""Raised when a pipeline is paused (internally, for flow control)."""
pass
# ---------------------------------------------------------------------------
# Cancellation — checked before each node
# ---------------------------------------------------------------------------
_cancel_check: dict[str, callable] = {}
def set_cancel_check(job_id: str, fn):
_cancel_check[job_id] = fn
def clear_cancel_check(job_id: str):
_cancel_check.pop(job_id, None)
# ---------------------------------------------------------------------------
# Pause / Resume / Step — checked after each node completes
#
# _pause_gate: threading.Event per job. When cleared, the runner blocks.
# When set, the runner proceeds to the next node.
# _pause_after_stage: if True, automatically clear the gate after each node.
# ---------------------------------------------------------------------------
_pause_gate: dict[str, threading.Event] = {}
_pause_after_stage: dict[str, bool] = {}
def init_pause(job_id: str, pause_after_stage: bool = False):
"""Initialize pause state for a job. Called when pipeline starts."""
gate = threading.Event()
gate.set() # start unpaused
_pause_gate[job_id] = gate
_pause_after_stage[job_id] = pause_after_stage
def clear_pause(job_id: str):
"""Clean up pause state. Called when pipeline finishes."""
_pause_gate.pop(job_id, None)
_pause_after_stage.pop(job_id, None)
def pause_pipeline(job_id: str):
"""Pause a running pipeline. It will block after the current stage completes."""
gate = _pause_gate.get(job_id)
if gate:
gate.clear()
logger.info("Pipeline %s paused", job_id)
def resume_pipeline(job_id: str):
"""Resume a paused pipeline."""
gate = _pause_gate.get(job_id)
if gate:
gate.set()
logger.info("Pipeline %s resumed", job_id)
def step_pipeline(job_id: str):
"""Run one stage then pause again."""
_pause_after_stage[job_id] = True
gate = _pause_gate.get(job_id)
if gate:
gate.set() # unblock for one stage, _pause_after_stage re-pauses after
logger.info("Pipeline %s stepping", job_id)
def set_pause_after_stage(job_id: str, enabled: bool):
"""Toggle pause-after-each-stage mode."""
_pause_after_stage[job_id] = enabled
if not enabled:
# If disabling, also resume in case we're currently paused
gate = _pause_gate.get(job_id)
if gate:
gate.set()
def is_paused(job_id: str) -> bool:
"""Check if a pipeline is currently paused."""
gate = _pause_gate.get(job_id)
return gate is not None and not gate.is_set()
def _wait_if_paused(job_id: str, node_name: str):
"""Block until resumed. Called after each node completes."""
gate = _pause_gate.get(job_id)
if gate is None:
return
# If pause-after-stage is on, pause now
if _pause_after_stage.get(job_id, False):
gate.clear()
from detect import emit
emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}")
# Block until gate is set (resume/step) or cancelled
while not gate.wait(timeout=0.5):
check = _cancel_check.get(job_id)
if check and check():
raise PipelineCancelled(f"Cancelled while paused before next stage")
def _checkpointing_node(node_name: str, node_fn):
"""Wrap a node function to auto-checkpoint after completion."""
def wrapper(state: DetectState) -> dict:
job_id = state.get("job_id", "")
check = _cancel_check.get(job_id)
if check and check():
raise PipelineCancelled(f"Cancelled before {node_name}")
result = node_fn(state)
job_id = state.get("job_id", "")
if not job_id:
return result
from detect.checkpoint import save_stage_output, save_frames
from detect.stages.base import _REGISTRY
merged = {**state, **result}
# Save frames once (first node), reuse manifest after
manifest = _frames_manifest.get(job_id)
if manifest is None and node_name == "extract_frames":
manifest = save_frames(job_id, merged.get("frames", []))
_frames_manifest[job_id] = manifest
# Serialize stage output using the stage's serialize_fn if available
stage_cls = _REGISTRY.get(node_name)
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
if serialize_fn:
output_json = serialize_fn(merged, job_id)
else:
output_json = {}
parent_id = _latest_checkpoint.get(job_id)
new_checkpoint_id = save_stage_output(
timeline_id=job_id,
parent_checkpoint_id=parent_id,
stage_name=node_name,
output_json=output_json,
)
_latest_checkpoint[job_id] = new_checkpoint_id
# Pause check — blocks if paused, respects cancel while waiting
_wait_if_paused(job_id, node_name)
return result
wrapper.__name__ = node_fn.__name__
return wrapper
# --- Graph construction ---
def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph:
"""
Build the pipeline graph.
checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var)
start_from: skip nodes before this stage (for replay)
"""
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
graph = StateGraph(DetectState)
# Filter to start_from if replaying
node_pairs = NODE_FUNCTIONS
if start_from:
start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from)
node_pairs = NODE_FUNCTIONS[start_idx:]
for name, fn in node_pairs:
wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn
graph.add_node(name, wrapped)
# Wire edges
entry = node_pairs[0][0]
graph.set_entry_point(entry)
for i in range(len(node_pairs) - 1):
graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0])
graph.add_edge(node_pairs[-1][0], END)
return graph
def get_pipeline(checkpoint: bool | None = None):
"""Return a compiled, runnable pipeline."""
return build_graph(checkpoint=checkpoint).compile()