This commit is contained in:
2026-03-23 15:52:03 -03:00
parent b57da622cb
commit 4fdbdfc6d3
11 changed files with 599 additions and 5 deletions

View File

@@ -35,6 +35,14 @@ def stats(job_id: str | None, **kwargs) -> None:
push_detect_event(job_id, "stats_update", dataclasses.asdict(s))
def graph_update(job_id: str | None, nodes: list[dict]) -> None:
"""Emit a graph_update event with node states."""
if not job_id:
return
payload = {"nodes": nodes}
push_detect_event(job_id, "graph_update", payload)
def detection(
job_id: str | None,
brand: str,

175
detect/graph.py Normal file
View File

@@ -0,0 +1,175 @@
"""
LangGraph pipeline graph for brand detection.
Nodes execute real logic for extract+filter, stubs for the rest.
Each node emits graph_update events so the UI can visualize transitions.
"""
from __future__ import annotations
from langgraph.graph import END, StateGraph
from detect import emit
from detect.models import PipelineStats
from detect.profiles import SoccerBroadcastProfile
from detect.state import DetectState
from detect.stages.frame_extractor import extract_frames
from detect.stages.scene_filter import scene_filter
NODES = [
"extract_frames",
"filter_scenes",
"detect_objects",
"run_ocr",
"match_brands",
"escalate_vlm",
"escalate_cloud",
"compile_report",
]
def _get_profile(state: DetectState):
name = state.get("profile_name", "soccer_broadcast")
if name == "soccer_broadcast":
return SoccerBroadcastProfile()
raise ValueError(f"Unknown profile: {name}")
# Track node states across the pipeline run
_node_states: dict[str, dict[str, str]] = {}
def _emit_transition(state: DetectState, node: str, status: str):
job_id = state.get("job_id")
if not job_id:
return
# Initialize state tracking for this job
if job_id not in _node_states:
_node_states[job_id] = {n: "pending" for n in NODES}
_node_states[job_id][node] = status
nodes = [{"id": n, "status": _node_states[job_id][n]} for n in NODES]
emit.graph_update(job_id, nodes)
# --- Node functions ---
def node_extract_frames(state: DetectState) -> dict:
_emit_transition(state, "extract_frames", "running")
profile = _get_profile(state)
config = profile.frame_extraction_config()
frames = extract_frames(state["video_path"], config, job_id=state.get("job_id"))
_emit_transition(state, "extract_frames", "done")
return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))}
def node_filter_scenes(state: DetectState) -> dict:
_emit_transition(state, "filter_scenes", "running")
profile = _get_profile(state)
config = profile.scene_filter_config()
frames = state.get("frames", [])
kept = scene_filter(frames, config, job_id=state.get("job_id"))
stats = state.get("stats", PipelineStats())
stats.frames_after_scene_filter = len(kept)
_emit_transition(state, "filter_scenes", "done")
return {"filtered_frames": kept, "stats": stats}
def node_detect_objects(state: DetectState) -> dict:
_emit_transition(state, "detect_objects", "running")
job_id = state.get("job_id")
emit.log(job_id, "YOLODetector", "INFO", "Stub: object detection not yet implemented")
_emit_transition(state, "detect_objects", "done")
return {}
def node_run_ocr(state: DetectState) -> dict:
_emit_transition(state, "run_ocr", "running")
job_id = state.get("job_id")
emit.log(job_id, "OCRStage", "INFO", "Stub: OCR not yet implemented")
_emit_transition(state, "run_ocr", "done")
return {}
def node_match_brands(state: DetectState) -> dict:
_emit_transition(state, "match_brands", "running")
job_id = state.get("job_id")
emit.log(job_id, "BrandResolver", "INFO", "Stub: brand matching not yet implemented")
_emit_transition(state, "match_brands", "done")
return {"detections": []}
def node_escalate_vlm(state: DetectState) -> dict:
_emit_transition(state, "escalate_vlm", "running")
job_id = state.get("job_id")
emit.log(job_id, "VLMLocal", "INFO", "Stub: VLM escalation not yet implemented")
_emit_transition(state, "escalate_vlm", "done")
return {}
def node_escalate_cloud(state: DetectState) -> dict:
_emit_transition(state, "escalate_cloud", "running")
job_id = state.get("job_id")
emit.log(job_id, "CloudLLM", "INFO", "Stub: cloud LLM escalation not yet implemented")
_emit_transition(state, "escalate_cloud", "done")
return {}
def node_compile_report(state: DetectState) -> dict:
_emit_transition(state, "compile_report", "running")
job_id = state.get("job_id")
profile = _get_profile(state)
detections = state.get("detections", [])
report = profile.aggregate(detections)
report.video_source = state.get("video_path", "")
emit.log(job_id, "Aggregator", "INFO",
f"Report: {len(report.brands)} brands, {len(report.timeline)} detections")
emit.job_complete(job_id, {
"video_source": report.video_source,
"content_type": report.content_type,
"brands": {k: {"total_appearances": v.total_appearances} for k, v in report.brands.items()},
})
_emit_transition(state, "compile_report", "done")
return {"report": report}
# --- Graph construction ---
def build_graph() -> StateGraph:
graph = StateGraph(DetectState)
graph.add_node("extract_frames", node_extract_frames)
graph.add_node("filter_scenes", node_filter_scenes)
graph.add_node("detect_objects", node_detect_objects)
graph.add_node("run_ocr", node_run_ocr)
graph.add_node("match_brands", node_match_brands)
graph.add_node("escalate_vlm", node_escalate_vlm)
graph.add_node("escalate_cloud", node_escalate_cloud)
graph.add_node("compile_report", node_compile_report)
graph.set_entry_point("extract_frames")
graph.add_edge("extract_frames", "filter_scenes")
graph.add_edge("filter_scenes", "detect_objects")
graph.add_edge("detect_objects", "run_ocr")
graph.add_edge("run_ocr", "match_brands")
graph.add_edge("match_brands", "escalate_vlm")
graph.add_edge("escalate_vlm", "escalate_cloud")
graph.add_edge("escalate_cloud", "compile_report")
graph.add_edge("compile_report", END)
return graph
def get_pipeline():
"""Return a compiled, runnable pipeline."""
return build_graph().compile()

28
detect/state.py Normal file
View File

@@ -0,0 +1,28 @@
"""
LangGraph state definition for the detection pipeline.
This TypedDict flows through all graph nodes. Each node reads what
it needs and writes its outputs. LangGraph manages the state transitions.
"""
from __future__ import annotations
from typing import TypedDict
from detect.models import BrandDetection, DetectionReport, Frame, PipelineStats
class DetectState(TypedDict, total=False):
# Input
video_path: str
job_id: str
profile_name: str
# Stage outputs
frames: list[Frame]
filtered_frames: list[Frame]
detections: list[BrandDetection]
report: DetectionReport
# Running stats (updated by each stage)
stats: PipelineStats