phase 4
This commit is contained in:
@@ -102,6 +102,7 @@ class Job(models.Model):
|
|||||||
source_asset_id = models.UUIDField()
|
source_asset_id = models.UUIDField()
|
||||||
video_path = models.CharField(max_length=1000)
|
video_path = models.CharField(max_length=1000)
|
||||||
profile_name = models.CharField(max_length=255)
|
profile_name = models.CharField(max_length=255)
|
||||||
|
timeline_id = models.UUIDField(null=True, blank=True)
|
||||||
parent_id = models.UUIDField(null=True, blank=True)
|
parent_id = models.UUIDField(null=True, blank=True)
|
||||||
run_type = models.CharField(max_length=20, choices=RunType.choices, default=RunType.INITIAL)
|
run_type = models.CharField(max_length=20, choices=RunType.choices, default=RunType.INITIAL)
|
||||||
config_overrides = models.JSONField(default=dict, blank=True)
|
config_overrides = models.JSONField(default=dict, blank=True)
|
||||||
@@ -113,7 +114,6 @@ class Job(models.Model):
|
|||||||
brands_found = models.IntegerField(default=0)
|
brands_found = models.IntegerField(default=0)
|
||||||
cloud_llm_calls = models.IntegerField(default=0)
|
cloud_llm_calls = models.IntegerField(default=0)
|
||||||
estimated_cost_usd = models.FloatField(default=0.0)
|
estimated_cost_usd = models.FloatField(default=0.0)
|
||||||
celery_task_id = models.CharField(max_length=255, null=True, blank=True)
|
|
||||||
priority = models.IntegerField(default=0)
|
priority = models.IntegerField(default=0)
|
||||||
created_at = models.DateTimeField(auto_now_add=True)
|
created_at = models.DateTimeField(auto_now_add=True)
|
||||||
started_at = models.DateTimeField(null=True, blank=True)
|
started_at = models.DateTimeField(null=True, blank=True)
|
||||||
@@ -151,6 +151,7 @@ class Checkpoint(models.Model):
|
|||||||
|
|
||||||
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
|
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
|
||||||
timeline_id = models.UUIDField()
|
timeline_id = models.UUIDField()
|
||||||
|
job_id = models.UUIDField(null=True, blank=True)
|
||||||
parent_id = models.UUIDField(null=True, blank=True)
|
parent_id = models.UUIDField(null=True, blank=True)
|
||||||
stage_outputs = models.JSONField(default=dict, blank=True)
|
stage_outputs = models.JSONField(default=dict, blank=True)
|
||||||
config_overrides = models.JSONField(default=dict, blank=True)
|
config_overrides = models.JSONField(default=dict, blank=True)
|
||||||
@@ -185,3 +186,18 @@ class Brand(models.Model):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return str(self.id)
|
return str(self.id)
|
||||||
|
|
||||||
|
|
||||||
|
class Profile(models.Model):
|
||||||
|
"""A content type profile."""
|
||||||
|
|
||||||
|
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
|
||||||
|
name = models.CharField(max_length=255)
|
||||||
|
pipeline = models.JSONField(default=dict, blank=True)
|
||||||
|
configs = models.JSONField(default=dict, blank=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|||||||
@@ -1,73 +0,0 @@
|
|||||||
"""
|
|
||||||
SSE endpoint for chunker pipeline events.
|
|
||||||
|
|
||||||
Uses Redis as the event bus. Pipeline pushes events via core.events,
|
|
||||||
SSE endpoint polls them.
|
|
||||||
|
|
||||||
GET /chunker/stream/{job_id} → text/event-stream
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from typing import AsyncGenerator
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
|
||||||
from starlette.responses import StreamingResponse
|
|
||||||
|
|
||||||
from core.events import poll_events
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/chunker", tags=["chunker"])
|
|
||||||
|
|
||||||
|
|
||||||
async def _event_generator(job_id: str) -> AsyncGenerator[str, None]:
|
|
||||||
"""
|
|
||||||
Generate SSE events by polling Redis for chunk job events.
|
|
||||||
"""
|
|
||||||
cursor = 0
|
|
||||||
timeout = time.monotonic() + 600 # 10 min max
|
|
||||||
|
|
||||||
while time.monotonic() < timeout:
|
|
||||||
events, cursor = poll_events(job_id, cursor)
|
|
||||||
|
|
||||||
if not events:
|
|
||||||
yield f"event: waiting\ndata: {json.dumps({'job_id': job_id})}\n\n"
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
for data in events:
|
|
||||||
event_type = data.pop("event", "update")
|
|
||||||
payload = {**data, "job_id": job_id}
|
|
||||||
|
|
||||||
yield f"event: {event_type}\ndata: {json.dumps(payload)}\n\n"
|
|
||||||
|
|
||||||
if event_type in ("pipeline_complete", "pipeline_error", "cancelled"):
|
|
||||||
yield f"event: done\ndata: {json.dumps({'job_id': job_id})}\n\n"
|
|
||||||
return
|
|
||||||
|
|
||||||
await asyncio.sleep(0.05)
|
|
||||||
|
|
||||||
yield f"event: timeout\ndata: {json.dumps({'job_id': job_id})}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/stream/{job_id}")
|
|
||||||
async def stream_chunk_job(job_id: str):
|
|
||||||
"""
|
|
||||||
SSE stream for a chunk pipeline job.
|
|
||||||
|
|
||||||
The UI connects via native EventSource:
|
|
||||||
const es = new EventSource('/api/chunker/stream/<job_id>');
|
|
||||||
es.addEventListener('processing', (e) => { ... });
|
|
||||||
"""
|
|
||||||
return StreamingResponse(
|
|
||||||
_event_generator(job_id),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@@ -58,32 +58,30 @@ def write_config(update: ConfigUpdate):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/config/profiles")
|
@router.get("/config/profiles")
|
||||||
def list_profiles():
|
def get_profiles():
|
||||||
"""List available detection profiles."""
|
"""List available detection profiles."""
|
||||||
from detect.profiles import _PROFILES
|
from core.detect.profile import list_profiles as _list
|
||||||
return [{"name": name} for name in _PROFILES]
|
return [{"name": name} for name in _list()]
|
||||||
|
|
||||||
|
|
||||||
@router.get("/config/profiles/{profile_name}/pipeline")
|
@router.get("/config/profiles/{profile_name}/pipeline")
|
||||||
def get_pipeline_config(profile_name: str):
|
def get_pipeline_config(profile_name: str):
|
||||||
"""Return the pipeline composition for a profile."""
|
"""Return the pipeline composition for a profile."""
|
||||||
from detect.profiles import get_profile
|
from core.detect.profile import get_profile
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from dataclasses import asdict
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
profile = get_profile(profile_name)
|
profile = get_profile(profile_name)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise HTTPException(status_code=404, detail=f"Unknown profile: {profile_name}")
|
raise HTTPException(status_code=404, detail=f"Unknown profile: {profile_name}")
|
||||||
|
|
||||||
config = profile.pipeline_config()
|
return profile["pipeline"]
|
||||||
return asdict(config)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/config/stages", response_model=list[StageConfigInfo])
|
@router.get("/config/stages", response_model=list[StageConfigInfo])
|
||||||
def list_stage_configs():
|
def list_stage_configs():
|
||||||
"""Return the stage palette with config field metadata for the editor."""
|
"""Return the stage palette with config field metadata for the editor."""
|
||||||
from detect.stages import list_stages
|
from core.detect.stages import list_stages
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for stage in list_stages():
|
for stage in list_stages():
|
||||||
@@ -95,7 +93,7 @@ def list_stage_configs():
|
|||||||
@router.get("/config/stages/{stage_name}", response_model=StageConfigInfo)
|
@router.get("/config/stages/{stage_name}", response_model=StageConfigInfo)
|
||||||
def get_stage_config(stage_name: str):
|
def get_stage_config(stage_name: str):
|
||||||
"""Return config field metadata for a single stage."""
|
"""Return config field metadata for a single stage."""
|
||||||
from detect.stages import get_stage
|
from core.detect.stages import get_stage
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stage = get_stage(stage_name)
|
stage = get_stage(stage_name)
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ class ReplaySingleStageResponse(BaseModel):
|
|||||||
@router.get("/checkpoints/{timeline_id}")
|
@router.get("/checkpoints/{timeline_id}")
|
||||||
def list_checkpoints(timeline_id: str) -> list[CheckpointInfo]:
|
def list_checkpoints(timeline_id: str) -> list[CheckpointInfo]:
|
||||||
"""List available checkpoint stages for a job."""
|
"""List available checkpoint stages for a job."""
|
||||||
from detect.checkpoint import list_checkpoints as _list
|
from core.detect.checkpoint import list_checkpoints as _list
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stages = _list(timeline_id)
|
stages = _list(timeline_id)
|
||||||
@@ -139,10 +139,10 @@ class CheckpointData(BaseModel):
|
|||||||
def get_checkpoint_data(timeline_id: str, stage: str):
|
def get_checkpoint_data(timeline_id: str, stage: str):
|
||||||
"""Load checkpoint frames + metadata for the editor UI."""
|
"""Load checkpoint frames + metadata for the editor UI."""
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from core.db.tables import Timeline, Checkpoint
|
from core.db.models import Timeline, Checkpoint
|
||||||
from core.db.connection import get_session
|
from core.db.connection import get_session
|
||||||
from core.db.checkpoint import list_checkpoints
|
from core.db.checkpoint import list_checkpoints
|
||||||
from detect.checkpoint.frames import load_frames_b64
|
from core.detect.checkpoint.frames import load_frames_b64
|
||||||
|
|
||||||
with get_session() as session:
|
with get_session() as session:
|
||||||
timeline = session.get(Timeline, UUID(timeline_id))
|
timeline = session.get(Timeline, UUID(timeline_id))
|
||||||
@@ -184,7 +184,7 @@ def get_checkpoint_data(timeline_id: str, stage: str):
|
|||||||
@router.get("/scenarios", response_model=list[ScenarioInfo])
|
@router.get("/scenarios", response_model=list[ScenarioInfo])
|
||||||
def list_scenarios_endpoint():
|
def list_scenarios_endpoint():
|
||||||
"""List all available scenarios (bookmarked checkpoints)."""
|
"""List all available scenarios (bookmarked checkpoints)."""
|
||||||
from core.db.tables import Timeline
|
from core.db.models import Timeline
|
||||||
from core.db.connection import get_session
|
from core.db.connection import get_session
|
||||||
from core.db.checkpoint import list_scenarios
|
from core.db.checkpoint import list_scenarios
|
||||||
|
|
||||||
@@ -212,7 +212,7 @@ def list_scenarios_endpoint():
|
|||||||
@router.post("/replay", response_model=ReplayResponse)
|
@router.post("/replay", response_model=ReplayResponse)
|
||||||
def replay(req: ReplayRequest):
|
def replay(req: ReplayRequest):
|
||||||
"""Replay pipeline from a specific stage with optional config overrides."""
|
"""Replay pipeline from a specific stage with optional config overrides."""
|
||||||
from detect.checkpoint import replay_from
|
from core.detect.checkpoint import replay_from
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = replay_from(
|
result = replay_from(
|
||||||
@@ -242,7 +242,7 @@ def replay(req: ReplayRequest):
|
|||||||
@router.post("/retry", response_model=RetryResponse)
|
@router.post("/retry", response_model=RetryResponse)
|
||||||
def retry(req: RetryRequest):
|
def retry(req: RetryRequest):
|
||||||
"""Queue an async retry of unresolved candidates with different config."""
|
"""Queue an async retry of unresolved candidates with different config."""
|
||||||
from detect.checkpoint.tasks import retry_candidates
|
from core.detect.checkpoint.tasks import retry_candidates
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"timeline_id": req.timeline_id,
|
"timeline_id": req.timeline_id,
|
||||||
@@ -266,7 +266,7 @@ def retry(req: RetryRequest):
|
|||||||
@router.post("/replay-stage", response_model=ReplaySingleStageResponse)
|
@router.post("/replay-stage", response_model=ReplaySingleStageResponse)
|
||||||
def replay_single_stage(req: ReplaySingleStageRequest):
|
def replay_single_stage(req: ReplaySingleStageRequest):
|
||||||
"""Replay a single stage on specific frames — fast path for interactive tuning."""
|
"""Replay a single stage on specific frames — fast path for interactive tuning."""
|
||||||
from detect.checkpoint.replay import replay_single_stage as _replay
|
from core.detect.checkpoint.replay import replay_single_stage as _replay
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = _replay(
|
result = _replay(
|
||||||
@@ -361,3 +361,41 @@ async def gpu_detect_edges_debug(request: Request):
|
|||||||
media_type="application/json")
|
media_type="application/json")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
|
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/gpu/segment_field")
|
||||||
|
async def gpu_segment_field(request: Request):
|
||||||
|
"""Proxy to GPU inference server — field segmentation."""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
body = await request.body()
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{_gpu_url()}/segment_field",
|
||||||
|
content=body,
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
return Response(content=resp.content, status_code=resp.status_code,
|
||||||
|
media_type="application/json")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/gpu/segment_field/debug")
|
||||||
|
async def gpu_segment_field_debug(request: Request):
|
||||||
|
"""Proxy to GPU inference server — field segmentation with debug overlay."""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
body = await request.body()
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{_gpu_url()}/segment_field/debug",
|
||||||
|
content=body,
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
return Response(content=resp.content, status_code=resp.status_code,
|
||||||
|
media_type="application/json")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
|
||||||
|
|||||||
@@ -60,9 +60,9 @@ def _resolve_video_path(video_path: str) -> str:
|
|||||||
@router.post("/run", response_model=RunResponse)
|
@router.post("/run", response_model=RunResponse)
|
||||||
def run_pipeline(req: RunRequest):
|
def run_pipeline(req: RunRequest):
|
||||||
"""Launch a detection pipeline run on a source chunk."""
|
"""Launch a detection pipeline run on a source chunk."""
|
||||||
from detect import emit
|
from core.detect import emit
|
||||||
from detect.graph import get_pipeline
|
from core.detect.graph import get_pipeline
|
||||||
from detect.state import DetectState
|
from core.detect.state import DetectState
|
||||||
|
|
||||||
local_path = _resolve_video_path(req.video_path)
|
local_path = _resolve_video_path(req.video_path)
|
||||||
job_id = str(uuid.uuid4())
|
job_id = str(uuid.uuid4())
|
||||||
@@ -79,7 +79,7 @@ def run_pipeline(req: RunRequest):
|
|||||||
|
|
||||||
# Clear any stale events from a previous run with same job_id
|
# Clear any stale events from a previous run with same job_id
|
||||||
from core.events import _get_redis
|
from core.events import _get_redis
|
||||||
from detect.events import DETECT_EVENTS_PREFIX
|
from core.detect.events import DETECT_EVENTS_PREFIX
|
||||||
r = _get_redis()
|
r = _get_redis()
|
||||||
r.delete(f"{DETECT_EVENTS_PREFIX}:{job_id}")
|
r.delete(f"{DETECT_EVENTS_PREFIX}:{job_id}")
|
||||||
|
|
||||||
@@ -97,7 +97,7 @@ def run_pipeline(req: RunRequest):
|
|||||||
source_asset_id=req.source_asset_id,
|
source_asset_id=req.source_asset_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
from detect.graph import (
|
from core.detect.graph import (
|
||||||
PipelineCancelled, set_cancel_check, clear_cancel_check,
|
PipelineCancelled, set_cancel_check, clear_cancel_check,
|
||||||
init_pause, clear_pause,
|
init_pause, clear_pause,
|
||||||
)
|
)
|
||||||
@@ -117,7 +117,7 @@ def run_pipeline(req: RunRequest):
|
|||||||
emit.job_complete(job_id, {"status": "cancelled"})
|
emit.job_complete(job_id, {"status": "cancelled"})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Pipeline run %s failed: %s", job_id, e)
|
logger.exception("Pipeline run %s failed: %s", job_id, e)
|
||||||
from detect.graph import _node_states, NODES
|
from core.detect.graph import _node_states, NODES
|
||||||
if job_id in _node_states:
|
if job_id in _node_states:
|
||||||
states = _node_states[job_id]
|
states = _node_states[job_id]
|
||||||
for node in reversed(NODES):
|
for node in reversed(NODES):
|
||||||
@@ -145,7 +145,7 @@ def run_pipeline(req: RunRequest):
|
|||||||
@router.post("/stop/{job_id}")
|
@router.post("/stop/{job_id}")
|
||||||
def stop_pipeline(job_id: str):
|
def stop_pipeline(job_id: str):
|
||||||
"""Stop a running pipeline. Signals cancellation; the thread checks on next stage."""
|
"""Stop a running pipeline. Signals cancellation; the thread checks on next stage."""
|
||||||
from detect import emit
|
from core.detect import emit
|
||||||
|
|
||||||
if job_id not in _running_jobs:
|
if job_id not in _running_jobs:
|
||||||
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
||||||
@@ -158,7 +158,7 @@ def stop_pipeline(job_id: str):
|
|||||||
@router.post("/pause/{job_id}")
|
@router.post("/pause/{job_id}")
|
||||||
def pause(job_id: str):
|
def pause(job_id: str):
|
||||||
"""Pause a running pipeline after the current stage completes."""
|
"""Pause a running pipeline after the current stage completes."""
|
||||||
from detect.graph import pause_pipeline
|
from core.detect.graph import pause_pipeline
|
||||||
|
|
||||||
if job_id not in _running_jobs:
|
if job_id not in _running_jobs:
|
||||||
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
||||||
@@ -170,7 +170,7 @@ def pause(job_id: str):
|
|||||||
@router.post("/resume/{job_id}")
|
@router.post("/resume/{job_id}")
|
||||||
def resume(job_id: str):
|
def resume(job_id: str):
|
||||||
"""Resume a paused pipeline."""
|
"""Resume a paused pipeline."""
|
||||||
from detect.graph import resume_pipeline
|
from core.detect.graph import resume_pipeline
|
||||||
|
|
||||||
if job_id not in _running_jobs:
|
if job_id not in _running_jobs:
|
||||||
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
||||||
@@ -182,7 +182,7 @@ def resume(job_id: str):
|
|||||||
@router.post("/step/{job_id}")
|
@router.post("/step/{job_id}")
|
||||||
def step(job_id: str):
|
def step(job_id: str):
|
||||||
"""Run one stage then pause again."""
|
"""Run one stage then pause again."""
|
||||||
from detect.graph import step_pipeline
|
from core.detect.graph import step_pipeline
|
||||||
|
|
||||||
if job_id not in _running_jobs:
|
if job_id not in _running_jobs:
|
||||||
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
||||||
@@ -194,7 +194,7 @@ def step(job_id: str):
|
|||||||
@router.post("/pause-after-stage/{job_id}")
|
@router.post("/pause-after-stage/{job_id}")
|
||||||
def toggle_pause_after_stage(job_id: str, enabled: bool = True):
|
def toggle_pause_after_stage(job_id: str, enabled: bool = True):
|
||||||
"""Toggle pause-after-each-stage mode."""
|
"""Toggle pause-after-each-stage mode."""
|
||||||
from detect.graph import set_pause_after_stage
|
from core.detect.graph import set_pause_after_stage
|
||||||
|
|
||||||
if job_id not in _running_jobs:
|
if job_id not in _running_jobs:
|
||||||
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
||||||
@@ -206,7 +206,7 @@ def toggle_pause_after_stage(job_id: str, enabled: bool = True):
|
|||||||
@router.get("/status/{job_id}")
|
@router.get("/status/{job_id}")
|
||||||
def pipeline_status(job_id: str):
|
def pipeline_status(job_id: str):
|
||||||
"""Get pipeline run status."""
|
"""Get pipeline run status."""
|
||||||
from detect.graph import is_paused
|
from core.detect.graph import is_paused
|
||||||
|
|
||||||
running = job_id in _running_jobs
|
running = job_id in _running_jobs
|
||||||
paused = is_paused(job_id)
|
paused = is_paused(job_id)
|
||||||
@@ -224,11 +224,23 @@ def pipeline_status(job_id: str):
|
|||||||
return {"status": status, "job_id": job_id}
|
return {"status": status, "job_id": job_id}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/timeline/{job_id}")
|
||||||
|
def get_timeline_for_job(job_id: str):
|
||||||
|
"""Get the timeline_id for a running or completed job."""
|
||||||
|
from core.detect.checkpoint.runner_bridge import get_timeline_id
|
||||||
|
|
||||||
|
tid = get_timeline_id(job_id)
|
||||||
|
if tid is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"No timeline for job: {job_id}")
|
||||||
|
|
||||||
|
return {"timeline_id": tid, "job_id": job_id}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/clear/{job_id}")
|
@router.post("/clear/{job_id}")
|
||||||
def clear_pipeline(job_id: str):
|
def clear_pipeline(job_id: str):
|
||||||
"""Clear events for a job from Redis."""
|
"""Clear events for a job from Redis."""
|
||||||
from core.events import _get_redis
|
from core.events import _get_redis
|
||||||
from detect.events import DETECT_EVENTS_PREFIX
|
from core.detect.events import DETECT_EVENTS_PREFIX
|
||||||
|
|
||||||
r = _get_redis()
|
r = _get_redis()
|
||||||
r.delete(f"{DETECT_EVENTS_PREFIX}:{job_id}")
|
r.delete(f"{DETECT_EVENTS_PREFIX}:{job_id}")
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from fastapi import APIRouter
|
|||||||
from starlette.responses import StreamingResponse
|
from starlette.responses import StreamingResponse
|
||||||
|
|
||||||
from core.events import poll_events
|
from core.events import poll_events
|
||||||
from detect.events import DETECT_EVENTS_PREFIX, TERMINAL_EVENTS
|
from core.detect.events import DETECT_EVENTS_PREFIX, TERMINAL_EVENTS
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -1,384 +0,0 @@
|
|||||||
"""
|
|
||||||
GraphQL API using strawberry, served via FastAPI.
|
|
||||||
|
|
||||||
Primary API for MPR — all client interactions go through GraphQL.
|
|
||||||
Uses core.db for data access.
|
|
||||||
Types are generated from schema/ via modelgen — see api/schema/graphql.py.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import List, Optional
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
import strawberry
|
|
||||||
from strawberry.schema.config import StrawberryConfig
|
|
||||||
from strawberry.types import Info
|
|
||||||
|
|
||||||
from core.api.schema.graphql import (
|
|
||||||
CancelResultType,
|
|
||||||
ChunkJobType,
|
|
||||||
ChunkOutputFileType,
|
|
||||||
CreateChunkJobInput,
|
|
||||||
CreateJobInput,
|
|
||||||
DeleteResultType,
|
|
||||||
MediaAssetType,
|
|
||||||
ScanResultType,
|
|
||||||
SystemStatusType,
|
|
||||||
TranscodeJobType,
|
|
||||||
TranscodePresetType,
|
|
||||||
UpdateAssetInput,
|
|
||||||
)
|
|
||||||
from core.storage import BUCKET_IN, list_objects, upload_file
|
|
||||||
|
|
||||||
VIDEO_EXTS = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv", ".wmv", ".m4v"}
|
|
||||||
AUDIO_EXTS = {".mp3", ".wav", ".flac", ".aac", ".ogg", ".m4a"}
|
|
||||||
MEDIA_EXTS = VIDEO_EXTS | AUDIO_EXTS
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Queries
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.type
|
|
||||||
class Query:
|
|
||||||
@strawberry.field
|
|
||||||
def assets(
|
|
||||||
self,
|
|
||||||
info: Info,
|
|
||||||
status: Optional[str] = None,
|
|
||||||
search: Optional[str] = None,
|
|
||||||
) -> List[MediaAssetType]:
|
|
||||||
from core.db import list_assets
|
|
||||||
|
|
||||||
return list_assets(status=status, search=search)
|
|
||||||
|
|
||||||
@strawberry.field
|
|
||||||
def asset(self, info: Info, id: UUID) -> Optional[MediaAssetType]:
|
|
||||||
from core.db import get_asset
|
|
||||||
|
|
||||||
try:
|
|
||||||
return get_asset(id)
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@strawberry.field
|
|
||||||
def jobs(
|
|
||||||
self,
|
|
||||||
info: Info,
|
|
||||||
status: Optional[str] = None,
|
|
||||||
source_asset_id: Optional[UUID] = None,
|
|
||||||
) -> List[TranscodeJobType]:
|
|
||||||
from core.db import list_jobs
|
|
||||||
|
|
||||||
return list_jobs(status=status, source_asset_id=source_asset_id)
|
|
||||||
|
|
||||||
@strawberry.field
|
|
||||||
def job(self, info: Info, id: UUID) -> Optional[TranscodeJobType]:
|
|
||||||
from core.db import get_job
|
|
||||||
|
|
||||||
try:
|
|
||||||
return get_job(id)
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@strawberry.field
|
|
||||||
def presets(self, info: Info) -> List[TranscodePresetType]:
|
|
||||||
from core.db import list_presets
|
|
||||||
|
|
||||||
return list_presets()
|
|
||||||
|
|
||||||
@strawberry.field
|
|
||||||
def system_status(self, info: Info) -> SystemStatusType:
|
|
||||||
return SystemStatusType(status="ok", version="0.1.0")
|
|
||||||
|
|
||||||
@strawberry.field
|
|
||||||
def chunk_output_files(self, info: Info, job_id: str) -> List[ChunkOutputFileType]:
|
|
||||||
"""List output chunk files for a completed job from media/out/."""
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
media_out = os.environ.get("MEDIA_OUT_DIR", "/app/media/out")
|
|
||||||
output_dir = Path(media_out) / "chunks" / job_id
|
|
||||||
if not output_dir.is_dir():
|
|
||||||
return []
|
|
||||||
return [
|
|
||||||
ChunkOutputFileType(
|
|
||||||
key=f.name,
|
|
||||||
size=f.stat().st_size,
|
|
||||||
url=f"/media/out/chunks/{job_id}/{f.name}",
|
|
||||||
)
|
|
||||||
for f in sorted(output_dir.iterdir())
|
|
||||||
if f.is_file()
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Mutations
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.type
|
|
||||||
class Mutation:
|
|
||||||
@strawberry.mutation
|
|
||||||
def scan_media_folder(self, info: Info) -> ScanResultType:
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from core.db import create_asset, get_asset_filenames
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Sync local media/in/ files to MinIO (handles fresh installs / pruned volumes)
|
|
||||||
local_media = Path("/app/media/in")
|
|
||||||
if local_media.is_dir():
|
|
||||||
existing_keys = {o["key"] for o in list_objects(BUCKET_IN)}
|
|
||||||
for f in local_media.iterdir():
|
|
||||||
if f.is_file() and f.suffix.lower() in MEDIA_EXTS:
|
|
||||||
if f.name not in existing_keys:
|
|
||||||
try:
|
|
||||||
upload_file(str(f), BUCKET_IN, f.name)
|
|
||||||
logger.info("Uploaded %s to MinIO", f.name)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Failed to upload %s: %s", f.name, e)
|
|
||||||
|
|
||||||
objects = list_objects(BUCKET_IN, extensions=MEDIA_EXTS)
|
|
||||||
existing = get_asset_filenames()
|
|
||||||
|
|
||||||
registered = []
|
|
||||||
skipped = []
|
|
||||||
|
|
||||||
for obj in objects:
|
|
||||||
if obj["filename"] in existing:
|
|
||||||
skipped.append(obj["filename"])
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
create_asset(
|
|
||||||
filename=obj["filename"],
|
|
||||||
file_path=obj["key"],
|
|
||||||
file_size=obj["size"],
|
|
||||||
)
|
|
||||||
registered.append(obj["filename"])
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return ScanResultType(
|
|
||||||
found=len(objects),
|
|
||||||
registered=len(registered),
|
|
||||||
skipped=len(skipped),
|
|
||||||
files=registered,
|
|
||||||
)
|
|
||||||
|
|
||||||
@strawberry.mutation
|
|
||||||
def create_job(self, info: Info, input: CreateJobInput) -> TranscodeJobType:
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from core.db import create_job, get_asset, get_preset
|
|
||||||
|
|
||||||
try:
|
|
||||||
source = get_asset(input.source_asset_id)
|
|
||||||
except Exception:
|
|
||||||
raise Exception("Source asset not found")
|
|
||||||
|
|
||||||
preset = None
|
|
||||||
preset_snapshot = {}
|
|
||||||
if input.preset_id:
|
|
||||||
try:
|
|
||||||
preset = get_preset(input.preset_id)
|
|
||||||
preset_snapshot = {
|
|
||||||
"name": preset.name,
|
|
||||||
"container": preset.container,
|
|
||||||
"video_codec": preset.video_codec,
|
|
||||||
"audio_codec": preset.audio_codec,
|
|
||||||
}
|
|
||||||
except Exception:
|
|
||||||
raise Exception("Preset not found")
|
|
||||||
|
|
||||||
if not preset and not input.trim_start and not input.trim_end:
|
|
||||||
raise Exception("Must specify preset_id or trim_start/trim_end")
|
|
||||||
|
|
||||||
output_filename = input.output_filename
|
|
||||||
if not output_filename:
|
|
||||||
stem = Path(source.filename).stem
|
|
||||||
ext = preset_snapshot.get("container", "mp4") if preset else "mp4"
|
|
||||||
output_filename = f"{stem}_output.{ext}"
|
|
||||||
|
|
||||||
job = create_job(
|
|
||||||
source_asset_id=source.id,
|
|
||||||
preset_id=preset.id if preset else None,
|
|
||||||
preset_snapshot=preset_snapshot,
|
|
||||||
trim_start=input.trim_start,
|
|
||||||
trim_end=input.trim_end,
|
|
||||||
output_filename=output_filename,
|
|
||||||
output_path=output_filename,
|
|
||||||
priority=input.priority or 0,
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"source_key": source.file_path,
|
|
||||||
"output_key": output_filename,
|
|
||||||
"preset": preset_snapshot or None,
|
|
||||||
"trim_start": input.trim_start,
|
|
||||||
"trim_end": input.trim_end,
|
|
||||||
"duration": source.duration,
|
|
||||||
}
|
|
||||||
|
|
||||||
executor_mode = os.environ.get("MPR_EXECUTOR", "local")
|
|
||||||
if executor_mode in ("lambda", "gcp"):
|
|
||||||
from core.jobs.executor import get_executor
|
|
||||||
|
|
||||||
get_executor().run(
|
|
||||||
job_type="transcode",
|
|
||||||
job_id=str(job.id),
|
|
||||||
payload=payload,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
from core.jobs.task import run_job
|
|
||||||
|
|
||||||
result = run_job.delay(
|
|
||||||
job_type="transcode",
|
|
||||||
job_id=str(job.id),
|
|
||||||
payload=payload,
|
|
||||||
)
|
|
||||||
job.celery_task_id = result.id
|
|
||||||
job.save(update_fields=["celery_task_id"])
|
|
||||||
|
|
||||||
return job
|
|
||||||
|
|
||||||
@strawberry.mutation
|
|
||||||
def cancel_job(self, info: Info, id: UUID) -> TranscodeJobType:
|
|
||||||
from core.db import get_job, update_job
|
|
||||||
|
|
||||||
try:
|
|
||||||
job = get_job(id)
|
|
||||||
except Exception:
|
|
||||||
raise Exception("Job not found")
|
|
||||||
|
|
||||||
if job.status not in ("pending", "processing"):
|
|
||||||
raise Exception(f"Cannot cancel job with status: {job.status}")
|
|
||||||
|
|
||||||
return update_job(job, status="cancelled")
|
|
||||||
|
|
||||||
@strawberry.mutation
|
|
||||||
def retry_job(self, info: Info, id: UUID) -> TranscodeJobType:
|
|
||||||
from core.db import get_job, update_job
|
|
||||||
|
|
||||||
try:
|
|
||||||
job = get_job(id)
|
|
||||||
except Exception:
|
|
||||||
raise Exception("Job not found")
|
|
||||||
|
|
||||||
if job.status != "failed":
|
|
||||||
raise Exception("Only failed jobs can be retried")
|
|
||||||
|
|
||||||
return update_job(job, status="pending", progress=0, error_message=None)
|
|
||||||
|
|
||||||
@strawberry.mutation
|
|
||||||
def update_asset(self, info: Info, id: UUID, input: UpdateAssetInput) -> MediaAssetType:
|
|
||||||
from core.db import get_asset, update_asset
|
|
||||||
|
|
||||||
try:
|
|
||||||
asset = get_asset(id)
|
|
||||||
except Exception:
|
|
||||||
raise Exception("Asset not found")
|
|
||||||
|
|
||||||
fields = {}
|
|
||||||
if input.comments is not None:
|
|
||||||
fields["comments"] = input.comments
|
|
||||||
if input.tags is not None:
|
|
||||||
fields["tags"] = input.tags
|
|
||||||
|
|
||||||
if fields:
|
|
||||||
asset = update_asset(asset, **fields)
|
|
||||||
|
|
||||||
return asset
|
|
||||||
|
|
||||||
@strawberry.mutation
|
|
||||||
def delete_asset(self, info: Info, id: UUID) -> DeleteResultType:
|
|
||||||
from core.db import delete_asset, get_asset
|
|
||||||
|
|
||||||
try:
|
|
||||||
asset = get_asset(id)
|
|
||||||
delete_asset(asset)
|
|
||||||
return DeleteResultType(ok=True)
|
|
||||||
except Exception:
|
|
||||||
raise Exception("Asset not found")
|
|
||||||
|
|
||||||
@strawberry.mutation
|
|
||||||
def create_chunk_job(self, info: Info, input: CreateChunkJobInput) -> ChunkJobType:
|
|
||||||
"""Create and dispatch a chunk pipeline job."""
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from core.db import get_asset
|
|
||||||
|
|
||||||
try:
|
|
||||||
source = get_asset(input.source_asset_id)
|
|
||||||
except Exception:
|
|
||||||
raise Exception("Source asset not found")
|
|
||||||
|
|
||||||
job_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"source_key": source.file_path,
|
|
||||||
"chunk_duration": input.chunk_duration,
|
|
||||||
"num_workers": input.num_workers,
|
|
||||||
"max_retries": input.max_retries,
|
|
||||||
"processor_type": input.processor_type,
|
|
||||||
"start_time": input.start_time,
|
|
||||||
"end_time": input.end_time,
|
|
||||||
}
|
|
||||||
|
|
||||||
executor_mode = os.environ.get("MPR_EXECUTOR", "local")
|
|
||||||
celery_task_id = None
|
|
||||||
|
|
||||||
if executor_mode in ("lambda", "gcp"):
|
|
||||||
from core.jobs.executor import get_executor
|
|
||||||
|
|
||||||
get_executor().run(
|
|
||||||
job_type="chunk",
|
|
||||||
job_id=job_id,
|
|
||||||
payload=payload,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
from core.jobs.task import run_job
|
|
||||||
|
|
||||||
result = run_job.delay(
|
|
||||||
job_type="chunk",
|
|
||||||
job_id=job_id,
|
|
||||||
payload=payload,
|
|
||||||
)
|
|
||||||
celery_task_id = result.id
|
|
||||||
|
|
||||||
return ChunkJobType(
|
|
||||||
id=uuid.UUID(job_id),
|
|
||||||
source_asset_id=input.source_asset_id,
|
|
||||||
chunk_duration=input.chunk_duration,
|
|
||||||
num_workers=input.num_workers,
|
|
||||||
max_retries=input.max_retries,
|
|
||||||
processor_type=input.processor_type,
|
|
||||||
status="pending",
|
|
||||||
progress=0.0,
|
|
||||||
priority=input.priority,
|
|
||||||
celery_task_id=celery_task_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
@strawberry.mutation
|
|
||||||
def cancel_chunk_job(self, info: Info, celery_task_id: str) -> CancelResultType:
|
|
||||||
"""Cancel a running chunk job by revoking its Celery task."""
|
|
||||||
try:
|
|
||||||
from admin.mpr.celery import app as celery_app
|
|
||||||
|
|
||||||
celery_app.control.revoke(celery_task_id, terminate=True, signal="SIGTERM")
|
|
||||||
return CancelResultType(ok=True, message="Task revoked")
|
|
||||||
except Exception as e:
|
|
||||||
return CancelResultType(ok=False, message=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Schema
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
schema = strawberry.Schema(
|
|
||||||
query=Query,
|
|
||||||
mutation=Mutation,
|
|
||||||
config=StrawberryConfig(auto_camel_case=False),
|
|
||||||
)
|
|
||||||
@@ -1,48 +1,38 @@
|
|||||||
"""
|
"""
|
||||||
MPR FastAPI Application
|
MPR FastAPI Application
|
||||||
|
|
||||||
Serves GraphQL API and Lambda callback endpoint.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Optional
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
# Add project root to path
|
# Add project root to path
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||||
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI, Header, HTTPException
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from strawberry.fastapi import GraphQLRouter
|
|
||||||
|
|
||||||
from core.api.chunker_sse import router as chunker_router
|
|
||||||
from core.api.detect import router as detect_router
|
from core.api.detect import router as detect_router
|
||||||
from core.api.graphql import schema as graphql_schema
|
|
||||||
|
|
||||||
CALLBACK_API_KEY = os.environ.get("CALLBACK_API_KEY", "")
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app):
|
async def lifespan(app):
|
||||||
# Create/reset DB tables on startup
|
|
||||||
from core.db.connection import create_tables
|
from core.db.connection import create_tables
|
||||||
|
from core.db.seed import seed_profiles
|
||||||
create_tables()
|
create_tables()
|
||||||
|
seed_profiles()
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="MPR API",
|
title="MPR API",
|
||||||
description="Media Processor — GraphQL API",
|
|
||||||
version="0.1.0",
|
version="0.1.0",
|
||||||
docs_url="/docs",
|
docs_url="/docs",
|
||||||
redoc_url="/redoc",
|
redoc_url="/redoc",
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
# CORS
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["http://mpr.local.ar", "http://k8s.mpr.local.ar", "http://localhost:5173"],
|
allow_origins=["http://mpr.local.ar", "http://k8s.mpr.local.ar", "http://localhost:5173"],
|
||||||
@@ -51,13 +41,6 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# GraphQL
|
|
||||||
graphql_router = GraphQLRouter(schema=graphql_schema, graphql_ide="graphiql")
|
|
||||||
app.include_router(graphql_router, prefix="/graphql")
|
|
||||||
|
|
||||||
# Chunker SSE
|
|
||||||
app.include_router(chunker_router)
|
|
||||||
|
|
||||||
# Detection API (sources, run, SSE, replay, config)
|
# Detection API (sources, run, SSE, replay, config)
|
||||||
app.include_router(detect_router)
|
app.include_router(detect_router)
|
||||||
|
|
||||||
@@ -69,48 +52,7 @@ def health():
|
|||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
def root():
|
def root():
|
||||||
"""API root."""
|
|
||||||
return {
|
return {
|
||||||
"name": "MPR API",
|
"name": "MPR API",
|
||||||
"version": "0.1.0",
|
"version": "0.1.0",
|
||||||
"graphql": "/graphql",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/jobs/{job_id}/callback")
|
|
||||||
def job_callback(
|
|
||||||
job_id: UUID,
|
|
||||||
payload: dict,
|
|
||||||
x_api_key: Optional[str] = Header(None),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Callback endpoint for Lambda to report job completion.
|
|
||||||
Protected by API key.
|
|
||||||
"""
|
|
||||||
if CALLBACK_API_KEY and x_api_key != CALLBACK_API_KEY:
|
|
||||||
raise HTTPException(status_code=403, detail="Invalid API key")
|
|
||||||
|
|
||||||
from django.utils import timezone
|
|
||||||
|
|
||||||
from core.db import get_job, update_job
|
|
||||||
|
|
||||||
try:
|
|
||||||
job = get_job(job_id)
|
|
||||||
except Exception:
|
|
||||||
raise HTTPException(status_code=404, detail="Job not found")
|
|
||||||
|
|
||||||
status = payload.get("status", "failed")
|
|
||||||
fields = {
|
|
||||||
"status": status,
|
|
||||||
"progress": 100.0 if status == "completed" else job.progress,
|
|
||||||
}
|
|
||||||
|
|
||||||
if payload.get("error"):
|
|
||||||
fields["error_message"] = payload["error"]
|
|
||||||
|
|
||||||
if status in ("completed", "failed"):
|
|
||||||
fields["completed_at"] = timezone.now()
|
|
||||||
|
|
||||||
update_job(job, **fields)
|
|
||||||
|
|
||||||
return {"ok": True}
|
|
||||||
|
|||||||
@@ -1,226 +0,0 @@
|
|||||||
"""
|
|
||||||
Strawberry Types - GENERATED FILE
|
|
||||||
|
|
||||||
Do not edit directly. Regenerate using modelgen.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import strawberry
|
|
||||||
from enum import Enum
|
|
||||||
from typing import List, Optional
|
|
||||||
from uuid import UUID
|
|
||||||
from datetime import datetime
|
|
||||||
from strawberry.scalars import JSON
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.enum
|
|
||||||
class AssetStatus(Enum):
|
|
||||||
PENDING = "pending"
|
|
||||||
READY = "ready"
|
|
||||||
ERROR = "error"
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.enum
|
|
||||||
class JobStatus(Enum):
|
|
||||||
PENDING = "pending"
|
|
||||||
PROCESSING = "processing"
|
|
||||||
COMPLETED = "completed"
|
|
||||||
FAILED = "failed"
|
|
||||||
CANCELLED = "cancelled"
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.type
|
|
||||||
class MediaAssetType:
|
|
||||||
"""A video/audio file registered in the system."""
|
|
||||||
|
|
||||||
id: Optional[UUID] = None
|
|
||||||
filename: Optional[str] = None
|
|
||||||
file_path: Optional[str] = None
|
|
||||||
status: Optional[str] = None
|
|
||||||
error_message: Optional[str] = None
|
|
||||||
file_size: Optional[float] = None
|
|
||||||
duration: Optional[float] = None
|
|
||||||
video_codec: Optional[str] = None
|
|
||||||
audio_codec: Optional[str] = None
|
|
||||||
width: Optional[int] = None
|
|
||||||
height: Optional[int] = None
|
|
||||||
framerate: Optional[float] = None
|
|
||||||
bitrate: Optional[int] = None
|
|
||||||
properties: Optional[JSON] = None
|
|
||||||
comments: Optional[str] = None
|
|
||||||
tags: Optional[List[str]] = None
|
|
||||||
created_at: Optional[datetime] = None
|
|
||||||
updated_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.type
|
|
||||||
class TranscodePresetType:
|
|
||||||
"""A reusable transcoding configuration (like Handbrake presets)."""
|
|
||||||
|
|
||||||
id: Optional[UUID] = None
|
|
||||||
name: Optional[str] = None
|
|
||||||
description: Optional[str] = None
|
|
||||||
is_builtin: Optional[bool] = None
|
|
||||||
container: Optional[str] = None
|
|
||||||
video_codec: Optional[str] = None
|
|
||||||
video_bitrate: Optional[str] = None
|
|
||||||
video_crf: Optional[int] = None
|
|
||||||
video_preset: Optional[str] = None
|
|
||||||
resolution: Optional[str] = None
|
|
||||||
framerate: Optional[float] = None
|
|
||||||
audio_codec: Optional[str] = None
|
|
||||||
audio_bitrate: Optional[str] = None
|
|
||||||
audio_channels: Optional[int] = None
|
|
||||||
audio_samplerate: Optional[int] = None
|
|
||||||
extra_args: Optional[List[str]] = None
|
|
||||||
created_at: Optional[datetime] = None
|
|
||||||
updated_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.type
|
|
||||||
class TranscodeJobType:
|
|
||||||
"""A transcoding or trimming job in the queue."""
|
|
||||||
|
|
||||||
id: Optional[UUID] = None
|
|
||||||
source_asset_id: Optional[UUID] = None
|
|
||||||
preset_id: Optional[UUID] = None
|
|
||||||
preset_snapshot: Optional[JSON] = None
|
|
||||||
trim_start: Optional[float] = None
|
|
||||||
trim_end: Optional[float] = None
|
|
||||||
output_filename: Optional[str] = None
|
|
||||||
output_path: Optional[str] = None
|
|
||||||
output_asset_id: Optional[UUID] = None
|
|
||||||
status: Optional[str] = None
|
|
||||||
progress: Optional[float] = None
|
|
||||||
current_frame: Optional[int] = None
|
|
||||||
current_time: Optional[float] = None
|
|
||||||
speed: Optional[str] = None
|
|
||||||
error_message: Optional[str] = None
|
|
||||||
celery_task_id: Optional[str] = None
|
|
||||||
execution_arn: Optional[str] = None
|
|
||||||
priority: Optional[int] = None
|
|
||||||
created_at: Optional[datetime] = None
|
|
||||||
started_at: Optional[datetime] = None
|
|
||||||
completed_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.input
|
|
||||||
class CreateJobInput:
|
|
||||||
"""Request body for creating a transcode/trim job."""
|
|
||||||
|
|
||||||
source_asset_id: UUID
|
|
||||||
preset_id: Optional[UUID] = None
|
|
||||||
trim_start: Optional[float] = None
|
|
||||||
trim_end: Optional[float] = None
|
|
||||||
output_filename: Optional[str] = None
|
|
||||||
priority: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.input
|
|
||||||
class UpdateAssetInput:
|
|
||||||
"""Request body for updating asset metadata."""
|
|
||||||
|
|
||||||
comments: Optional[str] = None
|
|
||||||
tags: Optional[List[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.type
|
|
||||||
class SystemStatusType:
|
|
||||||
"""System status response."""
|
|
||||||
|
|
||||||
status: Optional[str] = None
|
|
||||||
version: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.type
|
|
||||||
class ScanResultType:
|
|
||||||
"""Result of scanning the media input bucket."""
|
|
||||||
|
|
||||||
found: Optional[int] = None
|
|
||||||
registered: Optional[int] = None
|
|
||||||
skipped: Optional[int] = None
|
|
||||||
files: Optional[List[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.type
|
|
||||||
class DeleteResultType:
|
|
||||||
"""Result of a delete operation."""
|
|
||||||
|
|
||||||
ok: Optional[bool] = None
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.type
|
|
||||||
class WorkerStatusType:
|
|
||||||
"""Worker health and capabilities."""
|
|
||||||
|
|
||||||
available: Optional[bool] = None
|
|
||||||
active_jobs: Optional[int] = None
|
|
||||||
supported_codecs: Optional[List[str]] = None
|
|
||||||
gpu_available: Optional[bool] = None
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.enum
|
|
||||||
class ChunkJobStatus(Enum):
|
|
||||||
PENDING = "pending"
|
|
||||||
CHUNKING = "chunking"
|
|
||||||
PROCESSING = "processing"
|
|
||||||
COLLECTING = "collecting"
|
|
||||||
COMPLETED = "completed"
|
|
||||||
FAILED = "failed"
|
|
||||||
CANCELLED = "cancelled"
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.type
|
|
||||||
class ChunkJobType:
|
|
||||||
"""A chunk pipeline job."""
|
|
||||||
|
|
||||||
id: Optional[UUID] = None
|
|
||||||
source_asset_id: Optional[UUID] = None
|
|
||||||
chunk_duration: Optional[float] = None
|
|
||||||
num_workers: Optional[int] = None
|
|
||||||
max_retries: Optional[int] = None
|
|
||||||
processor_type: Optional[str] = None
|
|
||||||
status: Optional[str] = None
|
|
||||||
progress: Optional[float] = None
|
|
||||||
total_chunks: Optional[int] = None
|
|
||||||
processed_chunks: Optional[int] = None
|
|
||||||
failed_chunks: Optional[int] = None
|
|
||||||
retry_count: Optional[int] = None
|
|
||||||
error_message: Optional[str] = None
|
|
||||||
throughput_mbps: Optional[float] = None
|
|
||||||
elapsed_seconds: Optional[float] = None
|
|
||||||
celery_task_id: Optional[str] = None
|
|
||||||
priority: Optional[int] = None
|
|
||||||
created_at: Optional[datetime] = None
|
|
||||||
started_at: Optional[datetime] = None
|
|
||||||
completed_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.input
|
|
||||||
class CreateChunkJobInput:
|
|
||||||
"""Request body for creating a chunk pipeline job."""
|
|
||||||
|
|
||||||
source_asset_id: UUID
|
|
||||||
chunk_duration: float = 10.0
|
|
||||||
num_workers: int = 4
|
|
||||||
max_retries: int = 3
|
|
||||||
processor_type: str = "ffmpeg"
|
|
||||||
priority: int = 0
|
|
||||||
start_time: Optional[float] = None
|
|
||||||
end_time: Optional[float] = None
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.type
|
|
||||||
class CancelResultType:
|
|
||||||
"""Result of cancelling a chunk job."""
|
|
||||||
|
|
||||||
ok: bool = False
|
|
||||||
message: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.type
|
|
||||||
class ChunkOutputFileType:
|
|
||||||
"""A chunk output file in S3/MinIO with presigned download URL."""
|
|
||||||
|
|
||||||
key: str
|
|
||||||
size: int = 0
|
|
||||||
url: str = ""
|
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
"""
|
|
||||||
Chunker pipeline — splits files into chunks, processes concurrently, reassembles in order.
|
|
||||||
|
|
||||||
Public API:
|
|
||||||
Pipeline — orchestrates the full pipeline
|
|
||||||
PipelineResult — aggregate result dataclass
|
|
||||||
Chunker — file → Chunk generator
|
|
||||||
ChunkQueue — bounded thread-safe queue
|
|
||||||
WorkerPool — manages N worker threads
|
|
||||||
ResultCollector — heapq-based ordered reassembly
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .chunker import Chunker
|
|
||||||
from .collector import ResultCollector
|
|
||||||
from .exceptions import (
|
|
||||||
ChunkChecksumError,
|
|
||||||
ChunkError,
|
|
||||||
ChunkReadError,
|
|
||||||
PipelineError,
|
|
||||||
ProcessingError,
|
|
||||||
ProcessorFailureError,
|
|
||||||
ProcessorTimeoutError,
|
|
||||||
ReassemblyError,
|
|
||||||
)
|
|
||||||
from .models import Chunk, ChunkResult, PipelineResult
|
|
||||||
from .pipeline import Pipeline
|
|
||||||
from .pool import WorkerPool
|
|
||||||
from .processor import (
|
|
||||||
ChecksumProcessor,
|
|
||||||
CompositeProcessor,
|
|
||||||
FFmpegExtractProcessor,
|
|
||||||
Processor,
|
|
||||||
SimulatedDecodeProcessor,
|
|
||||||
)
|
|
||||||
from .queue import ChunkQueue
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
# Core
|
|
||||||
"Pipeline",
|
|
||||||
"PipelineResult",
|
|
||||||
# Components
|
|
||||||
"Chunker",
|
|
||||||
"ChunkQueue",
|
|
||||||
"WorkerPool",
|
|
||||||
"ResultCollector",
|
|
||||||
# Models
|
|
||||||
"Chunk",
|
|
||||||
"ChunkResult",
|
|
||||||
# Processors
|
|
||||||
"Processor",
|
|
||||||
"ChecksumProcessor",
|
|
||||||
"SimulatedDecodeProcessor",
|
|
||||||
"CompositeProcessor",
|
|
||||||
"FFmpegExtractProcessor",
|
|
||||||
# Exceptions
|
|
||||||
"PipelineError",
|
|
||||||
"ChunkError",
|
|
||||||
"ChunkReadError",
|
|
||||||
"ChunkChecksumError",
|
|
||||||
"ProcessingError",
|
|
||||||
"ProcessorFailureError",
|
|
||||||
"ProcessorTimeoutError",
|
|
||||||
"ReassemblyError",
|
|
||||||
]
|
|
||||||
@@ -1,101 +0,0 @@
|
|||||||
"""
|
|
||||||
Chunker — probes a media file and yields time-based Chunk objects.
|
|
||||||
|
|
||||||
Demonstrates:
|
|
||||||
- Function parameters and defaults (Interview Topic 1)
|
|
||||||
- List comprehensions and efficient iteration / generators (Interview Topic 3)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
from typing import Generator
|
|
||||||
|
|
||||||
from core.ffmpeg.probe import probe_file
|
|
||||||
|
|
||||||
from .exceptions import ChunkReadError
|
|
||||||
from .models import Chunk
|
|
||||||
|
|
||||||
|
|
||||||
class Chunker:
|
|
||||||
"""
|
|
||||||
Splits a media file into time-based chunks via a generator.
|
|
||||||
|
|
||||||
Uses FFmpeg probe to get duration, then yields Chunk objects
|
|
||||||
representing time segments (no data read — extraction happens in the processor).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: Path to the source media file
|
|
||||||
chunk_duration: Duration of each chunk in seconds (default: 10.0)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
file_path: str,
|
|
||||||
chunk_duration: float = 10.0,
|
|
||||||
start_time: float | None = None,
|
|
||||||
end_time: float | None = None,
|
|
||||||
):
|
|
||||||
if not os.path.isfile(file_path):
|
|
||||||
raise ChunkReadError(f"File not found: {file_path}")
|
|
||||||
if chunk_duration <= 0:
|
|
||||||
raise ValueError("chunk_duration must be positive")
|
|
||||||
|
|
||||||
self.file_path = file_path
|
|
||||||
self.chunk_duration = chunk_duration
|
|
||||||
self.file_size = os.path.getsize(file_path)
|
|
||||||
full_duration = self._probe_duration()
|
|
||||||
|
|
||||||
# Apply time range
|
|
||||||
self.range_start = max(start_time or 0.0, 0.0)
|
|
||||||
self.range_end = min(end_time or full_duration, full_duration)
|
|
||||||
if self.range_start >= self.range_end:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid range: start={self.range_start} >= end={self.range_end}"
|
|
||||||
)
|
|
||||||
self.source_duration = self.range_end - self.range_start
|
|
||||||
|
|
||||||
def _probe_duration(self) -> float:
|
|
||||||
"""Get source file duration via FFmpeg probe."""
|
|
||||||
try:
|
|
||||||
result = probe_file(self.file_path)
|
|
||||||
if result.duration is None or result.duration <= 0:
|
|
||||||
raise ChunkReadError(
|
|
||||||
f"Cannot determine duration for {self.file_path}"
|
|
||||||
)
|
|
||||||
return result.duration
|
|
||||||
except ChunkReadError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
raise ChunkReadError(
|
|
||||||
f"Failed to probe {self.file_path}: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
@property
|
|
||||||
def expected_chunks(self) -> int:
|
|
||||||
"""Calculate expected number of chunks (last chunk may be shorter)."""
|
|
||||||
if self.source_duration <= 0:
|
|
||||||
return 0
|
|
||||||
return math.ceil(self.source_duration / self.chunk_duration)
|
|
||||||
|
|
||||||
def chunks(self) -> Generator[Chunk, None, None]:
|
|
||||||
"""
|
|
||||||
Yield Chunk objects representing time segments of the source file.
|
|
||||||
|
|
||||||
Generator-based: chunks are yielded on demand.
|
|
||||||
Each chunk defines a time range — actual extraction is done by the processor.
|
|
||||||
"""
|
|
||||||
total = self.expected_chunks
|
|
||||||
for sequence in range(total):
|
|
||||||
start_time = self.range_start + sequence * self.chunk_duration
|
|
||||||
end_time = min(
|
|
||||||
start_time + self.chunk_duration, self.range_end
|
|
||||||
)
|
|
||||||
duration = end_time - start_time
|
|
||||||
|
|
||||||
yield Chunk(
|
|
||||||
sequence=sequence,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
source_path=self.file_path,
|
|
||||||
duration=duration,
|
|
||||||
)
|
|
||||||
@@ -1,98 +0,0 @@
|
|||||||
"""
|
|
||||||
ResultCollector — reassembles chunk results in sequence order using a min-heap.
|
|
||||||
|
|
||||||
Demonstrates:
|
|
||||||
- Algorithms and sorting (Interview Topic 6) — heapq for ordered reassembly
|
|
||||||
- Core data structures (Interview Topic 5) — heap, deque
|
|
||||||
"""
|
|
||||||
|
|
||||||
import heapq
|
|
||||||
from collections import deque
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from .exceptions import ReassemblyError
|
|
||||||
from .models import ChunkResult
|
|
||||||
|
|
||||||
|
|
||||||
class ResultCollector:
|
|
||||||
"""
|
|
||||||
Receives ChunkResults out of order, emits them in sequence order.
|
|
||||||
|
|
||||||
Uses a min-heap keyed on sequence number. Only emits a chunk when
|
|
||||||
all prior sequences have been accounted for.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
total_chunks: Expected total number of chunks
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, total_chunks: int):
|
|
||||||
self.total_chunks = total_chunks
|
|
||||||
self._heap: List[tuple[int, ChunkResult]] = []
|
|
||||||
self._next_sequence = 0
|
|
||||||
self._emitted: List[ChunkResult] = []
|
|
||||||
self._seen_sequences: set[int] = set()
|
|
||||||
# Sliding window for throughput calculation
|
|
||||||
self._recent_times: deque[float] = deque(maxlen=50)
|
|
||||||
|
|
||||||
def add(self, result: ChunkResult) -> List[ChunkResult]:
|
|
||||||
"""
|
|
||||||
Add a result and return any newly emittable results in order.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
result: A ChunkResult (may arrive out of order)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of results that can now be emitted in sequence order
|
|
||||||
(may be empty if we're still waiting for earlier sequences)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ReassemblyError: If a duplicate sequence is received
|
|
||||||
"""
|
|
||||||
if result.sequence in self._seen_sequences:
|
|
||||||
raise ReassemblyError(
|
|
||||||
f"Duplicate sequence number: {result.sequence}"
|
|
||||||
)
|
|
||||||
self._seen_sequences.add(result.sequence)
|
|
||||||
|
|
||||||
# Track processing time for throughput
|
|
||||||
if result.processing_time > 0:
|
|
||||||
self._recent_times.append(result.processing_time)
|
|
||||||
|
|
||||||
# Push to min-heap
|
|
||||||
heapq.heappush(self._heap, (result.sequence, result))
|
|
||||||
|
|
||||||
# Emit all consecutive results starting from _next_sequence
|
|
||||||
newly_emitted = []
|
|
||||||
while self._heap and self._heap[0][0] == self._next_sequence:
|
|
||||||
_, emitted_result = heapq.heappop(self._heap)
|
|
||||||
self._emitted.append(emitted_result)
|
|
||||||
newly_emitted.append(emitted_result)
|
|
||||||
self._next_sequence += 1
|
|
||||||
|
|
||||||
return newly_emitted
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_complete(self) -> bool:
|
|
||||||
"""True if all expected chunks have been emitted in order."""
|
|
||||||
return self._next_sequence == self.total_chunks
|
|
||||||
|
|
||||||
@property
|
|
||||||
def buffered_count(self) -> int:
|
|
||||||
"""Number of results waiting in the heap (arrived out of order)."""
|
|
||||||
return len(self._heap)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def emitted_count(self) -> int:
|
|
||||||
"""Number of results emitted in sequence order."""
|
|
||||||
return len(self._emitted)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def avg_processing_time(self) -> float:
|
|
||||||
"""Average processing time from recent results (sliding window)."""
|
|
||||||
if not self._recent_times:
|
|
||||||
return 0.0
|
|
||||||
return sum(self._recent_times) / len(self._recent_times)
|
|
||||||
|
|
||||||
def get_ordered_results(self) -> List[ChunkResult]:
|
|
||||||
"""Get all emitted results in sequence order."""
|
|
||||||
return list(self._emitted)
|
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
"""
|
|
||||||
Chunker exception hierarchy.
|
|
||||||
|
|
||||||
Demonstrates: Managing exceptions and writing resilient code (Interview Topic 7).
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineError(Exception):
|
|
||||||
"""Base exception for all chunker pipeline errors."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkError(PipelineError):
|
|
||||||
"""Errors related to chunk creation or validation."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkReadError(ChunkError):
|
|
||||||
"""Failed to read chunk data from source file."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkChecksumError(ChunkError):
|
|
||||||
"""Chunk data integrity validation failed."""
|
|
||||||
|
|
||||||
def __init__(self, sequence: int, expected: str, actual: str):
|
|
||||||
self.sequence = sequence
|
|
||||||
self.expected = expected
|
|
||||||
self.actual = actual
|
|
||||||
super().__init__(
|
|
||||||
f"Chunk {sequence}: checksum mismatch "
|
|
||||||
f"(expected={expected}, actual={actual})"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ProcessingError(PipelineError):
|
|
||||||
"""Errors during chunk processing by workers."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ProcessorTimeoutError(ProcessingError):
|
|
||||||
"""Processor exceeded allowed time for a chunk."""
|
|
||||||
|
|
||||||
def __init__(self, sequence: int, timeout: float):
|
|
||||||
self.sequence = sequence
|
|
||||||
self.timeout = timeout
|
|
||||||
super().__init__(f"Chunk {sequence}: processor timed out after {timeout}s")
|
|
||||||
|
|
||||||
|
|
||||||
class ProcessorFailureError(ProcessingError):
|
|
||||||
"""Processor failed to process a chunk after all retries."""
|
|
||||||
|
|
||||||
def __init__(self, sequence: int, retries: int, original_error: Exception):
|
|
||||||
self.sequence = sequence
|
|
||||||
self.retries = retries
|
|
||||||
self.original_error = original_error
|
|
||||||
super().__init__(
|
|
||||||
f"Chunk {sequence}: failed after {retries} retries — {original_error}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ReassemblyError(PipelineError):
|
|
||||||
"""Errors during result collection and ordering."""
|
|
||||||
pass
|
|
||||||
@@ -1,54 +0,0 @@
|
|||||||
"""
|
|
||||||
Internal data models for the chunker pipeline.
|
|
||||||
|
|
||||||
These are pipeline-internal dataclasses, not schema models.
|
|
||||||
Schema-level ChunkJob is in core/schema/models/jobs.py.
|
|
||||||
|
|
||||||
Demonstrates: Core data structures (Interview Topic 5).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Chunk:
|
|
||||||
"""A time-based segment of the source media file."""
|
|
||||||
|
|
||||||
sequence: int
|
|
||||||
start_time: float # seconds
|
|
||||||
end_time: float # seconds
|
|
||||||
source_path: str # path to source file
|
|
||||||
duration: float # end_time - start_time
|
|
||||||
checksum: str = "" # computed after extraction
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ChunkResult:
|
|
||||||
"""Result of processing a single chunk."""
|
|
||||||
|
|
||||||
sequence: int
|
|
||||||
success: bool
|
|
||||||
checksum_valid: bool = True
|
|
||||||
processing_time: float = 0.0
|
|
||||||
error: Optional[str] = None
|
|
||||||
retries: int = 0
|
|
||||||
worker_id: Optional[str] = None
|
|
||||||
output_file: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PipelineResult:
|
|
||||||
"""Aggregate result of the entire pipeline run."""
|
|
||||||
|
|
||||||
total_chunks: int = 0
|
|
||||||
processed: int = 0
|
|
||||||
failed: int = 0
|
|
||||||
retries: int = 0
|
|
||||||
elapsed_time: float = 0.0
|
|
||||||
throughput_mbps: float = 0.0
|
|
||||||
worker_stats: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
errors: List[str] = field(default_factory=list)
|
|
||||||
chunks_in_order: bool = True
|
|
||||||
output_dir: Optional[str] = None
|
|
||||||
chunk_files: List[str] = field(default_factory=list)
|
|
||||||
@@ -1,279 +0,0 @@
|
|||||||
"""
|
|
||||||
Pipeline — orchestrates the entire chunker pipeline.
|
|
||||||
|
|
||||||
Wires: Chunker → ChunkQueue → WorkerPool → ResultCollector → PipelineResult
|
|
||||||
|
|
||||||
Demonstrates:
|
|
||||||
- Function parameters and defaults (Interview Topic 1) — configurable pipeline
|
|
||||||
- Concurrency (Interview Topic 2) — producer thread + worker pool
|
|
||||||
- OOP design (Interview Topic 4) — composition of pipeline components
|
|
||||||
- Exception handling (Interview Topic 7) — graceful error propagation
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Callable, Dict, Optional
|
|
||||||
|
|
||||||
from .chunker import Chunker
|
|
||||||
from .collector import ResultCollector
|
|
||||||
from .exceptions import PipelineError
|
|
||||||
from .models import PipelineResult
|
|
||||||
from .pool import WorkerPool
|
|
||||||
from .queue import ChunkQueue
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
|
||||||
"""
|
|
||||||
Orchestrates the chunk processing pipeline.
|
|
||||||
|
|
||||||
The pipeline runs in three stages:
|
|
||||||
1. Producer thread: Chunker probes file → pushes time-based chunks to ChunkQueue
|
|
||||||
2. Worker pool: N workers pull from queue → extract mp4 segments → emit results
|
|
||||||
3. Collector: ResultCollector reassembles results in sequence order
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source: Path to the source media file
|
|
||||||
chunk_duration: Duration of each chunk in seconds (default: 10.0)
|
|
||||||
num_workers: Number of concurrent worker threads (default: 4)
|
|
||||||
max_retries: Max retry attempts per chunk (default: 3)
|
|
||||||
processor_type: Processor to use — "ffmpeg", "checksum", "simulated_decode", "composite"
|
|
||||||
queue_size: Max chunks buffered in queue (default: 10)
|
|
||||||
event_callback: Optional callback for real-time events
|
|
||||||
output_dir: Directory for output chunk files (required for "ffmpeg" processor)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
source: str,
|
|
||||||
chunk_duration: float = 10.0,
|
|
||||||
num_workers: int = 4,
|
|
||||||
max_retries: int = 3,
|
|
||||||
processor_type: str = "checksum",
|
|
||||||
queue_size: int = 10,
|
|
||||||
event_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None,
|
|
||||||
output_dir: Optional[str] = None,
|
|
||||||
start_time: Optional[float] = None,
|
|
||||||
end_time: Optional[float] = None,
|
|
||||||
):
|
|
||||||
self.source = source
|
|
||||||
self.chunk_duration = chunk_duration
|
|
||||||
self.num_workers = num_workers
|
|
||||||
self.max_retries = max_retries
|
|
||||||
self.processor_type = processor_type
|
|
||||||
self.queue_size = queue_size
|
|
||||||
self.event_callback = event_callback
|
|
||||||
self.output_dir = output_dir
|
|
||||||
self.start_time = start_time
|
|
||||||
self.end_time = end_time
|
|
||||||
|
|
||||||
def _emit(self, event_type: str, data: Dict[str, Any]) -> None:
|
|
||||||
"""Emit an event if callback is registered."""
|
|
||||||
if self.event_callback:
|
|
||||||
self.event_callback(event_type, data)
|
|
||||||
|
|
||||||
def _produce_chunks(
|
|
||||||
self, chunker: Chunker, chunk_queue: ChunkQueue
|
|
||||||
) -> None:
|
|
||||||
"""Producer thread: probe file and enqueue time-based chunks."""
|
|
||||||
try:
|
|
||||||
for chunk in chunker.chunks():
|
|
||||||
chunk_queue.put(chunk, timeout=30.0)
|
|
||||||
self._emit("chunk_queued", {
|
|
||||||
"sequence": chunk.sequence,
|
|
||||||
"start_time": chunk.start_time,
|
|
||||||
"end_time": chunk.end_time,
|
|
||||||
"duration": chunk.duration,
|
|
||||||
"queue_size": chunk_queue.qsize(),
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Producer error: {e}")
|
|
||||||
self._emit("producer_error", {"error": str(e)})
|
|
||||||
finally:
|
|
||||||
chunk_queue.close()
|
|
||||||
|
|
||||||
def _monitor_progress(
|
|
||||||
self, start_time: float, file_size: int, stop_event: threading.Event
|
|
||||||
) -> None:
|
|
||||||
"""Monitor thread: emit pipeline_progress every 500ms."""
|
|
||||||
while not stop_event.is_set():
|
|
||||||
elapsed = time.monotonic() - start_time
|
|
||||||
mb = file_size / (1024 * 1024)
|
|
||||||
self._emit("pipeline_progress", {
|
|
||||||
"elapsed": round(elapsed, 2),
|
|
||||||
"throughput_mbps": round(mb / elapsed, 2) if elapsed > 0 else 0,
|
|
||||||
})
|
|
||||||
stop_event.wait(0.5)
|
|
||||||
|
|
||||||
def _write_manifest(
|
|
||||||
self, result: PipelineResult, source_duration: float
|
|
||||||
) -> None:
|
|
||||||
"""Write manifest.json to output_dir with segment metadata."""
|
|
||||||
if not self.output_dir:
|
|
||||||
return
|
|
||||||
|
|
||||||
manifest = {
|
|
||||||
"source": self.source,
|
|
||||||
"source_duration": source_duration,
|
|
||||||
"chunk_duration": self.chunk_duration,
|
|
||||||
"total_chunks": result.total_chunks,
|
|
||||||
"processed": result.processed,
|
|
||||||
"failed": result.failed,
|
|
||||||
"elapsed_time": result.elapsed_time,
|
|
||||||
"throughput_mbps": result.throughput_mbps,
|
|
||||||
"segments": [
|
|
||||||
{
|
|
||||||
"sequence": i,
|
|
||||||
"file": f"chunk_{i:04d}.mp4",
|
|
||||||
"start": i * self.chunk_duration,
|
|
||||||
"end": min(
|
|
||||||
(i + 1) * self.chunk_duration, source_duration
|
|
||||||
),
|
|
||||||
}
|
|
||||||
for i in range(result.total_chunks)
|
|
||||||
if i < result.total_chunks
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
manifest_path = Path(self.output_dir) / "manifest.json"
|
|
||||||
manifest_path.write_text(json.dumps(manifest, indent=2))
|
|
||||||
logger.info(f"Manifest written to {manifest_path}")
|
|
||||||
|
|
||||||
def run(self) -> PipelineResult:
|
|
||||||
"""
|
|
||||||
Execute the full pipeline.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PipelineResult with aggregate stats
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
PipelineError: If the pipeline fails catastrophically
|
|
||||||
"""
|
|
||||||
start_time = time.monotonic()
|
|
||||||
self._emit("pipeline_start", {
|
|
||||||
"source": self.source,
|
|
||||||
"chunk_duration": self.chunk_duration,
|
|
||||||
"num_workers": self.num_workers,
|
|
||||||
"processor_type": self.processor_type,
|
|
||||||
})
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Stage 1: Set up chunker (probes file for duration)
|
|
||||||
chunker = Chunker(
|
|
||||||
self.source,
|
|
||||||
self.chunk_duration,
|
|
||||||
start_time=self.start_time,
|
|
||||||
end_time=self.end_time,
|
|
||||||
)
|
|
||||||
total_chunks = chunker.expected_chunks
|
|
||||||
|
|
||||||
if total_chunks == 0:
|
|
||||||
self._emit("pipeline_complete", {"total_chunks": 0})
|
|
||||||
return PipelineResult(chunks_in_order=True)
|
|
||||||
|
|
||||||
self._emit("pipeline_info", {
|
|
||||||
"file_size": chunker.file_size,
|
|
||||||
"source_duration": chunker.source_duration,
|
|
||||||
"total_chunks": total_chunks,
|
|
||||||
})
|
|
||||||
|
|
||||||
# Stage 2: Set up queue and worker pool
|
|
||||||
chunk_queue = ChunkQueue(maxsize=self.queue_size)
|
|
||||||
pool = WorkerPool(
|
|
||||||
num_workers=self.num_workers,
|
|
||||||
chunk_queue=chunk_queue,
|
|
||||||
processor_type=self.processor_type,
|
|
||||||
max_retries=self.max_retries,
|
|
||||||
event_callback=self.event_callback,
|
|
||||||
output_dir=self.output_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stage 3: Start workers, monitor, then produce chunks
|
|
||||||
pool.start()
|
|
||||||
|
|
||||||
monitor_stop = threading.Event()
|
|
||||||
monitor = threading.Thread(
|
|
||||||
target=self._monitor_progress,
|
|
||||||
args=(start_time, chunker.file_size, monitor_stop),
|
|
||||||
name="progress-monitor",
|
|
||||||
daemon=True,
|
|
||||||
)
|
|
||||||
monitor.start()
|
|
||||||
|
|
||||||
producer = threading.Thread(
|
|
||||||
target=self._produce_chunks,
|
|
||||||
args=(chunker, chunk_queue),
|
|
||||||
name="chunk-producer",
|
|
||||||
daemon=True,
|
|
||||||
)
|
|
||||||
producer.start()
|
|
||||||
|
|
||||||
# Stage 4: Wait for all workers to finish
|
|
||||||
all_results = pool.wait()
|
|
||||||
producer.join(timeout=5.0)
|
|
||||||
|
|
||||||
# Stop monitor
|
|
||||||
monitor_stop.set()
|
|
||||||
monitor.join(timeout=2.0)
|
|
||||||
|
|
||||||
# Stage 5: Collect results in order
|
|
||||||
collector = ResultCollector(total_chunks)
|
|
||||||
for r in all_results:
|
|
||||||
collector.add(r)
|
|
||||||
self._emit("chunk_collected", {
|
|
||||||
"sequence": r.sequence,
|
|
||||||
"success": r.success,
|
|
||||||
"buffered": collector.buffered_count,
|
|
||||||
"emitted": collector.emitted_count,
|
|
||||||
})
|
|
||||||
|
|
||||||
# Build result
|
|
||||||
elapsed = time.monotonic() - start_time
|
|
||||||
file_size_mb = chunker.file_size / (1024 * 1024)
|
|
||||||
throughput = file_size_mb / elapsed if elapsed > 0 else 0.0
|
|
||||||
|
|
||||||
failed_results = [r for r in all_results if not r.success]
|
|
||||||
total_retries = sum(r.retries for r in all_results)
|
|
||||||
chunk_files = [
|
|
||||||
r.output_file for r in all_results
|
|
||||||
if r.success and r.output_file
|
|
||||||
]
|
|
||||||
|
|
||||||
result = PipelineResult(
|
|
||||||
total_chunks=total_chunks,
|
|
||||||
processed=len(all_results),
|
|
||||||
failed=len(failed_results),
|
|
||||||
retries=total_retries,
|
|
||||||
elapsed_time=elapsed,
|
|
||||||
throughput_mbps=throughput,
|
|
||||||
worker_stats=pool.get_worker_stats(),
|
|
||||||
errors=[r.error for r in failed_results if r.error],
|
|
||||||
chunks_in_order=collector.is_complete,
|
|
||||||
output_dir=self.output_dir,
|
|
||||||
chunk_files=chunk_files,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Write manifest if output_dir is set
|
|
||||||
self._write_manifest(result, chunker.source_duration)
|
|
||||||
|
|
||||||
pool.shutdown()
|
|
||||||
|
|
||||||
self._emit("pipeline_complete", {
|
|
||||||
"total_chunks": result.total_chunks,
|
|
||||||
"processed": result.processed,
|
|
||||||
"failed": result.failed,
|
|
||||||
"elapsed": result.elapsed_time,
|
|
||||||
"throughput_mbps": result.throughput_mbps,
|
|
||||||
})
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except PipelineError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self._emit("pipeline_error", {"error": str(e)})
|
|
||||||
raise PipelineError(f"Pipeline failed: {e}") from e
|
|
||||||
@@ -1,125 +0,0 @@
|
|||||||
"""
|
|
||||||
WorkerPool — manages N worker threads via ThreadPoolExecutor.
|
|
||||||
|
|
||||||
Demonstrates: Python concurrency — threading (Interview Topic 2).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
|
||||||
|
|
||||||
from .models import ChunkResult
|
|
||||||
from .processor import (
|
|
||||||
ChecksumProcessor,
|
|
||||||
CompositeProcessor,
|
|
||||||
FFmpegExtractProcessor,
|
|
||||||
Processor,
|
|
||||||
SimulatedDecodeProcessor,
|
|
||||||
)
|
|
||||||
from .queue import ChunkQueue
|
|
||||||
from .worker import Worker
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def create_processor(
|
|
||||||
processor_type: str = "checksum",
|
|
||||||
output_dir: Optional[str] = None,
|
|
||||||
) -> Processor:
|
|
||||||
"""Factory for processor instances."""
|
|
||||||
if processor_type == "ffmpeg":
|
|
||||||
if not output_dir:
|
|
||||||
raise ValueError("output_dir required for ffmpeg processor")
|
|
||||||
return FFmpegExtractProcessor(output_dir=output_dir)
|
|
||||||
elif processor_type == "checksum":
|
|
||||||
return ChecksumProcessor()
|
|
||||||
elif processor_type == "simulated_decode":
|
|
||||||
return SimulatedDecodeProcessor()
|
|
||||||
elif processor_type == "composite":
|
|
||||||
return CompositeProcessor([
|
|
||||||
ChecksumProcessor(),
|
|
||||||
SimulatedDecodeProcessor(ms_per_second=50.0),
|
|
||||||
])
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown processor type: {processor_type}")
|
|
||||||
|
|
||||||
|
|
||||||
class WorkerPool:
|
|
||||||
"""
|
|
||||||
Manages N worker threads that process chunks concurrently.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_workers: Number of concurrent worker threads (default: 4)
|
|
||||||
chunk_queue: Shared queue to pull chunks from
|
|
||||||
processor_type: Type of processor for each worker (default: "checksum")
|
|
||||||
max_retries: Max retry attempts per chunk (default: 3)
|
|
||||||
event_callback: Optional callback for real-time events
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_workers: int = 4,
|
|
||||||
chunk_queue: Optional[ChunkQueue] = None,
|
|
||||||
processor_type: str = "checksum",
|
|
||||||
max_retries: int = 3,
|
|
||||||
event_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None,
|
|
||||||
output_dir: Optional[str] = None,
|
|
||||||
):
|
|
||||||
self.num_workers = num_workers
|
|
||||||
self.chunk_queue = chunk_queue or ChunkQueue()
|
|
||||||
self.processor_type = processor_type
|
|
||||||
self.max_retries = max_retries
|
|
||||||
self.event_callback = event_callback
|
|
||||||
self.output_dir = output_dir
|
|
||||||
self.shutdown_event = threading.Event()
|
|
||||||
self._executor: Optional[ThreadPoolExecutor] = None
|
|
||||||
self._futures: List[Future] = []
|
|
||||||
self._workers: List[Worker] = []
|
|
||||||
|
|
||||||
def start(self) -> None:
|
|
||||||
"""Start all worker threads."""
|
|
||||||
self._executor = ThreadPoolExecutor(
|
|
||||||
max_workers=self.num_workers,
|
|
||||||
thread_name_prefix="chunk-worker",
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in range(self.num_workers):
|
|
||||||
worker = Worker(
|
|
||||||
worker_id=f"worker-{i}",
|
|
||||||
chunk_queue=self.chunk_queue,
|
|
||||||
processor=create_processor(self.processor_type, output_dir=self.output_dir),
|
|
||||||
max_retries=self.max_retries,
|
|
||||||
event_callback=self.event_callback,
|
|
||||||
)
|
|
||||||
self._workers.append(worker)
|
|
||||||
future = self._executor.submit(worker.run)
|
|
||||||
self._futures.append(future)
|
|
||||||
|
|
||||||
logger.info(f"WorkerPool started with {self.num_workers} workers")
|
|
||||||
|
|
||||||
def wait(self) -> List[ChunkResult]:
|
|
||||||
"""Wait for all workers to finish and collect results."""
|
|
||||||
all_results = []
|
|
||||||
for future in self._futures:
|
|
||||||
results = future.result()
|
|
||||||
all_results.extend(results)
|
|
||||||
return all_results
|
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
|
||||||
"""Signal shutdown and cleanup."""
|
|
||||||
self.shutdown_event.set()
|
|
||||||
self.chunk_queue.close()
|
|
||||||
if self._executor:
|
|
||||||
self._executor.shutdown(wait=True)
|
|
||||||
|
|
||||||
def get_worker_stats(self) -> Dict[str, Any]:
|
|
||||||
"""Get per-worker statistics."""
|
|
||||||
return {
|
|
||||||
w.worker_id: {
|
|
||||||
"processed": w.processed_count,
|
|
||||||
"errors": w.error_count,
|
|
||||||
"retries": w.retry_count,
|
|
||||||
}
|
|
||||||
for w in self._workers
|
|
||||||
}
|
|
||||||
@@ -1,173 +0,0 @@
|
|||||||
"""
|
|
||||||
Processor ABC and concrete implementations.
|
|
||||||
|
|
||||||
Demonstrates: OOP design principles — ABC, inheritance, composition (Interview Topic 4).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import time
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from .exceptions import ChunkChecksumError
|
|
||||||
from .models import Chunk, ChunkResult
|
|
||||||
|
|
||||||
|
|
||||||
class Processor(ABC):
|
|
||||||
"""
|
|
||||||
Abstract base class for chunk processors.
|
|
||||||
|
|
||||||
Each processor defines how a single chunk is processed.
|
|
||||||
The Worker calls processor.process(chunk) and handles retries.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def process(self, chunk: Chunk) -> ChunkResult:
|
|
||||||
"""Process a single chunk and return the result."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class FFmpegExtractProcessor(Processor):
|
|
||||||
"""
|
|
||||||
Extracts a time segment from the source file using FFmpeg stream copy.
|
|
||||||
|
|
||||||
Produces a playable mp4 file per chunk — no re-encoding.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output_dir: Directory to write chunk mp4 files
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, output_dir: str):
|
|
||||||
self.output_dir = output_dir
|
|
||||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
def process(self, chunk: Chunk) -> ChunkResult:
|
|
||||||
from core.ffmpeg.transcode import TranscodeConfig, transcode
|
|
||||||
|
|
||||||
start = time.monotonic()
|
|
||||||
|
|
||||||
output_file = str(
|
|
||||||
Path(self.output_dir) / f"chunk_{chunk.sequence:04d}.mp4"
|
|
||||||
)
|
|
||||||
|
|
||||||
config = TranscodeConfig(
|
|
||||||
input_path=chunk.source_path,
|
|
||||||
output_path=output_file,
|
|
||||||
video_codec="copy",
|
|
||||||
audio_codec="copy",
|
|
||||||
trim_start=chunk.start_time,
|
|
||||||
trim_end=chunk.end_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
transcode(config)
|
|
||||||
|
|
||||||
# Compute checksum of output file
|
|
||||||
md5 = hashlib.md5()
|
|
||||||
with open(output_file, "rb") as f:
|
|
||||||
for block in iter(lambda: f.read(8192), b""):
|
|
||||||
md5.update(block)
|
|
||||||
checksum = md5.hexdigest()
|
|
||||||
|
|
||||||
elapsed = time.monotonic() - start
|
|
||||||
|
|
||||||
return ChunkResult(
|
|
||||||
sequence=chunk.sequence,
|
|
||||||
success=True,
|
|
||||||
checksum_valid=True,
|
|
||||||
processing_time=elapsed,
|
|
||||||
output_file=output_file,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ChecksumProcessor(Processor):
|
|
||||||
"""
|
|
||||||
Validates chunk metadata consistency.
|
|
||||||
|
|
||||||
For time-based chunks, verifies the time range is valid.
|
|
||||||
Raises ChunkChecksumError on invalid ranges.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def process(self, chunk: Chunk) -> ChunkResult:
|
|
||||||
start = time.monotonic()
|
|
||||||
|
|
||||||
valid = chunk.duration > 0 and chunk.end_time > chunk.start_time
|
|
||||||
|
|
||||||
if not valid:
|
|
||||||
raise ChunkChecksumError(
|
|
||||||
sequence=chunk.sequence,
|
|
||||||
expected="valid time range",
|
|
||||||
actual=f"{chunk.start_time}-{chunk.end_time}",
|
|
||||||
)
|
|
||||||
|
|
||||||
elapsed = time.monotonic() - start
|
|
||||||
|
|
||||||
return ChunkResult(
|
|
||||||
sequence=chunk.sequence,
|
|
||||||
success=True,
|
|
||||||
checksum_valid=True,
|
|
||||||
processing_time=elapsed,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SimulatedDecodeProcessor(Processor):
|
|
||||||
"""
|
|
||||||
Simulates decode work by sleeping proportional to chunk duration.
|
|
||||||
|
|
||||||
Useful for demonstrating concurrency behavior without real FFmpeg.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ms_per_second: Milliseconds of simulated work per second of chunk duration (default: 100)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, ms_per_second: float = 100.0):
|
|
||||||
self.ms_per_second = ms_per_second
|
|
||||||
|
|
||||||
def process(self, chunk: Chunk) -> ChunkResult:
|
|
||||||
start = time.monotonic()
|
|
||||||
|
|
||||||
sleep_time = (self.ms_per_second * chunk.duration) / 1000.0
|
|
||||||
time.sleep(sleep_time)
|
|
||||||
|
|
||||||
elapsed = time.monotonic() - start
|
|
||||||
|
|
||||||
return ChunkResult(
|
|
||||||
sequence=chunk.sequence,
|
|
||||||
success=True,
|
|
||||||
checksum_valid=True,
|
|
||||||
processing_time=elapsed,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CompositeProcessor(Processor):
|
|
||||||
"""
|
|
||||||
Chains multiple processors — runs each in sequence on the same chunk.
|
|
||||||
|
|
||||||
Demonstrates OOP composition pattern.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
processors: List of processors to chain
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, processors: List[Processor]):
|
|
||||||
if not processors:
|
|
||||||
raise ValueError("CompositeProcessor requires at least one processor")
|
|
||||||
self.processors = processors
|
|
||||||
|
|
||||||
def process(self, chunk: Chunk) -> ChunkResult:
|
|
||||||
start = time.monotonic()
|
|
||||||
last_result = None
|
|
||||||
|
|
||||||
for proc in self.processors:
|
|
||||||
last_result = proc.process(chunk)
|
|
||||||
if not last_result.success:
|
|
||||||
return last_result
|
|
||||||
|
|
||||||
elapsed = time.monotonic() - start
|
|
||||||
|
|
||||||
return ChunkResult(
|
|
||||||
sequence=chunk.sequence,
|
|
||||||
success=True,
|
|
||||||
checksum_valid=last_result.checksum_valid if last_result else True,
|
|
||||||
processing_time=elapsed,
|
|
||||||
)
|
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
"""
|
|
||||||
ChunkQueue — bounded, thread-safe queue with sentinel-based shutdown.
|
|
||||||
|
|
||||||
Demonstrates: Core data structures — queue.Queue (Interview Topic 5).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import queue
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from .models import Chunk
|
|
||||||
|
|
||||||
# Sentinel value to signal workers to stop
|
|
||||||
_SENTINEL = object()
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkQueue:
|
|
||||||
"""
|
|
||||||
Thread-safe bounded queue for chunks.
|
|
||||||
|
|
||||||
Provides backpressure: producers block when the queue is full,
|
|
||||||
preventing unbounded memory usage.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
maxsize: Maximum number of chunks in the queue (default: 10)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, maxsize: int = 10):
|
|
||||||
self._queue: queue.Queue = queue.Queue(maxsize=maxsize)
|
|
||||||
self._closed = False
|
|
||||||
self.maxsize = maxsize
|
|
||||||
|
|
||||||
def put(self, chunk: Chunk, timeout: Optional[float] = None) -> None:
|
|
||||||
"""
|
|
||||||
Add a chunk to the queue. Blocks if full (backpressure).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chunk: The chunk to enqueue
|
|
||||||
timeout: Max seconds to wait (None = block forever)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
queue.Full: If timeout expires while queue is full
|
|
||||||
"""
|
|
||||||
self._queue.put(chunk, timeout=timeout)
|
|
||||||
|
|
||||||
def get(self, timeout: Optional[float] = None) -> Optional[Chunk]:
|
|
||||||
"""
|
|
||||||
Get next chunk from queue. Returns None if queue is closed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timeout: Max seconds to wait (None = block forever)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Chunk or None (if sentinel received, meaning queue is closed)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
queue.Empty: If timeout expires while queue is empty
|
|
||||||
"""
|
|
||||||
item = self._queue.get(timeout=timeout)
|
|
||||||
if item is _SENTINEL:
|
|
||||||
# Re-put sentinel so other workers also see it
|
|
||||||
self._queue.put(_SENTINEL)
|
|
||||||
return None
|
|
||||||
return item
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
"""Signal all consumers to stop by inserting a sentinel."""
|
|
||||||
self._closed = True
|
|
||||||
self._queue.put(_SENTINEL)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_closed(self) -> bool:
|
|
||||||
return self._closed
|
|
||||||
|
|
||||||
def qsize(self) -> int:
|
|
||||||
"""Current number of items in the queue (approximate)."""
|
|
||||||
return self._queue.qsize()
|
|
||||||
@@ -1,143 +0,0 @@
|
|||||||
"""
|
|
||||||
Worker — pulls chunks from queue, processes with retry logic.
|
|
||||||
|
|
||||||
Demonstrates:
|
|
||||||
- Exception handling and resilient code (Interview Topic 7)
|
|
||||||
- Concurrency (Interview Topic 2) — workers run in thread pool
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import queue
|
|
||||||
import time
|
|
||||||
from typing import Any, Callable, Dict, Optional
|
|
||||||
|
|
||||||
from .exceptions import ProcessorFailureError
|
|
||||||
from .models import Chunk, ChunkResult
|
|
||||||
from .processor import Processor
|
|
||||||
from .queue import ChunkQueue
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
|
||||||
"""
|
|
||||||
Processes chunks from a queue with retry and exponential backoff.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
worker_id: Identifier for this worker (e.g. "worker-0")
|
|
||||||
chunk_queue: Source queue to pull chunks from
|
|
||||||
processor: Processor instance to use
|
|
||||||
max_retries: Maximum retry attempts per chunk (default: 3)
|
|
||||||
event_callback: Optional callback for real-time status updates
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
worker_id: str,
|
|
||||||
chunk_queue: ChunkQueue,
|
|
||||||
processor: Processor,
|
|
||||||
max_retries: int = 3,
|
|
||||||
event_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None,
|
|
||||||
):
|
|
||||||
self.worker_id = worker_id
|
|
||||||
self.chunk_queue = chunk_queue
|
|
||||||
self.processor = processor
|
|
||||||
self.max_retries = max_retries
|
|
||||||
self.event_callback = event_callback
|
|
||||||
self.processed_count = 0
|
|
||||||
self.error_count = 0
|
|
||||||
self.retry_count = 0
|
|
||||||
|
|
||||||
def _emit(self, event_type: str, data: Dict[str, Any]) -> None:
|
|
||||||
"""Emit an event if callback is registered."""
|
|
||||||
if self.event_callback:
|
|
||||||
self.event_callback(event_type, {"worker_id": self.worker_id, **data})
|
|
||||||
|
|
||||||
def _process_with_retry(self, chunk: Chunk) -> ChunkResult:
|
|
||||||
"""
|
|
||||||
Process a chunk with exponential backoff retry.
|
|
||||||
|
|
||||||
Retry delays: 0.1s, 0.2s, 0.4s, ... (doubles each attempt)
|
|
||||||
"""
|
|
||||||
last_error = None
|
|
||||||
|
|
||||||
for attempt in range(self.max_retries + 1):
|
|
||||||
try:
|
|
||||||
if attempt > 0:
|
|
||||||
backoff = 0.1 * (2 ** (attempt - 1))
|
|
||||||
self._emit("chunk_retry", {
|
|
||||||
"sequence": chunk.sequence,
|
|
||||||
"attempt": attempt,
|
|
||||||
"backoff": backoff,
|
|
||||||
})
|
|
||||||
time.sleep(backoff)
|
|
||||||
self.retry_count += 1
|
|
||||||
|
|
||||||
result = self.processor.process(chunk)
|
|
||||||
result.retries = attempt
|
|
||||||
result.worker_id = self.worker_id
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
last_error = e
|
|
||||||
logger.warning(
|
|
||||||
f"{self.worker_id}: chunk {chunk.sequence} "
|
|
||||||
f"attempt {attempt + 1}/{self.max_retries + 1} failed: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# All retries exhausted
|
|
||||||
self.error_count += 1
|
|
||||||
self._emit("chunk_error", {
|
|
||||||
"sequence": chunk.sequence,
|
|
||||||
"error": str(last_error),
|
|
||||||
"retries": self.max_retries,
|
|
||||||
})
|
|
||||||
|
|
||||||
return ChunkResult(
|
|
||||||
sequence=chunk.sequence,
|
|
||||||
success=False,
|
|
||||||
processing_time=0.0,
|
|
||||||
error=str(last_error),
|
|
||||||
retries=self.max_retries,
|
|
||||||
worker_id=self.worker_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run(self) -> list[ChunkResult]:
|
|
||||||
"""
|
|
||||||
Main worker loop — pull chunks and process until queue is closed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of ChunkResults processed by this worker
|
|
||||||
"""
|
|
||||||
results = []
|
|
||||||
self._emit("worker_status", {"state": "idle"})
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
chunk = self.chunk_queue.get(timeout=1.0)
|
|
||||||
except queue.Empty:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if chunk is None: # Sentinel received
|
|
||||||
break
|
|
||||||
|
|
||||||
self._emit("chunk_processing", {
|
|
||||||
"sequence": chunk.sequence,
|
|
||||||
"state": "processing",
|
|
||||||
"queue_size": self.chunk_queue.qsize(),
|
|
||||||
})
|
|
||||||
|
|
||||||
result = self._process_with_retry(chunk)
|
|
||||||
results.append(result)
|
|
||||||
self.processed_count += 1
|
|
||||||
|
|
||||||
self._emit("chunk_done", {
|
|
||||||
"sequence": chunk.sequence,
|
|
||||||
"success": result.success,
|
|
||||||
"processing_time": result.processing_time,
|
|
||||||
"retries": result.retries,
|
|
||||||
"queue_size": self.chunk_queue.qsize(),
|
|
||||||
})
|
|
||||||
|
|
||||||
self._emit("worker_status", {"state": "stopped"})
|
|
||||||
return results
|
|
||||||
@@ -13,7 +13,7 @@ Basic CRUD (create, get, update, delete) goes directly through the session:
|
|||||||
|
|
||||||
from .connection import get_session, create_tables
|
from .connection import get_session, create_tables
|
||||||
|
|
||||||
from .tables import MediaAsset, Job, Timeline, Checkpoint, Brand
|
from .models import MediaAsset, Job, Timeline, Checkpoint, Brand
|
||||||
|
|
||||||
from .assets import list_assets, get_asset_filenames
|
from .assets import list_assets, get_asset_filenames
|
||||||
from .job import list_jobs
|
from .job import list_jobs
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
from .tables import MediaAsset
|
from .models import MediaAsset
|
||||||
|
|
||||||
|
|
||||||
def list_assets(session: Session, status: Optional[str] = None, search: Optional[str] = None) -> list[MediaAsset]:
|
def list_assets(session: Session, status: Optional[str] = None, search: Optional[str] = None) -> list[MediaAsset]:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
from .tables import Brand
|
from .models import Brand
|
||||||
|
|
||||||
|
|
||||||
def get_or_create_brand(session: Session, canonical_name: str,
|
def get_or_create_brand(session: Session, canonical_name: str,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
from .tables import Checkpoint
|
from .models import Checkpoint
|
||||||
|
|
||||||
|
|
||||||
def get_latest_checkpoint(session: Session, timeline_id: UUID, parent_id: UUID | None = None) -> Checkpoint | None:
|
def get_latest_checkpoint(session: Session, timeline_id: UUID, parent_id: UUID | None = None) -> Checkpoint | None:
|
||||||
|
|||||||
@@ -30,5 +30,5 @@ def get_session() -> Session:
|
|||||||
def create_tables():
|
def create_tables():
|
||||||
"""Create all SQLModel tables."""
|
"""Create all SQLModel tables."""
|
||||||
from sqlmodel import SQLModel
|
from sqlmodel import SQLModel
|
||||||
from . import tables # noqa — registers all table classes
|
from . import models # noqa — registers all table classes
|
||||||
SQLModel.metadata.create_all(get_engine())
|
SQLModel.metadata.create_all(get_engine())
|
||||||
|
|||||||
142
core/db/fixtures/soccer_broadcast.json
Normal file
142
core/db/fixtures/soccer_broadcast.json
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
{
|
||||||
|
"name": "soccer_broadcast",
|
||||||
|
"pipeline": {
|
||||||
|
"name": "soccer_broadcast",
|
||||||
|
"profile_name": "soccer_broadcast",
|
||||||
|
"stages": [
|
||||||
|
{
|
||||||
|
"name": "extract_frames",
|
||||||
|
"branch": "trunk"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "filter_scenes",
|
||||||
|
"branch": "trunk"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "field_segmentation",
|
||||||
|
"branch": "trunk"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "detect_edges",
|
||||||
|
"branch": "hoarding"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "detect_objects",
|
||||||
|
"branch": "objects"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "preprocess"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "run_ocr"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "match_brands"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "escalate_vlm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "escalate_cloud"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "compile_report"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"source": "extract_frames",
|
||||||
|
"target": "filter_scenes"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source": "filter_scenes",
|
||||||
|
"target": "field_segmentation"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source": "field_segmentation",
|
||||||
|
"target": "detect_edges"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source": "field_segmentation",
|
||||||
|
"target": "detect_objects"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source": "detect_edges",
|
||||||
|
"target": "preprocess"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source": "detect_objects",
|
||||||
|
"target": "preprocess"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source": "preprocess",
|
||||||
|
"target": "run_ocr"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source": "run_ocr",
|
||||||
|
"target": "match_brands"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source": "match_brands",
|
||||||
|
"target": "escalate_vlm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source": "escalate_vlm",
|
||||||
|
"target": "escalate_cloud"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source": "escalate_cloud",
|
||||||
|
"target": "compile_report"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"configs": {
|
||||||
|
"extract_frames": {
|
||||||
|
"fps": 2.0,
|
||||||
|
"max_frames": 500
|
||||||
|
},
|
||||||
|
"filter_scenes": {
|
||||||
|
"hamming_threshold": 8,
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"field_segmentation": {
|
||||||
|
"enabled": true,
|
||||||
|
"hue_low": 30,
|
||||||
|
"hue_high": 85,
|
||||||
|
"sat_low": 30,
|
||||||
|
"sat_high": 255,
|
||||||
|
"val_low": 30,
|
||||||
|
"val_high": 255,
|
||||||
|
"morph_kernel": 15,
|
||||||
|
"min_area_ratio": 0.05
|
||||||
|
},
|
||||||
|
"detect_edges": {
|
||||||
|
"enabled": true,
|
||||||
|
"edge_canny_low": 50,
|
||||||
|
"edge_canny_high": 150,
|
||||||
|
"edge_hough_threshold": 80,
|
||||||
|
"edge_hough_min_length": 100,
|
||||||
|
"edge_hough_max_gap": 10,
|
||||||
|
"edge_pair_max_distance": 200,
|
||||||
|
"edge_pair_min_distance": 15
|
||||||
|
},
|
||||||
|
"detect_objects": {
|
||||||
|
"model_name": "yolov8n.pt",
|
||||||
|
"confidence_threshold": 0.3,
|
||||||
|
"target_classes": []
|
||||||
|
},
|
||||||
|
"run_ocr": {
|
||||||
|
"languages": [
|
||||||
|
"en",
|
||||||
|
"es"
|
||||||
|
],
|
||||||
|
"min_confidence": 0.5
|
||||||
|
},
|
||||||
|
"match_brands": {
|
||||||
|
"fuzzy_threshold": 75
|
||||||
|
},
|
||||||
|
"escalate_vlm": {
|
||||||
|
"vlm_prompt_template": "Identify the brand or sponsor visible in this cropped region from a soccer broadcast.{hint}{text} Respond with: brand, confidence (0-1), reasoning."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,7 +7,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
from .tables import Job
|
from .models import Job
|
||||||
|
|
||||||
|
|
||||||
def list_jobs(session: Session, parent_id: Optional[UUID] = None, status: Optional[str] = None) -> list[Job]:
|
def list_jobs(session: Session, parent_id: Optional[UUID] = None, status: Optional[str] = None) -> list[Job]:
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class SourceType(str, Enum):
|
|||||||
|
|
||||||
class MediaAsset(SQLModel, table=True):
|
class MediaAsset(SQLModel, table=True):
|
||||||
"""A video/audio file registered in the system."""
|
"""A video/audio file registered in the system."""
|
||||||
__tablename__ = "media_assets"
|
__tablename__ = "media_asset"
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
filename: str
|
filename: str
|
||||||
@@ -67,7 +67,7 @@ class MediaAsset(SQLModel, table=True):
|
|||||||
|
|
||||||
class TranscodePreset(SQLModel, table=True):
|
class TranscodePreset(SQLModel, table=True):
|
||||||
"""A reusable transcoding configuration (like Handbrake presets)."""
|
"""A reusable transcoding configuration (like Handbrake presets)."""
|
||||||
__tablename__ = "transcode_presets"
|
__tablename__ = "transcode_preset"
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
name: str
|
name: str
|
||||||
@@ -90,12 +90,13 @@ class TranscodePreset(SQLModel, table=True):
|
|||||||
|
|
||||||
class Job(SQLModel, table=True):
|
class Job(SQLModel, table=True):
|
||||||
"""A pipeline job."""
|
"""A pipeline job."""
|
||||||
__tablename__ = "jobs"
|
__tablename__ = "job"
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
source_asset_id: UUID = Field(index=True)
|
source_asset_id: UUID = Field(index=True)
|
||||||
video_path: str
|
video_path: str
|
||||||
profile_name: str = "soccer_broadcast"
|
profile_name: str = "soccer_broadcast"
|
||||||
|
timeline_id: Optional[UUID] = None
|
||||||
parent_id: Optional[UUID] = None
|
parent_id: Optional[UUID] = None
|
||||||
run_type: RunType = "initial"
|
run_type: RunType = "initial"
|
||||||
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||||
@@ -107,7 +108,6 @@ class Job(SQLModel, table=True):
|
|||||||
brands_found: int = 0
|
brands_found: int = 0
|
||||||
cloud_llm_calls: int = 0
|
cloud_llm_calls: int = 0
|
||||||
estimated_cost_usd: float = 0.0
|
estimated_cost_usd: float = 0.0
|
||||||
celery_task_id: Optional[str] = None
|
|
||||||
priority: int = 0
|
priority: int = 0
|
||||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||||
started_at: Optional[datetime] = None
|
started_at: Optional[datetime] = None
|
||||||
@@ -115,7 +115,7 @@ class Job(SQLModel, table=True):
|
|||||||
|
|
||||||
class Timeline(SQLModel, table=True):
|
class Timeline(SQLModel, table=True):
|
||||||
"""The frame sequence from a source video."""
|
"""The frame sequence from a source video."""
|
||||||
__tablename__ = "timelines"
|
__tablename__ = "timeline"
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
source_asset_id: Optional[UUID] = Field(default=None, index=True)
|
source_asset_id: Optional[UUID] = Field(default=None, index=True)
|
||||||
@@ -129,10 +129,11 @@ class Timeline(SQLModel, table=True):
|
|||||||
|
|
||||||
class Checkpoint(SQLModel, table=True):
|
class Checkpoint(SQLModel, table=True):
|
||||||
"""A snapshot of pipeline state on a timeline."""
|
"""A snapshot of pipeline state on a timeline."""
|
||||||
__tablename__ = "checkpoints"
|
__tablename__ = "checkpoint"
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
timeline_id: UUID
|
timeline_id: UUID
|
||||||
|
job_id: Optional[UUID] = Field(default=None, index=True)
|
||||||
parent_id: Optional[UUID] = None
|
parent_id: Optional[UUID] = None
|
||||||
stage_outputs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
stage_outputs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||||
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||||
@@ -143,7 +144,7 @@ class Checkpoint(SQLModel, table=True):
|
|||||||
|
|
||||||
class Brand(SQLModel, table=True):
|
class Brand(SQLModel, table=True):
|
||||||
"""A brand discovered or registered in the system."""
|
"""A brand discovered or registered in the system."""
|
||||||
__tablename__ = "brands"
|
__tablename__ = "brand"
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
canonical_name: str = Field(index=True)
|
canonical_name: str = Field(index=True)
|
||||||
@@ -154,3 +155,12 @@ class Brand(SQLModel, table=True):
|
|||||||
total_airings: int = 0
|
total_airings: int = 0
|
||||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||||
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
class Profile(SQLModel, table=True):
|
||||||
|
"""A content type profile."""
|
||||||
|
__tablename__ = "profile"
|
||||||
|
|
||||||
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
|
name: str
|
||||||
|
pipeline: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||||
|
configs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||||
|
|||||||
43
core/db/seed.py
Normal file
43
core/db/seed.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
"""
|
||||||
|
Seed data — insert initial profile rows if they don't exist.
|
||||||
|
|
||||||
|
Called on startup after create_tables().
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SEED_DIR = Path(__file__).parent / "fixtures"
|
||||||
|
|
||||||
|
|
||||||
|
def seed_profiles():
|
||||||
|
"""Insert seed profiles from JSON fixtures if not already present."""
|
||||||
|
from .connection import get_session
|
||||||
|
from .models import Profile
|
||||||
|
|
||||||
|
fixtures = list(SEED_DIR.glob("*.json"))
|
||||||
|
if not fixtures:
|
||||||
|
return
|
||||||
|
|
||||||
|
with get_session() as session:
|
||||||
|
for f in fixtures:
|
||||||
|
data = json.loads(f.read_text())
|
||||||
|
name = data["name"]
|
||||||
|
|
||||||
|
existing = session.query(Profile).filter(Profile.name == name).first()
|
||||||
|
if existing:
|
||||||
|
logger.debug("Profile %s already exists, skipping seed", name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
profile = Profile(
|
||||||
|
name=name,
|
||||||
|
pipeline=data.get("pipeline", {}),
|
||||||
|
configs=data.get("configs", {}),
|
||||||
|
)
|
||||||
|
session.add(profile)
|
||||||
|
logger.info("Seeded profile: %s", name)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
@@ -1,96 +0,0 @@
|
|||||||
"""
|
|
||||||
SQLModel table definitions.
|
|
||||||
|
|
||||||
Generated by modelgen from core/schema/models/. Do not edit directly.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
from uuid import UUID, uuid4
|
|
||||||
|
|
||||||
from sqlalchemy import JSON
|
|
||||||
from sqlmodel import Column, Field, SQLModel
|
|
||||||
|
|
||||||
|
|
||||||
class MediaAsset(SQLModel, table=True):
|
|
||||||
__tablename__ = "media_asset"
|
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
|
||||||
filename: str
|
|
||||||
path: str
|
|
||||||
status: str = "pending"
|
|
||||||
size_bytes: int = 0
|
|
||||||
duration_seconds: float = 0.0
|
|
||||||
width: Optional[int] = None
|
|
||||||
height: Optional[int] = None
|
|
||||||
fps: Optional[float] = None
|
|
||||||
codec: Optional[str] = None
|
|
||||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
|
||||||
|
|
||||||
|
|
||||||
class Job(SQLModel, table=True):
|
|
||||||
__tablename__ = "job"
|
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
|
||||||
source_asset_id: UUID = Field(index=True)
|
|
||||||
video_path: str
|
|
||||||
profile_name: str = "soccer_broadcast"
|
|
||||||
parent_id: Optional[UUID] = Field(default=None, index=True)
|
|
||||||
run_type: str = "initial"
|
|
||||||
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
|
||||||
status: str = "pending"
|
|
||||||
current_stage: Optional[str] = None
|
|
||||||
progress: float = 0.0
|
|
||||||
error_message: Optional[str] = None
|
|
||||||
total_detections: int = 0
|
|
||||||
brands_found: int = 0
|
|
||||||
cloud_llm_calls: int = 0
|
|
||||||
estimated_cost_usd: float = 0.0
|
|
||||||
priority: int = 0
|
|
||||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
|
||||||
started_at: Optional[datetime] = None
|
|
||||||
completed_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
|
|
||||||
class Timeline(SQLModel, table=True):
|
|
||||||
__tablename__ = "timeline"
|
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
|
||||||
source_asset_id: Optional[UUID] = Field(default=None, index=True)
|
|
||||||
source_video: str = ""
|
|
||||||
profile_name: str = ""
|
|
||||||
fps: float = 2.0
|
|
||||||
frames_prefix: str = ""
|
|
||||||
frames_manifest: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
|
||||||
frames_meta: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
|
|
||||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
|
||||||
|
|
||||||
|
|
||||||
class Checkpoint(SQLModel, table=True):
|
|
||||||
__tablename__ = "checkpoint"
|
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
|
||||||
timeline_id: UUID = Field(index=True)
|
|
||||||
parent_id: Optional[UUID] = Field(default=None, index=True)
|
|
||||||
stage_outputs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
|
||||||
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
|
||||||
stats: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
|
||||||
is_scenario: bool = False
|
|
||||||
scenario_label: str = ""
|
|
||||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
|
||||||
|
|
||||||
|
|
||||||
class Brand(SQLModel, table=True):
|
|
||||||
__tablename__ = "brand"
|
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
|
||||||
canonical_name: str = Field(index=True)
|
|
||||||
aliases: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
|
|
||||||
source: str = "ocr"
|
|
||||||
confirmed: bool = False
|
|
||||||
airings: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
|
|
||||||
total_airings: int = 0
|
|
||||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
|
||||||
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
|
||||||
@@ -16,4 +16,4 @@ from .storage import (
|
|||||||
load_stage_output,
|
load_stage_output,
|
||||||
)
|
)
|
||||||
from .frames import save_frames, load_frames
|
from .frames import save_frames, load_frames
|
||||||
from .runner_bridge import checkpoint_after_stage, reset_checkpoint_state
|
from .runner_bridge import checkpoint_after_stage, reset_checkpoint_state, get_timeline_id
|
||||||
@@ -9,7 +9,7 @@ import tempfile
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from detect.models import Frame
|
from core.detect.models import Frame
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -11,7 +11,7 @@ import logging
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from detect import emit
|
from core.detect import emit
|
||||||
# TODO: migrate to Timeline/Branch/Checkpoint model
|
# TODO: migrate to Timeline/Branch/Checkpoint model
|
||||||
# These old functions no longer exist — replay needs rework
|
# These old functions no longer exist — replay needs rework
|
||||||
def _not_migrated(*args, **kwargs):
|
def _not_migrated(*args, **kwargs):
|
||||||
@@ -19,66 +19,14 @@ def _not_migrated(*args, **kwargs):
|
|||||||
|
|
||||||
load_checkpoint = _not_migrated
|
load_checkpoint = _not_migrated
|
||||||
list_checkpoints = _not_migrated
|
list_checkpoints = _not_migrated
|
||||||
from detect.graph import NODES, build_graph
|
from core.detect.graph import NODES, build_graph
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OverrideProfile:
|
|
||||||
"""
|
|
||||||
Wraps a ContentTypeProfile and patches config methods with overrides.
|
|
||||||
|
|
||||||
Override dict structure:
|
# OverrideProfile removed — config overrides are now handled by dict merging
|
||||||
{
|
# in _load_profile() (nodes.py) and replay_single_stage (below).
|
||||||
"frame_extraction": {"fps": 1.0},
|
|
||||||
"scene_filter": {"hamming_threshold": 12},
|
|
||||||
"region_analysis": {"edge_canny_low": 30, "edge_canny_high": 120},
|
|
||||||
"detection": {"confidence_threshold": 0.5},
|
|
||||||
"ocr": {"languages": ["en", "es"], "min_confidence": 0.3},
|
|
||||||
"resolver": {"fuzzy_threshold": 60},
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, base, overrides: dict):
|
|
||||||
self._base = base
|
|
||||||
self._overrides = overrides
|
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
return getattr(self._base, name)
|
|
||||||
|
|
||||||
def _patch(self, config, key: str):
|
|
||||||
patches = self._overrides.get(key, {})
|
|
||||||
for k, v in patches.items():
|
|
||||||
if hasattr(config, k):
|
|
||||||
setattr(config, k, v)
|
|
||||||
return config
|
|
||||||
|
|
||||||
def frame_extraction_config(self):
|
|
||||||
return self._patch(self._base.frame_extraction_config(), "frame_extraction")
|
|
||||||
|
|
||||||
def scene_filter_config(self):
|
|
||||||
return self._patch(self._base.scene_filter_config(), "scene_filter")
|
|
||||||
|
|
||||||
def region_analysis_config(self):
|
|
||||||
return self._patch(self._base.region_analysis_config(), "region_analysis")
|
|
||||||
|
|
||||||
def detection_config(self):
|
|
||||||
return self._patch(self._base.detection_config(), "detection")
|
|
||||||
|
|
||||||
def ocr_config(self):
|
|
||||||
return self._patch(self._base.ocr_config(), "ocr")
|
|
||||||
|
|
||||||
def resolver_config(self):
|
|
||||||
return self._patch(self._base.resolver_config(), "resolver")
|
|
||||||
|
|
||||||
def vlm_prompt(self, crop_context):
|
|
||||||
return self._base.vlm_prompt(crop_context)
|
|
||||||
|
|
||||||
def aggregate(self, detections):
|
|
||||||
return self._base.aggregate(detections)
|
|
||||||
|
|
||||||
def auxiliary_detections(self, source):
|
|
||||||
return self._base.auxiliary_detections(source)
|
|
||||||
|
|
||||||
|
|
||||||
def replay_from(
|
def replay_from(
|
||||||
@@ -183,10 +131,16 @@ def replay_single_stage(
|
|||||||
state = load_checkpoint(job_id, previous_stage)
|
state = load_checkpoint(job_id, previous_stage)
|
||||||
|
|
||||||
# Build profile with overrides
|
# Build profile with overrides
|
||||||
from detect.profiles import get_profile
|
from core.detect.profile import get_profile, get_stage_config
|
||||||
profile = get_profile(state.get("profile_name", "soccer_broadcast"))
|
profile = get_profile(state.get("profile_name", "soccer_broadcast"))
|
||||||
if config_overrides:
|
if config_overrides:
|
||||||
profile = OverrideProfile(profile, config_overrides)
|
merged_configs = dict(profile.get("configs", {}))
|
||||||
|
for sname, soverrides in config_overrides.items():
|
||||||
|
if sname in merged_configs:
|
||||||
|
merged_configs[sname] = {**merged_configs[sname], **soverrides}
|
||||||
|
else:
|
||||||
|
merged_configs[sname] = soverrides
|
||||||
|
profile = {**profile, "configs": merged_configs}
|
||||||
|
|
||||||
# Run the stage function directly (not through the graph)
|
# Run the stage function directly (not through the graph)
|
||||||
if stage == "detect_edges":
|
if stage == "detect_edges":
|
||||||
@@ -207,9 +161,11 @@ def _replay_detect_edges(
|
|||||||
) -> dict:
|
) -> dict:
|
||||||
"""Run edge detection on checkpoint frames, optionally with debug overlays."""
|
"""Run edge detection on checkpoint frames, optionally with debug overlays."""
|
||||||
import os
|
import os
|
||||||
from detect.stages.edge_detector import detect_edge_regions
|
from core.detect.stages.edge_detector import detect_edge_regions
|
||||||
|
|
||||||
config = profile.region_analysis_config()
|
from core.detect.profile import get_stage_config
|
||||||
|
from core.detect.stages.models import RegionAnalysisConfig
|
||||||
|
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
||||||
frames = state.get("filtered_frames", [])
|
frames = state.get("filtered_frames", [])
|
||||||
|
|
||||||
if frame_refs:
|
if frame_refs:
|
||||||
@@ -231,7 +187,7 @@ def _replay_detect_edges(
|
|||||||
if debug and frames:
|
if debug and frames:
|
||||||
debug_data = {}
|
debug_data = {}
|
||||||
if inference_url:
|
if inference_url:
|
||||||
from detect.inference import InferenceClient
|
from core.detect.inference import InferenceClient
|
||||||
client = InferenceClient(base_url=inference_url, job_id=job_id)
|
client = InferenceClient(base_url=inference_url, job_id=job_id)
|
||||||
for frame in frames:
|
for frame in frames:
|
||||||
dr = client.detect_edges_debug(
|
dr = client.detect_edges_debug(
|
||||||
@@ -252,7 +208,7 @@ def _replay_detect_edges(
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# Local mode — import GPU module directly
|
# Local mode — import GPU module directly
|
||||||
from detect.stages.edge_detector import _load_cv_edges
|
from core.detect.stages.edge_detector import _load_cv_edges
|
||||||
edges_mod = _load_cv_edges()
|
edges_mod = _load_cv_edges()
|
||||||
for frame in frames:
|
for frame in frames:
|
||||||
dr = edges_mod.detect_edges_debug(
|
dr = edges_mod.detect_edges_debug(
|
||||||
97
core/detect/checkpoint/runner_bridge.py
Normal file
97
core/detect/checkpoint/runner_bridge.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
"""
|
||||||
|
Runner bridge — checkpoint hook called by PipelineRunner after each stage.
|
||||||
|
|
||||||
|
Owns the per-job state (timeline, frame manifest, checkpoint chain) that
|
||||||
|
the runner shouldn't know about.
|
||||||
|
|
||||||
|
Timeline and Job are independent entities:
|
||||||
|
- One Timeline can serve multiple Jobs (re-run with different params)
|
||||||
|
- One Job operates on one Timeline (set after frame extraction)
|
||||||
|
- Checkpoints belong to Timeline, tagged with the Job that created them
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Per-job state
|
||||||
|
_timeline_id: dict[str, str] = {}
|
||||||
|
_frames_manifest: dict[str, dict[int, str]] = {}
|
||||||
|
_latest_checkpoint: dict[str, str] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def reset_checkpoint_state(job_id: str):
|
||||||
|
"""Clean up per-job checkpoint state. Called when pipeline finishes."""
|
||||||
|
_timeline_id.pop(job_id, None)
|
||||||
|
_frames_manifest.pop(job_id, None)
|
||||||
|
_latest_checkpoint.pop(job_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict):
|
||||||
|
"""
|
||||||
|
Save a checkpoint after a stage completes.
|
||||||
|
|
||||||
|
Called by the runner. Handles:
|
||||||
|
- Timeline creation (once, on extract_frames)
|
||||||
|
- Frame upload (via create_timeline)
|
||||||
|
- Stage output serialization (via stage registry)
|
||||||
|
- Checkpoint chain (parent → child)
|
||||||
|
"""
|
||||||
|
if not job_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
from .storage import create_timeline, save_stage_output
|
||||||
|
from core.detect.stages.base import _REGISTRY
|
||||||
|
|
||||||
|
merged = {**state, **result}
|
||||||
|
|
||||||
|
# On extract_frames: create Timeline + upload frames + root checkpoint
|
||||||
|
if stage_name == "extract_frames" and job_id not in _timeline_id:
|
||||||
|
frames = merged.get("frames", [])
|
||||||
|
video_path = merged.get("video_path", "")
|
||||||
|
profile_name = merged.get("profile_name", "")
|
||||||
|
|
||||||
|
tid, cid = create_timeline(
|
||||||
|
source_video=video_path,
|
||||||
|
profile_name=profile_name,
|
||||||
|
frames=frames,
|
||||||
|
)
|
||||||
|
_timeline_id[job_id] = tid
|
||||||
|
_latest_checkpoint[job_id] = cid
|
||||||
|
logger.info("Job %s → Timeline %s (root checkpoint %s)", job_id, tid, cid)
|
||||||
|
|
||||||
|
# Emit timeline_id via SSE so the UI can use it for checkpoint loads
|
||||||
|
from core.detect import emit
|
||||||
|
emit.log(job_id, "Checkpoint", "INFO", f"timeline_id={tid}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# For subsequent stages: save checkpoint on the timeline
|
||||||
|
tid = _timeline_id.get(job_id)
|
||||||
|
if not tid:
|
||||||
|
logger.warning("No timeline for job %s, skipping checkpoint", job_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Serialize stage output using the stage's serialize_fn if available
|
||||||
|
stage_cls = _REGISTRY.get(stage_name)
|
||||||
|
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
||||||
|
if serialize_fn:
|
||||||
|
output_json = serialize_fn(merged, job_id)
|
||||||
|
else:
|
||||||
|
output_json = {}
|
||||||
|
|
||||||
|
parent_id = _latest_checkpoint.get(job_id)
|
||||||
|
new_checkpoint_id = save_stage_output(
|
||||||
|
timeline_id=tid,
|
||||||
|
parent_checkpoint_id=parent_id,
|
||||||
|
stage_name=stage_name,
|
||||||
|
output_json=output_json,
|
||||||
|
job_id=job_id,
|
||||||
|
)
|
||||||
|
_latest_checkpoint[job_id] = new_checkpoint_id
|
||||||
|
|
||||||
|
|
||||||
|
def get_timeline_id(job_id: str) -> str | None:
|
||||||
|
"""Get the timeline_id for a running job. Used by the UI to load checkpoints."""
|
||||||
|
return _timeline_id.get(job_id)
|
||||||
@@ -28,7 +28,7 @@ def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict:
|
|||||||
Calls each registered stage's serialize_fn for stage-owned data.
|
Calls each registered stage's serialize_fn for stage-owned data.
|
||||||
Envelope fields (job_id, etc.) are copied directly.
|
Envelope fields (job_id, etc.) are copied directly.
|
||||||
"""
|
"""
|
||||||
from detect.stages.base import _REGISTRY
|
from core.detect.stages.base import _REGISTRY
|
||||||
|
|
||||||
checkpoint = {}
|
checkpoint = {}
|
||||||
|
|
||||||
@@ -64,7 +64,7 @@ def deserialize_state(checkpoint: dict, frames: list) -> dict:
|
|||||||
|
|
||||||
Calls each stage's deserialize_fn to restore stage-owned data.
|
Calls each stage's deserialize_fn to restore stage-owned data.
|
||||||
"""
|
"""
|
||||||
from detect.stages.base import _REGISTRY
|
from core.detect.stages.base import _REGISTRY
|
||||||
|
|
||||||
frame_map = {f.sequence: f for f in frames}
|
frame_map = {f.sequence: f for f in frames}
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ def create_timeline(
|
|||||||
|
|
||||||
Returns (timeline_id, checkpoint_id).
|
Returns (timeline_id, checkpoint_id).
|
||||||
"""
|
"""
|
||||||
from core.db.tables import Timeline, Checkpoint
|
from core.db.models import Timeline, Checkpoint
|
||||||
from core.db.connection import get_session
|
from core.db.connection import get_session
|
||||||
|
|
||||||
with get_session() as session:
|
with get_session() as session:
|
||||||
@@ -81,7 +81,7 @@ def create_timeline(
|
|||||||
|
|
||||||
def get_timeline_frames(timeline_id: str) -> list:
|
def get_timeline_frames(timeline_id: str) -> list:
|
||||||
"""Load frames from a timeline (from MinIO) as Frame objects."""
|
"""Load frames from a timeline (from MinIO) as Frame objects."""
|
||||||
from core.db.tables import Timeline
|
from core.db.models import Timeline
|
||||||
from core.db.connection import get_session
|
from core.db.connection import get_session
|
||||||
|
|
||||||
with get_session() as session:
|
with get_session() as session:
|
||||||
@@ -96,7 +96,7 @@ def get_timeline_frames(timeline_id: str) -> list:
|
|||||||
|
|
||||||
def get_timeline_frames_b64(timeline_id: str) -> list[dict]:
|
def get_timeline_frames_b64(timeline_id: str) -> list[dict]:
|
||||||
"""Load frames as base64 JPEG (lightweight, no numpy)."""
|
"""Load frames as base64 JPEG (lightweight, no numpy)."""
|
||||||
from core.db.tables import Timeline
|
from core.db.models import Timeline
|
||||||
from core.db.connection import get_session
|
from core.db.connection import get_session
|
||||||
from .frames import load_frames_b64
|
from .frames import load_frames_b64
|
||||||
|
|
||||||
@@ -123,6 +123,7 @@ def save_stage_output(
|
|||||||
stats: dict | None = None,
|
stats: dict | None = None,
|
||||||
is_scenario: bool = False,
|
is_scenario: bool = False,
|
||||||
scenario_label: str = "",
|
scenario_label: str = "",
|
||||||
|
job_id: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Save a stage's output as a new checkpoint (child of parent).
|
Save a stage's output as a new checkpoint (child of parent).
|
||||||
@@ -130,7 +131,7 @@ def save_stage_output(
|
|||||||
Carries forward stage outputs from parent + adds the new one.
|
Carries forward stage outputs from parent + adds the new one.
|
||||||
Returns the new checkpoint ID.
|
Returns the new checkpoint ID.
|
||||||
"""
|
"""
|
||||||
from core.db.tables import Checkpoint
|
from core.db.models import Checkpoint
|
||||||
from core.db.connection import get_session
|
from core.db.connection import get_session
|
||||||
|
|
||||||
with get_session() as session:
|
with get_session() as session:
|
||||||
@@ -146,6 +147,7 @@ def save_stage_output(
|
|||||||
|
|
||||||
checkpoint = Checkpoint(
|
checkpoint = Checkpoint(
|
||||||
timeline_id=UUID(timeline_id),
|
timeline_id=UUID(timeline_id),
|
||||||
|
job_id=UUID(job_id) if job_id else None,
|
||||||
parent_id=UUID(parent_checkpoint_id) if parent_checkpoint_id else None,
|
parent_id=UUID(parent_checkpoint_id) if parent_checkpoint_id else None,
|
||||||
stage_outputs={**parent_outputs, stage_name: output_json},
|
stage_outputs={**parent_outputs, stage_name: output_json},
|
||||||
config_overrides={**parent_config, **(config_overrides or {})},
|
config_overrides={**parent_config, **(config_overrides or {})},
|
||||||
@@ -165,7 +167,7 @@ def save_stage_output(
|
|||||||
|
|
||||||
def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None:
|
def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None:
|
||||||
"""Load a stage's output from a checkpoint."""
|
"""Load a stage's output from a checkpoint."""
|
||||||
from core.db.tables import Checkpoint
|
from core.db.models import Checkpoint
|
||||||
from core.db.connection import get_session
|
from core.db.connection import get_session
|
||||||
|
|
||||||
with get_session() as session:
|
with get_session() as session:
|
||||||
@@ -16,8 +16,8 @@ from __future__ import annotations
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from detect.events import push_detect_event
|
from core.detect.events import push_detect_event
|
||||||
from detect.models import PipelineStats
|
from core.detect.models import PipelineStats
|
||||||
|
|
||||||
# Log level ordering for comparison
|
# Log level ordering for comparison
|
||||||
_LEVEL_ORDER = {"DEBUG": 0, "INFO": 1, "WARN": 2, "ERROR": 3}
|
_LEVEL_ORDER = {"DEBUG": 0, "INFO": 1, "WARN": 2, "ERROR": 3}
|
||||||
@@ -4,8 +4,8 @@ Graph event emission — node state tracking + SSE graph_update events.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from detect import emit
|
from core.detect import emit
|
||||||
from detect.state import DetectState
|
from core.detect.state import DetectState
|
||||||
|
|
||||||
|
|
||||||
# Track node states across pipeline runs
|
# Track node states across pipeline runs
|
||||||
@@ -1,28 +1,39 @@
|
|||||||
"""
|
"""
|
||||||
Pipeline node functions — one per stage.
|
Pipeline node functions — one per stage.
|
||||||
|
|
||||||
Each node: reads state, runs stage logic, emits transitions, returns output dict.
|
Each node: reads state, gets config from profile dict, runs stage logic,
|
||||||
|
emits transitions, returns output dict.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from detect import emit
|
from core.detect import emit
|
||||||
from detect.models import PipelineStats
|
from core.detect.models import CropContext, PipelineStats
|
||||||
from detect.profiles import SoccerBroadcastProfile
|
from core.detect.profile import get_profile, get_stage_config, build_vlm_prompt, aggregate_detections
|
||||||
from detect.state import DetectState
|
from core.detect.stages.models import (
|
||||||
from detect.stages.frame_extractor import extract_frames
|
DetectionConfig,
|
||||||
from detect.stages.scene_filter import scene_filter
|
FieldSegmentationConfig,
|
||||||
from detect.stages.edge_detector import detect_edge_regions
|
FrameExtractionConfig,
|
||||||
from detect.stages.yolo_detector import detect_objects
|
OCRConfig,
|
||||||
from detect.stages.preprocess import preprocess_regions
|
RegionAnalysisConfig,
|
||||||
from detect.stages.ocr_stage import run_ocr
|
ResolverConfig,
|
||||||
from detect.stages.brand_resolver import resolve_brands
|
SceneFilterConfig,
|
||||||
from detect.stages.vlm_local import escalate_vlm
|
)
|
||||||
from detect.stages.vlm_cloud import escalate_cloud
|
from core.detect.state import DetectState
|
||||||
from detect.stages.aggregator import compile_report
|
from core.detect.stages.frame_extractor import extract_frames
|
||||||
from detect.tracing import trace_node, flush as flush_traces
|
from core.detect.stages.scene_filter import scene_filter
|
||||||
|
from core.detect.stages.field_segmentation import run_field_segmentation
|
||||||
|
from core.detect.stages.edge_detector import detect_edge_regions
|
||||||
|
from core.detect.stages.yolo_detector import detect_objects
|
||||||
|
from core.detect.stages.preprocess import preprocess_regions
|
||||||
|
from core.detect.stages.ocr_stage import run_ocr
|
||||||
|
from core.detect.stages.brand_resolver import resolve_brands
|
||||||
|
from core.detect.stages.vlm_local import escalate_vlm
|
||||||
|
from core.detect.stages.vlm_cloud import escalate_cloud
|
||||||
|
from core.detect.stages.aggregator import compile_report
|
||||||
|
from core.detect.tracing import trace_node, flush as flush_traces
|
||||||
|
|
||||||
from .events import emit_transition
|
from .events import emit_transition
|
||||||
|
|
||||||
@@ -31,6 +42,7 @@ INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
|
|||||||
NODES = [
|
NODES = [
|
||||||
"extract_frames",
|
"extract_frames",
|
||||||
"filter_scenes",
|
"filter_scenes",
|
||||||
|
"field_segmentation",
|
||||||
"detect_edges",
|
"detect_edges",
|
||||||
"detect_objects",
|
"detect_objects",
|
||||||
"preprocess",
|
"preprocess",
|
||||||
@@ -42,17 +54,21 @@ NODES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _get_profile(state: DetectState):
|
def _load_profile(state: DetectState) -> dict:
|
||||||
|
"""Load profile dict, apply config overrides if present."""
|
||||||
name = state.get("profile_name", "soccer_broadcast")
|
name = state.get("profile_name", "soccer_broadcast")
|
||||||
if name == "soccer_broadcast":
|
profile = get_profile(name)
|
||||||
profile = SoccerBroadcastProfile()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown profile: {name}")
|
|
||||||
|
|
||||||
overrides = state.get("config_overrides")
|
overrides = state.get("config_overrides")
|
||||||
if overrides:
|
if overrides:
|
||||||
from detect.checkpoint.replay import OverrideProfile
|
# Merge overrides into a copy of the profile configs
|
||||||
profile = OverrideProfile(profile, overrides)
|
merged_configs = dict(profile.get("configs", {}))
|
||||||
|
for stage_name, stage_overrides in overrides.items():
|
||||||
|
if stage_name in merged_configs:
|
||||||
|
merged_configs[stage_name] = {**merged_configs[stage_name], **stage_overrides}
|
||||||
|
else:
|
||||||
|
merged_configs[stage_name] = stage_overrides
|
||||||
|
profile = {**profile, "configs": merged_configs}
|
||||||
|
|
||||||
return profile
|
return profile
|
||||||
|
|
||||||
@@ -70,16 +86,16 @@ def node_extract_frames(state: DetectState) -> dict:
|
|||||||
|
|
||||||
source_asset_id = state.get("source_asset_id")
|
source_asset_id = state.get("source_asset_id")
|
||||||
if source_asset_id and not state.get("session_brands"):
|
if source_asset_id and not state.get("session_brands"):
|
||||||
from detect.stages.brand_resolver import build_session_dict
|
from core.detect.stages.brand_resolver import build_session_dict
|
||||||
session_brands = build_session_dict(source_asset_id)
|
session_brands = build_session_dict(source_asset_id)
|
||||||
state["session_brands"] = session_brands
|
state["session_brands"] = session_brands
|
||||||
|
|
||||||
_emit(state, "extract_frames", "running")
|
_emit(state, "extract_frames", "running")
|
||||||
|
|
||||||
with trace_node(state, "extract_frames") as span:
|
with trace_node(state, "extract_frames") as span:
|
||||||
profile = _get_profile(state)
|
profile = _load_profile(state)
|
||||||
config = profile.frame_extraction_config()
|
config = FrameExtractionConfig(**get_stage_config(profile, "extract_frames"))
|
||||||
frames = extract_frames(state["video_path"], config, job_id=state.get("job_id"))
|
frames = extract_frames(state["video_path"], config, job_id=job_id)
|
||||||
span.set_output({"frames_extracted": len(frames)})
|
span.set_output({"frames_extracted": len(frames)})
|
||||||
|
|
||||||
_emit(state, "extract_frames", "done")
|
_emit(state, "extract_frames", "done")
|
||||||
@@ -90,8 +106,8 @@ def node_filter_scenes(state: DetectState) -> dict:
|
|||||||
_emit(state, "filter_scenes", "running")
|
_emit(state, "filter_scenes", "running")
|
||||||
|
|
||||||
with trace_node(state, "filter_scenes") as span:
|
with trace_node(state, "filter_scenes") as span:
|
||||||
profile = _get_profile(state)
|
profile = _load_profile(state)
|
||||||
config = profile.scene_filter_config()
|
config = SceneFilterConfig(**get_stage_config(profile, "filter_scenes"))
|
||||||
frames = state.get("frames", [])
|
frames = state.get("frames", [])
|
||||||
kept = scene_filter(frames, config, job_id=state.get("job_id"))
|
kept = scene_filter(frames, config, job_id=state.get("job_id"))
|
||||||
span.set_output({"frames_in": len(frames), "frames_kept": len(kept)})
|
span.set_output({"frames_in": len(frames), "frames_kept": len(kept)})
|
||||||
@@ -103,17 +119,42 @@ def node_filter_scenes(state: DetectState) -> dict:
|
|||||||
return {"filtered_frames": kept, "stats": stats}
|
return {"filtered_frames": kept, "stats": stats}
|
||||||
|
|
||||||
|
|
||||||
|
def node_field_segmentation(state: DetectState) -> dict:
|
||||||
|
_emit(state, "field_segmentation", "running")
|
||||||
|
|
||||||
|
with trace_node(state, "field_segmentation") as span:
|
||||||
|
profile = _load_profile(state)
|
||||||
|
config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation"))
|
||||||
|
frames = state.get("filtered_frames", [])
|
||||||
|
job_id = state.get("job_id")
|
||||||
|
|
||||||
|
result = run_field_segmentation(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
|
||||||
|
span.set_output({
|
||||||
|
"frames": len(frames),
|
||||||
|
"avg_coverage": sum(result["field_coverage"].values()) / max(len(result["field_coverage"]), 1),
|
||||||
|
})
|
||||||
|
|
||||||
|
_emit(state, "field_segmentation", "done")
|
||||||
|
return {
|
||||||
|
"field_masks": result["field_masks"],
|
||||||
|
"field_boundaries": result["field_boundaries"],
|
||||||
|
"field_coverage": result["field_coverage"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def node_detect_edges(state: DetectState) -> dict:
|
def node_detect_edges(state: DetectState) -> dict:
|
||||||
_emit(state, "detect_edges", "running")
|
_emit(state, "detect_edges", "running")
|
||||||
|
|
||||||
with trace_node(state, "detect_edges") as span:
|
with trace_node(state, "detect_edges") as span:
|
||||||
profile = _get_profile(state)
|
profile = _load_profile(state)
|
||||||
config = profile.region_analysis_config()
|
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
||||||
frames = state.get("filtered_frames", [])
|
frames = state.get("filtered_frames", [])
|
||||||
|
field_masks = state.get("field_masks", {})
|
||||||
job_id = state.get("job_id")
|
job_id = state.get("job_id")
|
||||||
|
|
||||||
regions = detect_edge_regions(
|
regions = detect_edge_regions(
|
||||||
frames, config, inference_url=INFERENCE_URL, job_id=job_id,
|
frames, config, inference_url=INFERENCE_URL, job_id=job_id,
|
||||||
|
field_masks=field_masks,
|
||||||
)
|
)
|
||||||
total = sum(len(r) for r in regions.values())
|
total = sum(len(r) for r in regions.values())
|
||||||
span.set_output({"frames": len(frames), "edge_regions": total})
|
span.set_output({"frames": len(frames), "edge_regions": total})
|
||||||
@@ -129,8 +170,8 @@ def node_detect_objects(state: DetectState) -> dict:
|
|||||||
_emit(state, "detect_objects", "running")
|
_emit(state, "detect_objects", "running")
|
||||||
|
|
||||||
with trace_node(state, "detect_objects") as span:
|
with trace_node(state, "detect_objects") as span:
|
||||||
profile = _get_profile(state)
|
profile = _load_profile(state)
|
||||||
config = profile.detection_config()
|
config = DetectionConfig(**get_stage_config(profile, "detect_objects"))
|
||||||
frames = state.get("filtered_frames", [])
|
frames = state.get("filtered_frames", [])
|
||||||
job_id = state.get("job_id")
|
job_id = state.get("job_id")
|
||||||
|
|
||||||
@@ -149,13 +190,12 @@ def node_preprocess(state: DetectState) -> dict:
|
|||||||
_emit(state, "preprocess", "running")
|
_emit(state, "preprocess", "running")
|
||||||
|
|
||||||
with trace_node(state, "preprocess") as span:
|
with trace_node(state, "preprocess") as span:
|
||||||
profile = _get_profile(state)
|
profile = _load_profile(state)
|
||||||
|
prep_config = get_stage_config(profile, "preprocess")
|
||||||
frames = state.get("filtered_frames", [])
|
frames = state.get("filtered_frames", [])
|
||||||
boxes = state.get("boxes_by_frame", {})
|
boxes = state.get("boxes_by_frame", {})
|
||||||
job_id = state.get("job_id")
|
job_id = state.get("job_id")
|
||||||
|
|
||||||
overrides = state.get("config_overrides", {})
|
|
||||||
prep_config = overrides.get("preprocessing", {})
|
|
||||||
do_contrast = prep_config.get("contrast", True)
|
do_contrast = prep_config.get("contrast", True)
|
||||||
do_deskew = prep_config.get("deskew", False)
|
do_deskew = prep_config.get("deskew", False)
|
||||||
do_binarize = prep_config.get("binarize", False)
|
do_binarize = prep_config.get("binarize", False)
|
||||||
@@ -178,8 +218,8 @@ def node_run_ocr(state: DetectState) -> dict:
|
|||||||
_emit(state, "run_ocr", "running")
|
_emit(state, "run_ocr", "running")
|
||||||
|
|
||||||
with trace_node(state, "run_ocr") as span:
|
with trace_node(state, "run_ocr") as span:
|
||||||
profile = _get_profile(state)
|
profile = _load_profile(state)
|
||||||
config = profile.ocr_config()
|
config = OCRConfig(**get_stage_config(profile, "run_ocr"))
|
||||||
frames = state.get("filtered_frames", [])
|
frames = state.get("filtered_frames", [])
|
||||||
boxes = state.get("boxes_by_frame", {})
|
boxes = state.get("boxes_by_frame", {})
|
||||||
job_id = state.get("job_id")
|
job_id = state.get("job_id")
|
||||||
@@ -198,18 +238,18 @@ def node_match_brands(state: DetectState) -> dict:
|
|||||||
_emit(state, "match_brands", "running")
|
_emit(state, "match_brands", "running")
|
||||||
|
|
||||||
with trace_node(state, "match_brands") as span:
|
with trace_node(state, "match_brands") as span:
|
||||||
profile = _get_profile(state)
|
profile = _load_profile(state)
|
||||||
resolver_config = profile.resolver_config()
|
config = ResolverConfig(**get_stage_config(profile, "match_brands"))
|
||||||
candidates = state.get("text_candidates", [])
|
candidates = state.get("text_candidates", [])
|
||||||
session_brands = state.get("session_brands", {})
|
session_brands = state.get("session_brands", {})
|
||||||
job_id = state.get("job_id")
|
job_id = state.get("job_id")
|
||||||
source_asset_id = state.get("source_asset_id")
|
source_asset_id = state.get("source_asset_id")
|
||||||
|
|
||||||
matched, unresolved = resolve_brands(
|
matched, unresolved = resolve_brands(
|
||||||
candidates, resolver_config,
|
candidates, config,
|
||||||
session_brands=session_brands,
|
session_brands=session_brands,
|
||||||
source_asset_id=source_asset_id,
|
source_asset_id=source_asset_id,
|
||||||
content_type=profile.name, job_id=job_id,
|
content_type=profile["name"], job_id=job_id,
|
||||||
)
|
)
|
||||||
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
|
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
|
||||||
|
|
||||||
@@ -221,15 +261,19 @@ def node_escalate_vlm(state: DetectState) -> dict:
|
|||||||
_emit(state, "escalate_vlm", "running")
|
_emit(state, "escalate_vlm", "running")
|
||||||
|
|
||||||
with trace_node(state, "escalate_vlm") as span:
|
with trace_node(state, "escalate_vlm") as span:
|
||||||
profile = _get_profile(state)
|
profile = _load_profile(state)
|
||||||
|
vlm_config = get_stage_config(profile, "escalate_vlm")
|
||||||
|
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
|
||||||
candidates = state.get("unresolved_candidates", [])
|
candidates = state.get("unresolved_candidates", [])
|
||||||
job_id = state.get("job_id")
|
job_id = state.get("job_id")
|
||||||
|
|
||||||
|
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
|
||||||
|
|
||||||
vlm_matched, still_unresolved = escalate_vlm(
|
vlm_matched, still_unresolved = escalate_vlm(
|
||||||
candidates,
|
candidates,
|
||||||
vlm_prompt_fn=profile.vlm_prompt,
|
vlm_prompt_fn=vlm_prompt_fn,
|
||||||
inference_url=INFERENCE_URL,
|
inference_url=INFERENCE_URL,
|
||||||
content_type=profile.name,
|
content_type=profile["name"],
|
||||||
source_asset_id=state.get("source_asset_id"),
|
source_asset_id=state.get("source_asset_id"),
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
)
|
)
|
||||||
@@ -254,16 +298,20 @@ def node_escalate_cloud(state: DetectState) -> dict:
|
|||||||
_emit(state, "escalate_cloud", "running")
|
_emit(state, "escalate_cloud", "running")
|
||||||
|
|
||||||
with trace_node(state, "escalate_cloud") as span:
|
with trace_node(state, "escalate_cloud") as span:
|
||||||
profile = _get_profile(state)
|
profile = _load_profile(state)
|
||||||
|
vlm_config = get_stage_config(profile, "escalate_vlm")
|
||||||
|
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
|
||||||
candidates = state.get("unresolved_candidates", [])
|
candidates = state.get("unresolved_candidates", [])
|
||||||
job_id = state.get("job_id")
|
job_id = state.get("job_id")
|
||||||
stats = state.get("stats", PipelineStats())
|
stats = state.get("stats", PipelineStats())
|
||||||
|
|
||||||
|
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
|
||||||
|
|
||||||
cloud_matched = escalate_cloud(
|
cloud_matched = escalate_cloud(
|
||||||
candidates,
|
candidates,
|
||||||
vlm_prompt_fn=profile.vlm_prompt,
|
vlm_prompt_fn=vlm_prompt_fn,
|
||||||
stats=stats,
|
stats=stats,
|
||||||
content_type=profile.name,
|
content_type=profile["name"],
|
||||||
source_asset_id=state.get("source_asset_id"),
|
source_asset_id=state.get("source_asset_id"),
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
)
|
)
|
||||||
@@ -283,7 +331,7 @@ def node_compile_report(state: DetectState) -> dict:
|
|||||||
_emit(state, "compile_report", "running")
|
_emit(state, "compile_report", "running")
|
||||||
|
|
||||||
with trace_node(state, "compile_report") as span:
|
with trace_node(state, "compile_report") as span:
|
||||||
profile = _get_profile(state)
|
profile = _load_profile(state)
|
||||||
detections = state.get("detections", [])
|
detections = state.get("detections", [])
|
||||||
stats = state.get("stats", PipelineStats())
|
stats = state.get("stats", PipelineStats())
|
||||||
job_id = state.get("job_id")
|
job_id = state.get("job_id")
|
||||||
@@ -292,7 +340,7 @@ def node_compile_report(state: DetectState) -> dict:
|
|||||||
detections=detections,
|
detections=detections,
|
||||||
stats=stats,
|
stats=stats,
|
||||||
video_source=state.get("video_path", ""),
|
video_source=state.get("video_path", ""),
|
||||||
content_type=profile.name,
|
content_type=profile["name"],
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -306,6 +354,7 @@ def node_compile_report(state: DetectState) -> dict:
|
|||||||
NODE_FUNCTIONS = [
|
NODE_FUNCTIONS = [
|
||||||
("extract_frames", node_extract_frames),
|
("extract_frames", node_extract_frames),
|
||||||
("filter_scenes", node_filter_scenes),
|
("filter_scenes", node_filter_scenes),
|
||||||
|
("field_segmentation", node_field_segmentation),
|
||||||
("detect_edges", node_detect_edges),
|
("detect_edges", node_detect_edges),
|
||||||
("detect_objects", node_detect_objects),
|
("detect_objects", node_detect_objects),
|
||||||
("preprocess", node_preprocess),
|
("preprocess", node_preprocess),
|
||||||
@@ -13,8 +13,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from core.schema.models.pipeline_config import PipelineConfig
|
from core.detect.stages.models import PipelineConfig
|
||||||
from detect.state import DetectState
|
from core.detect.state import DetectState
|
||||||
from .nodes import NODES, NODE_FUNCTIONS
|
from .nodes import NODES, NODE_FUNCTIONS
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -118,7 +118,7 @@ def _wait_if_paused(job_id: str, node_name: str):
|
|||||||
|
|
||||||
if _pause_after_stage.get(job_id, False):
|
if _pause_after_stage.get(job_id, False):
|
||||||
gate.clear()
|
gate.clear()
|
||||||
from detect import emit
|
from core.detect import emit
|
||||||
emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}")
|
emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}")
|
||||||
|
|
||||||
while not gate.wait(timeout=0.5):
|
while not gate.wait(timeout=0.5):
|
||||||
@@ -237,7 +237,7 @@ class PipelineRunner:
|
|||||||
|
|
||||||
# 4. Checkpoint
|
# 4. Checkpoint
|
||||||
if self.do_checkpoint:
|
if self.do_checkpoint:
|
||||||
from detect.checkpoint import checkpoint_after_stage
|
from core.detect.checkpoint import checkpoint_after_stage
|
||||||
checkpoint_after_stage(job_id, stage_name, state, result)
|
checkpoint_after_stage(job_id, stage_name, state, result)
|
||||||
|
|
||||||
# 5. Pause check
|
# 5. Pause check
|
||||||
@@ -256,11 +256,11 @@ def get_pipeline(
|
|||||||
start_from: str | None = None,
|
start_from: str | None = None,
|
||||||
) -> PipelineRunner:
|
) -> PipelineRunner:
|
||||||
"""Return a PipelineRunner for the given profile."""
|
"""Return a PipelineRunner for the given profile."""
|
||||||
from detect.profiles import get_profile
|
from core.detect.profile import get_profile, pipeline_config_from_dict
|
||||||
|
|
||||||
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
|
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
|
||||||
profile = get_profile(profile_name)
|
profile = get_profile(profile_name)
|
||||||
config = profile.pipeline_config()
|
config = pipeline_config_from_dict(profile["pipeline"])
|
||||||
|
|
||||||
return PipelineRunner(
|
return PipelineRunner(
|
||||||
config=config,
|
config=config,
|
||||||
@@ -231,6 +231,20 @@ class InferenceClient:
|
|||||||
pair_count=data.get("pair_count", 0),
|
pair_count=data.get("pair_count", 0),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def post(self, path: str, payload: dict) -> dict | None:
|
||||||
|
"""Generic POST to the inference server. Returns JSON response or None on error."""
|
||||||
|
try:
|
||||||
|
resp = self.session.post(
|
||||||
|
f"{self.base_url}{path}",
|
||||||
|
json=payload,
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Inference POST %s failed: %s", path, e)
|
||||||
|
return None
|
||||||
|
|
||||||
def load_model(self, model: str, quantization: str = "fp16") -> None:
|
def load_model(self, model: str, quantization: str = "fp16") -> None:
|
||||||
"""Request the server to load a model into VRAM."""
|
"""Request the server to load a model into VRAM."""
|
||||||
self.session.post(
|
self.session.post(
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
Inference response types.
|
Inference response types.
|
||||||
|
|
||||||
These are the shapes returned by the inference server.
|
These are the shapes returned by the inference server.
|
||||||
Kept separate from detect.models to avoid coupling the
|
Kept separate from core.detect.models to avoid coupling the
|
||||||
inference protocol to pipeline internals.
|
inference protocol to pipeline internals.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -2,8 +2,8 @@
|
|||||||
Detection pipeline runtime models.
|
Detection pipeline runtime models.
|
||||||
|
|
||||||
These are the data structures that flow between pipeline stages.
|
These are the data structures that flow between pipeline stages.
|
||||||
They contain runtime types (np.ndarray) so modelgen skips them —
|
They contain runtime types (np.ndarray) so they live here, not in
|
||||||
not generated to SQLModel or TypeScript.
|
core/schema/models/ (which is for modelgen source of truth).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -85,3 +85,11 @@ class DetectionReport:
|
|||||||
brands: dict[str, BrandStats] = field(default_factory=dict)
|
brands: dict[str, BrandStats] = field(default_factory=dict)
|
||||||
timeline: list[BrandDetection] = field(default_factory=list)
|
timeline: list[BrandDetection] = field(default_factory=list)
|
||||||
pipeline_stats: PipelineStats = field(default_factory=PipelineStats)
|
pipeline_stats: PipelineStats = field(default_factory=PipelineStats)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CropContext:
|
||||||
|
"""Runtime type — holds image bytes for VLM prompts."""
|
||||||
|
image: bytes
|
||||||
|
surrounding_text: str = ""
|
||||||
|
position_hint: str = ""
|
||||||
107
core/detect/profile.py
Normal file
107
core/detect/profile.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
"""
|
||||||
|
Profile registry and helpers.
|
||||||
|
|
||||||
|
Loads profile data from Postgres.
|
||||||
|
A profile is a dict with keys: name, pipeline, configs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from core.detect.stages.models import PipelineConfig, StageRef, Edge
|
||||||
|
from core.detect.models import (
|
||||||
|
BrandDetection,
|
||||||
|
BrandStats,
|
||||||
|
CropContext,
|
||||||
|
DetectionReport,
|
||||||
|
PipelineStats,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_profile(name: str) -> Dict[str, Any]:
|
||||||
|
"""Get a profile dict by name from the database."""
|
||||||
|
from core.db.connection import get_session
|
||||||
|
from core.db.models import Profile
|
||||||
|
|
||||||
|
with get_session() as session:
|
||||||
|
row = session.query(Profile).filter(Profile.name == name).first()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
raise ValueError(f"Unknown profile: {name!r}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": row.name,
|
||||||
|
"pipeline": row.pipeline or {},
|
||||||
|
"configs": row.configs or {},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def list_profiles() -> list[str]:
|
||||||
|
"""List available profile names from the database."""
|
||||||
|
from core.db.connection import get_session
|
||||||
|
from core.db.models import Profile
|
||||||
|
|
||||||
|
with get_session() as session:
|
||||||
|
rows = session.query(Profile.name).all()
|
||||||
|
|
||||||
|
return [r[0] for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
def get_stage_config(profile: Dict[str, Any], stage_name: str) -> dict:
|
||||||
|
"""Get config values for a stage from a profile."""
|
||||||
|
return profile.get("configs", {}).get(stage_name, {})
|
||||||
|
|
||||||
|
|
||||||
|
def pipeline_config_from_dict(data: Dict[str, Any]) -> PipelineConfig:
|
||||||
|
"""Deserialize a PipelineConfig from a JSONB dict."""
|
||||||
|
stages = [StageRef(**s) for s in data.get("stages", [])]
|
||||||
|
edges = [Edge(**e) for e in data.get("edges", [])]
|
||||||
|
return PipelineConfig(
|
||||||
|
name=data.get("name", ""),
|
||||||
|
profile_name=data.get("profile_name", ""),
|
||||||
|
stages=stages,
|
||||||
|
edges=edges,
|
||||||
|
routing_rules=data.get("routing_rules", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_vlm_prompt(crop_context: CropContext, template: str) -> str:
|
||||||
|
"""Build a VLM prompt from a template and crop context."""
|
||||||
|
hint = f" Position: {crop_context.position_hint}." if crop_context.position_hint else ""
|
||||||
|
text = f" Nearby text: '{crop_context.surrounding_text}'." if crop_context.surrounding_text else ""
|
||||||
|
return template.format(hint=hint, text=text)
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_detections(
|
||||||
|
detections: list[BrandDetection],
|
||||||
|
content_type: str,
|
||||||
|
) -> DetectionReport:
|
||||||
|
"""Group detections by brand into a report."""
|
||||||
|
brands: dict[str, BrandStats] = {}
|
||||||
|
for d in detections:
|
||||||
|
if d.brand not in brands:
|
||||||
|
brands[d.brand] = BrandStats()
|
||||||
|
s = brands[d.brand]
|
||||||
|
s.total_appearances += 1
|
||||||
|
s.total_screen_time += d.duration
|
||||||
|
s.avg_confidence = (
|
||||||
|
(s.avg_confidence * (s.total_appearances - 1) + d.confidence)
|
||||||
|
/ s.total_appearances
|
||||||
|
)
|
||||||
|
if s.first_seen == 0.0 or d.timestamp < s.first_seen:
|
||||||
|
s.first_seen = d.timestamp
|
||||||
|
if d.timestamp > s.last_seen:
|
||||||
|
s.last_seen = d.timestamp
|
||||||
|
|
||||||
|
return DetectionReport(
|
||||||
|
video_source="",
|
||||||
|
content_type=content_type,
|
||||||
|
duration_seconds=0.0,
|
||||||
|
brands=brands,
|
||||||
|
timeline=sorted(detections, key=lambda d: d.timestamp),
|
||||||
|
pipeline_stats=PipelineStats(),
|
||||||
|
)
|
||||||
@@ -145,3 +145,19 @@ class RetryResponse(BaseModel):
|
|||||||
status: str
|
status: str
|
||||||
task_id: str
|
task_id: str
|
||||||
job_id: str
|
job_id: str
|
||||||
|
|
||||||
|
class RunRequest(BaseModel):
|
||||||
|
"""Request body for launching a detection pipeline run."""
|
||||||
|
video_path: str
|
||||||
|
profile_name: str = "soccer_broadcast"
|
||||||
|
source_asset_id: str = ""
|
||||||
|
checkpoint: bool = True
|
||||||
|
skip_vlm: bool = False
|
||||||
|
skip_cloud: bool = False
|
||||||
|
log_level: str = "INFO"
|
||||||
|
|
||||||
|
class RunResponse(BaseModel):
|
||||||
|
"""Response after starting a pipeline run."""
|
||||||
|
status: str
|
||||||
|
job_id: str
|
||||||
|
video_path: str
|
||||||
@@ -16,6 +16,7 @@ from .base import (
|
|||||||
|
|
||||||
# Import all stage files to trigger auto-registration
|
# Import all stage files to trigger auto-registration
|
||||||
from . import edge_detector # noqa: F401
|
from . import edge_detector # noqa: F401
|
||||||
|
from . import field_segmentation # noqa: F401
|
||||||
|
|
||||||
# Import registry for backward compat (other stages still use old pattern)
|
# Import registry for backward compat (other stages still use old pattern)
|
||||||
from . import registry # noqa: F401
|
from . import registry # noqa: F401
|
||||||
@@ -9,8 +9,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from detect import emit
|
from core.detect import emit
|
||||||
from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
|
from core.detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -5,7 +5,7 @@ Each stage is a file that subclasses Stage. Auto-discovered via
|
|||||||
__init_subclass__. No manual registration needed.
|
__init_subclass__. No manual registration needed.
|
||||||
|
|
||||||
A stage:
|
A stage:
|
||||||
- Has a StageDefinition (from schema) with name, config, IO
|
- Has a StageDefinition (generated from schema) with name, config, IO
|
||||||
- Implements run(frames, config) → output
|
- Implements run(frames, config) → output
|
||||||
- Owns its output serialization (opaque blob)
|
- Owns its output serialization (opaque blob)
|
||||||
- Optionally has a TypeScript port for browser-side execution
|
- Optionally has a TypeScript port for browser-side execution
|
||||||
@@ -16,12 +16,30 @@ the format. The stage that wrote it is the only one that can read it.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from core.schema.models.stages import StageConfigField, StageIO, StageDefinition
|
from core.detect.stages.models import (
|
||||||
|
StageConfigField,
|
||||||
|
StageIO,
|
||||||
|
StageDefinition,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Legacy runtime extension — adds callable fields for old-style stages.
|
||||||
|
# New stages use Stage subclass with serialize()/deserialize() methods instead.
|
||||||
|
class LegacyStageDefinition:
|
||||||
|
"""Wraps a StageDefinition with callable serialize/deserialize functions."""
|
||||||
|
|
||||||
|
def __init__(self, definition: StageDefinition, fn=None, serialize_fn=None, deserialize_fn=None):
|
||||||
|
self._definition = definition
|
||||||
|
self.fn = fn
|
||||||
|
self.serialize_fn = serialize_fn
|
||||||
|
self.deserialize_fn = deserialize_fn
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self._definition, name)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -30,12 +48,18 @@ from core.schema.models.stages import StageConfigField, StageIO, StageDefinition
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
_REGISTRY: dict[str, type['Stage']] = {}
|
_REGISTRY: dict[str, type['Stage']] = {}
|
||||||
_LEGACY_REGISTRY: dict[str, StageDefinition] = {}
|
_LEGACY_REGISTRY: dict[str, LegacyStageDefinition] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_stage(definition: StageDefinition):
|
def register_stage(
|
||||||
|
definition: StageDefinition,
|
||||||
|
fn=None,
|
||||||
|
serialize_fn=None,
|
||||||
|
deserialize_fn=None,
|
||||||
|
):
|
||||||
"""Legacy registration for stages not yet converted to Stage subclass."""
|
"""Legacy registration for stages not yet converted to Stage subclass."""
|
||||||
_LEGACY_REGISTRY[definition.name] = definition
|
legacy = LegacyStageDefinition(definition, fn=fn, serialize_fn=serialize_fn, deserialize_fn=deserialize_fn)
|
||||||
|
_LEGACY_REGISTRY[definition.name] = legacy
|
||||||
|
|
||||||
|
|
||||||
class Stage:
|
class Stage:
|
||||||
@@ -55,13 +79,6 @@ class Stage:
|
|||||||
_REGISTRY[cls.definition.name] = cls
|
_REGISTRY[cls.definition.name] = cls
|
||||||
|
|
||||||
def run(self, frames: list, config: dict) -> Any:
|
def run(self, frames: list, config: dict) -> Any:
|
||||||
"""
|
|
||||||
Run the stage on a list of frames with the given config.
|
|
||||||
|
|
||||||
Config is a dict of parameter values (from slider UI or profile).
|
|
||||||
Returns the stage output — whatever shape this stage produces.
|
|
||||||
Debug overlays are included when config has debug=True.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def serialize(self, output: Any) -> bytes:
|
def serialize(self, output: Any) -> bytes:
|
||||||
@@ -79,12 +96,15 @@ class Stage:
|
|||||||
# Discovery API
|
# Discovery API
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def _all_definitions() -> dict[str, StageDefinition]:
|
def _all_definitions():
|
||||||
"""Merge new Stage subclass registry + legacy registry."""
|
"""Merge new Stage subclass registry + legacy registry.
|
||||||
|
|
||||||
|
Returns StageDefinition for new-style stages,
|
||||||
|
LegacyStageDefinition for legacy stages (has serialize_fn etc).
|
||||||
|
"""
|
||||||
merged = {}
|
merged = {}
|
||||||
# Legacy first, new overwrites (new takes precedence)
|
for name, legacy in _LEGACY_REGISTRY.items():
|
||||||
for name, defn in _LEGACY_REGISTRY.items():
|
merged[name] = legacy
|
||||||
merged[name] = defn
|
|
||||||
for name, cls in _REGISTRY.items():
|
for name, cls in _REGISTRY.items():
|
||||||
merged[name] = cls.definition
|
merged[name] = cls.definition
|
||||||
return merged
|
return merged
|
||||||
@@ -16,9 +16,9 @@ import logging
|
|||||||
|
|
||||||
from rapidfuzz import fuzz
|
from rapidfuzz import fuzz
|
||||||
|
|
||||||
from detect import emit
|
from core.detect import emit
|
||||||
from detect.models import BrandDetection, TextCandidate
|
from core.detect.models import BrandDetection, TextCandidate
|
||||||
from detect.profiles.base import ResolverConfig
|
from core.detect.stages.models import ResolverConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -21,10 +21,10 @@ from typing import Any
|
|||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from detect import emit
|
from core.detect import emit
|
||||||
from detect.models import BoundingBox, Frame
|
from core.detect.models import BoundingBox, Frame
|
||||||
from detect.stages.base import Stage
|
from core.detect.stages.base import Stage
|
||||||
from core.schema.models.stages import StageDefinition, StageConfigField, StageIO
|
from core.detect.stages.models import StageDefinition, StageConfigField, StageIO
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -42,14 +42,14 @@ class EdgeDetectionStage(Stage):
|
|||||||
writes=["edge_regions_by_frame"],
|
writes=["edge_regions_by_frame"],
|
||||||
),
|
),
|
||||||
config_fields=[
|
config_fields=[
|
||||||
StageConfigField("enabled", "bool", True, "Enable edge detection"),
|
StageConfigField(name="enabled", type="bool", default=True, description="Enable edge detection"),
|
||||||
StageConfigField("edge_canny_low", "int", 50, "Canny low threshold", min=0, max=255),
|
StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255),
|
||||||
StageConfigField("edge_canny_high", "int", 150, "Canny high threshold", min=0, max=255),
|
StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255),
|
||||||
StageConfigField("edge_hough_threshold", "int", 80, "Hough accumulator threshold", min=1, max=500),
|
StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500),
|
||||||
StageConfigField("edge_hough_min_length", "int", 100, "Min line length (px)", min=10, max=2000),
|
StageConfigField(name="edge_hough_min_length", type="int", default=100, description="Min line length (px)", min=10, max=2000),
|
||||||
StageConfigField("edge_hough_max_gap", "int", 10, "Max line gap (px)", min=1, max=100),
|
StageConfigField(name="edge_hough_max_gap", type="int", default=10, description="Max line gap (px)", min=1, max=100),
|
||||||
StageConfigField("edge_pair_max_distance", "int", 200, "Max distance between line pair (px)", min=10, max=500),
|
StageConfigField(name="edge_pair_max_distance", type="int", default=200, description="Max distance between line pair (px)", min=10, max=500),
|
||||||
StageConfigField("edge_pair_min_distance", "int", 15, "Min distance between line pair (px)", min=5, max=200),
|
StageConfigField(name="edge_pair_min_distance", type="int", default=15, description="Min distance between line pair (px)", min=5, max=200),
|
||||||
],
|
],
|
||||||
tracks_element="edge_region",
|
tracks_element="edge_region",
|
||||||
)
|
)
|
||||||
@@ -143,8 +143,8 @@ class EdgeDetectionStage(Stage):
|
|||||||
|
|
||||||
def _run_remote(self, frame: Frame, config: dict,
|
def _run_remote(self, frame: Frame, config: dict,
|
||||||
inference_url: str, job_id: str) -> list[BoundingBox]:
|
inference_url: str, job_id: str) -> list[BoundingBox]:
|
||||||
from detect.inference import InferenceClient
|
from core.detect.inference import InferenceClient
|
||||||
from detect.emit import _run_log_level
|
from core.detect.emit import _run_log_level
|
||||||
|
|
||||||
client = InferenceClient(
|
client = InferenceClient(
|
||||||
base_url=inference_url, job_id=job_id, log_level=_run_log_level,
|
base_url=inference_url, job_id=job_id, log_level=_run_log_level,
|
||||||
@@ -208,7 +208,7 @@ def _load_cv_edges():
|
|||||||
if _cv_edges_mod is None:
|
if _cv_edges_mod is None:
|
||||||
import importlib.util
|
import importlib.util
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
spec = importlib.util.spec_from_file_location("cv_edges", Path("gpu/models/cv/edges.py"))
|
spec = importlib.util.spec_from_file_location("cv_edges", Path("core/gpu/models/cv/edges.py"))
|
||||||
_cv_edges_mod = importlib.util.module_from_spec(spec)
|
_cv_edges_mod = importlib.util.module_from_spec(spec)
|
||||||
spec.loader.exec_module(_cv_edges_mod)
|
spec.loader.exec_module(_cv_edges_mod)
|
||||||
return _cv_edges_mod
|
return _cv_edges_mod
|
||||||
@@ -216,8 +216,43 @@ def _load_cv_edges():
|
|||||||
|
|
||||||
# --- Backward compat: standalone function for graph.py ---
|
# --- Backward compat: standalone function for graph.py ---
|
||||||
|
|
||||||
def detect_edge_regions(frames, config, inference_url=None, job_id=None):
|
def _filter_by_field_mask(boxes, mask, margin_px=50):
|
||||||
"""Convenience wrapper — calls EdgeDetectionStage.run()."""
|
"""
|
||||||
|
Keep only boxes that are near the pitch boundary (hoarding zone).
|
||||||
|
|
||||||
|
The field mask has 255=pitch, 0=not pitch. Hoardings sit just
|
||||||
|
outside the pitch boundary. We dilate the mask to create a
|
||||||
|
"boundary zone" and keep boxes whose center falls in the zone
|
||||||
|
between the dilated mask edge and the original mask.
|
||||||
|
"""
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
if mask is None or not boxes:
|
||||||
|
return boxes
|
||||||
|
|
||||||
|
# Dilate the pitch mask — the expansion zone is where hoardings are
|
||||||
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (margin_px * 2, margin_px * 2))
|
||||||
|
dilated = cv2.dilate(mask, kernel)
|
||||||
|
|
||||||
|
# Boundary zone = dilated but NOT original pitch
|
||||||
|
boundary_zone = cv2.bitwise_and(dilated, cv2.bitwise_not(mask))
|
||||||
|
|
||||||
|
kept = []
|
||||||
|
for box in boxes:
|
||||||
|
cx = box.x + box.w // 2
|
||||||
|
cy = box.y + box.h // 2
|
||||||
|
# Clamp to image bounds
|
||||||
|
cy = min(cy, boundary_zone.shape[0] - 1)
|
||||||
|
cx = min(cx, boundary_zone.shape[1] - 1)
|
||||||
|
if boundary_zone[cy, cx] > 0:
|
||||||
|
kept.append(box)
|
||||||
|
|
||||||
|
return kept
|
||||||
|
|
||||||
|
|
||||||
|
def detect_edge_regions(frames, config, inference_url=None, job_id=None, field_masks=None):
|
||||||
|
"""Convenience wrapper — calls EdgeDetectionStage.run(), optionally filters by field mask."""
|
||||||
stage = EdgeDetectionStage()
|
stage = EdgeDetectionStage()
|
||||||
cfg = {
|
cfg = {
|
||||||
"enabled": config.enabled,
|
"enabled": config.enabled,
|
||||||
@@ -231,4 +266,23 @@ def detect_edge_regions(frames, config, inference_url=None, job_id=None):
|
|||||||
"inference_url": inference_url,
|
"inference_url": inference_url,
|
||||||
"job_id": job_id,
|
"job_id": job_id,
|
||||||
}
|
}
|
||||||
return stage.run(frames, cfg)
|
all_boxes = stage.run(frames, cfg)
|
||||||
|
|
||||||
|
# Filter by field segmentation mask if available
|
||||||
|
if field_masks:
|
||||||
|
filtered_total = 0
|
||||||
|
original_total = sum(len(b) for b in all_boxes.values())
|
||||||
|
for seq, boxes in all_boxes.items():
|
||||||
|
mask = field_masks.get(seq)
|
||||||
|
if mask is not None:
|
||||||
|
all_boxes[seq] = _filter_by_field_mask(boxes, mask)
|
||||||
|
filtered_total += len(all_boxes[seq])
|
||||||
|
else:
|
||||||
|
filtered_total += len(boxes)
|
||||||
|
|
||||||
|
if original_total != filtered_total:
|
||||||
|
from core.detect import emit
|
||||||
|
emit.log(job_id, "EdgeDetection", "INFO",
|
||||||
|
f"Field mask filter: {original_total} → {filtered_total} regions")
|
||||||
|
|
||||||
|
return all_boxes
|
||||||
141
core/detect/stages/field_segmentation.py
Normal file
141
core/detect/stages/field_segmentation.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""
|
||||||
|
Stage — Field Segmentation
|
||||||
|
|
||||||
|
Calls the GPU inference server to detect pitch boundaries via
|
||||||
|
HSV green mask + morphology. The CV code lives in core/gpu/models/cv/.
|
||||||
|
|
||||||
|
Outputs a mask and boundary that downstream stages use as spatial priors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from core.detect import emit
|
||||||
|
from core.detect.models import Frame
|
||||||
|
from core.detect.stages.base import Stage
|
||||||
|
from core.detect.stages.models import (
|
||||||
|
FieldSegmentationConfig,
|
||||||
|
StageConfigField,
|
||||||
|
StageDefinition,
|
||||||
|
StageIO,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FieldSegmentationStage(Stage):
|
||||||
|
|
||||||
|
definition = StageDefinition(
|
||||||
|
name="field_segmentation",
|
||||||
|
label="Field Segmentation",
|
||||||
|
description="HSV green mask — detect pitch boundaries for spatial priors",
|
||||||
|
category="cv_analysis",
|
||||||
|
io=StageIO(
|
||||||
|
reads=["filtered_frames"],
|
||||||
|
writes=["field_mask"],
|
||||||
|
),
|
||||||
|
config_fields=[
|
||||||
|
StageConfigField(name="enabled", type="bool", default=True, description="Enable field segmentation"),
|
||||||
|
StageConfigField(name="hue_low", type="int", default=30, description="HSV hue lower bound", min=0, max=180),
|
||||||
|
StageConfigField(name="hue_high", type="int", default=85, description="HSV hue upper bound", min=0, max=180),
|
||||||
|
StageConfigField(name="sat_low", type="int", default=30, description="HSV saturation lower bound", min=0, max=255),
|
||||||
|
StageConfigField(name="sat_high", type="int", default=255, description="HSV saturation upper bound", min=0, max=255),
|
||||||
|
StageConfigField(name="val_low", type="int", default=30, description="HSV value lower bound", min=0, max=255),
|
||||||
|
StageConfigField(name="val_high", type="int", default=255, description="HSV value upper bound", min=0, max=255),
|
||||||
|
StageConfigField(name="morph_kernel", type="int", default=15, description="Morphology kernel size", min=3, max=51),
|
||||||
|
StageConfigField(name="min_area_ratio", type="float", default=0.05, description="Min contour area as fraction of frame", min=0.01, max=0.5),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _frame_to_b64(frame: Frame) -> str:
|
||||||
|
"""Encode frame image as base64 JPEG."""
|
||||||
|
img = Image.fromarray(frame.image)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
img.save(buf, format="JPEG", quality=85)
|
||||||
|
return base64.b64encode(buf.getvalue()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_mask_b64(mask_b64: str) -> np.ndarray:
|
||||||
|
"""Decode a base64 PNG mask back to numpy array."""
|
||||||
|
data = base64.b64decode(mask_b64)
|
||||||
|
img = Image.open(io.BytesIO(data)).convert("L")
|
||||||
|
return np.array(img)
|
||||||
|
|
||||||
|
|
||||||
|
def run_field_segmentation(
|
||||||
|
frames: list[Frame],
|
||||||
|
config: FieldSegmentationConfig,
|
||||||
|
inference_url: str | None = None,
|
||||||
|
job_id: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Run field segmentation on all frames via the inference server.
|
||||||
|
|
||||||
|
Returns dict with:
|
||||||
|
field_masks: {seq: np.ndarray}
|
||||||
|
field_boundaries: {seq: [(x,y), ...]}
|
||||||
|
field_coverage: {seq: float}
|
||||||
|
"""
|
||||||
|
if not config.enabled:
|
||||||
|
emit.log(job_id, "FieldSegmentation", "INFO", "Disabled, skipping")
|
||||||
|
return {"field_masks": {}, "field_boundaries": {}, "field_coverage": {}}
|
||||||
|
|
||||||
|
import os
|
||||||
|
url = inference_url or os.environ.get("INFERENCE_URL")
|
||||||
|
if not url:
|
||||||
|
emit.log(job_id, "FieldSegmentation", "WARNING",
|
||||||
|
"No INFERENCE_URL, skipping field segmentation")
|
||||||
|
return {"field_masks": {}, "field_boundaries": {}, "field_coverage": {}}
|
||||||
|
|
||||||
|
emit.log(job_id, "FieldSegmentation", "INFO",
|
||||||
|
f"Segmenting {len(frames)} frames (hue={config.hue_low}-{config.hue_high})")
|
||||||
|
|
||||||
|
from core.detect.inference import InferenceClient
|
||||||
|
from core.detect.emit import _run_log_level
|
||||||
|
client = InferenceClient(base_url=url, job_id=job_id or "", log_level=_run_log_level)
|
||||||
|
|
||||||
|
field_masks = {}
|
||||||
|
field_boundaries = {}
|
||||||
|
field_coverage = {}
|
||||||
|
|
||||||
|
for frame in frames:
|
||||||
|
image_b64 = _frame_to_b64(frame)
|
||||||
|
|
||||||
|
resp = client.post("/segment_field", {
|
||||||
|
"image": image_b64,
|
||||||
|
"hue_low": config.hue_low,
|
||||||
|
"hue_high": config.hue_high,
|
||||||
|
"sat_low": config.sat_low,
|
||||||
|
"sat_high": config.sat_high,
|
||||||
|
"val_low": config.val_low,
|
||||||
|
"val_high": config.val_high,
|
||||||
|
"morph_kernel": config.morph_kernel,
|
||||||
|
"min_area_ratio": config.min_area_ratio,
|
||||||
|
})
|
||||||
|
|
||||||
|
if resp is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
mask_b64 = resp.get("mask_b64", "")
|
||||||
|
if mask_b64:
|
||||||
|
field_masks[frame.sequence] = _decode_mask_b64(mask_b64)
|
||||||
|
|
||||||
|
field_boundaries[frame.sequence] = resp.get("boundary", [])
|
||||||
|
field_coverage[frame.sequence] = resp.get("coverage", 0.0)
|
||||||
|
|
||||||
|
avg_coverage = sum(field_coverage.values()) / max(len(field_coverage), 1)
|
||||||
|
emit.log(job_id, "FieldSegmentation", "INFO",
|
||||||
|
f"Done: {len(frames)} frames, avg coverage {avg_coverage:.1%}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"field_masks": field_masks,
|
||||||
|
"field_boundaries": field_boundaries,
|
||||||
|
"field_coverage": field_coverage,
|
||||||
|
}
|
||||||
@@ -16,9 +16,9 @@ import numpy as np
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from core.ffmpeg.probe import probe_file
|
from core.ffmpeg.probe import probe_file
|
||||||
from detect import emit
|
from core.detect import emit
|
||||||
from detect.models import Frame
|
from core.detect.models import Frame
|
||||||
from detect.profiles.base import FrameExtractionConfig
|
from core.detect.stages.models import FrameExtractionConfig
|
||||||
|
|
||||||
|
|
||||||
def _load_frames(tmpdir: Path, fps: float) -> list[Frame]:
|
def _load_frames(tmpdir: Path, fps: float) -> list[Frame]:
|
||||||
106
core/detect/stages/models.py
Normal file
106
core/detect/stages/models.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""
|
||||||
|
Pydantic Models - GENERATED FILE
|
||||||
|
|
||||||
|
Do not edit directly. Regenerate using modelgen.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
class StageConfigField(BaseModel):
|
||||||
|
"""A single tunable config parameter for the editor UI."""
|
||||||
|
name: str
|
||||||
|
type: str
|
||||||
|
default: Any
|
||||||
|
description: str = ""
|
||||||
|
min: Optional[float] = None
|
||||||
|
max: Optional[float] = None
|
||||||
|
options: Optional[List[str]] = None
|
||||||
|
|
||||||
|
class StageIO(BaseModel):
|
||||||
|
"""Declares what a stage reads and writes."""
|
||||||
|
reads: List[str] = Field(default_factory=list)
|
||||||
|
writes: List[str] = Field(default_factory=list)
|
||||||
|
optional_reads: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
class StageDefinition(BaseModel):
|
||||||
|
"""Complete metadata for a pipeline stage."""
|
||||||
|
name: str
|
||||||
|
label: str
|
||||||
|
description: str
|
||||||
|
category: str = "detection"
|
||||||
|
io: StageIO
|
||||||
|
config_fields: List[StageConfigField] = Field(default_factory=list)
|
||||||
|
tracks_element: Optional[str] = None
|
||||||
|
|
||||||
|
class FrameExtractionConfig(BaseModel):
|
||||||
|
"""FrameExtractionConfig(fps: float = 2.0, max_frames: int = 500)"""
|
||||||
|
fps: float = 2.0
|
||||||
|
max_frames: int = 500
|
||||||
|
|
||||||
|
class SceneFilterConfig(BaseModel):
|
||||||
|
"""SceneFilterConfig(hamming_threshold: int = 8, enabled: bool = True)"""
|
||||||
|
hamming_threshold: int = 8
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
|
class DetectionConfig(BaseModel):
|
||||||
|
"""DetectionConfig(model_name: str = 'yolov8n.pt', confidence_threshold: float = 0.3, target_classes: List[str] = <factory>)"""
|
||||||
|
model_name: str = "yolov8n.pt"
|
||||||
|
confidence_threshold: float = 0.3
|
||||||
|
target_classes: List[str]
|
||||||
|
|
||||||
|
class OCRConfig(BaseModel):
|
||||||
|
"""OCRConfig(languages: List[str] = <factory>, min_confidence: float = 0.5)"""
|
||||||
|
languages: List[str]
|
||||||
|
min_confidence: float = 0.5
|
||||||
|
|
||||||
|
class ResolverConfig(BaseModel):
|
||||||
|
"""ResolverConfig(fuzzy_threshold: int = 75)"""
|
||||||
|
fuzzy_threshold: int = 75
|
||||||
|
|
||||||
|
class RegionAnalysisConfig(BaseModel):
|
||||||
|
"""RegionAnalysisConfig(enabled: bool = True, edge_canny_low: int = 50, edge_canny_high: int = 150, edge_hough_threshold: int = 80, edge_hough_min_length: int = 100, edge_hough_max_gap: int = 10, edge_pair_max_distance: int = 200, edge_pair_min_distance: int = 15)"""
|
||||||
|
enabled: bool = True
|
||||||
|
edge_canny_low: int = 50
|
||||||
|
edge_canny_high: int = 150
|
||||||
|
edge_hough_threshold: int = 80
|
||||||
|
edge_hough_min_length: int = 100
|
||||||
|
edge_hough_max_gap: int = 10
|
||||||
|
edge_pair_max_distance: int = 200
|
||||||
|
edge_pair_min_distance: int = 15
|
||||||
|
|
||||||
|
class FieldSegmentationConfig(BaseModel):
|
||||||
|
"""FieldSegmentationConfig(enabled: bool = True, hue_low: int = 30, hue_high: int = 85, sat_low: int = 30, sat_high: int = 255, val_low: int = 30, val_high: int = 255, morph_kernel: int = 15, min_area_ratio: float = 0.05)"""
|
||||||
|
enabled: bool = True
|
||||||
|
hue_low: int = 30
|
||||||
|
hue_high: int = 85
|
||||||
|
sat_low: int = 30
|
||||||
|
sat_high: int = 255
|
||||||
|
val_low: int = 30
|
||||||
|
val_high: int = 255
|
||||||
|
morph_kernel: int = 15
|
||||||
|
min_area_ratio: float = 0.05
|
||||||
|
|
||||||
|
class StageRef(BaseModel):
|
||||||
|
"""Reference to a stage in the pipeline graph."""
|
||||||
|
name: str
|
||||||
|
branch: str = "trunk"
|
||||||
|
execution_target: str = "local"
|
||||||
|
|
||||||
|
class Edge(BaseModel):
|
||||||
|
"""Connection between stages in the graph."""
|
||||||
|
source: str
|
||||||
|
target: str
|
||||||
|
condition: str = ""
|
||||||
|
|
||||||
|
class PipelineConfig(BaseModel):
|
||||||
|
"""Pipeline graph topology + routing rules."""
|
||||||
|
name: str
|
||||||
|
profile_name: str
|
||||||
|
stages: List[StageRef] = Field(default_factory=list)
|
||||||
|
edges: List[Edge] = Field(default_factory=list)
|
||||||
|
routing_rules: Dict[str, Any]
|
||||||
@@ -18,9 +18,9 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from detect import emit
|
from core.detect import emit
|
||||||
from detect.models import BoundingBox, Frame, TextCandidate
|
from core.detect.models import BoundingBox, Frame, TextCandidate
|
||||||
from detect.profiles.base import OCRConfig
|
from core.detect.stages.models import OCRConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
pass
|
pass
|
||||||
@@ -91,8 +91,8 @@ def run_ocr(
|
|||||||
|
|
||||||
# Build these once per pipeline run, not per crop
|
# Build these once per pipeline run, not per crop
|
||||||
if inference_url:
|
if inference_url:
|
||||||
from detect.inference import InferenceClient
|
from core.detect.inference import InferenceClient
|
||||||
from detect.emit import _run_log_level
|
from core.detect.emit import _run_log_level
|
||||||
client = InferenceClient(base_url=inference_url, job_id=job_id or "", log_level=_run_log_level)
|
client = InferenceClient(base_url=inference_url, job_id=job_id or "", log_level=_run_log_level)
|
||||||
else:
|
else:
|
||||||
model = _get_local_model(config.languages[0])
|
model = _get_local_model(config.languages[0])
|
||||||
@@ -15,8 +15,8 @@ import logging
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from detect import emit
|
from core.detect import emit
|
||||||
from detect.models import BoundingBox, Frame
|
from core.detect.models import BoundingBox, Frame
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -124,5 +124,5 @@ def _preprocess_remote(crop: np.ndarray, inference_url: str,
|
|||||||
def _preprocess_local(crop: np.ndarray,
|
def _preprocess_local(crop: np.ndarray,
|
||||||
do_contrast: bool, do_deskew: bool, do_binarize: bool) -> np.ndarray:
|
do_contrast: bool, do_deskew: bool, do_binarize: bool) -> np.ndarray:
|
||||||
"""Run preprocessing in-process (requires opencv-python-headless)."""
|
"""Run preprocessing in-process (requires opencv-python-headless)."""
|
||||||
from gpu.models.preprocess import preprocess
|
from core.gpu.models.preprocess import preprocess
|
||||||
return preprocess(crop, do_binarize=do_binarize, do_deskew=do_deskew, do_contrast=do_contrast)
|
return preprocess(crop, do_binarize=do_binarize, do_deskew=do_deskew, do_contrast=do_contrast)
|
||||||
44
core/detect/stages/registry/cv_analysis.py
Normal file
44
core/detect/stages/registry/cv_analysis.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""Registration for CV analysis stages: edge detection."""
|
||||||
|
|
||||||
|
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||||
|
from core.detect.stages.base import register_stage
|
||||||
|
from ._serializers import serialize_dataclass_list, deserialize_bounding_box
|
||||||
|
|
||||||
|
|
||||||
|
def _ser_regions(state: dict, job_id: str) -> dict:
|
||||||
|
regions = state.get("edge_regions_by_frame", {})
|
||||||
|
serialized = {
|
||||||
|
str(seq): serialize_dataclass_list(bl) for seq, bl in regions.items()
|
||||||
|
}
|
||||||
|
return {"edge_regions_by_frame": serialized}
|
||||||
|
|
||||||
|
|
||||||
|
def _deser_regions(data: dict, job_id: str) -> dict:
|
||||||
|
regions = {}
|
||||||
|
for seq_str, box_dicts in data.get("edge_regions_by_frame", {}).items():
|
||||||
|
regions[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts]
|
||||||
|
return {"edge_regions_by_frame": regions}
|
||||||
|
|
||||||
|
|
||||||
|
def register():
|
||||||
|
edge_detection = StageDefinition(
|
||||||
|
name="detect_edges",
|
||||||
|
label="Edge Detection",
|
||||||
|
description="Canny + HoughLinesP — find horizontal line pairs (hoarding boundaries)",
|
||||||
|
category="cv_analysis",
|
||||||
|
io=StageIO(
|
||||||
|
reads=["filtered_frames"],
|
||||||
|
writes=["edge_regions_by_frame"],
|
||||||
|
),
|
||||||
|
config_fields=[
|
||||||
|
StageConfigField(name="enabled", type="bool", default=True, description="Enable region analysis"),
|
||||||
|
StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255),
|
||||||
|
StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255),
|
||||||
|
StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500),
|
||||||
|
StageConfigField(name="edge_hough_min_length", type="int", default=100, description="Min line length (px)", min=10, max=2000),
|
||||||
|
StageConfigField(name="edge_hough_max_gap", type="int", default=10, description="Max line gap (px)", min=1, max=100),
|
||||||
|
StageConfigField(name="edge_pair_max_distance", type="int", default=200, description="Max distance between line pair (px)", min=10, max=500),
|
||||||
|
StageConfigField(name="edge_pair_min_distance", type="int", default=15, description="Min distance between line pair (px)", min=5, max=200),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
register_stage(edge_detection, serialize_fn=_ser_regions, deserialize_fn=_deser_regions)
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Registration for detection stages: YOLO, OCR."""
|
"""Registration for detection stages: YOLO, OCR."""
|
||||||
|
|
||||||
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||||
|
from core.detect.stages.base import register_stage
|
||||||
from ._serializers import (
|
from ._serializers import (
|
||||||
serialize_dataclass_list,
|
serialize_dataclass_list,
|
||||||
serialize_text_candidates,
|
serialize_text_candidates,
|
||||||
@@ -38,14 +39,12 @@ def register():
|
|||||||
category="detection",
|
category="detection",
|
||||||
io=StageIO(reads=["filtered_frames"], writes=["boxes_by_frame"]),
|
io=StageIO(reads=["filtered_frames"], writes=["boxes_by_frame"]),
|
||||||
config_fields=[
|
config_fields=[
|
||||||
StageConfigField("model_name", "str", "yolov8n.pt", "YOLO model file"),
|
StageConfigField(name="model_name", type="str", default="yolov8n.pt", description="YOLO model file"),
|
||||||
StageConfigField("confidence_threshold", "float", 0.3, "Min detection confidence", min=0.0, max=1.0),
|
StageConfigField(name="confidence_threshold", type="float", default=0.3, description="Min detection confidence", min=0.0, max=1.0),
|
||||||
StageConfigField("target_classes", "list[str]", [], "YOLO classes to detect (empty = all)"),
|
StageConfigField(name="target_classes", type="list[str]", default=[], description="YOLO classes to detect (empty = all)"),
|
||||||
],
|
],
|
||||||
serialize_fn=_ser_detect,
|
|
||||||
deserialize_fn=_deser_detect,
|
|
||||||
)
|
)
|
||||||
register_stage(yolo)
|
register_stage(yolo, serialize_fn=_ser_detect, deserialize_fn=_deser_detect)
|
||||||
|
|
||||||
ocr = StageDefinition(
|
ocr = StageDefinition(
|
||||||
name="run_ocr",
|
name="run_ocr",
|
||||||
@@ -54,10 +53,8 @@ def register():
|
|||||||
category="detection",
|
category="detection",
|
||||||
io=StageIO(reads=["filtered_frames", "boxes_by_frame"], writes=["text_candidates"]),
|
io=StageIO(reads=["filtered_frames", "boxes_by_frame"], writes=["text_candidates"]),
|
||||||
config_fields=[
|
config_fields=[
|
||||||
StageConfigField("languages", "list[str]", ["en"], "OCR languages"),
|
StageConfigField(name="languages", type="list[str]", default=["en"], description="OCR languages"),
|
||||||
StageConfigField("min_confidence", "float", 0.5, "Min OCR confidence", min=0.0, max=1.0),
|
StageConfigField(name="min_confidence", type="float", default=0.5, description="Min OCR confidence", min=0.0, max=1.0),
|
||||||
],
|
],
|
||||||
serialize_fn=_ser_ocr,
|
|
||||||
deserialize_fn=_deser_ocr,
|
|
||||||
)
|
)
|
||||||
register_stage(ocr)
|
register_stage(ocr, serialize_fn=_ser_ocr, deserialize_fn=_deser_ocr)
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Registration for escalation stages: local VLM, cloud LLM."""
|
"""Registration for escalation stages: local VLM, cloud LLM."""
|
||||||
|
|
||||||
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||||
|
from core.detect.stages.base import register_stage
|
||||||
from ._serializers import (
|
from ._serializers import (
|
||||||
serialize_dataclass_list,
|
serialize_dataclass_list,
|
||||||
serialize_text_candidates,
|
serialize_text_candidates,
|
||||||
@@ -37,12 +38,10 @@ def register():
|
|||||||
optional_reads=["source_asset_id"],
|
optional_reads=["source_asset_id"],
|
||||||
),
|
),
|
||||||
config_fields=[
|
config_fields=[
|
||||||
StageConfigField("min_confidence", "float", 0.5, "Min VLM confidence", min=0.0, max=1.0),
|
StageConfigField(name="min_confidence", type="float", default=0.5, description="Min VLM confidence", min=0.0, max=1.0),
|
||||||
],
|
],
|
||||||
serialize_fn=_ser_escalation,
|
|
||||||
deserialize_fn=_deser_escalation,
|
|
||||||
)
|
)
|
||||||
register_stage(vlm)
|
register_stage(vlm, serialize_fn=_ser_escalation, deserialize_fn=_deser_escalation)
|
||||||
|
|
||||||
cloud = StageDefinition(
|
cloud = StageDefinition(
|
||||||
name="escalate_cloud",
|
name="escalate_cloud",
|
||||||
@@ -55,9 +54,7 @@ def register():
|
|||||||
optional_reads=["source_asset_id"],
|
optional_reads=["source_asset_id"],
|
||||||
),
|
),
|
||||||
config_fields=[
|
config_fields=[
|
||||||
StageConfigField("min_confidence", "float", 0.4, "Min cloud confidence", min=0.0, max=1.0),
|
StageConfigField(name="min_confidence", type="float", default=0.4, description="Min cloud confidence", min=0.0, max=1.0),
|
||||||
],
|
],
|
||||||
serialize_fn=_ser_escalation,
|
|
||||||
deserialize_fn=_deser_escalation,
|
|
||||||
)
|
)
|
||||||
register_stage(cloud)
|
register_stage(cloud, serialize_fn=_ser_escalation, deserialize_fn=_deser_escalation)
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Registration for output stages: report compilation."""
|
"""Registration for output stages: report compilation."""
|
||||||
|
|
||||||
from detect.stages.base import StageDefinition, StageIO, register_stage
|
from core.detect.stages.base import StageDefinition, StageIO, register_stage
|
||||||
from ._serializers import serialize_dataclass, deserialize_detection_report
|
from ._serializers import serialize_dataclass, deserialize_detection_report
|
||||||
|
|
||||||
|
|
||||||
@@ -26,7 +26,5 @@ def register():
|
|||||||
category="output",
|
category="output",
|
||||||
io=StageIO(reads=["detections"], writes=["report"]),
|
io=StageIO(reads=["detections"], writes=["report"]),
|
||||||
config_fields=[],
|
config_fields=[],
|
||||||
serialize_fn=_ser_report,
|
|
||||||
deserialize_fn=_deser_report,
|
|
||||||
)
|
)
|
||||||
register_stage(report)
|
register_stage(report, serialize_fn=_ser_report, deserialize_fn=_deser_report)
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Registration for preprocessing stages: frame extraction, scene filter, image preprocessing."""
|
"""Registration for preprocessing stages: frame extraction, scene filter, image preprocessing."""
|
||||||
|
|
||||||
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||||
|
from core.detect.stages.base import register_stage
|
||||||
from ._serializers import serialize_frames, deserialize_frames
|
from ._serializers import serialize_frames, deserialize_frames
|
||||||
|
|
||||||
|
|
||||||
@@ -44,13 +45,11 @@ def register():
|
|||||||
category="preprocessing",
|
category="preprocessing",
|
||||||
io=StageIO(reads=["video_path"], writes=["frames"]),
|
io=StageIO(reads=["video_path"], writes=["frames"]),
|
||||||
config_fields=[
|
config_fields=[
|
||||||
StageConfigField("fps", "float", 2.0, "Frames per second", min=0.1, max=30.0),
|
StageConfigField(name="fps", type="float", default=2.0, description="Frames per second", min=0.1, max=30.0),
|
||||||
StageConfigField("max_frames", "int", 500, "Maximum frames to extract", min=1, max=10000),
|
StageConfigField(name="max_frames", type="int", default=500, description="Maximum frames to extract", min=1, max=10000),
|
||||||
],
|
],
|
||||||
serialize_fn=_ser_extract,
|
|
||||||
deserialize_fn=_deser_extract,
|
|
||||||
)
|
)
|
||||||
register_stage(extract)
|
register_stage(extract, serialize_fn=_ser_extract, deserialize_fn=_deser_extract)
|
||||||
|
|
||||||
scene_filter = StageDefinition(
|
scene_filter = StageDefinition(
|
||||||
name="filter_scenes",
|
name="filter_scenes",
|
||||||
@@ -59,13 +58,11 @@ def register():
|
|||||||
category="preprocessing",
|
category="preprocessing",
|
||||||
io=StageIO(reads=["frames"], writes=["filtered_frames"]),
|
io=StageIO(reads=["frames"], writes=["filtered_frames"]),
|
||||||
config_fields=[
|
config_fields=[
|
||||||
StageConfigField("hamming_threshold", "int", 8, "Hamming distance threshold", min=0, max=64),
|
StageConfigField(name="hamming_threshold", type="int", default=8, description="Hamming distance threshold", min=0, max=64),
|
||||||
StageConfigField("enabled", "bool", True, "Enable scene filtering"),
|
StageConfigField(name="enabled", type="bool", default=True, description="Enable scene filtering"),
|
||||||
],
|
],
|
||||||
serialize_fn=_ser_filter,
|
|
||||||
deserialize_fn=_deser_filter,
|
|
||||||
)
|
)
|
||||||
register_stage(scene_filter)
|
register_stage(scene_filter, serialize_fn=_ser_filter, deserialize_fn=_deser_filter)
|
||||||
|
|
||||||
preprocess = StageDefinition(
|
preprocess = StageDefinition(
|
||||||
name="preprocess",
|
name="preprocess",
|
||||||
@@ -77,11 +74,9 @@ def register():
|
|||||||
writes=["preprocessed_crops"],
|
writes=["preprocessed_crops"],
|
||||||
),
|
),
|
||||||
config_fields=[
|
config_fields=[
|
||||||
StageConfigField("contrast", "bool", True, "CLAHE contrast enhancement"),
|
StageConfigField(name="contrast", type="bool", default=True, description="CLAHE contrast enhancement"),
|
||||||
StageConfigField("deskew", "bool", False, "Correct slight rotation"),
|
StageConfigField(name="deskew", type="bool", default=False, description="Correct slight rotation"),
|
||||||
StageConfigField("binarize", "bool", False, "Otsu binarization"),
|
StageConfigField(name="binarize", type="bool", default=False, description="Otsu binarization"),
|
||||||
],
|
],
|
||||||
serialize_fn=_ser_preprocess,
|
|
||||||
deserialize_fn=_deser_preprocess,
|
|
||||||
)
|
)
|
||||||
register_stage(preprocess)
|
register_stage(preprocess, serialize_fn=_ser_preprocess, deserialize_fn=_deser_preprocess)
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Registration for resolution stages: brand resolver."""
|
"""Registration for resolution stages: brand resolver."""
|
||||||
|
|
||||||
from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage
|
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||||
|
from core.detect.stages.base import register_stage
|
||||||
from ._serializers import (
|
from ._serializers import (
|
||||||
serialize_dataclass_list,
|
serialize_dataclass_list,
|
||||||
serialize_text_candidates,
|
serialize_text_candidates,
|
||||||
@@ -37,9 +38,7 @@ def register():
|
|||||||
optional_reads=["session_brands", "source_asset_id"],
|
optional_reads=["session_brands", "source_asset_id"],
|
||||||
),
|
),
|
||||||
config_fields=[
|
config_fields=[
|
||||||
StageConfigField("fuzzy_threshold", "int", 75, "Fuzzy match threshold", min=0, max=100),
|
StageConfigField(name="fuzzy_threshold", type="int", default=75, description="Fuzzy match threshold", min=0, max=100),
|
||||||
],
|
],
|
||||||
serialize_fn=_ser_brands,
|
|
||||||
deserialize_fn=_deser_brands,
|
|
||||||
)
|
)
|
||||||
register_stage(resolver)
|
register_stage(resolver, serialize_fn=_ser_brands, deserialize_fn=_deser_brands)
|
||||||
@@ -14,9 +14,9 @@ import time
|
|||||||
import imagehash
|
import imagehash
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from detect import emit
|
from core.detect import emit
|
||||||
from detect.models import Frame
|
from core.detect.models import Frame
|
||||||
from detect.profiles.base import SceneFilterConfig
|
from core.detect.stages.models import SceneFilterConfig
|
||||||
|
|
||||||
|
|
||||||
def _compute_hashes(frames: list[Frame]) -> list[imagehash.ImageHash]:
|
def _compute_hashes(frames: list[Frame]) -> list[imagehash.ImageHash]:
|
||||||
@@ -19,10 +19,10 @@ import time
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from detect import emit
|
from core.detect import emit
|
||||||
from detect.models import BrandDetection, PipelineStats, TextCandidate
|
from core.detect.models import BrandDetection, PipelineStats, TextCandidate
|
||||||
from detect.profiles.base import CropContext
|
from core.detect.models import CropContext
|
||||||
from detect.providers import get_provider, has_api_key
|
from core.detect.providers import get_provider, has_api_key
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ def _register_discovered_brand(brand: str, source_asset_id: str | None,
|
|||||||
timestamp: float, confidence: float):
|
timestamp: float, confidence: float):
|
||||||
"""Register a cloud-confirmed brand in the DB."""
|
"""Register a cloud-confirmed brand in the DB."""
|
||||||
try:
|
try:
|
||||||
from detect.stages.brand_resolver import _register_brand, _record_sighting
|
from core.detect.stages.brand_resolver import _register_brand, _record_sighting
|
||||||
brand_id = _register_brand(brand, "cloud_llm")
|
brand_id = _register_brand(brand, "cloud_llm")
|
||||||
if brand_id and source_asset_id:
|
if brand_id and source_asset_id:
|
||||||
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, "cloud_llm")
|
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, "cloud_llm")
|
||||||
@@ -14,9 +14,9 @@ import time
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from detect import emit
|
from core.detect import emit
|
||||||
from detect.models import BrandDetection, TextCandidate
|
from core.detect.models import BrandDetection, TextCandidate
|
||||||
from detect.profiles.base import CropContext
|
from core.detect.models import CropContext
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -25,7 +25,7 @@ def _register_discovered_brand(brand: str, source_asset_id: str | None,
|
|||||||
timestamp: float, confidence: float, source: str):
|
timestamp: float, confidence: float, source: str):
|
||||||
"""Register a VLM-confirmed brand in the DB."""
|
"""Register a VLM-confirmed brand in the DB."""
|
||||||
try:
|
try:
|
||||||
from detect.stages.brand_resolver import _register_brand, _record_sighting
|
from core.detect.stages.brand_resolver import _register_brand, _record_sighting
|
||||||
brand_id = _register_brand(brand, source)
|
brand_id = _register_brand(brand, source)
|
||||||
if brand_id and source_asset_id:
|
if brand_id and source_asset_id:
|
||||||
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, source)
|
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, source)
|
||||||
@@ -75,8 +75,8 @@ def escalate_vlm(
|
|||||||
still_unresolved: list[TextCandidate] = []
|
still_unresolved: list[TextCandidate] = []
|
||||||
|
|
||||||
if inference_url:
|
if inference_url:
|
||||||
from detect.inference import InferenceClient
|
from core.detect.inference import InferenceClient
|
||||||
from detect.emit import _run_log_level
|
from core.detect.emit import _run_log_level
|
||||||
client = InferenceClient(base_url=inference_url, job_id=job_id or "", log_level=_run_log_level)
|
client = InferenceClient(base_url=inference_url, job_id=job_id or "", log_level=_run_log_level)
|
||||||
|
|
||||||
for i, candidate in enumerate(candidates):
|
for i, candidate in enumerate(candidates):
|
||||||
@@ -152,6 +152,6 @@ def escalate_vlm(
|
|||||||
|
|
||||||
def _vlm_local(crop: np.ndarray, prompt: str) -> tuple[str, float, str]:
|
def _vlm_local(crop: np.ndarray, prompt: str) -> tuple[str, float, str]:
|
||||||
"""Run moondream2 in-process (single-box mode)."""
|
"""Run moondream2 in-process (single-box mode)."""
|
||||||
from gpu.models.vlm import query
|
from core.gpu.models.vlm import query
|
||||||
result = query(crop, prompt)
|
result = query(crop, prompt)
|
||||||
return result["brand"], result["confidence"], result["reasoning"]
|
return result["brand"], result["confidence"], result["reasoning"]
|
||||||
@@ -18,9 +18,9 @@ import time
|
|||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from detect import emit
|
from core.detect import emit
|
||||||
from detect.models import BoundingBox, Frame
|
from core.detect.models import BoundingBox, Frame
|
||||||
from detect.profiles.base import DetectionConfig
|
from core.detect.stages.models import DetectionConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ def _frame_to_b64(frame: Frame) -> str:
|
|||||||
def _detect_remote(frame: Frame, config: DetectionConfig, inference_url: str,
|
def _detect_remote(frame: Frame, config: DetectionConfig, inference_url: str,
|
||||||
job_id: str = "", log_level: str = "INFO") -> list[BoundingBox]:
|
job_id: str = "", log_level: str = "INFO") -> list[BoundingBox]:
|
||||||
"""Call the inference server over HTTP."""
|
"""Call the inference server over HTTP."""
|
||||||
from detect.inference import InferenceClient
|
from core.detect.inference import InferenceClient
|
||||||
client = InferenceClient(base_url=inference_url, job_id=job_id, log_level=log_level)
|
client = InferenceClient(base_url=inference_url, job_id=job_id, log_level=log_level)
|
||||||
results = client.detect(
|
results = client.detect(
|
||||||
image=frame.image,
|
image=frame.image,
|
||||||
@@ -104,7 +104,7 @@ def detect_objects(
|
|||||||
for i, frame in enumerate(frames):
|
for i, frame in enumerate(frames):
|
||||||
t0 = time.monotonic()
|
t0 = time.monotonic()
|
||||||
if inference_url:
|
if inference_url:
|
||||||
from detect.emit import _run_log_level
|
from core.detect.emit import _run_log_level
|
||||||
boxes = _detect_remote(frame, config, inference_url,
|
boxes = _detect_remote(frame, config, inference_url,
|
||||||
job_id=job_id or "", log_level=_run_log_level)
|
job_id=job_id or "", log_level=_run_log_level)
|
||||||
else:
|
else:
|
||||||
@@ -9,7 +9,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
from detect.models import BoundingBox, BrandDetection, DetectionReport, Frame, PipelineStats, TextCandidate
|
from core.detect.models import BoundingBox, BrandDetection, DetectionReport, Frame, PipelineStats, TextCandidate
|
||||||
|
|
||||||
|
|
||||||
class DetectState(TypedDict, total=False):
|
class DetectState(TypedDict, total=False):
|
||||||
@@ -22,6 +22,9 @@ class DetectState(TypedDict, total=False):
|
|||||||
# Stage outputs
|
# Stage outputs
|
||||||
frames: list[Frame]
|
frames: list[Frame]
|
||||||
filtered_frames: list[Frame]
|
filtered_frames: list[Frame]
|
||||||
|
field_masks: dict # {seq: np.ndarray} — pitch mask per frame
|
||||||
|
field_boundaries: dict # {seq: [(x,y), ...]} — pitch boundary per frame
|
||||||
|
field_coverage: dict # {seq: float} — pitch coverage ratio per frame
|
||||||
edge_regions_by_frame: dict[int, list[BoundingBox]]
|
edge_regions_by_frame: dict[int, list[BoundingBox]]
|
||||||
boxes_by_frame: dict[int, list[BoundingBox]]
|
boxes_by_frame: dict[int, list[BoundingBox]]
|
||||||
preprocessed_crops: dict # "{frame_seq}_{box_idx}" → np.ndarray
|
preprocessed_crops: dict # "{frame_seq}_{box_idx}" → np.ndarray
|
||||||
@@ -36,5 +39,5 @@ class DetectState(TypedDict, total=False):
|
|||||||
# Running stats (updated by each stage)
|
# Running stats (updated by each stage)
|
||||||
stats: PipelineStats
|
stats: PipelineStats
|
||||||
|
|
||||||
# Config overrides for replay (applied via OverrideProfile)
|
# Config overrides for replay (merged into profile configs dict)
|
||||||
config_overrides: dict
|
config_overrides: dict
|
||||||
@@ -6,7 +6,7 @@ and stage-level metadata. The Langfuse client is optional — if not configured
|
|||||||
(no LANGFUSE_SECRET_KEY), tracing is a no-op.
|
(no LANGFUSE_SECRET_KEY), tracing is a no-op.
|
||||||
|
|
||||||
Usage in graph nodes:
|
Usage in graph nodes:
|
||||||
from detect.tracing import trace_node
|
from core.detect.tracing import trace_node
|
||||||
|
|
||||||
def node_extract_frames(state):
|
def node_extract_frames(state):
|
||||||
with trace_node(state, "extract_frames") as span:
|
with trace_node(state, "extract_frames") as span:
|
||||||
@@ -1,13 +1,6 @@
|
|||||||
from .capabilities import get_decoders, get_encoders, get_formats
|
|
||||||
from .probe import ProbeResult, probe_file
|
from .probe import ProbeResult, probe_file
|
||||||
from .transcode import TranscodeConfig, transcode
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"probe_file",
|
"probe_file",
|
||||||
"ProbeResult",
|
"ProbeResult",
|
||||||
"transcode",
|
|
||||||
"TranscodeConfig",
|
|
||||||
"get_encoders",
|
|
||||||
"get_decoders",
|
|
||||||
"get_formats",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,145 +0,0 @@
|
|||||||
"""
|
|
||||||
FFmpeg capabilities - Discover available codecs and formats using ffmpeg-python.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from functools import lru_cache
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
import ffmpeg
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Codec:
|
|
||||||
"""An FFmpeg encoder or decoder."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
type: str # 'video' or 'audio'
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Format:
|
|
||||||
"""An FFmpeg format (muxer/demuxer)."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
can_demux: bool
|
|
||||||
can_mux: bool
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
|
||||||
def _get_ffmpeg_info() -> Dict[str, Any]:
|
|
||||||
"""Get FFmpeg capabilities info."""
|
|
||||||
# ffmpeg-python doesn't have a direct way to get codecs/formats
|
|
||||||
# but we can use probe on a dummy or parse -codecs output
|
|
||||||
# For now, return common codecs that are typically available
|
|
||||||
return {
|
|
||||||
"video_encoders": [
|
|
||||||
{"name": "libx264", "description": "H.264 / AVC"},
|
|
||||||
{"name": "libx265", "description": "H.265 / HEVC"},
|
|
||||||
{"name": "mpeg4", "description": "MPEG-4 Part 2"},
|
|
||||||
{"name": "libvpx", "description": "VP8"},
|
|
||||||
{"name": "libvpx-vp9", "description": "VP9"},
|
|
||||||
{"name": "h264_nvenc", "description": "NVIDIA NVENC H.264"},
|
|
||||||
{"name": "hevc_nvenc", "description": "NVIDIA NVENC H.265"},
|
|
||||||
{"name": "h264_vaapi", "description": "VAAPI H.264"},
|
|
||||||
{"name": "prores_ks", "description": "Apple ProRes"},
|
|
||||||
{"name": "dnxhd", "description": "Avid DNxHD/DNxHR"},
|
|
||||||
{"name": "copy", "description": "Stream copy (no encoding)"},
|
|
||||||
],
|
|
||||||
"audio_encoders": [
|
|
||||||
{"name": "aac", "description": "AAC"},
|
|
||||||
{"name": "libmp3lame", "description": "MP3"},
|
|
||||||
{"name": "libopus", "description": "Opus"},
|
|
||||||
{"name": "libvorbis", "description": "Vorbis"},
|
|
||||||
{"name": "pcm_s16le", "description": "PCM signed 16-bit little-endian"},
|
|
||||||
{"name": "flac", "description": "FLAC"},
|
|
||||||
{"name": "copy", "description": "Stream copy (no encoding)"},
|
|
||||||
],
|
|
||||||
"formats": [
|
|
||||||
{"name": "mp4", "description": "MP4", "can_demux": True, "can_mux": True},
|
|
||||||
{
|
|
||||||
"name": "mov",
|
|
||||||
"description": "QuickTime / MOV",
|
|
||||||
"can_demux": True,
|
|
||||||
"can_mux": True,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "mkv",
|
|
||||||
"description": "Matroska",
|
|
||||||
"can_demux": True,
|
|
||||||
"can_mux": True,
|
|
||||||
},
|
|
||||||
{"name": "webm", "description": "WebM", "can_demux": True, "can_mux": True},
|
|
||||||
{"name": "avi", "description": "AVI", "can_demux": True, "can_mux": True},
|
|
||||||
{"name": "flv", "description": "FLV", "can_demux": True, "can_mux": True},
|
|
||||||
{
|
|
||||||
"name": "ts",
|
|
||||||
"description": "MPEG-TS",
|
|
||||||
"can_demux": True,
|
|
||||||
"can_mux": True,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "mpegts",
|
|
||||||
"description": "MPEG-TS",
|
|
||||||
"can_demux": True,
|
|
||||||
"can_mux": True,
|
|
||||||
},
|
|
||||||
{"name": "hls", "description": "HLS", "can_demux": True, "can_mux": True},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_encoders() -> List[Codec]:
|
|
||||||
"""Get available encoders (video + audio)."""
|
|
||||||
info = _get_ffmpeg_info()
|
|
||||||
codecs = []
|
|
||||||
|
|
||||||
for c in info["video_encoders"]:
|
|
||||||
codecs.append(Codec(name=c["name"], description=c["description"], type="video"))
|
|
||||||
|
|
||||||
for c in info["audio_encoders"]:
|
|
||||||
codecs.append(Codec(name=c["name"], description=c["description"], type="audio"))
|
|
||||||
|
|
||||||
return codecs
|
|
||||||
|
|
||||||
|
|
||||||
def get_decoders() -> List[Codec]:
|
|
||||||
"""Get available decoders."""
|
|
||||||
# Most encoders can also decode
|
|
||||||
return get_encoders()
|
|
||||||
|
|
||||||
|
|
||||||
def get_formats() -> List[Format]:
|
|
||||||
"""Get available formats."""
|
|
||||||
info = _get_ffmpeg_info()
|
|
||||||
return [
|
|
||||||
Format(
|
|
||||||
name=f["name"],
|
|
||||||
description=f["description"],
|
|
||||||
can_demux=f["can_demux"],
|
|
||||||
can_mux=f["can_mux"],
|
|
||||||
)
|
|
||||||
for f in info["formats"]
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_video_encoders() -> List[Codec]:
|
|
||||||
"""Get available video encoders."""
|
|
||||||
return [c for c in get_encoders() if c.type == "video"]
|
|
||||||
|
|
||||||
|
|
||||||
def get_audio_encoders() -> List[Codec]:
|
|
||||||
"""Get available audio encoders."""
|
|
||||||
return [c for c in get_encoders() if c.type == "audio"]
|
|
||||||
|
|
||||||
|
|
||||||
def get_muxers() -> List[Format]:
|
|
||||||
"""Get available output formats (muxers)."""
|
|
||||||
return [f for f in get_formats() if f.can_mux]
|
|
||||||
|
|
||||||
|
|
||||||
def get_demuxers() -> List[Format]:
|
|
||||||
"""Get available input formats (demuxers)."""
|
|
||||||
return [f for f in get_formats() if f.can_demux]
|
|
||||||
@@ -1,225 +0,0 @@
|
|||||||
"""
|
|
||||||
FFmpeg transcode module - Transcode media files using ffmpeg-python.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
|
||||||
|
|
||||||
import ffmpeg
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TranscodeConfig:
|
|
||||||
"""Configuration for a transcode operation."""
|
|
||||||
|
|
||||||
input_path: str
|
|
||||||
output_path: str
|
|
||||||
|
|
||||||
# Video
|
|
||||||
video_codec: str = "libx264"
|
|
||||||
video_bitrate: Optional[str] = None
|
|
||||||
video_crf: Optional[int] = None
|
|
||||||
video_preset: Optional[str] = None
|
|
||||||
resolution: Optional[str] = None
|
|
||||||
framerate: Optional[float] = None
|
|
||||||
|
|
||||||
# Audio
|
|
||||||
audio_codec: str = "aac"
|
|
||||||
audio_bitrate: Optional[str] = None
|
|
||||||
audio_channels: Optional[int] = None
|
|
||||||
audio_samplerate: Optional[int] = None
|
|
||||||
|
|
||||||
# Trimming
|
|
||||||
trim_start: Optional[float] = None
|
|
||||||
trim_end: Optional[float] = None
|
|
||||||
|
|
||||||
# Container
|
|
||||||
container: str = "mp4"
|
|
||||||
|
|
||||||
# Extra args (key-value pairs)
|
|
||||||
extra_args: List[str] = field(default_factory=list)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_copy(self) -> bool:
|
|
||||||
"""Check if this is a stream copy (no transcoding)."""
|
|
||||||
return self.video_codec == "copy" and self.audio_codec == "copy"
|
|
||||||
|
|
||||||
|
|
||||||
def build_stream(config: TranscodeConfig):
|
|
||||||
"""
|
|
||||||
Build an ffmpeg-python stream from config.
|
|
||||||
|
|
||||||
Returns the stream object ready to run.
|
|
||||||
"""
|
|
||||||
# Input options
|
|
||||||
input_kwargs = {}
|
|
||||||
if config.trim_start is not None:
|
|
||||||
input_kwargs["ss"] = config.trim_start
|
|
||||||
|
|
||||||
stream = ffmpeg.input(config.input_path, **input_kwargs)
|
|
||||||
|
|
||||||
# Output options
|
|
||||||
output_kwargs = {
|
|
||||||
"vcodec": config.video_codec,
|
|
||||||
"acodec": config.audio_codec,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Trimming duration
|
|
||||||
if config.trim_end is not None:
|
|
||||||
if config.trim_start is not None:
|
|
||||||
output_kwargs["t"] = config.trim_end - config.trim_start
|
|
||||||
else:
|
|
||||||
output_kwargs["t"] = config.trim_end
|
|
||||||
|
|
||||||
# Video options (skip if copy)
|
|
||||||
if config.video_codec != "copy":
|
|
||||||
if config.video_crf is not None:
|
|
||||||
output_kwargs["crf"] = config.video_crf
|
|
||||||
elif config.video_bitrate:
|
|
||||||
output_kwargs["video_bitrate"] = config.video_bitrate
|
|
||||||
|
|
||||||
if config.video_preset:
|
|
||||||
output_kwargs["preset"] = config.video_preset
|
|
||||||
|
|
||||||
if config.resolution:
|
|
||||||
output_kwargs["s"] = config.resolution
|
|
||||||
|
|
||||||
if config.framerate:
|
|
||||||
output_kwargs["r"] = config.framerate
|
|
||||||
|
|
||||||
# Audio options (skip if copy)
|
|
||||||
if config.audio_codec != "copy":
|
|
||||||
if config.audio_bitrate:
|
|
||||||
output_kwargs["audio_bitrate"] = config.audio_bitrate
|
|
||||||
if config.audio_channels:
|
|
||||||
output_kwargs["ac"] = config.audio_channels
|
|
||||||
if config.audio_samplerate:
|
|
||||||
output_kwargs["ar"] = config.audio_samplerate
|
|
||||||
|
|
||||||
# Parse extra args into kwargs
|
|
||||||
extra_kwargs = parse_extra_args(config.extra_args)
|
|
||||||
output_kwargs.update(extra_kwargs)
|
|
||||||
|
|
||||||
stream = ffmpeg.output(stream, config.output_path, **output_kwargs)
|
|
||||||
stream = ffmpeg.overwrite_output(stream)
|
|
||||||
|
|
||||||
return stream
|
|
||||||
|
|
||||||
|
|
||||||
def parse_extra_args(extra_args: List[str]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Parse extra args list into kwargs dict.
|
|
||||||
|
|
||||||
["-vtag", "xvid", "-pix_fmt", "yuv420p"] -> {"vtag": "xvid", "pix_fmt": "yuv420p"}
|
|
||||||
"""
|
|
||||||
kwargs = {}
|
|
||||||
i = 0
|
|
||||||
while i < len(extra_args):
|
|
||||||
key = extra_args[i].lstrip("-")
|
|
||||||
if i + 1 < len(extra_args) and not extra_args[i + 1].startswith("-"):
|
|
||||||
kwargs[key] = extra_args[i + 1]
|
|
||||||
i += 2
|
|
||||||
else:
|
|
||||||
# Flag without value
|
|
||||||
kwargs[key] = None
|
|
||||||
i += 1
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
|
|
||||||
def transcode(
|
|
||||||
config: TranscodeConfig,
|
|
||||||
duration: Optional[float] = None,
|
|
||||||
progress_callback: Optional[Callable[[float, Dict[str, Any]], None]] = None,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Transcode a media file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Transcode configuration
|
|
||||||
duration: Total duration in seconds (for progress calculation, optional)
|
|
||||||
progress_callback: Called with (percent, details_dict) - requires duration
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if successful
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ffmpeg.Error: If transcoding fails
|
|
||||||
"""
|
|
||||||
# Ensure output directory exists
|
|
||||||
Path(config.output_path).parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
stream = build_stream(config)
|
|
||||||
|
|
||||||
if progress_callback and duration:
|
|
||||||
# Run with progress tracking using run_async
|
|
||||||
return _run_with_progress(stream, config, duration, progress_callback)
|
|
||||||
else:
|
|
||||||
# Run synchronously
|
|
||||||
ffmpeg.run(stream, capture_stdout=True, capture_stderr=True)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _run_with_progress(
|
|
||||||
stream,
|
|
||||||
config: TranscodeConfig,
|
|
||||||
duration: float,
|
|
||||||
progress_callback: Callable[[float, Dict[str, Any]], None],
|
|
||||||
) -> bool:
|
|
||||||
"""Run FFmpeg with progress tracking using run_async and stderr parsing."""
|
|
||||||
import re
|
|
||||||
|
|
||||||
# Calculate effective duration
|
|
||||||
effective_duration = duration
|
|
||||||
if config.trim_start and config.trim_end:
|
|
||||||
effective_duration = config.trim_end - config.trim_start
|
|
||||||
elif config.trim_end:
|
|
||||||
effective_duration = config.trim_end
|
|
||||||
elif config.trim_start:
|
|
||||||
effective_duration = duration - config.trim_start
|
|
||||||
|
|
||||||
# Run async to get process handle
|
|
||||||
process = ffmpeg.run_async(stream, pipe_stdout=True, pipe_stderr=True)
|
|
||||||
|
|
||||||
# Parse stderr for progress (time=HH:MM:SS.ms pattern)
|
|
||||||
time_pattern = re.compile(r"time=(\d+):(\d+):(\d+)\.(\d+)")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
line = process.stderr.readline()
|
|
||||||
if not line:
|
|
||||||
break
|
|
||||||
|
|
||||||
line = line.decode("utf-8", errors="ignore")
|
|
||||||
match = time_pattern.search(line)
|
|
||||||
if match:
|
|
||||||
hours = int(match.group(1))
|
|
||||||
minutes = int(match.group(2))
|
|
||||||
seconds = int(match.group(3))
|
|
||||||
ms = int(match.group(4))
|
|
||||||
|
|
||||||
current_time = hours * 3600 + minutes * 60 + seconds + ms / 100
|
|
||||||
percent = min(100.0, (current_time / effective_duration) * 100)
|
|
||||||
|
|
||||||
progress_callback(
|
|
||||||
percent,
|
|
||||||
{
|
|
||||||
"time": current_time,
|
|
||||||
"percent": percent,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait for completion
|
|
||||||
process.wait()
|
|
||||||
|
|
||||||
if process.returncode != 0:
|
|
||||||
raise ffmpeg.Error(
|
|
||||||
"ffmpeg", stdout=process.stdout.read(), stderr=process.stderr.read()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Final callback
|
|
||||||
progress_callback(
|
|
||||||
100.0, {"time": effective_duration, "percent": 100.0, "done": True}
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
|
||||||
6
core/gpu/models/__init__.py
Normal file
6
core/gpu/models/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
# GPU models — standalone container imports.
|
||||||
|
# When running as a container (cd gpu && python server.py), bare imports work.
|
||||||
|
# When imported from the main app (core.gpu.models.preprocess), only
|
||||||
|
# individual modules should be imported directly, not this __init__.
|
||||||
|
#
|
||||||
|
# The server.py imports detect/ocr/vlm directly, not through this file.
|
||||||
86
core/gpu/models/cv/segmentation.py
Normal file
86
core/gpu/models/cv/segmentation.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
"""
|
||||||
|
Field segmentation — HSV green mask → pitch boundary contour.
|
||||||
|
|
||||||
|
Pure OpenCV. Called by the inference server endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def segment_field(
|
||||||
|
image: np.ndarray,
|
||||||
|
hue_low: int = 30,
|
||||||
|
hue_high: int = 85,
|
||||||
|
sat_low: int = 30,
|
||||||
|
sat_high: int = 255,
|
||||||
|
val_low: int = 30,
|
||||||
|
val_high: int = 255,
|
||||||
|
morph_kernel: int = 15,
|
||||||
|
min_area_ratio: float = 0.05,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Detect the pitch area using HSV green thresholding.
|
||||||
|
|
||||||
|
Returns dict with:
|
||||||
|
boundary: list of [x, y] points
|
||||||
|
coverage: float (fraction of frame)
|
||||||
|
"""
|
||||||
|
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
|
||||||
|
|
||||||
|
lower = np.array([hue_low, sat_low, val_low])
|
||||||
|
upper = np.array([hue_high, sat_high, val_high])
|
||||||
|
mask = cv2.inRange(hsv, lower, upper)
|
||||||
|
|
||||||
|
k = morph_kernel
|
||||||
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
|
||||||
|
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
||||||
|
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
||||||
|
|
||||||
|
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
min_area = min_area_ratio * h * w
|
||||||
|
boundary = []
|
||||||
|
coverage = 0.0
|
||||||
|
|
||||||
|
if contours:
|
||||||
|
large = [c for c in contours if cv2.contourArea(c) >= min_area]
|
||||||
|
if large:
|
||||||
|
pitch_contour = max(large, key=cv2.contourArea)
|
||||||
|
boundary = pitch_contour.reshape(-1, 2).tolist()
|
||||||
|
coverage = cv2.contourArea(pitch_contour) / (h * w)
|
||||||
|
|
||||||
|
refined = np.zeros_like(mask)
|
||||||
|
cv2.drawContours(refined, [pitch_contour], -1, 255, cv2.FILLED)
|
||||||
|
mask = refined
|
||||||
|
|
||||||
|
return {
|
||||||
|
"boundary": boundary,
|
||||||
|
"coverage": coverage,
|
||||||
|
"mask": mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def segment_field_debug(
|
||||||
|
image: np.ndarray,
|
||||||
|
**kwargs,
|
||||||
|
) -> dict:
|
||||||
|
"""Same as segment_field but includes a mask overlay for the editor."""
|
||||||
|
result = segment_field(image, **kwargs)
|
||||||
|
mask = result["mask"]
|
||||||
|
|
||||||
|
# RGBA overlay: solid green where mask, fully transparent elsewhere
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
overlay = np.zeros((h, w, 4), dtype=np.uint8)
|
||||||
|
overlay[mask > 0] = [0, 255, 0, 255]
|
||||||
|
_, buf = cv2.imencode(".png", overlay)
|
||||||
|
result["mask_overlay_b64"] = base64.b64encode(buf.tobytes()).decode()
|
||||||
|
|
||||||
|
# Don't send the raw mask over HTTP
|
||||||
|
del result["mask"]
|
||||||
|
return result
|
||||||
@@ -101,6 +101,30 @@ class AnalyzeRegionsDebugResponse(BaseModel):
|
|||||||
horizontal_count: int = 0
|
horizontal_count: int = 0
|
||||||
pair_count: int = 0
|
pair_count: int = 0
|
||||||
|
|
||||||
|
class SegmentFieldRequest(BaseModel):
|
||||||
|
"""Request body for field segmentation."""
|
||||||
|
image: str
|
||||||
|
hue_low: int = 30
|
||||||
|
hue_high: int = 85
|
||||||
|
sat_low: int = 30
|
||||||
|
sat_high: int = 255
|
||||||
|
val_low: int = 30
|
||||||
|
val_high: int = 255
|
||||||
|
morph_kernel: int = 15
|
||||||
|
min_area_ratio: float = 0.05
|
||||||
|
|
||||||
|
class SegmentFieldResponse(BaseModel):
|
||||||
|
"""Response from field segmentation."""
|
||||||
|
boundary: List[List[int]] = Field(default_factory=list)
|
||||||
|
coverage: float = 0.0
|
||||||
|
mask_b64: str = ""
|
||||||
|
|
||||||
|
class SegmentFieldDebugResponse(BaseModel):
|
||||||
|
"""Response from field segmentation with debug overlay."""
|
||||||
|
boundary: List[List[int]] = Field(default_factory=list)
|
||||||
|
coverage: float = 0.0
|
||||||
|
mask_overlay_b64: str = ""
|
||||||
|
|
||||||
class ConfigUpdate(BaseModel):
|
class ConfigUpdate(BaseModel):
|
||||||
"""Request body for updating server configuration."""
|
"""Request body for updating server configuration."""
|
||||||
device: Optional[str] = None
|
device: Optional[str] = None
|
||||||
@@ -54,7 +54,7 @@ def _gpu_log(job_id: str, log_level: str, stage: str, level: str, msg: str):
|
|||||||
|
|
||||||
# --- Request/Response models (generated from core/schema/models/inference.py) ---
|
# --- Request/Response models (generated from core/schema/models/inference.py) ---
|
||||||
|
|
||||||
from models.inference_contract import (
|
from models.models import (
|
||||||
AnalyzeRegionsDebugResponse,
|
AnalyzeRegionsDebugResponse,
|
||||||
AnalyzeRegionsRequest,
|
AnalyzeRegionsRequest,
|
||||||
AnalyzeRegionsResponse,
|
AnalyzeRegionsResponse,
|
||||||
@@ -68,6 +68,9 @@ from models.inference_contract import (
|
|||||||
PreprocessRequest,
|
PreprocessRequest,
|
||||||
PreprocessResponse,
|
PreprocessResponse,
|
||||||
RegionBox,
|
RegionBox,
|
||||||
|
SegmentFieldRequest,
|
||||||
|
SegmentFieldResponse,
|
||||||
|
SegmentFieldDebugResponse,
|
||||||
VLMRequest,
|
VLMRequest,
|
||||||
VLMResponse,
|
VLMResponse,
|
||||||
)
|
)
|
||||||
@@ -310,6 +313,83 @@ def detect_edges_debug_endpoint(req: AnalyzeRegionsRequest, request: Request):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/segment_field", response_model=SegmentFieldResponse)
|
||||||
|
def segment_field_endpoint(req: SegmentFieldRequest, request: Request):
|
||||||
|
job_id, log_level = _job_ctx(request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
image = _decode_image(req.image)
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
t0 = time.monotonic()
|
||||||
|
from models.cv.segmentation import segment_field
|
||||||
|
|
||||||
|
result = segment_field(
|
||||||
|
image,
|
||||||
|
hue_low=req.hue_low,
|
||||||
|
hue_high=req.hue_high,
|
||||||
|
sat_low=req.sat_low,
|
||||||
|
sat_high=req.sat_high,
|
||||||
|
val_low=req.val_low,
|
||||||
|
val_high=req.val_high,
|
||||||
|
morph_kernel=req.morph_kernel,
|
||||||
|
min_area_ratio=req.min_area_ratio,
|
||||||
|
)
|
||||||
|
infer_ms = (time.monotonic() - t0) * 1000
|
||||||
|
|
||||||
|
# Encode mask as base64 PNG for downstream use
|
||||||
|
import cv2
|
||||||
|
_, buf = cv2.imencode(".png", result["mask"])
|
||||||
|
mask_b64 = base64.b64encode(buf.tobytes()).decode()
|
||||||
|
|
||||||
|
_gpu_log(job_id, log_level, "GPU:CV", "DEBUG",
|
||||||
|
f"Field segmentation {w}x{h}: {infer_ms:.0f}ms, coverage={result['coverage']:.1%}")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Field segmentation failed: {e}")
|
||||||
|
|
||||||
|
return SegmentFieldResponse(
|
||||||
|
boundary=result["boundary"],
|
||||||
|
coverage=result["coverage"],
|
||||||
|
mask_b64=mask_b64,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/segment_field/debug", response_model=SegmentFieldDebugResponse)
|
||||||
|
def segment_field_debug_endpoint(req: SegmentFieldRequest, request: Request):
|
||||||
|
job_id, log_level = _job_ctx(request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
image = _decode_image(req.image)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from models.cv.segmentation import segment_field_debug
|
||||||
|
|
||||||
|
result = segment_field_debug(
|
||||||
|
image,
|
||||||
|
hue_low=req.hue_low,
|
||||||
|
hue_high=req.hue_high,
|
||||||
|
sat_low=req.sat_low,
|
||||||
|
sat_high=req.sat_high,
|
||||||
|
val_low=req.val_low,
|
||||||
|
val_high=req.val_high,
|
||||||
|
morph_kernel=req.morph_kernel,
|
||||||
|
min_area_ratio=req.min_area_ratio,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Field segmentation debug failed: {e}")
|
||||||
|
|
||||||
|
return SegmentFieldDebugResponse(
|
||||||
|
boundary=result["boundary"],
|
||||||
|
coverage=result["coverage"],
|
||||||
|
mask_overlay_b64=result["mask_overlay_b64"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
@@ -180,24 +180,11 @@ class WorkerServicer(worker_pb2_grpc.WorkerServiceServicer):
|
|||||||
|
|
||||||
def GetWorkerStatus(self, request, context):
|
def GetWorkerStatus(self, request, context):
|
||||||
"""Get worker health and capabilities."""
|
"""Get worker health and capabilities."""
|
||||||
try:
|
|
||||||
from core.ffmpeg import get_encoders
|
|
||||||
|
|
||||||
encoders = get_encoders()
|
|
||||||
codec_names = [e["name"] for e in encoders.get("video", [])]
|
|
||||||
except Exception:
|
|
||||||
codec_names = []
|
|
||||||
|
|
||||||
# Check for GPU encoders
|
|
||||||
gpu_available = any(
|
|
||||||
"nvenc" in name or "vaapi" in name or "qsv" in name for name in codec_names
|
|
||||||
)
|
|
||||||
|
|
||||||
return worker_pb2.WorkerStatus(
|
return worker_pb2.WorkerStatus(
|
||||||
available=True,
|
available=True,
|
||||||
active_jobs=len(_active_jobs),
|
active_jobs=len(_active_jobs),
|
||||||
supported_codecs=codec_names[:20], # Limit to 20
|
supported_codecs=[],
|
||||||
gpu_available=gpu_available,
|
gpu_available=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user