diff --git a/core/api/chunker_sse.py b/core/api/chunker_sse.py index 2958e11..b684cbc 100644 --- a/core/api/chunker_sse.py +++ b/core/api/chunker_sse.py @@ -1,8 +1,8 @@ """ SSE endpoint for chunker pipeline events. -Uses Redis as the event bus between Celery workers and the SSE stream. -Celery worker pushes events via core.events, SSE endpoint polls them. +Uses Redis as the event bus. Pipeline pushes events via core.events, +SSE endpoint polls them. GET /chunker/stream/{job_id} → text/event-stream """ diff --git a/core/api/detect/__init__.py b/core/api/detect/__init__.py new file mode 100644 index 0000000..1b551c1 --- /dev/null +++ b/core/api/detect/__init__.py @@ -0,0 +1,20 @@ +""" +Detection API — aggregated router. + +Combines all detect sub-routers into a single include for main.py. +""" + +from fastapi import APIRouter + +from .sources import router as sources_router +from .run import router as run_router +from .sse import router as sse_router +from .replay import router as replay_router +from .config import router as config_router + +router = APIRouter() +router.include_router(sources_router) +router.include_router(run_router) +router.include_router(sse_router) +router.include_router(replay_router) +router.include_router(config_router) diff --git a/core/api/detect_config.py b/core/api/detect/config.py similarity index 100% rename from core/api/detect_config.py rename to core/api/detect/config.py diff --git a/core/api/detect_replay.py b/core/api/detect/replay.py similarity index 100% rename from core/api/detect_replay.py rename to core/api/detect/replay.py diff --git a/core/api/detect_sources.py b/core/api/detect/run.py similarity index 54% rename from core/api/detect_sources.py rename to core/api/detect/run.py index 697f0be..d3049d5 100644 --- a/core/api/detect_sources.py +++ b/core/api/detect/run.py @@ -1,19 +1,9 @@ """ -Source browser for detection pipeline. +Pipeline run endpoints. -Lists available media sources from blob storage (MinIO). -All file-based sources go through MinIO — no host filesystem access. -The pipeline downloads chunks to a temp path before processing. - -Source types (current and future): - - chunk_job: pre-chunked segments in MinIO (current) - - upload: user-uploaded file, lands in MinIO via upload endpoint (future) - - device: local camera/capture card via ffmpeg, no MinIO (future) - - stream: RTMP/HLS URL via ffmpeg, no MinIO (future) - -GET /detect/sources — list chunk jobs from blob store -GET /detect/sources/{job_id}/chunks — list chunks for a specific job -POST /detect/run — launch pipeline on selected source +POST /detect/run — launch pipeline on selected source +POST /detect/stop/{job_id} — cancel a running pipeline +POST /detect/clear/{job_id} — clear events from Redis """ from __future__ import annotations @@ -31,23 +21,10 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/detect", tags=["detect"]) # In-process pipeline tracking -_running_jobs: dict[str, "threading.Thread"] = {} +_running_jobs: dict[str, threading.Thread] = {} _cancelled_jobs: set[str] = set() -class ChunkInfo(BaseModel): - filename: str - key: str - size_bytes: int - - -class SourceInfo(BaseModel): - job_id: str - source_type: str = "chunk_job" - chunk_count: int - total_bytes: int = 0 - - class RunRequest(BaseModel): video_path: str # storage key profile_name: str = "soccer_broadcast" @@ -64,91 +41,6 @@ class RunResponse(BaseModel): video_path: str -# --------------------------------------------------------------------------- -# Source listing -# --------------------------------------------------------------------------- - -def _list_sources() -> list[SourceInfo]: - """List chunk jobs from blob storage.""" - from core.storage.blob import get_store - - store = get_store("out") - try: - objects = store.list(prefix="chunks/") - except Exception as e: - logger.warning("Failed to list blob sources: %s", e) - return [] - - jobs: dict[str, int] = {} - job_bytes: dict[str, int] = {} - for obj in objects: - # Keys include store prefix: out/chunks/{job_id}/file.mp4 - # Strip prefix to get: chunks/{job_id}/file.mp4 - rel_key = obj.key.removeprefix(store.prefix) - parts = rel_key.split("/") - if len(parts) >= 3 and parts[0] == "chunks": - job_id = parts[1] - jobs[job_id] = jobs.get(job_id, 0) + 1 - job_bytes[job_id] = job_bytes.get(job_id, 0) + obj.size_bytes - - sources = [] - for job_id, count in sorted(jobs.items()): - source = SourceInfo( - job_id=job_id, - source_type="chunk_job", - chunk_count=count, - total_bytes=job_bytes.get(job_id, 0), - ) - sources.append(source) - return sources - - -@router.get("/sources", response_model=list[SourceInfo]) -def list_sources(): - """List available chunk jobs from blob storage.""" - return _list_sources() - - -@router.get("/sources/{source_job_id}/chunks", response_model=list[ChunkInfo]) -def list_chunks(source_job_id: str): - """List chunks for a specific source job.""" - from core.storage.blob import get_store - - store = get_store("out") - try: - objects = store.list(prefix=f"chunks/{source_job_id}/", extensions={".mp4"}) - except Exception as e: - logger.warning("Failed to list chunks for %s: %s", source_job_id, e) - raise HTTPException(status_code=503, detail=f"Blob storage unavailable: {e}") - - if not objects: - raise HTTPException(status_code=404, detail=f"Source not found: {source_job_id}") - - chunks = [] - for obj in objects: - info = ChunkInfo(filename=obj.filename, key=obj.key, size_bytes=obj.size_bytes) - chunks.append(info) - return sorted(chunks, key=lambda c: c.filename) - - -@router.get("/sources/{source_job_id}/chunks/{filename}/url") -def get_chunk_url(source_job_id: str, filename: str): - """Return a presigned URL for previewing a chunk in the browser.""" - from core.storage.blob import get_store - - store = get_store("out") - key = f"chunks/{source_job_id}/{filename}" - try: - url = store.get_url(key, expires=3600) - except Exception as e: - raise HTTPException(status_code=503, detail=f"Could not generate URL: {e}") - return {"url": url} - - -# --------------------------------------------------------------------------- -# Run pipeline -# --------------------------------------------------------------------------- - def _resolve_video_path(video_path: str) -> str: """Download a chunk from blob storage to a temp file.""" from core.storage.blob import get_store @@ -216,7 +108,6 @@ def run_pipeline(req: RunRequest): emit.job_complete(job_id, {"status": "cancelled"}) except Exception as e: logger.exception("Pipeline run %s failed: %s", job_id, e) - # Mark the current/last stage as error in the graph from detect.graph import _node_states, NODES if job_id in _node_states: states = _node_states[job_id] diff --git a/core/api/detect/sources.py b/core/api/detect/sources.py new file mode 100644 index 0000000..25852ee --- /dev/null +++ b/core/api/detect/sources.py @@ -0,0 +1,108 @@ +""" +Source browser for detection pipeline. + +Lists available media sources from blob storage (MinIO). + +GET /detect/sources — list chunk jobs +GET /detect/sources/{job_id}/chunks — list chunks for a job +GET /detect/sources/{job_id}/chunks/{name}/url — presigned preview URL +""" + +from __future__ import annotations + +import logging + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/detect", tags=["detect"]) + + +class ChunkInfoResponse(BaseModel): + filename: str + key: str + size_bytes: int + + +class SourceInfoResponse(BaseModel): + job_id: str + source_type: str = "chunk_job" + chunk_count: int + total_bytes: int = 0 + + +def _list_sources() -> list[SourceInfoResponse]: + """List chunk jobs from blob storage.""" + from core.storage.blob import get_store + + store = get_store("out") + try: + objects = store.list(prefix="chunks/") + except Exception as e: + logger.warning("Failed to list blob sources: %s", e) + return [] + + jobs: dict[str, int] = {} + job_bytes: dict[str, int] = {} + for obj in objects: + rel_key = obj.key.removeprefix(store.prefix) + parts = rel_key.split("/") + if len(parts) >= 3 and parts[0] == "chunks": + job_id = parts[1] + jobs[job_id] = jobs.get(job_id, 0) + 1 + job_bytes[job_id] = job_bytes.get(job_id, 0) + obj.size_bytes + + sources = [] + for job_id, count in sorted(jobs.items()): + source = SourceInfoResponse( + job_id=job_id, + source_type="chunk_job", + chunk_count=count, + total_bytes=job_bytes.get(job_id, 0), + ) + sources.append(source) + return sources + + +@router.get("/sources", response_model=list[SourceInfoResponse]) +def list_sources(): + """List available chunk jobs from blob storage.""" + return _list_sources() + + +@router.get("/sources/{source_job_id}/chunks", response_model=list[ChunkInfoResponse]) +def list_chunks(source_job_id: str): + """List chunks for a specific source job.""" + from core.storage.blob import get_store + + store = get_store("out") + try: + objects = store.list(prefix=f"chunks/{source_job_id}/", extensions={".mp4"}) + except Exception as e: + logger.warning("Failed to list chunks for %s: %s", source_job_id, e) + raise HTTPException(status_code=503, detail=f"Blob storage unavailable: {e}") + + if not objects: + raise HTTPException(status_code=404, detail=f"Source not found: {source_job_id}") + + chunks = [] + for obj in objects: + info = ChunkInfoResponse(filename=obj.filename, key=obj.key, size_bytes=obj.size_bytes) + chunks.append(info) + return sorted(chunks, key=lambda c: c.filename) + + +@router.get("/sources/{source_job_id}/chunks/{filename}/url") +def get_chunk_url(source_job_id: str, filename: str): + """Return a presigned URL for previewing a chunk in the browser.""" + from core.storage.blob import get_store + + store = get_store("out") + key = f"chunks/{source_job_id}/{filename}" + try: + url = store.get_url(key, expires=3600) + except Exception as e: + raise HTTPException(status_code=503, detail=f"Could not generate URL: {e}") + return {"url": url} diff --git a/core/api/detect_sse.py b/core/api/detect/sse.py similarity index 100% rename from core/api/detect_sse.py rename to core/api/detect/sse.py diff --git a/core/api/main.py b/core/api/main.py index d3c38b8..ad76427 100644 --- a/core/api/main.py +++ b/core/api/main.py @@ -19,10 +19,7 @@ from fastapi.middleware.cors import CORSMiddleware from strawberry.fastapi import GraphQLRouter from core.api.chunker_sse import router as chunker_router -from core.api.detect_sse import router as detect_router -from core.api.detect_replay import router as detect_replay_router -from core.api.detect_config import router as detect_config_router -from core.api.detect_sources import router as detect_sources_router +from core.api.detect import router as detect_router from core.api.graphql import schema as graphql_schema CALLBACK_API_KEY = os.environ.get("CALLBACK_API_KEY", "") @@ -61,18 +58,9 @@ app.include_router(graphql_router, prefix="/graphql") # Chunker SSE app.include_router(chunker_router) -# Detection SSE +# Detection API (sources, run, SSE, replay, config) app.include_router(detect_router) -# Detection replay/retry -app.include_router(detect_replay_router) - -# Detection config -app.include_router(detect_config_router) - -# Detection sources + run launcher -app.include_router(detect_sources_router) - @app.get("/health") def health(): diff --git a/core/db/tables.py b/core/db/tables.py index d1d8489..c6f59c6 100644 --- a/core/db/tables.py +++ b/core/db/tables.py @@ -48,7 +48,6 @@ class Job(SQLModel, table=True): brands_found: int = 0 cloud_llm_calls: int = 0 estimated_cost_usd: float = 0.0 - celery_task_id: Optional[str] = None priority: int = 0 created_at: Optional[datetime] = Field(default_factory=datetime.utcnow) started_at: Optional[datetime] = None diff --git a/core/events.py b/core/events.py index a6b30cb..762ebec 100644 --- a/core/events.py +++ b/core/events.py @@ -1,7 +1,7 @@ """ Redis-based event bus for pipeline job progress. -Celery workers push events, SSE endpoints poll them. +Pipeline stages push events, SSE endpoints poll them. Only depends on redis — safe to import from any context. """ diff --git a/core/jobs/__init__.py b/core/jobs/__init__.py index 8827db9..c3be458 100644 --- a/core/jobs/__init__.py +++ b/core/jobs/__init__.py @@ -1,15 +1,13 @@ """ MPR Jobs Module -Provides executor abstraction and task dispatch for job processing. +Provides executor abstraction for job dispatch (local, Lambda, GCP). """ from .executor import Executor, LocalExecutor, get_executor -from .task import run_job __all__ = [ "Executor", "LocalExecutor", "get_executor", - "run_job", ] diff --git a/core/jobs/executor.py b/core/jobs/executor.py index ef1e6cd..42f545e 100644 --- a/core/jobs/executor.py +++ b/core/jobs/executor.py @@ -42,7 +42,7 @@ class Executor(ABC): class LocalExecutor(Executor): - """Execute jobs locally using registered handlers.""" + """Execute jobs locally by calling the stage function directly.""" def run( self, @@ -51,16 +51,10 @@ class LocalExecutor(Executor): payload: Dict[str, Any], progress_callback: Optional[Callable[[int, Dict[str, Any]], None]] = None, ) -> bool: - """Execute job using the appropriate local handler.""" - from .registry import get_handler - - handler = get_handler(job_type) - result = handler.process( - job_id=job_id, - payload=payload, - progress_callback=progress_callback, + """Execute job locally. Socket for PipelineRunner integration.""" + raise NotImplementedError( + "LocalExecutor.run() — will be wired to PipelineRunner in Phase 3" ) - return result.get("status") == "completed" class LambdaExecutor(Executor): diff --git a/core/jobs/handlers/__init__.py b/core/jobs/handlers/__init__.py deleted file mode 100644 index 8ac89c6..0000000 --- a/core/jobs/handlers/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Job handlers — type-specific execution logic.""" - -from .base import Handler - -__all__ = ["Handler"] diff --git a/core/jobs/handlers/base.py b/core/jobs/handlers/base.py deleted file mode 100644 index f47328f..0000000 --- a/core/jobs/handlers/base.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -Base Handler ABC — defines the interface for job-type-specific execution logic. - -A Handler knows HOW to execute a specific kind of job (transcode, chunk, etc.). -The Executor decides WHERE to run it (local, Lambda, GCP). -""" - -from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Optional - - -class Handler(ABC): - """Abstract base class for job handlers.""" - - @abstractmethod - def process( - self, - job_id: str, - payload: Dict[str, Any], - progress_callback: Optional[Callable[[int, Dict[str, Any]], None]] = None, - ) -> Dict[str, Any]: - """ - Execute job-specific logic. - - Args: - job_id: Unique job identifier - payload: Job-type-specific configuration - progress_callback: Called with (percent, details_dict) - - Returns: - Result dict with at least {"status": "completed"} or raises - """ - pass diff --git a/core/jobs/handlers/chunk.py b/core/jobs/handlers/chunk.py deleted file mode 100644 index 352a5fe..0000000 --- a/core/jobs/handlers/chunk.py +++ /dev/null @@ -1,125 +0,0 @@ -""" -ChunkHandler — job handler that wraps the chunker Pipeline. - -Downloads source from S3/MinIO, runs FFmpeg chunking pipeline, -writes mp4 segments + manifest to media/out/chunks/{job_id}/. -Pushes real-time events to Redis for SSE consumption. -""" - -import logging -import os -from typing import Any, Callable, Dict, Optional - -from core.events import push_event as push_chunk_event -from core.chunker import Pipeline -from core.storage import BUCKET_IN, download_to_temp - -from .base import Handler - -logger = logging.getLogger(__name__) - -MEDIA_OUT_DIR = os.environ.get("MEDIA_OUT_DIR", "/app/media/out") - - -class ChunkHandler(Handler): - """ - Handles chunk processing jobs by delegating to the chunker Pipeline. - - Expected payload keys: - source_key: str — S3 key of the source file in BUCKET_IN - chunk_duration: float — seconds per chunk (default: 10.0) - num_workers: int — concurrent workers (default: 4) - max_retries: int — retries per chunk (default: 3) - processor_type: str — "ffmpeg", "checksum", "simulated_decode", "composite" - queue_size: int — max queue depth (default: 10) - """ - - def process( - self, - job_id: str, - payload: Dict[str, Any], - progress_callback: Optional[Callable[[int, Dict[str, Any]], None]] = None, - ) -> Dict[str, Any]: - source_key = payload["source_key"] - processor_type = payload.get("processor_type", "ffmpeg") - - logger.info(f"ChunkHandler starting job {job_id}: {source_key}") - - # Download source from S3/MinIO - push_chunk_event(job_id, "pipeline_start", {"status": "downloading", "source_key": source_key}) - tmp_source = download_to_temp(BUCKET_IN, source_key) - - # Output directory: media/out/chunks/{job_id}/ - output_dir = os.path.join(MEDIA_OUT_DIR, "chunks", job_id) - if processor_type == "ffmpeg": - os.makedirs(output_dir, exist_ok=True) - - try: - def event_bridge(event_type: str, data: Dict[str, Any]) -> None: - """Bridge pipeline events to Redis + optional progress callback.""" - push_chunk_event(job_id, event_type, data) - - if progress_callback and event_type == "pipeline_complete": - progress_callback(100, data) - elif progress_callback and event_type == "chunk_done": - total = data.get("total_chunks", 1) - if total > 0: - pct = min(int((data.get("sequence", 0) + 1) / total * 100), 99) - progress_callback(pct, data) - - pipeline = Pipeline( - source=tmp_source, - chunk_duration=payload.get("chunk_duration", 10.0), - num_workers=payload.get("num_workers", 4), - max_retries=payload.get("max_retries", 3), - processor_type=processor_type, - queue_size=payload.get("queue_size", 10), - event_callback=event_bridge, - output_dir=output_dir if processor_type == "ffmpeg" else None, - start_time=payload.get("start_time"), - end_time=payload.get("end_time"), - ) - - result = pipeline.run() - - # Files are already in media/out/chunks/{job_id}/ - output_prefix = f"chunks/{job_id}" - output_files = [ - f"{output_prefix}/{os.path.basename(f)}" - for f in result.chunk_files - ] - - push_chunk_event(job_id, "pipeline_complete", { - "status": "completed", - "total_chunks": result.total_chunks, - "processed": result.processed, - "failed": result.failed, - "elapsed": result.elapsed_time, - "throughput_mbps": result.throughput_mbps, - }) - - return { - "status": "completed" if result.failed == 0 else "completed_with_errors", - "total_chunks": result.total_chunks, - "processed": result.processed, - "failed": result.failed, - "retries": result.retries, - "elapsed_time": result.elapsed_time, - "throughput_mbps": result.throughput_mbps, - "worker_stats": result.worker_stats, - "errors": result.errors, - "chunks_in_order": result.chunks_in_order, - "output_prefix": output_prefix, - "output_files": output_files, - } - - except Exception as e: - push_chunk_event(job_id, "pipeline_error", {"status": "failed", "error": str(e)}) - raise - - finally: - # Cleanup temp source file only (output dir is persistent) - try: - os.unlink(tmp_source) - except OSError: - pass diff --git a/core/jobs/handlers/detect.py b/core/jobs/handlers/detect.py deleted file mode 100644 index 67ce4e0..0000000 --- a/core/jobs/handlers/detect.py +++ /dev/null @@ -1,130 +0,0 @@ -""" -DetectHandler — runs the detection pipeline as a Celery job. - -Supports three modes via payload: - - Initial run: {"video_path": "...", "profile_name": "..."} - - Replay: {"replay_from": "run_ocr", "source_job_id": "...", "config_overrides": {...}} - - Retry: {"retry_from": "escalate_vlm", "source_job_id": "...", "config_overrides": {...}} -""" - -import logging -import os -import uuid -from typing import Any, Callable, Dict, Optional - -from .base import Handler - -logger = logging.getLogger(__name__) - - -class DetectHandler(Handler): - - def process( - self, - job_id: str, - payload: Dict[str, Any], - progress_callback: Optional[Callable[[int, Dict[str, Any]], None]] = None, - ) -> Dict[str, Any]: - - replay_from = payload.get("replay_from") - source_job_id = payload.get("source_job_id") - - if replay_from and source_job_id: - return self._run_replay(job_id, source_job_id, replay_from, payload, progress_callback) - - return self._run_initial(job_id, payload, progress_callback) - - def _run_initial( - self, - job_id: str, - payload: Dict[str, Any], - progress_callback: Optional[Callable], - ) -> Dict[str, Any]: - from detect import emit - from detect.graph import get_pipeline - from detect.state import DetectState - - video_path = payload["video_path"] - profile_name = payload.get("profile_name", "soccer_broadcast") - source_asset_id = payload.get("source_asset_id", "") - checkpoint_enabled = payload.get("checkpoint", os.environ.get("MPR_CHECKPOINT") == "1") - - emit.set_run_context( - run_id=job_id, - parent_job_id=payload.get("parent_job_id", job_id), - run_type="initial", - ) - - logger.info("DetectHandler: initial run job=%s video=%s profile=%s checkpoint=%s", - job_id, video_path, profile_name, checkpoint_enabled) - - if progress_callback: - progress_callback(0, {"stage": "starting"}) - - pipeline = get_pipeline(checkpoint=checkpoint_enabled) - - initial_state = DetectState( - video_path=video_path, - job_id=job_id, - profile_name=profile_name, - source_asset_id=source_asset_id, - ) - - try: - result = pipeline.invoke(initial_state) - finally: - emit.clear_run_context() - - detections = result.get("detections", []) - report = result.get("report") - brands_found = len(report.brands) if report else 0 - - if progress_callback: - progress_callback(100, {"stage": "completed"}) - - return { - "status": "completed", - "job_id": job_id, - "detections": len(detections), - "brands_found": brands_found, - } - - def _run_replay( - self, - job_id: str, - source_job_id: str, - start_stage: str, - payload: Dict[str, Any], - progress_callback: Optional[Callable], - ) -> Dict[str, Any]: - from detect.checkpoint import replay_from - - config_overrides = payload.get("config_overrides", {}) - - logger.info("DetectHandler: replay job=%s from=%s source=%s overrides=%s", - job_id, start_stage, source_job_id, config_overrides) - - if progress_callback: - progress_callback(0, {"stage": f"replaying from {start_stage}"}) - - result = replay_from( - job_id=source_job_id, - start_stage=start_stage, - config_overrides=config_overrides, - ) - - detections = result.get("detections", []) - report = result.get("report") - brands_found = len(report.brands) if report else 0 - - if progress_callback: - progress_callback(100, {"stage": "completed"}) - - return { - "status": "completed", - "job_id": job_id, - "source_job_id": source_job_id, - "replay_from": start_stage, - "detections": len(detections), - "brands_found": brands_found, - } diff --git a/core/jobs/handlers/transcode.py b/core/jobs/handlers/transcode.py deleted file mode 100644 index 6371e2f..0000000 --- a/core/jobs/handlers/transcode.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -TranscodeHandler — executes transcode/trim jobs using FFmpeg. - -Extracted from the old tasks.py Celery task logic. -""" - -import logging -import os -import tempfile -from pathlib import Path -from typing import Any, Callable, Dict, Optional - -from core.ffmpeg.transcode import TranscodeConfig, transcode -from core.storage import BUCKET_IN, BUCKET_OUT, download_to_temp, upload_file - -from .base import Handler - -logger = logging.getLogger(__name__) - - -class TranscodeHandler(Handler): - """Handle transcode and trim jobs via FFmpeg.""" - - def process( - self, - job_id: str, - payload: Dict[str, Any], - progress_callback: Optional[Callable[[int, Dict[str, Any]], None]] = None, - ) -> Dict[str, Any]: - source_key = payload["source_key"] - output_key = payload["output_key"] - preset = payload.get("preset") - trim_start = payload.get("trim_start") - trim_end = payload.get("trim_end") - duration = payload.get("duration") - - logger.info(f"TranscodeHandler: {source_key} -> {output_key}") - - # Download source - tmp_source = download_to_temp(BUCKET_IN, source_key) - - ext = Path(output_key).suffix or ".mp4" - fd, tmp_output = tempfile.mkstemp(suffix=ext) - os.close(fd) - - try: - if preset: - config = TranscodeConfig( - input_path=tmp_source, - output_path=tmp_output, - video_codec=preset.get("video_codec", "libx264"), - video_bitrate=preset.get("video_bitrate"), - video_crf=preset.get("video_crf"), - video_preset=preset.get("video_preset"), - resolution=preset.get("resolution"), - framerate=preset.get("framerate"), - audio_codec=preset.get("audio_codec", "aac"), - audio_bitrate=preset.get("audio_bitrate"), - audio_channels=preset.get("audio_channels"), - audio_samplerate=preset.get("audio_samplerate"), - container=preset.get("container", "mp4"), - extra_args=preset.get("extra_args", []), - trim_start=trim_start, - trim_end=trim_end, - ) - else: - config = TranscodeConfig( - input_path=tmp_source, - output_path=tmp_output, - video_codec="copy", - audio_codec="copy", - trim_start=trim_start, - trim_end=trim_end, - ) - - def wrapped_callback(percent: float, details: Dict[str, Any]) -> None: - if progress_callback: - progress_callback(int(percent), details) - - success = transcode( - config, - duration=duration, - progress_callback=wrapped_callback if progress_callback else None, - ) - - if not success: - raise RuntimeError("Transcode returned False") - - # Upload result - logger.info(f"Uploading {output_key} to {BUCKET_OUT}") - upload_file(tmp_output, BUCKET_OUT, output_key) - - return { - "status": "completed", - "job_id": job_id, - "output_key": output_key, - } - - finally: - for f in [tmp_source, tmp_output]: - try: - os.unlink(f) - except OSError: - pass diff --git a/core/jobs/registry.py b/core/jobs/registry.py deleted file mode 100644 index 9956e31..0000000 --- a/core/jobs/registry.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -Handler registry — maps job_type strings to Handler classes. -""" - -from typing import Dict, Type - -from .handlers.base import Handler - -_handlers: Dict[str, Type[Handler]] = {} - - -def register_handler(job_type: str, handler_class: Type[Handler]) -> None: - """Register a handler class for a job type.""" - _handlers[job_type] = handler_class - - -def get_handler(job_type: str) -> Handler: - """Get an instantiated handler for a job type.""" - if job_type not in _handlers: - raise ValueError(f"Unknown job type: {job_type}") - return _handlers[job_type]() - - -def _register_defaults() -> None: - """Register built-in handlers.""" - from .handlers.chunk import ChunkHandler - from .handlers.transcode import TranscodeHandler - from .handlers.detect import DetectHandler - - register_handler("transcode", TranscodeHandler) - register_handler("chunk", ChunkHandler) - register_handler("detect", DetectHandler) - - -_register_defaults() diff --git a/core/jobs/task.py b/core/jobs/task.py deleted file mode 100644 index 4c0a60d..0000000 --- a/core/jobs/task.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -Celery task for job processing. - -Generic dispatcher — routes to the appropriate handler based on job_type. -""" - -import logging -from typing import Any, Dict - -from celery import shared_task - -from core.rpc.server import update_job_progress - -logger = logging.getLogger(__name__) - - -@shared_task(bind=True, max_retries=3, default_retry_delay=60) -def run_job( - self, - job_type: str, - job_id: str, - payload: Dict[str, Any], -) -> Dict[str, Any]: - """ - Generic Celery task — dispatches to the registered handler for job_type. - """ - logger.info(f"Starting {job_type} job {job_id}") - - update_job_progress(job_id, progress=0, status="processing") - - def progress_callback(percent: int, details: Dict[str, Any]) -> None: - update_job_progress( - job_id, - progress=percent, - current_time=details.get("time", 0.0), - status="processing", - ) - - try: - from .registry import get_handler - - handler = get_handler(job_type) - result = handler.process( - job_id=job_id, - payload=payload, - progress_callback=progress_callback, - ) - - logger.info(f"Job {job_id} completed successfully") - update_job_progress(job_id, progress=100, status="completed") - return result - - except Exception as e: - logger.exception(f"Job {job_id} failed: {e}") - update_job_progress(job_id, progress=0, status="failed", error=str(e)) - - if self.request.retries < self.max_retries: - raise self.retry(exc=e) - - return { - "status": "failed", - "job_id": job_id, - "error": str(e), - } diff --git a/core/rpc/server.py b/core/rpc/server.py index 5d412da..aac866d 100644 --- a/core/rpc/server.py +++ b/core/rpc/server.py @@ -29,14 +29,9 @@ _active_jobs: dict[str, dict] = {} class WorkerServicer(worker_pb2_grpc.WorkerServiceServicer): """gRPC service implementation for worker operations.""" - def __init__(self, celery_app=None): - """ - Initialize the servicer. - - Args: - celery_app: Optional Celery app for task dispatch - """ - self.celery_app = celery_app + def __init__(self): + """Initialize the servicer.""" + pass def SubmitJob(self, request, context): """Submit a transcode/trim job to the worker.""" @@ -57,28 +52,7 @@ class WorkerServicer(worker_pb2_grpc.WorkerServiceServicer): "error": None, } - # Dispatch to Celery if available - if self.celery_app: - from core.jobs.task import run_job - - payload = { - "source_key": request.source_path, - "output_key": request.output_path, - "preset": preset, - "trim_start": request.trim_start - if request.HasField("trim_start") - else None, - "trim_end": request.trim_end - if request.HasField("trim_end") - else None, - } - - task = run_job.delay( - job_type="transcode", - job_id=job_id, - payload=payload, - ) - _active_jobs[job_id]["celery_task_id"] = task.id + # TODO: dispatch via executor (local/lambda/gcp/grpc) return worker_pb2.JobResponse( job_id=job_id, @@ -155,12 +129,6 @@ class WorkerServicer(worker_pb2_grpc.WorkerServiceServicer): if job_id in _active_jobs: _active_jobs[job_id]["status"] = "cancelled" - # Revoke Celery task if available - if self.celery_app: - task_id = _active_jobs[job_id].get("celery_task_id") - if task_id: - self.celery_app.control.revoke(task_id, terminate=True) - return worker_pb2.CancelResponse( job_id=job_id, cancelled=True, @@ -290,13 +258,12 @@ def update_job_progress( logger.warning(f"Failed to update job {job_id} in DB: {e}") -def serve(port: int = None, celery_app=None) -> grpc.Server: +def serve(port: int = None) -> grpc.Server: """ Start the gRPC server. Args: port: Port to listen on (defaults to GRPC_PORT env var) - celery_app: Optional Celery app for task dispatch Returns: The running gRPC server @@ -306,7 +273,7 @@ def serve(port: int = None, celery_app=None) -> grpc.Server: server = grpc.server(futures.ThreadPoolExecutor(max_workers=GRPC_MAX_WORKERS)) worker_pb2_grpc.add_WorkerServiceServicer_to_server( - WorkerServicer(celery_app=celery_app), + WorkerServicer(), server, ) server.add_insecure_port(f"[::]:{port}") diff --git a/core/schema/models/__init__.py b/core/schema/models/__init__.py index f085456..5ca8c32 100644 --- a/core/schema/models/__init__.py +++ b/core/schema/models/__init__.py @@ -35,6 +35,7 @@ from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generi from .inference import INFERENCE_VIEWS # noqa: F401 — GPU inference server API types from .ui_state import UI_STATE_VIEWS # noqa: F401 — UI store state types from .stages import StageConfigField, StageIO, StageDefinition, STAGE_VIEWS # noqa: F401 +from .detect_api import RunRequest, RunResponse, DETECT_API_VIEWS # noqa: F401 from .views import ChunkEvent, ChunkOutputFile, PipelineStats, WorkerEvent from .sources import ChunkInfo, SourceJob, SourceType diff --git a/core/schema/models/detect_api.py b/core/schema/models/detect_api.py new file mode 100644 index 0000000..ff6f998 --- /dev/null +++ b/core/schema/models/detect_api.py @@ -0,0 +1,31 @@ +""" +Detection API request/response models. + +Source of truth for detection pipeline API shapes. +Generated to Pydantic via modelgen. +""" + +from dataclasses import dataclass + + +@dataclass +class RunRequest: + """Request body for launching a detection pipeline run.""" + video_path: str # storage key + profile_name: str = "soccer_broadcast" + source_asset_id: str = "" + checkpoint: bool = True + skip_vlm: bool = False + skip_cloud: bool = False + log_level: str = "INFO" # INFO | DEBUG + + +@dataclass +class RunResponse: + """Response after starting a pipeline run.""" + status: str + job_id: str + video_path: str + + +DETECT_API_VIEWS = [RunRequest, RunResponse] diff --git a/core/schema/models/job.py b/core/schema/models/job.py index 013fa56..fda2d0c 100644 --- a/core/schema/models/job.py +++ b/core/schema/models/job.py @@ -56,7 +56,6 @@ class Job: estimated_cost_usd: float = 0.0 # Worker tracking - celery_task_id: Optional[str] = None priority: int = 0 # Timestamps diff --git a/core/schema/models/jobs.py b/core/schema/models/jobs.py deleted file mode 100644 index 0957034..0000000 --- a/core/schema/models/jobs.py +++ /dev/null @@ -1,133 +0,0 @@ -""" -Job Schema Definitions - -Source of truth for job data models. -TranscodeJob and ChunkJob share common lifecycle fields by convention. -""" - -from dataclasses import dataclass, field -from datetime import datetime -from enum import Enum -from typing import Any, Dict, List, Optional -from uuid import UUID - - -class JobStatus(str, Enum): - """Status of a transcode/trim job.""" - - PENDING = "pending" - PROCESSING = "processing" - COMPLETED = "completed" - FAILED = "failed" - CANCELLED = "cancelled" - - -@dataclass -class TranscodeJob: - """ - A transcoding or trimming job in the queue. - - Jobs can either: - - Transcode using a preset (full re-encode) - - Trim only (stream copy with -c:v copy -c:a copy) - - A trim-only job has no preset and uses stream copy. - """ - - id: UUID - - # Input - source_asset_id: UUID - - # Configuration - preset_id: Optional[UUID] = None - preset_snapshot: Dict[str, Any] = field( - default_factory=dict - ) # Copy at creation time - - # Trimming (optional) - trim_start: Optional[float] = None # seconds - trim_end: Optional[float] = None # seconds - - # Output - output_filename: str = "" - output_path: Optional[str] = None - output_asset_id: Optional[UUID] = None - - # Status & Progress - status: JobStatus = JobStatus.PENDING - progress: float = 0.0 # 0.0 to 100.0 - current_frame: Optional[int] = None - current_time: Optional[float] = None # seconds processed - speed: Optional[str] = None # "2.5x" - error_message: Optional[str] = None - - # Worker tracking - celery_task_id: Optional[str] = None - execution_arn: Optional[str] = None # AWS Step Functions execution ARN - priority: int = 0 # Lower = higher priority - - # Timestamps - created_at: Optional[datetime] = None - started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None - - @property - def is_trim_only(self) -> bool: - """Check if this is a trim-only job (stream copy, no transcode).""" - return self.preset_id is None and ( - self.trim_start is not None or self.trim_end is not None - ) - - -class ChunkJobStatus(str, Enum): - """Status of a chunk pipeline job.""" - - PENDING = "pending" - CHUNKING = "chunking" - PROCESSING = "processing" - COLLECTING = "collecting" - COMPLETED = "completed" - FAILED = "failed" - CANCELLED = "cancelled" - - -@dataclass -class ChunkJob: - """ - A chunk pipeline job — splits a media file into chunks and processes them - through a concurrent worker pool. - """ - - id: UUID - - # Input - source_asset_id: UUID - - # Configuration - chunk_duration: float = 10.0 # seconds - num_workers: int = 4 - max_retries: int = 3 - processor_type: str = "ffmpeg" # "ffmpeg", "checksum", "simulated_decode", "composite" - - # Status & Progress - status: ChunkJobStatus = ChunkJobStatus.PENDING - progress: float = 0.0 # 0.0 to 100.0 - total_chunks: int = 0 - processed_chunks: int = 0 - failed_chunks: int = 0 - retry_count: int = 0 - error_message: Optional[str] = None - - # Result stats - throughput_mbps: Optional[float] = None - elapsed_seconds: Optional[float] = None - - # Worker tracking - celery_task_id: Optional[str] = None - priority: int = 0 # Lower = higher priority - - # Timestamps - created_at: Optional[datetime] = None - started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None diff --git a/detect/checkpoint/__init__.py b/detect/checkpoint/__init__.py index 7f30b39..ec64509 100644 --- a/detect/checkpoint/__init__.py +++ b/detect/checkpoint/__init__.py @@ -5,7 +5,6 @@ Checkpoint system — Timeline + Checkpoint tree. frames.py — frame image S3 upload/download storage.py — Timeline + Checkpoint (Postgres + MinIO) replay.py — replay (TODO: migrate to new model) - tasks.py — retry_candidates Celery task """ from .storage import ( diff --git a/detect/checkpoint/tasks.py b/detect/checkpoint/tasks.py deleted file mode 100644 index ea9e6a6..0000000 --- a/detect/checkpoint/tasks.py +++ /dev/null @@ -1,71 +0,0 @@ -""" -Celery tasks for detection pipeline async operations. - -retry_candidates: re-run VLM/cloud escalation with different config. -""" - -from __future__ import annotations - -import logging -import uuid -from datetime import datetime, timezone - -from celery import shared_task - -logger = logging.getLogger(__name__) - - -@shared_task(bind=True, max_retries=1, default_retry_delay=30) -def retry_candidates( - self, - job_id: str, - config_overrides: dict | None = None, - start_stage: str = "escalate_vlm", -): - """ - Retry unresolved candidates with different config. - - Loads the checkpoint from the stage before start_stage, - applies config overrides (e.g. different cloud provider), - and runs from start_stage onward. - """ - from detect.checkpoint.replay import replay_from - - run_id = str(uuid.uuid4())[:8] - logger.info("Retry task %s: job=%s, from=%s, overrides=%s", - run_id, job_id, start_stage, config_overrides) - - try: - result = replay_from( - job_id=job_id, - start_stage=start_stage, - config_overrides=config_overrides, - ) - - detections = result.get("detections", []) - report = result.get("report") - brands_found = len(report.brands) if report else 0 - - logger.info("Retry %s complete: %d detections, %d brands", - run_id, len(detections), brands_found) - - return { - "status": "completed", - "run_id": run_id, - "job_id": job_id, - "detections": len(detections), - "brands_found": brands_found, - } - - except Exception as e: - logger.exception("Retry %s failed: %s", run_id, e) - - if self.request.retries < self.max_retries: - raise self.retry(exc=e) - - return { - "status": "failed", - "run_id": run_id, - "job_id": job_id, - "error": str(e), - } diff --git a/detect/graph/__init__.py b/detect/graph/__init__.py new file mode 100644 index 0000000..f1d7073 --- /dev/null +++ b/detect/graph/__init__.py @@ -0,0 +1,29 @@ +""" +Detection pipeline graph. + + detect/graph/ + nodes.py — node functions (one per stage) + events.py — graph_update SSE emission + runner.py — pipeline execution (LangGraph wrapper, checkpoint, cancel) +""" + +from .nodes import NODES, NODE_FUNCTIONS +from .runner import ( + PipelineCancelled, + build_graph, + clear_cancel_check, + get_pipeline, + set_cancel_check, +) +from .events import _node_states + +__all__ = [ + "NODES", + "NODE_FUNCTIONS", + "PipelineCancelled", + "build_graph", + "get_pipeline", + "set_cancel_check", + "clear_cancel_check", + "_node_states", +] diff --git a/detect/graph/events.py b/detect/graph/events.py new file mode 100644 index 0000000..c620955 --- /dev/null +++ b/detect/graph/events.py @@ -0,0 +1,27 @@ +""" +Graph event emission — node state tracking + SSE graph_update events. +""" + +from __future__ import annotations + +from detect import emit +from detect.state import DetectState + + +# Track node states across pipeline runs +_node_states: dict[str, dict[str, str]] = {} + + +def emit_transition(state: DetectState, node: str, status: str, node_list: list[str]): + """Update node status and emit graph_update SSE event.""" + job_id = state.get("job_id") + if not job_id: + return + + if job_id not in _node_states: + _node_states[job_id] = {n: "pending" for n in node_list} + + _node_states[job_id][node] = status + + nodes = [{"id": n, "status": _node_states[job_id][n]} for n in node_list] + emit.graph_update(job_id, nodes) diff --git a/detect/graph.py b/detect/graph/nodes.py similarity index 63% rename from detect/graph.py rename to detect/graph/nodes.py index 4f0434d..f913aec 100644 --- a/detect/graph.py +++ b/detect/graph/nodes.py @@ -1,16 +1,13 @@ """ -LangGraph pipeline graph for brand detection. +Pipeline node functions — one per stage. -Nodes execute real logic for extract+filter, stubs for the rest. -Each node emits graph_update events so the UI can visualize transitions. +Each node: reads state, runs stage logic, emits transitions, returns output dict. """ from __future__ import annotations import os -from langgraph.graph import END, StateGraph - from detect import emit from detect.models import PipelineStats from detect.profiles import SoccerBroadcastProfile @@ -27,6 +24,8 @@ from detect.stages.vlm_cloud import escalate_cloud from detect.stages.aggregator import compile_report from detect.tracing import trace_node, flush as flush_traces +from .events import emit_transition + INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode NODES = [ @@ -58,41 +57,24 @@ def _get_profile(state: DetectState): return profile -# Track node states across the pipeline run -_node_states: dict[str, dict[str, str]] = {} - - -def _emit_transition(state: DetectState, node: str, status: str): - job_id = state.get("job_id") - if not job_id: - return - - # Initialize state tracking for this job - if job_id not in _node_states: - _node_states[job_id] = {n: "pending" for n in NODES} - - _node_states[job_id][node] = status - - nodes = [{"id": n, "status": _node_states[job_id][n]} for n in NODES] - emit.graph_update(job_id, nodes) +def _emit(state, node, status): + emit_transition(state, node, status, NODES) # --- Node functions --- def node_extract_frames(state: DetectState) -> dict: - # Set run context for initial runs (replays set it in replay_from) job_id = state.get("job_id", "") if job_id and not emit._run_context: emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial") - # Load session brands from DB for this source source_asset_id = state.get("source_asset_id") if source_asset_id and not state.get("session_brands"): from detect.stages.brand_resolver import build_session_dict session_brands = build_session_dict(source_asset_id) state["session_brands"] = session_brands - _emit_transition(state, "extract_frames", "running") + _emit(state, "extract_frames", "running") with trace_node(state, "extract_frames") as span: profile = _get_profile(state) @@ -100,12 +82,12 @@ def node_extract_frames(state: DetectState) -> dict: frames = extract_frames(state["video_path"], config, job_id=state.get("job_id")) span.set_output({"frames_extracted": len(frames)}) - _emit_transition(state, "extract_frames", "done") + _emit(state, "extract_frames", "done") return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))} def node_filter_scenes(state: DetectState) -> dict: - _emit_transition(state, "filter_scenes", "running") + _emit(state, "filter_scenes", "running") with trace_node(state, "filter_scenes") as span: profile = _get_profile(state) @@ -117,12 +99,12 @@ def node_filter_scenes(state: DetectState) -> dict: stats = state.get("stats", PipelineStats()) stats.frames_after_scene_filter = len(kept) - _emit_transition(state, "filter_scenes", "done") + _emit(state, "filter_scenes", "done") return {"filtered_frames": kept, "stats": stats} def node_detect_edges(state: DetectState) -> dict: - _emit_transition(state, "detect_edges", "running") + _emit(state, "detect_edges", "running") with trace_node(state, "detect_edges") as span: profile = _get_profile(state) @@ -139,12 +121,12 @@ def node_detect_edges(state: DetectState) -> dict: stats = state.get("stats", PipelineStats()) stats.cv_regions_detected = total - _emit_transition(state, "detect_edges", "done") + _emit(state, "detect_edges", "done") return {"edge_regions_by_frame": regions, "stats": stats} def node_detect_objects(state: DetectState) -> dict: - _emit_transition(state, "detect_objects", "running") + _emit(state, "detect_objects", "running") with trace_node(state, "detect_objects") as span: profile = _get_profile(state) @@ -159,12 +141,12 @@ def node_detect_objects(state: DetectState) -> dict: stats = state.get("stats", PipelineStats()) stats.regions_detected = total_regions - _emit_transition(state, "detect_objects", "done") + _emit(state, "detect_objects", "done") return {"boxes_by_frame": all_boxes, "stats": stats} def node_preprocess(state: DetectState) -> dict: - _emit_transition(state, "preprocess", "running") + _emit(state, "preprocess", "running") with trace_node(state, "preprocess") as span: profile = _get_profile(state) @@ -172,7 +154,6 @@ def node_preprocess(state: DetectState) -> dict: boxes = state.get("boxes_by_frame", {}) job_id = state.get("job_id") - # Get preprocessing config from profile overrides or defaults overrides = state.get("config_overrides", {}) prep_config = overrides.get("preprocessing", {}) do_contrast = prep_config.get("contrast", True) @@ -189,12 +170,12 @@ def node_preprocess(state: DetectState) -> dict: ) span.set_output({"regions_preprocessed": len(preprocessed)}) - _emit_transition(state, "preprocess", "done") + _emit(state, "preprocess", "done") return {"preprocessed_crops": preprocessed} def node_run_ocr(state: DetectState) -> dict: - _emit_transition(state, "run_ocr", "running") + _emit(state, "run_ocr", "running") with trace_node(state, "run_ocr") as span: profile = _get_profile(state) @@ -209,12 +190,12 @@ def node_run_ocr(state: DetectState) -> dict: stats = state.get("stats", PipelineStats()) stats.regions_resolved_by_ocr = len(candidates) - _emit_transition(state, "run_ocr", "done") + _emit(state, "run_ocr", "done") return {"text_candidates": candidates, "stats": stats} def node_match_brands(state: DetectState) -> dict: - _emit_transition(state, "match_brands", "running") + _emit(state, "match_brands", "running") with trace_node(state, "match_brands") as span: profile = _get_profile(state) @@ -232,12 +213,12 @@ def node_match_brands(state: DetectState) -> dict: ) span.set_output({"matched": len(matched), "unresolved": len(unresolved)}) - _emit_transition(state, "match_brands", "done") + _emit(state, "match_brands", "done") return {"detections": matched, "unresolved_candidates": unresolved} def node_escalate_vlm(state: DetectState) -> dict: - _emit_transition(state, "escalate_vlm", "running") + _emit(state, "escalate_vlm", "running") with trace_node(state, "escalate_vlm") as span: profile = _get_profile(state) @@ -261,7 +242,7 @@ def node_escalate_vlm(state: DetectState) -> dict: existing = state.get("detections", []) vlm_skipped = os.environ.get("SKIP_VLM", "").strip() == "1" - _emit_transition(state, "escalate_vlm", "skipped" if vlm_skipped else "done") + _emit(state, "escalate_vlm", "skipped" if vlm_skipped else "done") return { "detections": existing + vlm_matched, "unresolved_candidates": still_unresolved, @@ -270,7 +251,7 @@ def node_escalate_vlm(state: DetectState) -> dict: def node_escalate_cloud(state: DetectState) -> dict: - _emit_transition(state, "escalate_cloud", "running") + _emit(state, "escalate_cloud", "running") with trace_node(state, "escalate_cloud") as span: profile = _get_profile(state) @@ -294,12 +275,12 @@ def node_escalate_cloud(state: DetectState) -> dict: existing = state.get("detections", []) cloud_skipped = os.environ.get("SKIP_CLOUD", "").strip() == "1" - _emit_transition(state, "escalate_cloud", "skipped" if cloud_skipped else "done") + _emit(state, "escalate_cloud", "skipped" if cloud_skipped else "done") return {"detections": existing + cloud_matched, "stats": stats} def node_compile_report(state: DetectState) -> dict: - _emit_transition(state, "compile_report", "running") + _emit(state, "compile_report", "running") with trace_node(state, "compile_report") as span: profile = _get_profile(state) @@ -318,85 +299,10 @@ def node_compile_report(state: DetectState) -> dict: span.set_output({"brands": len(report.brands), "detections": len(report.timeline)}) flush_traces() - _emit_transition(state, "compile_report", "done") + _emit(state, "compile_report", "done") return {"report": report} -# --- Checkpoint wrapper --- - -_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1" -_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job) -_latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id - - -class PipelineCancelled(Exception): - """Raised when a pipeline run is cancelled.""" - pass - - -# Cancellation hook — set by the run endpoint, checked before each node -_cancel_check: dict[str, callable] = {} - - -def set_cancel_check(job_id: str, fn): - _cancel_check[job_id] = fn - - -def clear_cancel_check(job_id: str): - _cancel_check.pop(job_id, None) - - -def _checkpointing_node(node_name: str, node_fn): - """Wrap a node function to auto-checkpoint after completion.""" - stage_index = NODES.index(node_name) - - def wrapper(state: DetectState) -> dict: - job_id = state.get("job_id", "") - check = _cancel_check.get(job_id) - if check and check(): - raise PipelineCancelled(f"Cancelled before {node_name}") - - result = node_fn(state) - - job_id = state.get("job_id", "") - if not job_id: - return result - - from detect.checkpoint import save_stage_output, save_frames - from detect.stages.base import _REGISTRY - - merged = {**state, **result} - - # Save frames once (first node), reuse manifest after - manifest = _frames_manifest.get(job_id) - if manifest is None and node_name == "extract_frames": - manifest = save_frames(job_id, merged.get("frames", [])) - _frames_manifest[job_id] = manifest - - # Serialize stage output using the stage's serialize_fn if available - stage_cls = _REGISTRY.get(node_name) - serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None) - if serialize_fn: - output_json = serialize_fn(merged, job_id) - else: - output_json = {} - - parent_id = _latest_checkpoint.get(job_id) - new_checkpoint_id = save_stage_output( - timeline_id=job_id, - parent_checkpoint_id=parent_id, - stage_name=node_name, - output_json=output_json, - ) - _latest_checkpoint[job_id] = new_checkpoint_id - return result - - wrapper.__name__ = node_fn.__name__ - return wrapper - - -# --- Graph construction --- - NODE_FUNCTIONS = [ ("extract_frames", node_extract_frames), ("filter_scenes", node_filter_scenes), @@ -409,41 +315,3 @@ NODE_FUNCTIONS = [ ("escalate_cloud", node_escalate_cloud), ("compile_report", node_compile_report), ] - - -def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph: - """ - Build the pipeline graph. - - checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var) - start_from: skip nodes before this stage (for replay) - """ - do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED - - graph = StateGraph(DetectState) - - # Filter to start_from if replaying - node_pairs = NODE_FUNCTIONS - if start_from: - start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from) - node_pairs = NODE_FUNCTIONS[start_idx:] - - for name, fn in node_pairs: - wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn - graph.add_node(name, wrapped) - - # Wire edges - entry = node_pairs[0][0] - graph.set_entry_point(entry) - - for i in range(len(node_pairs) - 1): - graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0]) - - graph.add_edge(node_pairs[-1][0], END) - - return graph - - -def get_pipeline(checkpoint: bool | None = None): - """Return a compiled, runnable pipeline.""" - return build_graph(checkpoint=checkpoint).compile() diff --git a/detect/graph/runner.py b/detect/graph/runner.py new file mode 100644 index 0000000..20997be --- /dev/null +++ b/detect/graph/runner.py @@ -0,0 +1,127 @@ +""" +Pipeline runner — executes stages sequentially with checkpointing and cancellation. + +Currently wraps LangGraph for execution. Will be replaced with a lean +custom runner in Phase 3, with an executor socket for distributed dispatch. +""" + +from __future__ import annotations + +import os + +from langgraph.graph import END, StateGraph + +from detect.state import DetectState +from .nodes import NODES, NODE_FUNCTIONS + + +# --- Checkpoint wrapper --- + +_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1" +_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job) +_latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id + + +class PipelineCancelled(Exception): + """Raised when a pipeline run is cancelled.""" + pass + + +# Cancellation hook — set by the run endpoint, checked before each node +_cancel_check: dict[str, callable] = {} + + +def set_cancel_check(job_id: str, fn): + _cancel_check[job_id] = fn + + +def clear_cancel_check(job_id: str): + _cancel_check.pop(job_id, None) + + +def _checkpointing_node(node_name: str, node_fn): + """Wrap a node function to auto-checkpoint after completion.""" + + def wrapper(state: DetectState) -> dict: + job_id = state.get("job_id", "") + check = _cancel_check.get(job_id) + if check and check(): + raise PipelineCancelled(f"Cancelled before {node_name}") + + result = node_fn(state) + + job_id = state.get("job_id", "") + if not job_id: + return result + + from detect.checkpoint import save_stage_output, save_frames + from detect.stages.base import _REGISTRY + + merged = {**state, **result} + + # Save frames once (first node), reuse manifest after + manifest = _frames_manifest.get(job_id) + if manifest is None and node_name == "extract_frames": + manifest = save_frames(job_id, merged.get("frames", [])) + _frames_manifest[job_id] = manifest + + # Serialize stage output using the stage's serialize_fn if available + stage_cls = _REGISTRY.get(node_name) + serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None) + if serialize_fn: + output_json = serialize_fn(merged, job_id) + else: + output_json = {} + + parent_id = _latest_checkpoint.get(job_id) + new_checkpoint_id = save_stage_output( + timeline_id=job_id, + parent_checkpoint_id=parent_id, + stage_name=node_name, + output_json=output_json, + ) + _latest_checkpoint[job_id] = new_checkpoint_id + return result + + wrapper.__name__ = node_fn.__name__ + return wrapper + + +# --- Graph construction --- + +def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph: + """ + Build the pipeline graph. + + checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var) + start_from: skip nodes before this stage (for replay) + """ + do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED + + graph = StateGraph(DetectState) + + # Filter to start_from if replaying + node_pairs = NODE_FUNCTIONS + if start_from: + start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from) + node_pairs = NODE_FUNCTIONS[start_idx:] + + for name, fn in node_pairs: + wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn + graph.add_node(name, wrapped) + + # Wire edges + entry = node_pairs[0][0] + graph.set_entry_point(entry) + + for i in range(len(node_pairs) - 1): + graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0]) + + graph.add_edge(node_pairs[-1][0], END) + + return graph + + +def get_pipeline(checkpoint: bool | None = None): + """Return a compiled, runnable pipeline.""" + return build_graph(checkpoint=checkpoint).compile()