phase 3
This commit is contained in:
@@ -4,12 +4,13 @@ Detection pipeline graph.
|
||||
detect/graph/
|
||||
nodes.py — node functions (one per stage)
|
||||
events.py — graph_update SSE emission
|
||||
runner.py — pipeline execution (LangGraph wrapper, checkpoint, cancel, pause)
|
||||
runner.py — PipelineRunner (config-driven, checkpoint, cancel, pause)
|
||||
"""
|
||||
|
||||
from .nodes import NODES, NODE_FUNCTIONS
|
||||
from .runner import (
|
||||
PipelineCancelled,
|
||||
PipelineRunner,
|
||||
build_graph,
|
||||
clear_cancel_check,
|
||||
clear_pause,
|
||||
@@ -28,6 +29,7 @@ __all__ = [
|
||||
"NODES",
|
||||
"NODE_FUNCTIONS",
|
||||
"PipelineCancelled",
|
||||
"PipelineRunner",
|
||||
"build_graph",
|
||||
"get_pipeline",
|
||||
"set_cancel_check",
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"""
|
||||
Pipeline runner — executes stages sequentially with checkpointing and cancellation.
|
||||
Pipeline runner — executes stages sequentially with checkpointing,
|
||||
cancellation, and pause/resume.
|
||||
|
||||
Currently wraps LangGraph for execution. Will be replaced with a lean
|
||||
custom runner in Phase 3, with an executor socket for distributed dispatch.
|
||||
Reads PipelineConfig from the profile to determine what stages to run.
|
||||
Flattens the graph into a linear sequence for now (serial execution).
|
||||
Executor socket: all stages run via LocalExecutor (call function directly).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -11,19 +13,14 @@ import logging
|
||||
import os
|
||||
import threading
|
||||
|
||||
from langgraph.graph import END, StateGraph
|
||||
|
||||
from core.schema.models.pipeline_config import PipelineConfig
|
||||
from detect.state import DetectState
|
||||
from .nodes import NODES, NODE_FUNCTIONS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# --- Checkpoint wrapper ---
|
||||
|
||||
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
|
||||
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
|
||||
_latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id
|
||||
|
||||
|
||||
class PipelineCancelled(Exception):
|
||||
@@ -53,10 +50,6 @@ def clear_cancel_check(job_id: str):
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pause / Resume / Step — checked after each node completes
|
||||
#
|
||||
# _pause_gate: threading.Event per job. When cleared, the runner blocks.
|
||||
# When set, the runner proceeds to the next node.
|
||||
# _pause_after_stage: if True, automatically clear the gate after each node.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_pause_gate: dict[str, threading.Event] = {}
|
||||
@@ -98,7 +91,7 @@ def step_pipeline(job_id: str):
|
||||
_pause_after_stage[job_id] = True
|
||||
gate = _pause_gate.get(job_id)
|
||||
if gate:
|
||||
gate.set() # unblock for one stage, _pause_after_stage re-pauses after
|
||||
gate.set()
|
||||
logger.info("Pipeline %s stepping", job_id)
|
||||
|
||||
|
||||
@@ -106,7 +99,6 @@ def set_pause_after_stage(job_id: str, enabled: bool):
|
||||
"""Toggle pause-after-each-stage mode."""
|
||||
_pause_after_stage[job_id] = enabled
|
||||
if not enabled:
|
||||
# If disabling, also resume in case we're currently paused
|
||||
gate = _pause_gate.get(job_id)
|
||||
if gate:
|
||||
gate.set()
|
||||
@@ -124,106 +116,159 @@ def _wait_if_paused(job_id: str, node_name: str):
|
||||
if gate is None:
|
||||
return
|
||||
|
||||
# If pause-after-stage is on, pause now
|
||||
if _pause_after_stage.get(job_id, False):
|
||||
gate.clear()
|
||||
from detect import emit
|
||||
emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}")
|
||||
|
||||
# Block until gate is set (resume/step) or cancelled
|
||||
while not gate.wait(timeout=0.5):
|
||||
check = _cancel_check.get(job_id)
|
||||
if check and check():
|
||||
raise PipelineCancelled(f"Cancelled while paused before next stage")
|
||||
|
||||
|
||||
def _checkpointing_node(node_name: str, node_fn):
|
||||
"""Wrap a node function to auto-checkpoint after completion."""
|
||||
|
||||
def wrapper(state: DetectState) -> dict:
|
||||
job_id = state.get("job_id", "")
|
||||
check = _cancel_check.get(job_id)
|
||||
if check and check():
|
||||
raise PipelineCancelled(f"Cancelled before {node_name}")
|
||||
|
||||
result = node_fn(state)
|
||||
|
||||
job_id = state.get("job_id", "")
|
||||
if not job_id:
|
||||
return result
|
||||
|
||||
from detect.checkpoint import save_stage_output, save_frames
|
||||
from detect.stages.base import _REGISTRY
|
||||
|
||||
merged = {**state, **result}
|
||||
|
||||
# Save frames once (first node), reuse manifest after
|
||||
manifest = _frames_manifest.get(job_id)
|
||||
if manifest is None and node_name == "extract_frames":
|
||||
manifest = save_frames(job_id, merged.get("frames", []))
|
||||
_frames_manifest[job_id] = manifest
|
||||
|
||||
# Serialize stage output using the stage's serialize_fn if available
|
||||
stage_cls = _REGISTRY.get(node_name)
|
||||
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
||||
if serialize_fn:
|
||||
output_json = serialize_fn(merged, job_id)
|
||||
else:
|
||||
output_json = {}
|
||||
|
||||
parent_id = _latest_checkpoint.get(job_id)
|
||||
new_checkpoint_id = save_stage_output(
|
||||
timeline_id=job_id,
|
||||
parent_checkpoint_id=parent_id,
|
||||
stage_name=node_name,
|
||||
output_json=output_json,
|
||||
)
|
||||
_latest_checkpoint[job_id] = new_checkpoint_id
|
||||
|
||||
# Pause check — blocks if paused, respects cancel while waiting
|
||||
_wait_if_paused(job_id, node_name)
|
||||
|
||||
return result
|
||||
|
||||
wrapper.__name__ = node_fn.__name__
|
||||
return wrapper
|
||||
|
||||
|
||||
# --- Graph construction ---
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pipeline Runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph:
|
||||
# Node function lookup — maps stage name to callable
|
||||
_NODE_FN_MAP: dict[str, callable] = {name: fn for name, fn in NODE_FUNCTIONS}
|
||||
|
||||
|
||||
def _flatten_config(config: PipelineConfig, start_from: str | None = None) -> list[str]:
|
||||
"""
|
||||
Build the pipeline graph.
|
||||
Flatten a PipelineConfig into a linear stage sequence.
|
||||
|
||||
checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var)
|
||||
start_from: skip nodes before this stage (for replay)
|
||||
For now: topological sort via edges. Falls back to stage order if no edges.
|
||||
Respects start_from for replay (skip stages before it).
|
||||
"""
|
||||
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
|
||||
if not config.edges:
|
||||
# No edges defined — use stage order as-is
|
||||
names = [s.name for s in config.stages]
|
||||
else:
|
||||
# Topological sort from edges
|
||||
graph: dict[str, list[str]] = {}
|
||||
in_degree: dict[str, int] = {}
|
||||
stage_names = {s.name for s in config.stages}
|
||||
|
||||
graph = StateGraph(DetectState)
|
||||
for name in stage_names:
|
||||
graph[name] = []
|
||||
in_degree[name] = 0
|
||||
|
||||
for edge in config.edges:
|
||||
if edge.source in stage_names and edge.target in stage_names:
|
||||
graph[edge.source].append(edge.target)
|
||||
in_degree[edge.target] = in_degree.get(edge.target, 0) + 1
|
||||
|
||||
# Kahn's algorithm
|
||||
queue = [n for n in stage_names if in_degree.get(n, 0) == 0]
|
||||
# Stable sort: prefer order from config.stages
|
||||
stage_order = {s.name: i for i, s in enumerate(config.stages)}
|
||||
queue.sort(key=lambda n: stage_order.get(n, 999))
|
||||
|
||||
names = []
|
||||
while queue:
|
||||
node = queue.pop(0)
|
||||
names.append(node)
|
||||
for neighbor in graph.get(node, []):
|
||||
in_degree[neighbor] -= 1
|
||||
if in_degree[neighbor] == 0:
|
||||
queue.append(neighbor)
|
||||
queue.sort(key=lambda n: stage_order.get(n, 999))
|
||||
|
||||
# Filter to start_from if replaying
|
||||
node_pairs = NODE_FUNCTIONS
|
||||
if start_from:
|
||||
start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from)
|
||||
node_pairs = NODE_FUNCTIONS[start_idx:]
|
||||
try:
|
||||
idx = names.index(start_from)
|
||||
names = names[idx:]
|
||||
except ValueError:
|
||||
raise ValueError(f"Stage {start_from!r} not in pipeline config")
|
||||
|
||||
for name, fn in node_pairs:
|
||||
wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn
|
||||
graph.add_node(name, wrapped)
|
||||
|
||||
# Wire edges
|
||||
entry = node_pairs[0][0]
|
||||
graph.set_entry_point(entry)
|
||||
|
||||
for i in range(len(node_pairs) - 1):
|
||||
graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0])
|
||||
|
||||
graph.add_edge(node_pairs[-1][0], END)
|
||||
|
||||
return graph
|
||||
return names
|
||||
|
||||
|
||||
def get_pipeline(checkpoint: bool | None = None):
|
||||
"""Return a compiled, runnable pipeline."""
|
||||
return build_graph(checkpoint=checkpoint).compile()
|
||||
class PipelineRunner:
|
||||
"""
|
||||
Executes a pipeline defined by PipelineConfig.
|
||||
|
||||
Runs stages sequentially (flattened). Each stage:
|
||||
1. Check cancel
|
||||
2. Run node function (via executor — local for now)
|
||||
3. Merge result into state
|
||||
4. Checkpoint (if enabled)
|
||||
5. Check pause
|
||||
|
||||
Executor socket: currently calls node functions directly.
|
||||
Future: dispatch to LocalExecutor / GrpcExecutor / LambdaExecutor
|
||||
based on StageRef.execution_target.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PipelineConfig,
|
||||
checkpoint: bool = False,
|
||||
start_from: str | None = None,
|
||||
):
|
||||
self.config = config
|
||||
self.do_checkpoint = checkpoint
|
||||
self.stage_sequence = _flatten_config(config, start_from)
|
||||
|
||||
def invoke(self, state: DetectState) -> DetectState:
|
||||
"""Run the pipeline on the given state. Returns final state."""
|
||||
for stage_name in self.stage_sequence:
|
||||
job_id = state.get("job_id", "")
|
||||
|
||||
# 1. Cancel check
|
||||
check = _cancel_check.get(job_id)
|
||||
if check and check():
|
||||
raise PipelineCancelled(f"Cancelled before {stage_name}")
|
||||
|
||||
# 2. Run node function
|
||||
node_fn = _NODE_FN_MAP.get(stage_name)
|
||||
if node_fn is None:
|
||||
logger.warning("No node function for stage %s, skipping", stage_name)
|
||||
continue
|
||||
|
||||
result = node_fn(state)
|
||||
|
||||
# 3. Merge result into state
|
||||
state.update(result)
|
||||
|
||||
# 4. Checkpoint
|
||||
if self.do_checkpoint:
|
||||
from detect.checkpoint import checkpoint_after_stage
|
||||
checkpoint_after_stage(job_id, stage_name, state, result)
|
||||
|
||||
# 5. Pause check
|
||||
_wait_if_paused(job_id, stage_name)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API — backwards compatible with old get_pipeline/build_graph
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def get_pipeline(
|
||||
checkpoint: bool | None = None,
|
||||
profile_name: str = "soccer_broadcast",
|
||||
start_from: str | None = None,
|
||||
) -> PipelineRunner:
|
||||
"""Return a PipelineRunner for the given profile."""
|
||||
from detect.profiles import get_profile
|
||||
|
||||
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
|
||||
profile = get_profile(profile_name)
|
||||
config = profile.pipeline_config()
|
||||
|
||||
return PipelineRunner(
|
||||
config=config,
|
||||
checkpoint=do_checkpoint,
|
||||
start_from=start_from,
|
||||
)
|
||||
|
||||
|
||||
def build_graph(checkpoint: bool | None = None, start_from: str | None = None):
|
||||
"""Backwards-compatible wrapper. Returns a PipelineRunner."""
|
||||
return get_pipeline(checkpoint=checkpoint, start_from=start_from)
|
||||
|
||||
Reference in New Issue
Block a user