This commit is contained in:
2026-03-28 10:05:59 -03:00
parent e46bbc419c
commit d0707333fd
12 changed files with 381 additions and 120 deletions

View File

@@ -5,6 +5,7 @@ Checkpoint system — Timeline + Checkpoint tree.
frames.py — frame image S3 upload/download
storage.py — Timeline + Checkpoint (Postgres + MinIO)
replay.py — replay (TODO: migrate to new model)
runner_bridge.py — checkpoint hook for PipelineRunner
"""
from .storage import (
@@ -15,3 +16,4 @@ from .storage import (
load_stage_output,
)
from .frames import save_frames, load_frames
from .runner_bridge import checkpoint_after_stage, reset_checkpoint_state

View File

@@ -0,0 +1,64 @@
"""
Runner bridge — checkpoint hook called by PipelineRunner after each stage.
Owns the per-job state (frame manifest cache, checkpoint chain) that
the runner shouldn't know about.
"""
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
# Per-job state
_frames_manifest: dict[str, dict[int, str]] = {}
_latest_checkpoint: dict[str, str] = {}
def reset_checkpoint_state(job_id: str):
"""Clean up per-job checkpoint state. Called when pipeline finishes."""
_frames_manifest.pop(job_id, None)
_latest_checkpoint.pop(job_id, None)
def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict):
"""
Save a checkpoint after a stage completes.
Called by the runner. Handles:
- Frame upload (once, on first stage)
- Stage output serialization (via stage registry)
- Checkpoint chain (parent → child)
"""
if not job_id:
return
from .storage import save_stage_output
from .frames import save_frames
from detect.stages.base import _REGISTRY
merged = {**state, **result}
# Save frames once (first stage that produces them)
manifest = _frames_manifest.get(job_id)
if manifest is None and stage_name == "extract_frames":
manifest = save_frames(job_id, merged.get("frames", []))
_frames_manifest[job_id] = manifest
# Serialize stage output using the stage's serialize_fn if available
stage_cls = _REGISTRY.get(stage_name)
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
if serialize_fn:
output_json = serialize_fn(merged, job_id)
else:
output_json = {}
parent_id = _latest_checkpoint.get(job_id)
new_checkpoint_id = save_stage_output(
timeline_id=job_id,
parent_checkpoint_id=parent_id,
stage_name=stage_name,
output_json=output_json,
)
_latest_checkpoint[job_id] = new_checkpoint_id

View File

@@ -4,12 +4,13 @@ Detection pipeline graph.
detect/graph/
nodes.py — node functions (one per stage)
events.py — graph_update SSE emission
runner.py — pipeline execution (LangGraph wrapper, checkpoint, cancel, pause)
runner.py — PipelineRunner (config-driven, checkpoint, cancel, pause)
"""
from .nodes import NODES, NODE_FUNCTIONS
from .runner import (
PipelineCancelled,
PipelineRunner,
build_graph,
clear_cancel_check,
clear_pause,
@@ -28,6 +29,7 @@ __all__ = [
"NODES",
"NODE_FUNCTIONS",
"PipelineCancelled",
"PipelineRunner",
"build_graph",
"get_pipeline",
"set_cancel_check",

View File

@@ -1,8 +1,10 @@
"""
Pipeline runner — executes stages sequentially with checkpointing and cancellation.
Pipeline runner — executes stages sequentially with checkpointing,
cancellation, and pause/resume.
Currently wraps LangGraph for execution. Will be replaced with a lean
custom runner in Phase 3, with an executor socket for distributed dispatch.
Reads PipelineConfig from the profile to determine what stages to run.
Flattens the graph into a linear sequence for now (serial execution).
Executor socket: all stages run via LocalExecutor (call function directly).
"""
from __future__ import annotations
@@ -11,19 +13,14 @@ import logging
import os
import threading
from langgraph.graph import END, StateGraph
from core.schema.models.pipeline_config import PipelineConfig
from detect.state import DetectState
from .nodes import NODES, NODE_FUNCTIONS
logger = logging.getLogger(__name__)
# --- Checkpoint wrapper ---
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
_latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id
class PipelineCancelled(Exception):
@@ -53,10 +50,6 @@ def clear_cancel_check(job_id: str):
# ---------------------------------------------------------------------------
# Pause / Resume / Step — checked after each node completes
#
# _pause_gate: threading.Event per job. When cleared, the runner blocks.
# When set, the runner proceeds to the next node.
# _pause_after_stage: if True, automatically clear the gate after each node.
# ---------------------------------------------------------------------------
_pause_gate: dict[str, threading.Event] = {}
@@ -98,7 +91,7 @@ def step_pipeline(job_id: str):
_pause_after_stage[job_id] = True
gate = _pause_gate.get(job_id)
if gate:
gate.set() # unblock for one stage, _pause_after_stage re-pauses after
gate.set()
logger.info("Pipeline %s stepping", job_id)
@@ -106,7 +99,6 @@ def set_pause_after_stage(job_id: str, enabled: bool):
"""Toggle pause-after-each-stage mode."""
_pause_after_stage[job_id] = enabled
if not enabled:
# If disabling, also resume in case we're currently paused
gate = _pause_gate.get(job_id)
if gate:
gate.set()
@@ -124,106 +116,159 @@ def _wait_if_paused(job_id: str, node_name: str):
if gate is None:
return
# If pause-after-stage is on, pause now
if _pause_after_stage.get(job_id, False):
gate.clear()
from detect import emit
emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}")
# Block until gate is set (resume/step) or cancelled
while not gate.wait(timeout=0.5):
check = _cancel_check.get(job_id)
if check and check():
raise PipelineCancelled(f"Cancelled while paused before next stage")
def _checkpointing_node(node_name: str, node_fn):
"""Wrap a node function to auto-checkpoint after completion."""
def wrapper(state: DetectState) -> dict:
job_id = state.get("job_id", "")
check = _cancel_check.get(job_id)
if check and check():
raise PipelineCancelled(f"Cancelled before {node_name}")
result = node_fn(state)
job_id = state.get("job_id", "")
if not job_id:
return result
from detect.checkpoint import save_stage_output, save_frames
from detect.stages.base import _REGISTRY
merged = {**state, **result}
# Save frames once (first node), reuse manifest after
manifest = _frames_manifest.get(job_id)
if manifest is None and node_name == "extract_frames":
manifest = save_frames(job_id, merged.get("frames", []))
_frames_manifest[job_id] = manifest
# Serialize stage output using the stage's serialize_fn if available
stage_cls = _REGISTRY.get(node_name)
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
if serialize_fn:
output_json = serialize_fn(merged, job_id)
else:
output_json = {}
parent_id = _latest_checkpoint.get(job_id)
new_checkpoint_id = save_stage_output(
timeline_id=job_id,
parent_checkpoint_id=parent_id,
stage_name=node_name,
output_json=output_json,
)
_latest_checkpoint[job_id] = new_checkpoint_id
# Pause check — blocks if paused, respects cancel while waiting
_wait_if_paused(job_id, node_name)
return result
wrapper.__name__ = node_fn.__name__
return wrapper
# --- Graph construction ---
# ---------------------------------------------------------------------------
# Pipeline Runner
# ---------------------------------------------------------------------------
def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph:
# Node function lookup — maps stage name to callable
_NODE_FN_MAP: dict[str, callable] = {name: fn for name, fn in NODE_FUNCTIONS}
def _flatten_config(config: PipelineConfig, start_from: str | None = None) -> list[str]:
"""
Build the pipeline graph.
Flatten a PipelineConfig into a linear stage sequence.
checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var)
start_from: skip nodes before this stage (for replay)
For now: topological sort via edges. Falls back to stage order if no edges.
Respects start_from for replay (skip stages before it).
"""
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
if not config.edges:
# No edges defined — use stage order as-is
names = [s.name for s in config.stages]
else:
# Topological sort from edges
graph: dict[str, list[str]] = {}
in_degree: dict[str, int] = {}
stage_names = {s.name for s in config.stages}
graph = StateGraph(DetectState)
for name in stage_names:
graph[name] = []
in_degree[name] = 0
for edge in config.edges:
if edge.source in stage_names and edge.target in stage_names:
graph[edge.source].append(edge.target)
in_degree[edge.target] = in_degree.get(edge.target, 0) + 1
# Kahn's algorithm
queue = [n for n in stage_names if in_degree.get(n, 0) == 0]
# Stable sort: prefer order from config.stages
stage_order = {s.name: i for i, s in enumerate(config.stages)}
queue.sort(key=lambda n: stage_order.get(n, 999))
names = []
while queue:
node = queue.pop(0)
names.append(node)
for neighbor in graph.get(node, []):
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
queue.sort(key=lambda n: stage_order.get(n, 999))
# Filter to start_from if replaying
node_pairs = NODE_FUNCTIONS
if start_from:
start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from)
node_pairs = NODE_FUNCTIONS[start_idx:]
try:
idx = names.index(start_from)
names = names[idx:]
except ValueError:
raise ValueError(f"Stage {start_from!r} not in pipeline config")
for name, fn in node_pairs:
wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn
graph.add_node(name, wrapped)
# Wire edges
entry = node_pairs[0][0]
graph.set_entry_point(entry)
for i in range(len(node_pairs) - 1):
graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0])
graph.add_edge(node_pairs[-1][0], END)
return graph
return names
def get_pipeline(checkpoint: bool | None = None):
"""Return a compiled, runnable pipeline."""
return build_graph(checkpoint=checkpoint).compile()
class PipelineRunner:
"""
Executes a pipeline defined by PipelineConfig.
Runs stages sequentially (flattened). Each stage:
1. Check cancel
2. Run node function (via executor — local for now)
3. Merge result into state
4. Checkpoint (if enabled)
5. Check pause
Executor socket: currently calls node functions directly.
Future: dispatch to LocalExecutor / GrpcExecutor / LambdaExecutor
based on StageRef.execution_target.
"""
def __init__(
self,
config: PipelineConfig,
checkpoint: bool = False,
start_from: str | None = None,
):
self.config = config
self.do_checkpoint = checkpoint
self.stage_sequence = _flatten_config(config, start_from)
def invoke(self, state: DetectState) -> DetectState:
"""Run the pipeline on the given state. Returns final state."""
for stage_name in self.stage_sequence:
job_id = state.get("job_id", "")
# 1. Cancel check
check = _cancel_check.get(job_id)
if check and check():
raise PipelineCancelled(f"Cancelled before {stage_name}")
# 2. Run node function
node_fn = _NODE_FN_MAP.get(stage_name)
if node_fn is None:
logger.warning("No node function for stage %s, skipping", stage_name)
continue
result = node_fn(state)
# 3. Merge result into state
state.update(result)
# 4. Checkpoint
if self.do_checkpoint:
from detect.checkpoint import checkpoint_after_stage
checkpoint_after_stage(job_id, stage_name, state, result)
# 5. Pause check
_wait_if_paused(job_id, stage_name)
return state
# ---------------------------------------------------------------------------
# Public API — backwards compatible with old get_pipeline/build_graph
# ---------------------------------------------------------------------------
def get_pipeline(
checkpoint: bool | None = None,
profile_name: str = "soccer_broadcast",
start_from: str | None = None,
) -> PipelineRunner:
"""Return a PipelineRunner for the given profile."""
from detect.profiles import get_profile
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
profile = get_profile(profile_name)
config = profile.pipeline_config()
return PipelineRunner(
config=config,
checkpoint=do_checkpoint,
start_from=start_from,
)
def build_graph(checkpoint: bool | None = None, start_from: str | None = None):
"""Backwards-compatible wrapper. Returns a PipelineRunner."""
return get_pipeline(checkpoint=checkpoint, start_from=start_from)

View File

@@ -1,17 +1,20 @@
"""
ContentTypeProfile protocol and config dataclasses.
The pipeline graph is fixed — what varies per content type is configuration
and hooks. Each profile provides stage configs, a brand dictionary,
VLM prompt templates, and an aggregation strategy.
Each profile defines the pipeline topology (as a JSONB blob), stage configs,
brand dictionary, VLM prompt templates, and aggregation strategy.
When profiles are persisted, the pipeline field is a JSONB column.
For now, profiles are code-only and pipeline_config() returns a hardcoded value.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Protocol
from typing import Any, Dict, Protocol
from detect.models import BrandDetection, DetectionReport
from core.schema.models.pipeline_config import PipelineConfig, StageRef, Edge
@dataclass
@@ -64,9 +67,24 @@ class CropContext:
position_hint: str = ""
def pipeline_config_from_dict(data: Dict[str, Any]) -> PipelineConfig:
"""Deserialize a PipelineConfig from a JSONB dict."""
stages = [StageRef(**s) for s in data.get("stages", [])]
edges = [Edge(**e) for e in data.get("edges", [])]
return PipelineConfig(
name=data.get("name", ""),
profile_name=data.get("profile_name", ""),
stages=stages,
edges=edges,
routing_rules=data.get("routing_rules", {}),
)
class ContentTypeProfile(Protocol):
name: str
pipeline: Dict[str, Any] # JSONB blob — PipelineConfig shape
def pipeline_config(self) -> PipelineConfig: ...
def frame_extraction_config(self) -> FrameExtractionConfig: ...
def scene_filter_config(self) -> SceneFilterConfig: ...
def region_analysis_config(self) -> RegionAnalysisConfig: ...

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from core.schema.models.pipeline_config import PipelineConfig
from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
from .base import (
@@ -12,12 +13,46 @@ from .base import (
RegionAnalysisConfig,
ResolverConfig,
SceneFilterConfig,
pipeline_config_from_dict,
)
class SoccerBroadcastProfile:
name = "soccer_broadcast"
# Pipeline topology as JSONB — will be a DB field when profiles are persisted
pipeline = {
"name": "soccer_broadcast",
"profile_name": "soccer_broadcast",
"stages": [
{"name": "extract_frames", "branch": "trunk"},
{"name": "filter_scenes", "branch": "trunk"},
{"name": "detect_edges", "branch": "hoarding"},
{"name": "detect_objects", "branch": "objects"},
{"name": "preprocess"},
{"name": "run_ocr"},
{"name": "match_brands"},
{"name": "escalate_vlm"},
{"name": "escalate_cloud"},
{"name": "compile_report"},
],
"edges": [
{"source": "extract_frames", "target": "filter_scenes"},
{"source": "filter_scenes", "target": "detect_edges"},
{"source": "filter_scenes", "target": "detect_objects"},
{"source": "detect_edges", "target": "preprocess"},
{"source": "detect_objects", "target": "preprocess"},
{"source": "preprocess", "target": "run_ocr"},
{"source": "run_ocr", "target": "match_brands"},
{"source": "match_brands", "target": "escalate_vlm"},
{"source": "escalate_vlm", "target": "escalate_cloud"},
{"source": "escalate_cloud", "target": "compile_report"},
],
}
def pipeline_config(self) -> PipelineConfig:
return pipeline_config_from_dict(self.pipeline)
def frame_extraction_config(self) -> FrameExtractionConfig:
return FrameExtractionConfig(fps=2.0, max_frames=500)