phase 10
This commit is contained in:
127
detect/graph.py
127
detect/graph.py
@@ -42,8 +42,16 @@ NODES = [
|
||||
def _get_profile(state: DetectState):
|
||||
name = state.get("profile_name", "soccer_broadcast")
|
||||
if name == "soccer_broadcast":
|
||||
return SoccerBroadcastProfile()
|
||||
raise ValueError(f"Unknown profile: {name}")
|
||||
profile = SoccerBroadcastProfile()
|
||||
else:
|
||||
raise ValueError(f"Unknown profile: {name}")
|
||||
|
||||
overrides = state.get("config_overrides")
|
||||
if overrides:
|
||||
from detect.checkpoint.replay import OverrideProfile
|
||||
profile = OverrideProfile(profile, overrides)
|
||||
|
||||
return profile
|
||||
|
||||
|
||||
# Track node states across the pipeline run
|
||||
@@ -68,6 +76,18 @@ def _emit_transition(state: DetectState, node: str, status: str):
|
||||
# --- Node functions ---
|
||||
|
||||
def node_extract_frames(state: DetectState) -> dict:
|
||||
# Set run context for initial runs (replays set it in replay_from)
|
||||
job_id = state.get("job_id", "")
|
||||
if job_id and not emit._run_context:
|
||||
emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial")
|
||||
|
||||
# Load session brands from DB for this source
|
||||
source_asset_id = state.get("source_asset_id")
|
||||
if source_asset_id and not state.get("session_brands"):
|
||||
from detect.stages.brand_resolver import build_session_dict
|
||||
session_brands = build_session_dict(source_asset_id)
|
||||
state["session_brands"] = session_brands
|
||||
|
||||
_emit_transition(state, "extract_frames", "running")
|
||||
|
||||
with trace_node(state, "extract_frames") as span:
|
||||
@@ -142,13 +162,16 @@ def node_match_brands(state: DetectState) -> dict:
|
||||
|
||||
with trace_node(state, "match_brands") as span:
|
||||
profile = _get_profile(state)
|
||||
dictionary = profile.brand_dictionary()
|
||||
resolver_config = profile.resolver_config()
|
||||
candidates = state.get("text_candidates", [])
|
||||
session_brands = state.get("session_brands", {})
|
||||
job_id = state.get("job_id")
|
||||
source_asset_id = state.get("source_asset_id")
|
||||
|
||||
matched, unresolved = resolve_brands(
|
||||
candidates, dictionary, resolver_config,
|
||||
candidates, resolver_config,
|
||||
session_brands=session_brands,
|
||||
source_asset_id=source_asset_id,
|
||||
content_type=profile.name, job_id=job_id,
|
||||
)
|
||||
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
|
||||
@@ -170,6 +193,7 @@ def node_escalate_vlm(state: DetectState) -> dict:
|
||||
vlm_prompt_fn=profile.vlm_prompt,
|
||||
inference_url=INFERENCE_URL,
|
||||
content_type=profile.name,
|
||||
source_asset_id=state.get("source_asset_id"),
|
||||
job_id=job_id,
|
||||
)
|
||||
|
||||
@@ -202,6 +226,7 @@ def node_escalate_cloud(state: DetectState) -> dict:
|
||||
vlm_prompt_fn=profile.vlm_prompt,
|
||||
stats=stats,
|
||||
content_type=profile.name,
|
||||
source_asset_id=state.get("source_asset_id"),
|
||||
job_id=job_id,
|
||||
)
|
||||
|
||||
@@ -239,33 +264,87 @@ def node_compile_report(state: DetectState) -> dict:
|
||||
return {"report": report}
|
||||
|
||||
|
||||
# --- Checkpoint wrapper ---
|
||||
|
||||
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
|
||||
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
|
||||
|
||||
|
||||
def _checkpointing_node(node_name: str, node_fn):
|
||||
"""Wrap a node function to auto-checkpoint after completion."""
|
||||
stage_index = NODES.index(node_name)
|
||||
|
||||
def wrapper(state: DetectState) -> dict:
|
||||
result = node_fn(state)
|
||||
|
||||
job_id = state.get("job_id", "")
|
||||
if not job_id:
|
||||
return result
|
||||
|
||||
from detect.checkpoint import save_checkpoint, save_frames
|
||||
|
||||
merged = {**state, **result}
|
||||
|
||||
# Save frames once (first checkpoint), reuse manifest after
|
||||
manifest = _frames_manifest.get(job_id)
|
||||
if manifest is None and node_name == "extract_frames":
|
||||
manifest = save_frames(job_id, merged.get("frames", []))
|
||||
_frames_manifest[job_id] = manifest
|
||||
|
||||
save_checkpoint(job_id, node_name, stage_index, merged, frames_manifest=manifest)
|
||||
return result
|
||||
|
||||
wrapper.__name__ = node_fn.__name__
|
||||
return wrapper
|
||||
|
||||
|
||||
# --- Graph construction ---
|
||||
|
||||
def build_graph() -> StateGraph:
|
||||
NODE_FUNCTIONS = [
|
||||
("extract_frames", node_extract_frames),
|
||||
("filter_scenes", node_filter_scenes),
|
||||
("detect_objects", node_detect_objects),
|
||||
("run_ocr", node_run_ocr),
|
||||
("match_brands", node_match_brands),
|
||||
("escalate_vlm", node_escalate_vlm),
|
||||
("escalate_cloud", node_escalate_cloud),
|
||||
("compile_report", node_compile_report),
|
||||
]
|
||||
|
||||
|
||||
def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph:
|
||||
"""
|
||||
Build the pipeline graph.
|
||||
|
||||
checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var)
|
||||
start_from: skip nodes before this stage (for replay)
|
||||
"""
|
||||
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
|
||||
|
||||
graph = StateGraph(DetectState)
|
||||
|
||||
graph.add_node("extract_frames", node_extract_frames)
|
||||
graph.add_node("filter_scenes", node_filter_scenes)
|
||||
graph.add_node("detect_objects", node_detect_objects)
|
||||
graph.add_node("run_ocr", node_run_ocr)
|
||||
graph.add_node("match_brands", node_match_brands)
|
||||
graph.add_node("escalate_vlm", node_escalate_vlm)
|
||||
graph.add_node("escalate_cloud", node_escalate_cloud)
|
||||
graph.add_node("compile_report", node_compile_report)
|
||||
# Filter to start_from if replaying
|
||||
node_pairs = NODE_FUNCTIONS
|
||||
if start_from:
|
||||
start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from)
|
||||
node_pairs = NODE_FUNCTIONS[start_idx:]
|
||||
|
||||
graph.set_entry_point("extract_frames")
|
||||
graph.add_edge("extract_frames", "filter_scenes")
|
||||
graph.add_edge("filter_scenes", "detect_objects")
|
||||
graph.add_edge("detect_objects", "run_ocr")
|
||||
graph.add_edge("run_ocr", "match_brands")
|
||||
graph.add_edge("match_brands", "escalate_vlm")
|
||||
graph.add_edge("escalate_vlm", "escalate_cloud")
|
||||
graph.add_edge("escalate_cloud", "compile_report")
|
||||
graph.add_edge("compile_report", END)
|
||||
for name, fn in node_pairs:
|
||||
wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn
|
||||
graph.add_node(name, wrapped)
|
||||
|
||||
# Wire edges
|
||||
entry = node_pairs[0][0]
|
||||
graph.set_entry_point(entry)
|
||||
|
||||
for i in range(len(node_pairs) - 1):
|
||||
graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0])
|
||||
|
||||
graph.add_edge(node_pairs[-1][0], END)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def get_pipeline():
|
||||
def get_pipeline(checkpoint: bool | None = None):
|
||||
"""Return a compiled, runnable pipeline."""
|
||||
return build_graph().compile()
|
||||
return build_graph(checkpoint=checkpoint).compile()
|
||||
|
||||
Reference in New Issue
Block a user