This commit is contained in:
2026-03-30 07:22:14 -03:00
parent d0707333fd
commit 4220b0418e
182 changed files with 3668 additions and 5231 deletions

View File

@@ -102,6 +102,7 @@ class Job(models.Model):
source_asset_id = models.UUIDField()
video_path = models.CharField(max_length=1000)
profile_name = models.CharField(max_length=255)
timeline_id = models.UUIDField(null=True, blank=True)
parent_id = models.UUIDField(null=True, blank=True)
run_type = models.CharField(max_length=20, choices=RunType.choices, default=RunType.INITIAL)
config_overrides = models.JSONField(default=dict, blank=True)
@@ -113,7 +114,6 @@ class Job(models.Model):
brands_found = models.IntegerField(default=0)
cloud_llm_calls = models.IntegerField(default=0)
estimated_cost_usd = models.FloatField(default=0.0)
celery_task_id = models.CharField(max_length=255, null=True, blank=True)
priority = models.IntegerField(default=0)
created_at = models.DateTimeField(auto_now_add=True)
started_at = models.DateTimeField(null=True, blank=True)
@@ -151,6 +151,7 @@ class Checkpoint(models.Model):
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
timeline_id = models.UUIDField()
job_id = models.UUIDField(null=True, blank=True)
parent_id = models.UUIDField(null=True, blank=True)
stage_outputs = models.JSONField(default=dict, blank=True)
config_overrides = models.JSONField(default=dict, blank=True)
@@ -185,3 +186,18 @@ class Brand(models.Model):
def __str__(self):
return str(self.id)
class Profile(models.Model):
"""A content type profile."""
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
name = models.CharField(max_length=255)
pipeline = models.JSONField(default=dict, blank=True)
configs = models.JSONField(default=dict, blank=True)
class Meta:
pass
def __str__(self):
return self.name

View File

@@ -1,73 +0,0 @@
"""
SSE endpoint for chunker pipeline events.
Uses Redis as the event bus. Pipeline pushes events via core.events,
SSE endpoint polls them.
GET /chunker/stream/{job_id} → text/event-stream
"""
import asyncio
import json
import logging
import time
from typing import AsyncGenerator
from fastapi import APIRouter
from starlette.responses import StreamingResponse
from core.events import poll_events
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/chunker", tags=["chunker"])
async def _event_generator(job_id: str) -> AsyncGenerator[str, None]:
"""
Generate SSE events by polling Redis for chunk job events.
"""
cursor = 0
timeout = time.monotonic() + 600 # 10 min max
while time.monotonic() < timeout:
events, cursor = poll_events(job_id, cursor)
if not events:
yield f"event: waiting\ndata: {json.dumps({'job_id': job_id})}\n\n"
await asyncio.sleep(0.1)
continue
for data in events:
event_type = data.pop("event", "update")
payload = {**data, "job_id": job_id}
yield f"event: {event_type}\ndata: {json.dumps(payload)}\n\n"
if event_type in ("pipeline_complete", "pipeline_error", "cancelled"):
yield f"event: done\ndata: {json.dumps({'job_id': job_id})}\n\n"
return
await asyncio.sleep(0.05)
yield f"event: timeout\ndata: {json.dumps({'job_id': job_id})}\n\n"
@router.get("/stream/{job_id}")
async def stream_chunk_job(job_id: str):
"""
SSE stream for a chunk pipeline job.
The UI connects via native EventSource:
const es = new EventSource('/api/chunker/stream/<job_id>');
es.addEventListener('processing', (e) => { ... });
"""
return StreamingResponse(
_event_generator(job_id),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)

View File

@@ -58,32 +58,30 @@ def write_config(update: ConfigUpdate):
@router.get("/config/profiles")
def list_profiles():
def get_profiles():
"""List available detection profiles."""
from detect.profiles import _PROFILES
return [{"name": name} for name in _PROFILES]
from core.detect.profile import list_profiles as _list
return [{"name": name} for name in _list()]
@router.get("/config/profiles/{profile_name}/pipeline")
def get_pipeline_config(profile_name: str):
"""Return the pipeline composition for a profile."""
from detect.profiles import get_profile
from core.detect.profile import get_profile
from fastapi import HTTPException
from dataclasses import asdict
try:
profile = get_profile(profile_name)
except ValueError:
raise HTTPException(status_code=404, detail=f"Unknown profile: {profile_name}")
config = profile.pipeline_config()
return asdict(config)
return profile["pipeline"]
@router.get("/config/stages", response_model=list[StageConfigInfo])
def list_stage_configs():
"""Return the stage palette with config field metadata for the editor."""
from detect.stages import list_stages
from core.detect.stages import list_stages
result = []
for stage in list_stages():
@@ -95,7 +93,7 @@ def list_stage_configs():
@router.get("/config/stages/{stage_name}", response_model=StageConfigInfo)
def get_stage_config(stage_name: str):
"""Return config field metadata for a single stage."""
from detect.stages import get_stage
from core.detect.stages import get_stage
try:
stage = get_stage(stage_name)

View File

@@ -105,7 +105,7 @@ class ReplaySingleStageResponse(BaseModel):
@router.get("/checkpoints/{timeline_id}")
def list_checkpoints(timeline_id: str) -> list[CheckpointInfo]:
"""List available checkpoint stages for a job."""
from detect.checkpoint import list_checkpoints as _list
from core.detect.checkpoint import list_checkpoints as _list
try:
stages = _list(timeline_id)
@@ -139,10 +139,10 @@ class CheckpointData(BaseModel):
def get_checkpoint_data(timeline_id: str, stage: str):
"""Load checkpoint frames + metadata for the editor UI."""
from uuid import UUID
from core.db.tables import Timeline, Checkpoint
from core.db.models import Timeline, Checkpoint
from core.db.connection import get_session
from core.db.checkpoint import list_checkpoints
from detect.checkpoint.frames import load_frames_b64
from core.detect.checkpoint.frames import load_frames_b64
with get_session() as session:
timeline = session.get(Timeline, UUID(timeline_id))
@@ -184,7 +184,7 @@ def get_checkpoint_data(timeline_id: str, stage: str):
@router.get("/scenarios", response_model=list[ScenarioInfo])
def list_scenarios_endpoint():
"""List all available scenarios (bookmarked checkpoints)."""
from core.db.tables import Timeline
from core.db.models import Timeline
from core.db.connection import get_session
from core.db.checkpoint import list_scenarios
@@ -212,7 +212,7 @@ def list_scenarios_endpoint():
@router.post("/replay", response_model=ReplayResponse)
def replay(req: ReplayRequest):
"""Replay pipeline from a specific stage with optional config overrides."""
from detect.checkpoint import replay_from
from core.detect.checkpoint import replay_from
try:
result = replay_from(
@@ -242,7 +242,7 @@ def replay(req: ReplayRequest):
@router.post("/retry", response_model=RetryResponse)
def retry(req: RetryRequest):
"""Queue an async retry of unresolved candidates with different config."""
from detect.checkpoint.tasks import retry_candidates
from core.detect.checkpoint.tasks import retry_candidates
kwargs = {
"timeline_id": req.timeline_id,
@@ -266,7 +266,7 @@ def retry(req: RetryRequest):
@router.post("/replay-stage", response_model=ReplaySingleStageResponse)
def replay_single_stage(req: ReplaySingleStageRequest):
"""Replay a single stage on specific frames — fast path for interactive tuning."""
from detect.checkpoint.replay import replay_single_stage as _replay
from core.detect.checkpoint.replay import replay_single_stage as _replay
try:
result = _replay(
@@ -361,3 +361,41 @@ async def gpu_detect_edges_debug(request: Request):
media_type="application/json")
except Exception as e:
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
@router.post("/gpu/segment_field")
async def gpu_segment_field(request: Request):
"""Proxy to GPU inference server — field segmentation."""
import httpx
body = await request.body()
try:
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{_gpu_url()}/segment_field",
content=body,
headers={"Content-Type": "application/json"},
)
return Response(content=resp.content, status_code=resp.status_code,
media_type="application/json")
except Exception as e:
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
@router.post("/gpu/segment_field/debug")
async def gpu_segment_field_debug(request: Request):
"""Proxy to GPU inference server — field segmentation with debug overlay."""
import httpx
body = await request.body()
try:
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{_gpu_url()}/segment_field/debug",
content=body,
headers={"Content-Type": "application/json"},
)
return Response(content=resp.content, status_code=resp.status_code,
media_type="application/json")
except Exception as e:
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")

View File

@@ -60,9 +60,9 @@ def _resolve_video_path(video_path: str) -> str:
@router.post("/run", response_model=RunResponse)
def run_pipeline(req: RunRequest):
"""Launch a detection pipeline run on a source chunk."""
from detect import emit
from detect.graph import get_pipeline
from detect.state import DetectState
from core.detect import emit
from core.detect.graph import get_pipeline
from core.detect.state import DetectState
local_path = _resolve_video_path(req.video_path)
job_id = str(uuid.uuid4())
@@ -79,7 +79,7 @@ def run_pipeline(req: RunRequest):
# Clear any stale events from a previous run with same job_id
from core.events import _get_redis
from detect.events import DETECT_EVENTS_PREFIX
from core.detect.events import DETECT_EVENTS_PREFIX
r = _get_redis()
r.delete(f"{DETECT_EVENTS_PREFIX}:{job_id}")
@@ -97,7 +97,7 @@ def run_pipeline(req: RunRequest):
source_asset_id=req.source_asset_id,
)
from detect.graph import (
from core.detect.graph import (
PipelineCancelled, set_cancel_check, clear_cancel_check,
init_pause, clear_pause,
)
@@ -117,7 +117,7 @@ def run_pipeline(req: RunRequest):
emit.job_complete(job_id, {"status": "cancelled"})
except Exception as e:
logger.exception("Pipeline run %s failed: %s", job_id, e)
from detect.graph import _node_states, NODES
from core.detect.graph import _node_states, NODES
if job_id in _node_states:
states = _node_states[job_id]
for node in reversed(NODES):
@@ -145,7 +145,7 @@ def run_pipeline(req: RunRequest):
@router.post("/stop/{job_id}")
def stop_pipeline(job_id: str):
"""Stop a running pipeline. Signals cancellation; the thread checks on next stage."""
from detect import emit
from core.detect import emit
if job_id not in _running_jobs:
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
@@ -158,7 +158,7 @@ def stop_pipeline(job_id: str):
@router.post("/pause/{job_id}")
def pause(job_id: str):
"""Pause a running pipeline after the current stage completes."""
from detect.graph import pause_pipeline
from core.detect.graph import pause_pipeline
if job_id not in _running_jobs:
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
@@ -170,7 +170,7 @@ def pause(job_id: str):
@router.post("/resume/{job_id}")
def resume(job_id: str):
"""Resume a paused pipeline."""
from detect.graph import resume_pipeline
from core.detect.graph import resume_pipeline
if job_id not in _running_jobs:
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
@@ -182,7 +182,7 @@ def resume(job_id: str):
@router.post("/step/{job_id}")
def step(job_id: str):
"""Run one stage then pause again."""
from detect.graph import step_pipeline
from core.detect.graph import step_pipeline
if job_id not in _running_jobs:
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
@@ -194,7 +194,7 @@ def step(job_id: str):
@router.post("/pause-after-stage/{job_id}")
def toggle_pause_after_stage(job_id: str, enabled: bool = True):
"""Toggle pause-after-each-stage mode."""
from detect.graph import set_pause_after_stage
from core.detect.graph import set_pause_after_stage
if job_id not in _running_jobs:
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
@@ -206,7 +206,7 @@ def toggle_pause_after_stage(job_id: str, enabled: bool = True):
@router.get("/status/{job_id}")
def pipeline_status(job_id: str):
"""Get pipeline run status."""
from detect.graph import is_paused
from core.detect.graph import is_paused
running = job_id in _running_jobs
paused = is_paused(job_id)
@@ -224,11 +224,23 @@ def pipeline_status(job_id: str):
return {"status": status, "job_id": job_id}
@router.get("/timeline/{job_id}")
def get_timeline_for_job(job_id: str):
"""Get the timeline_id for a running or completed job."""
from core.detect.checkpoint.runner_bridge import get_timeline_id
tid = get_timeline_id(job_id)
if tid is None:
raise HTTPException(status_code=404, detail=f"No timeline for job: {job_id}")
return {"timeline_id": tid, "job_id": job_id}
@router.post("/clear/{job_id}")
def clear_pipeline(job_id: str):
"""Clear events for a job from Redis."""
from core.events import _get_redis
from detect.events import DETECT_EVENTS_PREFIX
from core.detect.events import DETECT_EVENTS_PREFIX
r = _get_redis()
r.delete(f"{DETECT_EVENTS_PREFIX}:{job_id}")

View File

@@ -17,7 +17,7 @@ from fastapi import APIRouter
from starlette.responses import StreamingResponse
from core.events import poll_events
from detect.events import DETECT_EVENTS_PREFIX, TERMINAL_EVENTS
from core.detect.events import DETECT_EVENTS_PREFIX, TERMINAL_EVENTS
logger = logging.getLogger(__name__)

View File

@@ -1,384 +0,0 @@
"""
GraphQL API using strawberry, served via FastAPI.
Primary API for MPR — all client interactions go through GraphQL.
Uses core.db for data access.
Types are generated from schema/ via modelgen — see api/schema/graphql.py.
"""
import os
from typing import List, Optional
from uuid import UUID
import strawberry
from strawberry.schema.config import StrawberryConfig
from strawberry.types import Info
from core.api.schema.graphql import (
CancelResultType,
ChunkJobType,
ChunkOutputFileType,
CreateChunkJobInput,
CreateJobInput,
DeleteResultType,
MediaAssetType,
ScanResultType,
SystemStatusType,
TranscodeJobType,
TranscodePresetType,
UpdateAssetInput,
)
from core.storage import BUCKET_IN, list_objects, upload_file
VIDEO_EXTS = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv", ".wmv", ".m4v"}
AUDIO_EXTS = {".mp3", ".wav", ".flac", ".aac", ".ogg", ".m4a"}
MEDIA_EXTS = VIDEO_EXTS | AUDIO_EXTS
# ---------------------------------------------------------------------------
# Queries
# ---------------------------------------------------------------------------
@strawberry.type
class Query:
@strawberry.field
def assets(
self,
info: Info,
status: Optional[str] = None,
search: Optional[str] = None,
) -> List[MediaAssetType]:
from core.db import list_assets
return list_assets(status=status, search=search)
@strawberry.field
def asset(self, info: Info, id: UUID) -> Optional[MediaAssetType]:
from core.db import get_asset
try:
return get_asset(id)
except Exception:
return None
@strawberry.field
def jobs(
self,
info: Info,
status: Optional[str] = None,
source_asset_id: Optional[UUID] = None,
) -> List[TranscodeJobType]:
from core.db import list_jobs
return list_jobs(status=status, source_asset_id=source_asset_id)
@strawberry.field
def job(self, info: Info, id: UUID) -> Optional[TranscodeJobType]:
from core.db import get_job
try:
return get_job(id)
except Exception:
return None
@strawberry.field
def presets(self, info: Info) -> List[TranscodePresetType]:
from core.db import list_presets
return list_presets()
@strawberry.field
def system_status(self, info: Info) -> SystemStatusType:
return SystemStatusType(status="ok", version="0.1.0")
@strawberry.field
def chunk_output_files(self, info: Info, job_id: str) -> List[ChunkOutputFileType]:
"""List output chunk files for a completed job from media/out/."""
from pathlib import Path
media_out = os.environ.get("MEDIA_OUT_DIR", "/app/media/out")
output_dir = Path(media_out) / "chunks" / job_id
if not output_dir.is_dir():
return []
return [
ChunkOutputFileType(
key=f.name,
size=f.stat().st_size,
url=f"/media/out/chunks/{job_id}/{f.name}",
)
for f in sorted(output_dir.iterdir())
if f.is_file()
]
# ---------------------------------------------------------------------------
# Mutations
# ---------------------------------------------------------------------------
@strawberry.type
class Mutation:
@strawberry.mutation
def scan_media_folder(self, info: Info) -> ScanResultType:
import logging
from pathlib import Path
from core.db import create_asset, get_asset_filenames
logger = logging.getLogger(__name__)
# Sync local media/in/ files to MinIO (handles fresh installs / pruned volumes)
local_media = Path("/app/media/in")
if local_media.is_dir():
existing_keys = {o["key"] for o in list_objects(BUCKET_IN)}
for f in local_media.iterdir():
if f.is_file() and f.suffix.lower() in MEDIA_EXTS:
if f.name not in existing_keys:
try:
upload_file(str(f), BUCKET_IN, f.name)
logger.info("Uploaded %s to MinIO", f.name)
except Exception as e:
logger.warning("Failed to upload %s: %s", f.name, e)
objects = list_objects(BUCKET_IN, extensions=MEDIA_EXTS)
existing = get_asset_filenames()
registered = []
skipped = []
for obj in objects:
if obj["filename"] in existing:
skipped.append(obj["filename"])
continue
try:
create_asset(
filename=obj["filename"],
file_path=obj["key"],
file_size=obj["size"],
)
registered.append(obj["filename"])
except Exception:
pass
return ScanResultType(
found=len(objects),
registered=len(registered),
skipped=len(skipped),
files=registered,
)
@strawberry.mutation
def create_job(self, info: Info, input: CreateJobInput) -> TranscodeJobType:
from pathlib import Path
from core.db import create_job, get_asset, get_preset
try:
source = get_asset(input.source_asset_id)
except Exception:
raise Exception("Source asset not found")
preset = None
preset_snapshot = {}
if input.preset_id:
try:
preset = get_preset(input.preset_id)
preset_snapshot = {
"name": preset.name,
"container": preset.container,
"video_codec": preset.video_codec,
"audio_codec": preset.audio_codec,
}
except Exception:
raise Exception("Preset not found")
if not preset and not input.trim_start and not input.trim_end:
raise Exception("Must specify preset_id or trim_start/trim_end")
output_filename = input.output_filename
if not output_filename:
stem = Path(source.filename).stem
ext = preset_snapshot.get("container", "mp4") if preset else "mp4"
output_filename = f"{stem}_output.{ext}"
job = create_job(
source_asset_id=source.id,
preset_id=preset.id if preset else None,
preset_snapshot=preset_snapshot,
trim_start=input.trim_start,
trim_end=input.trim_end,
output_filename=output_filename,
output_path=output_filename,
priority=input.priority or 0,
)
payload = {
"source_key": source.file_path,
"output_key": output_filename,
"preset": preset_snapshot or None,
"trim_start": input.trim_start,
"trim_end": input.trim_end,
"duration": source.duration,
}
executor_mode = os.environ.get("MPR_EXECUTOR", "local")
if executor_mode in ("lambda", "gcp"):
from core.jobs.executor import get_executor
get_executor().run(
job_type="transcode",
job_id=str(job.id),
payload=payload,
)
else:
from core.jobs.task import run_job
result = run_job.delay(
job_type="transcode",
job_id=str(job.id),
payload=payload,
)
job.celery_task_id = result.id
job.save(update_fields=["celery_task_id"])
return job
@strawberry.mutation
def cancel_job(self, info: Info, id: UUID) -> TranscodeJobType:
from core.db import get_job, update_job
try:
job = get_job(id)
except Exception:
raise Exception("Job not found")
if job.status not in ("pending", "processing"):
raise Exception(f"Cannot cancel job with status: {job.status}")
return update_job(job, status="cancelled")
@strawberry.mutation
def retry_job(self, info: Info, id: UUID) -> TranscodeJobType:
from core.db import get_job, update_job
try:
job = get_job(id)
except Exception:
raise Exception("Job not found")
if job.status != "failed":
raise Exception("Only failed jobs can be retried")
return update_job(job, status="pending", progress=0, error_message=None)
@strawberry.mutation
def update_asset(self, info: Info, id: UUID, input: UpdateAssetInput) -> MediaAssetType:
from core.db import get_asset, update_asset
try:
asset = get_asset(id)
except Exception:
raise Exception("Asset not found")
fields = {}
if input.comments is not None:
fields["comments"] = input.comments
if input.tags is not None:
fields["tags"] = input.tags
if fields:
asset = update_asset(asset, **fields)
return asset
@strawberry.mutation
def delete_asset(self, info: Info, id: UUID) -> DeleteResultType:
from core.db import delete_asset, get_asset
try:
asset = get_asset(id)
delete_asset(asset)
return DeleteResultType(ok=True)
except Exception:
raise Exception("Asset not found")
@strawberry.mutation
def create_chunk_job(self, info: Info, input: CreateChunkJobInput) -> ChunkJobType:
"""Create and dispatch a chunk pipeline job."""
import uuid
from core.db import get_asset
try:
source = get_asset(input.source_asset_id)
except Exception:
raise Exception("Source asset not found")
job_id = str(uuid.uuid4())
payload = {
"source_key": source.file_path,
"chunk_duration": input.chunk_duration,
"num_workers": input.num_workers,
"max_retries": input.max_retries,
"processor_type": input.processor_type,
"start_time": input.start_time,
"end_time": input.end_time,
}
executor_mode = os.environ.get("MPR_EXECUTOR", "local")
celery_task_id = None
if executor_mode in ("lambda", "gcp"):
from core.jobs.executor import get_executor
get_executor().run(
job_type="chunk",
job_id=job_id,
payload=payload,
)
else:
from core.jobs.task import run_job
result = run_job.delay(
job_type="chunk",
job_id=job_id,
payload=payload,
)
celery_task_id = result.id
return ChunkJobType(
id=uuid.UUID(job_id),
source_asset_id=input.source_asset_id,
chunk_duration=input.chunk_duration,
num_workers=input.num_workers,
max_retries=input.max_retries,
processor_type=input.processor_type,
status="pending",
progress=0.0,
priority=input.priority,
celery_task_id=celery_task_id,
)
@strawberry.mutation
def cancel_chunk_job(self, info: Info, celery_task_id: str) -> CancelResultType:
"""Cancel a running chunk job by revoking its Celery task."""
try:
from admin.mpr.celery import app as celery_app
celery_app.control.revoke(celery_task_id, terminate=True, signal="SIGTERM")
return CancelResultType(ok=True, message="Task revoked")
except Exception as e:
return CancelResultType(ok=False, message=str(e))
# ---------------------------------------------------------------------------
# Schema
# ---------------------------------------------------------------------------
schema = strawberry.Schema(
query=Query,
mutation=Mutation,
config=StrawberryConfig(auto_camel_case=False),
)

View File

@@ -1,48 +1,38 @@
"""
MPR FastAPI Application
Serves GraphQL API and Lambda callback endpoint.
"""
import os
import sys
from typing import Optional
from uuid import UUID
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from contextlib import asynccontextmanager
from fastapi import FastAPI, Header, HTTPException
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from strawberry.fastapi import GraphQLRouter
from core.api.chunker_sse import router as chunker_router
from core.api.detect import router as detect_router
from core.api.graphql import schema as graphql_schema
CALLBACK_API_KEY = os.environ.get("CALLBACK_API_KEY", "")
@asynccontextmanager
async def lifespan(app):
# Create/reset DB tables on startup
from core.db.connection import create_tables
from core.db.seed import seed_profiles
create_tables()
seed_profiles()
yield
app = FastAPI(
title="MPR API",
description="Media Processor — GraphQL API",
version="0.1.0",
docs_url="/docs",
redoc_url="/redoc",
lifespan=lifespan,
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["http://mpr.local.ar", "http://k8s.mpr.local.ar", "http://localhost:5173"],
@@ -51,13 +41,6 @@ app.add_middleware(
allow_headers=["*"],
)
# GraphQL
graphql_router = GraphQLRouter(schema=graphql_schema, graphql_ide="graphiql")
app.include_router(graphql_router, prefix="/graphql")
# Chunker SSE
app.include_router(chunker_router)
# Detection API (sources, run, SSE, replay, config)
app.include_router(detect_router)
@@ -69,48 +52,7 @@ def health():
@app.get("/")
def root():
"""API root."""
return {
"name": "MPR API",
"version": "0.1.0",
"graphql": "/graphql",
}
@app.post("/api/jobs/{job_id}/callback")
def job_callback(
job_id: UUID,
payload: dict,
x_api_key: Optional[str] = Header(None),
):
"""
Callback endpoint for Lambda to report job completion.
Protected by API key.
"""
if CALLBACK_API_KEY and x_api_key != CALLBACK_API_KEY:
raise HTTPException(status_code=403, detail="Invalid API key")
from django.utils import timezone
from core.db import get_job, update_job
try:
job = get_job(job_id)
except Exception:
raise HTTPException(status_code=404, detail="Job not found")
status = payload.get("status", "failed")
fields = {
"status": status,
"progress": 100.0 if status == "completed" else job.progress,
}
if payload.get("error"):
fields["error_message"] = payload["error"]
if status in ("completed", "failed"):
fields["completed_at"] = timezone.now()
update_job(job, **fields)
return {"ok": True}

View File

@@ -1,226 +0,0 @@
"""
Strawberry Types - GENERATED FILE
Do not edit directly. Regenerate using modelgen.
"""
import strawberry
from enum import Enum
from typing import List, Optional
from uuid import UUID
from datetime import datetime
from strawberry.scalars import JSON
@strawberry.enum
class AssetStatus(Enum):
PENDING = "pending"
READY = "ready"
ERROR = "error"
@strawberry.enum
class JobStatus(Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@strawberry.type
class MediaAssetType:
"""A video/audio file registered in the system."""
id: Optional[UUID] = None
filename: Optional[str] = None
file_path: Optional[str] = None
status: Optional[str] = None
error_message: Optional[str] = None
file_size: Optional[float] = None
duration: Optional[float] = None
video_codec: Optional[str] = None
audio_codec: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
framerate: Optional[float] = None
bitrate: Optional[int] = None
properties: Optional[JSON] = None
comments: Optional[str] = None
tags: Optional[List[str]] = None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
@strawberry.type
class TranscodePresetType:
"""A reusable transcoding configuration (like Handbrake presets)."""
id: Optional[UUID] = None
name: Optional[str] = None
description: Optional[str] = None
is_builtin: Optional[bool] = None
container: Optional[str] = None
video_codec: Optional[str] = None
video_bitrate: Optional[str] = None
video_crf: Optional[int] = None
video_preset: Optional[str] = None
resolution: Optional[str] = None
framerate: Optional[float] = None
audio_codec: Optional[str] = None
audio_bitrate: Optional[str] = None
audio_channels: Optional[int] = None
audio_samplerate: Optional[int] = None
extra_args: Optional[List[str]] = None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
@strawberry.type
class TranscodeJobType:
"""A transcoding or trimming job in the queue."""
id: Optional[UUID] = None
source_asset_id: Optional[UUID] = None
preset_id: Optional[UUID] = None
preset_snapshot: Optional[JSON] = None
trim_start: Optional[float] = None
trim_end: Optional[float] = None
output_filename: Optional[str] = None
output_path: Optional[str] = None
output_asset_id: Optional[UUID] = None
status: Optional[str] = None
progress: Optional[float] = None
current_frame: Optional[int] = None
current_time: Optional[float] = None
speed: Optional[str] = None
error_message: Optional[str] = None
celery_task_id: Optional[str] = None
execution_arn: Optional[str] = None
priority: Optional[int] = None
created_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
@strawberry.input
class CreateJobInput:
"""Request body for creating a transcode/trim job."""
source_asset_id: UUID
preset_id: Optional[UUID] = None
trim_start: Optional[float] = None
trim_end: Optional[float] = None
output_filename: Optional[str] = None
priority: int = 0
@strawberry.input
class UpdateAssetInput:
"""Request body for updating asset metadata."""
comments: Optional[str] = None
tags: Optional[List[str]] = None
@strawberry.type
class SystemStatusType:
"""System status response."""
status: Optional[str] = None
version: Optional[str] = None
@strawberry.type
class ScanResultType:
"""Result of scanning the media input bucket."""
found: Optional[int] = None
registered: Optional[int] = None
skipped: Optional[int] = None
files: Optional[List[str]] = None
@strawberry.type
class DeleteResultType:
"""Result of a delete operation."""
ok: Optional[bool] = None
@strawberry.type
class WorkerStatusType:
"""Worker health and capabilities."""
available: Optional[bool] = None
active_jobs: Optional[int] = None
supported_codecs: Optional[List[str]] = None
gpu_available: Optional[bool] = None
@strawberry.enum
class ChunkJobStatus(Enum):
PENDING = "pending"
CHUNKING = "chunking"
PROCESSING = "processing"
COLLECTING = "collecting"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@strawberry.type
class ChunkJobType:
"""A chunk pipeline job."""
id: Optional[UUID] = None
source_asset_id: Optional[UUID] = None
chunk_duration: Optional[float] = None
num_workers: Optional[int] = None
max_retries: Optional[int] = None
processor_type: Optional[str] = None
status: Optional[str] = None
progress: Optional[float] = None
total_chunks: Optional[int] = None
processed_chunks: Optional[int] = None
failed_chunks: Optional[int] = None
retry_count: Optional[int] = None
error_message: Optional[str] = None
throughput_mbps: Optional[float] = None
elapsed_seconds: Optional[float] = None
celery_task_id: Optional[str] = None
priority: Optional[int] = None
created_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
@strawberry.input
class CreateChunkJobInput:
"""Request body for creating a chunk pipeline job."""
source_asset_id: UUID
chunk_duration: float = 10.0
num_workers: int = 4
max_retries: int = 3
processor_type: str = "ffmpeg"
priority: int = 0
start_time: Optional[float] = None
end_time: Optional[float] = None
@strawberry.type
class CancelResultType:
"""Result of cancelling a chunk job."""
ok: bool = False
message: Optional[str] = None
@strawberry.type
class ChunkOutputFileType:
"""A chunk output file in S3/MinIO with presigned download URL."""
key: str
size: int = 0
url: str = ""

View File

@@ -1,64 +0,0 @@
"""
Chunker pipeline — splits files into chunks, processes concurrently, reassembles in order.
Public API:
Pipeline — orchestrates the full pipeline
PipelineResult — aggregate result dataclass
Chunker — file → Chunk generator
ChunkQueue — bounded thread-safe queue
WorkerPool — manages N worker threads
ResultCollector — heapq-based ordered reassembly
"""
from .chunker import Chunker
from .collector import ResultCollector
from .exceptions import (
ChunkChecksumError,
ChunkError,
ChunkReadError,
PipelineError,
ProcessingError,
ProcessorFailureError,
ProcessorTimeoutError,
ReassemblyError,
)
from .models import Chunk, ChunkResult, PipelineResult
from .pipeline import Pipeline
from .pool import WorkerPool
from .processor import (
ChecksumProcessor,
CompositeProcessor,
FFmpegExtractProcessor,
Processor,
SimulatedDecodeProcessor,
)
from .queue import ChunkQueue
__all__ = [
# Core
"Pipeline",
"PipelineResult",
# Components
"Chunker",
"ChunkQueue",
"WorkerPool",
"ResultCollector",
# Models
"Chunk",
"ChunkResult",
# Processors
"Processor",
"ChecksumProcessor",
"SimulatedDecodeProcessor",
"CompositeProcessor",
"FFmpegExtractProcessor",
# Exceptions
"PipelineError",
"ChunkError",
"ChunkReadError",
"ChunkChecksumError",
"ProcessingError",
"ProcessorFailureError",
"ProcessorTimeoutError",
"ReassemblyError",
]

View File

@@ -1,101 +0,0 @@
"""
Chunker — probes a media file and yields time-based Chunk objects.
Demonstrates:
- Function parameters and defaults (Interview Topic 1)
- List comprehensions and efficient iteration / generators (Interview Topic 3)
"""
import math
import os
from typing import Generator
from core.ffmpeg.probe import probe_file
from .exceptions import ChunkReadError
from .models import Chunk
class Chunker:
"""
Splits a media file into time-based chunks via a generator.
Uses FFmpeg probe to get duration, then yields Chunk objects
representing time segments (no data read — extraction happens in the processor).
Args:
file_path: Path to the source media file
chunk_duration: Duration of each chunk in seconds (default: 10.0)
"""
def __init__(
self,
file_path: str,
chunk_duration: float = 10.0,
start_time: float | None = None,
end_time: float | None = None,
):
if not os.path.isfile(file_path):
raise ChunkReadError(f"File not found: {file_path}")
if chunk_duration <= 0:
raise ValueError("chunk_duration must be positive")
self.file_path = file_path
self.chunk_duration = chunk_duration
self.file_size = os.path.getsize(file_path)
full_duration = self._probe_duration()
# Apply time range
self.range_start = max(start_time or 0.0, 0.0)
self.range_end = min(end_time or full_duration, full_duration)
if self.range_start >= self.range_end:
raise ValueError(
f"Invalid range: start={self.range_start} >= end={self.range_end}"
)
self.source_duration = self.range_end - self.range_start
def _probe_duration(self) -> float:
"""Get source file duration via FFmpeg probe."""
try:
result = probe_file(self.file_path)
if result.duration is None or result.duration <= 0:
raise ChunkReadError(
f"Cannot determine duration for {self.file_path}"
)
return result.duration
except ChunkReadError:
raise
except Exception as e:
raise ChunkReadError(
f"Failed to probe {self.file_path}: {e}"
) from e
@property
def expected_chunks(self) -> int:
"""Calculate expected number of chunks (last chunk may be shorter)."""
if self.source_duration <= 0:
return 0
return math.ceil(self.source_duration / self.chunk_duration)
def chunks(self) -> Generator[Chunk, None, None]:
"""
Yield Chunk objects representing time segments of the source file.
Generator-based: chunks are yielded on demand.
Each chunk defines a time range — actual extraction is done by the processor.
"""
total = self.expected_chunks
for sequence in range(total):
start_time = self.range_start + sequence * self.chunk_duration
end_time = min(
start_time + self.chunk_duration, self.range_end
)
duration = end_time - start_time
yield Chunk(
sequence=sequence,
start_time=start_time,
end_time=end_time,
source_path=self.file_path,
duration=duration,
)

View File

@@ -1,98 +0,0 @@
"""
ResultCollector — reassembles chunk results in sequence order using a min-heap.
Demonstrates:
- Algorithms and sorting (Interview Topic 6) — heapq for ordered reassembly
- Core data structures (Interview Topic 5) — heap, deque
"""
import heapq
from collections import deque
from typing import List
from .exceptions import ReassemblyError
from .models import ChunkResult
class ResultCollector:
"""
Receives ChunkResults out of order, emits them in sequence order.
Uses a min-heap keyed on sequence number. Only emits a chunk when
all prior sequences have been accounted for.
Args:
total_chunks: Expected total number of chunks
"""
def __init__(self, total_chunks: int):
self.total_chunks = total_chunks
self._heap: List[tuple[int, ChunkResult]] = []
self._next_sequence = 0
self._emitted: List[ChunkResult] = []
self._seen_sequences: set[int] = set()
# Sliding window for throughput calculation
self._recent_times: deque[float] = deque(maxlen=50)
def add(self, result: ChunkResult) -> List[ChunkResult]:
"""
Add a result and return any newly emittable results in order.
Args:
result: A ChunkResult (may arrive out of order)
Returns:
List of results that can now be emitted in sequence order
(may be empty if we're still waiting for earlier sequences)
Raises:
ReassemblyError: If a duplicate sequence is received
"""
if result.sequence in self._seen_sequences:
raise ReassemblyError(
f"Duplicate sequence number: {result.sequence}"
)
self._seen_sequences.add(result.sequence)
# Track processing time for throughput
if result.processing_time > 0:
self._recent_times.append(result.processing_time)
# Push to min-heap
heapq.heappush(self._heap, (result.sequence, result))
# Emit all consecutive results starting from _next_sequence
newly_emitted = []
while self._heap and self._heap[0][0] == self._next_sequence:
_, emitted_result = heapq.heappop(self._heap)
self._emitted.append(emitted_result)
newly_emitted.append(emitted_result)
self._next_sequence += 1
return newly_emitted
@property
def is_complete(self) -> bool:
"""True if all expected chunks have been emitted in order."""
return self._next_sequence == self.total_chunks
@property
def buffered_count(self) -> int:
"""Number of results waiting in the heap (arrived out of order)."""
return len(self._heap)
@property
def emitted_count(self) -> int:
"""Number of results emitted in sequence order."""
return len(self._emitted)
@property
def avg_processing_time(self) -> float:
"""Average processing time from recent results (sliding window)."""
if not self._recent_times:
return 0.0
return sum(self._recent_times) / len(self._recent_times)
def get_ordered_results(self) -> List[ChunkResult]:
"""Get all emitted results in sequence order."""
return list(self._emitted)

View File

@@ -1,64 +0,0 @@
"""
Chunker exception hierarchy.
Demonstrates: Managing exceptions and writing resilient code (Interview Topic 7).
"""
class PipelineError(Exception):
"""Base exception for all chunker pipeline errors."""
pass
class ChunkError(PipelineError):
"""Errors related to chunk creation or validation."""
pass
class ChunkReadError(ChunkError):
"""Failed to read chunk data from source file."""
pass
class ChunkChecksumError(ChunkError):
"""Chunk data integrity validation failed."""
def __init__(self, sequence: int, expected: str, actual: str):
self.sequence = sequence
self.expected = expected
self.actual = actual
super().__init__(
f"Chunk {sequence}: checksum mismatch "
f"(expected={expected}, actual={actual})"
)
class ProcessingError(PipelineError):
"""Errors during chunk processing by workers."""
pass
class ProcessorTimeoutError(ProcessingError):
"""Processor exceeded allowed time for a chunk."""
def __init__(self, sequence: int, timeout: float):
self.sequence = sequence
self.timeout = timeout
super().__init__(f"Chunk {sequence}: processor timed out after {timeout}s")
class ProcessorFailureError(ProcessingError):
"""Processor failed to process a chunk after all retries."""
def __init__(self, sequence: int, retries: int, original_error: Exception):
self.sequence = sequence
self.retries = retries
self.original_error = original_error
super().__init__(
f"Chunk {sequence}: failed after {retries} retries — {original_error}"
)
class ReassemblyError(PipelineError):
"""Errors during result collection and ordering."""
pass

View File

@@ -1,54 +0,0 @@
"""
Internal data models for the chunker pipeline.
These are pipeline-internal dataclasses, not schema models.
Schema-level ChunkJob is in core/schema/models/jobs.py.
Demonstrates: Core data structures (Interview Topic 5).
"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
@dataclass
class Chunk:
"""A time-based segment of the source media file."""
sequence: int
start_time: float # seconds
end_time: float # seconds
source_path: str # path to source file
duration: float # end_time - start_time
checksum: str = "" # computed after extraction
@dataclass
class ChunkResult:
"""Result of processing a single chunk."""
sequence: int
success: bool
checksum_valid: bool = True
processing_time: float = 0.0
error: Optional[str] = None
retries: int = 0
worker_id: Optional[str] = None
output_file: Optional[str] = None
@dataclass
class PipelineResult:
"""Aggregate result of the entire pipeline run."""
total_chunks: int = 0
processed: int = 0
failed: int = 0
retries: int = 0
elapsed_time: float = 0.0
throughput_mbps: float = 0.0
worker_stats: Dict[str, Any] = field(default_factory=dict)
errors: List[str] = field(default_factory=list)
chunks_in_order: bool = True
output_dir: Optional[str] = None
chunk_files: List[str] = field(default_factory=list)

View File

@@ -1,279 +0,0 @@
"""
Pipeline — orchestrates the entire chunker pipeline.
Wires: Chunker → ChunkQueue → WorkerPool → ResultCollector → PipelineResult
Demonstrates:
- Function parameters and defaults (Interview Topic 1) — configurable pipeline
- Concurrency (Interview Topic 2) — producer thread + worker pool
- OOP design (Interview Topic 4) — composition of pipeline components
- Exception handling (Interview Topic 7) — graceful error propagation
"""
import json
import logging
import threading
import time
from pathlib import Path
from typing import Any, Callable, Dict, Optional
from .chunker import Chunker
from .collector import ResultCollector
from .exceptions import PipelineError
from .models import PipelineResult
from .pool import WorkerPool
from .queue import ChunkQueue
logger = logging.getLogger(__name__)
class Pipeline:
"""
Orchestrates the chunk processing pipeline.
The pipeline runs in three stages:
1. Producer thread: Chunker probes file → pushes time-based chunks to ChunkQueue
2. Worker pool: N workers pull from queue → extract mp4 segments → emit results
3. Collector: ResultCollector reassembles results in sequence order
Args:
source: Path to the source media file
chunk_duration: Duration of each chunk in seconds (default: 10.0)
num_workers: Number of concurrent worker threads (default: 4)
max_retries: Max retry attempts per chunk (default: 3)
processor_type: Processor to use — "ffmpeg", "checksum", "simulated_decode", "composite"
queue_size: Max chunks buffered in queue (default: 10)
event_callback: Optional callback for real-time events
output_dir: Directory for output chunk files (required for "ffmpeg" processor)
"""
def __init__(
self,
source: str,
chunk_duration: float = 10.0,
num_workers: int = 4,
max_retries: int = 3,
processor_type: str = "checksum",
queue_size: int = 10,
event_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None,
output_dir: Optional[str] = None,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
):
self.source = source
self.chunk_duration = chunk_duration
self.num_workers = num_workers
self.max_retries = max_retries
self.processor_type = processor_type
self.queue_size = queue_size
self.event_callback = event_callback
self.output_dir = output_dir
self.start_time = start_time
self.end_time = end_time
def _emit(self, event_type: str, data: Dict[str, Any]) -> None:
"""Emit an event if callback is registered."""
if self.event_callback:
self.event_callback(event_type, data)
def _produce_chunks(
self, chunker: Chunker, chunk_queue: ChunkQueue
) -> None:
"""Producer thread: probe file and enqueue time-based chunks."""
try:
for chunk in chunker.chunks():
chunk_queue.put(chunk, timeout=30.0)
self._emit("chunk_queued", {
"sequence": chunk.sequence,
"start_time": chunk.start_time,
"end_time": chunk.end_time,
"duration": chunk.duration,
"queue_size": chunk_queue.qsize(),
})
except Exception as e:
logger.error(f"Producer error: {e}")
self._emit("producer_error", {"error": str(e)})
finally:
chunk_queue.close()
def _monitor_progress(
self, start_time: float, file_size: int, stop_event: threading.Event
) -> None:
"""Monitor thread: emit pipeline_progress every 500ms."""
while not stop_event.is_set():
elapsed = time.monotonic() - start_time
mb = file_size / (1024 * 1024)
self._emit("pipeline_progress", {
"elapsed": round(elapsed, 2),
"throughput_mbps": round(mb / elapsed, 2) if elapsed > 0 else 0,
})
stop_event.wait(0.5)
def _write_manifest(
self, result: PipelineResult, source_duration: float
) -> None:
"""Write manifest.json to output_dir with segment metadata."""
if not self.output_dir:
return
manifest = {
"source": self.source,
"source_duration": source_duration,
"chunk_duration": self.chunk_duration,
"total_chunks": result.total_chunks,
"processed": result.processed,
"failed": result.failed,
"elapsed_time": result.elapsed_time,
"throughput_mbps": result.throughput_mbps,
"segments": [
{
"sequence": i,
"file": f"chunk_{i:04d}.mp4",
"start": i * self.chunk_duration,
"end": min(
(i + 1) * self.chunk_duration, source_duration
),
}
for i in range(result.total_chunks)
if i < result.total_chunks
],
}
manifest_path = Path(self.output_dir) / "manifest.json"
manifest_path.write_text(json.dumps(manifest, indent=2))
logger.info(f"Manifest written to {manifest_path}")
def run(self) -> PipelineResult:
"""
Execute the full pipeline.
Returns:
PipelineResult with aggregate stats
Raises:
PipelineError: If the pipeline fails catastrophically
"""
start_time = time.monotonic()
self._emit("pipeline_start", {
"source": self.source,
"chunk_duration": self.chunk_duration,
"num_workers": self.num_workers,
"processor_type": self.processor_type,
})
try:
# Stage 1: Set up chunker (probes file for duration)
chunker = Chunker(
self.source,
self.chunk_duration,
start_time=self.start_time,
end_time=self.end_time,
)
total_chunks = chunker.expected_chunks
if total_chunks == 0:
self._emit("pipeline_complete", {"total_chunks": 0})
return PipelineResult(chunks_in_order=True)
self._emit("pipeline_info", {
"file_size": chunker.file_size,
"source_duration": chunker.source_duration,
"total_chunks": total_chunks,
})
# Stage 2: Set up queue and worker pool
chunk_queue = ChunkQueue(maxsize=self.queue_size)
pool = WorkerPool(
num_workers=self.num_workers,
chunk_queue=chunk_queue,
processor_type=self.processor_type,
max_retries=self.max_retries,
event_callback=self.event_callback,
output_dir=self.output_dir,
)
# Stage 3: Start workers, monitor, then produce chunks
pool.start()
monitor_stop = threading.Event()
monitor = threading.Thread(
target=self._monitor_progress,
args=(start_time, chunker.file_size, monitor_stop),
name="progress-monitor",
daemon=True,
)
monitor.start()
producer = threading.Thread(
target=self._produce_chunks,
args=(chunker, chunk_queue),
name="chunk-producer",
daemon=True,
)
producer.start()
# Stage 4: Wait for all workers to finish
all_results = pool.wait()
producer.join(timeout=5.0)
# Stop monitor
monitor_stop.set()
monitor.join(timeout=2.0)
# Stage 5: Collect results in order
collector = ResultCollector(total_chunks)
for r in all_results:
collector.add(r)
self._emit("chunk_collected", {
"sequence": r.sequence,
"success": r.success,
"buffered": collector.buffered_count,
"emitted": collector.emitted_count,
})
# Build result
elapsed = time.monotonic() - start_time
file_size_mb = chunker.file_size / (1024 * 1024)
throughput = file_size_mb / elapsed if elapsed > 0 else 0.0
failed_results = [r for r in all_results if not r.success]
total_retries = sum(r.retries for r in all_results)
chunk_files = [
r.output_file for r in all_results
if r.success and r.output_file
]
result = PipelineResult(
total_chunks=total_chunks,
processed=len(all_results),
failed=len(failed_results),
retries=total_retries,
elapsed_time=elapsed,
throughput_mbps=throughput,
worker_stats=pool.get_worker_stats(),
errors=[r.error for r in failed_results if r.error],
chunks_in_order=collector.is_complete,
output_dir=self.output_dir,
chunk_files=chunk_files,
)
# Write manifest if output_dir is set
self._write_manifest(result, chunker.source_duration)
pool.shutdown()
self._emit("pipeline_complete", {
"total_chunks": result.total_chunks,
"processed": result.processed,
"failed": result.failed,
"elapsed": result.elapsed_time,
"throughput_mbps": result.throughput_mbps,
})
return result
except PipelineError:
raise
except Exception as e:
self._emit("pipeline_error", {"error": str(e)})
raise PipelineError(f"Pipeline failed: {e}") from e

View File

@@ -1,125 +0,0 @@
"""
WorkerPool — manages N worker threads via ThreadPoolExecutor.
Demonstrates: Python concurrency — threading (Interview Topic 2).
"""
import logging
import threading
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Any, Callable, Dict, List, Optional
from .models import ChunkResult
from .processor import (
ChecksumProcessor,
CompositeProcessor,
FFmpegExtractProcessor,
Processor,
SimulatedDecodeProcessor,
)
from .queue import ChunkQueue
from .worker import Worker
logger = logging.getLogger(__name__)
def create_processor(
processor_type: str = "checksum",
output_dir: Optional[str] = None,
) -> Processor:
"""Factory for processor instances."""
if processor_type == "ffmpeg":
if not output_dir:
raise ValueError("output_dir required for ffmpeg processor")
return FFmpegExtractProcessor(output_dir=output_dir)
elif processor_type == "checksum":
return ChecksumProcessor()
elif processor_type == "simulated_decode":
return SimulatedDecodeProcessor()
elif processor_type == "composite":
return CompositeProcessor([
ChecksumProcessor(),
SimulatedDecodeProcessor(ms_per_second=50.0),
])
else:
raise ValueError(f"Unknown processor type: {processor_type}")
class WorkerPool:
"""
Manages N worker threads that process chunks concurrently.
Args:
num_workers: Number of concurrent worker threads (default: 4)
chunk_queue: Shared queue to pull chunks from
processor_type: Type of processor for each worker (default: "checksum")
max_retries: Max retry attempts per chunk (default: 3)
event_callback: Optional callback for real-time events
"""
def __init__(
self,
num_workers: int = 4,
chunk_queue: Optional[ChunkQueue] = None,
processor_type: str = "checksum",
max_retries: int = 3,
event_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None,
output_dir: Optional[str] = None,
):
self.num_workers = num_workers
self.chunk_queue = chunk_queue or ChunkQueue()
self.processor_type = processor_type
self.max_retries = max_retries
self.event_callback = event_callback
self.output_dir = output_dir
self.shutdown_event = threading.Event()
self._executor: Optional[ThreadPoolExecutor] = None
self._futures: List[Future] = []
self._workers: List[Worker] = []
def start(self) -> None:
"""Start all worker threads."""
self._executor = ThreadPoolExecutor(
max_workers=self.num_workers,
thread_name_prefix="chunk-worker",
)
for i in range(self.num_workers):
worker = Worker(
worker_id=f"worker-{i}",
chunk_queue=self.chunk_queue,
processor=create_processor(self.processor_type, output_dir=self.output_dir),
max_retries=self.max_retries,
event_callback=self.event_callback,
)
self._workers.append(worker)
future = self._executor.submit(worker.run)
self._futures.append(future)
logger.info(f"WorkerPool started with {self.num_workers} workers")
def wait(self) -> List[ChunkResult]:
"""Wait for all workers to finish and collect results."""
all_results = []
for future in self._futures:
results = future.result()
all_results.extend(results)
return all_results
def shutdown(self) -> None:
"""Signal shutdown and cleanup."""
self.shutdown_event.set()
self.chunk_queue.close()
if self._executor:
self._executor.shutdown(wait=True)
def get_worker_stats(self) -> Dict[str, Any]:
"""Get per-worker statistics."""
return {
w.worker_id: {
"processed": w.processed_count,
"errors": w.error_count,
"retries": w.retry_count,
}
for w in self._workers
}

View File

@@ -1,173 +0,0 @@
"""
Processor ABC and concrete implementations.
Demonstrates: OOP design principles — ABC, inheritance, composition (Interview Topic 4).
"""
import hashlib
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List
from .exceptions import ChunkChecksumError
from .models import Chunk, ChunkResult
class Processor(ABC):
"""
Abstract base class for chunk processors.
Each processor defines how a single chunk is processed.
The Worker calls processor.process(chunk) and handles retries.
"""
@abstractmethod
def process(self, chunk: Chunk) -> ChunkResult:
"""Process a single chunk and return the result."""
pass
class FFmpegExtractProcessor(Processor):
"""
Extracts a time segment from the source file using FFmpeg stream copy.
Produces a playable mp4 file per chunk — no re-encoding.
Args:
output_dir: Directory to write chunk mp4 files
"""
def __init__(self, output_dir: str):
self.output_dir = output_dir
Path(output_dir).mkdir(parents=True, exist_ok=True)
def process(self, chunk: Chunk) -> ChunkResult:
from core.ffmpeg.transcode import TranscodeConfig, transcode
start = time.monotonic()
output_file = str(
Path(self.output_dir) / f"chunk_{chunk.sequence:04d}.mp4"
)
config = TranscodeConfig(
input_path=chunk.source_path,
output_path=output_file,
video_codec="copy",
audio_codec="copy",
trim_start=chunk.start_time,
trim_end=chunk.end_time,
)
transcode(config)
# Compute checksum of output file
md5 = hashlib.md5()
with open(output_file, "rb") as f:
for block in iter(lambda: f.read(8192), b""):
md5.update(block)
checksum = md5.hexdigest()
elapsed = time.monotonic() - start
return ChunkResult(
sequence=chunk.sequence,
success=True,
checksum_valid=True,
processing_time=elapsed,
output_file=output_file,
)
class ChecksumProcessor(Processor):
"""
Validates chunk metadata consistency.
For time-based chunks, verifies the time range is valid.
Raises ChunkChecksumError on invalid ranges.
"""
def process(self, chunk: Chunk) -> ChunkResult:
start = time.monotonic()
valid = chunk.duration > 0 and chunk.end_time > chunk.start_time
if not valid:
raise ChunkChecksumError(
sequence=chunk.sequence,
expected="valid time range",
actual=f"{chunk.start_time}-{chunk.end_time}",
)
elapsed = time.monotonic() - start
return ChunkResult(
sequence=chunk.sequence,
success=True,
checksum_valid=True,
processing_time=elapsed,
)
class SimulatedDecodeProcessor(Processor):
"""
Simulates decode work by sleeping proportional to chunk duration.
Useful for demonstrating concurrency behavior without real FFmpeg.
Args:
ms_per_second: Milliseconds of simulated work per second of chunk duration (default: 100)
"""
def __init__(self, ms_per_second: float = 100.0):
self.ms_per_second = ms_per_second
def process(self, chunk: Chunk) -> ChunkResult:
start = time.monotonic()
sleep_time = (self.ms_per_second * chunk.duration) / 1000.0
time.sleep(sleep_time)
elapsed = time.monotonic() - start
return ChunkResult(
sequence=chunk.sequence,
success=True,
checksum_valid=True,
processing_time=elapsed,
)
class CompositeProcessor(Processor):
"""
Chains multiple processors — runs each in sequence on the same chunk.
Demonstrates OOP composition pattern.
Args:
processors: List of processors to chain
"""
def __init__(self, processors: List[Processor]):
if not processors:
raise ValueError("CompositeProcessor requires at least one processor")
self.processors = processors
def process(self, chunk: Chunk) -> ChunkResult:
start = time.monotonic()
last_result = None
for proc in self.processors:
last_result = proc.process(chunk)
if not last_result.success:
return last_result
elapsed = time.monotonic() - start
return ChunkResult(
sequence=chunk.sequence,
success=True,
checksum_valid=last_result.checksum_valid if last_result else True,
processing_time=elapsed,
)

View File

@@ -1,76 +0,0 @@
"""
ChunkQueue — bounded, thread-safe queue with sentinel-based shutdown.
Demonstrates: Core data structures — queue.Queue (Interview Topic 5).
"""
import queue
from typing import Optional
from .models import Chunk
# Sentinel value to signal workers to stop
_SENTINEL = object()
class ChunkQueue:
"""
Thread-safe bounded queue for chunks.
Provides backpressure: producers block when the queue is full,
preventing unbounded memory usage.
Args:
maxsize: Maximum number of chunks in the queue (default: 10)
"""
def __init__(self, maxsize: int = 10):
self._queue: queue.Queue = queue.Queue(maxsize=maxsize)
self._closed = False
self.maxsize = maxsize
def put(self, chunk: Chunk, timeout: Optional[float] = None) -> None:
"""
Add a chunk to the queue. Blocks if full (backpressure).
Args:
chunk: The chunk to enqueue
timeout: Max seconds to wait (None = block forever)
Raises:
queue.Full: If timeout expires while queue is full
"""
self._queue.put(chunk, timeout=timeout)
def get(self, timeout: Optional[float] = None) -> Optional[Chunk]:
"""
Get next chunk from queue. Returns None if queue is closed.
Args:
timeout: Max seconds to wait (None = block forever)
Returns:
Chunk or None (if sentinel received, meaning queue is closed)
Raises:
queue.Empty: If timeout expires while queue is empty
"""
item = self._queue.get(timeout=timeout)
if item is _SENTINEL:
# Re-put sentinel so other workers also see it
self._queue.put(_SENTINEL)
return None
return item
def close(self) -> None:
"""Signal all consumers to stop by inserting a sentinel."""
self._closed = True
self._queue.put(_SENTINEL)
@property
def is_closed(self) -> bool:
return self._closed
def qsize(self) -> int:
"""Current number of items in the queue (approximate)."""
return self._queue.qsize()

View File

@@ -1,143 +0,0 @@
"""
Worker — pulls chunks from queue, processes with retry logic.
Demonstrates:
- Exception handling and resilient code (Interview Topic 7)
- Concurrency (Interview Topic 2) — workers run in thread pool
"""
import logging
import queue
import time
from typing import Any, Callable, Dict, Optional
from .exceptions import ProcessorFailureError
from .models import Chunk, ChunkResult
from .processor import Processor
from .queue import ChunkQueue
logger = logging.getLogger(__name__)
class Worker:
"""
Processes chunks from a queue with retry and exponential backoff.
Args:
worker_id: Identifier for this worker (e.g. "worker-0")
chunk_queue: Source queue to pull chunks from
processor: Processor instance to use
max_retries: Maximum retry attempts per chunk (default: 3)
event_callback: Optional callback for real-time status updates
"""
def __init__(
self,
worker_id: str,
chunk_queue: ChunkQueue,
processor: Processor,
max_retries: int = 3,
event_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None,
):
self.worker_id = worker_id
self.chunk_queue = chunk_queue
self.processor = processor
self.max_retries = max_retries
self.event_callback = event_callback
self.processed_count = 0
self.error_count = 0
self.retry_count = 0
def _emit(self, event_type: str, data: Dict[str, Any]) -> None:
"""Emit an event if callback is registered."""
if self.event_callback:
self.event_callback(event_type, {"worker_id": self.worker_id, **data})
def _process_with_retry(self, chunk: Chunk) -> ChunkResult:
"""
Process a chunk with exponential backoff retry.
Retry delays: 0.1s, 0.2s, 0.4s, ... (doubles each attempt)
"""
last_error = None
for attempt in range(self.max_retries + 1):
try:
if attempt > 0:
backoff = 0.1 * (2 ** (attempt - 1))
self._emit("chunk_retry", {
"sequence": chunk.sequence,
"attempt": attempt,
"backoff": backoff,
})
time.sleep(backoff)
self.retry_count += 1
result = self.processor.process(chunk)
result.retries = attempt
result.worker_id = self.worker_id
return result
except Exception as e:
last_error = e
logger.warning(
f"{self.worker_id}: chunk {chunk.sequence} "
f"attempt {attempt + 1}/{self.max_retries + 1} failed: {e}"
)
# All retries exhausted
self.error_count += 1
self._emit("chunk_error", {
"sequence": chunk.sequence,
"error": str(last_error),
"retries": self.max_retries,
})
return ChunkResult(
sequence=chunk.sequence,
success=False,
processing_time=0.0,
error=str(last_error),
retries=self.max_retries,
worker_id=self.worker_id,
)
def run(self) -> list[ChunkResult]:
"""
Main worker loop — pull chunks and process until queue is closed.
Returns:
List of ChunkResults processed by this worker
"""
results = []
self._emit("worker_status", {"state": "idle"})
while True:
try:
chunk = self.chunk_queue.get(timeout=1.0)
except queue.Empty:
continue
if chunk is None: # Sentinel received
break
self._emit("chunk_processing", {
"sequence": chunk.sequence,
"state": "processing",
"queue_size": self.chunk_queue.qsize(),
})
result = self._process_with_retry(chunk)
results.append(result)
self.processed_count += 1
self._emit("chunk_done", {
"sequence": chunk.sequence,
"success": result.success,
"processing_time": result.processing_time,
"retries": result.retries,
"queue_size": self.chunk_queue.qsize(),
})
self._emit("worker_status", {"state": "stopped"})
return results

View File

@@ -13,7 +13,7 @@ Basic CRUD (create, get, update, delete) goes directly through the session:
from .connection import get_session, create_tables
from .tables import MediaAsset, Job, Timeline, Checkpoint, Brand
from .models import MediaAsset, Job, Timeline, Checkpoint, Brand
from .assets import list_assets, get_asset_filenames
from .job import list_jobs

View File

@@ -7,7 +7,7 @@ from uuid import UUID
from sqlmodel import Session, select
from .tables import MediaAsset
from .models import MediaAsset
def list_assets(session: Session, status: Optional[str] = None, search: Optional[str] = None) -> list[MediaAsset]:

View File

@@ -7,7 +7,7 @@ from uuid import UUID
from sqlmodel import Session, select
from .tables import Brand
from .models import Brand
def get_or_create_brand(session: Session, canonical_name: str,

View File

@@ -6,7 +6,7 @@ from uuid import UUID
from sqlmodel import Session, select
from .tables import Checkpoint
from .models import Checkpoint
def get_latest_checkpoint(session: Session, timeline_id: UUID, parent_id: UUID | None = None) -> Checkpoint | None:

View File

@@ -30,5 +30,5 @@ def get_session() -> Session:
def create_tables():
"""Create all SQLModel tables."""
from sqlmodel import SQLModel
from . import tables # noqa — registers all table classes
from . import models # noqa — registers all table classes
SQLModel.metadata.create_all(get_engine())

View File

@@ -0,0 +1,142 @@
{
"name": "soccer_broadcast",
"pipeline": {
"name": "soccer_broadcast",
"profile_name": "soccer_broadcast",
"stages": [
{
"name": "extract_frames",
"branch": "trunk"
},
{
"name": "filter_scenes",
"branch": "trunk"
},
{
"name": "field_segmentation",
"branch": "trunk"
},
{
"name": "detect_edges",
"branch": "hoarding"
},
{
"name": "detect_objects",
"branch": "objects"
},
{
"name": "preprocess"
},
{
"name": "run_ocr"
},
{
"name": "match_brands"
},
{
"name": "escalate_vlm"
},
{
"name": "escalate_cloud"
},
{
"name": "compile_report"
}
],
"edges": [
{
"source": "extract_frames",
"target": "filter_scenes"
},
{
"source": "filter_scenes",
"target": "field_segmentation"
},
{
"source": "field_segmentation",
"target": "detect_edges"
},
{
"source": "field_segmentation",
"target": "detect_objects"
},
{
"source": "detect_edges",
"target": "preprocess"
},
{
"source": "detect_objects",
"target": "preprocess"
},
{
"source": "preprocess",
"target": "run_ocr"
},
{
"source": "run_ocr",
"target": "match_brands"
},
{
"source": "match_brands",
"target": "escalate_vlm"
},
{
"source": "escalate_vlm",
"target": "escalate_cloud"
},
{
"source": "escalate_cloud",
"target": "compile_report"
}
]
},
"configs": {
"extract_frames": {
"fps": 2.0,
"max_frames": 500
},
"filter_scenes": {
"hamming_threshold": 8,
"enabled": true
},
"field_segmentation": {
"enabled": true,
"hue_low": 30,
"hue_high": 85,
"sat_low": 30,
"sat_high": 255,
"val_low": 30,
"val_high": 255,
"morph_kernel": 15,
"min_area_ratio": 0.05
},
"detect_edges": {
"enabled": true,
"edge_canny_low": 50,
"edge_canny_high": 150,
"edge_hough_threshold": 80,
"edge_hough_min_length": 100,
"edge_hough_max_gap": 10,
"edge_pair_max_distance": 200,
"edge_pair_min_distance": 15
},
"detect_objects": {
"model_name": "yolov8n.pt",
"confidence_threshold": 0.3,
"target_classes": []
},
"run_ocr": {
"languages": [
"en",
"es"
],
"min_confidence": 0.5
},
"match_brands": {
"fuzzy_threshold": 75
},
"escalate_vlm": {
"vlm_prompt_template": "Identify the brand or sponsor visible in this cropped region from a soccer broadcast.{hint}{text} Respond with: brand, confidence (0-1), reasoning."
}
}
}

View File

@@ -7,7 +7,7 @@ from uuid import UUID
from sqlmodel import Session, select
from .tables import Job
from .models import Job
def list_jobs(session: Session, parent_id: Optional[UUID] = None, status: Optional[str] = None) -> list[Job]:

View File

@@ -44,7 +44,7 @@ class SourceType(str, Enum):
class MediaAsset(SQLModel, table=True):
"""A video/audio file registered in the system."""
__tablename__ = "media_assets"
__tablename__ = "media_asset"
id: UUID = Field(default_factory=uuid4, primary_key=True)
filename: str
@@ -67,7 +67,7 @@ class MediaAsset(SQLModel, table=True):
class TranscodePreset(SQLModel, table=True):
"""A reusable transcoding configuration (like Handbrake presets)."""
__tablename__ = "transcode_presets"
__tablename__ = "transcode_preset"
id: UUID = Field(default_factory=uuid4, primary_key=True)
name: str
@@ -90,12 +90,13 @@ class TranscodePreset(SQLModel, table=True):
class Job(SQLModel, table=True):
"""A pipeline job."""
__tablename__ = "jobs"
__tablename__ = "job"
id: UUID = Field(default_factory=uuid4, primary_key=True)
source_asset_id: UUID = Field(index=True)
video_path: str
profile_name: str = "soccer_broadcast"
timeline_id: Optional[UUID] = None
parent_id: Optional[UUID] = None
run_type: RunType = "initial"
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
@@ -107,7 +108,6 @@ class Job(SQLModel, table=True):
brands_found: int = 0
cloud_llm_calls: int = 0
estimated_cost_usd: float = 0.0
celery_task_id: Optional[str] = None
priority: int = 0
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
started_at: Optional[datetime] = None
@@ -115,7 +115,7 @@ class Job(SQLModel, table=True):
class Timeline(SQLModel, table=True):
"""The frame sequence from a source video."""
__tablename__ = "timelines"
__tablename__ = "timeline"
id: UUID = Field(default_factory=uuid4, primary_key=True)
source_asset_id: Optional[UUID] = Field(default=None, index=True)
@@ -129,10 +129,11 @@ class Timeline(SQLModel, table=True):
class Checkpoint(SQLModel, table=True):
"""A snapshot of pipeline state on a timeline."""
__tablename__ = "checkpoints"
__tablename__ = "checkpoint"
id: UUID = Field(default_factory=uuid4, primary_key=True)
timeline_id: UUID
job_id: Optional[UUID] = Field(default=None, index=True)
parent_id: Optional[UUID] = None
stage_outputs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
@@ -143,7 +144,7 @@ class Checkpoint(SQLModel, table=True):
class Brand(SQLModel, table=True):
"""A brand discovered or registered in the system."""
__tablename__ = "brands"
__tablename__ = "brand"
id: UUID = Field(default_factory=uuid4, primary_key=True)
canonical_name: str = Field(index=True)
@@ -154,3 +155,12 @@ class Brand(SQLModel, table=True):
total_airings: int = 0
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
class Profile(SQLModel, table=True):
"""A content type profile."""
__tablename__ = "profile"
id: UUID = Field(default_factory=uuid4, primary_key=True)
name: str
pipeline: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
configs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))

43
core/db/seed.py Normal file
View File

@@ -0,0 +1,43 @@
"""
Seed data — insert initial profile rows if they don't exist.
Called on startup after create_tables().
"""
import json
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
SEED_DIR = Path(__file__).parent / "fixtures"
def seed_profiles():
"""Insert seed profiles from JSON fixtures if not already present."""
from .connection import get_session
from .models import Profile
fixtures = list(SEED_DIR.glob("*.json"))
if not fixtures:
return
with get_session() as session:
for f in fixtures:
data = json.loads(f.read_text())
name = data["name"]
existing = session.query(Profile).filter(Profile.name == name).first()
if existing:
logger.debug("Profile %s already exists, skipping seed", name)
continue
profile = Profile(
name=name,
pipeline=data.get("pipeline", {}),
configs=data.get("configs", {}),
)
session.add(profile)
logger.info("Seeded profile: %s", name)
session.commit()

View File

@@ -1,96 +0,0 @@
"""
SQLModel table definitions.
Generated by modelgen from core/schema/models/. Do not edit directly.
"""
from __future__ import annotations
from datetime import datetime
from typing import Any, Dict, List, Optional
from uuid import UUID, uuid4
from sqlalchemy import JSON
from sqlmodel import Column, Field, SQLModel
class MediaAsset(SQLModel, table=True):
__tablename__ = "media_asset"
id: UUID = Field(default_factory=uuid4, primary_key=True)
filename: str
path: str
status: str = "pending"
size_bytes: int = 0
duration_seconds: float = 0.0
width: Optional[int] = None
height: Optional[int] = None
fps: Optional[float] = None
codec: Optional[str] = None
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
class Job(SQLModel, table=True):
__tablename__ = "job"
id: UUID = Field(default_factory=uuid4, primary_key=True)
source_asset_id: UUID = Field(index=True)
video_path: str
profile_name: str = "soccer_broadcast"
parent_id: Optional[UUID] = Field(default=None, index=True)
run_type: str = "initial"
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
status: str = "pending"
current_stage: Optional[str] = None
progress: float = 0.0
error_message: Optional[str] = None
total_detections: int = 0
brands_found: int = 0
cloud_llm_calls: int = 0
estimated_cost_usd: float = 0.0
priority: int = 0
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
class Timeline(SQLModel, table=True):
__tablename__ = "timeline"
id: UUID = Field(default_factory=uuid4, primary_key=True)
source_asset_id: Optional[UUID] = Field(default=None, index=True)
source_video: str = ""
profile_name: str = ""
fps: float = 2.0
frames_prefix: str = ""
frames_manifest: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
frames_meta: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
class Checkpoint(SQLModel, table=True):
__tablename__ = "checkpoint"
id: UUID = Field(default_factory=uuid4, primary_key=True)
timeline_id: UUID = Field(index=True)
parent_id: Optional[UUID] = Field(default=None, index=True)
stage_outputs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
stats: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
is_scenario: bool = False
scenario_label: str = ""
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
class Brand(SQLModel, table=True):
__tablename__ = "brand"
id: UUID = Field(default_factory=uuid4, primary_key=True)
canonical_name: str = Field(index=True)
aliases: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
source: str = "ocr"
confirmed: bool = False
airings: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
total_airings: int = 0
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow)

View File

@@ -16,4 +16,4 @@ from .storage import (
load_stage_output,
)
from .frames import save_frames, load_frames
from .runner_bridge import checkpoint_after_stage, reset_checkpoint_state
from .runner_bridge import checkpoint_after_stage, reset_checkpoint_state, get_timeline_id

View File

@@ -9,7 +9,7 @@ import tempfile
import numpy as np
from PIL import Image
from detect.models import Frame
from core.detect.models import Frame
logger = logging.getLogger(__name__)

View File

@@ -11,7 +11,7 @@ import logging
import uuid
from detect import emit
from core.detect import emit
# TODO: migrate to Timeline/Branch/Checkpoint model
# These old functions no longer exist — replay needs rework
def _not_migrated(*args, **kwargs):
@@ -19,66 +19,14 @@ def _not_migrated(*args, **kwargs):
load_checkpoint = _not_migrated
list_checkpoints = _not_migrated
from detect.graph import NODES, build_graph
from core.detect.graph import NODES, build_graph
logger = logging.getLogger(__name__)
class OverrideProfile:
"""
Wraps a ContentTypeProfile and patches config methods with overrides.
Override dict structure:
{
"frame_extraction": {"fps": 1.0},
"scene_filter": {"hamming_threshold": 12},
"region_analysis": {"edge_canny_low": 30, "edge_canny_high": 120},
"detection": {"confidence_threshold": 0.5},
"ocr": {"languages": ["en", "es"], "min_confidence": 0.3},
"resolver": {"fuzzy_threshold": 60},
}
"""
def __init__(self, base, overrides: dict):
self._base = base
self._overrides = overrides
def __getattr__(self, name):
return getattr(self._base, name)
def _patch(self, config, key: str):
patches = self._overrides.get(key, {})
for k, v in patches.items():
if hasattr(config, k):
setattr(config, k, v)
return config
def frame_extraction_config(self):
return self._patch(self._base.frame_extraction_config(), "frame_extraction")
def scene_filter_config(self):
return self._patch(self._base.scene_filter_config(), "scene_filter")
def region_analysis_config(self):
return self._patch(self._base.region_analysis_config(), "region_analysis")
def detection_config(self):
return self._patch(self._base.detection_config(), "detection")
def ocr_config(self):
return self._patch(self._base.ocr_config(), "ocr")
def resolver_config(self):
return self._patch(self._base.resolver_config(), "resolver")
def vlm_prompt(self, crop_context):
return self._base.vlm_prompt(crop_context)
def aggregate(self, detections):
return self._base.aggregate(detections)
def auxiliary_detections(self, source):
return self._base.auxiliary_detections(source)
# OverrideProfile removed — config overrides are now handled by dict merging
# in _load_profile() (nodes.py) and replay_single_stage (below).
def replay_from(
@@ -183,10 +131,16 @@ def replay_single_stage(
state = load_checkpoint(job_id, previous_stage)
# Build profile with overrides
from detect.profiles import get_profile
from core.detect.profile import get_profile, get_stage_config
profile = get_profile(state.get("profile_name", "soccer_broadcast"))
if config_overrides:
profile = OverrideProfile(profile, config_overrides)
merged_configs = dict(profile.get("configs", {}))
for sname, soverrides in config_overrides.items():
if sname in merged_configs:
merged_configs[sname] = {**merged_configs[sname], **soverrides}
else:
merged_configs[sname] = soverrides
profile = {**profile, "configs": merged_configs}
# Run the stage function directly (not through the graph)
if stage == "detect_edges":
@@ -207,9 +161,11 @@ def _replay_detect_edges(
) -> dict:
"""Run edge detection on checkpoint frames, optionally with debug overlays."""
import os
from detect.stages.edge_detector import detect_edge_regions
from core.detect.stages.edge_detector import detect_edge_regions
config = profile.region_analysis_config()
from core.detect.profile import get_stage_config
from core.detect.stages.models import RegionAnalysisConfig
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
frames = state.get("filtered_frames", [])
if frame_refs:
@@ -231,7 +187,7 @@ def _replay_detect_edges(
if debug and frames:
debug_data = {}
if inference_url:
from detect.inference import InferenceClient
from core.detect.inference import InferenceClient
client = InferenceClient(base_url=inference_url, job_id=job_id)
for frame in frames:
dr = client.detect_edges_debug(
@@ -252,7 +208,7 @@ def _replay_detect_edges(
}
else:
# Local mode — import GPU module directly
from detect.stages.edge_detector import _load_cv_edges
from core.detect.stages.edge_detector import _load_cv_edges
edges_mod = _load_cv_edges()
for frame in frames:
dr = edges_mod.detect_edges_debug(

View File

@@ -0,0 +1,97 @@
"""
Runner bridge — checkpoint hook called by PipelineRunner after each stage.
Owns the per-job state (timeline, frame manifest, checkpoint chain) that
the runner shouldn't know about.
Timeline and Job are independent entities:
- One Timeline can serve multiple Jobs (re-run with different params)
- One Job operates on one Timeline (set after frame extraction)
- Checkpoints belong to Timeline, tagged with the Job that created them
"""
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
# Per-job state
_timeline_id: dict[str, str] = {}
_frames_manifest: dict[str, dict[int, str]] = {}
_latest_checkpoint: dict[str, str] = {}
def reset_checkpoint_state(job_id: str):
"""Clean up per-job checkpoint state. Called when pipeline finishes."""
_timeline_id.pop(job_id, None)
_frames_manifest.pop(job_id, None)
_latest_checkpoint.pop(job_id, None)
def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict):
"""
Save a checkpoint after a stage completes.
Called by the runner. Handles:
- Timeline creation (once, on extract_frames)
- Frame upload (via create_timeline)
- Stage output serialization (via stage registry)
- Checkpoint chain (parent → child)
"""
if not job_id:
return
from .storage import create_timeline, save_stage_output
from core.detect.stages.base import _REGISTRY
merged = {**state, **result}
# On extract_frames: create Timeline + upload frames + root checkpoint
if stage_name == "extract_frames" and job_id not in _timeline_id:
frames = merged.get("frames", [])
video_path = merged.get("video_path", "")
profile_name = merged.get("profile_name", "")
tid, cid = create_timeline(
source_video=video_path,
profile_name=profile_name,
frames=frames,
)
_timeline_id[job_id] = tid
_latest_checkpoint[job_id] = cid
logger.info("Job %s → Timeline %s (root checkpoint %s)", job_id, tid, cid)
# Emit timeline_id via SSE so the UI can use it for checkpoint loads
from core.detect import emit
emit.log(job_id, "Checkpoint", "INFO", f"timeline_id={tid}")
return
# For subsequent stages: save checkpoint on the timeline
tid = _timeline_id.get(job_id)
if not tid:
logger.warning("No timeline for job %s, skipping checkpoint", job_id)
return
# Serialize stage output using the stage's serialize_fn if available
stage_cls = _REGISTRY.get(stage_name)
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
if serialize_fn:
output_json = serialize_fn(merged, job_id)
else:
output_json = {}
parent_id = _latest_checkpoint.get(job_id)
new_checkpoint_id = save_stage_output(
timeline_id=tid,
parent_checkpoint_id=parent_id,
stage_name=stage_name,
output_json=output_json,
job_id=job_id,
)
_latest_checkpoint[job_id] = new_checkpoint_id
def get_timeline_id(job_id: str) -> str | None:
"""Get the timeline_id for a running job. Used by the UI to load checkpoints."""
return _timeline_id.get(job_id)

View File

@@ -28,7 +28,7 @@ def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
Calls each registered stage's serialize_fn for stage-owned data.
Envelope fields (job_id, etc.) are copied directly.
"""
from detect.stages.base import _REGISTRY
from core.detect.stages.base import _REGISTRY
checkpoint = {}
@@ -64,7 +64,7 @@ def deserialize_state(checkpoint: dict, frames: list) -> dict:
Calls each stage's deserialize_fn to restore stage-owned data.
"""
from detect.stages.base import _REGISTRY
from core.detect.stages.base import _REGISTRY
frame_map = {f.sequence: f for f in frames}

View File

@@ -33,7 +33,7 @@ def create_timeline(
Returns (timeline_id, checkpoint_id).
"""
from core.db.tables import Timeline, Checkpoint
from core.db.models import Timeline, Checkpoint
from core.db.connection import get_session
with get_session() as session:
@@ -81,7 +81,7 @@ def create_timeline(
def get_timeline_frames(timeline_id: str) -> list:
"""Load frames from a timeline (from MinIO) as Frame objects."""
from core.db.tables import Timeline
from core.db.models import Timeline
from core.db.connection import get_session
with get_session() as session:
@@ -96,7 +96,7 @@ def get_timeline_frames(timeline_id: str) -> list:
def get_timeline_frames_b64(timeline_id: str) -> list[dict]:
"""Load frames as base64 JPEG (lightweight, no numpy)."""
from core.db.tables import Timeline
from core.db.models import Timeline
from core.db.connection import get_session
from .frames import load_frames_b64
@@ -123,6 +123,7 @@ def save_stage_output(
stats: dict | None = None,
is_scenario: bool = False,
scenario_label: str = "",
job_id: str | None = None,
) -> str:
"""
Save a stage's output as a new checkpoint (child of parent).
@@ -130,7 +131,7 @@ def save_stage_output(
Carries forward stage outputs from parent + adds the new one.
Returns the new checkpoint ID.
"""
from core.db.tables import Checkpoint
from core.db.models import Checkpoint
from core.db.connection import get_session
with get_session() as session:
@@ -146,6 +147,7 @@ def save_stage_output(
checkpoint = Checkpoint(
timeline_id=UUID(timeline_id),
job_id=UUID(job_id) if job_id else None,
parent_id=UUID(parent_checkpoint_id) if parent_checkpoint_id else None,
stage_outputs={**parent_outputs, stage_name: output_json},
config_overrides={**parent_config, **(config_overrides or {})},
@@ -165,7 +167,7 @@ def save_stage_output(
def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None:
"""Load a stage's output from a checkpoint."""
from core.db.tables import Checkpoint
from core.db.models import Checkpoint
from core.db.connection import get_session
with get_session() as session:

View File

@@ -16,8 +16,8 @@ from __future__ import annotations
import dataclasses
from datetime import datetime, timezone
from detect.events import push_detect_event
from detect.models import PipelineStats
from core.detect.events import push_detect_event
from core.detect.models import PipelineStats
# Log level ordering for comparison
_LEVEL_ORDER = {"DEBUG": 0, "INFO": 1, "WARN": 2, "ERROR": 3}

View File

@@ -4,8 +4,8 @@ Graph event emission — node state tracking + SSE graph_update events.
from __future__ import annotations
from detect import emit
from detect.state import DetectState
from core.detect import emit
from core.detect.state import DetectState
# Track node states across pipeline runs

View File

@@ -1,28 +1,39 @@
"""
Pipeline node functions one per stage.
Each node: reads state, runs stage logic, emits transitions, returns output dict.
Each node: reads state, gets config from profile dict, runs stage logic,
emits transitions, returns output dict.
"""
from __future__ import annotations
import os
from detect import emit
from detect.models import PipelineStats
from detect.profiles import SoccerBroadcastProfile
from detect.state import DetectState
from detect.stages.frame_extractor import extract_frames
from detect.stages.scene_filter import scene_filter
from detect.stages.edge_detector import detect_edge_regions
from detect.stages.yolo_detector import detect_objects
from detect.stages.preprocess import preprocess_regions
from detect.stages.ocr_stage import run_ocr
from detect.stages.brand_resolver import resolve_brands
from detect.stages.vlm_local import escalate_vlm
from detect.stages.vlm_cloud import escalate_cloud
from detect.stages.aggregator import compile_report
from detect.tracing import trace_node, flush as flush_traces
from core.detect import emit
from core.detect.models import CropContext, PipelineStats
from core.detect.profile import get_profile, get_stage_config, build_vlm_prompt, aggregate_detections
from core.detect.stages.models import (
DetectionConfig,
FieldSegmentationConfig,
FrameExtractionConfig,
OCRConfig,
RegionAnalysisConfig,
ResolverConfig,
SceneFilterConfig,
)
from core.detect.state import DetectState
from core.detect.stages.frame_extractor import extract_frames
from core.detect.stages.scene_filter import scene_filter
from core.detect.stages.field_segmentation import run_field_segmentation
from core.detect.stages.edge_detector import detect_edge_regions
from core.detect.stages.yolo_detector import detect_objects
from core.detect.stages.preprocess import preprocess_regions
from core.detect.stages.ocr_stage import run_ocr
from core.detect.stages.brand_resolver import resolve_brands
from core.detect.stages.vlm_local import escalate_vlm
from core.detect.stages.vlm_cloud import escalate_cloud
from core.detect.stages.aggregator import compile_report
from core.detect.tracing import trace_node, flush as flush_traces
from .events import emit_transition
@@ -31,6 +42,7 @@ INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
NODES = [
"extract_frames",
"filter_scenes",
"field_segmentation",
"detect_edges",
"detect_objects",
"preprocess",
@@ -42,17 +54,21 @@ NODES = [
]
def _get_profile(state: DetectState):
def _load_profile(state: DetectState) -> dict:
"""Load profile dict, apply config overrides if present."""
name = state.get("profile_name", "soccer_broadcast")
if name == "soccer_broadcast":
profile = SoccerBroadcastProfile()
else:
raise ValueError(f"Unknown profile: {name}")
profile = get_profile(name)
overrides = state.get("config_overrides")
if overrides:
from detect.checkpoint.replay import OverrideProfile
profile = OverrideProfile(profile, overrides)
# Merge overrides into a copy of the profile configs
merged_configs = dict(profile.get("configs", {}))
for stage_name, stage_overrides in overrides.items():
if stage_name in merged_configs:
merged_configs[stage_name] = {**merged_configs[stage_name], **stage_overrides}
else:
merged_configs[stage_name] = stage_overrides
profile = {**profile, "configs": merged_configs}
return profile
@@ -70,16 +86,16 @@ def node_extract_frames(state: DetectState) -> dict:
source_asset_id = state.get("source_asset_id")
if source_asset_id and not state.get("session_brands"):
from detect.stages.brand_resolver import build_session_dict
from core.detect.stages.brand_resolver import build_session_dict
session_brands = build_session_dict(source_asset_id)
state["session_brands"] = session_brands
_emit(state, "extract_frames", "running")
with trace_node(state, "extract_frames") as span:
profile = _get_profile(state)
config = profile.frame_extraction_config()
frames = extract_frames(state["video_path"], config, job_id=state.get("job_id"))
profile = _load_profile(state)
config = FrameExtractionConfig(**get_stage_config(profile, "extract_frames"))
frames = extract_frames(state["video_path"], config, job_id=job_id)
span.set_output({"frames_extracted": len(frames)})
_emit(state, "extract_frames", "done")
@@ -90,8 +106,8 @@ def node_filter_scenes(state: DetectState) -> dict:
_emit(state, "filter_scenes", "running")
with trace_node(state, "filter_scenes") as span:
profile = _get_profile(state)
config = profile.scene_filter_config()
profile = _load_profile(state)
config = SceneFilterConfig(**get_stage_config(profile, "filter_scenes"))
frames = state.get("frames", [])
kept = scene_filter(frames, config, job_id=state.get("job_id"))
span.set_output({"frames_in": len(frames), "frames_kept": len(kept)})
@@ -103,17 +119,42 @@ def node_filter_scenes(state: DetectState) -> dict:
return {"filtered_frames": kept, "stats": stats}
def node_field_segmentation(state: DetectState) -> dict:
_emit(state, "field_segmentation", "running")
with trace_node(state, "field_segmentation") as span:
profile = _load_profile(state)
config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation"))
frames = state.get("filtered_frames", [])
job_id = state.get("job_id")
result = run_field_segmentation(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
span.set_output({
"frames": len(frames),
"avg_coverage": sum(result["field_coverage"].values()) / max(len(result["field_coverage"]), 1),
})
_emit(state, "field_segmentation", "done")
return {
"field_masks": result["field_masks"],
"field_boundaries": result["field_boundaries"],
"field_coverage": result["field_coverage"],
}
def node_detect_edges(state: DetectState) -> dict:
_emit(state, "detect_edges", "running")
with trace_node(state, "detect_edges") as span:
profile = _get_profile(state)
config = profile.region_analysis_config()
profile = _load_profile(state)
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
frames = state.get("filtered_frames", [])
field_masks = state.get("field_masks", {})
job_id = state.get("job_id")
regions = detect_edge_regions(
frames, config, inference_url=INFERENCE_URL, job_id=job_id,
field_masks=field_masks,
)
total = sum(len(r) for r in regions.values())
span.set_output({"frames": len(frames), "edge_regions": total})
@@ -129,8 +170,8 @@ def node_detect_objects(state: DetectState) -> dict:
_emit(state, "detect_objects", "running")
with trace_node(state, "detect_objects") as span:
profile = _get_profile(state)
config = profile.detection_config()
profile = _load_profile(state)
config = DetectionConfig(**get_stage_config(profile, "detect_objects"))
frames = state.get("filtered_frames", [])
job_id = state.get("job_id")
@@ -149,13 +190,12 @@ def node_preprocess(state: DetectState) -> dict:
_emit(state, "preprocess", "running")
with trace_node(state, "preprocess") as span:
profile = _get_profile(state)
profile = _load_profile(state)
prep_config = get_stage_config(profile, "preprocess")
frames = state.get("filtered_frames", [])
boxes = state.get("boxes_by_frame", {})
job_id = state.get("job_id")
overrides = state.get("config_overrides", {})
prep_config = overrides.get("preprocessing", {})
do_contrast = prep_config.get("contrast", True)
do_deskew = prep_config.get("deskew", False)
do_binarize = prep_config.get("binarize", False)
@@ -178,8 +218,8 @@ def node_run_ocr(state: DetectState) -> dict:
_emit(state, "run_ocr", "running")
with trace_node(state, "run_ocr") as span:
profile = _get_profile(state)
config = profile.ocr_config()
profile = _load_profile(state)
config = OCRConfig(**get_stage_config(profile, "run_ocr"))
frames = state.get("filtered_frames", [])
boxes = state.get("boxes_by_frame", {})
job_id = state.get("job_id")
@@ -198,18 +238,18 @@ def node_match_brands(state: DetectState) -> dict:
_emit(state, "match_brands", "running")
with trace_node(state, "match_brands") as span:
profile = _get_profile(state)
resolver_config = profile.resolver_config()
profile = _load_profile(state)
config = ResolverConfig(**get_stage_config(profile, "match_brands"))
candidates = state.get("text_candidates", [])
session_brands = state.get("session_brands", {})
job_id = state.get("job_id")
source_asset_id = state.get("source_asset_id")
matched, unresolved = resolve_brands(
candidates, resolver_config,
candidates, config,
session_brands=session_brands,
source_asset_id=source_asset_id,
content_type=profile.name, job_id=job_id,
content_type=profile["name"], job_id=job_id,
)
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
@@ -221,15 +261,19 @@ def node_escalate_vlm(state: DetectState) -> dict:
_emit(state, "escalate_vlm", "running")
with trace_node(state, "escalate_vlm") as span:
profile = _get_profile(state)
profile = _load_profile(state)
vlm_config = get_stage_config(profile, "escalate_vlm")
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
candidates = state.get("unresolved_candidates", [])
job_id = state.get("job_id")
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
vlm_matched, still_unresolved = escalate_vlm(
candidates,
vlm_prompt_fn=profile.vlm_prompt,
vlm_prompt_fn=vlm_prompt_fn,
inference_url=INFERENCE_URL,
content_type=profile.name,
content_type=profile["name"],
source_asset_id=state.get("source_asset_id"),
job_id=job_id,
)
@@ -254,16 +298,20 @@ def node_escalate_cloud(state: DetectState) -> dict:
_emit(state, "escalate_cloud", "running")
with trace_node(state, "escalate_cloud") as span:
profile = _get_profile(state)
profile = _load_profile(state)
vlm_config = get_stage_config(profile, "escalate_vlm")
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
candidates = state.get("unresolved_candidates", [])
job_id = state.get("job_id")
stats = state.get("stats", PipelineStats())
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
cloud_matched = escalate_cloud(
candidates,
vlm_prompt_fn=profile.vlm_prompt,
vlm_prompt_fn=vlm_prompt_fn,
stats=stats,
content_type=profile.name,
content_type=profile["name"],
source_asset_id=state.get("source_asset_id"),
job_id=job_id,
)
@@ -283,7 +331,7 @@ def node_compile_report(state: DetectState) -> dict:
_emit(state, "compile_report", "running")
with trace_node(state, "compile_report") as span:
profile = _get_profile(state)
profile = _load_profile(state)
detections = state.get("detections", [])
stats = state.get("stats", PipelineStats())
job_id = state.get("job_id")
@@ -292,7 +340,7 @@ def node_compile_report(state: DetectState) -> dict:
detections=detections,
stats=stats,
video_source=state.get("video_path", ""),
content_type=profile.name,
content_type=profile["name"],
job_id=job_id,
)
@@ -306,6 +354,7 @@ def node_compile_report(state: DetectState) -> dict:
NODE_FUNCTIONS = [
("extract_frames", node_extract_frames),
("filter_scenes", node_filter_scenes),
("field_segmentation", node_field_segmentation),
("detect_edges", node_detect_edges),
("detect_objects", node_detect_objects),
("preprocess", node_preprocess),

View File

@@ -13,8 +13,8 @@ import logging
import os
import threading
from core.schema.models.pipeline_config import PipelineConfig
from detect.state import DetectState
from core.detect.stages.models import PipelineConfig
from core.detect.state import DetectState
from .nodes import NODES, NODE_FUNCTIONS
logger = logging.getLogger(__name__)
@@ -118,7 +118,7 @@ def _wait_if_paused(job_id: str, node_name: str):
if _pause_after_stage.get(job_id, False):
gate.clear()
from detect import emit
from core.detect import emit
emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}")
while not gate.wait(timeout=0.5):
@@ -237,7 +237,7 @@ class PipelineRunner:
# 4. Checkpoint
if self.do_checkpoint:
from detect.checkpoint import checkpoint_after_stage
from core.detect.checkpoint import checkpoint_after_stage
checkpoint_after_stage(job_id, stage_name, state, result)
# 5. Pause check
@@ -256,11 +256,11 @@ def get_pipeline(
start_from: str | None = None,
) -> PipelineRunner:
"""Return a PipelineRunner for the given profile."""
from detect.profiles import get_profile
from core.detect.profile import get_profile, pipeline_config_from_dict
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
profile = get_profile(profile_name)
config = profile.pipeline_config()
config = pipeline_config_from_dict(profile["pipeline"])
return PipelineRunner(
config=config,

View File

@@ -231,6 +231,20 @@ class InferenceClient:
pair_count=data.get("pair_count", 0),
)
def post(self, path: str, payload: dict) -> dict | None:
"""Generic POST to the inference server. Returns JSON response or None on error."""
try:
resp = self.session.post(
f"{self.base_url}{path}",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
return resp.json()
except Exception as e:
logger.warning("Inference POST %s failed: %s", path, e)
return None
def load_model(self, model: str, quantization: str = "fp16") -> None:
"""Request the server to load a model into VRAM."""
self.session.post(

View File

@@ -2,7 +2,7 @@
Inference response types.
These are the shapes returned by the inference server.
Kept separate from detect.models to avoid coupling the
Kept separate from core.detect.models to avoid coupling the
inference protocol to pipeline internals.
"""

View File

@@ -2,8 +2,8 @@
Detection pipeline runtime models.
These are the data structures that flow between pipeline stages.
They contain runtime types (np.ndarray) so modelgen skips them
not generated to SQLModel or TypeScript.
They contain runtime types (np.ndarray) so they live here, not in
core/schema/models/ (which is for modelgen source of truth).
"""
from __future__ import annotations
@@ -85,3 +85,11 @@ class DetectionReport:
brands: dict[str, BrandStats] = field(default_factory=dict)
timeline: list[BrandDetection] = field(default_factory=list)
pipeline_stats: PipelineStats = field(default_factory=PipelineStats)
@dataclass
class CropContext:
"""Runtime type — holds image bytes for VLM prompts."""
image: bytes
surrounding_text: str = ""
position_hint: str = ""

107
core/detect/profile.py Normal file
View File

@@ -0,0 +1,107 @@
"""
Profile registry and helpers.
Loads profile data from Postgres.
A profile is a dict with keys: name, pipeline, configs.
"""
from __future__ import annotations
import logging
from typing import Any, Dict
from core.detect.stages.models import PipelineConfig, StageRef, Edge
from core.detect.models import (
BrandDetection,
BrandStats,
CropContext,
DetectionReport,
PipelineStats,
)
logger = logging.getLogger(__name__)
def get_profile(name: str) -> Dict[str, Any]:
"""Get a profile dict by name from the database."""
from core.db.connection import get_session
from core.db.models import Profile
with get_session() as session:
row = session.query(Profile).filter(Profile.name == name).first()
if row is None:
raise ValueError(f"Unknown profile: {name!r}")
return {
"name": row.name,
"pipeline": row.pipeline or {},
"configs": row.configs or {},
}
def list_profiles() -> list[str]:
"""List available profile names from the database."""
from core.db.connection import get_session
from core.db.models import Profile
with get_session() as session:
rows = session.query(Profile.name).all()
return [r[0] for r in rows]
def get_stage_config(profile: Dict[str, Any], stage_name: str) -> dict:
"""Get config values for a stage from a profile."""
return profile.get("configs", {}).get(stage_name, {})
def pipeline_config_from_dict(data: Dict[str, Any]) -> PipelineConfig:
"""Deserialize a PipelineConfig from a JSONB dict."""
stages = [StageRef(**s) for s in data.get("stages", [])]
edges = [Edge(**e) for e in data.get("edges", [])]
return PipelineConfig(
name=data.get("name", ""),
profile_name=data.get("profile_name", ""),
stages=stages,
edges=edges,
routing_rules=data.get("routing_rules", {}),
)
def build_vlm_prompt(crop_context: CropContext, template: str) -> str:
"""Build a VLM prompt from a template and crop context."""
hint = f" Position: {crop_context.position_hint}." if crop_context.position_hint else ""
text = f" Nearby text: '{crop_context.surrounding_text}'." if crop_context.surrounding_text else ""
return template.format(hint=hint, text=text)
def aggregate_detections(
detections: list[BrandDetection],
content_type: str,
) -> DetectionReport:
"""Group detections by brand into a report."""
brands: dict[str, BrandStats] = {}
for d in detections:
if d.brand not in brands:
brands[d.brand] = BrandStats()
s = brands[d.brand]
s.total_appearances += 1
s.total_screen_time += d.duration
s.avg_confidence = (
(s.avg_confidence * (s.total_appearances - 1) + d.confidence)
/ s.total_appearances
)
if s.first_seen == 0.0 or d.timestamp < s.first_seen:
s.first_seen = d.timestamp
if d.timestamp > s.last_seen:
s.last_seen = d.timestamp
return DetectionReport(
video_source="",
content_type=content_type,
duration_seconds=0.0,
brands=brands,
timeline=sorted(detections, key=lambda d: d.timestamp),
pipeline_stats=PipelineStats(),
)

View File

@@ -145,3 +145,19 @@ class RetryResponse(BaseModel):
status: str
task_id: str
job_id: str
class RunRequest(BaseModel):
"""Request body for launching a detection pipeline run."""
video_path: str
profile_name: str = "soccer_broadcast"
source_asset_id: str = ""
checkpoint: bool = True
skip_vlm: bool = False
skip_cloud: bool = False
log_level: str = "INFO"
class RunResponse(BaseModel):
"""Response after starting a pipeline run."""
status: str
job_id: str
video_path: str

View File

@@ -16,6 +16,7 @@ from .base import (
# Import all stage files to trigger auto-registration
from . import edge_detector # noqa: F401
from . import field_segmentation # noqa: F401
# Import registry for backward compat (other stages still use old pattern)
from . import registry # noqa: F401

View File

@@ -9,8 +9,8 @@ from __future__ import annotations
import logging
from detect import emit
from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
from core.detect import emit
from core.detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
logger = logging.getLogger(__name__)

View File

@@ -5,7 +5,7 @@ Each stage is a file that subclasses Stage. Auto-discovered via
__init_subclass__. No manual registration needed.
A stage:
- Has a StageDefinition (from schema) with name, config, IO
- Has a StageDefinition (generated from schema) with name, config, IO
- Implements run(frames, config) output
- Owns its output serialization (opaque blob)
- Optionally has a TypeScript port for browser-side execution
@@ -16,12 +16,30 @@ the format. The stage that wrote it is the only one that can read it.
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
import numpy as np
from core.schema.models.stages import StageConfigField, StageIO, StageDefinition
from core.detect.stages.models import (
StageConfigField,
StageIO,
StageDefinition,
)
# Legacy runtime extension — adds callable fields for old-style stages.
# New stages use Stage subclass with serialize()/deserialize() methods instead.
class LegacyStageDefinition:
"""Wraps a StageDefinition with callable serialize/deserialize functions."""
def __init__(self, definition: StageDefinition, fn=None, serialize_fn=None, deserialize_fn=None):
self._definition = definition
self.fn = fn
self.serialize_fn = serialize_fn
self.deserialize_fn = deserialize_fn
def __getattr__(self, name):
return getattr(self._definition, name)
# ---------------------------------------------------------------------------
@@ -30,12 +48,18 @@ from core.schema.models.stages import StageConfigField, StageIO, StageDefinition
# ---------------------------------------------------------------------------
_REGISTRY: dict[str, type['Stage']] = {}
_LEGACY_REGISTRY: dict[str, StageDefinition] = {}
_LEGACY_REGISTRY: dict[str, LegacyStageDefinition] = {}
def register_stage(definition: StageDefinition):
def register_stage(
definition: StageDefinition,
fn=None,
serialize_fn=None,
deserialize_fn=None,
):
"""Legacy registration for stages not yet converted to Stage subclass."""
_LEGACY_REGISTRY[definition.name] = definition
legacy = LegacyStageDefinition(definition, fn=fn, serialize_fn=serialize_fn, deserialize_fn=deserialize_fn)
_LEGACY_REGISTRY[definition.name] = legacy
class Stage:
@@ -55,13 +79,6 @@ class Stage:
_REGISTRY[cls.definition.name] = cls
def run(self, frames: list, config: dict) -> Any:
"""
Run the stage on a list of frames with the given config.
Config is a dict of parameter values (from slider UI or profile).
Returns the stage output whatever shape this stage produces.
Debug overlays are included when config has debug=True.
"""
raise NotImplementedError
def serialize(self, output: Any) -> bytes:
@@ -79,12 +96,15 @@ class Stage:
# Discovery API
# ---------------------------------------------------------------------------
def _all_definitions() -> dict[str, StageDefinition]:
"""Merge new Stage subclass registry + legacy registry."""
def _all_definitions():
"""Merge new Stage subclass registry + legacy registry.
Returns StageDefinition for new-style stages,
LegacyStageDefinition for legacy stages (has serialize_fn etc).
"""
merged = {}
# Legacy first, new overwrites (new takes precedence)
for name, defn in _LEGACY_REGISTRY.items():
merged[name] = defn
for name, legacy in _LEGACY_REGISTRY.items():
merged[name] = legacy
for name, cls in _REGISTRY.items():
merged[name] = cls.definition
return merged

View File

@@ -16,9 +16,9 @@ import logging
from rapidfuzz import fuzz
from detect import emit
from detect.models import BrandDetection, TextCandidate
from detect.profiles.base import ResolverConfig
from core.detect import emit
from core.detect.models import BrandDetection, TextCandidate
from core.detect.stages.models import ResolverConfig
logger = logging.getLogger(__name__)

View File

@@ -21,10 +21,10 @@ from typing import Any
from PIL import Image
from detect import emit
from detect.models import BoundingBox, Frame
from detect.stages.base import Stage
from core.schema.models.stages import StageDefinition, StageConfigField, StageIO
from core.detect import emit
from core.detect.models import BoundingBox, Frame
from core.detect.stages.base import Stage
from core.detect.stages.models import StageDefinition, StageConfigField, StageIO
logger = logging.getLogger(__name__)
@@ -42,14 +42,14 @@ class EdgeDetectionStage(Stage):
writes=["edge_regions_by_frame"],
),
config_fields=[
StageConfigField("enabled", "bool", True, "Enable edge detection"),
StageConfigField("edge_canny_low", "int", 50, "Canny low threshold", min=0, max=255),
StageConfigField("edge_canny_high", "int", 150, "Canny high threshold", min=0, max=255),
StageConfigField("edge_hough_threshold", "int", 80, "Hough accumulator threshold", min=1, max=500),
StageConfigField("edge_hough_min_length", "int", 100, "Min line length (px)", min=10, max=2000),
StageConfigField("edge_hough_max_gap", "int", 10, "Max line gap (px)", min=1, max=100),
StageConfigField("edge_pair_max_distance", "int", 200, "Max distance between line pair (px)", min=10, max=500),
StageConfigField("edge_pair_min_distance", "int", 15, "Min distance between line pair (px)", min=5, max=200),
StageConfigField(name="enabled", type="bool", default=True, description="Enable edge detection"),
StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255),
StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255),
StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500),
StageConfigField(name="edge_hough_min_length", type="int", default=100, description="Min line length (px)", min=10, max=2000),
StageConfigField(name="edge_hough_max_gap", type="int", default=10, description="Max line gap (px)", min=1, max=100),
StageConfigField(name="edge_pair_max_distance", type="int", default=200, description="Max distance between line pair (px)", min=10, max=500),
StageConfigField(name="edge_pair_min_distance", type="int", default=15, description="Min distance between line pair (px)", min=5, max=200),
],
tracks_element="edge_region",
)
@@ -143,8 +143,8 @@ class EdgeDetectionStage(Stage):
def _run_remote(self, frame: Frame, config: dict,
inference_url: str, job_id: str) -> list[BoundingBox]:
from detect.inference import InferenceClient
from detect.emit import _run_log_level
from core.detect.inference import InferenceClient
from core.detect.emit import _run_log_level
client = InferenceClient(
base_url=inference_url, job_id=job_id, log_level=_run_log_level,
@@ -208,7 +208,7 @@ def _load_cv_edges():
if _cv_edges_mod is None:
import importlib.util
from pathlib import Path
spec = importlib.util.spec_from_file_location("cv_edges", Path("gpu/models/cv/edges.py"))
spec = importlib.util.spec_from_file_location("cv_edges", Path("core/gpu/models/cv/edges.py"))
_cv_edges_mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(_cv_edges_mod)
return _cv_edges_mod
@@ -216,8 +216,43 @@ def _load_cv_edges():
# --- Backward compat: standalone function for graph.py ---
def detect_edge_regions(frames, config, inference_url=None, job_id=None):
"""Convenience wrapper — calls EdgeDetectionStage.run()."""
def _filter_by_field_mask(boxes, mask, margin_px=50):
"""
Keep only boxes that are near the pitch boundary (hoarding zone).
The field mask has 255=pitch, 0=not pitch. Hoardings sit just
outside the pitch boundary. We dilate the mask to create a
"boundary zone" and keep boxes whose center falls in the zone
between the dilated mask edge and the original mask.
"""
import cv2
import numpy as np
if mask is None or not boxes:
return boxes
# Dilate the pitch mask — the expansion zone is where hoardings are
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (margin_px * 2, margin_px * 2))
dilated = cv2.dilate(mask, kernel)
# Boundary zone = dilated but NOT original pitch
boundary_zone = cv2.bitwise_and(dilated, cv2.bitwise_not(mask))
kept = []
for box in boxes:
cx = box.x + box.w // 2
cy = box.y + box.h // 2
# Clamp to image bounds
cy = min(cy, boundary_zone.shape[0] - 1)
cx = min(cx, boundary_zone.shape[1] - 1)
if boundary_zone[cy, cx] > 0:
kept.append(box)
return kept
def detect_edge_regions(frames, config, inference_url=None, job_id=None, field_masks=None):
"""Convenience wrapper — calls EdgeDetectionStage.run(), optionally filters by field mask."""
stage = EdgeDetectionStage()
cfg = {
"enabled": config.enabled,
@@ -231,4 +266,23 @@ def detect_edge_regions(frames, config, inference_url=None, job_id=None):
"inference_url": inference_url,
"job_id": job_id,
}
return stage.run(frames, cfg)
all_boxes = stage.run(frames, cfg)
# Filter by field segmentation mask if available
if field_masks:
filtered_total = 0
original_total = sum(len(b) for b in all_boxes.values())
for seq, boxes in all_boxes.items():
mask = field_masks.get(seq)
if mask is not None:
all_boxes[seq] = _filter_by_field_mask(boxes, mask)
filtered_total += len(all_boxes[seq])
else:
filtered_total += len(boxes)
if original_total != filtered_total:
from core.detect import emit
emit.log(job_id, "EdgeDetection", "INFO",
f"Field mask filter: {original_total}{filtered_total} regions")
return all_boxes

View File

@@ -0,0 +1,141 @@
"""
Stage — Field Segmentation
Calls the GPU inference server to detect pitch boundaries via
HSV green mask + morphology. The CV code lives in core/gpu/models/cv/.
Outputs a mask and boundary that downstream stages use as spatial priors.
"""
from __future__ import annotations
import base64
import io
import logging
import numpy as np
from PIL import Image
from core.detect import emit
from core.detect.models import Frame
from core.detect.stages.base import Stage
from core.detect.stages.models import (
FieldSegmentationConfig,
StageConfigField,
StageDefinition,
StageIO,
)
logger = logging.getLogger(__name__)
class FieldSegmentationStage(Stage):
definition = StageDefinition(
name="field_segmentation",
label="Field Segmentation",
description="HSV green mask — detect pitch boundaries for spatial priors",
category="cv_analysis",
io=StageIO(
reads=["filtered_frames"],
writes=["field_mask"],
),
config_fields=[
StageConfigField(name="enabled", type="bool", default=True, description="Enable field segmentation"),
StageConfigField(name="hue_low", type="int", default=30, description="HSV hue lower bound", min=0, max=180),
StageConfigField(name="hue_high", type="int", default=85, description="HSV hue upper bound", min=0, max=180),
StageConfigField(name="sat_low", type="int", default=30, description="HSV saturation lower bound", min=0, max=255),
StageConfigField(name="sat_high", type="int", default=255, description="HSV saturation upper bound", min=0, max=255),
StageConfigField(name="val_low", type="int", default=30, description="HSV value lower bound", min=0, max=255),
StageConfigField(name="val_high", type="int", default=255, description="HSV value upper bound", min=0, max=255),
StageConfigField(name="morph_kernel", type="int", default=15, description="Morphology kernel size", min=3, max=51),
StageConfigField(name="min_area_ratio", type="float", default=0.05, description="Min contour area as fraction of frame", min=0.01, max=0.5),
],
)
def _frame_to_b64(frame: Frame) -> str:
"""Encode frame image as base64 JPEG."""
img = Image.fromarray(frame.image)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
return base64.b64encode(buf.getvalue()).decode()
def _decode_mask_b64(mask_b64: str) -> np.ndarray:
"""Decode a base64 PNG mask back to numpy array."""
data = base64.b64decode(mask_b64)
img = Image.open(io.BytesIO(data)).convert("L")
return np.array(img)
def run_field_segmentation(
frames: list[Frame],
config: FieldSegmentationConfig,
inference_url: str | None = None,
job_id: str | None = None,
) -> dict:
"""
Run field segmentation on all frames via the inference server.
Returns dict with:
field_masks: {seq: np.ndarray}
field_boundaries: {seq: [(x,y), ...]}
field_coverage: {seq: float}
"""
if not config.enabled:
emit.log(job_id, "FieldSegmentation", "INFO", "Disabled, skipping")
return {"field_masks": {}, "field_boundaries": {}, "field_coverage": {}}
import os
url = inference_url or os.environ.get("INFERENCE_URL")
if not url:
emit.log(job_id, "FieldSegmentation", "WARNING",
"No INFERENCE_URL, skipping field segmentation")
return {"field_masks": {}, "field_boundaries": {}, "field_coverage": {}}
emit.log(job_id, "FieldSegmentation", "INFO",
f"Segmenting {len(frames)} frames (hue={config.hue_low}-{config.hue_high})")
from core.detect.inference import InferenceClient
from core.detect.emit import _run_log_level
client = InferenceClient(base_url=url, job_id=job_id or "", log_level=_run_log_level)
field_masks = {}
field_boundaries = {}
field_coverage = {}
for frame in frames:
image_b64 = _frame_to_b64(frame)
resp = client.post("/segment_field", {
"image": image_b64,
"hue_low": config.hue_low,
"hue_high": config.hue_high,
"sat_low": config.sat_low,
"sat_high": config.sat_high,
"val_low": config.val_low,
"val_high": config.val_high,
"morph_kernel": config.morph_kernel,
"min_area_ratio": config.min_area_ratio,
})
if resp is None:
continue
mask_b64 = resp.get("mask_b64", "")
if mask_b64:
field_masks[frame.sequence] = _decode_mask_b64(mask_b64)
field_boundaries[frame.sequence] = resp.get("boundary", [])
field_coverage[frame.sequence] = resp.get("coverage", 0.0)
avg_coverage = sum(field_coverage.values()) / max(len(field_coverage), 1)
emit.log(job_id, "FieldSegmentation", "INFO",
f"Done: {len(frames)} frames, avg coverage {avg_coverage:.1%}")
return {
"field_masks": field_masks,
"field_boundaries": field_boundaries,
"field_coverage": field_coverage,
}

View File

@@ -16,9 +16,9 @@ import numpy as np
from PIL import Image
from core.ffmpeg.probe import probe_file
from detect import emit
from detect.models import Frame
from detect.profiles.base import FrameExtractionConfig
from core.detect import emit
from core.detect.models import Frame
from core.detect.stages.models import FrameExtractionConfig
def _load_frames(tmpdir: Path, fps: float) -> list[Frame]:

View File

@@ -0,0 +1,106 @@
"""
Pydantic Models - GENERATED FILE
Do not edit directly. Regenerate using modelgen.
"""
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel, Field
class StageConfigField(BaseModel):
"""A single tunable config parameter for the editor UI."""
name: str
type: str
default: Any
description: str = ""
min: Optional[float] = None
max: Optional[float] = None
options: Optional[List[str]] = None
class StageIO(BaseModel):
"""Declares what a stage reads and writes."""
reads: List[str] = Field(default_factory=list)
writes: List[str] = Field(default_factory=list)
optional_reads: List[str] = Field(default_factory=list)
class StageDefinition(BaseModel):
"""Complete metadata for a pipeline stage."""
name: str
label: str
description: str
category: str = "detection"
io: StageIO
config_fields: List[StageConfigField] = Field(default_factory=list)
tracks_element: Optional[str] = None
class FrameExtractionConfig(BaseModel):
"""FrameExtractionConfig(fps: float = 2.0, max_frames: int = 500)"""
fps: float = 2.0
max_frames: int = 500
class SceneFilterConfig(BaseModel):
"""SceneFilterConfig(hamming_threshold: int = 8, enabled: bool = True)"""
hamming_threshold: int = 8
enabled: bool = True
class DetectionConfig(BaseModel):
"""DetectionConfig(model_name: str = 'yolov8n.pt', confidence_threshold: float = 0.3, target_classes: List[str] = <factory>)"""
model_name: str = "yolov8n.pt"
confidence_threshold: float = 0.3
target_classes: List[str]
class OCRConfig(BaseModel):
"""OCRConfig(languages: List[str] = <factory>, min_confidence: float = 0.5)"""
languages: List[str]
min_confidence: float = 0.5
class ResolverConfig(BaseModel):
"""ResolverConfig(fuzzy_threshold: int = 75)"""
fuzzy_threshold: int = 75
class RegionAnalysisConfig(BaseModel):
"""RegionAnalysisConfig(enabled: bool = True, edge_canny_low: int = 50, edge_canny_high: int = 150, edge_hough_threshold: int = 80, edge_hough_min_length: int = 100, edge_hough_max_gap: int = 10, edge_pair_max_distance: int = 200, edge_pair_min_distance: int = 15)"""
enabled: bool = True
edge_canny_low: int = 50
edge_canny_high: int = 150
edge_hough_threshold: int = 80
edge_hough_min_length: int = 100
edge_hough_max_gap: int = 10
edge_pair_max_distance: int = 200
edge_pair_min_distance: int = 15
class FieldSegmentationConfig(BaseModel):
"""FieldSegmentationConfig(enabled: bool = True, hue_low: int = 30, hue_high: int = 85, sat_low: int = 30, sat_high: int = 255, val_low: int = 30, val_high: int = 255, morph_kernel: int = 15, min_area_ratio: float = 0.05)"""
enabled: bool = True
hue_low: int = 30
hue_high: int = 85
sat_low: int = 30
sat_high: int = 255
val_low: int = 30
val_high: int = 255
morph_kernel: int = 15
min_area_ratio: float = 0.05
class StageRef(BaseModel):
"""Reference to a stage in the pipeline graph."""
name: str
branch: str = "trunk"
execution_target: str = "local"
class Edge(BaseModel):
"""Connection between stages in the graph."""
source: str
target: str
condition: str = ""
class PipelineConfig(BaseModel):
"""Pipeline graph topology + routing rules."""
name: str
profile_name: str
stages: List[StageRef] = Field(default_factory=list)
edges: List[Edge] = Field(default_factory=list)
routing_rules: Dict[str, Any]

View File

@@ -18,9 +18,9 @@ from typing import TYPE_CHECKING
import numpy as np
from detect import emit
from detect.models import BoundingBox, Frame, TextCandidate
from detect.profiles.base import OCRConfig
from core.detect import emit
from core.detect.models import BoundingBox, Frame, TextCandidate
from core.detect.stages.models import OCRConfig
if TYPE_CHECKING:
pass
@@ -91,8 +91,8 @@ def run_ocr(
# Build these once per pipeline run, not per crop
if inference_url:
from detect.inference import InferenceClient
from detect.emit import _run_log_level
from core.detect.inference import InferenceClient
from core.detect.emit import _run_log_level
client = InferenceClient(base_url=inference_url, job_id=job_id or "", log_level=_run_log_level)
else:
model = _get_local_model(config.languages[0])

View File

@@ -15,8 +15,8 @@ import logging
import numpy as np
from detect import emit
from detect.models import BoundingBox, Frame
from core.detect import emit
from core.detect.models import BoundingBox, Frame
logger = logging.getLogger(__name__)
@@ -124,5 +124,5 @@ def _preprocess_remote(crop: np.ndarray, inference_url: str,
def _preprocess_local(crop: np.ndarray,
do_contrast: bool, do_deskew: bool, do_binarize: bool) -> np.ndarray:
"""Run preprocessing in-process (requires opencv-python-headless)."""
from gpu.models.preprocess import preprocess
from core.gpu.models.preprocess import preprocess
return preprocess(crop, do_binarize=do_binarize, do_deskew=do_deskew, do_contrast=do_contrast)

View File

@@ -0,0 +1,44 @@
"""Registration for CV analysis stages: edge detection."""
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
from core.detect.stages.base import register_stage
from ._serializers import serialize_dataclass_list, deserialize_bounding_box
def _ser_regions(state: dict, job_id: str) -> dict:
regions = state.get("edge_regions_by_frame", {})
serialized = {
str(seq): serialize_dataclass_list(bl) for seq, bl in regions.items()
}
return {"edge_regions_by_frame": serialized}
def _deser_regions(data: dict, job_id: str) -> dict:
regions = {}
for seq_str, box_dicts in data.get("edge_regions_by_frame", {}).items():
regions[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts]
return {"edge_regions_by_frame": regions}
def register():
edge_detection = StageDefinition(
name="detect_edges",
label="Edge Detection",
description="Canny + HoughLinesP — find horizontal line pairs (hoarding boundaries)",
category="cv_analysis",
io=StageIO(
reads=["filtered_frames"],
writes=["edge_regions_by_frame"],
),
config_fields=[
StageConfigField(name="enabled", type="bool", default=True, description="Enable region analysis"),
StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255),
StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255),
StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500),
StageConfigField(name="edge_hough_min_length", type="int", default=100, description="Min line length (px)", min=10, max=2000),
StageConfigField(name="edge_hough_max_gap", type="int", default=10, description="Max line gap (px)", min=1, max=100),
StageConfigField(name="edge_pair_max_distance", type="int", default=200, description="Max distance between line pair (px)", min=10, max=500),
StageConfigField(name="edge_pair_min_distance", type="int", default=15, description="Min distance between line pair (px)", min=5, max=200),
],
)
register_stage(edge_detection, serialize_fn=_ser_regions, deserialize_fn=_deser_regions)

View File

@@ -1,6 +1,7 @@
"""Registration for detection stages: YOLO, OCR."""
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
from core.detect.stages.base import register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
@@ -38,14 +39,12 @@ def register():
category="detection",
io=StageIO(reads=["filtered_frames"], writes=["boxes_by_frame"]),
config_fields=[
StageConfigField("model_name", "str", "yolov8n.pt", "YOLO model file"),
StageConfigField("confidence_threshold", "float", 0.3, "Min detection confidence", min=0.0, max=1.0),
StageConfigField("target_classes", "list[str]", [], "YOLO classes to detect (empty = all)"),
StageConfigField(name="model_name", type="str", default="yolov8n.pt", description="YOLO model file"),
StageConfigField(name="confidence_threshold", type="float", default=0.3, description="Min detection confidence", min=0.0, max=1.0),
StageConfigField(name="target_classes", type="list[str]", default=[], description="YOLO classes to detect (empty = all)"),
],
serialize_fn=_ser_detect,
deserialize_fn=_deser_detect,
)
register_stage(yolo)
register_stage(yolo, serialize_fn=_ser_detect, deserialize_fn=_deser_detect)
ocr = StageDefinition(
name="run_ocr",
@@ -54,10 +53,8 @@ def register():
category="detection",
io=StageIO(reads=["filtered_frames", "boxes_by_frame"], writes=["text_candidates"]),
config_fields=[
StageConfigField("languages", "list[str]", ["en"], "OCR languages"),
StageConfigField("min_confidence", "float", 0.5, "Min OCR confidence", min=0.0, max=1.0),
StageConfigField(name="languages", type="list[str]", default=["en"], description="OCR languages"),
StageConfigField(name="min_confidence", type="float", default=0.5, description="Min OCR confidence", min=0.0, max=1.0),
],
serialize_fn=_ser_ocr,
deserialize_fn=_deser_ocr,
)
register_stage(ocr)
register_stage(ocr, serialize_fn=_ser_ocr, deserialize_fn=_deser_ocr)

View File

@@ -1,6 +1,7 @@
"""Registration for escalation stages: local VLM, cloud LLM."""
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
from core.detect.stages.base import register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
@@ -37,12 +38,10 @@ def register():
optional_reads=["source_asset_id"],
),
config_fields=[
StageConfigField("min_confidence", "float", 0.5, "Min VLM confidence", min=0.0, max=1.0),
StageConfigField(name="min_confidence", type="float", default=0.5, description="Min VLM confidence", min=0.0, max=1.0),
],
serialize_fn=_ser_escalation,
deserialize_fn=_deser_escalation,
)
register_stage(vlm)
register_stage(vlm, serialize_fn=_ser_escalation, deserialize_fn=_deser_escalation)
cloud = StageDefinition(
name="escalate_cloud",
@@ -55,9 +54,7 @@ def register():
optional_reads=["source_asset_id"],
),
config_fields=[
StageConfigField("min_confidence", "float", 0.4, "Min cloud confidence", min=0.0, max=1.0),
StageConfigField(name="min_confidence", type="float", default=0.4, description="Min cloud confidence", min=0.0, max=1.0),
],
serialize_fn=_ser_escalation,
deserialize_fn=_deser_escalation,
)
register_stage(cloud)
register_stage(cloud, serialize_fn=_ser_escalation, deserialize_fn=_deser_escalation)

View File

@@ -1,6 +1,6 @@
"""Registration for output stages: report compilation."""
from detect.stages.base import StageDefinition, StageIO, register_stage
from core.detect.stages.base import StageDefinition, StageIO, register_stage
from ._serializers import serialize_dataclass, deserialize_detection_report
@@ -26,7 +26,5 @@ def register():
category="output",
io=StageIO(reads=["detections"], writes=["report"]),
config_fields=[],
serialize_fn=_ser_report,
deserialize_fn=_deser_report,
)
register_stage(report)
register_stage(report, serialize_fn=_ser_report, deserialize_fn=_deser_report)

View File

@@ -1,6 +1,7 @@
"""Registration for preprocessing stages: frame extraction, scene filter, image preprocessing."""
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
from core.detect.stages.base import register_stage
from ._serializers import serialize_frames, deserialize_frames
@@ -44,13 +45,11 @@ def register():
category="preprocessing",
io=StageIO(reads=["video_path"], writes=["frames"]),
config_fields=[
StageConfigField("fps", "float", 2.0, "Frames per second", min=0.1, max=30.0),
StageConfigField("max_frames", "int", 500, "Maximum frames to extract", min=1, max=10000),
StageConfigField(name="fps", type="float", default=2.0, description="Frames per second", min=0.1, max=30.0),
StageConfigField(name="max_frames", type="int", default=500, description="Maximum frames to extract", min=1, max=10000),
],
serialize_fn=_ser_extract,
deserialize_fn=_deser_extract,
)
register_stage(extract)
register_stage(extract, serialize_fn=_ser_extract, deserialize_fn=_deser_extract)
scene_filter = StageDefinition(
name="filter_scenes",
@@ -59,13 +58,11 @@ def register():
category="preprocessing",
io=StageIO(reads=["frames"], writes=["filtered_frames"]),
config_fields=[
StageConfigField("hamming_threshold", "int", 8, "Hamming distance threshold", min=0, max=64),
StageConfigField("enabled", "bool", True, "Enable scene filtering"),
StageConfigField(name="hamming_threshold", type="int", default=8, description="Hamming distance threshold", min=0, max=64),
StageConfigField(name="enabled", type="bool", default=True, description="Enable scene filtering"),
],
serialize_fn=_ser_filter,
deserialize_fn=_deser_filter,
)
register_stage(scene_filter)
register_stage(scene_filter, serialize_fn=_ser_filter, deserialize_fn=_deser_filter)
preprocess = StageDefinition(
name="preprocess",
@@ -77,11 +74,9 @@ def register():
writes=["preprocessed_crops"],
),
config_fields=[
StageConfigField("contrast", "bool", True, "CLAHE contrast enhancement"),
StageConfigField("deskew", "bool", False, "Correct slight rotation"),
StageConfigField("binarize", "bool", False, "Otsu binarization"),
StageConfigField(name="contrast", type="bool", default=True, description="CLAHE contrast enhancement"),
StageConfigField(name="deskew", type="bool", default=False, description="Correct slight rotation"),
StageConfigField(name="binarize", type="bool", default=False, description="Otsu binarization"),
],
serialize_fn=_ser_preprocess,
deserialize_fn=_deser_preprocess,
)
register_stage(preprocess)
register_stage(preprocess, serialize_fn=_ser_preprocess, deserialize_fn=_deser_preprocess)

View File

@@ -1,6 +1,7 @@
"""Registration for resolution stages: brand resolver."""
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
from core.detect.stages.base import register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
@@ -37,9 +38,7 @@ def register():
optional_reads=["session_brands", "source_asset_id"],
),
config_fields=[
StageConfigField("fuzzy_threshold", "int", 75, "Fuzzy match threshold", min=0, max=100),
StageConfigField(name="fuzzy_threshold", type="int", default=75, description="Fuzzy match threshold", min=0, max=100),
],
serialize_fn=_ser_brands,
deserialize_fn=_deser_brands,
)
register_stage(resolver)
register_stage(resolver, serialize_fn=_ser_brands, deserialize_fn=_deser_brands)

View File

@@ -14,9 +14,9 @@ import time
import imagehash
from PIL import Image
from detect import emit
from detect.models import Frame
from detect.profiles.base import SceneFilterConfig
from core.detect import emit
from core.detect.models import Frame
from core.detect.stages.models import SceneFilterConfig
def _compute_hashes(frames: list[Frame]) -> list[imagehash.ImageHash]:

View File

@@ -19,10 +19,10 @@ import time
import numpy as np
from PIL import Image
from detect import emit
from detect.models import BrandDetection, PipelineStats, TextCandidate
from detect.profiles.base import CropContext
from detect.providers import get_provider, has_api_key
from core.detect import emit
from core.detect.models import BrandDetection, PipelineStats, TextCandidate
from core.detect.models import CropContext
from core.detect.providers import get_provider, has_api_key
logger = logging.getLogger(__name__)
@@ -33,7 +33,7 @@ def _register_discovered_brand(brand: str, source_asset_id: str | None,
timestamp: float, confidence: float):
"""Register a cloud-confirmed brand in the DB."""
try:
from detect.stages.brand_resolver import _register_brand, _record_sighting
from core.detect.stages.brand_resolver import _register_brand, _record_sighting
brand_id = _register_brand(brand, "cloud_llm")
if brand_id and source_asset_id:
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, "cloud_llm")

View File

@@ -14,9 +14,9 @@ import time
import numpy as np
from detect import emit
from detect.models import BrandDetection, TextCandidate
from detect.profiles.base import CropContext
from core.detect import emit
from core.detect.models import BrandDetection, TextCandidate
from core.detect.models import CropContext
logger = logging.getLogger(__name__)
@@ -25,7 +25,7 @@ def _register_discovered_brand(brand: str, source_asset_id: str | None,
timestamp: float, confidence: float, source: str):
"""Register a VLM-confirmed brand in the DB."""
try:
from detect.stages.brand_resolver import _register_brand, _record_sighting
from core.detect.stages.brand_resolver import _register_brand, _record_sighting
brand_id = _register_brand(brand, source)
if brand_id and source_asset_id:
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, source)
@@ -75,8 +75,8 @@ def escalate_vlm(
still_unresolved: list[TextCandidate] = []
if inference_url:
from detect.inference import InferenceClient
from detect.emit import _run_log_level
from core.detect.inference import InferenceClient
from core.detect.emit import _run_log_level
client = InferenceClient(base_url=inference_url, job_id=job_id or "", log_level=_run_log_level)
for i, candidate in enumerate(candidates):
@@ -152,6 +152,6 @@ def escalate_vlm(
def _vlm_local(crop: np.ndarray, prompt: str) -> tuple[str, float, str]:
"""Run moondream2 in-process (single-box mode)."""
from gpu.models.vlm import query
from core.gpu.models.vlm import query
result = query(crop, prompt)
return result["brand"], result["confidence"], result["reasoning"]

View File

@@ -18,9 +18,9 @@ import time
from PIL import Image
from detect import emit
from detect.models import BoundingBox, Frame
from detect.profiles.base import DetectionConfig
from core.detect import emit
from core.detect.models import BoundingBox, Frame
from core.detect.stages.models import DetectionConfig
logger = logging.getLogger(__name__)
@@ -36,7 +36,7 @@ def _frame_to_b64(frame: Frame) -> str:
def _detect_remote(frame: Frame, config: DetectionConfig, inference_url: str,
job_id: str = "", log_level: str = "INFO") -> list[BoundingBox]:
"""Call the inference server over HTTP."""
from detect.inference import InferenceClient
from core.detect.inference import InferenceClient
client = InferenceClient(base_url=inference_url, job_id=job_id, log_level=log_level)
results = client.detect(
image=frame.image,
@@ -104,7 +104,7 @@ def detect_objects(
for i, frame in enumerate(frames):
t0 = time.monotonic()
if inference_url:
from detect.emit import _run_log_level
from core.detect.emit import _run_log_level
boxes = _detect_remote(frame, config, inference_url,
job_id=job_id or "", log_level=_run_log_level)
else:

View File

@@ -9,7 +9,7 @@ from __future__ import annotations
from typing import TypedDict
from detect.models import BoundingBox, BrandDetection, DetectionReport, Frame, PipelineStats, TextCandidate
from core.detect.models import BoundingBox, BrandDetection, DetectionReport, Frame, PipelineStats, TextCandidate
class DetectState(TypedDict, total=False):
@@ -22,6 +22,9 @@ class DetectState(TypedDict, total=False):
# Stage outputs
frames: list[Frame]
filtered_frames: list[Frame]
field_masks: dict # {seq: np.ndarray} — pitch mask per frame
field_boundaries: dict # {seq: [(x,y), ...]} — pitch boundary per frame
field_coverage: dict # {seq: float} — pitch coverage ratio per frame
edge_regions_by_frame: dict[int, list[BoundingBox]]
boxes_by_frame: dict[int, list[BoundingBox]]
preprocessed_crops: dict # "{frame_seq}_{box_idx}" → np.ndarray
@@ -36,5 +39,5 @@ class DetectState(TypedDict, total=False):
# Running stats (updated by each stage)
stats: PipelineStats
# Config overrides for replay (applied via OverrideProfile)
# Config overrides for replay (merged into profile configs dict)
config_overrides: dict

View File

@@ -6,7 +6,7 @@ and stage-level metadata. The Langfuse client is optional — if not configured
(no LANGFUSE_SECRET_KEY), tracing is a no-op.
Usage in graph nodes:
from detect.tracing import trace_node
from core.detect.tracing import trace_node
def node_extract_frames(state):
with trace_node(state, "extract_frames") as span:

View File

@@ -1,13 +1,6 @@
from .capabilities import get_decoders, get_encoders, get_formats
from .probe import ProbeResult, probe_file
from .transcode import TranscodeConfig, transcode
__all__ = [
"probe_file",
"ProbeResult",
"transcode",
"TranscodeConfig",
"get_encoders",
"get_decoders",
"get_formats",
]

View File

@@ -1,145 +0,0 @@
"""
FFmpeg capabilities - Discover available codecs and formats using ffmpeg-python.
"""
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, Dict, List
import ffmpeg
@dataclass
class Codec:
"""An FFmpeg encoder or decoder."""
name: str
description: str
type: str # 'video' or 'audio'
@dataclass
class Format:
"""An FFmpeg format (muxer/demuxer)."""
name: str
description: str
can_demux: bool
can_mux: bool
@lru_cache(maxsize=1)
def _get_ffmpeg_info() -> Dict[str, Any]:
"""Get FFmpeg capabilities info."""
# ffmpeg-python doesn't have a direct way to get codecs/formats
# but we can use probe on a dummy or parse -codecs output
# For now, return common codecs that are typically available
return {
"video_encoders": [
{"name": "libx264", "description": "H.264 / AVC"},
{"name": "libx265", "description": "H.265 / HEVC"},
{"name": "mpeg4", "description": "MPEG-4 Part 2"},
{"name": "libvpx", "description": "VP8"},
{"name": "libvpx-vp9", "description": "VP9"},
{"name": "h264_nvenc", "description": "NVIDIA NVENC H.264"},
{"name": "hevc_nvenc", "description": "NVIDIA NVENC H.265"},
{"name": "h264_vaapi", "description": "VAAPI H.264"},
{"name": "prores_ks", "description": "Apple ProRes"},
{"name": "dnxhd", "description": "Avid DNxHD/DNxHR"},
{"name": "copy", "description": "Stream copy (no encoding)"},
],
"audio_encoders": [
{"name": "aac", "description": "AAC"},
{"name": "libmp3lame", "description": "MP3"},
{"name": "libopus", "description": "Opus"},
{"name": "libvorbis", "description": "Vorbis"},
{"name": "pcm_s16le", "description": "PCM signed 16-bit little-endian"},
{"name": "flac", "description": "FLAC"},
{"name": "copy", "description": "Stream copy (no encoding)"},
],
"formats": [
{"name": "mp4", "description": "MP4", "can_demux": True, "can_mux": True},
{
"name": "mov",
"description": "QuickTime / MOV",
"can_demux": True,
"can_mux": True,
},
{
"name": "mkv",
"description": "Matroska",
"can_demux": True,
"can_mux": True,
},
{"name": "webm", "description": "WebM", "can_demux": True, "can_mux": True},
{"name": "avi", "description": "AVI", "can_demux": True, "can_mux": True},
{"name": "flv", "description": "FLV", "can_demux": True, "can_mux": True},
{
"name": "ts",
"description": "MPEG-TS",
"can_demux": True,
"can_mux": True,
},
{
"name": "mpegts",
"description": "MPEG-TS",
"can_demux": True,
"can_mux": True,
},
{"name": "hls", "description": "HLS", "can_demux": True, "can_mux": True},
],
}
def get_encoders() -> List[Codec]:
"""Get available encoders (video + audio)."""
info = _get_ffmpeg_info()
codecs = []
for c in info["video_encoders"]:
codecs.append(Codec(name=c["name"], description=c["description"], type="video"))
for c in info["audio_encoders"]:
codecs.append(Codec(name=c["name"], description=c["description"], type="audio"))
return codecs
def get_decoders() -> List[Codec]:
"""Get available decoders."""
# Most encoders can also decode
return get_encoders()
def get_formats() -> List[Format]:
"""Get available formats."""
info = _get_ffmpeg_info()
return [
Format(
name=f["name"],
description=f["description"],
can_demux=f["can_demux"],
can_mux=f["can_mux"],
)
for f in info["formats"]
]
def get_video_encoders() -> List[Codec]:
"""Get available video encoders."""
return [c for c in get_encoders() if c.type == "video"]
def get_audio_encoders() -> List[Codec]:
"""Get available audio encoders."""
return [c for c in get_encoders() if c.type == "audio"]
def get_muxers() -> List[Format]:
"""Get available output formats (muxers)."""
return [f for f in get_formats() if f.can_mux]
def get_demuxers() -> List[Format]:
"""Get available input formats (demuxers)."""
return [f for f in get_formats() if f.can_demux]

View File

@@ -1,225 +0,0 @@
"""
FFmpeg transcode module - Transcode media files using ffmpeg-python.
"""
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
import ffmpeg
@dataclass
class TranscodeConfig:
"""Configuration for a transcode operation."""
input_path: str
output_path: str
# Video
video_codec: str = "libx264"
video_bitrate: Optional[str] = None
video_crf: Optional[int] = None
video_preset: Optional[str] = None
resolution: Optional[str] = None
framerate: Optional[float] = None
# Audio
audio_codec: str = "aac"
audio_bitrate: Optional[str] = None
audio_channels: Optional[int] = None
audio_samplerate: Optional[int] = None
# Trimming
trim_start: Optional[float] = None
trim_end: Optional[float] = None
# Container
container: str = "mp4"
# Extra args (key-value pairs)
extra_args: List[str] = field(default_factory=list)
@property
def is_copy(self) -> bool:
"""Check if this is a stream copy (no transcoding)."""
return self.video_codec == "copy" and self.audio_codec == "copy"
def build_stream(config: TranscodeConfig):
"""
Build an ffmpeg-python stream from config.
Returns the stream object ready to run.
"""
# Input options
input_kwargs = {}
if config.trim_start is not None:
input_kwargs["ss"] = config.trim_start
stream = ffmpeg.input(config.input_path, **input_kwargs)
# Output options
output_kwargs = {
"vcodec": config.video_codec,
"acodec": config.audio_codec,
}
# Trimming duration
if config.trim_end is not None:
if config.trim_start is not None:
output_kwargs["t"] = config.trim_end - config.trim_start
else:
output_kwargs["t"] = config.trim_end
# Video options (skip if copy)
if config.video_codec != "copy":
if config.video_crf is not None:
output_kwargs["crf"] = config.video_crf
elif config.video_bitrate:
output_kwargs["video_bitrate"] = config.video_bitrate
if config.video_preset:
output_kwargs["preset"] = config.video_preset
if config.resolution:
output_kwargs["s"] = config.resolution
if config.framerate:
output_kwargs["r"] = config.framerate
# Audio options (skip if copy)
if config.audio_codec != "copy":
if config.audio_bitrate:
output_kwargs["audio_bitrate"] = config.audio_bitrate
if config.audio_channels:
output_kwargs["ac"] = config.audio_channels
if config.audio_samplerate:
output_kwargs["ar"] = config.audio_samplerate
# Parse extra args into kwargs
extra_kwargs = parse_extra_args(config.extra_args)
output_kwargs.update(extra_kwargs)
stream = ffmpeg.output(stream, config.output_path, **output_kwargs)
stream = ffmpeg.overwrite_output(stream)
return stream
def parse_extra_args(extra_args: List[str]) -> Dict[str, Any]:
"""
Parse extra args list into kwargs dict.
["-vtag", "xvid", "-pix_fmt", "yuv420p"] -> {"vtag": "xvid", "pix_fmt": "yuv420p"}
"""
kwargs = {}
i = 0
while i < len(extra_args):
key = extra_args[i].lstrip("-")
if i + 1 < len(extra_args) and not extra_args[i + 1].startswith("-"):
kwargs[key] = extra_args[i + 1]
i += 2
else:
# Flag without value
kwargs[key] = None
i += 1
return kwargs
def transcode(
config: TranscodeConfig,
duration: Optional[float] = None,
progress_callback: Optional[Callable[[float, Dict[str, Any]], None]] = None,
) -> bool:
"""
Transcode a media file.
Args:
config: Transcode configuration
duration: Total duration in seconds (for progress calculation, optional)
progress_callback: Called with (percent, details_dict) - requires duration
Returns:
True if successful
Raises:
ffmpeg.Error: If transcoding fails
"""
# Ensure output directory exists
Path(config.output_path).parent.mkdir(parents=True, exist_ok=True)
stream = build_stream(config)
if progress_callback and duration:
# Run with progress tracking using run_async
return _run_with_progress(stream, config, duration, progress_callback)
else:
# Run synchronously
ffmpeg.run(stream, capture_stdout=True, capture_stderr=True)
return True
def _run_with_progress(
stream,
config: TranscodeConfig,
duration: float,
progress_callback: Callable[[float, Dict[str, Any]], None],
) -> bool:
"""Run FFmpeg with progress tracking using run_async and stderr parsing."""
import re
# Calculate effective duration
effective_duration = duration
if config.trim_start and config.trim_end:
effective_duration = config.trim_end - config.trim_start
elif config.trim_end:
effective_duration = config.trim_end
elif config.trim_start:
effective_duration = duration - config.trim_start
# Run async to get process handle
process = ffmpeg.run_async(stream, pipe_stdout=True, pipe_stderr=True)
# Parse stderr for progress (time=HH:MM:SS.ms pattern)
time_pattern = re.compile(r"time=(\d+):(\d+):(\d+)\.(\d+)")
while True:
line = process.stderr.readline()
if not line:
break
line = line.decode("utf-8", errors="ignore")
match = time_pattern.search(line)
if match:
hours = int(match.group(1))
minutes = int(match.group(2))
seconds = int(match.group(3))
ms = int(match.group(4))
current_time = hours * 3600 + minutes * 60 + seconds + ms / 100
percent = min(100.0, (current_time / effective_duration) * 100)
progress_callback(
percent,
{
"time": current_time,
"percent": percent,
},
)
# Wait for completion
process.wait()
if process.returncode != 0:
raise ffmpeg.Error(
"ffmpeg", stdout=process.stdout.read(), stderr=process.stderr.read()
)
# Final callback
progress_callback(
100.0, {"time": effective_duration, "percent": 100.0, "done": True}
)
return True

View File

@@ -0,0 +1,6 @@
# GPU models — standalone container imports.
# When running as a container (cd gpu && python server.py), bare imports work.
# When imported from the main app (core.gpu.models.preprocess), only
# individual modules should be imported directly, not this __init__.
#
# The server.py imports detect/ocr/vlm directly, not through this file.

View File

@@ -0,0 +1,86 @@
"""
Field segmentation — HSV green mask → pitch boundary contour.
Pure OpenCV. Called by the inference server endpoint.
"""
from __future__ import annotations
import base64
import cv2
import numpy as np
def segment_field(
image: np.ndarray,
hue_low: int = 30,
hue_high: int = 85,
sat_low: int = 30,
sat_high: int = 255,
val_low: int = 30,
val_high: int = 255,
morph_kernel: int = 15,
min_area_ratio: float = 0.05,
) -> dict:
"""
Detect the pitch area using HSV green thresholding.
Returns dict with:
boundary: list of [x, y] points
coverage: float (fraction of frame)
"""
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
lower = np.array([hue_low, sat_low, val_low])
upper = np.array([hue_high, sat_high, val_high])
mask = cv2.inRange(hsv, lower, upper)
k = morph_kernel
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
h, w = image.shape[:2]
min_area = min_area_ratio * h * w
boundary = []
coverage = 0.0
if contours:
large = [c for c in contours if cv2.contourArea(c) >= min_area]
if large:
pitch_contour = max(large, key=cv2.contourArea)
boundary = pitch_contour.reshape(-1, 2).tolist()
coverage = cv2.contourArea(pitch_contour) / (h * w)
refined = np.zeros_like(mask)
cv2.drawContours(refined, [pitch_contour], -1, 255, cv2.FILLED)
mask = refined
return {
"boundary": boundary,
"coverage": coverage,
"mask": mask,
}
def segment_field_debug(
image: np.ndarray,
**kwargs,
) -> dict:
"""Same as segment_field but includes a mask overlay for the editor."""
result = segment_field(image, **kwargs)
mask = result["mask"]
# RGBA overlay: solid green where mask, fully transparent elsewhere
h, w = image.shape[:2]
overlay = np.zeros((h, w, 4), dtype=np.uint8)
overlay[mask > 0] = [0, 255, 0, 255]
_, buf = cv2.imencode(".png", overlay)
result["mask_overlay_b64"] = base64.b64encode(buf.tobytes()).decode()
# Don't send the raw mask over HTTP
del result["mask"]
return result

View File

@@ -101,6 +101,30 @@ class AnalyzeRegionsDebugResponse(BaseModel):
horizontal_count: int = 0
pair_count: int = 0
class SegmentFieldRequest(BaseModel):
"""Request body for field segmentation."""
image: str
hue_low: int = 30
hue_high: int = 85
sat_low: int = 30
sat_high: int = 255
val_low: int = 30
val_high: int = 255
morph_kernel: int = 15
min_area_ratio: float = 0.05
class SegmentFieldResponse(BaseModel):
"""Response from field segmentation."""
boundary: List[List[int]] = Field(default_factory=list)
coverage: float = 0.0
mask_b64: str = ""
class SegmentFieldDebugResponse(BaseModel):
"""Response from field segmentation with debug overlay."""
boundary: List[List[int]] = Field(default_factory=list)
coverage: float = 0.0
mask_overlay_b64: str = ""
class ConfigUpdate(BaseModel):
"""Request body for updating server configuration."""
device: Optional[str] = None

View File

@@ -54,7 +54,7 @@ def _gpu_log(job_id: str, log_level: str, stage: str, level: str, msg: str):
# --- Request/Response models (generated from core/schema/models/inference.py) ---
from models.inference_contract import (
from models.models import (
AnalyzeRegionsDebugResponse,
AnalyzeRegionsRequest,
AnalyzeRegionsResponse,
@@ -68,6 +68,9 @@ from models.inference_contract import (
PreprocessRequest,
PreprocessResponse,
RegionBox,
SegmentFieldRequest,
SegmentFieldResponse,
SegmentFieldDebugResponse,
VLMRequest,
VLMResponse,
)
@@ -310,6 +313,83 @@ def detect_edges_debug_endpoint(req: AnalyzeRegionsRequest, request: Request):
return response
@app.post("/segment_field", response_model=SegmentFieldResponse)
def segment_field_endpoint(req: SegmentFieldRequest, request: Request):
job_id, log_level = _job_ctx(request)
try:
image = _decode_image(req.image)
h, w = image.shape[:2]
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
t0 = time.monotonic()
from models.cv.segmentation import segment_field
result = segment_field(
image,
hue_low=req.hue_low,
hue_high=req.hue_high,
sat_low=req.sat_low,
sat_high=req.sat_high,
val_low=req.val_low,
val_high=req.val_high,
morph_kernel=req.morph_kernel,
min_area_ratio=req.min_area_ratio,
)
infer_ms = (time.monotonic() - t0) * 1000
# Encode mask as base64 PNG for downstream use
import cv2
_, buf = cv2.imencode(".png", result["mask"])
mask_b64 = base64.b64encode(buf.tobytes()).decode()
_gpu_log(job_id, log_level, "GPU:CV", "DEBUG",
f"Field segmentation {w}x{h}: {infer_ms:.0f}ms, coverage={result['coverage']:.1%}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Field segmentation failed: {e}")
return SegmentFieldResponse(
boundary=result["boundary"],
coverage=result["coverage"],
mask_b64=mask_b64,
)
@app.post("/segment_field/debug", response_model=SegmentFieldDebugResponse)
def segment_field_debug_endpoint(req: SegmentFieldRequest, request: Request):
job_id, log_level = _job_ctx(request)
try:
image = _decode_image(req.image)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
from models.cv.segmentation import segment_field_debug
result = segment_field_debug(
image,
hue_low=req.hue_low,
hue_high=req.hue_high,
sat_low=req.sat_low,
sat_high=req.sat_high,
val_low=req.val_low,
val_high=req.val_high,
morph_kernel=req.morph_kernel,
min_area_ratio=req.min_area_ratio,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Field segmentation debug failed: {e}")
return SegmentFieldDebugResponse(
boundary=result["boundary"],
coverage=result["coverage"],
mask_overlay_b64=result["mask_overlay_b64"],
)
if __name__ == "__main__":
import uvicorn

View File

@@ -180,24 +180,11 @@ class WorkerServicer(worker_pb2_grpc.WorkerServiceServicer):
def GetWorkerStatus(self, request, context):
"""Get worker health and capabilities."""
try:
from core.ffmpeg import get_encoders
encoders = get_encoders()
codec_names = [e["name"] for e in encoders.get("video", [])]
except Exception:
codec_names = []
# Check for GPU encoders
gpu_available = any(
"nvenc" in name or "vaapi" in name or "qsv" in name for name in codec_names
)
return worker_pb2.WorkerStatus(
available=True,
active_jobs=len(_active_jobs),
supported_codecs=codec_names[:20], # Limit to 20
gpu_available=gpu_available,
supported_codecs=[],
gpu_available=False,
)

Some files were not shown because too many files have changed in this diff Show More