phase 4
This commit is contained in:
45
core/detect/graph/__init__.py
Normal file
45
core/detect/graph/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Detection pipeline graph.
|
||||
|
||||
detect/graph/
|
||||
nodes.py — node functions (one per stage)
|
||||
events.py — graph_update SSE emission
|
||||
runner.py — PipelineRunner (config-driven, checkpoint, cancel, pause)
|
||||
"""
|
||||
|
||||
from .nodes import NODES, NODE_FUNCTIONS
|
||||
from .runner import (
|
||||
PipelineCancelled,
|
||||
PipelineRunner,
|
||||
build_graph,
|
||||
clear_cancel_check,
|
||||
clear_pause,
|
||||
get_pipeline,
|
||||
init_pause,
|
||||
is_paused,
|
||||
pause_pipeline,
|
||||
resume_pipeline,
|
||||
set_cancel_check,
|
||||
set_pause_after_stage,
|
||||
step_pipeline,
|
||||
)
|
||||
from .events import _node_states
|
||||
|
||||
__all__ = [
|
||||
"NODES",
|
||||
"NODE_FUNCTIONS",
|
||||
"PipelineCancelled",
|
||||
"PipelineRunner",
|
||||
"build_graph",
|
||||
"get_pipeline",
|
||||
"set_cancel_check",
|
||||
"clear_cancel_check",
|
||||
"init_pause",
|
||||
"clear_pause",
|
||||
"pause_pipeline",
|
||||
"resume_pipeline",
|
||||
"step_pipeline",
|
||||
"set_pause_after_stage",
|
||||
"is_paused",
|
||||
"_node_states",
|
||||
]
|
||||
27
core/detect/graph/events.py
Normal file
27
core/detect/graph/events.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
Graph event emission — node state tracking + SSE graph_update events.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.state import DetectState
|
||||
|
||||
|
||||
# Track node states across pipeline runs
|
||||
_node_states: dict[str, dict[str, str]] = {}
|
||||
|
||||
|
||||
def emit_transition(state: DetectState, node: str, status: str, node_list: list[str]):
|
||||
"""Update node status and emit graph_update SSE event."""
|
||||
job_id = state.get("job_id")
|
||||
if not job_id:
|
||||
return
|
||||
|
||||
if job_id not in _node_states:
|
||||
_node_states[job_id] = {n: "pending" for n in node_list}
|
||||
|
||||
_node_states[job_id][node] = status
|
||||
|
||||
nodes = [{"id": n, "status": _node_states[job_id][n]} for n in node_list]
|
||||
emit.graph_update(job_id, nodes)
|
||||
366
core/detect/graph/nodes.py
Normal file
366
core/detect/graph/nodes.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""
|
||||
Pipeline node functions — one per stage.
|
||||
|
||||
Each node: reads state, gets config from profile dict, runs stage logic,
|
||||
emits transitions, returns output dict.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import CropContext, PipelineStats
|
||||
from core.detect.profile import get_profile, get_stage_config, build_vlm_prompt, aggregate_detections
|
||||
from core.detect.stages.models import (
|
||||
DetectionConfig,
|
||||
FieldSegmentationConfig,
|
||||
FrameExtractionConfig,
|
||||
OCRConfig,
|
||||
RegionAnalysisConfig,
|
||||
ResolverConfig,
|
||||
SceneFilterConfig,
|
||||
)
|
||||
from core.detect.state import DetectState
|
||||
from core.detect.stages.frame_extractor import extract_frames
|
||||
from core.detect.stages.scene_filter import scene_filter
|
||||
from core.detect.stages.field_segmentation import run_field_segmentation
|
||||
from core.detect.stages.edge_detector import detect_edge_regions
|
||||
from core.detect.stages.yolo_detector import detect_objects
|
||||
from core.detect.stages.preprocess import preprocess_regions
|
||||
from core.detect.stages.ocr_stage import run_ocr
|
||||
from core.detect.stages.brand_resolver import resolve_brands
|
||||
from core.detect.stages.vlm_local import escalate_vlm
|
||||
from core.detect.stages.vlm_cloud import escalate_cloud
|
||||
from core.detect.stages.aggregator import compile_report
|
||||
from core.detect.tracing import trace_node, flush as flush_traces
|
||||
|
||||
from .events import emit_transition
|
||||
|
||||
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
|
||||
|
||||
NODES = [
|
||||
"extract_frames",
|
||||
"filter_scenes",
|
||||
"field_segmentation",
|
||||
"detect_edges",
|
||||
"detect_objects",
|
||||
"preprocess",
|
||||
"run_ocr",
|
||||
"match_brands",
|
||||
"escalate_vlm",
|
||||
"escalate_cloud",
|
||||
"compile_report",
|
||||
]
|
||||
|
||||
|
||||
def _load_profile(state: DetectState) -> dict:
|
||||
"""Load profile dict, apply config overrides if present."""
|
||||
name = state.get("profile_name", "soccer_broadcast")
|
||||
profile = get_profile(name)
|
||||
|
||||
overrides = state.get("config_overrides")
|
||||
if overrides:
|
||||
# Merge overrides into a copy of the profile configs
|
||||
merged_configs = dict(profile.get("configs", {}))
|
||||
for stage_name, stage_overrides in overrides.items():
|
||||
if stage_name in merged_configs:
|
||||
merged_configs[stage_name] = {**merged_configs[stage_name], **stage_overrides}
|
||||
else:
|
||||
merged_configs[stage_name] = stage_overrides
|
||||
profile = {**profile, "configs": merged_configs}
|
||||
|
||||
return profile
|
||||
|
||||
|
||||
def _emit(state, node, status):
|
||||
emit_transition(state, node, status, NODES)
|
||||
|
||||
|
||||
# --- Node functions ---
|
||||
|
||||
def node_extract_frames(state: DetectState) -> dict:
|
||||
job_id = state.get("job_id", "")
|
||||
if job_id and not emit._run_context:
|
||||
emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial")
|
||||
|
||||
source_asset_id = state.get("source_asset_id")
|
||||
if source_asset_id and not state.get("session_brands"):
|
||||
from core.detect.stages.brand_resolver import build_session_dict
|
||||
session_brands = build_session_dict(source_asset_id)
|
||||
state["session_brands"] = session_brands
|
||||
|
||||
_emit(state, "extract_frames", "running")
|
||||
|
||||
with trace_node(state, "extract_frames") as span:
|
||||
profile = _load_profile(state)
|
||||
config = FrameExtractionConfig(**get_stage_config(profile, "extract_frames"))
|
||||
frames = extract_frames(state["video_path"], config, job_id=job_id)
|
||||
span.set_output({"frames_extracted": len(frames)})
|
||||
|
||||
_emit(state, "extract_frames", "done")
|
||||
return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))}
|
||||
|
||||
|
||||
def node_filter_scenes(state: DetectState) -> dict:
|
||||
_emit(state, "filter_scenes", "running")
|
||||
|
||||
with trace_node(state, "filter_scenes") as span:
|
||||
profile = _load_profile(state)
|
||||
config = SceneFilterConfig(**get_stage_config(profile, "filter_scenes"))
|
||||
frames = state.get("frames", [])
|
||||
kept = scene_filter(frames, config, job_id=state.get("job_id"))
|
||||
span.set_output({"frames_in": len(frames), "frames_kept": len(kept)})
|
||||
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.frames_after_scene_filter = len(kept)
|
||||
|
||||
_emit(state, "filter_scenes", "done")
|
||||
return {"filtered_frames": kept, "stats": stats}
|
||||
|
||||
|
||||
def node_field_segmentation(state: DetectState) -> dict:
|
||||
_emit(state, "field_segmentation", "running")
|
||||
|
||||
with trace_node(state, "field_segmentation") as span:
|
||||
profile = _load_profile(state)
|
||||
config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation"))
|
||||
frames = state.get("filtered_frames", [])
|
||||
job_id = state.get("job_id")
|
||||
|
||||
result = run_field_segmentation(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
|
||||
span.set_output({
|
||||
"frames": len(frames),
|
||||
"avg_coverage": sum(result["field_coverage"].values()) / max(len(result["field_coverage"]), 1),
|
||||
})
|
||||
|
||||
_emit(state, "field_segmentation", "done")
|
||||
return {
|
||||
"field_masks": result["field_masks"],
|
||||
"field_boundaries": result["field_boundaries"],
|
||||
"field_coverage": result["field_coverage"],
|
||||
}
|
||||
|
||||
|
||||
def node_detect_edges(state: DetectState) -> dict:
|
||||
_emit(state, "detect_edges", "running")
|
||||
|
||||
with trace_node(state, "detect_edges") as span:
|
||||
profile = _load_profile(state)
|
||||
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
||||
frames = state.get("filtered_frames", [])
|
||||
field_masks = state.get("field_masks", {})
|
||||
job_id = state.get("job_id")
|
||||
|
||||
regions = detect_edge_regions(
|
||||
frames, config, inference_url=INFERENCE_URL, job_id=job_id,
|
||||
field_masks=field_masks,
|
||||
)
|
||||
total = sum(len(r) for r in regions.values())
|
||||
span.set_output({"frames": len(frames), "edge_regions": total})
|
||||
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.cv_regions_detected = total
|
||||
|
||||
_emit(state, "detect_edges", "done")
|
||||
return {"edge_regions_by_frame": regions, "stats": stats}
|
||||
|
||||
|
||||
def node_detect_objects(state: DetectState) -> dict:
|
||||
_emit(state, "detect_objects", "running")
|
||||
|
||||
with trace_node(state, "detect_objects") as span:
|
||||
profile = _load_profile(state)
|
||||
config = DetectionConfig(**get_stage_config(profile, "detect_objects"))
|
||||
frames = state.get("filtered_frames", [])
|
||||
job_id = state.get("job_id")
|
||||
|
||||
all_boxes = detect_objects(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
|
||||
total_regions = sum(len(boxes) for boxes in all_boxes.values())
|
||||
span.set_output({"frames": len(frames), "regions_detected": total_regions})
|
||||
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.regions_detected = total_regions
|
||||
|
||||
_emit(state, "detect_objects", "done")
|
||||
return {"boxes_by_frame": all_boxes, "stats": stats}
|
||||
|
||||
|
||||
def node_preprocess(state: DetectState) -> dict:
|
||||
_emit(state, "preprocess", "running")
|
||||
|
||||
with trace_node(state, "preprocess") as span:
|
||||
profile = _load_profile(state)
|
||||
prep_config = get_stage_config(profile, "preprocess")
|
||||
frames = state.get("filtered_frames", [])
|
||||
boxes = state.get("boxes_by_frame", {})
|
||||
job_id = state.get("job_id")
|
||||
|
||||
do_contrast = prep_config.get("contrast", True)
|
||||
do_deskew = prep_config.get("deskew", False)
|
||||
do_binarize = prep_config.get("binarize", False)
|
||||
|
||||
preprocessed = preprocess_regions(
|
||||
frames, boxes,
|
||||
do_contrast=do_contrast,
|
||||
do_deskew=do_deskew,
|
||||
do_binarize=do_binarize,
|
||||
inference_url=INFERENCE_URL,
|
||||
job_id=job_id,
|
||||
)
|
||||
span.set_output({"regions_preprocessed": len(preprocessed)})
|
||||
|
||||
_emit(state, "preprocess", "done")
|
||||
return {"preprocessed_crops": preprocessed}
|
||||
|
||||
|
||||
def node_run_ocr(state: DetectState) -> dict:
|
||||
_emit(state, "run_ocr", "running")
|
||||
|
||||
with trace_node(state, "run_ocr") as span:
|
||||
profile = _load_profile(state)
|
||||
config = OCRConfig(**get_stage_config(profile, "run_ocr"))
|
||||
frames = state.get("filtered_frames", [])
|
||||
boxes = state.get("boxes_by_frame", {})
|
||||
job_id = state.get("job_id")
|
||||
|
||||
candidates = run_ocr(frames, boxes, config, inference_url=INFERENCE_URL, job_id=job_id)
|
||||
span.set_output({"regions_in": sum(len(b) for b in boxes.values()), "text_candidates": len(candidates)})
|
||||
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.regions_resolved_by_ocr = len(candidates)
|
||||
|
||||
_emit(state, "run_ocr", "done")
|
||||
return {"text_candidates": candidates, "stats": stats}
|
||||
|
||||
|
||||
def node_match_brands(state: DetectState) -> dict:
|
||||
_emit(state, "match_brands", "running")
|
||||
|
||||
with trace_node(state, "match_brands") as span:
|
||||
profile = _load_profile(state)
|
||||
config = ResolverConfig(**get_stage_config(profile, "match_brands"))
|
||||
candidates = state.get("text_candidates", [])
|
||||
session_brands = state.get("session_brands", {})
|
||||
job_id = state.get("job_id")
|
||||
source_asset_id = state.get("source_asset_id")
|
||||
|
||||
matched, unresolved = resolve_brands(
|
||||
candidates, config,
|
||||
session_brands=session_brands,
|
||||
source_asset_id=source_asset_id,
|
||||
content_type=profile["name"], job_id=job_id,
|
||||
)
|
||||
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
|
||||
|
||||
_emit(state, "match_brands", "done")
|
||||
return {"detections": matched, "unresolved_candidates": unresolved}
|
||||
|
||||
|
||||
def node_escalate_vlm(state: DetectState) -> dict:
|
||||
_emit(state, "escalate_vlm", "running")
|
||||
|
||||
with trace_node(state, "escalate_vlm") as span:
|
||||
profile = _load_profile(state)
|
||||
vlm_config = get_stage_config(profile, "escalate_vlm")
|
||||
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
|
||||
candidates = state.get("unresolved_candidates", [])
|
||||
job_id = state.get("job_id")
|
||||
|
||||
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
|
||||
|
||||
vlm_matched, still_unresolved = escalate_vlm(
|
||||
candidates,
|
||||
vlm_prompt_fn=vlm_prompt_fn,
|
||||
inference_url=INFERENCE_URL,
|
||||
content_type=profile["name"],
|
||||
source_asset_id=state.get("source_asset_id"),
|
||||
job_id=job_id,
|
||||
)
|
||||
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.regions_escalated_to_local_vlm = len(candidates)
|
||||
span.set_output({"candidates": len(candidates), "matched": len(vlm_matched),
|
||||
"still_unresolved": len(still_unresolved)})
|
||||
|
||||
existing = state.get("detections", [])
|
||||
|
||||
vlm_skipped = os.environ.get("SKIP_VLM", "").strip() == "1"
|
||||
_emit(state, "escalate_vlm", "skipped" if vlm_skipped else "done")
|
||||
return {
|
||||
"detections": existing + vlm_matched,
|
||||
"unresolved_candidates": still_unresolved,
|
||||
"stats": stats,
|
||||
}
|
||||
|
||||
|
||||
def node_escalate_cloud(state: DetectState) -> dict:
|
||||
_emit(state, "escalate_cloud", "running")
|
||||
|
||||
with trace_node(state, "escalate_cloud") as span:
|
||||
profile = _load_profile(state)
|
||||
vlm_config = get_stage_config(profile, "escalate_vlm")
|
||||
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
|
||||
candidates = state.get("unresolved_candidates", [])
|
||||
job_id = state.get("job_id")
|
||||
stats = state.get("stats", PipelineStats())
|
||||
|
||||
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
|
||||
|
||||
cloud_matched = escalate_cloud(
|
||||
candidates,
|
||||
vlm_prompt_fn=vlm_prompt_fn,
|
||||
stats=stats,
|
||||
content_type=profile["name"],
|
||||
source_asset_id=state.get("source_asset_id"),
|
||||
job_id=job_id,
|
||||
)
|
||||
|
||||
span.set_output({"candidates": len(candidates), "matched": len(cloud_matched),
|
||||
"cloud_calls": stats.cloud_llm_calls,
|
||||
"cost_usd": stats.estimated_cloud_cost_usd})
|
||||
|
||||
existing = state.get("detections", [])
|
||||
|
||||
cloud_skipped = os.environ.get("SKIP_CLOUD", "").strip() == "1"
|
||||
_emit(state, "escalate_cloud", "skipped" if cloud_skipped else "done")
|
||||
return {"detections": existing + cloud_matched, "stats": stats}
|
||||
|
||||
|
||||
def node_compile_report(state: DetectState) -> dict:
|
||||
_emit(state, "compile_report", "running")
|
||||
|
||||
with trace_node(state, "compile_report") as span:
|
||||
profile = _load_profile(state)
|
||||
detections = state.get("detections", [])
|
||||
stats = state.get("stats", PipelineStats())
|
||||
job_id = state.get("job_id")
|
||||
|
||||
report = compile_report(
|
||||
detections=detections,
|
||||
stats=stats,
|
||||
video_source=state.get("video_path", ""),
|
||||
content_type=profile["name"],
|
||||
job_id=job_id,
|
||||
)
|
||||
|
||||
span.set_output({"brands": len(report.brands), "detections": len(report.timeline)})
|
||||
|
||||
flush_traces()
|
||||
_emit(state, "compile_report", "done")
|
||||
return {"report": report}
|
||||
|
||||
|
||||
NODE_FUNCTIONS = [
|
||||
("extract_frames", node_extract_frames),
|
||||
("filter_scenes", node_filter_scenes),
|
||||
("field_segmentation", node_field_segmentation),
|
||||
("detect_edges", node_detect_edges),
|
||||
("detect_objects", node_detect_objects),
|
||||
("preprocess", node_preprocess),
|
||||
("run_ocr", node_run_ocr),
|
||||
("match_brands", node_match_brands),
|
||||
("escalate_vlm", node_escalate_vlm),
|
||||
("escalate_cloud", node_escalate_cloud),
|
||||
("compile_report", node_compile_report),
|
||||
]
|
||||
274
core/detect/graph/runner.py
Normal file
274
core/detect/graph/runner.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
Pipeline runner — executes stages sequentially with checkpointing,
|
||||
cancellation, and pause/resume.
|
||||
|
||||
Reads PipelineConfig from the profile to determine what stages to run.
|
||||
Flattens the graph into a linear sequence for now (serial execution).
|
||||
Executor socket: all stages run via LocalExecutor (call function directly).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
|
||||
from core.detect.stages.models import PipelineConfig
|
||||
from core.detect.state import DetectState
|
||||
from .nodes import NODES, NODE_FUNCTIONS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
|
||||
|
||||
|
||||
class PipelineCancelled(Exception):
|
||||
"""Raised when a pipeline run is cancelled."""
|
||||
pass
|
||||
|
||||
|
||||
class PipelinePaused(Exception):
|
||||
"""Raised when a pipeline is paused (internally, for flow control)."""
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cancellation — checked before each node
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_cancel_check: dict[str, callable] = {}
|
||||
|
||||
|
||||
def set_cancel_check(job_id: str, fn):
|
||||
_cancel_check[job_id] = fn
|
||||
|
||||
|
||||
def clear_cancel_check(job_id: str):
|
||||
_cancel_check.pop(job_id, None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pause / Resume / Step — checked after each node completes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_pause_gate: dict[str, threading.Event] = {}
|
||||
_pause_after_stage: dict[str, bool] = {}
|
||||
|
||||
|
||||
def init_pause(job_id: str, pause_after_stage: bool = False):
|
||||
"""Initialize pause state for a job. Called when pipeline starts."""
|
||||
gate = threading.Event()
|
||||
gate.set() # start unpaused
|
||||
_pause_gate[job_id] = gate
|
||||
_pause_after_stage[job_id] = pause_after_stage
|
||||
|
||||
|
||||
def clear_pause(job_id: str):
|
||||
"""Clean up pause state. Called when pipeline finishes."""
|
||||
_pause_gate.pop(job_id, None)
|
||||
_pause_after_stage.pop(job_id, None)
|
||||
|
||||
|
||||
def pause_pipeline(job_id: str):
|
||||
"""Pause a running pipeline. It will block after the current stage completes."""
|
||||
gate = _pause_gate.get(job_id)
|
||||
if gate:
|
||||
gate.clear()
|
||||
logger.info("Pipeline %s paused", job_id)
|
||||
|
||||
|
||||
def resume_pipeline(job_id: str):
|
||||
"""Resume a paused pipeline."""
|
||||
gate = _pause_gate.get(job_id)
|
||||
if gate:
|
||||
gate.set()
|
||||
logger.info("Pipeline %s resumed", job_id)
|
||||
|
||||
|
||||
def step_pipeline(job_id: str):
|
||||
"""Run one stage then pause again."""
|
||||
_pause_after_stage[job_id] = True
|
||||
gate = _pause_gate.get(job_id)
|
||||
if gate:
|
||||
gate.set()
|
||||
logger.info("Pipeline %s stepping", job_id)
|
||||
|
||||
|
||||
def set_pause_after_stage(job_id: str, enabled: bool):
|
||||
"""Toggle pause-after-each-stage mode."""
|
||||
_pause_after_stage[job_id] = enabled
|
||||
if not enabled:
|
||||
gate = _pause_gate.get(job_id)
|
||||
if gate:
|
||||
gate.set()
|
||||
|
||||
|
||||
def is_paused(job_id: str) -> bool:
|
||||
"""Check if a pipeline is currently paused."""
|
||||
gate = _pause_gate.get(job_id)
|
||||
return gate is not None and not gate.is_set()
|
||||
|
||||
|
||||
def _wait_if_paused(job_id: str, node_name: str):
|
||||
"""Block until resumed. Called after each node completes."""
|
||||
gate = _pause_gate.get(job_id)
|
||||
if gate is None:
|
||||
return
|
||||
|
||||
if _pause_after_stage.get(job_id, False):
|
||||
gate.clear()
|
||||
from core.detect import emit
|
||||
emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}")
|
||||
|
||||
while not gate.wait(timeout=0.5):
|
||||
check = _cancel_check.get(job_id)
|
||||
if check and check():
|
||||
raise PipelineCancelled(f"Cancelled while paused before next stage")
|
||||
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pipeline Runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Node function lookup — maps stage name to callable
|
||||
_NODE_FN_MAP: dict[str, callable] = {name: fn for name, fn in NODE_FUNCTIONS}
|
||||
|
||||
|
||||
def _flatten_config(config: PipelineConfig, start_from: str | None = None) -> list[str]:
|
||||
"""
|
||||
Flatten a PipelineConfig into a linear stage sequence.
|
||||
|
||||
For now: topological sort via edges. Falls back to stage order if no edges.
|
||||
Respects start_from for replay (skip stages before it).
|
||||
"""
|
||||
if not config.edges:
|
||||
# No edges defined — use stage order as-is
|
||||
names = [s.name for s in config.stages]
|
||||
else:
|
||||
# Topological sort from edges
|
||||
graph: dict[str, list[str]] = {}
|
||||
in_degree: dict[str, int] = {}
|
||||
stage_names = {s.name for s in config.stages}
|
||||
|
||||
for name in stage_names:
|
||||
graph[name] = []
|
||||
in_degree[name] = 0
|
||||
|
||||
for edge in config.edges:
|
||||
if edge.source in stage_names and edge.target in stage_names:
|
||||
graph[edge.source].append(edge.target)
|
||||
in_degree[edge.target] = in_degree.get(edge.target, 0) + 1
|
||||
|
||||
# Kahn's algorithm
|
||||
queue = [n for n in stage_names if in_degree.get(n, 0) == 0]
|
||||
# Stable sort: prefer order from config.stages
|
||||
stage_order = {s.name: i for i, s in enumerate(config.stages)}
|
||||
queue.sort(key=lambda n: stage_order.get(n, 999))
|
||||
|
||||
names = []
|
||||
while queue:
|
||||
node = queue.pop(0)
|
||||
names.append(node)
|
||||
for neighbor in graph.get(node, []):
|
||||
in_degree[neighbor] -= 1
|
||||
if in_degree[neighbor] == 0:
|
||||
queue.append(neighbor)
|
||||
queue.sort(key=lambda n: stage_order.get(n, 999))
|
||||
|
||||
if start_from:
|
||||
try:
|
||||
idx = names.index(start_from)
|
||||
names = names[idx:]
|
||||
except ValueError:
|
||||
raise ValueError(f"Stage {start_from!r} not in pipeline config")
|
||||
|
||||
return names
|
||||
|
||||
|
||||
class PipelineRunner:
|
||||
"""
|
||||
Executes a pipeline defined by PipelineConfig.
|
||||
|
||||
Runs stages sequentially (flattened). Each stage:
|
||||
1. Check cancel
|
||||
2. Run node function (via executor — local for now)
|
||||
3. Merge result into state
|
||||
4. Checkpoint (if enabled)
|
||||
5. Check pause
|
||||
|
||||
Executor socket: currently calls node functions directly.
|
||||
Future: dispatch to LocalExecutor / GrpcExecutor / LambdaExecutor
|
||||
based on StageRef.execution_target.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PipelineConfig,
|
||||
checkpoint: bool = False,
|
||||
start_from: str | None = None,
|
||||
):
|
||||
self.config = config
|
||||
self.do_checkpoint = checkpoint
|
||||
self.stage_sequence = _flatten_config(config, start_from)
|
||||
|
||||
def invoke(self, state: DetectState) -> DetectState:
|
||||
"""Run the pipeline on the given state. Returns final state."""
|
||||
for stage_name in self.stage_sequence:
|
||||
job_id = state.get("job_id", "")
|
||||
|
||||
# 1. Cancel check
|
||||
check = _cancel_check.get(job_id)
|
||||
if check and check():
|
||||
raise PipelineCancelled(f"Cancelled before {stage_name}")
|
||||
|
||||
# 2. Run node function
|
||||
node_fn = _NODE_FN_MAP.get(stage_name)
|
||||
if node_fn is None:
|
||||
logger.warning("No node function for stage %s, skipping", stage_name)
|
||||
continue
|
||||
|
||||
result = node_fn(state)
|
||||
|
||||
# 3. Merge result into state
|
||||
state.update(result)
|
||||
|
||||
# 4. Checkpoint
|
||||
if self.do_checkpoint:
|
||||
from core.detect.checkpoint import checkpoint_after_stage
|
||||
checkpoint_after_stage(job_id, stage_name, state, result)
|
||||
|
||||
# 5. Pause check
|
||||
_wait_if_paused(job_id, stage_name)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API — backwards compatible with old get_pipeline/build_graph
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def get_pipeline(
|
||||
checkpoint: bool | None = None,
|
||||
profile_name: str = "soccer_broadcast",
|
||||
start_from: str | None = None,
|
||||
) -> PipelineRunner:
|
||||
"""Return a PipelineRunner for the given profile."""
|
||||
from core.detect.profile import get_profile, pipeline_config_from_dict
|
||||
|
||||
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
|
||||
profile = get_profile(profile_name)
|
||||
config = pipeline_config_from_dict(profile["pipeline"])
|
||||
|
||||
return PipelineRunner(
|
||||
config=config,
|
||||
checkpoint=do_checkpoint,
|
||||
start_from=start_from,
|
||||
)
|
||||
|
||||
|
||||
def build_graph(checkpoint: bool | None = None, start_from: str | None = None):
|
||||
"""Backwards-compatible wrapper. Returns a PipelineRunner."""
|
||||
return get_pipeline(checkpoint=checkpoint, start_from=start_from)
|
||||
Reference in New Issue
Block a user