""" Pipeline node functions — one per stage. Each node: reads state, runs stage logic, emits transitions, returns output dict. """ from __future__ import annotations import os from detect import emit from detect.models import PipelineStats from detect.profiles import SoccerBroadcastProfile from detect.state import DetectState from detect.stages.frame_extractor import extract_frames from detect.stages.scene_filter import scene_filter from detect.stages.edge_detector import detect_edge_regions from detect.stages.yolo_detector import detect_objects from detect.stages.preprocess import preprocess_regions from detect.stages.ocr_stage import run_ocr from detect.stages.brand_resolver import resolve_brands from detect.stages.vlm_local import escalate_vlm from detect.stages.vlm_cloud import escalate_cloud from detect.stages.aggregator import compile_report from detect.tracing import trace_node, flush as flush_traces from .events import emit_transition INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode NODES = [ "extract_frames", "filter_scenes", "detect_edges", "detect_objects", "preprocess", "run_ocr", "match_brands", "escalate_vlm", "escalate_cloud", "compile_report", ] def _get_profile(state: DetectState): name = state.get("profile_name", "soccer_broadcast") if name == "soccer_broadcast": profile = SoccerBroadcastProfile() else: raise ValueError(f"Unknown profile: {name}") overrides = state.get("config_overrides") if overrides: from detect.checkpoint.replay import OverrideProfile profile = OverrideProfile(profile, overrides) return profile def _emit(state, node, status): emit_transition(state, node, status, NODES) # --- Node functions --- def node_extract_frames(state: DetectState) -> dict: job_id = state.get("job_id", "") if job_id and not emit._run_context: emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial") source_asset_id = state.get("source_asset_id") if source_asset_id and not state.get("session_brands"): from detect.stages.brand_resolver import build_session_dict session_brands = build_session_dict(source_asset_id) state["session_brands"] = session_brands _emit(state, "extract_frames", "running") with trace_node(state, "extract_frames") as span: profile = _get_profile(state) config = profile.frame_extraction_config() frames = extract_frames(state["video_path"], config, job_id=state.get("job_id")) span.set_output({"frames_extracted": len(frames)}) _emit(state, "extract_frames", "done") return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))} def node_filter_scenes(state: DetectState) -> dict: _emit(state, "filter_scenes", "running") with trace_node(state, "filter_scenes") as span: profile = _get_profile(state) config = profile.scene_filter_config() frames = state.get("frames", []) kept = scene_filter(frames, config, job_id=state.get("job_id")) span.set_output({"frames_in": len(frames), "frames_kept": len(kept)}) stats = state.get("stats", PipelineStats()) stats.frames_after_scene_filter = len(kept) _emit(state, "filter_scenes", "done") return {"filtered_frames": kept, "stats": stats} def node_detect_edges(state: DetectState) -> dict: _emit(state, "detect_edges", "running") with trace_node(state, "detect_edges") as span: profile = _get_profile(state) config = profile.region_analysis_config() frames = state.get("filtered_frames", []) job_id = state.get("job_id") regions = detect_edge_regions( frames, config, inference_url=INFERENCE_URL, job_id=job_id, ) total = sum(len(r) for r in regions.values()) span.set_output({"frames": len(frames), "edge_regions": total}) stats = state.get("stats", PipelineStats()) stats.cv_regions_detected = total _emit(state, "detect_edges", "done") return {"edge_regions_by_frame": regions, "stats": stats} def node_detect_objects(state: DetectState) -> dict: _emit(state, "detect_objects", "running") with trace_node(state, "detect_objects") as span: profile = _get_profile(state) config = profile.detection_config() frames = state.get("filtered_frames", []) job_id = state.get("job_id") all_boxes = detect_objects(frames, config, inference_url=INFERENCE_URL, job_id=job_id) total_regions = sum(len(boxes) for boxes in all_boxes.values()) span.set_output({"frames": len(frames), "regions_detected": total_regions}) stats = state.get("stats", PipelineStats()) stats.regions_detected = total_regions _emit(state, "detect_objects", "done") return {"boxes_by_frame": all_boxes, "stats": stats} def node_preprocess(state: DetectState) -> dict: _emit(state, "preprocess", "running") with trace_node(state, "preprocess") as span: profile = _get_profile(state) frames = state.get("filtered_frames", []) boxes = state.get("boxes_by_frame", {}) job_id = state.get("job_id") overrides = state.get("config_overrides", {}) prep_config = overrides.get("preprocessing", {}) do_contrast = prep_config.get("contrast", True) do_deskew = prep_config.get("deskew", False) do_binarize = prep_config.get("binarize", False) preprocessed = preprocess_regions( frames, boxes, do_contrast=do_contrast, do_deskew=do_deskew, do_binarize=do_binarize, inference_url=INFERENCE_URL, job_id=job_id, ) span.set_output({"regions_preprocessed": len(preprocessed)}) _emit(state, "preprocess", "done") return {"preprocessed_crops": preprocessed} def node_run_ocr(state: DetectState) -> dict: _emit(state, "run_ocr", "running") with trace_node(state, "run_ocr") as span: profile = _get_profile(state) config = profile.ocr_config() frames = state.get("filtered_frames", []) boxes = state.get("boxes_by_frame", {}) job_id = state.get("job_id") candidates = run_ocr(frames, boxes, config, inference_url=INFERENCE_URL, job_id=job_id) span.set_output({"regions_in": sum(len(b) for b in boxes.values()), "text_candidates": len(candidates)}) stats = state.get("stats", PipelineStats()) stats.regions_resolved_by_ocr = len(candidates) _emit(state, "run_ocr", "done") return {"text_candidates": candidates, "stats": stats} def node_match_brands(state: DetectState) -> dict: _emit(state, "match_brands", "running") with trace_node(state, "match_brands") as span: profile = _get_profile(state) resolver_config = profile.resolver_config() candidates = state.get("text_candidates", []) session_brands = state.get("session_brands", {}) job_id = state.get("job_id") source_asset_id = state.get("source_asset_id") matched, unresolved = resolve_brands( candidates, resolver_config, session_brands=session_brands, source_asset_id=source_asset_id, content_type=profile.name, job_id=job_id, ) span.set_output({"matched": len(matched), "unresolved": len(unresolved)}) _emit(state, "match_brands", "done") return {"detections": matched, "unresolved_candidates": unresolved} def node_escalate_vlm(state: DetectState) -> dict: _emit(state, "escalate_vlm", "running") with trace_node(state, "escalate_vlm") as span: profile = _get_profile(state) candidates = state.get("unresolved_candidates", []) job_id = state.get("job_id") vlm_matched, still_unresolved = escalate_vlm( candidates, vlm_prompt_fn=profile.vlm_prompt, inference_url=INFERENCE_URL, content_type=profile.name, source_asset_id=state.get("source_asset_id"), job_id=job_id, ) stats = state.get("stats", PipelineStats()) stats.regions_escalated_to_local_vlm = len(candidates) span.set_output({"candidates": len(candidates), "matched": len(vlm_matched), "still_unresolved": len(still_unresolved)}) existing = state.get("detections", []) vlm_skipped = os.environ.get("SKIP_VLM", "").strip() == "1" _emit(state, "escalate_vlm", "skipped" if vlm_skipped else "done") return { "detections": existing + vlm_matched, "unresolved_candidates": still_unresolved, "stats": stats, } def node_escalate_cloud(state: DetectState) -> dict: _emit(state, "escalate_cloud", "running") with trace_node(state, "escalate_cloud") as span: profile = _get_profile(state) candidates = state.get("unresolved_candidates", []) job_id = state.get("job_id") stats = state.get("stats", PipelineStats()) cloud_matched = escalate_cloud( candidates, vlm_prompt_fn=profile.vlm_prompt, stats=stats, content_type=profile.name, source_asset_id=state.get("source_asset_id"), job_id=job_id, ) span.set_output({"candidates": len(candidates), "matched": len(cloud_matched), "cloud_calls": stats.cloud_llm_calls, "cost_usd": stats.estimated_cloud_cost_usd}) existing = state.get("detections", []) cloud_skipped = os.environ.get("SKIP_CLOUD", "").strip() == "1" _emit(state, "escalate_cloud", "skipped" if cloud_skipped else "done") return {"detections": existing + cloud_matched, "stats": stats} def node_compile_report(state: DetectState) -> dict: _emit(state, "compile_report", "running") with trace_node(state, "compile_report") as span: profile = _get_profile(state) detections = state.get("detections", []) stats = state.get("stats", PipelineStats()) job_id = state.get("job_id") report = compile_report( detections=detections, stats=stats, video_source=state.get("video_path", ""), content_type=profile.name, job_id=job_id, ) span.set_output({"brands": len(report.brands), "detections": len(report.timeline)}) flush_traces() _emit(state, "compile_report", "done") return {"report": report} NODE_FUNCTIONS = [ ("extract_frames", node_extract_frames), ("filter_scenes", node_filter_scenes), ("detect_edges", node_detect_edges), ("detect_objects", node_detect_objects), ("preprocess", node_preprocess), ("run_ocr", node_run_ocr), ("match_brands", node_match_brands), ("escalate_vlm", node_escalate_vlm), ("escalate_cloud", node_escalate_cloud), ("compile_report", node_compile_report), ]