""" LangGraph pipeline graph for brand detection. Nodes execute real logic for extract+filter, stubs for the rest. Each node emits graph_update events so the UI can visualize transitions. """ from __future__ import annotations import os from langgraph.graph import END, StateGraph from detect import emit from detect.models import PipelineStats from detect.profiles import SoccerBroadcastProfile from detect.state import DetectState from detect.stages.frame_extractor import extract_frames from detect.stages.scene_filter import scene_filter from detect.stages.yolo_detector import detect_objects from detect.stages.ocr_stage import run_ocr from detect.stages.brand_resolver import resolve_brands from detect.stages.vlm_local import escalate_vlm from detect.stages.vlm_cloud import escalate_cloud from detect.stages.aggregator import compile_report from detect.tracing import trace_node, flush as flush_traces INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode NODES = [ "extract_frames", "filter_scenes", "detect_objects", "run_ocr", "match_brands", "escalate_vlm", "escalate_cloud", "compile_report", ] def _get_profile(state: DetectState): name = state.get("profile_name", "soccer_broadcast") if name == "soccer_broadcast": return SoccerBroadcastProfile() raise ValueError(f"Unknown profile: {name}") # Track node states across the pipeline run _node_states: dict[str, dict[str, str]] = {} def _emit_transition(state: DetectState, node: str, status: str): job_id = state.get("job_id") if not job_id: return # Initialize state tracking for this job if job_id not in _node_states: _node_states[job_id] = {n: "pending" for n in NODES} _node_states[job_id][node] = status nodes = [{"id": n, "status": _node_states[job_id][n]} for n in NODES] emit.graph_update(job_id, nodes) # --- Node functions --- def node_extract_frames(state: DetectState) -> dict: _emit_transition(state, "extract_frames", "running") with trace_node(state, "extract_frames") as span: profile = _get_profile(state) config = profile.frame_extraction_config() frames = extract_frames(state["video_path"], config, job_id=state.get("job_id")) span.set_output({"frames_extracted": len(frames)}) _emit_transition(state, "extract_frames", "done") return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))} def node_filter_scenes(state: DetectState) -> dict: _emit_transition(state, "filter_scenes", "running") with trace_node(state, "filter_scenes") as span: profile = _get_profile(state) config = profile.scene_filter_config() frames = state.get("frames", []) kept = scene_filter(frames, config, job_id=state.get("job_id")) span.set_output({"frames_in": len(frames), "frames_kept": len(kept)}) stats = state.get("stats", PipelineStats()) stats.frames_after_scene_filter = len(kept) _emit_transition(state, "filter_scenes", "done") return {"filtered_frames": kept, "stats": stats} def node_detect_objects(state: DetectState) -> dict: _emit_transition(state, "detect_objects", "running") with trace_node(state, "detect_objects") as span: profile = _get_profile(state) config = profile.detection_config() frames = state.get("filtered_frames", []) job_id = state.get("job_id") all_boxes = detect_objects(frames, config, inference_url=INFERENCE_URL, job_id=job_id) total_regions = sum(len(boxes) for boxes in all_boxes.values()) span.set_output({"frames": len(frames), "regions_detected": total_regions}) stats = state.get("stats", PipelineStats()) stats.regions_detected = total_regions _emit_transition(state, "detect_objects", "done") return {"boxes_by_frame": all_boxes, "stats": stats} def node_run_ocr(state: DetectState) -> dict: _emit_transition(state, "run_ocr", "running") with trace_node(state, "run_ocr") as span: profile = _get_profile(state) config = profile.ocr_config() frames = state.get("filtered_frames", []) boxes = state.get("boxes_by_frame", {}) job_id = state.get("job_id") candidates = run_ocr(frames, boxes, config, inference_url=INFERENCE_URL, job_id=job_id) span.set_output({"regions_in": sum(len(b) for b in boxes.values()), "text_candidates": len(candidates)}) stats = state.get("stats", PipelineStats()) stats.regions_resolved_by_ocr = len(candidates) _emit_transition(state, "run_ocr", "done") return {"text_candidates": candidates, "stats": stats} def node_match_brands(state: DetectState) -> dict: _emit_transition(state, "match_brands", "running") with trace_node(state, "match_brands") as span: profile = _get_profile(state) dictionary = profile.brand_dictionary() resolver_config = profile.resolver_config() candidates = state.get("text_candidates", []) job_id = state.get("job_id") matched, unresolved = resolve_brands( candidates, dictionary, resolver_config, content_type=profile.name, job_id=job_id, ) span.set_output({"matched": len(matched), "unresolved": len(unresolved)}) _emit_transition(state, "match_brands", "done") return {"detections": matched, "unresolved_candidates": unresolved} def node_escalate_vlm(state: DetectState) -> dict: _emit_transition(state, "escalate_vlm", "running") with trace_node(state, "escalate_vlm") as span: profile = _get_profile(state) candidates = state.get("unresolved_candidates", []) job_id = state.get("job_id") vlm_matched, still_unresolved = escalate_vlm( candidates, vlm_prompt_fn=profile.vlm_prompt, inference_url=INFERENCE_URL, content_type=profile.name, job_id=job_id, ) stats = state.get("stats", PipelineStats()) stats.regions_escalated_to_local_vlm = len(candidates) span.set_output({"candidates": len(candidates), "matched": len(vlm_matched), "still_unresolved": len(still_unresolved)}) existing = state.get("detections", []) _emit_transition(state, "escalate_vlm", "done") return { "detections": existing + vlm_matched, "unresolved_candidates": still_unresolved, "stats": stats, } def node_escalate_cloud(state: DetectState) -> dict: _emit_transition(state, "escalate_cloud", "running") with trace_node(state, "escalate_cloud") as span: profile = _get_profile(state) candidates = state.get("unresolved_candidates", []) job_id = state.get("job_id") stats = state.get("stats", PipelineStats()) cloud_matched = escalate_cloud( candidates, vlm_prompt_fn=profile.vlm_prompt, stats=stats, content_type=profile.name, job_id=job_id, ) span.set_output({"candidates": len(candidates), "matched": len(cloud_matched), "cloud_calls": stats.cloud_llm_calls, "cost_usd": stats.estimated_cloud_cost_usd}) existing = state.get("detections", []) _emit_transition(state, "escalate_cloud", "done") return {"detections": existing + cloud_matched, "stats": stats} def node_compile_report(state: DetectState) -> dict: _emit_transition(state, "compile_report", "running") with trace_node(state, "compile_report") as span: profile = _get_profile(state) detections = state.get("detections", []) stats = state.get("stats", PipelineStats()) job_id = state.get("job_id") report = compile_report( detections=detections, stats=stats, video_source=state.get("video_path", ""), content_type=profile.name, job_id=job_id, ) span.set_output({"brands": len(report.brands), "detections": len(report.timeline)}) flush_traces() _emit_transition(state, "compile_report", "done") return {"report": report} # --- Graph construction --- def build_graph() -> StateGraph: graph = StateGraph(DetectState) graph.add_node("extract_frames", node_extract_frames) graph.add_node("filter_scenes", node_filter_scenes) graph.add_node("detect_objects", node_detect_objects) graph.add_node("run_ocr", node_run_ocr) graph.add_node("match_brands", node_match_brands) graph.add_node("escalate_vlm", node_escalate_vlm) graph.add_node("escalate_cloud", node_escalate_cloud) graph.add_node("compile_report", node_compile_report) graph.set_entry_point("extract_frames") graph.add_edge("extract_frames", "filter_scenes") graph.add_edge("filter_scenes", "detect_objects") graph.add_edge("detect_objects", "run_ocr") graph.add_edge("run_ocr", "match_brands") graph.add_edge("match_brands", "escalate_vlm") graph.add_edge("escalate_vlm", "escalate_cloud") graph.add_edge("escalate_cloud", "compile_report") graph.add_edge("compile_report", END) return graph def get_pipeline(): """Return a compiled, runnable pipeline.""" return build_graph().compile()