""" LangGraph pipeline graph for brand detection. Nodes execute real logic for extract+filter, stubs for the rest. Each node emits graph_update events so the UI can visualize transitions. """ from __future__ import annotations import os from langgraph.graph import END, StateGraph from detect import emit from detect.models import PipelineStats from detect.profiles import SoccerBroadcastProfile from detect.state import DetectState from detect.stages.frame_extractor import extract_frames from detect.stages.scene_filter import scene_filter from detect.stages.yolo_detector import detect_objects from detect.stages.preprocess import preprocess_regions from detect.stages.ocr_stage import run_ocr from detect.stages.brand_resolver import resolve_brands from detect.stages.vlm_local import escalate_vlm from detect.stages.vlm_cloud import escalate_cloud from detect.stages.aggregator import compile_report from detect.tracing import trace_node, flush as flush_traces INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode NODES = [ "extract_frames", "filter_scenes", "detect_objects", "preprocess", "run_ocr", "match_brands", "escalate_vlm", "escalate_cloud", "compile_report", ] def _get_profile(state: DetectState): name = state.get("profile_name", "soccer_broadcast") if name == "soccer_broadcast": profile = SoccerBroadcastProfile() else: raise ValueError(f"Unknown profile: {name}") overrides = state.get("config_overrides") if overrides: from detect.checkpoint.replay import OverrideProfile profile = OverrideProfile(profile, overrides) return profile # Track node states across the pipeline run _node_states: dict[str, dict[str, str]] = {} def _emit_transition(state: DetectState, node: str, status: str): job_id = state.get("job_id") if not job_id: return # Initialize state tracking for this job if job_id not in _node_states: _node_states[job_id] = {n: "pending" for n in NODES} _node_states[job_id][node] = status nodes = [{"id": n, "status": _node_states[job_id][n]} for n in NODES] emit.graph_update(job_id, nodes) # --- Node functions --- def node_extract_frames(state: DetectState) -> dict: # Set run context for initial runs (replays set it in replay_from) job_id = state.get("job_id", "") if job_id and not emit._run_context: emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial") # Load session brands from DB for this source source_asset_id = state.get("source_asset_id") if source_asset_id and not state.get("session_brands"): from detect.stages.brand_resolver import build_session_dict session_brands = build_session_dict(source_asset_id) state["session_brands"] = session_brands _emit_transition(state, "extract_frames", "running") with trace_node(state, "extract_frames") as span: profile = _get_profile(state) config = profile.frame_extraction_config() frames = extract_frames(state["video_path"], config, job_id=state.get("job_id")) span.set_output({"frames_extracted": len(frames)}) _emit_transition(state, "extract_frames", "done") return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))} def node_filter_scenes(state: DetectState) -> dict: _emit_transition(state, "filter_scenes", "running") with trace_node(state, "filter_scenes") as span: profile = _get_profile(state) config = profile.scene_filter_config() frames = state.get("frames", []) kept = scene_filter(frames, config, job_id=state.get("job_id")) span.set_output({"frames_in": len(frames), "frames_kept": len(kept)}) stats = state.get("stats", PipelineStats()) stats.frames_after_scene_filter = len(kept) _emit_transition(state, "filter_scenes", "done") return {"filtered_frames": kept, "stats": stats} def node_detect_objects(state: DetectState) -> dict: _emit_transition(state, "detect_objects", "running") with trace_node(state, "detect_objects") as span: profile = _get_profile(state) config = profile.detection_config() frames = state.get("filtered_frames", []) job_id = state.get("job_id") all_boxes = detect_objects(frames, config, inference_url=INFERENCE_URL, job_id=job_id) total_regions = sum(len(boxes) for boxes in all_boxes.values()) span.set_output({"frames": len(frames), "regions_detected": total_regions}) stats = state.get("stats", PipelineStats()) stats.regions_detected = total_regions _emit_transition(state, "detect_objects", "done") return {"boxes_by_frame": all_boxes, "stats": stats} def node_preprocess(state: DetectState) -> dict: _emit_transition(state, "preprocess", "running") with trace_node(state, "preprocess") as span: profile = _get_profile(state) frames = state.get("filtered_frames", []) boxes = state.get("boxes_by_frame", {}) job_id = state.get("job_id") # Get preprocessing config from profile overrides or defaults overrides = state.get("config_overrides", {}) prep_config = overrides.get("preprocessing", {}) do_contrast = prep_config.get("contrast", True) do_deskew = prep_config.get("deskew", False) do_binarize = prep_config.get("binarize", False) preprocessed = preprocess_regions( frames, boxes, do_contrast=do_contrast, do_deskew=do_deskew, do_binarize=do_binarize, inference_url=INFERENCE_URL, job_id=job_id, ) span.set_output({"regions_preprocessed": len(preprocessed)}) _emit_transition(state, "preprocess", "done") return {"preprocessed_crops": preprocessed} def node_run_ocr(state: DetectState) -> dict: _emit_transition(state, "run_ocr", "running") with trace_node(state, "run_ocr") as span: profile = _get_profile(state) config = profile.ocr_config() frames = state.get("filtered_frames", []) boxes = state.get("boxes_by_frame", {}) job_id = state.get("job_id") candidates = run_ocr(frames, boxes, config, inference_url=INFERENCE_URL, job_id=job_id) span.set_output({"regions_in": sum(len(b) for b in boxes.values()), "text_candidates": len(candidates)}) stats = state.get("stats", PipelineStats()) stats.regions_resolved_by_ocr = len(candidates) _emit_transition(state, "run_ocr", "done") return {"text_candidates": candidates, "stats": stats} def node_match_brands(state: DetectState) -> dict: _emit_transition(state, "match_brands", "running") with trace_node(state, "match_brands") as span: profile = _get_profile(state) resolver_config = profile.resolver_config() candidates = state.get("text_candidates", []) session_brands = state.get("session_brands", {}) job_id = state.get("job_id") source_asset_id = state.get("source_asset_id") matched, unresolved = resolve_brands( candidates, resolver_config, session_brands=session_brands, source_asset_id=source_asset_id, content_type=profile.name, job_id=job_id, ) span.set_output({"matched": len(matched), "unresolved": len(unresolved)}) _emit_transition(state, "match_brands", "done") return {"detections": matched, "unresolved_candidates": unresolved} def node_escalate_vlm(state: DetectState) -> dict: _emit_transition(state, "escalate_vlm", "running") with trace_node(state, "escalate_vlm") as span: profile = _get_profile(state) candidates = state.get("unresolved_candidates", []) job_id = state.get("job_id") vlm_matched, still_unresolved = escalate_vlm( candidates, vlm_prompt_fn=profile.vlm_prompt, inference_url=INFERENCE_URL, content_type=profile.name, source_asset_id=state.get("source_asset_id"), job_id=job_id, ) stats = state.get("stats", PipelineStats()) stats.regions_escalated_to_local_vlm = len(candidates) span.set_output({"candidates": len(candidates), "matched": len(vlm_matched), "still_unresolved": len(still_unresolved)}) existing = state.get("detections", []) _emit_transition(state, "escalate_vlm", "done") return { "detections": existing + vlm_matched, "unresolved_candidates": still_unresolved, "stats": stats, } def node_escalate_cloud(state: DetectState) -> dict: _emit_transition(state, "escalate_cloud", "running") with trace_node(state, "escalate_cloud") as span: profile = _get_profile(state) candidates = state.get("unresolved_candidates", []) job_id = state.get("job_id") stats = state.get("stats", PipelineStats()) cloud_matched = escalate_cloud( candidates, vlm_prompt_fn=profile.vlm_prompt, stats=stats, content_type=profile.name, source_asset_id=state.get("source_asset_id"), job_id=job_id, ) span.set_output({"candidates": len(candidates), "matched": len(cloud_matched), "cloud_calls": stats.cloud_llm_calls, "cost_usd": stats.estimated_cloud_cost_usd}) existing = state.get("detections", []) _emit_transition(state, "escalate_cloud", "done") return {"detections": existing + cloud_matched, "stats": stats} def node_compile_report(state: DetectState) -> dict: _emit_transition(state, "compile_report", "running") with trace_node(state, "compile_report") as span: profile = _get_profile(state) detections = state.get("detections", []) stats = state.get("stats", PipelineStats()) job_id = state.get("job_id") report = compile_report( detections=detections, stats=stats, video_source=state.get("video_path", ""), content_type=profile.name, job_id=job_id, ) span.set_output({"brands": len(report.brands), "detections": len(report.timeline)}) flush_traces() _emit_transition(state, "compile_report", "done") return {"report": report} # --- Checkpoint wrapper --- _CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1" _frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job) def _checkpointing_node(node_name: str, node_fn): """Wrap a node function to auto-checkpoint after completion.""" stage_index = NODES.index(node_name) def wrapper(state: DetectState) -> dict: result = node_fn(state) job_id = state.get("job_id", "") if not job_id: return result from detect.checkpoint import save_checkpoint, save_frames merged = {**state, **result} # Save frames once (first checkpoint), reuse manifest after manifest = _frames_manifest.get(job_id) if manifest is None and node_name == "extract_frames": manifest = save_frames(job_id, merged.get("frames", [])) _frames_manifest[job_id] = manifest save_checkpoint(job_id, node_name, stage_index, merged, frames_manifest=manifest) return result wrapper.__name__ = node_fn.__name__ return wrapper # --- Graph construction --- NODE_FUNCTIONS = [ ("extract_frames", node_extract_frames), ("filter_scenes", node_filter_scenes), ("detect_objects", node_detect_objects), ("preprocess", node_preprocess), ("run_ocr", node_run_ocr), ("match_brands", node_match_brands), ("escalate_vlm", node_escalate_vlm), ("escalate_cloud", node_escalate_cloud), ("compile_report", node_compile_report), ] def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph: """ Build the pipeline graph. checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var) start_from: skip nodes before this stage (for replay) """ do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED graph = StateGraph(DetectState) # Filter to start_from if replaying node_pairs = NODE_FUNCTIONS if start_from: start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from) node_pairs = NODE_FUNCTIONS[start_idx:] for name, fn in node_pairs: wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn graph.add_node(name, wrapped) # Wire edges entry = node_pairs[0][0] graph.set_entry_point(entry) for i in range(len(node_pairs) - 1): graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0]) graph.add_edge(node_pairs[-1][0], END) return graph def get_pipeline(checkpoint: bool | None = None): """Return a compiled, runnable pipeline.""" return build_graph(checkpoint=checkpoint).compile()