This commit is contained in:
2026-03-28 08:46:06 -03:00
parent acc99e691d
commit 0bd3888155
30 changed files with 390 additions and 1044 deletions

29
detect/graph/__init__.py Normal file
View File

@@ -0,0 +1,29 @@
"""
Detection pipeline graph.
detect/graph/
nodes.py — node functions (one per stage)
events.py — graph_update SSE emission
runner.py — pipeline execution (LangGraph wrapper, checkpoint, cancel)
"""
from .nodes import NODES, NODE_FUNCTIONS
from .runner import (
PipelineCancelled,
build_graph,
clear_cancel_check,
get_pipeline,
set_cancel_check,
)
from .events import _node_states
__all__ = [
"NODES",
"NODE_FUNCTIONS",
"PipelineCancelled",
"build_graph",
"get_pipeline",
"set_cancel_check",
"clear_cancel_check",
"_node_states",
]

27
detect/graph/events.py Normal file
View File

@@ -0,0 +1,27 @@
"""
Graph event emission — node state tracking + SSE graph_update events.
"""
from __future__ import annotations
from detect import emit
from detect.state import DetectState
# Track node states across pipeline runs
_node_states: dict[str, dict[str, str]] = {}
def emit_transition(state: DetectState, node: str, status: str, node_list: list[str]):
"""Update node status and emit graph_update SSE event."""
job_id = state.get("job_id")
if not job_id:
return
if job_id not in _node_states:
_node_states[job_id] = {n: "pending" for n in node_list}
_node_states[job_id][node] = status
nodes = [{"id": n, "status": _node_states[job_id][n]} for n in node_list]
emit.graph_update(job_id, nodes)

317
detect/graph/nodes.py Normal file
View File

@@ -0,0 +1,317 @@
"""
Pipeline node functions — one per stage.
Each node: reads state, runs stage logic, emits transitions, returns output dict.
"""
from __future__ import annotations
import os
from detect import emit
from detect.models import PipelineStats
from detect.profiles import SoccerBroadcastProfile
from detect.state import DetectState
from detect.stages.frame_extractor import extract_frames
from detect.stages.scene_filter import scene_filter
from detect.stages.edge_detector import detect_edge_regions
from detect.stages.yolo_detector import detect_objects
from detect.stages.preprocess import preprocess_regions
from detect.stages.ocr_stage import run_ocr
from detect.stages.brand_resolver import resolve_brands
from detect.stages.vlm_local import escalate_vlm
from detect.stages.vlm_cloud import escalate_cloud
from detect.stages.aggregator import compile_report
from detect.tracing import trace_node, flush as flush_traces
from .events import emit_transition
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
NODES = [
"extract_frames",
"filter_scenes",
"detect_edges",
"detect_objects",
"preprocess",
"run_ocr",
"match_brands",
"escalate_vlm",
"escalate_cloud",
"compile_report",
]
def _get_profile(state: DetectState):
name = state.get("profile_name", "soccer_broadcast")
if name == "soccer_broadcast":
profile = SoccerBroadcastProfile()
else:
raise ValueError(f"Unknown profile: {name}")
overrides = state.get("config_overrides")
if overrides:
from detect.checkpoint.replay import OverrideProfile
profile = OverrideProfile(profile, overrides)
return profile
def _emit(state, node, status):
emit_transition(state, node, status, NODES)
# --- Node functions ---
def node_extract_frames(state: DetectState) -> dict:
job_id = state.get("job_id", "")
if job_id and not emit._run_context:
emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial")
source_asset_id = state.get("source_asset_id")
if source_asset_id and not state.get("session_brands"):
from detect.stages.brand_resolver import build_session_dict
session_brands = build_session_dict(source_asset_id)
state["session_brands"] = session_brands
_emit(state, "extract_frames", "running")
with trace_node(state, "extract_frames") as span:
profile = _get_profile(state)
config = profile.frame_extraction_config()
frames = extract_frames(state["video_path"], config, job_id=state.get("job_id"))
span.set_output({"frames_extracted": len(frames)})
_emit(state, "extract_frames", "done")
return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))}
def node_filter_scenes(state: DetectState) -> dict:
_emit(state, "filter_scenes", "running")
with trace_node(state, "filter_scenes") as span:
profile = _get_profile(state)
config = profile.scene_filter_config()
frames = state.get("frames", [])
kept = scene_filter(frames, config, job_id=state.get("job_id"))
span.set_output({"frames_in": len(frames), "frames_kept": len(kept)})
stats = state.get("stats", PipelineStats())
stats.frames_after_scene_filter = len(kept)
_emit(state, "filter_scenes", "done")
return {"filtered_frames": kept, "stats": stats}
def node_detect_edges(state: DetectState) -> dict:
_emit(state, "detect_edges", "running")
with trace_node(state, "detect_edges") as span:
profile = _get_profile(state)
config = profile.region_analysis_config()
frames = state.get("filtered_frames", [])
job_id = state.get("job_id")
regions = detect_edge_regions(
frames, config, inference_url=INFERENCE_URL, job_id=job_id,
)
total = sum(len(r) for r in regions.values())
span.set_output({"frames": len(frames), "edge_regions": total})
stats = state.get("stats", PipelineStats())
stats.cv_regions_detected = total
_emit(state, "detect_edges", "done")
return {"edge_regions_by_frame": regions, "stats": stats}
def node_detect_objects(state: DetectState) -> dict:
_emit(state, "detect_objects", "running")
with trace_node(state, "detect_objects") as span:
profile = _get_profile(state)
config = profile.detection_config()
frames = state.get("filtered_frames", [])
job_id = state.get("job_id")
all_boxes = detect_objects(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
total_regions = sum(len(boxes) for boxes in all_boxes.values())
span.set_output({"frames": len(frames), "regions_detected": total_regions})
stats = state.get("stats", PipelineStats())
stats.regions_detected = total_regions
_emit(state, "detect_objects", "done")
return {"boxes_by_frame": all_boxes, "stats": stats}
def node_preprocess(state: DetectState) -> dict:
_emit(state, "preprocess", "running")
with trace_node(state, "preprocess") as span:
profile = _get_profile(state)
frames = state.get("filtered_frames", [])
boxes = state.get("boxes_by_frame", {})
job_id = state.get("job_id")
overrides = state.get("config_overrides", {})
prep_config = overrides.get("preprocessing", {})
do_contrast = prep_config.get("contrast", True)
do_deskew = prep_config.get("deskew", False)
do_binarize = prep_config.get("binarize", False)
preprocessed = preprocess_regions(
frames, boxes,
do_contrast=do_contrast,
do_deskew=do_deskew,
do_binarize=do_binarize,
inference_url=INFERENCE_URL,
job_id=job_id,
)
span.set_output({"regions_preprocessed": len(preprocessed)})
_emit(state, "preprocess", "done")
return {"preprocessed_crops": preprocessed}
def node_run_ocr(state: DetectState) -> dict:
_emit(state, "run_ocr", "running")
with trace_node(state, "run_ocr") as span:
profile = _get_profile(state)
config = profile.ocr_config()
frames = state.get("filtered_frames", [])
boxes = state.get("boxes_by_frame", {})
job_id = state.get("job_id")
candidates = run_ocr(frames, boxes, config, inference_url=INFERENCE_URL, job_id=job_id)
span.set_output({"regions_in": sum(len(b) for b in boxes.values()), "text_candidates": len(candidates)})
stats = state.get("stats", PipelineStats())
stats.regions_resolved_by_ocr = len(candidates)
_emit(state, "run_ocr", "done")
return {"text_candidates": candidates, "stats": stats}
def node_match_brands(state: DetectState) -> dict:
_emit(state, "match_brands", "running")
with trace_node(state, "match_brands") as span:
profile = _get_profile(state)
resolver_config = profile.resolver_config()
candidates = state.get("text_candidates", [])
session_brands = state.get("session_brands", {})
job_id = state.get("job_id")
source_asset_id = state.get("source_asset_id")
matched, unresolved = resolve_brands(
candidates, resolver_config,
session_brands=session_brands,
source_asset_id=source_asset_id,
content_type=profile.name, job_id=job_id,
)
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
_emit(state, "match_brands", "done")
return {"detections": matched, "unresolved_candidates": unresolved}
def node_escalate_vlm(state: DetectState) -> dict:
_emit(state, "escalate_vlm", "running")
with trace_node(state, "escalate_vlm") as span:
profile = _get_profile(state)
candidates = state.get("unresolved_candidates", [])
job_id = state.get("job_id")
vlm_matched, still_unresolved = escalate_vlm(
candidates,
vlm_prompt_fn=profile.vlm_prompt,
inference_url=INFERENCE_URL,
content_type=profile.name,
source_asset_id=state.get("source_asset_id"),
job_id=job_id,
)
stats = state.get("stats", PipelineStats())
stats.regions_escalated_to_local_vlm = len(candidates)
span.set_output({"candidates": len(candidates), "matched": len(vlm_matched),
"still_unresolved": len(still_unresolved)})
existing = state.get("detections", [])
vlm_skipped = os.environ.get("SKIP_VLM", "").strip() == "1"
_emit(state, "escalate_vlm", "skipped" if vlm_skipped else "done")
return {
"detections": existing + vlm_matched,
"unresolved_candidates": still_unresolved,
"stats": stats,
}
def node_escalate_cloud(state: DetectState) -> dict:
_emit(state, "escalate_cloud", "running")
with trace_node(state, "escalate_cloud") as span:
profile = _get_profile(state)
candidates = state.get("unresolved_candidates", [])
job_id = state.get("job_id")
stats = state.get("stats", PipelineStats())
cloud_matched = escalate_cloud(
candidates,
vlm_prompt_fn=profile.vlm_prompt,
stats=stats,
content_type=profile.name,
source_asset_id=state.get("source_asset_id"),
job_id=job_id,
)
span.set_output({"candidates": len(candidates), "matched": len(cloud_matched),
"cloud_calls": stats.cloud_llm_calls,
"cost_usd": stats.estimated_cloud_cost_usd})
existing = state.get("detections", [])
cloud_skipped = os.environ.get("SKIP_CLOUD", "").strip() == "1"
_emit(state, "escalate_cloud", "skipped" if cloud_skipped else "done")
return {"detections": existing + cloud_matched, "stats": stats}
def node_compile_report(state: DetectState) -> dict:
_emit(state, "compile_report", "running")
with trace_node(state, "compile_report") as span:
profile = _get_profile(state)
detections = state.get("detections", [])
stats = state.get("stats", PipelineStats())
job_id = state.get("job_id")
report = compile_report(
detections=detections,
stats=stats,
video_source=state.get("video_path", ""),
content_type=profile.name,
job_id=job_id,
)
span.set_output({"brands": len(report.brands), "detections": len(report.timeline)})
flush_traces()
_emit(state, "compile_report", "done")
return {"report": report}
NODE_FUNCTIONS = [
("extract_frames", node_extract_frames),
("filter_scenes", node_filter_scenes),
("detect_edges", node_detect_edges),
("detect_objects", node_detect_objects),
("preprocess", node_preprocess),
("run_ocr", node_run_ocr),
("match_brands", node_match_brands),
("escalate_vlm", node_escalate_vlm),
("escalate_cloud", node_escalate_cloud),
("compile_report", node_compile_report),
]

127
detect/graph/runner.py Normal file
View File

@@ -0,0 +1,127 @@
"""
Pipeline runner — executes stages sequentially with checkpointing and cancellation.
Currently wraps LangGraph for execution. Will be replaced with a lean
custom runner in Phase 3, with an executor socket for distributed dispatch.
"""
from __future__ import annotations
import os
from langgraph.graph import END, StateGraph
from detect.state import DetectState
from .nodes import NODES, NODE_FUNCTIONS
# --- Checkpoint wrapper ---
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
_latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id
class PipelineCancelled(Exception):
"""Raised when a pipeline run is cancelled."""
pass
# Cancellation hook — set by the run endpoint, checked before each node
_cancel_check: dict[str, callable] = {}
def set_cancel_check(job_id: str, fn):
_cancel_check[job_id] = fn
def clear_cancel_check(job_id: str):
_cancel_check.pop(job_id, None)
def _checkpointing_node(node_name: str, node_fn):
"""Wrap a node function to auto-checkpoint after completion."""
def wrapper(state: DetectState) -> dict:
job_id = state.get("job_id", "")
check = _cancel_check.get(job_id)
if check and check():
raise PipelineCancelled(f"Cancelled before {node_name}")
result = node_fn(state)
job_id = state.get("job_id", "")
if not job_id:
return result
from detect.checkpoint import save_stage_output, save_frames
from detect.stages.base import _REGISTRY
merged = {**state, **result}
# Save frames once (first node), reuse manifest after
manifest = _frames_manifest.get(job_id)
if manifest is None and node_name == "extract_frames":
manifest = save_frames(job_id, merged.get("frames", []))
_frames_manifest[job_id] = manifest
# Serialize stage output using the stage's serialize_fn if available
stage_cls = _REGISTRY.get(node_name)
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
if serialize_fn:
output_json = serialize_fn(merged, job_id)
else:
output_json = {}
parent_id = _latest_checkpoint.get(job_id)
new_checkpoint_id = save_stage_output(
timeline_id=job_id,
parent_checkpoint_id=parent_id,
stage_name=node_name,
output_json=output_json,
)
_latest_checkpoint[job_id] = new_checkpoint_id
return result
wrapper.__name__ = node_fn.__name__
return wrapper
# --- Graph construction ---
def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph:
"""
Build the pipeline graph.
checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var)
start_from: skip nodes before this stage (for replay)
"""
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
graph = StateGraph(DetectState)
# Filter to start_from if replaying
node_pairs = NODE_FUNCTIONS
if start_from:
start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from)
node_pairs = NODE_FUNCTIONS[start_idx:]
for name, fn in node_pairs:
wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn
graph.add_node(name, wrapped)
# Wire edges
entry = node_pairs[0][0]
graph.set_entry_point(entry)
for i in range(len(node_pairs) - 1):
graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0])
graph.add_edge(node_pairs[-1][0], END)
return graph
def get_pipeline(checkpoint: bool | None = None):
"""Return a compiled, runnable pipeline."""
return build_graph(checkpoint=checkpoint).compile()