This commit is contained in:
2026-03-30 07:22:14 -03:00
parent d0707333fd
commit 4220b0418e
182 changed files with 3668 additions and 5231 deletions

View File

@@ -1,73 +0,0 @@
"""
SSE endpoint for chunker pipeline events.
Uses Redis as the event bus. Pipeline pushes events via core.events,
SSE endpoint polls them.
GET /chunker/stream/{job_id} → text/event-stream
"""
import asyncio
import json
import logging
import time
from typing import AsyncGenerator
from fastapi import APIRouter
from starlette.responses import StreamingResponse
from core.events import poll_events
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/chunker", tags=["chunker"])
async def _event_generator(job_id: str) -> AsyncGenerator[str, None]:
"""
Generate SSE events by polling Redis for chunk job events.
"""
cursor = 0
timeout = time.monotonic() + 600 # 10 min max
while time.monotonic() < timeout:
events, cursor = poll_events(job_id, cursor)
if not events:
yield f"event: waiting\ndata: {json.dumps({'job_id': job_id})}\n\n"
await asyncio.sleep(0.1)
continue
for data in events:
event_type = data.pop("event", "update")
payload = {**data, "job_id": job_id}
yield f"event: {event_type}\ndata: {json.dumps(payload)}\n\n"
if event_type in ("pipeline_complete", "pipeline_error", "cancelled"):
yield f"event: done\ndata: {json.dumps({'job_id': job_id})}\n\n"
return
await asyncio.sleep(0.05)
yield f"event: timeout\ndata: {json.dumps({'job_id': job_id})}\n\n"
@router.get("/stream/{job_id}")
async def stream_chunk_job(job_id: str):
"""
SSE stream for a chunk pipeline job.
The UI connects via native EventSource:
const es = new EventSource('/api/chunker/stream/<job_id>');
es.addEventListener('processing', (e) => { ... });
"""
return StreamingResponse(
_event_generator(job_id),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)

View File

@@ -58,32 +58,30 @@ def write_config(update: ConfigUpdate):
@router.get("/config/profiles")
def list_profiles():
def get_profiles():
"""List available detection profiles."""
from detect.profiles import _PROFILES
return [{"name": name} for name in _PROFILES]
from core.detect.profile import list_profiles as _list
return [{"name": name} for name in _list()]
@router.get("/config/profiles/{profile_name}/pipeline")
def get_pipeline_config(profile_name: str):
"""Return the pipeline composition for a profile."""
from detect.profiles import get_profile
from core.detect.profile import get_profile
from fastapi import HTTPException
from dataclasses import asdict
try:
profile = get_profile(profile_name)
except ValueError:
raise HTTPException(status_code=404, detail=f"Unknown profile: {profile_name}")
config = profile.pipeline_config()
return asdict(config)
return profile["pipeline"]
@router.get("/config/stages", response_model=list[StageConfigInfo])
def list_stage_configs():
"""Return the stage palette with config field metadata for the editor."""
from detect.stages import list_stages
from core.detect.stages import list_stages
result = []
for stage in list_stages():
@@ -95,7 +93,7 @@ def list_stage_configs():
@router.get("/config/stages/{stage_name}", response_model=StageConfigInfo)
def get_stage_config(stage_name: str):
"""Return config field metadata for a single stage."""
from detect.stages import get_stage
from core.detect.stages import get_stage
try:
stage = get_stage(stage_name)

View File

@@ -105,7 +105,7 @@ class ReplaySingleStageResponse(BaseModel):
@router.get("/checkpoints/{timeline_id}")
def list_checkpoints(timeline_id: str) -> list[CheckpointInfo]:
"""List available checkpoint stages for a job."""
from detect.checkpoint import list_checkpoints as _list
from core.detect.checkpoint import list_checkpoints as _list
try:
stages = _list(timeline_id)
@@ -139,10 +139,10 @@ class CheckpointData(BaseModel):
def get_checkpoint_data(timeline_id: str, stage: str):
"""Load checkpoint frames + metadata for the editor UI."""
from uuid import UUID
from core.db.tables import Timeline, Checkpoint
from core.db.models import Timeline, Checkpoint
from core.db.connection import get_session
from core.db.checkpoint import list_checkpoints
from detect.checkpoint.frames import load_frames_b64
from core.detect.checkpoint.frames import load_frames_b64
with get_session() as session:
timeline = session.get(Timeline, UUID(timeline_id))
@@ -184,7 +184,7 @@ def get_checkpoint_data(timeline_id: str, stage: str):
@router.get("/scenarios", response_model=list[ScenarioInfo])
def list_scenarios_endpoint():
"""List all available scenarios (bookmarked checkpoints)."""
from core.db.tables import Timeline
from core.db.models import Timeline
from core.db.connection import get_session
from core.db.checkpoint import list_scenarios
@@ -212,7 +212,7 @@ def list_scenarios_endpoint():
@router.post("/replay", response_model=ReplayResponse)
def replay(req: ReplayRequest):
"""Replay pipeline from a specific stage with optional config overrides."""
from detect.checkpoint import replay_from
from core.detect.checkpoint import replay_from
try:
result = replay_from(
@@ -242,7 +242,7 @@ def replay(req: ReplayRequest):
@router.post("/retry", response_model=RetryResponse)
def retry(req: RetryRequest):
"""Queue an async retry of unresolved candidates with different config."""
from detect.checkpoint.tasks import retry_candidates
from core.detect.checkpoint.tasks import retry_candidates
kwargs = {
"timeline_id": req.timeline_id,
@@ -266,7 +266,7 @@ def retry(req: RetryRequest):
@router.post("/replay-stage", response_model=ReplaySingleStageResponse)
def replay_single_stage(req: ReplaySingleStageRequest):
"""Replay a single stage on specific frames — fast path for interactive tuning."""
from detect.checkpoint.replay import replay_single_stage as _replay
from core.detect.checkpoint.replay import replay_single_stage as _replay
try:
result = _replay(
@@ -361,3 +361,41 @@ async def gpu_detect_edges_debug(request: Request):
media_type="application/json")
except Exception as e:
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
@router.post("/gpu/segment_field")
async def gpu_segment_field(request: Request):
"""Proxy to GPU inference server — field segmentation."""
import httpx
body = await request.body()
try:
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{_gpu_url()}/segment_field",
content=body,
headers={"Content-Type": "application/json"},
)
return Response(content=resp.content, status_code=resp.status_code,
media_type="application/json")
except Exception as e:
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
@router.post("/gpu/segment_field/debug")
async def gpu_segment_field_debug(request: Request):
"""Proxy to GPU inference server — field segmentation with debug overlay."""
import httpx
body = await request.body()
try:
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{_gpu_url()}/segment_field/debug",
content=body,
headers={"Content-Type": "application/json"},
)
return Response(content=resp.content, status_code=resp.status_code,
media_type="application/json")
except Exception as e:
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")

View File

@@ -60,9 +60,9 @@ def _resolve_video_path(video_path: str) -> str:
@router.post("/run", response_model=RunResponse)
def run_pipeline(req: RunRequest):
"""Launch a detection pipeline run on a source chunk."""
from detect import emit
from detect.graph import get_pipeline
from detect.state import DetectState
from core.detect import emit
from core.detect.graph import get_pipeline
from core.detect.state import DetectState
local_path = _resolve_video_path(req.video_path)
job_id = str(uuid.uuid4())
@@ -79,7 +79,7 @@ def run_pipeline(req: RunRequest):
# Clear any stale events from a previous run with same job_id
from core.events import _get_redis
from detect.events import DETECT_EVENTS_PREFIX
from core.detect.events import DETECT_EVENTS_PREFIX
r = _get_redis()
r.delete(f"{DETECT_EVENTS_PREFIX}:{job_id}")
@@ -97,7 +97,7 @@ def run_pipeline(req: RunRequest):
source_asset_id=req.source_asset_id,
)
from detect.graph import (
from core.detect.graph import (
PipelineCancelled, set_cancel_check, clear_cancel_check,
init_pause, clear_pause,
)
@@ -117,7 +117,7 @@ def run_pipeline(req: RunRequest):
emit.job_complete(job_id, {"status": "cancelled"})
except Exception as e:
logger.exception("Pipeline run %s failed: %s", job_id, e)
from detect.graph import _node_states, NODES
from core.detect.graph import _node_states, NODES
if job_id in _node_states:
states = _node_states[job_id]
for node in reversed(NODES):
@@ -145,7 +145,7 @@ def run_pipeline(req: RunRequest):
@router.post("/stop/{job_id}")
def stop_pipeline(job_id: str):
"""Stop a running pipeline. Signals cancellation; the thread checks on next stage."""
from detect import emit
from core.detect import emit
if job_id not in _running_jobs:
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
@@ -158,7 +158,7 @@ def stop_pipeline(job_id: str):
@router.post("/pause/{job_id}")
def pause(job_id: str):
"""Pause a running pipeline after the current stage completes."""
from detect.graph import pause_pipeline
from core.detect.graph import pause_pipeline
if job_id not in _running_jobs:
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
@@ -170,7 +170,7 @@ def pause(job_id: str):
@router.post("/resume/{job_id}")
def resume(job_id: str):
"""Resume a paused pipeline."""
from detect.graph import resume_pipeline
from core.detect.graph import resume_pipeline
if job_id not in _running_jobs:
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
@@ -182,7 +182,7 @@ def resume(job_id: str):
@router.post("/step/{job_id}")
def step(job_id: str):
"""Run one stage then pause again."""
from detect.graph import step_pipeline
from core.detect.graph import step_pipeline
if job_id not in _running_jobs:
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
@@ -194,7 +194,7 @@ def step(job_id: str):
@router.post("/pause-after-stage/{job_id}")
def toggle_pause_after_stage(job_id: str, enabled: bool = True):
"""Toggle pause-after-each-stage mode."""
from detect.graph import set_pause_after_stage
from core.detect.graph import set_pause_after_stage
if job_id not in _running_jobs:
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
@@ -206,7 +206,7 @@ def toggle_pause_after_stage(job_id: str, enabled: bool = True):
@router.get("/status/{job_id}")
def pipeline_status(job_id: str):
"""Get pipeline run status."""
from detect.graph import is_paused
from core.detect.graph import is_paused
running = job_id in _running_jobs
paused = is_paused(job_id)
@@ -224,11 +224,23 @@ def pipeline_status(job_id: str):
return {"status": status, "job_id": job_id}
@router.get("/timeline/{job_id}")
def get_timeline_for_job(job_id: str):
"""Get the timeline_id for a running or completed job."""
from core.detect.checkpoint.runner_bridge import get_timeline_id
tid = get_timeline_id(job_id)
if tid is None:
raise HTTPException(status_code=404, detail=f"No timeline for job: {job_id}")
return {"timeline_id": tid, "job_id": job_id}
@router.post("/clear/{job_id}")
def clear_pipeline(job_id: str):
"""Clear events for a job from Redis."""
from core.events import _get_redis
from detect.events import DETECT_EVENTS_PREFIX
from core.detect.events import DETECT_EVENTS_PREFIX
r = _get_redis()
r.delete(f"{DETECT_EVENTS_PREFIX}:{job_id}")

View File

@@ -17,7 +17,7 @@ from fastapi import APIRouter
from starlette.responses import StreamingResponse
from core.events import poll_events
from detect.events import DETECT_EVENTS_PREFIX, TERMINAL_EVENTS
from core.detect.events import DETECT_EVENTS_PREFIX, TERMINAL_EVENTS
logger = logging.getLogger(__name__)

View File

@@ -1,384 +0,0 @@
"""
GraphQL API using strawberry, served via FastAPI.
Primary API for MPR — all client interactions go through GraphQL.
Uses core.db for data access.
Types are generated from schema/ via modelgen — see api/schema/graphql.py.
"""
import os
from typing import List, Optional
from uuid import UUID
import strawberry
from strawberry.schema.config import StrawberryConfig
from strawberry.types import Info
from core.api.schema.graphql import (
CancelResultType,
ChunkJobType,
ChunkOutputFileType,
CreateChunkJobInput,
CreateJobInput,
DeleteResultType,
MediaAssetType,
ScanResultType,
SystemStatusType,
TranscodeJobType,
TranscodePresetType,
UpdateAssetInput,
)
from core.storage import BUCKET_IN, list_objects, upload_file
VIDEO_EXTS = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv", ".wmv", ".m4v"}
AUDIO_EXTS = {".mp3", ".wav", ".flac", ".aac", ".ogg", ".m4a"}
MEDIA_EXTS = VIDEO_EXTS | AUDIO_EXTS
# ---------------------------------------------------------------------------
# Queries
# ---------------------------------------------------------------------------
@strawberry.type
class Query:
@strawberry.field
def assets(
self,
info: Info,
status: Optional[str] = None,
search: Optional[str] = None,
) -> List[MediaAssetType]:
from core.db import list_assets
return list_assets(status=status, search=search)
@strawberry.field
def asset(self, info: Info, id: UUID) -> Optional[MediaAssetType]:
from core.db import get_asset
try:
return get_asset(id)
except Exception:
return None
@strawberry.field
def jobs(
self,
info: Info,
status: Optional[str] = None,
source_asset_id: Optional[UUID] = None,
) -> List[TranscodeJobType]:
from core.db import list_jobs
return list_jobs(status=status, source_asset_id=source_asset_id)
@strawberry.field
def job(self, info: Info, id: UUID) -> Optional[TranscodeJobType]:
from core.db import get_job
try:
return get_job(id)
except Exception:
return None
@strawberry.field
def presets(self, info: Info) -> List[TranscodePresetType]:
from core.db import list_presets
return list_presets()
@strawberry.field
def system_status(self, info: Info) -> SystemStatusType:
return SystemStatusType(status="ok", version="0.1.0")
@strawberry.field
def chunk_output_files(self, info: Info, job_id: str) -> List[ChunkOutputFileType]:
"""List output chunk files for a completed job from media/out/."""
from pathlib import Path
media_out = os.environ.get("MEDIA_OUT_DIR", "/app/media/out")
output_dir = Path(media_out) / "chunks" / job_id
if not output_dir.is_dir():
return []
return [
ChunkOutputFileType(
key=f.name,
size=f.stat().st_size,
url=f"/media/out/chunks/{job_id}/{f.name}",
)
for f in sorted(output_dir.iterdir())
if f.is_file()
]
# ---------------------------------------------------------------------------
# Mutations
# ---------------------------------------------------------------------------
@strawberry.type
class Mutation:
@strawberry.mutation
def scan_media_folder(self, info: Info) -> ScanResultType:
import logging
from pathlib import Path
from core.db import create_asset, get_asset_filenames
logger = logging.getLogger(__name__)
# Sync local media/in/ files to MinIO (handles fresh installs / pruned volumes)
local_media = Path("/app/media/in")
if local_media.is_dir():
existing_keys = {o["key"] for o in list_objects(BUCKET_IN)}
for f in local_media.iterdir():
if f.is_file() and f.suffix.lower() in MEDIA_EXTS:
if f.name not in existing_keys:
try:
upload_file(str(f), BUCKET_IN, f.name)
logger.info("Uploaded %s to MinIO", f.name)
except Exception as e:
logger.warning("Failed to upload %s: %s", f.name, e)
objects = list_objects(BUCKET_IN, extensions=MEDIA_EXTS)
existing = get_asset_filenames()
registered = []
skipped = []
for obj in objects:
if obj["filename"] in existing:
skipped.append(obj["filename"])
continue
try:
create_asset(
filename=obj["filename"],
file_path=obj["key"],
file_size=obj["size"],
)
registered.append(obj["filename"])
except Exception:
pass
return ScanResultType(
found=len(objects),
registered=len(registered),
skipped=len(skipped),
files=registered,
)
@strawberry.mutation
def create_job(self, info: Info, input: CreateJobInput) -> TranscodeJobType:
from pathlib import Path
from core.db import create_job, get_asset, get_preset
try:
source = get_asset(input.source_asset_id)
except Exception:
raise Exception("Source asset not found")
preset = None
preset_snapshot = {}
if input.preset_id:
try:
preset = get_preset(input.preset_id)
preset_snapshot = {
"name": preset.name,
"container": preset.container,
"video_codec": preset.video_codec,
"audio_codec": preset.audio_codec,
}
except Exception:
raise Exception("Preset not found")
if not preset and not input.trim_start and not input.trim_end:
raise Exception("Must specify preset_id or trim_start/trim_end")
output_filename = input.output_filename
if not output_filename:
stem = Path(source.filename).stem
ext = preset_snapshot.get("container", "mp4") if preset else "mp4"
output_filename = f"{stem}_output.{ext}"
job = create_job(
source_asset_id=source.id,
preset_id=preset.id if preset else None,
preset_snapshot=preset_snapshot,
trim_start=input.trim_start,
trim_end=input.trim_end,
output_filename=output_filename,
output_path=output_filename,
priority=input.priority or 0,
)
payload = {
"source_key": source.file_path,
"output_key": output_filename,
"preset": preset_snapshot or None,
"trim_start": input.trim_start,
"trim_end": input.trim_end,
"duration": source.duration,
}
executor_mode = os.environ.get("MPR_EXECUTOR", "local")
if executor_mode in ("lambda", "gcp"):
from core.jobs.executor import get_executor
get_executor().run(
job_type="transcode",
job_id=str(job.id),
payload=payload,
)
else:
from core.jobs.task import run_job
result = run_job.delay(
job_type="transcode",
job_id=str(job.id),
payload=payload,
)
job.celery_task_id = result.id
job.save(update_fields=["celery_task_id"])
return job
@strawberry.mutation
def cancel_job(self, info: Info, id: UUID) -> TranscodeJobType:
from core.db import get_job, update_job
try:
job = get_job(id)
except Exception:
raise Exception("Job not found")
if job.status not in ("pending", "processing"):
raise Exception(f"Cannot cancel job with status: {job.status}")
return update_job(job, status="cancelled")
@strawberry.mutation
def retry_job(self, info: Info, id: UUID) -> TranscodeJobType:
from core.db import get_job, update_job
try:
job = get_job(id)
except Exception:
raise Exception("Job not found")
if job.status != "failed":
raise Exception("Only failed jobs can be retried")
return update_job(job, status="pending", progress=0, error_message=None)
@strawberry.mutation
def update_asset(self, info: Info, id: UUID, input: UpdateAssetInput) -> MediaAssetType:
from core.db import get_asset, update_asset
try:
asset = get_asset(id)
except Exception:
raise Exception("Asset not found")
fields = {}
if input.comments is not None:
fields["comments"] = input.comments
if input.tags is not None:
fields["tags"] = input.tags
if fields:
asset = update_asset(asset, **fields)
return asset
@strawberry.mutation
def delete_asset(self, info: Info, id: UUID) -> DeleteResultType:
from core.db import delete_asset, get_asset
try:
asset = get_asset(id)
delete_asset(asset)
return DeleteResultType(ok=True)
except Exception:
raise Exception("Asset not found")
@strawberry.mutation
def create_chunk_job(self, info: Info, input: CreateChunkJobInput) -> ChunkJobType:
"""Create and dispatch a chunk pipeline job."""
import uuid
from core.db import get_asset
try:
source = get_asset(input.source_asset_id)
except Exception:
raise Exception("Source asset not found")
job_id = str(uuid.uuid4())
payload = {
"source_key": source.file_path,
"chunk_duration": input.chunk_duration,
"num_workers": input.num_workers,
"max_retries": input.max_retries,
"processor_type": input.processor_type,
"start_time": input.start_time,
"end_time": input.end_time,
}
executor_mode = os.environ.get("MPR_EXECUTOR", "local")
celery_task_id = None
if executor_mode in ("lambda", "gcp"):
from core.jobs.executor import get_executor
get_executor().run(
job_type="chunk",
job_id=job_id,
payload=payload,
)
else:
from core.jobs.task import run_job
result = run_job.delay(
job_type="chunk",
job_id=job_id,
payload=payload,
)
celery_task_id = result.id
return ChunkJobType(
id=uuid.UUID(job_id),
source_asset_id=input.source_asset_id,
chunk_duration=input.chunk_duration,
num_workers=input.num_workers,
max_retries=input.max_retries,
processor_type=input.processor_type,
status="pending",
progress=0.0,
priority=input.priority,
celery_task_id=celery_task_id,
)
@strawberry.mutation
def cancel_chunk_job(self, info: Info, celery_task_id: str) -> CancelResultType:
"""Cancel a running chunk job by revoking its Celery task."""
try:
from admin.mpr.celery import app as celery_app
celery_app.control.revoke(celery_task_id, terminate=True, signal="SIGTERM")
return CancelResultType(ok=True, message="Task revoked")
except Exception as e:
return CancelResultType(ok=False, message=str(e))
# ---------------------------------------------------------------------------
# Schema
# ---------------------------------------------------------------------------
schema = strawberry.Schema(
query=Query,
mutation=Mutation,
config=StrawberryConfig(auto_camel_case=False),
)

View File

@@ -1,48 +1,38 @@
"""
MPR FastAPI Application
Serves GraphQL API and Lambda callback endpoint.
"""
import os
import sys
from typing import Optional
from uuid import UUID
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from contextlib import asynccontextmanager
from fastapi import FastAPI, Header, HTTPException
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from strawberry.fastapi import GraphQLRouter
from core.api.chunker_sse import router as chunker_router
from core.api.detect import router as detect_router
from core.api.graphql import schema as graphql_schema
CALLBACK_API_KEY = os.environ.get("CALLBACK_API_KEY", "")
@asynccontextmanager
async def lifespan(app):
# Create/reset DB tables on startup
from core.db.connection import create_tables
from core.db.seed import seed_profiles
create_tables()
seed_profiles()
yield
app = FastAPI(
title="MPR API",
description="Media Processor — GraphQL API",
version="0.1.0",
docs_url="/docs",
redoc_url="/redoc",
lifespan=lifespan,
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["http://mpr.local.ar", "http://k8s.mpr.local.ar", "http://localhost:5173"],
@@ -51,13 +41,6 @@ app.add_middleware(
allow_headers=["*"],
)
# GraphQL
graphql_router = GraphQLRouter(schema=graphql_schema, graphql_ide="graphiql")
app.include_router(graphql_router, prefix="/graphql")
# Chunker SSE
app.include_router(chunker_router)
# Detection API (sources, run, SSE, replay, config)
app.include_router(detect_router)
@@ -69,48 +52,7 @@ def health():
@app.get("/")
def root():
"""API root."""
return {
"name": "MPR API",
"version": "0.1.0",
"graphql": "/graphql",
}
@app.post("/api/jobs/{job_id}/callback")
def job_callback(
job_id: UUID,
payload: dict,
x_api_key: Optional[str] = Header(None),
):
"""
Callback endpoint for Lambda to report job completion.
Protected by API key.
"""
if CALLBACK_API_KEY and x_api_key != CALLBACK_API_KEY:
raise HTTPException(status_code=403, detail="Invalid API key")
from django.utils import timezone
from core.db import get_job, update_job
try:
job = get_job(job_id)
except Exception:
raise HTTPException(status_code=404, detail="Job not found")
status = payload.get("status", "failed")
fields = {
"status": status,
"progress": 100.0 if status == "completed" else job.progress,
}
if payload.get("error"):
fields["error_message"] = payload["error"]
if status in ("completed", "failed"):
fields["completed_at"] = timezone.now()
update_job(job, **fields)
return {"ok": True}

View File

@@ -1,226 +0,0 @@
"""
Strawberry Types - GENERATED FILE
Do not edit directly. Regenerate using modelgen.
"""
import strawberry
from enum import Enum
from typing import List, Optional
from uuid import UUID
from datetime import datetime
from strawberry.scalars import JSON
@strawberry.enum
class AssetStatus(Enum):
PENDING = "pending"
READY = "ready"
ERROR = "error"
@strawberry.enum
class JobStatus(Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@strawberry.type
class MediaAssetType:
"""A video/audio file registered in the system."""
id: Optional[UUID] = None
filename: Optional[str] = None
file_path: Optional[str] = None
status: Optional[str] = None
error_message: Optional[str] = None
file_size: Optional[float] = None
duration: Optional[float] = None
video_codec: Optional[str] = None
audio_codec: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
framerate: Optional[float] = None
bitrate: Optional[int] = None
properties: Optional[JSON] = None
comments: Optional[str] = None
tags: Optional[List[str]] = None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
@strawberry.type
class TranscodePresetType:
"""A reusable transcoding configuration (like Handbrake presets)."""
id: Optional[UUID] = None
name: Optional[str] = None
description: Optional[str] = None
is_builtin: Optional[bool] = None
container: Optional[str] = None
video_codec: Optional[str] = None
video_bitrate: Optional[str] = None
video_crf: Optional[int] = None
video_preset: Optional[str] = None
resolution: Optional[str] = None
framerate: Optional[float] = None
audio_codec: Optional[str] = None
audio_bitrate: Optional[str] = None
audio_channels: Optional[int] = None
audio_samplerate: Optional[int] = None
extra_args: Optional[List[str]] = None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
@strawberry.type
class TranscodeJobType:
"""A transcoding or trimming job in the queue."""
id: Optional[UUID] = None
source_asset_id: Optional[UUID] = None
preset_id: Optional[UUID] = None
preset_snapshot: Optional[JSON] = None
trim_start: Optional[float] = None
trim_end: Optional[float] = None
output_filename: Optional[str] = None
output_path: Optional[str] = None
output_asset_id: Optional[UUID] = None
status: Optional[str] = None
progress: Optional[float] = None
current_frame: Optional[int] = None
current_time: Optional[float] = None
speed: Optional[str] = None
error_message: Optional[str] = None
celery_task_id: Optional[str] = None
execution_arn: Optional[str] = None
priority: Optional[int] = None
created_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
@strawberry.input
class CreateJobInput:
"""Request body for creating a transcode/trim job."""
source_asset_id: UUID
preset_id: Optional[UUID] = None
trim_start: Optional[float] = None
trim_end: Optional[float] = None
output_filename: Optional[str] = None
priority: int = 0
@strawberry.input
class UpdateAssetInput:
"""Request body for updating asset metadata."""
comments: Optional[str] = None
tags: Optional[List[str]] = None
@strawberry.type
class SystemStatusType:
"""System status response."""
status: Optional[str] = None
version: Optional[str] = None
@strawberry.type
class ScanResultType:
"""Result of scanning the media input bucket."""
found: Optional[int] = None
registered: Optional[int] = None
skipped: Optional[int] = None
files: Optional[List[str]] = None
@strawberry.type
class DeleteResultType:
"""Result of a delete operation."""
ok: Optional[bool] = None
@strawberry.type
class WorkerStatusType:
"""Worker health and capabilities."""
available: Optional[bool] = None
active_jobs: Optional[int] = None
supported_codecs: Optional[List[str]] = None
gpu_available: Optional[bool] = None
@strawberry.enum
class ChunkJobStatus(Enum):
PENDING = "pending"
CHUNKING = "chunking"
PROCESSING = "processing"
COLLECTING = "collecting"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@strawberry.type
class ChunkJobType:
"""A chunk pipeline job."""
id: Optional[UUID] = None
source_asset_id: Optional[UUID] = None
chunk_duration: Optional[float] = None
num_workers: Optional[int] = None
max_retries: Optional[int] = None
processor_type: Optional[str] = None
status: Optional[str] = None
progress: Optional[float] = None
total_chunks: Optional[int] = None
processed_chunks: Optional[int] = None
failed_chunks: Optional[int] = None
retry_count: Optional[int] = None
error_message: Optional[str] = None
throughput_mbps: Optional[float] = None
elapsed_seconds: Optional[float] = None
celery_task_id: Optional[str] = None
priority: Optional[int] = None
created_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
@strawberry.input
class CreateChunkJobInput:
"""Request body for creating a chunk pipeline job."""
source_asset_id: UUID
chunk_duration: float = 10.0
num_workers: int = 4
max_retries: int = 3
processor_type: str = "ffmpeg"
priority: int = 0
start_time: Optional[float] = None
end_time: Optional[float] = None
@strawberry.type
class CancelResultType:
"""Result of cancelling a chunk job."""
ok: bool = False
message: Optional[str] = None
@strawberry.type
class ChunkOutputFileType:
"""A chunk output file in S3/MinIO with presigned download URL."""
key: str
size: int = 0
url: str = ""

View File

@@ -1,64 +0,0 @@
"""
Chunker pipeline — splits files into chunks, processes concurrently, reassembles in order.
Public API:
Pipeline — orchestrates the full pipeline
PipelineResult — aggregate result dataclass
Chunker — file → Chunk generator
ChunkQueue — bounded thread-safe queue
WorkerPool — manages N worker threads
ResultCollector — heapq-based ordered reassembly
"""
from .chunker import Chunker
from .collector import ResultCollector
from .exceptions import (
ChunkChecksumError,
ChunkError,
ChunkReadError,
PipelineError,
ProcessingError,
ProcessorFailureError,
ProcessorTimeoutError,
ReassemblyError,
)
from .models import Chunk, ChunkResult, PipelineResult
from .pipeline import Pipeline
from .pool import WorkerPool
from .processor import (
ChecksumProcessor,
CompositeProcessor,
FFmpegExtractProcessor,
Processor,
SimulatedDecodeProcessor,
)
from .queue import ChunkQueue
__all__ = [
# Core
"Pipeline",
"PipelineResult",
# Components
"Chunker",
"ChunkQueue",
"WorkerPool",
"ResultCollector",
# Models
"Chunk",
"ChunkResult",
# Processors
"Processor",
"ChecksumProcessor",
"SimulatedDecodeProcessor",
"CompositeProcessor",
"FFmpegExtractProcessor",
# Exceptions
"PipelineError",
"ChunkError",
"ChunkReadError",
"ChunkChecksumError",
"ProcessingError",
"ProcessorFailureError",
"ProcessorTimeoutError",
"ReassemblyError",
]

View File

@@ -1,101 +0,0 @@
"""
Chunker — probes a media file and yields time-based Chunk objects.
Demonstrates:
- Function parameters and defaults (Interview Topic 1)
- List comprehensions and efficient iteration / generators (Interview Topic 3)
"""
import math
import os
from typing import Generator
from core.ffmpeg.probe import probe_file
from .exceptions import ChunkReadError
from .models import Chunk
class Chunker:
"""
Splits a media file into time-based chunks via a generator.
Uses FFmpeg probe to get duration, then yields Chunk objects
representing time segments (no data read — extraction happens in the processor).
Args:
file_path: Path to the source media file
chunk_duration: Duration of each chunk in seconds (default: 10.0)
"""
def __init__(
self,
file_path: str,
chunk_duration: float = 10.0,
start_time: float | None = None,
end_time: float | None = None,
):
if not os.path.isfile(file_path):
raise ChunkReadError(f"File not found: {file_path}")
if chunk_duration <= 0:
raise ValueError("chunk_duration must be positive")
self.file_path = file_path
self.chunk_duration = chunk_duration
self.file_size = os.path.getsize(file_path)
full_duration = self._probe_duration()
# Apply time range
self.range_start = max(start_time or 0.0, 0.0)
self.range_end = min(end_time or full_duration, full_duration)
if self.range_start >= self.range_end:
raise ValueError(
f"Invalid range: start={self.range_start} >= end={self.range_end}"
)
self.source_duration = self.range_end - self.range_start
def _probe_duration(self) -> float:
"""Get source file duration via FFmpeg probe."""
try:
result = probe_file(self.file_path)
if result.duration is None or result.duration <= 0:
raise ChunkReadError(
f"Cannot determine duration for {self.file_path}"
)
return result.duration
except ChunkReadError:
raise
except Exception as e:
raise ChunkReadError(
f"Failed to probe {self.file_path}: {e}"
) from e
@property
def expected_chunks(self) -> int:
"""Calculate expected number of chunks (last chunk may be shorter)."""
if self.source_duration <= 0:
return 0
return math.ceil(self.source_duration / self.chunk_duration)
def chunks(self) -> Generator[Chunk, None, None]:
"""
Yield Chunk objects representing time segments of the source file.
Generator-based: chunks are yielded on demand.
Each chunk defines a time range — actual extraction is done by the processor.
"""
total = self.expected_chunks
for sequence in range(total):
start_time = self.range_start + sequence * self.chunk_duration
end_time = min(
start_time + self.chunk_duration, self.range_end
)
duration = end_time - start_time
yield Chunk(
sequence=sequence,
start_time=start_time,
end_time=end_time,
source_path=self.file_path,
duration=duration,
)

View File

@@ -1,98 +0,0 @@
"""
ResultCollector — reassembles chunk results in sequence order using a min-heap.
Demonstrates:
- Algorithms and sorting (Interview Topic 6) — heapq for ordered reassembly
- Core data structures (Interview Topic 5) — heap, deque
"""
import heapq
from collections import deque
from typing import List
from .exceptions import ReassemblyError
from .models import ChunkResult
class ResultCollector:
"""
Receives ChunkResults out of order, emits them in sequence order.
Uses a min-heap keyed on sequence number. Only emits a chunk when
all prior sequences have been accounted for.
Args:
total_chunks: Expected total number of chunks
"""
def __init__(self, total_chunks: int):
self.total_chunks = total_chunks
self._heap: List[tuple[int, ChunkResult]] = []
self._next_sequence = 0
self._emitted: List[ChunkResult] = []
self._seen_sequences: set[int] = set()
# Sliding window for throughput calculation
self._recent_times: deque[float] = deque(maxlen=50)
def add(self, result: ChunkResult) -> List[ChunkResult]:
"""
Add a result and return any newly emittable results in order.
Args:
result: A ChunkResult (may arrive out of order)
Returns:
List of results that can now be emitted in sequence order
(may be empty if we're still waiting for earlier sequences)
Raises:
ReassemblyError: If a duplicate sequence is received
"""
if result.sequence in self._seen_sequences:
raise ReassemblyError(
f"Duplicate sequence number: {result.sequence}"
)
self._seen_sequences.add(result.sequence)
# Track processing time for throughput
if result.processing_time > 0:
self._recent_times.append(result.processing_time)
# Push to min-heap
heapq.heappush(self._heap, (result.sequence, result))
# Emit all consecutive results starting from _next_sequence
newly_emitted = []
while self._heap and self._heap[0][0] == self._next_sequence:
_, emitted_result = heapq.heappop(self._heap)
self._emitted.append(emitted_result)
newly_emitted.append(emitted_result)
self._next_sequence += 1
return newly_emitted
@property
def is_complete(self) -> bool:
"""True if all expected chunks have been emitted in order."""
return self._next_sequence == self.total_chunks
@property
def buffered_count(self) -> int:
"""Number of results waiting in the heap (arrived out of order)."""
return len(self._heap)
@property
def emitted_count(self) -> int:
"""Number of results emitted in sequence order."""
return len(self._emitted)
@property
def avg_processing_time(self) -> float:
"""Average processing time from recent results (sliding window)."""
if not self._recent_times:
return 0.0
return sum(self._recent_times) / len(self._recent_times)
def get_ordered_results(self) -> List[ChunkResult]:
"""Get all emitted results in sequence order."""
return list(self._emitted)

View File

@@ -1,64 +0,0 @@
"""
Chunker exception hierarchy.
Demonstrates: Managing exceptions and writing resilient code (Interview Topic 7).
"""
class PipelineError(Exception):
"""Base exception for all chunker pipeline errors."""
pass
class ChunkError(PipelineError):
"""Errors related to chunk creation or validation."""
pass
class ChunkReadError(ChunkError):
"""Failed to read chunk data from source file."""
pass
class ChunkChecksumError(ChunkError):
"""Chunk data integrity validation failed."""
def __init__(self, sequence: int, expected: str, actual: str):
self.sequence = sequence
self.expected = expected
self.actual = actual
super().__init__(
f"Chunk {sequence}: checksum mismatch "
f"(expected={expected}, actual={actual})"
)
class ProcessingError(PipelineError):
"""Errors during chunk processing by workers."""
pass
class ProcessorTimeoutError(ProcessingError):
"""Processor exceeded allowed time for a chunk."""
def __init__(self, sequence: int, timeout: float):
self.sequence = sequence
self.timeout = timeout
super().__init__(f"Chunk {sequence}: processor timed out after {timeout}s")
class ProcessorFailureError(ProcessingError):
"""Processor failed to process a chunk after all retries."""
def __init__(self, sequence: int, retries: int, original_error: Exception):
self.sequence = sequence
self.retries = retries
self.original_error = original_error
super().__init__(
f"Chunk {sequence}: failed after {retries} retries — {original_error}"
)
class ReassemblyError(PipelineError):
"""Errors during result collection and ordering."""
pass

View File

@@ -1,54 +0,0 @@
"""
Internal data models for the chunker pipeline.
These are pipeline-internal dataclasses, not schema models.
Schema-level ChunkJob is in core/schema/models/jobs.py.
Demonstrates: Core data structures (Interview Topic 5).
"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
@dataclass
class Chunk:
"""A time-based segment of the source media file."""
sequence: int
start_time: float # seconds
end_time: float # seconds
source_path: str # path to source file
duration: float # end_time - start_time
checksum: str = "" # computed after extraction
@dataclass
class ChunkResult:
"""Result of processing a single chunk."""
sequence: int
success: bool
checksum_valid: bool = True
processing_time: float = 0.0
error: Optional[str] = None
retries: int = 0
worker_id: Optional[str] = None
output_file: Optional[str] = None
@dataclass
class PipelineResult:
"""Aggregate result of the entire pipeline run."""
total_chunks: int = 0
processed: int = 0
failed: int = 0
retries: int = 0
elapsed_time: float = 0.0
throughput_mbps: float = 0.0
worker_stats: Dict[str, Any] = field(default_factory=dict)
errors: List[str] = field(default_factory=list)
chunks_in_order: bool = True
output_dir: Optional[str] = None
chunk_files: List[str] = field(default_factory=list)

View File

@@ -1,279 +0,0 @@
"""
Pipeline — orchestrates the entire chunker pipeline.
Wires: Chunker → ChunkQueue → WorkerPool → ResultCollector → PipelineResult
Demonstrates:
- Function parameters and defaults (Interview Topic 1) — configurable pipeline
- Concurrency (Interview Topic 2) — producer thread + worker pool
- OOP design (Interview Topic 4) — composition of pipeline components
- Exception handling (Interview Topic 7) — graceful error propagation
"""
import json
import logging
import threading
import time
from pathlib import Path
from typing import Any, Callable, Dict, Optional
from .chunker import Chunker
from .collector import ResultCollector
from .exceptions import PipelineError
from .models import PipelineResult
from .pool import WorkerPool
from .queue import ChunkQueue
logger = logging.getLogger(__name__)
class Pipeline:
"""
Orchestrates the chunk processing pipeline.
The pipeline runs in three stages:
1. Producer thread: Chunker probes file → pushes time-based chunks to ChunkQueue
2. Worker pool: N workers pull from queue → extract mp4 segments → emit results
3. Collector: ResultCollector reassembles results in sequence order
Args:
source: Path to the source media file
chunk_duration: Duration of each chunk in seconds (default: 10.0)
num_workers: Number of concurrent worker threads (default: 4)
max_retries: Max retry attempts per chunk (default: 3)
processor_type: Processor to use — "ffmpeg", "checksum", "simulated_decode", "composite"
queue_size: Max chunks buffered in queue (default: 10)
event_callback: Optional callback for real-time events
output_dir: Directory for output chunk files (required for "ffmpeg" processor)
"""
def __init__(
self,
source: str,
chunk_duration: float = 10.0,
num_workers: int = 4,
max_retries: int = 3,
processor_type: str = "checksum",
queue_size: int = 10,
event_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None,
output_dir: Optional[str] = None,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
):
self.source = source
self.chunk_duration = chunk_duration
self.num_workers = num_workers
self.max_retries = max_retries
self.processor_type = processor_type
self.queue_size = queue_size
self.event_callback = event_callback
self.output_dir = output_dir
self.start_time = start_time
self.end_time = end_time
def _emit(self, event_type: str, data: Dict[str, Any]) -> None:
"""Emit an event if callback is registered."""
if self.event_callback:
self.event_callback(event_type, data)
def _produce_chunks(
self, chunker: Chunker, chunk_queue: ChunkQueue
) -> None:
"""Producer thread: probe file and enqueue time-based chunks."""
try:
for chunk in chunker.chunks():
chunk_queue.put(chunk, timeout=30.0)
self._emit("chunk_queued", {
"sequence": chunk.sequence,
"start_time": chunk.start_time,
"end_time": chunk.end_time,
"duration": chunk.duration,
"queue_size": chunk_queue.qsize(),
})
except Exception as e:
logger.error(f"Producer error: {e}")
self._emit("producer_error", {"error": str(e)})
finally:
chunk_queue.close()
def _monitor_progress(
self, start_time: float, file_size: int, stop_event: threading.Event
) -> None:
"""Monitor thread: emit pipeline_progress every 500ms."""
while not stop_event.is_set():
elapsed = time.monotonic() - start_time
mb = file_size / (1024 * 1024)
self._emit("pipeline_progress", {
"elapsed": round(elapsed, 2),
"throughput_mbps": round(mb / elapsed, 2) if elapsed > 0 else 0,
})
stop_event.wait(0.5)
def _write_manifest(
self, result: PipelineResult, source_duration: float
) -> None:
"""Write manifest.json to output_dir with segment metadata."""
if not self.output_dir:
return
manifest = {
"source": self.source,
"source_duration": source_duration,
"chunk_duration": self.chunk_duration,
"total_chunks": result.total_chunks,
"processed": result.processed,
"failed": result.failed,
"elapsed_time": result.elapsed_time,
"throughput_mbps": result.throughput_mbps,
"segments": [
{
"sequence": i,
"file": f"chunk_{i:04d}.mp4",
"start": i * self.chunk_duration,
"end": min(
(i + 1) * self.chunk_duration, source_duration
),
}
for i in range(result.total_chunks)
if i < result.total_chunks
],
}
manifest_path = Path(self.output_dir) / "manifest.json"
manifest_path.write_text(json.dumps(manifest, indent=2))
logger.info(f"Manifest written to {manifest_path}")
def run(self) -> PipelineResult:
"""
Execute the full pipeline.
Returns:
PipelineResult with aggregate stats
Raises:
PipelineError: If the pipeline fails catastrophically
"""
start_time = time.monotonic()
self._emit("pipeline_start", {
"source": self.source,
"chunk_duration": self.chunk_duration,
"num_workers": self.num_workers,
"processor_type": self.processor_type,
})
try:
# Stage 1: Set up chunker (probes file for duration)
chunker = Chunker(
self.source,
self.chunk_duration,
start_time=self.start_time,
end_time=self.end_time,
)
total_chunks = chunker.expected_chunks
if total_chunks == 0:
self._emit("pipeline_complete", {"total_chunks": 0})
return PipelineResult(chunks_in_order=True)
self._emit("pipeline_info", {
"file_size": chunker.file_size,
"source_duration": chunker.source_duration,
"total_chunks": total_chunks,
})
# Stage 2: Set up queue and worker pool
chunk_queue = ChunkQueue(maxsize=self.queue_size)
pool = WorkerPool(
num_workers=self.num_workers,
chunk_queue=chunk_queue,
processor_type=self.processor_type,
max_retries=self.max_retries,
event_callback=self.event_callback,
output_dir=self.output_dir,
)
# Stage 3: Start workers, monitor, then produce chunks
pool.start()
monitor_stop = threading.Event()
monitor = threading.Thread(
target=self._monitor_progress,
args=(start_time, chunker.file_size, monitor_stop),
name="progress-monitor",
daemon=True,
)
monitor.start()
producer = threading.Thread(
target=self._produce_chunks,
args=(chunker, chunk_queue),
name="chunk-producer",
daemon=True,
)
producer.start()
# Stage 4: Wait for all workers to finish
all_results = pool.wait()
producer.join(timeout=5.0)
# Stop monitor
monitor_stop.set()
monitor.join(timeout=2.0)
# Stage 5: Collect results in order
collector = ResultCollector(total_chunks)
for r in all_results:
collector.add(r)
self._emit("chunk_collected", {
"sequence": r.sequence,
"success": r.success,
"buffered": collector.buffered_count,
"emitted": collector.emitted_count,
})
# Build result
elapsed = time.monotonic() - start_time
file_size_mb = chunker.file_size / (1024 * 1024)
throughput = file_size_mb / elapsed if elapsed > 0 else 0.0
failed_results = [r for r in all_results if not r.success]
total_retries = sum(r.retries for r in all_results)
chunk_files = [
r.output_file for r in all_results
if r.success and r.output_file
]
result = PipelineResult(
total_chunks=total_chunks,
processed=len(all_results),
failed=len(failed_results),
retries=total_retries,
elapsed_time=elapsed,
throughput_mbps=throughput,
worker_stats=pool.get_worker_stats(),
errors=[r.error for r in failed_results if r.error],
chunks_in_order=collector.is_complete,
output_dir=self.output_dir,
chunk_files=chunk_files,
)
# Write manifest if output_dir is set
self._write_manifest(result, chunker.source_duration)
pool.shutdown()
self._emit("pipeline_complete", {
"total_chunks": result.total_chunks,
"processed": result.processed,
"failed": result.failed,
"elapsed": result.elapsed_time,
"throughput_mbps": result.throughput_mbps,
})
return result
except PipelineError:
raise
except Exception as e:
self._emit("pipeline_error", {"error": str(e)})
raise PipelineError(f"Pipeline failed: {e}") from e

View File

@@ -1,125 +0,0 @@
"""
WorkerPool — manages N worker threads via ThreadPoolExecutor.
Demonstrates: Python concurrency — threading (Interview Topic 2).
"""
import logging
import threading
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Any, Callable, Dict, List, Optional
from .models import ChunkResult
from .processor import (
ChecksumProcessor,
CompositeProcessor,
FFmpegExtractProcessor,
Processor,
SimulatedDecodeProcessor,
)
from .queue import ChunkQueue
from .worker import Worker
logger = logging.getLogger(__name__)
def create_processor(
processor_type: str = "checksum",
output_dir: Optional[str] = None,
) -> Processor:
"""Factory for processor instances."""
if processor_type == "ffmpeg":
if not output_dir:
raise ValueError("output_dir required for ffmpeg processor")
return FFmpegExtractProcessor(output_dir=output_dir)
elif processor_type == "checksum":
return ChecksumProcessor()
elif processor_type == "simulated_decode":
return SimulatedDecodeProcessor()
elif processor_type == "composite":
return CompositeProcessor([
ChecksumProcessor(),
SimulatedDecodeProcessor(ms_per_second=50.0),
])
else:
raise ValueError(f"Unknown processor type: {processor_type}")
class WorkerPool:
"""
Manages N worker threads that process chunks concurrently.
Args:
num_workers: Number of concurrent worker threads (default: 4)
chunk_queue: Shared queue to pull chunks from
processor_type: Type of processor for each worker (default: "checksum")
max_retries: Max retry attempts per chunk (default: 3)
event_callback: Optional callback for real-time events
"""
def __init__(
self,
num_workers: int = 4,
chunk_queue: Optional[ChunkQueue] = None,
processor_type: str = "checksum",
max_retries: int = 3,
event_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None,
output_dir: Optional[str] = None,
):
self.num_workers = num_workers
self.chunk_queue = chunk_queue or ChunkQueue()
self.processor_type = processor_type
self.max_retries = max_retries
self.event_callback = event_callback
self.output_dir = output_dir
self.shutdown_event = threading.Event()
self._executor: Optional[ThreadPoolExecutor] = None
self._futures: List[Future] = []
self._workers: List[Worker] = []
def start(self) -> None:
"""Start all worker threads."""
self._executor = ThreadPoolExecutor(
max_workers=self.num_workers,
thread_name_prefix="chunk-worker",
)
for i in range(self.num_workers):
worker = Worker(
worker_id=f"worker-{i}",
chunk_queue=self.chunk_queue,
processor=create_processor(self.processor_type, output_dir=self.output_dir),
max_retries=self.max_retries,
event_callback=self.event_callback,
)
self._workers.append(worker)
future = self._executor.submit(worker.run)
self._futures.append(future)
logger.info(f"WorkerPool started with {self.num_workers} workers")
def wait(self) -> List[ChunkResult]:
"""Wait for all workers to finish and collect results."""
all_results = []
for future in self._futures:
results = future.result()
all_results.extend(results)
return all_results
def shutdown(self) -> None:
"""Signal shutdown and cleanup."""
self.shutdown_event.set()
self.chunk_queue.close()
if self._executor:
self._executor.shutdown(wait=True)
def get_worker_stats(self) -> Dict[str, Any]:
"""Get per-worker statistics."""
return {
w.worker_id: {
"processed": w.processed_count,
"errors": w.error_count,
"retries": w.retry_count,
}
for w in self._workers
}

View File

@@ -1,173 +0,0 @@
"""
Processor ABC and concrete implementations.
Demonstrates: OOP design principles — ABC, inheritance, composition (Interview Topic 4).
"""
import hashlib
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List
from .exceptions import ChunkChecksumError
from .models import Chunk, ChunkResult
class Processor(ABC):
"""
Abstract base class for chunk processors.
Each processor defines how a single chunk is processed.
The Worker calls processor.process(chunk) and handles retries.
"""
@abstractmethod
def process(self, chunk: Chunk) -> ChunkResult:
"""Process a single chunk and return the result."""
pass
class FFmpegExtractProcessor(Processor):
"""
Extracts a time segment from the source file using FFmpeg stream copy.
Produces a playable mp4 file per chunk — no re-encoding.
Args:
output_dir: Directory to write chunk mp4 files
"""
def __init__(self, output_dir: str):
self.output_dir = output_dir
Path(output_dir).mkdir(parents=True, exist_ok=True)
def process(self, chunk: Chunk) -> ChunkResult:
from core.ffmpeg.transcode import TranscodeConfig, transcode
start = time.monotonic()
output_file = str(
Path(self.output_dir) / f"chunk_{chunk.sequence:04d}.mp4"
)
config = TranscodeConfig(
input_path=chunk.source_path,
output_path=output_file,
video_codec="copy",
audio_codec="copy",
trim_start=chunk.start_time,
trim_end=chunk.end_time,
)
transcode(config)
# Compute checksum of output file
md5 = hashlib.md5()
with open(output_file, "rb") as f:
for block in iter(lambda: f.read(8192), b""):
md5.update(block)
checksum = md5.hexdigest()
elapsed = time.monotonic() - start
return ChunkResult(
sequence=chunk.sequence,
success=True,
checksum_valid=True,
processing_time=elapsed,
output_file=output_file,
)
class ChecksumProcessor(Processor):
"""
Validates chunk metadata consistency.
For time-based chunks, verifies the time range is valid.
Raises ChunkChecksumError on invalid ranges.
"""
def process(self, chunk: Chunk) -> ChunkResult:
start = time.monotonic()
valid = chunk.duration > 0 and chunk.end_time > chunk.start_time
if not valid:
raise ChunkChecksumError(
sequence=chunk.sequence,
expected="valid time range",
actual=f"{chunk.start_time}-{chunk.end_time}",
)
elapsed = time.monotonic() - start
return ChunkResult(
sequence=chunk.sequence,
success=True,
checksum_valid=True,
processing_time=elapsed,
)
class SimulatedDecodeProcessor(Processor):
"""
Simulates decode work by sleeping proportional to chunk duration.
Useful for demonstrating concurrency behavior without real FFmpeg.
Args:
ms_per_second: Milliseconds of simulated work per second of chunk duration (default: 100)
"""
def __init__(self, ms_per_second: float = 100.0):
self.ms_per_second = ms_per_second
def process(self, chunk: Chunk) -> ChunkResult:
start = time.monotonic()
sleep_time = (self.ms_per_second * chunk.duration) / 1000.0
time.sleep(sleep_time)
elapsed = time.monotonic() - start
return ChunkResult(
sequence=chunk.sequence,
success=True,
checksum_valid=True,
processing_time=elapsed,
)
class CompositeProcessor(Processor):
"""
Chains multiple processors — runs each in sequence on the same chunk.
Demonstrates OOP composition pattern.
Args:
processors: List of processors to chain
"""
def __init__(self, processors: List[Processor]):
if not processors:
raise ValueError("CompositeProcessor requires at least one processor")
self.processors = processors
def process(self, chunk: Chunk) -> ChunkResult:
start = time.monotonic()
last_result = None
for proc in self.processors:
last_result = proc.process(chunk)
if not last_result.success:
return last_result
elapsed = time.monotonic() - start
return ChunkResult(
sequence=chunk.sequence,
success=True,
checksum_valid=last_result.checksum_valid if last_result else True,
processing_time=elapsed,
)

View File

@@ -1,76 +0,0 @@
"""
ChunkQueue — bounded, thread-safe queue with sentinel-based shutdown.
Demonstrates: Core data structures — queue.Queue (Interview Topic 5).
"""
import queue
from typing import Optional
from .models import Chunk
# Sentinel value to signal workers to stop
_SENTINEL = object()
class ChunkQueue:
"""
Thread-safe bounded queue for chunks.
Provides backpressure: producers block when the queue is full,
preventing unbounded memory usage.
Args:
maxsize: Maximum number of chunks in the queue (default: 10)
"""
def __init__(self, maxsize: int = 10):
self._queue: queue.Queue = queue.Queue(maxsize=maxsize)
self._closed = False
self.maxsize = maxsize
def put(self, chunk: Chunk, timeout: Optional[float] = None) -> None:
"""
Add a chunk to the queue. Blocks if full (backpressure).
Args:
chunk: The chunk to enqueue
timeout: Max seconds to wait (None = block forever)
Raises:
queue.Full: If timeout expires while queue is full
"""
self._queue.put(chunk, timeout=timeout)
def get(self, timeout: Optional[float] = None) -> Optional[Chunk]:
"""
Get next chunk from queue. Returns None if queue is closed.
Args:
timeout: Max seconds to wait (None = block forever)
Returns:
Chunk or None (if sentinel received, meaning queue is closed)
Raises:
queue.Empty: If timeout expires while queue is empty
"""
item = self._queue.get(timeout=timeout)
if item is _SENTINEL:
# Re-put sentinel so other workers also see it
self._queue.put(_SENTINEL)
return None
return item
def close(self) -> None:
"""Signal all consumers to stop by inserting a sentinel."""
self._closed = True
self._queue.put(_SENTINEL)
@property
def is_closed(self) -> bool:
return self._closed
def qsize(self) -> int:
"""Current number of items in the queue (approximate)."""
return self._queue.qsize()

View File

@@ -1,143 +0,0 @@
"""
Worker — pulls chunks from queue, processes with retry logic.
Demonstrates:
- Exception handling and resilient code (Interview Topic 7)
- Concurrency (Interview Topic 2) — workers run in thread pool
"""
import logging
import queue
import time
from typing import Any, Callable, Dict, Optional
from .exceptions import ProcessorFailureError
from .models import Chunk, ChunkResult
from .processor import Processor
from .queue import ChunkQueue
logger = logging.getLogger(__name__)
class Worker:
"""
Processes chunks from a queue with retry and exponential backoff.
Args:
worker_id: Identifier for this worker (e.g. "worker-0")
chunk_queue: Source queue to pull chunks from
processor: Processor instance to use
max_retries: Maximum retry attempts per chunk (default: 3)
event_callback: Optional callback for real-time status updates
"""
def __init__(
self,
worker_id: str,
chunk_queue: ChunkQueue,
processor: Processor,
max_retries: int = 3,
event_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None,
):
self.worker_id = worker_id
self.chunk_queue = chunk_queue
self.processor = processor
self.max_retries = max_retries
self.event_callback = event_callback
self.processed_count = 0
self.error_count = 0
self.retry_count = 0
def _emit(self, event_type: str, data: Dict[str, Any]) -> None:
"""Emit an event if callback is registered."""
if self.event_callback:
self.event_callback(event_type, {"worker_id": self.worker_id, **data})
def _process_with_retry(self, chunk: Chunk) -> ChunkResult:
"""
Process a chunk with exponential backoff retry.
Retry delays: 0.1s, 0.2s, 0.4s, ... (doubles each attempt)
"""
last_error = None
for attempt in range(self.max_retries + 1):
try:
if attempt > 0:
backoff = 0.1 * (2 ** (attempt - 1))
self._emit("chunk_retry", {
"sequence": chunk.sequence,
"attempt": attempt,
"backoff": backoff,
})
time.sleep(backoff)
self.retry_count += 1
result = self.processor.process(chunk)
result.retries = attempt
result.worker_id = self.worker_id
return result
except Exception as e:
last_error = e
logger.warning(
f"{self.worker_id}: chunk {chunk.sequence} "
f"attempt {attempt + 1}/{self.max_retries + 1} failed: {e}"
)
# All retries exhausted
self.error_count += 1
self._emit("chunk_error", {
"sequence": chunk.sequence,
"error": str(last_error),
"retries": self.max_retries,
})
return ChunkResult(
sequence=chunk.sequence,
success=False,
processing_time=0.0,
error=str(last_error),
retries=self.max_retries,
worker_id=self.worker_id,
)
def run(self) -> list[ChunkResult]:
"""
Main worker loop — pull chunks and process until queue is closed.
Returns:
List of ChunkResults processed by this worker
"""
results = []
self._emit("worker_status", {"state": "idle"})
while True:
try:
chunk = self.chunk_queue.get(timeout=1.0)
except queue.Empty:
continue
if chunk is None: # Sentinel received
break
self._emit("chunk_processing", {
"sequence": chunk.sequence,
"state": "processing",
"queue_size": self.chunk_queue.qsize(),
})
result = self._process_with_retry(chunk)
results.append(result)
self.processed_count += 1
self._emit("chunk_done", {
"sequence": chunk.sequence,
"success": result.success,
"processing_time": result.processing_time,
"retries": result.retries,
"queue_size": self.chunk_queue.qsize(),
})
self._emit("worker_status", {"state": "stopped"})
return results

View File

@@ -13,7 +13,7 @@ Basic CRUD (create, get, update, delete) goes directly through the session:
from .connection import get_session, create_tables
from .tables import MediaAsset, Job, Timeline, Checkpoint, Brand
from .models import MediaAsset, Job, Timeline, Checkpoint, Brand
from .assets import list_assets, get_asset_filenames
from .job import list_jobs

View File

@@ -7,7 +7,7 @@ from uuid import UUID
from sqlmodel import Session, select
from .tables import MediaAsset
from .models import MediaAsset
def list_assets(session: Session, status: Optional[str] = None, search: Optional[str] = None) -> list[MediaAsset]:

View File

@@ -7,7 +7,7 @@ from uuid import UUID
from sqlmodel import Session, select
from .tables import Brand
from .models import Brand
def get_or_create_brand(session: Session, canonical_name: str,

View File

@@ -6,7 +6,7 @@ from uuid import UUID
from sqlmodel import Session, select
from .tables import Checkpoint
from .models import Checkpoint
def get_latest_checkpoint(session: Session, timeline_id: UUID, parent_id: UUID | None = None) -> Checkpoint | None:

View File

@@ -30,5 +30,5 @@ def get_session() -> Session:
def create_tables():
"""Create all SQLModel tables."""
from sqlmodel import SQLModel
from . import tables # noqa — registers all table classes
from . import models # noqa — registers all table classes
SQLModel.metadata.create_all(get_engine())

View File

@@ -0,0 +1,142 @@
{
"name": "soccer_broadcast",
"pipeline": {
"name": "soccer_broadcast",
"profile_name": "soccer_broadcast",
"stages": [
{
"name": "extract_frames",
"branch": "trunk"
},
{
"name": "filter_scenes",
"branch": "trunk"
},
{
"name": "field_segmentation",
"branch": "trunk"
},
{
"name": "detect_edges",
"branch": "hoarding"
},
{
"name": "detect_objects",
"branch": "objects"
},
{
"name": "preprocess"
},
{
"name": "run_ocr"
},
{
"name": "match_brands"
},
{
"name": "escalate_vlm"
},
{
"name": "escalate_cloud"
},
{
"name": "compile_report"
}
],
"edges": [
{
"source": "extract_frames",
"target": "filter_scenes"
},
{
"source": "filter_scenes",
"target": "field_segmentation"
},
{
"source": "field_segmentation",
"target": "detect_edges"
},
{
"source": "field_segmentation",
"target": "detect_objects"
},
{
"source": "detect_edges",
"target": "preprocess"
},
{
"source": "detect_objects",
"target": "preprocess"
},
{
"source": "preprocess",
"target": "run_ocr"
},
{
"source": "run_ocr",
"target": "match_brands"
},
{
"source": "match_brands",
"target": "escalate_vlm"
},
{
"source": "escalate_vlm",
"target": "escalate_cloud"
},
{
"source": "escalate_cloud",
"target": "compile_report"
}
]
},
"configs": {
"extract_frames": {
"fps": 2.0,
"max_frames": 500
},
"filter_scenes": {
"hamming_threshold": 8,
"enabled": true
},
"field_segmentation": {
"enabled": true,
"hue_low": 30,
"hue_high": 85,
"sat_low": 30,
"sat_high": 255,
"val_low": 30,
"val_high": 255,
"morph_kernel": 15,
"min_area_ratio": 0.05
},
"detect_edges": {
"enabled": true,
"edge_canny_low": 50,
"edge_canny_high": 150,
"edge_hough_threshold": 80,
"edge_hough_min_length": 100,
"edge_hough_max_gap": 10,
"edge_pair_max_distance": 200,
"edge_pair_min_distance": 15
},
"detect_objects": {
"model_name": "yolov8n.pt",
"confidence_threshold": 0.3,
"target_classes": []
},
"run_ocr": {
"languages": [
"en",
"es"
],
"min_confidence": 0.5
},
"match_brands": {
"fuzzy_threshold": 75
},
"escalate_vlm": {
"vlm_prompt_template": "Identify the brand or sponsor visible in this cropped region from a soccer broadcast.{hint}{text} Respond with: brand, confidence (0-1), reasoning."
}
}
}

View File

@@ -7,7 +7,7 @@ from uuid import UUID
from sqlmodel import Session, select
from .tables import Job
from .models import Job
def list_jobs(session: Session, parent_id: Optional[UUID] = None, status: Optional[str] = None) -> list[Job]:

View File

@@ -44,7 +44,7 @@ class SourceType(str, Enum):
class MediaAsset(SQLModel, table=True):
"""A video/audio file registered in the system."""
__tablename__ = "media_assets"
__tablename__ = "media_asset"
id: UUID = Field(default_factory=uuid4, primary_key=True)
filename: str
@@ -67,7 +67,7 @@ class MediaAsset(SQLModel, table=True):
class TranscodePreset(SQLModel, table=True):
"""A reusable transcoding configuration (like Handbrake presets)."""
__tablename__ = "transcode_presets"
__tablename__ = "transcode_preset"
id: UUID = Field(default_factory=uuid4, primary_key=True)
name: str
@@ -90,12 +90,13 @@ class TranscodePreset(SQLModel, table=True):
class Job(SQLModel, table=True):
"""A pipeline job."""
__tablename__ = "jobs"
__tablename__ = "job"
id: UUID = Field(default_factory=uuid4, primary_key=True)
source_asset_id: UUID = Field(index=True)
video_path: str
profile_name: str = "soccer_broadcast"
timeline_id: Optional[UUID] = None
parent_id: Optional[UUID] = None
run_type: RunType = "initial"
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
@@ -107,7 +108,6 @@ class Job(SQLModel, table=True):
brands_found: int = 0
cloud_llm_calls: int = 0
estimated_cost_usd: float = 0.0
celery_task_id: Optional[str] = None
priority: int = 0
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
started_at: Optional[datetime] = None
@@ -115,7 +115,7 @@ class Job(SQLModel, table=True):
class Timeline(SQLModel, table=True):
"""The frame sequence from a source video."""
__tablename__ = "timelines"
__tablename__ = "timeline"
id: UUID = Field(default_factory=uuid4, primary_key=True)
source_asset_id: Optional[UUID] = Field(default=None, index=True)
@@ -129,10 +129,11 @@ class Timeline(SQLModel, table=True):
class Checkpoint(SQLModel, table=True):
"""A snapshot of pipeline state on a timeline."""
__tablename__ = "checkpoints"
__tablename__ = "checkpoint"
id: UUID = Field(default_factory=uuid4, primary_key=True)
timeline_id: UUID
job_id: Optional[UUID] = Field(default=None, index=True)
parent_id: Optional[UUID] = None
stage_outputs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
@@ -143,7 +144,7 @@ class Checkpoint(SQLModel, table=True):
class Brand(SQLModel, table=True):
"""A brand discovered or registered in the system."""
__tablename__ = "brands"
__tablename__ = "brand"
id: UUID = Field(default_factory=uuid4, primary_key=True)
canonical_name: str = Field(index=True)
@@ -154,3 +155,12 @@ class Brand(SQLModel, table=True):
total_airings: int = 0
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
class Profile(SQLModel, table=True):
"""A content type profile."""
__tablename__ = "profile"
id: UUID = Field(default_factory=uuid4, primary_key=True)
name: str
pipeline: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
configs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))

43
core/db/seed.py Normal file
View File

@@ -0,0 +1,43 @@
"""
Seed data — insert initial profile rows if they don't exist.
Called on startup after create_tables().
"""
import json
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
SEED_DIR = Path(__file__).parent / "fixtures"
def seed_profiles():
"""Insert seed profiles from JSON fixtures if not already present."""
from .connection import get_session
from .models import Profile
fixtures = list(SEED_DIR.glob("*.json"))
if not fixtures:
return
with get_session() as session:
for f in fixtures:
data = json.loads(f.read_text())
name = data["name"]
existing = session.query(Profile).filter(Profile.name == name).first()
if existing:
logger.debug("Profile %s already exists, skipping seed", name)
continue
profile = Profile(
name=name,
pipeline=data.get("pipeline", {}),
configs=data.get("configs", {}),
)
session.add(profile)
logger.info("Seeded profile: %s", name)
session.commit()

View File

@@ -1,96 +0,0 @@
"""
SQLModel table definitions.
Generated by modelgen from core/schema/models/. Do not edit directly.
"""
from __future__ import annotations
from datetime import datetime
from typing import Any, Dict, List, Optional
from uuid import UUID, uuid4
from sqlalchemy import JSON
from sqlmodel import Column, Field, SQLModel
class MediaAsset(SQLModel, table=True):
__tablename__ = "media_asset"
id: UUID = Field(default_factory=uuid4, primary_key=True)
filename: str
path: str
status: str = "pending"
size_bytes: int = 0
duration_seconds: float = 0.0
width: Optional[int] = None
height: Optional[int] = None
fps: Optional[float] = None
codec: Optional[str] = None
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
class Job(SQLModel, table=True):
__tablename__ = "job"
id: UUID = Field(default_factory=uuid4, primary_key=True)
source_asset_id: UUID = Field(index=True)
video_path: str
profile_name: str = "soccer_broadcast"
parent_id: Optional[UUID] = Field(default=None, index=True)
run_type: str = "initial"
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
status: str = "pending"
current_stage: Optional[str] = None
progress: float = 0.0
error_message: Optional[str] = None
total_detections: int = 0
brands_found: int = 0
cloud_llm_calls: int = 0
estimated_cost_usd: float = 0.0
priority: int = 0
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
class Timeline(SQLModel, table=True):
__tablename__ = "timeline"
id: UUID = Field(default_factory=uuid4, primary_key=True)
source_asset_id: Optional[UUID] = Field(default=None, index=True)
source_video: str = ""
profile_name: str = ""
fps: float = 2.0
frames_prefix: str = ""
frames_manifest: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
frames_meta: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
class Checkpoint(SQLModel, table=True):
__tablename__ = "checkpoint"
id: UUID = Field(default_factory=uuid4, primary_key=True)
timeline_id: UUID = Field(index=True)
parent_id: Optional[UUID] = Field(default=None, index=True)
stage_outputs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
stats: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
is_scenario: bool = False
scenario_label: str = ""
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
class Brand(SQLModel, table=True):
__tablename__ = "brand"
id: UUID = Field(default_factory=uuid4, primary_key=True)
canonical_name: str = Field(index=True)
aliases: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
source: str = "ocr"
confirmed: bool = False
airings: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
total_airings: int = 0
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow)

0
core/detect/__init__.py Normal file
View File

View File

@@ -0,0 +1,19 @@
"""
Checkpoint system — Timeline + Checkpoint tree.
detect/checkpoint/
frames.py — frame image S3 upload/download
storage.py — Timeline + Checkpoint (Postgres + MinIO)
replay.py — replay (TODO: migrate to new model)
runner_bridge.py — checkpoint hook for PipelineRunner
"""
from .storage import (
create_timeline,
get_timeline_frames,
get_timeline_frames_b64,
save_stage_output,
load_stage_output,
)
from .frames import save_frames, load_frames
from .runner_bridge import checkpoint_after_stage, reset_checkpoint_state, get_timeline_id

View File

@@ -0,0 +1,111 @@
"""Frame image storage — save/load to S3/MinIO as JPEGs."""
from __future__ import annotations
import logging
import os
import tempfile
import numpy as np
from PIL import Image
from core.detect.models import Frame
logger = logging.getLogger(__name__)
BUCKET = os.environ.get("S3_BUCKET", "mpr")
CHECKPOINT_PREFIX = "checkpoints"
def save_frames(job_id: str, frames: list[Frame]) -> dict[int, str]:
"""
Save frame images to S3 as JPEGs.
Returns manifest: {sequence: s3_key}
"""
from core.storage.s3 import upload_file
manifest = {}
for frame in frames:
key = f"{CHECKPOINT_PREFIX}/{job_id}/frames/{frame.sequence}.jpg"
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
img = Image.fromarray(frame.image)
img.save(tmp, format="JPEG", quality=85)
tmp_path = tmp.name
try:
upload_file(tmp_path, BUCKET, key)
finally:
os.unlink(tmp_path)
manifest[frame.sequence] = key
logger.info("Saved %d frames to s3://%s/%s/%s/frames/",
len(frames), BUCKET, CHECKPOINT_PREFIX, job_id)
return manifest
def load_frames(manifest: dict[int, str], frame_metadata: list[dict]) -> list[Frame]:
"""
Load frame images from S3 and reconstitute Frame objects.
frame_metadata: list of dicts with sequence, chunk_id, timestamp, perceptual_hash.
"""
from core.storage.s3 import download_to_temp
meta_map = {m["sequence"]: m for m in frame_metadata}
frames = []
for seq, key in manifest.items():
tmp_path = download_to_temp(BUCKET, key)
try:
img = Image.open(tmp_path).convert("RGB")
image_array = np.array(img)
finally:
os.unlink(tmp_path)
meta = meta_map.get(seq, {})
frame = Frame(
sequence=seq,
chunk_id=meta.get("chunk_id", 0),
timestamp=meta.get("timestamp", 0.0),
image=image_array,
perceptual_hash=meta.get("perceptual_hash", ""),
)
frames.append(frame)
frames.sort(key=lambda f: f.sequence)
return frames
def load_frames_b64(manifest: dict[int, str], frame_metadata: list[dict]) -> list[dict]:
"""
Load frame images from S3 as base64 JPEG — lightweight, no numpy.
Returns list of dicts: {seq, timestamp, jpeg_b64}
"""
import base64
from core.storage.s3 import download_to_temp
meta_map = {m["sequence"]: m for m in frame_metadata}
frames = []
for seq, key in manifest.items():
tmp_path = download_to_temp(BUCKET, key)
try:
with open(tmp_path, "rb") as f:
jpeg_bytes = f.read()
finally:
os.unlink(tmp_path)
meta = meta_map.get(seq, {})
frames.append({
"seq": seq,
"timestamp": meta.get("timestamp", 0.0),
"jpeg_b64": base64.b64encode(jpeg_bytes).decode(),
})
frames.sort(key=lambda f: f["seq"])
return frames

View File

@@ -0,0 +1,232 @@
"""
Pipeline replay — re-run from any stage with different config.
Loads a checkpoint, applies config overrides, builds a subgraph
starting from the target stage, and invokes it.
"""
from __future__ import annotations
import logging
import uuid
from core.detect import emit
# TODO: migrate to Timeline/Branch/Checkpoint model
# These old functions no longer exist — replay needs rework
def _not_migrated(*args, **kwargs):
raise NotImplementedError("Replay not yet migrated to Timeline/Branch/Checkpoint model")
load_checkpoint = _not_migrated
list_checkpoints = _not_migrated
from core.detect.graph import NODES, build_graph
logger = logging.getLogger(__name__)
# OverrideProfile removed — config overrides are now handled by dict merging
# in _load_profile() (nodes.py) and replay_single_stage (below).
def replay_from(
job_id: str,
start_stage: str,
config_overrides: dict | None = None,
checkpoint: bool = True,
) -> dict:
"""
Replay the pipeline from a specific stage.
Loads the checkpoint from the stage immediately before start_stage,
applies config overrides, and runs the subgraph from start_stage onward.
Returns the final state dict.
"""
if start_stage not in NODES:
raise ValueError(f"Unknown stage: {start_stage!r}. Options: {NODES}")
start_idx = NODES.index(start_stage)
# Load checkpoint from the stage before start_stage
if start_idx == 0:
raise ValueError("Cannot replay from the first stage — just run the full pipeline")
previous_stage = NODES[start_idx - 1]
available = list_checkpoints(job_id)
if previous_stage not in available:
raise ValueError(
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
f"Available: {available}"
)
logger.info("Replaying job %s from %s (loading checkpoint: %s)",
job_id, start_stage, previous_stage)
state = load_checkpoint(job_id, previous_stage)
# Apply config overrides
if config_overrides:
state["config_overrides"] = config_overrides
# Set run context for SSE events
run_id = str(uuid.uuid4())[:8]
emit.set_run_context(
run_id=run_id,
parent_job_id=job_id,
run_type="replay",
)
# Build subgraph starting from start_stage
graph = build_graph(checkpoint=checkpoint, start_from=start_stage)
pipeline = graph.compile()
try:
result = pipeline.invoke(state)
finally:
emit.clear_run_context()
return result
def replay_single_stage(
job_id: str,
stage: str,
frame_refs: list[int] | None = None,
config_overrides: dict | None = None,
debug: bool = False,
) -> dict:
"""
Replay a single stage on specific frames (or all frames from checkpoint).
Fast path for interactive parameter tuning — runs only the target stage
function, not the full pipeline tail. Returns the stage output directly.
When debug=True and stage is detect_edges, returns additional overlay
data (Canny edges, Hough lines) for visual feedback in the editor.
For detect_edges: returns {"edge_regions_by_frame": {seq: [box, ...]}}
With debug=True, also returns {"debug": {seq: {edge_overlay_b64, lines_overlay_b64, ...}}}
"""
if stage not in NODES:
raise ValueError(f"Unknown stage: {stage!r}. Options: {NODES}")
stage_idx = NODES.index(stage)
if stage_idx == 0:
raise ValueError("Cannot replay the first stage — just run the full pipeline")
previous_stage = NODES[stage_idx - 1]
available = list_checkpoints(job_id)
if previous_stage not in available:
raise ValueError(
f"No checkpoint for stage {previous_stage!r} (job {job_id}). "
f"Available: {available}"
)
logger.info("Single-stage replay: job %s, stage %s (loading checkpoint: %s, debug=%s)",
job_id, stage, previous_stage, debug)
state = load_checkpoint(job_id, previous_stage)
# Build profile with overrides
from core.detect.profile import get_profile, get_stage_config
profile = get_profile(state.get("profile_name", "soccer_broadcast"))
if config_overrides:
merged_configs = dict(profile.get("configs", {}))
for sname, soverrides in config_overrides.items():
if sname in merged_configs:
merged_configs[sname] = {**merged_configs[sname], **soverrides}
else:
merged_configs[sname] = soverrides
profile = {**profile, "configs": merged_configs}
# Run the stage function directly (not through the graph)
if stage == "detect_edges":
return _replay_detect_edges(state, profile, frame_refs, job_id, debug)
else:
raise ValueError(
f"Single-stage replay not yet implemented for {stage!r}. "
f"Use replay_from() for full pipeline replay."
)
def _replay_detect_edges(
state: dict,
profile,
frame_refs: list[int] | None,
job_id: str,
debug: bool,
) -> dict:
"""Run edge detection on checkpoint frames, optionally with debug overlays."""
import os
from core.detect.stages.edge_detector import detect_edge_regions
from core.detect.profile import get_stage_config
from core.detect.stages.models import RegionAnalysisConfig
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
frames = state.get("filtered_frames", [])
if frame_refs:
ref_set = set(frame_refs)
frames = [f for f in frames if f.sequence in ref_set]
inference_url = os.environ.get("INFERENCE_URL")
# Normal run — always needed for the boxes
result = detect_edge_regions(
frames=frames,
config=config,
inference_url=inference_url,
job_id=job_id,
)
output = {"edge_regions_by_frame": result}
# Debug overlays — call debug endpoint (remote) or local debug function
if debug and frames:
debug_data = {}
if inference_url:
from core.detect.inference import InferenceClient
client = InferenceClient(base_url=inference_url, job_id=job_id)
for frame in frames:
dr = client.detect_edges_debug(
image=frame.image,
edge_canny_low=config.edge_canny_low,
edge_canny_high=config.edge_canny_high,
edge_hough_threshold=config.edge_hough_threshold,
edge_hough_min_length=config.edge_hough_min_length,
edge_hough_max_gap=config.edge_hough_max_gap,
edge_pair_max_distance=config.edge_pair_max_distance,
edge_pair_min_distance=config.edge_pair_min_distance,
)
debug_data[frame.sequence] = {
"edge_overlay_b64": dr.edge_overlay_b64,
"lines_overlay_b64": dr.lines_overlay_b64,
"horizontal_count": dr.horizontal_count,
"pair_count": dr.pair_count,
}
else:
# Local mode — import GPU module directly
from core.detect.stages.edge_detector import _load_cv_edges
edges_mod = _load_cv_edges()
for frame in frames:
dr = edges_mod.detect_edges_debug(
frame.image,
canny_low=config.edge_canny_low,
canny_high=config.edge_canny_high,
hough_threshold=config.edge_hough_threshold,
hough_min_length=config.edge_hough_min_length,
hough_max_gap=config.edge_hough_max_gap,
pair_max_distance=config.edge_pair_max_distance,
pair_min_distance=config.edge_pair_min_distance,
)
debug_data[frame.sequence] = {
"edge_overlay_b64": dr["edge_overlay_b64"],
"lines_overlay_b64": dr["lines_overlay_b64"],
"horizontal_count": dr["horizontal_count"],
"pair_count": dr["pair_count"],
}
output["debug"] = debug_data
return output

View File

@@ -0,0 +1,97 @@
"""
Runner bridge — checkpoint hook called by PipelineRunner after each stage.
Owns the per-job state (timeline, frame manifest, checkpoint chain) that
the runner shouldn't know about.
Timeline and Job are independent entities:
- One Timeline can serve multiple Jobs (re-run with different params)
- One Job operates on one Timeline (set after frame extraction)
- Checkpoints belong to Timeline, tagged with the Job that created them
"""
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
# Per-job state
_timeline_id: dict[str, str] = {}
_frames_manifest: dict[str, dict[int, str]] = {}
_latest_checkpoint: dict[str, str] = {}
def reset_checkpoint_state(job_id: str):
"""Clean up per-job checkpoint state. Called when pipeline finishes."""
_timeline_id.pop(job_id, None)
_frames_manifest.pop(job_id, None)
_latest_checkpoint.pop(job_id, None)
def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict):
"""
Save a checkpoint after a stage completes.
Called by the runner. Handles:
- Timeline creation (once, on extract_frames)
- Frame upload (via create_timeline)
- Stage output serialization (via stage registry)
- Checkpoint chain (parent → child)
"""
if not job_id:
return
from .storage import create_timeline, save_stage_output
from core.detect.stages.base import _REGISTRY
merged = {**state, **result}
# On extract_frames: create Timeline + upload frames + root checkpoint
if stage_name == "extract_frames" and job_id not in _timeline_id:
frames = merged.get("frames", [])
video_path = merged.get("video_path", "")
profile_name = merged.get("profile_name", "")
tid, cid = create_timeline(
source_video=video_path,
profile_name=profile_name,
frames=frames,
)
_timeline_id[job_id] = tid
_latest_checkpoint[job_id] = cid
logger.info("Job %s → Timeline %s (root checkpoint %s)", job_id, tid, cid)
# Emit timeline_id via SSE so the UI can use it for checkpoint loads
from core.detect import emit
emit.log(job_id, "Checkpoint", "INFO", f"timeline_id={tid}")
return
# For subsequent stages: save checkpoint on the timeline
tid = _timeline_id.get(job_id)
if not tid:
logger.warning("No timeline for job %s, skipping checkpoint", job_id)
return
# Serialize stage output using the stage's serialize_fn if available
stage_cls = _REGISTRY.get(stage_name)
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
if serialize_fn:
output_json = serialize_fn(merged, job_id)
else:
output_json = {}
parent_id = _latest_checkpoint.get(job_id)
new_checkpoint_id = save_stage_output(
timeline_id=tid,
parent_checkpoint_id=parent_id,
stage_name=stage_name,
output_json=output_json,
job_id=job_id,
)
_latest_checkpoint[job_id] = new_checkpoint_id
def get_timeline_id(job_id: str) -> str | None:
"""Get the timeline_id for a running job. Used by the UI to load checkpoints."""
return _timeline_id.get(job_id)

View File

@@ -0,0 +1,108 @@
"""
State serialization — DetectState ↔ JSON-compatible dict.
Delegates to each stage's serialize_fn/deserialize_fn via the registry.
This file has no model-specific knowledge — stages own their data format.
The only things serialized here are the "envelope" fields (job_id, video_path, etc.)
that don't belong to any stage.
"""
from __future__ import annotations
from core.schema.serializers._common import serialize_dataclass
from core.schema.serializers.pipeline import (
deserialize_pipeline_stats,
deserialize_text_candidates,
)
# Envelope fields — not owned by any stage, always present
ENVELOPE_KEYS = ["job_id", "video_path", "profile_name", "config_overrides"]
def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
"""
Serialize DetectState to a JSON-compatible dict.
Calls each registered stage's serialize_fn for stage-owned data.
Envelope fields (job_id, etc.) are copied directly.
"""
from core.detect.stages.base import _REGISTRY
checkpoint = {}
# Envelope
for key in ENVELOPE_KEYS:
default = {} if key == "config_overrides" else ""
checkpoint[key] = state.get(key, default)
# Frames manifest (needed by frame-loading stages)
checkpoint["frames_manifest"] = {str(k): v for k, v in frames_manifest.items()}
# Stats (shared across stages, not owned by one)
stats = state.get("stats")
if stats is not None:
checkpoint["stats"] = serialize_dataclass(stats)
else:
checkpoint["stats"] = {}
# Per-stage data
for name, stage_def in _REGISTRY.items():
if stage_def.serialize_fn is None:
continue
job_id = state.get("job_id", "")
stage_data = stage_def.serialize_fn(state, job_id)
checkpoint[f"stage_{name}"] = stage_data
return checkpoint
def deserialize_state(checkpoint: dict, frames: list) -> dict:
"""
Reconstitute DetectState from a checkpoint dict + loaded frames.
Calls each stage's deserialize_fn to restore stage-owned data.
"""
from core.detect.stages.base import _REGISTRY
frame_map = {f.sequence: f for f in frames}
state = {}
# Envelope
for key in ENVELOPE_KEYS:
default = {} if key == "config_overrides" else ""
state[key] = checkpoint.get(key, default)
# Frames (always present, loaded externally)
state["frames"] = frames
# Stats
state["stats"] = deserialize_pipeline_stats(checkpoint.get("stats", {}))
# Per-stage data
for name, stage_def in _REGISTRY.items():
if stage_def.deserialize_fn is None:
continue
stage_key = f"stage_{name}"
if stage_key not in checkpoint:
continue
job_id = state.get("job_id", "")
stage_data = stage_def.deserialize_fn(checkpoint[stage_key], job_id)
for k, v in stage_data.items():
if k == "_filtered_sequences":
# Reconnect filtered frames from sequence list
seq_set = set(v)
state["filtered_frames"] = [f for f in frames if f.sequence in seq_set]
elif k.endswith("_raw"):
# Raw text candidates need frame reference reconnection
real_key = k.removeprefix("_").removesuffix("_raw")
state[real_key] = deserialize_text_candidates(v, frame_map)
else:
state[k] = v
return state

View File

@@ -0,0 +1,177 @@
"""
Checkpoint storage — Timeline + Checkpoint (tree of snapshots).
Timeline: frame sequence from source video (frames in MinIO)
Checkpoint: snapshot of pipeline state (stage outputs as JSONB in Postgres)
parent_id forms a tree — multiple children = different config tries
"""
from __future__ import annotations
import logging
from uuid import UUID
from .frames import save_frames, load_frames, CHECKPOINT_PREFIX
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Timeline
# ---------------------------------------------------------------------------
def create_timeline(
source_video: str,
profile_name: str,
frames: list,
fps: float = 2.0,
source_asset_id: UUID | None = None,
) -> tuple[str, str]:
"""
Create a timeline from frames. Uploads frame images to MinIO,
creates Timeline + root Checkpoint in Postgres.
Returns (timeline_id, checkpoint_id).
"""
from core.db.models import Timeline, Checkpoint
from core.db.connection import get_session
with get_session() as session:
timeline = Timeline(
source_video=source_video,
profile_name=profile_name,
source_asset_id=source_asset_id,
fps=fps,
)
session.add(timeline)
session.flush()
tid = str(timeline.id)
# Upload frames to MinIO
manifest = save_frames(tid, frames)
frames_meta = [
{
"sequence": f.sequence,
"chunk_id": getattr(f, "chunk_id", 0),
"timestamp": f.timestamp,
"perceptual_hash": getattr(f, "perceptual_hash", ""),
}
for f in frames
]
timeline.frames_prefix = f"{CHECKPOINT_PREFIX}/{tid}/frames/"
timeline.frames_manifest = {str(k): v for k, v in manifest.items()}
timeline.frames_meta = frames_meta
checkpoint = Checkpoint(
timeline_id=timeline.id,
parent_id=None,
stage_outputs={},
stats={"frames_extracted": len(frames)},
)
session.add(checkpoint)
session.commit()
session.refresh(checkpoint)
cid = str(checkpoint.id)
logger.info("Timeline created: %s (%d frames, root checkpoint %s)", tid, len(frames), cid)
return tid, cid
def get_timeline_frames(timeline_id: str) -> list:
"""Load frames from a timeline (from MinIO) as Frame objects."""
from core.db.models import Timeline
from core.db.connection import get_session
with get_session() as session:
timeline = session.get(Timeline, UUID(timeline_id))
if not timeline:
raise ValueError(f"Timeline not found: {timeline_id}")
raw_manifest = timeline.frames_manifest or {}
manifest = {int(k): v for k, v in raw_manifest.items()}
return load_frames(manifest, timeline.frames_meta or [])
def get_timeline_frames_b64(timeline_id: str) -> list[dict]:
"""Load frames as base64 JPEG (lightweight, no numpy)."""
from core.db.models import Timeline
from core.db.connection import get_session
from .frames import load_frames_b64
with get_session() as session:
timeline = session.get(Timeline, UUID(timeline_id))
if not timeline:
raise ValueError(f"Timeline not found: {timeline_id}")
raw_manifest = timeline.frames_manifest or {}
manifest = {int(k): v for k, v in raw_manifest.items()}
return load_frames_b64(manifest, timeline.frames_meta or [])
# ---------------------------------------------------------------------------
# Checkpoint
# ---------------------------------------------------------------------------
def save_stage_output(
timeline_id: str,
parent_checkpoint_id: str | None,
stage_name: str,
output_json: dict,
config_overrides: dict | None = None,
stats: dict | None = None,
is_scenario: bool = False,
scenario_label: str = "",
job_id: str | None = None,
) -> str:
"""
Save a stage's output as a new checkpoint (child of parent).
Carries forward stage outputs from parent + adds the new one.
Returns the new checkpoint ID.
"""
from core.db.models import Checkpoint
from core.db.connection import get_session
with get_session() as session:
parent_outputs = {}
parent_stats = {}
parent_config = {}
if parent_checkpoint_id:
parent = session.get(Checkpoint, UUID(parent_checkpoint_id))
if parent:
parent_outputs = dict(parent.stage_outputs or {})
parent_stats = dict(parent.stats or {})
parent_config = dict(parent.config_overrides or {})
checkpoint = Checkpoint(
timeline_id=UUID(timeline_id),
job_id=UUID(job_id) if job_id else None,
parent_id=UUID(parent_checkpoint_id) if parent_checkpoint_id else None,
stage_outputs={**parent_outputs, stage_name: output_json},
config_overrides={**parent_config, **(config_overrides or {})},
stats={**parent_stats, **(stats or {})},
is_scenario=is_scenario,
scenario_label=scenario_label,
)
session.add(checkpoint)
session.commit()
session.refresh(checkpoint)
cid = str(checkpoint.id)
logger.info("Checkpoint saved: %s (timeline %s, stage %s, parent %s)",
cid, timeline_id, stage_name, parent_checkpoint_id)
return cid
def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None:
"""Load a stage's output from a checkpoint."""
from core.db.models import Checkpoint
from core.db.connection import get_session
with get_session() as session:
checkpoint = session.get(Checkpoint, UUID(checkpoint_id))
if not checkpoint:
return None
return (checkpoint.stage_outputs or {}).get(stage_name)

159
core/detect/emit.py Normal file
View File

@@ -0,0 +1,159 @@
"""
Event emission helpers for detection pipeline stages.
Single place that knows how to build event payloads.
Stages call these instead of constructing dicts or dataclasses directly.
Run context (run_id, parent_job_id) is set once at pipeline start via
set_run_context() and automatically injected into all events.
Log level is set per-run with optional per-stage overrides.
DEBUG events are only pushed when the run (or stage) log level allows it.
"""
from __future__ import annotations
import dataclasses
from datetime import datetime, timezone
from core.detect.events import push_detect_event
from core.detect.models import PipelineStats
# Log level ordering for comparison
_LEVEL_ORDER = {"DEBUG": 0, "INFO": 1, "WARN": 2, "ERROR": 3}
# Module-level run context — set once per pipeline invocation
_run_context: dict = {}
_run_log_level: str = "INFO"
_stage_log_levels: dict[str, str] = {} # stage_name → level override
def set_run_context(
run_id: str = "",
parent_job_id: str = "",
run_type: str = "initial",
log_level: str = "INFO",
):
"""Set the run context for all subsequent events in this pipeline invocation."""
global _run_context, _run_log_level
_run_context = {
"run_id": run_id,
"parent_job_id": parent_job_id,
"run_type": run_type,
}
_run_log_level = log_level.upper()
_stage_log_levels.clear()
def set_stage_log_level(stage: str, level: str):
"""Override log level for a specific stage."""
_stage_log_levels[stage] = level.upper()
def clear_stage_log_level(stage: str):
"""Remove per-stage log level override."""
_stage_log_levels.pop(stage, None)
def clear_run_context():
global _run_context, _run_log_level
_run_context = {}
_run_log_level = "INFO"
_stage_log_levels.clear()
def _should_emit(level: str, stage: str) -> bool:
"""Check if this log level should be emitted given run/stage settings."""
effective = _stage_log_levels.get(stage, _run_log_level)
return _LEVEL_ORDER.get(level.upper(), 1) >= _LEVEL_ORDER.get(effective, 1)
def _inject_context(payload: dict) -> dict:
"""Add run context fields to an event payload."""
if _run_context:
payload.update(_run_context)
return payload
def log(job_id: str | None, stage: str, level: str, msg: str) -> None:
if not job_id:
return
if not _should_emit(level, stage):
return
payload = {
"level": level,
"stage": stage,
"msg": msg,
"ts": datetime.now(timezone.utc).isoformat(),
}
_inject_context(payload)
push_detect_event(job_id, "log", payload)
def stats(job_id: str | None, **kwargs) -> None:
if not job_id:
return
s = PipelineStats(**kwargs)
payload = dataclasses.asdict(s)
_inject_context(payload)
push_detect_event(job_id, "stats_update", payload)
def frame_update(
job_id: str | None,
frame_ref: int,
timestamp: float,
jpeg_b64: str,
boxes: list[dict],
) -> None:
if not job_id:
return
payload = {
"frame_ref": frame_ref,
"timestamp": timestamp,
"jpeg_b64": jpeg_b64,
"boxes": boxes,
}
_inject_context(payload)
push_detect_event(job_id, "frame_update", payload)
def graph_update(job_id: str | None, nodes: list[dict]) -> None:
if not job_id:
return
payload = {"nodes": nodes}
_inject_context(payload)
push_detect_event(job_id, "graph_update", payload)
def detection(
job_id: str | None,
brand: str,
confidence: float,
source: str,
timestamp: float,
duration: float = 0.0,
content_type: str = "",
frame_ref: int | None = None,
) -> None:
if not job_id:
return
payload = {
"brand": brand,
"confidence": confidence,
"source": source,
"timestamp": timestamp,
"duration": duration,
"content_type": content_type,
"frame_ref": frame_ref,
}
_inject_context(payload)
push_detect_event(job_id, "detection", payload)
def job_complete(job_id: str | None, report: dict) -> None:
if not job_id:
return
payload = {"job_id": job_id, "report": report}
_inject_context(payload)
push_detect_event(job_id, "job_complete", payload)

42
core/detect/events.py Normal file
View File

@@ -0,0 +1,42 @@
"""
Detection pipeline event helpers.
Non-generated runtime code for pushing SSE events.
The event payload types are in sse_contract.py (generated by modelgen).
"""
from pydantic import BaseModel
from core.events import push_event
DETECT_EVENTS_PREFIX = "detect_events"
# SSE event type names
EVENT_GRAPH_UPDATE = "graph_update"
EVENT_STATS_UPDATE = "stats_update"
EVENT_FRAME_UPDATE = "frame_update"
EVENT_DETECTION = "detection"
EVENT_LOG = "log"
EVENT_JOB_COMPLETE = "job_complete"
ALL_EVENT_TYPES = [
EVENT_GRAPH_UPDATE,
EVENT_STATS_UPDATE,
EVENT_FRAME_UPDATE,
EVENT_DETECTION,
EVENT_LOG,
EVENT_JOB_COMPLETE,
]
TERMINAL_EVENTS = [EVENT_JOB_COMPLETE]
def push_detect_event(job_id: str, event_type: str, data: BaseModel | dict) -> None:
"""Push a detection event to Redis. Accepts Pydantic models or plain dicts."""
payload = data.model_dump(mode="json") if isinstance(data, BaseModel) else data
push_event(
job_id=job_id,
event_type=event_type,
data=payload,
prefix=DETECT_EVENTS_PREFIX,
)

View File

@@ -0,0 +1,45 @@
"""
Detection pipeline graph.
detect/graph/
nodes.py — node functions (one per stage)
events.py — graph_update SSE emission
runner.py — PipelineRunner (config-driven, checkpoint, cancel, pause)
"""
from .nodes import NODES, NODE_FUNCTIONS
from .runner import (
PipelineCancelled,
PipelineRunner,
build_graph,
clear_cancel_check,
clear_pause,
get_pipeline,
init_pause,
is_paused,
pause_pipeline,
resume_pipeline,
set_cancel_check,
set_pause_after_stage,
step_pipeline,
)
from .events import _node_states
__all__ = [
"NODES",
"NODE_FUNCTIONS",
"PipelineCancelled",
"PipelineRunner",
"build_graph",
"get_pipeline",
"set_cancel_check",
"clear_cancel_check",
"init_pause",
"clear_pause",
"pause_pipeline",
"resume_pipeline",
"step_pipeline",
"set_pause_after_stage",
"is_paused",
"_node_states",
]

View File

@@ -0,0 +1,27 @@
"""
Graph event emission — node state tracking + SSE graph_update events.
"""
from __future__ import annotations
from core.detect import emit
from core.detect.state import DetectState
# Track node states across pipeline runs
_node_states: dict[str, dict[str, str]] = {}
def emit_transition(state: DetectState, node: str, status: str, node_list: list[str]):
"""Update node status and emit graph_update SSE event."""
job_id = state.get("job_id")
if not job_id:
return
if job_id not in _node_states:
_node_states[job_id] = {n: "pending" for n in node_list}
_node_states[job_id][node] = status
nodes = [{"id": n, "status": _node_states[job_id][n]} for n in node_list]
emit.graph_update(job_id, nodes)

366
core/detect/graph/nodes.py Normal file
View File

@@ -0,0 +1,366 @@
"""
Pipeline node functions — one per stage.
Each node: reads state, gets config from profile dict, runs stage logic,
emits transitions, returns output dict.
"""
from __future__ import annotations
import os
from core.detect import emit
from core.detect.models import CropContext, PipelineStats
from core.detect.profile import get_profile, get_stage_config, build_vlm_prompt, aggregate_detections
from core.detect.stages.models import (
DetectionConfig,
FieldSegmentationConfig,
FrameExtractionConfig,
OCRConfig,
RegionAnalysisConfig,
ResolverConfig,
SceneFilterConfig,
)
from core.detect.state import DetectState
from core.detect.stages.frame_extractor import extract_frames
from core.detect.stages.scene_filter import scene_filter
from core.detect.stages.field_segmentation import run_field_segmentation
from core.detect.stages.edge_detector import detect_edge_regions
from core.detect.stages.yolo_detector import detect_objects
from core.detect.stages.preprocess import preprocess_regions
from core.detect.stages.ocr_stage import run_ocr
from core.detect.stages.brand_resolver import resolve_brands
from core.detect.stages.vlm_local import escalate_vlm
from core.detect.stages.vlm_cloud import escalate_cloud
from core.detect.stages.aggregator import compile_report
from core.detect.tracing import trace_node, flush as flush_traces
from .events import emit_transition
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
NODES = [
"extract_frames",
"filter_scenes",
"field_segmentation",
"detect_edges",
"detect_objects",
"preprocess",
"run_ocr",
"match_brands",
"escalate_vlm",
"escalate_cloud",
"compile_report",
]
def _load_profile(state: DetectState) -> dict:
"""Load profile dict, apply config overrides if present."""
name = state.get("profile_name", "soccer_broadcast")
profile = get_profile(name)
overrides = state.get("config_overrides")
if overrides:
# Merge overrides into a copy of the profile configs
merged_configs = dict(profile.get("configs", {}))
for stage_name, stage_overrides in overrides.items():
if stage_name in merged_configs:
merged_configs[stage_name] = {**merged_configs[stage_name], **stage_overrides}
else:
merged_configs[stage_name] = stage_overrides
profile = {**profile, "configs": merged_configs}
return profile
def _emit(state, node, status):
emit_transition(state, node, status, NODES)
# --- Node functions ---
def node_extract_frames(state: DetectState) -> dict:
job_id = state.get("job_id", "")
if job_id and not emit._run_context:
emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial")
source_asset_id = state.get("source_asset_id")
if source_asset_id and not state.get("session_brands"):
from core.detect.stages.brand_resolver import build_session_dict
session_brands = build_session_dict(source_asset_id)
state["session_brands"] = session_brands
_emit(state, "extract_frames", "running")
with trace_node(state, "extract_frames") as span:
profile = _load_profile(state)
config = FrameExtractionConfig(**get_stage_config(profile, "extract_frames"))
frames = extract_frames(state["video_path"], config, job_id=job_id)
span.set_output({"frames_extracted": len(frames)})
_emit(state, "extract_frames", "done")
return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))}
def node_filter_scenes(state: DetectState) -> dict:
_emit(state, "filter_scenes", "running")
with trace_node(state, "filter_scenes") as span:
profile = _load_profile(state)
config = SceneFilterConfig(**get_stage_config(profile, "filter_scenes"))
frames = state.get("frames", [])
kept = scene_filter(frames, config, job_id=state.get("job_id"))
span.set_output({"frames_in": len(frames), "frames_kept": len(kept)})
stats = state.get("stats", PipelineStats())
stats.frames_after_scene_filter = len(kept)
_emit(state, "filter_scenes", "done")
return {"filtered_frames": kept, "stats": stats}
def node_field_segmentation(state: DetectState) -> dict:
_emit(state, "field_segmentation", "running")
with trace_node(state, "field_segmentation") as span:
profile = _load_profile(state)
config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation"))
frames = state.get("filtered_frames", [])
job_id = state.get("job_id")
result = run_field_segmentation(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
span.set_output({
"frames": len(frames),
"avg_coverage": sum(result["field_coverage"].values()) / max(len(result["field_coverage"]), 1),
})
_emit(state, "field_segmentation", "done")
return {
"field_masks": result["field_masks"],
"field_boundaries": result["field_boundaries"],
"field_coverage": result["field_coverage"],
}
def node_detect_edges(state: DetectState) -> dict:
_emit(state, "detect_edges", "running")
with trace_node(state, "detect_edges") as span:
profile = _load_profile(state)
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
frames = state.get("filtered_frames", [])
field_masks = state.get("field_masks", {})
job_id = state.get("job_id")
regions = detect_edge_regions(
frames, config, inference_url=INFERENCE_URL, job_id=job_id,
field_masks=field_masks,
)
total = sum(len(r) for r in regions.values())
span.set_output({"frames": len(frames), "edge_regions": total})
stats = state.get("stats", PipelineStats())
stats.cv_regions_detected = total
_emit(state, "detect_edges", "done")
return {"edge_regions_by_frame": regions, "stats": stats}
def node_detect_objects(state: DetectState) -> dict:
_emit(state, "detect_objects", "running")
with trace_node(state, "detect_objects") as span:
profile = _load_profile(state)
config = DetectionConfig(**get_stage_config(profile, "detect_objects"))
frames = state.get("filtered_frames", [])
job_id = state.get("job_id")
all_boxes = detect_objects(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
total_regions = sum(len(boxes) for boxes in all_boxes.values())
span.set_output({"frames": len(frames), "regions_detected": total_regions})
stats = state.get("stats", PipelineStats())
stats.regions_detected = total_regions
_emit(state, "detect_objects", "done")
return {"boxes_by_frame": all_boxes, "stats": stats}
def node_preprocess(state: DetectState) -> dict:
_emit(state, "preprocess", "running")
with trace_node(state, "preprocess") as span:
profile = _load_profile(state)
prep_config = get_stage_config(profile, "preprocess")
frames = state.get("filtered_frames", [])
boxes = state.get("boxes_by_frame", {})
job_id = state.get("job_id")
do_contrast = prep_config.get("contrast", True)
do_deskew = prep_config.get("deskew", False)
do_binarize = prep_config.get("binarize", False)
preprocessed = preprocess_regions(
frames, boxes,
do_contrast=do_contrast,
do_deskew=do_deskew,
do_binarize=do_binarize,
inference_url=INFERENCE_URL,
job_id=job_id,
)
span.set_output({"regions_preprocessed": len(preprocessed)})
_emit(state, "preprocess", "done")
return {"preprocessed_crops": preprocessed}
def node_run_ocr(state: DetectState) -> dict:
_emit(state, "run_ocr", "running")
with trace_node(state, "run_ocr") as span:
profile = _load_profile(state)
config = OCRConfig(**get_stage_config(profile, "run_ocr"))
frames = state.get("filtered_frames", [])
boxes = state.get("boxes_by_frame", {})
job_id = state.get("job_id")
candidates = run_ocr(frames, boxes, config, inference_url=INFERENCE_URL, job_id=job_id)
span.set_output({"regions_in": sum(len(b) for b in boxes.values()), "text_candidates": len(candidates)})
stats = state.get("stats", PipelineStats())
stats.regions_resolved_by_ocr = len(candidates)
_emit(state, "run_ocr", "done")
return {"text_candidates": candidates, "stats": stats}
def node_match_brands(state: DetectState) -> dict:
_emit(state, "match_brands", "running")
with trace_node(state, "match_brands") as span:
profile = _load_profile(state)
config = ResolverConfig(**get_stage_config(profile, "match_brands"))
candidates = state.get("text_candidates", [])
session_brands = state.get("session_brands", {})
job_id = state.get("job_id")
source_asset_id = state.get("source_asset_id")
matched, unresolved = resolve_brands(
candidates, config,
session_brands=session_brands,
source_asset_id=source_asset_id,
content_type=profile["name"], job_id=job_id,
)
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
_emit(state, "match_brands", "done")
return {"detections": matched, "unresolved_candidates": unresolved}
def node_escalate_vlm(state: DetectState) -> dict:
_emit(state, "escalate_vlm", "running")
with trace_node(state, "escalate_vlm") as span:
profile = _load_profile(state)
vlm_config = get_stage_config(profile, "escalate_vlm")
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
candidates = state.get("unresolved_candidates", [])
job_id = state.get("job_id")
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
vlm_matched, still_unresolved = escalate_vlm(
candidates,
vlm_prompt_fn=vlm_prompt_fn,
inference_url=INFERENCE_URL,
content_type=profile["name"],
source_asset_id=state.get("source_asset_id"),
job_id=job_id,
)
stats = state.get("stats", PipelineStats())
stats.regions_escalated_to_local_vlm = len(candidates)
span.set_output({"candidates": len(candidates), "matched": len(vlm_matched),
"still_unresolved": len(still_unresolved)})
existing = state.get("detections", [])
vlm_skipped = os.environ.get("SKIP_VLM", "").strip() == "1"
_emit(state, "escalate_vlm", "skipped" if vlm_skipped else "done")
return {
"detections": existing + vlm_matched,
"unresolved_candidates": still_unresolved,
"stats": stats,
}
def node_escalate_cloud(state: DetectState) -> dict:
_emit(state, "escalate_cloud", "running")
with trace_node(state, "escalate_cloud") as span:
profile = _load_profile(state)
vlm_config = get_stage_config(profile, "escalate_vlm")
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
candidates = state.get("unresolved_candidates", [])
job_id = state.get("job_id")
stats = state.get("stats", PipelineStats())
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
cloud_matched = escalate_cloud(
candidates,
vlm_prompt_fn=vlm_prompt_fn,
stats=stats,
content_type=profile["name"],
source_asset_id=state.get("source_asset_id"),
job_id=job_id,
)
span.set_output({"candidates": len(candidates), "matched": len(cloud_matched),
"cloud_calls": stats.cloud_llm_calls,
"cost_usd": stats.estimated_cloud_cost_usd})
existing = state.get("detections", [])
cloud_skipped = os.environ.get("SKIP_CLOUD", "").strip() == "1"
_emit(state, "escalate_cloud", "skipped" if cloud_skipped else "done")
return {"detections": existing + cloud_matched, "stats": stats}
def node_compile_report(state: DetectState) -> dict:
_emit(state, "compile_report", "running")
with trace_node(state, "compile_report") as span:
profile = _load_profile(state)
detections = state.get("detections", [])
stats = state.get("stats", PipelineStats())
job_id = state.get("job_id")
report = compile_report(
detections=detections,
stats=stats,
video_source=state.get("video_path", ""),
content_type=profile["name"],
job_id=job_id,
)
span.set_output({"brands": len(report.brands), "detections": len(report.timeline)})
flush_traces()
_emit(state, "compile_report", "done")
return {"report": report}
NODE_FUNCTIONS = [
("extract_frames", node_extract_frames),
("filter_scenes", node_filter_scenes),
("field_segmentation", node_field_segmentation),
("detect_edges", node_detect_edges),
("detect_objects", node_detect_objects),
("preprocess", node_preprocess),
("run_ocr", node_run_ocr),
("match_brands", node_match_brands),
("escalate_vlm", node_escalate_vlm),
("escalate_cloud", node_escalate_cloud),
("compile_report", node_compile_report),
]

274
core/detect/graph/runner.py Normal file
View File

@@ -0,0 +1,274 @@
"""
Pipeline runner — executes stages sequentially with checkpointing,
cancellation, and pause/resume.
Reads PipelineConfig from the profile to determine what stages to run.
Flattens the graph into a linear sequence for now (serial execution).
Executor socket: all stages run via LocalExecutor (call function directly).
"""
from __future__ import annotations
import logging
import os
import threading
from core.detect.stages.models import PipelineConfig
from core.detect.state import DetectState
from .nodes import NODES, NODE_FUNCTIONS
logger = logging.getLogger(__name__)
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
class PipelineCancelled(Exception):
"""Raised when a pipeline run is cancelled."""
pass
class PipelinePaused(Exception):
"""Raised when a pipeline is paused (internally, for flow control)."""
pass
# ---------------------------------------------------------------------------
# Cancellation — checked before each node
# ---------------------------------------------------------------------------
_cancel_check: dict[str, callable] = {}
def set_cancel_check(job_id: str, fn):
_cancel_check[job_id] = fn
def clear_cancel_check(job_id: str):
_cancel_check.pop(job_id, None)
# ---------------------------------------------------------------------------
# Pause / Resume / Step — checked after each node completes
# ---------------------------------------------------------------------------
_pause_gate: dict[str, threading.Event] = {}
_pause_after_stage: dict[str, bool] = {}
def init_pause(job_id: str, pause_after_stage: bool = False):
"""Initialize pause state for a job. Called when pipeline starts."""
gate = threading.Event()
gate.set() # start unpaused
_pause_gate[job_id] = gate
_pause_after_stage[job_id] = pause_after_stage
def clear_pause(job_id: str):
"""Clean up pause state. Called when pipeline finishes."""
_pause_gate.pop(job_id, None)
_pause_after_stage.pop(job_id, None)
def pause_pipeline(job_id: str):
"""Pause a running pipeline. It will block after the current stage completes."""
gate = _pause_gate.get(job_id)
if gate:
gate.clear()
logger.info("Pipeline %s paused", job_id)
def resume_pipeline(job_id: str):
"""Resume a paused pipeline."""
gate = _pause_gate.get(job_id)
if gate:
gate.set()
logger.info("Pipeline %s resumed", job_id)
def step_pipeline(job_id: str):
"""Run one stage then pause again."""
_pause_after_stage[job_id] = True
gate = _pause_gate.get(job_id)
if gate:
gate.set()
logger.info("Pipeline %s stepping", job_id)
def set_pause_after_stage(job_id: str, enabled: bool):
"""Toggle pause-after-each-stage mode."""
_pause_after_stage[job_id] = enabled
if not enabled:
gate = _pause_gate.get(job_id)
if gate:
gate.set()
def is_paused(job_id: str) -> bool:
"""Check if a pipeline is currently paused."""
gate = _pause_gate.get(job_id)
return gate is not None and not gate.is_set()
def _wait_if_paused(job_id: str, node_name: str):
"""Block until resumed. Called after each node completes."""
gate = _pause_gate.get(job_id)
if gate is None:
return
if _pause_after_stage.get(job_id, False):
gate.clear()
from core.detect import emit
emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}")
while not gate.wait(timeout=0.5):
check = _cancel_check.get(job_id)
if check and check():
raise PipelineCancelled(f"Cancelled while paused before next stage")
# ---------------------------------------------------------------------------
# Pipeline Runner
# ---------------------------------------------------------------------------
# Node function lookup — maps stage name to callable
_NODE_FN_MAP: dict[str, callable] = {name: fn for name, fn in NODE_FUNCTIONS}
def _flatten_config(config: PipelineConfig, start_from: str | None = None) -> list[str]:
"""
Flatten a PipelineConfig into a linear stage sequence.
For now: topological sort via edges. Falls back to stage order if no edges.
Respects start_from for replay (skip stages before it).
"""
if not config.edges:
# No edges defined — use stage order as-is
names = [s.name for s in config.stages]
else:
# Topological sort from edges
graph: dict[str, list[str]] = {}
in_degree: dict[str, int] = {}
stage_names = {s.name for s in config.stages}
for name in stage_names:
graph[name] = []
in_degree[name] = 0
for edge in config.edges:
if edge.source in stage_names and edge.target in stage_names:
graph[edge.source].append(edge.target)
in_degree[edge.target] = in_degree.get(edge.target, 0) + 1
# Kahn's algorithm
queue = [n for n in stage_names if in_degree.get(n, 0) == 0]
# Stable sort: prefer order from config.stages
stage_order = {s.name: i for i, s in enumerate(config.stages)}
queue.sort(key=lambda n: stage_order.get(n, 999))
names = []
while queue:
node = queue.pop(0)
names.append(node)
for neighbor in graph.get(node, []):
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
queue.sort(key=lambda n: stage_order.get(n, 999))
if start_from:
try:
idx = names.index(start_from)
names = names[idx:]
except ValueError:
raise ValueError(f"Stage {start_from!r} not in pipeline config")
return names
class PipelineRunner:
"""
Executes a pipeline defined by PipelineConfig.
Runs stages sequentially (flattened). Each stage:
1. Check cancel
2. Run node function (via executor — local for now)
3. Merge result into state
4. Checkpoint (if enabled)
5. Check pause
Executor socket: currently calls node functions directly.
Future: dispatch to LocalExecutor / GrpcExecutor / LambdaExecutor
based on StageRef.execution_target.
"""
def __init__(
self,
config: PipelineConfig,
checkpoint: bool = False,
start_from: str | None = None,
):
self.config = config
self.do_checkpoint = checkpoint
self.stage_sequence = _flatten_config(config, start_from)
def invoke(self, state: DetectState) -> DetectState:
"""Run the pipeline on the given state. Returns final state."""
for stage_name in self.stage_sequence:
job_id = state.get("job_id", "")
# 1. Cancel check
check = _cancel_check.get(job_id)
if check and check():
raise PipelineCancelled(f"Cancelled before {stage_name}")
# 2. Run node function
node_fn = _NODE_FN_MAP.get(stage_name)
if node_fn is None:
logger.warning("No node function for stage %s, skipping", stage_name)
continue
result = node_fn(state)
# 3. Merge result into state
state.update(result)
# 4. Checkpoint
if self.do_checkpoint:
from core.detect.checkpoint import checkpoint_after_stage
checkpoint_after_stage(job_id, stage_name, state, result)
# 5. Pause check
_wait_if_paused(job_id, stage_name)
return state
# ---------------------------------------------------------------------------
# Public API — backwards compatible with old get_pipeline/build_graph
# ---------------------------------------------------------------------------
def get_pipeline(
checkpoint: bool | None = None,
profile_name: str = "soccer_broadcast",
start_from: str | None = None,
) -> PipelineRunner:
"""Return a PipelineRunner for the given profile."""
from core.detect.profile import get_profile, pipeline_config_from_dict
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
profile = get_profile(profile_name)
config = pipeline_config_from_dict(profile["pipeline"])
return PipelineRunner(
config=config,
checkpoint=do_checkpoint,
start_from=start_from,
)
def build_graph(checkpoint: bool | None = None, start_from: str | None = None):
"""Backwards-compatible wrapper. Returns a PipelineRunner."""
return get_pipeline(checkpoint=checkpoint, start_from=start_from)

View File

@@ -0,0 +1,4 @@
from .client import InferenceClient
from .types import DetectResult, OCRResult, VLMResult
__all__ = ["InferenceClient", "DetectResult", "OCRResult", "VLMResult"]

View File

@@ -0,0 +1,262 @@
"""
HTTP client for the inference server.
The pipeline stages call this instead of importing ML libraries directly.
The inference server runs on the GPU machine (or spot instance).
"""
from __future__ import annotations
import base64
import io
import logging
import os
import numpy as np
import requests
from PIL import Image
from .types import DetectResult, OCRResult, RegionDebugResult, RegionResult, ServerStatus, VLMResult
logger = logging.getLogger(__name__)
DEFAULT_URL = os.environ.get("INFERENCE_URL", "http://localhost:8000")
def _encode_image(image: np.ndarray) -> str:
"""Encode numpy array as base64 JPEG."""
img = Image.fromarray(image)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
return base64.b64encode(buf.getvalue()).decode()
class InferenceClient:
"""HTTP client for the GPU inference server."""
def __init__(self, base_url: str | None = None, timeout: float = 60.0,
job_id: str = "", log_level: str = "INFO"):
self.base_url = (base_url or DEFAULT_URL).rstrip("/")
self.timeout = timeout
self.job_id = job_id
self.log_level = log_level
self.session = requests.Session()
if job_id:
self.session.headers["X-Job-Id"] = job_id
self.session.headers["X-Log-Level"] = log_level
def health(self) -> ServerStatus:
"""Check server health and loaded models."""
resp = self.session.get(f"{self.base_url}/health", timeout=self.timeout)
resp.raise_for_status()
data = resp.json()
return ServerStatus(
loaded_models=data.get("loaded_models", []),
vram_used_mb=data.get("vram_used_mb", 0),
vram_budget_mb=data.get("vram_budget_mb", 0),
strategy=data.get("strategy", "sequential"),
)
def detect(
self,
image: np.ndarray,
model: str = "yolov8n",
confidence: float = 0.3,
target_classes: list[str] | None = None,
) -> list[DetectResult]:
"""Run object detection on an image."""
payload = {
"image": _encode_image(image),
"model": model,
"confidence": confidence,
}
if target_classes:
payload["target_classes"] = target_classes
resp = self.session.post(
f"{self.base_url}/detect",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
results = []
for d in resp.json().get("detections", []):
result = DetectResult(
x=d["x"], y=d["y"], w=d["w"], h=d["h"],
confidence=d["confidence"], label=d["label"],
)
results.append(result)
return results
def ocr(
self,
image: np.ndarray,
languages: list[str] | None = None,
) -> list[OCRResult]:
"""Run OCR on an image region."""
payload = {
"image": _encode_image(image),
}
if languages:
payload["languages"] = languages
resp = self.session.post(
f"{self.base_url}/ocr",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
results = []
for d in resp.json().get("results", []):
result = OCRResult(
text=d["text"],
confidence=d["confidence"],
bbox=tuple(d["bbox"]),
)
results.append(result)
return results
def vlm(
self,
image: np.ndarray,
prompt: str,
model: str = "moondream2",
) -> VLMResult:
"""Query a visual language model with an image crop + prompt."""
payload = {
"image": _encode_image(image),
"prompt": prompt,
"model": model,
}
resp = self.session.post(
f"{self.base_url}/vlm",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
data = resp.json()
return VLMResult(
brand=data.get("brand", ""),
confidence=data.get("confidence", 0.0),
reasoning=data.get("reasoning", ""),
)
def detect_edges(
self,
image: np.ndarray,
edge_canny_low: int = 50,
edge_canny_high: int = 150,
edge_hough_threshold: int = 80,
edge_hough_min_length: int = 100,
edge_hough_max_gap: int = 10,
edge_pair_max_distance: int = 200,
edge_pair_min_distance: int = 15,
) -> list[RegionResult]:
"""Run edge detection on an image."""
payload = {
"image": _encode_image(image),
"edge_canny_low": edge_canny_low,
"edge_canny_high": edge_canny_high,
"edge_hough_threshold": edge_hough_threshold,
"edge_hough_min_length": edge_hough_min_length,
"edge_hough_max_gap": edge_hough_max_gap,
"edge_pair_max_distance": edge_pair_max_distance,
"edge_pair_min_distance": edge_pair_min_distance,
}
resp = self.session.post(
f"{self.base_url}/detect_edges",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
results = []
for r in resp.json().get("regions", []):
result = RegionResult(
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
confidence=r["confidence"], label=r["label"],
)
results.append(result)
return results
def detect_edges_debug(
self,
image: np.ndarray,
edge_canny_low: int = 50,
edge_canny_high: int = 150,
edge_hough_threshold: int = 80,
edge_hough_min_length: int = 100,
edge_hough_max_gap: int = 10,
edge_pair_max_distance: int = 200,
edge_pair_min_distance: int = 15,
) -> RegionDebugResult:
"""Run edge detection with debug overlays."""
payload = {
"image": _encode_image(image),
"edge_canny_low": edge_canny_low,
"edge_canny_high": edge_canny_high,
"edge_hough_threshold": edge_hough_threshold,
"edge_hough_min_length": edge_hough_min_length,
"edge_hough_max_gap": edge_hough_max_gap,
"edge_pair_max_distance": edge_pair_max_distance,
"edge_pair_min_distance": edge_pair_min_distance,
}
resp = self.session.post(
f"{self.base_url}/detect_edges/debug",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
data = resp.json()
regions = []
for r in data.get("regions", []):
region = RegionResult(
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
confidence=r["confidence"], label=r["label"],
)
regions.append(region)
return RegionDebugResult(
regions=regions,
edge_overlay_b64=data.get("edge_overlay_b64", ""),
lines_overlay_b64=data.get("lines_overlay_b64", ""),
horizontal_count=data.get("horizontal_count", 0),
pair_count=data.get("pair_count", 0),
)
def post(self, path: str, payload: dict) -> dict | None:
"""Generic POST to the inference server. Returns JSON response or None on error."""
try:
resp = self.session.post(
f"{self.base_url}{path}",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
return resp.json()
except Exception as e:
logger.warning("Inference POST %s failed: %s", path, e)
return None
def load_model(self, model: str, quantization: str = "fp16") -> None:
"""Request the server to load a model into VRAM."""
self.session.post(
f"{self.base_url}/models/load",
json={"model": model, "quantization": quantization},
timeout=self.timeout,
).raise_for_status()
def unload_model(self, model: str) -> None:
"""Request the server to unload a model from VRAM."""
self.session.post(
f"{self.base_url}/models/unload",
json={"model": model},
timeout=self.timeout,
).raise_for_status()

View File

@@ -0,0 +1,76 @@
"""
Inference response types.
These are the shapes returned by the inference server.
Kept separate from core.detect.models to avoid coupling the
inference protocol to pipeline internals.
"""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass
class DetectResult:
"""Single object detection from YOLO or similar."""
x: int
y: int
w: int
h: int
confidence: float
label: str
@dataclass
class OCRResult:
"""Text extracted from a region."""
text: str
confidence: float
bbox: tuple[int, int, int, int] # x, y, w, h
@dataclass
class VLMResult:
"""Visual language model response for a crop."""
brand: str
confidence: float
reasoning: str
@dataclass
class RegionResult:
"""A candidate region from CV analysis."""
x: int
y: int
w: int
h: int
confidence: float
label: str
@dataclass
class RegionDebugResult:
"""CV region analysis with debug overlays."""
regions: list[RegionResult] = field(default_factory=list)
edge_overlay_b64: str = ""
lines_overlay_b64: str = ""
horizontal_count: int = 0
pair_count: int = 0
@dataclass
class ModelInfo:
"""Info about a loaded model."""
name: str
vram_mb: float
quantization: str # fp32, fp16, int8, int4
@dataclass
class ServerStatus:
"""Inference server health response."""
loaded_models: list[ModelInfo] = field(default_factory=list)
vram_used_mb: float = 0.0
vram_budget_mb: float = 0.0
strategy: str = "sequential" # sequential, concurrent, auto

View File

@@ -2,8 +2,8 @@
Detection pipeline runtime models.
These are the data structures that flow between pipeline stages.
They contain runtime types (np.ndarray) so modelgen skips them
not generated to SQLModel or TypeScript.
They contain runtime types (np.ndarray) so they live here, not in
core/schema/models/ (which is for modelgen source of truth).
"""
from __future__ import annotations
@@ -85,3 +85,11 @@ class DetectionReport:
brands: dict[str, BrandStats] = field(default_factory=dict)
timeline: list[BrandDetection] = field(default_factory=list)
pipeline_stats: PipelineStats = field(default_factory=PipelineStats)
@dataclass
class CropContext:
"""Runtime type — holds image bytes for VLM prompts."""
image: bytes
surrounding_text: str = ""
position_hint: str = ""

107
core/detect/profile.py Normal file
View File

@@ -0,0 +1,107 @@
"""
Profile registry and helpers.
Loads profile data from Postgres.
A profile is a dict with keys: name, pipeline, configs.
"""
from __future__ import annotations
import logging
from typing import Any, Dict
from core.detect.stages.models import PipelineConfig, StageRef, Edge
from core.detect.models import (
BrandDetection,
BrandStats,
CropContext,
DetectionReport,
PipelineStats,
)
logger = logging.getLogger(__name__)
def get_profile(name: str) -> Dict[str, Any]:
"""Get a profile dict by name from the database."""
from core.db.connection import get_session
from core.db.models import Profile
with get_session() as session:
row = session.query(Profile).filter(Profile.name == name).first()
if row is None:
raise ValueError(f"Unknown profile: {name!r}")
return {
"name": row.name,
"pipeline": row.pipeline or {},
"configs": row.configs or {},
}
def list_profiles() -> list[str]:
"""List available profile names from the database."""
from core.db.connection import get_session
from core.db.models import Profile
with get_session() as session:
rows = session.query(Profile.name).all()
return [r[0] for r in rows]
def get_stage_config(profile: Dict[str, Any], stage_name: str) -> dict:
"""Get config values for a stage from a profile."""
return profile.get("configs", {}).get(stage_name, {})
def pipeline_config_from_dict(data: Dict[str, Any]) -> PipelineConfig:
"""Deserialize a PipelineConfig from a JSONB dict."""
stages = [StageRef(**s) for s in data.get("stages", [])]
edges = [Edge(**e) for e in data.get("edges", [])]
return PipelineConfig(
name=data.get("name", ""),
profile_name=data.get("profile_name", ""),
stages=stages,
edges=edges,
routing_rules=data.get("routing_rules", {}),
)
def build_vlm_prompt(crop_context: CropContext, template: str) -> str:
"""Build a VLM prompt from a template and crop context."""
hint = f" Position: {crop_context.position_hint}." if crop_context.position_hint else ""
text = f" Nearby text: '{crop_context.surrounding_text}'." if crop_context.surrounding_text else ""
return template.format(hint=hint, text=text)
def aggregate_detections(
detections: list[BrandDetection],
content_type: str,
) -> DetectionReport:
"""Group detections by brand into a report."""
brands: dict[str, BrandStats] = {}
for d in detections:
if d.brand not in brands:
brands[d.brand] = BrandStats()
s = brands[d.brand]
s.total_appearances += 1
s.total_screen_time += d.duration
s.avg_confidence = (
(s.avg_confidence * (s.total_appearances - 1) + d.confidence)
/ s.total_appearances
)
if s.first_seen == 0.0 or d.timestamp < s.first_seen:
s.first_seen = d.timestamp
if d.timestamp > s.last_seen:
s.last_seen = d.timestamp
return DetectionReport(
video_source="",
content_type=content_type,
duration_seconds=0.0,
brands=brands,
timeline=sorted(detections, key=lambda d: d.timestamp),
pipeline_stats=PipelineStats(),
)

View File

@@ -0,0 +1,58 @@
"""
Cloud LLM provider registry.
Select provider via CLOUD_LLM_PROVIDER env var.
Each provider reads its own env vars for auth/config.
CLOUD_LLM_PROVIDER=groq → GROQ_API_KEY, GROQ_MODEL, GROQ_BASE_URL
CLOUD_LLM_PROVIDER=gemini → GEMINI_API_KEY, GEMINI_MODEL
CLOUD_LLM_PROVIDER=openai → OPENAI_API_KEY, OPENAI_MODEL, OPENAI_BASE_URL
CLOUD_LLM_PROVIDER=claude → ANTHROPIC_API_KEY, CLAUDE_MODEL
"""
from __future__ import annotations
import os
from .base import CloudProvider, ProviderResponse
from .groq import GroqProvider
from .gemini import GeminiProvider
from .openai_compat import OpenAICompatProvider
from .claude import ClaudeProvider
PROVIDERS: dict[str, type] = {
"groq": GroqProvider,
"gemini": GeminiProvider,
"openai": OpenAICompatProvider,
"claude": ClaudeProvider,
}
_cached: CloudProvider | None = None
def get_provider() -> CloudProvider:
"""Get the configured cloud provider (cached after first call)."""
global _cached
if _cached is not None:
return _cached
name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
cls = PROVIDERS.get(name)
if cls is None:
raise ValueError(f"Unknown provider: {name!r}. Options: {list(PROVIDERS)}")
_cached = cls()
return _cached
def has_api_key() -> bool:
"""Check if the configured provider has an API key set."""
name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
key_map = {
"groq": "GROQ_API_KEY",
"gemini": "GEMINI_API_KEY",
"openai": "OPENAI_API_KEY",
"claude": "ANTHROPIC_API_KEY",
}
env_var = key_map.get(name, "")
return bool(os.environ.get(env_var, ""))

View File

@@ -0,0 +1,36 @@
"""Cloud LLM provider protocol and model metadata."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol
@dataclass
class ModelInfo:
"""Metadata for a cloud LLM model."""
id: str
vision: bool = True
cost_per_input_token: float = 0.0
cost_per_output_token: float = 0.0
max_output_tokens: int = 4096
notes: str = ""
@dataclass
class ProviderResponse:
answer: str
total_tokens: int = 0
class CloudProvider(Protocol):
"""
Interface for cloud LLM providers.
Each provider handles its own auth, payload format, and response parsing.
The pipeline only calls call() and reads the response.
"""
name: str
models: dict[str, ModelInfo]
def call(self, image_b64: str, prompt: str) -> ProviderResponse: ...

View File

@@ -0,0 +1,73 @@
"""Anthropic Claude provider — uses the official SDK."""
from __future__ import annotations
import logging
import os
from .base import ModelInfo, ProviderResponse
logger = logging.getLogger(__name__)
# Claude-specific env vars
# ANTHROPIC_API_KEY is read by the SDK automatically
CLAUDE_MODEL = os.environ.get("CLAUDE_MODEL", "claude-sonnet-4-20250514")
MODELS = {
"claude-sonnet-4-20250514": ModelInfo(
id="claude-sonnet-4-20250514",
vision=True,
cost_per_input_token=0.000003,
cost_per_output_token=0.000015,
notes="Best balance of quality/cost with vision",
),
"claude-haiku-4-5-20251001": ModelInfo(
id="claude-haiku-4-5-20251001",
vision=True,
cost_per_input_token=0.0000008,
cost_per_output_token=0.000004,
notes="Fastest, cheapest, good for simple brand ID",
),
"claude-opus-4-6": ModelInfo(
id="claude-opus-4-6",
vision=True,
cost_per_input_token=0.000015,
cost_per_output_token=0.000075,
notes="Highest quality, use for ambiguous cases",
),
}
class ClaudeProvider:
name = "claude"
models = MODELS
def __init__(self):
from anthropic import Anthropic
self.client = Anthropic()
self.model = CLAUDE_MODEL
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
message = self.client.messages.create(
model=self.model,
max_tokens=150,
messages=[{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": image_b64,
},
},
{"type": "text", "text": prompt},
],
}],
)
answer = message.content[0].text.strip()
total_tokens = message.usage.input_tokens + message.usage.output_tokens
return ProviderResponse(answer=answer, total_tokens=total_tokens)

View File

@@ -0,0 +1,75 @@
"""Google Gemini provider — native REST API, not OpenAI-compatible."""
from __future__ import annotations
import logging
import os
import requests
from .base import ModelInfo, ProviderResponse
logger = logging.getLogger(__name__)
# Gemini-specific env vars
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-2.0-flash")
MODELS = {
"gemini-2.0-flash": ModelInfo(
id="gemini-2.0-flash",
vision=True,
cost_per_input_token=0.0000001,
cost_per_output_token=0.0000004,
notes="Fast, cheap, good vision",
),
"gemini-2.0-pro": ModelInfo(
id="gemini-2.0-pro",
vision=True,
cost_per_input_token=0.00000125,
cost_per_output_token=0.000005,
notes="Higher quality, slower",
),
"gemini-1.5-flash": ModelInfo(
id="gemini-1.5-flash",
vision=True,
cost_per_input_token=0.000000075,
cost_per_output_token=0.0000003,
notes="Cheapest option",
),
}
class GeminiProvider:
name = "gemini"
models = MODELS
def __init__(self):
self.api_key = GEMINI_API_KEY
self.model = GEMINI_MODEL
self.endpoint = (
f"https://generativelanguage.googleapis.com/v1beta/models/"
f"{self.model}:generateContent"
)
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
payload = {
"contents": [{
"parts": [
{"text": prompt},
{"inline_data": {"mime_type": "image/jpeg", "data": image_b64}},
],
}],
"generationConfig": {"maxOutputTokens": 150},
}
url = f"{self.endpoint}?key={self.api_key}"
resp = requests.post(url, json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
answer = data["candidates"][0]["content"]["parts"][0]["text"].strip()
usage = data.get("usageMetadata", {})
total_tokens = usage.get("totalTokenCount", 0)
return ProviderResponse(answer=answer, total_tokens=total_tokens)

View File

@@ -0,0 +1,66 @@
"""Groq cloud provider — OpenAI-compatible API with vision."""
from __future__ import annotations
import logging
import os
import requests
from .base import ModelInfo, ProviderResponse
logger = logging.getLogger(__name__)
# Groq-specific env vars
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
GROQ_BASE_URL = os.environ.get("GROQ_BASE_URL", "https://api.groq.com/openai/v1")
GROQ_MODEL = os.environ.get("GROQ_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct")
MODELS = {
"meta-llama/llama-4-scout-17b-16e-instruct": ModelInfo(
id="meta-llama/llama-4-scout-17b-16e-instruct",
vision=True,
cost_per_input_token=0.0,
cost_per_output_token=0.0,
notes="Llama 4 Scout, only vision model on Groq free tier",
),
}
class GroqProvider:
name = "groq"
models = MODELS
def __init__(self):
self.api_key = GROQ_API_KEY
self.base_url = GROQ_BASE_URL
self.model = GROQ_MODEL
self.endpoint = f"{self.base_url.rstrip('/')}/chat/completions"
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
payload = {
"model": self.model,
"messages": [{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {
"url": f"data:image/jpeg;base64,{image_b64}",
}},
],
}],
"max_tokens": 150,
}
resp = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
answer = data["choices"][0]["message"]["content"].strip()
total_tokens = data.get("usage", {}).get("total_tokens", 0)
return ProviderResponse(answer=answer, total_tokens=total_tokens)

View File

@@ -0,0 +1,73 @@
"""Generic OpenAI-compatible provider (OpenAI, Together, etc.)."""
from __future__ import annotations
import logging
import os
import requests
from .base import ModelInfo, ProviderResponse
logger = logging.getLogger(__name__)
# OpenAI-compat specific env vars
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
OPENAI_MODEL = os.environ.get("OPENAI_MODEL", "gpt-4o-mini")
MODELS = {
"gpt-4o-mini": ModelInfo(
id="gpt-4o-mini",
vision=True,
cost_per_input_token=0.00000015,
cost_per_output_token=0.0000006,
notes="Cheap, fast, decent vision",
),
"gpt-4o": ModelInfo(
id="gpt-4o",
vision=True,
cost_per_input_token=0.0000025,
cost_per_output_token=0.00001,
notes="Best OpenAI vision model",
),
}
class OpenAICompatProvider:
name = "openai"
models = MODELS
def __init__(self):
self.api_key = OPENAI_API_KEY
self.base_url = OPENAI_BASE_URL
self.model = OPENAI_MODEL
self.endpoint = f"{self.base_url.rstrip('/')}/chat/completions"
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
payload = {
"model": self.model,
"messages": [{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {
"url": f"data:image/jpeg;base64,{image_b64}",
}},
],
}],
"max_tokens": 150,
}
resp = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
answer = data["choices"][0]["message"]["content"].strip()
total_tokens = data.get("usage", {}).get("total_tokens", 0)
return ProviderResponse(answer=answer, total_tokens=total_tokens)

163
core/detect/sse.py Normal file
View File

@@ -0,0 +1,163 @@
"""
Pydantic Models - GENERATED FILE
Do not edit directly. Regenerate using modelgen.
"""
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel, Field
class GraphNode(BaseModel):
"""A pipeline stage node."""
id: str
status: str = "idle"
items_in: int = 0
items_out: int = 0
class GraphEdge(BaseModel):
"""An edge between pipeline stages."""
source: str
target: str
throughput: int = 0
class BoundingBoxEvent(BaseModel):
"""Bounding box in SSE event payloads."""
x: int
y: int
w: int
h: int
confidence: float
label: str
resolved_brand: Optional[str] = None
source: Optional[str] = None
stage: Optional[str] = None
class BrandSummary(BaseModel):
"""Per-brand stats in the final report."""
brand: str
total_appearances: int = 0
total_screen_time: float = 0.0
avg_confidence: float = 0.0
first_seen: float = 0.0
last_seen: float = 0.0
class GraphUpdate(BaseModel):
"""Pipeline node state transition. SSE event: graph_update"""
nodes: List[GraphNode] = Field(default_factory=list)
edges: List[GraphEdge] = Field(default_factory=list)
active_path: List[str] = Field(default_factory=list)
class StatsUpdate(BaseModel):
"""Funnel statistics snapshot. SSE event: stats_update"""
frames_extracted: int = 0
frames_after_scene_filter: int = 0
cv_regions_detected: int = 0
regions_detected: int = 0
regions_resolved_by_ocr: int = 0
regions_escalated_to_local_vlm: int = 0
regions_escalated_to_cloud_llm: int = 0
cloud_llm_calls: int = 0
processing_time_seconds: float = 0.0
estimated_cloud_cost_usd: float = 0.0
run_id: Optional[str] = None
parent_job_id: Optional[str] = None
run_type: str = "initial"
class FrameUpdate(BaseModel):
"""Current frame being processed. SSE event: frame_update"""
frame_ref: int
timestamp: float
jpeg_b64: str
boxes: List[BoundingBoxEvent] = Field(default_factory=list)
class Detection(BaseModel):
"""A confirmed brand detection. SSE event: detection"""
brand: str
timestamp: float
duration: float
confidence: float
source: str
content_type: str
bbox: Optional[BoundingBoxEvent] = None
frame_ref: Optional[int] = None
class LogEvent(BaseModel):
"""Pipeline log line. SSE event: log"""
level: str
stage: str
msg: str
ts: str
trace_id: Optional[str] = None
class DetectionReportSummary(BaseModel):
"""Final detection report summary."""
video_source: str
content_type: str
duration_seconds: float
total_detections: int = 0
brands: List[BrandSummary] = Field(default_factory=list)
stats: Optional[StatsUpdate] = None
class JobComplete(BaseModel):
"""Final report when pipeline finishes. SSE event: job_complete"""
job_id: str
report: Optional[DetectionReportSummary] = None
class RunContext(BaseModel):
"""Run context injected into all SSE events for grouping."""
run_id: str
parent_job_id: str
run_type: str = "initial"
class CheckpointInfo(BaseModel):
"""Available checkpoint for a stage."""
stage: str
is_scenario: bool = False
scenario_label: str = ""
class ReplayRequest(BaseModel):
"""Request to replay pipeline from a specific stage."""
job_id: str
start_stage: str
config_overrides: Optional[Dict[str, Any]] = None
class ReplayResponse(BaseModel):
"""Result of a replay invocation."""
status: str
job_id: str
start_stage: str
detections: int = 0
brands_found: int = 0
class RetryRequest(BaseModel):
"""Request to queue async retry with different config."""
job_id: str
config_overrides: Optional[Dict[str, Any]] = None
start_stage: str = "escalate_vlm"
schedule_seconds: Optional[float] = None
class RetryResponse(BaseModel):
"""Result of queueing a retry task."""
status: str
task_id: str
job_id: str
class RunRequest(BaseModel):
"""Request body for launching a detection pipeline run."""
video_path: str
profile_name: str = "soccer_broadcast"
source_asset_id: str = ""
checkpoint: bool = True
skip_vlm: bool = False
skip_cloud: bool = False
log_level: str = "INFO"
class RunResponse(BaseModel):
"""Response after starting a pipeline run."""
status: str
job_id: str
video_path: str

View File

@@ -0,0 +1,22 @@
"""
Pipeline stages.
Each stage is a file with a Stage subclass. Auto-discovered via
__init_subclass__ — importing the file registers the stage.
"""
from .base import (
Stage,
get_stage,
get_stage_instance,
list_stages,
list_stage_classes,
get_palette,
)
# Import all stage files to trigger auto-registration
from . import edge_detector # noqa: F401
from . import field_segmentation # noqa: F401
# Import registry for backward compat (other stages still use old pattern)
from . import registry # noqa: F401

View File

@@ -0,0 +1,116 @@
"""
Stage 8 — Report compilation
Groups all detections by brand, merges contiguous appearances,
and builds the final DetectionReport.
"""
from __future__ import annotations
import logging
from core.detect import emit
from core.detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
logger = logging.getLogger(__name__)
def _merge_contiguous(detections: list[BrandDetection], gap_threshold: float = 2.0) -> list[BrandDetection]:
"""
Merge detections of the same brand that are close in time.
If two detections of the same brand are within gap_threshold seconds,
they're merged into one detection spanning the full range.
"""
if not detections:
return []
sorted_dets = sorted(detections, key=lambda d: (d.brand, d.timestamp))
merged: list[BrandDetection] = []
current = sorted_dets[0]
for det in sorted_dets[1:]:
if (det.brand == current.brand
and det.timestamp <= current.timestamp + current.duration + gap_threshold):
end = max(current.timestamp + current.duration,
det.timestamp + det.duration)
current = BrandDetection(
brand=current.brand,
timestamp=current.timestamp,
duration=end - current.timestamp,
confidence=max(current.confidence, det.confidence),
source=current.source,
bbox=current.bbox,
frame_ref=current.frame_ref,
content_type=current.content_type,
)
else:
merged.append(current)
current = det
merged.append(current)
return merged
def compile_report(
detections: list[BrandDetection],
stats: PipelineStats,
video_source: str = "",
content_type: str = "",
duration_seconds: float = 0.0,
job_id: str | None = None,
) -> DetectionReport:
"""
Build the final detection report from all accumulated detections.
Merges contiguous detections, computes per-brand stats,
and emits the job_complete event.
"""
merged = _merge_contiguous(detections)
brands: dict[str, BrandStats] = {}
for d in merged:
if d.brand not in brands:
brands[d.brand] = BrandStats()
s = brands[d.brand]
s.total_appearances += 1
s.total_screen_time += d.duration
s.avg_confidence = (
(s.avg_confidence * (s.total_appearances - 1) + d.confidence)
/ s.total_appearances
)
if s.first_seen == 0.0 or d.timestamp < s.first_seen:
s.first_seen = d.timestamp
if d.timestamp > s.last_seen:
s.last_seen = d.timestamp
report = DetectionReport(
video_source=video_source,
content_type=content_type,
duration_seconds=duration_seconds,
brands=brands,
timeline=sorted(merged, key=lambda d: d.timestamp),
pipeline_stats=stats,
)
emit.log(job_id, "Aggregator", "INFO",
f"Report: {len(brands)} brands, {len(merged)} detections "
f"(merged from {len(detections)} raw)")
emit.job_complete(job_id, {
"video_source": report.video_source,
"content_type": report.content_type,
"duration_seconds": report.duration_seconds,
"brands": {
k: {
"total_appearances": v.total_appearances,
"total_screen_time": v.total_screen_time,
"avg_confidence": round(v.avg_confidence, 3),
"first_seen": v.first_seen,
"last_seen": v.last_seen,
}
for k, v in brands.items()
},
})
return report

151
core/detect/stages/base.py Normal file
View File

@@ -0,0 +1,151 @@
"""
Stage base class — common interface for all pipeline stages.
Each stage is a file that subclasses Stage. Auto-discovered via
__init_subclass__. No manual registration needed.
A stage:
- Has a StageDefinition (generated from schema) with name, config, IO
- Implements run(frames, config) → output
- Owns its output serialization (opaque blob)
- Optionally has a TypeScript port for browser-side execution
The checkpoint layer stores stage output as blobs without knowing
the format. The stage that wrote it is the only one that can read it.
"""
from __future__ import annotations
from typing import Any
import numpy as np
from core.detect.stages.models import (
StageConfigField,
StageIO,
StageDefinition,
)
# Legacy runtime extension — adds callable fields for old-style stages.
# New stages use Stage subclass with serialize()/deserialize() methods instead.
class LegacyStageDefinition:
"""Wraps a StageDefinition with callable serialize/deserialize functions."""
def __init__(self, definition: StageDefinition, fn=None, serialize_fn=None, deserialize_fn=None):
self._definition = definition
self.fn = fn
self.serialize_fn = serialize_fn
self.deserialize_fn = deserialize_fn
def __getattr__(self, name):
return getattr(self._definition, name)
# ---------------------------------------------------------------------------
# Registry — auto-populated by __init_subclass__ (new stages)
# + register_stage() (legacy stages during migration)
# ---------------------------------------------------------------------------
_REGISTRY: dict[str, type['Stage']] = {}
_LEGACY_REGISTRY: dict[str, LegacyStageDefinition] = {}
def register_stage(
definition: StageDefinition,
fn=None,
serialize_fn=None,
deserialize_fn=None,
):
"""Legacy registration for stages not yet converted to Stage subclass."""
legacy = LegacyStageDefinition(definition, fn=fn, serialize_fn=serialize_fn, deserialize_fn=deserialize_fn)
_LEGACY_REGISTRY[definition.name] = legacy
class Stage:
"""
Base class for all pipeline stages.
Subclass this in detect/stages/<name>.py. Define `definition` as a
class attribute. Implement `run()`. Optionally override `serialize()`
and `deserialize()` for custom blob formats (default is JSON).
"""
definition: StageDefinition # set by each subclass
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if hasattr(cls, 'definition') and cls.definition is not None:
_REGISTRY[cls.definition.name] = cls
def run(self, frames: list, config: dict) -> Any:
raise NotImplementedError
def serialize(self, output: Any) -> bytes:
"""Serialize stage output to bytes for checkpoint storage."""
import json
return json.dumps(output, default=str).encode()
def deserialize(self, data: bytes) -> Any:
"""Deserialize stage output from checkpoint blob."""
import json
return json.loads(data)
# ---------------------------------------------------------------------------
# Discovery API
# ---------------------------------------------------------------------------
def _all_definitions():
"""Merge new Stage subclass registry + legacy registry.
Returns StageDefinition for new-style stages,
LegacyStageDefinition for legacy stages (has serialize_fn etc).
"""
merged = {}
for name, legacy in _LEGACY_REGISTRY.items():
merged[name] = legacy
for name, cls in _REGISTRY.items():
merged[name] = cls.definition
return merged
def get_stage(name: str) -> StageDefinition:
"""Get a stage definition by name (works for both new and legacy)."""
all_defs = _all_definitions()
if name not in all_defs:
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(all_defs)}")
return all_defs[name]
def get_stage_class(name: str) -> type[Stage] | None:
"""Get a Stage subclass by name. Returns None for legacy stages."""
return _REGISTRY.get(name)
def get_stage_instance(name: str) -> Stage:
"""Get an instantiated Stage by name. Only works for new-style stages."""
cls = _REGISTRY.get(name)
if cls is None:
raise KeyError(f"No Stage subclass for {name!r}. Legacy stages don't have instances.")
return cls()
def list_stages() -> list[StageDefinition]:
"""List all registered stage definitions (new + legacy)."""
return list(_all_definitions().values())
def list_stage_classes() -> list[type[Stage]]:
"""List all registered Stage subclasses (new-style only)."""
return list(_REGISTRY.values())
def get_palette() -> dict[str, list[StageDefinition]]:
"""Group stages by category for the editor palette."""
palette: dict[str, list[StageDefinition]] = {}
for defn in _all_definitions().values():
if defn.category not in palette:
palette[defn.category] = []
palette[defn.category].append(defn)
return palette

View File

@@ -0,0 +1,216 @@
"""
Stage 5 — Brand Resolver (discovery mode)
Discovery-first brand matching. No static dictionary — all brands live in the DB.
Flow:
1. Check session brands first (brands already seen in this run, in-memory)
2. Check global known brands (accumulated across all runs)
3. Unresolved candidates → escalate to VLM/cloud
4. Confirmed brands get added to DB for future runs
"""
from __future__ import annotations
import logging
from rapidfuzz import fuzz
from core.detect import emit
from core.detect.models import BrandDetection, TextCandidate
from core.detect.stages.models import ResolverConfig
logger = logging.getLogger(__name__)
def _normalize(text: str) -> str:
return text.strip().lower()
def _has_db() -> bool:
try:
from core.db import find_brand_by_text as _
return True
except (ImportError, Exception):
return False
def _match_session(text: str, session_brands: dict[str, str]) -> str | None:
return session_brands.get(_normalize(text))
def _match_known(text: str, threshold: int) -> tuple[str | None, str | None]:
"""Check against global known brands in DB. Returns (canonical_name, brand_id) or (None, None)."""
if not _has_db():
return None, None
from core.db import find_brand_by_text, list_brands
from core.db.connection import get_session
with get_session() as session:
brand = find_brand_by_text(session, text)
if brand:
return brand.canonical_name, str(brand.id)
all_brands = list_brands(session)
normalized = _normalize(text)
best_brand = None
best_score = 0
for known in all_brands:
names = [known.canonical_name] + (known.aliases or [])
for name in names:
score = fuzz.ratio(normalized, _normalize(name))
if score > best_score and score >= threshold:
best_score = score
best_brand = known
if best_brand:
return best_brand.canonical_name, str(best_brand.id)
return None, None
def _register_brand(canonical_name: str, source: str) -> str | None:
"""Register a newly discovered brand in the DB. Returns brand_id."""
if not _has_db():
return None
from core.db import get_or_create_brand
from core.db.connection import get_session
with get_session() as session:
brand, created = get_or_create_brand(session, canonical_name, source=source)
session.commit()
if created:
logger.info("New brand discovered: %s (source=%s)", canonical_name, source)
return str(brand.id)
def _record_airing(timeline_id: str | None, brand_id: str,
frame_seq: int, confidence: float, source: str):
"""Record a brand airing on a timeline."""
if not _has_db() or not timeline_id:
return
from core.db import record_airing
from core.db.connection import get_session
from uuid import UUID
with get_session() as session:
record_airing(
session,
brand_id=UUID(brand_id),
timeline_id=UUID(timeline_id),
frame_start=frame_seq,
frame_end=frame_seq,
confidence=confidence,
source=source,
)
session.commit()
def build_session_dict(source_asset_id: str | None = None) -> dict[str, str]:
"""
Load known brands from DB as a session lookup dict.
Returns {normalized_name: canonical_name, ...} including aliases.
"""
if not _has_db():
return {}
from core.db import list_brands
from core.db.connection import get_session
with get_session() as session:
all_brands = list_brands(session)
session_dict = {}
for brand in all_brands:
session_dict[_normalize(brand.canonical_name)] = brand.canonical_name
for alias in (brand.aliases or []):
session_dict[_normalize(alias)] = brand.canonical_name
return session_dict
def resolve_brands(
candidates: list[TextCandidate],
config: ResolverConfig,
session_brands: dict[str, str] | None = None,
source_asset_id: str | None = None,
content_type: str = "",
job_id: str | None = None,
) -> tuple[list[BrandDetection], list[TextCandidate]]:
"""
Match text candidates against known brands (session → global → unresolved).
session_brands: pre-loaded session dict (from build_session_dict)
job_id: timeline_id — used to record airings
"""
if session_brands is None:
session_brands = {}
emit.log(job_id, "BrandResolver", "INFO",
f"Resolving {len(candidates)} candidates "
f"(session={len(session_brands)} brands, fuzzy={config.fuzzy_threshold})")
matched: list[BrandDetection] = []
unresolved: list[TextCandidate] = []
session_hits = 0
known_hits = 0
for candidate in candidates:
text = candidate.text
brand_name = None
brand_id = None
match_source = "ocr"
# 1. Check session (cheapest — in-memory dict)
brand_name = _match_session(text, session_brands)
if brand_name:
session_hits += 1
else:
# 2. Check global known brands (DB query + fuzzy)
brand_name, brand_id = _match_known(text, config.fuzzy_threshold)
if brand_name:
known_hits += 1
session_brands[_normalize(brand_name)] = brand_name
if brand_name:
detection = BrandDetection(
brand=brand_name,
timestamp=candidate.frame.timestamp,
duration=0.5,
confidence=candidate.ocr_confidence,
source=match_source,
bbox=candidate.bbox,
frame_ref=candidate.frame.sequence,
content_type=content_type,
)
matched.append(detection)
if brand_id:
_record_airing(
job_id, brand_id,
candidate.frame.sequence, candidate.ocr_confidence, match_source,
)
emit.detection(
job_id,
brand=brand_name,
confidence=candidate.ocr_confidence,
source=match_source,
timestamp=candidate.frame.timestamp,
content_type=content_type,
frame_ref=candidate.frame.sequence,
)
else:
unresolved.append(candidate)
emit.log(job_id, "BrandResolver", "INFO",
f"Session: {session_hits}, Known: {known_hits}, "
f"Unresolved: {len(unresolved)} → escalating")
return matched, unresolved

View File

@@ -0,0 +1,288 @@
"""
Stage — Edge Detection
Canny + HoughLinesP to find horizontal line pairs that bound
advertising hoardings. Pure OpenCV, no ML models.
Two modes:
- Remote: calls GPU inference server over HTTP
- Local: imports cv2 directly (OpenCV on same machine)
"""
from __future__ import annotations
import base64
import io
import json
import logging
import os
import time
from typing import Any
from PIL import Image
from core.detect import emit
from core.detect.models import BoundingBox, Frame
from core.detect.stages.base import Stage
from core.detect.stages.models import StageDefinition, StageConfigField, StageIO
logger = logging.getLogger(__name__)
class EdgeDetectionStage(Stage):
definition = StageDefinition(
name="detect_edges",
label="Edge Detection",
description="Canny + HoughLinesP — find horizontal line pairs (hoarding boundaries)",
category="cv_analysis",
io=StageIO(
reads=["filtered_frames"],
writes=["edge_regions_by_frame"],
),
config_fields=[
StageConfigField(name="enabled", type="bool", default=True, description="Enable edge detection"),
StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255),
StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255),
StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500),
StageConfigField(name="edge_hough_min_length", type="int", default=100, description="Min line length (px)", min=10, max=2000),
StageConfigField(name="edge_hough_max_gap", type="int", default=10, description="Max line gap (px)", min=1, max=100),
StageConfigField(name="edge_pair_max_distance", type="int", default=200, description="Max distance between line pair (px)", min=10, max=500),
StageConfigField(name="edge_pair_min_distance", type="int", default=15, description="Min distance between line pair (px)", min=5, max=200),
],
tracks_element="edge_region",
)
def run(self, frames: list[Frame], config: dict) -> dict[int, list[BoundingBox]]:
"""
Run edge detection on all frames.
Config keys: enabled, edge_canny_low, edge_canny_high, edge_hough_threshold,
edge_hough_min_length, edge_hough_max_gap, edge_pair_max_distance, edge_pair_min_distance,
debug (bool), inference_url (str|None), job_id (str|None).
Returns dict mapping frame sequence → list of BoundingBox.
"""
enabled = config.get("enabled", True)
job_id = config.get("job_id")
inference_url = config.get("inference_url") or os.environ.get("INFERENCE_URL")
if not enabled:
emit.log(job_id, "EdgeDetection", "INFO", "Edge detection disabled, skipping")
return {}
mode = "remote" if inference_url else "local"
emit.log(job_id, "EdgeDetection", "INFO",
f"Detecting edges in {len(frames)} frames (mode={mode})")
all_boxes: dict[int, list[BoundingBox]] = {}
total_regions = 0
for frame in frames:
t0 = time.monotonic()
if inference_url:
boxes = self._run_remote(frame, config, inference_url, job_id or "")
else:
boxes = self._run_local(frame, config)
ms = (time.monotonic() - t0) * 1000
all_boxes[frame.sequence] = boxes
total_regions += len(boxes)
emit.log(job_id, "EdgeDetection", "DEBUG",
f"Frame {frame.sequence}: {len(boxes)} regions in {ms:.0f}ms"
+ (f" [{', '.join(b.label for b in boxes)}]" if boxes else ""))
if boxes and job_id:
box_dicts = [
{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
"confidence": b.confidence, "label": b.label,
"stage": "detect_edges"}
for b in boxes
]
emit.frame_update(
job_id,
frame_ref=frame.sequence,
timestamp=frame.timestamp,
jpeg_b64=_frame_to_b64(frame),
boxes=box_dicts,
)
emit.log(job_id, "EdgeDetection", "INFO",
f"Found {total_regions} edge regions across {len(frames)} frames")
emit.stats(job_id, cv_regions_detected=total_regions)
return all_boxes
def serialize(self, output: Any) -> bytes:
"""Serialize edge regions to JSON blob."""
serialized = {}
for seq, boxes in output.items():
serialized[str(seq)] = [
{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
"confidence": b.confidence, "label": b.label}
for b in boxes
]
return json.dumps(serialized).encode()
def deserialize(self, data: bytes) -> dict[int, list[BoundingBox]]:
"""Deserialize edge regions from JSON blob."""
raw = json.loads(data)
result = {}
for seq_str, box_dicts in raw.items():
boxes = [
BoundingBox(x=b["x"], y=b["y"], w=b["w"], h=b["h"],
confidence=b["confidence"], label=b["label"])
for b in box_dicts
]
result[int(seq_str)] = boxes
return result
# --- Private helpers ---
def _run_remote(self, frame: Frame, config: dict,
inference_url: str, job_id: str) -> list[BoundingBox]:
from core.detect.inference import InferenceClient
from core.detect.emit import _run_log_level
client = InferenceClient(
base_url=inference_url, job_id=job_id, log_level=_run_log_level,
)
results = client.detect_edges(
image=frame.image,
edge_canny_low=config.get("edge_canny_low", 50),
edge_canny_high=config.get("edge_canny_high", 150),
edge_hough_threshold=config.get("edge_hough_threshold", 80),
edge_hough_min_length=config.get("edge_hough_min_length", 100),
edge_hough_max_gap=config.get("edge_hough_max_gap", 10),
edge_pair_max_distance=config.get("edge_pair_max_distance", 200),
edge_pair_min_distance=config.get("edge_pair_min_distance", 15),
)
boxes = []
for r in results:
box = BoundingBox(
x=r.x, y=r.y, w=r.w, h=r.h,
confidence=r.confidence, label=r.label,
)
boxes.append(box)
return boxes
def _run_local(self, frame: Frame, config: dict) -> list[BoundingBox]:
detect_edges_fn = _load_cv_edges().detect_edges
edge_results = detect_edges_fn(
frame.image,
canny_low=config.get("edge_canny_low", 50),
canny_high=config.get("edge_canny_high", 150),
hough_threshold=config.get("edge_hough_threshold", 80),
hough_min_length=config.get("edge_hough_min_length", 100),
hough_max_gap=config.get("edge_hough_max_gap", 10),
pair_max_distance=config.get("edge_pair_max_distance", 200),
pair_min_distance=config.get("edge_pair_min_distance", 15),
)
boxes = []
for r in edge_results:
box = BoundingBox(
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
confidence=r["confidence"], label=r["label"],
)
boxes.append(box)
return boxes
# --- Module-level helpers ---
def _frame_to_b64(frame: Frame) -> str:
img = Image.fromarray(frame.image)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=70)
return base64.b64encode(buf.getvalue()).decode()
_cv_edges_mod = None
def _load_cv_edges():
global _cv_edges_mod
if _cv_edges_mod is None:
import importlib.util
from pathlib import Path
spec = importlib.util.spec_from_file_location("cv_edges", Path("core/gpu/models/cv/edges.py"))
_cv_edges_mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(_cv_edges_mod)
return _cv_edges_mod
# --- Backward compat: standalone function for graph.py ---
def _filter_by_field_mask(boxes, mask, margin_px=50):
"""
Keep only boxes that are near the pitch boundary (hoarding zone).
The field mask has 255=pitch, 0=not pitch. Hoardings sit just
outside the pitch boundary. We dilate the mask to create a
"boundary zone" and keep boxes whose center falls in the zone
between the dilated mask edge and the original mask.
"""
import cv2
import numpy as np
if mask is None or not boxes:
return boxes
# Dilate the pitch mask — the expansion zone is where hoardings are
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (margin_px * 2, margin_px * 2))
dilated = cv2.dilate(mask, kernel)
# Boundary zone = dilated but NOT original pitch
boundary_zone = cv2.bitwise_and(dilated, cv2.bitwise_not(mask))
kept = []
for box in boxes:
cx = box.x + box.w // 2
cy = box.y + box.h // 2
# Clamp to image bounds
cy = min(cy, boundary_zone.shape[0] - 1)
cx = min(cx, boundary_zone.shape[1] - 1)
if boundary_zone[cy, cx] > 0:
kept.append(box)
return kept
def detect_edge_regions(frames, config, inference_url=None, job_id=None, field_masks=None):
"""Convenience wrapper — calls EdgeDetectionStage.run(), optionally filters by field mask."""
stage = EdgeDetectionStage()
cfg = {
"enabled": config.enabled,
"edge_canny_low": config.edge_canny_low,
"edge_canny_high": config.edge_canny_high,
"edge_hough_threshold": config.edge_hough_threshold,
"edge_hough_min_length": config.edge_hough_min_length,
"edge_hough_max_gap": config.edge_hough_max_gap,
"edge_pair_max_distance": config.edge_pair_max_distance,
"edge_pair_min_distance": config.edge_pair_min_distance,
"inference_url": inference_url,
"job_id": job_id,
}
all_boxes = stage.run(frames, cfg)
# Filter by field segmentation mask if available
if field_masks:
filtered_total = 0
original_total = sum(len(b) for b in all_boxes.values())
for seq, boxes in all_boxes.items():
mask = field_masks.get(seq)
if mask is not None:
all_boxes[seq] = _filter_by_field_mask(boxes, mask)
filtered_total += len(all_boxes[seq])
else:
filtered_total += len(boxes)
if original_total != filtered_total:
from core.detect import emit
emit.log(job_id, "EdgeDetection", "INFO",
f"Field mask filter: {original_total}{filtered_total} regions")
return all_boxes

View File

@@ -0,0 +1,141 @@
"""
Stage — Field Segmentation
Calls the GPU inference server to detect pitch boundaries via
HSV green mask + morphology. The CV code lives in core/gpu/models/cv/.
Outputs a mask and boundary that downstream stages use as spatial priors.
"""
from __future__ import annotations
import base64
import io
import logging
import numpy as np
from PIL import Image
from core.detect import emit
from core.detect.models import Frame
from core.detect.stages.base import Stage
from core.detect.stages.models import (
FieldSegmentationConfig,
StageConfigField,
StageDefinition,
StageIO,
)
logger = logging.getLogger(__name__)
class FieldSegmentationStage(Stage):
definition = StageDefinition(
name="field_segmentation",
label="Field Segmentation",
description="HSV green mask — detect pitch boundaries for spatial priors",
category="cv_analysis",
io=StageIO(
reads=["filtered_frames"],
writes=["field_mask"],
),
config_fields=[
StageConfigField(name="enabled", type="bool", default=True, description="Enable field segmentation"),
StageConfigField(name="hue_low", type="int", default=30, description="HSV hue lower bound", min=0, max=180),
StageConfigField(name="hue_high", type="int", default=85, description="HSV hue upper bound", min=0, max=180),
StageConfigField(name="sat_low", type="int", default=30, description="HSV saturation lower bound", min=0, max=255),
StageConfigField(name="sat_high", type="int", default=255, description="HSV saturation upper bound", min=0, max=255),
StageConfigField(name="val_low", type="int", default=30, description="HSV value lower bound", min=0, max=255),
StageConfigField(name="val_high", type="int", default=255, description="HSV value upper bound", min=0, max=255),
StageConfigField(name="morph_kernel", type="int", default=15, description="Morphology kernel size", min=3, max=51),
StageConfigField(name="min_area_ratio", type="float", default=0.05, description="Min contour area as fraction of frame", min=0.01, max=0.5),
],
)
def _frame_to_b64(frame: Frame) -> str:
"""Encode frame image as base64 JPEG."""
img = Image.fromarray(frame.image)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
return base64.b64encode(buf.getvalue()).decode()
def _decode_mask_b64(mask_b64: str) -> np.ndarray:
"""Decode a base64 PNG mask back to numpy array."""
data = base64.b64decode(mask_b64)
img = Image.open(io.BytesIO(data)).convert("L")
return np.array(img)
def run_field_segmentation(
frames: list[Frame],
config: FieldSegmentationConfig,
inference_url: str | None = None,
job_id: str | None = None,
) -> dict:
"""
Run field segmentation on all frames via the inference server.
Returns dict with:
field_masks: {seq: np.ndarray}
field_boundaries: {seq: [(x,y), ...]}
field_coverage: {seq: float}
"""
if not config.enabled:
emit.log(job_id, "FieldSegmentation", "INFO", "Disabled, skipping")
return {"field_masks": {}, "field_boundaries": {}, "field_coverage": {}}
import os
url = inference_url or os.environ.get("INFERENCE_URL")
if not url:
emit.log(job_id, "FieldSegmentation", "WARNING",
"No INFERENCE_URL, skipping field segmentation")
return {"field_masks": {}, "field_boundaries": {}, "field_coverage": {}}
emit.log(job_id, "FieldSegmentation", "INFO",
f"Segmenting {len(frames)} frames (hue={config.hue_low}-{config.hue_high})")
from core.detect.inference import InferenceClient
from core.detect.emit import _run_log_level
client = InferenceClient(base_url=url, job_id=job_id or "", log_level=_run_log_level)
field_masks = {}
field_boundaries = {}
field_coverage = {}
for frame in frames:
image_b64 = _frame_to_b64(frame)
resp = client.post("/segment_field", {
"image": image_b64,
"hue_low": config.hue_low,
"hue_high": config.hue_high,
"sat_low": config.sat_low,
"sat_high": config.sat_high,
"val_low": config.val_low,
"val_high": config.val_high,
"morph_kernel": config.morph_kernel,
"min_area_ratio": config.min_area_ratio,
})
if resp is None:
continue
mask_b64 = resp.get("mask_b64", "")
if mask_b64:
field_masks[frame.sequence] = _decode_mask_b64(mask_b64)
field_boundaries[frame.sequence] = resp.get("boundary", [])
field_coverage[frame.sequence] = resp.get("coverage", 0.0)
avg_coverage = sum(field_coverage.values()) / max(len(field_coverage), 1)
emit.log(job_id, "FieldSegmentation", "INFO",
f"Done: {len(frames)} frames, avg coverage {avg_coverage:.1%}")
return {
"field_masks": field_masks,
"field_boundaries": field_boundaries,
"field_coverage": field_coverage,
}

View File

@@ -0,0 +1,93 @@
"""
Stage 1 — Frame Extraction
Extracts frames from a video at a configurable FPS using the core ffmpeg module.
Emits log + stats_update SSE events as it works.
"""
from __future__ import annotations
import tempfile
import time
from pathlib import Path
import ffmpeg
import numpy as np
from PIL import Image
from core.ffmpeg.probe import probe_file
from core.detect import emit
from core.detect.models import Frame
from core.detect.stages.models import FrameExtractionConfig
def _load_frames(tmpdir: Path, fps: float) -> list[Frame]:
"""Load extracted JPEG files into Frame objects."""
frame_files = sorted(tmpdir.glob("frame_*.jpg"))
frames = []
for i, fpath in enumerate(frame_files):
img = Image.open(fpath)
frame = Frame(
sequence=i,
chunk_id=0,
timestamp=i / fps,
image=np.array(img),
)
frames.append(frame)
return frames
def extract_frames(
video_path: str,
config: FrameExtractionConfig,
job_id: str | None = None,
) -> list[Frame]:
"""
Extract frames from video at the configured FPS.
Uses ffmpeg-python to build the extraction pipeline,
outputs JPEG files to a temp dir, then loads as numpy arrays.
"""
probe = probe_file(video_path)
duration = probe.duration or 0.0
emit.log(job_id, "FrameExtractor", "INFO",
f"Starting extraction: {Path(video_path).name} "
f"({duration:.1f}s, {probe.width}x{probe.height}, fps={config.fps})")
emit.log(job_id, "FrameExtractor", "DEBUG",
f"Probe: codec={probe.video_codec}, bitrate={probe.video_bitrate}, max_frames={config.max_frames}")
with tempfile.TemporaryDirectory() as tmpdir:
pattern = str(Path(tmpdir) / "frame_%06d.jpg")
stream = (
ffmpeg
.input(video_path)
.filter("fps", fps=config.fps)
.output(pattern, qscale=2, frames=config.max_frames)
.overwrite_output()
)
t0 = time.monotonic()
try:
stream.run(capture_stdout=True, capture_stderr=True, quiet=True)
except ffmpeg.Error as e:
stderr = e.stderr.decode() if e.stderr else "unknown error"
emit.log(job_id, "FrameExtractor", "ERROR", f"FFmpeg failed: {stderr[:200]}")
raise RuntimeError(f"FFmpeg failed: {stderr}") from e
ffmpeg_ms = (time.monotonic() - t0) * 1000
emit.log(job_id, "FrameExtractor", "DEBUG", f"FFmpeg decode: {ffmpeg_ms:.0f}ms")
t0 = time.monotonic()
frames = _load_frames(Path(tmpdir), config.fps)
load_ms = (time.monotonic() - t0) * 1000
if frames:
h, w = frames[0].image.shape[:2]
mem_mb = sum(f.image.nbytes for f in frames) / (1024 * 1024)
emit.log(job_id, "FrameExtractor", "DEBUG",
f"Loaded {len(frames)} frames ({w}x{h}) in {load_ms:.0f}ms, {mem_mb:.1f}MB in memory")
emit.log(job_id, "FrameExtractor", "INFO", f"Extracted {len(frames)} frames")
emit.stats(job_id, frames_extracted=len(frames))
return frames

View File

@@ -0,0 +1,106 @@
"""
Pydantic Models - GENERATED FILE
Do not edit directly. Regenerate using modelgen.
"""
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel, Field
class StageConfigField(BaseModel):
"""A single tunable config parameter for the editor UI."""
name: str
type: str
default: Any
description: str = ""
min: Optional[float] = None
max: Optional[float] = None
options: Optional[List[str]] = None
class StageIO(BaseModel):
"""Declares what a stage reads and writes."""
reads: List[str] = Field(default_factory=list)
writes: List[str] = Field(default_factory=list)
optional_reads: List[str] = Field(default_factory=list)
class StageDefinition(BaseModel):
"""Complete metadata for a pipeline stage."""
name: str
label: str
description: str
category: str = "detection"
io: StageIO
config_fields: List[StageConfigField] = Field(default_factory=list)
tracks_element: Optional[str] = None
class FrameExtractionConfig(BaseModel):
"""FrameExtractionConfig(fps: float = 2.0, max_frames: int = 500)"""
fps: float = 2.0
max_frames: int = 500
class SceneFilterConfig(BaseModel):
"""SceneFilterConfig(hamming_threshold: int = 8, enabled: bool = True)"""
hamming_threshold: int = 8
enabled: bool = True
class DetectionConfig(BaseModel):
"""DetectionConfig(model_name: str = 'yolov8n.pt', confidence_threshold: float = 0.3, target_classes: List[str] = <factory>)"""
model_name: str = "yolov8n.pt"
confidence_threshold: float = 0.3
target_classes: List[str]
class OCRConfig(BaseModel):
"""OCRConfig(languages: List[str] = <factory>, min_confidence: float = 0.5)"""
languages: List[str]
min_confidence: float = 0.5
class ResolverConfig(BaseModel):
"""ResolverConfig(fuzzy_threshold: int = 75)"""
fuzzy_threshold: int = 75
class RegionAnalysisConfig(BaseModel):
"""RegionAnalysisConfig(enabled: bool = True, edge_canny_low: int = 50, edge_canny_high: int = 150, edge_hough_threshold: int = 80, edge_hough_min_length: int = 100, edge_hough_max_gap: int = 10, edge_pair_max_distance: int = 200, edge_pair_min_distance: int = 15)"""
enabled: bool = True
edge_canny_low: int = 50
edge_canny_high: int = 150
edge_hough_threshold: int = 80
edge_hough_min_length: int = 100
edge_hough_max_gap: int = 10
edge_pair_max_distance: int = 200
edge_pair_min_distance: int = 15
class FieldSegmentationConfig(BaseModel):
"""FieldSegmentationConfig(enabled: bool = True, hue_low: int = 30, hue_high: int = 85, sat_low: int = 30, sat_high: int = 255, val_low: int = 30, val_high: int = 255, morph_kernel: int = 15, min_area_ratio: float = 0.05)"""
enabled: bool = True
hue_low: int = 30
hue_high: int = 85
sat_low: int = 30
sat_high: int = 255
val_low: int = 30
val_high: int = 255
morph_kernel: int = 15
min_area_ratio: float = 0.05
class StageRef(BaseModel):
"""Reference to a stage in the pipeline graph."""
name: str
branch: str = "trunk"
execution_target: str = "local"
class Edge(BaseModel):
"""Connection between stages in the graph."""
source: str
target: str
condition: str = ""
class PipelineConfig(BaseModel):
"""Pipeline graph topology + routing rules."""
name: str
profile_name: str
stages: List[StageRef] = Field(default_factory=list)
edges: List[Edge] = Field(default_factory=list)
routing_rules: Dict[str, Any]

View File

@@ -0,0 +1,139 @@
"""
Stage 4 — OCR
Reads text from detected regions (YOLO bounding box crops).
Two modes:
- remote: calls inference server over HTTP (separate GPU box, or localhost)
- local: runs PaddleOCR in-process (single-box setup with enough VRAM)
The mode is selected by whether inference_url is provided.
Model instances are cached at module level so they survive across pipeline runs.
"""
from __future__ import annotations
import logging
import time
from typing import TYPE_CHECKING
import numpy as np
from core.detect import emit
from core.detect.models import BoundingBox, Frame, TextCandidate
from core.detect.stages.models import OCRConfig
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
# Module-level cache — avoids reloading the model for every crop or pipeline run
_local_ocr_cache: dict[str, object] = {}
def _crop_region(frame: Frame, box: BoundingBox) -> np.ndarray:
h, w = frame.image.shape[:2]
x1 = max(0, box.x)
y1 = max(0, box.y)
x2 = min(w, box.x + box.w)
y2 = min(h, box.y + box.h)
return frame.image[y1:y2, x1:x2]
def _get_local_model(lang: str):
if lang not in _local_ocr_cache:
from paddleocr import PaddleOCR
logger.info("Loading PaddleOCR locally (lang=%s)", lang)
_local_ocr_cache[lang] = PaddleOCR(lang=lang)
return _local_ocr_cache[lang]
def _parse_ocr_raw(raw, min_confidence: float) -> list[dict]:
"""Parse PaddleOCR 3.x result — handles dict-based and nested-list layouts."""
results = []
for page in (raw or []):
if not page:
continue
if isinstance(page, dict):
for text, confidence in zip(page.get("rec_texts", []), page.get("rec_scores", [])):
if float(confidence) >= min_confidence:
results.append({"text": text, "confidence": float(confidence)})
continue
for line in page:
if not line:
continue
rec = line[1]
if isinstance(rec, (list, tuple)) and len(rec) >= 2:
text, confidence = rec[0], rec[1]
if float(confidence) >= min_confidence:
results.append({"text": text, "confidence": float(confidence)})
return results
def run_ocr(
frames: list[Frame],
boxes_by_frame: dict[int, list[BoundingBox]],
config: OCRConfig,
inference_url: str | None = None,
job_id: str | None = None,
) -> list[TextCandidate]:
"""
Run OCR on cropped regions from YOLO detections.
inference_url=None → local in-process PaddleOCR (single-box)
inference_url=str → remote inference server (split or localhost)
"""
total_regions = sum(len(boxes) for boxes in boxes_by_frame.values())
mode = "remote" if inference_url else "local"
emit.log(job_id, "OCRStage", "INFO",
f"Running OCR on {total_regions} regions (mode={mode})")
# Build these once per pipeline run, not per crop
if inference_url:
from core.detect.inference import InferenceClient
from core.detect.emit import _run_log_level
client = InferenceClient(base_url=inference_url, job_id=job_id or "", log_level=_run_log_level)
else:
model = _get_local_model(config.languages[0])
frame_map = {f.sequence: f for f in frames}
candidates: list[TextCandidate] = []
for seq, boxes in boxes_by_frame.items():
frame = frame_map.get(seq)
if not frame:
continue
for box in boxes:
crop = _crop_region(frame, box)
if crop.size == 0:
continue
t0 = time.monotonic()
if inference_url:
raw_results = client.ocr(image=crop, languages=config.languages)
texts = [{"text": r.text, "confidence": r.confidence} for r in raw_results]
else:
raw = model.ocr(crop)
texts = _parse_ocr_raw(raw, config.min_confidence)
ocr_ms = (time.monotonic() - t0) * 1000
h, w = crop.shape[:2]
text_preview = ", ".join(t["text"][:30] for t in texts) if texts else "(none)"
emit.log(job_id, "OCRStage", "DEBUG",
f"Frame {seq} box {box.x},{box.y} ({w}x{h}): {ocr_ms:.0f}ms → {text_preview}")
for t in texts:
candidates.append(TextCandidate(
frame=frame,
bbox=box,
text=t["text"],
ocr_confidence=t["confidence"],
))
emit.log(job_id, "OCRStage", "INFO",
f"Extracted text from {len(candidates)} regions")
emit.stats(job_id, regions_resolved_by_ocr=len(candidates))
return candidates

View File

@@ -0,0 +1,128 @@
"""
Stage 3.5 — Preprocessing
Runs between YOLO detection and OCR. Applies configurable image
preprocessing to each detected region crop: contrast enhancement,
deskewing, binarization.
Operates on the crops derived from boxes_by_frame, produces
preprocessed_crops keyed by (frame_sequence, box_index).
"""
from __future__ import annotations
import logging
import numpy as np
from core.detect import emit
from core.detect.models import BoundingBox, Frame
logger = logging.getLogger(__name__)
def _crop_region(frame: Frame, box: BoundingBox) -> np.ndarray:
h, w = frame.image.shape[:2]
x1 = max(0, box.x)
y1 = max(0, box.y)
x2 = min(w, box.x + box.w)
y2 = min(h, box.y + box.h)
return frame.image[y1:y2, x1:x2]
def preprocess_regions(
frames: list[Frame],
boxes_by_frame: dict[int, list[BoundingBox]],
do_contrast: bool = True,
do_deskew: bool = False,
do_binarize: bool = False,
inference_url: str | None = None,
job_id: str | None = None,
) -> dict[str, np.ndarray]:
"""
Preprocess cropped regions from YOLO detections.
Returns dict keyed by "{frame_seq}_{box_idx}" → preprocessed crop.
These are passed to the OCR stage instead of raw crops.
"""
total_regions = sum(len(boxes) for boxes in boxes_by_frame.values())
any_active = do_contrast or do_deskew or do_binarize
if not any_active:
emit.log(job_id, "Preprocess", "INFO",
f"Preprocessing disabled, passing {total_regions} regions through")
return {}
mode = "remote" if inference_url else "local"
emit.log(job_id, "Preprocess", "INFO",
f"Preprocessing {total_regions} regions (mode={mode}, "
f"contrast={do_contrast}, deskew={do_deskew}, binarize={do_binarize})")
frame_map = {f.sequence: f for f in frames}
preprocessed: dict[str, np.ndarray] = {}
processed_count = 0
for seq, boxes in boxes_by_frame.items():
frame = frame_map.get(seq)
if not frame:
continue
for idx, box in enumerate(boxes):
crop = _crop_region(frame, box)
if crop.size == 0:
continue
key = f"{seq}_{idx}"
if inference_url:
result = _preprocess_remote(crop, inference_url,
do_contrast, do_deskew, do_binarize)
else:
result = _preprocess_local(crop, do_contrast, do_deskew, do_binarize)
preprocessed[key] = result
processed_count += 1
emit.log(job_id, "Preprocess", "INFO",
f"Preprocessed {processed_count} regions")
return preprocessed
def _preprocess_remote(crop: np.ndarray, inference_url: str,
do_contrast: bool, do_deskew: bool, do_binarize: bool) -> np.ndarray:
"""Call GPU server /preprocess endpoint."""
import base64
import io
import requests
from PIL import Image
img = Image.fromarray(crop)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
image_b64 = base64.b64encode(buf.getvalue()).decode()
resp = requests.post(
f"{inference_url.rstrip('/')}/preprocess",
json={
"image": image_b64,
"contrast": do_contrast,
"deskew": do_deskew,
"binarize": do_binarize,
},
timeout=30,
)
resp.raise_for_status()
data = resp.json()
result_bytes = base64.b64decode(data["image"])
result_img = Image.open(io.BytesIO(result_bytes)).convert("RGB")
return np.array(result_img)
def _preprocess_local(crop: np.ndarray,
do_contrast: bool, do_deskew: bool, do_binarize: bool) -> np.ndarray:
"""Run preprocessing in-process (requires opencv-python-headless)."""
from core.gpu.models.preprocess import preprocess
return preprocess(crop, do_binarize=do_binarize, do_deskew=do_deskew, do_contrast=do_contrast)

View File

@@ -0,0 +1,31 @@
"""
Stage registry — registers all built-in stages.
Split by category:
preprocessing.py — extract_frames, filter_scenes
cv_analysis.py — detect_edges (+ future: detect_contours, detect_color, merge_regions)
detection.py — detect_objects, run_ocr
resolution.py — match_brands
escalation.py — escalate_vlm, escalate_cloud
output.py — compile_report
_serializers.py — shared serialization helpers
"""
from . import preprocessing
from . import cv_analysis
from . import detection
from . import resolution
from . import escalation
from . import output
def register_all():
preprocessing.register()
cv_analysis.register()
detection.register()
resolution.register()
escalation.register()
output.register()
register_all()

View File

@@ -0,0 +1,25 @@
"""
Re-export serializers from core/schema/serializers/.
Stage registry modules import from here for convenience.
All serialization logic lives in core/schema/serializers/.
"""
from core.schema.serializers._common import (
safe_construct,
serialize_dataclass,
serialize_dataclass_list,
)
from core.schema.serializers.pipeline import (
serialize_frame_meta,
serialize_frames_with_upload as serialize_frames,
deserialize_frames_with_download as deserialize_frames,
serialize_text_candidate,
serialize_text_candidates,
deserialize_text_candidate,
deserialize_text_candidates,
deserialize_bounding_box,
deserialize_brand_detection,
deserialize_pipeline_stats,
deserialize_detection_report,
)

View File

@@ -0,0 +1,44 @@
"""Registration for CV analysis stages: edge detection."""
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
from core.detect.stages.base import register_stage
from ._serializers import serialize_dataclass_list, deserialize_bounding_box
def _ser_regions(state: dict, job_id: str) -> dict:
regions = state.get("edge_regions_by_frame", {})
serialized = {
str(seq): serialize_dataclass_list(bl) for seq, bl in regions.items()
}
return {"edge_regions_by_frame": serialized}
def _deser_regions(data: dict, job_id: str) -> dict:
regions = {}
for seq_str, box_dicts in data.get("edge_regions_by_frame", {}).items():
regions[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts]
return {"edge_regions_by_frame": regions}
def register():
edge_detection = StageDefinition(
name="detect_edges",
label="Edge Detection",
description="Canny + HoughLinesP — find horizontal line pairs (hoarding boundaries)",
category="cv_analysis",
io=StageIO(
reads=["filtered_frames"],
writes=["edge_regions_by_frame"],
),
config_fields=[
StageConfigField(name="enabled", type="bool", default=True, description="Enable region analysis"),
StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255),
StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255),
StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500),
StageConfigField(name="edge_hough_min_length", type="int", default=100, description="Min line length (px)", min=10, max=2000),
StageConfigField(name="edge_hough_max_gap", type="int", default=10, description="Max line gap (px)", min=1, max=100),
StageConfigField(name="edge_pair_max_distance", type="int", default=200, description="Max distance between line pair (px)", min=10, max=500),
StageConfigField(name="edge_pair_min_distance", type="int", default=15, description="Min distance between line pair (px)", min=5, max=200),
],
)
register_stage(edge_detection, serialize_fn=_ser_regions, deserialize_fn=_deser_regions)

View File

@@ -0,0 +1,60 @@
"""Registration for detection stages: YOLO, OCR."""
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
from core.detect.stages.base import register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
deserialize_bounding_box,
)
def _ser_detect(state: dict, job_id: str) -> dict:
boxes = state.get("boxes_by_frame", {})
serialized = {str(seq): serialize_dataclass_list(bl) for seq, bl in boxes.items()}
return {"boxes_by_frame": serialized}
def _deser_detect(data: dict, job_id: str) -> dict:
boxes = {}
for seq_str, box_dicts in data.get("boxes_by_frame", {}).items():
boxes[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts]
return {"boxes_by_frame": boxes}
def _ser_ocr(state: dict, job_id: str) -> dict:
candidates = state.get("text_candidates", [])
return {"text_candidates": serialize_text_candidates(candidates)}
def _deser_ocr(data: dict, job_id: str) -> dict:
return {"_text_candidates_raw": data["text_candidates"]}
def register():
yolo = StageDefinition(
name="detect_objects",
label="Object Detection",
description="YOLO object detection on filtered frames",
category="detection",
io=StageIO(reads=["filtered_frames"], writes=["boxes_by_frame"]),
config_fields=[
StageConfigField(name="model_name", type="str", default="yolov8n.pt", description="YOLO model file"),
StageConfigField(name="confidence_threshold", type="float", default=0.3, description="Min detection confidence", min=0.0, max=1.0),
StageConfigField(name="target_classes", type="list[str]", default=[], description="YOLO classes to detect (empty = all)"),
],
)
register_stage(yolo, serialize_fn=_ser_detect, deserialize_fn=_deser_detect)
ocr = StageDefinition(
name="run_ocr",
label="OCR",
description="Extract text from detected regions",
category="detection",
io=StageIO(reads=["filtered_frames", "boxes_by_frame"], writes=["text_candidates"]),
config_fields=[
StageConfigField(name="languages", type="list[str]", default=["en"], description="OCR languages"),
StageConfigField(name="min_confidence", type="float", default=0.5, description="Min OCR confidence", min=0.0, max=1.0),
],
)
register_stage(ocr, serialize_fn=_ser_ocr, deserialize_fn=_deser_ocr)

View File

@@ -0,0 +1,60 @@
"""Registration for escalation stages: local VLM, cloud LLM."""
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
from core.detect.stages.base import register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
deserialize_brand_detection,
)
def _ser_escalation(state: dict, job_id: str) -> dict:
matched = state.get("detections", [])
unresolved = state.get("unresolved_candidates", [])
return {
"detections": serialize_dataclass_list(matched),
"unresolved_candidates": serialize_text_candidates(unresolved),
}
def _deser_escalation(data: dict, job_id: str) -> dict:
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
return {
"detections": detections,
"_unresolved_raw": data.get("unresolved_candidates", []),
}
def register():
vlm = StageDefinition(
name="escalate_vlm",
label="Local VLM",
description="Process unresolved crops with moondream2",
category="escalation",
io=StageIO(
reads=["unresolved_candidates"],
writes=["detections", "unresolved_candidates"],
optional_reads=["source_asset_id"],
),
config_fields=[
StageConfigField(name="min_confidence", type="float", default=0.5, description="Min VLM confidence", min=0.0, max=1.0),
],
)
register_stage(vlm, serialize_fn=_ser_escalation, deserialize_fn=_deser_escalation)
cloud = StageDefinition(
name="escalate_cloud",
label="Cloud LLM",
description="Escalate remaining crops to cloud provider",
category="escalation",
io=StageIO(
reads=["unresolved_candidates"],
writes=["detections"],
optional_reads=["source_asset_id"],
),
config_fields=[
StageConfigField(name="min_confidence", type="float", default=0.4, description="Min cloud confidence", min=0.0, max=1.0),
],
)
register_stage(cloud, serialize_fn=_ser_escalation, deserialize_fn=_deser_escalation)

View File

@@ -0,0 +1,30 @@
"""Registration for output stages: report compilation."""
from core.detect.stages.base import StageDefinition, StageIO, register_stage
from ._serializers import serialize_dataclass, deserialize_detection_report
def _ser_report(state: dict, job_id: str) -> dict:
report = state.get("report")
if report is None:
return {"report": None}
return {"report": serialize_dataclass(report)}
def _deser_report(data: dict, job_id: str) -> dict:
raw = data.get("report")
if raw is None:
return {"report": None}
return {"report": deserialize_detection_report(raw)}
def register():
report = StageDefinition(
name="compile_report",
label="Report",
description="Merge detections and compile final report",
category="output",
io=StageIO(reads=["detections"], writes=["report"]),
config_fields=[],
)
register_stage(report, serialize_fn=_ser_report, deserialize_fn=_deser_report)

View File

@@ -0,0 +1,82 @@
"""Registration for preprocessing stages: frame extraction, scene filter, image preprocessing."""
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
from core.detect.stages.base import register_stage
from ._serializers import serialize_frames, deserialize_frames
def _ser_extract(state: dict, job_id: str) -> dict:
frames = state.get("frames", [])
meta, manifest = serialize_frames(frames, job_id)
return {"frames_meta": meta, "frames_manifest": manifest}
def _deser_extract(data: dict, job_id: str) -> dict:
frames = deserialize_frames(data["frames_meta"], data["frames_manifest"], job_id)
return {"frames": frames}
def _ser_filter(state: dict, job_id: str) -> dict:
filtered = state.get("filtered_frames", [])
seqs = [f.sequence for f in filtered]
return {"filtered_frame_sequences": seqs}
def _deser_filter(data: dict, job_id: str) -> dict:
return {"_filtered_sequences": data["filtered_frame_sequences"]}
def _ser_preprocess(state: dict, job_id: str) -> dict:
# Preprocessed crops are numpy arrays — regenerable from frames + boxes + config
crops = state.get("preprocessed_crops", {})
return {"crop_keys": list(crops.keys()), "count": len(crops)}
def _deser_preprocess(data: dict, job_id: str) -> dict:
# Crops are regenerable — no need to restore from checkpoint
return {"preprocessed_crops": {}}
def register():
extract = StageDefinition(
name="extract_frames",
label="Frame Extraction",
description="Extract frames from video at configurable FPS",
category="preprocessing",
io=StageIO(reads=["video_path"], writes=["frames"]),
config_fields=[
StageConfigField(name="fps", type="float", default=2.0, description="Frames per second", min=0.1, max=30.0),
StageConfigField(name="max_frames", type="int", default=500, description="Maximum frames to extract", min=1, max=10000),
],
)
register_stage(extract, serialize_fn=_ser_extract, deserialize_fn=_deser_extract)
scene_filter = StageDefinition(
name="filter_scenes",
label="Scene Filter",
description="Deduplicate similar frames using perceptual hashing",
category="preprocessing",
io=StageIO(reads=["frames"], writes=["filtered_frames"]),
config_fields=[
StageConfigField(name="hamming_threshold", type="int", default=8, description="Hamming distance threshold", min=0, max=64),
StageConfigField(name="enabled", type="bool", default=True, description="Enable scene filtering"),
],
)
register_stage(scene_filter, serialize_fn=_ser_filter, deserialize_fn=_deser_filter)
preprocess = StageDefinition(
name="preprocess",
label="Preprocess",
description="Image preprocessing on detected regions before OCR",
category="preprocessing",
io=StageIO(
reads=["filtered_frames", "boxes_by_frame"],
writes=["preprocessed_crops"],
),
config_fields=[
StageConfigField(name="contrast", type="bool", default=True, description="CLAHE contrast enhancement"),
StageConfigField(name="deskew", type="bool", default=False, description="Correct slight rotation"),
StageConfigField(name="binarize", type="bool", default=False, description="Otsu binarization"),
],
)
register_stage(preprocess, serialize_fn=_ser_preprocess, deserialize_fn=_deser_preprocess)

View File

@@ -0,0 +1,44 @@
"""Registration for resolution stages: brand resolver."""
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
from core.detect.stages.base import register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
deserialize_brand_detection,
)
def _ser_brands(state: dict, job_id: str) -> dict:
matched = state.get("detections", [])
unresolved = state.get("unresolved_candidates", [])
return {
"detections": serialize_dataclass_list(matched),
"unresolved_candidates": serialize_text_candidates(unresolved),
}
def _deser_brands(data: dict, job_id: str) -> dict:
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
return {
"detections": detections,
"_unresolved_raw": data.get("unresolved_candidates", []),
}
def register():
resolver = StageDefinition(
name="match_brands",
label="Brand Resolver",
description="Match OCR text against known brands (session + global DB)",
category="resolution",
io=StageIO(
reads=["text_candidates"],
writes=["detections", "unresolved_candidates"],
optional_reads=["session_brands", "source_asset_id"],
),
config_fields=[
StageConfigField(name="fuzzy_threshold", type="int", default=75, description="Fuzzy match threshold", min=0, max=100),
],
)
register_stage(resolver, serialize_fn=_ser_brands, deserialize_fn=_deser_brands)

View File

@@ -0,0 +1,86 @@
"""
Stage 2 — Scene Filter
Removes near-duplicate frames using perceptual hashing (pHash).
Frames with a hamming distance below the threshold are considered
duplicates and dropped. This dramatically reduces work for downstream
CV stages without losing unique visual content.
"""
from __future__ import annotations
import time
import imagehash
from PIL import Image
from core.detect import emit
from core.detect.models import Frame
from core.detect.stages.models import SceneFilterConfig
def _compute_hashes(frames: list[Frame]) -> list[imagehash.ImageHash]:
"""Compute perceptual hashes for all frames."""
hashes = []
for f in frames:
img = Image.fromarray(f.image)
h = imagehash.phash(img)
f.perceptual_hash = str(h)
hashes.append(h)
return hashes
def _dedup(frames: list[Frame], hashes: list[imagehash.ImageHash], threshold: int) -> list[Frame]:
"""Greedy dedup: keep a frame if it's sufficiently different from all kept frames."""
kept = [frames[0]]
kept_hashes = [hashes[0]]
for i in range(1, len(frames)):
is_duplicate = any(hashes[i] - kh < threshold for kh in kept_hashes)
if not is_duplicate:
kept.append(frames[i])
kept_hashes.append(hashes[i])
return kept
def scene_filter(
frames: list[Frame],
config: SceneFilterConfig,
job_id: str | None = None,
) -> list[Frame]:
"""
Filter near-duplicate frames based on perceptual hash distance.
Keeps the first frame in each group of similar frames.
Returns a new list — does not mutate the input.
"""
if not config.enabled:
emit.log(job_id, "SceneFilter", "INFO", "Scene filter disabled, passing all frames through")
return frames
if not frames:
return []
emit.log(job_id, "SceneFilter", "INFO",
f"Filtering {len(frames)} frames (hamming_threshold={config.hamming_threshold})")
t0 = time.monotonic()
hashes = _compute_hashes(frames)
hash_ms = (time.monotonic() - t0) * 1000
emit.log(job_id, "SceneFilter", "DEBUG",
f"Computed {len(hashes)} perceptual hashes in {hash_ms:.0f}ms ({hash_ms/max(len(hashes),1):.1f}ms/frame)")
t0 = time.monotonic()
kept = _dedup(frames, hashes, config.hamming_threshold)
dedup_ms = (time.monotonic() - t0) * 1000
emit.log(job_id, "SceneFilter", "DEBUG", f"Dedup pass: {dedup_ms:.0f}ms")
dropped = len(frames) - len(kept)
pct = (dropped / len(frames) * 100) if frames else 0
emit.log(job_id, "SceneFilter", "INFO",
f"Kept {len(kept)} frames, dropped {dropped} ({pct:.0f}% reduction)")
emit.stats(job_id, frames_extracted=len(frames), frames_after_scene_filter=len(kept))
return kept

View File

@@ -0,0 +1,201 @@
"""
Stage 7 — Cloud LLM escalation
Last resort for crops the local VLM couldn't resolve.
Provider-agnostic — switch via CLOUD_LLM_PROVIDER env var.
Each provider has its own file under detect/providers/.
Tracks token usage and cost.
"""
from __future__ import annotations
import base64
import io
import logging
import os
import time
import numpy as np
from PIL import Image
from core.detect import emit
from core.detect.models import BrandDetection, PipelineStats, TextCandidate
from core.detect.models import CropContext
from core.detect.providers import get_provider, has_api_key
logger = logging.getLogger(__name__)
ESTIMATED_TOKENS_PER_CROP = 500
def _register_discovered_brand(brand: str, source_asset_id: str | None,
timestamp: float, confidence: float):
"""Register a cloud-confirmed brand in the DB."""
try:
from core.detect.stages.brand_resolver import _register_brand, _record_sighting
brand_id = _register_brand(brand, "cloud_llm")
if brand_id and source_asset_id:
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, "cloud_llm")
except Exception as e:
logger.debug("Failed to register brand %s: %s", brand, e)
def _encode_crop(crop: np.ndarray) -> str:
img = Image.fromarray(crop)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
return base64.b64encode(buf.getvalue()).decode()
def _crop_image(candidate: TextCandidate) -> np.ndarray:
frame = candidate.frame
box = candidate.bbox
h, w = frame.image.shape[:2]
x1 = max(0, box.x)
y1 = max(0, box.y)
x2 = min(w, box.x + box.w)
y2 = min(h, box.y + box.h)
return frame.image[y1:y2, x1:x2]
def _parse_response(answer: str, total_tokens: int) -> dict:
"""Parse LLM free-text response into structured output."""
parts = [p.strip() for p in answer.split(",", 2)]
brand = parts[0] if parts else ""
confidence = 0.5
reasoning = answer
if len(parts) >= 2:
try:
confidence = float(parts[1])
confidence = max(0.0, min(1.0, confidence))
except ValueError:
pass
if len(parts) >= 3:
reasoning = parts[2]
return {
"brand": brand,
"confidence": confidence,
"reasoning": reasoning,
"tokens": total_tokens or ESTIMATED_TOKENS_PER_CROP,
}
def _call_cloud_api(image_b64: str, prompt: str) -> dict:
"""Route to the configured provider and parse the response."""
provider = get_provider()
result = provider.call(image_b64, prompt)
return _parse_response(result.answer, result.total_tokens)
def escalate_cloud(
candidates: list[TextCandidate],
vlm_prompt_fn,
stats: PipelineStats,
min_confidence: float = 0.4,
content_type: str = "",
source_asset_id: str | None = None,
job_id: str | None = None,
) -> list[BrandDetection]:
"""
Send remaining unresolved crops to cloud LLM.
Provider is selected via CLOUD_LLM_PROVIDER env var (groq, gemini, openai).
Updates stats with call count and cost.
"""
if not candidates:
return []
if os.environ.get("SKIP_CLOUD", "").strip() == "1":
emit.log(job_id, "CloudLLM", "INFO",
f"SKIP_CLOUD=1, skipping {len(candidates)} crops")
return []
if not has_api_key():
emit.log(job_id, "CloudLLM", "WARNING",
f"No API key set for cloud provider, skipping {len(candidates)} crops")
return []
provider = get_provider()
emit.log(job_id, "CloudLLM", "INFO",
f"Escalating {len(candidates)} crops to {provider.name}")
matched: list[BrandDetection] = []
total_cost = 0.0
for i, candidate in enumerate(candidates):
crop = _crop_image(candidate)
if crop.size == 0:
continue
crop_context = CropContext(
image=b"",
surrounding_text=candidate.text,
position_hint=f"frame {candidate.frame.sequence}",
)
prompt = vlm_prompt_fn(crop_context)
image_b64 = _encode_crop(crop)
t0 = time.monotonic()
try:
result = _call_cloud_api(image_b64, prompt)
except Exception as e:
call_ms = (time.monotonic() - t0) * 1000
emit.log(job_id, "CloudLLM", "DEBUG",
f"[{i+1}/{len(candidates)}] FAILED '{candidate.text[:30]}': {e} ({call_ms:.0f}ms)")
continue
call_ms = (time.monotonic() - t0) * 1000
stats.cloud_llm_calls += 1
model_info = provider.models.get(provider.model)
cost_per_token = model_info.cost_per_input_token if model_info else 0.00001
call_cost = result["tokens"] * cost_per_token
total_cost += call_cost
brand = result["brand"]
confidence = result["confidence"]
emit.log(job_id, "CloudLLM", "DEBUG",
f"[{i+1}/{len(candidates)}] '{candidate.text[:30]}'"
f"{'' + brand if brand else ''} "
f"(conf={confidence:.2f}, {result['tokens']}tok, ${call_cost:.4f}, {call_ms:.0f}ms)")
if brand and confidence >= min_confidence:
detection = BrandDetection(
brand=brand,
timestamp=candidate.frame.timestamp,
duration=0.5,
confidence=confidence,
source="cloud_llm",
bbox=candidate.bbox,
frame_ref=candidate.frame.sequence,
content_type=content_type,
)
matched.append(detection)
emit.detection(
job_id,
brand=brand,
confidence=confidence,
source="cloud_llm",
timestamp=candidate.frame.timestamp,
content_type=content_type,
frame_ref=candidate.frame.sequence,
)
# Register newly discovered brand in DB
_register_discovered_brand(brand, source_asset_id,
candidate.frame.timestamp, confidence)
stats.estimated_cloud_cost_usd += total_cost
stats.regions_escalated_to_cloud_llm = len(candidates)
emit.log(job_id, "CloudLLM", "INFO",
f"Cloud resolved {len(matched)}/{len(candidates)}"
f"cost ${total_cost:.4f} ({stats.cloud_llm_calls} calls total)")
return matched

View File

@@ -0,0 +1,157 @@
"""
Stage 6 — Local VLM escalation (moondream2)
Processes unresolved text candidates by sending crop images + prompt
to the local VLM on the inference server. Produces BrandDetection
objects for crops the VLM can identify.
"""
from __future__ import annotations
import logging
import os
import time
import numpy as np
from core.detect import emit
from core.detect.models import BrandDetection, TextCandidate
from core.detect.models import CropContext
logger = logging.getLogger(__name__)
def _register_discovered_brand(brand: str, source_asset_id: str | None,
timestamp: float, confidence: float, source: str):
"""Register a VLM-confirmed brand in the DB."""
try:
from core.detect.stages.brand_resolver import _register_brand, _record_sighting
brand_id = _register_brand(brand, source)
if brand_id and source_asset_id:
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, source)
except Exception as e:
logger.debug("Failed to register brand %s: %s", brand, e)
def _crop_image(candidate: TextCandidate) -> np.ndarray:
frame = candidate.frame
box = candidate.bbox
h, w = frame.image.shape[:2]
x1 = max(0, box.x)
y1 = max(0, box.y)
x2 = min(w, box.x + box.w)
y2 = min(h, box.y + box.h)
return frame.image[y1:y2, x1:x2]
def escalate_vlm(
candidates: list[TextCandidate],
vlm_prompt_fn,
inference_url: str | None = None,
min_confidence: float = 0.5,
content_type: str = "",
source_asset_id: str | None = None,
job_id: str | None = None,
) -> tuple[list[BrandDetection], list[TextCandidate]]:
"""
Send unresolved crops to local VLM for brand identification.
Returns:
- matched: BrandDetections the VLM confirmed
- still_unresolved: candidates the VLM couldn't resolve (→ cloud escalation)
"""
if not candidates:
return [], []
if os.environ.get("SKIP_VLM", "").strip() == "1":
emit.log(job_id, "VLMLocal", "INFO",
f"SKIP_VLM=1, skipping {len(candidates)} crops")
return [], candidates
emit.log(job_id, "VLMLocal", "INFO",
f"Processing {len(candidates)} unresolved crops with moondream2")
matched: list[BrandDetection] = []
still_unresolved: list[TextCandidate] = []
if inference_url:
from core.detect.inference import InferenceClient
from core.detect.emit import _run_log_level
client = InferenceClient(base_url=inference_url, job_id=job_id or "", log_level=_run_log_level)
for i, candidate in enumerate(candidates):
crop = _crop_image(candidate)
if crop.size == 0:
still_unresolved.append(candidate)
continue
crop_context = CropContext(
image=b"", # not used for prompt generation
surrounding_text=candidate.text,
position_hint=f"frame {candidate.frame.sequence}",
)
prompt = vlm_prompt_fn(crop_context)
t0 = time.monotonic()
try:
if inference_url:
result = client.vlm(image=crop, prompt=prompt)
brand = result.brand
confidence = result.confidence
reasoning = result.reasoning
else:
brand, confidence, reasoning = _vlm_local(crop, prompt)
except Exception as e:
vlm_ms = (time.monotonic() - t0) * 1000
emit.log(job_id, "VLMLocal", "DEBUG",
f"[{i+1}/{len(candidates)}] FAILED '{candidate.text[:30]}': {e} ({vlm_ms:.0f}ms)")
still_unresolved.append(candidate)
continue
vlm_ms = (time.monotonic() - t0) * 1000
emit.log(job_id, "VLMLocal", "DEBUG",
f"[{i+1}/{len(candidates)}] '{candidate.text[:30]}'"
f"{'' + brand if brand else '✗ unresolved'} "
f"(conf={confidence:.2f}, {vlm_ms:.0f}ms)")
if brand and confidence >= min_confidence:
detection = BrandDetection(
brand=brand,
timestamp=candidate.frame.timestamp,
duration=0.5,
confidence=confidence,
source="local_vlm",
bbox=candidate.bbox,
frame_ref=candidate.frame.sequence,
content_type=content_type,
)
matched.append(detection)
emit.detection(
job_id,
brand=brand,
confidence=confidence,
source="local_vlm",
timestamp=candidate.frame.timestamp,
content_type=content_type,
frame_ref=candidate.frame.sequence,
)
# Register newly discovered brand in DB
_register_discovered_brand(brand, source_asset_id,
candidate.frame.timestamp, confidence, "local_vlm")
logger.debug("VLM matched: %s (%.2f) — %s", brand, confidence, reasoning)
else:
still_unresolved.append(candidate)
emit.log(job_id, "VLMLocal", "INFO",
f"VLM resolved {len(matched)}, unresolved {len(still_unresolved)} → cloud")
return matched, still_unresolved
def _vlm_local(crop: np.ndarray, prompt: str) -> tuple[str, float, str]:
"""Run moondream2 in-process (single-box mode)."""
from core.gpu.models.vlm import query
result = query(crop, prompt)
return result["brand"], result["confidence"], result["reasoning"]

View File

@@ -0,0 +1,138 @@
"""
Stage 3 — YOLO Object Detection
Detects regions of interest (logos, text, banners) in frames.
Two modes:
- Remote: calls inference server over HTTP (GPU on another machine)
- Local: imports ultralytics directly (GPU on same machine)
Emits frame_update events with bounding boxes for the UI.
"""
from __future__ import annotations
import base64
import io
import logging
import time
from PIL import Image
from core.detect import emit
from core.detect.models import BoundingBox, Frame
from core.detect.stages.models import DetectionConfig
logger = logging.getLogger(__name__)
def _frame_to_b64(frame: Frame) -> str:
"""Encode frame as base64 JPEG for SSE frame_update events."""
img = Image.fromarray(frame.image)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=70)
return base64.b64encode(buf.getvalue()).decode()
def _detect_remote(frame: Frame, config: DetectionConfig, inference_url: str,
job_id: str = "", log_level: str = "INFO") -> list[BoundingBox]:
"""Call the inference server over HTTP."""
from core.detect.inference import InferenceClient
client = InferenceClient(base_url=inference_url, job_id=job_id, log_level=log_level)
results = client.detect(
image=frame.image,
model=config.model_name,
confidence=config.confidence_threshold,
target_classes=config.target_classes,
)
boxes = []
for r in results:
box = BoundingBox(
x=r.x, y=r.y, w=r.w, h=r.h,
confidence=r.confidence, label=r.label,
)
boxes.append(box)
return boxes
def _detect_local(frame: Frame, config: DetectionConfig) -> list[BoundingBox]:
"""Run YOLO in-process (requires ultralytics installed)."""
from ultralytics import YOLO
model = YOLO(config.model_name)
results = model(frame.image, conf=config.confidence_threshold, verbose=False)
boxes = []
for r in results:
for det in r.boxes:
x1, y1, x2, y2 = det.xyxy[0].tolist()
label = r.names[int(det.cls[0])]
if config.target_classes and label not in config.target_classes:
continue
box = BoundingBox(
x=int(x1), y=int(y1),
w=int(x2 - x1), h=int(y2 - y1),
confidence=float(det.conf[0]),
label=label,
)
boxes.append(box)
return boxes
def detect_objects(
frames: list[Frame],
config: DetectionConfig,
inference_url: str | None = None,
job_id: str | None = None,
) -> dict[int, list[BoundingBox]]:
"""
Run object detection on all frames.
If inference_url is provided, calls the remote GPU server.
Otherwise, imports ultralytics and runs locally.
Returns a dict mapping frame sequence → list of bounding boxes.
"""
mode = "remote" if inference_url else "local"
emit.log(job_id, "YOLODetector", "INFO",
f"Detecting objects in {len(frames)} frames "
f"(model={config.model_name}, conf={config.confidence_threshold}, mode={mode})")
all_boxes: dict[int, list[BoundingBox]] = {}
total_regions = 0
for i, frame in enumerate(frames):
t0 = time.monotonic()
if inference_url:
from core.detect.emit import _run_log_level
boxes = _detect_remote(frame, config, inference_url,
job_id=job_id or "", log_level=_run_log_level)
else:
boxes = _detect_local(frame, config)
det_ms = (time.monotonic() - t0) * 1000
all_boxes[frame.sequence] = boxes
total_regions += len(boxes)
emit.log(job_id, "YOLODetector", "DEBUG",
f"Frame {frame.sequence}: {len(boxes)} regions in {det_ms:.0f}ms"
f" [{', '.join(b.label for b in boxes)}]" if boxes else
f"Frame {frame.sequence}: 0 regions in {det_ms:.0f}ms")
if boxes and job_id:
box_dicts = [{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
"confidence": b.confidence, "label": b.label}
for b in boxes]
emit.frame_update(
job_id,
frame_ref=frame.sequence,
timestamp=frame.timestamp,
jpeg_b64=_frame_to_b64(frame),
boxes=box_dicts,
)
emit.log(job_id, "YOLODetector", "INFO",
f"Detected {total_regions} regions across {len(frames)} frames")
emit.stats(job_id, regions_detected=total_regions)
return all_boxes

43
core/detect/state.py Normal file
View File

@@ -0,0 +1,43 @@
"""
LangGraph state definition for the detection pipeline.
This TypedDict flows through all graph nodes. Each node reads what
it needs and writes its outputs. LangGraph manages the state transitions.
"""
from __future__ import annotations
from typing import TypedDict
from core.detect.models import BoundingBox, BrandDetection, DetectionReport, Frame, PipelineStats, TextCandidate
class DetectState(TypedDict, total=False):
# Input
video_path: str
job_id: str
profile_name: str
source_asset_id: str # UUID of the source MediaAsset
# Stage outputs
frames: list[Frame]
filtered_frames: list[Frame]
field_masks: dict # {seq: np.ndarray} — pitch mask per frame
field_boundaries: dict # {seq: [(x,y), ...]} — pitch boundary per frame
field_coverage: dict # {seq: float} — pitch coverage ratio per frame
edge_regions_by_frame: dict[int, list[BoundingBox]]
boxes_by_frame: dict[int, list[BoundingBox]]
preprocessed_crops: dict # "{frame_seq}_{box_idx}" → np.ndarray
text_candidates: list[TextCandidate]
unresolved_candidates: list[TextCandidate]
detections: list[BrandDetection]
report: DetectionReport
# Session brands (accumulated during the run, persisted to DB)
session_brands: dict # {normalized_name: canonical_name}
# Running stats (updated by each stage)
stats: PipelineStats
# Config overrides for replay (merged into profile configs dict)
config_overrides: dict

131
core/detect/tracing.py Normal file
View File

@@ -0,0 +1,131 @@
"""
Langfuse tracing for the detection pipeline.
Provides span helpers that graph nodes use to record timing, frame counts,
and stage-level metadata. The Langfuse client is optional — if not configured
(no LANGFUSE_SECRET_KEY), tracing is a no-op.
Usage in graph nodes:
from core.detect.tracing import trace_node
def node_extract_frames(state):
with trace_node(state, "extract_frames") as span:
...
span.set_output({"frames": len(frames)})
return {...}
"""
from __future__ import annotations
import logging
import os
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
_client = None
_enabled: bool | None = None
def _get_client():
"""Lazy-init Langfuse client. Returns None if not configured."""
global _client, _enabled
if _enabled is False:
return None
if _client is not None:
return _client
secret = os.environ.get("LANGFUSE_SECRET_KEY", "")
if not secret:
_enabled = False
logger.info("Langfuse not configured (no LANGFUSE_SECRET_KEY), tracing disabled")
return None
try:
from langfuse import Langfuse
_client = Langfuse()
_enabled = True
logger.info("Langfuse tracing enabled")
return _client
except Exception as e:
_enabled = False
logger.warning("Langfuse init failed: %s — tracing disabled", e)
return None
@dataclass
class SpanContext:
"""Wraps a Langfuse span with convenience methods."""
_span: object | None = None
_start: float = field(default_factory=time.monotonic)
metadata: dict = field(default_factory=dict)
def set_output(self, output: dict) -> None:
self.metadata.update(output)
def set_error(self, error: str) -> None:
self.metadata["error"] = error
def _finish(self, status: str = "ok") -> None:
elapsed = time.monotonic() - self._start
self.metadata["duration_seconds"] = round(elapsed, 3)
self.metadata["status"] = status
if self._span is not None:
try:
self._span.update(
output=self.metadata,
level="ERROR" if status == "error" else "DEFAULT",
)
self._span.end()
except Exception as e:
logger.debug("Failed to end Langfuse span: %s", e)
@contextmanager
def trace_node(state: dict, node_name: str):
"""
Context manager that creates a Langfuse span for a pipeline node.
Usage:
with trace_node(state, "extract_frames") as span:
frames = do_work()
span.set_output({"frames": len(frames)})
"""
job_id = state.get("job_id", "unknown")
profile = state.get("profile_name", "")
client = _get_client()
span_obj = None
if client is not None:
try:
trace = client.trace(
name=f"detect:{job_id}",
session_id=job_id,
metadata={"profile": profile},
)
span_obj = trace.span(
name=node_name,
input={"job_id": job_id, "profile": profile},
)
except Exception as e:
logger.debug("Failed to create Langfuse span: %s", e)
ctx = SpanContext(_span=span_obj)
try:
yield ctx
ctx._finish("ok")
except Exception:
ctx._finish("error")
raise
def flush():
"""Flush pending Langfuse events. Call at pipeline end."""
if _client is not None:
try:
_client.flush()
except Exception as e:
logger.debug("Langfuse flush failed: %s", e)

View File

@@ -1,13 +1,6 @@
from .capabilities import get_decoders, get_encoders, get_formats
from .probe import ProbeResult, probe_file
from .transcode import TranscodeConfig, transcode
__all__ = [
"probe_file",
"ProbeResult",
"transcode",
"TranscodeConfig",
"get_encoders",
"get_decoders",
"get_formats",
]

View File

@@ -1,145 +0,0 @@
"""
FFmpeg capabilities - Discover available codecs and formats using ffmpeg-python.
"""
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, Dict, List
import ffmpeg
@dataclass
class Codec:
"""An FFmpeg encoder or decoder."""
name: str
description: str
type: str # 'video' or 'audio'
@dataclass
class Format:
"""An FFmpeg format (muxer/demuxer)."""
name: str
description: str
can_demux: bool
can_mux: bool
@lru_cache(maxsize=1)
def _get_ffmpeg_info() -> Dict[str, Any]:
"""Get FFmpeg capabilities info."""
# ffmpeg-python doesn't have a direct way to get codecs/formats
# but we can use probe on a dummy or parse -codecs output
# For now, return common codecs that are typically available
return {
"video_encoders": [
{"name": "libx264", "description": "H.264 / AVC"},
{"name": "libx265", "description": "H.265 / HEVC"},
{"name": "mpeg4", "description": "MPEG-4 Part 2"},
{"name": "libvpx", "description": "VP8"},
{"name": "libvpx-vp9", "description": "VP9"},
{"name": "h264_nvenc", "description": "NVIDIA NVENC H.264"},
{"name": "hevc_nvenc", "description": "NVIDIA NVENC H.265"},
{"name": "h264_vaapi", "description": "VAAPI H.264"},
{"name": "prores_ks", "description": "Apple ProRes"},
{"name": "dnxhd", "description": "Avid DNxHD/DNxHR"},
{"name": "copy", "description": "Stream copy (no encoding)"},
],
"audio_encoders": [
{"name": "aac", "description": "AAC"},
{"name": "libmp3lame", "description": "MP3"},
{"name": "libopus", "description": "Opus"},
{"name": "libvorbis", "description": "Vorbis"},
{"name": "pcm_s16le", "description": "PCM signed 16-bit little-endian"},
{"name": "flac", "description": "FLAC"},
{"name": "copy", "description": "Stream copy (no encoding)"},
],
"formats": [
{"name": "mp4", "description": "MP4", "can_demux": True, "can_mux": True},
{
"name": "mov",
"description": "QuickTime / MOV",
"can_demux": True,
"can_mux": True,
},
{
"name": "mkv",
"description": "Matroska",
"can_demux": True,
"can_mux": True,
},
{"name": "webm", "description": "WebM", "can_demux": True, "can_mux": True},
{"name": "avi", "description": "AVI", "can_demux": True, "can_mux": True},
{"name": "flv", "description": "FLV", "can_demux": True, "can_mux": True},
{
"name": "ts",
"description": "MPEG-TS",
"can_demux": True,
"can_mux": True,
},
{
"name": "mpegts",
"description": "MPEG-TS",
"can_demux": True,
"can_mux": True,
},
{"name": "hls", "description": "HLS", "can_demux": True, "can_mux": True},
],
}
def get_encoders() -> List[Codec]:
"""Get available encoders (video + audio)."""
info = _get_ffmpeg_info()
codecs = []
for c in info["video_encoders"]:
codecs.append(Codec(name=c["name"], description=c["description"], type="video"))
for c in info["audio_encoders"]:
codecs.append(Codec(name=c["name"], description=c["description"], type="audio"))
return codecs
def get_decoders() -> List[Codec]:
"""Get available decoders."""
# Most encoders can also decode
return get_encoders()
def get_formats() -> List[Format]:
"""Get available formats."""
info = _get_ffmpeg_info()
return [
Format(
name=f["name"],
description=f["description"],
can_demux=f["can_demux"],
can_mux=f["can_mux"],
)
for f in info["formats"]
]
def get_video_encoders() -> List[Codec]:
"""Get available video encoders."""
return [c for c in get_encoders() if c.type == "video"]
def get_audio_encoders() -> List[Codec]:
"""Get available audio encoders."""
return [c for c in get_encoders() if c.type == "audio"]
def get_muxers() -> List[Format]:
"""Get available output formats (muxers)."""
return [f for f in get_formats() if f.can_mux]
def get_demuxers() -> List[Format]:
"""Get available input formats (demuxers)."""
return [f for f in get_formats() if f.can_demux]

View File

@@ -1,225 +0,0 @@
"""
FFmpeg transcode module - Transcode media files using ffmpeg-python.
"""
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
import ffmpeg
@dataclass
class TranscodeConfig:
"""Configuration for a transcode operation."""
input_path: str
output_path: str
# Video
video_codec: str = "libx264"
video_bitrate: Optional[str] = None
video_crf: Optional[int] = None
video_preset: Optional[str] = None
resolution: Optional[str] = None
framerate: Optional[float] = None
# Audio
audio_codec: str = "aac"
audio_bitrate: Optional[str] = None
audio_channels: Optional[int] = None
audio_samplerate: Optional[int] = None
# Trimming
trim_start: Optional[float] = None
trim_end: Optional[float] = None
# Container
container: str = "mp4"
# Extra args (key-value pairs)
extra_args: List[str] = field(default_factory=list)
@property
def is_copy(self) -> bool:
"""Check if this is a stream copy (no transcoding)."""
return self.video_codec == "copy" and self.audio_codec == "copy"
def build_stream(config: TranscodeConfig):
"""
Build an ffmpeg-python stream from config.
Returns the stream object ready to run.
"""
# Input options
input_kwargs = {}
if config.trim_start is not None:
input_kwargs["ss"] = config.trim_start
stream = ffmpeg.input(config.input_path, **input_kwargs)
# Output options
output_kwargs = {
"vcodec": config.video_codec,
"acodec": config.audio_codec,
}
# Trimming duration
if config.trim_end is not None:
if config.trim_start is not None:
output_kwargs["t"] = config.trim_end - config.trim_start
else:
output_kwargs["t"] = config.trim_end
# Video options (skip if copy)
if config.video_codec != "copy":
if config.video_crf is not None:
output_kwargs["crf"] = config.video_crf
elif config.video_bitrate:
output_kwargs["video_bitrate"] = config.video_bitrate
if config.video_preset:
output_kwargs["preset"] = config.video_preset
if config.resolution:
output_kwargs["s"] = config.resolution
if config.framerate:
output_kwargs["r"] = config.framerate
# Audio options (skip if copy)
if config.audio_codec != "copy":
if config.audio_bitrate:
output_kwargs["audio_bitrate"] = config.audio_bitrate
if config.audio_channels:
output_kwargs["ac"] = config.audio_channels
if config.audio_samplerate:
output_kwargs["ar"] = config.audio_samplerate
# Parse extra args into kwargs
extra_kwargs = parse_extra_args(config.extra_args)
output_kwargs.update(extra_kwargs)
stream = ffmpeg.output(stream, config.output_path, **output_kwargs)
stream = ffmpeg.overwrite_output(stream)
return stream
def parse_extra_args(extra_args: List[str]) -> Dict[str, Any]:
"""
Parse extra args list into kwargs dict.
["-vtag", "xvid", "-pix_fmt", "yuv420p"] -> {"vtag": "xvid", "pix_fmt": "yuv420p"}
"""
kwargs = {}
i = 0
while i < len(extra_args):
key = extra_args[i].lstrip("-")
if i + 1 < len(extra_args) and not extra_args[i + 1].startswith("-"):
kwargs[key] = extra_args[i + 1]
i += 2
else:
# Flag without value
kwargs[key] = None
i += 1
return kwargs
def transcode(
config: TranscodeConfig,
duration: Optional[float] = None,
progress_callback: Optional[Callable[[float, Dict[str, Any]], None]] = None,
) -> bool:
"""
Transcode a media file.
Args:
config: Transcode configuration
duration: Total duration in seconds (for progress calculation, optional)
progress_callback: Called with (percent, details_dict) - requires duration
Returns:
True if successful
Raises:
ffmpeg.Error: If transcoding fails
"""
# Ensure output directory exists
Path(config.output_path).parent.mkdir(parents=True, exist_ok=True)
stream = build_stream(config)
if progress_callback and duration:
# Run with progress tracking using run_async
return _run_with_progress(stream, config, duration, progress_callback)
else:
# Run synchronously
ffmpeg.run(stream, capture_stdout=True, capture_stderr=True)
return True
def _run_with_progress(
stream,
config: TranscodeConfig,
duration: float,
progress_callback: Callable[[float, Dict[str, Any]], None],
) -> bool:
"""Run FFmpeg with progress tracking using run_async and stderr parsing."""
import re
# Calculate effective duration
effective_duration = duration
if config.trim_start and config.trim_end:
effective_duration = config.trim_end - config.trim_start
elif config.trim_end:
effective_duration = config.trim_end
elif config.trim_start:
effective_duration = duration - config.trim_start
# Run async to get process handle
process = ffmpeg.run_async(stream, pipe_stdout=True, pipe_stderr=True)
# Parse stderr for progress (time=HH:MM:SS.ms pattern)
time_pattern = re.compile(r"time=(\d+):(\d+):(\d+)\.(\d+)")
while True:
line = process.stderr.readline()
if not line:
break
line = line.decode("utf-8", errors="ignore")
match = time_pattern.search(line)
if match:
hours = int(match.group(1))
minutes = int(match.group(2))
seconds = int(match.group(3))
ms = int(match.group(4))
current_time = hours * 3600 + minutes * 60 + seconds + ms / 100
percent = min(100.0, (current_time / effective_duration) * 100)
progress_callback(
percent,
{
"time": current_time,
"percent": percent,
},
)
# Wait for completion
process.wait()
if process.returncode != 0:
raise ffmpeg.Error(
"ffmpeg", stdout=process.stdout.read(), stderr=process.stderr.read()
)
# Final callback
progress_callback(
100.0, {"time": effective_duration, "percent": 100.0, "done": True}
)
return True

18
core/gpu/.env.template Normal file
View File

@@ -0,0 +1,18 @@
# Inference server configuration
HOST=0.0.0.0
PORT=8000
# VRAM management
VRAM_BUDGET_MB=10240
STRATEGY=sequential # sequential | concurrent | auto
# Model defaults
YOLO_MODEL=yolov8n.pt
YOLO_CONFIDENCE=0.3
# OCR
OCR_LANGUAGES=en,es
OCR_MIN_CONFIDENCE=0.5
# Device
DEVICE=auto # auto | cpu | cuda | cuda:0

18
core/gpu/Dockerfile Normal file
View File

@@ -0,0 +1,18 @@
FROM python:3.11-slim
RUN pip install --no-cache-dir uv
RUN apt-get update && apt-get install -y \
libgl1 libglib2.0-0 \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY requirements.txt .
RUN uv pip install --system --no-cache -r requirements.txt
COPY . .
EXPOSE 8000
CMD ["python", "server.py"]

0
core/gpu/__init__.py Normal file
View File

39
core/gpu/config.py Normal file
View File

@@ -0,0 +1,39 @@
"""
Runtime config — loaded from env, mutable via API.
The UI config panel is just a visual editor for these same values.
"""
from __future__ import annotations
import os
_config = {
"device": os.environ.get("DEVICE", "auto"),
"yolo_model": os.environ.get("YOLO_MODEL", "yolov8n.pt"),
"yolo_confidence": float(os.environ.get("YOLO_CONFIDENCE", "0.3")),
"vram_budget_mb": int(os.environ.get("VRAM_BUDGET_MB", "10240")),
"strategy": os.environ.get("STRATEGY", "sequential"),
"ocr_languages": os.environ.get("OCR_LANGUAGES", "en").split(","),
"ocr_min_confidence": float(os.environ.get("OCR_MIN_CONFIDENCE", "0.5")),
}
def get_config() -> dict:
return _config
def update_config(changes: dict) -> dict:
_config.update(changes)
return _config
def get_device() -> str:
device = _config["device"]
if device != "auto":
return device
try:
import torch
return "cuda" if torch.cuda.is_available() else "cpu"
except ImportError:
return "cpu"

52
core/gpu/emit.py Normal file
View File

@@ -0,0 +1,52 @@
"""
Lightweight event emitter for the GPU inference server.
Pushes debug logs to the same Redis stream as the pipeline orchestrator,
so GPU-side details (model load, VRAM, inference timing) appear in the
same log panel.
Only active when the request includes X-Job-Id header.
No dependency on the detect package.
"""
from __future__ import annotations
import json
import os
from datetime import datetime, timezone
import redis
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
EVENTS_PREFIX = "detect_events"
_LEVEL_ORDER = {"DEBUG": 0, "INFO": 1, "WARN": 2, "ERROR": 3}
_redis_client = None
def _get_redis():
global _redis_client
if _redis_client is None:
_redis_client = redis.from_url(REDIS_URL, decode_responses=True)
return _redis_client
def log(job_id: str, stage: str, level: str, msg: str, log_level: str = "INFO"):
"""Push a log event to Redis if the level meets the threshold."""
if not job_id:
return
if _LEVEL_ORDER.get(level.upper(), 1) < _LEVEL_ORDER.get(log_level.upper(), 1):
return
r = _get_redis()
key = f"{EVENTS_PREFIX}:{job_id}"
event = json.dumps({
"event": "log",
"level": level,
"stage": stage,
"msg": msg,
"ts": datetime.now(timezone.utc).isoformat(),
})
r.rpush(key, event)
r.expire(key, 3600)

View File

@@ -0,0 +1,6 @@
# GPU models — standalone container imports.
# When running as a container (cd gpu && python server.py), bare imports work.
# When imported from the main app (core.gpu.models.preprocess), only
# individual modules should be imported directly, not this __init__.
#
# The server.py imports detect/ocr/vlm directly, not through this file.

View File

@@ -0,0 +1 @@
"""CV operations — pure OpenCV, no ML models."""

258
core/gpu/models/cv/edges.py Normal file
View File

@@ -0,0 +1,258 @@
"""
Edge detection — Canny + HoughLinesP → parallel line pairs → bounding boxes.
Finds horizontal line pairs with consistent spacing, which correspond to
the top and bottom edges of advertising hoardings.
"""
from __future__ import annotations
import base64
import io
import cv2
import numpy as np
def detect_edges(
image: np.ndarray,
canny_low: int = 50,
canny_high: int = 150,
hough_threshold: int = 80,
hough_min_length: int = 100,
hough_max_gap: int = 10,
pair_max_distance: int = 200,
pair_min_distance: int = 15,
) -> list[dict]:
"""
Find horizontal line pairs that likely bound advertising hoardings.
Returns list of dicts with keys: x, y, w, h, confidence, label.
Each box represents the region between a detected pair of parallel
horizontal lines.
"""
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
edges = cv2.Canny(gray, canny_low, canny_high)
raw_lines = cv2.HoughLinesP(
edges,
rho=1,
theta=np.pi / 180,
threshold=hough_threshold,
minLineLength=hough_min_length,
maxLineGap=hough_max_gap,
)
if raw_lines is None:
return []
# Filter to near-horizontal lines (within 10 degrees)
horizontals = _filter_horizontal(raw_lines, max_angle_deg=10)
if len(horizontals) < 2:
return []
# Find pairs of parallel horizontals with consistent spacing
pairs = _find_line_pairs(
horizontals,
min_distance=pair_min_distance,
max_distance=pair_max_distance,
)
# Convert pairs to bounding boxes
h, w = image.shape[:2]
results = []
for top_line, bottom_line in pairs:
box = _pair_to_bbox(top_line, bottom_line, frame_width=w, frame_height=h)
if box is not None:
results.append(box)
return results
def _filter_horizontal(lines: np.ndarray, max_angle_deg: float = 10) -> list[tuple]:
"""Keep only lines within max_angle_deg of horizontal."""
max_slope = np.tan(np.radians(max_angle_deg))
result = []
for line in lines:
x1, y1, x2, y2 = line[0]
dx = x2 - x1
if dx == 0:
continue
slope = abs((y2 - y1) / dx)
if slope <= max_slope:
y_mid = (y1 + y2) / 2
x_min = min(x1, x2)
x_max = max(x1, x2)
length = np.sqrt(dx**2 + (y2 - y1) ** 2)
result.append((x_min, x_max, y_mid, length))
return result
def _find_line_pairs(
horizontals: list[tuple],
min_distance: int,
max_distance: int,
) -> list[tuple]:
"""
Find pairs of horizontal lines that could be top/bottom of a hoarding.
Lines must overlap horizontally and be spaced within [min_distance, max_distance].
"""
# Sort by y position
sorted_lines = sorted(horizontals, key=lambda l: l[2])
pairs = []
used = set()
for i, top in enumerate(sorted_lines):
if i in used:
continue
for j, bottom in enumerate(sorted_lines[i + 1 :], start=i + 1):
if j in used:
continue
y_gap = bottom[2] - top[2]
if y_gap < min_distance:
continue
if y_gap > max_distance:
break # sorted by y, no point checking further
# Check horizontal overlap
overlap_start = max(top[0], bottom[0])
overlap_end = min(top[1], bottom[1])
overlap = overlap_end - overlap_start
# Require at least 50% overlap relative to shorter line
shorter_length = min(top[1] - top[0], bottom[1] - bottom[0])
if shorter_length > 0 and overlap / shorter_length >= 0.5:
pairs.append((top, bottom))
used.add(i)
used.add(j)
break
return pairs
def _pair_to_bbox(
top: tuple,
bottom: tuple,
frame_width: int,
frame_height: int,
) -> dict | None:
"""Convert a line pair to a bounding box dict."""
x = int(max(0, min(top[0], bottom[0])))
y = int(max(0, top[2]))
x2 = int(min(frame_width, max(top[1], bottom[1])))
y2 = int(min(frame_height, bottom[2]))
w = x2 - x
h = y2 - y
if w < 20 or h < 5:
return None
# Confidence based on line lengths relative to box width
avg_line_length = (top[3] + bottom[3]) / 2
coverage = min(1.0, avg_line_length / max(w, 1))
return {
"x": x,
"y": y,
"w": w,
"h": h,
"confidence": round(coverage, 3),
"label": "edge_region",
}
def _np_to_b64_jpeg(image: np.ndarray, quality: int = 70) -> str:
"""Encode a numpy image (BGR or grayscale) as base64 JPEG."""
ok, buf = cv2.imencode(".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality])
if not ok:
return ""
return base64.b64encode(buf.tobytes()).decode()
def detect_edges_debug(
image: np.ndarray,
canny_low: int = 50,
canny_high: int = 150,
hough_threshold: int = 80,
hough_min_length: int = 100,
hough_max_gap: int = 10,
pair_max_distance: int = 200,
pair_min_distance: int = 15,
) -> dict:
"""
Same as detect_edges but returns intermediate visualizations.
Returns dict with:
regions: list[dict] — same boxes as detect_edges
edge_overlay_b64: str — Canny edge image as base64 JPEG
lines_overlay_b64: str — frame with Hough lines drawn
horizontal_count: int — number of horizontal lines found
pair_count: int — number of line pairs found
"""
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
edges = cv2.Canny(gray, canny_low, canny_high)
# Edge overlay — Canny output as-is (white edges on black)
edge_overlay_b64 = _np_to_b64_jpeg(edges)
raw_lines = cv2.HoughLinesP(
edges,
rho=1,
theta=np.pi / 180,
threshold=hough_threshold,
minLineLength=hough_min_length,
maxLineGap=hough_max_gap,
)
# Lines overlay — draw all Hough lines on a copy of the frame
lines_vis = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
if raw_lines is not None:
for line in raw_lines:
x1, y1, x2, y2 = line[0]
cv2.line(lines_vis, (x1, y1), (x2, y2), (0, 0, 255), 1)
horizontals = []
if raw_lines is not None:
horizontals = _filter_horizontal(raw_lines, max_angle_deg=10)
# Draw horizontal lines in cyan, thicker
for h_line in horizontals:
x_min, x_max, y_mid, _ = h_line
cv2.line(lines_vis, (int(x_min), int(y_mid)), (int(x_max), int(y_mid)), (255, 255, 0), 2)
pairs = []
if len(horizontals) >= 2:
pairs = _find_line_pairs(
horizontals,
min_distance=pair_min_distance,
max_distance=pair_max_distance,
)
# Draw paired lines in green
for top_line, bottom_line in pairs:
cv2.line(lines_vis, (int(top_line[0]), int(top_line[2])),
(int(top_line[1]), int(top_line[2])), (0, 255, 0), 2)
cv2.line(lines_vis, (int(bottom_line[0]), int(bottom_line[2])),
(int(bottom_line[1]), int(bottom_line[2])), (0, 255, 0), 2)
lines_overlay_b64 = _np_to_b64_jpeg(lines_vis)
# Build region boxes (same logic as detect_edges)
h, w = image.shape[:2]
regions = []
for top_line, bottom_line in pairs:
box = _pair_to_bbox(top_line, bottom_line, frame_width=w, frame_height=h)
if box is not None:
regions.append(box)
return {
"regions": regions,
"edge_overlay_b64": edge_overlay_b64,
"lines_overlay_b64": lines_overlay_b64,
"horizontal_count": len(horizontals),
"pair_count": len(pairs),
}

View File

@@ -0,0 +1,86 @@
"""
Field segmentation — HSV green mask → pitch boundary contour.
Pure OpenCV. Called by the inference server endpoint.
"""
from __future__ import annotations
import base64
import cv2
import numpy as np
def segment_field(
image: np.ndarray,
hue_low: int = 30,
hue_high: int = 85,
sat_low: int = 30,
sat_high: int = 255,
val_low: int = 30,
val_high: int = 255,
morph_kernel: int = 15,
min_area_ratio: float = 0.05,
) -> dict:
"""
Detect the pitch area using HSV green thresholding.
Returns dict with:
boundary: list of [x, y] points
coverage: float (fraction of frame)
"""
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
lower = np.array([hue_low, sat_low, val_low])
upper = np.array([hue_high, sat_high, val_high])
mask = cv2.inRange(hsv, lower, upper)
k = morph_kernel
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
h, w = image.shape[:2]
min_area = min_area_ratio * h * w
boundary = []
coverage = 0.0
if contours:
large = [c for c in contours if cv2.contourArea(c) >= min_area]
if large:
pitch_contour = max(large, key=cv2.contourArea)
boundary = pitch_contour.reshape(-1, 2).tolist()
coverage = cv2.contourArea(pitch_contour) / (h * w)
refined = np.zeros_like(mask)
cv2.drawContours(refined, [pitch_contour], -1, 255, cv2.FILLED)
mask = refined
return {
"boundary": boundary,
"coverage": coverage,
"mask": mask,
}
def segment_field_debug(
image: np.ndarray,
**kwargs,
) -> dict:
"""Same as segment_field but includes a mask overlay for the editor."""
result = segment_field(image, **kwargs)
mask = result["mask"]
# RGBA overlay: solid green where mask, fully transparent elsewhere
h, w = image.shape[:2]
overlay = np.zeros((h, w, 4), dtype=np.uint8)
overlay[mask > 0] = [0, 255, 0, 255]
_, buf = cv2.imencode(".png", overlay)
result["mask_overlay_b64"] = base64.b64encode(buf.tobytes()).decode()
# Don't send the raw mask over HTTP
del result["mask"]
return result

136
core/gpu/models/models.py Normal file
View File

@@ -0,0 +1,136 @@
"""
Pydantic Models - GENERATED FILE
Do not edit directly. Regenerate using modelgen.
"""
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel, Field
class DetectRequest(BaseModel):
"""Request body for object detection."""
image: str
model: Optional[str] = None
confidence: Optional[float] = None
target_classes: Optional[List[str]] = None
class BBox(BaseModel):
"""A detected bounding box."""
x: int
y: int
w: int
h: int
confidence: float
label: str
class DetectResponse(BaseModel):
"""Response from object detection."""
detections: List[BBox] = Field(default_factory=list)
class OCRRequest(BaseModel):
"""Request body for OCR."""
image: str
languages: Optional[List[str]] = None
class OCRTextResult(BaseModel):
"""A single OCR text extraction result."""
text: str
confidence: float
bbox: List[int] = Field(default_factory=list)
class OCRResponse(BaseModel):
"""Response from OCR."""
results: List[OCRTextResult] = Field(default_factory=list)
class PreprocessRequest(BaseModel):
"""Request body for image preprocessing."""
image: str
binarize: bool = False
deskew: bool = False
contrast: bool = True
class PreprocessResponse(BaseModel):
"""Response from preprocessing."""
image: str
class VLMRequest(BaseModel):
"""Request body for visual language model query."""
image: str
prompt: str
model: Optional[str] = None
class VLMResponse(BaseModel):
"""Response from VLM."""
brand: str
confidence: float
reasoning: str
class AnalyzeRegionsRequest(BaseModel):
"""Request body for CV region analysis."""
image: str
edge_canny_low: int = 50
edge_canny_high: int = 150
edge_hough_threshold: int = 80
edge_hough_min_length: int = 100
edge_hough_max_gap: int = 10
edge_pair_max_distance: int = 200
edge_pair_min_distance: int = 15
class RegionBox(BaseModel):
"""A candidate region from CV analysis."""
x: int
y: int
w: int
h: int
confidence: float
label: str
class AnalyzeRegionsResponse(BaseModel):
"""Response from CV region analysis."""
regions: List[RegionBox] = Field(default_factory=list)
class AnalyzeRegionsDebugResponse(BaseModel):
"""Response from CV region analysis with debug overlays."""
regions: List[RegionBox] = Field(default_factory=list)
edge_overlay_b64: str = ""
lines_overlay_b64: str = ""
horizontal_count: int = 0
pair_count: int = 0
class SegmentFieldRequest(BaseModel):
"""Request body for field segmentation."""
image: str
hue_low: int = 30
hue_high: int = 85
sat_low: int = 30
sat_high: int = 255
val_low: int = 30
val_high: int = 255
morph_kernel: int = 15
min_area_ratio: float = 0.05
class SegmentFieldResponse(BaseModel):
"""Response from field segmentation."""
boundary: List[List[int]] = Field(default_factory=list)
coverage: float = 0.0
mask_b64: str = ""
class SegmentFieldDebugResponse(BaseModel):
"""Response from field segmentation with debug overlay."""
boundary: List[List[int]] = Field(default_factory=list)
coverage: float = 0.0
mask_overlay_b64: str = ""
class ConfigUpdate(BaseModel):
"""Request body for updating server configuration."""
device: Optional[str] = None
yolo_model: Optional[str] = None
yolo_confidence: Optional[float] = None
vram_budget_mb: Optional[int] = None
strategy: Optional[str] = None
ocr_languages: Optional[List[str]] = None
ocr_min_confidence: Optional[float] = None

105
core/gpu/models/ocr.py Normal file
View File

@@ -0,0 +1,105 @@
"""PaddleOCR 3.x text extraction wrapper."""
from __future__ import annotations
import logging
from models import registry
from config import get_config
logger = logging.getLogger(__name__)
def _load(languages: list[str]):
from paddleocr import PaddleOCR
key = f"ocr_{'_'.join(languages)}"
model = PaddleOCR(lang=languages[0])
registry.put(key, model)
return model
def _get(languages: list[str] | None = None):
langs = languages or get_config()["ocr_languages"]
key = f"ocr_{'_'.join(langs)}"
model = registry.get(key)
if model is None:
model = _load(langs)
return model
def _parse_raw(raw) -> list[tuple[list, str, float]]:
"""
Parse PaddleOCR output into (points, text, confidence) tuples.
PaddleOCR 3.x changed the result format. Two known layouts:
Layout A — dict-based (new pipeline API):
raw = [{'rec_texts': [...], 'rec_scores': [...], 'dt_polys': [...]}]
Layout B — nested list (2.x compat / some 3.x builds):
raw = [[ [points, [text, score]], ... ]]
raw = [[ [points, [text, score], [cls, cls_score]], ... ]] # with angle cls
"""
results = []
for page in raw:
if not page:
continue
# Layout A: dict with parallel lists
if isinstance(page, dict):
texts = page.get("rec_texts", [])
scores = page.get("rec_scores", [])
polys = page.get("dt_polys", [])
for points, text, confidence in zip(polys, texts, scores):
results.append((points, text, float(confidence)))
continue
# Layout B: list of per-line entries
for line in page:
if not line:
continue
# line[0] is always the polygon points
points = line[0]
# line[1] is [text, score] — ignore any extra elements (angle cls etc.)
rec = line[1]
if isinstance(rec, (list, tuple)) and len(rec) >= 2:
text, confidence = rec[0], rec[1]
else:
logger.warning("Unexpected OCR line format: %s", line)
continue
results.append((points, str(text), float(confidence)))
return results
def ocr(image, languages: list[str] | None = None, min_confidence: float | None = None) -> list[dict]:
"""Run OCR on an image, return list of text result dicts."""
cfg = get_config()
min_conf = min_confidence if min_confidence is not None else cfg["ocr_min_confidence"]
model = _get(languages)
raw = model.ocr(image)
logger.debug("OCR raw: %s", raw)
parsed = _parse_raw(raw)
results = []
for points, text, confidence in parsed:
if confidence < min_conf:
continue
xs = [p[0] for p in points]
ys = [p[1] for p in points]
results.append({
"text": text,
"confidence": confidence,
"bbox": [int(min(xs)), int(min(ys)),
int(max(xs) - min(xs)), int(max(ys) - min(ys))],
})
return results

View File

@@ -0,0 +1,117 @@
"""
Image preprocessing pipeline for crops before OCR.
Each step is independently toggleable via config.
Operates on numpy arrays (BGR or RGB), returns processed array.
"""
from __future__ import annotations
import logging
import numpy as np
logger = logging.getLogger(__name__)
def binarize(image: np.ndarray, threshold: int = 128) -> np.ndarray:
"""Convert to grayscale and apply Otsu binarization."""
import cv2
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
gray = image
_, binary = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# Convert back to 3-channel for downstream compatibility
result = cv2.cvtColor(binary, cv2.COLOR_GRAY2RGB)
return result
def deskew(image: np.ndarray) -> np.ndarray:
"""Correct slight rotation using minimum area rectangle."""
import cv2
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
gray = image
coords = np.column_stack(np.where(gray < 128))
if len(coords) < 10:
return image
rect = cv2.minAreaRect(coords)
angle = rect[-1]
# Normalize angle
if angle < -45:
angle = -(90 + angle)
else:
angle = -angle
if abs(angle) < 0.5:
return image
h, w = image.shape[:2]
center = (w // 2, h // 2)
rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
result = cv2.warpAffine(
image, rotation_matrix, (w, h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REPLICATE,
)
return result
def enhance_contrast(image: np.ndarray) -> np.ndarray:
"""Apply CLAHE (adaptive histogram equalization) for contrast normalization."""
import cv2
if len(image.shape) == 3:
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
l_channel = lab[:, :, 0]
else:
l_channel = image
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(l_channel)
if len(image.shape) == 3:
lab[:, :, 0] = enhanced
result = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
else:
result = enhanced
return result
def preprocess(
image: np.ndarray,
do_binarize: bool = False,
do_deskew: bool = False,
do_contrast: bool = True,
) -> np.ndarray:
"""
Run the preprocessing pipeline on a crop image.
Each step is independently toggleable. Order: contrast → deskew → binarize.
Contrast first (works best on color), binarize last (destroys color info).
"""
result = image
if do_contrast:
result = enhance_contrast(result)
logger.debug("Preprocessing: contrast enhanced")
if do_deskew:
result = deskew(result)
logger.debug("Preprocessing: deskewed")
if do_binarize:
result = binarize(result)
logger.debug("Preprocessing: binarized")
return result

View File

@@ -0,0 +1,37 @@
"""
Model registry — manages loaded models and VRAM lifecycle.
"""
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
_models: dict[str, object] = {}
def get(name: str) -> object | None:
return _models.get(name)
def put(name: str, model: object) -> None:
_models[name] = model
logger.info("Loaded %s", name)
def unload(name: str) -> bool:
if name in _models:
del _models[name]
logger.info("Unloaded %s", name)
return True
return False
def loaded() -> list[str]:
return list(_models.keys())
def clear() -> None:
_models.clear()
logger.info("All models unloaded")

100
core/gpu/models/vlm.py Normal file
View File

@@ -0,0 +1,100 @@
"""moondream2 visual language model wrapper."""
from __future__ import annotations
import logging
from models import registry
from config import get_config
logger = logging.getLogger(__name__)
_MODEL_KEY = "vlm_moondream2"
def _load():
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = get_config().get("device", "auto")
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Loading moondream2 (device=%s)...", device)
model_id = "vikhyatk/moondream2"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
dtype = torch.float16 if "cuda" in device else torch.float32
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
dtype=dtype,
device_map=device,
)
wrapper = {"model": model, "tokenizer": tokenizer}
registry.put(_MODEL_KEY, wrapper)
logger.info("moondream2 loaded")
return wrapper
def _get():
wrapper = registry.get(_MODEL_KEY)
if wrapper is None:
wrapper = _load()
return wrapper
def query(image, prompt: str) -> dict:
"""
Query moondream2 with an image crop and prompt.
Returns {"brand": str, "confidence": float, "reasoning": str}
"""
from PIL import Image as PILImage
wrapper = _get()
model = wrapper["model"]
tokenizer = wrapper["tokenizer"]
# Convert numpy array to PIL if needed
if not isinstance(image, PILImage.Image):
image = PILImage.fromarray(image)
enc_image = model.encode_image(image)
answer = model.answer_question(enc_image, prompt, tokenizer)
# Parse response — moondream2 returns free text, extract brand + confidence
result = _parse_vlm_response(answer)
return result
def _parse_vlm_response(answer: str) -> dict:
"""
Parse moondream2 free-text response into structured output.
Expected format from prompt: "brand, confidence (0-1), reasoning"
Falls back gracefully if format doesn't match.
"""
answer = answer.strip()
parts = [p.strip() for p in answer.split(",", 2)]
brand = parts[0] if parts else ""
confidence = 0.5
reasoning = answer
if len(parts) >= 2:
try:
confidence = float(parts[1])
confidence = max(0.0, min(1.0, confidence))
except ValueError:
pass
if len(parts) >= 3:
reasoning = parts[2]
return {
"brand": brand,
"confidence": confidence,
"reasoning": reasoning,
}

54
core/gpu/models/yolo.py Normal file
View File

@@ -0,0 +1,54 @@
"""YOLO object detection model wrapper."""
from __future__ import annotations
import logging
from models import registry
from config import get_config, get_device
logger = logging.getLogger(__name__)
def _load(model_name: str):
from ultralytics import YOLO
device = get_device()
model = YOLO(model_name)
model.to(device)
registry.put(model_name, model)
return model
def _get(model_name: str | None = None):
name = model_name or get_config()["yolo_model"]
model = registry.get(name)
if model is None:
model = _load(name)
return model
def detect(image, model_name: str | None = None, confidence: float | None = None, target_classes: list[str] | None = None) -> list[dict]:
"""Run YOLO detection, return list of bbox dicts."""
cfg = get_config()
conf = confidence if confidence is not None else cfg["yolo_confidence"]
model = _get(model_name)
results = model(image, conf=conf, verbose=False)
detections = []
for r in results:
for box in r.boxes:
x1, y1, x2, y2 = box.xyxy[0].tolist()
label = r.names[int(box.cls[0])]
if target_classes and label not in target_classes:
continue
detections.append({
"x": int(x1), "y": int(y1),
"w": int(x2 - x1), "h": int(y2 - y1),
"confidence": float(box.conf[0]),
"label": label,
})
return detections

31
core/gpu/requirements.txt Normal file
View File

@@ -0,0 +1,31 @@
fastapi>=0.109.0
uvicorn[standard]>=0.27.0
rapidfuzz>=3.0.0
Pillow>=10.0.0
redis>=5.0.0
# --- GPU-specific installs (mcrn: RTX 3080, CUDA toolkit 12.8) ---
#
# torch: must be installed from the PyTorch index, NOT from PyPI.
# cu126 is the closest build to CUDA 12.8 (no cu128 wheel yet; cu126 is forward-compatible).
# Install with:
# uv pip install --reinstall torch torchvision --index-url https://download.pytorch.org/whl/cu126
#
# ultralytics pulls torch as a dependency — reinstall torch after ultralytics to ensure
# the correct CUDA build. Mixing the PyPI torch with CUDA 12.8 causes NCCL symbol errors.
ultralytics>=8.0.0
# paddlepaddle-gpu: NOT available on PyPI. Install from PaddlePaddle's package index.
# cu126 build works on CUDA 12.8.
# Install with:
# uv pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
paddleocr>=3.0.0
# VLM (moondream2) — uses torch (already installed above)
# Pinned <5: transformers 5.x broke moondream2's custom model code
# (all_tied_weights_keys API change). Also needs accelerate for device_map.
transformers>=4.40.0,<5
accelerate>=0.27.0
# Preprocessing (phase 12)
opencv-python-headless>=4.8.0

54
core/gpu/run.sh Executable file
View File

@@ -0,0 +1,54 @@
#!/bin/bash
# Run the inference server
#
# Usage:
# ./run.sh # Local (pip install -r requirements.txt first)
# ./run.sh docker # Docker (CPU)
# ./run.sh docker-gpu # Docker with GPU
# ./run.sh stop # Stop Docker container
set -e
cd "$(dirname "${BASH_SOURCE[0]}")"
# Load env (create from template if missing)
if [ ! -f .env ]; then
if [ -f .env.template ]; then
cp .env.template .env
echo "Created .env from template — edit as needed"
fi
fi
if [ -f .env ]; then
set -a
source .env
set +a
fi
case "${1:-local}" in
local)
python server.py
;;
docker)
docker build -t mpr-inference .
ENV_FLAG=""; [ -f .env ] && ENV_FLAG="--env-file .env"
docker run --rm -p "${PORT:-8000}:8000" \
$ENV_FLAG \
--name mpr-inference \
mpr-inference
;;
docker-gpu)
docker build -t mpr-inference .
ENV_FLAG=""; [ -f .env ] && ENV_FLAG="--env-file .env"
docker run --rm --gpus all -p "${PORT:-8000}:8000" \
$ENV_FLAG \
--name mpr-inference \
mpr-inference
;;
stop)
docker stop mpr-inference 2>/dev/null || true
;;
*)
echo "Usage: ./run.sh [local|docker|docker-gpu|stop]"
exit 1
;;
esac

399
core/gpu/server.py Normal file
View File

@@ -0,0 +1,399 @@
"""
Inference server — thin HTTP routes over model wrappers.
Config lives in config.py, model logic in models/.
This file is just the FastAPI glue.
Usage:
cd gpu && python server.py
"""
from __future__ import annotations
import base64
import io
import logging
import os
import time
from contextlib import asynccontextmanager
import numpy as np
from fastapi import FastAPI, HTTPException, Request
from PIL import Image
from pydantic import BaseModel
from emit import log as emit_log
from config import get_config, get_device, update_config
from models import registry
from models.yolo import detect as yolo_detect
from models.ocr import ocr as ocr_run
from models.vlm import query as vlm_query
logger = logging.getLogger(__name__)
def _decode_image(b64: str) -> np.ndarray:
data = base64.b64decode(b64)
img = Image.open(io.BytesIO(data)).convert("RGB")
return np.array(img)
def _job_ctx(request: Request) -> tuple[str, str]:
"""Extract job_id and log_level from request headers."""
job_id = request.headers.get("x-job-id", "")
log_level = request.headers.get("x-log-level", "INFO")
return job_id, log_level
def _gpu_log(job_id: str, log_level: str, stage: str, level: str, msg: str):
"""Emit a log event if job context is present."""
if job_id:
emit_log(job_id, stage, level, msg, log_level=log_level)
# --- Request/Response models (generated from core/schema/models/inference.py) ---
from models.models import (
AnalyzeRegionsDebugResponse,
AnalyzeRegionsRequest,
AnalyzeRegionsResponse,
BBox,
ConfigUpdate,
DetectRequest,
DetectResponse,
OCRRequest,
OCRResponse,
OCRTextResult,
PreprocessRequest,
PreprocessResponse,
RegionBox,
SegmentFieldRequest,
SegmentFieldResponse,
SegmentFieldDebugResponse,
VLMRequest,
VLMResponse,
)
# --- App ---
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Inference server starting (device=%s)", get_device())
yield
logger.info("Shutting down")
registry.clear()
app = FastAPI(title="MPR Inference Server", lifespan=lifespan)
@app.get("/health")
def health():
cfg = get_config()
return {
"status": "ok",
"device": get_device(),
"loaded_models": registry.loaded(),
"vram_budget_mb": cfg["vram_budget_mb"],
"strategy": cfg["strategy"],
}
@app.get("/config")
def read_config():
return {**get_config(), "device_resolved": get_device()}
@app.put("/config")
def write_config(update: ConfigUpdate):
changes = update.model_dump(exclude_none=True)
if not changes:
return get_config()
# Unload model if it changed
old_model = get_config().get("yolo_model")
if "yolo_model" in changes and changes["yolo_model"] != old_model:
registry.unload(old_model)
update_config(changes)
logger.info("Config updated: %s", changes)
return {**get_config(), "device_resolved": get_device()}
@app.post("/models/unload")
def unload_model(body: dict):
name = body.get("model", "")
unloaded = registry.unload(name)
return {"status": "unloaded" if unloaded else "not_loaded", "model": name}
@app.post("/detect", response_model=DetectResponse)
def detect(req: DetectRequest, request: Request):
job_id, log_level = _job_ctx(request)
try:
t0 = time.monotonic()
image = _decode_image(req.image)
decode_ms = (time.monotonic() - t0) * 1000
h, w = image.shape[:2]
_gpu_log(job_id, log_level, "GPU:YOLO", "DEBUG",
f"Decoded {w}x{h} image in {decode_ms:.0f}ms")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
t0 = time.monotonic()
results = yolo_detect(
image,
model_name=req.model,
confidence=req.confidence,
target_classes=req.target_classes,
)
infer_ms = (time.monotonic() - t0) * 1000
_gpu_log(job_id, log_level, "GPU:YOLO", "DEBUG",
f"Inference: {len(results)} detections in {infer_ms:.0f}ms "
f"(model={req.model}, conf={req.confidence})")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Detection failed: {e}")
return DetectResponse(detections=[BBox(**r) for r in results])
@app.post("/ocr", response_model=OCRResponse)
def ocr(req: OCRRequest, request: Request):
job_id, log_level = _job_ctx(request)
try:
image = _decode_image(req.image)
h, w = image.shape[:2]
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
t0 = time.monotonic()
results = ocr_run(image, languages=req.languages)
infer_ms = (time.monotonic() - t0) * 1000
texts = [r["text"][:20] for r in results]
_gpu_log(job_id, log_level, "GPU:OCR", "DEBUG",
f"OCR {w}x{h}: {infer_ms:.0f}ms → {len(results)} results {texts}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"OCR failed: {e}")
return OCRResponse(results=[OCRTextResult(**r) for r in results])
@app.post("/preprocess", response_model=PreprocessResponse)
def preprocess_image(req: PreprocessRequest):
try:
image = _decode_image(req.image)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
from models.preprocess import preprocess
processed = preprocess(
image,
do_binarize=req.binarize,
do_deskew=req.deskew,
do_contrast=req.contrast,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Preprocessing failed: {e}")
from PIL import Image as PILImage
import io
img = PILImage.fromarray(processed)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=90)
result_b64 = base64.b64encode(buf.getvalue()).decode()
return PreprocessResponse(image=result_b64)
@app.post("/vlm", response_model=VLMResponse)
def vlm(req: VLMRequest, request: Request):
job_id, log_level = _job_ctx(request)
try:
image = _decode_image(req.image)
h, w = image.shape[:2]
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
t0 = time.monotonic()
result = vlm_query(image, req.prompt)
infer_ms = (time.monotonic() - t0) * 1000
_gpu_log(job_id, log_level, "GPU:VLM", "DEBUG",
f"VLM {w}x{h}: {infer_ms:.0f}ms → "
f"brand='{result.get('brand', '')}' conf={result.get('confidence', 0):.2f}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"VLM failed: {e}")
return VLMResponse(**result)
@app.post("/detect_edges", response_model=AnalyzeRegionsResponse)
def detect_edges_endpoint(req: AnalyzeRegionsRequest, request: Request):
job_id, log_level = _job_ctx(request)
try:
image = _decode_image(req.image)
h, w = image.shape[:2]
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
t0 = time.monotonic()
from models.cv.edges import detect_edges
edge_regions = detect_edges(
image,
canny_low=req.edge_canny_low,
canny_high=req.edge_canny_high,
hough_threshold=req.edge_hough_threshold,
hough_min_length=req.edge_hough_min_length,
hough_max_gap=req.edge_hough_max_gap,
pair_max_distance=req.edge_pair_max_distance,
pair_min_distance=req.edge_pair_min_distance,
)
infer_ms = (time.monotonic() - t0) * 1000
_gpu_log(job_id, log_level, "GPU:CV", "DEBUG",
f"Edge analysis {w}x{h}: {infer_ms:.0f}ms → {len(edge_regions)} regions")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Region analysis failed: {e}")
boxes = [RegionBox(**r) for r in edge_regions]
return AnalyzeRegionsResponse(regions=boxes)
@app.post("/detect_edges/debug", response_model=AnalyzeRegionsDebugResponse)
def detect_edges_debug_endpoint(req: AnalyzeRegionsRequest, request: Request):
job_id, log_level = _job_ctx(request)
try:
image = _decode_image(req.image)
h, w = image.shape[:2]
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
t0 = time.monotonic()
from models.cv.edges import detect_edges_debug
result = detect_edges_debug(
image,
canny_low=req.edge_canny_low,
canny_high=req.edge_canny_high,
hough_threshold=req.edge_hough_threshold,
hough_min_length=req.edge_hough_min_length,
hough_max_gap=req.edge_hough_max_gap,
pair_max_distance=req.edge_pair_max_distance,
pair_min_distance=req.edge_pair_min_distance,
)
infer_ms = (time.monotonic() - t0) * 1000
_gpu_log(job_id, log_level, "GPU:CV", "DEBUG",
f"Edge debug {w}x{h}: {infer_ms:.0f}ms → {len(result['regions'])} regions, "
f"{result['horizontal_count']} horizontals, {result['pair_count']} pairs")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Region debug analysis failed: {e}")
boxes = [RegionBox(**r) for r in result["regions"]]
response = AnalyzeRegionsDebugResponse(
regions=boxes,
edge_overlay_b64=result["edge_overlay_b64"],
lines_overlay_b64=result["lines_overlay_b64"],
horizontal_count=result["horizontal_count"],
pair_count=result["pair_count"],
)
return response
@app.post("/segment_field", response_model=SegmentFieldResponse)
def segment_field_endpoint(req: SegmentFieldRequest, request: Request):
job_id, log_level = _job_ctx(request)
try:
image = _decode_image(req.image)
h, w = image.shape[:2]
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
t0 = time.monotonic()
from models.cv.segmentation import segment_field
result = segment_field(
image,
hue_low=req.hue_low,
hue_high=req.hue_high,
sat_low=req.sat_low,
sat_high=req.sat_high,
val_low=req.val_low,
val_high=req.val_high,
morph_kernel=req.morph_kernel,
min_area_ratio=req.min_area_ratio,
)
infer_ms = (time.monotonic() - t0) * 1000
# Encode mask as base64 PNG for downstream use
import cv2
_, buf = cv2.imencode(".png", result["mask"])
mask_b64 = base64.b64encode(buf.tobytes()).decode()
_gpu_log(job_id, log_level, "GPU:CV", "DEBUG",
f"Field segmentation {w}x{h}: {infer_ms:.0f}ms, coverage={result['coverage']:.1%}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Field segmentation failed: {e}")
return SegmentFieldResponse(
boundary=result["boundary"],
coverage=result["coverage"],
mask_b64=mask_b64,
)
@app.post("/segment_field/debug", response_model=SegmentFieldDebugResponse)
def segment_field_debug_endpoint(req: SegmentFieldRequest, request: Request):
job_id, log_level = _job_ctx(request)
try:
image = _decode_image(req.image)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
from models.cv.segmentation import segment_field_debug
result = segment_field_debug(
image,
hue_low=req.hue_low,
hue_high=req.hue_high,
sat_low=req.sat_low,
sat_high=req.sat_high,
val_low=req.val_low,
val_high=req.val_high,
morph_kernel=req.morph_kernel,
min_area_ratio=req.min_area_ratio,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Field segmentation debug failed: {e}")
return SegmentFieldDebugResponse(
boundary=result["boundary"],
coverage=result["coverage"],
mask_overlay_b64=result["mask_overlay_b64"],
)
if __name__ == "__main__":
import uvicorn
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s%(message)s")
host = os.environ.get("HOST", "0.0.0.0")
port = int(os.environ.get("PORT", "8000"))
uvicorn.run(app, host=host, port=port)

View File

@@ -180,24 +180,11 @@ class WorkerServicer(worker_pb2_grpc.WorkerServiceServicer):
def GetWorkerStatus(self, request, context):
"""Get worker health and capabilities."""
try:
from core.ffmpeg import get_encoders
encoders = get_encoders()
codec_names = [e["name"] for e in encoders.get("video", [])]
except Exception:
codec_names = []
# Check for GPU encoders
gpu_available = any(
"nvenc" in name or "vaapi" in name or "qsv" in name for name in codec_names
)
return worker_pb2.WorkerStatus(
available=True,
active_jobs=len(_active_jobs),
supported_codecs=codec_names[:20], # Limit to 20
gpu_available=gpu_available,
supported_codecs=[],
gpu_available=False,
)

View File

@@ -28,7 +28,7 @@
},
{
"target": "pydantic",
"output": "detect/sse_contract.py",
"output": "core/detect/sse.py",
"include": ["detect_views"]
},
{
@@ -43,8 +43,13 @@
},
{
"target": "pydantic",
"output": "gpu/models/inference_contract.py",
"output": "core/gpu/models/models.py",
"include": ["inference_views"]
},
{
"target": "pydantic",
"output": "core/detect/stages/models.py",
"include": ["stage_views"]
}
]
}

Some files were not shown because too many files have changed in this diff Show More