This commit is contained in:
2026-03-28 08:46:06 -03:00
parent acc99e691d
commit 0bd3888155
30 changed files with 390 additions and 1044 deletions

View File

@@ -1,8 +1,8 @@
"""
SSE endpoint for chunker pipeline events.
Uses Redis as the event bus between Celery workers and the SSE stream.
Celery worker pushes events via core.events, SSE endpoint polls them.
Uses Redis as the event bus. Pipeline pushes events via core.events,
SSE endpoint polls them.
GET /chunker/stream/{job_id} → text/event-stream
"""

View File

@@ -0,0 +1,20 @@
"""
Detection API — aggregated router.
Combines all detect sub-routers into a single include for main.py.
"""
from fastapi import APIRouter
from .sources import router as sources_router
from .run import router as run_router
from .sse import router as sse_router
from .replay import router as replay_router
from .config import router as config_router
router = APIRouter()
router.include_router(sources_router)
router.include_router(run_router)
router.include_router(sse_router)
router.include_router(replay_router)
router.include_router(config_router)

View File

@@ -1,19 +1,9 @@
"""
Source browser for detection pipeline.
Pipeline run endpoints.
Lists available media sources from blob storage (MinIO).
All file-based sources go through MinIO no host filesystem access.
The pipeline downloads chunks to a temp path before processing.
Source types (current and future):
- chunk_job: pre-chunked segments in MinIO (current)
- upload: user-uploaded file, lands in MinIO via upload endpoint (future)
- device: local camera/capture card via ffmpeg, no MinIO (future)
- stream: RTMP/HLS URL via ffmpeg, no MinIO (future)
GET /detect/sources list chunk jobs from blob store
GET /detect/sources/{job_id}/chunks list chunks for a specific job
POST /detect/run launch pipeline on selected source
POST /detect/run launch pipeline on selected source
POST /detect/stop/{job_id} cancel a running pipeline
POST /detect/clear/{job_id} clear events from Redis
"""
from __future__ import annotations
@@ -31,23 +21,10 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/detect", tags=["detect"])
# In-process pipeline tracking
_running_jobs: dict[str, "threading.Thread"] = {}
_running_jobs: dict[str, threading.Thread] = {}
_cancelled_jobs: set[str] = set()
class ChunkInfo(BaseModel):
filename: str
key: str
size_bytes: int
class SourceInfo(BaseModel):
job_id: str
source_type: str = "chunk_job"
chunk_count: int
total_bytes: int = 0
class RunRequest(BaseModel):
video_path: str # storage key
profile_name: str = "soccer_broadcast"
@@ -64,91 +41,6 @@ class RunResponse(BaseModel):
video_path: str
# ---------------------------------------------------------------------------
# Source listing
# ---------------------------------------------------------------------------
def _list_sources() -> list[SourceInfo]:
"""List chunk jobs from blob storage."""
from core.storage.blob import get_store
store = get_store("out")
try:
objects = store.list(prefix="chunks/")
except Exception as e:
logger.warning("Failed to list blob sources: %s", e)
return []
jobs: dict[str, int] = {}
job_bytes: dict[str, int] = {}
for obj in objects:
# Keys include store prefix: out/chunks/{job_id}/file.mp4
# Strip prefix to get: chunks/{job_id}/file.mp4
rel_key = obj.key.removeprefix(store.prefix)
parts = rel_key.split("/")
if len(parts) >= 3 and parts[0] == "chunks":
job_id = parts[1]
jobs[job_id] = jobs.get(job_id, 0) + 1
job_bytes[job_id] = job_bytes.get(job_id, 0) + obj.size_bytes
sources = []
for job_id, count in sorted(jobs.items()):
source = SourceInfo(
job_id=job_id,
source_type="chunk_job",
chunk_count=count,
total_bytes=job_bytes.get(job_id, 0),
)
sources.append(source)
return sources
@router.get("/sources", response_model=list[SourceInfo])
def list_sources():
"""List available chunk jobs from blob storage."""
return _list_sources()
@router.get("/sources/{source_job_id}/chunks", response_model=list[ChunkInfo])
def list_chunks(source_job_id: str):
"""List chunks for a specific source job."""
from core.storage.blob import get_store
store = get_store("out")
try:
objects = store.list(prefix=f"chunks/{source_job_id}/", extensions={".mp4"})
except Exception as e:
logger.warning("Failed to list chunks for %s: %s", source_job_id, e)
raise HTTPException(status_code=503, detail=f"Blob storage unavailable: {e}")
if not objects:
raise HTTPException(status_code=404, detail=f"Source not found: {source_job_id}")
chunks = []
for obj in objects:
info = ChunkInfo(filename=obj.filename, key=obj.key, size_bytes=obj.size_bytes)
chunks.append(info)
return sorted(chunks, key=lambda c: c.filename)
@router.get("/sources/{source_job_id}/chunks/{filename}/url")
def get_chunk_url(source_job_id: str, filename: str):
"""Return a presigned URL for previewing a chunk in the browser."""
from core.storage.blob import get_store
store = get_store("out")
key = f"chunks/{source_job_id}/{filename}"
try:
url = store.get_url(key, expires=3600)
except Exception as e:
raise HTTPException(status_code=503, detail=f"Could not generate URL: {e}")
return {"url": url}
# ---------------------------------------------------------------------------
# Run pipeline
# ---------------------------------------------------------------------------
def _resolve_video_path(video_path: str) -> str:
"""Download a chunk from blob storage to a temp file."""
from core.storage.blob import get_store
@@ -216,7 +108,6 @@ def run_pipeline(req: RunRequest):
emit.job_complete(job_id, {"status": "cancelled"})
except Exception as e:
logger.exception("Pipeline run %s failed: %s", job_id, e)
# Mark the current/last stage as error in the graph
from detect.graph import _node_states, NODES
if job_id in _node_states:
states = _node_states[job_id]

108
core/api/detect/sources.py Normal file
View File

@@ -0,0 +1,108 @@
"""
Source browser for detection pipeline.
Lists available media sources from blob storage (MinIO).
GET /detect/sources — list chunk jobs
GET /detect/sources/{job_id}/chunks — list chunks for a job
GET /detect/sources/{job_id}/chunks/{name}/url — presigned preview URL
"""
from __future__ import annotations
import logging
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/detect", tags=["detect"])
class ChunkInfoResponse(BaseModel):
filename: str
key: str
size_bytes: int
class SourceInfoResponse(BaseModel):
job_id: str
source_type: str = "chunk_job"
chunk_count: int
total_bytes: int = 0
def _list_sources() -> list[SourceInfoResponse]:
"""List chunk jobs from blob storage."""
from core.storage.blob import get_store
store = get_store("out")
try:
objects = store.list(prefix="chunks/")
except Exception as e:
logger.warning("Failed to list blob sources: %s", e)
return []
jobs: dict[str, int] = {}
job_bytes: dict[str, int] = {}
for obj in objects:
rel_key = obj.key.removeprefix(store.prefix)
parts = rel_key.split("/")
if len(parts) >= 3 and parts[0] == "chunks":
job_id = parts[1]
jobs[job_id] = jobs.get(job_id, 0) + 1
job_bytes[job_id] = job_bytes.get(job_id, 0) + obj.size_bytes
sources = []
for job_id, count in sorted(jobs.items()):
source = SourceInfoResponse(
job_id=job_id,
source_type="chunk_job",
chunk_count=count,
total_bytes=job_bytes.get(job_id, 0),
)
sources.append(source)
return sources
@router.get("/sources", response_model=list[SourceInfoResponse])
def list_sources():
"""List available chunk jobs from blob storage."""
return _list_sources()
@router.get("/sources/{source_job_id}/chunks", response_model=list[ChunkInfoResponse])
def list_chunks(source_job_id: str):
"""List chunks for a specific source job."""
from core.storage.blob import get_store
store = get_store("out")
try:
objects = store.list(prefix=f"chunks/{source_job_id}/", extensions={".mp4"})
except Exception as e:
logger.warning("Failed to list chunks for %s: %s", source_job_id, e)
raise HTTPException(status_code=503, detail=f"Blob storage unavailable: {e}")
if not objects:
raise HTTPException(status_code=404, detail=f"Source not found: {source_job_id}")
chunks = []
for obj in objects:
info = ChunkInfoResponse(filename=obj.filename, key=obj.key, size_bytes=obj.size_bytes)
chunks.append(info)
return sorted(chunks, key=lambda c: c.filename)
@router.get("/sources/{source_job_id}/chunks/{filename}/url")
def get_chunk_url(source_job_id: str, filename: str):
"""Return a presigned URL for previewing a chunk in the browser."""
from core.storage.blob import get_store
store = get_store("out")
key = f"chunks/{source_job_id}/{filename}"
try:
url = store.get_url(key, expires=3600)
except Exception as e:
raise HTTPException(status_code=503, detail=f"Could not generate URL: {e}")
return {"url": url}

View File

@@ -19,10 +19,7 @@ from fastapi.middleware.cors import CORSMiddleware
from strawberry.fastapi import GraphQLRouter
from core.api.chunker_sse import router as chunker_router
from core.api.detect_sse import router as detect_router
from core.api.detect_replay import router as detect_replay_router
from core.api.detect_config import router as detect_config_router
from core.api.detect_sources import router as detect_sources_router
from core.api.detect import router as detect_router
from core.api.graphql import schema as graphql_schema
CALLBACK_API_KEY = os.environ.get("CALLBACK_API_KEY", "")
@@ -61,18 +58,9 @@ app.include_router(graphql_router, prefix="/graphql")
# Chunker SSE
app.include_router(chunker_router)
# Detection SSE
# Detection API (sources, run, SSE, replay, config)
app.include_router(detect_router)
# Detection replay/retry
app.include_router(detect_replay_router)
# Detection config
app.include_router(detect_config_router)
# Detection sources + run launcher
app.include_router(detect_sources_router)
@app.get("/health")
def health():

View File

@@ -48,7 +48,6 @@ class Job(SQLModel, table=True):
brands_found: int = 0
cloud_llm_calls: int = 0
estimated_cost_usd: float = 0.0
celery_task_id: Optional[str] = None
priority: int = 0
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
started_at: Optional[datetime] = None

View File

@@ -1,7 +1,7 @@
"""
Redis-based event bus for pipeline job progress.
Celery workers push events, SSE endpoints poll them.
Pipeline stages push events, SSE endpoints poll them.
Only depends on redis — safe to import from any context.
"""

View File

@@ -1,15 +1,13 @@
"""
MPR Jobs Module
Provides executor abstraction and task dispatch for job processing.
Provides executor abstraction for job dispatch (local, Lambda, GCP).
"""
from .executor import Executor, LocalExecutor, get_executor
from .task import run_job
__all__ = [
"Executor",
"LocalExecutor",
"get_executor",
"run_job",
]

View File

@@ -42,7 +42,7 @@ class Executor(ABC):
class LocalExecutor(Executor):
"""Execute jobs locally using registered handlers."""
"""Execute jobs locally by calling the stage function directly."""
def run(
self,
@@ -51,16 +51,10 @@ class LocalExecutor(Executor):
payload: Dict[str, Any],
progress_callback: Optional[Callable[[int, Dict[str, Any]], None]] = None,
) -> bool:
"""Execute job using the appropriate local handler."""
from .registry import get_handler
handler = get_handler(job_type)
result = handler.process(
job_id=job_id,
payload=payload,
progress_callback=progress_callback,
"""Execute job locally. Socket for PipelineRunner integration."""
raise NotImplementedError(
"LocalExecutor.run() — will be wired to PipelineRunner in Phase 3"
)
return result.get("status") == "completed"
class LambdaExecutor(Executor):

View File

@@ -1,5 +0,0 @@
"""Job handlers — type-specific execution logic."""
from .base import Handler
__all__ = ["Handler"]

View File

@@ -1,33 +0,0 @@
"""
Base Handler ABC — defines the interface for job-type-specific execution logic.
A Handler knows HOW to execute a specific kind of job (transcode, chunk, etc.).
The Executor decides WHERE to run it (local, Lambda, GCP).
"""
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional
class Handler(ABC):
"""Abstract base class for job handlers."""
@abstractmethod
def process(
self,
job_id: str,
payload: Dict[str, Any],
progress_callback: Optional[Callable[[int, Dict[str, Any]], None]] = None,
) -> Dict[str, Any]:
"""
Execute job-specific logic.
Args:
job_id: Unique job identifier
payload: Job-type-specific configuration
progress_callback: Called with (percent, details_dict)
Returns:
Result dict with at least {"status": "completed"} or raises
"""
pass

View File

@@ -1,125 +0,0 @@
"""
ChunkHandler — job handler that wraps the chunker Pipeline.
Downloads source from S3/MinIO, runs FFmpeg chunking pipeline,
writes mp4 segments + manifest to media/out/chunks/{job_id}/.
Pushes real-time events to Redis for SSE consumption.
"""
import logging
import os
from typing import Any, Callable, Dict, Optional
from core.events import push_event as push_chunk_event
from core.chunker import Pipeline
from core.storage import BUCKET_IN, download_to_temp
from .base import Handler
logger = logging.getLogger(__name__)
MEDIA_OUT_DIR = os.environ.get("MEDIA_OUT_DIR", "/app/media/out")
class ChunkHandler(Handler):
"""
Handles chunk processing jobs by delegating to the chunker Pipeline.
Expected payload keys:
source_key: str — S3 key of the source file in BUCKET_IN
chunk_duration: float — seconds per chunk (default: 10.0)
num_workers: int — concurrent workers (default: 4)
max_retries: int — retries per chunk (default: 3)
processor_type: str — "ffmpeg", "checksum", "simulated_decode", "composite"
queue_size: int — max queue depth (default: 10)
"""
def process(
self,
job_id: str,
payload: Dict[str, Any],
progress_callback: Optional[Callable[[int, Dict[str, Any]], None]] = None,
) -> Dict[str, Any]:
source_key = payload["source_key"]
processor_type = payload.get("processor_type", "ffmpeg")
logger.info(f"ChunkHandler starting job {job_id}: {source_key}")
# Download source from S3/MinIO
push_chunk_event(job_id, "pipeline_start", {"status": "downloading", "source_key": source_key})
tmp_source = download_to_temp(BUCKET_IN, source_key)
# Output directory: media/out/chunks/{job_id}/
output_dir = os.path.join(MEDIA_OUT_DIR, "chunks", job_id)
if processor_type == "ffmpeg":
os.makedirs(output_dir, exist_ok=True)
try:
def event_bridge(event_type: str, data: Dict[str, Any]) -> None:
"""Bridge pipeline events to Redis + optional progress callback."""
push_chunk_event(job_id, event_type, data)
if progress_callback and event_type == "pipeline_complete":
progress_callback(100, data)
elif progress_callback and event_type == "chunk_done":
total = data.get("total_chunks", 1)
if total > 0:
pct = min(int((data.get("sequence", 0) + 1) / total * 100), 99)
progress_callback(pct, data)
pipeline = Pipeline(
source=tmp_source,
chunk_duration=payload.get("chunk_duration", 10.0),
num_workers=payload.get("num_workers", 4),
max_retries=payload.get("max_retries", 3),
processor_type=processor_type,
queue_size=payload.get("queue_size", 10),
event_callback=event_bridge,
output_dir=output_dir if processor_type == "ffmpeg" else None,
start_time=payload.get("start_time"),
end_time=payload.get("end_time"),
)
result = pipeline.run()
# Files are already in media/out/chunks/{job_id}/
output_prefix = f"chunks/{job_id}"
output_files = [
f"{output_prefix}/{os.path.basename(f)}"
for f in result.chunk_files
]
push_chunk_event(job_id, "pipeline_complete", {
"status": "completed",
"total_chunks": result.total_chunks,
"processed": result.processed,
"failed": result.failed,
"elapsed": result.elapsed_time,
"throughput_mbps": result.throughput_mbps,
})
return {
"status": "completed" if result.failed == 0 else "completed_with_errors",
"total_chunks": result.total_chunks,
"processed": result.processed,
"failed": result.failed,
"retries": result.retries,
"elapsed_time": result.elapsed_time,
"throughput_mbps": result.throughput_mbps,
"worker_stats": result.worker_stats,
"errors": result.errors,
"chunks_in_order": result.chunks_in_order,
"output_prefix": output_prefix,
"output_files": output_files,
}
except Exception as e:
push_chunk_event(job_id, "pipeline_error", {"status": "failed", "error": str(e)})
raise
finally:
# Cleanup temp source file only (output dir is persistent)
try:
os.unlink(tmp_source)
except OSError:
pass

View File

@@ -1,130 +0,0 @@
"""
DetectHandler — runs the detection pipeline as a Celery job.
Supports three modes via payload:
- Initial run: {"video_path": "...", "profile_name": "..."}
- Replay: {"replay_from": "run_ocr", "source_job_id": "...", "config_overrides": {...}}
- Retry: {"retry_from": "escalate_vlm", "source_job_id": "...", "config_overrides": {...}}
"""
import logging
import os
import uuid
from typing import Any, Callable, Dict, Optional
from .base import Handler
logger = logging.getLogger(__name__)
class DetectHandler(Handler):
def process(
self,
job_id: str,
payload: Dict[str, Any],
progress_callback: Optional[Callable[[int, Dict[str, Any]], None]] = None,
) -> Dict[str, Any]:
replay_from = payload.get("replay_from")
source_job_id = payload.get("source_job_id")
if replay_from and source_job_id:
return self._run_replay(job_id, source_job_id, replay_from, payload, progress_callback)
return self._run_initial(job_id, payload, progress_callback)
def _run_initial(
self,
job_id: str,
payload: Dict[str, Any],
progress_callback: Optional[Callable],
) -> Dict[str, Any]:
from detect import emit
from detect.graph import get_pipeline
from detect.state import DetectState
video_path = payload["video_path"]
profile_name = payload.get("profile_name", "soccer_broadcast")
source_asset_id = payload.get("source_asset_id", "")
checkpoint_enabled = payload.get("checkpoint", os.environ.get("MPR_CHECKPOINT") == "1")
emit.set_run_context(
run_id=job_id,
parent_job_id=payload.get("parent_job_id", job_id),
run_type="initial",
)
logger.info("DetectHandler: initial run job=%s video=%s profile=%s checkpoint=%s",
job_id, video_path, profile_name, checkpoint_enabled)
if progress_callback:
progress_callback(0, {"stage": "starting"})
pipeline = get_pipeline(checkpoint=checkpoint_enabled)
initial_state = DetectState(
video_path=video_path,
job_id=job_id,
profile_name=profile_name,
source_asset_id=source_asset_id,
)
try:
result = pipeline.invoke(initial_state)
finally:
emit.clear_run_context()
detections = result.get("detections", [])
report = result.get("report")
brands_found = len(report.brands) if report else 0
if progress_callback:
progress_callback(100, {"stage": "completed"})
return {
"status": "completed",
"job_id": job_id,
"detections": len(detections),
"brands_found": brands_found,
}
def _run_replay(
self,
job_id: str,
source_job_id: str,
start_stage: str,
payload: Dict[str, Any],
progress_callback: Optional[Callable],
) -> Dict[str, Any]:
from detect.checkpoint import replay_from
config_overrides = payload.get("config_overrides", {})
logger.info("DetectHandler: replay job=%s from=%s source=%s overrides=%s",
job_id, start_stage, source_job_id, config_overrides)
if progress_callback:
progress_callback(0, {"stage": f"replaying from {start_stage}"})
result = replay_from(
job_id=source_job_id,
start_stage=start_stage,
config_overrides=config_overrides,
)
detections = result.get("detections", [])
report = result.get("report")
brands_found = len(report.brands) if report else 0
if progress_callback:
progress_callback(100, {"stage": "completed"})
return {
"status": "completed",
"job_id": job_id,
"source_job_id": source_job_id,
"replay_from": start_stage,
"detections": len(detections),
"brands_found": brands_found,
}

View File

@@ -1,104 +0,0 @@
"""
TranscodeHandler — executes transcode/trim jobs using FFmpeg.
Extracted from the old tasks.py Celery task logic.
"""
import logging
import os
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, Optional
from core.ffmpeg.transcode import TranscodeConfig, transcode
from core.storage import BUCKET_IN, BUCKET_OUT, download_to_temp, upload_file
from .base import Handler
logger = logging.getLogger(__name__)
class TranscodeHandler(Handler):
"""Handle transcode and trim jobs via FFmpeg."""
def process(
self,
job_id: str,
payload: Dict[str, Any],
progress_callback: Optional[Callable[[int, Dict[str, Any]], None]] = None,
) -> Dict[str, Any]:
source_key = payload["source_key"]
output_key = payload["output_key"]
preset = payload.get("preset")
trim_start = payload.get("trim_start")
trim_end = payload.get("trim_end")
duration = payload.get("duration")
logger.info(f"TranscodeHandler: {source_key} -> {output_key}")
# Download source
tmp_source = download_to_temp(BUCKET_IN, source_key)
ext = Path(output_key).suffix or ".mp4"
fd, tmp_output = tempfile.mkstemp(suffix=ext)
os.close(fd)
try:
if preset:
config = TranscodeConfig(
input_path=tmp_source,
output_path=tmp_output,
video_codec=preset.get("video_codec", "libx264"),
video_bitrate=preset.get("video_bitrate"),
video_crf=preset.get("video_crf"),
video_preset=preset.get("video_preset"),
resolution=preset.get("resolution"),
framerate=preset.get("framerate"),
audio_codec=preset.get("audio_codec", "aac"),
audio_bitrate=preset.get("audio_bitrate"),
audio_channels=preset.get("audio_channels"),
audio_samplerate=preset.get("audio_samplerate"),
container=preset.get("container", "mp4"),
extra_args=preset.get("extra_args", []),
trim_start=trim_start,
trim_end=trim_end,
)
else:
config = TranscodeConfig(
input_path=tmp_source,
output_path=tmp_output,
video_codec="copy",
audio_codec="copy",
trim_start=trim_start,
trim_end=trim_end,
)
def wrapped_callback(percent: float, details: Dict[str, Any]) -> None:
if progress_callback:
progress_callback(int(percent), details)
success = transcode(
config,
duration=duration,
progress_callback=wrapped_callback if progress_callback else None,
)
if not success:
raise RuntimeError("Transcode returned False")
# Upload result
logger.info(f"Uploading {output_key} to {BUCKET_OUT}")
upload_file(tmp_output, BUCKET_OUT, output_key)
return {
"status": "completed",
"job_id": job_id,
"output_key": output_key,
}
finally:
for f in [tmp_source, tmp_output]:
try:
os.unlink(f)
except OSError:
pass

View File

@@ -1,35 +0,0 @@
"""
Handler registry — maps job_type strings to Handler classes.
"""
from typing import Dict, Type
from .handlers.base import Handler
_handlers: Dict[str, Type[Handler]] = {}
def register_handler(job_type: str, handler_class: Type[Handler]) -> None:
"""Register a handler class for a job type."""
_handlers[job_type] = handler_class
def get_handler(job_type: str) -> Handler:
"""Get an instantiated handler for a job type."""
if job_type not in _handlers:
raise ValueError(f"Unknown job type: {job_type}")
return _handlers[job_type]()
def _register_defaults() -> None:
"""Register built-in handlers."""
from .handlers.chunk import ChunkHandler
from .handlers.transcode import TranscodeHandler
from .handlers.detect import DetectHandler
register_handler("transcode", TranscodeHandler)
register_handler("chunk", ChunkHandler)
register_handler("detect", DetectHandler)
_register_defaults()

View File

@@ -1,64 +0,0 @@
"""
Celery task for job processing.
Generic dispatcher — routes to the appropriate handler based on job_type.
"""
import logging
from typing import Any, Dict
from celery import shared_task
from core.rpc.server import update_job_progress
logger = logging.getLogger(__name__)
@shared_task(bind=True, max_retries=3, default_retry_delay=60)
def run_job(
self,
job_type: str,
job_id: str,
payload: Dict[str, Any],
) -> Dict[str, Any]:
"""
Generic Celery task — dispatches to the registered handler for job_type.
"""
logger.info(f"Starting {job_type} job {job_id}")
update_job_progress(job_id, progress=0, status="processing")
def progress_callback(percent: int, details: Dict[str, Any]) -> None:
update_job_progress(
job_id,
progress=percent,
current_time=details.get("time", 0.0),
status="processing",
)
try:
from .registry import get_handler
handler = get_handler(job_type)
result = handler.process(
job_id=job_id,
payload=payload,
progress_callback=progress_callback,
)
logger.info(f"Job {job_id} completed successfully")
update_job_progress(job_id, progress=100, status="completed")
return result
except Exception as e:
logger.exception(f"Job {job_id} failed: {e}")
update_job_progress(job_id, progress=0, status="failed", error=str(e))
if self.request.retries < self.max_retries:
raise self.retry(exc=e)
return {
"status": "failed",
"job_id": job_id,
"error": str(e),
}

View File

@@ -29,14 +29,9 @@ _active_jobs: dict[str, dict] = {}
class WorkerServicer(worker_pb2_grpc.WorkerServiceServicer):
"""gRPC service implementation for worker operations."""
def __init__(self, celery_app=None):
"""
Initialize the servicer.
Args:
celery_app: Optional Celery app for task dispatch
"""
self.celery_app = celery_app
def __init__(self):
"""Initialize the servicer."""
pass
def SubmitJob(self, request, context):
"""Submit a transcode/trim job to the worker."""
@@ -57,28 +52,7 @@ class WorkerServicer(worker_pb2_grpc.WorkerServiceServicer):
"error": None,
}
# Dispatch to Celery if available
if self.celery_app:
from core.jobs.task import run_job
payload = {
"source_key": request.source_path,
"output_key": request.output_path,
"preset": preset,
"trim_start": request.trim_start
if request.HasField("trim_start")
else None,
"trim_end": request.trim_end
if request.HasField("trim_end")
else None,
}
task = run_job.delay(
job_type="transcode",
job_id=job_id,
payload=payload,
)
_active_jobs[job_id]["celery_task_id"] = task.id
# TODO: dispatch via executor (local/lambda/gcp/grpc)
return worker_pb2.JobResponse(
job_id=job_id,
@@ -155,12 +129,6 @@ class WorkerServicer(worker_pb2_grpc.WorkerServiceServicer):
if job_id in _active_jobs:
_active_jobs[job_id]["status"] = "cancelled"
# Revoke Celery task if available
if self.celery_app:
task_id = _active_jobs[job_id].get("celery_task_id")
if task_id:
self.celery_app.control.revoke(task_id, terminate=True)
return worker_pb2.CancelResponse(
job_id=job_id,
cancelled=True,
@@ -290,13 +258,12 @@ def update_job_progress(
logger.warning(f"Failed to update job {job_id} in DB: {e}")
def serve(port: int = None, celery_app=None) -> grpc.Server:
def serve(port: int = None) -> grpc.Server:
"""
Start the gRPC server.
Args:
port: Port to listen on (defaults to GRPC_PORT env var)
celery_app: Optional Celery app for task dispatch
Returns:
The running gRPC server
@@ -306,7 +273,7 @@ def serve(port: int = None, celery_app=None) -> grpc.Server:
server = grpc.server(futures.ThreadPoolExecutor(max_workers=GRPC_MAX_WORKERS))
worker_pb2_grpc.add_WorkerServiceServicer_to_server(
WorkerServicer(celery_app=celery_app),
WorkerServicer(),
server,
)
server.add_insecure_port(f"[::]:{port}")

View File

@@ -35,6 +35,7 @@ from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generi
from .inference import INFERENCE_VIEWS # noqa: F401 — GPU inference server API types
from .ui_state import UI_STATE_VIEWS # noqa: F401 — UI store state types
from .stages import StageConfigField, StageIO, StageDefinition, STAGE_VIEWS # noqa: F401
from .detect_api import RunRequest, RunResponse, DETECT_API_VIEWS # noqa: F401
from .views import ChunkEvent, ChunkOutputFile, PipelineStats, WorkerEvent
from .sources import ChunkInfo, SourceJob, SourceType

View File

@@ -0,0 +1,31 @@
"""
Detection API request/response models.
Source of truth for detection pipeline API shapes.
Generated to Pydantic via modelgen.
"""
from dataclasses import dataclass
@dataclass
class RunRequest:
"""Request body for launching a detection pipeline run."""
video_path: str # storage key
profile_name: str = "soccer_broadcast"
source_asset_id: str = ""
checkpoint: bool = True
skip_vlm: bool = False
skip_cloud: bool = False
log_level: str = "INFO" # INFO | DEBUG
@dataclass
class RunResponse:
"""Response after starting a pipeline run."""
status: str
job_id: str
video_path: str
DETECT_API_VIEWS = [RunRequest, RunResponse]

View File

@@ -56,7 +56,6 @@ class Job:
estimated_cost_usd: float = 0.0
# Worker tracking
celery_task_id: Optional[str] = None
priority: int = 0
# Timestamps

View File

@@ -1,133 +0,0 @@
"""
Job Schema Definitions
Source of truth for job data models.
TranscodeJob and ChunkJob share common lifecycle fields by convention.
"""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
class JobStatus(str, Enum):
"""Status of a transcode/trim job."""
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class TranscodeJob:
"""
A transcoding or trimming job in the queue.
Jobs can either:
- Transcode using a preset (full re-encode)
- Trim only (stream copy with -c:v copy -c:a copy)
A trim-only job has no preset and uses stream copy.
"""
id: UUID
# Input
source_asset_id: UUID
# Configuration
preset_id: Optional[UUID] = None
preset_snapshot: Dict[str, Any] = field(
default_factory=dict
) # Copy at creation time
# Trimming (optional)
trim_start: Optional[float] = None # seconds
trim_end: Optional[float] = None # seconds
# Output
output_filename: str = ""
output_path: Optional[str] = None
output_asset_id: Optional[UUID] = None
# Status & Progress
status: JobStatus = JobStatus.PENDING
progress: float = 0.0 # 0.0 to 100.0
current_frame: Optional[int] = None
current_time: Optional[float] = None # seconds processed
speed: Optional[str] = None # "2.5x"
error_message: Optional[str] = None
# Worker tracking
celery_task_id: Optional[str] = None
execution_arn: Optional[str] = None # AWS Step Functions execution ARN
priority: int = 0 # Lower = higher priority
# Timestamps
created_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
@property
def is_trim_only(self) -> bool:
"""Check if this is a trim-only job (stream copy, no transcode)."""
return self.preset_id is None and (
self.trim_start is not None or self.trim_end is not None
)
class ChunkJobStatus(str, Enum):
"""Status of a chunk pipeline job."""
PENDING = "pending"
CHUNKING = "chunking"
PROCESSING = "processing"
COLLECTING = "collecting"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class ChunkJob:
"""
A chunk pipeline job — splits a media file into chunks and processes them
through a concurrent worker pool.
"""
id: UUID
# Input
source_asset_id: UUID
# Configuration
chunk_duration: float = 10.0 # seconds
num_workers: int = 4
max_retries: int = 3
processor_type: str = "ffmpeg" # "ffmpeg", "checksum", "simulated_decode", "composite"
# Status & Progress
status: ChunkJobStatus = ChunkJobStatus.PENDING
progress: float = 0.0 # 0.0 to 100.0
total_chunks: int = 0
processed_chunks: int = 0
failed_chunks: int = 0
retry_count: int = 0
error_message: Optional[str] = None
# Result stats
throughput_mbps: Optional[float] = None
elapsed_seconds: Optional[float] = None
# Worker tracking
celery_task_id: Optional[str] = None
priority: int = 0 # Lower = higher priority
# Timestamps
created_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None

View File

@@ -5,7 +5,6 @@ Checkpoint system — Timeline + Checkpoint tree.
frames.py — frame image S3 upload/download
storage.py — Timeline + Checkpoint (Postgres + MinIO)
replay.py — replay (TODO: migrate to new model)
tasks.py — retry_candidates Celery task
"""
from .storage import (

View File

@@ -1,71 +0,0 @@
"""
Celery tasks for detection pipeline async operations.
retry_candidates: re-run VLM/cloud escalation with different config.
"""
from __future__ import annotations
import logging
import uuid
from datetime import datetime, timezone
from celery import shared_task
logger = logging.getLogger(__name__)
@shared_task(bind=True, max_retries=1, default_retry_delay=30)
def retry_candidates(
self,
job_id: str,
config_overrides: dict | None = None,
start_stage: str = "escalate_vlm",
):
"""
Retry unresolved candidates with different config.
Loads the checkpoint from the stage before start_stage,
applies config overrides (e.g. different cloud provider),
and runs from start_stage onward.
"""
from detect.checkpoint.replay import replay_from
run_id = str(uuid.uuid4())[:8]
logger.info("Retry task %s: job=%s, from=%s, overrides=%s",
run_id, job_id, start_stage, config_overrides)
try:
result = replay_from(
job_id=job_id,
start_stage=start_stage,
config_overrides=config_overrides,
)
detections = result.get("detections", [])
report = result.get("report")
brands_found = len(report.brands) if report else 0
logger.info("Retry %s complete: %d detections, %d brands",
run_id, len(detections), brands_found)
return {
"status": "completed",
"run_id": run_id,
"job_id": job_id,
"detections": len(detections),
"brands_found": brands_found,
}
except Exception as e:
logger.exception("Retry %s failed: %s", run_id, e)
if self.request.retries < self.max_retries:
raise self.retry(exc=e)
return {
"status": "failed",
"run_id": run_id,
"job_id": job_id,
"error": str(e),
}

29
detect/graph/__init__.py Normal file
View File

@@ -0,0 +1,29 @@
"""
Detection pipeline graph.
detect/graph/
nodes.py — node functions (one per stage)
events.py — graph_update SSE emission
runner.py — pipeline execution (LangGraph wrapper, checkpoint, cancel)
"""
from .nodes import NODES, NODE_FUNCTIONS
from .runner import (
PipelineCancelled,
build_graph,
clear_cancel_check,
get_pipeline,
set_cancel_check,
)
from .events import _node_states
__all__ = [
"NODES",
"NODE_FUNCTIONS",
"PipelineCancelled",
"build_graph",
"get_pipeline",
"set_cancel_check",
"clear_cancel_check",
"_node_states",
]

27
detect/graph/events.py Normal file
View File

@@ -0,0 +1,27 @@
"""
Graph event emission — node state tracking + SSE graph_update events.
"""
from __future__ import annotations
from detect import emit
from detect.state import DetectState
# Track node states across pipeline runs
_node_states: dict[str, dict[str, str]] = {}
def emit_transition(state: DetectState, node: str, status: str, node_list: list[str]):
"""Update node status and emit graph_update SSE event."""
job_id = state.get("job_id")
if not job_id:
return
if job_id not in _node_states:
_node_states[job_id] = {n: "pending" for n in node_list}
_node_states[job_id][node] = status
nodes = [{"id": n, "status": _node_states[job_id][n]} for n in node_list]
emit.graph_update(job_id, nodes)

View File

@@ -1,16 +1,13 @@
"""
LangGraph pipeline graph for brand detection.
Pipeline node functions one per stage.
Nodes execute real logic for extract+filter, stubs for the rest.
Each node emits graph_update events so the UI can visualize transitions.
Each node: reads state, runs stage logic, emits transitions, returns output dict.
"""
from __future__ import annotations
import os
from langgraph.graph import END, StateGraph
from detect import emit
from detect.models import PipelineStats
from detect.profiles import SoccerBroadcastProfile
@@ -27,6 +24,8 @@ from detect.stages.vlm_cloud import escalate_cloud
from detect.stages.aggregator import compile_report
from detect.tracing import trace_node, flush as flush_traces
from .events import emit_transition
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
NODES = [
@@ -58,41 +57,24 @@ def _get_profile(state: DetectState):
return profile
# Track node states across the pipeline run
_node_states: dict[str, dict[str, str]] = {}
def _emit_transition(state: DetectState, node: str, status: str):
job_id = state.get("job_id")
if not job_id:
return
# Initialize state tracking for this job
if job_id not in _node_states:
_node_states[job_id] = {n: "pending" for n in NODES}
_node_states[job_id][node] = status
nodes = [{"id": n, "status": _node_states[job_id][n]} for n in NODES]
emit.graph_update(job_id, nodes)
def _emit(state, node, status):
emit_transition(state, node, status, NODES)
# --- Node functions ---
def node_extract_frames(state: DetectState) -> dict:
# Set run context for initial runs (replays set it in replay_from)
job_id = state.get("job_id", "")
if job_id and not emit._run_context:
emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial")
# Load session brands from DB for this source
source_asset_id = state.get("source_asset_id")
if source_asset_id and not state.get("session_brands"):
from detect.stages.brand_resolver import build_session_dict
session_brands = build_session_dict(source_asset_id)
state["session_brands"] = session_brands
_emit_transition(state, "extract_frames", "running")
_emit(state, "extract_frames", "running")
with trace_node(state, "extract_frames") as span:
profile = _get_profile(state)
@@ -100,12 +82,12 @@ def node_extract_frames(state: DetectState) -> dict:
frames = extract_frames(state["video_path"], config, job_id=state.get("job_id"))
span.set_output({"frames_extracted": len(frames)})
_emit_transition(state, "extract_frames", "done")
_emit(state, "extract_frames", "done")
return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))}
def node_filter_scenes(state: DetectState) -> dict:
_emit_transition(state, "filter_scenes", "running")
_emit(state, "filter_scenes", "running")
with trace_node(state, "filter_scenes") as span:
profile = _get_profile(state)
@@ -117,12 +99,12 @@ def node_filter_scenes(state: DetectState) -> dict:
stats = state.get("stats", PipelineStats())
stats.frames_after_scene_filter = len(kept)
_emit_transition(state, "filter_scenes", "done")
_emit(state, "filter_scenes", "done")
return {"filtered_frames": kept, "stats": stats}
def node_detect_edges(state: DetectState) -> dict:
_emit_transition(state, "detect_edges", "running")
_emit(state, "detect_edges", "running")
with trace_node(state, "detect_edges") as span:
profile = _get_profile(state)
@@ -139,12 +121,12 @@ def node_detect_edges(state: DetectState) -> dict:
stats = state.get("stats", PipelineStats())
stats.cv_regions_detected = total
_emit_transition(state, "detect_edges", "done")
_emit(state, "detect_edges", "done")
return {"edge_regions_by_frame": regions, "stats": stats}
def node_detect_objects(state: DetectState) -> dict:
_emit_transition(state, "detect_objects", "running")
_emit(state, "detect_objects", "running")
with trace_node(state, "detect_objects") as span:
profile = _get_profile(state)
@@ -159,12 +141,12 @@ def node_detect_objects(state: DetectState) -> dict:
stats = state.get("stats", PipelineStats())
stats.regions_detected = total_regions
_emit_transition(state, "detect_objects", "done")
_emit(state, "detect_objects", "done")
return {"boxes_by_frame": all_boxes, "stats": stats}
def node_preprocess(state: DetectState) -> dict:
_emit_transition(state, "preprocess", "running")
_emit(state, "preprocess", "running")
with trace_node(state, "preprocess") as span:
profile = _get_profile(state)
@@ -172,7 +154,6 @@ def node_preprocess(state: DetectState) -> dict:
boxes = state.get("boxes_by_frame", {})
job_id = state.get("job_id")
# Get preprocessing config from profile overrides or defaults
overrides = state.get("config_overrides", {})
prep_config = overrides.get("preprocessing", {})
do_contrast = prep_config.get("contrast", True)
@@ -189,12 +170,12 @@ def node_preprocess(state: DetectState) -> dict:
)
span.set_output({"regions_preprocessed": len(preprocessed)})
_emit_transition(state, "preprocess", "done")
_emit(state, "preprocess", "done")
return {"preprocessed_crops": preprocessed}
def node_run_ocr(state: DetectState) -> dict:
_emit_transition(state, "run_ocr", "running")
_emit(state, "run_ocr", "running")
with trace_node(state, "run_ocr") as span:
profile = _get_profile(state)
@@ -209,12 +190,12 @@ def node_run_ocr(state: DetectState) -> dict:
stats = state.get("stats", PipelineStats())
stats.regions_resolved_by_ocr = len(candidates)
_emit_transition(state, "run_ocr", "done")
_emit(state, "run_ocr", "done")
return {"text_candidates": candidates, "stats": stats}
def node_match_brands(state: DetectState) -> dict:
_emit_transition(state, "match_brands", "running")
_emit(state, "match_brands", "running")
with trace_node(state, "match_brands") as span:
profile = _get_profile(state)
@@ -232,12 +213,12 @@ def node_match_brands(state: DetectState) -> dict:
)
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
_emit_transition(state, "match_brands", "done")
_emit(state, "match_brands", "done")
return {"detections": matched, "unresolved_candidates": unresolved}
def node_escalate_vlm(state: DetectState) -> dict:
_emit_transition(state, "escalate_vlm", "running")
_emit(state, "escalate_vlm", "running")
with trace_node(state, "escalate_vlm") as span:
profile = _get_profile(state)
@@ -261,7 +242,7 @@ def node_escalate_vlm(state: DetectState) -> dict:
existing = state.get("detections", [])
vlm_skipped = os.environ.get("SKIP_VLM", "").strip() == "1"
_emit_transition(state, "escalate_vlm", "skipped" if vlm_skipped else "done")
_emit(state, "escalate_vlm", "skipped" if vlm_skipped else "done")
return {
"detections": existing + vlm_matched,
"unresolved_candidates": still_unresolved,
@@ -270,7 +251,7 @@ def node_escalate_vlm(state: DetectState) -> dict:
def node_escalate_cloud(state: DetectState) -> dict:
_emit_transition(state, "escalate_cloud", "running")
_emit(state, "escalate_cloud", "running")
with trace_node(state, "escalate_cloud") as span:
profile = _get_profile(state)
@@ -294,12 +275,12 @@ def node_escalate_cloud(state: DetectState) -> dict:
existing = state.get("detections", [])
cloud_skipped = os.environ.get("SKIP_CLOUD", "").strip() == "1"
_emit_transition(state, "escalate_cloud", "skipped" if cloud_skipped else "done")
_emit(state, "escalate_cloud", "skipped" if cloud_skipped else "done")
return {"detections": existing + cloud_matched, "stats": stats}
def node_compile_report(state: DetectState) -> dict:
_emit_transition(state, "compile_report", "running")
_emit(state, "compile_report", "running")
with trace_node(state, "compile_report") as span:
profile = _get_profile(state)
@@ -318,85 +299,10 @@ def node_compile_report(state: DetectState) -> dict:
span.set_output({"brands": len(report.brands), "detections": len(report.timeline)})
flush_traces()
_emit_transition(state, "compile_report", "done")
_emit(state, "compile_report", "done")
return {"report": report}
# --- Checkpoint wrapper ---
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
_latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id
class PipelineCancelled(Exception):
"""Raised when a pipeline run is cancelled."""
pass
# Cancellation hook — set by the run endpoint, checked before each node
_cancel_check: dict[str, callable] = {}
def set_cancel_check(job_id: str, fn):
_cancel_check[job_id] = fn
def clear_cancel_check(job_id: str):
_cancel_check.pop(job_id, None)
def _checkpointing_node(node_name: str, node_fn):
"""Wrap a node function to auto-checkpoint after completion."""
stage_index = NODES.index(node_name)
def wrapper(state: DetectState) -> dict:
job_id = state.get("job_id", "")
check = _cancel_check.get(job_id)
if check and check():
raise PipelineCancelled(f"Cancelled before {node_name}")
result = node_fn(state)
job_id = state.get("job_id", "")
if not job_id:
return result
from detect.checkpoint import save_stage_output, save_frames
from detect.stages.base import _REGISTRY
merged = {**state, **result}
# Save frames once (first node), reuse manifest after
manifest = _frames_manifest.get(job_id)
if manifest is None and node_name == "extract_frames":
manifest = save_frames(job_id, merged.get("frames", []))
_frames_manifest[job_id] = manifest
# Serialize stage output using the stage's serialize_fn if available
stage_cls = _REGISTRY.get(node_name)
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
if serialize_fn:
output_json = serialize_fn(merged, job_id)
else:
output_json = {}
parent_id = _latest_checkpoint.get(job_id)
new_checkpoint_id = save_stage_output(
timeline_id=job_id,
parent_checkpoint_id=parent_id,
stage_name=node_name,
output_json=output_json,
)
_latest_checkpoint[job_id] = new_checkpoint_id
return result
wrapper.__name__ = node_fn.__name__
return wrapper
# --- Graph construction ---
NODE_FUNCTIONS = [
("extract_frames", node_extract_frames),
("filter_scenes", node_filter_scenes),
@@ -409,41 +315,3 @@ NODE_FUNCTIONS = [
("escalate_cloud", node_escalate_cloud),
("compile_report", node_compile_report),
]
def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph:
"""
Build the pipeline graph.
checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var)
start_from: skip nodes before this stage (for replay)
"""
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
graph = StateGraph(DetectState)
# Filter to start_from if replaying
node_pairs = NODE_FUNCTIONS
if start_from:
start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from)
node_pairs = NODE_FUNCTIONS[start_idx:]
for name, fn in node_pairs:
wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn
graph.add_node(name, wrapped)
# Wire edges
entry = node_pairs[0][0]
graph.set_entry_point(entry)
for i in range(len(node_pairs) - 1):
graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0])
graph.add_edge(node_pairs[-1][0], END)
return graph
def get_pipeline(checkpoint: bool | None = None):
"""Return a compiled, runnable pipeline."""
return build_graph(checkpoint=checkpoint).compile()

127
detect/graph/runner.py Normal file
View File

@@ -0,0 +1,127 @@
"""
Pipeline runner — executes stages sequentially with checkpointing and cancellation.
Currently wraps LangGraph for execution. Will be replaced with a lean
custom runner in Phase 3, with an executor socket for distributed dispatch.
"""
from __future__ import annotations
import os
from langgraph.graph import END, StateGraph
from detect.state import DetectState
from .nodes import NODES, NODE_FUNCTIONS
# --- Checkpoint wrapper ---
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
_latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id
class PipelineCancelled(Exception):
"""Raised when a pipeline run is cancelled."""
pass
# Cancellation hook — set by the run endpoint, checked before each node
_cancel_check: dict[str, callable] = {}
def set_cancel_check(job_id: str, fn):
_cancel_check[job_id] = fn
def clear_cancel_check(job_id: str):
_cancel_check.pop(job_id, None)
def _checkpointing_node(node_name: str, node_fn):
"""Wrap a node function to auto-checkpoint after completion."""
def wrapper(state: DetectState) -> dict:
job_id = state.get("job_id", "")
check = _cancel_check.get(job_id)
if check and check():
raise PipelineCancelled(f"Cancelled before {node_name}")
result = node_fn(state)
job_id = state.get("job_id", "")
if not job_id:
return result
from detect.checkpoint import save_stage_output, save_frames
from detect.stages.base import _REGISTRY
merged = {**state, **result}
# Save frames once (first node), reuse manifest after
manifest = _frames_manifest.get(job_id)
if manifest is None and node_name == "extract_frames":
manifest = save_frames(job_id, merged.get("frames", []))
_frames_manifest[job_id] = manifest
# Serialize stage output using the stage's serialize_fn if available
stage_cls = _REGISTRY.get(node_name)
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
if serialize_fn:
output_json = serialize_fn(merged, job_id)
else:
output_json = {}
parent_id = _latest_checkpoint.get(job_id)
new_checkpoint_id = save_stage_output(
timeline_id=job_id,
parent_checkpoint_id=parent_id,
stage_name=node_name,
output_json=output_json,
)
_latest_checkpoint[job_id] = new_checkpoint_id
return result
wrapper.__name__ = node_fn.__name__
return wrapper
# --- Graph construction ---
def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph:
"""
Build the pipeline graph.
checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var)
start_from: skip nodes before this stage (for replay)
"""
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
graph = StateGraph(DetectState)
# Filter to start_from if replaying
node_pairs = NODE_FUNCTIONS
if start_from:
start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from)
node_pairs = NODE_FUNCTIONS[start_idx:]
for name, fn in node_pairs:
wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn
graph.add_node(name, wrapped)
# Wire edges
entry = node_pairs[0][0]
graph.set_entry_point(entry)
for i in range(len(node_pairs) - 1):
graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0])
graph.add_edge(node_pairs[-1][0], END)
return graph
def get_pipeline(checkpoint: bool | None = None):
"""Return a compiled, runnable pipeline."""
return build_graph(checkpoint=checkpoint).compile()