367 lines
13 KiB
Python
367 lines
13 KiB
Python
"""
|
|
Pipeline node functions — one per stage.
|
|
|
|
Each node: reads state, gets config from profile dict, runs stage logic,
|
|
emits transitions, returns output dict.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
|
|
from core.detect import emit
|
|
from core.detect.models import CropContext, PipelineStats
|
|
from core.detect.profile import get_profile, get_stage_config, build_vlm_prompt, aggregate_detections
|
|
from core.detect.stages.models import (
|
|
DetectionConfig,
|
|
FieldSegmentationConfig,
|
|
FrameExtractionConfig,
|
|
OCRConfig,
|
|
RegionAnalysisConfig,
|
|
ResolverConfig,
|
|
SceneFilterConfig,
|
|
)
|
|
from core.detect.state import DetectState
|
|
from core.detect.stages.frame_extractor import extract_frames
|
|
from core.detect.stages.scene_filter import scene_filter
|
|
from core.detect.stages.field_segmentation import run_field_segmentation
|
|
from core.detect.stages.edge_detector import detect_edge_regions
|
|
from core.detect.stages.yolo_detector import detect_objects
|
|
from core.detect.stages.preprocess import preprocess_regions
|
|
from core.detect.stages.ocr_stage import run_ocr
|
|
from core.detect.stages.brand_resolver import resolve_brands
|
|
from core.detect.stages.vlm_local import escalate_vlm
|
|
from core.detect.stages.vlm_cloud import escalate_cloud
|
|
from core.detect.stages.aggregator import compile_report
|
|
from core.detect.tracing import trace_node, flush as flush_traces
|
|
|
|
from .events import emit_transition
|
|
|
|
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
|
|
|
|
NODES = [
|
|
"extract_frames",
|
|
"filter_scenes",
|
|
"field_segmentation",
|
|
"detect_edges",
|
|
"detect_objects",
|
|
"preprocess",
|
|
"run_ocr",
|
|
"match_brands",
|
|
"escalate_vlm",
|
|
"escalate_cloud",
|
|
"compile_report",
|
|
]
|
|
|
|
|
|
def _load_profile(state: DetectState) -> dict:
|
|
"""Load profile dict, apply config overrides if present."""
|
|
name = state.get("profile_name", "soccer_broadcast")
|
|
profile = get_profile(name)
|
|
|
|
overrides = state.get("config_overrides")
|
|
if overrides:
|
|
# Merge overrides into a copy of the profile configs
|
|
merged_configs = dict(profile.get("configs", {}))
|
|
for stage_name, stage_overrides in overrides.items():
|
|
if stage_name in merged_configs:
|
|
merged_configs[stage_name] = {**merged_configs[stage_name], **stage_overrides}
|
|
else:
|
|
merged_configs[stage_name] = stage_overrides
|
|
profile = {**profile, "configs": merged_configs}
|
|
|
|
return profile
|
|
|
|
|
|
def _emit(state, node, status):
|
|
emit_transition(state, node, status, NODES)
|
|
|
|
|
|
# --- Node functions ---
|
|
|
|
def node_extract_frames(state: DetectState) -> dict:
|
|
job_id = state.get("job_id", "")
|
|
if job_id and not emit._run_context:
|
|
emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial")
|
|
|
|
source_asset_id = state.get("source_asset_id")
|
|
if source_asset_id and not state.get("session_brands"):
|
|
from core.detect.stages.brand_resolver import build_session_dict
|
|
session_brands = build_session_dict(source_asset_id)
|
|
state["session_brands"] = session_brands
|
|
|
|
_emit(state, "extract_frames", "running")
|
|
|
|
with trace_node(state, "extract_frames") as span:
|
|
profile = _load_profile(state)
|
|
config = FrameExtractionConfig(**get_stage_config(profile, "extract_frames"))
|
|
frames = extract_frames(state["video_path"], config, job_id=job_id)
|
|
span.set_output({"frames_extracted": len(frames)})
|
|
|
|
_emit(state, "extract_frames", "done")
|
|
return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))}
|
|
|
|
|
|
def node_filter_scenes(state: DetectState) -> dict:
|
|
_emit(state, "filter_scenes", "running")
|
|
|
|
with trace_node(state, "filter_scenes") as span:
|
|
profile = _load_profile(state)
|
|
config = SceneFilterConfig(**get_stage_config(profile, "filter_scenes"))
|
|
frames = state.get("frames", [])
|
|
kept = scene_filter(frames, config, job_id=state.get("job_id"))
|
|
span.set_output({"frames_in": len(frames), "frames_kept": len(kept)})
|
|
|
|
stats = state.get("stats", PipelineStats())
|
|
stats.frames_after_scene_filter = len(kept)
|
|
|
|
_emit(state, "filter_scenes", "done")
|
|
return {"filtered_frames": kept, "stats": stats}
|
|
|
|
|
|
def node_field_segmentation(state: DetectState) -> dict:
|
|
_emit(state, "field_segmentation", "running")
|
|
|
|
with trace_node(state, "field_segmentation") as span:
|
|
profile = _load_profile(state)
|
|
config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation"))
|
|
frames = state.get("filtered_frames", [])
|
|
job_id = state.get("job_id")
|
|
|
|
result = run_field_segmentation(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
|
|
span.set_output({
|
|
"frames": len(frames),
|
|
"avg_coverage": sum(result["field_coverage"].values()) / max(len(result["field_coverage"]), 1),
|
|
})
|
|
|
|
_emit(state, "field_segmentation", "done")
|
|
return {
|
|
"field_masks": result["field_masks"],
|
|
"field_boundaries": result["field_boundaries"],
|
|
"field_coverage": result["field_coverage"],
|
|
}
|
|
|
|
|
|
def node_detect_edges(state: DetectState) -> dict:
|
|
_emit(state, "detect_edges", "running")
|
|
|
|
with trace_node(state, "detect_edges") as span:
|
|
profile = _load_profile(state)
|
|
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
|
frames = state.get("filtered_frames", [])
|
|
field_masks = state.get("field_masks", {})
|
|
job_id = state.get("job_id")
|
|
|
|
regions = detect_edge_regions(
|
|
frames, config, inference_url=INFERENCE_URL, job_id=job_id,
|
|
field_masks=field_masks,
|
|
)
|
|
total = sum(len(r) for r in regions.values())
|
|
span.set_output({"frames": len(frames), "edge_regions": total})
|
|
|
|
stats = state.get("stats", PipelineStats())
|
|
stats.cv_regions_detected = total
|
|
|
|
_emit(state, "detect_edges", "done")
|
|
return {"edge_regions_by_frame": regions, "stats": stats}
|
|
|
|
|
|
def node_detect_objects(state: DetectState) -> dict:
|
|
_emit(state, "detect_objects", "running")
|
|
|
|
with trace_node(state, "detect_objects") as span:
|
|
profile = _load_profile(state)
|
|
config = DetectionConfig(**get_stage_config(profile, "detect_objects"))
|
|
frames = state.get("filtered_frames", [])
|
|
job_id = state.get("job_id")
|
|
|
|
all_boxes = detect_objects(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
|
|
total_regions = sum(len(boxes) for boxes in all_boxes.values())
|
|
span.set_output({"frames": len(frames), "regions_detected": total_regions})
|
|
|
|
stats = state.get("stats", PipelineStats())
|
|
stats.regions_detected = total_regions
|
|
|
|
_emit(state, "detect_objects", "done")
|
|
return {"boxes_by_frame": all_boxes, "stats": stats}
|
|
|
|
|
|
def node_preprocess(state: DetectState) -> dict:
|
|
_emit(state, "preprocess", "running")
|
|
|
|
with trace_node(state, "preprocess") as span:
|
|
profile = _load_profile(state)
|
|
prep_config = get_stage_config(profile, "preprocess")
|
|
frames = state.get("filtered_frames", [])
|
|
boxes = state.get("boxes_by_frame", {})
|
|
job_id = state.get("job_id")
|
|
|
|
do_contrast = prep_config.get("contrast", True)
|
|
do_deskew = prep_config.get("deskew", False)
|
|
do_binarize = prep_config.get("binarize", False)
|
|
|
|
preprocessed = preprocess_regions(
|
|
frames, boxes,
|
|
do_contrast=do_contrast,
|
|
do_deskew=do_deskew,
|
|
do_binarize=do_binarize,
|
|
inference_url=INFERENCE_URL,
|
|
job_id=job_id,
|
|
)
|
|
span.set_output({"regions_preprocessed": len(preprocessed)})
|
|
|
|
_emit(state, "preprocess", "done")
|
|
return {"preprocessed_crops": preprocessed}
|
|
|
|
|
|
def node_run_ocr(state: DetectState) -> dict:
|
|
_emit(state, "run_ocr", "running")
|
|
|
|
with trace_node(state, "run_ocr") as span:
|
|
profile = _load_profile(state)
|
|
config = OCRConfig(**get_stage_config(profile, "run_ocr"))
|
|
frames = state.get("filtered_frames", [])
|
|
boxes = state.get("boxes_by_frame", {})
|
|
job_id = state.get("job_id")
|
|
|
|
candidates = run_ocr(frames, boxes, config, inference_url=INFERENCE_URL, job_id=job_id)
|
|
span.set_output({"regions_in": sum(len(b) for b in boxes.values()), "text_candidates": len(candidates)})
|
|
|
|
stats = state.get("stats", PipelineStats())
|
|
stats.regions_resolved_by_ocr = len(candidates)
|
|
|
|
_emit(state, "run_ocr", "done")
|
|
return {"text_candidates": candidates, "stats": stats}
|
|
|
|
|
|
def node_match_brands(state: DetectState) -> dict:
|
|
_emit(state, "match_brands", "running")
|
|
|
|
with trace_node(state, "match_brands") as span:
|
|
profile = _load_profile(state)
|
|
config = ResolverConfig(**get_stage_config(profile, "match_brands"))
|
|
candidates = state.get("text_candidates", [])
|
|
session_brands = state.get("session_brands", {})
|
|
job_id = state.get("job_id")
|
|
source_asset_id = state.get("source_asset_id")
|
|
|
|
matched, unresolved = resolve_brands(
|
|
candidates, config,
|
|
session_brands=session_brands,
|
|
source_asset_id=source_asset_id,
|
|
content_type=profile["name"], job_id=job_id,
|
|
)
|
|
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
|
|
|
|
_emit(state, "match_brands", "done")
|
|
return {"detections": matched, "unresolved_candidates": unresolved}
|
|
|
|
|
|
def node_escalate_vlm(state: DetectState) -> dict:
|
|
_emit(state, "escalate_vlm", "running")
|
|
|
|
with trace_node(state, "escalate_vlm") as span:
|
|
profile = _load_profile(state)
|
|
vlm_config = get_stage_config(profile, "escalate_vlm")
|
|
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
|
|
candidates = state.get("unresolved_candidates", [])
|
|
job_id = state.get("job_id")
|
|
|
|
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
|
|
|
|
vlm_matched, still_unresolved = escalate_vlm(
|
|
candidates,
|
|
vlm_prompt_fn=vlm_prompt_fn,
|
|
inference_url=INFERENCE_URL,
|
|
content_type=profile["name"],
|
|
source_asset_id=state.get("source_asset_id"),
|
|
job_id=job_id,
|
|
)
|
|
|
|
stats = state.get("stats", PipelineStats())
|
|
stats.regions_escalated_to_local_vlm = len(candidates)
|
|
span.set_output({"candidates": len(candidates), "matched": len(vlm_matched),
|
|
"still_unresolved": len(still_unresolved)})
|
|
|
|
existing = state.get("detections", [])
|
|
|
|
vlm_skipped = os.environ.get("SKIP_VLM", "").strip() == "1"
|
|
_emit(state, "escalate_vlm", "skipped" if vlm_skipped else "done")
|
|
return {
|
|
"detections": existing + vlm_matched,
|
|
"unresolved_candidates": still_unresolved,
|
|
"stats": stats,
|
|
}
|
|
|
|
|
|
def node_escalate_cloud(state: DetectState) -> dict:
|
|
_emit(state, "escalate_cloud", "running")
|
|
|
|
with trace_node(state, "escalate_cloud") as span:
|
|
profile = _load_profile(state)
|
|
vlm_config = get_stage_config(profile, "escalate_vlm")
|
|
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
|
|
candidates = state.get("unresolved_candidates", [])
|
|
job_id = state.get("job_id")
|
|
stats = state.get("stats", PipelineStats())
|
|
|
|
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
|
|
|
|
cloud_matched = escalate_cloud(
|
|
candidates,
|
|
vlm_prompt_fn=vlm_prompt_fn,
|
|
stats=stats,
|
|
content_type=profile["name"],
|
|
source_asset_id=state.get("source_asset_id"),
|
|
job_id=job_id,
|
|
)
|
|
|
|
span.set_output({"candidates": len(candidates), "matched": len(cloud_matched),
|
|
"cloud_calls": stats.cloud_llm_calls,
|
|
"cost_usd": stats.estimated_cloud_cost_usd})
|
|
|
|
existing = state.get("detections", [])
|
|
|
|
cloud_skipped = os.environ.get("SKIP_CLOUD", "").strip() == "1"
|
|
_emit(state, "escalate_cloud", "skipped" if cloud_skipped else "done")
|
|
return {"detections": existing + cloud_matched, "stats": stats}
|
|
|
|
|
|
def node_compile_report(state: DetectState) -> dict:
|
|
_emit(state, "compile_report", "running")
|
|
|
|
with trace_node(state, "compile_report") as span:
|
|
profile = _load_profile(state)
|
|
detections = state.get("detections", [])
|
|
stats = state.get("stats", PipelineStats())
|
|
job_id = state.get("job_id")
|
|
|
|
report = compile_report(
|
|
detections=detections,
|
|
stats=stats,
|
|
video_source=state.get("video_path", ""),
|
|
content_type=profile["name"],
|
|
job_id=job_id,
|
|
)
|
|
|
|
span.set_output({"brands": len(report.brands), "detections": len(report.timeline)})
|
|
|
|
flush_traces()
|
|
_emit(state, "compile_report", "done")
|
|
return {"report": report}
|
|
|
|
|
|
NODE_FUNCTIONS = [
|
|
("extract_frames", node_extract_frames),
|
|
("filter_scenes", node_filter_scenes),
|
|
("field_segmentation", node_field_segmentation),
|
|
("detect_edges", node_detect_edges),
|
|
("detect_objects", node_detect_objects),
|
|
("preprocess", node_preprocess),
|
|
("run_ocr", node_run_ocr),
|
|
("match_brands", node_match_brands),
|
|
("escalate_vlm", node_escalate_vlm),
|
|
("escalate_cloud", node_escalate_cloud),
|
|
("compile_report", node_compile_report),
|
|
]
|