""" Pipeline node functions — one per stage. Each node: reads state, gets config from profile dict, runs stage logic, emits transitions, returns output dict. """ from __future__ import annotations import os from core.detect import emit from core.detect.models import CropContext, PipelineStats from core.detect.profile import get_profile, get_stage_config, build_vlm_prompt, aggregate_detections from core.detect.stages.models import ( DetectionConfig, FieldSegmentationConfig, FrameExtractionConfig, OCRConfig, RegionAnalysisConfig, ResolverConfig, SceneFilterConfig, ) from core.detect.state import DetectState from core.detect.stages.frame_extractor import extract_frames from core.detect.stages.scene_filter import scene_filter from core.detect.stages.field_segmentation import run_field_segmentation from core.detect.stages.edge_detector import detect_edge_regions from core.detect.stages.yolo_detector import detect_objects from core.detect.stages.preprocess import preprocess_regions from core.detect.stages.ocr_stage import run_ocr from core.detect.stages.brand_resolver import resolve_brands from core.detect.stages.vlm_local import escalate_vlm from core.detect.stages.vlm_cloud import escalate_cloud from core.detect.stages.aggregator import compile_report from core.detect.tracing import trace_node, flush as flush_traces from .events import emit_transition INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode NODES = [ "extract_frames", "filter_scenes", "field_segmentation", "detect_edges", "detect_objects", "preprocess", "run_ocr", "match_brands", "escalate_vlm", "escalate_cloud", "compile_report", ] def _load_profile(state: DetectState) -> dict: """Load profile dict, apply config overrides if present.""" name = state.get("profile_name", "soccer_broadcast") profile = get_profile(name) overrides = state.get("config_overrides") if overrides: # Merge overrides into a copy of the profile configs merged_configs = dict(profile.get("configs", {})) for stage_name, stage_overrides in overrides.items(): if stage_name in merged_configs: merged_configs[stage_name] = {**merged_configs[stage_name], **stage_overrides} else: merged_configs[stage_name] = stage_overrides profile = {**profile, "configs": merged_configs} return profile def _emit(state, node, status): emit_transition(state, node, status, NODES) # --- Node functions --- def node_extract_frames(state: DetectState) -> dict: job_id = state.get("job_id", "") if job_id and not emit._run_context: emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial") source_asset_id = state.get("source_asset_id") if source_asset_id and not state.get("session_brands"): from core.detect.stages.brand_resolver import build_session_dict session_brands = build_session_dict(source_asset_id) state["session_brands"] = session_brands _emit(state, "extract_frames", "running") with trace_node(state, "extract_frames") as span: profile = _load_profile(state) config = FrameExtractionConfig(**get_stage_config(profile, "extract_frames")) frames = extract_frames(state["video_path"], config, job_id=job_id) span.set_output({"frames_extracted": len(frames)}) # Cache frames on the timeline for reuse across jobs and UI timeline_id = state.get("timeline_id") if timeline_id: from core.detect.checkpoint.frames import cache_frames, cache_exists if not cache_exists(timeline_id): cache_frames(timeline_id, frames) from core.detect.checkpoint.storage import update_timeline_status update_timeline_status(timeline_id, "cached", frame_count=len(frames)) _emit(state, "extract_frames", "done") return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))} def node_filter_scenes(state: DetectState) -> dict: _emit(state, "filter_scenes", "running") with trace_node(state, "filter_scenes") as span: profile = _load_profile(state) config = SceneFilterConfig(**get_stage_config(profile, "filter_scenes")) frames = state.get("frames", []) kept = scene_filter(frames, config, job_id=state.get("job_id")) span.set_output({"frames_in": len(frames), "frames_kept": len(kept)}) stats = state.get("stats", PipelineStats()) stats.frames_after_scene_filter = len(kept) _emit(state, "filter_scenes", "done") return {"filtered_frames": kept, "stats": stats} def node_field_segmentation(state: DetectState) -> dict: _emit(state, "field_segmentation", "running") with trace_node(state, "field_segmentation") as span: profile = _load_profile(state) config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation")) frames = state.get("filtered_frames", []) job_id = state.get("job_id") result = run_field_segmentation(frames, config, inference_url=INFERENCE_URL, job_id=job_id) span.set_output({ "frames": len(frames), "avg_coverage": sum(result["field_coverage"].values()) / max(len(result["field_coverage"]), 1), }) _emit(state, "field_segmentation", "done") return { "field_masks": result["field_masks"], "field_boundaries": result["field_boundaries"], "field_coverage": result["field_coverage"], } def node_detect_edges(state: DetectState) -> dict: _emit(state, "detect_edges", "running") with trace_node(state, "detect_edges") as span: profile = _load_profile(state) config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges")) frames = state.get("filtered_frames", []) field_masks = state.get("field_masks", {}) job_id = state.get("job_id") regions = detect_edge_regions( frames, config, inference_url=INFERENCE_URL, job_id=job_id, field_masks=field_masks, ) total = sum(len(r) for r in regions.values()) span.set_output({"frames": len(frames), "edge_regions": total}) stats = state.get("stats", PipelineStats()) stats.cv_regions_detected = total _emit(state, "detect_edges", "done") return {"edge_regions_by_frame": regions, "stats": stats} def node_detect_objects(state: DetectState) -> dict: _emit(state, "detect_objects", "running") with trace_node(state, "detect_objects") as span: profile = _load_profile(state) config = DetectionConfig(**get_stage_config(profile, "detect_objects")) frames = state.get("filtered_frames", []) job_id = state.get("job_id") all_boxes = detect_objects(frames, config, inference_url=INFERENCE_URL, job_id=job_id) total_regions = sum(len(boxes) for boxes in all_boxes.values()) span.set_output({"frames": len(frames), "regions_detected": total_regions}) stats = state.get("stats", PipelineStats()) stats.regions_detected = total_regions _emit(state, "detect_objects", "done") return {"boxes_by_frame": all_boxes, "stats": stats} def node_preprocess(state: DetectState) -> dict: _emit(state, "preprocess", "running") with trace_node(state, "preprocess") as span: profile = _load_profile(state) prep_config = get_stage_config(profile, "preprocess") frames = state.get("filtered_frames", []) boxes = state.get("boxes_by_frame", {}) job_id = state.get("job_id") do_contrast = prep_config.get("contrast", True) do_deskew = prep_config.get("deskew", False) do_binarize = prep_config.get("binarize", False) preprocessed = preprocess_regions( frames, boxes, do_contrast=do_contrast, do_deskew=do_deskew, do_binarize=do_binarize, inference_url=INFERENCE_URL, job_id=job_id, ) span.set_output({"regions_preprocessed": len(preprocessed)}) _emit(state, "preprocess", "done") return {"preprocessed_crops": preprocessed} def node_run_ocr(state: DetectState) -> dict: _emit(state, "run_ocr", "running") with trace_node(state, "run_ocr") as span: profile = _load_profile(state) config = OCRConfig(**get_stage_config(profile, "run_ocr")) frames = state.get("filtered_frames", []) boxes = state.get("boxes_by_frame", {}) job_id = state.get("job_id") candidates = run_ocr(frames, boxes, config, inference_url=INFERENCE_URL, job_id=job_id) span.set_output({"regions_in": sum(len(b) for b in boxes.values()), "text_candidates": len(candidates)}) stats = state.get("stats", PipelineStats()) stats.regions_resolved_by_ocr = len(candidates) _emit(state, "run_ocr", "done") return {"text_candidates": candidates, "stats": stats} def node_match_brands(state: DetectState) -> dict: _emit(state, "match_brands", "running") with trace_node(state, "match_brands") as span: profile = _load_profile(state) config = ResolverConfig(**get_stage_config(profile, "match_brands")) candidates = state.get("text_candidates", []) session_brands = state.get("session_brands", {}) job_id = state.get("job_id") source_asset_id = state.get("source_asset_id") matched, unresolved = resolve_brands( candidates, config, session_brands=session_brands, source_asset_id=source_asset_id, content_type=profile["name"], job_id=job_id, ) span.set_output({"matched": len(matched), "unresolved": len(unresolved)}) _emit(state, "match_brands", "done") return {"detections": matched, "unresolved_candidates": unresolved} def node_escalate_vlm(state: DetectState) -> dict: _emit(state, "escalate_vlm", "running") with trace_node(state, "escalate_vlm") as span: profile = _load_profile(state) vlm_config = get_stage_config(profile, "escalate_vlm") vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.") candidates = state.get("unresolved_candidates", []) job_id = state.get("job_id") vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template) vlm_matched, still_unresolved = escalate_vlm( candidates, vlm_prompt_fn=vlm_prompt_fn, inference_url=INFERENCE_URL, content_type=profile["name"], source_asset_id=state.get("source_asset_id"), job_id=job_id, ) stats = state.get("stats", PipelineStats()) stats.regions_escalated_to_local_vlm = len(candidates) span.set_output({"candidates": len(candidates), "matched": len(vlm_matched), "still_unresolved": len(still_unresolved)}) existing = state.get("detections", []) vlm_skipped = os.environ.get("SKIP_VLM", "").strip() == "1" _emit(state, "escalate_vlm", "skipped" if vlm_skipped else "done") return { "detections": existing + vlm_matched, "unresolved_candidates": still_unresolved, "stats": stats, } def node_escalate_cloud(state: DetectState) -> dict: _emit(state, "escalate_cloud", "running") with trace_node(state, "escalate_cloud") as span: profile = _load_profile(state) vlm_config = get_stage_config(profile, "escalate_vlm") vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.") candidates = state.get("unresolved_candidates", []) job_id = state.get("job_id") stats = state.get("stats", PipelineStats()) vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template) cloud_matched = escalate_cloud( candidates, vlm_prompt_fn=vlm_prompt_fn, stats=stats, content_type=profile["name"], source_asset_id=state.get("source_asset_id"), job_id=job_id, ) span.set_output({"candidates": len(candidates), "matched": len(cloud_matched), "cloud_calls": stats.cloud_llm_calls, "cost_usd": stats.estimated_cloud_cost_usd}) existing = state.get("detections", []) cloud_skipped = os.environ.get("SKIP_CLOUD", "").strip() == "1" _emit(state, "escalate_cloud", "skipped" if cloud_skipped else "done") return {"detections": existing + cloud_matched, "stats": stats} def node_compile_report(state: DetectState) -> dict: _emit(state, "compile_report", "running") with trace_node(state, "compile_report") as span: profile = _load_profile(state) detections = state.get("detections", []) stats = state.get("stats", PipelineStats()) job_id = state.get("job_id") report = compile_report( detections=detections, stats=stats, video_source=state.get("video_path", ""), content_type=profile["name"], job_id=job_id, ) span.set_output({"brands": len(report.brands), "detections": len(report.timeline)}) flush_traces() _emit(state, "compile_report", "done") return {"report": report} NODE_FUNCTIONS = [ ("extract_frames", node_extract_frames), ("filter_scenes", node_filter_scenes), ("field_segmentation", node_field_segmentation), ("detect_edges", node_detect_edges), ("detect_objects", node_detect_objects), ("preprocess", node_preprocess), ("run_ocr", node_run_ocr), ("match_brands", node_match_brands), ("escalate_vlm", node_escalate_vlm), ("escalate_cloud", node_escalate_cloud), ("compile_report", node_compile_report), ]