275 lines
8.4 KiB
Python
275 lines
8.4 KiB
Python
"""
|
|
Pipeline runner — executes stages sequentially with checkpointing,
|
|
cancellation, and pause/resume.
|
|
|
|
Reads PipelineConfig from the profile to determine what stages to run.
|
|
Flattens the graph into a linear sequence for now (serial execution).
|
|
Executor socket: all stages run via LocalExecutor (call function directly).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
import threading
|
|
|
|
from core.schema.models.pipeline_config import PipelineConfig
|
|
from detect.state import DetectState
|
|
from .nodes import NODES, NODE_FUNCTIONS
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
|
|
|
|
|
|
class PipelineCancelled(Exception):
|
|
"""Raised when a pipeline run is cancelled."""
|
|
pass
|
|
|
|
|
|
class PipelinePaused(Exception):
|
|
"""Raised when a pipeline is paused (internally, for flow control)."""
|
|
pass
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Cancellation — checked before each node
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_cancel_check: dict[str, callable] = {}
|
|
|
|
|
|
def set_cancel_check(job_id: str, fn):
|
|
_cancel_check[job_id] = fn
|
|
|
|
|
|
def clear_cancel_check(job_id: str):
|
|
_cancel_check.pop(job_id, None)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Pause / Resume / Step — checked after each node completes
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_pause_gate: dict[str, threading.Event] = {}
|
|
_pause_after_stage: dict[str, bool] = {}
|
|
|
|
|
|
def init_pause(job_id: str, pause_after_stage: bool = False):
|
|
"""Initialize pause state for a job. Called when pipeline starts."""
|
|
gate = threading.Event()
|
|
gate.set() # start unpaused
|
|
_pause_gate[job_id] = gate
|
|
_pause_after_stage[job_id] = pause_after_stage
|
|
|
|
|
|
def clear_pause(job_id: str):
|
|
"""Clean up pause state. Called when pipeline finishes."""
|
|
_pause_gate.pop(job_id, None)
|
|
_pause_after_stage.pop(job_id, None)
|
|
|
|
|
|
def pause_pipeline(job_id: str):
|
|
"""Pause a running pipeline. It will block after the current stage completes."""
|
|
gate = _pause_gate.get(job_id)
|
|
if gate:
|
|
gate.clear()
|
|
logger.info("Pipeline %s paused", job_id)
|
|
|
|
|
|
def resume_pipeline(job_id: str):
|
|
"""Resume a paused pipeline."""
|
|
gate = _pause_gate.get(job_id)
|
|
if gate:
|
|
gate.set()
|
|
logger.info("Pipeline %s resumed", job_id)
|
|
|
|
|
|
def step_pipeline(job_id: str):
|
|
"""Run one stage then pause again."""
|
|
_pause_after_stage[job_id] = True
|
|
gate = _pause_gate.get(job_id)
|
|
if gate:
|
|
gate.set()
|
|
logger.info("Pipeline %s stepping", job_id)
|
|
|
|
|
|
def set_pause_after_stage(job_id: str, enabled: bool):
|
|
"""Toggle pause-after-each-stage mode."""
|
|
_pause_after_stage[job_id] = enabled
|
|
if not enabled:
|
|
gate = _pause_gate.get(job_id)
|
|
if gate:
|
|
gate.set()
|
|
|
|
|
|
def is_paused(job_id: str) -> bool:
|
|
"""Check if a pipeline is currently paused."""
|
|
gate = _pause_gate.get(job_id)
|
|
return gate is not None and not gate.is_set()
|
|
|
|
|
|
def _wait_if_paused(job_id: str, node_name: str):
|
|
"""Block until resumed. Called after each node completes."""
|
|
gate = _pause_gate.get(job_id)
|
|
if gate is None:
|
|
return
|
|
|
|
if _pause_after_stage.get(job_id, False):
|
|
gate.clear()
|
|
from detect import emit
|
|
emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}")
|
|
|
|
while not gate.wait(timeout=0.5):
|
|
check = _cancel_check.get(job_id)
|
|
if check and check():
|
|
raise PipelineCancelled(f"Cancelled while paused before next stage")
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Pipeline Runner
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Node function lookup — maps stage name to callable
|
|
_NODE_FN_MAP: dict[str, callable] = {name: fn for name, fn in NODE_FUNCTIONS}
|
|
|
|
|
|
def _flatten_config(config: PipelineConfig, start_from: str | None = None) -> list[str]:
|
|
"""
|
|
Flatten a PipelineConfig into a linear stage sequence.
|
|
|
|
For now: topological sort via edges. Falls back to stage order if no edges.
|
|
Respects start_from for replay (skip stages before it).
|
|
"""
|
|
if not config.edges:
|
|
# No edges defined — use stage order as-is
|
|
names = [s.name for s in config.stages]
|
|
else:
|
|
# Topological sort from edges
|
|
graph: dict[str, list[str]] = {}
|
|
in_degree: dict[str, int] = {}
|
|
stage_names = {s.name for s in config.stages}
|
|
|
|
for name in stage_names:
|
|
graph[name] = []
|
|
in_degree[name] = 0
|
|
|
|
for edge in config.edges:
|
|
if edge.source in stage_names and edge.target in stage_names:
|
|
graph[edge.source].append(edge.target)
|
|
in_degree[edge.target] = in_degree.get(edge.target, 0) + 1
|
|
|
|
# Kahn's algorithm
|
|
queue = [n for n in stage_names if in_degree.get(n, 0) == 0]
|
|
# Stable sort: prefer order from config.stages
|
|
stage_order = {s.name: i for i, s in enumerate(config.stages)}
|
|
queue.sort(key=lambda n: stage_order.get(n, 999))
|
|
|
|
names = []
|
|
while queue:
|
|
node = queue.pop(0)
|
|
names.append(node)
|
|
for neighbor in graph.get(node, []):
|
|
in_degree[neighbor] -= 1
|
|
if in_degree[neighbor] == 0:
|
|
queue.append(neighbor)
|
|
queue.sort(key=lambda n: stage_order.get(n, 999))
|
|
|
|
if start_from:
|
|
try:
|
|
idx = names.index(start_from)
|
|
names = names[idx:]
|
|
except ValueError:
|
|
raise ValueError(f"Stage {start_from!r} not in pipeline config")
|
|
|
|
return names
|
|
|
|
|
|
class PipelineRunner:
|
|
"""
|
|
Executes a pipeline defined by PipelineConfig.
|
|
|
|
Runs stages sequentially (flattened). Each stage:
|
|
1. Check cancel
|
|
2. Run node function (via executor — local for now)
|
|
3. Merge result into state
|
|
4. Checkpoint (if enabled)
|
|
5. Check pause
|
|
|
|
Executor socket: currently calls node functions directly.
|
|
Future: dispatch to LocalExecutor / GrpcExecutor / LambdaExecutor
|
|
based on StageRef.execution_target.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: PipelineConfig,
|
|
checkpoint: bool = False,
|
|
start_from: str | None = None,
|
|
):
|
|
self.config = config
|
|
self.do_checkpoint = checkpoint
|
|
self.stage_sequence = _flatten_config(config, start_from)
|
|
|
|
def invoke(self, state: DetectState) -> DetectState:
|
|
"""Run the pipeline on the given state. Returns final state."""
|
|
for stage_name in self.stage_sequence:
|
|
job_id = state.get("job_id", "")
|
|
|
|
# 1. Cancel check
|
|
check = _cancel_check.get(job_id)
|
|
if check and check():
|
|
raise PipelineCancelled(f"Cancelled before {stage_name}")
|
|
|
|
# 2. Run node function
|
|
node_fn = _NODE_FN_MAP.get(stage_name)
|
|
if node_fn is None:
|
|
logger.warning("No node function for stage %s, skipping", stage_name)
|
|
continue
|
|
|
|
result = node_fn(state)
|
|
|
|
# 3. Merge result into state
|
|
state.update(result)
|
|
|
|
# 4. Checkpoint
|
|
if self.do_checkpoint:
|
|
from detect.checkpoint import checkpoint_after_stage
|
|
checkpoint_after_stage(job_id, stage_name, state, result)
|
|
|
|
# 5. Pause check
|
|
_wait_if_paused(job_id, stage_name)
|
|
|
|
return state
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Public API — backwards compatible with old get_pipeline/build_graph
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def get_pipeline(
|
|
checkpoint: bool | None = None,
|
|
profile_name: str = "soccer_broadcast",
|
|
start_from: str | None = None,
|
|
) -> PipelineRunner:
|
|
"""Return a PipelineRunner for the given profile."""
|
|
from detect.profiles import get_profile
|
|
|
|
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
|
|
profile = get_profile(profile_name)
|
|
config = profile.pipeline_config()
|
|
|
|
return PipelineRunner(
|
|
config=config,
|
|
checkpoint=do_checkpoint,
|
|
start_from=start_from,
|
|
)
|
|
|
|
|
|
def build_graph(checkpoint: bool | None = None, start_from: str | None = None):
|
|
"""Backwards-compatible wrapper. Returns a PipelineRunner."""
|
|
return get_pipeline(checkpoint=checkpoint, start_from=start_from)
|