Files
mediaproc/core/detect/graph/nodes.py
2026-03-30 09:53:10 -03:00

376 lines
14 KiB
Python

"""
Pipeline node functions — one per stage.
Each node: reads state, gets config from profile dict, runs stage logic,
emits transitions, returns output dict.
"""
from __future__ import annotations
import os
from core.detect import emit
from core.detect.models import CropContext, PipelineStats
from core.detect.profile import get_profile, get_stage_config, build_vlm_prompt, aggregate_detections
from core.detect.stages.models import (
DetectionConfig,
FieldSegmentationConfig,
FrameExtractionConfig,
OCRConfig,
RegionAnalysisConfig,
ResolverConfig,
SceneFilterConfig,
)
from core.detect.state import DetectState
from core.detect.stages.frame_extractor import extract_frames
from core.detect.stages.scene_filter import scene_filter
from core.detect.stages.field_segmentation import run_field_segmentation
from core.detect.stages.edge_detector import detect_edge_regions
from core.detect.stages.yolo_detector import detect_objects
from core.detect.stages.preprocess import preprocess_regions
from core.detect.stages.ocr_stage import run_ocr
from core.detect.stages.brand_resolver import resolve_brands
from core.detect.stages.vlm_local import escalate_vlm
from core.detect.stages.vlm_cloud import escalate_cloud
from core.detect.stages.aggregator import compile_report
from core.detect.tracing import trace_node, flush as flush_traces
from .events import emit_transition
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
NODES = [
"extract_frames",
"filter_scenes",
"field_segmentation",
"detect_edges",
"detect_objects",
"preprocess",
"run_ocr",
"match_brands",
"escalate_vlm",
"escalate_cloud",
"compile_report",
]
def _load_profile(state: DetectState) -> dict:
"""Load profile dict, apply config overrides if present."""
name = state.get("profile_name", "soccer_broadcast")
profile = get_profile(name)
overrides = state.get("config_overrides")
if overrides:
# Merge overrides into a copy of the profile configs
merged_configs = dict(profile.get("configs", {}))
for stage_name, stage_overrides in overrides.items():
if stage_name in merged_configs:
merged_configs[stage_name] = {**merged_configs[stage_name], **stage_overrides}
else:
merged_configs[stage_name] = stage_overrides
profile = {**profile, "configs": merged_configs}
return profile
def _emit(state, node, status):
emit_transition(state, node, status, NODES)
# --- Node functions ---
def node_extract_frames(state: DetectState) -> dict:
job_id = state.get("job_id", "")
if job_id and not emit._run_context:
emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial")
source_asset_id = state.get("source_asset_id")
if source_asset_id and not state.get("session_brands"):
from core.detect.stages.brand_resolver import build_session_dict
session_brands = build_session_dict(source_asset_id)
state["session_brands"] = session_brands
_emit(state, "extract_frames", "running")
with trace_node(state, "extract_frames") as span:
profile = _load_profile(state)
config = FrameExtractionConfig(**get_stage_config(profile, "extract_frames"))
frames = extract_frames(state["video_path"], config, job_id=job_id)
span.set_output({"frames_extracted": len(frames)})
# Cache frames on the timeline for reuse across jobs and UI
timeline_id = state.get("timeline_id")
if timeline_id:
from core.detect.checkpoint.frames import cache_frames, cache_exists
if not cache_exists(timeline_id):
cache_frames(timeline_id, frames)
from core.detect.checkpoint.storage import update_timeline_status
update_timeline_status(timeline_id, "cached", frame_count=len(frames))
_emit(state, "extract_frames", "done")
return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))}
def node_filter_scenes(state: DetectState) -> dict:
_emit(state, "filter_scenes", "running")
with trace_node(state, "filter_scenes") as span:
profile = _load_profile(state)
config = SceneFilterConfig(**get_stage_config(profile, "filter_scenes"))
frames = state.get("frames", [])
kept = scene_filter(frames, config, job_id=state.get("job_id"))
span.set_output({"frames_in": len(frames), "frames_kept": len(kept)})
stats = state.get("stats", PipelineStats())
stats.frames_after_scene_filter = len(kept)
_emit(state, "filter_scenes", "done")
return {"filtered_frames": kept, "stats": stats}
def node_field_segmentation(state: DetectState) -> dict:
_emit(state, "field_segmentation", "running")
with trace_node(state, "field_segmentation") as span:
profile = _load_profile(state)
config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation"))
frames = state.get("filtered_frames", [])
job_id = state.get("job_id")
result = run_field_segmentation(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
span.set_output({
"frames": len(frames),
"avg_coverage": sum(result["field_coverage"].values()) / max(len(result["field_coverage"]), 1),
})
_emit(state, "field_segmentation", "done")
return {
"field_masks": result["field_masks"],
"field_boundaries": result["field_boundaries"],
"field_coverage": result["field_coverage"],
}
def node_detect_edges(state: DetectState) -> dict:
_emit(state, "detect_edges", "running")
with trace_node(state, "detect_edges") as span:
profile = _load_profile(state)
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
frames = state.get("filtered_frames", [])
field_masks = state.get("field_masks", {})
job_id = state.get("job_id")
regions = detect_edge_regions(
frames, config, inference_url=INFERENCE_URL, job_id=job_id,
field_masks=field_masks,
)
total = sum(len(r) for r in regions.values())
span.set_output({"frames": len(frames), "edge_regions": total})
stats = state.get("stats", PipelineStats())
stats.cv_regions_detected = total
_emit(state, "detect_edges", "done")
return {"edge_regions_by_frame": regions, "stats": stats}
def node_detect_objects(state: DetectState) -> dict:
_emit(state, "detect_objects", "running")
with trace_node(state, "detect_objects") as span:
profile = _load_profile(state)
config = DetectionConfig(**get_stage_config(profile, "detect_objects"))
frames = state.get("filtered_frames", [])
job_id = state.get("job_id")
all_boxes = detect_objects(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
total_regions = sum(len(boxes) for boxes in all_boxes.values())
span.set_output({"frames": len(frames), "regions_detected": total_regions})
stats = state.get("stats", PipelineStats())
stats.regions_detected = total_regions
_emit(state, "detect_objects", "done")
return {"boxes_by_frame": all_boxes, "stats": stats}
def node_preprocess(state: DetectState) -> dict:
_emit(state, "preprocess", "running")
with trace_node(state, "preprocess") as span:
profile = _load_profile(state)
prep_config = get_stage_config(profile, "preprocess")
frames = state.get("filtered_frames", [])
boxes = state.get("boxes_by_frame", {})
job_id = state.get("job_id")
do_contrast = prep_config.get("contrast", True)
do_deskew = prep_config.get("deskew", False)
do_binarize = prep_config.get("binarize", False)
preprocessed = preprocess_regions(
frames, boxes,
do_contrast=do_contrast,
do_deskew=do_deskew,
do_binarize=do_binarize,
inference_url=INFERENCE_URL,
job_id=job_id,
)
span.set_output({"regions_preprocessed": len(preprocessed)})
_emit(state, "preprocess", "done")
return {"preprocessed_crops": preprocessed}
def node_run_ocr(state: DetectState) -> dict:
_emit(state, "run_ocr", "running")
with trace_node(state, "run_ocr") as span:
profile = _load_profile(state)
config = OCRConfig(**get_stage_config(profile, "run_ocr"))
frames = state.get("filtered_frames", [])
boxes = state.get("boxes_by_frame", {})
job_id = state.get("job_id")
candidates = run_ocr(frames, boxes, config, inference_url=INFERENCE_URL, job_id=job_id)
span.set_output({"regions_in": sum(len(b) for b in boxes.values()), "text_candidates": len(candidates)})
stats = state.get("stats", PipelineStats())
stats.regions_resolved_by_ocr = len(candidates)
_emit(state, "run_ocr", "done")
return {"text_candidates": candidates, "stats": stats}
def node_match_brands(state: DetectState) -> dict:
_emit(state, "match_brands", "running")
with trace_node(state, "match_brands") as span:
profile = _load_profile(state)
config = ResolverConfig(**get_stage_config(profile, "match_brands"))
candidates = state.get("text_candidates", [])
session_brands = state.get("session_brands", {})
job_id = state.get("job_id")
source_asset_id = state.get("source_asset_id")
matched, unresolved = resolve_brands(
candidates, config,
session_brands=session_brands,
source_asset_id=source_asset_id,
content_type=profile["name"], job_id=job_id,
)
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
_emit(state, "match_brands", "done")
return {"detections": matched, "unresolved_candidates": unresolved}
def node_escalate_vlm(state: DetectState) -> dict:
_emit(state, "escalate_vlm", "running")
with trace_node(state, "escalate_vlm") as span:
profile = _load_profile(state)
vlm_config = get_stage_config(profile, "escalate_vlm")
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
candidates = state.get("unresolved_candidates", [])
job_id = state.get("job_id")
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
vlm_matched, still_unresolved = escalate_vlm(
candidates,
vlm_prompt_fn=vlm_prompt_fn,
inference_url=INFERENCE_URL,
content_type=profile["name"],
source_asset_id=state.get("source_asset_id"),
job_id=job_id,
)
stats = state.get("stats", PipelineStats())
stats.regions_escalated_to_local_vlm = len(candidates)
span.set_output({"candidates": len(candidates), "matched": len(vlm_matched),
"still_unresolved": len(still_unresolved)})
existing = state.get("detections", [])
vlm_skipped = os.environ.get("SKIP_VLM", "").strip() == "1"
_emit(state, "escalate_vlm", "skipped" if vlm_skipped else "done")
return {
"detections": existing + vlm_matched,
"unresolved_candidates": still_unresolved,
"stats": stats,
}
def node_escalate_cloud(state: DetectState) -> dict:
_emit(state, "escalate_cloud", "running")
with trace_node(state, "escalate_cloud") as span:
profile = _load_profile(state)
vlm_config = get_stage_config(profile, "escalate_vlm")
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
candidates = state.get("unresolved_candidates", [])
job_id = state.get("job_id")
stats = state.get("stats", PipelineStats())
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
cloud_matched = escalate_cloud(
candidates,
vlm_prompt_fn=vlm_prompt_fn,
stats=stats,
content_type=profile["name"],
source_asset_id=state.get("source_asset_id"),
job_id=job_id,
)
span.set_output({"candidates": len(candidates), "matched": len(cloud_matched),
"cloud_calls": stats.cloud_llm_calls,
"cost_usd": stats.estimated_cloud_cost_usd})
existing = state.get("detections", [])
cloud_skipped = os.environ.get("SKIP_CLOUD", "").strip() == "1"
_emit(state, "escalate_cloud", "skipped" if cloud_skipped else "done")
return {"detections": existing + cloud_matched, "stats": stats}
def node_compile_report(state: DetectState) -> dict:
_emit(state, "compile_report", "running")
with trace_node(state, "compile_report") as span:
profile = _load_profile(state)
detections = state.get("detections", [])
stats = state.get("stats", PipelineStats())
job_id = state.get("job_id")
report = compile_report(
detections=detections,
stats=stats,
video_source=state.get("video_path", ""),
content_type=profile["name"],
job_id=job_id,
)
span.set_output({"brands": len(report.brands), "detections": len(report.timeline)})
flush_traces()
_emit(state, "compile_report", "done")
return {"report": report}
NODE_FUNCTIONS = [
("extract_frames", node_extract_frames),
("filter_scenes", node_filter_scenes),
("field_segmentation", node_field_segmentation),
("detect_edges", node_detect_edges),
("detect_objects", node_detect_objects),
("preprocess", node_preprocess),
("run_ocr", node_run_ocr),
("match_brands", node_match_brands),
("escalate_vlm", node_escalate_vlm),
("escalate_cloud", node_escalate_cloud),
("compile_report", node_compile_report),
]