phase 1
This commit is contained in:
@@ -5,7 +5,6 @@ Checkpoint system — Timeline + Checkpoint tree.
|
||||
frames.py — frame image S3 upload/download
|
||||
storage.py — Timeline + Checkpoint (Postgres + MinIO)
|
||||
replay.py — replay (TODO: migrate to new model)
|
||||
tasks.py — retry_candidates Celery task
|
||||
"""
|
||||
|
||||
from .storage import (
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
"""
|
||||
Celery tasks for detection pipeline async operations.
|
||||
|
||||
retry_candidates: re-run VLM/cloud escalation with different config.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from celery import shared_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(bind=True, max_retries=1, default_retry_delay=30)
|
||||
def retry_candidates(
|
||||
self,
|
||||
job_id: str,
|
||||
config_overrides: dict | None = None,
|
||||
start_stage: str = "escalate_vlm",
|
||||
):
|
||||
"""
|
||||
Retry unresolved candidates with different config.
|
||||
|
||||
Loads the checkpoint from the stage before start_stage,
|
||||
applies config overrides (e.g. different cloud provider),
|
||||
and runs from start_stage onward.
|
||||
"""
|
||||
from detect.checkpoint.replay import replay_from
|
||||
|
||||
run_id = str(uuid.uuid4())[:8]
|
||||
logger.info("Retry task %s: job=%s, from=%s, overrides=%s",
|
||||
run_id, job_id, start_stage, config_overrides)
|
||||
|
||||
try:
|
||||
result = replay_from(
|
||||
job_id=job_id,
|
||||
start_stage=start_stage,
|
||||
config_overrides=config_overrides,
|
||||
)
|
||||
|
||||
detections = result.get("detections", [])
|
||||
report = result.get("report")
|
||||
brands_found = len(report.brands) if report else 0
|
||||
|
||||
logger.info("Retry %s complete: %d detections, %d brands",
|
||||
run_id, len(detections), brands_found)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"run_id": run_id,
|
||||
"job_id": job_id,
|
||||
"detections": len(detections),
|
||||
"brands_found": brands_found,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Retry %s failed: %s", run_id, e)
|
||||
|
||||
if self.request.retries < self.max_retries:
|
||||
raise self.retry(exc=e)
|
||||
|
||||
return {
|
||||
"status": "failed",
|
||||
"run_id": run_id,
|
||||
"job_id": job_id,
|
||||
"error": str(e),
|
||||
}
|
||||
29
detect/graph/__init__.py
Normal file
29
detect/graph/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Detection pipeline graph.
|
||||
|
||||
detect/graph/
|
||||
nodes.py — node functions (one per stage)
|
||||
events.py — graph_update SSE emission
|
||||
runner.py — pipeline execution (LangGraph wrapper, checkpoint, cancel)
|
||||
"""
|
||||
|
||||
from .nodes import NODES, NODE_FUNCTIONS
|
||||
from .runner import (
|
||||
PipelineCancelled,
|
||||
build_graph,
|
||||
clear_cancel_check,
|
||||
get_pipeline,
|
||||
set_cancel_check,
|
||||
)
|
||||
from .events import _node_states
|
||||
|
||||
__all__ = [
|
||||
"NODES",
|
||||
"NODE_FUNCTIONS",
|
||||
"PipelineCancelled",
|
||||
"build_graph",
|
||||
"get_pipeline",
|
||||
"set_cancel_check",
|
||||
"clear_cancel_check",
|
||||
"_node_states",
|
||||
]
|
||||
27
detect/graph/events.py
Normal file
27
detect/graph/events.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
Graph event emission — node state tracking + SSE graph_update events.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from detect import emit
|
||||
from detect.state import DetectState
|
||||
|
||||
|
||||
# Track node states across pipeline runs
|
||||
_node_states: dict[str, dict[str, str]] = {}
|
||||
|
||||
|
||||
def emit_transition(state: DetectState, node: str, status: str, node_list: list[str]):
|
||||
"""Update node status and emit graph_update SSE event."""
|
||||
job_id = state.get("job_id")
|
||||
if not job_id:
|
||||
return
|
||||
|
||||
if job_id not in _node_states:
|
||||
_node_states[job_id] = {n: "pending" for n in node_list}
|
||||
|
||||
_node_states[job_id][node] = status
|
||||
|
||||
nodes = [{"id": n, "status": _node_states[job_id][n]} for n in node_list]
|
||||
emit.graph_update(job_id, nodes)
|
||||
@@ -1,16 +1,13 @@
|
||||
"""
|
||||
LangGraph pipeline graph for brand detection.
|
||||
Pipeline node functions — one per stage.
|
||||
|
||||
Nodes execute real logic for extract+filter, stubs for the rest.
|
||||
Each node emits graph_update events so the UI can visualize transitions.
|
||||
Each node: reads state, runs stage logic, emits transitions, returns output dict.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from langgraph.graph import END, StateGraph
|
||||
|
||||
from detect import emit
|
||||
from detect.models import PipelineStats
|
||||
from detect.profiles import SoccerBroadcastProfile
|
||||
@@ -27,6 +24,8 @@ from detect.stages.vlm_cloud import escalate_cloud
|
||||
from detect.stages.aggregator import compile_report
|
||||
from detect.tracing import trace_node, flush as flush_traces
|
||||
|
||||
from .events import emit_transition
|
||||
|
||||
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
|
||||
|
||||
NODES = [
|
||||
@@ -58,41 +57,24 @@ def _get_profile(state: DetectState):
|
||||
return profile
|
||||
|
||||
|
||||
# Track node states across the pipeline run
|
||||
_node_states: dict[str, dict[str, str]] = {}
|
||||
|
||||
|
||||
def _emit_transition(state: DetectState, node: str, status: str):
|
||||
job_id = state.get("job_id")
|
||||
if not job_id:
|
||||
return
|
||||
|
||||
# Initialize state tracking for this job
|
||||
if job_id not in _node_states:
|
||||
_node_states[job_id] = {n: "pending" for n in NODES}
|
||||
|
||||
_node_states[job_id][node] = status
|
||||
|
||||
nodes = [{"id": n, "status": _node_states[job_id][n]} for n in NODES]
|
||||
emit.graph_update(job_id, nodes)
|
||||
def _emit(state, node, status):
|
||||
emit_transition(state, node, status, NODES)
|
||||
|
||||
|
||||
# --- Node functions ---
|
||||
|
||||
def node_extract_frames(state: DetectState) -> dict:
|
||||
# Set run context for initial runs (replays set it in replay_from)
|
||||
job_id = state.get("job_id", "")
|
||||
if job_id and not emit._run_context:
|
||||
emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial")
|
||||
|
||||
# Load session brands from DB for this source
|
||||
source_asset_id = state.get("source_asset_id")
|
||||
if source_asset_id and not state.get("session_brands"):
|
||||
from detect.stages.brand_resolver import build_session_dict
|
||||
session_brands = build_session_dict(source_asset_id)
|
||||
state["session_brands"] = session_brands
|
||||
|
||||
_emit_transition(state, "extract_frames", "running")
|
||||
_emit(state, "extract_frames", "running")
|
||||
|
||||
with trace_node(state, "extract_frames") as span:
|
||||
profile = _get_profile(state)
|
||||
@@ -100,12 +82,12 @@ def node_extract_frames(state: DetectState) -> dict:
|
||||
frames = extract_frames(state["video_path"], config, job_id=state.get("job_id"))
|
||||
span.set_output({"frames_extracted": len(frames)})
|
||||
|
||||
_emit_transition(state, "extract_frames", "done")
|
||||
_emit(state, "extract_frames", "done")
|
||||
return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))}
|
||||
|
||||
|
||||
def node_filter_scenes(state: DetectState) -> dict:
|
||||
_emit_transition(state, "filter_scenes", "running")
|
||||
_emit(state, "filter_scenes", "running")
|
||||
|
||||
with trace_node(state, "filter_scenes") as span:
|
||||
profile = _get_profile(state)
|
||||
@@ -117,12 +99,12 @@ def node_filter_scenes(state: DetectState) -> dict:
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.frames_after_scene_filter = len(kept)
|
||||
|
||||
_emit_transition(state, "filter_scenes", "done")
|
||||
_emit(state, "filter_scenes", "done")
|
||||
return {"filtered_frames": kept, "stats": stats}
|
||||
|
||||
|
||||
def node_detect_edges(state: DetectState) -> dict:
|
||||
_emit_transition(state, "detect_edges", "running")
|
||||
_emit(state, "detect_edges", "running")
|
||||
|
||||
with trace_node(state, "detect_edges") as span:
|
||||
profile = _get_profile(state)
|
||||
@@ -139,12 +121,12 @@ def node_detect_edges(state: DetectState) -> dict:
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.cv_regions_detected = total
|
||||
|
||||
_emit_transition(state, "detect_edges", "done")
|
||||
_emit(state, "detect_edges", "done")
|
||||
return {"edge_regions_by_frame": regions, "stats": stats}
|
||||
|
||||
|
||||
def node_detect_objects(state: DetectState) -> dict:
|
||||
_emit_transition(state, "detect_objects", "running")
|
||||
_emit(state, "detect_objects", "running")
|
||||
|
||||
with trace_node(state, "detect_objects") as span:
|
||||
profile = _get_profile(state)
|
||||
@@ -159,12 +141,12 @@ def node_detect_objects(state: DetectState) -> dict:
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.regions_detected = total_regions
|
||||
|
||||
_emit_transition(state, "detect_objects", "done")
|
||||
_emit(state, "detect_objects", "done")
|
||||
return {"boxes_by_frame": all_boxes, "stats": stats}
|
||||
|
||||
|
||||
def node_preprocess(state: DetectState) -> dict:
|
||||
_emit_transition(state, "preprocess", "running")
|
||||
_emit(state, "preprocess", "running")
|
||||
|
||||
with trace_node(state, "preprocess") as span:
|
||||
profile = _get_profile(state)
|
||||
@@ -172,7 +154,6 @@ def node_preprocess(state: DetectState) -> dict:
|
||||
boxes = state.get("boxes_by_frame", {})
|
||||
job_id = state.get("job_id")
|
||||
|
||||
# Get preprocessing config from profile overrides or defaults
|
||||
overrides = state.get("config_overrides", {})
|
||||
prep_config = overrides.get("preprocessing", {})
|
||||
do_contrast = prep_config.get("contrast", True)
|
||||
@@ -189,12 +170,12 @@ def node_preprocess(state: DetectState) -> dict:
|
||||
)
|
||||
span.set_output({"regions_preprocessed": len(preprocessed)})
|
||||
|
||||
_emit_transition(state, "preprocess", "done")
|
||||
_emit(state, "preprocess", "done")
|
||||
return {"preprocessed_crops": preprocessed}
|
||||
|
||||
|
||||
def node_run_ocr(state: DetectState) -> dict:
|
||||
_emit_transition(state, "run_ocr", "running")
|
||||
_emit(state, "run_ocr", "running")
|
||||
|
||||
with trace_node(state, "run_ocr") as span:
|
||||
profile = _get_profile(state)
|
||||
@@ -209,12 +190,12 @@ def node_run_ocr(state: DetectState) -> dict:
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.regions_resolved_by_ocr = len(candidates)
|
||||
|
||||
_emit_transition(state, "run_ocr", "done")
|
||||
_emit(state, "run_ocr", "done")
|
||||
return {"text_candidates": candidates, "stats": stats}
|
||||
|
||||
|
||||
def node_match_brands(state: DetectState) -> dict:
|
||||
_emit_transition(state, "match_brands", "running")
|
||||
_emit(state, "match_brands", "running")
|
||||
|
||||
with trace_node(state, "match_brands") as span:
|
||||
profile = _get_profile(state)
|
||||
@@ -232,12 +213,12 @@ def node_match_brands(state: DetectState) -> dict:
|
||||
)
|
||||
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
|
||||
|
||||
_emit_transition(state, "match_brands", "done")
|
||||
_emit(state, "match_brands", "done")
|
||||
return {"detections": matched, "unresolved_candidates": unresolved}
|
||||
|
||||
|
||||
def node_escalate_vlm(state: DetectState) -> dict:
|
||||
_emit_transition(state, "escalate_vlm", "running")
|
||||
_emit(state, "escalate_vlm", "running")
|
||||
|
||||
with trace_node(state, "escalate_vlm") as span:
|
||||
profile = _get_profile(state)
|
||||
@@ -261,7 +242,7 @@ def node_escalate_vlm(state: DetectState) -> dict:
|
||||
existing = state.get("detections", [])
|
||||
|
||||
vlm_skipped = os.environ.get("SKIP_VLM", "").strip() == "1"
|
||||
_emit_transition(state, "escalate_vlm", "skipped" if vlm_skipped else "done")
|
||||
_emit(state, "escalate_vlm", "skipped" if vlm_skipped else "done")
|
||||
return {
|
||||
"detections": existing + vlm_matched,
|
||||
"unresolved_candidates": still_unresolved,
|
||||
@@ -270,7 +251,7 @@ def node_escalate_vlm(state: DetectState) -> dict:
|
||||
|
||||
|
||||
def node_escalate_cloud(state: DetectState) -> dict:
|
||||
_emit_transition(state, "escalate_cloud", "running")
|
||||
_emit(state, "escalate_cloud", "running")
|
||||
|
||||
with trace_node(state, "escalate_cloud") as span:
|
||||
profile = _get_profile(state)
|
||||
@@ -294,12 +275,12 @@ def node_escalate_cloud(state: DetectState) -> dict:
|
||||
existing = state.get("detections", [])
|
||||
|
||||
cloud_skipped = os.environ.get("SKIP_CLOUD", "").strip() == "1"
|
||||
_emit_transition(state, "escalate_cloud", "skipped" if cloud_skipped else "done")
|
||||
_emit(state, "escalate_cloud", "skipped" if cloud_skipped else "done")
|
||||
return {"detections": existing + cloud_matched, "stats": stats}
|
||||
|
||||
|
||||
def node_compile_report(state: DetectState) -> dict:
|
||||
_emit_transition(state, "compile_report", "running")
|
||||
_emit(state, "compile_report", "running")
|
||||
|
||||
with trace_node(state, "compile_report") as span:
|
||||
profile = _get_profile(state)
|
||||
@@ -318,85 +299,10 @@ def node_compile_report(state: DetectState) -> dict:
|
||||
span.set_output({"brands": len(report.brands), "detections": len(report.timeline)})
|
||||
|
||||
flush_traces()
|
||||
_emit_transition(state, "compile_report", "done")
|
||||
_emit(state, "compile_report", "done")
|
||||
return {"report": report}
|
||||
|
||||
|
||||
# --- Checkpoint wrapper ---
|
||||
|
||||
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
|
||||
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
|
||||
_latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id
|
||||
|
||||
|
||||
class PipelineCancelled(Exception):
|
||||
"""Raised when a pipeline run is cancelled."""
|
||||
pass
|
||||
|
||||
|
||||
# Cancellation hook — set by the run endpoint, checked before each node
|
||||
_cancel_check: dict[str, callable] = {}
|
||||
|
||||
|
||||
def set_cancel_check(job_id: str, fn):
|
||||
_cancel_check[job_id] = fn
|
||||
|
||||
|
||||
def clear_cancel_check(job_id: str):
|
||||
_cancel_check.pop(job_id, None)
|
||||
|
||||
|
||||
def _checkpointing_node(node_name: str, node_fn):
|
||||
"""Wrap a node function to auto-checkpoint after completion."""
|
||||
stage_index = NODES.index(node_name)
|
||||
|
||||
def wrapper(state: DetectState) -> dict:
|
||||
job_id = state.get("job_id", "")
|
||||
check = _cancel_check.get(job_id)
|
||||
if check and check():
|
||||
raise PipelineCancelled(f"Cancelled before {node_name}")
|
||||
|
||||
result = node_fn(state)
|
||||
|
||||
job_id = state.get("job_id", "")
|
||||
if not job_id:
|
||||
return result
|
||||
|
||||
from detect.checkpoint import save_stage_output, save_frames
|
||||
from detect.stages.base import _REGISTRY
|
||||
|
||||
merged = {**state, **result}
|
||||
|
||||
# Save frames once (first node), reuse manifest after
|
||||
manifest = _frames_manifest.get(job_id)
|
||||
if manifest is None and node_name == "extract_frames":
|
||||
manifest = save_frames(job_id, merged.get("frames", []))
|
||||
_frames_manifest[job_id] = manifest
|
||||
|
||||
# Serialize stage output using the stage's serialize_fn if available
|
||||
stage_cls = _REGISTRY.get(node_name)
|
||||
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
||||
if serialize_fn:
|
||||
output_json = serialize_fn(merged, job_id)
|
||||
else:
|
||||
output_json = {}
|
||||
|
||||
parent_id = _latest_checkpoint.get(job_id)
|
||||
new_checkpoint_id = save_stage_output(
|
||||
timeline_id=job_id,
|
||||
parent_checkpoint_id=parent_id,
|
||||
stage_name=node_name,
|
||||
output_json=output_json,
|
||||
)
|
||||
_latest_checkpoint[job_id] = new_checkpoint_id
|
||||
return result
|
||||
|
||||
wrapper.__name__ = node_fn.__name__
|
||||
return wrapper
|
||||
|
||||
|
||||
# --- Graph construction ---
|
||||
|
||||
NODE_FUNCTIONS = [
|
||||
("extract_frames", node_extract_frames),
|
||||
("filter_scenes", node_filter_scenes),
|
||||
@@ -409,41 +315,3 @@ NODE_FUNCTIONS = [
|
||||
("escalate_cloud", node_escalate_cloud),
|
||||
("compile_report", node_compile_report),
|
||||
]
|
||||
|
||||
|
||||
def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph:
|
||||
"""
|
||||
Build the pipeline graph.
|
||||
|
||||
checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var)
|
||||
start_from: skip nodes before this stage (for replay)
|
||||
"""
|
||||
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
|
||||
|
||||
graph = StateGraph(DetectState)
|
||||
|
||||
# Filter to start_from if replaying
|
||||
node_pairs = NODE_FUNCTIONS
|
||||
if start_from:
|
||||
start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from)
|
||||
node_pairs = NODE_FUNCTIONS[start_idx:]
|
||||
|
||||
for name, fn in node_pairs:
|
||||
wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn
|
||||
graph.add_node(name, wrapped)
|
||||
|
||||
# Wire edges
|
||||
entry = node_pairs[0][0]
|
||||
graph.set_entry_point(entry)
|
||||
|
||||
for i in range(len(node_pairs) - 1):
|
||||
graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0])
|
||||
|
||||
graph.add_edge(node_pairs[-1][0], END)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def get_pipeline(checkpoint: bool | None = None):
|
||||
"""Return a compiled, runnable pipeline."""
|
||||
return build_graph(checkpoint=checkpoint).compile()
|
||||
127
detect/graph/runner.py
Normal file
127
detect/graph/runner.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Pipeline runner — executes stages sequentially with checkpointing and cancellation.
|
||||
|
||||
Currently wraps LangGraph for execution. Will be replaced with a lean
|
||||
custom runner in Phase 3, with an executor socket for distributed dispatch.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from langgraph.graph import END, StateGraph
|
||||
|
||||
from detect.state import DetectState
|
||||
from .nodes import NODES, NODE_FUNCTIONS
|
||||
|
||||
|
||||
# --- Checkpoint wrapper ---
|
||||
|
||||
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
|
||||
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
|
||||
_latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id
|
||||
|
||||
|
||||
class PipelineCancelled(Exception):
|
||||
"""Raised when a pipeline run is cancelled."""
|
||||
pass
|
||||
|
||||
|
||||
# Cancellation hook — set by the run endpoint, checked before each node
|
||||
_cancel_check: dict[str, callable] = {}
|
||||
|
||||
|
||||
def set_cancel_check(job_id: str, fn):
|
||||
_cancel_check[job_id] = fn
|
||||
|
||||
|
||||
def clear_cancel_check(job_id: str):
|
||||
_cancel_check.pop(job_id, None)
|
||||
|
||||
|
||||
def _checkpointing_node(node_name: str, node_fn):
|
||||
"""Wrap a node function to auto-checkpoint after completion."""
|
||||
|
||||
def wrapper(state: DetectState) -> dict:
|
||||
job_id = state.get("job_id", "")
|
||||
check = _cancel_check.get(job_id)
|
||||
if check and check():
|
||||
raise PipelineCancelled(f"Cancelled before {node_name}")
|
||||
|
||||
result = node_fn(state)
|
||||
|
||||
job_id = state.get("job_id", "")
|
||||
if not job_id:
|
||||
return result
|
||||
|
||||
from detect.checkpoint import save_stage_output, save_frames
|
||||
from detect.stages.base import _REGISTRY
|
||||
|
||||
merged = {**state, **result}
|
||||
|
||||
# Save frames once (first node), reuse manifest after
|
||||
manifest = _frames_manifest.get(job_id)
|
||||
if manifest is None and node_name == "extract_frames":
|
||||
manifest = save_frames(job_id, merged.get("frames", []))
|
||||
_frames_manifest[job_id] = manifest
|
||||
|
||||
# Serialize stage output using the stage's serialize_fn if available
|
||||
stage_cls = _REGISTRY.get(node_name)
|
||||
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
||||
if serialize_fn:
|
||||
output_json = serialize_fn(merged, job_id)
|
||||
else:
|
||||
output_json = {}
|
||||
|
||||
parent_id = _latest_checkpoint.get(job_id)
|
||||
new_checkpoint_id = save_stage_output(
|
||||
timeline_id=job_id,
|
||||
parent_checkpoint_id=parent_id,
|
||||
stage_name=node_name,
|
||||
output_json=output_json,
|
||||
)
|
||||
_latest_checkpoint[job_id] = new_checkpoint_id
|
||||
return result
|
||||
|
||||
wrapper.__name__ = node_fn.__name__
|
||||
return wrapper
|
||||
|
||||
|
||||
# --- Graph construction ---
|
||||
|
||||
def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph:
|
||||
"""
|
||||
Build the pipeline graph.
|
||||
|
||||
checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var)
|
||||
start_from: skip nodes before this stage (for replay)
|
||||
"""
|
||||
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
|
||||
|
||||
graph = StateGraph(DetectState)
|
||||
|
||||
# Filter to start_from if replaying
|
||||
node_pairs = NODE_FUNCTIONS
|
||||
if start_from:
|
||||
start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from)
|
||||
node_pairs = NODE_FUNCTIONS[start_idx:]
|
||||
|
||||
for name, fn in node_pairs:
|
||||
wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn
|
||||
graph.add_node(name, wrapped)
|
||||
|
||||
# Wire edges
|
||||
entry = node_pairs[0][0]
|
||||
graph.set_entry_point(entry)
|
||||
|
||||
for i in range(len(node_pairs) - 1):
|
||||
graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0])
|
||||
|
||||
graph.add_edge(node_pairs[-1][0], END)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def get_pipeline(checkpoint: bool | None = None):
|
||||
"""Return a compiled, runnable pipeline."""
|
||||
return build_graph(checkpoint=checkpoint).compile()
|
||||
Reference in New Issue
Block a user