From d0707333fd8da5ddab29dfd1faa999c10180b2ed Mon Sep 17 00:00:00 2001 From: buenosairesam Date: Sat, 28 Mar 2026 10:05:59 -0300 Subject: [PATCH] phase 3 --- core/api/detect/config.py | 16 ++ core/api/detect/run.py | 2 +- core/schema/models/__init__.py | 1 + core/schema/models/pipeline_config.py | 46 ++++ detect/checkpoint/__init__.py | 2 + detect/checkpoint/runner_bridge.py | 64 +++++ detect/graph/__init__.py | 4 +- detect/graph/runner.py | 235 +++++++++++------- detect/profiles/base.py | 26 +- detect/profiles/soccer.py | 35 +++ tests/detect/test_graph.py | 6 +- .../src/panels/PipelineGraphPanel.vue | 64 +++-- 12 files changed, 381 insertions(+), 120 deletions(-) create mode 100644 core/schema/models/pipeline_config.py create mode 100644 detect/checkpoint/runner_bridge.py diff --git a/core/api/detect/config.py b/core/api/detect/config.py index ed12bf7..43f8a42 100644 --- a/core/api/detect/config.py +++ b/core/api/detect/config.py @@ -64,6 +64,22 @@ def list_profiles(): return [{"name": name} for name in _PROFILES] +@router.get("/config/profiles/{profile_name}/pipeline") +def get_pipeline_config(profile_name: str): + """Return the pipeline composition for a profile.""" + from detect.profiles import get_profile + from fastapi import HTTPException + from dataclasses import asdict + + try: + profile = get_profile(profile_name) + except ValueError: + raise HTTPException(status_code=404, detail=f"Unknown profile: {profile_name}") + + config = profile.pipeline_config() + return asdict(config) + + @router.get("/config/stages", response_model=list[StageConfigInfo]) def list_stage_configs(): """Return the stage palette with config field metadata for the editor.""" diff --git a/core/api/detect/run.py b/core/api/detect/run.py index 9b34bc3..3533629 100644 --- a/core/api/detect/run.py +++ b/core/api/detect/run.py @@ -88,7 +88,7 @@ def run_pipeline(req: RunRequest): log_level=req.log_level, ) - pipeline = get_pipeline(checkpoint=req.checkpoint) + pipeline = get_pipeline(checkpoint=req.checkpoint, profile_name=req.profile_name) initial_state = DetectState( video_path=local_path, diff --git a/core/schema/models/__init__.py b/core/schema/models/__init__.py index 5ca8c32..f4a1416 100644 --- a/core/schema/models/__init__.py +++ b/core/schema/models/__init__.py @@ -35,6 +35,7 @@ from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generi from .inference import INFERENCE_VIEWS # noqa: F401 — GPU inference server API types from .ui_state import UI_STATE_VIEWS # noqa: F401 — UI store state types from .stages import StageConfigField, StageIO, StageDefinition, STAGE_VIEWS # noqa: F401 +from .pipeline_config import StageRef, Edge, PipelineConfig, PIPELINE_CONFIG_VIEWS # noqa: F401 from .detect_api import RunRequest, RunResponse, DETECT_API_VIEWS # noqa: F401 from .views import ChunkEvent, ChunkOutputFile, PipelineStats, WorkerEvent from .sources import ChunkInfo, SourceJob, SourceType diff --git a/core/schema/models/pipeline_config.py b/core/schema/models/pipeline_config.py new file mode 100644 index 0000000..b2a710e --- /dev/null +++ b/core/schema/models/pipeline_config.py @@ -0,0 +1,46 @@ +""" +Pipeline composition config — source of truth for graph topology. + +Defines what stages run, in what order, with what branching. +Belongs to a profile. Persisted as JSONB. + +The execution strategy (serial, parallel, distributed) is separate — +the runner reads this config and flattens it into a sequence for now. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +@dataclass +class StageRef: + """Reference to a stage in the pipeline graph.""" + name: str # stage name (matches StageDefinition.name) + branch: str = "trunk" # which branch this belongs to + execution_target: str = "local" # local | gpu | lambda | gcp + + +@dataclass +class Edge: + """Connection between stages in the graph.""" + source: str # stage name + target: str # stage name + condition: str = "" # empty = unconditional, otherwise a routing rule key + + +@dataclass +class PipelineConfig: + """ + Pipeline graph topology + routing rules. + + Holder model — stages/edges define the graph shape, + routing_rules is a JSONB blob for decision tree logic. + """ + name: str + profile_name: str + stages: List[StageRef] = field(default_factory=list) + edges: List[Edge] = field(default_factory=list) + routing_rules: Dict[str, Any] = field(default_factory=dict) + + +PIPELINE_CONFIG_VIEWS = [StageRef, Edge, PipelineConfig] diff --git a/detect/checkpoint/__init__.py b/detect/checkpoint/__init__.py index ec64509..32b1c97 100644 --- a/detect/checkpoint/__init__.py +++ b/detect/checkpoint/__init__.py @@ -5,6 +5,7 @@ Checkpoint system — Timeline + Checkpoint tree. frames.py — frame image S3 upload/download storage.py — Timeline + Checkpoint (Postgres + MinIO) replay.py — replay (TODO: migrate to new model) + runner_bridge.py — checkpoint hook for PipelineRunner """ from .storage import ( @@ -15,3 +16,4 @@ from .storage import ( load_stage_output, ) from .frames import save_frames, load_frames +from .runner_bridge import checkpoint_after_stage, reset_checkpoint_state diff --git a/detect/checkpoint/runner_bridge.py b/detect/checkpoint/runner_bridge.py new file mode 100644 index 0000000..e2b53df --- /dev/null +++ b/detect/checkpoint/runner_bridge.py @@ -0,0 +1,64 @@ +""" +Runner bridge — checkpoint hook called by PipelineRunner after each stage. + +Owns the per-job state (frame manifest cache, checkpoint chain) that +the runner shouldn't know about. +""" + +from __future__ import annotations + +import logging + +logger = logging.getLogger(__name__) + +# Per-job state +_frames_manifest: dict[str, dict[int, str]] = {} +_latest_checkpoint: dict[str, str] = {} + + +def reset_checkpoint_state(job_id: str): + """Clean up per-job checkpoint state. Called when pipeline finishes.""" + _frames_manifest.pop(job_id, None) + _latest_checkpoint.pop(job_id, None) + + +def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict): + """ + Save a checkpoint after a stage completes. + + Called by the runner. Handles: + - Frame upload (once, on first stage) + - Stage output serialization (via stage registry) + - Checkpoint chain (parent → child) + """ + if not job_id: + return + + from .storage import save_stage_output + from .frames import save_frames + from detect.stages.base import _REGISTRY + + merged = {**state, **result} + + # Save frames once (first stage that produces them) + manifest = _frames_manifest.get(job_id) + if manifest is None and stage_name == "extract_frames": + manifest = save_frames(job_id, merged.get("frames", [])) + _frames_manifest[job_id] = manifest + + # Serialize stage output using the stage's serialize_fn if available + stage_cls = _REGISTRY.get(stage_name) + serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None) + if serialize_fn: + output_json = serialize_fn(merged, job_id) + else: + output_json = {} + + parent_id = _latest_checkpoint.get(job_id) + new_checkpoint_id = save_stage_output( + timeline_id=job_id, + parent_checkpoint_id=parent_id, + stage_name=stage_name, + output_json=output_json, + ) + _latest_checkpoint[job_id] = new_checkpoint_id diff --git a/detect/graph/__init__.py b/detect/graph/__init__.py index 48c5f5c..26fd1b6 100644 --- a/detect/graph/__init__.py +++ b/detect/graph/__init__.py @@ -4,12 +4,13 @@ Detection pipeline graph. detect/graph/ nodes.py — node functions (one per stage) events.py — graph_update SSE emission - runner.py — pipeline execution (LangGraph wrapper, checkpoint, cancel, pause) + runner.py — PipelineRunner (config-driven, checkpoint, cancel, pause) """ from .nodes import NODES, NODE_FUNCTIONS from .runner import ( PipelineCancelled, + PipelineRunner, build_graph, clear_cancel_check, clear_pause, @@ -28,6 +29,7 @@ __all__ = [ "NODES", "NODE_FUNCTIONS", "PipelineCancelled", + "PipelineRunner", "build_graph", "get_pipeline", "set_cancel_check", diff --git a/detect/graph/runner.py b/detect/graph/runner.py index ed27ffc..e3218b5 100644 --- a/detect/graph/runner.py +++ b/detect/graph/runner.py @@ -1,8 +1,10 @@ """ -Pipeline runner — executes stages sequentially with checkpointing and cancellation. +Pipeline runner — executes stages sequentially with checkpointing, +cancellation, and pause/resume. -Currently wraps LangGraph for execution. Will be replaced with a lean -custom runner in Phase 3, with an executor socket for distributed dispatch. +Reads PipelineConfig from the profile to determine what stages to run. +Flattens the graph into a linear sequence for now (serial execution). +Executor socket: all stages run via LocalExecutor (call function directly). """ from __future__ import annotations @@ -11,19 +13,14 @@ import logging import os import threading -from langgraph.graph import END, StateGraph - +from core.schema.models.pipeline_config import PipelineConfig from detect.state import DetectState from .nodes import NODES, NODE_FUNCTIONS logger = logging.getLogger(__name__) -# --- Checkpoint wrapper --- - _CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1" -_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job) -_latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id class PipelineCancelled(Exception): @@ -53,10 +50,6 @@ def clear_cancel_check(job_id: str): # --------------------------------------------------------------------------- # Pause / Resume / Step — checked after each node completes -# -# _pause_gate: threading.Event per job. When cleared, the runner blocks. -# When set, the runner proceeds to the next node. -# _pause_after_stage: if True, automatically clear the gate after each node. # --------------------------------------------------------------------------- _pause_gate: dict[str, threading.Event] = {} @@ -98,7 +91,7 @@ def step_pipeline(job_id: str): _pause_after_stage[job_id] = True gate = _pause_gate.get(job_id) if gate: - gate.set() # unblock for one stage, _pause_after_stage re-pauses after + gate.set() logger.info("Pipeline %s stepping", job_id) @@ -106,7 +99,6 @@ def set_pause_after_stage(job_id: str, enabled: bool): """Toggle pause-after-each-stage mode.""" _pause_after_stage[job_id] = enabled if not enabled: - # If disabling, also resume in case we're currently paused gate = _pause_gate.get(job_id) if gate: gate.set() @@ -124,106 +116,159 @@ def _wait_if_paused(job_id: str, node_name: str): if gate is None: return - # If pause-after-stage is on, pause now if _pause_after_stage.get(job_id, False): gate.clear() from detect import emit emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}") - # Block until gate is set (resume/step) or cancelled while not gate.wait(timeout=0.5): check = _cancel_check.get(job_id) if check and check(): raise PipelineCancelled(f"Cancelled while paused before next stage") -def _checkpointing_node(node_name: str, node_fn): - """Wrap a node function to auto-checkpoint after completion.""" - - def wrapper(state: DetectState) -> dict: - job_id = state.get("job_id", "") - check = _cancel_check.get(job_id) - if check and check(): - raise PipelineCancelled(f"Cancelled before {node_name}") - - result = node_fn(state) - - job_id = state.get("job_id", "") - if not job_id: - return result - - from detect.checkpoint import save_stage_output, save_frames - from detect.stages.base import _REGISTRY - - merged = {**state, **result} - - # Save frames once (first node), reuse manifest after - manifest = _frames_manifest.get(job_id) - if manifest is None and node_name == "extract_frames": - manifest = save_frames(job_id, merged.get("frames", [])) - _frames_manifest[job_id] = manifest - - # Serialize stage output using the stage's serialize_fn if available - stage_cls = _REGISTRY.get(node_name) - serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None) - if serialize_fn: - output_json = serialize_fn(merged, job_id) - else: - output_json = {} - - parent_id = _latest_checkpoint.get(job_id) - new_checkpoint_id = save_stage_output( - timeline_id=job_id, - parent_checkpoint_id=parent_id, - stage_name=node_name, - output_json=output_json, - ) - _latest_checkpoint[job_id] = new_checkpoint_id - - # Pause check — blocks if paused, respects cancel while waiting - _wait_if_paused(job_id, node_name) - - return result - - wrapper.__name__ = node_fn.__name__ - return wrapper -# --- Graph construction --- +# --------------------------------------------------------------------------- +# Pipeline Runner +# --------------------------------------------------------------------------- -def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph: +# Node function lookup — maps stage name to callable +_NODE_FN_MAP: dict[str, callable] = {name: fn for name, fn in NODE_FUNCTIONS} + + +def _flatten_config(config: PipelineConfig, start_from: str | None = None) -> list[str]: """ - Build the pipeline graph. + Flatten a PipelineConfig into a linear stage sequence. - checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var) - start_from: skip nodes before this stage (for replay) + For now: topological sort via edges. Falls back to stage order if no edges. + Respects start_from for replay (skip stages before it). """ - do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED + if not config.edges: + # No edges defined — use stage order as-is + names = [s.name for s in config.stages] + else: + # Topological sort from edges + graph: dict[str, list[str]] = {} + in_degree: dict[str, int] = {} + stage_names = {s.name for s in config.stages} - graph = StateGraph(DetectState) + for name in stage_names: + graph[name] = [] + in_degree[name] = 0 + + for edge in config.edges: + if edge.source in stage_names and edge.target in stage_names: + graph[edge.source].append(edge.target) + in_degree[edge.target] = in_degree.get(edge.target, 0) + 1 + + # Kahn's algorithm + queue = [n for n in stage_names if in_degree.get(n, 0) == 0] + # Stable sort: prefer order from config.stages + stage_order = {s.name: i for i, s in enumerate(config.stages)} + queue.sort(key=lambda n: stage_order.get(n, 999)) + + names = [] + while queue: + node = queue.pop(0) + names.append(node) + for neighbor in graph.get(node, []): + in_degree[neighbor] -= 1 + if in_degree[neighbor] == 0: + queue.append(neighbor) + queue.sort(key=lambda n: stage_order.get(n, 999)) - # Filter to start_from if replaying - node_pairs = NODE_FUNCTIONS if start_from: - start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from) - node_pairs = NODE_FUNCTIONS[start_idx:] + try: + idx = names.index(start_from) + names = names[idx:] + except ValueError: + raise ValueError(f"Stage {start_from!r} not in pipeline config") - for name, fn in node_pairs: - wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn - graph.add_node(name, wrapped) - - # Wire edges - entry = node_pairs[0][0] - graph.set_entry_point(entry) - - for i in range(len(node_pairs) - 1): - graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0]) - - graph.add_edge(node_pairs[-1][0], END) - - return graph + return names -def get_pipeline(checkpoint: bool | None = None): - """Return a compiled, runnable pipeline.""" - return build_graph(checkpoint=checkpoint).compile() +class PipelineRunner: + """ + Executes a pipeline defined by PipelineConfig. + + Runs stages sequentially (flattened). Each stage: + 1. Check cancel + 2. Run node function (via executor — local for now) + 3. Merge result into state + 4. Checkpoint (if enabled) + 5. Check pause + + Executor socket: currently calls node functions directly. + Future: dispatch to LocalExecutor / GrpcExecutor / LambdaExecutor + based on StageRef.execution_target. + """ + + def __init__( + self, + config: PipelineConfig, + checkpoint: bool = False, + start_from: str | None = None, + ): + self.config = config + self.do_checkpoint = checkpoint + self.stage_sequence = _flatten_config(config, start_from) + + def invoke(self, state: DetectState) -> DetectState: + """Run the pipeline on the given state. Returns final state.""" + for stage_name in self.stage_sequence: + job_id = state.get("job_id", "") + + # 1. Cancel check + check = _cancel_check.get(job_id) + if check and check(): + raise PipelineCancelled(f"Cancelled before {stage_name}") + + # 2. Run node function + node_fn = _NODE_FN_MAP.get(stage_name) + if node_fn is None: + logger.warning("No node function for stage %s, skipping", stage_name) + continue + + result = node_fn(state) + + # 3. Merge result into state + state.update(result) + + # 4. Checkpoint + if self.do_checkpoint: + from detect.checkpoint import checkpoint_after_stage + checkpoint_after_stage(job_id, stage_name, state, result) + + # 5. Pause check + _wait_if_paused(job_id, stage_name) + + return state + + +# --------------------------------------------------------------------------- +# Public API — backwards compatible with old get_pipeline/build_graph +# --------------------------------------------------------------------------- + +def get_pipeline( + checkpoint: bool | None = None, + profile_name: str = "soccer_broadcast", + start_from: str | None = None, +) -> PipelineRunner: + """Return a PipelineRunner for the given profile.""" + from detect.profiles import get_profile + + do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED + profile = get_profile(profile_name) + config = profile.pipeline_config() + + return PipelineRunner( + config=config, + checkpoint=do_checkpoint, + start_from=start_from, + ) + + +def build_graph(checkpoint: bool | None = None, start_from: str | None = None): + """Backwards-compatible wrapper. Returns a PipelineRunner.""" + return get_pipeline(checkpoint=checkpoint, start_from=start_from) diff --git a/detect/profiles/base.py b/detect/profiles/base.py index b0c4c03..f5ab78e 100644 --- a/detect/profiles/base.py +++ b/detect/profiles/base.py @@ -1,17 +1,20 @@ """ ContentTypeProfile protocol and config dataclasses. -The pipeline graph is fixed — what varies per content type is configuration -and hooks. Each profile provides stage configs, a brand dictionary, -VLM prompt templates, and an aggregation strategy. +Each profile defines the pipeline topology (as a JSONB blob), stage configs, +brand dictionary, VLM prompt templates, and aggregation strategy. + +When profiles are persisted, the pipeline field is a JSONB column. +For now, profiles are code-only and pipeline_config() returns a hardcoded value. """ from __future__ import annotations from dataclasses import dataclass, field -from typing import Protocol +from typing import Any, Dict, Protocol from detect.models import BrandDetection, DetectionReport +from core.schema.models.pipeline_config import PipelineConfig, StageRef, Edge @dataclass @@ -64,9 +67,24 @@ class CropContext: position_hint: str = "" +def pipeline_config_from_dict(data: Dict[str, Any]) -> PipelineConfig: + """Deserialize a PipelineConfig from a JSONB dict.""" + stages = [StageRef(**s) for s in data.get("stages", [])] + edges = [Edge(**e) for e in data.get("edges", [])] + return PipelineConfig( + name=data.get("name", ""), + profile_name=data.get("profile_name", ""), + stages=stages, + edges=edges, + routing_rules=data.get("routing_rules", {}), + ) + + class ContentTypeProfile(Protocol): name: str + pipeline: Dict[str, Any] # JSONB blob — PipelineConfig shape + def pipeline_config(self) -> PipelineConfig: ... def frame_extraction_config(self) -> FrameExtractionConfig: ... def scene_filter_config(self) -> SceneFilterConfig: ... def region_analysis_config(self) -> RegionAnalysisConfig: ... diff --git a/detect/profiles/soccer.py b/detect/profiles/soccer.py index 60e2a66..c35727f 100644 --- a/detect/profiles/soccer.py +++ b/detect/profiles/soccer.py @@ -2,6 +2,7 @@ from __future__ import annotations +from core.schema.models.pipeline_config import PipelineConfig from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats from .base import ( @@ -12,12 +13,46 @@ from .base import ( RegionAnalysisConfig, ResolverConfig, SceneFilterConfig, + pipeline_config_from_dict, ) class SoccerBroadcastProfile: name = "soccer_broadcast" + # Pipeline topology as JSONB — will be a DB field when profiles are persisted + pipeline = { + "name": "soccer_broadcast", + "profile_name": "soccer_broadcast", + "stages": [ + {"name": "extract_frames", "branch": "trunk"}, + {"name": "filter_scenes", "branch": "trunk"}, + {"name": "detect_edges", "branch": "hoarding"}, + {"name": "detect_objects", "branch": "objects"}, + {"name": "preprocess"}, + {"name": "run_ocr"}, + {"name": "match_brands"}, + {"name": "escalate_vlm"}, + {"name": "escalate_cloud"}, + {"name": "compile_report"}, + ], + "edges": [ + {"source": "extract_frames", "target": "filter_scenes"}, + {"source": "filter_scenes", "target": "detect_edges"}, + {"source": "filter_scenes", "target": "detect_objects"}, + {"source": "detect_edges", "target": "preprocess"}, + {"source": "detect_objects", "target": "preprocess"}, + {"source": "preprocess", "target": "run_ocr"}, + {"source": "run_ocr", "target": "match_brands"}, + {"source": "match_brands", "target": "escalate_vlm"}, + {"source": "escalate_vlm", "target": "escalate_cloud"}, + {"source": "escalate_cloud", "target": "compile_report"}, + ], + } + + def pipeline_config(self) -> PipelineConfig: + return pipeline_config_from_dict(self.pipeline) + def frame_extraction_config(self) -> FrameExtractionConfig: return FrameExtractionConfig(fps=2.0, max_frames=500) diff --git a/tests/detect/test_graph.py b/tests/detect/test_graph.py index 1b7455d..61ac5eb 100644 --- a/tests/detect/test_graph.py +++ b/tests/detect/test_graph.py @@ -1,4 +1,4 @@ -"""Tests for the LangGraph detection pipeline.""" +"""Tests for the detection pipeline runner.""" import os @@ -33,9 +33,9 @@ def test_graph_compiles(): def test_graph_has_all_nodes(): - graph = build_graph() + runner = build_graph() for node in NODES: - assert node in graph.nodes + assert node in runner.stage_sequence @requires_inference diff --git a/ui/detection-app/src/panels/PipelineGraphPanel.vue b/ui/detection-app/src/panels/PipelineGraphPanel.vue index ecd63ea..87bf4f7 100644 --- a/ui/detection-app/src/panels/PipelineGraphPanel.vue +++ b/ui/detection-app/src/panels/PipelineGraphPanel.vue @@ -1,10 +1,17 @@