phase 5
This commit is contained in:
@@ -35,6 +35,14 @@ def stats(job_id: str | None, **kwargs) -> None:
|
||||
push_detect_event(job_id, "stats_update", dataclasses.asdict(s))
|
||||
|
||||
|
||||
def graph_update(job_id: str | None, nodes: list[dict]) -> None:
|
||||
"""Emit a graph_update event with node states."""
|
||||
if not job_id:
|
||||
return
|
||||
payload = {"nodes": nodes}
|
||||
push_detect_event(job_id, "graph_update", payload)
|
||||
|
||||
|
||||
def detection(
|
||||
job_id: str | None,
|
||||
brand: str,
|
||||
|
||||
175
detect/graph.py
Normal file
175
detect/graph.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
LangGraph pipeline graph for brand detection.
|
||||
|
||||
Nodes execute real logic for extract+filter, stubs for the rest.
|
||||
Each node emits graph_update events so the UI can visualize transitions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langgraph.graph import END, StateGraph
|
||||
|
||||
from detect import emit
|
||||
from detect.models import PipelineStats
|
||||
from detect.profiles import SoccerBroadcastProfile
|
||||
from detect.state import DetectState
|
||||
from detect.stages.frame_extractor import extract_frames
|
||||
from detect.stages.scene_filter import scene_filter
|
||||
|
||||
NODES = [
|
||||
"extract_frames",
|
||||
"filter_scenes",
|
||||
"detect_objects",
|
||||
"run_ocr",
|
||||
"match_brands",
|
||||
"escalate_vlm",
|
||||
"escalate_cloud",
|
||||
"compile_report",
|
||||
]
|
||||
|
||||
|
||||
def _get_profile(state: DetectState):
|
||||
name = state.get("profile_name", "soccer_broadcast")
|
||||
if name == "soccer_broadcast":
|
||||
return SoccerBroadcastProfile()
|
||||
raise ValueError(f"Unknown profile: {name}")
|
||||
|
||||
|
||||
# Track node states across the pipeline run
|
||||
_node_states: dict[str, dict[str, str]] = {}
|
||||
|
||||
|
||||
def _emit_transition(state: DetectState, node: str, status: str):
|
||||
job_id = state.get("job_id")
|
||||
if not job_id:
|
||||
return
|
||||
|
||||
# Initialize state tracking for this job
|
||||
if job_id not in _node_states:
|
||||
_node_states[job_id] = {n: "pending" for n in NODES}
|
||||
|
||||
_node_states[job_id][node] = status
|
||||
|
||||
nodes = [{"id": n, "status": _node_states[job_id][n]} for n in NODES]
|
||||
emit.graph_update(job_id, nodes)
|
||||
|
||||
|
||||
# --- Node functions ---
|
||||
|
||||
def node_extract_frames(state: DetectState) -> dict:
|
||||
_emit_transition(state, "extract_frames", "running")
|
||||
|
||||
profile = _get_profile(state)
|
||||
config = profile.frame_extraction_config()
|
||||
frames = extract_frames(state["video_path"], config, job_id=state.get("job_id"))
|
||||
|
||||
_emit_transition(state, "extract_frames", "done")
|
||||
return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))}
|
||||
|
||||
|
||||
def node_filter_scenes(state: DetectState) -> dict:
|
||||
_emit_transition(state, "filter_scenes", "running")
|
||||
|
||||
profile = _get_profile(state)
|
||||
config = profile.scene_filter_config()
|
||||
frames = state.get("frames", [])
|
||||
kept = scene_filter(frames, config, job_id=state.get("job_id"))
|
||||
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.frames_after_scene_filter = len(kept)
|
||||
|
||||
_emit_transition(state, "filter_scenes", "done")
|
||||
return {"filtered_frames": kept, "stats": stats}
|
||||
|
||||
|
||||
def node_detect_objects(state: DetectState) -> dict:
|
||||
_emit_transition(state, "detect_objects", "running")
|
||||
job_id = state.get("job_id")
|
||||
emit.log(job_id, "YOLODetector", "INFO", "Stub: object detection not yet implemented")
|
||||
_emit_transition(state, "detect_objects", "done")
|
||||
return {}
|
||||
|
||||
|
||||
def node_run_ocr(state: DetectState) -> dict:
|
||||
_emit_transition(state, "run_ocr", "running")
|
||||
job_id = state.get("job_id")
|
||||
emit.log(job_id, "OCRStage", "INFO", "Stub: OCR not yet implemented")
|
||||
_emit_transition(state, "run_ocr", "done")
|
||||
return {}
|
||||
|
||||
|
||||
def node_match_brands(state: DetectState) -> dict:
|
||||
_emit_transition(state, "match_brands", "running")
|
||||
job_id = state.get("job_id")
|
||||
emit.log(job_id, "BrandResolver", "INFO", "Stub: brand matching not yet implemented")
|
||||
_emit_transition(state, "match_brands", "done")
|
||||
return {"detections": []}
|
||||
|
||||
|
||||
def node_escalate_vlm(state: DetectState) -> dict:
|
||||
_emit_transition(state, "escalate_vlm", "running")
|
||||
job_id = state.get("job_id")
|
||||
emit.log(job_id, "VLMLocal", "INFO", "Stub: VLM escalation not yet implemented")
|
||||
_emit_transition(state, "escalate_vlm", "done")
|
||||
return {}
|
||||
|
||||
|
||||
def node_escalate_cloud(state: DetectState) -> dict:
|
||||
_emit_transition(state, "escalate_cloud", "running")
|
||||
job_id = state.get("job_id")
|
||||
emit.log(job_id, "CloudLLM", "INFO", "Stub: cloud LLM escalation not yet implemented")
|
||||
_emit_transition(state, "escalate_cloud", "done")
|
||||
return {}
|
||||
|
||||
|
||||
def node_compile_report(state: DetectState) -> dict:
|
||||
_emit_transition(state, "compile_report", "running")
|
||||
job_id = state.get("job_id")
|
||||
|
||||
profile = _get_profile(state)
|
||||
detections = state.get("detections", [])
|
||||
report = profile.aggregate(detections)
|
||||
report.video_source = state.get("video_path", "")
|
||||
|
||||
emit.log(job_id, "Aggregator", "INFO",
|
||||
f"Report: {len(report.brands)} brands, {len(report.timeline)} detections")
|
||||
emit.job_complete(job_id, {
|
||||
"video_source": report.video_source,
|
||||
"content_type": report.content_type,
|
||||
"brands": {k: {"total_appearances": v.total_appearances} for k, v in report.brands.items()},
|
||||
})
|
||||
|
||||
_emit_transition(state, "compile_report", "done")
|
||||
return {"report": report}
|
||||
|
||||
|
||||
# --- Graph construction ---
|
||||
|
||||
def build_graph() -> StateGraph:
|
||||
graph = StateGraph(DetectState)
|
||||
|
||||
graph.add_node("extract_frames", node_extract_frames)
|
||||
graph.add_node("filter_scenes", node_filter_scenes)
|
||||
graph.add_node("detect_objects", node_detect_objects)
|
||||
graph.add_node("run_ocr", node_run_ocr)
|
||||
graph.add_node("match_brands", node_match_brands)
|
||||
graph.add_node("escalate_vlm", node_escalate_vlm)
|
||||
graph.add_node("escalate_cloud", node_escalate_cloud)
|
||||
graph.add_node("compile_report", node_compile_report)
|
||||
|
||||
graph.set_entry_point("extract_frames")
|
||||
graph.add_edge("extract_frames", "filter_scenes")
|
||||
graph.add_edge("filter_scenes", "detect_objects")
|
||||
graph.add_edge("detect_objects", "run_ocr")
|
||||
graph.add_edge("run_ocr", "match_brands")
|
||||
graph.add_edge("match_brands", "escalate_vlm")
|
||||
graph.add_edge("escalate_vlm", "escalate_cloud")
|
||||
graph.add_edge("escalate_cloud", "compile_report")
|
||||
graph.add_edge("compile_report", END)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def get_pipeline():
|
||||
"""Return a compiled, runnable pipeline."""
|
||||
return build_graph().compile()
|
||||
28
detect/state.py
Normal file
28
detect/state.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
LangGraph state definition for the detection pipeline.
|
||||
|
||||
This TypedDict flows through all graph nodes. Each node reads what
|
||||
it needs and writes its outputs. LangGraph manages the state transitions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
from detect.models import BrandDetection, DetectionReport, Frame, PipelineStats
|
||||
|
||||
|
||||
class DetectState(TypedDict, total=False):
|
||||
# Input
|
||||
video_path: str
|
||||
job_id: str
|
||||
profile_name: str
|
||||
|
||||
# Stage outputs
|
||||
frames: list[Frame]
|
||||
filtered_frames: list[Frame]
|
||||
detections: list[BrandDetection]
|
||||
report: DetectionReport
|
||||
|
||||
# Running stats (updated by each stage)
|
||||
stats: PipelineStats
|
||||
Reference in New Issue
Block a user