phase 4
This commit is contained in:
@@ -102,6 +102,7 @@ class Job(models.Model):
|
||||
source_asset_id = models.UUIDField()
|
||||
video_path = models.CharField(max_length=1000)
|
||||
profile_name = models.CharField(max_length=255)
|
||||
timeline_id = models.UUIDField(null=True, blank=True)
|
||||
parent_id = models.UUIDField(null=True, blank=True)
|
||||
run_type = models.CharField(max_length=20, choices=RunType.choices, default=RunType.INITIAL)
|
||||
config_overrides = models.JSONField(default=dict, blank=True)
|
||||
@@ -113,7 +114,6 @@ class Job(models.Model):
|
||||
brands_found = models.IntegerField(default=0)
|
||||
cloud_llm_calls = models.IntegerField(default=0)
|
||||
estimated_cost_usd = models.FloatField(default=0.0)
|
||||
celery_task_id = models.CharField(max_length=255, null=True, blank=True)
|
||||
priority = models.IntegerField(default=0)
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
started_at = models.DateTimeField(null=True, blank=True)
|
||||
@@ -151,6 +151,7 @@ class Checkpoint(models.Model):
|
||||
|
||||
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
|
||||
timeline_id = models.UUIDField()
|
||||
job_id = models.UUIDField(null=True, blank=True)
|
||||
parent_id = models.UUIDField(null=True, blank=True)
|
||||
stage_outputs = models.JSONField(default=dict, blank=True)
|
||||
config_overrides = models.JSONField(default=dict, blank=True)
|
||||
@@ -185,3 +186,18 @@ class Brand(models.Model):
|
||||
def __str__(self):
|
||||
return str(self.id)
|
||||
|
||||
|
||||
class Profile(models.Model):
|
||||
"""A content type profile."""
|
||||
|
||||
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
|
||||
name = models.CharField(max_length=255)
|
||||
pipeline = models.JSONField(default=dict, blank=True)
|
||||
configs = models.JSONField(default=dict, blank=True)
|
||||
|
||||
class Meta:
|
||||
pass
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
"""
|
||||
SSE endpoint for chunker pipeline events.
|
||||
|
||||
Uses Redis as the event bus. Pipeline pushes events via core.events,
|
||||
SSE endpoint polls them.
|
||||
|
||||
GET /chunker/stream/{job_id} → text/event-stream
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from core.events import poll_events
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/chunker", tags=["chunker"])
|
||||
|
||||
|
||||
async def _event_generator(job_id: str) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Generate SSE events by polling Redis for chunk job events.
|
||||
"""
|
||||
cursor = 0
|
||||
timeout = time.monotonic() + 600 # 10 min max
|
||||
|
||||
while time.monotonic() < timeout:
|
||||
events, cursor = poll_events(job_id, cursor)
|
||||
|
||||
if not events:
|
||||
yield f"event: waiting\ndata: {json.dumps({'job_id': job_id})}\n\n"
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
for data in events:
|
||||
event_type = data.pop("event", "update")
|
||||
payload = {**data, "job_id": job_id}
|
||||
|
||||
yield f"event: {event_type}\ndata: {json.dumps(payload)}\n\n"
|
||||
|
||||
if event_type in ("pipeline_complete", "pipeline_error", "cancelled"):
|
||||
yield f"event: done\ndata: {json.dumps({'job_id': job_id})}\n\n"
|
||||
return
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
yield f"event: timeout\ndata: {json.dumps({'job_id': job_id})}\n\n"
|
||||
|
||||
|
||||
@router.get("/stream/{job_id}")
|
||||
async def stream_chunk_job(job_id: str):
|
||||
"""
|
||||
SSE stream for a chunk pipeline job.
|
||||
|
||||
The UI connects via native EventSource:
|
||||
const es = new EventSource('/api/chunker/stream/<job_id>');
|
||||
es.addEventListener('processing', (e) => { ... });
|
||||
"""
|
||||
return StreamingResponse(
|
||||
_event_generator(job_id),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
@@ -58,32 +58,30 @@ def write_config(update: ConfigUpdate):
|
||||
|
||||
|
||||
@router.get("/config/profiles")
|
||||
def list_profiles():
|
||||
def get_profiles():
|
||||
"""List available detection profiles."""
|
||||
from detect.profiles import _PROFILES
|
||||
return [{"name": name} for name in _PROFILES]
|
||||
from core.detect.profile import list_profiles as _list
|
||||
return [{"name": name} for name in _list()]
|
||||
|
||||
|
||||
@router.get("/config/profiles/{profile_name}/pipeline")
|
||||
def get_pipeline_config(profile_name: str):
|
||||
"""Return the pipeline composition for a profile."""
|
||||
from detect.profiles import get_profile
|
||||
from core.detect.profile import get_profile
|
||||
from fastapi import HTTPException
|
||||
from dataclasses import asdict
|
||||
|
||||
try:
|
||||
profile = get_profile(profile_name)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown profile: {profile_name}")
|
||||
|
||||
config = profile.pipeline_config()
|
||||
return asdict(config)
|
||||
return profile["pipeline"]
|
||||
|
||||
|
||||
@router.get("/config/stages", response_model=list[StageConfigInfo])
|
||||
def list_stage_configs():
|
||||
"""Return the stage palette with config field metadata for the editor."""
|
||||
from detect.stages import list_stages
|
||||
from core.detect.stages import list_stages
|
||||
|
||||
result = []
|
||||
for stage in list_stages():
|
||||
@@ -95,7 +93,7 @@ def list_stage_configs():
|
||||
@router.get("/config/stages/{stage_name}", response_model=StageConfigInfo)
|
||||
def get_stage_config(stage_name: str):
|
||||
"""Return config field metadata for a single stage."""
|
||||
from detect.stages import get_stage
|
||||
from core.detect.stages import get_stage
|
||||
|
||||
try:
|
||||
stage = get_stage(stage_name)
|
||||
|
||||
@@ -105,7 +105,7 @@ class ReplaySingleStageResponse(BaseModel):
|
||||
@router.get("/checkpoints/{timeline_id}")
|
||||
def list_checkpoints(timeline_id: str) -> list[CheckpointInfo]:
|
||||
"""List available checkpoint stages for a job."""
|
||||
from detect.checkpoint import list_checkpoints as _list
|
||||
from core.detect.checkpoint import list_checkpoints as _list
|
||||
|
||||
try:
|
||||
stages = _list(timeline_id)
|
||||
@@ -139,10 +139,10 @@ class CheckpointData(BaseModel):
|
||||
def get_checkpoint_data(timeline_id: str, stage: str):
|
||||
"""Load checkpoint frames + metadata for the editor UI."""
|
||||
from uuid import UUID
|
||||
from core.db.tables import Timeline, Checkpoint
|
||||
from core.db.models import Timeline, Checkpoint
|
||||
from core.db.connection import get_session
|
||||
from core.db.checkpoint import list_checkpoints
|
||||
from detect.checkpoint.frames import load_frames_b64
|
||||
from core.detect.checkpoint.frames import load_frames_b64
|
||||
|
||||
with get_session() as session:
|
||||
timeline = session.get(Timeline, UUID(timeline_id))
|
||||
@@ -184,7 +184,7 @@ def get_checkpoint_data(timeline_id: str, stage: str):
|
||||
@router.get("/scenarios", response_model=list[ScenarioInfo])
|
||||
def list_scenarios_endpoint():
|
||||
"""List all available scenarios (bookmarked checkpoints)."""
|
||||
from core.db.tables import Timeline
|
||||
from core.db.models import Timeline
|
||||
from core.db.connection import get_session
|
||||
from core.db.checkpoint import list_scenarios
|
||||
|
||||
@@ -212,7 +212,7 @@ def list_scenarios_endpoint():
|
||||
@router.post("/replay", response_model=ReplayResponse)
|
||||
def replay(req: ReplayRequest):
|
||||
"""Replay pipeline from a specific stage with optional config overrides."""
|
||||
from detect.checkpoint import replay_from
|
||||
from core.detect.checkpoint import replay_from
|
||||
|
||||
try:
|
||||
result = replay_from(
|
||||
@@ -242,7 +242,7 @@ def replay(req: ReplayRequest):
|
||||
@router.post("/retry", response_model=RetryResponse)
|
||||
def retry(req: RetryRequest):
|
||||
"""Queue an async retry of unresolved candidates with different config."""
|
||||
from detect.checkpoint.tasks import retry_candidates
|
||||
from core.detect.checkpoint.tasks import retry_candidates
|
||||
|
||||
kwargs = {
|
||||
"timeline_id": req.timeline_id,
|
||||
@@ -266,7 +266,7 @@ def retry(req: RetryRequest):
|
||||
@router.post("/replay-stage", response_model=ReplaySingleStageResponse)
|
||||
def replay_single_stage(req: ReplaySingleStageRequest):
|
||||
"""Replay a single stage on specific frames — fast path for interactive tuning."""
|
||||
from detect.checkpoint.replay import replay_single_stage as _replay
|
||||
from core.detect.checkpoint.replay import replay_single_stage as _replay
|
||||
|
||||
try:
|
||||
result = _replay(
|
||||
@@ -361,3 +361,41 @@ async def gpu_detect_edges_debug(request: Request):
|
||||
media_type="application/json")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
|
||||
|
||||
|
||||
@router.post("/gpu/segment_field")
|
||||
async def gpu_segment_field(request: Request):
|
||||
"""Proxy to GPU inference server — field segmentation."""
|
||||
import httpx
|
||||
|
||||
body = await request.body()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
f"{_gpu_url()}/segment_field",
|
||||
content=body,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
return Response(content=resp.content, status_code=resp.status_code,
|
||||
media_type="application/json")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
|
||||
|
||||
|
||||
@router.post("/gpu/segment_field/debug")
|
||||
async def gpu_segment_field_debug(request: Request):
|
||||
"""Proxy to GPU inference server — field segmentation with debug overlay."""
|
||||
import httpx
|
||||
|
||||
body = await request.body()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
f"{_gpu_url()}/segment_field/debug",
|
||||
content=body,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
return Response(content=resp.content, status_code=resp.status_code,
|
||||
media_type="application/json")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
|
||||
|
||||
@@ -60,9 +60,9 @@ def _resolve_video_path(video_path: str) -> str:
|
||||
@router.post("/run", response_model=RunResponse)
|
||||
def run_pipeline(req: RunRequest):
|
||||
"""Launch a detection pipeline run on a source chunk."""
|
||||
from detect import emit
|
||||
from detect.graph import get_pipeline
|
||||
from detect.state import DetectState
|
||||
from core.detect import emit
|
||||
from core.detect.graph import get_pipeline
|
||||
from core.detect.state import DetectState
|
||||
|
||||
local_path = _resolve_video_path(req.video_path)
|
||||
job_id = str(uuid.uuid4())
|
||||
@@ -79,7 +79,7 @@ def run_pipeline(req: RunRequest):
|
||||
|
||||
# Clear any stale events from a previous run with same job_id
|
||||
from core.events import _get_redis
|
||||
from detect.events import DETECT_EVENTS_PREFIX
|
||||
from core.detect.events import DETECT_EVENTS_PREFIX
|
||||
r = _get_redis()
|
||||
r.delete(f"{DETECT_EVENTS_PREFIX}:{job_id}")
|
||||
|
||||
@@ -97,7 +97,7 @@ def run_pipeline(req: RunRequest):
|
||||
source_asset_id=req.source_asset_id,
|
||||
)
|
||||
|
||||
from detect.graph import (
|
||||
from core.detect.graph import (
|
||||
PipelineCancelled, set_cancel_check, clear_cancel_check,
|
||||
init_pause, clear_pause,
|
||||
)
|
||||
@@ -117,7 +117,7 @@ def run_pipeline(req: RunRequest):
|
||||
emit.job_complete(job_id, {"status": "cancelled"})
|
||||
except Exception as e:
|
||||
logger.exception("Pipeline run %s failed: %s", job_id, e)
|
||||
from detect.graph import _node_states, NODES
|
||||
from core.detect.graph import _node_states, NODES
|
||||
if job_id in _node_states:
|
||||
states = _node_states[job_id]
|
||||
for node in reversed(NODES):
|
||||
@@ -145,7 +145,7 @@ def run_pipeline(req: RunRequest):
|
||||
@router.post("/stop/{job_id}")
|
||||
def stop_pipeline(job_id: str):
|
||||
"""Stop a running pipeline. Signals cancellation; the thread checks on next stage."""
|
||||
from detect import emit
|
||||
from core.detect import emit
|
||||
|
||||
if job_id not in _running_jobs:
|
||||
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
||||
@@ -158,7 +158,7 @@ def stop_pipeline(job_id: str):
|
||||
@router.post("/pause/{job_id}")
|
||||
def pause(job_id: str):
|
||||
"""Pause a running pipeline after the current stage completes."""
|
||||
from detect.graph import pause_pipeline
|
||||
from core.detect.graph import pause_pipeline
|
||||
|
||||
if job_id not in _running_jobs:
|
||||
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
||||
@@ -170,7 +170,7 @@ def pause(job_id: str):
|
||||
@router.post("/resume/{job_id}")
|
||||
def resume(job_id: str):
|
||||
"""Resume a paused pipeline."""
|
||||
from detect.graph import resume_pipeline
|
||||
from core.detect.graph import resume_pipeline
|
||||
|
||||
if job_id not in _running_jobs:
|
||||
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
||||
@@ -182,7 +182,7 @@ def resume(job_id: str):
|
||||
@router.post("/step/{job_id}")
|
||||
def step(job_id: str):
|
||||
"""Run one stage then pause again."""
|
||||
from detect.graph import step_pipeline
|
||||
from core.detect.graph import step_pipeline
|
||||
|
||||
if job_id not in _running_jobs:
|
||||
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
||||
@@ -194,7 +194,7 @@ def step(job_id: str):
|
||||
@router.post("/pause-after-stage/{job_id}")
|
||||
def toggle_pause_after_stage(job_id: str, enabled: bool = True):
|
||||
"""Toggle pause-after-each-stage mode."""
|
||||
from detect.graph import set_pause_after_stage
|
||||
from core.detect.graph import set_pause_after_stage
|
||||
|
||||
if job_id not in _running_jobs:
|
||||
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
||||
@@ -206,7 +206,7 @@ def toggle_pause_after_stage(job_id: str, enabled: bool = True):
|
||||
@router.get("/status/{job_id}")
|
||||
def pipeline_status(job_id: str):
|
||||
"""Get pipeline run status."""
|
||||
from detect.graph import is_paused
|
||||
from core.detect.graph import is_paused
|
||||
|
||||
running = job_id in _running_jobs
|
||||
paused = is_paused(job_id)
|
||||
@@ -224,11 +224,23 @@ def pipeline_status(job_id: str):
|
||||
return {"status": status, "job_id": job_id}
|
||||
|
||||
|
||||
@router.get("/timeline/{job_id}")
|
||||
def get_timeline_for_job(job_id: str):
|
||||
"""Get the timeline_id for a running or completed job."""
|
||||
from core.detect.checkpoint.runner_bridge import get_timeline_id
|
||||
|
||||
tid = get_timeline_id(job_id)
|
||||
if tid is None:
|
||||
raise HTTPException(status_code=404, detail=f"No timeline for job: {job_id}")
|
||||
|
||||
return {"timeline_id": tid, "job_id": job_id}
|
||||
|
||||
|
||||
@router.post("/clear/{job_id}")
|
||||
def clear_pipeline(job_id: str):
|
||||
"""Clear events for a job from Redis."""
|
||||
from core.events import _get_redis
|
||||
from detect.events import DETECT_EVENTS_PREFIX
|
||||
from core.detect.events import DETECT_EVENTS_PREFIX
|
||||
|
||||
r = _get_redis()
|
||||
r.delete(f"{DETECT_EVENTS_PREFIX}:{job_id}")
|
||||
|
||||
@@ -17,7 +17,7 @@ from fastapi import APIRouter
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from core.events import poll_events
|
||||
from detect.events import DETECT_EVENTS_PREFIX, TERMINAL_EVENTS
|
||||
from core.detect.events import DETECT_EVENTS_PREFIX, TERMINAL_EVENTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,384 +0,0 @@
|
||||
"""
|
||||
GraphQL API using strawberry, served via FastAPI.
|
||||
|
||||
Primary API for MPR — all client interactions go through GraphQL.
|
||||
Uses core.db for data access.
|
||||
Types are generated from schema/ via modelgen — see api/schema/graphql.py.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import strawberry
|
||||
from strawberry.schema.config import StrawberryConfig
|
||||
from strawberry.types import Info
|
||||
|
||||
from core.api.schema.graphql import (
|
||||
CancelResultType,
|
||||
ChunkJobType,
|
||||
ChunkOutputFileType,
|
||||
CreateChunkJobInput,
|
||||
CreateJobInput,
|
||||
DeleteResultType,
|
||||
MediaAssetType,
|
||||
ScanResultType,
|
||||
SystemStatusType,
|
||||
TranscodeJobType,
|
||||
TranscodePresetType,
|
||||
UpdateAssetInput,
|
||||
)
|
||||
from core.storage import BUCKET_IN, list_objects, upload_file
|
||||
|
||||
VIDEO_EXTS = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv", ".wmv", ".m4v"}
|
||||
AUDIO_EXTS = {".mp3", ".wav", ".flac", ".aac", ".ogg", ".m4a"}
|
||||
MEDIA_EXTS = VIDEO_EXTS | AUDIO_EXTS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Queries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class Query:
|
||||
@strawberry.field
|
||||
def assets(
|
||||
self,
|
||||
info: Info,
|
||||
status: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
) -> List[MediaAssetType]:
|
||||
from core.db import list_assets
|
||||
|
||||
return list_assets(status=status, search=search)
|
||||
|
||||
@strawberry.field
|
||||
def asset(self, info: Info, id: UUID) -> Optional[MediaAssetType]:
|
||||
from core.db import get_asset
|
||||
|
||||
try:
|
||||
return get_asset(id)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@strawberry.field
|
||||
def jobs(
|
||||
self,
|
||||
info: Info,
|
||||
status: Optional[str] = None,
|
||||
source_asset_id: Optional[UUID] = None,
|
||||
) -> List[TranscodeJobType]:
|
||||
from core.db import list_jobs
|
||||
|
||||
return list_jobs(status=status, source_asset_id=source_asset_id)
|
||||
|
||||
@strawberry.field
|
||||
def job(self, info: Info, id: UUID) -> Optional[TranscodeJobType]:
|
||||
from core.db import get_job
|
||||
|
||||
try:
|
||||
return get_job(id)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@strawberry.field
|
||||
def presets(self, info: Info) -> List[TranscodePresetType]:
|
||||
from core.db import list_presets
|
||||
|
||||
return list_presets()
|
||||
|
||||
@strawberry.field
|
||||
def system_status(self, info: Info) -> SystemStatusType:
|
||||
return SystemStatusType(status="ok", version="0.1.0")
|
||||
|
||||
@strawberry.field
|
||||
def chunk_output_files(self, info: Info, job_id: str) -> List[ChunkOutputFileType]:
|
||||
"""List output chunk files for a completed job from media/out/."""
|
||||
from pathlib import Path
|
||||
|
||||
media_out = os.environ.get("MEDIA_OUT_DIR", "/app/media/out")
|
||||
output_dir = Path(media_out) / "chunks" / job_id
|
||||
if not output_dir.is_dir():
|
||||
return []
|
||||
return [
|
||||
ChunkOutputFileType(
|
||||
key=f.name,
|
||||
size=f.stat().st_size,
|
||||
url=f"/media/out/chunks/{job_id}/{f.name}",
|
||||
)
|
||||
for f in sorted(output_dir.iterdir())
|
||||
if f.is_file()
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mutations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class Mutation:
|
||||
@strawberry.mutation
|
||||
def scan_media_folder(self, info: Info) -> ScanResultType:
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from core.db import create_asset, get_asset_filenames
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Sync local media/in/ files to MinIO (handles fresh installs / pruned volumes)
|
||||
local_media = Path("/app/media/in")
|
||||
if local_media.is_dir():
|
||||
existing_keys = {o["key"] for o in list_objects(BUCKET_IN)}
|
||||
for f in local_media.iterdir():
|
||||
if f.is_file() and f.suffix.lower() in MEDIA_EXTS:
|
||||
if f.name not in existing_keys:
|
||||
try:
|
||||
upload_file(str(f), BUCKET_IN, f.name)
|
||||
logger.info("Uploaded %s to MinIO", f.name)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to upload %s: %s", f.name, e)
|
||||
|
||||
objects = list_objects(BUCKET_IN, extensions=MEDIA_EXTS)
|
||||
existing = get_asset_filenames()
|
||||
|
||||
registered = []
|
||||
skipped = []
|
||||
|
||||
for obj in objects:
|
||||
if obj["filename"] in existing:
|
||||
skipped.append(obj["filename"])
|
||||
continue
|
||||
try:
|
||||
create_asset(
|
||||
filename=obj["filename"],
|
||||
file_path=obj["key"],
|
||||
file_size=obj["size"],
|
||||
)
|
||||
registered.append(obj["filename"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ScanResultType(
|
||||
found=len(objects),
|
||||
registered=len(registered),
|
||||
skipped=len(skipped),
|
||||
files=registered,
|
||||
)
|
||||
|
||||
@strawberry.mutation
|
||||
def create_job(self, info: Info, input: CreateJobInput) -> TranscodeJobType:
|
||||
from pathlib import Path
|
||||
|
||||
from core.db import create_job, get_asset, get_preset
|
||||
|
||||
try:
|
||||
source = get_asset(input.source_asset_id)
|
||||
except Exception:
|
||||
raise Exception("Source asset not found")
|
||||
|
||||
preset = None
|
||||
preset_snapshot = {}
|
||||
if input.preset_id:
|
||||
try:
|
||||
preset = get_preset(input.preset_id)
|
||||
preset_snapshot = {
|
||||
"name": preset.name,
|
||||
"container": preset.container,
|
||||
"video_codec": preset.video_codec,
|
||||
"audio_codec": preset.audio_codec,
|
||||
}
|
||||
except Exception:
|
||||
raise Exception("Preset not found")
|
||||
|
||||
if not preset and not input.trim_start and not input.trim_end:
|
||||
raise Exception("Must specify preset_id or trim_start/trim_end")
|
||||
|
||||
output_filename = input.output_filename
|
||||
if not output_filename:
|
||||
stem = Path(source.filename).stem
|
||||
ext = preset_snapshot.get("container", "mp4") if preset else "mp4"
|
||||
output_filename = f"{stem}_output.{ext}"
|
||||
|
||||
job = create_job(
|
||||
source_asset_id=source.id,
|
||||
preset_id=preset.id if preset else None,
|
||||
preset_snapshot=preset_snapshot,
|
||||
trim_start=input.trim_start,
|
||||
trim_end=input.trim_end,
|
||||
output_filename=output_filename,
|
||||
output_path=output_filename,
|
||||
priority=input.priority or 0,
|
||||
)
|
||||
|
||||
payload = {
|
||||
"source_key": source.file_path,
|
||||
"output_key": output_filename,
|
||||
"preset": preset_snapshot or None,
|
||||
"trim_start": input.trim_start,
|
||||
"trim_end": input.trim_end,
|
||||
"duration": source.duration,
|
||||
}
|
||||
|
||||
executor_mode = os.environ.get("MPR_EXECUTOR", "local")
|
||||
if executor_mode in ("lambda", "gcp"):
|
||||
from core.jobs.executor import get_executor
|
||||
|
||||
get_executor().run(
|
||||
job_type="transcode",
|
||||
job_id=str(job.id),
|
||||
payload=payload,
|
||||
)
|
||||
else:
|
||||
from core.jobs.task import run_job
|
||||
|
||||
result = run_job.delay(
|
||||
job_type="transcode",
|
||||
job_id=str(job.id),
|
||||
payload=payload,
|
||||
)
|
||||
job.celery_task_id = result.id
|
||||
job.save(update_fields=["celery_task_id"])
|
||||
|
||||
return job
|
||||
|
||||
@strawberry.mutation
|
||||
def cancel_job(self, info: Info, id: UUID) -> TranscodeJobType:
|
||||
from core.db import get_job, update_job
|
||||
|
||||
try:
|
||||
job = get_job(id)
|
||||
except Exception:
|
||||
raise Exception("Job not found")
|
||||
|
||||
if job.status not in ("pending", "processing"):
|
||||
raise Exception(f"Cannot cancel job with status: {job.status}")
|
||||
|
||||
return update_job(job, status="cancelled")
|
||||
|
||||
@strawberry.mutation
|
||||
def retry_job(self, info: Info, id: UUID) -> TranscodeJobType:
|
||||
from core.db import get_job, update_job
|
||||
|
||||
try:
|
||||
job = get_job(id)
|
||||
except Exception:
|
||||
raise Exception("Job not found")
|
||||
|
||||
if job.status != "failed":
|
||||
raise Exception("Only failed jobs can be retried")
|
||||
|
||||
return update_job(job, status="pending", progress=0, error_message=None)
|
||||
|
||||
@strawberry.mutation
|
||||
def update_asset(self, info: Info, id: UUID, input: UpdateAssetInput) -> MediaAssetType:
|
||||
from core.db import get_asset, update_asset
|
||||
|
||||
try:
|
||||
asset = get_asset(id)
|
||||
except Exception:
|
||||
raise Exception("Asset not found")
|
||||
|
||||
fields = {}
|
||||
if input.comments is not None:
|
||||
fields["comments"] = input.comments
|
||||
if input.tags is not None:
|
||||
fields["tags"] = input.tags
|
||||
|
||||
if fields:
|
||||
asset = update_asset(asset, **fields)
|
||||
|
||||
return asset
|
||||
|
||||
@strawberry.mutation
|
||||
def delete_asset(self, info: Info, id: UUID) -> DeleteResultType:
|
||||
from core.db import delete_asset, get_asset
|
||||
|
||||
try:
|
||||
asset = get_asset(id)
|
||||
delete_asset(asset)
|
||||
return DeleteResultType(ok=True)
|
||||
except Exception:
|
||||
raise Exception("Asset not found")
|
||||
|
||||
@strawberry.mutation
|
||||
def create_chunk_job(self, info: Info, input: CreateChunkJobInput) -> ChunkJobType:
|
||||
"""Create and dispatch a chunk pipeline job."""
|
||||
import uuid
|
||||
|
||||
from core.db import get_asset
|
||||
|
||||
try:
|
||||
source = get_asset(input.source_asset_id)
|
||||
except Exception:
|
||||
raise Exception("Source asset not found")
|
||||
|
||||
job_id = str(uuid.uuid4())
|
||||
|
||||
payload = {
|
||||
"source_key": source.file_path,
|
||||
"chunk_duration": input.chunk_duration,
|
||||
"num_workers": input.num_workers,
|
||||
"max_retries": input.max_retries,
|
||||
"processor_type": input.processor_type,
|
||||
"start_time": input.start_time,
|
||||
"end_time": input.end_time,
|
||||
}
|
||||
|
||||
executor_mode = os.environ.get("MPR_EXECUTOR", "local")
|
||||
celery_task_id = None
|
||||
|
||||
if executor_mode in ("lambda", "gcp"):
|
||||
from core.jobs.executor import get_executor
|
||||
|
||||
get_executor().run(
|
||||
job_type="chunk",
|
||||
job_id=job_id,
|
||||
payload=payload,
|
||||
)
|
||||
else:
|
||||
from core.jobs.task import run_job
|
||||
|
||||
result = run_job.delay(
|
||||
job_type="chunk",
|
||||
job_id=job_id,
|
||||
payload=payload,
|
||||
)
|
||||
celery_task_id = result.id
|
||||
|
||||
return ChunkJobType(
|
||||
id=uuid.UUID(job_id),
|
||||
source_asset_id=input.source_asset_id,
|
||||
chunk_duration=input.chunk_duration,
|
||||
num_workers=input.num_workers,
|
||||
max_retries=input.max_retries,
|
||||
processor_type=input.processor_type,
|
||||
status="pending",
|
||||
progress=0.0,
|
||||
priority=input.priority,
|
||||
celery_task_id=celery_task_id,
|
||||
)
|
||||
|
||||
@strawberry.mutation
|
||||
def cancel_chunk_job(self, info: Info, celery_task_id: str) -> CancelResultType:
|
||||
"""Cancel a running chunk job by revoking its Celery task."""
|
||||
try:
|
||||
from admin.mpr.celery import app as celery_app
|
||||
|
||||
celery_app.control.revoke(celery_task_id, terminate=True, signal="SIGTERM")
|
||||
return CancelResultType(ok=True, message="Task revoked")
|
||||
except Exception as e:
|
||||
return CancelResultType(ok=False, message=str(e))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
schema = strawberry.Schema(
|
||||
query=Query,
|
||||
mutation=Mutation,
|
||||
config=StrawberryConfig(auto_camel_case=False),
|
||||
)
|
||||
@@ -1,48 +1,38 @@
|
||||
"""
|
||||
MPR FastAPI Application
|
||||
|
||||
Serves GraphQL API and Lambda callback endpoint.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, Header, HTTPException
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from strawberry.fastapi import GraphQLRouter
|
||||
|
||||
from core.api.chunker_sse import router as chunker_router
|
||||
from core.api.detect import router as detect_router
|
||||
from core.api.graphql import schema as graphql_schema
|
||||
|
||||
CALLBACK_API_KEY = os.environ.get("CALLBACK_API_KEY", "")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app):
|
||||
# Create/reset DB tables on startup
|
||||
from core.db.connection import create_tables
|
||||
from core.db.seed import seed_profiles
|
||||
create_tables()
|
||||
seed_profiles()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="MPR API",
|
||||
description="Media Processor — GraphQL API",
|
||||
version="0.1.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://mpr.local.ar", "http://k8s.mpr.local.ar", "http://localhost:5173"],
|
||||
@@ -51,13 +41,6 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# GraphQL
|
||||
graphql_router = GraphQLRouter(schema=graphql_schema, graphql_ide="graphiql")
|
||||
app.include_router(graphql_router, prefix="/graphql")
|
||||
|
||||
# Chunker SSE
|
||||
app.include_router(chunker_router)
|
||||
|
||||
# Detection API (sources, run, SSE, replay, config)
|
||||
app.include_router(detect_router)
|
||||
|
||||
@@ -69,48 +52,7 @@ def health():
|
||||
|
||||
@app.get("/")
|
||||
def root():
|
||||
"""API root."""
|
||||
return {
|
||||
"name": "MPR API",
|
||||
"version": "0.1.0",
|
||||
"graphql": "/graphql",
|
||||
}
|
||||
|
||||
|
||||
@app.post("/api/jobs/{job_id}/callback")
|
||||
def job_callback(
|
||||
job_id: UUID,
|
||||
payload: dict,
|
||||
x_api_key: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
Callback endpoint for Lambda to report job completion.
|
||||
Protected by API key.
|
||||
"""
|
||||
if CALLBACK_API_KEY and x_api_key != CALLBACK_API_KEY:
|
||||
raise HTTPException(status_code=403, detail="Invalid API key")
|
||||
|
||||
from django.utils import timezone
|
||||
|
||||
from core.db import get_job, update_job
|
||||
|
||||
try:
|
||||
job = get_job(job_id)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
status = payload.get("status", "failed")
|
||||
fields = {
|
||||
"status": status,
|
||||
"progress": 100.0 if status == "completed" else job.progress,
|
||||
}
|
||||
|
||||
if payload.get("error"):
|
||||
fields["error_message"] = payload["error"]
|
||||
|
||||
if status in ("completed", "failed"):
|
||||
fields["completed_at"] = timezone.now()
|
||||
|
||||
update_job(job, **fields)
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
@@ -1,226 +0,0 @@
|
||||
"""
|
||||
Strawberry Types - GENERATED FILE
|
||||
|
||||
Do not edit directly. Regenerate using modelgen.
|
||||
"""
|
||||
|
||||
import strawberry
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
from strawberry.scalars import JSON
|
||||
|
||||
|
||||
@strawberry.enum
|
||||
class AssetStatus(Enum):
|
||||
PENDING = "pending"
|
||||
READY = "ready"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@strawberry.enum
|
||||
class JobStatus(Enum):
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class MediaAssetType:
|
||||
"""A video/audio file registered in the system."""
|
||||
|
||||
id: Optional[UUID] = None
|
||||
filename: Optional[str] = None
|
||||
file_path: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
file_size: Optional[float] = None
|
||||
duration: Optional[float] = None
|
||||
video_codec: Optional[str] = None
|
||||
audio_codec: Optional[str] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
framerate: Optional[float] = None
|
||||
bitrate: Optional[int] = None
|
||||
properties: Optional[JSON] = None
|
||||
comments: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class TranscodePresetType:
|
||||
"""A reusable transcoding configuration (like Handbrake presets)."""
|
||||
|
||||
id: Optional[UUID] = None
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
is_builtin: Optional[bool] = None
|
||||
container: Optional[str] = None
|
||||
video_codec: Optional[str] = None
|
||||
video_bitrate: Optional[str] = None
|
||||
video_crf: Optional[int] = None
|
||||
video_preset: Optional[str] = None
|
||||
resolution: Optional[str] = None
|
||||
framerate: Optional[float] = None
|
||||
audio_codec: Optional[str] = None
|
||||
audio_bitrate: Optional[str] = None
|
||||
audio_channels: Optional[int] = None
|
||||
audio_samplerate: Optional[int] = None
|
||||
extra_args: Optional[List[str]] = None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class TranscodeJobType:
|
||||
"""A transcoding or trimming job in the queue."""
|
||||
|
||||
id: Optional[UUID] = None
|
||||
source_asset_id: Optional[UUID] = None
|
||||
preset_id: Optional[UUID] = None
|
||||
preset_snapshot: Optional[JSON] = None
|
||||
trim_start: Optional[float] = None
|
||||
trim_end: Optional[float] = None
|
||||
output_filename: Optional[str] = None
|
||||
output_path: Optional[str] = None
|
||||
output_asset_id: Optional[UUID] = None
|
||||
status: Optional[str] = None
|
||||
progress: Optional[float] = None
|
||||
current_frame: Optional[int] = None
|
||||
current_time: Optional[float] = None
|
||||
speed: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
celery_task_id: Optional[str] = None
|
||||
execution_arn: Optional[str] = None
|
||||
priority: Optional[int] = None
|
||||
created_at: Optional[datetime] = None
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class CreateJobInput:
|
||||
"""Request body for creating a transcode/trim job."""
|
||||
|
||||
source_asset_id: UUID
|
||||
preset_id: Optional[UUID] = None
|
||||
trim_start: Optional[float] = None
|
||||
trim_end: Optional[float] = None
|
||||
output_filename: Optional[str] = None
|
||||
priority: int = 0
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class UpdateAssetInput:
|
||||
"""Request body for updating asset metadata."""
|
||||
|
||||
comments: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class SystemStatusType:
|
||||
"""System status response."""
|
||||
|
||||
status: Optional[str] = None
|
||||
version: Optional[str] = None
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class ScanResultType:
|
||||
"""Result of scanning the media input bucket."""
|
||||
|
||||
found: Optional[int] = None
|
||||
registered: Optional[int] = None
|
||||
skipped: Optional[int] = None
|
||||
files: Optional[List[str]] = None
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class DeleteResultType:
|
||||
"""Result of a delete operation."""
|
||||
|
||||
ok: Optional[bool] = None
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class WorkerStatusType:
|
||||
"""Worker health and capabilities."""
|
||||
|
||||
available: Optional[bool] = None
|
||||
active_jobs: Optional[int] = None
|
||||
supported_codecs: Optional[List[str]] = None
|
||||
gpu_available: Optional[bool] = None
|
||||
|
||||
|
||||
@strawberry.enum
|
||||
class ChunkJobStatus(Enum):
|
||||
PENDING = "pending"
|
||||
CHUNKING = "chunking"
|
||||
PROCESSING = "processing"
|
||||
COLLECTING = "collecting"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class ChunkJobType:
|
||||
"""A chunk pipeline job."""
|
||||
|
||||
id: Optional[UUID] = None
|
||||
source_asset_id: Optional[UUID] = None
|
||||
chunk_duration: Optional[float] = None
|
||||
num_workers: Optional[int] = None
|
||||
max_retries: Optional[int] = None
|
||||
processor_type: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
progress: Optional[float] = None
|
||||
total_chunks: Optional[int] = None
|
||||
processed_chunks: Optional[int] = None
|
||||
failed_chunks: Optional[int] = None
|
||||
retry_count: Optional[int] = None
|
||||
error_message: Optional[str] = None
|
||||
throughput_mbps: Optional[float] = None
|
||||
elapsed_seconds: Optional[float] = None
|
||||
celery_task_id: Optional[str] = None
|
||||
priority: Optional[int] = None
|
||||
created_at: Optional[datetime] = None
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class CreateChunkJobInput:
|
||||
"""Request body for creating a chunk pipeline job."""
|
||||
|
||||
source_asset_id: UUID
|
||||
chunk_duration: float = 10.0
|
||||
num_workers: int = 4
|
||||
max_retries: int = 3
|
||||
processor_type: str = "ffmpeg"
|
||||
priority: int = 0
|
||||
start_time: Optional[float] = None
|
||||
end_time: Optional[float] = None
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class CancelResultType:
|
||||
"""Result of cancelling a chunk job."""
|
||||
|
||||
ok: bool = False
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class ChunkOutputFileType:
|
||||
"""A chunk output file in S3/MinIO with presigned download URL."""
|
||||
|
||||
key: str
|
||||
size: int = 0
|
||||
url: str = ""
|
||||
@@ -1,64 +0,0 @@
|
||||
"""
|
||||
Chunker pipeline — splits files into chunks, processes concurrently, reassembles in order.
|
||||
|
||||
Public API:
|
||||
Pipeline — orchestrates the full pipeline
|
||||
PipelineResult — aggregate result dataclass
|
||||
Chunker — file → Chunk generator
|
||||
ChunkQueue — bounded thread-safe queue
|
||||
WorkerPool — manages N worker threads
|
||||
ResultCollector — heapq-based ordered reassembly
|
||||
"""
|
||||
|
||||
from .chunker import Chunker
|
||||
from .collector import ResultCollector
|
||||
from .exceptions import (
|
||||
ChunkChecksumError,
|
||||
ChunkError,
|
||||
ChunkReadError,
|
||||
PipelineError,
|
||||
ProcessingError,
|
||||
ProcessorFailureError,
|
||||
ProcessorTimeoutError,
|
||||
ReassemblyError,
|
||||
)
|
||||
from .models import Chunk, ChunkResult, PipelineResult
|
||||
from .pipeline import Pipeline
|
||||
from .pool import WorkerPool
|
||||
from .processor import (
|
||||
ChecksumProcessor,
|
||||
CompositeProcessor,
|
||||
FFmpegExtractProcessor,
|
||||
Processor,
|
||||
SimulatedDecodeProcessor,
|
||||
)
|
||||
from .queue import ChunkQueue
|
||||
|
||||
__all__ = [
|
||||
# Core
|
||||
"Pipeline",
|
||||
"PipelineResult",
|
||||
# Components
|
||||
"Chunker",
|
||||
"ChunkQueue",
|
||||
"WorkerPool",
|
||||
"ResultCollector",
|
||||
# Models
|
||||
"Chunk",
|
||||
"ChunkResult",
|
||||
# Processors
|
||||
"Processor",
|
||||
"ChecksumProcessor",
|
||||
"SimulatedDecodeProcessor",
|
||||
"CompositeProcessor",
|
||||
"FFmpegExtractProcessor",
|
||||
# Exceptions
|
||||
"PipelineError",
|
||||
"ChunkError",
|
||||
"ChunkReadError",
|
||||
"ChunkChecksumError",
|
||||
"ProcessingError",
|
||||
"ProcessorFailureError",
|
||||
"ProcessorTimeoutError",
|
||||
"ReassemblyError",
|
||||
]
|
||||
@@ -1,101 +0,0 @@
|
||||
"""
|
||||
Chunker — probes a media file and yields time-based Chunk objects.
|
||||
|
||||
Demonstrates:
|
||||
- Function parameters and defaults (Interview Topic 1)
|
||||
- List comprehensions and efficient iteration / generators (Interview Topic 3)
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Generator
|
||||
|
||||
from core.ffmpeg.probe import probe_file
|
||||
|
||||
from .exceptions import ChunkReadError
|
||||
from .models import Chunk
|
||||
|
||||
|
||||
class Chunker:
|
||||
"""
|
||||
Splits a media file into time-based chunks via a generator.
|
||||
|
||||
Uses FFmpeg probe to get duration, then yields Chunk objects
|
||||
representing time segments (no data read — extraction happens in the processor).
|
||||
|
||||
Args:
|
||||
file_path: Path to the source media file
|
||||
chunk_duration: Duration of each chunk in seconds (default: 10.0)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
chunk_duration: float = 10.0,
|
||||
start_time: float | None = None,
|
||||
end_time: float | None = None,
|
||||
):
|
||||
if not os.path.isfile(file_path):
|
||||
raise ChunkReadError(f"File not found: {file_path}")
|
||||
if chunk_duration <= 0:
|
||||
raise ValueError("chunk_duration must be positive")
|
||||
|
||||
self.file_path = file_path
|
||||
self.chunk_duration = chunk_duration
|
||||
self.file_size = os.path.getsize(file_path)
|
||||
full_duration = self._probe_duration()
|
||||
|
||||
# Apply time range
|
||||
self.range_start = max(start_time or 0.0, 0.0)
|
||||
self.range_end = min(end_time or full_duration, full_duration)
|
||||
if self.range_start >= self.range_end:
|
||||
raise ValueError(
|
||||
f"Invalid range: start={self.range_start} >= end={self.range_end}"
|
||||
)
|
||||
self.source_duration = self.range_end - self.range_start
|
||||
|
||||
def _probe_duration(self) -> float:
|
||||
"""Get source file duration via FFmpeg probe."""
|
||||
try:
|
||||
result = probe_file(self.file_path)
|
||||
if result.duration is None or result.duration <= 0:
|
||||
raise ChunkReadError(
|
||||
f"Cannot determine duration for {self.file_path}"
|
||||
)
|
||||
return result.duration
|
||||
except ChunkReadError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ChunkReadError(
|
||||
f"Failed to probe {self.file_path}: {e}"
|
||||
) from e
|
||||
|
||||
@property
|
||||
def expected_chunks(self) -> int:
|
||||
"""Calculate expected number of chunks (last chunk may be shorter)."""
|
||||
if self.source_duration <= 0:
|
||||
return 0
|
||||
return math.ceil(self.source_duration / self.chunk_duration)
|
||||
|
||||
def chunks(self) -> Generator[Chunk, None, None]:
|
||||
"""
|
||||
Yield Chunk objects representing time segments of the source file.
|
||||
|
||||
Generator-based: chunks are yielded on demand.
|
||||
Each chunk defines a time range — actual extraction is done by the processor.
|
||||
"""
|
||||
total = self.expected_chunks
|
||||
for sequence in range(total):
|
||||
start_time = self.range_start + sequence * self.chunk_duration
|
||||
end_time = min(
|
||||
start_time + self.chunk_duration, self.range_end
|
||||
)
|
||||
duration = end_time - start_time
|
||||
|
||||
yield Chunk(
|
||||
sequence=sequence,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
source_path=self.file_path,
|
||||
duration=duration,
|
||||
)
|
||||
@@ -1,98 +0,0 @@
|
||||
"""
|
||||
ResultCollector — reassembles chunk results in sequence order using a min-heap.
|
||||
|
||||
Demonstrates:
|
||||
- Algorithms and sorting (Interview Topic 6) — heapq for ordered reassembly
|
||||
- Core data structures (Interview Topic 5) — heap, deque
|
||||
"""
|
||||
|
||||
import heapq
|
||||
from collections import deque
|
||||
from typing import List
|
||||
|
||||
from .exceptions import ReassemblyError
|
||||
from .models import ChunkResult
|
||||
|
||||
|
||||
class ResultCollector:
|
||||
"""
|
||||
Receives ChunkResults out of order, emits them in sequence order.
|
||||
|
||||
Uses a min-heap keyed on sequence number. Only emits a chunk when
|
||||
all prior sequences have been accounted for.
|
||||
|
||||
Args:
|
||||
total_chunks: Expected total number of chunks
|
||||
"""
|
||||
|
||||
def __init__(self, total_chunks: int):
|
||||
self.total_chunks = total_chunks
|
||||
self._heap: List[tuple[int, ChunkResult]] = []
|
||||
self._next_sequence = 0
|
||||
self._emitted: List[ChunkResult] = []
|
||||
self._seen_sequences: set[int] = set()
|
||||
# Sliding window for throughput calculation
|
||||
self._recent_times: deque[float] = deque(maxlen=50)
|
||||
|
||||
def add(self, result: ChunkResult) -> List[ChunkResult]:
|
||||
"""
|
||||
Add a result and return any newly emittable results in order.
|
||||
|
||||
Args:
|
||||
result: A ChunkResult (may arrive out of order)
|
||||
|
||||
Returns:
|
||||
List of results that can now be emitted in sequence order
|
||||
(may be empty if we're still waiting for earlier sequences)
|
||||
|
||||
Raises:
|
||||
ReassemblyError: If a duplicate sequence is received
|
||||
"""
|
||||
if result.sequence in self._seen_sequences:
|
||||
raise ReassemblyError(
|
||||
f"Duplicate sequence number: {result.sequence}"
|
||||
)
|
||||
self._seen_sequences.add(result.sequence)
|
||||
|
||||
# Track processing time for throughput
|
||||
if result.processing_time > 0:
|
||||
self._recent_times.append(result.processing_time)
|
||||
|
||||
# Push to min-heap
|
||||
heapq.heappush(self._heap, (result.sequence, result))
|
||||
|
||||
# Emit all consecutive results starting from _next_sequence
|
||||
newly_emitted = []
|
||||
while self._heap and self._heap[0][0] == self._next_sequence:
|
||||
_, emitted_result = heapq.heappop(self._heap)
|
||||
self._emitted.append(emitted_result)
|
||||
newly_emitted.append(emitted_result)
|
||||
self._next_sequence += 1
|
||||
|
||||
return newly_emitted
|
||||
|
||||
@property
|
||||
def is_complete(self) -> bool:
|
||||
"""True if all expected chunks have been emitted in order."""
|
||||
return self._next_sequence == self.total_chunks
|
||||
|
||||
@property
|
||||
def buffered_count(self) -> int:
|
||||
"""Number of results waiting in the heap (arrived out of order)."""
|
||||
return len(self._heap)
|
||||
|
||||
@property
|
||||
def emitted_count(self) -> int:
|
||||
"""Number of results emitted in sequence order."""
|
||||
return len(self._emitted)
|
||||
|
||||
@property
|
||||
def avg_processing_time(self) -> float:
|
||||
"""Average processing time from recent results (sliding window)."""
|
||||
if not self._recent_times:
|
||||
return 0.0
|
||||
return sum(self._recent_times) / len(self._recent_times)
|
||||
|
||||
def get_ordered_results(self) -> List[ChunkResult]:
|
||||
"""Get all emitted results in sequence order."""
|
||||
return list(self._emitted)
|
||||
@@ -1,64 +0,0 @@
|
||||
"""
|
||||
Chunker exception hierarchy.
|
||||
|
||||
Demonstrates: Managing exceptions and writing resilient code (Interview Topic 7).
|
||||
"""
|
||||
|
||||
|
||||
class PipelineError(Exception):
|
||||
"""Base exception for all chunker pipeline errors."""
|
||||
pass
|
||||
|
||||
|
||||
class ChunkError(PipelineError):
|
||||
"""Errors related to chunk creation or validation."""
|
||||
pass
|
||||
|
||||
|
||||
class ChunkReadError(ChunkError):
|
||||
"""Failed to read chunk data from source file."""
|
||||
pass
|
||||
|
||||
|
||||
class ChunkChecksumError(ChunkError):
|
||||
"""Chunk data integrity validation failed."""
|
||||
|
||||
def __init__(self, sequence: int, expected: str, actual: str):
|
||||
self.sequence = sequence
|
||||
self.expected = expected
|
||||
self.actual = actual
|
||||
super().__init__(
|
||||
f"Chunk {sequence}: checksum mismatch "
|
||||
f"(expected={expected}, actual={actual})"
|
||||
)
|
||||
|
||||
|
||||
class ProcessingError(PipelineError):
|
||||
"""Errors during chunk processing by workers."""
|
||||
pass
|
||||
|
||||
|
||||
class ProcessorTimeoutError(ProcessingError):
|
||||
"""Processor exceeded allowed time for a chunk."""
|
||||
|
||||
def __init__(self, sequence: int, timeout: float):
|
||||
self.sequence = sequence
|
||||
self.timeout = timeout
|
||||
super().__init__(f"Chunk {sequence}: processor timed out after {timeout}s")
|
||||
|
||||
|
||||
class ProcessorFailureError(ProcessingError):
|
||||
"""Processor failed to process a chunk after all retries."""
|
||||
|
||||
def __init__(self, sequence: int, retries: int, original_error: Exception):
|
||||
self.sequence = sequence
|
||||
self.retries = retries
|
||||
self.original_error = original_error
|
||||
super().__init__(
|
||||
f"Chunk {sequence}: failed after {retries} retries — {original_error}"
|
||||
)
|
||||
|
||||
|
||||
class ReassemblyError(PipelineError):
|
||||
"""Errors during result collection and ordering."""
|
||||
pass
|
||||
@@ -1,54 +0,0 @@
|
||||
"""
|
||||
Internal data models for the chunker pipeline.
|
||||
|
||||
These are pipeline-internal dataclasses, not schema models.
|
||||
Schema-level ChunkJob is in core/schema/models/jobs.py.
|
||||
|
||||
Demonstrates: Core data structures (Interview Topic 5).
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chunk:
|
||||
"""A time-based segment of the source media file."""
|
||||
|
||||
sequence: int
|
||||
start_time: float # seconds
|
||||
end_time: float # seconds
|
||||
source_path: str # path to source file
|
||||
duration: float # end_time - start_time
|
||||
checksum: str = "" # computed after extraction
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkResult:
|
||||
"""Result of processing a single chunk."""
|
||||
|
||||
sequence: int
|
||||
success: bool
|
||||
checksum_valid: bool = True
|
||||
processing_time: float = 0.0
|
||||
error: Optional[str] = None
|
||||
retries: int = 0
|
||||
worker_id: Optional[str] = None
|
||||
output_file: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineResult:
|
||||
"""Aggregate result of the entire pipeline run."""
|
||||
|
||||
total_chunks: int = 0
|
||||
processed: int = 0
|
||||
failed: int = 0
|
||||
retries: int = 0
|
||||
elapsed_time: float = 0.0
|
||||
throughput_mbps: float = 0.0
|
||||
worker_stats: Dict[str, Any] = field(default_factory=dict)
|
||||
errors: List[str] = field(default_factory=list)
|
||||
chunks_in_order: bool = True
|
||||
output_dir: Optional[str] = None
|
||||
chunk_files: List[str] = field(default_factory=list)
|
||||
@@ -1,279 +0,0 @@
|
||||
"""
|
||||
Pipeline — orchestrates the entire chunker pipeline.
|
||||
|
||||
Wires: Chunker → ChunkQueue → WorkerPool → ResultCollector → PipelineResult
|
||||
|
||||
Demonstrates:
|
||||
- Function parameters and defaults (Interview Topic 1) — configurable pipeline
|
||||
- Concurrency (Interview Topic 2) — producer thread + worker pool
|
||||
- OOP design (Interview Topic 4) — composition of pipeline components
|
||||
- Exception handling (Interview Topic 7) — graceful error propagation
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from .chunker import Chunker
|
||||
from .collector import ResultCollector
|
||||
from .exceptions import PipelineError
|
||||
from .models import PipelineResult
|
||||
from .pool import WorkerPool
|
||||
from .queue import ChunkQueue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""
|
||||
Orchestrates the chunk processing pipeline.
|
||||
|
||||
The pipeline runs in three stages:
|
||||
1. Producer thread: Chunker probes file → pushes time-based chunks to ChunkQueue
|
||||
2. Worker pool: N workers pull from queue → extract mp4 segments → emit results
|
||||
3. Collector: ResultCollector reassembles results in sequence order
|
||||
|
||||
Args:
|
||||
source: Path to the source media file
|
||||
chunk_duration: Duration of each chunk in seconds (default: 10.0)
|
||||
num_workers: Number of concurrent worker threads (default: 4)
|
||||
max_retries: Max retry attempts per chunk (default: 3)
|
||||
processor_type: Processor to use — "ffmpeg", "checksum", "simulated_decode", "composite"
|
||||
queue_size: Max chunks buffered in queue (default: 10)
|
||||
event_callback: Optional callback for real-time events
|
||||
output_dir: Directory for output chunk files (required for "ffmpeg" processor)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source: str,
|
||||
chunk_duration: float = 10.0,
|
||||
num_workers: int = 4,
|
||||
max_retries: int = 3,
|
||||
processor_type: str = "checksum",
|
||||
queue_size: int = 10,
|
||||
event_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
):
|
||||
self.source = source
|
||||
self.chunk_duration = chunk_duration
|
||||
self.num_workers = num_workers
|
||||
self.max_retries = max_retries
|
||||
self.processor_type = processor_type
|
||||
self.queue_size = queue_size
|
||||
self.event_callback = event_callback
|
||||
self.output_dir = output_dir
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
|
||||
def _emit(self, event_type: str, data: Dict[str, Any]) -> None:
|
||||
"""Emit an event if callback is registered."""
|
||||
if self.event_callback:
|
||||
self.event_callback(event_type, data)
|
||||
|
||||
def _produce_chunks(
|
||||
self, chunker: Chunker, chunk_queue: ChunkQueue
|
||||
) -> None:
|
||||
"""Producer thread: probe file and enqueue time-based chunks."""
|
||||
try:
|
||||
for chunk in chunker.chunks():
|
||||
chunk_queue.put(chunk, timeout=30.0)
|
||||
self._emit("chunk_queued", {
|
||||
"sequence": chunk.sequence,
|
||||
"start_time": chunk.start_time,
|
||||
"end_time": chunk.end_time,
|
||||
"duration": chunk.duration,
|
||||
"queue_size": chunk_queue.qsize(),
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Producer error: {e}")
|
||||
self._emit("producer_error", {"error": str(e)})
|
||||
finally:
|
||||
chunk_queue.close()
|
||||
|
||||
def _monitor_progress(
|
||||
self, start_time: float, file_size: int, stop_event: threading.Event
|
||||
) -> None:
|
||||
"""Monitor thread: emit pipeline_progress every 500ms."""
|
||||
while not stop_event.is_set():
|
||||
elapsed = time.monotonic() - start_time
|
||||
mb = file_size / (1024 * 1024)
|
||||
self._emit("pipeline_progress", {
|
||||
"elapsed": round(elapsed, 2),
|
||||
"throughput_mbps": round(mb / elapsed, 2) if elapsed > 0 else 0,
|
||||
})
|
||||
stop_event.wait(0.5)
|
||||
|
||||
def _write_manifest(
|
||||
self, result: PipelineResult, source_duration: float
|
||||
) -> None:
|
||||
"""Write manifest.json to output_dir with segment metadata."""
|
||||
if not self.output_dir:
|
||||
return
|
||||
|
||||
manifest = {
|
||||
"source": self.source,
|
||||
"source_duration": source_duration,
|
||||
"chunk_duration": self.chunk_duration,
|
||||
"total_chunks": result.total_chunks,
|
||||
"processed": result.processed,
|
||||
"failed": result.failed,
|
||||
"elapsed_time": result.elapsed_time,
|
||||
"throughput_mbps": result.throughput_mbps,
|
||||
"segments": [
|
||||
{
|
||||
"sequence": i,
|
||||
"file": f"chunk_{i:04d}.mp4",
|
||||
"start": i * self.chunk_duration,
|
||||
"end": min(
|
||||
(i + 1) * self.chunk_duration, source_duration
|
||||
),
|
||||
}
|
||||
for i in range(result.total_chunks)
|
||||
if i < result.total_chunks
|
||||
],
|
||||
}
|
||||
|
||||
manifest_path = Path(self.output_dir) / "manifest.json"
|
||||
manifest_path.write_text(json.dumps(manifest, indent=2))
|
||||
logger.info(f"Manifest written to {manifest_path}")
|
||||
|
||||
def run(self) -> PipelineResult:
|
||||
"""
|
||||
Execute the full pipeline.
|
||||
|
||||
Returns:
|
||||
PipelineResult with aggregate stats
|
||||
|
||||
Raises:
|
||||
PipelineError: If the pipeline fails catastrophically
|
||||
"""
|
||||
start_time = time.monotonic()
|
||||
self._emit("pipeline_start", {
|
||||
"source": self.source,
|
||||
"chunk_duration": self.chunk_duration,
|
||||
"num_workers": self.num_workers,
|
||||
"processor_type": self.processor_type,
|
||||
})
|
||||
|
||||
try:
|
||||
# Stage 1: Set up chunker (probes file for duration)
|
||||
chunker = Chunker(
|
||||
self.source,
|
||||
self.chunk_duration,
|
||||
start_time=self.start_time,
|
||||
end_time=self.end_time,
|
||||
)
|
||||
total_chunks = chunker.expected_chunks
|
||||
|
||||
if total_chunks == 0:
|
||||
self._emit("pipeline_complete", {"total_chunks": 0})
|
||||
return PipelineResult(chunks_in_order=True)
|
||||
|
||||
self._emit("pipeline_info", {
|
||||
"file_size": chunker.file_size,
|
||||
"source_duration": chunker.source_duration,
|
||||
"total_chunks": total_chunks,
|
||||
})
|
||||
|
||||
# Stage 2: Set up queue and worker pool
|
||||
chunk_queue = ChunkQueue(maxsize=self.queue_size)
|
||||
pool = WorkerPool(
|
||||
num_workers=self.num_workers,
|
||||
chunk_queue=chunk_queue,
|
||||
processor_type=self.processor_type,
|
||||
max_retries=self.max_retries,
|
||||
event_callback=self.event_callback,
|
||||
output_dir=self.output_dir,
|
||||
)
|
||||
|
||||
# Stage 3: Start workers, monitor, then produce chunks
|
||||
pool.start()
|
||||
|
||||
monitor_stop = threading.Event()
|
||||
monitor = threading.Thread(
|
||||
target=self._monitor_progress,
|
||||
args=(start_time, chunker.file_size, monitor_stop),
|
||||
name="progress-monitor",
|
||||
daemon=True,
|
||||
)
|
||||
monitor.start()
|
||||
|
||||
producer = threading.Thread(
|
||||
target=self._produce_chunks,
|
||||
args=(chunker, chunk_queue),
|
||||
name="chunk-producer",
|
||||
daemon=True,
|
||||
)
|
||||
producer.start()
|
||||
|
||||
# Stage 4: Wait for all workers to finish
|
||||
all_results = pool.wait()
|
||||
producer.join(timeout=5.0)
|
||||
|
||||
# Stop monitor
|
||||
monitor_stop.set()
|
||||
monitor.join(timeout=2.0)
|
||||
|
||||
# Stage 5: Collect results in order
|
||||
collector = ResultCollector(total_chunks)
|
||||
for r in all_results:
|
||||
collector.add(r)
|
||||
self._emit("chunk_collected", {
|
||||
"sequence": r.sequence,
|
||||
"success": r.success,
|
||||
"buffered": collector.buffered_count,
|
||||
"emitted": collector.emitted_count,
|
||||
})
|
||||
|
||||
# Build result
|
||||
elapsed = time.monotonic() - start_time
|
||||
file_size_mb = chunker.file_size / (1024 * 1024)
|
||||
throughput = file_size_mb / elapsed if elapsed > 0 else 0.0
|
||||
|
||||
failed_results = [r for r in all_results if not r.success]
|
||||
total_retries = sum(r.retries for r in all_results)
|
||||
chunk_files = [
|
||||
r.output_file for r in all_results
|
||||
if r.success and r.output_file
|
||||
]
|
||||
|
||||
result = PipelineResult(
|
||||
total_chunks=total_chunks,
|
||||
processed=len(all_results),
|
||||
failed=len(failed_results),
|
||||
retries=total_retries,
|
||||
elapsed_time=elapsed,
|
||||
throughput_mbps=throughput,
|
||||
worker_stats=pool.get_worker_stats(),
|
||||
errors=[r.error for r in failed_results if r.error],
|
||||
chunks_in_order=collector.is_complete,
|
||||
output_dir=self.output_dir,
|
||||
chunk_files=chunk_files,
|
||||
)
|
||||
|
||||
# Write manifest if output_dir is set
|
||||
self._write_manifest(result, chunker.source_duration)
|
||||
|
||||
pool.shutdown()
|
||||
|
||||
self._emit("pipeline_complete", {
|
||||
"total_chunks": result.total_chunks,
|
||||
"processed": result.processed,
|
||||
"failed": result.failed,
|
||||
"elapsed": result.elapsed_time,
|
||||
"throughput_mbps": result.throughput_mbps,
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
except PipelineError:
|
||||
raise
|
||||
except Exception as e:
|
||||
self._emit("pipeline_error", {"error": str(e)})
|
||||
raise PipelineError(f"Pipeline failed: {e}") from e
|
||||
@@ -1,125 +0,0 @@
|
||||
"""
|
||||
WorkerPool — manages N worker threads via ThreadPoolExecutor.
|
||||
|
||||
Demonstrates: Python concurrency — threading (Interview Topic 2).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from .models import ChunkResult
|
||||
from .processor import (
|
||||
ChecksumProcessor,
|
||||
CompositeProcessor,
|
||||
FFmpegExtractProcessor,
|
||||
Processor,
|
||||
SimulatedDecodeProcessor,
|
||||
)
|
||||
from .queue import ChunkQueue
|
||||
from .worker import Worker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_processor(
|
||||
processor_type: str = "checksum",
|
||||
output_dir: Optional[str] = None,
|
||||
) -> Processor:
|
||||
"""Factory for processor instances."""
|
||||
if processor_type == "ffmpeg":
|
||||
if not output_dir:
|
||||
raise ValueError("output_dir required for ffmpeg processor")
|
||||
return FFmpegExtractProcessor(output_dir=output_dir)
|
||||
elif processor_type == "checksum":
|
||||
return ChecksumProcessor()
|
||||
elif processor_type == "simulated_decode":
|
||||
return SimulatedDecodeProcessor()
|
||||
elif processor_type == "composite":
|
||||
return CompositeProcessor([
|
||||
ChecksumProcessor(),
|
||||
SimulatedDecodeProcessor(ms_per_second=50.0),
|
||||
])
|
||||
else:
|
||||
raise ValueError(f"Unknown processor type: {processor_type}")
|
||||
|
||||
|
||||
class WorkerPool:
|
||||
"""
|
||||
Manages N worker threads that process chunks concurrently.
|
||||
|
||||
Args:
|
||||
num_workers: Number of concurrent worker threads (default: 4)
|
||||
chunk_queue: Shared queue to pull chunks from
|
||||
processor_type: Type of processor for each worker (default: "checksum")
|
||||
max_retries: Max retry attempts per chunk (default: 3)
|
||||
event_callback: Optional callback for real-time events
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_workers: int = 4,
|
||||
chunk_queue: Optional[ChunkQueue] = None,
|
||||
processor_type: str = "checksum",
|
||||
max_retries: int = 3,
|
||||
event_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
):
|
||||
self.num_workers = num_workers
|
||||
self.chunk_queue = chunk_queue or ChunkQueue()
|
||||
self.processor_type = processor_type
|
||||
self.max_retries = max_retries
|
||||
self.event_callback = event_callback
|
||||
self.output_dir = output_dir
|
||||
self.shutdown_event = threading.Event()
|
||||
self._executor: Optional[ThreadPoolExecutor] = None
|
||||
self._futures: List[Future] = []
|
||||
self._workers: List[Worker] = []
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start all worker threads."""
|
||||
self._executor = ThreadPoolExecutor(
|
||||
max_workers=self.num_workers,
|
||||
thread_name_prefix="chunk-worker",
|
||||
)
|
||||
|
||||
for i in range(self.num_workers):
|
||||
worker = Worker(
|
||||
worker_id=f"worker-{i}",
|
||||
chunk_queue=self.chunk_queue,
|
||||
processor=create_processor(self.processor_type, output_dir=self.output_dir),
|
||||
max_retries=self.max_retries,
|
||||
event_callback=self.event_callback,
|
||||
)
|
||||
self._workers.append(worker)
|
||||
future = self._executor.submit(worker.run)
|
||||
self._futures.append(future)
|
||||
|
||||
logger.info(f"WorkerPool started with {self.num_workers} workers")
|
||||
|
||||
def wait(self) -> List[ChunkResult]:
|
||||
"""Wait for all workers to finish and collect results."""
|
||||
all_results = []
|
||||
for future in self._futures:
|
||||
results = future.result()
|
||||
all_results.extend(results)
|
||||
return all_results
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Signal shutdown and cleanup."""
|
||||
self.shutdown_event.set()
|
||||
self.chunk_queue.close()
|
||||
if self._executor:
|
||||
self._executor.shutdown(wait=True)
|
||||
|
||||
def get_worker_stats(self) -> Dict[str, Any]:
|
||||
"""Get per-worker statistics."""
|
||||
return {
|
||||
w.worker_id: {
|
||||
"processed": w.processed_count,
|
||||
"errors": w.error_count,
|
||||
"retries": w.retry_count,
|
||||
}
|
||||
for w in self._workers
|
||||
}
|
||||
@@ -1,173 +0,0 @@
|
||||
"""
|
||||
Processor ABC and concrete implementations.
|
||||
|
||||
Demonstrates: OOP design principles — ABC, inheritance, composition (Interview Topic 4).
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from .exceptions import ChunkChecksumError
|
||||
from .models import Chunk, ChunkResult
|
||||
|
||||
|
||||
class Processor(ABC):
|
||||
"""
|
||||
Abstract base class for chunk processors.
|
||||
|
||||
Each processor defines how a single chunk is processed.
|
||||
The Worker calls processor.process(chunk) and handles retries.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def process(self, chunk: Chunk) -> ChunkResult:
|
||||
"""Process a single chunk and return the result."""
|
||||
pass
|
||||
|
||||
|
||||
class FFmpegExtractProcessor(Processor):
|
||||
"""
|
||||
Extracts a time segment from the source file using FFmpeg stream copy.
|
||||
|
||||
Produces a playable mp4 file per chunk — no re-encoding.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to write chunk mp4 files
|
||||
"""
|
||||
|
||||
def __init__(self, output_dir: str):
|
||||
self.output_dir = output_dir
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def process(self, chunk: Chunk) -> ChunkResult:
|
||||
from core.ffmpeg.transcode import TranscodeConfig, transcode
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
output_file = str(
|
||||
Path(self.output_dir) / f"chunk_{chunk.sequence:04d}.mp4"
|
||||
)
|
||||
|
||||
config = TranscodeConfig(
|
||||
input_path=chunk.source_path,
|
||||
output_path=output_file,
|
||||
video_codec="copy",
|
||||
audio_codec="copy",
|
||||
trim_start=chunk.start_time,
|
||||
trim_end=chunk.end_time,
|
||||
)
|
||||
|
||||
transcode(config)
|
||||
|
||||
# Compute checksum of output file
|
||||
md5 = hashlib.md5()
|
||||
with open(output_file, "rb") as f:
|
||||
for block in iter(lambda: f.read(8192), b""):
|
||||
md5.update(block)
|
||||
checksum = md5.hexdigest()
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
return ChunkResult(
|
||||
sequence=chunk.sequence,
|
||||
success=True,
|
||||
checksum_valid=True,
|
||||
processing_time=elapsed,
|
||||
output_file=output_file,
|
||||
)
|
||||
|
||||
|
||||
class ChecksumProcessor(Processor):
|
||||
"""
|
||||
Validates chunk metadata consistency.
|
||||
|
||||
For time-based chunks, verifies the time range is valid.
|
||||
Raises ChunkChecksumError on invalid ranges.
|
||||
"""
|
||||
|
||||
def process(self, chunk: Chunk) -> ChunkResult:
|
||||
start = time.monotonic()
|
||||
|
||||
valid = chunk.duration > 0 and chunk.end_time > chunk.start_time
|
||||
|
||||
if not valid:
|
||||
raise ChunkChecksumError(
|
||||
sequence=chunk.sequence,
|
||||
expected="valid time range",
|
||||
actual=f"{chunk.start_time}-{chunk.end_time}",
|
||||
)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
return ChunkResult(
|
||||
sequence=chunk.sequence,
|
||||
success=True,
|
||||
checksum_valid=True,
|
||||
processing_time=elapsed,
|
||||
)
|
||||
|
||||
|
||||
class SimulatedDecodeProcessor(Processor):
|
||||
"""
|
||||
Simulates decode work by sleeping proportional to chunk duration.
|
||||
|
||||
Useful for demonstrating concurrency behavior without real FFmpeg.
|
||||
|
||||
Args:
|
||||
ms_per_second: Milliseconds of simulated work per second of chunk duration (default: 100)
|
||||
"""
|
||||
|
||||
def __init__(self, ms_per_second: float = 100.0):
|
||||
self.ms_per_second = ms_per_second
|
||||
|
||||
def process(self, chunk: Chunk) -> ChunkResult:
|
||||
start = time.monotonic()
|
||||
|
||||
sleep_time = (self.ms_per_second * chunk.duration) / 1000.0
|
||||
time.sleep(sleep_time)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
return ChunkResult(
|
||||
sequence=chunk.sequence,
|
||||
success=True,
|
||||
checksum_valid=True,
|
||||
processing_time=elapsed,
|
||||
)
|
||||
|
||||
|
||||
class CompositeProcessor(Processor):
|
||||
"""
|
||||
Chains multiple processors — runs each in sequence on the same chunk.
|
||||
|
||||
Demonstrates OOP composition pattern.
|
||||
|
||||
Args:
|
||||
processors: List of processors to chain
|
||||
"""
|
||||
|
||||
def __init__(self, processors: List[Processor]):
|
||||
if not processors:
|
||||
raise ValueError("CompositeProcessor requires at least one processor")
|
||||
self.processors = processors
|
||||
|
||||
def process(self, chunk: Chunk) -> ChunkResult:
|
||||
start = time.monotonic()
|
||||
last_result = None
|
||||
|
||||
for proc in self.processors:
|
||||
last_result = proc.process(chunk)
|
||||
if not last_result.success:
|
||||
return last_result
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
return ChunkResult(
|
||||
sequence=chunk.sequence,
|
||||
success=True,
|
||||
checksum_valid=last_result.checksum_valid if last_result else True,
|
||||
processing_time=elapsed,
|
||||
)
|
||||
@@ -1,76 +0,0 @@
|
||||
"""
|
||||
ChunkQueue — bounded, thread-safe queue with sentinel-based shutdown.
|
||||
|
||||
Demonstrates: Core data structures — queue.Queue (Interview Topic 5).
|
||||
"""
|
||||
|
||||
import queue
|
||||
from typing import Optional
|
||||
|
||||
from .models import Chunk
|
||||
|
||||
# Sentinel value to signal workers to stop
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
class ChunkQueue:
|
||||
"""
|
||||
Thread-safe bounded queue for chunks.
|
||||
|
||||
Provides backpressure: producers block when the queue is full,
|
||||
preventing unbounded memory usage.
|
||||
|
||||
Args:
|
||||
maxsize: Maximum number of chunks in the queue (default: 10)
|
||||
"""
|
||||
|
||||
def __init__(self, maxsize: int = 10):
|
||||
self._queue: queue.Queue = queue.Queue(maxsize=maxsize)
|
||||
self._closed = False
|
||||
self.maxsize = maxsize
|
||||
|
||||
def put(self, chunk: Chunk, timeout: Optional[float] = None) -> None:
|
||||
"""
|
||||
Add a chunk to the queue. Blocks if full (backpressure).
|
||||
|
||||
Args:
|
||||
chunk: The chunk to enqueue
|
||||
timeout: Max seconds to wait (None = block forever)
|
||||
|
||||
Raises:
|
||||
queue.Full: If timeout expires while queue is full
|
||||
"""
|
||||
self._queue.put(chunk, timeout=timeout)
|
||||
|
||||
def get(self, timeout: Optional[float] = None) -> Optional[Chunk]:
|
||||
"""
|
||||
Get next chunk from queue. Returns None if queue is closed.
|
||||
|
||||
Args:
|
||||
timeout: Max seconds to wait (None = block forever)
|
||||
|
||||
Returns:
|
||||
Chunk or None (if sentinel received, meaning queue is closed)
|
||||
|
||||
Raises:
|
||||
queue.Empty: If timeout expires while queue is empty
|
||||
"""
|
||||
item = self._queue.get(timeout=timeout)
|
||||
if item is _SENTINEL:
|
||||
# Re-put sentinel so other workers also see it
|
||||
self._queue.put(_SENTINEL)
|
||||
return None
|
||||
return item
|
||||
|
||||
def close(self) -> None:
|
||||
"""Signal all consumers to stop by inserting a sentinel."""
|
||||
self._closed = True
|
||||
self._queue.put(_SENTINEL)
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""Current number of items in the queue (approximate)."""
|
||||
return self._queue.qsize()
|
||||
@@ -1,143 +0,0 @@
|
||||
"""
|
||||
Worker — pulls chunks from queue, processes with retry logic.
|
||||
|
||||
Demonstrates:
|
||||
- Exception handling and resilient code (Interview Topic 7)
|
||||
- Concurrency (Interview Topic 2) — workers run in thread pool
|
||||
"""
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import time
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from .exceptions import ProcessorFailureError
|
||||
from .models import Chunk, ChunkResult
|
||||
from .processor import Processor
|
||||
from .queue import ChunkQueue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Worker:
|
||||
"""
|
||||
Processes chunks from a queue with retry and exponential backoff.
|
||||
|
||||
Args:
|
||||
worker_id: Identifier for this worker (e.g. "worker-0")
|
||||
chunk_queue: Source queue to pull chunks from
|
||||
processor: Processor instance to use
|
||||
max_retries: Maximum retry attempts per chunk (default: 3)
|
||||
event_callback: Optional callback for real-time status updates
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_id: str,
|
||||
chunk_queue: ChunkQueue,
|
||||
processor: Processor,
|
||||
max_retries: int = 3,
|
||||
event_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None,
|
||||
):
|
||||
self.worker_id = worker_id
|
||||
self.chunk_queue = chunk_queue
|
||||
self.processor = processor
|
||||
self.max_retries = max_retries
|
||||
self.event_callback = event_callback
|
||||
self.processed_count = 0
|
||||
self.error_count = 0
|
||||
self.retry_count = 0
|
||||
|
||||
def _emit(self, event_type: str, data: Dict[str, Any]) -> None:
|
||||
"""Emit an event if callback is registered."""
|
||||
if self.event_callback:
|
||||
self.event_callback(event_type, {"worker_id": self.worker_id, **data})
|
||||
|
||||
def _process_with_retry(self, chunk: Chunk) -> ChunkResult:
|
||||
"""
|
||||
Process a chunk with exponential backoff retry.
|
||||
|
||||
Retry delays: 0.1s, 0.2s, 0.4s, ... (doubles each attempt)
|
||||
"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
if attempt > 0:
|
||||
backoff = 0.1 * (2 ** (attempt - 1))
|
||||
self._emit("chunk_retry", {
|
||||
"sequence": chunk.sequence,
|
||||
"attempt": attempt,
|
||||
"backoff": backoff,
|
||||
})
|
||||
time.sleep(backoff)
|
||||
self.retry_count += 1
|
||||
|
||||
result = self.processor.process(chunk)
|
||||
result.retries = attempt
|
||||
result.worker_id = self.worker_id
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.warning(
|
||||
f"{self.worker_id}: chunk {chunk.sequence} "
|
||||
f"attempt {attempt + 1}/{self.max_retries + 1} failed: {e}"
|
||||
)
|
||||
|
||||
# All retries exhausted
|
||||
self.error_count += 1
|
||||
self._emit("chunk_error", {
|
||||
"sequence": chunk.sequence,
|
||||
"error": str(last_error),
|
||||
"retries": self.max_retries,
|
||||
})
|
||||
|
||||
return ChunkResult(
|
||||
sequence=chunk.sequence,
|
||||
success=False,
|
||||
processing_time=0.0,
|
||||
error=str(last_error),
|
||||
retries=self.max_retries,
|
||||
worker_id=self.worker_id,
|
||||
)
|
||||
|
||||
def run(self) -> list[ChunkResult]:
|
||||
"""
|
||||
Main worker loop — pull chunks and process until queue is closed.
|
||||
|
||||
Returns:
|
||||
List of ChunkResults processed by this worker
|
||||
"""
|
||||
results = []
|
||||
self._emit("worker_status", {"state": "idle"})
|
||||
|
||||
while True:
|
||||
try:
|
||||
chunk = self.chunk_queue.get(timeout=1.0)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
if chunk is None: # Sentinel received
|
||||
break
|
||||
|
||||
self._emit("chunk_processing", {
|
||||
"sequence": chunk.sequence,
|
||||
"state": "processing",
|
||||
"queue_size": self.chunk_queue.qsize(),
|
||||
})
|
||||
|
||||
result = self._process_with_retry(chunk)
|
||||
results.append(result)
|
||||
self.processed_count += 1
|
||||
|
||||
self._emit("chunk_done", {
|
||||
"sequence": chunk.sequence,
|
||||
"success": result.success,
|
||||
"processing_time": result.processing_time,
|
||||
"retries": result.retries,
|
||||
"queue_size": self.chunk_queue.qsize(),
|
||||
})
|
||||
|
||||
self._emit("worker_status", {"state": "stopped"})
|
||||
return results
|
||||
@@ -13,7 +13,7 @@ Basic CRUD (create, get, update, delete) goes directly through the session:
|
||||
|
||||
from .connection import get_session, create_tables
|
||||
|
||||
from .tables import MediaAsset, Job, Timeline, Checkpoint, Brand
|
||||
from .models import MediaAsset, Job, Timeline, Checkpoint, Brand
|
||||
|
||||
from .assets import list_assets, get_asset_filenames
|
||||
from .job import list_jobs
|
||||
|
||||
@@ -7,7 +7,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from .tables import MediaAsset
|
||||
from .models import MediaAsset
|
||||
|
||||
|
||||
def list_assets(session: Session, status: Optional[str] = None, search: Optional[str] = None) -> list[MediaAsset]:
|
||||
|
||||
@@ -7,7 +7,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from .tables import Brand
|
||||
from .models import Brand
|
||||
|
||||
|
||||
def get_or_create_brand(session: Session, canonical_name: str,
|
||||
|
||||
@@ -6,7 +6,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from .tables import Checkpoint
|
||||
from .models import Checkpoint
|
||||
|
||||
|
||||
def get_latest_checkpoint(session: Session, timeline_id: UUID, parent_id: UUID | None = None) -> Checkpoint | None:
|
||||
|
||||
@@ -30,5 +30,5 @@ def get_session() -> Session:
|
||||
def create_tables():
|
||||
"""Create all SQLModel tables."""
|
||||
from sqlmodel import SQLModel
|
||||
from . import tables # noqa — registers all table classes
|
||||
from . import models # noqa — registers all table classes
|
||||
SQLModel.metadata.create_all(get_engine())
|
||||
|
||||
142
core/db/fixtures/soccer_broadcast.json
Normal file
142
core/db/fixtures/soccer_broadcast.json
Normal file
@@ -0,0 +1,142 @@
|
||||
{
|
||||
"name": "soccer_broadcast",
|
||||
"pipeline": {
|
||||
"name": "soccer_broadcast",
|
||||
"profile_name": "soccer_broadcast",
|
||||
"stages": [
|
||||
{
|
||||
"name": "extract_frames",
|
||||
"branch": "trunk"
|
||||
},
|
||||
{
|
||||
"name": "filter_scenes",
|
||||
"branch": "trunk"
|
||||
},
|
||||
{
|
||||
"name": "field_segmentation",
|
||||
"branch": "trunk"
|
||||
},
|
||||
{
|
||||
"name": "detect_edges",
|
||||
"branch": "hoarding"
|
||||
},
|
||||
{
|
||||
"name": "detect_objects",
|
||||
"branch": "objects"
|
||||
},
|
||||
{
|
||||
"name": "preprocess"
|
||||
},
|
||||
{
|
||||
"name": "run_ocr"
|
||||
},
|
||||
{
|
||||
"name": "match_brands"
|
||||
},
|
||||
{
|
||||
"name": "escalate_vlm"
|
||||
},
|
||||
{
|
||||
"name": "escalate_cloud"
|
||||
},
|
||||
{
|
||||
"name": "compile_report"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": "extract_frames",
|
||||
"target": "filter_scenes"
|
||||
},
|
||||
{
|
||||
"source": "filter_scenes",
|
||||
"target": "field_segmentation"
|
||||
},
|
||||
{
|
||||
"source": "field_segmentation",
|
||||
"target": "detect_edges"
|
||||
},
|
||||
{
|
||||
"source": "field_segmentation",
|
||||
"target": "detect_objects"
|
||||
},
|
||||
{
|
||||
"source": "detect_edges",
|
||||
"target": "preprocess"
|
||||
},
|
||||
{
|
||||
"source": "detect_objects",
|
||||
"target": "preprocess"
|
||||
},
|
||||
{
|
||||
"source": "preprocess",
|
||||
"target": "run_ocr"
|
||||
},
|
||||
{
|
||||
"source": "run_ocr",
|
||||
"target": "match_brands"
|
||||
},
|
||||
{
|
||||
"source": "match_brands",
|
||||
"target": "escalate_vlm"
|
||||
},
|
||||
{
|
||||
"source": "escalate_vlm",
|
||||
"target": "escalate_cloud"
|
||||
},
|
||||
{
|
||||
"source": "escalate_cloud",
|
||||
"target": "compile_report"
|
||||
}
|
||||
]
|
||||
},
|
||||
"configs": {
|
||||
"extract_frames": {
|
||||
"fps": 2.0,
|
||||
"max_frames": 500
|
||||
},
|
||||
"filter_scenes": {
|
||||
"hamming_threshold": 8,
|
||||
"enabled": true
|
||||
},
|
||||
"field_segmentation": {
|
||||
"enabled": true,
|
||||
"hue_low": 30,
|
||||
"hue_high": 85,
|
||||
"sat_low": 30,
|
||||
"sat_high": 255,
|
||||
"val_low": 30,
|
||||
"val_high": 255,
|
||||
"morph_kernel": 15,
|
||||
"min_area_ratio": 0.05
|
||||
},
|
||||
"detect_edges": {
|
||||
"enabled": true,
|
||||
"edge_canny_low": 50,
|
||||
"edge_canny_high": 150,
|
||||
"edge_hough_threshold": 80,
|
||||
"edge_hough_min_length": 100,
|
||||
"edge_hough_max_gap": 10,
|
||||
"edge_pair_max_distance": 200,
|
||||
"edge_pair_min_distance": 15
|
||||
},
|
||||
"detect_objects": {
|
||||
"model_name": "yolov8n.pt",
|
||||
"confidence_threshold": 0.3,
|
||||
"target_classes": []
|
||||
},
|
||||
"run_ocr": {
|
||||
"languages": [
|
||||
"en",
|
||||
"es"
|
||||
],
|
||||
"min_confidence": 0.5
|
||||
},
|
||||
"match_brands": {
|
||||
"fuzzy_threshold": 75
|
||||
},
|
||||
"escalate_vlm": {
|
||||
"vlm_prompt_template": "Identify the brand or sponsor visible in this cropped region from a soccer broadcast.{hint}{text} Respond with: brand, confidence (0-1), reasoning."
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from .tables import Job
|
||||
from .models import Job
|
||||
|
||||
|
||||
def list_jobs(session: Session, parent_id: Optional[UUID] = None, status: Optional[str] = None) -> list[Job]:
|
||||
|
||||
@@ -44,7 +44,7 @@ class SourceType(str, Enum):
|
||||
|
||||
class MediaAsset(SQLModel, table=True):
|
||||
"""A video/audio file registered in the system."""
|
||||
__tablename__ = "media_assets"
|
||||
__tablename__ = "media_asset"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
filename: str
|
||||
@@ -67,7 +67,7 @@ class MediaAsset(SQLModel, table=True):
|
||||
|
||||
class TranscodePreset(SQLModel, table=True):
|
||||
"""A reusable transcoding configuration (like Handbrake presets)."""
|
||||
__tablename__ = "transcode_presets"
|
||||
__tablename__ = "transcode_preset"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
name: str
|
||||
@@ -90,12 +90,13 @@ class TranscodePreset(SQLModel, table=True):
|
||||
|
||||
class Job(SQLModel, table=True):
|
||||
"""A pipeline job."""
|
||||
__tablename__ = "jobs"
|
||||
__tablename__ = "job"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
source_asset_id: UUID = Field(index=True)
|
||||
video_path: str
|
||||
profile_name: str = "soccer_broadcast"
|
||||
timeline_id: Optional[UUID] = None
|
||||
parent_id: Optional[UUID] = None
|
||||
run_type: RunType = "initial"
|
||||
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
@@ -107,7 +108,6 @@ class Job(SQLModel, table=True):
|
||||
brands_found: int = 0
|
||||
cloud_llm_calls: int = 0
|
||||
estimated_cost_usd: float = 0.0
|
||||
celery_task_id: Optional[str] = None
|
||||
priority: int = 0
|
||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
started_at: Optional[datetime] = None
|
||||
@@ -115,7 +115,7 @@ class Job(SQLModel, table=True):
|
||||
|
||||
class Timeline(SQLModel, table=True):
|
||||
"""The frame sequence from a source video."""
|
||||
__tablename__ = "timelines"
|
||||
__tablename__ = "timeline"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
source_asset_id: Optional[UUID] = Field(default=None, index=True)
|
||||
@@ -129,10 +129,11 @@ class Timeline(SQLModel, table=True):
|
||||
|
||||
class Checkpoint(SQLModel, table=True):
|
||||
"""A snapshot of pipeline state on a timeline."""
|
||||
__tablename__ = "checkpoints"
|
||||
__tablename__ = "checkpoint"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
timeline_id: UUID
|
||||
job_id: Optional[UUID] = Field(default=None, index=True)
|
||||
parent_id: Optional[UUID] = None
|
||||
stage_outputs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
@@ -143,7 +144,7 @@ class Checkpoint(SQLModel, table=True):
|
||||
|
||||
class Brand(SQLModel, table=True):
|
||||
"""A brand discovered or registered in the system."""
|
||||
__tablename__ = "brands"
|
||||
__tablename__ = "brand"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
canonical_name: str = Field(index=True)
|
||||
@@ -154,3 +155,12 @@ class Brand(SQLModel, table=True):
|
||||
total_airings: int = 0
|
||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
|
||||
class Profile(SQLModel, table=True):
|
||||
"""A content type profile."""
|
||||
__tablename__ = "profile"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
name: str
|
||||
pipeline: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
configs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
|
||||
43
core/db/seed.py
Normal file
43
core/db/seed.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Seed data — insert initial profile rows if they don't exist.
|
||||
|
||||
Called on startup after create_tables().
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SEED_DIR = Path(__file__).parent / "fixtures"
|
||||
|
||||
|
||||
def seed_profiles():
|
||||
"""Insert seed profiles from JSON fixtures if not already present."""
|
||||
from .connection import get_session
|
||||
from .models import Profile
|
||||
|
||||
fixtures = list(SEED_DIR.glob("*.json"))
|
||||
if not fixtures:
|
||||
return
|
||||
|
||||
with get_session() as session:
|
||||
for f in fixtures:
|
||||
data = json.loads(f.read_text())
|
||||
name = data["name"]
|
||||
|
||||
existing = session.query(Profile).filter(Profile.name == name).first()
|
||||
if existing:
|
||||
logger.debug("Profile %s already exists, skipping seed", name)
|
||||
continue
|
||||
|
||||
profile = Profile(
|
||||
name=name,
|
||||
pipeline=data.get("pipeline", {}),
|
||||
configs=data.get("configs", {}),
|
||||
)
|
||||
session.add(profile)
|
||||
logger.info("Seeded profile: %s", name)
|
||||
|
||||
session.commit()
|
||||
@@ -1,96 +0,0 @@
|
||||
"""
|
||||
SQLModel table definitions.
|
||||
|
||||
Generated by modelgen from core/schema/models/. Do not edit directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import JSON
|
||||
from sqlmodel import Column, Field, SQLModel
|
||||
|
||||
|
||||
class MediaAsset(SQLModel, table=True):
|
||||
__tablename__ = "media_asset"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
filename: str
|
||||
path: str
|
||||
status: str = "pending"
|
||||
size_bytes: int = 0
|
||||
duration_seconds: float = 0.0
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
fps: Optional[float] = None
|
||||
codec: Optional[str] = None
|
||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class Job(SQLModel, table=True):
|
||||
__tablename__ = "job"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
source_asset_id: UUID = Field(index=True)
|
||||
video_path: str
|
||||
profile_name: str = "soccer_broadcast"
|
||||
parent_id: Optional[UUID] = Field(default=None, index=True)
|
||||
run_type: str = "initial"
|
||||
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
status: str = "pending"
|
||||
current_stage: Optional[str] = None
|
||||
progress: float = 0.0
|
||||
error_message: Optional[str] = None
|
||||
total_detections: int = 0
|
||||
brands_found: int = 0
|
||||
cloud_llm_calls: int = 0
|
||||
estimated_cost_usd: float = 0.0
|
||||
priority: int = 0
|
||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class Timeline(SQLModel, table=True):
|
||||
__tablename__ = "timeline"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
source_asset_id: Optional[UUID] = Field(default=None, index=True)
|
||||
source_video: str = ""
|
||||
profile_name: str = ""
|
||||
fps: float = 2.0
|
||||
frames_prefix: str = ""
|
||||
frames_manifest: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
frames_meta: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
|
||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class Checkpoint(SQLModel, table=True):
|
||||
__tablename__ = "checkpoint"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
timeline_id: UUID = Field(index=True)
|
||||
parent_id: Optional[UUID] = Field(default=None, index=True)
|
||||
stage_outputs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
stats: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
is_scenario: bool = False
|
||||
scenario_label: str = ""
|
||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class Brand(SQLModel, table=True):
|
||||
__tablename__ = "brand"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
canonical_name: str = Field(index=True)
|
||||
aliases: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
|
||||
source: str = "ocr"
|
||||
confirmed: bool = False
|
||||
airings: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
|
||||
total_airings: int = 0
|
||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
@@ -16,4 +16,4 @@ from .storage import (
|
||||
load_stage_output,
|
||||
)
|
||||
from .frames import save_frames, load_frames
|
||||
from .runner_bridge import checkpoint_after_stage, reset_checkpoint_state
|
||||
from .runner_bridge import checkpoint_after_stage, reset_checkpoint_state, get_timeline_id
|
||||
@@ -9,7 +9,7 @@ import tempfile
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from detect.models import Frame
|
||||
from core.detect.models import Frame
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -11,7 +11,7 @@ import logging
|
||||
|
||||
import uuid
|
||||
|
||||
from detect import emit
|
||||
from core.detect import emit
|
||||
# TODO: migrate to Timeline/Branch/Checkpoint model
|
||||
# These old functions no longer exist — replay needs rework
|
||||
def _not_migrated(*args, **kwargs):
|
||||
@@ -19,66 +19,14 @@ def _not_migrated(*args, **kwargs):
|
||||
|
||||
load_checkpoint = _not_migrated
|
||||
list_checkpoints = _not_migrated
|
||||
from detect.graph import NODES, build_graph
|
||||
from core.detect.graph import NODES, build_graph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OverrideProfile:
|
||||
"""
|
||||
Wraps a ContentTypeProfile and patches config methods with overrides.
|
||||
|
||||
Override dict structure:
|
||||
{
|
||||
"frame_extraction": {"fps": 1.0},
|
||||
"scene_filter": {"hamming_threshold": 12},
|
||||
"region_analysis": {"edge_canny_low": 30, "edge_canny_high": 120},
|
||||
"detection": {"confidence_threshold": 0.5},
|
||||
"ocr": {"languages": ["en", "es"], "min_confidence": 0.3},
|
||||
"resolver": {"fuzzy_threshold": 60},
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, base, overrides: dict):
|
||||
self._base = base
|
||||
self._overrides = overrides
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._base, name)
|
||||
|
||||
def _patch(self, config, key: str):
|
||||
patches = self._overrides.get(key, {})
|
||||
for k, v in patches.items():
|
||||
if hasattr(config, k):
|
||||
setattr(config, k, v)
|
||||
return config
|
||||
|
||||
def frame_extraction_config(self):
|
||||
return self._patch(self._base.frame_extraction_config(), "frame_extraction")
|
||||
|
||||
def scene_filter_config(self):
|
||||
return self._patch(self._base.scene_filter_config(), "scene_filter")
|
||||
|
||||
def region_analysis_config(self):
|
||||
return self._patch(self._base.region_analysis_config(), "region_analysis")
|
||||
|
||||
def detection_config(self):
|
||||
return self._patch(self._base.detection_config(), "detection")
|
||||
|
||||
def ocr_config(self):
|
||||
return self._patch(self._base.ocr_config(), "ocr")
|
||||
|
||||
def resolver_config(self):
|
||||
return self._patch(self._base.resolver_config(), "resolver")
|
||||
|
||||
def vlm_prompt(self, crop_context):
|
||||
return self._base.vlm_prompt(crop_context)
|
||||
|
||||
def aggregate(self, detections):
|
||||
return self._base.aggregate(detections)
|
||||
|
||||
def auxiliary_detections(self, source):
|
||||
return self._base.auxiliary_detections(source)
|
||||
# OverrideProfile removed — config overrides are now handled by dict merging
|
||||
# in _load_profile() (nodes.py) and replay_single_stage (below).
|
||||
|
||||
|
||||
def replay_from(
|
||||
@@ -183,10 +131,16 @@ def replay_single_stage(
|
||||
state = load_checkpoint(job_id, previous_stage)
|
||||
|
||||
# Build profile with overrides
|
||||
from detect.profiles import get_profile
|
||||
from core.detect.profile import get_profile, get_stage_config
|
||||
profile = get_profile(state.get("profile_name", "soccer_broadcast"))
|
||||
if config_overrides:
|
||||
profile = OverrideProfile(profile, config_overrides)
|
||||
merged_configs = dict(profile.get("configs", {}))
|
||||
for sname, soverrides in config_overrides.items():
|
||||
if sname in merged_configs:
|
||||
merged_configs[sname] = {**merged_configs[sname], **soverrides}
|
||||
else:
|
||||
merged_configs[sname] = soverrides
|
||||
profile = {**profile, "configs": merged_configs}
|
||||
|
||||
# Run the stage function directly (not through the graph)
|
||||
if stage == "detect_edges":
|
||||
@@ -207,9 +161,11 @@ def _replay_detect_edges(
|
||||
) -> dict:
|
||||
"""Run edge detection on checkpoint frames, optionally with debug overlays."""
|
||||
import os
|
||||
from detect.stages.edge_detector import detect_edge_regions
|
||||
from core.detect.stages.edge_detector import detect_edge_regions
|
||||
|
||||
config = profile.region_analysis_config()
|
||||
from core.detect.profile import get_stage_config
|
||||
from core.detect.stages.models import RegionAnalysisConfig
|
||||
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
||||
frames = state.get("filtered_frames", [])
|
||||
|
||||
if frame_refs:
|
||||
@@ -231,7 +187,7 @@ def _replay_detect_edges(
|
||||
if debug and frames:
|
||||
debug_data = {}
|
||||
if inference_url:
|
||||
from detect.inference import InferenceClient
|
||||
from core.detect.inference import InferenceClient
|
||||
client = InferenceClient(base_url=inference_url, job_id=job_id)
|
||||
for frame in frames:
|
||||
dr = client.detect_edges_debug(
|
||||
@@ -252,7 +208,7 @@ def _replay_detect_edges(
|
||||
}
|
||||
else:
|
||||
# Local mode — import GPU module directly
|
||||
from detect.stages.edge_detector import _load_cv_edges
|
||||
from core.detect.stages.edge_detector import _load_cv_edges
|
||||
edges_mod = _load_cv_edges()
|
||||
for frame in frames:
|
||||
dr = edges_mod.detect_edges_debug(
|
||||
97
core/detect/checkpoint/runner_bridge.py
Normal file
97
core/detect/checkpoint/runner_bridge.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Runner bridge — checkpoint hook called by PipelineRunner after each stage.
|
||||
|
||||
Owns the per-job state (timeline, frame manifest, checkpoint chain) that
|
||||
the runner shouldn't know about.
|
||||
|
||||
Timeline and Job are independent entities:
|
||||
- One Timeline can serve multiple Jobs (re-run with different params)
|
||||
- One Job operates on one Timeline (set after frame extraction)
|
||||
- Checkpoints belong to Timeline, tagged with the Job that created them
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Per-job state
|
||||
_timeline_id: dict[str, str] = {}
|
||||
_frames_manifest: dict[str, dict[int, str]] = {}
|
||||
_latest_checkpoint: dict[str, str] = {}
|
||||
|
||||
|
||||
def reset_checkpoint_state(job_id: str):
|
||||
"""Clean up per-job checkpoint state. Called when pipeline finishes."""
|
||||
_timeline_id.pop(job_id, None)
|
||||
_frames_manifest.pop(job_id, None)
|
||||
_latest_checkpoint.pop(job_id, None)
|
||||
|
||||
|
||||
def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict):
|
||||
"""
|
||||
Save a checkpoint after a stage completes.
|
||||
|
||||
Called by the runner. Handles:
|
||||
- Timeline creation (once, on extract_frames)
|
||||
- Frame upload (via create_timeline)
|
||||
- Stage output serialization (via stage registry)
|
||||
- Checkpoint chain (parent → child)
|
||||
"""
|
||||
if not job_id:
|
||||
return
|
||||
|
||||
from .storage import create_timeline, save_stage_output
|
||||
from core.detect.stages.base import _REGISTRY
|
||||
|
||||
merged = {**state, **result}
|
||||
|
||||
# On extract_frames: create Timeline + upload frames + root checkpoint
|
||||
if stage_name == "extract_frames" and job_id not in _timeline_id:
|
||||
frames = merged.get("frames", [])
|
||||
video_path = merged.get("video_path", "")
|
||||
profile_name = merged.get("profile_name", "")
|
||||
|
||||
tid, cid = create_timeline(
|
||||
source_video=video_path,
|
||||
profile_name=profile_name,
|
||||
frames=frames,
|
||||
)
|
||||
_timeline_id[job_id] = tid
|
||||
_latest_checkpoint[job_id] = cid
|
||||
logger.info("Job %s → Timeline %s (root checkpoint %s)", job_id, tid, cid)
|
||||
|
||||
# Emit timeline_id via SSE so the UI can use it for checkpoint loads
|
||||
from core.detect import emit
|
||||
emit.log(job_id, "Checkpoint", "INFO", f"timeline_id={tid}")
|
||||
return
|
||||
|
||||
# For subsequent stages: save checkpoint on the timeline
|
||||
tid = _timeline_id.get(job_id)
|
||||
if not tid:
|
||||
logger.warning("No timeline for job %s, skipping checkpoint", job_id)
|
||||
return
|
||||
|
||||
# Serialize stage output using the stage's serialize_fn if available
|
||||
stage_cls = _REGISTRY.get(stage_name)
|
||||
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
||||
if serialize_fn:
|
||||
output_json = serialize_fn(merged, job_id)
|
||||
else:
|
||||
output_json = {}
|
||||
|
||||
parent_id = _latest_checkpoint.get(job_id)
|
||||
new_checkpoint_id = save_stage_output(
|
||||
timeline_id=tid,
|
||||
parent_checkpoint_id=parent_id,
|
||||
stage_name=stage_name,
|
||||
output_json=output_json,
|
||||
job_id=job_id,
|
||||
)
|
||||
_latest_checkpoint[job_id] = new_checkpoint_id
|
||||
|
||||
|
||||
def get_timeline_id(job_id: str) -> str | None:
|
||||
"""Get the timeline_id for a running job. Used by the UI to load checkpoints."""
|
||||
return _timeline_id.get(job_id)
|
||||
@@ -28,7 +28,7 @@ def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
|
||||
Calls each registered stage's serialize_fn for stage-owned data.
|
||||
Envelope fields (job_id, etc.) are copied directly.
|
||||
"""
|
||||
from detect.stages.base import _REGISTRY
|
||||
from core.detect.stages.base import _REGISTRY
|
||||
|
||||
checkpoint = {}
|
||||
|
||||
@@ -64,7 +64,7 @@ def deserialize_state(checkpoint: dict, frames: list) -> dict:
|
||||
|
||||
Calls each stage's deserialize_fn to restore stage-owned data.
|
||||
"""
|
||||
from detect.stages.base import _REGISTRY
|
||||
from core.detect.stages.base import _REGISTRY
|
||||
|
||||
frame_map = {f.sequence: f for f in frames}
|
||||
|
||||
@@ -33,7 +33,7 @@ def create_timeline(
|
||||
|
||||
Returns (timeline_id, checkpoint_id).
|
||||
"""
|
||||
from core.db.tables import Timeline, Checkpoint
|
||||
from core.db.models import Timeline, Checkpoint
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
@@ -81,7 +81,7 @@ def create_timeline(
|
||||
|
||||
def get_timeline_frames(timeline_id: str) -> list:
|
||||
"""Load frames from a timeline (from MinIO) as Frame objects."""
|
||||
from core.db.tables import Timeline
|
||||
from core.db.models import Timeline
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
@@ -96,7 +96,7 @@ def get_timeline_frames(timeline_id: str) -> list:
|
||||
|
||||
def get_timeline_frames_b64(timeline_id: str) -> list[dict]:
|
||||
"""Load frames as base64 JPEG (lightweight, no numpy)."""
|
||||
from core.db.tables import Timeline
|
||||
from core.db.models import Timeline
|
||||
from core.db.connection import get_session
|
||||
from .frames import load_frames_b64
|
||||
|
||||
@@ -123,6 +123,7 @@ def save_stage_output(
|
||||
stats: dict | None = None,
|
||||
is_scenario: bool = False,
|
||||
scenario_label: str = "",
|
||||
job_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Save a stage's output as a new checkpoint (child of parent).
|
||||
@@ -130,7 +131,7 @@ def save_stage_output(
|
||||
Carries forward stage outputs from parent + adds the new one.
|
||||
Returns the new checkpoint ID.
|
||||
"""
|
||||
from core.db.tables import Checkpoint
|
||||
from core.db.models import Checkpoint
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
@@ -146,6 +147,7 @@ def save_stage_output(
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
timeline_id=UUID(timeline_id),
|
||||
job_id=UUID(job_id) if job_id else None,
|
||||
parent_id=UUID(parent_checkpoint_id) if parent_checkpoint_id else None,
|
||||
stage_outputs={**parent_outputs, stage_name: output_json},
|
||||
config_overrides={**parent_config, **(config_overrides or {})},
|
||||
@@ -165,7 +167,7 @@ def save_stage_output(
|
||||
|
||||
def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None:
|
||||
"""Load a stage's output from a checkpoint."""
|
||||
from core.db.tables import Checkpoint
|
||||
from core.db.models import Checkpoint
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
@@ -16,8 +16,8 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from detect.events import push_detect_event
|
||||
from detect.models import PipelineStats
|
||||
from core.detect.events import push_detect_event
|
||||
from core.detect.models import PipelineStats
|
||||
|
||||
# Log level ordering for comparison
|
||||
_LEVEL_ORDER = {"DEBUG": 0, "INFO": 1, "WARN": 2, "ERROR": 3}
|
||||
@@ -4,8 +4,8 @@ Graph event emission — node state tracking + SSE graph_update events.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from detect import emit
|
||||
from detect.state import DetectState
|
||||
from core.detect import emit
|
||||
from core.detect.state import DetectState
|
||||
|
||||
|
||||
# Track node states across pipeline runs
|
||||
@@ -1,28 +1,39 @@
|
||||
"""
|
||||
Pipeline node functions — one per stage.
|
||||
|
||||
Each node: reads state, runs stage logic, emits transitions, returns output dict.
|
||||
Each node: reads state, gets config from profile dict, runs stage logic,
|
||||
emits transitions, returns output dict.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from detect import emit
|
||||
from detect.models import PipelineStats
|
||||
from detect.profiles import SoccerBroadcastProfile
|
||||
from detect.state import DetectState
|
||||
from detect.stages.frame_extractor import extract_frames
|
||||
from detect.stages.scene_filter import scene_filter
|
||||
from detect.stages.edge_detector import detect_edge_regions
|
||||
from detect.stages.yolo_detector import detect_objects
|
||||
from detect.stages.preprocess import preprocess_regions
|
||||
from detect.stages.ocr_stage import run_ocr
|
||||
from detect.stages.brand_resolver import resolve_brands
|
||||
from detect.stages.vlm_local import escalate_vlm
|
||||
from detect.stages.vlm_cloud import escalate_cloud
|
||||
from detect.stages.aggregator import compile_report
|
||||
from detect.tracing import trace_node, flush as flush_traces
|
||||
from core.detect import emit
|
||||
from core.detect.models import CropContext, PipelineStats
|
||||
from core.detect.profile import get_profile, get_stage_config, build_vlm_prompt, aggregate_detections
|
||||
from core.detect.stages.models import (
|
||||
DetectionConfig,
|
||||
FieldSegmentationConfig,
|
||||
FrameExtractionConfig,
|
||||
OCRConfig,
|
||||
RegionAnalysisConfig,
|
||||
ResolverConfig,
|
||||
SceneFilterConfig,
|
||||
)
|
||||
from core.detect.state import DetectState
|
||||
from core.detect.stages.frame_extractor import extract_frames
|
||||
from core.detect.stages.scene_filter import scene_filter
|
||||
from core.detect.stages.field_segmentation import run_field_segmentation
|
||||
from core.detect.stages.edge_detector import detect_edge_regions
|
||||
from core.detect.stages.yolo_detector import detect_objects
|
||||
from core.detect.stages.preprocess import preprocess_regions
|
||||
from core.detect.stages.ocr_stage import run_ocr
|
||||
from core.detect.stages.brand_resolver import resolve_brands
|
||||
from core.detect.stages.vlm_local import escalate_vlm
|
||||
from core.detect.stages.vlm_cloud import escalate_cloud
|
||||
from core.detect.stages.aggregator import compile_report
|
||||
from core.detect.tracing import trace_node, flush as flush_traces
|
||||
|
||||
from .events import emit_transition
|
||||
|
||||
@@ -31,6 +42,7 @@ INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
|
||||
NODES = [
|
||||
"extract_frames",
|
||||
"filter_scenes",
|
||||
"field_segmentation",
|
||||
"detect_edges",
|
||||
"detect_objects",
|
||||
"preprocess",
|
||||
@@ -42,17 +54,21 @@ NODES = [
|
||||
]
|
||||
|
||||
|
||||
def _get_profile(state: DetectState):
|
||||
def _load_profile(state: DetectState) -> dict:
|
||||
"""Load profile dict, apply config overrides if present."""
|
||||
name = state.get("profile_name", "soccer_broadcast")
|
||||
if name == "soccer_broadcast":
|
||||
profile = SoccerBroadcastProfile()
|
||||
else:
|
||||
raise ValueError(f"Unknown profile: {name}")
|
||||
profile = get_profile(name)
|
||||
|
||||
overrides = state.get("config_overrides")
|
||||
if overrides:
|
||||
from detect.checkpoint.replay import OverrideProfile
|
||||
profile = OverrideProfile(profile, overrides)
|
||||
# Merge overrides into a copy of the profile configs
|
||||
merged_configs = dict(profile.get("configs", {}))
|
||||
for stage_name, stage_overrides in overrides.items():
|
||||
if stage_name in merged_configs:
|
||||
merged_configs[stage_name] = {**merged_configs[stage_name], **stage_overrides}
|
||||
else:
|
||||
merged_configs[stage_name] = stage_overrides
|
||||
profile = {**profile, "configs": merged_configs}
|
||||
|
||||
return profile
|
||||
|
||||
@@ -70,16 +86,16 @@ def node_extract_frames(state: DetectState) -> dict:
|
||||
|
||||
source_asset_id = state.get("source_asset_id")
|
||||
if source_asset_id and not state.get("session_brands"):
|
||||
from detect.stages.brand_resolver import build_session_dict
|
||||
from core.detect.stages.brand_resolver import build_session_dict
|
||||
session_brands = build_session_dict(source_asset_id)
|
||||
state["session_brands"] = session_brands
|
||||
|
||||
_emit(state, "extract_frames", "running")
|
||||
|
||||
with trace_node(state, "extract_frames") as span:
|
||||
profile = _get_profile(state)
|
||||
config = profile.frame_extraction_config()
|
||||
frames = extract_frames(state["video_path"], config, job_id=state.get("job_id"))
|
||||
profile = _load_profile(state)
|
||||
config = FrameExtractionConfig(**get_stage_config(profile, "extract_frames"))
|
||||
frames = extract_frames(state["video_path"], config, job_id=job_id)
|
||||
span.set_output({"frames_extracted": len(frames)})
|
||||
|
||||
_emit(state, "extract_frames", "done")
|
||||
@@ -90,8 +106,8 @@ def node_filter_scenes(state: DetectState) -> dict:
|
||||
_emit(state, "filter_scenes", "running")
|
||||
|
||||
with trace_node(state, "filter_scenes") as span:
|
||||
profile = _get_profile(state)
|
||||
config = profile.scene_filter_config()
|
||||
profile = _load_profile(state)
|
||||
config = SceneFilterConfig(**get_stage_config(profile, "filter_scenes"))
|
||||
frames = state.get("frames", [])
|
||||
kept = scene_filter(frames, config, job_id=state.get("job_id"))
|
||||
span.set_output({"frames_in": len(frames), "frames_kept": len(kept)})
|
||||
@@ -103,17 +119,42 @@ def node_filter_scenes(state: DetectState) -> dict:
|
||||
return {"filtered_frames": kept, "stats": stats}
|
||||
|
||||
|
||||
def node_field_segmentation(state: DetectState) -> dict:
|
||||
_emit(state, "field_segmentation", "running")
|
||||
|
||||
with trace_node(state, "field_segmentation") as span:
|
||||
profile = _load_profile(state)
|
||||
config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation"))
|
||||
frames = state.get("filtered_frames", [])
|
||||
job_id = state.get("job_id")
|
||||
|
||||
result = run_field_segmentation(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
|
||||
span.set_output({
|
||||
"frames": len(frames),
|
||||
"avg_coverage": sum(result["field_coverage"].values()) / max(len(result["field_coverage"]), 1),
|
||||
})
|
||||
|
||||
_emit(state, "field_segmentation", "done")
|
||||
return {
|
||||
"field_masks": result["field_masks"],
|
||||
"field_boundaries": result["field_boundaries"],
|
||||
"field_coverage": result["field_coverage"],
|
||||
}
|
||||
|
||||
|
||||
def node_detect_edges(state: DetectState) -> dict:
|
||||
_emit(state, "detect_edges", "running")
|
||||
|
||||
with trace_node(state, "detect_edges") as span:
|
||||
profile = _get_profile(state)
|
||||
config = profile.region_analysis_config()
|
||||
profile = _load_profile(state)
|
||||
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
||||
frames = state.get("filtered_frames", [])
|
||||
field_masks = state.get("field_masks", {})
|
||||
job_id = state.get("job_id")
|
||||
|
||||
regions = detect_edge_regions(
|
||||
frames, config, inference_url=INFERENCE_URL, job_id=job_id,
|
||||
field_masks=field_masks,
|
||||
)
|
||||
total = sum(len(r) for r in regions.values())
|
||||
span.set_output({"frames": len(frames), "edge_regions": total})
|
||||
@@ -129,8 +170,8 @@ def node_detect_objects(state: DetectState) -> dict:
|
||||
_emit(state, "detect_objects", "running")
|
||||
|
||||
with trace_node(state, "detect_objects") as span:
|
||||
profile = _get_profile(state)
|
||||
config = profile.detection_config()
|
||||
profile = _load_profile(state)
|
||||
config = DetectionConfig(**get_stage_config(profile, "detect_objects"))
|
||||
frames = state.get("filtered_frames", [])
|
||||
job_id = state.get("job_id")
|
||||
|
||||
@@ -149,13 +190,12 @@ def node_preprocess(state: DetectState) -> dict:
|
||||
_emit(state, "preprocess", "running")
|
||||
|
||||
with trace_node(state, "preprocess") as span:
|
||||
profile = _get_profile(state)
|
||||
profile = _load_profile(state)
|
||||
prep_config = get_stage_config(profile, "preprocess")
|
||||
frames = state.get("filtered_frames", [])
|
||||
boxes = state.get("boxes_by_frame", {})
|
||||
job_id = state.get("job_id")
|
||||
|
||||
overrides = state.get("config_overrides", {})
|
||||
prep_config = overrides.get("preprocessing", {})
|
||||
do_contrast = prep_config.get("contrast", True)
|
||||
do_deskew = prep_config.get("deskew", False)
|
||||
do_binarize = prep_config.get("binarize", False)
|
||||
@@ -178,8 +218,8 @@ def node_run_ocr(state: DetectState) -> dict:
|
||||
_emit(state, "run_ocr", "running")
|
||||
|
||||
with trace_node(state, "run_ocr") as span:
|
||||
profile = _get_profile(state)
|
||||
config = profile.ocr_config()
|
||||
profile = _load_profile(state)
|
||||
config = OCRConfig(**get_stage_config(profile, "run_ocr"))
|
||||
frames = state.get("filtered_frames", [])
|
||||
boxes = state.get("boxes_by_frame", {})
|
||||
job_id = state.get("job_id")
|
||||
@@ -198,18 +238,18 @@ def node_match_brands(state: DetectState) -> dict:
|
||||
_emit(state, "match_brands", "running")
|
||||
|
||||
with trace_node(state, "match_brands") as span:
|
||||
profile = _get_profile(state)
|
||||
resolver_config = profile.resolver_config()
|
||||
profile = _load_profile(state)
|
||||
config = ResolverConfig(**get_stage_config(profile, "match_brands"))
|
||||
candidates = state.get("text_candidates", [])
|
||||
session_brands = state.get("session_brands", {})
|
||||
job_id = state.get("job_id")
|
||||
source_asset_id = state.get("source_asset_id")
|
||||
|
||||
matched, unresolved = resolve_brands(
|
||||
candidates, resolver_config,
|
||||
candidates, config,
|
||||
session_brands=session_brands,
|
||||
source_asset_id=source_asset_id,
|
||||
content_type=profile.name, job_id=job_id,
|
||||
content_type=profile["name"], job_id=job_id,
|
||||
)
|
||||
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
|
||||
|
||||
@@ -221,15 +261,19 @@ def node_escalate_vlm(state: DetectState) -> dict:
|
||||
_emit(state, "escalate_vlm", "running")
|
||||
|
||||
with trace_node(state, "escalate_vlm") as span:
|
||||
profile = _get_profile(state)
|
||||
profile = _load_profile(state)
|
||||
vlm_config = get_stage_config(profile, "escalate_vlm")
|
||||
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
|
||||
candidates = state.get("unresolved_candidates", [])
|
||||
job_id = state.get("job_id")
|
||||
|
||||
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
|
||||
|
||||
vlm_matched, still_unresolved = escalate_vlm(
|
||||
candidates,
|
||||
vlm_prompt_fn=profile.vlm_prompt,
|
||||
vlm_prompt_fn=vlm_prompt_fn,
|
||||
inference_url=INFERENCE_URL,
|
||||
content_type=profile.name,
|
||||
content_type=profile["name"],
|
||||
source_asset_id=state.get("source_asset_id"),
|
||||
job_id=job_id,
|
||||
)
|
||||
@@ -254,16 +298,20 @@ def node_escalate_cloud(state: DetectState) -> dict:
|
||||
_emit(state, "escalate_cloud", "running")
|
||||
|
||||
with trace_node(state, "escalate_cloud") as span:
|
||||
profile = _get_profile(state)
|
||||
profile = _load_profile(state)
|
||||
vlm_config = get_stage_config(profile, "escalate_vlm")
|
||||
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
|
||||
candidates = state.get("unresolved_candidates", [])
|
||||
job_id = state.get("job_id")
|
||||
stats = state.get("stats", PipelineStats())
|
||||
|
||||
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
|
||||
|
||||
cloud_matched = escalate_cloud(
|
||||
candidates,
|
||||
vlm_prompt_fn=profile.vlm_prompt,
|
||||
vlm_prompt_fn=vlm_prompt_fn,
|
||||
stats=stats,
|
||||
content_type=profile.name,
|
||||
content_type=profile["name"],
|
||||
source_asset_id=state.get("source_asset_id"),
|
||||
job_id=job_id,
|
||||
)
|
||||
@@ -283,7 +331,7 @@ def node_compile_report(state: DetectState) -> dict:
|
||||
_emit(state, "compile_report", "running")
|
||||
|
||||
with trace_node(state, "compile_report") as span:
|
||||
profile = _get_profile(state)
|
||||
profile = _load_profile(state)
|
||||
detections = state.get("detections", [])
|
||||
stats = state.get("stats", PipelineStats())
|
||||
job_id = state.get("job_id")
|
||||
@@ -292,7 +340,7 @@ def node_compile_report(state: DetectState) -> dict:
|
||||
detections=detections,
|
||||
stats=stats,
|
||||
video_source=state.get("video_path", ""),
|
||||
content_type=profile.name,
|
||||
content_type=profile["name"],
|
||||
job_id=job_id,
|
||||
)
|
||||
|
||||
@@ -306,6 +354,7 @@ def node_compile_report(state: DetectState) -> dict:
|
||||
NODE_FUNCTIONS = [
|
||||
("extract_frames", node_extract_frames),
|
||||
("filter_scenes", node_filter_scenes),
|
||||
("field_segmentation", node_field_segmentation),
|
||||
("detect_edges", node_detect_edges),
|
||||
("detect_objects", node_detect_objects),
|
||||
("preprocess", node_preprocess),
|
||||
@@ -13,8 +13,8 @@ import logging
|
||||
import os
|
||||
import threading
|
||||
|
||||
from core.schema.models.pipeline_config import PipelineConfig
|
||||
from detect.state import DetectState
|
||||
from core.detect.stages.models import PipelineConfig
|
||||
from core.detect.state import DetectState
|
||||
from .nodes import NODES, NODE_FUNCTIONS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -118,7 +118,7 @@ def _wait_if_paused(job_id: str, node_name: str):
|
||||
|
||||
if _pause_after_stage.get(job_id, False):
|
||||
gate.clear()
|
||||
from detect import emit
|
||||
from core.detect import emit
|
||||
emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}")
|
||||
|
||||
while not gate.wait(timeout=0.5):
|
||||
@@ -237,7 +237,7 @@ class PipelineRunner:
|
||||
|
||||
# 4. Checkpoint
|
||||
if self.do_checkpoint:
|
||||
from detect.checkpoint import checkpoint_after_stage
|
||||
from core.detect.checkpoint import checkpoint_after_stage
|
||||
checkpoint_after_stage(job_id, stage_name, state, result)
|
||||
|
||||
# 5. Pause check
|
||||
@@ -256,11 +256,11 @@ def get_pipeline(
|
||||
start_from: str | None = None,
|
||||
) -> PipelineRunner:
|
||||
"""Return a PipelineRunner for the given profile."""
|
||||
from detect.profiles import get_profile
|
||||
from core.detect.profile import get_profile, pipeline_config_from_dict
|
||||
|
||||
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
|
||||
profile = get_profile(profile_name)
|
||||
config = profile.pipeline_config()
|
||||
config = pipeline_config_from_dict(profile["pipeline"])
|
||||
|
||||
return PipelineRunner(
|
||||
config=config,
|
||||
@@ -231,6 +231,20 @@ class InferenceClient:
|
||||
pair_count=data.get("pair_count", 0),
|
||||
)
|
||||
|
||||
def post(self, path: str, payload: dict) -> dict | None:
|
||||
"""Generic POST to the inference server. Returns JSON response or None on error."""
|
||||
try:
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}{path}",
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except Exception as e:
|
||||
logger.warning("Inference POST %s failed: %s", path, e)
|
||||
return None
|
||||
|
||||
def load_model(self, model: str, quantization: str = "fp16") -> None:
|
||||
"""Request the server to load a model into VRAM."""
|
||||
self.session.post(
|
||||
@@ -2,7 +2,7 @@
|
||||
Inference response types.
|
||||
|
||||
These are the shapes returned by the inference server.
|
||||
Kept separate from detect.models to avoid coupling the
|
||||
Kept separate from core.detect.models to avoid coupling the
|
||||
inference protocol to pipeline internals.
|
||||
"""
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
Detection pipeline runtime models.
|
||||
|
||||
These are the data structures that flow between pipeline stages.
|
||||
They contain runtime types (np.ndarray) so modelgen skips them —
|
||||
not generated to SQLModel or TypeScript.
|
||||
They contain runtime types (np.ndarray) so they live here, not in
|
||||
core/schema/models/ (which is for modelgen source of truth).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -85,3 +85,11 @@ class DetectionReport:
|
||||
brands: dict[str, BrandStats] = field(default_factory=dict)
|
||||
timeline: list[BrandDetection] = field(default_factory=list)
|
||||
pipeline_stats: PipelineStats = field(default_factory=PipelineStats)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CropContext:
|
||||
"""Runtime type — holds image bytes for VLM prompts."""
|
||||
image: bytes
|
||||
surrounding_text: str = ""
|
||||
position_hint: str = ""
|
||||
107
core/detect/profile.py
Normal file
107
core/detect/profile.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Profile registry and helpers.
|
||||
|
||||
Loads profile data from Postgres.
|
||||
A profile is a dict with keys: name, pipeline, configs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from core.detect.stages.models import PipelineConfig, StageRef, Edge
|
||||
from core.detect.models import (
|
||||
BrandDetection,
|
||||
BrandStats,
|
||||
CropContext,
|
||||
DetectionReport,
|
||||
PipelineStats,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_profile(name: str) -> Dict[str, Any]:
|
||||
"""Get a profile dict by name from the database."""
|
||||
from core.db.connection import get_session
|
||||
from core.db.models import Profile
|
||||
|
||||
with get_session() as session:
|
||||
row = session.query(Profile).filter(Profile.name == name).first()
|
||||
|
||||
if row is None:
|
||||
raise ValueError(f"Unknown profile: {name!r}")
|
||||
|
||||
return {
|
||||
"name": row.name,
|
||||
"pipeline": row.pipeline or {},
|
||||
"configs": row.configs or {},
|
||||
}
|
||||
|
||||
|
||||
def list_profiles() -> list[str]:
|
||||
"""List available profile names from the database."""
|
||||
from core.db.connection import get_session
|
||||
from core.db.models import Profile
|
||||
|
||||
with get_session() as session:
|
||||
rows = session.query(Profile.name).all()
|
||||
|
||||
return [r[0] for r in rows]
|
||||
|
||||
|
||||
def get_stage_config(profile: Dict[str, Any], stage_name: str) -> dict:
|
||||
"""Get config values for a stage from a profile."""
|
||||
return profile.get("configs", {}).get(stage_name, {})
|
||||
|
||||
|
||||
def pipeline_config_from_dict(data: Dict[str, Any]) -> PipelineConfig:
|
||||
"""Deserialize a PipelineConfig from a JSONB dict."""
|
||||
stages = [StageRef(**s) for s in data.get("stages", [])]
|
||||
edges = [Edge(**e) for e in data.get("edges", [])]
|
||||
return PipelineConfig(
|
||||
name=data.get("name", ""),
|
||||
profile_name=data.get("profile_name", ""),
|
||||
stages=stages,
|
||||
edges=edges,
|
||||
routing_rules=data.get("routing_rules", {}),
|
||||
)
|
||||
|
||||
|
||||
def build_vlm_prompt(crop_context: CropContext, template: str) -> str:
|
||||
"""Build a VLM prompt from a template and crop context."""
|
||||
hint = f" Position: {crop_context.position_hint}." if crop_context.position_hint else ""
|
||||
text = f" Nearby text: '{crop_context.surrounding_text}'." if crop_context.surrounding_text else ""
|
||||
return template.format(hint=hint, text=text)
|
||||
|
||||
|
||||
def aggregate_detections(
|
||||
detections: list[BrandDetection],
|
||||
content_type: str,
|
||||
) -> DetectionReport:
|
||||
"""Group detections by brand into a report."""
|
||||
brands: dict[str, BrandStats] = {}
|
||||
for d in detections:
|
||||
if d.brand not in brands:
|
||||
brands[d.brand] = BrandStats()
|
||||
s = brands[d.brand]
|
||||
s.total_appearances += 1
|
||||
s.total_screen_time += d.duration
|
||||
s.avg_confidence = (
|
||||
(s.avg_confidence * (s.total_appearances - 1) + d.confidence)
|
||||
/ s.total_appearances
|
||||
)
|
||||
if s.first_seen == 0.0 or d.timestamp < s.first_seen:
|
||||
s.first_seen = d.timestamp
|
||||
if d.timestamp > s.last_seen:
|
||||
s.last_seen = d.timestamp
|
||||
|
||||
return DetectionReport(
|
||||
video_source="",
|
||||
content_type=content_type,
|
||||
duration_seconds=0.0,
|
||||
brands=brands,
|
||||
timeline=sorted(detections, key=lambda d: d.timestamp),
|
||||
pipeline_stats=PipelineStats(),
|
||||
)
|
||||
@@ -145,3 +145,19 @@ class RetryResponse(BaseModel):
|
||||
status: str
|
||||
task_id: str
|
||||
job_id: str
|
||||
|
||||
class RunRequest(BaseModel):
|
||||
"""Request body for launching a detection pipeline run."""
|
||||
video_path: str
|
||||
profile_name: str = "soccer_broadcast"
|
||||
source_asset_id: str = ""
|
||||
checkpoint: bool = True
|
||||
skip_vlm: bool = False
|
||||
skip_cloud: bool = False
|
||||
log_level: str = "INFO"
|
||||
|
||||
class RunResponse(BaseModel):
|
||||
"""Response after starting a pipeline run."""
|
||||
status: str
|
||||
job_id: str
|
||||
video_path: str
|
||||
@@ -16,6 +16,7 @@ from .base import (
|
||||
|
||||
# Import all stage files to trigger auto-registration
|
||||
from . import edge_detector # noqa: F401
|
||||
from . import field_segmentation # noqa: F401
|
||||
|
||||
# Import registry for backward compat (other stages still use old pattern)
|
||||
from . import registry # noqa: F401
|
||||
@@ -9,8 +9,8 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from detect import emit
|
||||
from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
|
||||
from core.detect import emit
|
||||
from core.detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -5,7 +5,7 @@ Each stage is a file that subclasses Stage. Auto-discovered via
|
||||
__init_subclass__. No manual registration needed.
|
||||
|
||||
A stage:
|
||||
- Has a StageDefinition (from schema) with name, config, IO
|
||||
- Has a StageDefinition (generated from schema) with name, config, IO
|
||||
- Implements run(frames, config) → output
|
||||
- Owns its output serialization (opaque blob)
|
||||
- Optionally has a TypeScript port for browser-side execution
|
||||
@@ -16,12 +16,30 @@ the format. The stage that wrote it is the only one that can read it.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from core.schema.models.stages import StageConfigField, StageIO, StageDefinition
|
||||
from core.detect.stages.models import (
|
||||
StageConfigField,
|
||||
StageIO,
|
||||
StageDefinition,
|
||||
)
|
||||
|
||||
|
||||
# Legacy runtime extension — adds callable fields for old-style stages.
|
||||
# New stages use Stage subclass with serialize()/deserialize() methods instead.
|
||||
class LegacyStageDefinition:
|
||||
"""Wraps a StageDefinition with callable serialize/deserialize functions."""
|
||||
|
||||
def __init__(self, definition: StageDefinition, fn=None, serialize_fn=None, deserialize_fn=None):
|
||||
self._definition = definition
|
||||
self.fn = fn
|
||||
self.serialize_fn = serialize_fn
|
||||
self.deserialize_fn = deserialize_fn
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._definition, name)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -30,12 +48,18 @@ from core.schema.models.stages import StageConfigField, StageIO, StageDefinition
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_REGISTRY: dict[str, type['Stage']] = {}
|
||||
_LEGACY_REGISTRY: dict[str, StageDefinition] = {}
|
||||
_LEGACY_REGISTRY: dict[str, LegacyStageDefinition] = {}
|
||||
|
||||
|
||||
def register_stage(definition: StageDefinition):
|
||||
def register_stage(
|
||||
definition: StageDefinition,
|
||||
fn=None,
|
||||
serialize_fn=None,
|
||||
deserialize_fn=None,
|
||||
):
|
||||
"""Legacy registration for stages not yet converted to Stage subclass."""
|
||||
_LEGACY_REGISTRY[definition.name] = definition
|
||||
legacy = LegacyStageDefinition(definition, fn=fn, serialize_fn=serialize_fn, deserialize_fn=deserialize_fn)
|
||||
_LEGACY_REGISTRY[definition.name] = legacy
|
||||
|
||||
|
||||
class Stage:
|
||||
@@ -55,13 +79,6 @@ class Stage:
|
||||
_REGISTRY[cls.definition.name] = cls
|
||||
|
||||
def run(self, frames: list, config: dict) -> Any:
|
||||
"""
|
||||
Run the stage on a list of frames with the given config.
|
||||
|
||||
Config is a dict of parameter values (from slider UI or profile).
|
||||
Returns the stage output — whatever shape this stage produces.
|
||||
Debug overlays are included when config has debug=True.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def serialize(self, output: Any) -> bytes:
|
||||
@@ -79,12 +96,15 @@ class Stage:
|
||||
# Discovery API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _all_definitions() -> dict[str, StageDefinition]:
|
||||
"""Merge new Stage subclass registry + legacy registry."""
|
||||
def _all_definitions():
|
||||
"""Merge new Stage subclass registry + legacy registry.
|
||||
|
||||
Returns StageDefinition for new-style stages,
|
||||
LegacyStageDefinition for legacy stages (has serialize_fn etc).
|
||||
"""
|
||||
merged = {}
|
||||
# Legacy first, new overwrites (new takes precedence)
|
||||
for name, defn in _LEGACY_REGISTRY.items():
|
||||
merged[name] = defn
|
||||
for name, legacy in _LEGACY_REGISTRY.items():
|
||||
merged[name] = legacy
|
||||
for name, cls in _REGISTRY.items():
|
||||
merged[name] = cls.definition
|
||||
return merged
|
||||
@@ -16,9 +16,9 @@ import logging
|
||||
|
||||
from rapidfuzz import fuzz
|
||||
|
||||
from detect import emit
|
||||
from detect.models import BrandDetection, TextCandidate
|
||||
from detect.profiles.base import ResolverConfig
|
||||
from core.detect import emit
|
||||
from core.detect.models import BrandDetection, TextCandidate
|
||||
from core.detect.stages.models import ResolverConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,10 +21,10 @@ from typing import Any
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from detect import emit
|
||||
from detect.models import BoundingBox, Frame
|
||||
from detect.stages.base import Stage
|
||||
from core.schema.models.stages import StageDefinition, StageConfigField, StageIO
|
||||
from core.detect import emit
|
||||
from core.detect.models import BoundingBox, Frame
|
||||
from core.detect.stages.base import Stage
|
||||
from core.detect.stages.models import StageDefinition, StageConfigField, StageIO
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -42,14 +42,14 @@ class EdgeDetectionStage(Stage):
|
||||
writes=["edge_regions_by_frame"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField("enabled", "bool", True, "Enable edge detection"),
|
||||
StageConfigField("edge_canny_low", "int", 50, "Canny low threshold", min=0, max=255),
|
||||
StageConfigField("edge_canny_high", "int", 150, "Canny high threshold", min=0, max=255),
|
||||
StageConfigField("edge_hough_threshold", "int", 80, "Hough accumulator threshold", min=1, max=500),
|
||||
StageConfigField("edge_hough_min_length", "int", 100, "Min line length (px)", min=10, max=2000),
|
||||
StageConfigField("edge_hough_max_gap", "int", 10, "Max line gap (px)", min=1, max=100),
|
||||
StageConfigField("edge_pair_max_distance", "int", 200, "Max distance between line pair (px)", min=10, max=500),
|
||||
StageConfigField("edge_pair_min_distance", "int", 15, "Min distance between line pair (px)", min=5, max=200),
|
||||
StageConfigField(name="enabled", type="bool", default=True, description="Enable edge detection"),
|
||||
StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255),
|
||||
StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255),
|
||||
StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500),
|
||||
StageConfigField(name="edge_hough_min_length", type="int", default=100, description="Min line length (px)", min=10, max=2000),
|
||||
StageConfigField(name="edge_hough_max_gap", type="int", default=10, description="Max line gap (px)", min=1, max=100),
|
||||
StageConfigField(name="edge_pair_max_distance", type="int", default=200, description="Max distance between line pair (px)", min=10, max=500),
|
||||
StageConfigField(name="edge_pair_min_distance", type="int", default=15, description="Min distance between line pair (px)", min=5, max=200),
|
||||
],
|
||||
tracks_element="edge_region",
|
||||
)
|
||||
@@ -143,8 +143,8 @@ class EdgeDetectionStage(Stage):
|
||||
|
||||
def _run_remote(self, frame: Frame, config: dict,
|
||||
inference_url: str, job_id: str) -> list[BoundingBox]:
|
||||
from detect.inference import InferenceClient
|
||||
from detect.emit import _run_log_level
|
||||
from core.detect.inference import InferenceClient
|
||||
from core.detect.emit import _run_log_level
|
||||
|
||||
client = InferenceClient(
|
||||
base_url=inference_url, job_id=job_id, log_level=_run_log_level,
|
||||
@@ -208,7 +208,7 @@ def _load_cv_edges():
|
||||
if _cv_edges_mod is None:
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
spec = importlib.util.spec_from_file_location("cv_edges", Path("gpu/models/cv/edges.py"))
|
||||
spec = importlib.util.spec_from_file_location("cv_edges", Path("core/gpu/models/cv/edges.py"))
|
||||
_cv_edges_mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(_cv_edges_mod)
|
||||
return _cv_edges_mod
|
||||
@@ -216,8 +216,43 @@ def _load_cv_edges():
|
||||
|
||||
# --- Backward compat: standalone function for graph.py ---
|
||||
|
||||
def detect_edge_regions(frames, config, inference_url=None, job_id=None):
|
||||
"""Convenience wrapper — calls EdgeDetectionStage.run()."""
|
||||
def _filter_by_field_mask(boxes, mask, margin_px=50):
|
||||
"""
|
||||
Keep only boxes that are near the pitch boundary (hoarding zone).
|
||||
|
||||
The field mask has 255=pitch, 0=not pitch. Hoardings sit just
|
||||
outside the pitch boundary. We dilate the mask to create a
|
||||
"boundary zone" and keep boxes whose center falls in the zone
|
||||
between the dilated mask edge and the original mask.
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
if mask is None or not boxes:
|
||||
return boxes
|
||||
|
||||
# Dilate the pitch mask — the expansion zone is where hoardings are
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (margin_px * 2, margin_px * 2))
|
||||
dilated = cv2.dilate(mask, kernel)
|
||||
|
||||
# Boundary zone = dilated but NOT original pitch
|
||||
boundary_zone = cv2.bitwise_and(dilated, cv2.bitwise_not(mask))
|
||||
|
||||
kept = []
|
||||
for box in boxes:
|
||||
cx = box.x + box.w // 2
|
||||
cy = box.y + box.h // 2
|
||||
# Clamp to image bounds
|
||||
cy = min(cy, boundary_zone.shape[0] - 1)
|
||||
cx = min(cx, boundary_zone.shape[1] - 1)
|
||||
if boundary_zone[cy, cx] > 0:
|
||||
kept.append(box)
|
||||
|
||||
return kept
|
||||
|
||||
|
||||
def detect_edge_regions(frames, config, inference_url=None, job_id=None, field_masks=None):
|
||||
"""Convenience wrapper — calls EdgeDetectionStage.run(), optionally filters by field mask."""
|
||||
stage = EdgeDetectionStage()
|
||||
cfg = {
|
||||
"enabled": config.enabled,
|
||||
@@ -231,4 +266,23 @@ def detect_edge_regions(frames, config, inference_url=None, job_id=None):
|
||||
"inference_url": inference_url,
|
||||
"job_id": job_id,
|
||||
}
|
||||
return stage.run(frames, cfg)
|
||||
all_boxes = stage.run(frames, cfg)
|
||||
|
||||
# Filter by field segmentation mask if available
|
||||
if field_masks:
|
||||
filtered_total = 0
|
||||
original_total = sum(len(b) for b in all_boxes.values())
|
||||
for seq, boxes in all_boxes.items():
|
||||
mask = field_masks.get(seq)
|
||||
if mask is not None:
|
||||
all_boxes[seq] = _filter_by_field_mask(boxes, mask)
|
||||
filtered_total += len(all_boxes[seq])
|
||||
else:
|
||||
filtered_total += len(boxes)
|
||||
|
||||
if original_total != filtered_total:
|
||||
from core.detect import emit
|
||||
emit.log(job_id, "EdgeDetection", "INFO",
|
||||
f"Field mask filter: {original_total} → {filtered_total} regions")
|
||||
|
||||
return all_boxes
|
||||
141
core/detect/stages/field_segmentation.py
Normal file
141
core/detect/stages/field_segmentation.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Stage — Field Segmentation
|
||||
|
||||
Calls the GPU inference server to detect pitch boundaries via
|
||||
HSV green mask + morphology. The CV code lives in core/gpu/models/cv/.
|
||||
|
||||
Outputs a mask and boundary that downstream stages use as spatial priors.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import Frame
|
||||
from core.detect.stages.base import Stage
|
||||
from core.detect.stages.models import (
|
||||
FieldSegmentationConfig,
|
||||
StageConfigField,
|
||||
StageDefinition,
|
||||
StageIO,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FieldSegmentationStage(Stage):
|
||||
|
||||
definition = StageDefinition(
|
||||
name="field_segmentation",
|
||||
label="Field Segmentation",
|
||||
description="HSV green mask — detect pitch boundaries for spatial priors",
|
||||
category="cv_analysis",
|
||||
io=StageIO(
|
||||
reads=["filtered_frames"],
|
||||
writes=["field_mask"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField(name="enabled", type="bool", default=True, description="Enable field segmentation"),
|
||||
StageConfigField(name="hue_low", type="int", default=30, description="HSV hue lower bound", min=0, max=180),
|
||||
StageConfigField(name="hue_high", type="int", default=85, description="HSV hue upper bound", min=0, max=180),
|
||||
StageConfigField(name="sat_low", type="int", default=30, description="HSV saturation lower bound", min=0, max=255),
|
||||
StageConfigField(name="sat_high", type="int", default=255, description="HSV saturation upper bound", min=0, max=255),
|
||||
StageConfigField(name="val_low", type="int", default=30, description="HSV value lower bound", min=0, max=255),
|
||||
StageConfigField(name="val_high", type="int", default=255, description="HSV value upper bound", min=0, max=255),
|
||||
StageConfigField(name="morph_kernel", type="int", default=15, description="Morphology kernel size", min=3, max=51),
|
||||
StageConfigField(name="min_area_ratio", type="float", default=0.05, description="Min contour area as fraction of frame", min=0.01, max=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _frame_to_b64(frame: Frame) -> str:
|
||||
"""Encode frame image as base64 JPEG."""
|
||||
img = Image.fromarray(frame.image)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=85)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
def _decode_mask_b64(mask_b64: str) -> np.ndarray:
|
||||
"""Decode a base64 PNG mask back to numpy array."""
|
||||
data = base64.b64decode(mask_b64)
|
||||
img = Image.open(io.BytesIO(data)).convert("L")
|
||||
return np.array(img)
|
||||
|
||||
|
||||
def run_field_segmentation(
|
||||
frames: list[Frame],
|
||||
config: FieldSegmentationConfig,
|
||||
inference_url: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Run field segmentation on all frames via the inference server.
|
||||
|
||||
Returns dict with:
|
||||
field_masks: {seq: np.ndarray}
|
||||
field_boundaries: {seq: [(x,y), ...]}
|
||||
field_coverage: {seq: float}
|
||||
"""
|
||||
if not config.enabled:
|
||||
emit.log(job_id, "FieldSegmentation", "INFO", "Disabled, skipping")
|
||||
return {"field_masks": {}, "field_boundaries": {}, "field_coverage": {}}
|
||||
|
||||
import os
|
||||
url = inference_url or os.environ.get("INFERENCE_URL")
|
||||
if not url:
|
||||
emit.log(job_id, "FieldSegmentation", "WARNING",
|
||||
"No INFERENCE_URL, skipping field segmentation")
|
||||
return {"field_masks": {}, "field_boundaries": {}, "field_coverage": {}}
|
||||
|
||||
emit.log(job_id, "FieldSegmentation", "INFO",
|
||||
f"Segmenting {len(frames)} frames (hue={config.hue_low}-{config.hue_high})")
|
||||
|
||||
from core.detect.inference import InferenceClient
|
||||
from core.detect.emit import _run_log_level
|
||||
client = InferenceClient(base_url=url, job_id=job_id or "", log_level=_run_log_level)
|
||||
|
||||
field_masks = {}
|
||||
field_boundaries = {}
|
||||
field_coverage = {}
|
||||
|
||||
for frame in frames:
|
||||
image_b64 = _frame_to_b64(frame)
|
||||
|
||||
resp = client.post("/segment_field", {
|
||||
"image": image_b64,
|
||||
"hue_low": config.hue_low,
|
||||
"hue_high": config.hue_high,
|
||||
"sat_low": config.sat_low,
|
||||
"sat_high": config.sat_high,
|
||||
"val_low": config.val_low,
|
||||
"val_high": config.val_high,
|
||||
"morph_kernel": config.morph_kernel,
|
||||
"min_area_ratio": config.min_area_ratio,
|
||||
})
|
||||
|
||||
if resp is None:
|
||||
continue
|
||||
|
||||
mask_b64 = resp.get("mask_b64", "")
|
||||
if mask_b64:
|
||||
field_masks[frame.sequence] = _decode_mask_b64(mask_b64)
|
||||
|
||||
field_boundaries[frame.sequence] = resp.get("boundary", [])
|
||||
field_coverage[frame.sequence] = resp.get("coverage", 0.0)
|
||||
|
||||
avg_coverage = sum(field_coverage.values()) / max(len(field_coverage), 1)
|
||||
emit.log(job_id, "FieldSegmentation", "INFO",
|
||||
f"Done: {len(frames)} frames, avg coverage {avg_coverage:.1%}")
|
||||
|
||||
return {
|
||||
"field_masks": field_masks,
|
||||
"field_boundaries": field_boundaries,
|
||||
"field_coverage": field_coverage,
|
||||
}
|
||||
@@ -16,9 +16,9 @@ import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from core.ffmpeg.probe import probe_file
|
||||
from detect import emit
|
||||
from detect.models import Frame
|
||||
from detect.profiles.base import FrameExtractionConfig
|
||||
from core.detect import emit
|
||||
from core.detect.models import Frame
|
||||
from core.detect.stages.models import FrameExtractionConfig
|
||||
|
||||
|
||||
def _load_frames(tmpdir: Path, fps: float) -> list[Frame]:
|
||||
106
core/detect/stages/models.py
Normal file
106
core/detect/stages/models.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
Pydantic Models - GENERATED FILE
|
||||
|
||||
Do not edit directly. Regenerate using modelgen.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class StageConfigField(BaseModel):
|
||||
"""A single tunable config parameter for the editor UI."""
|
||||
name: str
|
||||
type: str
|
||||
default: Any
|
||||
description: str = ""
|
||||
min: Optional[float] = None
|
||||
max: Optional[float] = None
|
||||
options: Optional[List[str]] = None
|
||||
|
||||
class StageIO(BaseModel):
|
||||
"""Declares what a stage reads and writes."""
|
||||
reads: List[str] = Field(default_factory=list)
|
||||
writes: List[str] = Field(default_factory=list)
|
||||
optional_reads: List[str] = Field(default_factory=list)
|
||||
|
||||
class StageDefinition(BaseModel):
|
||||
"""Complete metadata for a pipeline stage."""
|
||||
name: str
|
||||
label: str
|
||||
description: str
|
||||
category: str = "detection"
|
||||
io: StageIO
|
||||
config_fields: List[StageConfigField] = Field(default_factory=list)
|
||||
tracks_element: Optional[str] = None
|
||||
|
||||
class FrameExtractionConfig(BaseModel):
|
||||
"""FrameExtractionConfig(fps: float = 2.0, max_frames: int = 500)"""
|
||||
fps: float = 2.0
|
||||
max_frames: int = 500
|
||||
|
||||
class SceneFilterConfig(BaseModel):
|
||||
"""SceneFilterConfig(hamming_threshold: int = 8, enabled: bool = True)"""
|
||||
hamming_threshold: int = 8
|
||||
enabled: bool = True
|
||||
|
||||
class DetectionConfig(BaseModel):
|
||||
"""DetectionConfig(model_name: str = 'yolov8n.pt', confidence_threshold: float = 0.3, target_classes: List[str] = <factory>)"""
|
||||
model_name: str = "yolov8n.pt"
|
||||
confidence_threshold: float = 0.3
|
||||
target_classes: List[str]
|
||||
|
||||
class OCRConfig(BaseModel):
|
||||
"""OCRConfig(languages: List[str] = <factory>, min_confidence: float = 0.5)"""
|
||||
languages: List[str]
|
||||
min_confidence: float = 0.5
|
||||
|
||||
class ResolverConfig(BaseModel):
|
||||
"""ResolverConfig(fuzzy_threshold: int = 75)"""
|
||||
fuzzy_threshold: int = 75
|
||||
|
||||
class RegionAnalysisConfig(BaseModel):
|
||||
"""RegionAnalysisConfig(enabled: bool = True, edge_canny_low: int = 50, edge_canny_high: int = 150, edge_hough_threshold: int = 80, edge_hough_min_length: int = 100, edge_hough_max_gap: int = 10, edge_pair_max_distance: int = 200, edge_pair_min_distance: int = 15)"""
|
||||
enabled: bool = True
|
||||
edge_canny_low: int = 50
|
||||
edge_canny_high: int = 150
|
||||
edge_hough_threshold: int = 80
|
||||
edge_hough_min_length: int = 100
|
||||
edge_hough_max_gap: int = 10
|
||||
edge_pair_max_distance: int = 200
|
||||
edge_pair_min_distance: int = 15
|
||||
|
||||
class FieldSegmentationConfig(BaseModel):
|
||||
"""FieldSegmentationConfig(enabled: bool = True, hue_low: int = 30, hue_high: int = 85, sat_low: int = 30, sat_high: int = 255, val_low: int = 30, val_high: int = 255, morph_kernel: int = 15, min_area_ratio: float = 0.05)"""
|
||||
enabled: bool = True
|
||||
hue_low: int = 30
|
||||
hue_high: int = 85
|
||||
sat_low: int = 30
|
||||
sat_high: int = 255
|
||||
val_low: int = 30
|
||||
val_high: int = 255
|
||||
morph_kernel: int = 15
|
||||
min_area_ratio: float = 0.05
|
||||
|
||||
class StageRef(BaseModel):
|
||||
"""Reference to a stage in the pipeline graph."""
|
||||
name: str
|
||||
branch: str = "trunk"
|
||||
execution_target: str = "local"
|
||||
|
||||
class Edge(BaseModel):
|
||||
"""Connection between stages in the graph."""
|
||||
source: str
|
||||
target: str
|
||||
condition: str = ""
|
||||
|
||||
class PipelineConfig(BaseModel):
|
||||
"""Pipeline graph topology + routing rules."""
|
||||
name: str
|
||||
profile_name: str
|
||||
stages: List[StageRef] = Field(default_factory=list)
|
||||
edges: List[Edge] = Field(default_factory=list)
|
||||
routing_rules: Dict[str, Any]
|
||||
@@ -18,9 +18,9 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from detect import emit
|
||||
from detect.models import BoundingBox, Frame, TextCandidate
|
||||
from detect.profiles.base import OCRConfig
|
||||
from core.detect import emit
|
||||
from core.detect.models import BoundingBox, Frame, TextCandidate
|
||||
from core.detect.stages.models import OCRConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
@@ -91,8 +91,8 @@ def run_ocr(
|
||||
|
||||
# Build these once per pipeline run, not per crop
|
||||
if inference_url:
|
||||
from detect.inference import InferenceClient
|
||||
from detect.emit import _run_log_level
|
||||
from core.detect.inference import InferenceClient
|
||||
from core.detect.emit import _run_log_level
|
||||
client = InferenceClient(base_url=inference_url, job_id=job_id or "", log_level=_run_log_level)
|
||||
else:
|
||||
model = _get_local_model(config.languages[0])
|
||||
@@ -15,8 +15,8 @@ import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from detect import emit
|
||||
from detect.models import BoundingBox, Frame
|
||||
from core.detect import emit
|
||||
from core.detect.models import BoundingBox, Frame
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -124,5 +124,5 @@ def _preprocess_remote(crop: np.ndarray, inference_url: str,
|
||||
def _preprocess_local(crop: np.ndarray,
|
||||
do_contrast: bool, do_deskew: bool, do_binarize: bool) -> np.ndarray:
|
||||
"""Run preprocessing in-process (requires opencv-python-headless)."""
|
||||
from gpu.models.preprocess import preprocess
|
||||
from core.gpu.models.preprocess import preprocess
|
||||
return preprocess(crop, do_binarize=do_binarize, do_deskew=do_deskew, do_contrast=do_contrast)
|
||||
44
core/detect/stages/registry/cv_analysis.py
Normal file
44
core/detect/stages/registry/cv_analysis.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Registration for CV analysis stages: edge detection."""
|
||||
|
||||
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||
from core.detect.stages.base import register_stage
|
||||
from ._serializers import serialize_dataclass_list, deserialize_bounding_box
|
||||
|
||||
|
||||
def _ser_regions(state: dict, job_id: str) -> dict:
|
||||
regions = state.get("edge_regions_by_frame", {})
|
||||
serialized = {
|
||||
str(seq): serialize_dataclass_list(bl) for seq, bl in regions.items()
|
||||
}
|
||||
return {"edge_regions_by_frame": serialized}
|
||||
|
||||
|
||||
def _deser_regions(data: dict, job_id: str) -> dict:
|
||||
regions = {}
|
||||
for seq_str, box_dicts in data.get("edge_regions_by_frame", {}).items():
|
||||
regions[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts]
|
||||
return {"edge_regions_by_frame": regions}
|
||||
|
||||
|
||||
def register():
|
||||
edge_detection = StageDefinition(
|
||||
name="detect_edges",
|
||||
label="Edge Detection",
|
||||
description="Canny + HoughLinesP — find horizontal line pairs (hoarding boundaries)",
|
||||
category="cv_analysis",
|
||||
io=StageIO(
|
||||
reads=["filtered_frames"],
|
||||
writes=["edge_regions_by_frame"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField(name="enabled", type="bool", default=True, description="Enable region analysis"),
|
||||
StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255),
|
||||
StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255),
|
||||
StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500),
|
||||
StageConfigField(name="edge_hough_min_length", type="int", default=100, description="Min line length (px)", min=10, max=2000),
|
||||
StageConfigField(name="edge_hough_max_gap", type="int", default=10, description="Max line gap (px)", min=1, max=100),
|
||||
StageConfigField(name="edge_pair_max_distance", type="int", default=200, description="Max distance between line pair (px)", min=10, max=500),
|
||||
StageConfigField(name="edge_pair_min_distance", type="int", default=15, description="Min distance between line pair (px)", min=5, max=200),
|
||||
],
|
||||
)
|
||||
register_stage(edge_detection, serialize_fn=_ser_regions, deserialize_fn=_deser_regions)
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Registration for detection stages: YOLO, OCR."""
|
||||
|
||||
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
||||
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||
from core.detect.stages.base import register_stage
|
||||
from ._serializers import (
|
||||
serialize_dataclass_list,
|
||||
serialize_text_candidates,
|
||||
@@ -38,14 +39,12 @@ def register():
|
||||
category="detection",
|
||||
io=StageIO(reads=["filtered_frames"], writes=["boxes_by_frame"]),
|
||||
config_fields=[
|
||||
StageConfigField("model_name", "str", "yolov8n.pt", "YOLO model file"),
|
||||
StageConfigField("confidence_threshold", "float", 0.3, "Min detection confidence", min=0.0, max=1.0),
|
||||
StageConfigField("target_classes", "list[str]", [], "YOLO classes to detect (empty = all)"),
|
||||
StageConfigField(name="model_name", type="str", default="yolov8n.pt", description="YOLO model file"),
|
||||
StageConfigField(name="confidence_threshold", type="float", default=0.3, description="Min detection confidence", min=0.0, max=1.0),
|
||||
StageConfigField(name="target_classes", type="list[str]", default=[], description="YOLO classes to detect (empty = all)"),
|
||||
],
|
||||
serialize_fn=_ser_detect,
|
||||
deserialize_fn=_deser_detect,
|
||||
)
|
||||
register_stage(yolo)
|
||||
register_stage(yolo, serialize_fn=_ser_detect, deserialize_fn=_deser_detect)
|
||||
|
||||
ocr = StageDefinition(
|
||||
name="run_ocr",
|
||||
@@ -54,10 +53,8 @@ def register():
|
||||
category="detection",
|
||||
io=StageIO(reads=["filtered_frames", "boxes_by_frame"], writes=["text_candidates"]),
|
||||
config_fields=[
|
||||
StageConfigField("languages", "list[str]", ["en"], "OCR languages"),
|
||||
StageConfigField("min_confidence", "float", 0.5, "Min OCR confidence", min=0.0, max=1.0),
|
||||
StageConfigField(name="languages", type="list[str]", default=["en"], description="OCR languages"),
|
||||
StageConfigField(name="min_confidence", type="float", default=0.5, description="Min OCR confidence", min=0.0, max=1.0),
|
||||
],
|
||||
serialize_fn=_ser_ocr,
|
||||
deserialize_fn=_deser_ocr,
|
||||
)
|
||||
register_stage(ocr)
|
||||
register_stage(ocr, serialize_fn=_ser_ocr, deserialize_fn=_deser_ocr)
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Registration for escalation stages: local VLM, cloud LLM."""
|
||||
|
||||
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
||||
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||
from core.detect.stages.base import register_stage
|
||||
from ._serializers import (
|
||||
serialize_dataclass_list,
|
||||
serialize_text_candidates,
|
||||
@@ -37,12 +38,10 @@ def register():
|
||||
optional_reads=["source_asset_id"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField("min_confidence", "float", 0.5, "Min VLM confidence", min=0.0, max=1.0),
|
||||
StageConfigField(name="min_confidence", type="float", default=0.5, description="Min VLM confidence", min=0.0, max=1.0),
|
||||
],
|
||||
serialize_fn=_ser_escalation,
|
||||
deserialize_fn=_deser_escalation,
|
||||
)
|
||||
register_stage(vlm)
|
||||
register_stage(vlm, serialize_fn=_ser_escalation, deserialize_fn=_deser_escalation)
|
||||
|
||||
cloud = StageDefinition(
|
||||
name="escalate_cloud",
|
||||
@@ -55,9 +54,7 @@ def register():
|
||||
optional_reads=["source_asset_id"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField("min_confidence", "float", 0.4, "Min cloud confidence", min=0.0, max=1.0),
|
||||
StageConfigField(name="min_confidence", type="float", default=0.4, description="Min cloud confidence", min=0.0, max=1.0),
|
||||
],
|
||||
serialize_fn=_ser_escalation,
|
||||
deserialize_fn=_deser_escalation,
|
||||
)
|
||||
register_stage(cloud)
|
||||
register_stage(cloud, serialize_fn=_ser_escalation, deserialize_fn=_deser_escalation)
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Registration for output stages: report compilation."""
|
||||
|
||||
from detect.stages.base import StageDefinition, StageIO, register_stage
|
||||
from core.detect.stages.base import StageDefinition, StageIO, register_stage
|
||||
from ._serializers import serialize_dataclass, deserialize_detection_report
|
||||
|
||||
|
||||
@@ -26,7 +26,5 @@ def register():
|
||||
category="output",
|
||||
io=StageIO(reads=["detections"], writes=["report"]),
|
||||
config_fields=[],
|
||||
serialize_fn=_ser_report,
|
||||
deserialize_fn=_deser_report,
|
||||
)
|
||||
register_stage(report)
|
||||
register_stage(report, serialize_fn=_ser_report, deserialize_fn=_deser_report)
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Registration for preprocessing stages: frame extraction, scene filter, image preprocessing."""
|
||||
|
||||
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
||||
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||
from core.detect.stages.base import register_stage
|
||||
from ._serializers import serialize_frames, deserialize_frames
|
||||
|
||||
|
||||
@@ -44,13 +45,11 @@ def register():
|
||||
category="preprocessing",
|
||||
io=StageIO(reads=["video_path"], writes=["frames"]),
|
||||
config_fields=[
|
||||
StageConfigField("fps", "float", 2.0, "Frames per second", min=0.1, max=30.0),
|
||||
StageConfigField("max_frames", "int", 500, "Maximum frames to extract", min=1, max=10000),
|
||||
StageConfigField(name="fps", type="float", default=2.0, description="Frames per second", min=0.1, max=30.0),
|
||||
StageConfigField(name="max_frames", type="int", default=500, description="Maximum frames to extract", min=1, max=10000),
|
||||
],
|
||||
serialize_fn=_ser_extract,
|
||||
deserialize_fn=_deser_extract,
|
||||
)
|
||||
register_stage(extract)
|
||||
register_stage(extract, serialize_fn=_ser_extract, deserialize_fn=_deser_extract)
|
||||
|
||||
scene_filter = StageDefinition(
|
||||
name="filter_scenes",
|
||||
@@ -59,13 +58,11 @@ def register():
|
||||
category="preprocessing",
|
||||
io=StageIO(reads=["frames"], writes=["filtered_frames"]),
|
||||
config_fields=[
|
||||
StageConfigField("hamming_threshold", "int", 8, "Hamming distance threshold", min=0, max=64),
|
||||
StageConfigField("enabled", "bool", True, "Enable scene filtering"),
|
||||
StageConfigField(name="hamming_threshold", type="int", default=8, description="Hamming distance threshold", min=0, max=64),
|
||||
StageConfigField(name="enabled", type="bool", default=True, description="Enable scene filtering"),
|
||||
],
|
||||
serialize_fn=_ser_filter,
|
||||
deserialize_fn=_deser_filter,
|
||||
)
|
||||
register_stage(scene_filter)
|
||||
register_stage(scene_filter, serialize_fn=_ser_filter, deserialize_fn=_deser_filter)
|
||||
|
||||
preprocess = StageDefinition(
|
||||
name="preprocess",
|
||||
@@ -77,11 +74,9 @@ def register():
|
||||
writes=["preprocessed_crops"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField("contrast", "bool", True, "CLAHE contrast enhancement"),
|
||||
StageConfigField("deskew", "bool", False, "Correct slight rotation"),
|
||||
StageConfigField("binarize", "bool", False, "Otsu binarization"),
|
||||
StageConfigField(name="contrast", type="bool", default=True, description="CLAHE contrast enhancement"),
|
||||
StageConfigField(name="deskew", type="bool", default=False, description="Correct slight rotation"),
|
||||
StageConfigField(name="binarize", type="bool", default=False, description="Otsu binarization"),
|
||||
],
|
||||
serialize_fn=_ser_preprocess,
|
||||
deserialize_fn=_deser_preprocess,
|
||||
)
|
||||
register_stage(preprocess)
|
||||
register_stage(preprocess, serialize_fn=_ser_preprocess, deserialize_fn=_deser_preprocess)
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Registration for resolution stages: brand resolver."""
|
||||
|
||||
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
||||
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||
from core.detect.stages.base import register_stage
|
||||
from ._serializers import (
|
||||
serialize_dataclass_list,
|
||||
serialize_text_candidates,
|
||||
@@ -37,9 +38,7 @@ def register():
|
||||
optional_reads=["session_brands", "source_asset_id"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField("fuzzy_threshold", "int", 75, "Fuzzy match threshold", min=0, max=100),
|
||||
StageConfigField(name="fuzzy_threshold", type="int", default=75, description="Fuzzy match threshold", min=0, max=100),
|
||||
],
|
||||
serialize_fn=_ser_brands,
|
||||
deserialize_fn=_deser_brands,
|
||||
)
|
||||
register_stage(resolver)
|
||||
register_stage(resolver, serialize_fn=_ser_brands, deserialize_fn=_deser_brands)
|
||||
@@ -14,9 +14,9 @@ import time
|
||||
import imagehash
|
||||
from PIL import Image
|
||||
|
||||
from detect import emit
|
||||
from detect.models import Frame
|
||||
from detect.profiles.base import SceneFilterConfig
|
||||
from core.detect import emit
|
||||
from core.detect.models import Frame
|
||||
from core.detect.stages.models import SceneFilterConfig
|
||||
|
||||
|
||||
def _compute_hashes(frames: list[Frame]) -> list[imagehash.ImageHash]:
|
||||
@@ -19,10 +19,10 @@ import time
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from detect import emit
|
||||
from detect.models import BrandDetection, PipelineStats, TextCandidate
|
||||
from detect.profiles.base import CropContext
|
||||
from detect.providers import get_provider, has_api_key
|
||||
from core.detect import emit
|
||||
from core.detect.models import BrandDetection, PipelineStats, TextCandidate
|
||||
from core.detect.models import CropContext
|
||||
from core.detect.providers import get_provider, has_api_key
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,7 +33,7 @@ def _register_discovered_brand(brand: str, source_asset_id: str | None,
|
||||
timestamp: float, confidence: float):
|
||||
"""Register a cloud-confirmed brand in the DB."""
|
||||
try:
|
||||
from detect.stages.brand_resolver import _register_brand, _record_sighting
|
||||
from core.detect.stages.brand_resolver import _register_brand, _record_sighting
|
||||
brand_id = _register_brand(brand, "cloud_llm")
|
||||
if brand_id and source_asset_id:
|
||||
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, "cloud_llm")
|
||||
@@ -14,9 +14,9 @@ import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from detect import emit
|
||||
from detect.models import BrandDetection, TextCandidate
|
||||
from detect.profiles.base import CropContext
|
||||
from core.detect import emit
|
||||
from core.detect.models import BrandDetection, TextCandidate
|
||||
from core.detect.models import CropContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,7 +25,7 @@ def _register_discovered_brand(brand: str, source_asset_id: str | None,
|
||||
timestamp: float, confidence: float, source: str):
|
||||
"""Register a VLM-confirmed brand in the DB."""
|
||||
try:
|
||||
from detect.stages.brand_resolver import _register_brand, _record_sighting
|
||||
from core.detect.stages.brand_resolver import _register_brand, _record_sighting
|
||||
brand_id = _register_brand(brand, source)
|
||||
if brand_id and source_asset_id:
|
||||
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, source)
|
||||
@@ -75,8 +75,8 @@ def escalate_vlm(
|
||||
still_unresolved: list[TextCandidate] = []
|
||||
|
||||
if inference_url:
|
||||
from detect.inference import InferenceClient
|
||||
from detect.emit import _run_log_level
|
||||
from core.detect.inference import InferenceClient
|
||||
from core.detect.emit import _run_log_level
|
||||
client = InferenceClient(base_url=inference_url, job_id=job_id or "", log_level=_run_log_level)
|
||||
|
||||
for i, candidate in enumerate(candidates):
|
||||
@@ -152,6 +152,6 @@ def escalate_vlm(
|
||||
|
||||
def _vlm_local(crop: np.ndarray, prompt: str) -> tuple[str, float, str]:
|
||||
"""Run moondream2 in-process (single-box mode)."""
|
||||
from gpu.models.vlm import query
|
||||
from core.gpu.models.vlm import query
|
||||
result = query(crop, prompt)
|
||||
return result["brand"], result["confidence"], result["reasoning"]
|
||||
@@ -18,9 +18,9 @@ import time
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from detect import emit
|
||||
from detect.models import BoundingBox, Frame
|
||||
from detect.profiles.base import DetectionConfig
|
||||
from core.detect import emit
|
||||
from core.detect.models import BoundingBox, Frame
|
||||
from core.detect.stages.models import DetectionConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,7 +36,7 @@ def _frame_to_b64(frame: Frame) -> str:
|
||||
def _detect_remote(frame: Frame, config: DetectionConfig, inference_url: str,
|
||||
job_id: str = "", log_level: str = "INFO") -> list[BoundingBox]:
|
||||
"""Call the inference server over HTTP."""
|
||||
from detect.inference import InferenceClient
|
||||
from core.detect.inference import InferenceClient
|
||||
client = InferenceClient(base_url=inference_url, job_id=job_id, log_level=log_level)
|
||||
results = client.detect(
|
||||
image=frame.image,
|
||||
@@ -104,7 +104,7 @@ def detect_objects(
|
||||
for i, frame in enumerate(frames):
|
||||
t0 = time.monotonic()
|
||||
if inference_url:
|
||||
from detect.emit import _run_log_level
|
||||
from core.detect.emit import _run_log_level
|
||||
boxes = _detect_remote(frame, config, inference_url,
|
||||
job_id=job_id or "", log_level=_run_log_level)
|
||||
else:
|
||||
@@ -9,7 +9,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
from detect.models import BoundingBox, BrandDetection, DetectionReport, Frame, PipelineStats, TextCandidate
|
||||
from core.detect.models import BoundingBox, BrandDetection, DetectionReport, Frame, PipelineStats, TextCandidate
|
||||
|
||||
|
||||
class DetectState(TypedDict, total=False):
|
||||
@@ -22,6 +22,9 @@ class DetectState(TypedDict, total=False):
|
||||
# Stage outputs
|
||||
frames: list[Frame]
|
||||
filtered_frames: list[Frame]
|
||||
field_masks: dict # {seq: np.ndarray} — pitch mask per frame
|
||||
field_boundaries: dict # {seq: [(x,y), ...]} — pitch boundary per frame
|
||||
field_coverage: dict # {seq: float} — pitch coverage ratio per frame
|
||||
edge_regions_by_frame: dict[int, list[BoundingBox]]
|
||||
boxes_by_frame: dict[int, list[BoundingBox]]
|
||||
preprocessed_crops: dict # "{frame_seq}_{box_idx}" → np.ndarray
|
||||
@@ -36,5 +39,5 @@ class DetectState(TypedDict, total=False):
|
||||
# Running stats (updated by each stage)
|
||||
stats: PipelineStats
|
||||
|
||||
# Config overrides for replay (applied via OverrideProfile)
|
||||
# Config overrides for replay (merged into profile configs dict)
|
||||
config_overrides: dict
|
||||
@@ -6,7 +6,7 @@ and stage-level metadata. The Langfuse client is optional — if not configured
|
||||
(no LANGFUSE_SECRET_KEY), tracing is a no-op.
|
||||
|
||||
Usage in graph nodes:
|
||||
from detect.tracing import trace_node
|
||||
from core.detect.tracing import trace_node
|
||||
|
||||
def node_extract_frames(state):
|
||||
with trace_node(state, "extract_frames") as span:
|
||||
@@ -1,13 +1,6 @@
|
||||
from .capabilities import get_decoders, get_encoders, get_formats
|
||||
from .probe import ProbeResult, probe_file
|
||||
from .transcode import TranscodeConfig, transcode
|
||||
|
||||
__all__ = [
|
||||
"probe_file",
|
||||
"ProbeResult",
|
||||
"transcode",
|
||||
"TranscodeConfig",
|
||||
"get_encoders",
|
||||
"get_decoders",
|
||||
"get_formats",
|
||||
]
|
||||
|
||||
@@ -1,145 +0,0 @@
|
||||
"""
|
||||
FFmpeg capabilities - Discover available codecs and formats using ffmpeg-python.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import ffmpeg
|
||||
|
||||
|
||||
@dataclass
|
||||
class Codec:
|
||||
"""An FFmpeg encoder or decoder."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
type: str # 'video' or 'audio'
|
||||
|
||||
|
||||
@dataclass
|
||||
class Format:
|
||||
"""An FFmpeg format (muxer/demuxer)."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
can_demux: bool
|
||||
can_mux: bool
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_ffmpeg_info() -> Dict[str, Any]:
|
||||
"""Get FFmpeg capabilities info."""
|
||||
# ffmpeg-python doesn't have a direct way to get codecs/formats
|
||||
# but we can use probe on a dummy or parse -codecs output
|
||||
# For now, return common codecs that are typically available
|
||||
return {
|
||||
"video_encoders": [
|
||||
{"name": "libx264", "description": "H.264 / AVC"},
|
||||
{"name": "libx265", "description": "H.265 / HEVC"},
|
||||
{"name": "mpeg4", "description": "MPEG-4 Part 2"},
|
||||
{"name": "libvpx", "description": "VP8"},
|
||||
{"name": "libvpx-vp9", "description": "VP9"},
|
||||
{"name": "h264_nvenc", "description": "NVIDIA NVENC H.264"},
|
||||
{"name": "hevc_nvenc", "description": "NVIDIA NVENC H.265"},
|
||||
{"name": "h264_vaapi", "description": "VAAPI H.264"},
|
||||
{"name": "prores_ks", "description": "Apple ProRes"},
|
||||
{"name": "dnxhd", "description": "Avid DNxHD/DNxHR"},
|
||||
{"name": "copy", "description": "Stream copy (no encoding)"},
|
||||
],
|
||||
"audio_encoders": [
|
||||
{"name": "aac", "description": "AAC"},
|
||||
{"name": "libmp3lame", "description": "MP3"},
|
||||
{"name": "libopus", "description": "Opus"},
|
||||
{"name": "libvorbis", "description": "Vorbis"},
|
||||
{"name": "pcm_s16le", "description": "PCM signed 16-bit little-endian"},
|
||||
{"name": "flac", "description": "FLAC"},
|
||||
{"name": "copy", "description": "Stream copy (no encoding)"},
|
||||
],
|
||||
"formats": [
|
||||
{"name": "mp4", "description": "MP4", "can_demux": True, "can_mux": True},
|
||||
{
|
||||
"name": "mov",
|
||||
"description": "QuickTime / MOV",
|
||||
"can_demux": True,
|
||||
"can_mux": True,
|
||||
},
|
||||
{
|
||||
"name": "mkv",
|
||||
"description": "Matroska",
|
||||
"can_demux": True,
|
||||
"can_mux": True,
|
||||
},
|
||||
{"name": "webm", "description": "WebM", "can_demux": True, "can_mux": True},
|
||||
{"name": "avi", "description": "AVI", "can_demux": True, "can_mux": True},
|
||||
{"name": "flv", "description": "FLV", "can_demux": True, "can_mux": True},
|
||||
{
|
||||
"name": "ts",
|
||||
"description": "MPEG-TS",
|
||||
"can_demux": True,
|
||||
"can_mux": True,
|
||||
},
|
||||
{
|
||||
"name": "mpegts",
|
||||
"description": "MPEG-TS",
|
||||
"can_demux": True,
|
||||
"can_mux": True,
|
||||
},
|
||||
{"name": "hls", "description": "HLS", "can_demux": True, "can_mux": True},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def get_encoders() -> List[Codec]:
|
||||
"""Get available encoders (video + audio)."""
|
||||
info = _get_ffmpeg_info()
|
||||
codecs = []
|
||||
|
||||
for c in info["video_encoders"]:
|
||||
codecs.append(Codec(name=c["name"], description=c["description"], type="video"))
|
||||
|
||||
for c in info["audio_encoders"]:
|
||||
codecs.append(Codec(name=c["name"], description=c["description"], type="audio"))
|
||||
|
||||
return codecs
|
||||
|
||||
|
||||
def get_decoders() -> List[Codec]:
|
||||
"""Get available decoders."""
|
||||
# Most encoders can also decode
|
||||
return get_encoders()
|
||||
|
||||
|
||||
def get_formats() -> List[Format]:
|
||||
"""Get available formats."""
|
||||
info = _get_ffmpeg_info()
|
||||
return [
|
||||
Format(
|
||||
name=f["name"],
|
||||
description=f["description"],
|
||||
can_demux=f["can_demux"],
|
||||
can_mux=f["can_mux"],
|
||||
)
|
||||
for f in info["formats"]
|
||||
]
|
||||
|
||||
|
||||
def get_video_encoders() -> List[Codec]:
|
||||
"""Get available video encoders."""
|
||||
return [c for c in get_encoders() if c.type == "video"]
|
||||
|
||||
|
||||
def get_audio_encoders() -> List[Codec]:
|
||||
"""Get available audio encoders."""
|
||||
return [c for c in get_encoders() if c.type == "audio"]
|
||||
|
||||
|
||||
def get_muxers() -> List[Format]:
|
||||
"""Get available output formats (muxers)."""
|
||||
return [f for f in get_formats() if f.can_mux]
|
||||
|
||||
|
||||
def get_demuxers() -> List[Format]:
|
||||
"""Get available input formats (demuxers)."""
|
||||
return [f for f in get_formats() if f.can_demux]
|
||||
@@ -1,225 +0,0 @@
|
||||
"""
|
||||
FFmpeg transcode module - Transcode media files using ffmpeg-python.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import ffmpeg
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscodeConfig:
|
||||
"""Configuration for a transcode operation."""
|
||||
|
||||
input_path: str
|
||||
output_path: str
|
||||
|
||||
# Video
|
||||
video_codec: str = "libx264"
|
||||
video_bitrate: Optional[str] = None
|
||||
video_crf: Optional[int] = None
|
||||
video_preset: Optional[str] = None
|
||||
resolution: Optional[str] = None
|
||||
framerate: Optional[float] = None
|
||||
|
||||
# Audio
|
||||
audio_codec: str = "aac"
|
||||
audio_bitrate: Optional[str] = None
|
||||
audio_channels: Optional[int] = None
|
||||
audio_samplerate: Optional[int] = None
|
||||
|
||||
# Trimming
|
||||
trim_start: Optional[float] = None
|
||||
trim_end: Optional[float] = None
|
||||
|
||||
# Container
|
||||
container: str = "mp4"
|
||||
|
||||
# Extra args (key-value pairs)
|
||||
extra_args: List[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def is_copy(self) -> bool:
|
||||
"""Check if this is a stream copy (no transcoding)."""
|
||||
return self.video_codec == "copy" and self.audio_codec == "copy"
|
||||
|
||||
|
||||
def build_stream(config: TranscodeConfig):
|
||||
"""
|
||||
Build an ffmpeg-python stream from config.
|
||||
|
||||
Returns the stream object ready to run.
|
||||
"""
|
||||
# Input options
|
||||
input_kwargs = {}
|
||||
if config.trim_start is not None:
|
||||
input_kwargs["ss"] = config.trim_start
|
||||
|
||||
stream = ffmpeg.input(config.input_path, **input_kwargs)
|
||||
|
||||
# Output options
|
||||
output_kwargs = {
|
||||
"vcodec": config.video_codec,
|
||||
"acodec": config.audio_codec,
|
||||
}
|
||||
|
||||
# Trimming duration
|
||||
if config.trim_end is not None:
|
||||
if config.trim_start is not None:
|
||||
output_kwargs["t"] = config.trim_end - config.trim_start
|
||||
else:
|
||||
output_kwargs["t"] = config.trim_end
|
||||
|
||||
# Video options (skip if copy)
|
||||
if config.video_codec != "copy":
|
||||
if config.video_crf is not None:
|
||||
output_kwargs["crf"] = config.video_crf
|
||||
elif config.video_bitrate:
|
||||
output_kwargs["video_bitrate"] = config.video_bitrate
|
||||
|
||||
if config.video_preset:
|
||||
output_kwargs["preset"] = config.video_preset
|
||||
|
||||
if config.resolution:
|
||||
output_kwargs["s"] = config.resolution
|
||||
|
||||
if config.framerate:
|
||||
output_kwargs["r"] = config.framerate
|
||||
|
||||
# Audio options (skip if copy)
|
||||
if config.audio_codec != "copy":
|
||||
if config.audio_bitrate:
|
||||
output_kwargs["audio_bitrate"] = config.audio_bitrate
|
||||
if config.audio_channels:
|
||||
output_kwargs["ac"] = config.audio_channels
|
||||
if config.audio_samplerate:
|
||||
output_kwargs["ar"] = config.audio_samplerate
|
||||
|
||||
# Parse extra args into kwargs
|
||||
extra_kwargs = parse_extra_args(config.extra_args)
|
||||
output_kwargs.update(extra_kwargs)
|
||||
|
||||
stream = ffmpeg.output(stream, config.output_path, **output_kwargs)
|
||||
stream = ffmpeg.overwrite_output(stream)
|
||||
|
||||
return stream
|
||||
|
||||
|
||||
def parse_extra_args(extra_args: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse extra args list into kwargs dict.
|
||||
|
||||
["-vtag", "xvid", "-pix_fmt", "yuv420p"] -> {"vtag": "xvid", "pix_fmt": "yuv420p"}
|
||||
"""
|
||||
kwargs = {}
|
||||
i = 0
|
||||
while i < len(extra_args):
|
||||
key = extra_args[i].lstrip("-")
|
||||
if i + 1 < len(extra_args) and not extra_args[i + 1].startswith("-"):
|
||||
kwargs[key] = extra_args[i + 1]
|
||||
i += 2
|
||||
else:
|
||||
# Flag without value
|
||||
kwargs[key] = None
|
||||
i += 1
|
||||
return kwargs
|
||||
|
||||
|
||||
def transcode(
|
||||
config: TranscodeConfig,
|
||||
duration: Optional[float] = None,
|
||||
progress_callback: Optional[Callable[[float, Dict[str, Any]], None]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Transcode a media file.
|
||||
|
||||
Args:
|
||||
config: Transcode configuration
|
||||
duration: Total duration in seconds (for progress calculation, optional)
|
||||
progress_callback: Called with (percent, details_dict) - requires duration
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
|
||||
Raises:
|
||||
ffmpeg.Error: If transcoding fails
|
||||
"""
|
||||
# Ensure output directory exists
|
||||
Path(config.output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
stream = build_stream(config)
|
||||
|
||||
if progress_callback and duration:
|
||||
# Run with progress tracking using run_async
|
||||
return _run_with_progress(stream, config, duration, progress_callback)
|
||||
else:
|
||||
# Run synchronously
|
||||
ffmpeg.run(stream, capture_stdout=True, capture_stderr=True)
|
||||
return True
|
||||
|
||||
|
||||
def _run_with_progress(
|
||||
stream,
|
||||
config: TranscodeConfig,
|
||||
duration: float,
|
||||
progress_callback: Callable[[float, Dict[str, Any]], None],
|
||||
) -> bool:
|
||||
"""Run FFmpeg with progress tracking using run_async and stderr parsing."""
|
||||
import re
|
||||
|
||||
# Calculate effective duration
|
||||
effective_duration = duration
|
||||
if config.trim_start and config.trim_end:
|
||||
effective_duration = config.trim_end - config.trim_start
|
||||
elif config.trim_end:
|
||||
effective_duration = config.trim_end
|
||||
elif config.trim_start:
|
||||
effective_duration = duration - config.trim_start
|
||||
|
||||
# Run async to get process handle
|
||||
process = ffmpeg.run_async(stream, pipe_stdout=True, pipe_stderr=True)
|
||||
|
||||
# Parse stderr for progress (time=HH:MM:SS.ms pattern)
|
||||
time_pattern = re.compile(r"time=(\d+):(\d+):(\d+)\.(\d+)")
|
||||
|
||||
while True:
|
||||
line = process.stderr.readline()
|
||||
if not line:
|
||||
break
|
||||
|
||||
line = line.decode("utf-8", errors="ignore")
|
||||
match = time_pattern.search(line)
|
||||
if match:
|
||||
hours = int(match.group(1))
|
||||
minutes = int(match.group(2))
|
||||
seconds = int(match.group(3))
|
||||
ms = int(match.group(4))
|
||||
|
||||
current_time = hours * 3600 + minutes * 60 + seconds + ms / 100
|
||||
percent = min(100.0, (current_time / effective_duration) * 100)
|
||||
|
||||
progress_callback(
|
||||
percent,
|
||||
{
|
||||
"time": current_time,
|
||||
"percent": percent,
|
||||
},
|
||||
)
|
||||
|
||||
# Wait for completion
|
||||
process.wait()
|
||||
|
||||
if process.returncode != 0:
|
||||
raise ffmpeg.Error(
|
||||
"ffmpeg", stdout=process.stdout.read(), stderr=process.stderr.read()
|
||||
)
|
||||
|
||||
# Final callback
|
||||
progress_callback(
|
||||
100.0, {"time": effective_duration, "percent": 100.0, "done": True}
|
||||
)
|
||||
|
||||
return True
|
||||
6
core/gpu/models/__init__.py
Normal file
6
core/gpu/models/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# GPU models — standalone container imports.
|
||||
# When running as a container (cd gpu && python server.py), bare imports work.
|
||||
# When imported from the main app (core.gpu.models.preprocess), only
|
||||
# individual modules should be imported directly, not this __init__.
|
||||
#
|
||||
# The server.py imports detect/ocr/vlm directly, not through this file.
|
||||
86
core/gpu/models/cv/segmentation.py
Normal file
86
core/gpu/models/cv/segmentation.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Field segmentation — HSV green mask → pitch boundary contour.
|
||||
|
||||
Pure OpenCV. Called by the inference server endpoint.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def segment_field(
|
||||
image: np.ndarray,
|
||||
hue_low: int = 30,
|
||||
hue_high: int = 85,
|
||||
sat_low: int = 30,
|
||||
sat_high: int = 255,
|
||||
val_low: int = 30,
|
||||
val_high: int = 255,
|
||||
morph_kernel: int = 15,
|
||||
min_area_ratio: float = 0.05,
|
||||
) -> dict:
|
||||
"""
|
||||
Detect the pitch area using HSV green thresholding.
|
||||
|
||||
Returns dict with:
|
||||
boundary: list of [x, y] points
|
||||
coverage: float (fraction of frame)
|
||||
"""
|
||||
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
|
||||
|
||||
lower = np.array([hue_low, sat_low, val_low])
|
||||
upper = np.array([hue_high, sat_high, val_high])
|
||||
mask = cv2.inRange(hsv, lower, upper)
|
||||
|
||||
k = morph_kernel
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
|
||||
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
||||
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
||||
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
h, w = image.shape[:2]
|
||||
min_area = min_area_ratio * h * w
|
||||
boundary = []
|
||||
coverage = 0.0
|
||||
|
||||
if contours:
|
||||
large = [c for c in contours if cv2.contourArea(c) >= min_area]
|
||||
if large:
|
||||
pitch_contour = max(large, key=cv2.contourArea)
|
||||
boundary = pitch_contour.reshape(-1, 2).tolist()
|
||||
coverage = cv2.contourArea(pitch_contour) / (h * w)
|
||||
|
||||
refined = np.zeros_like(mask)
|
||||
cv2.drawContours(refined, [pitch_contour], -1, 255, cv2.FILLED)
|
||||
mask = refined
|
||||
|
||||
return {
|
||||
"boundary": boundary,
|
||||
"coverage": coverage,
|
||||
"mask": mask,
|
||||
}
|
||||
|
||||
|
||||
def segment_field_debug(
|
||||
image: np.ndarray,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""Same as segment_field but includes a mask overlay for the editor."""
|
||||
result = segment_field(image, **kwargs)
|
||||
mask = result["mask"]
|
||||
|
||||
# RGBA overlay: solid green where mask, fully transparent elsewhere
|
||||
h, w = image.shape[:2]
|
||||
overlay = np.zeros((h, w, 4), dtype=np.uint8)
|
||||
overlay[mask > 0] = [0, 255, 0, 255]
|
||||
_, buf = cv2.imencode(".png", overlay)
|
||||
result["mask_overlay_b64"] = base64.b64encode(buf.tobytes()).decode()
|
||||
|
||||
# Don't send the raw mask over HTTP
|
||||
del result["mask"]
|
||||
return result
|
||||
@@ -101,6 +101,30 @@ class AnalyzeRegionsDebugResponse(BaseModel):
|
||||
horizontal_count: int = 0
|
||||
pair_count: int = 0
|
||||
|
||||
class SegmentFieldRequest(BaseModel):
|
||||
"""Request body for field segmentation."""
|
||||
image: str
|
||||
hue_low: int = 30
|
||||
hue_high: int = 85
|
||||
sat_low: int = 30
|
||||
sat_high: int = 255
|
||||
val_low: int = 30
|
||||
val_high: int = 255
|
||||
morph_kernel: int = 15
|
||||
min_area_ratio: float = 0.05
|
||||
|
||||
class SegmentFieldResponse(BaseModel):
|
||||
"""Response from field segmentation."""
|
||||
boundary: List[List[int]] = Field(default_factory=list)
|
||||
coverage: float = 0.0
|
||||
mask_b64: str = ""
|
||||
|
||||
class SegmentFieldDebugResponse(BaseModel):
|
||||
"""Response from field segmentation with debug overlay."""
|
||||
boundary: List[List[int]] = Field(default_factory=list)
|
||||
coverage: float = 0.0
|
||||
mask_overlay_b64: str = ""
|
||||
|
||||
class ConfigUpdate(BaseModel):
|
||||
"""Request body for updating server configuration."""
|
||||
device: Optional[str] = None
|
||||
@@ -54,7 +54,7 @@ def _gpu_log(job_id: str, log_level: str, stage: str, level: str, msg: str):
|
||||
|
||||
# --- Request/Response models (generated from core/schema/models/inference.py) ---
|
||||
|
||||
from models.inference_contract import (
|
||||
from models.models import (
|
||||
AnalyzeRegionsDebugResponse,
|
||||
AnalyzeRegionsRequest,
|
||||
AnalyzeRegionsResponse,
|
||||
@@ -68,6 +68,9 @@ from models.inference_contract import (
|
||||
PreprocessRequest,
|
||||
PreprocessResponse,
|
||||
RegionBox,
|
||||
SegmentFieldRequest,
|
||||
SegmentFieldResponse,
|
||||
SegmentFieldDebugResponse,
|
||||
VLMRequest,
|
||||
VLMResponse,
|
||||
)
|
||||
@@ -310,6 +313,83 @@ def detect_edges_debug_endpoint(req: AnalyzeRegionsRequest, request: Request):
|
||||
return response
|
||||
|
||||
|
||||
@app.post("/segment_field", response_model=SegmentFieldResponse)
|
||||
def segment_field_endpoint(req: SegmentFieldRequest, request: Request):
|
||||
job_id, log_level = _job_ctx(request)
|
||||
|
||||
try:
|
||||
image = _decode_image(req.image)
|
||||
h, w = image.shape[:2]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||
|
||||
try:
|
||||
t0 = time.monotonic()
|
||||
from models.cv.segmentation import segment_field
|
||||
|
||||
result = segment_field(
|
||||
image,
|
||||
hue_low=req.hue_low,
|
||||
hue_high=req.hue_high,
|
||||
sat_low=req.sat_low,
|
||||
sat_high=req.sat_high,
|
||||
val_low=req.val_low,
|
||||
val_high=req.val_high,
|
||||
morph_kernel=req.morph_kernel,
|
||||
min_area_ratio=req.min_area_ratio,
|
||||
)
|
||||
infer_ms = (time.monotonic() - t0) * 1000
|
||||
|
||||
# Encode mask as base64 PNG for downstream use
|
||||
import cv2
|
||||
_, buf = cv2.imencode(".png", result["mask"])
|
||||
mask_b64 = base64.b64encode(buf.tobytes()).decode()
|
||||
|
||||
_gpu_log(job_id, log_level, "GPU:CV", "DEBUG",
|
||||
f"Field segmentation {w}x{h}: {infer_ms:.0f}ms, coverage={result['coverage']:.1%}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Field segmentation failed: {e}")
|
||||
|
||||
return SegmentFieldResponse(
|
||||
boundary=result["boundary"],
|
||||
coverage=result["coverage"],
|
||||
mask_b64=mask_b64,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/segment_field/debug", response_model=SegmentFieldDebugResponse)
|
||||
def segment_field_debug_endpoint(req: SegmentFieldRequest, request: Request):
|
||||
job_id, log_level = _job_ctx(request)
|
||||
|
||||
try:
|
||||
image = _decode_image(req.image)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||
|
||||
try:
|
||||
from models.cv.segmentation import segment_field_debug
|
||||
|
||||
result = segment_field_debug(
|
||||
image,
|
||||
hue_low=req.hue_low,
|
||||
hue_high=req.hue_high,
|
||||
sat_low=req.sat_low,
|
||||
sat_high=req.sat_high,
|
||||
val_low=req.val_low,
|
||||
val_high=req.val_high,
|
||||
morph_kernel=req.morph_kernel,
|
||||
min_area_ratio=req.min_area_ratio,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Field segmentation debug failed: {e}")
|
||||
|
||||
return SegmentFieldDebugResponse(
|
||||
boundary=result["boundary"],
|
||||
coverage=result["coverage"],
|
||||
mask_overlay_b64=result["mask_overlay_b64"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
@@ -180,24 +180,11 @@ class WorkerServicer(worker_pb2_grpc.WorkerServiceServicer):
|
||||
|
||||
def GetWorkerStatus(self, request, context):
|
||||
"""Get worker health and capabilities."""
|
||||
try:
|
||||
from core.ffmpeg import get_encoders
|
||||
|
||||
encoders = get_encoders()
|
||||
codec_names = [e["name"] for e in encoders.get("video", [])]
|
||||
except Exception:
|
||||
codec_names = []
|
||||
|
||||
# Check for GPU encoders
|
||||
gpu_available = any(
|
||||
"nvenc" in name or "vaapi" in name or "qsv" in name for name in codec_names
|
||||
)
|
||||
|
||||
return worker_pb2.WorkerStatus(
|
||||
available=True,
|
||||
active_jobs=len(_active_jobs),
|
||||
supported_codecs=codec_names[:20], # Limit to 20
|
||||
gpu_available=gpu_available,
|
||||
supported_codecs=[],
|
||||
gpu_available=False,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user