phase 1
This commit is contained in:
@@ -1,8 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
SSE endpoint for chunker pipeline events.
|
SSE endpoint for chunker pipeline events.
|
||||||
|
|
||||||
Uses Redis as the event bus between Celery workers and the SSE stream.
|
Uses Redis as the event bus. Pipeline pushes events via core.events,
|
||||||
Celery worker pushes events via core.events, SSE endpoint polls them.
|
SSE endpoint polls them.
|
||||||
|
|
||||||
GET /chunker/stream/{job_id} → text/event-stream
|
GET /chunker/stream/{job_id} → text/event-stream
|
||||||
"""
|
"""
|
||||||
|
|||||||
20
core/api/detect/__init__.py
Normal file
20
core/api/detect/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
"""
|
||||||
|
Detection API — aggregated router.
|
||||||
|
|
||||||
|
Combines all detect sub-routers into a single include for main.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from .sources import router as sources_router
|
||||||
|
from .run import router as run_router
|
||||||
|
from .sse import router as sse_router
|
||||||
|
from .replay import router as replay_router
|
||||||
|
from .config import router as config_router
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
router.include_router(sources_router)
|
||||||
|
router.include_router(run_router)
|
||||||
|
router.include_router(sse_router)
|
||||||
|
router.include_router(replay_router)
|
||||||
|
router.include_router(config_router)
|
||||||
@@ -1,19 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
Source browser for detection pipeline.
|
Pipeline run endpoints.
|
||||||
|
|
||||||
Lists available media sources from blob storage (MinIO).
|
POST /detect/run — launch pipeline on selected source
|
||||||
All file-based sources go through MinIO — no host filesystem access.
|
POST /detect/stop/{job_id} — cancel a running pipeline
|
||||||
The pipeline downloads chunks to a temp path before processing.
|
POST /detect/clear/{job_id} — clear events from Redis
|
||||||
|
|
||||||
Source types (current and future):
|
|
||||||
- chunk_job: pre-chunked segments in MinIO (current)
|
|
||||||
- upload: user-uploaded file, lands in MinIO via upload endpoint (future)
|
|
||||||
- device: local camera/capture card via ffmpeg, no MinIO (future)
|
|
||||||
- stream: RTMP/HLS URL via ffmpeg, no MinIO (future)
|
|
||||||
|
|
||||||
GET /detect/sources — list chunk jobs from blob store
|
|
||||||
GET /detect/sources/{job_id}/chunks — list chunks for a specific job
|
|
||||||
POST /detect/run — launch pipeline on selected source
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -31,23 +21,10 @@ logger = logging.getLogger(__name__)
|
|||||||
router = APIRouter(prefix="/detect", tags=["detect"])
|
router = APIRouter(prefix="/detect", tags=["detect"])
|
||||||
|
|
||||||
# In-process pipeline tracking
|
# In-process pipeline tracking
|
||||||
_running_jobs: dict[str, "threading.Thread"] = {}
|
_running_jobs: dict[str, threading.Thread] = {}
|
||||||
_cancelled_jobs: set[str] = set()
|
_cancelled_jobs: set[str] = set()
|
||||||
|
|
||||||
|
|
||||||
class ChunkInfo(BaseModel):
|
|
||||||
filename: str
|
|
||||||
key: str
|
|
||||||
size_bytes: int
|
|
||||||
|
|
||||||
|
|
||||||
class SourceInfo(BaseModel):
|
|
||||||
job_id: str
|
|
||||||
source_type: str = "chunk_job"
|
|
||||||
chunk_count: int
|
|
||||||
total_bytes: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class RunRequest(BaseModel):
|
class RunRequest(BaseModel):
|
||||||
video_path: str # storage key
|
video_path: str # storage key
|
||||||
profile_name: str = "soccer_broadcast"
|
profile_name: str = "soccer_broadcast"
|
||||||
@@ -64,91 +41,6 @@ class RunResponse(BaseModel):
|
|||||||
video_path: str
|
video_path: str
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Source listing
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _list_sources() -> list[SourceInfo]:
|
|
||||||
"""List chunk jobs from blob storage."""
|
|
||||||
from core.storage.blob import get_store
|
|
||||||
|
|
||||||
store = get_store("out")
|
|
||||||
try:
|
|
||||||
objects = store.list(prefix="chunks/")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Failed to list blob sources: %s", e)
|
|
||||||
return []
|
|
||||||
|
|
||||||
jobs: dict[str, int] = {}
|
|
||||||
job_bytes: dict[str, int] = {}
|
|
||||||
for obj in objects:
|
|
||||||
# Keys include store prefix: out/chunks/{job_id}/file.mp4
|
|
||||||
# Strip prefix to get: chunks/{job_id}/file.mp4
|
|
||||||
rel_key = obj.key.removeprefix(store.prefix)
|
|
||||||
parts = rel_key.split("/")
|
|
||||||
if len(parts) >= 3 and parts[0] == "chunks":
|
|
||||||
job_id = parts[1]
|
|
||||||
jobs[job_id] = jobs.get(job_id, 0) + 1
|
|
||||||
job_bytes[job_id] = job_bytes.get(job_id, 0) + obj.size_bytes
|
|
||||||
|
|
||||||
sources = []
|
|
||||||
for job_id, count in sorted(jobs.items()):
|
|
||||||
source = SourceInfo(
|
|
||||||
job_id=job_id,
|
|
||||||
source_type="chunk_job",
|
|
||||||
chunk_count=count,
|
|
||||||
total_bytes=job_bytes.get(job_id, 0),
|
|
||||||
)
|
|
||||||
sources.append(source)
|
|
||||||
return sources
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/sources", response_model=list[SourceInfo])
|
|
||||||
def list_sources():
|
|
||||||
"""List available chunk jobs from blob storage."""
|
|
||||||
return _list_sources()
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/sources/{source_job_id}/chunks", response_model=list[ChunkInfo])
|
|
||||||
def list_chunks(source_job_id: str):
|
|
||||||
"""List chunks for a specific source job."""
|
|
||||||
from core.storage.blob import get_store
|
|
||||||
|
|
||||||
store = get_store("out")
|
|
||||||
try:
|
|
||||||
objects = store.list(prefix=f"chunks/{source_job_id}/", extensions={".mp4"})
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Failed to list chunks for %s: %s", source_job_id, e)
|
|
||||||
raise HTTPException(status_code=503, detail=f"Blob storage unavailable: {e}")
|
|
||||||
|
|
||||||
if not objects:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Source not found: {source_job_id}")
|
|
||||||
|
|
||||||
chunks = []
|
|
||||||
for obj in objects:
|
|
||||||
info = ChunkInfo(filename=obj.filename, key=obj.key, size_bytes=obj.size_bytes)
|
|
||||||
chunks.append(info)
|
|
||||||
return sorted(chunks, key=lambda c: c.filename)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/sources/{source_job_id}/chunks/{filename}/url")
|
|
||||||
def get_chunk_url(source_job_id: str, filename: str):
|
|
||||||
"""Return a presigned URL for previewing a chunk in the browser."""
|
|
||||||
from core.storage.blob import get_store
|
|
||||||
|
|
||||||
store = get_store("out")
|
|
||||||
key = f"chunks/{source_job_id}/{filename}"
|
|
||||||
try:
|
|
||||||
url = store.get_url(key, expires=3600)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=503, detail=f"Could not generate URL: {e}")
|
|
||||||
return {"url": url}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Run pipeline
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _resolve_video_path(video_path: str) -> str:
|
def _resolve_video_path(video_path: str) -> str:
|
||||||
"""Download a chunk from blob storage to a temp file."""
|
"""Download a chunk from blob storage to a temp file."""
|
||||||
from core.storage.blob import get_store
|
from core.storage.blob import get_store
|
||||||
@@ -216,7 +108,6 @@ def run_pipeline(req: RunRequest):
|
|||||||
emit.job_complete(job_id, {"status": "cancelled"})
|
emit.job_complete(job_id, {"status": "cancelled"})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Pipeline run %s failed: %s", job_id, e)
|
logger.exception("Pipeline run %s failed: %s", job_id, e)
|
||||||
# Mark the current/last stage as error in the graph
|
|
||||||
from detect.graph import _node_states, NODES
|
from detect.graph import _node_states, NODES
|
||||||
if job_id in _node_states:
|
if job_id in _node_states:
|
||||||
states = _node_states[job_id]
|
states = _node_states[job_id]
|
||||||
108
core/api/detect/sources.py
Normal file
108
core/api/detect/sources.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""
|
||||||
|
Source browser for detection pipeline.
|
||||||
|
|
||||||
|
Lists available media sources from blob storage (MinIO).
|
||||||
|
|
||||||
|
GET /detect/sources — list chunk jobs
|
||||||
|
GET /detect/sources/{job_id}/chunks — list chunks for a job
|
||||||
|
GET /detect/sources/{job_id}/chunks/{name}/url — presigned preview URL
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/detect", tags=["detect"])
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkInfoResponse(BaseModel):
|
||||||
|
filename: str
|
||||||
|
key: str
|
||||||
|
size_bytes: int
|
||||||
|
|
||||||
|
|
||||||
|
class SourceInfoResponse(BaseModel):
|
||||||
|
job_id: str
|
||||||
|
source_type: str = "chunk_job"
|
||||||
|
chunk_count: int
|
||||||
|
total_bytes: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
def _list_sources() -> list[SourceInfoResponse]:
|
||||||
|
"""List chunk jobs from blob storage."""
|
||||||
|
from core.storage.blob import get_store
|
||||||
|
|
||||||
|
store = get_store("out")
|
||||||
|
try:
|
||||||
|
objects = store.list(prefix="chunks/")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to list blob sources: %s", e)
|
||||||
|
return []
|
||||||
|
|
||||||
|
jobs: dict[str, int] = {}
|
||||||
|
job_bytes: dict[str, int] = {}
|
||||||
|
for obj in objects:
|
||||||
|
rel_key = obj.key.removeprefix(store.prefix)
|
||||||
|
parts = rel_key.split("/")
|
||||||
|
if len(parts) >= 3 and parts[0] == "chunks":
|
||||||
|
job_id = parts[1]
|
||||||
|
jobs[job_id] = jobs.get(job_id, 0) + 1
|
||||||
|
job_bytes[job_id] = job_bytes.get(job_id, 0) + obj.size_bytes
|
||||||
|
|
||||||
|
sources = []
|
||||||
|
for job_id, count in sorted(jobs.items()):
|
||||||
|
source = SourceInfoResponse(
|
||||||
|
job_id=job_id,
|
||||||
|
source_type="chunk_job",
|
||||||
|
chunk_count=count,
|
||||||
|
total_bytes=job_bytes.get(job_id, 0),
|
||||||
|
)
|
||||||
|
sources.append(source)
|
||||||
|
return sources
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sources", response_model=list[SourceInfoResponse])
|
||||||
|
def list_sources():
|
||||||
|
"""List available chunk jobs from blob storage."""
|
||||||
|
return _list_sources()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sources/{source_job_id}/chunks", response_model=list[ChunkInfoResponse])
|
||||||
|
def list_chunks(source_job_id: str):
|
||||||
|
"""List chunks for a specific source job."""
|
||||||
|
from core.storage.blob import get_store
|
||||||
|
|
||||||
|
store = get_store("out")
|
||||||
|
try:
|
||||||
|
objects = store.list(prefix=f"chunks/{source_job_id}/", extensions={".mp4"})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to list chunks for %s: %s", source_job_id, e)
|
||||||
|
raise HTTPException(status_code=503, detail=f"Blob storage unavailable: {e}")
|
||||||
|
|
||||||
|
if not objects:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Source not found: {source_job_id}")
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
for obj in objects:
|
||||||
|
info = ChunkInfoResponse(filename=obj.filename, key=obj.key, size_bytes=obj.size_bytes)
|
||||||
|
chunks.append(info)
|
||||||
|
return sorted(chunks, key=lambda c: c.filename)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sources/{source_job_id}/chunks/{filename}/url")
|
||||||
|
def get_chunk_url(source_job_id: str, filename: str):
|
||||||
|
"""Return a presigned URL for previewing a chunk in the browser."""
|
||||||
|
from core.storage.blob import get_store
|
||||||
|
|
||||||
|
store = get_store("out")
|
||||||
|
key = f"chunks/{source_job_id}/{filename}"
|
||||||
|
try:
|
||||||
|
url = store.get_url(key, expires=3600)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=503, detail=f"Could not generate URL: {e}")
|
||||||
|
return {"url": url}
|
||||||
@@ -19,10 +19,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from strawberry.fastapi import GraphQLRouter
|
from strawberry.fastapi import GraphQLRouter
|
||||||
|
|
||||||
from core.api.chunker_sse import router as chunker_router
|
from core.api.chunker_sse import router as chunker_router
|
||||||
from core.api.detect_sse import router as detect_router
|
from core.api.detect import router as detect_router
|
||||||
from core.api.detect_replay import router as detect_replay_router
|
|
||||||
from core.api.detect_config import router as detect_config_router
|
|
||||||
from core.api.detect_sources import router as detect_sources_router
|
|
||||||
from core.api.graphql import schema as graphql_schema
|
from core.api.graphql import schema as graphql_schema
|
||||||
|
|
||||||
CALLBACK_API_KEY = os.environ.get("CALLBACK_API_KEY", "")
|
CALLBACK_API_KEY = os.environ.get("CALLBACK_API_KEY", "")
|
||||||
@@ -61,18 +58,9 @@ app.include_router(graphql_router, prefix="/graphql")
|
|||||||
# Chunker SSE
|
# Chunker SSE
|
||||||
app.include_router(chunker_router)
|
app.include_router(chunker_router)
|
||||||
|
|
||||||
# Detection SSE
|
# Detection API (sources, run, SSE, replay, config)
|
||||||
app.include_router(detect_router)
|
app.include_router(detect_router)
|
||||||
|
|
||||||
# Detection replay/retry
|
|
||||||
app.include_router(detect_replay_router)
|
|
||||||
|
|
||||||
# Detection config
|
|
||||||
app.include_router(detect_config_router)
|
|
||||||
|
|
||||||
# Detection sources + run launcher
|
|
||||||
app.include_router(detect_sources_router)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
def health():
|
def health():
|
||||||
|
|||||||
@@ -48,7 +48,6 @@ class Job(SQLModel, table=True):
|
|||||||
brands_found: int = 0
|
brands_found: int = 0
|
||||||
cloud_llm_calls: int = 0
|
cloud_llm_calls: int = 0
|
||||||
estimated_cost_usd: float = 0.0
|
estimated_cost_usd: float = 0.0
|
||||||
celery_task_id: Optional[str] = None
|
|
||||||
priority: int = 0
|
priority: int = 0
|
||||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||||
started_at: Optional[datetime] = None
|
started_at: Optional[datetime] = None
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Redis-based event bus for pipeline job progress.
|
Redis-based event bus for pipeline job progress.
|
||||||
|
|
||||||
Celery workers push events, SSE endpoints poll them.
|
Pipeline stages push events, SSE endpoints poll them.
|
||||||
Only depends on redis — safe to import from any context.
|
Only depends on redis — safe to import from any context.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,13 @@
|
|||||||
"""
|
"""
|
||||||
MPR Jobs Module
|
MPR Jobs Module
|
||||||
|
|
||||||
Provides executor abstraction and task dispatch for job processing.
|
Provides executor abstraction for job dispatch (local, Lambda, GCP).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .executor import Executor, LocalExecutor, get_executor
|
from .executor import Executor, LocalExecutor, get_executor
|
||||||
from .task import run_job
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Executor",
|
"Executor",
|
||||||
"LocalExecutor",
|
"LocalExecutor",
|
||||||
"get_executor",
|
"get_executor",
|
||||||
"run_job",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ class Executor(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class LocalExecutor(Executor):
|
class LocalExecutor(Executor):
|
||||||
"""Execute jobs locally using registered handlers."""
|
"""Execute jobs locally by calling the stage function directly."""
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
@@ -51,16 +51,10 @@ class LocalExecutor(Executor):
|
|||||||
payload: Dict[str, Any],
|
payload: Dict[str, Any],
|
||||||
progress_callback: Optional[Callable[[int, Dict[str, Any]], None]] = None,
|
progress_callback: Optional[Callable[[int, Dict[str, Any]], None]] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Execute job using the appropriate local handler."""
|
"""Execute job locally. Socket for PipelineRunner integration."""
|
||||||
from .registry import get_handler
|
raise NotImplementedError(
|
||||||
|
"LocalExecutor.run() — will be wired to PipelineRunner in Phase 3"
|
||||||
handler = get_handler(job_type)
|
|
||||||
result = handler.process(
|
|
||||||
job_id=job_id,
|
|
||||||
payload=payload,
|
|
||||||
progress_callback=progress_callback,
|
|
||||||
)
|
)
|
||||||
return result.get("status") == "completed"
|
|
||||||
|
|
||||||
|
|
||||||
class LambdaExecutor(Executor):
|
class LambdaExecutor(Executor):
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
"""Job handlers — type-specific execution logic."""
|
|
||||||
|
|
||||||
from .base import Handler
|
|
||||||
|
|
||||||
__all__ = ["Handler"]
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
"""
|
|
||||||
Base Handler ABC — defines the interface for job-type-specific execution logic.
|
|
||||||
|
|
||||||
A Handler knows HOW to execute a specific kind of job (transcode, chunk, etc.).
|
|
||||||
The Executor decides WHERE to run it (local, Lambda, GCP).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Callable, Dict, Optional
|
|
||||||
|
|
||||||
|
|
||||||
class Handler(ABC):
|
|
||||||
"""Abstract base class for job handlers."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def process(
|
|
||||||
self,
|
|
||||||
job_id: str,
|
|
||||||
payload: Dict[str, Any],
|
|
||||||
progress_callback: Optional[Callable[[int, Dict[str, Any]], None]] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Execute job-specific logic.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
job_id: Unique job identifier
|
|
||||||
payload: Job-type-specific configuration
|
|
||||||
progress_callback: Called with (percent, details_dict)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Result dict with at least {"status": "completed"} or raises
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
@@ -1,125 +0,0 @@
|
|||||||
"""
|
|
||||||
ChunkHandler — job handler that wraps the chunker Pipeline.
|
|
||||||
|
|
||||||
Downloads source from S3/MinIO, runs FFmpeg chunking pipeline,
|
|
||||||
writes mp4 segments + manifest to media/out/chunks/{job_id}/.
|
|
||||||
Pushes real-time events to Redis for SSE consumption.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import Any, Callable, Dict, Optional
|
|
||||||
|
|
||||||
from core.events import push_event as push_chunk_event
|
|
||||||
from core.chunker import Pipeline
|
|
||||||
from core.storage import BUCKET_IN, download_to_temp
|
|
||||||
|
|
||||||
from .base import Handler
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
MEDIA_OUT_DIR = os.environ.get("MEDIA_OUT_DIR", "/app/media/out")
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkHandler(Handler):
|
|
||||||
"""
|
|
||||||
Handles chunk processing jobs by delegating to the chunker Pipeline.
|
|
||||||
|
|
||||||
Expected payload keys:
|
|
||||||
source_key: str — S3 key of the source file in BUCKET_IN
|
|
||||||
chunk_duration: float — seconds per chunk (default: 10.0)
|
|
||||||
num_workers: int — concurrent workers (default: 4)
|
|
||||||
max_retries: int — retries per chunk (default: 3)
|
|
||||||
processor_type: str — "ffmpeg", "checksum", "simulated_decode", "composite"
|
|
||||||
queue_size: int — max queue depth (default: 10)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def process(
|
|
||||||
self,
|
|
||||||
job_id: str,
|
|
||||||
payload: Dict[str, Any],
|
|
||||||
progress_callback: Optional[Callable[[int, Dict[str, Any]], None]] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
source_key = payload["source_key"]
|
|
||||||
processor_type = payload.get("processor_type", "ffmpeg")
|
|
||||||
|
|
||||||
logger.info(f"ChunkHandler starting job {job_id}: {source_key}")
|
|
||||||
|
|
||||||
# Download source from S3/MinIO
|
|
||||||
push_chunk_event(job_id, "pipeline_start", {"status": "downloading", "source_key": source_key})
|
|
||||||
tmp_source = download_to_temp(BUCKET_IN, source_key)
|
|
||||||
|
|
||||||
# Output directory: media/out/chunks/{job_id}/
|
|
||||||
output_dir = os.path.join(MEDIA_OUT_DIR, "chunks", job_id)
|
|
||||||
if processor_type == "ffmpeg":
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
def event_bridge(event_type: str, data: Dict[str, Any]) -> None:
|
|
||||||
"""Bridge pipeline events to Redis + optional progress callback."""
|
|
||||||
push_chunk_event(job_id, event_type, data)
|
|
||||||
|
|
||||||
if progress_callback and event_type == "pipeline_complete":
|
|
||||||
progress_callback(100, data)
|
|
||||||
elif progress_callback and event_type == "chunk_done":
|
|
||||||
total = data.get("total_chunks", 1)
|
|
||||||
if total > 0:
|
|
||||||
pct = min(int((data.get("sequence", 0) + 1) / total * 100), 99)
|
|
||||||
progress_callback(pct, data)
|
|
||||||
|
|
||||||
pipeline = Pipeline(
|
|
||||||
source=tmp_source,
|
|
||||||
chunk_duration=payload.get("chunk_duration", 10.0),
|
|
||||||
num_workers=payload.get("num_workers", 4),
|
|
||||||
max_retries=payload.get("max_retries", 3),
|
|
||||||
processor_type=processor_type,
|
|
||||||
queue_size=payload.get("queue_size", 10),
|
|
||||||
event_callback=event_bridge,
|
|
||||||
output_dir=output_dir if processor_type == "ffmpeg" else None,
|
|
||||||
start_time=payload.get("start_time"),
|
|
||||||
end_time=payload.get("end_time"),
|
|
||||||
)
|
|
||||||
|
|
||||||
result = pipeline.run()
|
|
||||||
|
|
||||||
# Files are already in media/out/chunks/{job_id}/
|
|
||||||
output_prefix = f"chunks/{job_id}"
|
|
||||||
output_files = [
|
|
||||||
f"{output_prefix}/{os.path.basename(f)}"
|
|
||||||
for f in result.chunk_files
|
|
||||||
]
|
|
||||||
|
|
||||||
push_chunk_event(job_id, "pipeline_complete", {
|
|
||||||
"status": "completed",
|
|
||||||
"total_chunks": result.total_chunks,
|
|
||||||
"processed": result.processed,
|
|
||||||
"failed": result.failed,
|
|
||||||
"elapsed": result.elapsed_time,
|
|
||||||
"throughput_mbps": result.throughput_mbps,
|
|
||||||
})
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "completed" if result.failed == 0 else "completed_with_errors",
|
|
||||||
"total_chunks": result.total_chunks,
|
|
||||||
"processed": result.processed,
|
|
||||||
"failed": result.failed,
|
|
||||||
"retries": result.retries,
|
|
||||||
"elapsed_time": result.elapsed_time,
|
|
||||||
"throughput_mbps": result.throughput_mbps,
|
|
||||||
"worker_stats": result.worker_stats,
|
|
||||||
"errors": result.errors,
|
|
||||||
"chunks_in_order": result.chunks_in_order,
|
|
||||||
"output_prefix": output_prefix,
|
|
||||||
"output_files": output_files,
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
push_chunk_event(job_id, "pipeline_error", {"status": "failed", "error": str(e)})
|
|
||||||
raise
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Cleanup temp source file only (output dir is persistent)
|
|
||||||
try:
|
|
||||||
os.unlink(tmp_source)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
@@ -1,130 +0,0 @@
|
|||||||
"""
|
|
||||||
DetectHandler — runs the detection pipeline as a Celery job.
|
|
||||||
|
|
||||||
Supports three modes via payload:
|
|
||||||
- Initial run: {"video_path": "...", "profile_name": "..."}
|
|
||||||
- Replay: {"replay_from": "run_ocr", "source_job_id": "...", "config_overrides": {...}}
|
|
||||||
- Retry: {"retry_from": "escalate_vlm", "source_job_id": "...", "config_overrides": {...}}
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
from typing import Any, Callable, Dict, Optional
|
|
||||||
|
|
||||||
from .base import Handler
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class DetectHandler(Handler):
|
|
||||||
|
|
||||||
def process(
|
|
||||||
self,
|
|
||||||
job_id: str,
|
|
||||||
payload: Dict[str, Any],
|
|
||||||
progress_callback: Optional[Callable[[int, Dict[str, Any]], None]] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
|
|
||||||
replay_from = payload.get("replay_from")
|
|
||||||
source_job_id = payload.get("source_job_id")
|
|
||||||
|
|
||||||
if replay_from and source_job_id:
|
|
||||||
return self._run_replay(job_id, source_job_id, replay_from, payload, progress_callback)
|
|
||||||
|
|
||||||
return self._run_initial(job_id, payload, progress_callback)
|
|
||||||
|
|
||||||
def _run_initial(
|
|
||||||
self,
|
|
||||||
job_id: str,
|
|
||||||
payload: Dict[str, Any],
|
|
||||||
progress_callback: Optional[Callable],
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
from detect import emit
|
|
||||||
from detect.graph import get_pipeline
|
|
||||||
from detect.state import DetectState
|
|
||||||
|
|
||||||
video_path = payload["video_path"]
|
|
||||||
profile_name = payload.get("profile_name", "soccer_broadcast")
|
|
||||||
source_asset_id = payload.get("source_asset_id", "")
|
|
||||||
checkpoint_enabled = payload.get("checkpoint", os.environ.get("MPR_CHECKPOINT") == "1")
|
|
||||||
|
|
||||||
emit.set_run_context(
|
|
||||||
run_id=job_id,
|
|
||||||
parent_job_id=payload.get("parent_job_id", job_id),
|
|
||||||
run_type="initial",
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("DetectHandler: initial run job=%s video=%s profile=%s checkpoint=%s",
|
|
||||||
job_id, video_path, profile_name, checkpoint_enabled)
|
|
||||||
|
|
||||||
if progress_callback:
|
|
||||||
progress_callback(0, {"stage": "starting"})
|
|
||||||
|
|
||||||
pipeline = get_pipeline(checkpoint=checkpoint_enabled)
|
|
||||||
|
|
||||||
initial_state = DetectState(
|
|
||||||
video_path=video_path,
|
|
||||||
job_id=job_id,
|
|
||||||
profile_name=profile_name,
|
|
||||||
source_asset_id=source_asset_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = pipeline.invoke(initial_state)
|
|
||||||
finally:
|
|
||||||
emit.clear_run_context()
|
|
||||||
|
|
||||||
detections = result.get("detections", [])
|
|
||||||
report = result.get("report")
|
|
||||||
brands_found = len(report.brands) if report else 0
|
|
||||||
|
|
||||||
if progress_callback:
|
|
||||||
progress_callback(100, {"stage": "completed"})
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "completed",
|
|
||||||
"job_id": job_id,
|
|
||||||
"detections": len(detections),
|
|
||||||
"brands_found": brands_found,
|
|
||||||
}
|
|
||||||
|
|
||||||
def _run_replay(
|
|
||||||
self,
|
|
||||||
job_id: str,
|
|
||||||
source_job_id: str,
|
|
||||||
start_stage: str,
|
|
||||||
payload: Dict[str, Any],
|
|
||||||
progress_callback: Optional[Callable],
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
from detect.checkpoint import replay_from
|
|
||||||
|
|
||||||
config_overrides = payload.get("config_overrides", {})
|
|
||||||
|
|
||||||
logger.info("DetectHandler: replay job=%s from=%s source=%s overrides=%s",
|
|
||||||
job_id, start_stage, source_job_id, config_overrides)
|
|
||||||
|
|
||||||
if progress_callback:
|
|
||||||
progress_callback(0, {"stage": f"replaying from {start_stage}"})
|
|
||||||
|
|
||||||
result = replay_from(
|
|
||||||
job_id=source_job_id,
|
|
||||||
start_stage=start_stage,
|
|
||||||
config_overrides=config_overrides,
|
|
||||||
)
|
|
||||||
|
|
||||||
detections = result.get("detections", [])
|
|
||||||
report = result.get("report")
|
|
||||||
brands_found = len(report.brands) if report else 0
|
|
||||||
|
|
||||||
if progress_callback:
|
|
||||||
progress_callback(100, {"stage": "completed"})
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "completed",
|
|
||||||
"job_id": job_id,
|
|
||||||
"source_job_id": source_job_id,
|
|
||||||
"replay_from": start_stage,
|
|
||||||
"detections": len(detections),
|
|
||||||
"brands_found": brands_found,
|
|
||||||
}
|
|
||||||
@@ -1,104 +0,0 @@
|
|||||||
"""
|
|
||||||
TranscodeHandler — executes transcode/trim jobs using FFmpeg.
|
|
||||||
|
|
||||||
Extracted from the old tasks.py Celery task logic.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Callable, Dict, Optional
|
|
||||||
|
|
||||||
from core.ffmpeg.transcode import TranscodeConfig, transcode
|
|
||||||
from core.storage import BUCKET_IN, BUCKET_OUT, download_to_temp, upload_file
|
|
||||||
|
|
||||||
from .base import Handler
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class TranscodeHandler(Handler):
|
|
||||||
"""Handle transcode and trim jobs via FFmpeg."""
|
|
||||||
|
|
||||||
def process(
|
|
||||||
self,
|
|
||||||
job_id: str,
|
|
||||||
payload: Dict[str, Any],
|
|
||||||
progress_callback: Optional[Callable[[int, Dict[str, Any]], None]] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
source_key = payload["source_key"]
|
|
||||||
output_key = payload["output_key"]
|
|
||||||
preset = payload.get("preset")
|
|
||||||
trim_start = payload.get("trim_start")
|
|
||||||
trim_end = payload.get("trim_end")
|
|
||||||
duration = payload.get("duration")
|
|
||||||
|
|
||||||
logger.info(f"TranscodeHandler: {source_key} -> {output_key}")
|
|
||||||
|
|
||||||
# Download source
|
|
||||||
tmp_source = download_to_temp(BUCKET_IN, source_key)
|
|
||||||
|
|
||||||
ext = Path(output_key).suffix or ".mp4"
|
|
||||||
fd, tmp_output = tempfile.mkstemp(suffix=ext)
|
|
||||||
os.close(fd)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if preset:
|
|
||||||
config = TranscodeConfig(
|
|
||||||
input_path=tmp_source,
|
|
||||||
output_path=tmp_output,
|
|
||||||
video_codec=preset.get("video_codec", "libx264"),
|
|
||||||
video_bitrate=preset.get("video_bitrate"),
|
|
||||||
video_crf=preset.get("video_crf"),
|
|
||||||
video_preset=preset.get("video_preset"),
|
|
||||||
resolution=preset.get("resolution"),
|
|
||||||
framerate=preset.get("framerate"),
|
|
||||||
audio_codec=preset.get("audio_codec", "aac"),
|
|
||||||
audio_bitrate=preset.get("audio_bitrate"),
|
|
||||||
audio_channels=preset.get("audio_channels"),
|
|
||||||
audio_samplerate=preset.get("audio_samplerate"),
|
|
||||||
container=preset.get("container", "mp4"),
|
|
||||||
extra_args=preset.get("extra_args", []),
|
|
||||||
trim_start=trim_start,
|
|
||||||
trim_end=trim_end,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
config = TranscodeConfig(
|
|
||||||
input_path=tmp_source,
|
|
||||||
output_path=tmp_output,
|
|
||||||
video_codec="copy",
|
|
||||||
audio_codec="copy",
|
|
||||||
trim_start=trim_start,
|
|
||||||
trim_end=trim_end,
|
|
||||||
)
|
|
||||||
|
|
||||||
def wrapped_callback(percent: float, details: Dict[str, Any]) -> None:
|
|
||||||
if progress_callback:
|
|
||||||
progress_callback(int(percent), details)
|
|
||||||
|
|
||||||
success = transcode(
|
|
||||||
config,
|
|
||||||
duration=duration,
|
|
||||||
progress_callback=wrapped_callback if progress_callback else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
raise RuntimeError("Transcode returned False")
|
|
||||||
|
|
||||||
# Upload result
|
|
||||||
logger.info(f"Uploading {output_key} to {BUCKET_OUT}")
|
|
||||||
upload_file(tmp_output, BUCKET_OUT, output_key)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "completed",
|
|
||||||
"job_id": job_id,
|
|
||||||
"output_key": output_key,
|
|
||||||
}
|
|
||||||
|
|
||||||
finally:
|
|
||||||
for f in [tmp_source, tmp_output]:
|
|
||||||
try:
|
|
||||||
os.unlink(f)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
"""
|
|
||||||
Handler registry — maps job_type strings to Handler classes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, Type
|
|
||||||
|
|
||||||
from .handlers.base import Handler
|
|
||||||
|
|
||||||
_handlers: Dict[str, Type[Handler]] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def register_handler(job_type: str, handler_class: Type[Handler]) -> None:
|
|
||||||
"""Register a handler class for a job type."""
|
|
||||||
_handlers[job_type] = handler_class
|
|
||||||
|
|
||||||
|
|
||||||
def get_handler(job_type: str) -> Handler:
|
|
||||||
"""Get an instantiated handler for a job type."""
|
|
||||||
if job_type not in _handlers:
|
|
||||||
raise ValueError(f"Unknown job type: {job_type}")
|
|
||||||
return _handlers[job_type]()
|
|
||||||
|
|
||||||
|
|
||||||
def _register_defaults() -> None:
|
|
||||||
"""Register built-in handlers."""
|
|
||||||
from .handlers.chunk import ChunkHandler
|
|
||||||
from .handlers.transcode import TranscodeHandler
|
|
||||||
from .handlers.detect import DetectHandler
|
|
||||||
|
|
||||||
register_handler("transcode", TranscodeHandler)
|
|
||||||
register_handler("chunk", ChunkHandler)
|
|
||||||
register_handler("detect", DetectHandler)
|
|
||||||
|
|
||||||
|
|
||||||
_register_defaults()
|
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
"""
|
|
||||||
Celery task for job processing.
|
|
||||||
|
|
||||||
Generic dispatcher — routes to the appropriate handler based on job_type.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
from celery import shared_task
|
|
||||||
|
|
||||||
from core.rpc.server import update_job_progress
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@shared_task(bind=True, max_retries=3, default_retry_delay=60)
|
|
||||||
def run_job(
|
|
||||||
self,
|
|
||||||
job_type: str,
|
|
||||||
job_id: str,
|
|
||||||
payload: Dict[str, Any],
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Generic Celery task — dispatches to the registered handler for job_type.
|
|
||||||
"""
|
|
||||||
logger.info(f"Starting {job_type} job {job_id}")
|
|
||||||
|
|
||||||
update_job_progress(job_id, progress=0, status="processing")
|
|
||||||
|
|
||||||
def progress_callback(percent: int, details: Dict[str, Any]) -> None:
|
|
||||||
update_job_progress(
|
|
||||||
job_id,
|
|
||||||
progress=percent,
|
|
||||||
current_time=details.get("time", 0.0),
|
|
||||||
status="processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from .registry import get_handler
|
|
||||||
|
|
||||||
handler = get_handler(job_type)
|
|
||||||
result = handler.process(
|
|
||||||
job_id=job_id,
|
|
||||||
payload=payload,
|
|
||||||
progress_callback=progress_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Job {job_id} completed successfully")
|
|
||||||
update_job_progress(job_id, progress=100, status="completed")
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Job {job_id} failed: {e}")
|
|
||||||
update_job_progress(job_id, progress=0, status="failed", error=str(e))
|
|
||||||
|
|
||||||
if self.request.retries < self.max_retries:
|
|
||||||
raise self.retry(exc=e)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "failed",
|
|
||||||
"job_id": job_id,
|
|
||||||
"error": str(e),
|
|
||||||
}
|
|
||||||
@@ -29,14 +29,9 @@ _active_jobs: dict[str, dict] = {}
|
|||||||
class WorkerServicer(worker_pb2_grpc.WorkerServiceServicer):
|
class WorkerServicer(worker_pb2_grpc.WorkerServiceServicer):
|
||||||
"""gRPC service implementation for worker operations."""
|
"""gRPC service implementation for worker operations."""
|
||||||
|
|
||||||
def __init__(self, celery_app=None):
|
def __init__(self):
|
||||||
"""
|
"""Initialize the servicer."""
|
||||||
Initialize the servicer.
|
pass
|
||||||
|
|
||||||
Args:
|
|
||||||
celery_app: Optional Celery app for task dispatch
|
|
||||||
"""
|
|
||||||
self.celery_app = celery_app
|
|
||||||
|
|
||||||
def SubmitJob(self, request, context):
|
def SubmitJob(self, request, context):
|
||||||
"""Submit a transcode/trim job to the worker."""
|
"""Submit a transcode/trim job to the worker."""
|
||||||
@@ -57,28 +52,7 @@ class WorkerServicer(worker_pb2_grpc.WorkerServiceServicer):
|
|||||||
"error": None,
|
"error": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Dispatch to Celery if available
|
# TODO: dispatch via executor (local/lambda/gcp/grpc)
|
||||||
if self.celery_app:
|
|
||||||
from core.jobs.task import run_job
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"source_key": request.source_path,
|
|
||||||
"output_key": request.output_path,
|
|
||||||
"preset": preset,
|
|
||||||
"trim_start": request.trim_start
|
|
||||||
if request.HasField("trim_start")
|
|
||||||
else None,
|
|
||||||
"trim_end": request.trim_end
|
|
||||||
if request.HasField("trim_end")
|
|
||||||
else None,
|
|
||||||
}
|
|
||||||
|
|
||||||
task = run_job.delay(
|
|
||||||
job_type="transcode",
|
|
||||||
job_id=job_id,
|
|
||||||
payload=payload,
|
|
||||||
)
|
|
||||||
_active_jobs[job_id]["celery_task_id"] = task.id
|
|
||||||
|
|
||||||
return worker_pb2.JobResponse(
|
return worker_pb2.JobResponse(
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
@@ -155,12 +129,6 @@ class WorkerServicer(worker_pb2_grpc.WorkerServiceServicer):
|
|||||||
if job_id in _active_jobs:
|
if job_id in _active_jobs:
|
||||||
_active_jobs[job_id]["status"] = "cancelled"
|
_active_jobs[job_id]["status"] = "cancelled"
|
||||||
|
|
||||||
# Revoke Celery task if available
|
|
||||||
if self.celery_app:
|
|
||||||
task_id = _active_jobs[job_id].get("celery_task_id")
|
|
||||||
if task_id:
|
|
||||||
self.celery_app.control.revoke(task_id, terminate=True)
|
|
||||||
|
|
||||||
return worker_pb2.CancelResponse(
|
return worker_pb2.CancelResponse(
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
cancelled=True,
|
cancelled=True,
|
||||||
@@ -290,13 +258,12 @@ def update_job_progress(
|
|||||||
logger.warning(f"Failed to update job {job_id} in DB: {e}")
|
logger.warning(f"Failed to update job {job_id} in DB: {e}")
|
||||||
|
|
||||||
|
|
||||||
def serve(port: int = None, celery_app=None) -> grpc.Server:
|
def serve(port: int = None) -> grpc.Server:
|
||||||
"""
|
"""
|
||||||
Start the gRPC server.
|
Start the gRPC server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
port: Port to listen on (defaults to GRPC_PORT env var)
|
port: Port to listen on (defaults to GRPC_PORT env var)
|
||||||
celery_app: Optional Celery app for task dispatch
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The running gRPC server
|
The running gRPC server
|
||||||
@@ -306,7 +273,7 @@ def serve(port: int = None, celery_app=None) -> grpc.Server:
|
|||||||
|
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=GRPC_MAX_WORKERS))
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=GRPC_MAX_WORKERS))
|
||||||
worker_pb2_grpc.add_WorkerServiceServicer_to_server(
|
worker_pb2_grpc.add_WorkerServiceServicer_to_server(
|
||||||
WorkerServicer(celery_app=celery_app),
|
WorkerServicer(),
|
||||||
server,
|
server,
|
||||||
)
|
)
|
||||||
server.add_insecure_port(f"[::]:{port}")
|
server.add_insecure_port(f"[::]:{port}")
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generi
|
|||||||
from .inference import INFERENCE_VIEWS # noqa: F401 — GPU inference server API types
|
from .inference import INFERENCE_VIEWS # noqa: F401 — GPU inference server API types
|
||||||
from .ui_state import UI_STATE_VIEWS # noqa: F401 — UI store state types
|
from .ui_state import UI_STATE_VIEWS # noqa: F401 — UI store state types
|
||||||
from .stages import StageConfigField, StageIO, StageDefinition, STAGE_VIEWS # noqa: F401
|
from .stages import StageConfigField, StageIO, StageDefinition, STAGE_VIEWS # noqa: F401
|
||||||
|
from .detect_api import RunRequest, RunResponse, DETECT_API_VIEWS # noqa: F401
|
||||||
from .views import ChunkEvent, ChunkOutputFile, PipelineStats, WorkerEvent
|
from .views import ChunkEvent, ChunkOutputFile, PipelineStats, WorkerEvent
|
||||||
from .sources import ChunkInfo, SourceJob, SourceType
|
from .sources import ChunkInfo, SourceJob, SourceType
|
||||||
|
|
||||||
|
|||||||
31
core/schema/models/detect_api.py
Normal file
31
core/schema/models/detect_api.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""
|
||||||
|
Detection API request/response models.
|
||||||
|
|
||||||
|
Source of truth for detection pipeline API shapes.
|
||||||
|
Generated to Pydantic via modelgen.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RunRequest:
|
||||||
|
"""Request body for launching a detection pipeline run."""
|
||||||
|
video_path: str # storage key
|
||||||
|
profile_name: str = "soccer_broadcast"
|
||||||
|
source_asset_id: str = ""
|
||||||
|
checkpoint: bool = True
|
||||||
|
skip_vlm: bool = False
|
||||||
|
skip_cloud: bool = False
|
||||||
|
log_level: str = "INFO" # INFO | DEBUG
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RunResponse:
|
||||||
|
"""Response after starting a pipeline run."""
|
||||||
|
status: str
|
||||||
|
job_id: str
|
||||||
|
video_path: str
|
||||||
|
|
||||||
|
|
||||||
|
DETECT_API_VIEWS = [RunRequest, RunResponse]
|
||||||
@@ -56,7 +56,6 @@ class Job:
|
|||||||
estimated_cost_usd: float = 0.0
|
estimated_cost_usd: float = 0.0
|
||||||
|
|
||||||
# Worker tracking
|
# Worker tracking
|
||||||
celery_task_id: Optional[str] = None
|
|
||||||
priority: int = 0
|
priority: int = 0
|
||||||
|
|
||||||
# Timestamps
|
# Timestamps
|
||||||
|
|||||||
@@ -1,133 +0,0 @@
|
|||||||
"""
|
|
||||||
Job Schema Definitions
|
|
||||||
|
|
||||||
Source of truth for job data models.
|
|
||||||
TranscodeJob and ChunkJob share common lifecycle fields by convention.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
|
|
||||||
class JobStatus(str, Enum):
|
|
||||||
"""Status of a transcode/trim job."""
|
|
||||||
|
|
||||||
PENDING = "pending"
|
|
||||||
PROCESSING = "processing"
|
|
||||||
COMPLETED = "completed"
|
|
||||||
FAILED = "failed"
|
|
||||||
CANCELLED = "cancelled"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TranscodeJob:
|
|
||||||
"""
|
|
||||||
A transcoding or trimming job in the queue.
|
|
||||||
|
|
||||||
Jobs can either:
|
|
||||||
- Transcode using a preset (full re-encode)
|
|
||||||
- Trim only (stream copy with -c:v copy -c:a copy)
|
|
||||||
|
|
||||||
A trim-only job has no preset and uses stream copy.
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: UUID
|
|
||||||
|
|
||||||
# Input
|
|
||||||
source_asset_id: UUID
|
|
||||||
|
|
||||||
# Configuration
|
|
||||||
preset_id: Optional[UUID] = None
|
|
||||||
preset_snapshot: Dict[str, Any] = field(
|
|
||||||
default_factory=dict
|
|
||||||
) # Copy at creation time
|
|
||||||
|
|
||||||
# Trimming (optional)
|
|
||||||
trim_start: Optional[float] = None # seconds
|
|
||||||
trim_end: Optional[float] = None # seconds
|
|
||||||
|
|
||||||
# Output
|
|
||||||
output_filename: str = ""
|
|
||||||
output_path: Optional[str] = None
|
|
||||||
output_asset_id: Optional[UUID] = None
|
|
||||||
|
|
||||||
# Status & Progress
|
|
||||||
status: JobStatus = JobStatus.PENDING
|
|
||||||
progress: float = 0.0 # 0.0 to 100.0
|
|
||||||
current_frame: Optional[int] = None
|
|
||||||
current_time: Optional[float] = None # seconds processed
|
|
||||||
speed: Optional[str] = None # "2.5x"
|
|
||||||
error_message: Optional[str] = None
|
|
||||||
|
|
||||||
# Worker tracking
|
|
||||||
celery_task_id: Optional[str] = None
|
|
||||||
execution_arn: Optional[str] = None # AWS Step Functions execution ARN
|
|
||||||
priority: int = 0 # Lower = higher priority
|
|
||||||
|
|
||||||
# Timestamps
|
|
||||||
created_at: Optional[datetime] = None
|
|
||||||
started_at: Optional[datetime] = None
|
|
||||||
completed_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_trim_only(self) -> bool:
|
|
||||||
"""Check if this is a trim-only job (stream copy, no transcode)."""
|
|
||||||
return self.preset_id is None and (
|
|
||||||
self.trim_start is not None or self.trim_end is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkJobStatus(str, Enum):
|
|
||||||
"""Status of a chunk pipeline job."""
|
|
||||||
|
|
||||||
PENDING = "pending"
|
|
||||||
CHUNKING = "chunking"
|
|
||||||
PROCESSING = "processing"
|
|
||||||
COLLECTING = "collecting"
|
|
||||||
COMPLETED = "completed"
|
|
||||||
FAILED = "failed"
|
|
||||||
CANCELLED = "cancelled"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ChunkJob:
|
|
||||||
"""
|
|
||||||
A chunk pipeline job — splits a media file into chunks and processes them
|
|
||||||
through a concurrent worker pool.
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: UUID
|
|
||||||
|
|
||||||
# Input
|
|
||||||
source_asset_id: UUID
|
|
||||||
|
|
||||||
# Configuration
|
|
||||||
chunk_duration: float = 10.0 # seconds
|
|
||||||
num_workers: int = 4
|
|
||||||
max_retries: int = 3
|
|
||||||
processor_type: str = "ffmpeg" # "ffmpeg", "checksum", "simulated_decode", "composite"
|
|
||||||
|
|
||||||
# Status & Progress
|
|
||||||
status: ChunkJobStatus = ChunkJobStatus.PENDING
|
|
||||||
progress: float = 0.0 # 0.0 to 100.0
|
|
||||||
total_chunks: int = 0
|
|
||||||
processed_chunks: int = 0
|
|
||||||
failed_chunks: int = 0
|
|
||||||
retry_count: int = 0
|
|
||||||
error_message: Optional[str] = None
|
|
||||||
|
|
||||||
# Result stats
|
|
||||||
throughput_mbps: Optional[float] = None
|
|
||||||
elapsed_seconds: Optional[float] = None
|
|
||||||
|
|
||||||
# Worker tracking
|
|
||||||
celery_task_id: Optional[str] = None
|
|
||||||
priority: int = 0 # Lower = higher priority
|
|
||||||
|
|
||||||
# Timestamps
|
|
||||||
created_at: Optional[datetime] = None
|
|
||||||
started_at: Optional[datetime] = None
|
|
||||||
completed_at: Optional[datetime] = None
|
|
||||||
@@ -5,7 +5,6 @@ Checkpoint system — Timeline + Checkpoint tree.
|
|||||||
frames.py — frame image S3 upload/download
|
frames.py — frame image S3 upload/download
|
||||||
storage.py — Timeline + Checkpoint (Postgres + MinIO)
|
storage.py — Timeline + Checkpoint (Postgres + MinIO)
|
||||||
replay.py — replay (TODO: migrate to new model)
|
replay.py — replay (TODO: migrate to new model)
|
||||||
tasks.py — retry_candidates Celery task
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .storage import (
|
from .storage import (
|
||||||
|
|||||||
@@ -1,71 +0,0 @@
|
|||||||
"""
|
|
||||||
Celery tasks for detection pipeline async operations.
|
|
||||||
|
|
||||||
retry_candidates: re-run VLM/cloud escalation with different config.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
from celery import shared_task
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@shared_task(bind=True, max_retries=1, default_retry_delay=30)
|
|
||||||
def retry_candidates(
|
|
||||||
self,
|
|
||||||
job_id: str,
|
|
||||||
config_overrides: dict | None = None,
|
|
||||||
start_stage: str = "escalate_vlm",
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Retry unresolved candidates with different config.
|
|
||||||
|
|
||||||
Loads the checkpoint from the stage before start_stage,
|
|
||||||
applies config overrides (e.g. different cloud provider),
|
|
||||||
and runs from start_stage onward.
|
|
||||||
"""
|
|
||||||
from detect.checkpoint.replay import replay_from
|
|
||||||
|
|
||||||
run_id = str(uuid.uuid4())[:8]
|
|
||||||
logger.info("Retry task %s: job=%s, from=%s, overrides=%s",
|
|
||||||
run_id, job_id, start_stage, config_overrides)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = replay_from(
|
|
||||||
job_id=job_id,
|
|
||||||
start_stage=start_stage,
|
|
||||||
config_overrides=config_overrides,
|
|
||||||
)
|
|
||||||
|
|
||||||
detections = result.get("detections", [])
|
|
||||||
report = result.get("report")
|
|
||||||
brands_found = len(report.brands) if report else 0
|
|
||||||
|
|
||||||
logger.info("Retry %s complete: %d detections, %d brands",
|
|
||||||
run_id, len(detections), brands_found)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "completed",
|
|
||||||
"run_id": run_id,
|
|
||||||
"job_id": job_id,
|
|
||||||
"detections": len(detections),
|
|
||||||
"brands_found": brands_found,
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Retry %s failed: %s", run_id, e)
|
|
||||||
|
|
||||||
if self.request.retries < self.max_retries:
|
|
||||||
raise self.retry(exc=e)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "failed",
|
|
||||||
"run_id": run_id,
|
|
||||||
"job_id": job_id,
|
|
||||||
"error": str(e),
|
|
||||||
}
|
|
||||||
29
detect/graph/__init__.py
Normal file
29
detect/graph/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""
|
||||||
|
Detection pipeline graph.
|
||||||
|
|
||||||
|
detect/graph/
|
||||||
|
nodes.py — node functions (one per stage)
|
||||||
|
events.py — graph_update SSE emission
|
||||||
|
runner.py — pipeline execution (LangGraph wrapper, checkpoint, cancel)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .nodes import NODES, NODE_FUNCTIONS
|
||||||
|
from .runner import (
|
||||||
|
PipelineCancelled,
|
||||||
|
build_graph,
|
||||||
|
clear_cancel_check,
|
||||||
|
get_pipeline,
|
||||||
|
set_cancel_check,
|
||||||
|
)
|
||||||
|
from .events import _node_states
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"NODES",
|
||||||
|
"NODE_FUNCTIONS",
|
||||||
|
"PipelineCancelled",
|
||||||
|
"build_graph",
|
||||||
|
"get_pipeline",
|
||||||
|
"set_cancel_check",
|
||||||
|
"clear_cancel_check",
|
||||||
|
"_node_states",
|
||||||
|
]
|
||||||
27
detect/graph/events.py
Normal file
27
detect/graph/events.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
"""
|
||||||
|
Graph event emission — node state tracking + SSE graph_update events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from detect import emit
|
||||||
|
from detect.state import DetectState
|
||||||
|
|
||||||
|
|
||||||
|
# Track node states across pipeline runs
|
||||||
|
_node_states: dict[str, dict[str, str]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def emit_transition(state: DetectState, node: str, status: str, node_list: list[str]):
|
||||||
|
"""Update node status and emit graph_update SSE event."""
|
||||||
|
job_id = state.get("job_id")
|
||||||
|
if not job_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
if job_id not in _node_states:
|
||||||
|
_node_states[job_id] = {n: "pending" for n in node_list}
|
||||||
|
|
||||||
|
_node_states[job_id][node] = status
|
||||||
|
|
||||||
|
nodes = [{"id": n, "status": _node_states[job_id][n]} for n in node_list]
|
||||||
|
emit.graph_update(job_id, nodes)
|
||||||
@@ -1,16 +1,13 @@
|
|||||||
"""
|
"""
|
||||||
LangGraph pipeline graph for brand detection.
|
Pipeline node functions — one per stage.
|
||||||
|
|
||||||
Nodes execute real logic for extract+filter, stubs for the rest.
|
Each node: reads state, runs stage logic, emits transitions, returns output dict.
|
||||||
Each node emits graph_update events so the UI can visualize transitions.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from langgraph.graph import END, StateGraph
|
|
||||||
|
|
||||||
from detect import emit
|
from detect import emit
|
||||||
from detect.models import PipelineStats
|
from detect.models import PipelineStats
|
||||||
from detect.profiles import SoccerBroadcastProfile
|
from detect.profiles import SoccerBroadcastProfile
|
||||||
@@ -27,6 +24,8 @@ from detect.stages.vlm_cloud import escalate_cloud
|
|||||||
from detect.stages.aggregator import compile_report
|
from detect.stages.aggregator import compile_report
|
||||||
from detect.tracing import trace_node, flush as flush_traces
|
from detect.tracing import trace_node, flush as flush_traces
|
||||||
|
|
||||||
|
from .events import emit_transition
|
||||||
|
|
||||||
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
|
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
|
||||||
|
|
||||||
NODES = [
|
NODES = [
|
||||||
@@ -58,41 +57,24 @@ def _get_profile(state: DetectState):
|
|||||||
return profile
|
return profile
|
||||||
|
|
||||||
|
|
||||||
# Track node states across the pipeline run
|
def _emit(state, node, status):
|
||||||
_node_states: dict[str, dict[str, str]] = {}
|
emit_transition(state, node, status, NODES)
|
||||||
|
|
||||||
|
|
||||||
def _emit_transition(state: DetectState, node: str, status: str):
|
|
||||||
job_id = state.get("job_id")
|
|
||||||
if not job_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Initialize state tracking for this job
|
|
||||||
if job_id not in _node_states:
|
|
||||||
_node_states[job_id] = {n: "pending" for n in NODES}
|
|
||||||
|
|
||||||
_node_states[job_id][node] = status
|
|
||||||
|
|
||||||
nodes = [{"id": n, "status": _node_states[job_id][n]} for n in NODES]
|
|
||||||
emit.graph_update(job_id, nodes)
|
|
||||||
|
|
||||||
|
|
||||||
# --- Node functions ---
|
# --- Node functions ---
|
||||||
|
|
||||||
def node_extract_frames(state: DetectState) -> dict:
|
def node_extract_frames(state: DetectState) -> dict:
|
||||||
# Set run context for initial runs (replays set it in replay_from)
|
|
||||||
job_id = state.get("job_id", "")
|
job_id = state.get("job_id", "")
|
||||||
if job_id and not emit._run_context:
|
if job_id and not emit._run_context:
|
||||||
emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial")
|
emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial")
|
||||||
|
|
||||||
# Load session brands from DB for this source
|
|
||||||
source_asset_id = state.get("source_asset_id")
|
source_asset_id = state.get("source_asset_id")
|
||||||
if source_asset_id and not state.get("session_brands"):
|
if source_asset_id and not state.get("session_brands"):
|
||||||
from detect.stages.brand_resolver import build_session_dict
|
from detect.stages.brand_resolver import build_session_dict
|
||||||
session_brands = build_session_dict(source_asset_id)
|
session_brands = build_session_dict(source_asset_id)
|
||||||
state["session_brands"] = session_brands
|
state["session_brands"] = session_brands
|
||||||
|
|
||||||
_emit_transition(state, "extract_frames", "running")
|
_emit(state, "extract_frames", "running")
|
||||||
|
|
||||||
with trace_node(state, "extract_frames") as span:
|
with trace_node(state, "extract_frames") as span:
|
||||||
profile = _get_profile(state)
|
profile = _get_profile(state)
|
||||||
@@ -100,12 +82,12 @@ def node_extract_frames(state: DetectState) -> dict:
|
|||||||
frames = extract_frames(state["video_path"], config, job_id=state.get("job_id"))
|
frames = extract_frames(state["video_path"], config, job_id=state.get("job_id"))
|
||||||
span.set_output({"frames_extracted": len(frames)})
|
span.set_output({"frames_extracted": len(frames)})
|
||||||
|
|
||||||
_emit_transition(state, "extract_frames", "done")
|
_emit(state, "extract_frames", "done")
|
||||||
return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))}
|
return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))}
|
||||||
|
|
||||||
|
|
||||||
def node_filter_scenes(state: DetectState) -> dict:
|
def node_filter_scenes(state: DetectState) -> dict:
|
||||||
_emit_transition(state, "filter_scenes", "running")
|
_emit(state, "filter_scenes", "running")
|
||||||
|
|
||||||
with trace_node(state, "filter_scenes") as span:
|
with trace_node(state, "filter_scenes") as span:
|
||||||
profile = _get_profile(state)
|
profile = _get_profile(state)
|
||||||
@@ -117,12 +99,12 @@ def node_filter_scenes(state: DetectState) -> dict:
|
|||||||
stats = state.get("stats", PipelineStats())
|
stats = state.get("stats", PipelineStats())
|
||||||
stats.frames_after_scene_filter = len(kept)
|
stats.frames_after_scene_filter = len(kept)
|
||||||
|
|
||||||
_emit_transition(state, "filter_scenes", "done")
|
_emit(state, "filter_scenes", "done")
|
||||||
return {"filtered_frames": kept, "stats": stats}
|
return {"filtered_frames": kept, "stats": stats}
|
||||||
|
|
||||||
|
|
||||||
def node_detect_edges(state: DetectState) -> dict:
|
def node_detect_edges(state: DetectState) -> dict:
|
||||||
_emit_transition(state, "detect_edges", "running")
|
_emit(state, "detect_edges", "running")
|
||||||
|
|
||||||
with trace_node(state, "detect_edges") as span:
|
with trace_node(state, "detect_edges") as span:
|
||||||
profile = _get_profile(state)
|
profile = _get_profile(state)
|
||||||
@@ -139,12 +121,12 @@ def node_detect_edges(state: DetectState) -> dict:
|
|||||||
stats = state.get("stats", PipelineStats())
|
stats = state.get("stats", PipelineStats())
|
||||||
stats.cv_regions_detected = total
|
stats.cv_regions_detected = total
|
||||||
|
|
||||||
_emit_transition(state, "detect_edges", "done")
|
_emit(state, "detect_edges", "done")
|
||||||
return {"edge_regions_by_frame": regions, "stats": stats}
|
return {"edge_regions_by_frame": regions, "stats": stats}
|
||||||
|
|
||||||
|
|
||||||
def node_detect_objects(state: DetectState) -> dict:
|
def node_detect_objects(state: DetectState) -> dict:
|
||||||
_emit_transition(state, "detect_objects", "running")
|
_emit(state, "detect_objects", "running")
|
||||||
|
|
||||||
with trace_node(state, "detect_objects") as span:
|
with trace_node(state, "detect_objects") as span:
|
||||||
profile = _get_profile(state)
|
profile = _get_profile(state)
|
||||||
@@ -159,12 +141,12 @@ def node_detect_objects(state: DetectState) -> dict:
|
|||||||
stats = state.get("stats", PipelineStats())
|
stats = state.get("stats", PipelineStats())
|
||||||
stats.regions_detected = total_regions
|
stats.regions_detected = total_regions
|
||||||
|
|
||||||
_emit_transition(state, "detect_objects", "done")
|
_emit(state, "detect_objects", "done")
|
||||||
return {"boxes_by_frame": all_boxes, "stats": stats}
|
return {"boxes_by_frame": all_boxes, "stats": stats}
|
||||||
|
|
||||||
|
|
||||||
def node_preprocess(state: DetectState) -> dict:
|
def node_preprocess(state: DetectState) -> dict:
|
||||||
_emit_transition(state, "preprocess", "running")
|
_emit(state, "preprocess", "running")
|
||||||
|
|
||||||
with trace_node(state, "preprocess") as span:
|
with trace_node(state, "preprocess") as span:
|
||||||
profile = _get_profile(state)
|
profile = _get_profile(state)
|
||||||
@@ -172,7 +154,6 @@ def node_preprocess(state: DetectState) -> dict:
|
|||||||
boxes = state.get("boxes_by_frame", {})
|
boxes = state.get("boxes_by_frame", {})
|
||||||
job_id = state.get("job_id")
|
job_id = state.get("job_id")
|
||||||
|
|
||||||
# Get preprocessing config from profile overrides or defaults
|
|
||||||
overrides = state.get("config_overrides", {})
|
overrides = state.get("config_overrides", {})
|
||||||
prep_config = overrides.get("preprocessing", {})
|
prep_config = overrides.get("preprocessing", {})
|
||||||
do_contrast = prep_config.get("contrast", True)
|
do_contrast = prep_config.get("contrast", True)
|
||||||
@@ -189,12 +170,12 @@ def node_preprocess(state: DetectState) -> dict:
|
|||||||
)
|
)
|
||||||
span.set_output({"regions_preprocessed": len(preprocessed)})
|
span.set_output({"regions_preprocessed": len(preprocessed)})
|
||||||
|
|
||||||
_emit_transition(state, "preprocess", "done")
|
_emit(state, "preprocess", "done")
|
||||||
return {"preprocessed_crops": preprocessed}
|
return {"preprocessed_crops": preprocessed}
|
||||||
|
|
||||||
|
|
||||||
def node_run_ocr(state: DetectState) -> dict:
|
def node_run_ocr(state: DetectState) -> dict:
|
||||||
_emit_transition(state, "run_ocr", "running")
|
_emit(state, "run_ocr", "running")
|
||||||
|
|
||||||
with trace_node(state, "run_ocr") as span:
|
with trace_node(state, "run_ocr") as span:
|
||||||
profile = _get_profile(state)
|
profile = _get_profile(state)
|
||||||
@@ -209,12 +190,12 @@ def node_run_ocr(state: DetectState) -> dict:
|
|||||||
stats = state.get("stats", PipelineStats())
|
stats = state.get("stats", PipelineStats())
|
||||||
stats.regions_resolved_by_ocr = len(candidates)
|
stats.regions_resolved_by_ocr = len(candidates)
|
||||||
|
|
||||||
_emit_transition(state, "run_ocr", "done")
|
_emit(state, "run_ocr", "done")
|
||||||
return {"text_candidates": candidates, "stats": stats}
|
return {"text_candidates": candidates, "stats": stats}
|
||||||
|
|
||||||
|
|
||||||
def node_match_brands(state: DetectState) -> dict:
|
def node_match_brands(state: DetectState) -> dict:
|
||||||
_emit_transition(state, "match_brands", "running")
|
_emit(state, "match_brands", "running")
|
||||||
|
|
||||||
with trace_node(state, "match_brands") as span:
|
with trace_node(state, "match_brands") as span:
|
||||||
profile = _get_profile(state)
|
profile = _get_profile(state)
|
||||||
@@ -232,12 +213,12 @@ def node_match_brands(state: DetectState) -> dict:
|
|||||||
)
|
)
|
||||||
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
|
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
|
||||||
|
|
||||||
_emit_transition(state, "match_brands", "done")
|
_emit(state, "match_brands", "done")
|
||||||
return {"detections": matched, "unresolved_candidates": unresolved}
|
return {"detections": matched, "unresolved_candidates": unresolved}
|
||||||
|
|
||||||
|
|
||||||
def node_escalate_vlm(state: DetectState) -> dict:
|
def node_escalate_vlm(state: DetectState) -> dict:
|
||||||
_emit_transition(state, "escalate_vlm", "running")
|
_emit(state, "escalate_vlm", "running")
|
||||||
|
|
||||||
with trace_node(state, "escalate_vlm") as span:
|
with trace_node(state, "escalate_vlm") as span:
|
||||||
profile = _get_profile(state)
|
profile = _get_profile(state)
|
||||||
@@ -261,7 +242,7 @@ def node_escalate_vlm(state: DetectState) -> dict:
|
|||||||
existing = state.get("detections", [])
|
existing = state.get("detections", [])
|
||||||
|
|
||||||
vlm_skipped = os.environ.get("SKIP_VLM", "").strip() == "1"
|
vlm_skipped = os.environ.get("SKIP_VLM", "").strip() == "1"
|
||||||
_emit_transition(state, "escalate_vlm", "skipped" if vlm_skipped else "done")
|
_emit(state, "escalate_vlm", "skipped" if vlm_skipped else "done")
|
||||||
return {
|
return {
|
||||||
"detections": existing + vlm_matched,
|
"detections": existing + vlm_matched,
|
||||||
"unresolved_candidates": still_unresolved,
|
"unresolved_candidates": still_unresolved,
|
||||||
@@ -270,7 +251,7 @@ def node_escalate_vlm(state: DetectState) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
def node_escalate_cloud(state: DetectState) -> dict:
|
def node_escalate_cloud(state: DetectState) -> dict:
|
||||||
_emit_transition(state, "escalate_cloud", "running")
|
_emit(state, "escalate_cloud", "running")
|
||||||
|
|
||||||
with trace_node(state, "escalate_cloud") as span:
|
with trace_node(state, "escalate_cloud") as span:
|
||||||
profile = _get_profile(state)
|
profile = _get_profile(state)
|
||||||
@@ -294,12 +275,12 @@ def node_escalate_cloud(state: DetectState) -> dict:
|
|||||||
existing = state.get("detections", [])
|
existing = state.get("detections", [])
|
||||||
|
|
||||||
cloud_skipped = os.environ.get("SKIP_CLOUD", "").strip() == "1"
|
cloud_skipped = os.environ.get("SKIP_CLOUD", "").strip() == "1"
|
||||||
_emit_transition(state, "escalate_cloud", "skipped" if cloud_skipped else "done")
|
_emit(state, "escalate_cloud", "skipped" if cloud_skipped else "done")
|
||||||
return {"detections": existing + cloud_matched, "stats": stats}
|
return {"detections": existing + cloud_matched, "stats": stats}
|
||||||
|
|
||||||
|
|
||||||
def node_compile_report(state: DetectState) -> dict:
|
def node_compile_report(state: DetectState) -> dict:
|
||||||
_emit_transition(state, "compile_report", "running")
|
_emit(state, "compile_report", "running")
|
||||||
|
|
||||||
with trace_node(state, "compile_report") as span:
|
with trace_node(state, "compile_report") as span:
|
||||||
profile = _get_profile(state)
|
profile = _get_profile(state)
|
||||||
@@ -318,85 +299,10 @@ def node_compile_report(state: DetectState) -> dict:
|
|||||||
span.set_output({"brands": len(report.brands), "detections": len(report.timeline)})
|
span.set_output({"brands": len(report.brands), "detections": len(report.timeline)})
|
||||||
|
|
||||||
flush_traces()
|
flush_traces()
|
||||||
_emit_transition(state, "compile_report", "done")
|
_emit(state, "compile_report", "done")
|
||||||
return {"report": report}
|
return {"report": report}
|
||||||
|
|
||||||
|
|
||||||
# --- Checkpoint wrapper ---
|
|
||||||
|
|
||||||
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
|
|
||||||
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
|
|
||||||
_latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineCancelled(Exception):
|
|
||||||
"""Raised when a pipeline run is cancelled."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# Cancellation hook — set by the run endpoint, checked before each node
|
|
||||||
_cancel_check: dict[str, callable] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def set_cancel_check(job_id: str, fn):
|
|
||||||
_cancel_check[job_id] = fn
|
|
||||||
|
|
||||||
|
|
||||||
def clear_cancel_check(job_id: str):
|
|
||||||
_cancel_check.pop(job_id, None)
|
|
||||||
|
|
||||||
|
|
||||||
def _checkpointing_node(node_name: str, node_fn):
|
|
||||||
"""Wrap a node function to auto-checkpoint after completion."""
|
|
||||||
stage_index = NODES.index(node_name)
|
|
||||||
|
|
||||||
def wrapper(state: DetectState) -> dict:
|
|
||||||
job_id = state.get("job_id", "")
|
|
||||||
check = _cancel_check.get(job_id)
|
|
||||||
if check and check():
|
|
||||||
raise PipelineCancelled(f"Cancelled before {node_name}")
|
|
||||||
|
|
||||||
result = node_fn(state)
|
|
||||||
|
|
||||||
job_id = state.get("job_id", "")
|
|
||||||
if not job_id:
|
|
||||||
return result
|
|
||||||
|
|
||||||
from detect.checkpoint import save_stage_output, save_frames
|
|
||||||
from detect.stages.base import _REGISTRY
|
|
||||||
|
|
||||||
merged = {**state, **result}
|
|
||||||
|
|
||||||
# Save frames once (first node), reuse manifest after
|
|
||||||
manifest = _frames_manifest.get(job_id)
|
|
||||||
if manifest is None and node_name == "extract_frames":
|
|
||||||
manifest = save_frames(job_id, merged.get("frames", []))
|
|
||||||
_frames_manifest[job_id] = manifest
|
|
||||||
|
|
||||||
# Serialize stage output using the stage's serialize_fn if available
|
|
||||||
stage_cls = _REGISTRY.get(node_name)
|
|
||||||
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
|
||||||
if serialize_fn:
|
|
||||||
output_json = serialize_fn(merged, job_id)
|
|
||||||
else:
|
|
||||||
output_json = {}
|
|
||||||
|
|
||||||
parent_id = _latest_checkpoint.get(job_id)
|
|
||||||
new_checkpoint_id = save_stage_output(
|
|
||||||
timeline_id=job_id,
|
|
||||||
parent_checkpoint_id=parent_id,
|
|
||||||
stage_name=node_name,
|
|
||||||
output_json=output_json,
|
|
||||||
)
|
|
||||||
_latest_checkpoint[job_id] = new_checkpoint_id
|
|
||||||
return result
|
|
||||||
|
|
||||||
wrapper.__name__ = node_fn.__name__
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
# --- Graph construction ---
|
|
||||||
|
|
||||||
NODE_FUNCTIONS = [
|
NODE_FUNCTIONS = [
|
||||||
("extract_frames", node_extract_frames),
|
("extract_frames", node_extract_frames),
|
||||||
("filter_scenes", node_filter_scenes),
|
("filter_scenes", node_filter_scenes),
|
||||||
@@ -409,41 +315,3 @@ NODE_FUNCTIONS = [
|
|||||||
("escalate_cloud", node_escalate_cloud),
|
("escalate_cloud", node_escalate_cloud),
|
||||||
("compile_report", node_compile_report),
|
("compile_report", node_compile_report),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph:
|
|
||||||
"""
|
|
||||||
Build the pipeline graph.
|
|
||||||
|
|
||||||
checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var)
|
|
||||||
start_from: skip nodes before this stage (for replay)
|
|
||||||
"""
|
|
||||||
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
|
|
||||||
|
|
||||||
graph = StateGraph(DetectState)
|
|
||||||
|
|
||||||
# Filter to start_from if replaying
|
|
||||||
node_pairs = NODE_FUNCTIONS
|
|
||||||
if start_from:
|
|
||||||
start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from)
|
|
||||||
node_pairs = NODE_FUNCTIONS[start_idx:]
|
|
||||||
|
|
||||||
for name, fn in node_pairs:
|
|
||||||
wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn
|
|
||||||
graph.add_node(name, wrapped)
|
|
||||||
|
|
||||||
# Wire edges
|
|
||||||
entry = node_pairs[0][0]
|
|
||||||
graph.set_entry_point(entry)
|
|
||||||
|
|
||||||
for i in range(len(node_pairs) - 1):
|
|
||||||
graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0])
|
|
||||||
|
|
||||||
graph.add_edge(node_pairs[-1][0], END)
|
|
||||||
|
|
||||||
return graph
|
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline(checkpoint: bool | None = None):
|
|
||||||
"""Return a compiled, runnable pipeline."""
|
|
||||||
return build_graph(checkpoint=checkpoint).compile()
|
|
||||||
127
detect/graph/runner.py
Normal file
127
detect/graph/runner.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
"""
|
||||||
|
Pipeline runner — executes stages sequentially with checkpointing and cancellation.
|
||||||
|
|
||||||
|
Currently wraps LangGraph for execution. Will be replaced with a lean
|
||||||
|
custom runner in Phase 3, with an executor socket for distributed dispatch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from langgraph.graph import END, StateGraph
|
||||||
|
|
||||||
|
from detect.state import DetectState
|
||||||
|
from .nodes import NODES, NODE_FUNCTIONS
|
||||||
|
|
||||||
|
|
||||||
|
# --- Checkpoint wrapper ---
|
||||||
|
|
||||||
|
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
|
||||||
|
_frames_manifest: dict[str, dict[int, str]] = {} # job_id → manifest (cached per job)
|
||||||
|
_latest_checkpoint: dict[str, str] = {} # job_id → latest checkpoint_id
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineCancelled(Exception):
|
||||||
|
"""Raised when a pipeline run is cancelled."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Cancellation hook — set by the run endpoint, checked before each node
|
||||||
|
_cancel_check: dict[str, callable] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def set_cancel_check(job_id: str, fn):
|
||||||
|
_cancel_check[job_id] = fn
|
||||||
|
|
||||||
|
|
||||||
|
def clear_cancel_check(job_id: str):
|
||||||
|
_cancel_check.pop(job_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _checkpointing_node(node_name: str, node_fn):
|
||||||
|
"""Wrap a node function to auto-checkpoint after completion."""
|
||||||
|
|
||||||
|
def wrapper(state: DetectState) -> dict:
|
||||||
|
job_id = state.get("job_id", "")
|
||||||
|
check = _cancel_check.get(job_id)
|
||||||
|
if check and check():
|
||||||
|
raise PipelineCancelled(f"Cancelled before {node_name}")
|
||||||
|
|
||||||
|
result = node_fn(state)
|
||||||
|
|
||||||
|
job_id = state.get("job_id", "")
|
||||||
|
if not job_id:
|
||||||
|
return result
|
||||||
|
|
||||||
|
from detect.checkpoint import save_stage_output, save_frames
|
||||||
|
from detect.stages.base import _REGISTRY
|
||||||
|
|
||||||
|
merged = {**state, **result}
|
||||||
|
|
||||||
|
# Save frames once (first node), reuse manifest after
|
||||||
|
manifest = _frames_manifest.get(job_id)
|
||||||
|
if manifest is None and node_name == "extract_frames":
|
||||||
|
manifest = save_frames(job_id, merged.get("frames", []))
|
||||||
|
_frames_manifest[job_id] = manifest
|
||||||
|
|
||||||
|
# Serialize stage output using the stage's serialize_fn if available
|
||||||
|
stage_cls = _REGISTRY.get(node_name)
|
||||||
|
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
||||||
|
if serialize_fn:
|
||||||
|
output_json = serialize_fn(merged, job_id)
|
||||||
|
else:
|
||||||
|
output_json = {}
|
||||||
|
|
||||||
|
parent_id = _latest_checkpoint.get(job_id)
|
||||||
|
new_checkpoint_id = save_stage_output(
|
||||||
|
timeline_id=job_id,
|
||||||
|
parent_checkpoint_id=parent_id,
|
||||||
|
stage_name=node_name,
|
||||||
|
output_json=output_json,
|
||||||
|
)
|
||||||
|
_latest_checkpoint[job_id] = new_checkpoint_id
|
||||||
|
return result
|
||||||
|
|
||||||
|
wrapper.__name__ = node_fn.__name__
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
# --- Graph construction ---
|
||||||
|
|
||||||
|
def build_graph(checkpoint: bool | None = None, start_from: str | None = None) -> StateGraph:
|
||||||
|
"""
|
||||||
|
Build the pipeline graph.
|
||||||
|
|
||||||
|
checkpoint: enable auto-checkpointing (default: MPR_CHECKPOINT env var)
|
||||||
|
start_from: skip nodes before this stage (for replay)
|
||||||
|
"""
|
||||||
|
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
|
||||||
|
|
||||||
|
graph = StateGraph(DetectState)
|
||||||
|
|
||||||
|
# Filter to start_from if replaying
|
||||||
|
node_pairs = NODE_FUNCTIONS
|
||||||
|
if start_from:
|
||||||
|
start_idx = next(i for i, (name, _) in enumerate(NODE_FUNCTIONS) if name == start_from)
|
||||||
|
node_pairs = NODE_FUNCTIONS[start_idx:]
|
||||||
|
|
||||||
|
for name, fn in node_pairs:
|
||||||
|
wrapped = _checkpointing_node(name, fn) if do_checkpoint else fn
|
||||||
|
graph.add_node(name, wrapped)
|
||||||
|
|
||||||
|
# Wire edges
|
||||||
|
entry = node_pairs[0][0]
|
||||||
|
graph.set_entry_point(entry)
|
||||||
|
|
||||||
|
for i in range(len(node_pairs) - 1):
|
||||||
|
graph.add_edge(node_pairs[i][0], node_pairs[i + 1][0])
|
||||||
|
|
||||||
|
graph.add_edge(node_pairs[-1][0], END)
|
||||||
|
|
||||||
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
def get_pipeline(checkpoint: bool | None = None):
|
||||||
|
"""Return a compiled, runnable pipeline."""
|
||||||
|
return build_graph(checkpoint=checkpoint).compile()
|
||||||
Reference in New Issue
Block a user