""" LangGraph pipeline graph for brand detection. Nodes execute real logic for extract+filter, stubs for the rest. Each node emits graph_update events so the UI can visualize transitions. """ from __future__ import annotations import os from langgraph.graph import END, StateGraph from detect import emit from detect.models import PipelineStats from detect.profiles import SoccerBroadcastProfile from detect.state import DetectState from detect.stages.frame_extractor import extract_frames from detect.stages.scene_filter import scene_filter from detect.stages.yolo_detector import detect_objects from detect.stages.ocr_stage import run_ocr from detect.stages.brand_resolver import resolve_brands INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode NODES = [ "extract_frames", "filter_scenes", "detect_objects", "run_ocr", "match_brands", "escalate_vlm", "escalate_cloud", "compile_report", ] def _get_profile(state: DetectState): name = state.get("profile_name", "soccer_broadcast") if name == "soccer_broadcast": return SoccerBroadcastProfile() raise ValueError(f"Unknown profile: {name}") # Track node states across the pipeline run _node_states: dict[str, dict[str, str]] = {} def _emit_transition(state: DetectState, node: str, status: str): job_id = state.get("job_id") if not job_id: return # Initialize state tracking for this job if job_id not in _node_states: _node_states[job_id] = {n: "pending" for n in NODES} _node_states[job_id][node] = status nodes = [{"id": n, "status": _node_states[job_id][n]} for n in NODES] emit.graph_update(job_id, nodes) # --- Node functions --- def node_extract_frames(state: DetectState) -> dict: _emit_transition(state, "extract_frames", "running") profile = _get_profile(state) config = profile.frame_extraction_config() frames = extract_frames(state["video_path"], config, job_id=state.get("job_id")) _emit_transition(state, "extract_frames", "done") return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))} def node_filter_scenes(state: DetectState) -> dict: _emit_transition(state, "filter_scenes", "running") profile = _get_profile(state) config = profile.scene_filter_config() frames = state.get("frames", []) kept = scene_filter(frames, config, job_id=state.get("job_id")) stats = state.get("stats", PipelineStats()) stats.frames_after_scene_filter = len(kept) _emit_transition(state, "filter_scenes", "done") return {"filtered_frames": kept, "stats": stats} def node_detect_objects(state: DetectState) -> dict: _emit_transition(state, "detect_objects", "running") profile = _get_profile(state) config = profile.detection_config() frames = state.get("filtered_frames", []) job_id = state.get("job_id") all_boxes = detect_objects(frames, config, inference_url=INFERENCE_URL, job_id=job_id) stats = state.get("stats", PipelineStats()) stats.regions_detected = sum(len(boxes) for boxes in all_boxes.values()) _emit_transition(state, "detect_objects", "done") return {"boxes_by_frame": all_boxes, "stats": stats} def node_run_ocr(state: DetectState) -> dict: _emit_transition(state, "run_ocr", "running") profile = _get_profile(state) config = profile.ocr_config() frames = state.get("filtered_frames", []) boxes = state.get("boxes_by_frame", {}) job_id = state.get("job_id") candidates = run_ocr(frames, boxes, config, inference_url=INFERENCE_URL, job_id=job_id) stats = state.get("stats", PipelineStats()) stats.regions_resolved_by_ocr = len(candidates) _emit_transition(state, "run_ocr", "done") return {"text_candidates": candidates, "stats": stats} def node_match_brands(state: DetectState) -> dict: _emit_transition(state, "match_brands", "running") profile = _get_profile(state) dictionary = profile.brand_dictionary() resolver_config = profile.resolver_config() candidates = state.get("text_candidates", []) job_id = state.get("job_id") matched, unresolved = resolve_brands( candidates, dictionary, resolver_config, content_type=profile.name, job_id=job_id, ) _emit_transition(state, "match_brands", "done") return {"detections": matched, "unresolved_candidates": unresolved} def node_escalate_vlm(state: DetectState) -> dict: _emit_transition(state, "escalate_vlm", "running") job_id = state.get("job_id") emit.log(job_id, "VLMLocal", "INFO", "Stub: VLM escalation not yet implemented") _emit_transition(state, "escalate_vlm", "done") return {} def node_escalate_cloud(state: DetectState) -> dict: _emit_transition(state, "escalate_cloud", "running") job_id = state.get("job_id") emit.log(job_id, "CloudLLM", "INFO", "Stub: cloud LLM escalation not yet implemented") _emit_transition(state, "escalate_cloud", "done") return {} def node_compile_report(state: DetectState) -> dict: _emit_transition(state, "compile_report", "running") job_id = state.get("job_id") profile = _get_profile(state) detections = state.get("detections", []) report = profile.aggregate(detections) report.video_source = state.get("video_path", "") emit.log(job_id, "Aggregator", "INFO", f"Report: {len(report.brands)} brands, {len(report.timeline)} detections") emit.job_complete(job_id, { "video_source": report.video_source, "content_type": report.content_type, "brands": {k: {"total_appearances": v.total_appearances} for k, v in report.brands.items()}, }) _emit_transition(state, "compile_report", "done") return {"report": report} # --- Graph construction --- def build_graph() -> StateGraph: graph = StateGraph(DetectState) graph.add_node("extract_frames", node_extract_frames) graph.add_node("filter_scenes", node_filter_scenes) graph.add_node("detect_objects", node_detect_objects) graph.add_node("run_ocr", node_run_ocr) graph.add_node("match_brands", node_match_brands) graph.add_node("escalate_vlm", node_escalate_vlm) graph.add_node("escalate_cloud", node_escalate_cloud) graph.add_node("compile_report", node_compile_report) graph.set_entry_point("extract_frames") graph.add_edge("extract_frames", "filter_scenes") graph.add_edge("filter_scenes", "detect_objects") graph.add_edge("detect_objects", "run_ocr") graph.add_edge("run_ocr", "match_brands") graph.add_edge("match_brands", "escalate_vlm") graph.add_edge("escalate_vlm", "escalate_cloud") graph.add_edge("escalate_cloud", "compile_report") graph.add_edge("compile_report", END) return graph def get_pipeline(): """Return a compiled, runnable pipeline.""" return build_graph().compile()