This commit is contained in:
2026-03-30 07:22:14 -03:00
parent d0707333fd
commit 4220b0418e
182 changed files with 3668 additions and 5231 deletions

View File

@@ -0,0 +1,45 @@
"""
Detection pipeline graph.
detect/graph/
nodes.py — node functions (one per stage)
events.py — graph_update SSE emission
runner.py — PipelineRunner (config-driven, checkpoint, cancel, pause)
"""
from .nodes import NODES, NODE_FUNCTIONS
from .runner import (
PipelineCancelled,
PipelineRunner,
build_graph,
clear_cancel_check,
clear_pause,
get_pipeline,
init_pause,
is_paused,
pause_pipeline,
resume_pipeline,
set_cancel_check,
set_pause_after_stage,
step_pipeline,
)
from .events import _node_states
__all__ = [
"NODES",
"NODE_FUNCTIONS",
"PipelineCancelled",
"PipelineRunner",
"build_graph",
"get_pipeline",
"set_cancel_check",
"clear_cancel_check",
"init_pause",
"clear_pause",
"pause_pipeline",
"resume_pipeline",
"step_pipeline",
"set_pause_after_stage",
"is_paused",
"_node_states",
]

View File

@@ -0,0 +1,27 @@
"""
Graph event emission — node state tracking + SSE graph_update events.
"""
from __future__ import annotations
from core.detect import emit
from core.detect.state import DetectState
# Track node states across pipeline runs
_node_states: dict[str, dict[str, str]] = {}
def emit_transition(state: DetectState, node: str, status: str, node_list: list[str]):
"""Update node status and emit graph_update SSE event."""
job_id = state.get("job_id")
if not job_id:
return
if job_id not in _node_states:
_node_states[job_id] = {n: "pending" for n in node_list}
_node_states[job_id][node] = status
nodes = [{"id": n, "status": _node_states[job_id][n]} for n in node_list]
emit.graph_update(job_id, nodes)

366
core/detect/graph/nodes.py Normal file
View File

@@ -0,0 +1,366 @@
"""
Pipeline node functions — one per stage.
Each node: reads state, gets config from profile dict, runs stage logic,
emits transitions, returns output dict.
"""
from __future__ import annotations
import os
from core.detect import emit
from core.detect.models import CropContext, PipelineStats
from core.detect.profile import get_profile, get_stage_config, build_vlm_prompt, aggregate_detections
from core.detect.stages.models import (
DetectionConfig,
FieldSegmentationConfig,
FrameExtractionConfig,
OCRConfig,
RegionAnalysisConfig,
ResolverConfig,
SceneFilterConfig,
)
from core.detect.state import DetectState
from core.detect.stages.frame_extractor import extract_frames
from core.detect.stages.scene_filter import scene_filter
from core.detect.stages.field_segmentation import run_field_segmentation
from core.detect.stages.edge_detector import detect_edge_regions
from core.detect.stages.yolo_detector import detect_objects
from core.detect.stages.preprocess import preprocess_regions
from core.detect.stages.ocr_stage import run_ocr
from core.detect.stages.brand_resolver import resolve_brands
from core.detect.stages.vlm_local import escalate_vlm
from core.detect.stages.vlm_cloud import escalate_cloud
from core.detect.stages.aggregator import compile_report
from core.detect.tracing import trace_node, flush as flush_traces
from .events import emit_transition
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
NODES = [
"extract_frames",
"filter_scenes",
"field_segmentation",
"detect_edges",
"detect_objects",
"preprocess",
"run_ocr",
"match_brands",
"escalate_vlm",
"escalate_cloud",
"compile_report",
]
def _load_profile(state: DetectState) -> dict:
"""Load profile dict, apply config overrides if present."""
name = state.get("profile_name", "soccer_broadcast")
profile = get_profile(name)
overrides = state.get("config_overrides")
if overrides:
# Merge overrides into a copy of the profile configs
merged_configs = dict(profile.get("configs", {}))
for stage_name, stage_overrides in overrides.items():
if stage_name in merged_configs:
merged_configs[stage_name] = {**merged_configs[stage_name], **stage_overrides}
else:
merged_configs[stage_name] = stage_overrides
profile = {**profile, "configs": merged_configs}
return profile
def _emit(state, node, status):
emit_transition(state, node, status, NODES)
# --- Node functions ---
def node_extract_frames(state: DetectState) -> dict:
job_id = state.get("job_id", "")
if job_id and not emit._run_context:
emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial")
source_asset_id = state.get("source_asset_id")
if source_asset_id and not state.get("session_brands"):
from core.detect.stages.brand_resolver import build_session_dict
session_brands = build_session_dict(source_asset_id)
state["session_brands"] = session_brands
_emit(state, "extract_frames", "running")
with trace_node(state, "extract_frames") as span:
profile = _load_profile(state)
config = FrameExtractionConfig(**get_stage_config(profile, "extract_frames"))
frames = extract_frames(state["video_path"], config, job_id=job_id)
span.set_output({"frames_extracted": len(frames)})
_emit(state, "extract_frames", "done")
return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))}
def node_filter_scenes(state: DetectState) -> dict:
_emit(state, "filter_scenes", "running")
with trace_node(state, "filter_scenes") as span:
profile = _load_profile(state)
config = SceneFilterConfig(**get_stage_config(profile, "filter_scenes"))
frames = state.get("frames", [])
kept = scene_filter(frames, config, job_id=state.get("job_id"))
span.set_output({"frames_in": len(frames), "frames_kept": len(kept)})
stats = state.get("stats", PipelineStats())
stats.frames_after_scene_filter = len(kept)
_emit(state, "filter_scenes", "done")
return {"filtered_frames": kept, "stats": stats}
def node_field_segmentation(state: DetectState) -> dict:
_emit(state, "field_segmentation", "running")
with trace_node(state, "field_segmentation") as span:
profile = _load_profile(state)
config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation"))
frames = state.get("filtered_frames", [])
job_id = state.get("job_id")
result = run_field_segmentation(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
span.set_output({
"frames": len(frames),
"avg_coverage": sum(result["field_coverage"].values()) / max(len(result["field_coverage"]), 1),
})
_emit(state, "field_segmentation", "done")
return {
"field_masks": result["field_masks"],
"field_boundaries": result["field_boundaries"],
"field_coverage": result["field_coverage"],
}
def node_detect_edges(state: DetectState) -> dict:
_emit(state, "detect_edges", "running")
with trace_node(state, "detect_edges") as span:
profile = _load_profile(state)
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
frames = state.get("filtered_frames", [])
field_masks = state.get("field_masks", {})
job_id = state.get("job_id")
regions = detect_edge_regions(
frames, config, inference_url=INFERENCE_URL, job_id=job_id,
field_masks=field_masks,
)
total = sum(len(r) for r in regions.values())
span.set_output({"frames": len(frames), "edge_regions": total})
stats = state.get("stats", PipelineStats())
stats.cv_regions_detected = total
_emit(state, "detect_edges", "done")
return {"edge_regions_by_frame": regions, "stats": stats}
def node_detect_objects(state: DetectState) -> dict:
_emit(state, "detect_objects", "running")
with trace_node(state, "detect_objects") as span:
profile = _load_profile(state)
config = DetectionConfig(**get_stage_config(profile, "detect_objects"))
frames = state.get("filtered_frames", [])
job_id = state.get("job_id")
all_boxes = detect_objects(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
total_regions = sum(len(boxes) for boxes in all_boxes.values())
span.set_output({"frames": len(frames), "regions_detected": total_regions})
stats = state.get("stats", PipelineStats())
stats.regions_detected = total_regions
_emit(state, "detect_objects", "done")
return {"boxes_by_frame": all_boxes, "stats": stats}
def node_preprocess(state: DetectState) -> dict:
_emit(state, "preprocess", "running")
with trace_node(state, "preprocess") as span:
profile = _load_profile(state)
prep_config = get_stage_config(profile, "preprocess")
frames = state.get("filtered_frames", [])
boxes = state.get("boxes_by_frame", {})
job_id = state.get("job_id")
do_contrast = prep_config.get("contrast", True)
do_deskew = prep_config.get("deskew", False)
do_binarize = prep_config.get("binarize", False)
preprocessed = preprocess_regions(
frames, boxes,
do_contrast=do_contrast,
do_deskew=do_deskew,
do_binarize=do_binarize,
inference_url=INFERENCE_URL,
job_id=job_id,
)
span.set_output({"regions_preprocessed": len(preprocessed)})
_emit(state, "preprocess", "done")
return {"preprocessed_crops": preprocessed}
def node_run_ocr(state: DetectState) -> dict:
_emit(state, "run_ocr", "running")
with trace_node(state, "run_ocr") as span:
profile = _load_profile(state)
config = OCRConfig(**get_stage_config(profile, "run_ocr"))
frames = state.get("filtered_frames", [])
boxes = state.get("boxes_by_frame", {})
job_id = state.get("job_id")
candidates = run_ocr(frames, boxes, config, inference_url=INFERENCE_URL, job_id=job_id)
span.set_output({"regions_in": sum(len(b) for b in boxes.values()), "text_candidates": len(candidates)})
stats = state.get("stats", PipelineStats())
stats.regions_resolved_by_ocr = len(candidates)
_emit(state, "run_ocr", "done")
return {"text_candidates": candidates, "stats": stats}
def node_match_brands(state: DetectState) -> dict:
_emit(state, "match_brands", "running")
with trace_node(state, "match_brands") as span:
profile = _load_profile(state)
config = ResolverConfig(**get_stage_config(profile, "match_brands"))
candidates = state.get("text_candidates", [])
session_brands = state.get("session_brands", {})
job_id = state.get("job_id")
source_asset_id = state.get("source_asset_id")
matched, unresolved = resolve_brands(
candidates, config,
session_brands=session_brands,
source_asset_id=source_asset_id,
content_type=profile["name"], job_id=job_id,
)
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
_emit(state, "match_brands", "done")
return {"detections": matched, "unresolved_candidates": unresolved}
def node_escalate_vlm(state: DetectState) -> dict:
_emit(state, "escalate_vlm", "running")
with trace_node(state, "escalate_vlm") as span:
profile = _load_profile(state)
vlm_config = get_stage_config(profile, "escalate_vlm")
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
candidates = state.get("unresolved_candidates", [])
job_id = state.get("job_id")
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
vlm_matched, still_unresolved = escalate_vlm(
candidates,
vlm_prompt_fn=vlm_prompt_fn,
inference_url=INFERENCE_URL,
content_type=profile["name"],
source_asset_id=state.get("source_asset_id"),
job_id=job_id,
)
stats = state.get("stats", PipelineStats())
stats.regions_escalated_to_local_vlm = len(candidates)
span.set_output({"candidates": len(candidates), "matched": len(vlm_matched),
"still_unresolved": len(still_unresolved)})
existing = state.get("detections", [])
vlm_skipped = os.environ.get("SKIP_VLM", "").strip() == "1"
_emit(state, "escalate_vlm", "skipped" if vlm_skipped else "done")
return {
"detections": existing + vlm_matched,
"unresolved_candidates": still_unresolved,
"stats": stats,
}
def node_escalate_cloud(state: DetectState) -> dict:
_emit(state, "escalate_cloud", "running")
with trace_node(state, "escalate_cloud") as span:
profile = _load_profile(state)
vlm_config = get_stage_config(profile, "escalate_vlm")
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
candidates = state.get("unresolved_candidates", [])
job_id = state.get("job_id")
stats = state.get("stats", PipelineStats())
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
cloud_matched = escalate_cloud(
candidates,
vlm_prompt_fn=vlm_prompt_fn,
stats=stats,
content_type=profile["name"],
source_asset_id=state.get("source_asset_id"),
job_id=job_id,
)
span.set_output({"candidates": len(candidates), "matched": len(cloud_matched),
"cloud_calls": stats.cloud_llm_calls,
"cost_usd": stats.estimated_cloud_cost_usd})
existing = state.get("detections", [])
cloud_skipped = os.environ.get("SKIP_CLOUD", "").strip() == "1"
_emit(state, "escalate_cloud", "skipped" if cloud_skipped else "done")
return {"detections": existing + cloud_matched, "stats": stats}
def node_compile_report(state: DetectState) -> dict:
_emit(state, "compile_report", "running")
with trace_node(state, "compile_report") as span:
profile = _load_profile(state)
detections = state.get("detections", [])
stats = state.get("stats", PipelineStats())
job_id = state.get("job_id")
report = compile_report(
detections=detections,
stats=stats,
video_source=state.get("video_path", ""),
content_type=profile["name"],
job_id=job_id,
)
span.set_output({"brands": len(report.brands), "detections": len(report.timeline)})
flush_traces()
_emit(state, "compile_report", "done")
return {"report": report}
NODE_FUNCTIONS = [
("extract_frames", node_extract_frames),
("filter_scenes", node_filter_scenes),
("field_segmentation", node_field_segmentation),
("detect_edges", node_detect_edges),
("detect_objects", node_detect_objects),
("preprocess", node_preprocess),
("run_ocr", node_run_ocr),
("match_brands", node_match_brands),
("escalate_vlm", node_escalate_vlm),
("escalate_cloud", node_escalate_cloud),
("compile_report", node_compile_report),
]

274
core/detect/graph/runner.py Normal file
View File

@@ -0,0 +1,274 @@
"""
Pipeline runner — executes stages sequentially with checkpointing,
cancellation, and pause/resume.
Reads PipelineConfig from the profile to determine what stages to run.
Flattens the graph into a linear sequence for now (serial execution).
Executor socket: all stages run via LocalExecutor (call function directly).
"""
from __future__ import annotations
import logging
import os
import threading
from core.detect.stages.models import PipelineConfig
from core.detect.state import DetectState
from .nodes import NODES, NODE_FUNCTIONS
logger = logging.getLogger(__name__)
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
class PipelineCancelled(Exception):
"""Raised when a pipeline run is cancelled."""
pass
class PipelinePaused(Exception):
"""Raised when a pipeline is paused (internally, for flow control)."""
pass
# ---------------------------------------------------------------------------
# Cancellation — checked before each node
# ---------------------------------------------------------------------------
_cancel_check: dict[str, callable] = {}
def set_cancel_check(job_id: str, fn):
_cancel_check[job_id] = fn
def clear_cancel_check(job_id: str):
_cancel_check.pop(job_id, None)
# ---------------------------------------------------------------------------
# Pause / Resume / Step — checked after each node completes
# ---------------------------------------------------------------------------
_pause_gate: dict[str, threading.Event] = {}
_pause_after_stage: dict[str, bool] = {}
def init_pause(job_id: str, pause_after_stage: bool = False):
"""Initialize pause state for a job. Called when pipeline starts."""
gate = threading.Event()
gate.set() # start unpaused
_pause_gate[job_id] = gate
_pause_after_stage[job_id] = pause_after_stage
def clear_pause(job_id: str):
"""Clean up pause state. Called when pipeline finishes."""
_pause_gate.pop(job_id, None)
_pause_after_stage.pop(job_id, None)
def pause_pipeline(job_id: str):
"""Pause a running pipeline. It will block after the current stage completes."""
gate = _pause_gate.get(job_id)
if gate:
gate.clear()
logger.info("Pipeline %s paused", job_id)
def resume_pipeline(job_id: str):
"""Resume a paused pipeline."""
gate = _pause_gate.get(job_id)
if gate:
gate.set()
logger.info("Pipeline %s resumed", job_id)
def step_pipeline(job_id: str):
"""Run one stage then pause again."""
_pause_after_stage[job_id] = True
gate = _pause_gate.get(job_id)
if gate:
gate.set()
logger.info("Pipeline %s stepping", job_id)
def set_pause_after_stage(job_id: str, enabled: bool):
"""Toggle pause-after-each-stage mode."""
_pause_after_stage[job_id] = enabled
if not enabled:
gate = _pause_gate.get(job_id)
if gate:
gate.set()
def is_paused(job_id: str) -> bool:
"""Check if a pipeline is currently paused."""
gate = _pause_gate.get(job_id)
return gate is not None and not gate.is_set()
def _wait_if_paused(job_id: str, node_name: str):
"""Block until resumed. Called after each node completes."""
gate = _pause_gate.get(job_id)
if gate is None:
return
if _pause_after_stage.get(job_id, False):
gate.clear()
from core.detect import emit
emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}")
while not gate.wait(timeout=0.5):
check = _cancel_check.get(job_id)
if check and check():
raise PipelineCancelled(f"Cancelled while paused before next stage")
# ---------------------------------------------------------------------------
# Pipeline Runner
# ---------------------------------------------------------------------------
# Node function lookup — maps stage name to callable
_NODE_FN_MAP: dict[str, callable] = {name: fn for name, fn in NODE_FUNCTIONS}
def _flatten_config(config: PipelineConfig, start_from: str | None = None) -> list[str]:
"""
Flatten a PipelineConfig into a linear stage sequence.
For now: topological sort via edges. Falls back to stage order if no edges.
Respects start_from for replay (skip stages before it).
"""
if not config.edges:
# No edges defined — use stage order as-is
names = [s.name for s in config.stages]
else:
# Topological sort from edges
graph: dict[str, list[str]] = {}
in_degree: dict[str, int] = {}
stage_names = {s.name for s in config.stages}
for name in stage_names:
graph[name] = []
in_degree[name] = 0
for edge in config.edges:
if edge.source in stage_names and edge.target in stage_names:
graph[edge.source].append(edge.target)
in_degree[edge.target] = in_degree.get(edge.target, 0) + 1
# Kahn's algorithm
queue = [n for n in stage_names if in_degree.get(n, 0) == 0]
# Stable sort: prefer order from config.stages
stage_order = {s.name: i for i, s in enumerate(config.stages)}
queue.sort(key=lambda n: stage_order.get(n, 999))
names = []
while queue:
node = queue.pop(0)
names.append(node)
for neighbor in graph.get(node, []):
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
queue.sort(key=lambda n: stage_order.get(n, 999))
if start_from:
try:
idx = names.index(start_from)
names = names[idx:]
except ValueError:
raise ValueError(f"Stage {start_from!r} not in pipeline config")
return names
class PipelineRunner:
"""
Executes a pipeline defined by PipelineConfig.
Runs stages sequentially (flattened). Each stage:
1. Check cancel
2. Run node function (via executor — local for now)
3. Merge result into state
4. Checkpoint (if enabled)
5. Check pause
Executor socket: currently calls node functions directly.
Future: dispatch to LocalExecutor / GrpcExecutor / LambdaExecutor
based on StageRef.execution_target.
"""
def __init__(
self,
config: PipelineConfig,
checkpoint: bool = False,
start_from: str | None = None,
):
self.config = config
self.do_checkpoint = checkpoint
self.stage_sequence = _flatten_config(config, start_from)
def invoke(self, state: DetectState) -> DetectState:
"""Run the pipeline on the given state. Returns final state."""
for stage_name in self.stage_sequence:
job_id = state.get("job_id", "")
# 1. Cancel check
check = _cancel_check.get(job_id)
if check and check():
raise PipelineCancelled(f"Cancelled before {stage_name}")
# 2. Run node function
node_fn = _NODE_FN_MAP.get(stage_name)
if node_fn is None:
logger.warning("No node function for stage %s, skipping", stage_name)
continue
result = node_fn(state)
# 3. Merge result into state
state.update(result)
# 4. Checkpoint
if self.do_checkpoint:
from core.detect.checkpoint import checkpoint_after_stage
checkpoint_after_stage(job_id, stage_name, state, result)
# 5. Pause check
_wait_if_paused(job_id, stage_name)
return state
# ---------------------------------------------------------------------------
# Public API — backwards compatible with old get_pipeline/build_graph
# ---------------------------------------------------------------------------
def get_pipeline(
checkpoint: bool | None = None,
profile_name: str = "soccer_broadcast",
start_from: str | None = None,
) -> PipelineRunner:
"""Return a PipelineRunner for the given profile."""
from core.detect.profile import get_profile, pipeline_config_from_dict
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
profile = get_profile(profile_name)
config = pipeline_config_from_dict(profile["pipeline"])
return PipelineRunner(
config=config,
checkpoint=do_checkpoint,
start_from=start_from,
)
def build_graph(checkpoint: bool | None = None, start_from: str | None = None):
"""Backwards-compatible wrapper. Returns a PipelineRunner."""
return get_pipeline(checkpoint=checkpoint, start_from=start_from)