phase 3
This commit is contained in:
@@ -64,6 +64,22 @@ def list_profiles():
|
|||||||
return [{"name": name} for name in _PROFILES]
|
return [{"name": name} for name in _PROFILES]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config/profiles/{profile_name}/pipeline")
|
||||||
|
def get_pipeline_config(profile_name: str):
|
||||||
|
"""Return the pipeline composition for a profile."""
|
||||||
|
from detect.profiles import get_profile
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from dataclasses import asdict
|
||||||
|
|
||||||
|
try:
|
||||||
|
profile = get_profile(profile_name)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Unknown profile: {profile_name}")
|
||||||
|
|
||||||
|
config = profile.pipeline_config()
|
||||||
|
return asdict(config)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/config/stages", response_model=list[StageConfigInfo])
|
@router.get("/config/stages", response_model=list[StageConfigInfo])
|
||||||
def list_stage_configs():
|
def list_stage_configs():
|
||||||
"""Return the stage palette with config field metadata for the editor."""
|
"""Return the stage palette with config field metadata for the editor."""
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ def run_pipeline(req: RunRequest):
|
|||||||
log_level=req.log_level,
|
log_level=req.log_level,
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline = get_pipeline(checkpoint=req.checkpoint)
|
pipeline = get_pipeline(checkpoint=req.checkpoint, profile_name=req.profile_name)
|
||||||
|
|
||||||
initial_state = DetectState(
|
initial_state = DetectState(
|
||||||
video_path=local_path,
|
video_path=local_path,
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generi
|
|||||||
from .inference import INFERENCE_VIEWS # noqa: F401 — GPU inference server API types
|
from .inference import INFERENCE_VIEWS # noqa: F401 — GPU inference server API types
|
||||||
from .ui_state import UI_STATE_VIEWS # noqa: F401 — UI store state types
|
from .ui_state import UI_STATE_VIEWS # noqa: F401 — UI store state types
|
||||||
from .stages import StageConfigField, StageIO, StageDefinition, STAGE_VIEWS # noqa: F401
|
from .stages import StageConfigField, StageIO, StageDefinition, STAGE_VIEWS # noqa: F401
|
||||||
|
from .pipeline_config import StageRef, Edge, PipelineConfig, PIPELINE_CONFIG_VIEWS # noqa: F401
|
||||||
from .detect_api import RunRequest, RunResponse, DETECT_API_VIEWS # noqa: F401
|
from .detect_api import RunRequest, RunResponse, DETECT_API_VIEWS # noqa: F401
|
||||||
from .views import ChunkEvent, ChunkOutputFile, PipelineStats, WorkerEvent
|
from .views import ChunkEvent, ChunkOutputFile, PipelineStats, WorkerEvent
|
||||||
from .sources import ChunkInfo, SourceJob, SourceType
|
from .sources import ChunkInfo, SourceJob, SourceType
|
||||||
|
|||||||
46
core/schema/models/pipeline_config.py
Normal file
46
core/schema/models/pipeline_config.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""
|
||||||
|
Pipeline composition config — source of truth for graph topology.
|
||||||
|
|
||||||
|
Defines what stages run, in what order, with what branching.
|
||||||
|
Belongs to a profile. Persisted as JSONB.
|
||||||
|
|
||||||
|
The execution strategy (serial, parallel, distributed) is separate —
|
||||||
|
the runner reads this config and flattens it into a sequence for now.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StageRef:
|
||||||
|
"""Reference to a stage in the pipeline graph."""
|
||||||
|
name: str # stage name (matches StageDefinition.name)
|
||||||
|
branch: str = "trunk" # which branch this belongs to
|
||||||
|
execution_target: str = "local" # local | gpu | lambda | gcp
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Edge:
|
||||||
|
"""Connection between stages in the graph."""
|
||||||
|
source: str # stage name
|
||||||
|
target: str # stage name
|
||||||
|
condition: str = "" # empty = unconditional, otherwise a routing rule key
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineConfig:
|
||||||
|
"""
|
||||||
|
Pipeline graph topology + routing rules.
|
||||||
|
|
||||||
|
Holder model — stages/edges define the graph shape,
|
||||||
|
routing_rules is a JSONB blob for decision tree logic.
|
||||||
|
"""
|
||||||
|
name: str
|
||||||
|
profile_name: str
|
||||||
|
stages: List[StageRef] = field(default_factory=list)
|
||||||
|
edges: List[Edge] = field(default_factory=list)
|
||||||
|
routing_rules: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
PIPELINE_CONFIG_VIEWS = [StageRef, Edge, PipelineConfig]
|
||||||
@@ -5,6 +5,7 @@ Checkpoint system — Timeline + Checkpoint tree.
|
|||||||
frames.py — frame image S3 upload/download
|
frames.py — frame image S3 upload/download
|
||||||
storage.py — Timeline + Checkpoint (Postgres + MinIO)
|
storage.py — Timeline + Checkpoint (Postgres + MinIO)
|
||||||
replay.py — replay (TODO: migrate to new model)
|
replay.py — replay (TODO: migrate to new model)
|
||||||
|
runner_bridge.py — checkpoint hook for PipelineRunner
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .storage import (
|
from .storage import (
|
||||||
@@ -15,3 +16,4 @@ from .storage import (
|
|||||||
load_stage_output,
|
load_stage_output,
|
||||||
)
|
)
|
||||||
from .frames import save_frames, load_frames
|
from .frames import save_frames, load_frames
|
||||||
|
from .runner_bridge import checkpoint_after_stage, reset_checkpoint_state
|
||||||
|
|||||||
64
detect/checkpoint/runner_bridge.py
Normal file
64
detect/checkpoint/runner_bridge.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
"""
|
||||||
|
Runner bridge — checkpoint hook called by PipelineRunner after each stage.
|
||||||
|
|
||||||
|
Owns the per-job state (frame manifest cache, checkpoint chain) that
|
||||||
|
the runner shouldn't know about.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Per-job state
|
||||||
|
_frames_manifest: dict[str, dict[int, str]] = {}
|
||||||
|
_latest_checkpoint: dict[str, str] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def reset_checkpoint_state(job_id: str):
|
||||||
|
"""Clean up per-job checkpoint state. Called when pipeline finishes."""
|
||||||
|
_frames_manifest.pop(job_id, None)
|
||||||
|
_latest_checkpoint.pop(job_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict):
|
||||||
|
"""
|
||||||
|
Save a checkpoint after a stage completes.
|
||||||
|
|
||||||
|
Called by the runner. Handles:
|
||||||
|
- Frame upload (once, on first stage)
|
||||||
|
- Stage output serialization (via stage registry)
|
||||||
|
- Checkpoint chain (parent → child)
|
||||||
|
"""
|
||||||
|
if not job_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
from .storage import save_stage_output
|
||||||
|
from .frames import save_frames
|
||||||
|
from detect.stages.base import _REGISTRY
|
||||||
|
|
||||||
|
merged = {**state, **result}
|
||||||
|
|
||||||
|
# Save frames once (first stage that produces them)
|
||||||
|
manifest = _frames_manifest.get(job_id)
|
||||||
|
if manifest is None and stage_name == "extract_frames":
|
||||||
|
manifest = save_frames(job_id, merged.get("frames", []))
|
||||||
|
_frames_manifest[job_id] = manifest
|
||||||
|
|
||||||
|
# Serialize stage output using the stage's serialize_fn if available
|
||||||
|
stage_cls = _REGISTRY.get(stage_name)
|
||||||
|
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
||||||
|
if serialize_fn:
|
||||||
|
output_json = serialize_fn(merged, job_id)
|
||||||
|
else:
|
||||||
|
output_json = {}
|
||||||
|
|
||||||
|
parent_id = _latest_checkpoint.get(job_id)
|
||||||
|
new_checkpoint_id = save_stage_output(
|
||||||
|
timeline_id=job_id,
|
||||||
|
parent_checkpoint_id=parent_id,
|
||||||
|
stage_name=stage_name,
|
||||||
|
output_json=output_json,
|
||||||
|
)
|
||||||
|
_latest_checkpoint[job_id] = new_checkpoint_id
|
||||||
@@ -4,12 +4,13 @@ Detection pipeline graph.
|
|||||||
detect/graph/
|
detect/graph/
|
||||||
nodes.py — node functions (one per stage)
|
nodes.py — node functions (one per stage)
|
||||||
events.py — graph_update SSE emission
|
events.py — graph_update SSE emission
|
||||||
runner.py — pipeline execution (LangGraph wrapper, checkpoint, cancel, pause)
|
runner.py — PipelineRunner (config-driven, checkpoint, cancel, pause)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .nodes import NODES, NODE_FUNCTIONS
|
from .nodes import NODES, NODE_FUNCTIONS
|
||||||
from .runner import (
|
from .runner import (
|
||||||
PipelineCancelled,
|
PipelineCancelled,
|
||||||
|
PipelineRunner,
|
||||||
build_graph,
|
build_graph,
|
||||||
clear_cancel_check,
|
clear_cancel_check,
|
||||||
clear_pause,
|
clear_pause,
|
||||||
@@ -28,6 +29,7 @@ __all__ = [
|
|||||||
"NODES",
|
"NODES",
|
||||||
"NODE_FUNCTIONS",
|
"NODE_FUNCTIONS",
|
||||||
"PipelineCancelled",
|
"PipelineCancelled",
|
||||||
|
"PipelineRunner",
|
||||||
"build_graph",
|
"build_graph",
|
||||||
"get_pipeline",
|
"get_pipeline",
|
||||||
"set_cancel_check",
|
"set_cancel_check",
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
Pipeline runner — executes stages sequentially with checkpointing and cancellation.
|
Pipeline runner — executes stages sequentially with checkpointing,
|
||||||
|
cancellation, and pause/resume.
|
||||||
|
|
||||||
Currently wraps LangGraph for execution. Will be replaced with a lean
|
Reads PipelineConfig from the profile to determine what stages to run.
|
||||||
custom runner in Phase 3, with an executor socket for distributed dispatch.
|
Flattens the graph into a linear sequence for now (serial execution).
|
||||||
|
Executor socket: all stages run via LocalExecutor (call function directly).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -11,19 +13,14 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from langgraph.graph import END, StateGraph
|
from core.schema.models.pipeline_config import PipelineConfig
|
||||||
|
|
||||||
from detect.state import DetectState
|
from detect.state import DetectState
|
||||||
from .nodes import NODES, NODE_FUNCTIONS
|
from .nodes import NODES, NODE_FUNCTIONS
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# --- Checkpoint wrapper ---
|
|
||||||
|
|
||||||
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
|
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
|
||||||
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
|
|
||||||
_latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineCancelled(Exception):
|
class PipelineCancelled(Exception):
|
||||||
@@ -53,10 +50,6 @@ def clear_cancel_check(job_id: str):
|
|||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Pause / Resume / Step — checked after each node completes
|
# Pause / Resume / Step — checked after each node completes
|
||||||
#
|
|
||||||
# _pause_gate: threading.Event per job. When cleared, the runner blocks.
|
|
||||||
# When set, the runner proceeds to the next node.
|
|
||||||
# _pause_after_stage: if True, automatically clear the gate after each node.
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
_pause_gate: dict[str, threading.Event] = {}
|
_pause_gate: dict[str, threading.Event] = {}
|
||||||
@@ -98,7 +91,7 @@ def step_pipeline(job_id: str):
|
|||||||
_pause_after_stage[job_id] = True
|
_pause_after_stage[job_id] = True
|
||||||
gate = _pause_gate.get(job_id)
|
gate = _pause_gate.get(job_id)
|
||||||
if gate:
|
if gate:
|
||||||
gate.set() # unblock for one stage, _pause_after_stage re-pauses after
|
gate.set()
|
||||||
logger.info("Pipeline %s stepping", job_id)
|
logger.info("Pipeline %s stepping", job_id)
|
||||||
|
|
||||||
|
|
||||||
@@ -106,7 +99,6 @@ def set_pause_after_stage(job_id: str, enabled: bool):
|
|||||||
"""Toggle pause-after-each-stage mode."""
|
"""Toggle pause-after-each-stage mode."""
|
||||||
_pause_after_stage[job_id] = enabled
|
_pause_after_stage[job_id] = enabled
|
||||||
if not enabled:
|
if not enabled:
|
||||||
# If disabling, also resume in case we're currently paused
|
|
||||||
gate = _pause_gate.get(job_id)
|
gate = _pause_gate.get(job_id)
|
||||||
if gate:
|
if gate:
|
||||||
gate.set()
|
gate.set()
|
||||||
@@ -124,106 +116,159 @@ def _wait_if_paused(job_id: str, node_name: str):
|
|||||||
if gate is None:
|
if gate is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# If pause-after-stage is on, pause now
|
|
||||||
if _pause_after_stage.get(job_id, False):
|
if _pause_after_stage.get(job_id, False):
|
||||||
gate.clear()
|
gate.clear()
|
||||||
from detect import emit
|
from detect import emit
|
||||||
emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}")
|
emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}")
|
||||||
|
|
||||||
# Block until gate is set (resume/step) or cancelled
|
|
||||||
while not gate.wait(timeout=0.5):
|
while not gate.wait(timeout=0.5):
|
||||||
check = _cancel_check.get(job_id)
|
check = _cancel_check.get(job_id)
|
||||||
if check and check():
|
if check and check():
|
||||||
raise PipelineCancelled(f"Cancelled while paused before next stage")
|
raise PipelineCancelled(f"Cancelled while paused before next stage")
|
||||||
|
|
||||||
|
|
||||||
def _checkpointing_node(node_name: str, node_fn):
|
|
||||||
"""Wrap a node function to auto-checkpoint after completion."""
|
|
||||||
|
|
||||||
def wrapper(state: DetectState) -> dict:
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pipeline Runner
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Node function lookup — maps stage name to callable
|
||||||
|
_NODE_FN_MAP: dict[str, callable] = {name: fn for name, fn in NODE_FUNCTIONS}
|
||||||
|
|
||||||
|
|
||||||
|
def _flatten_config(config: PipelineConfig, start_from: str | None = None) -> list[str]:
|
||||||
|
"""
|
||||||
|
Flatten a PipelineConfig into a linear stage sequence.
|
||||||
|
|
||||||
|
For now: topological sort via edges. Falls back to stage order if no edges.
|
||||||
|
Respects start_from for replay (skip stages before it).
|
||||||
|
"""
|
||||||
|
if not config.edges:
|
||||||
|
# No edges defined — use stage order as-is
|
||||||
|
names = [s.name for s in config.stages]
|
||||||
|
else:
|
||||||
|
# Topological sort from edges
|
||||||
|
graph: dict[str, list[str]] = {}
|
||||||
|
in_degree: dict[str, int] = {}
|
||||||
|
stage_names = {s.name for s in config.stages}
|
||||||
|
|
||||||
|
for name in stage_names:
|
||||||
|
graph[name] = []
|
||||||
|
in_degree[name] = 0
|
||||||
|
|
||||||
|
for edge in config.edges:
|
||||||
|
if edge.source in stage_names and edge.target in stage_names:
|
||||||
|
graph[edge.source].append(edge.target)
|
||||||
|
in_degree[edge.target] = in_degree.get(edge.target, 0) + 1
|
||||||
|
|
||||||
|
# Kahn's algorithm
|
||||||
|
queue = [n for n in stage_names if in_degree.get(n, 0) == 0]
|
||||||
|
# Stable sort: prefer order from config.stages
|
||||||
|
stage_order = {s.name: i for i, s in enumerate(config.stages)}
|
||||||
|
queue.sort(key=lambda n: stage_order.get(n, 999))
|
||||||
|
|
||||||
|
names = []
|
||||||
|
while queue:
|
||||||
|
node = queue.pop(0)
|
||||||
|
names.append(node)
|
||||||
|
for neighbor in graph.get(node, []):
|
||||||
|
in_degree[neighbor] -= 1
|
||||||
|
if in_degree[neighbor] == 0:
|
||||||
|
queue.append(neighbor)
|
||||||
|
queue.sort(key=lambda n: stage_order.get(n, 999))
|
||||||
|
|
||||||
|
if start_from:
|
||||||
|
try:
|
||||||
|
idx = names.index(start_from)
|
||||||
|
names = names[idx:]
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"Stage {start_from!r} not in pipeline config")
|
||||||
|
|
||||||
|
return names
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineRunner:
|
||||||
|
"""
|
||||||
|
Executes a pipeline defined by PipelineConfig.
|
||||||
|
|
||||||
|
Runs stages sequentially (flattened). Each stage:
|
||||||
|
1. Check cancel
|
||||||
|
2. Run node function (via executor — local for now)
|
||||||
|
3. Merge result into state
|
||||||
|
4. Checkpoint (if enabled)
|
||||||
|
5. Check pause
|
||||||
|
|
||||||
|
Executor socket: currently calls node functions directly.
|
||||||
|
Future: dispatch to LocalExecutor / GrpcExecutor / LambdaExecutor
|
||||||
|
based on StageRef.execution_target.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PipelineConfig,
|
||||||
|
checkpoint: bool = False,
|
||||||
|
start_from: str | None = None,
|
||||||
|
):
|
||||||
|
self.config = config
|
||||||
|
self.do_checkpoint = checkpoint
|
||||||
|
self.stage_sequence = _flatten_config(config, start_from)
|
||||||
|
|
||||||
|
def invoke(self, state: DetectState) -> DetectState:
|
||||||
|
"""Run the pipeline on the given state. Returns final state."""
|
||||||
|
for stage_name in self.stage_sequence:
|
||||||
job_id = state.get("job_id", "")
|
job_id = state.get("job_id", "")
|
||||||
|
|
||||||
|
# 1. Cancel check
|
||||||
check = _cancel_check.get(job_id)
|
check = _cancel_check.get(job_id)
|
||||||
if check and check():
|
if check and check():
|
||||||
raise PipelineCancelled(f"Cancelled before {node_name}")
|
raise PipelineCancelled(f"Cancelled before {stage_name}")
|
||||||
|
|
||||||
|
# 2. Run node function
|
||||||
|
node_fn = _NODE_FN_MAP.get(stage_name)
|
||||||
|
if node_fn is None:
|
||||||
|
logger.warning("No node function for stage %s, skipping", stage_name)
|
||||||
|
continue
|
||||||
|
|
||||||
result = node_fn(state)
|
result = node_fn(state)
|
||||||
|
|
||||||
job_id = state.get("job_id", "")
|
# 3. Merge result into state
|
||||||
if not job_id:
|
state.update(result)
|
||||||
return result
|
|
||||||
|
|
||||||
from detect.checkpoint import save_stage_output, save_frames
|
# 4. Checkpoint
|
||||||
from detect.stages.base import _REGISTRY
|
if self.do_checkpoint:
|
||||||
|
from detect.checkpoint import checkpoint_after_stage
|
||||||
|
checkpoint_after_stage(job_id, stage_name, state, result)
|
||||||
|
|
||||||
merged = {**state, **result}
|
# 5. Pause check
|
||||||
|
_wait_if_paused(job_id, stage_name)
|
||||||
|
|
||||||
# Save frames once (first node), reuse manifest after
|
return state
|
||||||
manifest = _frames_manifest.get(job_id)
|
|
||||||
if manifest is None and node_name == "extract_frames":
|
|
||||||
manifest = save_frames(job_id, merged.get("frames", []))
|
|
||||||
_frames_manifest[job_id] = manifest
|
|
||||||
|
|
||||||
# Serialize stage output using the stage's serialize_fn if available
|
|
||||||
stage_cls = _REGISTRY.get(node_name)
|
|
||||||
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
|
||||||
if serialize_fn:
|
|
||||||
output_json = serialize_fn(merged, job_id)
|
|
||||||
else:
|
|
||||||
output_json = {}
|
|
||||||
|
|
||||||
parent_id = _latest_checkpoint.get(job_id)
|
|
||||||
new_checkpoint_id = save_stage_output(
|
|
||||||
timeline_id=job_id,
|
|
||||||
parent_checkpoint_id=parent_id,
|
|
||||||
stage_name=node_name,
|
|
||||||
output_json=output_json,
|
|
||||||
)
|
|
||||||
_latest_checkpoint[job_id] = new_checkpoint_id
|
|
||||||
|
|
||||||
# Pause check — blocks if paused, respects cancel while waiting
|
|
||||||
_wait_if_paused(job_id, node_name)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
wrapper.__name__ = node_fn.__name__
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
# --- Graph construction ---
|
# ---------------------------------------------------------------------------
|
||||||
|
# Public API — backwards compatible with old get_pipeline/build_graph
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph:
|
def get_pipeline(
|
||||||
"""
|
checkpoint: bool | None = None,
|
||||||
Build the pipeline graph.
|
profile_name: str = "soccer_broadcast",
|
||||||
|
start_from: str | None = None,
|
||||||
|
) -> PipelineRunner:
|
||||||
|
"""Return a PipelineRunner for the given profile."""
|
||||||
|
from detect.profiles import get_profile
|
||||||
|
|
||||||
checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var)
|
|
||||||
start_from: skip nodes before this stage (for replay)
|
|
||||||
"""
|
|
||||||
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
|
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
|
||||||
|
profile = get_profile(profile_name)
|
||||||
|
config = profile.pipeline_config()
|
||||||
|
|
||||||
graph = StateGraph(DetectState)
|
return PipelineRunner(
|
||||||
|
config=config,
|
||||||
# Filter to start_from if replaying
|
checkpoint=do_checkpoint,
|
||||||
node_pairs = NODE_FUNCTIONS
|
start_from=start_from,
|
||||||
if start_from:
|
)
|
||||||
start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from)
|
|
||||||
node_pairs = NODE_FUNCTIONS[start_idx:]
|
|
||||||
|
|
||||||
for name, fn in node_pairs:
|
|
||||||
wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn
|
|
||||||
graph.add_node(name, wrapped)
|
|
||||||
|
|
||||||
# Wire edges
|
|
||||||
entry = node_pairs[0][0]
|
|
||||||
graph.set_entry_point(entry)
|
|
||||||
|
|
||||||
for i in range(len(node_pairs) - 1):
|
|
||||||
graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0])
|
|
||||||
|
|
||||||
graph.add_edge(node_pairs[-1][0], END)
|
|
||||||
|
|
||||||
return graph
|
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline(checkpoint: bool | None = None):
|
def build_graph(checkpoint: bool | None = None, start_from: str | None = None):
|
||||||
"""Return a compiled, runnable pipeline."""
|
"""Backwards-compatible wrapper. Returns a PipelineRunner."""
|
||||||
return build_graph(checkpoint=checkpoint).compile()
|
return get_pipeline(checkpoint=checkpoint, start_from=start_from)
|
||||||
|
|||||||
@@ -1,17 +1,20 @@
|
|||||||
"""
|
"""
|
||||||
ContentTypeProfile protocol and config dataclasses.
|
ContentTypeProfile protocol and config dataclasses.
|
||||||
|
|
||||||
The pipeline graph is fixed — what varies per content type is configuration
|
Each profile defines the pipeline topology (as a JSONB blob), stage configs,
|
||||||
and hooks. Each profile provides stage configs, a brand dictionary,
|
brand dictionary, VLM prompt templates, and aggregation strategy.
|
||||||
VLM prompt templates, and an aggregation strategy.
|
|
||||||
|
When profiles are persisted, the pipeline field is a JSONB column.
|
||||||
|
For now, profiles are code-only and pipeline_config() returns a hardcoded value.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Protocol
|
from typing import Any, Dict, Protocol
|
||||||
|
|
||||||
from detect.models import BrandDetection, DetectionReport
|
from detect.models import BrandDetection, DetectionReport
|
||||||
|
from core.schema.models.pipeline_config import PipelineConfig, StageRef, Edge
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -64,9 +67,24 @@ class CropContext:
|
|||||||
position_hint: str = ""
|
position_hint: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def pipeline_config_from_dict(data: Dict[str, Any]) -> PipelineConfig:
|
||||||
|
"""Deserialize a PipelineConfig from a JSONB dict."""
|
||||||
|
stages = [StageRef(**s) for s in data.get("stages", [])]
|
||||||
|
edges = [Edge(**e) for e in data.get("edges", [])]
|
||||||
|
return PipelineConfig(
|
||||||
|
name=data.get("name", ""),
|
||||||
|
profile_name=data.get("profile_name", ""),
|
||||||
|
stages=stages,
|
||||||
|
edges=edges,
|
||||||
|
routing_rules=data.get("routing_rules", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ContentTypeProfile(Protocol):
|
class ContentTypeProfile(Protocol):
|
||||||
name: str
|
name: str
|
||||||
|
pipeline: Dict[str, Any] # JSONB blob — PipelineConfig shape
|
||||||
|
|
||||||
|
def pipeline_config(self) -> PipelineConfig: ...
|
||||||
def frame_extraction_config(self) -> FrameExtractionConfig: ...
|
def frame_extraction_config(self) -> FrameExtractionConfig: ...
|
||||||
def scene_filter_config(self) -> SceneFilterConfig: ...
|
def scene_filter_config(self) -> SceneFilterConfig: ...
|
||||||
def region_analysis_config(self) -> RegionAnalysisConfig: ...
|
def region_analysis_config(self) -> RegionAnalysisConfig: ...
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from core.schema.models.pipeline_config import PipelineConfig
|
||||||
from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
|
from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
@@ -12,12 +13,46 @@ from .base import (
|
|||||||
RegionAnalysisConfig,
|
RegionAnalysisConfig,
|
||||||
ResolverConfig,
|
ResolverConfig,
|
||||||
SceneFilterConfig,
|
SceneFilterConfig,
|
||||||
|
pipeline_config_from_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SoccerBroadcastProfile:
|
class SoccerBroadcastProfile:
|
||||||
name = "soccer_broadcast"
|
name = "soccer_broadcast"
|
||||||
|
|
||||||
|
# Pipeline topology as JSONB — will be a DB field when profiles are persisted
|
||||||
|
pipeline = {
|
||||||
|
"name": "soccer_broadcast",
|
||||||
|
"profile_name": "soccer_broadcast",
|
||||||
|
"stages": [
|
||||||
|
{"name": "extract_frames", "branch": "trunk"},
|
||||||
|
{"name": "filter_scenes", "branch": "trunk"},
|
||||||
|
{"name": "detect_edges", "branch": "hoarding"},
|
||||||
|
{"name": "detect_objects", "branch": "objects"},
|
||||||
|
{"name": "preprocess"},
|
||||||
|
{"name": "run_ocr"},
|
||||||
|
{"name": "match_brands"},
|
||||||
|
{"name": "escalate_vlm"},
|
||||||
|
{"name": "escalate_cloud"},
|
||||||
|
{"name": "compile_report"},
|
||||||
|
],
|
||||||
|
"edges": [
|
||||||
|
{"source": "extract_frames", "target": "filter_scenes"},
|
||||||
|
{"source": "filter_scenes", "target": "detect_edges"},
|
||||||
|
{"source": "filter_scenes", "target": "detect_objects"},
|
||||||
|
{"source": "detect_edges", "target": "preprocess"},
|
||||||
|
{"source": "detect_objects", "target": "preprocess"},
|
||||||
|
{"source": "preprocess", "target": "run_ocr"},
|
||||||
|
{"source": "run_ocr", "target": "match_brands"},
|
||||||
|
{"source": "match_brands", "target": "escalate_vlm"},
|
||||||
|
{"source": "escalate_vlm", "target": "escalate_cloud"},
|
||||||
|
{"source": "escalate_cloud", "target": "compile_report"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def pipeline_config(self) -> PipelineConfig:
|
||||||
|
return pipeline_config_from_dict(self.pipeline)
|
||||||
|
|
||||||
def frame_extraction_config(self) -> FrameExtractionConfig:
|
def frame_extraction_config(self) -> FrameExtractionConfig:
|
||||||
return FrameExtractionConfig(fps=2.0, max_frames=500)
|
return FrameExtractionConfig(fps=2.0, max_frames=500)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Tests for the LangGraph detection pipeline."""
|
"""Tests for the detection pipeline runner."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -33,9 +33,9 @@ def test_graph_compiles():
|
|||||||
|
|
||||||
|
|
||||||
def test_graph_has_all_nodes():
|
def test_graph_has_all_nodes():
|
||||||
graph = build_graph()
|
runner = build_graph()
|
||||||
for node in NODES:
|
for node in NODES:
|
||||||
assert node in graph.nodes
|
assert node in runner.stage_sequence
|
||||||
|
|
||||||
|
|
||||||
@requires_inference
|
@requires_inference
|
||||||
|
|||||||
@@ -1,10 +1,17 @@
|
|||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, watch, computed } from 'vue'
|
import { ref, watch, computed, onMounted } from 'vue'
|
||||||
import { Panel, GraphRenderer } from 'mpr-ui-framework'
|
import { Panel, GraphRenderer } from 'mpr-ui-framework'
|
||||||
import type { GraphNode, GraphMode, DataSource } from 'mpr-ui-framework'
|
import type { GraphNode, GraphMode, DataSource } from 'mpr-ui-framework'
|
||||||
import { usePipelineStore } from '../stores/pipeline'
|
import { usePipelineStore } from '../stores/pipeline'
|
||||||
import { useStageRegistry } from '../composables/useStageRegistry'
|
import { useStageRegistry } from '../composables/useStageRegistry'
|
||||||
|
|
||||||
|
interface PipelineConfigResponse {
|
||||||
|
name: string
|
||||||
|
profile_name: string
|
||||||
|
stages: { name: string; branch: string; execution_target: string }[]
|
||||||
|
edges: { source: string; target: string; condition: string }[]
|
||||||
|
}
|
||||||
|
|
||||||
const props = defineProps<{
|
const props = defineProps<{
|
||||||
source: DataSource
|
source: DataSource
|
||||||
status?: 'idle' | 'live' | 'processing' | 'error'
|
status?: 'idle' | 'live' | 'processing' | 'error'
|
||||||
@@ -14,6 +21,19 @@ const pipeline = usePipelineStore()
|
|||||||
const { stageNames, editableStages } = useStageRegistry()
|
const { stageNames, editableStages } = useStageRegistry()
|
||||||
|
|
||||||
const nodes = ref<GraphNode[]>([])
|
const nodes = ref<GraphNode[]>([])
|
||||||
|
const pipelineConfig = ref<PipelineConfigResponse | null>(null)
|
||||||
|
|
||||||
|
// Fetch pipeline config for a profile
|
||||||
|
async function fetchPipelineConfig(profileName: string) {
|
||||||
|
try {
|
||||||
|
const resp = await fetch(`/api/detect/config/profiles/${profileName}/pipeline`)
|
||||||
|
if (!resp.ok) return
|
||||||
|
pipelineConfig.value = await resp.json()
|
||||||
|
} catch { /* ignore */ }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load default profile config on mount
|
||||||
|
onMounted(() => fetchPipelineConfig('soccer_broadcast'))
|
||||||
|
|
||||||
// Derive graph mode from pipeline layout mode
|
// Derive graph mode from pipeline layout mode
|
||||||
const graphMode = computed<GraphMode>(() => {
|
const graphMode = computed<GraphMode>(() => {
|
||||||
@@ -23,27 +43,20 @@ const graphMode = computed<GraphMode>(() => {
|
|||||||
return 'observe'
|
return 'observe'
|
||||||
})
|
})
|
||||||
|
|
||||||
// Initialize nodes from registry when it loads
|
// Initialize nodes from pipeline config when it loads
|
||||||
|
watch(pipelineConfig, (config) => {
|
||||||
|
if (config && config.stages.length > 0 && nodes.value.length === 0) {
|
||||||
|
nodes.value = config.stages.map((s) => ({ id: s.name, status: 'pending' }))
|
||||||
|
}
|
||||||
|
}, { immediate: true })
|
||||||
|
|
||||||
|
// Fallback: init from registry if no config loaded
|
||||||
watch(stageNames, (names) => {
|
watch(stageNames, (names) => {
|
||||||
if (names.length > 0 && nodes.value.length === 0) {
|
if (names.length > 0 && nodes.value.length === 0) {
|
||||||
nodes.value = names.map((id) => ({ id, status: 'pending' }))
|
nodes.value = names.map((id) => ({ id, status: 'pending' }))
|
||||||
}
|
}
|
||||||
}, { immediate: true })
|
}, { immediate: true })
|
||||||
|
|
||||||
// Source selector: placeholders until a chunk is selected, then real stage names
|
|
||||||
const displayNodes = computed<GraphNode[]>(() => {
|
|
||||||
if (pipeline.layoutMode === 'source_selector') {
|
|
||||||
if (pipeline.sourceHasSelection) {
|
|
||||||
return stageNames.value.map((id) => ({ id, status: 'pending' as const }))
|
|
||||||
}
|
|
||||||
return Array.from({ length: 10 }, (_, i) => ({
|
|
||||||
id: `_placeholder_${i}`,
|
|
||||||
status: 'placeholder' as const,
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
return nodes.value
|
|
||||||
})
|
|
||||||
|
|
||||||
props.source.on<{ nodes: GraphNode[] }>('graph_update', (e) => {
|
props.source.on<{ nodes: GraphNode[] }>('graph_update', (e) => {
|
||||||
nodes.value = e.nodes
|
nodes.value = e.nodes
|
||||||
})
|
})
|
||||||
@@ -63,6 +76,25 @@ props.source.on<{ report?: { status?: string } }>('job_complete', (e) => {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Source selector: placeholders until a chunk is selected, then real stage names
|
||||||
|
const configStageNames = computed(() =>
|
||||||
|
pipelineConfig.value?.stages.map(s => s.name) ?? stageNames.value
|
||||||
|
)
|
||||||
|
|
||||||
|
const displayNodes = computed<GraphNode[]>(() => {
|
||||||
|
if (pipeline.layoutMode === 'source_selector') {
|
||||||
|
if (pipeline.sourceHasSelection) {
|
||||||
|
return configStageNames.value.map((id) => ({ id, status: 'pending' as const }))
|
||||||
|
}
|
||||||
|
const count = configStageNames.value.length || 10
|
||||||
|
return Array.from({ length: count }, (_, i) => ({
|
||||||
|
id: `_placeholder_${i}`,
|
||||||
|
status: 'placeholder' as const,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
return nodes.value
|
||||||
|
})
|
||||||
|
|
||||||
function onOpenRegionEditor(stage: string) {
|
function onOpenRegionEditor(stage: string) {
|
||||||
pipeline.openBBoxEditor(stage)
|
pipeline.openBBoxEditor(stage)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user