From 4220b0418e4c33d341fd06ef708ae0f8ad07fc8f Mon Sep 17 00:00:00 2001 From: buenosairesam Date: Mon, 30 Mar 2026 07:22:14 -0300 Subject: [PATCH] phase 4 --- admin/mpr/media_assets/models.py | 18 +- core/api/chunker_sse.py | 73 -- core/api/detect/config.py | 16 +- core/api/detect/replay.py | 52 +- core/api/detect/run.py | 38 +- core/api/detect/sse.py | 2 +- core/api/graphql.py | 384 ----------- core/api/main.py | 64 +- core/api/schema/graphql.py | 226 ------- core/chunker/__init__.py | 64 -- core/chunker/chunker.py | 101 --- core/chunker/collector.py | 98 --- core/chunker/exceptions.py | 64 -- core/chunker/models.py | 54 -- core/chunker/pipeline.py | 279 -------- core/chunker/pool.py | 125 ---- core/chunker/processor.py | 173 ----- core/chunker/queue.py | 76 --- core/chunker/worker.py | 143 ---- core/db/__init__.py | 2 +- core/db/assets.py | 2 +- core/db/brand.py | 2 +- core/db/checkpoint.py | 2 +- core/db/connection.py | 2 +- core/db/fixtures/soccer_broadcast.json | 142 ++++ core/db/job.py | 2 +- core/db/models.py | 24 +- core/db/seed.py | 43 ++ core/db/tables.py | 96 --- {detect => core/detect}/__init__.py | 0 .../detect}/checkpoint/__init__.py | 2 +- {detect => core/detect}/checkpoint/frames.py | 2 +- {detect => core/detect}/checkpoint/replay.py | 80 +-- core/detect/checkpoint/runner_bridge.py | 97 +++ .../detect}/checkpoint/serializer.py | 4 +- {detect => core/detect}/checkpoint/storage.py | 12 +- {detect => core/detect}/emit.py | 4 +- {detect => core/detect}/events.py | 0 {detect => core/detect}/graph/__init__.py | 0 {detect => core/detect}/graph/events.py | 4 +- {detect => core/detect}/graph/nodes.py | 149 ++-- {detect => core/detect}/graph/runner.py | 12 +- {detect => core/detect}/inference/__init__.py | 0 {detect => core/detect}/inference/client.py | 14 + {detect => core/detect}/inference/types.py | 2 +- .../models/pipeline.py => detect/models.py} | 12 +- core/detect/profile.py | 107 +++ {detect => core/detect}/providers/__init__.py | 0 {detect => core/detect}/providers/base.py | 0 {detect => core/detect}/providers/claude.py | 0 {detect => core/detect}/providers/gemini.py | 0 {detect => core/detect}/providers/groq.py | 0 .../detect}/providers/openai_compat.py | 0 detect/sse_contract.py => core/detect/sse.py | 16 + {detect => core/detect}/stages/__init__.py | 1 + {detect => core/detect}/stages/aggregator.py | 4 +- {detect => core/detect}/stages/base.py | 56 +- .../detect}/stages/brand_resolver.py | 6 +- .../detect}/stages/edge_detector.py | 90 ++- core/detect/stages/field_segmentation.py | 141 ++++ .../detect}/stages/frame_extractor.py | 6 +- core/detect/stages/models.py | 106 +++ {detect => core/detect}/stages/ocr_stage.py | 10 +- {detect => core/detect}/stages/preprocess.py | 6 +- .../detect}/stages/registry/__init__.py | 0 .../detect}/stages/registry/_serializers.py | 0 core/detect/stages/registry/cv_analysis.py | 44 ++ .../detect}/stages/registry/detection.py | 21 +- .../detect}/stages/registry/escalation.py | 15 +- .../detect}/stages/registry/output.py | 6 +- .../detect}/stages/registry/preprocessing.py | 29 +- .../detect}/stages/registry/resolution.py | 9 +- .../detect}/stages/scene_filter.py | 6 +- {detect => core/detect}/stages/vlm_cloud.py | 10 +- {detect => core/detect}/stages/vlm_local.py | 14 +- .../detect}/stages/yolo_detector.py | 10 +- {detect => core/detect}/state.py | 7 +- {detect => core/detect}/tracing.py | 2 +- core/ffmpeg/__init__.py | 7 - core/ffmpeg/capabilities.py | 145 ---- core/ffmpeg/transcode.py | 225 ------- {gpu => core/gpu}/.env.template | 0 {gpu => core/gpu}/Dockerfile | 0 {gpu => core/gpu}/__init__.py | 0 {gpu => core/gpu}/config.py | 0 {gpu => core/gpu}/emit.py | 0 core/gpu/models/__init__.py | 6 + {gpu => core/gpu}/models/cv/__init__.py | 0 {gpu => core/gpu}/models/cv/edges.py | 0 core/gpu/models/cv/segmentation.py | 86 +++ .../gpu/models/models.py | 24 + {gpu => core/gpu}/models/ocr.py | 0 {gpu => core/gpu}/models/preprocess.py | 0 {gpu => core/gpu}/models/registry.py | 0 {gpu => core/gpu}/models/vlm.py | 0 {gpu => core/gpu}/models/yolo.py | 0 {gpu => core/gpu}/requirements.txt | 0 {gpu => core/gpu}/run.sh | 0 {gpu => core/gpu}/server.py | 82 ++- core/rpc/server.py | 17 +- core/schema/modelgen.json | 9 +- core/schema/models/__init__.py | 71 +- core/schema/models/checkpoint.py | 1 + core/schema/models/detect_api.py | 31 - core/schema/models/{detect.py => event.py} | 25 + core/schema/models/inference.py | 36 + core/schema/models/job.py | 3 + core/schema/models/pipeline_config.py | 46 -- core/schema/models/{presets.py => preset.py} | 0 core/schema/models/profile.py | 30 + core/schema/models/{sources.py => source.py} | 0 core/schema/models/stage.py | 153 +++++ core/schema/models/stages.py | 69 -- core/schema/models/{views.py => view.py} | 0 core/schema/serializers/pipeline.py | 6 +- ctrl/Tiltfile | 2 +- ctrl/sync.sh | 8 +- detect/checkpoint/runner_bridge.py | 64 -- detect/models.py | 21 - detect/profiles/__init__.py | 35 - detect/profiles/base.py | 96 --- detect/profiles/soccer.py | 122 ---- detect/profiles/stubs.py | 98 --- detect/stages/registry/cv_analysis.py | 45 -- gpu/models/__init__.py | 5 - modelgen/generator/pydantic.py | 1 + modelgen/generator/sqlmodel.py | 13 +- modelgen/types.py | 1 + requirements.txt | 1 + tests/chunker/__init__.py | 0 tests/chunker/conftest.py | 76 --- tests/chunker/test_chunker.py | 149 ---- tests/chunker/test_collector.py | 103 --- tests/chunker/test_exceptions.py | 69 -- tests/chunker/test_pipeline.py | 144 ---- tests/chunker/test_processor.py | 98 --- tests/chunker/test_queue.py | 115 ---- tests/chunker/test_worker.py | 127 ---- tests/detect/manual/run_extract_filter.py | 6 +- tests/detect/manual/run_graph.py | 4 +- tests/detect/manual/run_region_analysis.py | 16 +- tests/detect/manual/seed_scenario.py | 4 +- tests/detect/manual/test_cloud_provider.py | 2 +- .../detect/manual/test_frame_extractor_e2e.py | 4 +- tests/detect/manual/test_ocr_e2e.py | 6 +- tests/detect/manual/test_replay.py | 8 +- tests/detect/test_aggregator.py | 6 +- tests/detect/test_brand_resolver.py | 16 +- tests/detect/test_checkpoint.py | 49 +- tests/detect/test_config_endpoint.py | 2 +- tests/detect/test_edge_sensitivity.py | 2 +- tests/detect/test_frame_extractor.py | 6 +- tests/detect/test_graph.py | 10 +- tests/detect/test_ocr_stage.py | 16 +- tests/detect/test_preprocess.py | 10 +- tests/detect/test_profiles.py | 71 +- tests/detect/test_region_analyzer.py | 32 +- tests/detect/test_replay.py | 96 ++- tests/detect/test_scene_filter.py | 8 +- tests/detect/test_sse_contract.py | 2 +- tests/detect/test_stage_registry.py | 4 +- tests/detect/test_tracing.py | 2 +- tests/detect/test_vlm_cloud.py | 16 +- ui/common/types/generated.ts | 10 +- ui/detection-app/.gitignore | 4 + ui/detection-app/package.json | 7 +- ui/detection-app/pnpm-lock.yaml | 8 + ui/detection-app/src/App.vue | 39 +- .../src/components/StageConfig.vue | 634 ++++++++++++++++++ .../src/components/StageConfigSliders.vue | 385 ----------- .../src/composables/useCheckpointLoader.ts | 26 +- .../src/composables/useEditorState.ts | 131 ++-- ui/detection-app/src/cv/edges.ts | 443 ++++++------ ui/detection-app/src/cv/edgesTs.ts | 278 ++++++++ ui/detection-app/src/cv/index.ts | 95 +-- ui/detection-app/src/cv/opencv.ts | 84 +++ ui/detection-app/src/cv/segmentation.ts | 212 ++++++ ui/detection-app/src/cv/wasmBridge.ts | 121 ++++ ui/detection-app/src/cv/worker.ts | 45 +- ui/detection-app/src/types/sse-contract.ts | 16 + ui/detection-app/src/vite-env.d.ts | 1 + ui/framework/src/tokens.css | 14 + 182 files changed, 3668 insertions(+), 5231 deletions(-) delete mode 100644 core/api/chunker_sse.py delete mode 100644 core/api/graphql.py delete mode 100644 core/api/schema/graphql.py delete mode 100644 core/chunker/__init__.py delete mode 100644 core/chunker/chunker.py delete mode 100644 core/chunker/collector.py delete mode 100644 core/chunker/exceptions.py delete mode 100644 core/chunker/models.py delete mode 100644 core/chunker/pipeline.py delete mode 100644 core/chunker/pool.py delete mode 100644 core/chunker/processor.py delete mode 100644 core/chunker/queue.py delete mode 100644 core/chunker/worker.py create mode 100644 core/db/fixtures/soccer_broadcast.json create mode 100644 core/db/seed.py delete mode 100644 core/db/tables.py rename {detect => core/detect}/__init__.py (100%) rename {detect => core/detect}/checkpoint/__init__.py (95%) rename {detect => core/detect}/checkpoint/frames.py (98%) rename {detect => core/detect}/checkpoint/replay.py (76%) create mode 100644 core/detect/checkpoint/runner_bridge.py rename {detect => core/detect}/checkpoint/serializer.py (97%) rename {detect => core/detect}/checkpoint/storage.py (94%) rename {detect => core/detect}/emit.py (97%) rename {detect => core/detect}/events.py (100%) rename {detect => core/detect}/graph/__init__.py (100%) rename {detect => core/detect}/graph/events.py (90%) rename {detect => core/detect}/graph/nodes.py (64%) rename {detect => core/detect}/graph/runner.py (96%) rename {detect => core/detect}/inference/__init__.py (100%) rename {detect => core/detect}/inference/client.py (93%) rename {detect => core/detect}/inference/types.py (96%) rename core/{schema/models/pipeline.py => detect/models.py} (86%) create mode 100644 core/detect/profile.py rename {detect => core/detect}/providers/__init__.py (100%) rename {detect => core/detect}/providers/base.py (100%) rename {detect => core/detect}/providers/claude.py (100%) rename {detect => core/detect}/providers/gemini.py (100%) rename {detect => core/detect}/providers/groq.py (100%) rename {detect => core/detect}/providers/openai_compat.py (100%) rename detect/sse_contract.py => core/detect/sse.py (90%) rename {detect => core/detect}/stages/__init__.py (91%) rename {detect => core/detect}/stages/aggregator.py (96%) rename {detect => core/detect}/stages/base.py (72%) rename {detect => core/detect}/stages/brand_resolver.py (97%) rename {detect => core/detect}/stages/edge_detector.py (68%) create mode 100644 core/detect/stages/field_segmentation.py rename {detect => core/detect}/stages/frame_extractor.py (95%) create mode 100644 core/detect/stages/models.py rename {detect => core/detect}/stages/ocr_stage.py (94%) rename {detect => core/detect}/stages/preprocess.py (96%) rename {detect => core/detect}/stages/registry/__init__.py (100%) rename {detect => core/detect}/stages/registry/_serializers.py (100%) create mode 100644 core/detect/stages/registry/cv_analysis.py rename {detect => core/detect}/stages/registry/detection.py (62%) rename {detect => core/detect}/stages/registry/escalation.py (73%) rename {detect => core/detect}/stages/registry/output.py (82%) rename {detect => core/detect}/stages/registry/preprocessing.py (64%) rename {detect => core/detect}/stages/registry/resolution.py (77%) rename {detect => core/detect}/stages/scene_filter.py (95%) rename {detect => core/detect}/stages/vlm_cloud.py (95%) rename {detect => core/detect}/stages/vlm_local.py (93%) rename {detect => core/detect}/stages/yolo_detector.py (94%) rename {detect => core/detect}/state.py (72%) rename {detect => core/detect}/tracing.py (98%) delete mode 100644 core/ffmpeg/capabilities.py delete mode 100644 core/ffmpeg/transcode.py rename {gpu => core/gpu}/.env.template (100%) rename {gpu => core/gpu}/Dockerfile (100%) rename {gpu => core/gpu}/__init__.py (100%) rename {gpu => core/gpu}/config.py (100%) rename {gpu => core/gpu}/emit.py (100%) create mode 100644 core/gpu/models/__init__.py rename {gpu => core/gpu}/models/cv/__init__.py (100%) rename {gpu => core/gpu}/models/cv/edges.py (100%) create mode 100644 core/gpu/models/cv/segmentation.py rename gpu/models/inference_contract.py => core/gpu/models/models.py (80%) rename {gpu => core/gpu}/models/ocr.py (100%) rename {gpu => core/gpu}/models/preprocess.py (100%) rename {gpu => core/gpu}/models/registry.py (100%) rename {gpu => core/gpu}/models/vlm.py (100%) rename {gpu => core/gpu}/models/yolo.py (100%) rename {gpu => core/gpu}/requirements.txt (100%) rename {gpu => core/gpu}/run.sh (100%) rename {gpu => core/gpu}/server.py (79%) delete mode 100644 core/schema/models/detect_api.py rename core/schema/models/{detect.py => event.py} (91%) delete mode 100644 core/schema/models/pipeline_config.py rename core/schema/models/{presets.py => preset.py} (100%) create mode 100644 core/schema/models/profile.py rename core/schema/models/{sources.py => source.py} (100%) create mode 100644 core/schema/models/stage.py delete mode 100644 core/schema/models/stages.py rename core/schema/models/{views.py => view.py} (100%) delete mode 100644 detect/checkpoint/runner_bridge.py delete mode 100644 detect/models.py delete mode 100644 detect/profiles/__init__.py delete mode 100644 detect/profiles/base.py delete mode 100644 detect/profiles/soccer.py delete mode 100644 detect/profiles/stubs.py delete mode 100644 detect/stages/registry/cv_analysis.py delete mode 100644 gpu/models/__init__.py delete mode 100644 tests/chunker/__init__.py delete mode 100644 tests/chunker/conftest.py delete mode 100644 tests/chunker/test_chunker.py delete mode 100644 tests/chunker/test_collector.py delete mode 100644 tests/chunker/test_exceptions.py delete mode 100644 tests/chunker/test_pipeline.py delete mode 100644 tests/chunker/test_processor.py delete mode 100644 tests/chunker/test_queue.py delete mode 100644 tests/chunker/test_worker.py create mode 100644 ui/detection-app/.gitignore create mode 100644 ui/detection-app/src/components/StageConfig.vue delete mode 100644 ui/detection-app/src/components/StageConfigSliders.vue create mode 100644 ui/detection-app/src/cv/edgesTs.ts create mode 100644 ui/detection-app/src/cv/opencv.ts create mode 100644 ui/detection-app/src/cv/segmentation.ts create mode 100644 ui/detection-app/src/cv/wasmBridge.ts create mode 100644 ui/detection-app/src/vite-env.d.ts diff --git a/admin/mpr/media_assets/models.py b/admin/mpr/media_assets/models.py index 111f03d..7e3f991 100644 --- a/admin/mpr/media_assets/models.py +++ b/admin/mpr/media_assets/models.py @@ -102,6 +102,7 @@ class Job(models.Model): source_asset_id = models.UUIDField() video_path = models.CharField(max_length=1000) profile_name = models.CharField(max_length=255) + timeline_id = models.UUIDField(null=True, blank=True) parent_id = models.UUIDField(null=True, blank=True) run_type = models.CharField(max_length=20, choices=RunType.choices, default=RunType.INITIAL) config_overrides = models.JSONField(default=dict, blank=True) @@ -113,7 +114,6 @@ class Job(models.Model): brands_found = models.IntegerField(default=0) cloud_llm_calls = models.IntegerField(default=0) estimated_cost_usd = models.FloatField(default=0.0) - celery_task_id = models.CharField(max_length=255, null=True, blank=True) priority = models.IntegerField(default=0) created_at = models.DateTimeField(auto_now_add=True) started_at = models.DateTimeField(null=True, blank=True) @@ -151,6 +151,7 @@ class Checkpoint(models.Model): id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) timeline_id = models.UUIDField() + job_id = models.UUIDField(null=True, blank=True) parent_id = models.UUIDField(null=True, blank=True) stage_outputs = models.JSONField(default=dict, blank=True) config_overrides = models.JSONField(default=dict, blank=True) @@ -185,3 +186,18 @@ class Brand(models.Model): def __str__(self): return str(self.id) + +class Profile(models.Model): + """A content type profile.""" + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + name = models.CharField(max_length=255) + pipeline = models.JSONField(default=dict, blank=True) + configs = models.JSONField(default=dict, blank=True) + + class Meta: + pass + + def __str__(self): + return self.name + diff --git a/core/api/chunker_sse.py b/core/api/chunker_sse.py deleted file mode 100644 index b684cbc..0000000 --- a/core/api/chunker_sse.py +++ /dev/null @@ -1,73 +0,0 @@ -""" -SSE endpoint for chunker pipeline events. - -Uses Redis as the event bus. Pipeline pushes events via core.events, -SSE endpoint polls them. - -GET /chunker/stream/{job_id} → text/event-stream -""" - -import asyncio -import json -import logging -import time -from typing import AsyncGenerator - -from fastapi import APIRouter -from starlette.responses import StreamingResponse - -from core.events import poll_events - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/chunker", tags=["chunker"]) - - -async def _event_generator(job_id: str) -> AsyncGenerator[str, None]: - """ - Generate SSE events by polling Redis for chunk job events. - """ - cursor = 0 - timeout = time.monotonic() + 600 # 10 min max - - while time.monotonic() < timeout: - events, cursor = poll_events(job_id, cursor) - - if not events: - yield f"event: waiting\ndata: {json.dumps({'job_id': job_id})}\n\n" - await asyncio.sleep(0.1) - continue - - for data in events: - event_type = data.pop("event", "update") - payload = {**data, "job_id": job_id} - - yield f"event: {event_type}\ndata: {json.dumps(payload)}\n\n" - - if event_type in ("pipeline_complete", "pipeline_error", "cancelled"): - yield f"event: done\ndata: {json.dumps({'job_id': job_id})}\n\n" - return - - await asyncio.sleep(0.05) - - yield f"event: timeout\ndata: {json.dumps({'job_id': job_id})}\n\n" - - -@router.get("/stream/{job_id}") -async def stream_chunk_job(job_id: str): - """ - SSE stream for a chunk pipeline job. - - The UI connects via native EventSource: - const es = new EventSource('/api/chunker/stream/'); - es.addEventListener('processing', (e) => { ... }); - """ - return StreamingResponse( - _event_generator(job_id), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) diff --git a/core/api/detect/config.py b/core/api/detect/config.py index 43f8a42..2436e61 100644 --- a/core/api/detect/config.py +++ b/core/api/detect/config.py @@ -58,32 +58,30 @@ def write_config(update: ConfigUpdate): @router.get("/config/profiles") -def list_profiles(): +def get_profiles(): """List available detection profiles.""" - from detect.profiles import _PROFILES - return [{"name": name} for name in _PROFILES] + from core.detect.profile import list_profiles as _list + return [{"name": name} for name in _list()] @router.get("/config/profiles/{profile_name}/pipeline") def get_pipeline_config(profile_name: str): """Return the pipeline composition for a profile.""" - from detect.profiles import get_profile + from core.detect.profile import get_profile from fastapi import HTTPException - from dataclasses import asdict try: profile = get_profile(profile_name) except ValueError: raise HTTPException(status_code=404, detail=f"Unknown profile: {profile_name}") - config = profile.pipeline_config() - return asdict(config) + return profile["pipeline"] @router.get("/config/stages", response_model=list[StageConfigInfo]) def list_stage_configs(): """Return the stage palette with config field metadata for the editor.""" - from detect.stages import list_stages + from core.detect.stages import list_stages result = [] for stage in list_stages(): @@ -95,7 +93,7 @@ def list_stage_configs(): @router.get("/config/stages/{stage_name}", response_model=StageConfigInfo) def get_stage_config(stage_name: str): """Return config field metadata for a single stage.""" - from detect.stages import get_stage + from core.detect.stages import get_stage try: stage = get_stage(stage_name) diff --git a/core/api/detect/replay.py b/core/api/detect/replay.py index 2ab36f6..02091bc 100644 --- a/core/api/detect/replay.py +++ b/core/api/detect/replay.py @@ -105,7 +105,7 @@ class ReplaySingleStageResponse(BaseModel): @router.get("/checkpoints/{timeline_id}") def list_checkpoints(timeline_id: str) -> list[CheckpointInfo]: """List available checkpoint stages for a job.""" - from detect.checkpoint import list_checkpoints as _list + from core.detect.checkpoint import list_checkpoints as _list try: stages = _list(timeline_id) @@ -139,10 +139,10 @@ class CheckpointData(BaseModel): def get_checkpoint_data(timeline_id: str, stage: str): """Load checkpoint frames + metadata for the editor UI.""" from uuid import UUID - from core.db.tables import Timeline, Checkpoint + from core.db.models import Timeline, Checkpoint from core.db.connection import get_session from core.db.checkpoint import list_checkpoints - from detect.checkpoint.frames import load_frames_b64 + from core.detect.checkpoint.frames import load_frames_b64 with get_session() as session: timeline = session.get(Timeline, UUID(timeline_id)) @@ -184,7 +184,7 @@ def get_checkpoint_data(timeline_id: str, stage: str): @router.get("/scenarios", response_model=list[ScenarioInfo]) def list_scenarios_endpoint(): """List all available scenarios (bookmarked checkpoints).""" - from core.db.tables import Timeline + from core.db.models import Timeline from core.db.connection import get_session from core.db.checkpoint import list_scenarios @@ -212,7 +212,7 @@ def list_scenarios_endpoint(): @router.post("/replay", response_model=ReplayResponse) def replay(req: ReplayRequest): """Replay pipeline from a specific stage with optional config overrides.""" - from detect.checkpoint import replay_from + from core.detect.checkpoint import replay_from try: result = replay_from( @@ -242,7 +242,7 @@ def replay(req: ReplayRequest): @router.post("/retry", response_model=RetryResponse) def retry(req: RetryRequest): """Queue an async retry of unresolved candidates with different config.""" - from detect.checkpoint.tasks import retry_candidates + from core.detect.checkpoint.tasks import retry_candidates kwargs = { "timeline_id": req.timeline_id, @@ -266,7 +266,7 @@ def retry(req: RetryRequest): @router.post("/replay-stage", response_model=ReplaySingleStageResponse) def replay_single_stage(req: ReplaySingleStageRequest): """Replay a single stage on specific frames — fast path for interactive tuning.""" - from detect.checkpoint.replay import replay_single_stage as _replay + from core.detect.checkpoint.replay import replay_single_stage as _replay try: result = _replay( @@ -361,3 +361,41 @@ async def gpu_detect_edges_debug(request: Request): media_type="application/json") except Exception as e: raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}") + + +@router.post("/gpu/segment_field") +async def gpu_segment_field(request: Request): + """Proxy to GPU inference server — field segmentation.""" + import httpx + + body = await request.body() + try: + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + f"{_gpu_url()}/segment_field", + content=body, + headers={"Content-Type": "application/json"}, + ) + return Response(content=resp.content, status_code=resp.status_code, + media_type="application/json") + except Exception as e: + raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}") + + +@router.post("/gpu/segment_field/debug") +async def gpu_segment_field_debug(request: Request): + """Proxy to GPU inference server — field segmentation with debug overlay.""" + import httpx + + body = await request.body() + try: + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + f"{_gpu_url()}/segment_field/debug", + content=body, + headers={"Content-Type": "application/json"}, + ) + return Response(content=resp.content, status_code=resp.status_code, + media_type="application/json") + except Exception as e: + raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}") diff --git a/core/api/detect/run.py b/core/api/detect/run.py index 3533629..404520b 100644 --- a/core/api/detect/run.py +++ b/core/api/detect/run.py @@ -60,9 +60,9 @@ def _resolve_video_path(video_path: str) -> str: @router.post("/run", response_model=RunResponse) def run_pipeline(req: RunRequest): """Launch a detection pipeline run on a source chunk.""" - from detect import emit - from detect.graph import get_pipeline - from detect.state import DetectState + from core.detect import emit + from core.detect.graph import get_pipeline + from core.detect.state import DetectState local_path = _resolve_video_path(req.video_path) job_id = str(uuid.uuid4()) @@ -79,7 +79,7 @@ def run_pipeline(req: RunRequest): # Clear any stale events from a previous run with same job_id from core.events import _get_redis - from detect.events import DETECT_EVENTS_PREFIX + from core.detect.events import DETECT_EVENTS_PREFIX r = _get_redis() r.delete(f"{DETECT_EVENTS_PREFIX}:{job_id}") @@ -97,7 +97,7 @@ def run_pipeline(req: RunRequest): source_asset_id=req.source_asset_id, ) - from detect.graph import ( + from core.detect.graph import ( PipelineCancelled, set_cancel_check, clear_cancel_check, init_pause, clear_pause, ) @@ -117,7 +117,7 @@ def run_pipeline(req: RunRequest): emit.job_complete(job_id, {"status": "cancelled"}) except Exception as e: logger.exception("Pipeline run %s failed: %s", job_id, e) - from detect.graph import _node_states, NODES + from core.detect.graph import _node_states, NODES if job_id in _node_states: states = _node_states[job_id] for node in reversed(NODES): @@ -145,7 +145,7 @@ def run_pipeline(req: RunRequest): @router.post("/stop/{job_id}") def stop_pipeline(job_id: str): """Stop a running pipeline. Signals cancellation; the thread checks on next stage.""" - from detect import emit + from core.detect import emit if job_id not in _running_jobs: raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}") @@ -158,7 +158,7 @@ def stop_pipeline(job_id: str): @router.post("/pause/{job_id}") def pause(job_id: str): """Pause a running pipeline after the current stage completes.""" - from detect.graph import pause_pipeline + from core.detect.graph import pause_pipeline if job_id not in _running_jobs: raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}") @@ -170,7 +170,7 @@ def pause(job_id: str): @router.post("/resume/{job_id}") def resume(job_id: str): """Resume a paused pipeline.""" - from detect.graph import resume_pipeline + from core.detect.graph import resume_pipeline if job_id not in _running_jobs: raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}") @@ -182,7 +182,7 @@ def resume(job_id: str): @router.post("/step/{job_id}") def step(job_id: str): """Run one stage then pause again.""" - from detect.graph import step_pipeline + from core.detect.graph import step_pipeline if job_id not in _running_jobs: raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}") @@ -194,7 +194,7 @@ def step(job_id: str): @router.post("/pause-after-stage/{job_id}") def toggle_pause_after_stage(job_id: str, enabled: bool = True): """Toggle pause-after-each-stage mode.""" - from detect.graph import set_pause_after_stage + from core.detect.graph import set_pause_after_stage if job_id not in _running_jobs: raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}") @@ -206,7 +206,7 @@ def toggle_pause_after_stage(job_id: str, enabled: bool = True): @router.get("/status/{job_id}") def pipeline_status(job_id: str): """Get pipeline run status.""" - from detect.graph import is_paused + from core.detect.graph import is_paused running = job_id in _running_jobs paused = is_paused(job_id) @@ -224,11 +224,23 @@ def pipeline_status(job_id: str): return {"status": status, "job_id": job_id} +@router.get("/timeline/{job_id}") +def get_timeline_for_job(job_id: str): + """Get the timeline_id for a running or completed job.""" + from core.detect.checkpoint.runner_bridge import get_timeline_id + + tid = get_timeline_id(job_id) + if tid is None: + raise HTTPException(status_code=404, detail=f"No timeline for job: {job_id}") + + return {"timeline_id": tid, "job_id": job_id} + + @router.post("/clear/{job_id}") def clear_pipeline(job_id: str): """Clear events for a job from Redis.""" from core.events import _get_redis - from detect.events import DETECT_EVENTS_PREFIX + from core.detect.events import DETECT_EVENTS_PREFIX r = _get_redis() r.delete(f"{DETECT_EVENTS_PREFIX}:{job_id}") diff --git a/core/api/detect/sse.py b/core/api/detect/sse.py index 41064f6..dc5c6bf 100644 --- a/core/api/detect/sse.py +++ b/core/api/detect/sse.py @@ -17,7 +17,7 @@ from fastapi import APIRouter from starlette.responses import StreamingResponse from core.events import poll_events -from detect.events import DETECT_EVENTS_PREFIX, TERMINAL_EVENTS +from core.detect.events import DETECT_EVENTS_PREFIX, TERMINAL_EVENTS logger = logging.getLogger(__name__) diff --git a/core/api/graphql.py b/core/api/graphql.py deleted file mode 100644 index 6cd3b41..0000000 --- a/core/api/graphql.py +++ /dev/null @@ -1,384 +0,0 @@ -""" -GraphQL API using strawberry, served via FastAPI. - -Primary API for MPR — all client interactions go through GraphQL. -Uses core.db for data access. -Types are generated from schema/ via modelgen — see api/schema/graphql.py. -""" - -import os -from typing import List, Optional -from uuid import UUID - -import strawberry -from strawberry.schema.config import StrawberryConfig -from strawberry.types import Info - -from core.api.schema.graphql import ( - CancelResultType, - ChunkJobType, - ChunkOutputFileType, - CreateChunkJobInput, - CreateJobInput, - DeleteResultType, - MediaAssetType, - ScanResultType, - SystemStatusType, - TranscodeJobType, - TranscodePresetType, - UpdateAssetInput, -) -from core.storage import BUCKET_IN, list_objects, upload_file - -VIDEO_EXTS = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv", ".wmv", ".m4v"} -AUDIO_EXTS = {".mp3", ".wav", ".flac", ".aac", ".ogg", ".m4a"} -MEDIA_EXTS = VIDEO_EXTS | AUDIO_EXTS - - -# --------------------------------------------------------------------------- -# Queries -# --------------------------------------------------------------------------- - - -@strawberry.type -class Query: - @strawberry.field - def assets( - self, - info: Info, - status: Optional[str] = None, - search: Optional[str] = None, - ) -> List[MediaAssetType]: - from core.db import list_assets - - return list_assets(status=status, search=search) - - @strawberry.field - def asset(self, info: Info, id: UUID) -> Optional[MediaAssetType]: - from core.db import get_asset - - try: - return get_asset(id) - except Exception: - return None - - @strawberry.field - def jobs( - self, - info: Info, - status: Optional[str] = None, - source_asset_id: Optional[UUID] = None, - ) -> List[TranscodeJobType]: - from core.db import list_jobs - - return list_jobs(status=status, source_asset_id=source_asset_id) - - @strawberry.field - def job(self, info: Info, id: UUID) -> Optional[TranscodeJobType]: - from core.db import get_job - - try: - return get_job(id) - except Exception: - return None - - @strawberry.field - def presets(self, info: Info) -> List[TranscodePresetType]: - from core.db import list_presets - - return list_presets() - - @strawberry.field - def system_status(self, info: Info) -> SystemStatusType: - return SystemStatusType(status="ok", version="0.1.0") - - @strawberry.field - def chunk_output_files(self, info: Info, job_id: str) -> List[ChunkOutputFileType]: - """List output chunk files for a completed job from media/out/.""" - from pathlib import Path - - media_out = os.environ.get("MEDIA_OUT_DIR", "/app/media/out") - output_dir = Path(media_out) / "chunks" / job_id - if not output_dir.is_dir(): - return [] - return [ - ChunkOutputFileType( - key=f.name, - size=f.stat().st_size, - url=f"/media/out/chunks/{job_id}/{f.name}", - ) - for f in sorted(output_dir.iterdir()) - if f.is_file() - ] - - -# --------------------------------------------------------------------------- -# Mutations -# --------------------------------------------------------------------------- - - -@strawberry.type -class Mutation: - @strawberry.mutation - def scan_media_folder(self, info: Info) -> ScanResultType: - import logging - from pathlib import Path - - from core.db import create_asset, get_asset_filenames - - logger = logging.getLogger(__name__) - - # Sync local media/in/ files to MinIO (handles fresh installs / pruned volumes) - local_media = Path("/app/media/in") - if local_media.is_dir(): - existing_keys = {o["key"] for o in list_objects(BUCKET_IN)} - for f in local_media.iterdir(): - if f.is_file() and f.suffix.lower() in MEDIA_EXTS: - if f.name not in existing_keys: - try: - upload_file(str(f), BUCKET_IN, f.name) - logger.info("Uploaded %s to MinIO", f.name) - except Exception as e: - logger.warning("Failed to upload %s: %s", f.name, e) - - objects = list_objects(BUCKET_IN, extensions=MEDIA_EXTS) - existing = get_asset_filenames() - - registered = [] - skipped = [] - - for obj in objects: - if obj["filename"] in existing: - skipped.append(obj["filename"]) - continue - try: - create_asset( - filename=obj["filename"], - file_path=obj["key"], - file_size=obj["size"], - ) - registered.append(obj["filename"]) - except Exception: - pass - - return ScanResultType( - found=len(objects), - registered=len(registered), - skipped=len(skipped), - files=registered, - ) - - @strawberry.mutation - def create_job(self, info: Info, input: CreateJobInput) -> TranscodeJobType: - from pathlib import Path - - from core.db import create_job, get_asset, get_preset - - try: - source = get_asset(input.source_asset_id) - except Exception: - raise Exception("Source asset not found") - - preset = None - preset_snapshot = {} - if input.preset_id: - try: - preset = get_preset(input.preset_id) - preset_snapshot = { - "name": preset.name, - "container": preset.container, - "video_codec": preset.video_codec, - "audio_codec": preset.audio_codec, - } - except Exception: - raise Exception("Preset not found") - - if not preset and not input.trim_start and not input.trim_end: - raise Exception("Must specify preset_id or trim_start/trim_end") - - output_filename = input.output_filename - if not output_filename: - stem = Path(source.filename).stem - ext = preset_snapshot.get("container", "mp4") if preset else "mp4" - output_filename = f"{stem}_output.{ext}" - - job = create_job( - source_asset_id=source.id, - preset_id=preset.id if preset else None, - preset_snapshot=preset_snapshot, - trim_start=input.trim_start, - trim_end=input.trim_end, - output_filename=output_filename, - output_path=output_filename, - priority=input.priority or 0, - ) - - payload = { - "source_key": source.file_path, - "output_key": output_filename, - "preset": preset_snapshot or None, - "trim_start": input.trim_start, - "trim_end": input.trim_end, - "duration": source.duration, - } - - executor_mode = os.environ.get("MPR_EXECUTOR", "local") - if executor_mode in ("lambda", "gcp"): - from core.jobs.executor import get_executor - - get_executor().run( - job_type="transcode", - job_id=str(job.id), - payload=payload, - ) - else: - from core.jobs.task import run_job - - result = run_job.delay( - job_type="transcode", - job_id=str(job.id), - payload=payload, - ) - job.celery_task_id = result.id - job.save(update_fields=["celery_task_id"]) - - return job - - @strawberry.mutation - def cancel_job(self, info: Info, id: UUID) -> TranscodeJobType: - from core.db import get_job, update_job - - try: - job = get_job(id) - except Exception: - raise Exception("Job not found") - - if job.status not in ("pending", "processing"): - raise Exception(f"Cannot cancel job with status: {job.status}") - - return update_job(job, status="cancelled") - - @strawberry.mutation - def retry_job(self, info: Info, id: UUID) -> TranscodeJobType: - from core.db import get_job, update_job - - try: - job = get_job(id) - except Exception: - raise Exception("Job not found") - - if job.status != "failed": - raise Exception("Only failed jobs can be retried") - - return update_job(job, status="pending", progress=0, error_message=None) - - @strawberry.mutation - def update_asset(self, info: Info, id: UUID, input: UpdateAssetInput) -> MediaAssetType: - from core.db import get_asset, update_asset - - try: - asset = get_asset(id) - except Exception: - raise Exception("Asset not found") - - fields = {} - if input.comments is not None: - fields["comments"] = input.comments - if input.tags is not None: - fields["tags"] = input.tags - - if fields: - asset = update_asset(asset, **fields) - - return asset - - @strawberry.mutation - def delete_asset(self, info: Info, id: UUID) -> DeleteResultType: - from core.db import delete_asset, get_asset - - try: - asset = get_asset(id) - delete_asset(asset) - return DeleteResultType(ok=True) - except Exception: - raise Exception("Asset not found") - - @strawberry.mutation - def create_chunk_job(self, info: Info, input: CreateChunkJobInput) -> ChunkJobType: - """Create and dispatch a chunk pipeline job.""" - import uuid - - from core.db import get_asset - - try: - source = get_asset(input.source_asset_id) - except Exception: - raise Exception("Source asset not found") - - job_id = str(uuid.uuid4()) - - payload = { - "source_key": source.file_path, - "chunk_duration": input.chunk_duration, - "num_workers": input.num_workers, - "max_retries": input.max_retries, - "processor_type": input.processor_type, - "start_time": input.start_time, - "end_time": input.end_time, - } - - executor_mode = os.environ.get("MPR_EXECUTOR", "local") - celery_task_id = None - - if executor_mode in ("lambda", "gcp"): - from core.jobs.executor import get_executor - - get_executor().run( - job_type="chunk", - job_id=job_id, - payload=payload, - ) - else: - from core.jobs.task import run_job - - result = run_job.delay( - job_type="chunk", - job_id=job_id, - payload=payload, - ) - celery_task_id = result.id - - return ChunkJobType( - id=uuid.UUID(job_id), - source_asset_id=input.source_asset_id, - chunk_duration=input.chunk_duration, - num_workers=input.num_workers, - max_retries=input.max_retries, - processor_type=input.processor_type, - status="pending", - progress=0.0, - priority=input.priority, - celery_task_id=celery_task_id, - ) - - @strawberry.mutation - def cancel_chunk_job(self, info: Info, celery_task_id: str) -> CancelResultType: - """Cancel a running chunk job by revoking its Celery task.""" - try: - from admin.mpr.celery import app as celery_app - - celery_app.control.revoke(celery_task_id, terminate=True, signal="SIGTERM") - return CancelResultType(ok=True, message="Task revoked") - except Exception as e: - return CancelResultType(ok=False, message=str(e)) - - -# --------------------------------------------------------------------------- -# Schema -# --------------------------------------------------------------------------- - -schema = strawberry.Schema( - query=Query, - mutation=Mutation, - config=StrawberryConfig(auto_camel_case=False), -) diff --git a/core/api/main.py b/core/api/main.py index ad76427..4820c20 100644 --- a/core/api/main.py +++ b/core/api/main.py @@ -1,48 +1,38 @@ """ MPR FastAPI Application - -Serves GraphQL API and Lambda callback endpoint. """ import os import sys -from typing import Optional -from uuid import UUID # Add project root to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) from contextlib import asynccontextmanager -from fastapi import FastAPI, Header, HTTPException +from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from strawberry.fastapi import GraphQLRouter -from core.api.chunker_sse import router as chunker_router from core.api.detect import router as detect_router -from core.api.graphql import schema as graphql_schema - -CALLBACK_API_KEY = os.environ.get("CALLBACK_API_KEY", "") @asynccontextmanager async def lifespan(app): - # Create/reset DB tables on startup from core.db.connection import create_tables + from core.db.seed import seed_profiles create_tables() + seed_profiles() yield app = FastAPI( title="MPR API", - description="Media Processor — GraphQL API", version="0.1.0", docs_url="/docs", redoc_url="/redoc", lifespan=lifespan, ) -# CORS app.add_middleware( CORSMiddleware, allow_origins=["http://mpr.local.ar", "http://k8s.mpr.local.ar", "http://localhost:5173"], @@ -51,13 +41,6 @@ app.add_middleware( allow_headers=["*"], ) -# GraphQL -graphql_router = GraphQLRouter(schema=graphql_schema, graphql_ide="graphiql") -app.include_router(graphql_router, prefix="/graphql") - -# Chunker SSE -app.include_router(chunker_router) - # Detection API (sources, run, SSE, replay, config) app.include_router(detect_router) @@ -69,48 +52,7 @@ def health(): @app.get("/") def root(): - """API root.""" return { "name": "MPR API", "version": "0.1.0", - "graphql": "/graphql", } - - -@app.post("/api/jobs/{job_id}/callback") -def job_callback( - job_id: UUID, - payload: dict, - x_api_key: Optional[str] = Header(None), -): - """ - Callback endpoint for Lambda to report job completion. - Protected by API key. - """ - if CALLBACK_API_KEY and x_api_key != CALLBACK_API_KEY: - raise HTTPException(status_code=403, detail="Invalid API key") - - from django.utils import timezone - - from core.db import get_job, update_job - - try: - job = get_job(job_id) - except Exception: - raise HTTPException(status_code=404, detail="Job not found") - - status = payload.get("status", "failed") - fields = { - "status": status, - "progress": 100.0 if status == "completed" else job.progress, - } - - if payload.get("error"): - fields["error_message"] = payload["error"] - - if status in ("completed", "failed"): - fields["completed_at"] = timezone.now() - - update_job(job, **fields) - - return {"ok": True} diff --git a/core/api/schema/graphql.py b/core/api/schema/graphql.py deleted file mode 100644 index cf4c9a3..0000000 --- a/core/api/schema/graphql.py +++ /dev/null @@ -1,226 +0,0 @@ -""" -Strawberry Types - GENERATED FILE - -Do not edit directly. Regenerate using modelgen. -""" - -import strawberry -from enum import Enum -from typing import List, Optional -from uuid import UUID -from datetime import datetime -from strawberry.scalars import JSON - - -@strawberry.enum -class AssetStatus(Enum): - PENDING = "pending" - READY = "ready" - ERROR = "error" - - -@strawberry.enum -class JobStatus(Enum): - PENDING = "pending" - PROCESSING = "processing" - COMPLETED = "completed" - FAILED = "failed" - CANCELLED = "cancelled" - - -@strawberry.type -class MediaAssetType: - """A video/audio file registered in the system.""" - - id: Optional[UUID] = None - filename: Optional[str] = None - file_path: Optional[str] = None - status: Optional[str] = None - error_message: Optional[str] = None - file_size: Optional[float] = None - duration: Optional[float] = None - video_codec: Optional[str] = None - audio_codec: Optional[str] = None - width: Optional[int] = None - height: Optional[int] = None - framerate: Optional[float] = None - bitrate: Optional[int] = None - properties: Optional[JSON] = None - comments: Optional[str] = None - tags: Optional[List[str]] = None - created_at: Optional[datetime] = None - updated_at: Optional[datetime] = None - - -@strawberry.type -class TranscodePresetType: - """A reusable transcoding configuration (like Handbrake presets).""" - - id: Optional[UUID] = None - name: Optional[str] = None - description: Optional[str] = None - is_builtin: Optional[bool] = None - container: Optional[str] = None - video_codec: Optional[str] = None - video_bitrate: Optional[str] = None - video_crf: Optional[int] = None - video_preset: Optional[str] = None - resolution: Optional[str] = None - framerate: Optional[float] = None - audio_codec: Optional[str] = None - audio_bitrate: Optional[str] = None - audio_channels: Optional[int] = None - audio_samplerate: Optional[int] = None - extra_args: Optional[List[str]] = None - created_at: Optional[datetime] = None - updated_at: Optional[datetime] = None - - -@strawberry.type -class TranscodeJobType: - """A transcoding or trimming job in the queue.""" - - id: Optional[UUID] = None - source_asset_id: Optional[UUID] = None - preset_id: Optional[UUID] = None - preset_snapshot: Optional[JSON] = None - trim_start: Optional[float] = None - trim_end: Optional[float] = None - output_filename: Optional[str] = None - output_path: Optional[str] = None - output_asset_id: Optional[UUID] = None - status: Optional[str] = None - progress: Optional[float] = None - current_frame: Optional[int] = None - current_time: Optional[float] = None - speed: Optional[str] = None - error_message: Optional[str] = None - celery_task_id: Optional[str] = None - execution_arn: Optional[str] = None - priority: Optional[int] = None - created_at: Optional[datetime] = None - started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None - - -@strawberry.input -class CreateJobInput: - """Request body for creating a transcode/trim job.""" - - source_asset_id: UUID - preset_id: Optional[UUID] = None - trim_start: Optional[float] = None - trim_end: Optional[float] = None - output_filename: Optional[str] = None - priority: int = 0 - - -@strawberry.input -class UpdateAssetInput: - """Request body for updating asset metadata.""" - - comments: Optional[str] = None - tags: Optional[List[str]] = None - - -@strawberry.type -class SystemStatusType: - """System status response.""" - - status: Optional[str] = None - version: Optional[str] = None - - -@strawberry.type -class ScanResultType: - """Result of scanning the media input bucket.""" - - found: Optional[int] = None - registered: Optional[int] = None - skipped: Optional[int] = None - files: Optional[List[str]] = None - - -@strawberry.type -class DeleteResultType: - """Result of a delete operation.""" - - ok: Optional[bool] = None - - -@strawberry.type -class WorkerStatusType: - """Worker health and capabilities.""" - - available: Optional[bool] = None - active_jobs: Optional[int] = None - supported_codecs: Optional[List[str]] = None - gpu_available: Optional[bool] = None - - -@strawberry.enum -class ChunkJobStatus(Enum): - PENDING = "pending" - CHUNKING = "chunking" - PROCESSING = "processing" - COLLECTING = "collecting" - COMPLETED = "completed" - FAILED = "failed" - CANCELLED = "cancelled" - - -@strawberry.type -class ChunkJobType: - """A chunk pipeline job.""" - - id: Optional[UUID] = None - source_asset_id: Optional[UUID] = None - chunk_duration: Optional[float] = None - num_workers: Optional[int] = None - max_retries: Optional[int] = None - processor_type: Optional[str] = None - status: Optional[str] = None - progress: Optional[float] = None - total_chunks: Optional[int] = None - processed_chunks: Optional[int] = None - failed_chunks: Optional[int] = None - retry_count: Optional[int] = None - error_message: Optional[str] = None - throughput_mbps: Optional[float] = None - elapsed_seconds: Optional[float] = None - celery_task_id: Optional[str] = None - priority: Optional[int] = None - created_at: Optional[datetime] = None - started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None - - -@strawberry.input -class CreateChunkJobInput: - """Request body for creating a chunk pipeline job.""" - - source_asset_id: UUID - chunk_duration: float = 10.0 - num_workers: int = 4 - max_retries: int = 3 - processor_type: str = "ffmpeg" - priority: int = 0 - start_time: Optional[float] = None - end_time: Optional[float] = None - - -@strawberry.type -class CancelResultType: - """Result of cancelling a chunk job.""" - - ok: bool = False - message: Optional[str] = None - - -@strawberry.type -class ChunkOutputFileType: - """A chunk output file in S3/MinIO with presigned download URL.""" - - key: str - size: int = 0 - url: str = "" diff --git a/core/chunker/__init__.py b/core/chunker/__init__.py deleted file mode 100644 index 81effee..0000000 --- a/core/chunker/__init__.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -Chunker pipeline — splits files into chunks, processes concurrently, reassembles in order. - -Public API: - Pipeline — orchestrates the full pipeline - PipelineResult — aggregate result dataclass - Chunker — file → Chunk generator - ChunkQueue — bounded thread-safe queue - WorkerPool — manages N worker threads - ResultCollector — heapq-based ordered reassembly -""" - -from .chunker import Chunker -from .collector import ResultCollector -from .exceptions import ( - ChunkChecksumError, - ChunkError, - ChunkReadError, - PipelineError, - ProcessingError, - ProcessorFailureError, - ProcessorTimeoutError, - ReassemblyError, -) -from .models import Chunk, ChunkResult, PipelineResult -from .pipeline import Pipeline -from .pool import WorkerPool -from .processor import ( - ChecksumProcessor, - CompositeProcessor, - FFmpegExtractProcessor, - Processor, - SimulatedDecodeProcessor, -) -from .queue import ChunkQueue - -__all__ = [ - # Core - "Pipeline", - "PipelineResult", - # Components - "Chunker", - "ChunkQueue", - "WorkerPool", - "ResultCollector", - # Models - "Chunk", - "ChunkResult", - # Processors - "Processor", - "ChecksumProcessor", - "SimulatedDecodeProcessor", - "CompositeProcessor", - "FFmpegExtractProcessor", - # Exceptions - "PipelineError", - "ChunkError", - "ChunkReadError", - "ChunkChecksumError", - "ProcessingError", - "ProcessorFailureError", - "ProcessorTimeoutError", - "ReassemblyError", -] diff --git a/core/chunker/chunker.py b/core/chunker/chunker.py deleted file mode 100644 index 53a2eb0..0000000 --- a/core/chunker/chunker.py +++ /dev/null @@ -1,101 +0,0 @@ -""" -Chunker — probes a media file and yields time-based Chunk objects. - -Demonstrates: -- Function parameters and defaults (Interview Topic 1) -- List comprehensions and efficient iteration / generators (Interview Topic 3) -""" - -import math -import os -from typing import Generator - -from core.ffmpeg.probe import probe_file - -from .exceptions import ChunkReadError -from .models import Chunk - - -class Chunker: - """ - Splits a media file into time-based chunks via a generator. - - Uses FFmpeg probe to get duration, then yields Chunk objects - representing time segments (no data read — extraction happens in the processor). - - Args: - file_path: Path to the source media file - chunk_duration: Duration of each chunk in seconds (default: 10.0) - """ - - def __init__( - self, - file_path: str, - chunk_duration: float = 10.0, - start_time: float | None = None, - end_time: float | None = None, - ): - if not os.path.isfile(file_path): - raise ChunkReadError(f"File not found: {file_path}") - if chunk_duration <= 0: - raise ValueError("chunk_duration must be positive") - - self.file_path = file_path - self.chunk_duration = chunk_duration - self.file_size = os.path.getsize(file_path) - full_duration = self._probe_duration() - - # Apply time range - self.range_start = max(start_time or 0.0, 0.0) - self.range_end = min(end_time or full_duration, full_duration) - if self.range_start >= self.range_end: - raise ValueError( - f"Invalid range: start={self.range_start} >= end={self.range_end}" - ) - self.source_duration = self.range_end - self.range_start - - def _probe_duration(self) -> float: - """Get source file duration via FFmpeg probe.""" - try: - result = probe_file(self.file_path) - if result.duration is None or result.duration <= 0: - raise ChunkReadError( - f"Cannot determine duration for {self.file_path}" - ) - return result.duration - except ChunkReadError: - raise - except Exception as e: - raise ChunkReadError( - f"Failed to probe {self.file_path}: {e}" - ) from e - - @property - def expected_chunks(self) -> int: - """Calculate expected number of chunks (last chunk may be shorter).""" - if self.source_duration <= 0: - return 0 - return math.ceil(self.source_duration / self.chunk_duration) - - def chunks(self) -> Generator[Chunk, None, None]: - """ - Yield Chunk objects representing time segments of the source file. - - Generator-based: chunks are yielded on demand. - Each chunk defines a time range — actual extraction is done by the processor. - """ - total = self.expected_chunks - for sequence in range(total): - start_time = self.range_start + sequence * self.chunk_duration - end_time = min( - start_time + self.chunk_duration, self.range_end - ) - duration = end_time - start_time - - yield Chunk( - sequence=sequence, - start_time=start_time, - end_time=end_time, - source_path=self.file_path, - duration=duration, - ) diff --git a/core/chunker/collector.py b/core/chunker/collector.py deleted file mode 100644 index 4e4b5fd..0000000 --- a/core/chunker/collector.py +++ /dev/null @@ -1,98 +0,0 @@ -""" -ResultCollector — reassembles chunk results in sequence order using a min-heap. - -Demonstrates: -- Algorithms and sorting (Interview Topic 6) — heapq for ordered reassembly -- Core data structures (Interview Topic 5) — heap, deque -""" - -import heapq -from collections import deque -from typing import List - -from .exceptions import ReassemblyError -from .models import ChunkResult - - -class ResultCollector: - """ - Receives ChunkResults out of order, emits them in sequence order. - - Uses a min-heap keyed on sequence number. Only emits a chunk when - all prior sequences have been accounted for. - - Args: - total_chunks: Expected total number of chunks - """ - - def __init__(self, total_chunks: int): - self.total_chunks = total_chunks - self._heap: List[tuple[int, ChunkResult]] = [] - self._next_sequence = 0 - self._emitted: List[ChunkResult] = [] - self._seen_sequences: set[int] = set() - # Sliding window for throughput calculation - self._recent_times: deque[float] = deque(maxlen=50) - - def add(self, result: ChunkResult) -> List[ChunkResult]: - """ - Add a result and return any newly emittable results in order. - - Args: - result: A ChunkResult (may arrive out of order) - - Returns: - List of results that can now be emitted in sequence order - (may be empty if we're still waiting for earlier sequences) - - Raises: - ReassemblyError: If a duplicate sequence is received - """ - if result.sequence in self._seen_sequences: - raise ReassemblyError( - f"Duplicate sequence number: {result.sequence}" - ) - self._seen_sequences.add(result.sequence) - - # Track processing time for throughput - if result.processing_time > 0: - self._recent_times.append(result.processing_time) - - # Push to min-heap - heapq.heappush(self._heap, (result.sequence, result)) - - # Emit all consecutive results starting from _next_sequence - newly_emitted = [] - while self._heap and self._heap[0][0] == self._next_sequence: - _, emitted_result = heapq.heappop(self._heap) - self._emitted.append(emitted_result) - newly_emitted.append(emitted_result) - self._next_sequence += 1 - - return newly_emitted - - @property - def is_complete(self) -> bool: - """True if all expected chunks have been emitted in order.""" - return self._next_sequence == self.total_chunks - - @property - def buffered_count(self) -> int: - """Number of results waiting in the heap (arrived out of order).""" - return len(self._heap) - - @property - def emitted_count(self) -> int: - """Number of results emitted in sequence order.""" - return len(self._emitted) - - @property - def avg_processing_time(self) -> float: - """Average processing time from recent results (sliding window).""" - if not self._recent_times: - return 0.0 - return sum(self._recent_times) / len(self._recent_times) - - def get_ordered_results(self) -> List[ChunkResult]: - """Get all emitted results in sequence order.""" - return list(self._emitted) diff --git a/core/chunker/exceptions.py b/core/chunker/exceptions.py deleted file mode 100644 index 2426152..0000000 --- a/core/chunker/exceptions.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -Chunker exception hierarchy. - -Demonstrates: Managing exceptions and writing resilient code (Interview Topic 7). -""" - - -class PipelineError(Exception): - """Base exception for all chunker pipeline errors.""" - pass - - -class ChunkError(PipelineError): - """Errors related to chunk creation or validation.""" - pass - - -class ChunkReadError(ChunkError): - """Failed to read chunk data from source file.""" - pass - - -class ChunkChecksumError(ChunkError): - """Chunk data integrity validation failed.""" - - def __init__(self, sequence: int, expected: str, actual: str): - self.sequence = sequence - self.expected = expected - self.actual = actual - super().__init__( - f"Chunk {sequence}: checksum mismatch " - f"(expected={expected}, actual={actual})" - ) - - -class ProcessingError(PipelineError): - """Errors during chunk processing by workers.""" - pass - - -class ProcessorTimeoutError(ProcessingError): - """Processor exceeded allowed time for a chunk.""" - - def __init__(self, sequence: int, timeout: float): - self.sequence = sequence - self.timeout = timeout - super().__init__(f"Chunk {sequence}: processor timed out after {timeout}s") - - -class ProcessorFailureError(ProcessingError): - """Processor failed to process a chunk after all retries.""" - - def __init__(self, sequence: int, retries: int, original_error: Exception): - self.sequence = sequence - self.retries = retries - self.original_error = original_error - super().__init__( - f"Chunk {sequence}: failed after {retries} retries — {original_error}" - ) - - -class ReassemblyError(PipelineError): - """Errors during result collection and ordering.""" - pass diff --git a/core/chunker/models.py b/core/chunker/models.py deleted file mode 100644 index d2f6a7d..0000000 --- a/core/chunker/models.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -Internal data models for the chunker pipeline. - -These are pipeline-internal dataclasses, not schema models. -Schema-level ChunkJob is in core/schema/models/jobs.py. - -Demonstrates: Core data structures (Interview Topic 5). -""" - -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional - - -@dataclass -class Chunk: - """A time-based segment of the source media file.""" - - sequence: int - start_time: float # seconds - end_time: float # seconds - source_path: str # path to source file - duration: float # end_time - start_time - checksum: str = "" # computed after extraction - - -@dataclass -class ChunkResult: - """Result of processing a single chunk.""" - - sequence: int - success: bool - checksum_valid: bool = True - processing_time: float = 0.0 - error: Optional[str] = None - retries: int = 0 - worker_id: Optional[str] = None - output_file: Optional[str] = None - - -@dataclass -class PipelineResult: - """Aggregate result of the entire pipeline run.""" - - total_chunks: int = 0 - processed: int = 0 - failed: int = 0 - retries: int = 0 - elapsed_time: float = 0.0 - throughput_mbps: float = 0.0 - worker_stats: Dict[str, Any] = field(default_factory=dict) - errors: List[str] = field(default_factory=list) - chunks_in_order: bool = True - output_dir: Optional[str] = None - chunk_files: List[str] = field(default_factory=list) diff --git a/core/chunker/pipeline.py b/core/chunker/pipeline.py deleted file mode 100644 index 975c249..0000000 --- a/core/chunker/pipeline.py +++ /dev/null @@ -1,279 +0,0 @@ -""" -Pipeline — orchestrates the entire chunker pipeline. - -Wires: Chunker → ChunkQueue → WorkerPool → ResultCollector → PipelineResult - -Demonstrates: -- Function parameters and defaults (Interview Topic 1) — configurable pipeline -- Concurrency (Interview Topic 2) — producer thread + worker pool -- OOP design (Interview Topic 4) — composition of pipeline components -- Exception handling (Interview Topic 7) — graceful error propagation -""" - -import json -import logging -import threading -import time -from pathlib import Path -from typing import Any, Callable, Dict, Optional - -from .chunker import Chunker -from .collector import ResultCollector -from .exceptions import PipelineError -from .models import PipelineResult -from .pool import WorkerPool -from .queue import ChunkQueue - -logger = logging.getLogger(__name__) - - -class Pipeline: - """ - Orchestrates the chunk processing pipeline. - - The pipeline runs in three stages: - 1. Producer thread: Chunker probes file → pushes time-based chunks to ChunkQueue - 2. Worker pool: N workers pull from queue → extract mp4 segments → emit results - 3. Collector: ResultCollector reassembles results in sequence order - - Args: - source: Path to the source media file - chunk_duration: Duration of each chunk in seconds (default: 10.0) - num_workers: Number of concurrent worker threads (default: 4) - max_retries: Max retry attempts per chunk (default: 3) - processor_type: Processor to use — "ffmpeg", "checksum", "simulated_decode", "composite" - queue_size: Max chunks buffered in queue (default: 10) - event_callback: Optional callback for real-time events - output_dir: Directory for output chunk files (required for "ffmpeg" processor) - """ - - def __init__( - self, - source: str, - chunk_duration: float = 10.0, - num_workers: int = 4, - max_retries: int = 3, - processor_type: str = "checksum", - queue_size: int = 10, - event_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None, - output_dir: Optional[str] = None, - start_time: Optional[float] = None, - end_time: Optional[float] = None, - ): - self.source = source - self.chunk_duration = chunk_duration - self.num_workers = num_workers - self.max_retries = max_retries - self.processor_type = processor_type - self.queue_size = queue_size - self.event_callback = event_callback - self.output_dir = output_dir - self.start_time = start_time - self.end_time = end_time - - def _emit(self, event_type: str, data: Dict[str, Any]) -> None: - """Emit an event if callback is registered.""" - if self.event_callback: - self.event_callback(event_type, data) - - def _produce_chunks( - self, chunker: Chunker, chunk_queue: ChunkQueue - ) -> None: - """Producer thread: probe file and enqueue time-based chunks.""" - try: - for chunk in chunker.chunks(): - chunk_queue.put(chunk, timeout=30.0) - self._emit("chunk_queued", { - "sequence": chunk.sequence, - "start_time": chunk.start_time, - "end_time": chunk.end_time, - "duration": chunk.duration, - "queue_size": chunk_queue.qsize(), - }) - except Exception as e: - logger.error(f"Producer error: {e}") - self._emit("producer_error", {"error": str(e)}) - finally: - chunk_queue.close() - - def _monitor_progress( - self, start_time: float, file_size: int, stop_event: threading.Event - ) -> None: - """Monitor thread: emit pipeline_progress every 500ms.""" - while not stop_event.is_set(): - elapsed = time.monotonic() - start_time - mb = file_size / (1024 * 1024) - self._emit("pipeline_progress", { - "elapsed": round(elapsed, 2), - "throughput_mbps": round(mb / elapsed, 2) if elapsed > 0 else 0, - }) - stop_event.wait(0.5) - - def _write_manifest( - self, result: PipelineResult, source_duration: float - ) -> None: - """Write manifest.json to output_dir with segment metadata.""" - if not self.output_dir: - return - - manifest = { - "source": self.source, - "source_duration": source_duration, - "chunk_duration": self.chunk_duration, - "total_chunks": result.total_chunks, - "processed": result.processed, - "failed": result.failed, - "elapsed_time": result.elapsed_time, - "throughput_mbps": result.throughput_mbps, - "segments": [ - { - "sequence": i, - "file": f"chunk_{i:04d}.mp4", - "start": i * self.chunk_duration, - "end": min( - (i + 1) * self.chunk_duration, source_duration - ), - } - for i in range(result.total_chunks) - if i < result.total_chunks - ], - } - - manifest_path = Path(self.output_dir) / "manifest.json" - manifest_path.write_text(json.dumps(manifest, indent=2)) - logger.info(f"Manifest written to {manifest_path}") - - def run(self) -> PipelineResult: - """ - Execute the full pipeline. - - Returns: - PipelineResult with aggregate stats - - Raises: - PipelineError: If the pipeline fails catastrophically - """ - start_time = time.monotonic() - self._emit("pipeline_start", { - "source": self.source, - "chunk_duration": self.chunk_duration, - "num_workers": self.num_workers, - "processor_type": self.processor_type, - }) - - try: - # Stage 1: Set up chunker (probes file for duration) - chunker = Chunker( - self.source, - self.chunk_duration, - start_time=self.start_time, - end_time=self.end_time, - ) - total_chunks = chunker.expected_chunks - - if total_chunks == 0: - self._emit("pipeline_complete", {"total_chunks": 0}) - return PipelineResult(chunks_in_order=True) - - self._emit("pipeline_info", { - "file_size": chunker.file_size, - "source_duration": chunker.source_duration, - "total_chunks": total_chunks, - }) - - # Stage 2: Set up queue and worker pool - chunk_queue = ChunkQueue(maxsize=self.queue_size) - pool = WorkerPool( - num_workers=self.num_workers, - chunk_queue=chunk_queue, - processor_type=self.processor_type, - max_retries=self.max_retries, - event_callback=self.event_callback, - output_dir=self.output_dir, - ) - - # Stage 3: Start workers, monitor, then produce chunks - pool.start() - - monitor_stop = threading.Event() - monitor = threading.Thread( - target=self._monitor_progress, - args=(start_time, chunker.file_size, monitor_stop), - name="progress-monitor", - daemon=True, - ) - monitor.start() - - producer = threading.Thread( - target=self._produce_chunks, - args=(chunker, chunk_queue), - name="chunk-producer", - daemon=True, - ) - producer.start() - - # Stage 4: Wait for all workers to finish - all_results = pool.wait() - producer.join(timeout=5.0) - - # Stop monitor - monitor_stop.set() - monitor.join(timeout=2.0) - - # Stage 5: Collect results in order - collector = ResultCollector(total_chunks) - for r in all_results: - collector.add(r) - self._emit("chunk_collected", { - "sequence": r.sequence, - "success": r.success, - "buffered": collector.buffered_count, - "emitted": collector.emitted_count, - }) - - # Build result - elapsed = time.monotonic() - start_time - file_size_mb = chunker.file_size / (1024 * 1024) - throughput = file_size_mb / elapsed if elapsed > 0 else 0.0 - - failed_results = [r for r in all_results if not r.success] - total_retries = sum(r.retries for r in all_results) - chunk_files = [ - r.output_file for r in all_results - if r.success and r.output_file - ] - - result = PipelineResult( - total_chunks=total_chunks, - processed=len(all_results), - failed=len(failed_results), - retries=total_retries, - elapsed_time=elapsed, - throughput_mbps=throughput, - worker_stats=pool.get_worker_stats(), - errors=[r.error for r in failed_results if r.error], - chunks_in_order=collector.is_complete, - output_dir=self.output_dir, - chunk_files=chunk_files, - ) - - # Write manifest if output_dir is set - self._write_manifest(result, chunker.source_duration) - - pool.shutdown() - - self._emit("pipeline_complete", { - "total_chunks": result.total_chunks, - "processed": result.processed, - "failed": result.failed, - "elapsed": result.elapsed_time, - "throughput_mbps": result.throughput_mbps, - }) - - return result - - except PipelineError: - raise - except Exception as e: - self._emit("pipeline_error", {"error": str(e)}) - raise PipelineError(f"Pipeline failed: {e}") from e diff --git a/core/chunker/pool.py b/core/chunker/pool.py deleted file mode 100644 index bc86d04..0000000 --- a/core/chunker/pool.py +++ /dev/null @@ -1,125 +0,0 @@ -""" -WorkerPool — manages N worker threads via ThreadPoolExecutor. - -Demonstrates: Python concurrency — threading (Interview Topic 2). -""" - -import logging -import threading -from concurrent.futures import Future, ThreadPoolExecutor -from typing import Any, Callable, Dict, List, Optional - -from .models import ChunkResult -from .processor import ( - ChecksumProcessor, - CompositeProcessor, - FFmpegExtractProcessor, - Processor, - SimulatedDecodeProcessor, -) -from .queue import ChunkQueue -from .worker import Worker - -logger = logging.getLogger(__name__) - - -def create_processor( - processor_type: str = "checksum", - output_dir: Optional[str] = None, -) -> Processor: - """Factory for processor instances.""" - if processor_type == "ffmpeg": - if not output_dir: - raise ValueError("output_dir required for ffmpeg processor") - return FFmpegExtractProcessor(output_dir=output_dir) - elif processor_type == "checksum": - return ChecksumProcessor() - elif processor_type == "simulated_decode": - return SimulatedDecodeProcessor() - elif processor_type == "composite": - return CompositeProcessor([ - ChecksumProcessor(), - SimulatedDecodeProcessor(ms_per_second=50.0), - ]) - else: - raise ValueError(f"Unknown processor type: {processor_type}") - - -class WorkerPool: - """ - Manages N worker threads that process chunks concurrently. - - Args: - num_workers: Number of concurrent worker threads (default: 4) - chunk_queue: Shared queue to pull chunks from - processor_type: Type of processor for each worker (default: "checksum") - max_retries: Max retry attempts per chunk (default: 3) - event_callback: Optional callback for real-time events - """ - - def __init__( - self, - num_workers: int = 4, - chunk_queue: Optional[ChunkQueue] = None, - processor_type: str = "checksum", - max_retries: int = 3, - event_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None, - output_dir: Optional[str] = None, - ): - self.num_workers = num_workers - self.chunk_queue = chunk_queue or ChunkQueue() - self.processor_type = processor_type - self.max_retries = max_retries - self.event_callback = event_callback - self.output_dir = output_dir - self.shutdown_event = threading.Event() - self._executor: Optional[ThreadPoolExecutor] = None - self._futures: List[Future] = [] - self._workers: List[Worker] = [] - - def start(self) -> None: - """Start all worker threads.""" - self._executor = ThreadPoolExecutor( - max_workers=self.num_workers, - thread_name_prefix="chunk-worker", - ) - - for i in range(self.num_workers): - worker = Worker( - worker_id=f"worker-{i}", - chunk_queue=self.chunk_queue, - processor=create_processor(self.processor_type, output_dir=self.output_dir), - max_retries=self.max_retries, - event_callback=self.event_callback, - ) - self._workers.append(worker) - future = self._executor.submit(worker.run) - self._futures.append(future) - - logger.info(f"WorkerPool started with {self.num_workers} workers") - - def wait(self) -> List[ChunkResult]: - """Wait for all workers to finish and collect results.""" - all_results = [] - for future in self._futures: - results = future.result() - all_results.extend(results) - return all_results - - def shutdown(self) -> None: - """Signal shutdown and cleanup.""" - self.shutdown_event.set() - self.chunk_queue.close() - if self._executor: - self._executor.shutdown(wait=True) - - def get_worker_stats(self) -> Dict[str, Any]: - """Get per-worker statistics.""" - return { - w.worker_id: { - "processed": w.processed_count, - "errors": w.error_count, - "retries": w.retry_count, - } - for w in self._workers - } diff --git a/core/chunker/processor.py b/core/chunker/processor.py deleted file mode 100644 index dd5d772..0000000 --- a/core/chunker/processor.py +++ /dev/null @@ -1,173 +0,0 @@ -""" -Processor ABC and concrete implementations. - -Demonstrates: OOP design principles — ABC, inheritance, composition (Interview Topic 4). -""" - -import hashlib -import time -from abc import ABC, abstractmethod -from pathlib import Path -from typing import List - -from .exceptions import ChunkChecksumError -from .models import Chunk, ChunkResult - - -class Processor(ABC): - """ - Abstract base class for chunk processors. - - Each processor defines how a single chunk is processed. - The Worker calls processor.process(chunk) and handles retries. - """ - - @abstractmethod - def process(self, chunk: Chunk) -> ChunkResult: - """Process a single chunk and return the result.""" - pass - - -class FFmpegExtractProcessor(Processor): - """ - Extracts a time segment from the source file using FFmpeg stream copy. - - Produces a playable mp4 file per chunk — no re-encoding. - - Args: - output_dir: Directory to write chunk mp4 files - """ - - def __init__(self, output_dir: str): - self.output_dir = output_dir - Path(output_dir).mkdir(parents=True, exist_ok=True) - - def process(self, chunk: Chunk) -> ChunkResult: - from core.ffmpeg.transcode import TranscodeConfig, transcode - - start = time.monotonic() - - output_file = str( - Path(self.output_dir) / f"chunk_{chunk.sequence:04d}.mp4" - ) - - config = TranscodeConfig( - input_path=chunk.source_path, - output_path=output_file, - video_codec="copy", - audio_codec="copy", - trim_start=chunk.start_time, - trim_end=chunk.end_time, - ) - - transcode(config) - - # Compute checksum of output file - md5 = hashlib.md5() - with open(output_file, "rb") as f: - for block in iter(lambda: f.read(8192), b""): - md5.update(block) - checksum = md5.hexdigest() - - elapsed = time.monotonic() - start - - return ChunkResult( - sequence=chunk.sequence, - success=True, - checksum_valid=True, - processing_time=elapsed, - output_file=output_file, - ) - - -class ChecksumProcessor(Processor): - """ - Validates chunk metadata consistency. - - For time-based chunks, verifies the time range is valid. - Raises ChunkChecksumError on invalid ranges. - """ - - def process(self, chunk: Chunk) -> ChunkResult: - start = time.monotonic() - - valid = chunk.duration > 0 and chunk.end_time > chunk.start_time - - if not valid: - raise ChunkChecksumError( - sequence=chunk.sequence, - expected="valid time range", - actual=f"{chunk.start_time}-{chunk.end_time}", - ) - - elapsed = time.monotonic() - start - - return ChunkResult( - sequence=chunk.sequence, - success=True, - checksum_valid=True, - processing_time=elapsed, - ) - - -class SimulatedDecodeProcessor(Processor): - """ - Simulates decode work by sleeping proportional to chunk duration. - - Useful for demonstrating concurrency behavior without real FFmpeg. - - Args: - ms_per_second: Milliseconds of simulated work per second of chunk duration (default: 100) - """ - - def __init__(self, ms_per_second: float = 100.0): - self.ms_per_second = ms_per_second - - def process(self, chunk: Chunk) -> ChunkResult: - start = time.monotonic() - - sleep_time = (self.ms_per_second * chunk.duration) / 1000.0 - time.sleep(sleep_time) - - elapsed = time.monotonic() - start - - return ChunkResult( - sequence=chunk.sequence, - success=True, - checksum_valid=True, - processing_time=elapsed, - ) - - -class CompositeProcessor(Processor): - """ - Chains multiple processors — runs each in sequence on the same chunk. - - Demonstrates OOP composition pattern. - - Args: - processors: List of processors to chain - """ - - def __init__(self, processors: List[Processor]): - if not processors: - raise ValueError("CompositeProcessor requires at least one processor") - self.processors = processors - - def process(self, chunk: Chunk) -> ChunkResult: - start = time.monotonic() - last_result = None - - for proc in self.processors: - last_result = proc.process(chunk) - if not last_result.success: - return last_result - - elapsed = time.monotonic() - start - - return ChunkResult( - sequence=chunk.sequence, - success=True, - checksum_valid=last_result.checksum_valid if last_result else True, - processing_time=elapsed, - ) diff --git a/core/chunker/queue.py b/core/chunker/queue.py deleted file mode 100644 index 191a219..0000000 --- a/core/chunker/queue.py +++ /dev/null @@ -1,76 +0,0 @@ -""" -ChunkQueue — bounded, thread-safe queue with sentinel-based shutdown. - -Demonstrates: Core data structures — queue.Queue (Interview Topic 5). -""" - -import queue -from typing import Optional - -from .models import Chunk - -# Sentinel value to signal workers to stop -_SENTINEL = object() - - -class ChunkQueue: - """ - Thread-safe bounded queue for chunks. - - Provides backpressure: producers block when the queue is full, - preventing unbounded memory usage. - - Args: - maxsize: Maximum number of chunks in the queue (default: 10) - """ - - def __init__(self, maxsize: int = 10): - self._queue: queue.Queue = queue.Queue(maxsize=maxsize) - self._closed = False - self.maxsize = maxsize - - def put(self, chunk: Chunk, timeout: Optional[float] = None) -> None: - """ - Add a chunk to the queue. Blocks if full (backpressure). - - Args: - chunk: The chunk to enqueue - timeout: Max seconds to wait (None = block forever) - - Raises: - queue.Full: If timeout expires while queue is full - """ - self._queue.put(chunk, timeout=timeout) - - def get(self, timeout: Optional[float] = None) -> Optional[Chunk]: - """ - Get next chunk from queue. Returns None if queue is closed. - - Args: - timeout: Max seconds to wait (None = block forever) - - Returns: - Chunk or None (if sentinel received, meaning queue is closed) - - Raises: - queue.Empty: If timeout expires while queue is empty - """ - item = self._queue.get(timeout=timeout) - if item is _SENTINEL: - # Re-put sentinel so other workers also see it - self._queue.put(_SENTINEL) - return None - return item - - def close(self) -> None: - """Signal all consumers to stop by inserting a sentinel.""" - self._closed = True - self._queue.put(_SENTINEL) - - @property - def is_closed(self) -> bool: - return self._closed - - def qsize(self) -> int: - """Current number of items in the queue (approximate).""" - return self._queue.qsize() diff --git a/core/chunker/worker.py b/core/chunker/worker.py deleted file mode 100644 index de094ca..0000000 --- a/core/chunker/worker.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Worker — pulls chunks from queue, processes with retry logic. - -Demonstrates: -- Exception handling and resilient code (Interview Topic 7) -- Concurrency (Interview Topic 2) — workers run in thread pool -""" - -import logging -import queue -import time -from typing import Any, Callable, Dict, Optional - -from .exceptions import ProcessorFailureError -from .models import Chunk, ChunkResult -from .processor import Processor -from .queue import ChunkQueue - -logger = logging.getLogger(__name__) - - -class Worker: - """ - Processes chunks from a queue with retry and exponential backoff. - - Args: - worker_id: Identifier for this worker (e.g. "worker-0") - chunk_queue: Source queue to pull chunks from - processor: Processor instance to use - max_retries: Maximum retry attempts per chunk (default: 3) - event_callback: Optional callback for real-time status updates - """ - - def __init__( - self, - worker_id: str, - chunk_queue: ChunkQueue, - processor: Processor, - max_retries: int = 3, - event_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None, - ): - self.worker_id = worker_id - self.chunk_queue = chunk_queue - self.processor = processor - self.max_retries = max_retries - self.event_callback = event_callback - self.processed_count = 0 - self.error_count = 0 - self.retry_count = 0 - - def _emit(self, event_type: str, data: Dict[str, Any]) -> None: - """Emit an event if callback is registered.""" - if self.event_callback: - self.event_callback(event_type, {"worker_id": self.worker_id, **data}) - - def _process_with_retry(self, chunk: Chunk) -> ChunkResult: - """ - Process a chunk with exponential backoff retry. - - Retry delays: 0.1s, 0.2s, 0.4s, ... (doubles each attempt) - """ - last_error = None - - for attempt in range(self.max_retries + 1): - try: - if attempt > 0: - backoff = 0.1 * (2 ** (attempt - 1)) - self._emit("chunk_retry", { - "sequence": chunk.sequence, - "attempt": attempt, - "backoff": backoff, - }) - time.sleep(backoff) - self.retry_count += 1 - - result = self.processor.process(chunk) - result.retries = attempt - result.worker_id = self.worker_id - return result - - except Exception as e: - last_error = e - logger.warning( - f"{self.worker_id}: chunk {chunk.sequence} " - f"attempt {attempt + 1}/{self.max_retries + 1} failed: {e}" - ) - - # All retries exhausted - self.error_count += 1 - self._emit("chunk_error", { - "sequence": chunk.sequence, - "error": str(last_error), - "retries": self.max_retries, - }) - - return ChunkResult( - sequence=chunk.sequence, - success=False, - processing_time=0.0, - error=str(last_error), - retries=self.max_retries, - worker_id=self.worker_id, - ) - - def run(self) -> list[ChunkResult]: - """ - Main worker loop — pull chunks and process until queue is closed. - - Returns: - List of ChunkResults processed by this worker - """ - results = [] - self._emit("worker_status", {"state": "idle"}) - - while True: - try: - chunk = self.chunk_queue.get(timeout=1.0) - except queue.Empty: - continue - - if chunk is None: # Sentinel received - break - - self._emit("chunk_processing", { - "sequence": chunk.sequence, - "state": "processing", - "queue_size": self.chunk_queue.qsize(), - }) - - result = self._process_with_retry(chunk) - results.append(result) - self.processed_count += 1 - - self._emit("chunk_done", { - "sequence": chunk.sequence, - "success": result.success, - "processing_time": result.processing_time, - "retries": result.retries, - "queue_size": self.chunk_queue.qsize(), - }) - - self._emit("worker_status", {"state": "stopped"}) - return results diff --git a/core/db/__init__.py b/core/db/__init__.py index 971b297..172a776 100644 --- a/core/db/__init__.py +++ b/core/db/__init__.py @@ -13,7 +13,7 @@ Basic CRUD (create, get, update, delete) goes directly through the session: from .connection import get_session, create_tables -from .tables import MediaAsset, Job, Timeline, Checkpoint, Brand +from .models import MediaAsset, Job, Timeline, Checkpoint, Brand from .assets import list_assets, get_asset_filenames from .job import list_jobs diff --git a/core/db/assets.py b/core/db/assets.py index b0972c1..765e17a 100644 --- a/core/db/assets.py +++ b/core/db/assets.py @@ -7,7 +7,7 @@ from uuid import UUID from sqlmodel import Session, select -from .tables import MediaAsset +from .models import MediaAsset def list_assets(session: Session, status: Optional[str] = None, search: Optional[str] = None) -> list[MediaAsset]: diff --git a/core/db/brand.py b/core/db/brand.py index c8a9ee9..325c724 100644 --- a/core/db/brand.py +++ b/core/db/brand.py @@ -7,7 +7,7 @@ from uuid import UUID from sqlmodel import Session, select -from .tables import Brand +from .models import Brand def get_or_create_brand(session: Session, canonical_name: str, diff --git a/core/db/checkpoint.py b/core/db/checkpoint.py index 1669b6b..6561b46 100644 --- a/core/db/checkpoint.py +++ b/core/db/checkpoint.py @@ -6,7 +6,7 @@ from uuid import UUID from sqlmodel import Session, select -from .tables import Checkpoint +from .models import Checkpoint def get_latest_checkpoint(session: Session, timeline_id: UUID, parent_id: UUID | None = None) -> Checkpoint | None: diff --git a/core/db/connection.py b/core/db/connection.py index a72b213..885c4f6 100644 --- a/core/db/connection.py +++ b/core/db/connection.py @@ -30,5 +30,5 @@ def get_session() -> Session: def create_tables(): """Create all SQLModel tables.""" from sqlmodel import SQLModel - from . import tables # noqa — registers all table classes + from . import models # noqa — registers all table classes SQLModel.metadata.create_all(get_engine()) diff --git a/core/db/fixtures/soccer_broadcast.json b/core/db/fixtures/soccer_broadcast.json new file mode 100644 index 0000000..cad76eb --- /dev/null +++ b/core/db/fixtures/soccer_broadcast.json @@ -0,0 +1,142 @@ +{ + "name": "soccer_broadcast", + "pipeline": { + "name": "soccer_broadcast", + "profile_name": "soccer_broadcast", + "stages": [ + { + "name": "extract_frames", + "branch": "trunk" + }, + { + "name": "filter_scenes", + "branch": "trunk" + }, + { + "name": "field_segmentation", + "branch": "trunk" + }, + { + "name": "detect_edges", + "branch": "hoarding" + }, + { + "name": "detect_objects", + "branch": "objects" + }, + { + "name": "preprocess" + }, + { + "name": "run_ocr" + }, + { + "name": "match_brands" + }, + { + "name": "escalate_vlm" + }, + { + "name": "escalate_cloud" + }, + { + "name": "compile_report" + } + ], + "edges": [ + { + "source": "extract_frames", + "target": "filter_scenes" + }, + { + "source": "filter_scenes", + "target": "field_segmentation" + }, + { + "source": "field_segmentation", + "target": "detect_edges" + }, + { + "source": "field_segmentation", + "target": "detect_objects" + }, + { + "source": "detect_edges", + "target": "preprocess" + }, + { + "source": "detect_objects", + "target": "preprocess" + }, + { + "source": "preprocess", + "target": "run_ocr" + }, + { + "source": "run_ocr", + "target": "match_brands" + }, + { + "source": "match_brands", + "target": "escalate_vlm" + }, + { + "source": "escalate_vlm", + "target": "escalate_cloud" + }, + { + "source": "escalate_cloud", + "target": "compile_report" + } + ] + }, + "configs": { + "extract_frames": { + "fps": 2.0, + "max_frames": 500 + }, + "filter_scenes": { + "hamming_threshold": 8, + "enabled": true + }, + "field_segmentation": { + "enabled": true, + "hue_low": 30, + "hue_high": 85, + "sat_low": 30, + "sat_high": 255, + "val_low": 30, + "val_high": 255, + "morph_kernel": 15, + "min_area_ratio": 0.05 + }, + "detect_edges": { + "enabled": true, + "edge_canny_low": 50, + "edge_canny_high": 150, + "edge_hough_threshold": 80, + "edge_hough_min_length": 100, + "edge_hough_max_gap": 10, + "edge_pair_max_distance": 200, + "edge_pair_min_distance": 15 + }, + "detect_objects": { + "model_name": "yolov8n.pt", + "confidence_threshold": 0.3, + "target_classes": [] + }, + "run_ocr": { + "languages": [ + "en", + "es" + ], + "min_confidence": 0.5 + }, + "match_brands": { + "fuzzy_threshold": 75 + }, + "escalate_vlm": { + "vlm_prompt_template": "Identify the brand or sponsor visible in this cropped region from a soccer broadcast.{hint}{text} Respond with: brand, confidence (0-1), reasoning." + } + } +} diff --git a/core/db/job.py b/core/db/job.py index 0d17927..2f0bc6c 100644 --- a/core/db/job.py +++ b/core/db/job.py @@ -7,7 +7,7 @@ from uuid import UUID from sqlmodel import Session, select -from .tables import Job +from .models import Job def list_jobs(session: Session, parent_id: Optional[UUID] = None, status: Optional[str] = None) -> list[Job]: diff --git a/core/db/models.py b/core/db/models.py index ee8a4fd..dbcf8d1 100644 --- a/core/db/models.py +++ b/core/db/models.py @@ -44,7 +44,7 @@ class SourceType(str, Enum): class MediaAsset(SQLModel, table=True): """A video/audio file registered in the system.""" - __tablename__ = "media_assets" + __tablename__ = "media_asset" id: UUID = Field(default_factory=uuid4, primary_key=True) filename: str @@ -67,7 +67,7 @@ class MediaAsset(SQLModel, table=True): class TranscodePreset(SQLModel, table=True): """A reusable transcoding configuration (like Handbrake presets).""" - __tablename__ = "transcode_presets" + __tablename__ = "transcode_preset" id: UUID = Field(default_factory=uuid4, primary_key=True) name: str @@ -90,12 +90,13 @@ class TranscodePreset(SQLModel, table=True): class Job(SQLModel, table=True): """A pipeline job.""" - __tablename__ = "jobs" + __tablename__ = "job" id: UUID = Field(default_factory=uuid4, primary_key=True) source_asset_id: UUID = Field(index=True) video_path: str profile_name: str = "soccer_broadcast" + timeline_id: Optional[UUID] = None parent_id: Optional[UUID] = None run_type: RunType = "initial" config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) @@ -107,7 +108,6 @@ class Job(SQLModel, table=True): brands_found: int = 0 cloud_llm_calls: int = 0 estimated_cost_usd: float = 0.0 - celery_task_id: Optional[str] = None priority: int = 0 created_at: Optional[datetime] = Field(default_factory=datetime.utcnow) started_at: Optional[datetime] = None @@ -115,7 +115,7 @@ class Job(SQLModel, table=True): class Timeline(SQLModel, table=True): """The frame sequence from a source video.""" - __tablename__ = "timelines" + __tablename__ = "timeline" id: UUID = Field(default_factory=uuid4, primary_key=True) source_asset_id: Optional[UUID] = Field(default=None, index=True) @@ -129,10 +129,11 @@ class Timeline(SQLModel, table=True): class Checkpoint(SQLModel, table=True): """A snapshot of pipeline state on a timeline.""" - __tablename__ = "checkpoints" + __tablename__ = "checkpoint" id: UUID = Field(default_factory=uuid4, primary_key=True) timeline_id: UUID + job_id: Optional[UUID] = Field(default=None, index=True) parent_id: Optional[UUID] = None stage_outputs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) @@ -143,7 +144,7 @@ class Checkpoint(SQLModel, table=True): class Brand(SQLModel, table=True): """A brand discovered or registered in the system.""" - __tablename__ = "brands" + __tablename__ = "brand" id: UUID = Field(default_factory=uuid4, primary_key=True) canonical_name: str = Field(index=True) @@ -154,3 +155,12 @@ class Brand(SQLModel, table=True): total_airings: int = 0 created_at: Optional[datetime] = Field(default_factory=datetime.utcnow) updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow) + +class Profile(SQLModel, table=True): + """A content type profile.""" + __tablename__ = "profile" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + name: str + pipeline: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) + configs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) diff --git a/core/db/seed.py b/core/db/seed.py new file mode 100644 index 0000000..8df375c --- /dev/null +++ b/core/db/seed.py @@ -0,0 +1,43 @@ +""" +Seed data — insert initial profile rows if they don't exist. + +Called on startup after create_tables(). +""" + +import json +import logging +from pathlib import Path + +logger = logging.getLogger(__name__) + +SEED_DIR = Path(__file__).parent / "fixtures" + + +def seed_profiles(): + """Insert seed profiles from JSON fixtures if not already present.""" + from .connection import get_session + from .models import Profile + + fixtures = list(SEED_DIR.glob("*.json")) + if not fixtures: + return + + with get_session() as session: + for f in fixtures: + data = json.loads(f.read_text()) + name = data["name"] + + existing = session.query(Profile).filter(Profile.name == name).first() + if existing: + logger.debug("Profile %s already exists, skipping seed", name) + continue + + profile = Profile( + name=name, + pipeline=data.get("pipeline", {}), + configs=data.get("configs", {}), + ) + session.add(profile) + logger.info("Seeded profile: %s", name) + + session.commit() diff --git a/core/db/tables.py b/core/db/tables.py deleted file mode 100644 index c6f59c6..0000000 --- a/core/db/tables.py +++ /dev/null @@ -1,96 +0,0 @@ -""" -SQLModel table definitions. - -Generated by modelgen from core/schema/models/. Do not edit directly. -""" - -from __future__ import annotations - -from datetime import datetime -from typing import Any, Dict, List, Optional -from uuid import UUID, uuid4 - -from sqlalchemy import JSON -from sqlmodel import Column, Field, SQLModel - - -class MediaAsset(SQLModel, table=True): - __tablename__ = "media_asset" - - id: UUID = Field(default_factory=uuid4, primary_key=True) - filename: str - path: str - status: str = "pending" - size_bytes: int = 0 - duration_seconds: float = 0.0 - width: Optional[int] = None - height: Optional[int] = None - fps: Optional[float] = None - codec: Optional[str] = None - created_at: Optional[datetime] = Field(default_factory=datetime.utcnow) - - -class Job(SQLModel, table=True): - __tablename__ = "job" - - id: UUID = Field(default_factory=uuid4, primary_key=True) - source_asset_id: UUID = Field(index=True) - video_path: str - profile_name: str = "soccer_broadcast" - parent_id: Optional[UUID] = Field(default=None, index=True) - run_type: str = "initial" - config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) - status: str = "pending" - current_stage: Optional[str] = None - progress: float = 0.0 - error_message: Optional[str] = None - total_detections: int = 0 - brands_found: int = 0 - cloud_llm_calls: int = 0 - estimated_cost_usd: float = 0.0 - priority: int = 0 - created_at: Optional[datetime] = Field(default_factory=datetime.utcnow) - started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None - - -class Timeline(SQLModel, table=True): - __tablename__ = "timeline" - - id: UUID = Field(default_factory=uuid4, primary_key=True) - source_asset_id: Optional[UUID] = Field(default=None, index=True) - source_video: str = "" - profile_name: str = "" - fps: float = 2.0 - frames_prefix: str = "" - frames_manifest: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) - frames_meta: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]')) - created_at: Optional[datetime] = Field(default_factory=datetime.utcnow) - - -class Checkpoint(SQLModel, table=True): - __tablename__ = "checkpoint" - - id: UUID = Field(default_factory=uuid4, primary_key=True) - timeline_id: UUID = Field(index=True) - parent_id: Optional[UUID] = Field(default=None, index=True) - stage_outputs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) - config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) - stats: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) - is_scenario: bool = False - scenario_label: str = "" - created_at: Optional[datetime] = Field(default_factory=datetime.utcnow) - - -class Brand(SQLModel, table=True): - __tablename__ = "brand" - - id: UUID = Field(default_factory=uuid4, primary_key=True) - canonical_name: str = Field(index=True) - aliases: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]')) - source: str = "ocr" - confirmed: bool = False - airings: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]')) - total_airings: int = 0 - created_at: Optional[datetime] = Field(default_factory=datetime.utcnow) - updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow) diff --git a/detect/__init__.py b/core/detect/__init__.py similarity index 100% rename from detect/__init__.py rename to core/detect/__init__.py diff --git a/detect/checkpoint/__init__.py b/core/detect/checkpoint/__init__.py similarity index 95% rename from detect/checkpoint/__init__.py rename to core/detect/checkpoint/__init__.py index 32b1c97..a2741d0 100644 --- a/detect/checkpoint/__init__.py +++ b/core/detect/checkpoint/__init__.py @@ -16,4 +16,4 @@ from .storage import ( load_stage_output, ) from .frames import save_frames, load_frames -from .runner_bridge import checkpoint_after_stage, reset_checkpoint_state +from .runner_bridge import checkpoint_after_stage, reset_checkpoint_state, get_timeline_id diff --git a/detect/checkpoint/frames.py b/core/detect/checkpoint/frames.py similarity index 98% rename from detect/checkpoint/frames.py rename to core/detect/checkpoint/frames.py index ec2e0fa..c1d81ca 100644 --- a/detect/checkpoint/frames.py +++ b/core/detect/checkpoint/frames.py @@ -9,7 +9,7 @@ import tempfile import numpy as np from PIL import Image -from detect.models import Frame +from core.detect.models import Frame logger = logging.getLogger(__name__) diff --git a/detect/checkpoint/replay.py b/core/detect/checkpoint/replay.py similarity index 76% rename from detect/checkpoint/replay.py rename to core/detect/checkpoint/replay.py index 785d422..59fa962 100644 --- a/detect/checkpoint/replay.py +++ b/core/detect/checkpoint/replay.py @@ -11,7 +11,7 @@ import logging import uuid -from detect import emit +from core.detect import emit # TODO: migrate to Timeline/Branch/Checkpoint model # These old functions no longer exist — replay needs rework def _not_migrated(*args, **kwargs): @@ -19,66 +19,14 @@ def _not_migrated(*args, **kwargs): load_checkpoint = _not_migrated list_checkpoints = _not_migrated -from detect.graph import NODES, build_graph +from core.detect.graph import NODES, build_graph logger = logging.getLogger(__name__) -class OverrideProfile: - """ - Wraps a ContentTypeProfile and patches config methods with overrides. - Override dict structure: - { - "frame_extraction": {"fps": 1.0}, - "scene_filter": {"hamming_threshold": 12}, - "region_analysis": {"edge_canny_low": 30, "edge_canny_high": 120}, - "detection": {"confidence_threshold": 0.5}, - "ocr": {"languages": ["en", "es"], "min_confidence": 0.3}, - "resolver": {"fuzzy_threshold": 60}, - } - """ - - def __init__(self, base, overrides: dict): - self._base = base - self._overrides = overrides - - def __getattr__(self, name): - return getattr(self._base, name) - - def _patch(self, config, key: str): - patches = self._overrides.get(key, {}) - for k, v in patches.items(): - if hasattr(config, k): - setattr(config, k, v) - return config - - def frame_extraction_config(self): - return self._patch(self._base.frame_extraction_config(), "frame_extraction") - - def scene_filter_config(self): - return self._patch(self._base.scene_filter_config(), "scene_filter") - - def region_analysis_config(self): - return self._patch(self._base.region_analysis_config(), "region_analysis") - - def detection_config(self): - return self._patch(self._base.detection_config(), "detection") - - def ocr_config(self): - return self._patch(self._base.ocr_config(), "ocr") - - def resolver_config(self): - return self._patch(self._base.resolver_config(), "resolver") - - def vlm_prompt(self, crop_context): - return self._base.vlm_prompt(crop_context) - - def aggregate(self, detections): - return self._base.aggregate(detections) - - def auxiliary_detections(self, source): - return self._base.auxiliary_detections(source) +# OverrideProfile removed — config overrides are now handled by dict merging +# in _load_profile() (nodes.py) and replay_single_stage (below). def replay_from( @@ -183,10 +131,16 @@ def replay_single_stage( state = load_checkpoint(job_id, previous_stage) # Build profile with overrides - from detect.profiles import get_profile + from core.detect.profile import get_profile, get_stage_config profile = get_profile(state.get("profile_name", "soccer_broadcast")) if config_overrides: - profile = OverrideProfile(profile, config_overrides) + merged_configs = dict(profile.get("configs", {})) + for sname, soverrides in config_overrides.items(): + if sname in merged_configs: + merged_configs[sname] = {**merged_configs[sname], **soverrides} + else: + merged_configs[sname] = soverrides + profile = {**profile, "configs": merged_configs} # Run the stage function directly (not through the graph) if stage == "detect_edges": @@ -207,9 +161,11 @@ def _replay_detect_edges( ) -> dict: """Run edge detection on checkpoint frames, optionally with debug overlays.""" import os - from detect.stages.edge_detector import detect_edge_regions + from core.detect.stages.edge_detector import detect_edge_regions - config = profile.region_analysis_config() + from core.detect.profile import get_stage_config + from core.detect.stages.models import RegionAnalysisConfig + config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges")) frames = state.get("filtered_frames", []) if frame_refs: @@ -231,7 +187,7 @@ def _replay_detect_edges( if debug and frames: debug_data = {} if inference_url: - from detect.inference import InferenceClient + from core.detect.inference import InferenceClient client = InferenceClient(base_url=inference_url, job_id=job_id) for frame in frames: dr = client.detect_edges_debug( @@ -252,7 +208,7 @@ def _replay_detect_edges( } else: # Local mode — import GPU module directly - from detect.stages.edge_detector import _load_cv_edges + from core.detect.stages.edge_detector import _load_cv_edges edges_mod = _load_cv_edges() for frame in frames: dr = edges_mod.detect_edges_debug( diff --git a/core/detect/checkpoint/runner_bridge.py b/core/detect/checkpoint/runner_bridge.py new file mode 100644 index 0000000..831b921 --- /dev/null +++ b/core/detect/checkpoint/runner_bridge.py @@ -0,0 +1,97 @@ +""" +Runner bridge — checkpoint hook called by PipelineRunner after each stage. + +Owns the per-job state (timeline, frame manifest, checkpoint chain) that +the runner shouldn't know about. + +Timeline and Job are independent entities: +- One Timeline can serve multiple Jobs (re-run with different params) +- One Job operates on one Timeline (set after frame extraction) +- Checkpoints belong to Timeline, tagged with the Job that created them +""" + +from __future__ import annotations + +import logging + +logger = logging.getLogger(__name__) + +# Per-job state +_timeline_id: dict[str, str] = {} +_frames_manifest: dict[str, dict[int, str]] = {} +_latest_checkpoint: dict[str, str] = {} + + +def reset_checkpoint_state(job_id: str): + """Clean up per-job checkpoint state. Called when pipeline finishes.""" + _timeline_id.pop(job_id, None) + _frames_manifest.pop(job_id, None) + _latest_checkpoint.pop(job_id, None) + + +def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict): + """ + Save a checkpoint after a stage completes. + + Called by the runner. Handles: + - Timeline creation (once, on extract_frames) + - Frame upload (via create_timeline) + - Stage output serialization (via stage registry) + - Checkpoint chain (parent → child) + """ + if not job_id: + return + + from .storage import create_timeline, save_stage_output + from core.detect.stages.base import _REGISTRY + + merged = {**state, **result} + + # On extract_frames: create Timeline + upload frames + root checkpoint + if stage_name == "extract_frames" and job_id not in _timeline_id: + frames = merged.get("frames", []) + video_path = merged.get("video_path", "") + profile_name = merged.get("profile_name", "") + + tid, cid = create_timeline( + source_video=video_path, + profile_name=profile_name, + frames=frames, + ) + _timeline_id[job_id] = tid + _latest_checkpoint[job_id] = cid + logger.info("Job %s → Timeline %s (root checkpoint %s)", job_id, tid, cid) + + # Emit timeline_id via SSE so the UI can use it for checkpoint loads + from core.detect import emit + emit.log(job_id, "Checkpoint", "INFO", f"timeline_id={tid}") + return + + # For subsequent stages: save checkpoint on the timeline + tid = _timeline_id.get(job_id) + if not tid: + logger.warning("No timeline for job %s, skipping checkpoint", job_id) + return + + # Serialize stage output using the stage's serialize_fn if available + stage_cls = _REGISTRY.get(stage_name) + serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None) + if serialize_fn: + output_json = serialize_fn(merged, job_id) + else: + output_json = {} + + parent_id = _latest_checkpoint.get(job_id) + new_checkpoint_id = save_stage_output( + timeline_id=tid, + parent_checkpoint_id=parent_id, + stage_name=stage_name, + output_json=output_json, + job_id=job_id, + ) + _latest_checkpoint[job_id] = new_checkpoint_id + + +def get_timeline_id(job_id: str) -> str | None: + """Get the timeline_id for a running job. Used by the UI to load checkpoints.""" + return _timeline_id.get(job_id) diff --git a/detect/checkpoint/serializer.py b/core/detect/checkpoint/serializer.py similarity index 97% rename from detect/checkpoint/serializer.py rename to core/detect/checkpoint/serializer.py index 466b1eb..0894a3d 100644 --- a/detect/checkpoint/serializer.py +++ b/core/detect/checkpoint/serializer.py @@ -28,7 +28,7 @@ def serialize_state(state: dict, frames_manifest: dict[int, str]) -> dict: Calls each registered stage's serialize_fn for stage-owned data. Envelope fields (job_id, etc.) are copied directly. """ - from detect.stages.base import _REGISTRY + from core.detect.stages.base import _REGISTRY checkpoint = {} @@ -64,7 +64,7 @@ def deserialize_state(checkpoint: dict, frames: list) -> dict: Calls each stage's deserialize_fn to restore stage-owned data. """ - from detect.stages.base import _REGISTRY + from core.detect.stages.base import _REGISTRY frame_map = {f.sequence: f for f in frames} diff --git a/detect/checkpoint/storage.py b/core/detect/checkpoint/storage.py similarity index 94% rename from detect/checkpoint/storage.py rename to core/detect/checkpoint/storage.py index f0b6aad..1bdc3c9 100644 --- a/detect/checkpoint/storage.py +++ b/core/detect/checkpoint/storage.py @@ -33,7 +33,7 @@ def create_timeline( Returns (timeline_id, checkpoint_id). """ - from core.db.tables import Timeline, Checkpoint + from core.db.models import Timeline, Checkpoint from core.db.connection import get_session with get_session() as session: @@ -81,7 +81,7 @@ def create_timeline( def get_timeline_frames(timeline_id: str) -> list: """Load frames from a timeline (from MinIO) as Frame objects.""" - from core.db.tables import Timeline + from core.db.models import Timeline from core.db.connection import get_session with get_session() as session: @@ -96,7 +96,7 @@ def get_timeline_frames(timeline_id: str) -> list: def get_timeline_frames_b64(timeline_id: str) -> list[dict]: """Load frames as base64 JPEG (lightweight, no numpy).""" - from core.db.tables import Timeline + from core.db.models import Timeline from core.db.connection import get_session from .frames import load_frames_b64 @@ -123,6 +123,7 @@ def save_stage_output( stats: dict | None = None, is_scenario: bool = False, scenario_label: str = "", + job_id: str | None = None, ) -> str: """ Save a stage's output as a new checkpoint (child of parent). @@ -130,7 +131,7 @@ def save_stage_output( Carries forward stage outputs from parent + adds the new one. Returns the new checkpoint ID. """ - from core.db.tables import Checkpoint + from core.db.models import Checkpoint from core.db.connection import get_session with get_session() as session: @@ -146,6 +147,7 @@ def save_stage_output( checkpoint = Checkpoint( timeline_id=UUID(timeline_id), + job_id=UUID(job_id) if job_id else None, parent_id=UUID(parent_checkpoint_id) if parent_checkpoint_id else None, stage_outputs={**parent_outputs, stage_name: output_json}, config_overrides={**parent_config, **(config_overrides or {})}, @@ -165,7 +167,7 @@ def save_stage_output( def load_stage_output(checkpoint_id: str, stage_name: str) -> dict | None: """Load a stage's output from a checkpoint.""" - from core.db.tables import Checkpoint + from core.db.models import Checkpoint from core.db.connection import get_session with get_session() as session: diff --git a/detect/emit.py b/core/detect/emit.py similarity index 97% rename from detect/emit.py rename to core/detect/emit.py index 6b0b0af..29d0483 100644 --- a/detect/emit.py +++ b/core/detect/emit.py @@ -16,8 +16,8 @@ from __future__ import annotations import dataclasses from datetime import datetime, timezone -from detect.events import push_detect_event -from detect.models import PipelineStats +from core.detect.events import push_detect_event +from core.detect.models import PipelineStats # Log level ordering for comparison _LEVEL_ORDER = {"DEBUG": 0, "INFO": 1, "WARN": 2, "ERROR": 3} diff --git a/detect/events.py b/core/detect/events.py similarity index 100% rename from detect/events.py rename to core/detect/events.py diff --git a/detect/graph/__init__.py b/core/detect/graph/__init__.py similarity index 100% rename from detect/graph/__init__.py rename to core/detect/graph/__init__.py diff --git a/detect/graph/events.py b/core/detect/graph/events.py similarity index 90% rename from detect/graph/events.py rename to core/detect/graph/events.py index c620955..794cb61 100644 --- a/detect/graph/events.py +++ b/core/detect/graph/events.py @@ -4,8 +4,8 @@ Graph event emission — node state tracking + SSE graph_update events. from __future__ import annotations -from detect import emit -from detect.state import DetectState +from core.detect import emit +from core.detect.state import DetectState # Track node states across pipeline runs diff --git a/detect/graph/nodes.py b/core/detect/graph/nodes.py similarity index 64% rename from detect/graph/nodes.py rename to core/detect/graph/nodes.py index f913aec..76737ce 100644 --- a/detect/graph/nodes.py +++ b/core/detect/graph/nodes.py @@ -1,28 +1,39 @@ """ Pipeline node functions — one per stage. -Each node: reads state, runs stage logic, emits transitions, returns output dict. +Each node: reads state, gets config from profile dict, runs stage logic, +emits transitions, returns output dict. """ from __future__ import annotations import os -from detect import emit -from detect.models import PipelineStats -from detect.profiles import SoccerBroadcastProfile -from detect.state import DetectState -from detect.stages.frame_extractor import extract_frames -from detect.stages.scene_filter import scene_filter -from detect.stages.edge_detector import detect_edge_regions -from detect.stages.yolo_detector import detect_objects -from detect.stages.preprocess import preprocess_regions -from detect.stages.ocr_stage import run_ocr -from detect.stages.brand_resolver import resolve_brands -from detect.stages.vlm_local import escalate_vlm -from detect.stages.vlm_cloud import escalate_cloud -from detect.stages.aggregator import compile_report -from detect.tracing import trace_node, flush as flush_traces +from core.detect import emit +from core.detect.models import CropContext, PipelineStats +from core.detect.profile import get_profile, get_stage_config, build_vlm_prompt, aggregate_detections +from core.detect.stages.models import ( + DetectionConfig, + FieldSegmentationConfig, + FrameExtractionConfig, + OCRConfig, + RegionAnalysisConfig, + ResolverConfig, + SceneFilterConfig, +) +from core.detect.state import DetectState +from core.detect.stages.frame_extractor import extract_frames +from core.detect.stages.scene_filter import scene_filter +from core.detect.stages.field_segmentation import run_field_segmentation +from core.detect.stages.edge_detector import detect_edge_regions +from core.detect.stages.yolo_detector import detect_objects +from core.detect.stages.preprocess import preprocess_regions +from core.detect.stages.ocr_stage import run_ocr +from core.detect.stages.brand_resolver import resolve_brands +from core.detect.stages.vlm_local import escalate_vlm +from core.detect.stages.vlm_cloud import escalate_cloud +from core.detect.stages.aggregator import compile_report +from core.detect.tracing import trace_node, flush as flush_traces from .events import emit_transition @@ -31,6 +42,7 @@ INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode NODES = [ "extract_frames", "filter_scenes", + "field_segmentation", "detect_edges", "detect_objects", "preprocess", @@ -42,17 +54,21 @@ NODES = [ ] -def _get_profile(state: DetectState): +def _load_profile(state: DetectState) -> dict: + """Load profile dict, apply config overrides if present.""" name = state.get("profile_name", "soccer_broadcast") - if name == "soccer_broadcast": - profile = SoccerBroadcastProfile() - else: - raise ValueError(f"Unknown profile: {name}") + profile = get_profile(name) overrides = state.get("config_overrides") if overrides: - from detect.checkpoint.replay import OverrideProfile - profile = OverrideProfile(profile, overrides) + # Merge overrides into a copy of the profile configs + merged_configs = dict(profile.get("configs", {})) + for stage_name, stage_overrides in overrides.items(): + if stage_name in merged_configs: + merged_configs[stage_name] = {**merged_configs[stage_name], **stage_overrides} + else: + merged_configs[stage_name] = stage_overrides + profile = {**profile, "configs": merged_configs} return profile @@ -70,16 +86,16 @@ def node_extract_frames(state: DetectState) -> dict: source_asset_id = state.get("source_asset_id") if source_asset_id and not state.get("session_brands"): - from detect.stages.brand_resolver import build_session_dict + from core.detect.stages.brand_resolver import build_session_dict session_brands = build_session_dict(source_asset_id) state["session_brands"] = session_brands _emit(state, "extract_frames", "running") with trace_node(state, "extract_frames") as span: - profile = _get_profile(state) - config = profile.frame_extraction_config() - frames = extract_frames(state["video_path"], config, job_id=state.get("job_id")) + profile = _load_profile(state) + config = FrameExtractionConfig(**get_stage_config(profile, "extract_frames")) + frames = extract_frames(state["video_path"], config, job_id=job_id) span.set_output({"frames_extracted": len(frames)}) _emit(state, "extract_frames", "done") @@ -90,8 +106,8 @@ def node_filter_scenes(state: DetectState) -> dict: _emit(state, "filter_scenes", "running") with trace_node(state, "filter_scenes") as span: - profile = _get_profile(state) - config = profile.scene_filter_config() + profile = _load_profile(state) + config = SceneFilterConfig(**get_stage_config(profile, "filter_scenes")) frames = state.get("frames", []) kept = scene_filter(frames, config, job_id=state.get("job_id")) span.set_output({"frames_in": len(frames), "frames_kept": len(kept)}) @@ -103,17 +119,42 @@ def node_filter_scenes(state: DetectState) -> dict: return {"filtered_frames": kept, "stats": stats} +def node_field_segmentation(state: DetectState) -> dict: + _emit(state, "field_segmentation", "running") + + with trace_node(state, "field_segmentation") as span: + profile = _load_profile(state) + config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation")) + frames = state.get("filtered_frames", []) + job_id = state.get("job_id") + + result = run_field_segmentation(frames, config, inference_url=INFERENCE_URL, job_id=job_id) + span.set_output({ + "frames": len(frames), + "avg_coverage": sum(result["field_coverage"].values()) / max(len(result["field_coverage"]), 1), + }) + + _emit(state, "field_segmentation", "done") + return { + "field_masks": result["field_masks"], + "field_boundaries": result["field_boundaries"], + "field_coverage": result["field_coverage"], + } + + def node_detect_edges(state: DetectState) -> dict: _emit(state, "detect_edges", "running") with trace_node(state, "detect_edges") as span: - profile = _get_profile(state) - config = profile.region_analysis_config() + profile = _load_profile(state) + config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges")) frames = state.get("filtered_frames", []) + field_masks = state.get("field_masks", {}) job_id = state.get("job_id") regions = detect_edge_regions( frames, config, inference_url=INFERENCE_URL, job_id=job_id, + field_masks=field_masks, ) total = sum(len(r) for r in regions.values()) span.set_output({"frames": len(frames), "edge_regions": total}) @@ -129,8 +170,8 @@ def node_detect_objects(state: DetectState) -> dict: _emit(state, "detect_objects", "running") with trace_node(state, "detect_objects") as span: - profile = _get_profile(state) - config = profile.detection_config() + profile = _load_profile(state) + config = DetectionConfig(**get_stage_config(profile, "detect_objects")) frames = state.get("filtered_frames", []) job_id = state.get("job_id") @@ -149,13 +190,12 @@ def node_preprocess(state: DetectState) -> dict: _emit(state, "preprocess", "running") with trace_node(state, "preprocess") as span: - profile = _get_profile(state) + profile = _load_profile(state) + prep_config = get_stage_config(profile, "preprocess") frames = state.get("filtered_frames", []) boxes = state.get("boxes_by_frame", {}) job_id = state.get("job_id") - overrides = state.get("config_overrides", {}) - prep_config = overrides.get("preprocessing", {}) do_contrast = prep_config.get("contrast", True) do_deskew = prep_config.get("deskew", False) do_binarize = prep_config.get("binarize", False) @@ -178,8 +218,8 @@ def node_run_ocr(state: DetectState) -> dict: _emit(state, "run_ocr", "running") with trace_node(state, "run_ocr") as span: - profile = _get_profile(state) - config = profile.ocr_config() + profile = _load_profile(state) + config = OCRConfig(**get_stage_config(profile, "run_ocr")) frames = state.get("filtered_frames", []) boxes = state.get("boxes_by_frame", {}) job_id = state.get("job_id") @@ -198,18 +238,18 @@ def node_match_brands(state: DetectState) -> dict: _emit(state, "match_brands", "running") with trace_node(state, "match_brands") as span: - profile = _get_profile(state) - resolver_config = profile.resolver_config() + profile = _load_profile(state) + config = ResolverConfig(**get_stage_config(profile, "match_brands")) candidates = state.get("text_candidates", []) session_brands = state.get("session_brands", {}) job_id = state.get("job_id") source_asset_id = state.get("source_asset_id") matched, unresolved = resolve_brands( - candidates, resolver_config, + candidates, config, session_brands=session_brands, source_asset_id=source_asset_id, - content_type=profile.name, job_id=job_id, + content_type=profile["name"], job_id=job_id, ) span.set_output({"matched": len(matched), "unresolved": len(unresolved)}) @@ -221,15 +261,19 @@ def node_escalate_vlm(state: DetectState) -> dict: _emit(state, "escalate_vlm", "running") with trace_node(state, "escalate_vlm") as span: - profile = _get_profile(state) + profile = _load_profile(state) + vlm_config = get_stage_config(profile, "escalate_vlm") + vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.") candidates = state.get("unresolved_candidates", []) job_id = state.get("job_id") + vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template) + vlm_matched, still_unresolved = escalate_vlm( candidates, - vlm_prompt_fn=profile.vlm_prompt, + vlm_prompt_fn=vlm_prompt_fn, inference_url=INFERENCE_URL, - content_type=profile.name, + content_type=profile["name"], source_asset_id=state.get("source_asset_id"), job_id=job_id, ) @@ -254,16 +298,20 @@ def node_escalate_cloud(state: DetectState) -> dict: _emit(state, "escalate_cloud", "running") with trace_node(state, "escalate_cloud") as span: - profile = _get_profile(state) + profile = _load_profile(state) + vlm_config = get_stage_config(profile, "escalate_vlm") + vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.") candidates = state.get("unresolved_candidates", []) job_id = state.get("job_id") stats = state.get("stats", PipelineStats()) + vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template) + cloud_matched = escalate_cloud( candidates, - vlm_prompt_fn=profile.vlm_prompt, + vlm_prompt_fn=vlm_prompt_fn, stats=stats, - content_type=profile.name, + content_type=profile["name"], source_asset_id=state.get("source_asset_id"), job_id=job_id, ) @@ -283,7 +331,7 @@ def node_compile_report(state: DetectState) -> dict: _emit(state, "compile_report", "running") with trace_node(state, "compile_report") as span: - profile = _get_profile(state) + profile = _load_profile(state) detections = state.get("detections", []) stats = state.get("stats", PipelineStats()) job_id = state.get("job_id") @@ -292,7 +340,7 @@ def node_compile_report(state: DetectState) -> dict: detections=detections, stats=stats, video_source=state.get("video_path", ""), - content_type=profile.name, + content_type=profile["name"], job_id=job_id, ) @@ -306,6 +354,7 @@ def node_compile_report(state: DetectState) -> dict: NODE_FUNCTIONS = [ ("extract_frames", node_extract_frames), ("filter_scenes", node_filter_scenes), + ("field_segmentation", node_field_segmentation), ("detect_edges", node_detect_edges), ("detect_objects", node_detect_objects), ("preprocess", node_preprocess), diff --git a/detect/graph/runner.py b/core/detect/graph/runner.py similarity index 96% rename from detect/graph/runner.py rename to core/detect/graph/runner.py index e3218b5..92461be 100644 --- a/detect/graph/runner.py +++ b/core/detect/graph/runner.py @@ -13,8 +13,8 @@ import logging import os import threading -from core.schema.models.pipeline_config import PipelineConfig -from detect.state import DetectState +from core.detect.stages.models import PipelineConfig +from core.detect.state import DetectState from .nodes import NODES, NODE_FUNCTIONS logger = logging.getLogger(__name__) @@ -118,7 +118,7 @@ def _wait_if_paused(job_id: str, node_name: str): if _pause_after_stage.get(job_id, False): gate.clear() - from detect import emit + from core.detect import emit emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}") while not gate.wait(timeout=0.5): @@ -237,7 +237,7 @@ class PipelineRunner: # 4. Checkpoint if self.do_checkpoint: - from detect.checkpoint import checkpoint_after_stage + from core.detect.checkpoint import checkpoint_after_stage checkpoint_after_stage(job_id, stage_name, state, result) # 5. Pause check @@ -256,11 +256,11 @@ def get_pipeline( start_from: str | None = None, ) -> PipelineRunner: """Return a PipelineRunner for the given profile.""" - from detect.profiles import get_profile + from core.detect.profile import get_profile, pipeline_config_from_dict do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED profile = get_profile(profile_name) - config = profile.pipeline_config() + config = pipeline_config_from_dict(profile["pipeline"]) return PipelineRunner( config=config, diff --git a/detect/inference/__init__.py b/core/detect/inference/__init__.py similarity index 100% rename from detect/inference/__init__.py rename to core/detect/inference/__init__.py diff --git a/detect/inference/client.py b/core/detect/inference/client.py similarity index 93% rename from detect/inference/client.py rename to core/detect/inference/client.py index 2f1bb90..01e502b 100644 --- a/detect/inference/client.py +++ b/core/detect/inference/client.py @@ -231,6 +231,20 @@ class InferenceClient: pair_count=data.get("pair_count", 0), ) + def post(self, path: str, payload: dict) -> dict | None: + """Generic POST to the inference server. Returns JSON response or None on error.""" + try: + resp = self.session.post( + f"{self.base_url}{path}", + json=payload, + timeout=self.timeout, + ) + resp.raise_for_status() + return resp.json() + except Exception as e: + logger.warning("Inference POST %s failed: %s", path, e) + return None + def load_model(self, model: str, quantization: str = "fp16") -> None: """Request the server to load a model into VRAM.""" self.session.post( diff --git a/detect/inference/types.py b/core/detect/inference/types.py similarity index 96% rename from detect/inference/types.py rename to core/detect/inference/types.py index 90e3ba7..498e010 100644 --- a/detect/inference/types.py +++ b/core/detect/inference/types.py @@ -2,7 +2,7 @@ Inference response types. These are the shapes returned by the inference server. -Kept separate from detect.models to avoid coupling the +Kept separate from core.detect.models to avoid coupling the inference protocol to pipeline internals. """ diff --git a/core/schema/models/pipeline.py b/core/detect/models.py similarity index 86% rename from core/schema/models/pipeline.py rename to core/detect/models.py index 8a8c763..80e2b20 100644 --- a/core/schema/models/pipeline.py +++ b/core/detect/models.py @@ -2,8 +2,8 @@ Detection pipeline runtime models. These are the data structures that flow between pipeline stages. -They contain runtime types (np.ndarray) so modelgen skips them — -not generated to SQLModel or TypeScript. +They contain runtime types (np.ndarray) so they live here, not in +core/schema/models/ (which is for modelgen source of truth). """ from __future__ import annotations @@ -85,3 +85,11 @@ class DetectionReport: brands: dict[str, BrandStats] = field(default_factory=dict) timeline: list[BrandDetection] = field(default_factory=list) pipeline_stats: PipelineStats = field(default_factory=PipelineStats) + + +@dataclass +class CropContext: + """Runtime type — holds image bytes for VLM prompts.""" + image: bytes + surrounding_text: str = "" + position_hint: str = "" diff --git a/core/detect/profile.py b/core/detect/profile.py new file mode 100644 index 0000000..6bb5781 --- /dev/null +++ b/core/detect/profile.py @@ -0,0 +1,107 @@ +""" +Profile registry and helpers. + +Loads profile data from Postgres. +A profile is a dict with keys: name, pipeline, configs. +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict + +from core.detect.stages.models import PipelineConfig, StageRef, Edge +from core.detect.models import ( + BrandDetection, + BrandStats, + CropContext, + DetectionReport, + PipelineStats, +) + +logger = logging.getLogger(__name__) + + +def get_profile(name: str) -> Dict[str, Any]: + """Get a profile dict by name from the database.""" + from core.db.connection import get_session + from core.db.models import Profile + + with get_session() as session: + row = session.query(Profile).filter(Profile.name == name).first() + + if row is None: + raise ValueError(f"Unknown profile: {name!r}") + + return { + "name": row.name, + "pipeline": row.pipeline or {}, + "configs": row.configs or {}, + } + + +def list_profiles() -> list[str]: + """List available profile names from the database.""" + from core.db.connection import get_session + from core.db.models import Profile + + with get_session() as session: + rows = session.query(Profile.name).all() + + return [r[0] for r in rows] + + +def get_stage_config(profile: Dict[str, Any], stage_name: str) -> dict: + """Get config values for a stage from a profile.""" + return profile.get("configs", {}).get(stage_name, {}) + + +def pipeline_config_from_dict(data: Dict[str, Any]) -> PipelineConfig: + """Deserialize a PipelineConfig from a JSONB dict.""" + stages = [StageRef(**s) for s in data.get("stages", [])] + edges = [Edge(**e) for e in data.get("edges", [])] + return PipelineConfig( + name=data.get("name", ""), + profile_name=data.get("profile_name", ""), + stages=stages, + edges=edges, + routing_rules=data.get("routing_rules", {}), + ) + + +def build_vlm_prompt(crop_context: CropContext, template: str) -> str: + """Build a VLM prompt from a template and crop context.""" + hint = f" Position: {crop_context.position_hint}." if crop_context.position_hint else "" + text = f" Nearby text: '{crop_context.surrounding_text}'." if crop_context.surrounding_text else "" + return template.format(hint=hint, text=text) + + +def aggregate_detections( + detections: list[BrandDetection], + content_type: str, +) -> DetectionReport: + """Group detections by brand into a report.""" + brands: dict[str, BrandStats] = {} + for d in detections: + if d.brand not in brands: + brands[d.brand] = BrandStats() + s = brands[d.brand] + s.total_appearances += 1 + s.total_screen_time += d.duration + s.avg_confidence = ( + (s.avg_confidence * (s.total_appearances - 1) + d.confidence) + / s.total_appearances + ) + if s.first_seen == 0.0 or d.timestamp < s.first_seen: + s.first_seen = d.timestamp + if d.timestamp > s.last_seen: + s.last_seen = d.timestamp + + return DetectionReport( + video_source="", + content_type=content_type, + duration_seconds=0.0, + brands=brands, + timeline=sorted(detections, key=lambda d: d.timestamp), + pipeline_stats=PipelineStats(), + ) diff --git a/detect/providers/__init__.py b/core/detect/providers/__init__.py similarity index 100% rename from detect/providers/__init__.py rename to core/detect/providers/__init__.py diff --git a/detect/providers/base.py b/core/detect/providers/base.py similarity index 100% rename from detect/providers/base.py rename to core/detect/providers/base.py diff --git a/detect/providers/claude.py b/core/detect/providers/claude.py similarity index 100% rename from detect/providers/claude.py rename to core/detect/providers/claude.py diff --git a/detect/providers/gemini.py b/core/detect/providers/gemini.py similarity index 100% rename from detect/providers/gemini.py rename to core/detect/providers/gemini.py diff --git a/detect/providers/groq.py b/core/detect/providers/groq.py similarity index 100% rename from detect/providers/groq.py rename to core/detect/providers/groq.py diff --git a/detect/providers/openai_compat.py b/core/detect/providers/openai_compat.py similarity index 100% rename from detect/providers/openai_compat.py rename to core/detect/providers/openai_compat.py diff --git a/detect/sse_contract.py b/core/detect/sse.py similarity index 90% rename from detect/sse_contract.py rename to core/detect/sse.py index 9bc77c8..2bdc2da 100644 --- a/detect/sse_contract.py +++ b/core/detect/sse.py @@ -145,3 +145,19 @@ class RetryResponse(BaseModel): status: str task_id: str job_id: str + +class RunRequest(BaseModel): + """Request body for launching a detection pipeline run.""" + video_path: str + profile_name: str = "soccer_broadcast" + source_asset_id: str = "" + checkpoint: bool = True + skip_vlm: bool = False + skip_cloud: bool = False + log_level: str = "INFO" + +class RunResponse(BaseModel): + """Response after starting a pipeline run.""" + status: str + job_id: str + video_path: str diff --git a/detect/stages/__init__.py b/core/detect/stages/__init__.py similarity index 91% rename from detect/stages/__init__.py rename to core/detect/stages/__init__.py index 13817aa..4fb9561 100644 --- a/detect/stages/__init__.py +++ b/core/detect/stages/__init__.py @@ -16,6 +16,7 @@ from .base import ( # Import all stage files to trigger auto-registration from . import edge_detector # noqa: F401 +from . import field_segmentation # noqa: F401 # Import registry for backward compat (other stages still use old pattern) from . import registry # noqa: F401 diff --git a/detect/stages/aggregator.py b/core/detect/stages/aggregator.py similarity index 96% rename from detect/stages/aggregator.py rename to core/detect/stages/aggregator.py index 433f7f5..229e750 100644 --- a/detect/stages/aggregator.py +++ b/core/detect/stages/aggregator.py @@ -9,8 +9,8 @@ from __future__ import annotations import logging -from detect import emit -from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats +from core.detect import emit +from core.detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats logger = logging.getLogger(__name__) diff --git a/detect/stages/base.py b/core/detect/stages/base.py similarity index 72% rename from detect/stages/base.py rename to core/detect/stages/base.py index 03f7679..5dd8dfb 100644 --- a/detect/stages/base.py +++ b/core/detect/stages/base.py @@ -5,7 +5,7 @@ Each stage is a file that subclasses Stage. Auto-discovered via __init_subclass__. No manual registration needed. A stage: - - Has a StageDefinition (from schema) with name, config, IO + - Has a StageDefinition (generated from schema) with name, config, IO - Implements run(frames, config) → output - Owns its output serialization (opaque blob) - Optionally has a TypeScript port for browser-side execution @@ -16,12 +16,30 @@ the format. The stage that wrote it is the only one that can read it. from __future__ import annotations -from dataclasses import dataclass, field from typing import Any import numpy as np -from core.schema.models.stages import StageConfigField, StageIO, StageDefinition +from core.detect.stages.models import ( + StageConfigField, + StageIO, + StageDefinition, +) + + +# Legacy runtime extension — adds callable fields for old-style stages. +# New stages use Stage subclass with serialize()/deserialize() methods instead. +class LegacyStageDefinition: + """Wraps a StageDefinition with callable serialize/deserialize functions.""" + + def __init__(self, definition: StageDefinition, fn=None, serialize_fn=None, deserialize_fn=None): + self._definition = definition + self.fn = fn + self.serialize_fn = serialize_fn + self.deserialize_fn = deserialize_fn + + def __getattr__(self, name): + return getattr(self._definition, name) # --------------------------------------------------------------------------- @@ -30,12 +48,18 @@ from core.schema.models.stages import StageConfigField, StageIO, StageDefinition # --------------------------------------------------------------------------- _REGISTRY: dict[str, type['Stage']] = {} -_LEGACY_REGISTRY: dict[str, StageDefinition] = {} +_LEGACY_REGISTRY: dict[str, LegacyStageDefinition] = {} -def register_stage(definition: StageDefinition): +def register_stage( + definition: StageDefinition, + fn=None, + serialize_fn=None, + deserialize_fn=None, +): """Legacy registration for stages not yet converted to Stage subclass.""" - _LEGACY_REGISTRY[definition.name] = definition + legacy = LegacyStageDefinition(definition, fn=fn, serialize_fn=serialize_fn, deserialize_fn=deserialize_fn) + _LEGACY_REGISTRY[definition.name] = legacy class Stage: @@ -55,13 +79,6 @@ class Stage: _REGISTRY[cls.definition.name] = cls def run(self, frames: list, config: dict) -> Any: - """ - Run the stage on a list of frames with the given config. - - Config is a dict of parameter values (from slider UI or profile). - Returns the stage output — whatever shape this stage produces. - Debug overlays are included when config has debug=True. - """ raise NotImplementedError def serialize(self, output: Any) -> bytes: @@ -79,12 +96,15 @@ class Stage: # Discovery API # --------------------------------------------------------------------------- -def _all_definitions() -> dict[str, StageDefinition]: - """Merge new Stage subclass registry + legacy registry.""" +def _all_definitions(): + """Merge new Stage subclass registry + legacy registry. + + Returns StageDefinition for new-style stages, + LegacyStageDefinition for legacy stages (has serialize_fn etc). + """ merged = {} - # Legacy first, new overwrites (new takes precedence) - for name, defn in _LEGACY_REGISTRY.items(): - merged[name] = defn + for name, legacy in _LEGACY_REGISTRY.items(): + merged[name] = legacy for name, cls in _REGISTRY.items(): merged[name] = cls.definition return merged diff --git a/detect/stages/brand_resolver.py b/core/detect/stages/brand_resolver.py similarity index 97% rename from detect/stages/brand_resolver.py rename to core/detect/stages/brand_resolver.py index 7bc4590..008ae26 100644 --- a/detect/stages/brand_resolver.py +++ b/core/detect/stages/brand_resolver.py @@ -16,9 +16,9 @@ import logging from rapidfuzz import fuzz -from detect import emit -from detect.models import BrandDetection, TextCandidate -from detect.profiles.base import ResolverConfig +from core.detect import emit +from core.detect.models import BrandDetection, TextCandidate +from core.detect.stages.models import ResolverConfig logger = logging.getLogger(__name__) diff --git a/detect/stages/edge_detector.py b/core/detect/stages/edge_detector.py similarity index 68% rename from detect/stages/edge_detector.py rename to core/detect/stages/edge_detector.py index 85a2e0e..5755dac 100644 --- a/detect/stages/edge_detector.py +++ b/core/detect/stages/edge_detector.py @@ -21,10 +21,10 @@ from typing import Any from PIL import Image -from detect import emit -from detect.models import BoundingBox, Frame -from detect.stages.base import Stage -from core.schema.models.stages import StageDefinition, StageConfigField, StageIO +from core.detect import emit +from core.detect.models import BoundingBox, Frame +from core.detect.stages.base import Stage +from core.detect.stages.models import StageDefinition, StageConfigField, StageIO logger = logging.getLogger(__name__) @@ -42,14 +42,14 @@ class EdgeDetectionStage(Stage): writes=["edge_regions_by_frame"], ), config_fields=[ - StageConfigField("enabled", "bool", True, "Enable edge detection"), - StageConfigField("edge_canny_low", "int", 50, "Canny low threshold", min=0, max=255), - StageConfigField("edge_canny_high", "int", 150, "Canny high threshold", min=0, max=255), - StageConfigField("edge_hough_threshold", "int", 80, "Hough accumulator threshold", min=1, max=500), - StageConfigField("edge_hough_min_length", "int", 100, "Min line length (px)", min=10, max=2000), - StageConfigField("edge_hough_max_gap", "int", 10, "Max line gap (px)", min=1, max=100), - StageConfigField("edge_pair_max_distance", "int", 200, "Max distance between line pair (px)", min=10, max=500), - StageConfigField("edge_pair_min_distance", "int", 15, "Min distance between line pair (px)", min=5, max=200), + StageConfigField(name="enabled", type="bool", default=True, description="Enable edge detection"), + StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255), + StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255), + StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500), + StageConfigField(name="edge_hough_min_length", type="int", default=100, description="Min line length (px)", min=10, max=2000), + StageConfigField(name="edge_hough_max_gap", type="int", default=10, description="Max line gap (px)", min=1, max=100), + StageConfigField(name="edge_pair_max_distance", type="int", default=200, description="Max distance between line pair (px)", min=10, max=500), + StageConfigField(name="edge_pair_min_distance", type="int", default=15, description="Min distance between line pair (px)", min=5, max=200), ], tracks_element="edge_region", ) @@ -143,8 +143,8 @@ class EdgeDetectionStage(Stage): def _run_remote(self, frame: Frame, config: dict, inference_url: str, job_id: str) -> list[BoundingBox]: - from detect.inference import InferenceClient - from detect.emit import _run_log_level + from core.detect.inference import InferenceClient + from core.detect.emit import _run_log_level client = InferenceClient( base_url=inference_url, job_id=job_id, log_level=_run_log_level, @@ -208,7 +208,7 @@ def _load_cv_edges(): if _cv_edges_mod is None: import importlib.util from pathlib import Path - spec = importlib.util.spec_from_file_location("cv_edges", Path("gpu/models/cv/edges.py")) + spec = importlib.util.spec_from_file_location("cv_edges", Path("core/gpu/models/cv/edges.py")) _cv_edges_mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(_cv_edges_mod) return _cv_edges_mod @@ -216,8 +216,43 @@ def _load_cv_edges(): # --- Backward compat: standalone function for graph.py --- -def detect_edge_regions(frames, config, inference_url=None, job_id=None): - """Convenience wrapper — calls EdgeDetectionStage.run().""" +def _filter_by_field_mask(boxes, mask, margin_px=50): + """ + Keep only boxes that are near the pitch boundary (hoarding zone). + + The field mask has 255=pitch, 0=not pitch. Hoardings sit just + outside the pitch boundary. We dilate the mask to create a + "boundary zone" and keep boxes whose center falls in the zone + between the dilated mask edge and the original mask. + """ + import cv2 + import numpy as np + + if mask is None or not boxes: + return boxes + + # Dilate the pitch mask — the expansion zone is where hoardings are + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (margin_px * 2, margin_px * 2)) + dilated = cv2.dilate(mask, kernel) + + # Boundary zone = dilated but NOT original pitch + boundary_zone = cv2.bitwise_and(dilated, cv2.bitwise_not(mask)) + + kept = [] + for box in boxes: + cx = box.x + box.w // 2 + cy = box.y + box.h // 2 + # Clamp to image bounds + cy = min(cy, boundary_zone.shape[0] - 1) + cx = min(cx, boundary_zone.shape[1] - 1) + if boundary_zone[cy, cx] > 0: + kept.append(box) + + return kept + + +def detect_edge_regions(frames, config, inference_url=None, job_id=None, field_masks=None): + """Convenience wrapper — calls EdgeDetectionStage.run(), optionally filters by field mask.""" stage = EdgeDetectionStage() cfg = { "enabled": config.enabled, @@ -231,4 +266,23 @@ def detect_edge_regions(frames, config, inference_url=None, job_id=None): "inference_url": inference_url, "job_id": job_id, } - return stage.run(frames, cfg) + all_boxes = stage.run(frames, cfg) + + # Filter by field segmentation mask if available + if field_masks: + filtered_total = 0 + original_total = sum(len(b) for b in all_boxes.values()) + for seq, boxes in all_boxes.items(): + mask = field_masks.get(seq) + if mask is not None: + all_boxes[seq] = _filter_by_field_mask(boxes, mask) + filtered_total += len(all_boxes[seq]) + else: + filtered_total += len(boxes) + + if original_total != filtered_total: + from core.detect import emit + emit.log(job_id, "EdgeDetection", "INFO", + f"Field mask filter: {original_total} → {filtered_total} regions") + + return all_boxes diff --git a/core/detect/stages/field_segmentation.py b/core/detect/stages/field_segmentation.py new file mode 100644 index 0000000..953d9d7 --- /dev/null +++ b/core/detect/stages/field_segmentation.py @@ -0,0 +1,141 @@ +""" +Stage — Field Segmentation + +Calls the GPU inference server to detect pitch boundaries via +HSV green mask + morphology. The CV code lives in core/gpu/models/cv/. + +Outputs a mask and boundary that downstream stages use as spatial priors. +""" + +from __future__ import annotations + +import base64 +import io +import logging + +import numpy as np +from PIL import Image + +from core.detect import emit +from core.detect.models import Frame +from core.detect.stages.base import Stage +from core.detect.stages.models import ( + FieldSegmentationConfig, + StageConfigField, + StageDefinition, + StageIO, +) + +logger = logging.getLogger(__name__) + + +class FieldSegmentationStage(Stage): + + definition = StageDefinition( + name="field_segmentation", + label="Field Segmentation", + description="HSV green mask — detect pitch boundaries for spatial priors", + category="cv_analysis", + io=StageIO( + reads=["filtered_frames"], + writes=["field_mask"], + ), + config_fields=[ + StageConfigField(name="enabled", type="bool", default=True, description="Enable field segmentation"), + StageConfigField(name="hue_low", type="int", default=30, description="HSV hue lower bound", min=0, max=180), + StageConfigField(name="hue_high", type="int", default=85, description="HSV hue upper bound", min=0, max=180), + StageConfigField(name="sat_low", type="int", default=30, description="HSV saturation lower bound", min=0, max=255), + StageConfigField(name="sat_high", type="int", default=255, description="HSV saturation upper bound", min=0, max=255), + StageConfigField(name="val_low", type="int", default=30, description="HSV value lower bound", min=0, max=255), + StageConfigField(name="val_high", type="int", default=255, description="HSV value upper bound", min=0, max=255), + StageConfigField(name="morph_kernel", type="int", default=15, description="Morphology kernel size", min=3, max=51), + StageConfigField(name="min_area_ratio", type="float", default=0.05, description="Min contour area as fraction of frame", min=0.01, max=0.5), + ], + ) + + +def _frame_to_b64(frame: Frame) -> str: + """Encode frame image as base64 JPEG.""" + img = Image.fromarray(frame.image) + buf = io.BytesIO() + img.save(buf, format="JPEG", quality=85) + return base64.b64encode(buf.getvalue()).decode() + + +def _decode_mask_b64(mask_b64: str) -> np.ndarray: + """Decode a base64 PNG mask back to numpy array.""" + data = base64.b64decode(mask_b64) + img = Image.open(io.BytesIO(data)).convert("L") + return np.array(img) + + +def run_field_segmentation( + frames: list[Frame], + config: FieldSegmentationConfig, + inference_url: str | None = None, + job_id: str | None = None, +) -> dict: + """ + Run field segmentation on all frames via the inference server. + + Returns dict with: + field_masks: {seq: np.ndarray} + field_boundaries: {seq: [(x,y), ...]} + field_coverage: {seq: float} + """ + if not config.enabled: + emit.log(job_id, "FieldSegmentation", "INFO", "Disabled, skipping") + return {"field_masks": {}, "field_boundaries": {}, "field_coverage": {}} + + import os + url = inference_url or os.environ.get("INFERENCE_URL") + if not url: + emit.log(job_id, "FieldSegmentation", "WARNING", + "No INFERENCE_URL, skipping field segmentation") + return {"field_masks": {}, "field_boundaries": {}, "field_coverage": {}} + + emit.log(job_id, "FieldSegmentation", "INFO", + f"Segmenting {len(frames)} frames (hue={config.hue_low}-{config.hue_high})") + + from core.detect.inference import InferenceClient + from core.detect.emit import _run_log_level + client = InferenceClient(base_url=url, job_id=job_id or "", log_level=_run_log_level) + + field_masks = {} + field_boundaries = {} + field_coverage = {} + + for frame in frames: + image_b64 = _frame_to_b64(frame) + + resp = client.post("/segment_field", { + "image": image_b64, + "hue_low": config.hue_low, + "hue_high": config.hue_high, + "sat_low": config.sat_low, + "sat_high": config.sat_high, + "val_low": config.val_low, + "val_high": config.val_high, + "morph_kernel": config.morph_kernel, + "min_area_ratio": config.min_area_ratio, + }) + + if resp is None: + continue + + mask_b64 = resp.get("mask_b64", "") + if mask_b64: + field_masks[frame.sequence] = _decode_mask_b64(mask_b64) + + field_boundaries[frame.sequence] = resp.get("boundary", []) + field_coverage[frame.sequence] = resp.get("coverage", 0.0) + + avg_coverage = sum(field_coverage.values()) / max(len(field_coverage), 1) + emit.log(job_id, "FieldSegmentation", "INFO", + f"Done: {len(frames)} frames, avg coverage {avg_coverage:.1%}") + + return { + "field_masks": field_masks, + "field_boundaries": field_boundaries, + "field_coverage": field_coverage, + } diff --git a/detect/stages/frame_extractor.py b/core/detect/stages/frame_extractor.py similarity index 95% rename from detect/stages/frame_extractor.py rename to core/detect/stages/frame_extractor.py index c4e5d4b..df11add 100644 --- a/detect/stages/frame_extractor.py +++ b/core/detect/stages/frame_extractor.py @@ -16,9 +16,9 @@ import numpy as np from PIL import Image from core.ffmpeg.probe import probe_file -from detect import emit -from detect.models import Frame -from detect.profiles.base import FrameExtractionConfig +from core.detect import emit +from core.detect.models import Frame +from core.detect.stages.models import FrameExtractionConfig def _load_frames(tmpdir: Path, fps: float) -> list[Frame]: diff --git a/core/detect/stages/models.py b/core/detect/stages/models.py new file mode 100644 index 0000000..985b8ce --- /dev/null +++ b/core/detect/stages/models.py @@ -0,0 +1,106 @@ +""" +Pydantic Models - GENERATED FILE + +Do not edit directly. Regenerate using modelgen. +""" + +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional +from uuid import UUID + +from pydantic import BaseModel, Field + +class StageConfigField(BaseModel): + """A single tunable config parameter for the editor UI.""" + name: str + type: str + default: Any + description: str = "" + min: Optional[float] = None + max: Optional[float] = None + options: Optional[List[str]] = None + +class StageIO(BaseModel): + """Declares what a stage reads and writes.""" + reads: List[str] = Field(default_factory=list) + writes: List[str] = Field(default_factory=list) + optional_reads: List[str] = Field(default_factory=list) + +class StageDefinition(BaseModel): + """Complete metadata for a pipeline stage.""" + name: str + label: str + description: str + category: str = "detection" + io: StageIO + config_fields: List[StageConfigField] = Field(default_factory=list) + tracks_element: Optional[str] = None + +class FrameExtractionConfig(BaseModel): + """FrameExtractionConfig(fps: float = 2.0, max_frames: int = 500)""" + fps: float = 2.0 + max_frames: int = 500 + +class SceneFilterConfig(BaseModel): + """SceneFilterConfig(hamming_threshold: int = 8, enabled: bool = True)""" + hamming_threshold: int = 8 + enabled: bool = True + +class DetectionConfig(BaseModel): + """DetectionConfig(model_name: str = 'yolov8n.pt', confidence_threshold: float = 0.3, target_classes: List[str] = )""" + model_name: str = "yolov8n.pt" + confidence_threshold: float = 0.3 + target_classes: List[str] + +class OCRConfig(BaseModel): + """OCRConfig(languages: List[str] = , min_confidence: float = 0.5)""" + languages: List[str] + min_confidence: float = 0.5 + +class ResolverConfig(BaseModel): + """ResolverConfig(fuzzy_threshold: int = 75)""" + fuzzy_threshold: int = 75 + +class RegionAnalysisConfig(BaseModel): + """RegionAnalysisConfig(enabled: bool = True, edge_canny_low: int = 50, edge_canny_high: int = 150, edge_hough_threshold: int = 80, edge_hough_min_length: int = 100, edge_hough_max_gap: int = 10, edge_pair_max_distance: int = 200, edge_pair_min_distance: int = 15)""" + enabled: bool = True + edge_canny_low: int = 50 + edge_canny_high: int = 150 + edge_hough_threshold: int = 80 + edge_hough_min_length: int = 100 + edge_hough_max_gap: int = 10 + edge_pair_max_distance: int = 200 + edge_pair_min_distance: int = 15 + +class FieldSegmentationConfig(BaseModel): + """FieldSegmentationConfig(enabled: bool = True, hue_low: int = 30, hue_high: int = 85, sat_low: int = 30, sat_high: int = 255, val_low: int = 30, val_high: int = 255, morph_kernel: int = 15, min_area_ratio: float = 0.05)""" + enabled: bool = True + hue_low: int = 30 + hue_high: int = 85 + sat_low: int = 30 + sat_high: int = 255 + val_low: int = 30 + val_high: int = 255 + morph_kernel: int = 15 + min_area_ratio: float = 0.05 + +class StageRef(BaseModel): + """Reference to a stage in the pipeline graph.""" + name: str + branch: str = "trunk" + execution_target: str = "local" + +class Edge(BaseModel): + """Connection between stages in the graph.""" + source: str + target: str + condition: str = "" + +class PipelineConfig(BaseModel): + """Pipeline graph topology + routing rules.""" + name: str + profile_name: str + stages: List[StageRef] = Field(default_factory=list) + edges: List[Edge] = Field(default_factory=list) + routing_rules: Dict[str, Any] diff --git a/detect/stages/ocr_stage.py b/core/detect/stages/ocr_stage.py similarity index 94% rename from detect/stages/ocr_stage.py rename to core/detect/stages/ocr_stage.py index 608e038..6d643ad 100644 --- a/detect/stages/ocr_stage.py +++ b/core/detect/stages/ocr_stage.py @@ -18,9 +18,9 @@ from typing import TYPE_CHECKING import numpy as np -from detect import emit -from detect.models import BoundingBox, Frame, TextCandidate -from detect.profiles.base import OCRConfig +from core.detect import emit +from core.detect.models import BoundingBox, Frame, TextCandidate +from core.detect.stages.models import OCRConfig if TYPE_CHECKING: pass @@ -91,8 +91,8 @@ def run_ocr( # Build these once per pipeline run, not per crop if inference_url: - from detect.inference import InferenceClient - from detect.emit import _run_log_level + from core.detect.inference import InferenceClient + from core.detect.emit import _run_log_level client = InferenceClient(base_url=inference_url, job_id=job_id or "", log_level=_run_log_level) else: model = _get_local_model(config.languages[0]) diff --git a/detect/stages/preprocess.py b/core/detect/stages/preprocess.py similarity index 96% rename from detect/stages/preprocess.py rename to core/detect/stages/preprocess.py index a63c9be..9a22ff9 100644 --- a/detect/stages/preprocess.py +++ b/core/detect/stages/preprocess.py @@ -15,8 +15,8 @@ import logging import numpy as np -from detect import emit -from detect.models import BoundingBox, Frame +from core.detect import emit +from core.detect.models import BoundingBox, Frame logger = logging.getLogger(__name__) @@ -124,5 +124,5 @@ def _preprocess_remote(crop: np.ndarray, inference_url: str, def _preprocess_local(crop: np.ndarray, do_contrast: bool, do_deskew: bool, do_binarize: bool) -> np.ndarray: """Run preprocessing in-process (requires opencv-python-headless).""" - from gpu.models.preprocess import preprocess + from core.gpu.models.preprocess import preprocess return preprocess(crop, do_binarize=do_binarize, do_deskew=do_deskew, do_contrast=do_contrast) diff --git a/detect/stages/registry/__init__.py b/core/detect/stages/registry/__init__.py similarity index 100% rename from detect/stages/registry/__init__.py rename to core/detect/stages/registry/__init__.py diff --git a/detect/stages/registry/_serializers.py b/core/detect/stages/registry/_serializers.py similarity index 100% rename from detect/stages/registry/_serializers.py rename to core/detect/stages/registry/_serializers.py diff --git a/core/detect/stages/registry/cv_analysis.py b/core/detect/stages/registry/cv_analysis.py new file mode 100644 index 0000000..10f4645 --- /dev/null +++ b/core/detect/stages/registry/cv_analysis.py @@ -0,0 +1,44 @@ +"""Registration for CV analysis stages: edge detection.""" + +from core.detect.stages.models import StageDefinition, StageIO, StageConfigField +from core.detect.stages.base import register_stage +from ._serializers import serialize_dataclass_list, deserialize_bounding_box + + +def _ser_regions(state: dict, job_id: str) -> dict: + regions = state.get("edge_regions_by_frame", {}) + serialized = { + str(seq): serialize_dataclass_list(bl) for seq, bl in regions.items() + } + return {"edge_regions_by_frame": serialized} + + +def _deser_regions(data: dict, job_id: str) -> dict: + regions = {} + for seq_str, box_dicts in data.get("edge_regions_by_frame", {}).items(): + regions[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts] + return {"edge_regions_by_frame": regions} + + +def register(): + edge_detection = StageDefinition( + name="detect_edges", + label="Edge Detection", + description="Canny + HoughLinesP — find horizontal line pairs (hoarding boundaries)", + category="cv_analysis", + io=StageIO( + reads=["filtered_frames"], + writes=["edge_regions_by_frame"], + ), + config_fields=[ + StageConfigField(name="enabled", type="bool", default=True, description="Enable region analysis"), + StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255), + StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255), + StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500), + StageConfigField(name="edge_hough_min_length", type="int", default=100, description="Min line length (px)", min=10, max=2000), + StageConfigField(name="edge_hough_max_gap", type="int", default=10, description="Max line gap (px)", min=1, max=100), + StageConfigField(name="edge_pair_max_distance", type="int", default=200, description="Max distance between line pair (px)", min=10, max=500), + StageConfigField(name="edge_pair_min_distance", type="int", default=15, description="Min distance between line pair (px)", min=5, max=200), + ], + ) + register_stage(edge_detection, serialize_fn=_ser_regions, deserialize_fn=_deser_regions) diff --git a/detect/stages/registry/detection.py b/core/detect/stages/registry/detection.py similarity index 62% rename from detect/stages/registry/detection.py rename to core/detect/stages/registry/detection.py index 14b671d..3b848b9 100644 --- a/detect/stages/registry/detection.py +++ b/core/detect/stages/registry/detection.py @@ -1,6 +1,7 @@ """Registration for detection stages: YOLO, OCR.""" -from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage +from core.detect.stages.models import StageDefinition, StageIO, StageConfigField +from core.detect.stages.base import register_stage from ._serializers import ( serialize_dataclass_list, serialize_text_candidates, @@ -38,14 +39,12 @@ def register(): category="detection", io=StageIO(reads=["filtered_frames"], writes=["boxes_by_frame"]), config_fields=[ - StageConfigField("model_name", "str", "yolov8n.pt", "YOLO model file"), - StageConfigField("confidence_threshold", "float", 0.3, "Min detection confidence", min=0.0, max=1.0), - StageConfigField("target_classes", "list[str]", [], "YOLO classes to detect (empty = all)"), + StageConfigField(name="model_name", type="str", default="yolov8n.pt", description="YOLO model file"), + StageConfigField(name="confidence_threshold", type="float", default=0.3, description="Min detection confidence", min=0.0, max=1.0), + StageConfigField(name="target_classes", type="list[str]", default=[], description="YOLO classes to detect (empty = all)"), ], - serialize_fn=_ser_detect, - deserialize_fn=_deser_detect, ) - register_stage(yolo) + register_stage(yolo, serialize_fn=_ser_detect, deserialize_fn=_deser_detect) ocr = StageDefinition( name="run_ocr", @@ -54,10 +53,8 @@ def register(): category="detection", io=StageIO(reads=["filtered_frames", "boxes_by_frame"], writes=["text_candidates"]), config_fields=[ - StageConfigField("languages", "list[str]", ["en"], "OCR languages"), - StageConfigField("min_confidence", "float", 0.5, "Min OCR confidence", min=0.0, max=1.0), + StageConfigField(name="languages", type="list[str]", default=["en"], description="OCR languages"), + StageConfigField(name="min_confidence", type="float", default=0.5, description="Min OCR confidence", min=0.0, max=1.0), ], - serialize_fn=_ser_ocr, - deserialize_fn=_deser_ocr, ) - register_stage(ocr) + register_stage(ocr, serialize_fn=_ser_ocr, deserialize_fn=_deser_ocr) diff --git a/detect/stages/registry/escalation.py b/core/detect/stages/registry/escalation.py similarity index 73% rename from detect/stages/registry/escalation.py rename to core/detect/stages/registry/escalation.py index fca0222..741ea53 100644 --- a/detect/stages/registry/escalation.py +++ b/core/detect/stages/registry/escalation.py @@ -1,6 +1,7 @@ """Registration for escalation stages: local VLM, cloud LLM.""" -from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage +from core.detect.stages.models import StageDefinition, StageIO, StageConfigField +from core.detect.stages.base import register_stage from ._serializers import ( serialize_dataclass_list, serialize_text_candidates, @@ -37,12 +38,10 @@ def register(): optional_reads=["source_asset_id"], ), config_fields=[ - StageConfigField("min_confidence", "float", 0.5, "Min VLM confidence", min=0.0, max=1.0), + StageConfigField(name="min_confidence", type="float", default=0.5, description="Min VLM confidence", min=0.0, max=1.0), ], - serialize_fn=_ser_escalation, - deserialize_fn=_deser_escalation, ) - register_stage(vlm) + register_stage(vlm, serialize_fn=_ser_escalation, deserialize_fn=_deser_escalation) cloud = StageDefinition( name="escalate_cloud", @@ -55,9 +54,7 @@ def register(): optional_reads=["source_asset_id"], ), config_fields=[ - StageConfigField("min_confidence", "float", 0.4, "Min cloud confidence", min=0.0, max=1.0), + StageConfigField(name="min_confidence", type="float", default=0.4, description="Min cloud confidence", min=0.0, max=1.0), ], - serialize_fn=_ser_escalation, - deserialize_fn=_deser_escalation, ) - register_stage(cloud) + register_stage(cloud, serialize_fn=_ser_escalation, deserialize_fn=_deser_escalation) diff --git a/detect/stages/registry/output.py b/core/detect/stages/registry/output.py similarity index 82% rename from detect/stages/registry/output.py rename to core/detect/stages/registry/output.py index 4efa4b0..9808d8a 100644 --- a/detect/stages/registry/output.py +++ b/core/detect/stages/registry/output.py @@ -1,6 +1,6 @@ """Registration for output stages: report compilation.""" -from detect.stages.base import StageDefinition, StageIO, register_stage +from core.detect.stages.base import StageDefinition, StageIO, register_stage from ._serializers import serialize_dataclass, deserialize_detection_report @@ -26,7 +26,5 @@ def register(): category="output", io=StageIO(reads=["detections"], writes=["report"]), config_fields=[], - serialize_fn=_ser_report, - deserialize_fn=_deser_report, ) - register_stage(report) + register_stage(report, serialize_fn=_ser_report, deserialize_fn=_deser_report) diff --git a/detect/stages/registry/preprocessing.py b/core/detect/stages/registry/preprocessing.py similarity index 64% rename from detect/stages/registry/preprocessing.py rename to core/detect/stages/registry/preprocessing.py index cdc3f4f..deb353a 100644 --- a/detect/stages/registry/preprocessing.py +++ b/core/detect/stages/registry/preprocessing.py @@ -1,6 +1,7 @@ """Registration for preprocessing stages: frame extraction, scene filter, image preprocessing.""" -from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage +from core.detect.stages.models import StageDefinition, StageIO, StageConfigField +from core.detect.stages.base import register_stage from ._serializers import serialize_frames, deserialize_frames @@ -44,13 +45,11 @@ def register(): category="preprocessing", io=StageIO(reads=["video_path"], writes=["frames"]), config_fields=[ - StageConfigField("fps", "float", 2.0, "Frames per second", min=0.1, max=30.0), - StageConfigField("max_frames", "int", 500, "Maximum frames to extract", min=1, max=10000), + StageConfigField(name="fps", type="float", default=2.0, description="Frames per second", min=0.1, max=30.0), + StageConfigField(name="max_frames", type="int", default=500, description="Maximum frames to extract", min=1, max=10000), ], - serialize_fn=_ser_extract, - deserialize_fn=_deser_extract, ) - register_stage(extract) + register_stage(extract, serialize_fn=_ser_extract, deserialize_fn=_deser_extract) scene_filter = StageDefinition( name="filter_scenes", @@ -59,13 +58,11 @@ def register(): category="preprocessing", io=StageIO(reads=["frames"], writes=["filtered_frames"]), config_fields=[ - StageConfigField("hamming_threshold", "int", 8, "Hamming distance threshold", min=0, max=64), - StageConfigField("enabled", "bool", True, "Enable scene filtering"), + StageConfigField(name="hamming_threshold", type="int", default=8, description="Hamming distance threshold", min=0, max=64), + StageConfigField(name="enabled", type="bool", default=True, description="Enable scene filtering"), ], - serialize_fn=_ser_filter, - deserialize_fn=_deser_filter, ) - register_stage(scene_filter) + register_stage(scene_filter, serialize_fn=_ser_filter, deserialize_fn=_deser_filter) preprocess = StageDefinition( name="preprocess", @@ -77,11 +74,9 @@ def register(): writes=["preprocessed_crops"], ), config_fields=[ - StageConfigField("contrast", "bool", True, "CLAHE contrast enhancement"), - StageConfigField("deskew", "bool", False, "Correct slight rotation"), - StageConfigField("binarize", "bool", False, "Otsu binarization"), + StageConfigField(name="contrast", type="bool", default=True, description="CLAHE contrast enhancement"), + StageConfigField(name="deskew", type="bool", default=False, description="Correct slight rotation"), + StageConfigField(name="binarize", type="bool", default=False, description="Otsu binarization"), ], - serialize_fn=_ser_preprocess, - deserialize_fn=_deser_preprocess, ) - register_stage(preprocess) + register_stage(preprocess, serialize_fn=_ser_preprocess, deserialize_fn=_deser_preprocess) diff --git a/detect/stages/registry/resolution.py b/core/detect/stages/registry/resolution.py similarity index 77% rename from detect/stages/registry/resolution.py rename to core/detect/stages/registry/resolution.py index a2affe8..cd838cc 100644 --- a/detect/stages/registry/resolution.py +++ b/core/detect/stages/registry/resolution.py @@ -1,6 +1,7 @@ """Registration for resolution stages: brand resolver.""" -from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage +from core.detect.stages.models import StageDefinition, StageIO, StageConfigField +from core.detect.stages.base import register_stage from ._serializers import ( serialize_dataclass_list, serialize_text_candidates, @@ -37,9 +38,7 @@ def register(): optional_reads=["session_brands", "source_asset_id"], ), config_fields=[ - StageConfigField("fuzzy_threshold", "int", 75, "Fuzzy match threshold", min=0, max=100), + StageConfigField(name="fuzzy_threshold", type="int", default=75, description="Fuzzy match threshold", min=0, max=100), ], - serialize_fn=_ser_brands, - deserialize_fn=_deser_brands, ) - register_stage(resolver) + register_stage(resolver, serialize_fn=_ser_brands, deserialize_fn=_deser_brands) diff --git a/detect/stages/scene_filter.py b/core/detect/stages/scene_filter.py similarity index 95% rename from detect/stages/scene_filter.py rename to core/detect/stages/scene_filter.py index cdeb197..f6675a8 100644 --- a/detect/stages/scene_filter.py +++ b/core/detect/stages/scene_filter.py @@ -14,9 +14,9 @@ import time import imagehash from PIL import Image -from detect import emit -from detect.models import Frame -from detect.profiles.base import SceneFilterConfig +from core.detect import emit +from core.detect.models import Frame +from core.detect.stages.models import SceneFilterConfig def _compute_hashes(frames: list[Frame]) -> list[imagehash.ImageHash]: diff --git a/detect/stages/vlm_cloud.py b/core/detect/stages/vlm_cloud.py similarity index 95% rename from detect/stages/vlm_cloud.py rename to core/detect/stages/vlm_cloud.py index df721e4..40f57ea 100644 --- a/detect/stages/vlm_cloud.py +++ b/core/detect/stages/vlm_cloud.py @@ -19,10 +19,10 @@ import time import numpy as np from PIL import Image -from detect import emit -from detect.models import BrandDetection, PipelineStats, TextCandidate -from detect.profiles.base import CropContext -from detect.providers import get_provider, has_api_key +from core.detect import emit +from core.detect.models import BrandDetection, PipelineStats, TextCandidate +from core.detect.models import CropContext +from core.detect.providers import get_provider, has_api_key logger = logging.getLogger(__name__) @@ -33,7 +33,7 @@ def _register_discovered_brand(brand: str, source_asset_id: str | None, timestamp: float, confidence: float): """Register a cloud-confirmed brand in the DB.""" try: - from detect.stages.brand_resolver import _register_brand, _record_sighting + from core.detect.stages.brand_resolver import _register_brand, _record_sighting brand_id = _register_brand(brand, "cloud_llm") if brand_id and source_asset_id: _record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, "cloud_llm") diff --git a/detect/stages/vlm_local.py b/core/detect/stages/vlm_local.py similarity index 93% rename from detect/stages/vlm_local.py rename to core/detect/stages/vlm_local.py index 2e8e1aa..a62cef3 100644 --- a/detect/stages/vlm_local.py +++ b/core/detect/stages/vlm_local.py @@ -14,9 +14,9 @@ import time import numpy as np -from detect import emit -from detect.models import BrandDetection, TextCandidate -from detect.profiles.base import CropContext +from core.detect import emit +from core.detect.models import BrandDetection, TextCandidate +from core.detect.models import CropContext logger = logging.getLogger(__name__) @@ -25,7 +25,7 @@ def _register_discovered_brand(brand: str, source_asset_id: str | None, timestamp: float, confidence: float, source: str): """Register a VLM-confirmed brand in the DB.""" try: - from detect.stages.brand_resolver import _register_brand, _record_sighting + from core.detect.stages.brand_resolver import _register_brand, _record_sighting brand_id = _register_brand(brand, source) if brand_id and source_asset_id: _record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, source) @@ -75,8 +75,8 @@ def escalate_vlm( still_unresolved: list[TextCandidate] = [] if inference_url: - from detect.inference import InferenceClient - from detect.emit import _run_log_level + from core.detect.inference import InferenceClient + from core.detect.emit import _run_log_level client = InferenceClient(base_url=inference_url, job_id=job_id or "", log_level=_run_log_level) for i, candidate in enumerate(candidates): @@ -152,6 +152,6 @@ def escalate_vlm( def _vlm_local(crop: np.ndarray, prompt: str) -> tuple[str, float, str]: """Run moondream2 in-process (single-box mode).""" - from gpu.models.vlm import query + from core.gpu.models.vlm import query result = query(crop, prompt) return result["brand"], result["confidence"], result["reasoning"] diff --git a/detect/stages/yolo_detector.py b/core/detect/stages/yolo_detector.py similarity index 94% rename from detect/stages/yolo_detector.py rename to core/detect/stages/yolo_detector.py index bdb8902..c7c303a 100644 --- a/detect/stages/yolo_detector.py +++ b/core/detect/stages/yolo_detector.py @@ -18,9 +18,9 @@ import time from PIL import Image -from detect import emit -from detect.models import BoundingBox, Frame -from detect.profiles.base import DetectionConfig +from core.detect import emit +from core.detect.models import BoundingBox, Frame +from core.detect.stages.models import DetectionConfig logger = logging.getLogger(__name__) @@ -36,7 +36,7 @@ def _frame_to_b64(frame: Frame) -> str: def _detect_remote(frame: Frame, config: DetectionConfig, inference_url: str, job_id: str = "", log_level: str = "INFO") -> list[BoundingBox]: """Call the inference server over HTTP.""" - from detect.inference import InferenceClient + from core.detect.inference import InferenceClient client = InferenceClient(base_url=inference_url, job_id=job_id, log_level=log_level) results = client.detect( image=frame.image, @@ -104,7 +104,7 @@ def detect_objects( for i, frame in enumerate(frames): t0 = time.monotonic() if inference_url: - from detect.emit import _run_log_level + from core.detect.emit import _run_log_level boxes = _detect_remote(frame, config, inference_url, job_id=job_id or "", log_level=_run_log_level) else: diff --git a/detect/state.py b/core/detect/state.py similarity index 72% rename from detect/state.py rename to core/detect/state.py index bff6cbe..344223c 100644 --- a/detect/state.py +++ b/core/detect/state.py @@ -9,7 +9,7 @@ from __future__ import annotations from typing import TypedDict -from detect.models import BoundingBox, BrandDetection, DetectionReport, Frame, PipelineStats, TextCandidate +from core.detect.models import BoundingBox, BrandDetection, DetectionReport, Frame, PipelineStats, TextCandidate class DetectState(TypedDict, total=False): @@ -22,6 +22,9 @@ class DetectState(TypedDict, total=False): # Stage outputs frames: list[Frame] filtered_frames: list[Frame] + field_masks: dict # {seq: np.ndarray} — pitch mask per frame + field_boundaries: dict # {seq: [(x,y), ...]} — pitch boundary per frame + field_coverage: dict # {seq: float} — pitch coverage ratio per frame edge_regions_by_frame: dict[int, list[BoundingBox]] boxes_by_frame: dict[int, list[BoundingBox]] preprocessed_crops: dict # "{frame_seq}_{box_idx}" → np.ndarray @@ -36,5 +39,5 @@ class DetectState(TypedDict, total=False): # Running stats (updated by each stage) stats: PipelineStats - # Config overrides for replay (applied via OverrideProfile) + # Config overrides for replay (merged into profile configs dict) config_overrides: dict diff --git a/detect/tracing.py b/core/detect/tracing.py similarity index 98% rename from detect/tracing.py rename to core/detect/tracing.py index dda8053..123a4ed 100644 --- a/detect/tracing.py +++ b/core/detect/tracing.py @@ -6,7 +6,7 @@ and stage-level metadata. The Langfuse client is optional — if not configured (no LANGFUSE_SECRET_KEY), tracing is a no-op. Usage in graph nodes: - from detect.tracing import trace_node + from core.detect.tracing import trace_node def node_extract_frames(state): with trace_node(state, "extract_frames") as span: diff --git a/core/ffmpeg/__init__.py b/core/ffmpeg/__init__.py index 6ba3ced..ca3c19d 100644 --- a/core/ffmpeg/__init__.py +++ b/core/ffmpeg/__init__.py @@ -1,13 +1,6 @@ -from .capabilities import get_decoders, get_encoders, get_formats from .probe import ProbeResult, probe_file -from .transcode import TranscodeConfig, transcode __all__ = [ "probe_file", "ProbeResult", - "transcode", - "TranscodeConfig", - "get_encoders", - "get_decoders", - "get_formats", ] diff --git a/core/ffmpeg/capabilities.py b/core/ffmpeg/capabilities.py deleted file mode 100644 index 3f4d6df..0000000 --- a/core/ffmpeg/capabilities.py +++ /dev/null @@ -1,145 +0,0 @@ -""" -FFmpeg capabilities - Discover available codecs and formats using ffmpeg-python. -""" - -from dataclasses import dataclass -from functools import lru_cache -from typing import Any, Dict, List - -import ffmpeg - - -@dataclass -class Codec: - """An FFmpeg encoder or decoder.""" - - name: str - description: str - type: str # 'video' or 'audio' - - -@dataclass -class Format: - """An FFmpeg format (muxer/demuxer).""" - - name: str - description: str - can_demux: bool - can_mux: bool - - -@lru_cache(maxsize=1) -def _get_ffmpeg_info() -> Dict[str, Any]: - """Get FFmpeg capabilities info.""" - # ffmpeg-python doesn't have a direct way to get codecs/formats - # but we can use probe on a dummy or parse -codecs output - # For now, return common codecs that are typically available - return { - "video_encoders": [ - {"name": "libx264", "description": "H.264 / AVC"}, - {"name": "libx265", "description": "H.265 / HEVC"}, - {"name": "mpeg4", "description": "MPEG-4 Part 2"}, - {"name": "libvpx", "description": "VP8"}, - {"name": "libvpx-vp9", "description": "VP9"}, - {"name": "h264_nvenc", "description": "NVIDIA NVENC H.264"}, - {"name": "hevc_nvenc", "description": "NVIDIA NVENC H.265"}, - {"name": "h264_vaapi", "description": "VAAPI H.264"}, - {"name": "prores_ks", "description": "Apple ProRes"}, - {"name": "dnxhd", "description": "Avid DNxHD/DNxHR"}, - {"name": "copy", "description": "Stream copy (no encoding)"}, - ], - "audio_encoders": [ - {"name": "aac", "description": "AAC"}, - {"name": "libmp3lame", "description": "MP3"}, - {"name": "libopus", "description": "Opus"}, - {"name": "libvorbis", "description": "Vorbis"}, - {"name": "pcm_s16le", "description": "PCM signed 16-bit little-endian"}, - {"name": "flac", "description": "FLAC"}, - {"name": "copy", "description": "Stream copy (no encoding)"}, - ], - "formats": [ - {"name": "mp4", "description": "MP4", "can_demux": True, "can_mux": True}, - { - "name": "mov", - "description": "QuickTime / MOV", - "can_demux": True, - "can_mux": True, - }, - { - "name": "mkv", - "description": "Matroska", - "can_demux": True, - "can_mux": True, - }, - {"name": "webm", "description": "WebM", "can_demux": True, "can_mux": True}, - {"name": "avi", "description": "AVI", "can_demux": True, "can_mux": True}, - {"name": "flv", "description": "FLV", "can_demux": True, "can_mux": True}, - { - "name": "ts", - "description": "MPEG-TS", - "can_demux": True, - "can_mux": True, - }, - { - "name": "mpegts", - "description": "MPEG-TS", - "can_demux": True, - "can_mux": True, - }, - {"name": "hls", "description": "HLS", "can_demux": True, "can_mux": True}, - ], - } - - -def get_encoders() -> List[Codec]: - """Get available encoders (video + audio).""" - info = _get_ffmpeg_info() - codecs = [] - - for c in info["video_encoders"]: - codecs.append(Codec(name=c["name"], description=c["description"], type="video")) - - for c in info["audio_encoders"]: - codecs.append(Codec(name=c["name"], description=c["description"], type="audio")) - - return codecs - - -def get_decoders() -> List[Codec]: - """Get available decoders.""" - # Most encoders can also decode - return get_encoders() - - -def get_formats() -> List[Format]: - """Get available formats.""" - info = _get_ffmpeg_info() - return [ - Format( - name=f["name"], - description=f["description"], - can_demux=f["can_demux"], - can_mux=f["can_mux"], - ) - for f in info["formats"] - ] - - -def get_video_encoders() -> List[Codec]: - """Get available video encoders.""" - return [c for c in get_encoders() if c.type == "video"] - - -def get_audio_encoders() -> List[Codec]: - """Get available audio encoders.""" - return [c for c in get_encoders() if c.type == "audio"] - - -def get_muxers() -> List[Format]: - """Get available output formats (muxers).""" - return [f for f in get_formats() if f.can_mux] - - -def get_demuxers() -> List[Format]: - """Get available input formats (demuxers).""" - return [f for f in get_formats() if f.can_demux] diff --git a/core/ffmpeg/transcode.py b/core/ffmpeg/transcode.py deleted file mode 100644 index f12084f..0000000 --- a/core/ffmpeg/transcode.py +++ /dev/null @@ -1,225 +0,0 @@ -""" -FFmpeg transcode module - Transcode media files using ffmpeg-python. -""" - -import sys -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional - -import ffmpeg - - -@dataclass -class TranscodeConfig: - """Configuration for a transcode operation.""" - - input_path: str - output_path: str - - # Video - video_codec: str = "libx264" - video_bitrate: Optional[str] = None - video_crf: Optional[int] = None - video_preset: Optional[str] = None - resolution: Optional[str] = None - framerate: Optional[float] = None - - # Audio - audio_codec: str = "aac" - audio_bitrate: Optional[str] = None - audio_channels: Optional[int] = None - audio_samplerate: Optional[int] = None - - # Trimming - trim_start: Optional[float] = None - trim_end: Optional[float] = None - - # Container - container: str = "mp4" - - # Extra args (key-value pairs) - extra_args: List[str] = field(default_factory=list) - - @property - def is_copy(self) -> bool: - """Check if this is a stream copy (no transcoding).""" - return self.video_codec == "copy" and self.audio_codec == "copy" - - -def build_stream(config: TranscodeConfig): - """ - Build an ffmpeg-python stream from config. - - Returns the stream object ready to run. - """ - # Input options - input_kwargs = {} - if config.trim_start is not None: - input_kwargs["ss"] = config.trim_start - - stream = ffmpeg.input(config.input_path, **input_kwargs) - - # Output options - output_kwargs = { - "vcodec": config.video_codec, - "acodec": config.audio_codec, - } - - # Trimming duration - if config.trim_end is not None: - if config.trim_start is not None: - output_kwargs["t"] = config.trim_end - config.trim_start - else: - output_kwargs["t"] = config.trim_end - - # Video options (skip if copy) - if config.video_codec != "copy": - if config.video_crf is not None: - output_kwargs["crf"] = config.video_crf - elif config.video_bitrate: - output_kwargs["video_bitrate"] = config.video_bitrate - - if config.video_preset: - output_kwargs["preset"] = config.video_preset - - if config.resolution: - output_kwargs["s"] = config.resolution - - if config.framerate: - output_kwargs["r"] = config.framerate - - # Audio options (skip if copy) - if config.audio_codec != "copy": - if config.audio_bitrate: - output_kwargs["audio_bitrate"] = config.audio_bitrate - if config.audio_channels: - output_kwargs["ac"] = config.audio_channels - if config.audio_samplerate: - output_kwargs["ar"] = config.audio_samplerate - - # Parse extra args into kwargs - extra_kwargs = parse_extra_args(config.extra_args) - output_kwargs.update(extra_kwargs) - - stream = ffmpeg.output(stream, config.output_path, **output_kwargs) - stream = ffmpeg.overwrite_output(stream) - - return stream - - -def parse_extra_args(extra_args: List[str]) -> Dict[str, Any]: - """ - Parse extra args list into kwargs dict. - - ["-vtag", "xvid", "-pix_fmt", "yuv420p"] -> {"vtag": "xvid", "pix_fmt": "yuv420p"} - """ - kwargs = {} - i = 0 - while i < len(extra_args): - key = extra_args[i].lstrip("-") - if i + 1 < len(extra_args) and not extra_args[i + 1].startswith("-"): - kwargs[key] = extra_args[i + 1] - i += 2 - else: - # Flag without value - kwargs[key] = None - i += 1 - return kwargs - - -def transcode( - config: TranscodeConfig, - duration: Optional[float] = None, - progress_callback: Optional[Callable[[float, Dict[str, Any]], None]] = None, -) -> bool: - """ - Transcode a media file. - - Args: - config: Transcode configuration - duration: Total duration in seconds (for progress calculation, optional) - progress_callback: Called with (percent, details_dict) - requires duration - - Returns: - True if successful - - Raises: - ffmpeg.Error: If transcoding fails - """ - # Ensure output directory exists - Path(config.output_path).parent.mkdir(parents=True, exist_ok=True) - - stream = build_stream(config) - - if progress_callback and duration: - # Run with progress tracking using run_async - return _run_with_progress(stream, config, duration, progress_callback) - else: - # Run synchronously - ffmpeg.run(stream, capture_stdout=True, capture_stderr=True) - return True - - -def _run_with_progress( - stream, - config: TranscodeConfig, - duration: float, - progress_callback: Callable[[float, Dict[str, Any]], None], -) -> bool: - """Run FFmpeg with progress tracking using run_async and stderr parsing.""" - import re - - # Calculate effective duration - effective_duration = duration - if config.trim_start and config.trim_end: - effective_duration = config.trim_end - config.trim_start - elif config.trim_end: - effective_duration = config.trim_end - elif config.trim_start: - effective_duration = duration - config.trim_start - - # Run async to get process handle - process = ffmpeg.run_async(stream, pipe_stdout=True, pipe_stderr=True) - - # Parse stderr for progress (time=HH:MM:SS.ms pattern) - time_pattern = re.compile(r"time=(\d+):(\d+):(\d+)\.(\d+)") - - while True: - line = process.stderr.readline() - if not line: - break - - line = line.decode("utf-8", errors="ignore") - match = time_pattern.search(line) - if match: - hours = int(match.group(1)) - minutes = int(match.group(2)) - seconds = int(match.group(3)) - ms = int(match.group(4)) - - current_time = hours * 3600 + minutes * 60 + seconds + ms / 100 - percent = min(100.0, (current_time / effective_duration) * 100) - - progress_callback( - percent, - { - "time": current_time, - "percent": percent, - }, - ) - - # Wait for completion - process.wait() - - if process.returncode != 0: - raise ffmpeg.Error( - "ffmpeg", stdout=process.stdout.read(), stderr=process.stderr.read() - ) - - # Final callback - progress_callback( - 100.0, {"time": effective_duration, "percent": 100.0, "done": True} - ) - - return True diff --git a/gpu/.env.template b/core/gpu/.env.template similarity index 100% rename from gpu/.env.template rename to core/gpu/.env.template diff --git a/gpu/Dockerfile b/core/gpu/Dockerfile similarity index 100% rename from gpu/Dockerfile rename to core/gpu/Dockerfile diff --git a/gpu/__init__.py b/core/gpu/__init__.py similarity index 100% rename from gpu/__init__.py rename to core/gpu/__init__.py diff --git a/gpu/config.py b/core/gpu/config.py similarity index 100% rename from gpu/config.py rename to core/gpu/config.py diff --git a/gpu/emit.py b/core/gpu/emit.py similarity index 100% rename from gpu/emit.py rename to core/gpu/emit.py diff --git a/core/gpu/models/__init__.py b/core/gpu/models/__init__.py new file mode 100644 index 0000000..6551d0f --- /dev/null +++ b/core/gpu/models/__init__.py @@ -0,0 +1,6 @@ +# GPU models — standalone container imports. +# When running as a container (cd gpu && python server.py), bare imports work. +# When imported from the main app (core.gpu.models.preprocess), only +# individual modules should be imported directly, not this __init__. +# +# The server.py imports detect/ocr/vlm directly, not through this file. diff --git a/gpu/models/cv/__init__.py b/core/gpu/models/cv/__init__.py similarity index 100% rename from gpu/models/cv/__init__.py rename to core/gpu/models/cv/__init__.py diff --git a/gpu/models/cv/edges.py b/core/gpu/models/cv/edges.py similarity index 100% rename from gpu/models/cv/edges.py rename to core/gpu/models/cv/edges.py diff --git a/core/gpu/models/cv/segmentation.py b/core/gpu/models/cv/segmentation.py new file mode 100644 index 0000000..cf8ee38 --- /dev/null +++ b/core/gpu/models/cv/segmentation.py @@ -0,0 +1,86 @@ +""" +Field segmentation — HSV green mask → pitch boundary contour. + +Pure OpenCV. Called by the inference server endpoint. +""" + +from __future__ import annotations + +import base64 + +import cv2 +import numpy as np + + +def segment_field( + image: np.ndarray, + hue_low: int = 30, + hue_high: int = 85, + sat_low: int = 30, + sat_high: int = 255, + val_low: int = 30, + val_high: int = 255, + morph_kernel: int = 15, + min_area_ratio: float = 0.05, +) -> dict: + """ + Detect the pitch area using HSV green thresholding. + + Returns dict with: + boundary: list of [x, y] points + coverage: float (fraction of frame) + """ + hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) + + lower = np.array([hue_low, sat_low, val_low]) + upper = np.array([hue_high, sat_high, val_high]) + mask = cv2.inRange(hsv, lower, upper) + + k = morph_kernel + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)) + mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) + mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) + + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + h, w = image.shape[:2] + min_area = min_area_ratio * h * w + boundary = [] + coverage = 0.0 + + if contours: + large = [c for c in contours if cv2.contourArea(c) >= min_area] + if large: + pitch_contour = max(large, key=cv2.contourArea) + boundary = pitch_contour.reshape(-1, 2).tolist() + coverage = cv2.contourArea(pitch_contour) / (h * w) + + refined = np.zeros_like(mask) + cv2.drawContours(refined, [pitch_contour], -1, 255, cv2.FILLED) + mask = refined + + return { + "boundary": boundary, + "coverage": coverage, + "mask": mask, + } + + +def segment_field_debug( + image: np.ndarray, + **kwargs, +) -> dict: + """Same as segment_field but includes a mask overlay for the editor.""" + result = segment_field(image, **kwargs) + mask = result["mask"] + + # RGBA overlay: solid green where mask, fully transparent elsewhere + h, w = image.shape[:2] + overlay = np.zeros((h, w, 4), dtype=np.uint8) + overlay[mask > 0] = [0, 255, 0, 255] + _, buf = cv2.imencode(".png", overlay) + result["mask_overlay_b64"] = base64.b64encode(buf.tobytes()).decode() + + # Don't send the raw mask over HTTP + del result["mask"] + return result diff --git a/gpu/models/inference_contract.py b/core/gpu/models/models.py similarity index 80% rename from gpu/models/inference_contract.py rename to core/gpu/models/models.py index 5f63639..f2a90f5 100644 --- a/gpu/models/inference_contract.py +++ b/core/gpu/models/models.py @@ -101,6 +101,30 @@ class AnalyzeRegionsDebugResponse(BaseModel): horizontal_count: int = 0 pair_count: int = 0 +class SegmentFieldRequest(BaseModel): + """Request body for field segmentation.""" + image: str + hue_low: int = 30 + hue_high: int = 85 + sat_low: int = 30 + sat_high: int = 255 + val_low: int = 30 + val_high: int = 255 + morph_kernel: int = 15 + min_area_ratio: float = 0.05 + +class SegmentFieldResponse(BaseModel): + """Response from field segmentation.""" + boundary: List[List[int]] = Field(default_factory=list) + coverage: float = 0.0 + mask_b64: str = "" + +class SegmentFieldDebugResponse(BaseModel): + """Response from field segmentation with debug overlay.""" + boundary: List[List[int]] = Field(default_factory=list) + coverage: float = 0.0 + mask_overlay_b64: str = "" + class ConfigUpdate(BaseModel): """Request body for updating server configuration.""" device: Optional[str] = None diff --git a/gpu/models/ocr.py b/core/gpu/models/ocr.py similarity index 100% rename from gpu/models/ocr.py rename to core/gpu/models/ocr.py diff --git a/gpu/models/preprocess.py b/core/gpu/models/preprocess.py similarity index 100% rename from gpu/models/preprocess.py rename to core/gpu/models/preprocess.py diff --git a/gpu/models/registry.py b/core/gpu/models/registry.py similarity index 100% rename from gpu/models/registry.py rename to core/gpu/models/registry.py diff --git a/gpu/models/vlm.py b/core/gpu/models/vlm.py similarity index 100% rename from gpu/models/vlm.py rename to core/gpu/models/vlm.py diff --git a/gpu/models/yolo.py b/core/gpu/models/yolo.py similarity index 100% rename from gpu/models/yolo.py rename to core/gpu/models/yolo.py diff --git a/gpu/requirements.txt b/core/gpu/requirements.txt similarity index 100% rename from gpu/requirements.txt rename to core/gpu/requirements.txt diff --git a/gpu/run.sh b/core/gpu/run.sh similarity index 100% rename from gpu/run.sh rename to core/gpu/run.sh diff --git a/gpu/server.py b/core/gpu/server.py similarity index 79% rename from gpu/server.py rename to core/gpu/server.py index 76ce8a6..5f828aa 100644 --- a/gpu/server.py +++ b/core/gpu/server.py @@ -54,7 +54,7 @@ def _gpu_log(job_id: str, log_level: str, stage: str, level: str, msg: str): # --- Request/Response models (generated from core/schema/models/inference.py) --- -from models.inference_contract import ( +from models.models import ( AnalyzeRegionsDebugResponse, AnalyzeRegionsRequest, AnalyzeRegionsResponse, @@ -68,6 +68,9 @@ from models.inference_contract import ( PreprocessRequest, PreprocessResponse, RegionBox, + SegmentFieldRequest, + SegmentFieldResponse, + SegmentFieldDebugResponse, VLMRequest, VLMResponse, ) @@ -310,6 +313,83 @@ def detect_edges_debug_endpoint(req: AnalyzeRegionsRequest, request: Request): return response +@app.post("/segment_field", response_model=SegmentFieldResponse) +def segment_field_endpoint(req: SegmentFieldRequest, request: Request): + job_id, log_level = _job_ctx(request) + + try: + image = _decode_image(req.image) + h, w = image.shape[:2] + except Exception as e: + raise HTTPException(status_code=400, detail=f"Bad image: {e}") + + try: + t0 = time.monotonic() + from models.cv.segmentation import segment_field + + result = segment_field( + image, + hue_low=req.hue_low, + hue_high=req.hue_high, + sat_low=req.sat_low, + sat_high=req.sat_high, + val_low=req.val_low, + val_high=req.val_high, + morph_kernel=req.morph_kernel, + min_area_ratio=req.min_area_ratio, + ) + infer_ms = (time.monotonic() - t0) * 1000 + + # Encode mask as base64 PNG for downstream use + import cv2 + _, buf = cv2.imencode(".png", result["mask"]) + mask_b64 = base64.b64encode(buf.tobytes()).decode() + + _gpu_log(job_id, log_level, "GPU:CV", "DEBUG", + f"Field segmentation {w}x{h}: {infer_ms:.0f}ms, coverage={result['coverage']:.1%}") + except Exception as e: + raise HTTPException(status_code=500, detail=f"Field segmentation failed: {e}") + + return SegmentFieldResponse( + boundary=result["boundary"], + coverage=result["coverage"], + mask_b64=mask_b64, + ) + + +@app.post("/segment_field/debug", response_model=SegmentFieldDebugResponse) +def segment_field_debug_endpoint(req: SegmentFieldRequest, request: Request): + job_id, log_level = _job_ctx(request) + + try: + image = _decode_image(req.image) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Bad image: {e}") + + try: + from models.cv.segmentation import segment_field_debug + + result = segment_field_debug( + image, + hue_low=req.hue_low, + hue_high=req.hue_high, + sat_low=req.sat_low, + sat_high=req.sat_high, + val_low=req.val_low, + val_high=req.val_high, + morph_kernel=req.morph_kernel, + min_area_ratio=req.min_area_ratio, + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Field segmentation debug failed: {e}") + + return SegmentFieldDebugResponse( + boundary=result["boundary"], + coverage=result["coverage"], + mask_overlay_b64=result["mask_overlay_b64"], + ) + + if __name__ == "__main__": import uvicorn diff --git a/core/rpc/server.py b/core/rpc/server.py index aac866d..5f02f06 100644 --- a/core/rpc/server.py +++ b/core/rpc/server.py @@ -180,24 +180,11 @@ class WorkerServicer(worker_pb2_grpc.WorkerServiceServicer): def GetWorkerStatus(self, request, context): """Get worker health and capabilities.""" - try: - from core.ffmpeg import get_encoders - - encoders = get_encoders() - codec_names = [e["name"] for e in encoders.get("video", [])] - except Exception: - codec_names = [] - - # Check for GPU encoders - gpu_available = any( - "nvenc" in name or "vaapi" in name or "qsv" in name for name in codec_names - ) - return worker_pb2.WorkerStatus( available=True, active_jobs=len(_active_jobs), - supported_codecs=codec_names[:20], # Limit to 20 - gpu_available=gpu_available, + supported_codecs=[], + gpu_available=False, ) diff --git a/core/schema/modelgen.json b/core/schema/modelgen.json index 9783025..b0b4b36 100644 --- a/core/schema/modelgen.json +++ b/core/schema/modelgen.json @@ -28,7 +28,7 @@ }, { "target": "pydantic", - "output": "detect/sse_contract.py", + "output": "core/detect/sse.py", "include": ["detect_views"] }, { @@ -43,8 +43,13 @@ }, { "target": "pydantic", - "output": "gpu/models/inference_contract.py", + "output": "core/gpu/models/models.py", "include": ["inference_views"] + }, + { + "target": "pydantic", + "output": "core/detect/stages/models.py", + "include": ["stage_views"] } ] } diff --git a/core/schema/models/__init__.py b/core/schema/models/__init__.py index f4a1416..1006e3d 100644 --- a/core/schema/models/__init__.py +++ b/core/schema/models/__init__.py @@ -30,19 +30,18 @@ from .timeline import Timeline from .checkpoint import Checkpoint from .brand import BrandSource, Brand from .media import AssetStatus, MediaAsset -from .presets import BUILTIN_PRESETS, TranscodePreset -from .detect import DETECT_VIEWS # noqa: F401 — discovered by modelgen generic loader -from .inference import INFERENCE_VIEWS # noqa: F401 — GPU inference server API types -from .ui_state import UI_STATE_VIEWS # noqa: F401 — UI store state types -from .stages import StageConfigField, StageIO, StageDefinition, STAGE_VIEWS # noqa: F401 -from .pipeline_config import StageRef, Edge, PipelineConfig, PIPELINE_CONFIG_VIEWS # noqa: F401 -from .detect_api import RunRequest, RunResponse, DETECT_API_VIEWS # noqa: F401 -from .views import ChunkEvent, ChunkOutputFile, PipelineStats, WorkerEvent -from .sources import ChunkInfo, SourceJob, SourceType +from .profile import Profile +from .preset import BUILTIN_PRESETS, TranscodePreset +from .event import DETECT_VIEWS # noqa: F401 +from .inference import INFERENCE_VIEWS # noqa: F401 +from .ui_state import UI_STATE_VIEWS # noqa: F401 +from .stage import STAGE_VIEWS # noqa: F401 +from .view import ChunkEvent, ChunkOutputFile, PipelineStats, WorkerEvent +from .source import ChunkInfo, SourceJob, SourceType # Core domain models - generates SQLModel, TypeScript DATACLASSES = [MediaAsset, TranscodePreset, - Job, Timeline, Checkpoint, Brand] + Job, Timeline, Checkpoint, Brand, Profile] # API request/response models API_MODELS = [ @@ -75,55 +74,3 @@ GRPC_MESSAGES = [ ChunkStreamRequest, ChunkPipelineEvent, ] - -__all__ = [ - # Models - "MediaAsset", - "TranscodePreset", - "Job", - "Timeline", - "Checkpoint", - # Enums - "AssetStatus", - "JobStatus", - "RunType", - "BrandSource", - "SourceType", - # Stages - "StageConfigField", - "StageIO", - "StageDefinition", - # API - "CreateJobRequest", - "UpdateAssetRequest", - "DeleteResult", - "ScanResult", - "SystemStatus", - # gRPC - "GRPC_SERVICE", - "JobRequest", - "JobResponse", - "ProgressRequest", - "ProgressUpdate", - "CancelRequest", - "CancelResponse", - "WorkerStatus", - "Empty", - "ChunkStreamRequest", - "ChunkPipelineEvent", - # Views - "ChunkEvent", - "WorkerEvent", - "PipelineStats", - "ChunkOutputFile", - # Sources - "SourceJob", - "ChunkInfo", - # For generator - "DATACLASSES", - "API_MODELS", - "ENUMS", - "VIEWS", - "GRPC_MESSAGES", - "BUILTIN_PRESETS", -] diff --git a/core/schema/models/checkpoint.py b/core/schema/models/checkpoint.py index 3e9b594..038aee7 100644 --- a/core/schema/models/checkpoint.py +++ b/core/schema/models/checkpoint.py @@ -20,6 +20,7 @@ class Checkpoint: id: UUID timeline_id: UUID + job_id: Optional[UUID] = None # which job created this checkpoint parent_id: Optional[UUID] = None # null = root checkpoint # Stage outputs — JSONB per stage, opaque to the checkpoint layer diff --git a/core/schema/models/detect_api.py b/core/schema/models/detect_api.py deleted file mode 100644 index ff6f998..0000000 --- a/core/schema/models/detect_api.py +++ /dev/null @@ -1,31 +0,0 @@ -""" -Detection API request/response models. - -Source of truth for detection pipeline API shapes. -Generated to Pydantic via modelgen. -""" - -from dataclasses import dataclass - - -@dataclass -class RunRequest: - """Request body for launching a detection pipeline run.""" - video_path: str # storage key - profile_name: str = "soccer_broadcast" - source_asset_id: str = "" - checkpoint: bool = True - skip_vlm: bool = False - skip_cloud: bool = False - log_level: str = "INFO" # INFO | DEBUG - - -@dataclass -class RunResponse: - """Response after starting a pipeline run.""" - status: str - job_id: str - video_path: str - - -DETECT_API_VIEWS = [RunRequest, RunResponse] diff --git a/core/schema/models/detect.py b/core/schema/models/event.py similarity index 91% rename from core/schema/models/detect.py rename to core/schema/models/event.py index 2450af7..8f3b755 100644 --- a/core/schema/models/detect.py +++ b/core/schema/models/event.py @@ -214,6 +214,29 @@ class RetryResponse: job_id: str +# --- API request/response --- + + +@dataclass +class RunRequest: + """Request body for launching a detection pipeline run.""" + video_path: str + profile_name: str = "soccer_broadcast" + source_asset_id: str = "" + checkpoint: bool = True + skip_vlm: bool = False + skip_cloud: bool = False + log_level: str = "INFO" + + +@dataclass +class RunResponse: + """Response after starting a pipeline run.""" + status: str + job_id: str + video_path: str + + # --- Export lists for modelgen --- DETECT_VIEWS = [ @@ -234,4 +257,6 @@ DETECT_VIEWS = [ ReplayResponse, RetryRequest, RetryResponse, + RunRequest, + RunResponse, ] diff --git a/core/schema/models/inference.py b/core/schema/models/inference.py index 117c129..cad4ee4 100644 --- a/core/schema/models/inference.py +++ b/core/schema/models/inference.py @@ -160,6 +160,39 @@ class AnalyzeRegionsDebugResponse: pair_count: int = 0 +# --- Field Segmentation --- + + +@dataclass +class SegmentFieldRequest: + """Request body for field segmentation.""" + image: str # base64 JPEG + hue_low: int = 30 + hue_high: int = 85 + sat_low: int = 30 + sat_high: int = 255 + val_low: int = 30 + val_high: int = 255 + morph_kernel: int = 15 + min_area_ratio: float = 0.05 + + +@dataclass +class SegmentFieldResponse: + """Response from field segmentation.""" + boundary: List[List[int]] = field(default_factory=list) + coverage: float = 0.0 + mask_b64: str = "" # binary mask as base64 PNG (for downstream stages) + + +@dataclass +class SegmentFieldDebugResponse: + """Response from field segmentation with debug overlay.""" + boundary: List[List[int]] = field(default_factory=list) + coverage: float = 0.0 + mask_overlay_b64: str = "" + + # --- Server Config --- @@ -193,5 +226,8 @@ INFERENCE_VIEWS = [ RegionBox, AnalyzeRegionsResponse, AnalyzeRegionsDebugResponse, + SegmentFieldRequest, + SegmentFieldResponse, + SegmentFieldDebugResponse, ConfigUpdate, ] diff --git a/core/schema/models/job.py b/core/schema/models/job.py index fda2d0c..5cf3ba4 100644 --- a/core/schema/models/job.py +++ b/core/schema/models/job.py @@ -38,6 +38,9 @@ class Job: video_path: str profile_name: str = "soccer_broadcast" + # Timeline — set after frame extraction, or upfront for replay jobs + timeline_id: Optional[UUID] = None + # Lineage parent_id: Optional[UUID] = None run_type: RunType = RunType.INITIAL diff --git a/core/schema/models/pipeline_config.py b/core/schema/models/pipeline_config.py deleted file mode 100644 index b2a710e..0000000 --- a/core/schema/models/pipeline_config.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -Pipeline composition config — source of truth for graph topology. - -Defines what stages run, in what order, with what branching. -Belongs to a profile. Persisted as JSONB. - -The execution strategy (serial, parallel, distributed) is separate — -the runner reads this config and flattens it into a sequence for now. -""" - -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional - - -@dataclass -class StageRef: - """Reference to a stage in the pipeline graph.""" - name: str # stage name (matches StageDefinition.name) - branch: str = "trunk" # which branch this belongs to - execution_target: str = "local" # local | gpu | lambda | gcp - - -@dataclass -class Edge: - """Connection between stages in the graph.""" - source: str # stage name - target: str # stage name - condition: str = "" # empty = unconditional, otherwise a routing rule key - - -@dataclass -class PipelineConfig: - """ - Pipeline graph topology + routing rules. - - Holder model — stages/edges define the graph shape, - routing_rules is a JSONB blob for decision tree logic. - """ - name: str - profile_name: str - stages: List[StageRef] = field(default_factory=list) - edges: List[Edge] = field(default_factory=list) - routing_rules: Dict[str, Any] = field(default_factory=dict) - - -PIPELINE_CONFIG_VIEWS = [StageRef, Edge, PipelineConfig] diff --git a/core/schema/models/presets.py b/core/schema/models/preset.py similarity index 100% rename from core/schema/models/presets.py rename to core/schema/models/preset.py diff --git a/core/schema/models/profile.py b/core/schema/models/profile.py new file mode 100644 index 0000000..4c6c011 --- /dev/null +++ b/core/schema/models/profile.py @@ -0,0 +1,30 @@ +""" +Profile schema — source of truth for content type profiles. + +A profile has two JSONB fields: + - pipeline: graph topology (stages, edges, routing rules) + - configs: per-stage config values keyed by stage name + +Validated at read time using generated contracts (StageConfigField, PipelineConfig). +""" + +from dataclasses import dataclass, field +from typing import Any, Dict +from uuid import UUID + + +@dataclass +class Profile: + """ + A content type profile. + + Defines what pipeline to run and how each stage is configured. + Seed data inserted via JSON fixtures on startup. + """ + id: UUID + name: str + pipeline: Dict[str, Any] = field(default_factory=dict) + configs: Dict[str, Any] = field(default_factory=dict) + + +PROFILE_VIEWS = [Profile] diff --git a/core/schema/models/sources.py b/core/schema/models/source.py similarity index 100% rename from core/schema/models/sources.py rename to core/schema/models/source.py diff --git a/core/schema/models/stage.py b/core/schema/models/stage.py new file mode 100644 index 0000000..46f8477 --- /dev/null +++ b/core/schema/models/stage.py @@ -0,0 +1,153 @@ +""" +Stage & Pipeline Schema Definitions + +Source of truth for: +- Stage metadata (StageDefinition, config fields, IO) +- Stage config shapes (FrameExtractionConfig, etc.) +- Pipeline topology (StageRef, Edge, PipelineConfig) + +Generates: Pydantic (detect/contract.py), TypeScript via modelgen. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +# --- Stage metadata --- + +@dataclass +class StageConfigField: + """A single tunable config parameter for the editor UI.""" + name: str + type: str # "float", "int", "str", "bool" + default: Any + description: str = "" + min: Optional[float] = None + max: Optional[float] = None + options: Optional[List[str]] = None + + +@dataclass +class StageIO: + """Declares what a stage reads and writes.""" + reads: List[str] = field(default_factory=list) + writes: List[str] = field(default_factory=list) + optional_reads: List[str] = field(default_factory=list) + + +@dataclass +class StageDefinition: + """Complete metadata for a pipeline stage.""" + name: str + label: str + description: str + category: str = "detection" + io: StageIO = field(default_factory=StageIO) + config_fields: List[StageConfigField] = field(default_factory=list) + tracks_element: Optional[str] = None + + +# --- Stage config shapes --- + +@dataclass +class FrameExtractionConfig: + fps: float = 2.0 + max_frames: int = 500 + + +@dataclass +class SceneFilterConfig: + hamming_threshold: int = 8 + enabled: bool = True + + +@dataclass +class DetectionConfig: + model_name: str = "yolov8n.pt" + confidence_threshold: float = 0.3 + target_classes: List[str] = field(default_factory=lambda: ["logo", "text"]) + + +@dataclass +class OCRConfig: + languages: List[str] = field(default_factory=lambda: ["en"]) + min_confidence: float = 0.5 + + +@dataclass +class ResolverConfig: + fuzzy_threshold: int = 75 + + +@dataclass +class RegionAnalysisConfig: + enabled: bool = True + edge_canny_low: int = 50 + edge_canny_high: int = 150 + edge_hough_threshold: int = 80 + edge_hough_min_length: int = 100 + edge_hough_max_gap: int = 10 + edge_pair_max_distance: int = 200 + edge_pair_min_distance: int = 15 + + +@dataclass +class FieldSegmentationConfig: + enabled: bool = True + # HSV green range for pitch detection + hue_low: int = 30 + hue_high: int = 85 + sat_low: int = 30 + sat_high: int = 255 + val_low: int = 30 + val_high: int = 255 + # Morphology + morph_kernel: int = 15 # kernel size for close/open + min_area_ratio: float = 0.05 # minimum contour area as fraction of frame + + +# --- Pipeline topology --- + +@dataclass +class StageRef: + """Reference to a stage in the pipeline graph.""" + name: str + branch: str = "trunk" + execution_target: str = "local" + + +@dataclass +class Edge: + """Connection between stages in the graph.""" + source: str + target: str + condition: str = "" + + +@dataclass +class PipelineConfig: + """Pipeline graph topology + routing rules.""" + name: str + profile_name: str + stages: List[StageRef] = field(default_factory=list) + edges: List[Edge] = field(default_factory=list) + routing_rules: Dict[str, Any] = field(default_factory=dict) + + +# --- Export for modelgen --- + +STAGE_VIEWS = [ + StageConfigField, + StageIO, + StageDefinition, + FrameExtractionConfig, + SceneFilterConfig, + DetectionConfig, + OCRConfig, + ResolverConfig, + RegionAnalysisConfig, + FieldSegmentationConfig, + StageRef, + Edge, + PipelineConfig, +] diff --git a/core/schema/models/stages.py b/core/schema/models/stages.py deleted file mode 100644 index d7ba1d4..0000000 --- a/core/schema/models/stages.py +++ /dev/null @@ -1,69 +0,0 @@ -""" -Stage Schema Definitions - -Source of truth for pipeline stage metadata. -Generates: Pydantic, TypeScript via modelgen. - -Each stage is defined by its config fields. The implementation -lives in detect/stages/.py as a Stage subclass. -""" - -from dataclasses import dataclass, field -from typing import Any, List, Optional - - -@dataclass -class StageConfigField: - """A single tunable config parameter for the editor UI.""" - name: str - type: str # "float", "int", "str", "bool" - default: Any - description: str = "" - min: Optional[float] = None - max: Optional[float] = None - options: Optional[List[str]] = None - - -@dataclass -class StageIO: - """Declares what a stage reads and writes.""" - reads: List[str] = field(default_factory=list) - writes: List[str] = field(default_factory=list) - optional_reads: List[str] = field(default_factory=list) - - -@dataclass -class StageDefinition: - """ - Complete metadata for a pipeline stage. - - Lives in schema as the source of truth. Each stage implementation - references a StageDefinition. The editor, graph, and checkpoint - system all consume this. - """ - name: str - label: str - description: str - category: str = "detection" - io: StageIO = field(default_factory=StageIO) - config_fields: List[StageConfigField] = field(default_factory=list) - - # The box label this stage produces that should be time-tracked in the editor. - # Set to the label string (e.g. "edge_region") for stages that have a - # meaningful temporal element. None means no motion tracker overlay. - tracks_element: Optional[str] = None - - # Legacy fields — used by old registry pattern during migration. - # New stages use Stage subclass instead. - fn: Any = None - serialize_fn: Any = None - deserialize_fn: Any = None - - -# --- Export for modelgen --- - -STAGE_VIEWS = [ - StageConfigField, - StageIO, - StageDefinition, -] diff --git a/core/schema/models/views.py b/core/schema/models/view.py similarity index 100% rename from core/schema/models/views.py rename to core/schema/models/view.py diff --git a/core/schema/serializers/pipeline.py b/core/schema/serializers/pipeline.py index e8b1440..92e78a7 100644 --- a/core/schema/serializers/pipeline.py +++ b/core/schema/serializers/pipeline.py @@ -11,7 +11,7 @@ from __future__ import annotations import dataclasses -from core.schema.models.pipeline import ( +from core.detect.models import ( BoundingBox, BrandDetection, BrandStats, @@ -36,7 +36,7 @@ def serialize_frame_meta(frame: Frame) -> dict: def serialize_frames_with_upload(frames: list[Frame], job_id: str) -> tuple[list[dict], dict[int, str]]: """Upload frame images to S3, return metadata + manifest.""" - from detect.checkpoint.frames import save_frames + from core.detect.checkpoint.frames import save_frames manifest = save_frames(job_id, frames) meta = [serialize_frame_meta(f) for f in frames] @@ -45,7 +45,7 @@ def serialize_frames_with_upload(frames: list[Frame], job_id: str) -> tuple[list def deserialize_frames_with_download(meta: list[dict], manifest: dict, job_id: str) -> list[Frame]: """Load frames from S3 + metadata.""" - from detect.checkpoint.frames import load_frames + from core.detect.checkpoint.frames import load_frames int_manifest = {int(k): v for k, v in manifest.items()} return load_frames(int_manifest, meta) diff --git a/ctrl/Tiltfile b/ctrl/Tiltfile index 43521e7..4ba229f 100644 --- a/ctrl/Tiltfile +++ b/ctrl/Tiltfile @@ -18,7 +18,7 @@ docker_build( 'mpr-fastapi', context='..', dockerfile='Dockerfile', - ignore=['.git', 'def', 'docs', 'media', 'ui', 'gpu', 'modelgen', '.claude', 'tests'], + ignore=['.git', 'def', 'docs', 'media', 'ui', 'modelgen', '.claude', 'tests'], live_update=[ sync('..', '/app'), ], diff --git a/ctrl/sync.sh b/ctrl/sync.sh index 53476fc..c040959 100755 --- a/ctrl/sync.sh +++ b/ctrl/sync.sh @@ -1,9 +1,9 @@ #!/bin/bash -# Sync gpu/ folder to the GPU machine +# Sync core/gpu/ folder to the GPU machine # Usage: ./ctrl/sync.sh [HOST] [DEST] # # Examples: -# ./ctrl/sync.sh # defaults: mcrn:~/mpr/gpu +# ./ctrl/sync.sh # defaults: mcrndeb:~/wdir/mpr/gpu # ./ctrl/sync.sh 192.168.1.3 # custom host # ./ctrl/sync.sh mcrn ~/inference # custom host + dest @@ -13,11 +13,11 @@ cd "$(dirname "$0")/.." HOST="${1:-mcrndeb}" DEST="${2:-~/wdir/mpr/gpu}" -echo "Syncing gpu/ to ${HOST}:${DEST}..." +echo "Syncing core/gpu/ to ${HOST}:${DEST}..." rsync -avz --exclude='.git' --exclude='__pycache__' \ --exclude='*.pyc' --exclude='.env' \ --filter=':- .gitignore' \ - gpu/ "${HOST}:${DEST}/" + core/gpu/ "${HOST}:${DEST}/" echo "Done. Run on ${HOST}:" echo " cd ${DEST} && cp .env.template .env && ./run.sh" diff --git a/detect/checkpoint/runner_bridge.py b/detect/checkpoint/runner_bridge.py deleted file mode 100644 index e2b53df..0000000 --- a/detect/checkpoint/runner_bridge.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -Runner bridge — checkpoint hook called by PipelineRunner after each stage. - -Owns the per-job state (frame manifest cache, checkpoint chain) that -the runner shouldn't know about. -""" - -from __future__ import annotations - -import logging - -logger = logging.getLogger(__name__) - -# Per-job state -_frames_manifest: dict[str, dict[int, str]] = {} -_latest_checkpoint: dict[str, str] = {} - - -def reset_checkpoint_state(job_id: str): - """Clean up per-job checkpoint state. Called when pipeline finishes.""" - _frames_manifest.pop(job_id, None) - _latest_checkpoint.pop(job_id, None) - - -def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict): - """ - Save a checkpoint after a stage completes. - - Called by the runner. Handles: - - Frame upload (once, on first stage) - - Stage output serialization (via stage registry) - - Checkpoint chain (parent → child) - """ - if not job_id: - return - - from .storage import save_stage_output - from .frames import save_frames - from detect.stages.base import _REGISTRY - - merged = {**state, **result} - - # Save frames once (first stage that produces them) - manifest = _frames_manifest.get(job_id) - if manifest is None and stage_name == "extract_frames": - manifest = save_frames(job_id, merged.get("frames", [])) - _frames_manifest[job_id] = manifest - - # Serialize stage output using the stage's serialize_fn if available - stage_cls = _REGISTRY.get(stage_name) - serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None) - if serialize_fn: - output_json = serialize_fn(merged, job_id) - else: - output_json = {} - - parent_id = _latest_checkpoint.get(job_id) - new_checkpoint_id = save_stage_output( - timeline_id=job_id, - parent_checkpoint_id=parent_id, - stage_name=stage_name, - output_json=output_json, - ) - _latest_checkpoint[job_id] = new_checkpoint_id diff --git a/detect/models.py b/detect/models.py deleted file mode 100644 index 2e2aeca..0000000 --- a/detect/models.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Re-export pipeline runtime models from core/schema/models/pipeline.py.""" - -from core.schema.models.pipeline import ( - BoundingBox, - BrandDetection, - BrandStats, - DetectionReport, - Frame, - PipelineStats, - TextCandidate, -) - -__all__ = [ - "BoundingBox", - "BrandDetection", - "BrandStats", - "DetectionReport", - "Frame", - "PipelineStats", - "TextCandidate", -] diff --git a/detect/profiles/__init__.py b/detect/profiles/__init__.py deleted file mode 100644 index 4b21b5b..0000000 --- a/detect/profiles/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -from .base import ( - ContentTypeProfile, - CropContext, - DetectionConfig, - FrameExtractionConfig, - OCRConfig, - ResolverConfig, - SceneFilterConfig, -) -from .soccer import SoccerBroadcastProfile - -_PROFILES: dict[str, type] = { - "soccer_broadcast": SoccerBroadcastProfile, -} - - -def get_profile(name: str) -> ContentTypeProfile: - """Get a profile instance by name.""" - cls = _PROFILES.get(name) - if cls is None: - raise ValueError(f"Unknown profile: {name!r}. Available: {list(_PROFILES)}") - return cls() - - -__all__ = [ - "ContentTypeProfile", - "CropContext", - "DetectionConfig", - "FrameExtractionConfig", - "OCRConfig", - "ResolverConfig", - "SceneFilterConfig", - "SoccerBroadcastProfile", - "get_profile", -] diff --git a/detect/profiles/base.py b/detect/profiles/base.py deleted file mode 100644 index f5ab78e..0000000 --- a/detect/profiles/base.py +++ /dev/null @@ -1,96 +0,0 @@ -""" -ContentTypeProfile protocol and config dataclasses. - -Each profile defines the pipeline topology (as a JSONB blob), stage configs, -brand dictionary, VLM prompt templates, and aggregation strategy. - -When profiles are persisted, the pipeline field is a JSONB column. -For now, profiles are code-only and pipeline_config() returns a hardcoded value. -""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Any, Dict, Protocol - -from detect.models import BrandDetection, DetectionReport -from core.schema.models.pipeline_config import PipelineConfig, StageRef, Edge - - -@dataclass -class FrameExtractionConfig: - fps: float = 2.0 - max_frames: int = 500 - - -@dataclass -class SceneFilterConfig: - hamming_threshold: int = 8 - enabled: bool = True - - -@dataclass -class DetectionConfig: - model_name: str = "yolov8n.pt" - confidence_threshold: float = 0.3 - target_classes: list[str] = field(default_factory=lambda: ["logo", "text"]) - - -@dataclass -class OCRConfig: - languages: list[str] = field(default_factory=lambda: ["en"]) - min_confidence: float = 0.5 - - -@dataclass -class ResolverConfig: - fuzzy_threshold: int = 75 - - -@dataclass -class RegionAnalysisConfig: - enabled: bool = True - # Edge detection (Canny + HoughLinesP) - edge_canny_low: int = 50 - edge_canny_high: int = 150 - edge_hough_threshold: int = 80 - edge_hough_min_length: int = 100 - edge_hough_max_gap: int = 10 - edge_pair_max_distance: int = 200 - edge_pair_min_distance: int = 15 - - -@dataclass -class CropContext: - image: bytes - surrounding_text: str = "" - position_hint: str = "" - - -def pipeline_config_from_dict(data: Dict[str, Any]) -> PipelineConfig: - """Deserialize a PipelineConfig from a JSONB dict.""" - stages = [StageRef(**s) for s in data.get("stages", [])] - edges = [Edge(**e) for e in data.get("edges", [])] - return PipelineConfig( - name=data.get("name", ""), - profile_name=data.get("profile_name", ""), - stages=stages, - edges=edges, - routing_rules=data.get("routing_rules", {}), - ) - - -class ContentTypeProfile(Protocol): - name: str - pipeline: Dict[str, Any] # JSONB blob — PipelineConfig shape - - def pipeline_config(self) -> PipelineConfig: ... - def frame_extraction_config(self) -> FrameExtractionConfig: ... - def scene_filter_config(self) -> SceneFilterConfig: ... - def region_analysis_config(self) -> RegionAnalysisConfig: ... - def detection_config(self) -> DetectionConfig: ... - def ocr_config(self) -> OCRConfig: ... - def resolver_config(self) -> ResolverConfig: ... - def vlm_prompt(self, crop_context: CropContext) -> str: ... - def aggregate(self, detections: list[BrandDetection]) -> DetectionReport: ... - def auxiliary_detections(self, source: str) -> list[BrandDetection]: ... diff --git a/detect/profiles/soccer.py b/detect/profiles/soccer.py deleted file mode 100644 index c35727f..0000000 --- a/detect/profiles/soccer.py +++ /dev/null @@ -1,122 +0,0 @@ -"""Soccer broadcast profile — pitch hoardings, kits, scoreboards.""" - -from __future__ import annotations - -from core.schema.models.pipeline_config import PipelineConfig -from detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats - -from .base import ( - CropContext, - DetectionConfig, - FrameExtractionConfig, - OCRConfig, - RegionAnalysisConfig, - ResolverConfig, - SceneFilterConfig, - pipeline_config_from_dict, -) - - -class SoccerBroadcastProfile: - name = "soccer_broadcast" - - # Pipeline topology as JSONB — will be a DB field when profiles are persisted - pipeline = { - "name": "soccer_broadcast", - "profile_name": "soccer_broadcast", - "stages": [ - {"name": "extract_frames", "branch": "trunk"}, - {"name": "filter_scenes", "branch": "trunk"}, - {"name": "detect_edges", "branch": "hoarding"}, - {"name": "detect_objects", "branch": "objects"}, - {"name": "preprocess"}, - {"name": "run_ocr"}, - {"name": "match_brands"}, - {"name": "escalate_vlm"}, - {"name": "escalate_cloud"}, - {"name": "compile_report"}, - ], - "edges": [ - {"source": "extract_frames", "target": "filter_scenes"}, - {"source": "filter_scenes", "target": "detect_edges"}, - {"source": "filter_scenes", "target": "detect_objects"}, - {"source": "detect_edges", "target": "preprocess"}, - {"source": "detect_objects", "target": "preprocess"}, - {"source": "preprocess", "target": "run_ocr"}, - {"source": "run_ocr", "target": "match_brands"}, - {"source": "match_brands", "target": "escalate_vlm"}, - {"source": "escalate_vlm", "target": "escalate_cloud"}, - {"source": "escalate_cloud", "target": "compile_report"}, - ], - } - - def pipeline_config(self) -> PipelineConfig: - return pipeline_config_from_dict(self.pipeline) - - def frame_extraction_config(self) -> FrameExtractionConfig: - return FrameExtractionConfig(fps=2.0, max_frames=500) - - def scene_filter_config(self) -> SceneFilterConfig: - return SceneFilterConfig(hamming_threshold=8, enabled=True) - - def region_analysis_config(self) -> RegionAnalysisConfig: - return RegionAnalysisConfig( - edge_canny_low=50, - edge_canny_high=150, - edge_hough_threshold=80, - edge_hough_min_length=100, - edge_hough_max_gap=10, - edge_pair_max_distance=200, - edge_pair_min_distance=15, - ) - - def detection_config(self) -> DetectionConfig: - return DetectionConfig( - model_name="yolov8n.pt", - confidence_threshold=0.3, - target_classes=[], # empty = accept all COCO classes (until custom model) - ) - - def ocr_config(self) -> OCRConfig: - return OCRConfig(languages=["en", "es"], min_confidence=0.5) - - def resolver_config(self) -> ResolverConfig: - return ResolverConfig(fuzzy_threshold=75) - - def vlm_prompt(self, crop_context: CropContext) -> str: - hint = f" Position: {crop_context.position_hint}." if crop_context.position_hint else "" - text = f" Nearby text: '{crop_context.surrounding_text}'." if crop_context.surrounding_text else "" - return ( - f"Identify the brand or sponsor visible in this cropped region " - f"from a soccer broadcast.{hint}{text} " - f"Respond with: brand, confidence (0-1), reasoning." - ) - - def aggregate(self, detections: list[BrandDetection]) -> DetectionReport: - brands: dict[str, BrandStats] = {} - for d in detections: - if d.brand not in brands: - brands[d.brand] = BrandStats() - s = brands[d.brand] - s.total_appearances += 1 - s.total_screen_time += d.duration - s.avg_confidence = ( - (s.avg_confidence * (s.total_appearances - 1) + d.confidence) - / s.total_appearances - ) - if s.first_seen == 0.0 or d.timestamp < s.first_seen: - s.first_seen = d.timestamp - if d.timestamp > s.last_seen: - s.last_seen = d.timestamp - - return DetectionReport( - video_source="", - content_type=self.name, - duration_seconds=0.0, - brands=brands, - timeline=sorted(detections, key=lambda d: d.timestamp), - pipeline_stats=PipelineStats(), - ) - - def auxiliary_detections(self, source: str) -> list[BrandDetection]: - return [] diff --git a/detect/profiles/stubs.py b/detect/profiles/stubs.py deleted file mode 100644 index ba43163..0000000 --- a/detect/profiles/stubs.py +++ /dev/null @@ -1,98 +0,0 @@ -"""Stub profiles — interfaces defined, not yet implemented.""" - -from __future__ import annotations - -from detect.models import BrandDetection, DetectionReport - -from .base import ( - CropContext, - DetectionConfig, - FrameExtractionConfig, - OCRConfig, - ResolverConfig, - SceneFilterConfig, -) - - -class NewsBroadcastProfile: - name = "news_broadcast" - - def frame_extraction_config(self) -> FrameExtractionConfig: - raise NotImplementedError - - def scene_filter_config(self) -> SceneFilterConfig: - raise NotImplementedError - - def detection_config(self) -> DetectionConfig: - raise NotImplementedError - - def ocr_config(self) -> OCRConfig: - raise NotImplementedError - - def resolver_config(self) -> ResolverConfig: - raise NotImplementedError - - def vlm_prompt(self, crop_context: CropContext) -> str: - raise NotImplementedError - - def aggregate(self, detections: list[BrandDetection]) -> DetectionReport: - raise NotImplementedError - - def auxiliary_detections(self, source: str) -> list[BrandDetection]: - raise NotImplementedError - - -class AdvertisingProfile: - name = "advertising" - - def frame_extraction_config(self) -> FrameExtractionConfig: - raise NotImplementedError - - def scene_filter_config(self) -> SceneFilterConfig: - raise NotImplementedError - - def detection_config(self) -> DetectionConfig: - raise NotImplementedError - - def ocr_config(self) -> OCRConfig: - raise NotImplementedError - - def resolver_config(self) -> ResolverConfig: - raise NotImplementedError - - def vlm_prompt(self, crop_context: CropContext) -> str: - raise NotImplementedError - - def aggregate(self, detections: list[BrandDetection]) -> DetectionReport: - raise NotImplementedError - - def auxiliary_detections(self, source: str) -> list[BrandDetection]: - raise NotImplementedError - - -class TranscriptProfile: - name = "transcript" - - def frame_extraction_config(self) -> FrameExtractionConfig: - raise NotImplementedError - - def scene_filter_config(self) -> SceneFilterConfig: - raise NotImplementedError - - def detection_config(self) -> DetectionConfig: - raise NotImplementedError - - def ocr_config(self) -> OCRConfig: - raise NotImplementedError - - def resolver_config(self) -> ResolverConfig: - raise NotImplementedError - - def vlm_prompt(self, crop_context: CropContext) -> str: - raise NotImplementedError - - def aggregate(self, detections: list[BrandDetection]) -> DetectionReport: - raise NotImplementedError - - def auxiliary_detections(self, source: str) -> list[BrandDetection]: - raise NotImplementedError diff --git a/detect/stages/registry/cv_analysis.py b/detect/stages/registry/cv_analysis.py deleted file mode 100644 index 8ffb6c8..0000000 --- a/detect/stages/registry/cv_analysis.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Registration for CV analysis stages: edge detection.""" - -from detect.stages.base import StageDefinition, StageIO, StageConfigField, register_stage -from ._serializers import serialize_dataclass_list, deserialize_bounding_box - - -def _ser_regions(state: dict, job_id: str) -> dict: - regions = state.get("edge_regions_by_frame", {}) - serialized = { - str(seq): serialize_dataclass_list(bl) for seq, bl in regions.items() - } - return {"edge_regions_by_frame": serialized} - - -def _deser_regions(data: dict, job_id: str) -> dict: - regions = {} - for seq_str, box_dicts in data.get("edge_regions_by_frame", {}).items(): - regions[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts] - return {"edge_regions_by_frame": regions} - - -def register(): - edge_detection = StageDefinition( - name="detect_edges", - label="Edge Detection", - description="Canny + HoughLinesP — find horizontal line pairs (hoarding boundaries)", - category="cv_analysis", - io=StageIO( - reads=["filtered_frames"], - writes=["edge_regions_by_frame"], - ), - config_fields=[ - StageConfigField("enabled", "bool", True, "Enable region analysis"), - StageConfigField("edge_canny_low", "int", 50, "Canny low threshold", min=0, max=255), - StageConfigField("edge_canny_high", "int", 150, "Canny high threshold", min=0, max=255), - StageConfigField("edge_hough_threshold", "int", 80, "Hough accumulator threshold", min=1, max=500), - StageConfigField("edge_hough_min_length", "int", 100, "Min line length (px)", min=10, max=2000), - StageConfigField("edge_hough_max_gap", "int", 10, "Max line gap (px)", min=1, max=100), - StageConfigField("edge_pair_max_distance", "int", 200, "Max distance between line pair (px)", min=10, max=500), - StageConfigField("edge_pair_min_distance", "int", 15, "Min distance between line pair (px)", min=5, max=200), - ], - serialize_fn=_ser_regions, - deserialize_fn=_deser_regions, - ) - register_stage(edge_detection) diff --git a/gpu/models/__init__.py b/gpu/models/__init__.py deleted file mode 100644 index 3dd5327..0000000 --- a/gpu/models/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from . import registry -from .yolo import detect -from .ocr import ocr - -__all__ = ["registry", "detect", "ocr"] diff --git a/modelgen/generator/pydantic.py b/modelgen/generator/pydantic.py index 4cef31d..75077fd 100644 --- a/modelgen/generator/pydantic.py +++ b/modelgen/generator/pydantic.py @@ -246,6 +246,7 @@ class PydanticGenerator(BaseGenerator): "", ] + def _generate_enum(self, enum_def: EnumDefinition) -> List[str]: lines = [f"class {enum_def.name}(str, Enum):"] for name, value in enum_def.values: diff --git a/modelgen/generator/sqlmodel.py b/modelgen/generator/sqlmodel.py index 4498ed7..40c574b 100644 --- a/modelgen/generator/sqlmodel.py +++ b/modelgen/generator/sqlmodel.py @@ -99,14 +99,9 @@ def _resolve_field(name, type_hint, default): return "" -def _to_snake_plural(name): - """CamelCase → snake_case_plural for table names.""" - s = re.sub(r"(?<=[a-z])(?=[A-Z])", "_", name).lower() - if s.endswith("y") and not s.endswith("ey"): - return s[:-1] + "ies" - if s.endswith("s"): - return s + "es" - return s + "s" +def _to_snake(name): + """CamelCase → snake_case for table names.""" + return re.sub(r"(?<=[a-z])(?=[A-Z])", "_", name).lower() _HEADER = [ @@ -162,7 +157,7 @@ class SQLModelGenerator(PydanticGenerator): def _build_table(name, docstring, hints, fields, resolve_type_fn): """Build a SQLModel table class from field data.""" - table_name = _to_snake_plural(name) + table_name = _to_snake(name) lines = [ f"class {name}(SQLModel, table=True):", f' """{docstring.strip().split(chr(10))[0]}"""', diff --git a/modelgen/types.py b/modelgen/types.py index 9f7300a..274d14a 100644 --- a/modelgen/types.py +++ b/modelgen/types.py @@ -54,6 +54,7 @@ PYDANTIC_RESOLVERS: dict[Any, Callable[[Any], str]] = { int: lambda _: "int", float: lambda _: "float", bool: lambda _: "bool", + Any: lambda _: "Any", "UUID": lambda _: "UUID", "datetime": lambda _: "datetime", "dict": lambda _: "Dict[str, Any]", diff --git a/requirements.txt b/requirements.txt index bda5b00..f4305cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,6 +36,7 @@ sqlmodel>=0.0.14 psycopg2-binary>=2.9.9 # Detection pipeline orchestration +opencv-python-headless>=4.8.0 numpy>=1.24.0 Pillow>=10.0.0 imagehash>=4.3.0 diff --git a/tests/chunker/__init__.py b/tests/chunker/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/chunker/conftest.py b/tests/chunker/conftest.py deleted file mode 100644 index 1ddafe6..0000000 --- a/tests/chunker/conftest.py +++ /dev/null @@ -1,76 +0,0 @@ -""" -Shared fixtures for chunker tests. - -Demonstrates: TDD and unit testing best practices (Interview Topic 8) — fixtures, temp files. -""" - -import os -import tempfile - -import pytest - -from core.chunker.models import Chunk, ChunkResult - - -@pytest.fixture -def temp_file(): - """Create a temporary file with known content, cleaned up after test.""" - files = [] - - def _create(content: bytes = b"x" * 4096): - f = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") - f.write(content) - f.close() - files.append(f.name) - return f.name - - yield _create - - for path in files: - if os.path.exists(path): - os.unlink(path) - - -@pytest.fixture -def sample_chunk(temp_file): - """Create a sample time-based Chunk with valid time range.""" - path = temp_file(b"x" * 1024) - return Chunk( - sequence=0, - start_time=0.0, - end_time=10.0, - source_path=path, - duration=10.0, - ) - - -@pytest.fixture -def make_chunk(temp_file): - """Factory fixture for creating time-based chunks with specific sequence numbers.""" - path = temp_file(b"x" * 1024) - - def _make(sequence: int, duration: float = 10.0) -> Chunk: - start = sequence * duration - return Chunk( - sequence=sequence, - start_time=start, - end_time=start + duration, - source_path=path, - duration=duration, - ) - - return _make - - -@pytest.fixture -def make_result(): - """Factory fixture for creating ChunkResults.""" - - def _make(sequence: int, success: bool = True, processing_time: float = 0.01) -> ChunkResult: - return ChunkResult( - sequence=sequence, - success=success, - processing_time=processing_time, - ) - - return _make diff --git a/tests/chunker/test_chunker.py b/tests/chunker/test_chunker.py deleted file mode 100644 index 7f132fa..0000000 --- a/tests/chunker/test_chunker.py +++ /dev/null @@ -1,149 +0,0 @@ -""" -Tests for Chunker — time-based segmentation, chunk counts, sequence numbers, generator behavior. - -Demonstrates: TDD (Interview Topic 8) — parametrized tests, edge cases, mocking. -""" - -from unittest.mock import patch, MagicMock - -import pytest - -from core.chunker import Chunker -from core.chunker.exceptions import ChunkReadError - - -def mock_probe(duration): - """Create a mock probe_file that returns the given duration.""" - result = MagicMock() - result.duration = duration - return result - - -class TestChunker: - @patch("core.chunker.chunker.probe_file") - def test_basic_chunking(self, mock_pf, temp_file): - """File splits into expected number of time-based chunks.""" - path = temp_file(b"x" * 1000) - mock_pf.return_value = mock_probe(30.0) - - chunker = Chunker(path, chunk_duration=10.0) - chunks = list(chunker.chunks()) - - assert len(chunks) == 3 - assert chunks[0].start_time == 0.0 - assert chunks[0].end_time == 10.0 - assert chunks[0].duration == 10.0 - assert chunks[1].start_time == 10.0 - assert chunks[2].start_time == 20.0 - - @patch("core.chunker.chunker.probe_file") - def test_sequence_numbers(self, mock_pf, temp_file): - """Chunks have sequential sequence numbers starting at 0.""" - path = temp_file(b"x" * 100) - mock_pf.return_value = mock_probe(40.0) - - chunker = Chunker(path, chunk_duration=10.0) - chunks = list(chunker.chunks()) - sequences = [c.sequence for c in chunks] - - assert sequences == [0, 1, 2, 3] - - @patch("core.chunker.chunker.probe_file") - def test_time_ranges(self, mock_pf, temp_file): - """Each chunk has correct start_time and end_time.""" - path = temp_file(b"x" * 100) - mock_pf.return_value = mock_probe(25.0) - - chunker = Chunker(path, chunk_duration=10.0) - chunks = list(chunker.chunks()) - - assert chunks[0].start_time == 0.0 - assert chunks[0].end_time == 10.0 - assert chunks[1].start_time == 10.0 - assert chunks[1].end_time == 20.0 - assert chunks[2].start_time == 20.0 - assert chunks[2].end_time == 25.0 # last chunk shorter - assert chunks[2].duration == 5.0 - - @patch("core.chunker.chunker.probe_file") - def test_expected_chunks_property(self, mock_pf, temp_file): - """expected_chunks calculates correctly before iteration.""" - path = temp_file(b"x" * 100) - mock_pf.return_value = mock_probe(25.0) - - chunker = Chunker(path, chunk_duration=10.0) - assert chunker.expected_chunks == 3 # ceil(25/10) - - @patch("core.chunker.chunker.probe_file") - def test_source_path_on_chunks(self, mock_pf, temp_file): - """Each chunk carries the source file path.""" - path = temp_file(b"x" * 100) - mock_pf.return_value = mock_probe(10.0) - - chunker = Chunker(path, chunk_duration=10.0) - chunks = list(chunker.chunks()) - - assert all(c.source_path == path for c in chunks) - - def test_file_not_found(self): - """Non-existent file raises ChunkReadError.""" - with pytest.raises(ChunkReadError, match="File not found"): - Chunker("/nonexistent/file.mp4") - - @patch("core.chunker.chunker.probe_file") - def test_invalid_chunk_duration(self, mock_pf, temp_file): - """Zero or negative chunk_duration raises ValueError.""" - path = temp_file(b"x" * 100) - - with pytest.raises(ValueError, match="chunk_duration must be positive"): - Chunker(path, chunk_duration=0) - - with pytest.raises(ValueError, match="chunk_duration must be positive"): - Chunker(path, chunk_duration=-1) - - @patch("core.chunker.chunker.probe_file") - def test_generator_laziness(self, mock_pf, temp_file): - """Chunks are yielded lazily, not pre-loaded.""" - path = temp_file(b"x" * 100) - mock_pf.return_value = mock_probe(30.0) - - chunker = Chunker(path, chunk_duration=10.0) - gen = chunker.chunks() - first = next(gen) - assert first.sequence == 0 - # Generator is not exhausted — remaining chunks still pending - - @pytest.mark.parametrize("duration,chunk_dur,expected", [ - (10.0, 10.0, 1), - (10.1, 10.0, 2), - (1.0, 1.0, 1), - (100.0, 1.0, 100), - (5.0, 100.0, 1), - ]) - @patch("core.chunker.chunker.probe_file") - def test_expected_chunks_parametrized(self, mock_pf, temp_file, duration, chunk_dur, expected): - """Parametrized: various duration/chunk_duration combos.""" - path = temp_file(b"x" * 100) - mock_pf.return_value = mock_probe(duration) - chunker = Chunker(path, chunk_duration=chunk_dur) - assert chunker.expected_chunks == expected - - @patch("core.chunker.chunker.probe_file") - def test_exact_multiple(self, mock_pf, temp_file): - """Duration exactly divisible by chunk_duration.""" - path = temp_file(b"x" * 100) - mock_pf.return_value = mock_probe(30.0) - - chunker = Chunker(path, chunk_duration=10.0) - chunks = list(chunker.chunks()) - assert len(chunks) == 3 - assert all(c.duration == 10.0 for c in chunks) - - @patch("core.chunker.chunker.probe_file") - def test_probe_failure(self, mock_pf, temp_file): - """Probe failure raises ChunkReadError.""" - path = temp_file(b"x" * 100) - mock_pf.side_effect = Exception("ffprobe failed") - - with pytest.raises(ChunkReadError, match="Failed to probe"): - Chunker(path, chunk_duration=10.0) diff --git a/tests/chunker/test_collector.py b/tests/chunker/test_collector.py deleted file mode 100644 index 8dc1e64..0000000 --- a/tests/chunker/test_collector.py +++ /dev/null @@ -1,103 +0,0 @@ -""" -Tests for ResultCollector — ordered reassembly, out-of-order buffering, duplicates. - -Demonstrates: TDD (Interview Topic 8) — testing algorithms (heapq reassembly). -""" - -import pytest - -from core.chunker.collector import ResultCollector -from core.chunker.exceptions import ReassemblyError - - -class TestResultCollector: - def test_in_order_emission(self, make_result): - """Results arriving in order are emitted immediately.""" - collector = ResultCollector(total_chunks=3) - - emitted = collector.add(make_result(0)) - assert len(emitted) == 1 - assert emitted[0].sequence == 0 - - emitted = collector.add(make_result(1)) - assert len(emitted) == 1 - - emitted = collector.add(make_result(2)) - assert len(emitted) == 1 - - assert collector.is_complete - - def test_out_of_order_buffering(self, make_result): - """Out-of-order results are buffered until gaps fill.""" - collector = ResultCollector(total_chunks=3) - - # Arrive: 2, 0, 1 - emitted = collector.add(make_result(2)) - assert len(emitted) == 0 - assert collector.buffered_count == 1 - - emitted = collector.add(make_result(0)) - assert len(emitted) == 1 # Only 0 emitted, 1 still missing - - emitted = collector.add(make_result(1)) - assert len(emitted) == 2 # 1 and 2 now emittable - assert collector.is_complete - - def test_reverse_order(self, make_result): - """All results arrive in reverse — only last add emits everything.""" - collector = ResultCollector(total_chunks=4) - - for seq in [3, 2, 1]: - emitted = collector.add(make_result(seq)) - assert len(emitted) == 0 - - emitted = collector.add(make_result(0)) - assert len(emitted) == 4 - assert collector.is_complete - - def test_duplicate_raises(self, make_result): - """Duplicate sequence number raises ReassemblyError.""" - collector = ResultCollector(total_chunks=3) - collector.add(make_result(0)) - - with pytest.raises(ReassemblyError, match="Duplicate"): - collector.add(make_result(0)) - - def test_emitted_count(self, make_result): - """emitted_count tracks correctly.""" - collector = ResultCollector(total_chunks=3) - assert collector.emitted_count == 0 - - collector.add(make_result(0)) - assert collector.emitted_count == 1 - - collector.add(make_result(2)) # buffered - assert collector.emitted_count == 1 - - collector.add(make_result(1)) # releases 1 and 2 - assert collector.emitted_count == 3 - - def test_get_ordered_results(self, make_result): - """get_ordered_results returns all emitted results in order.""" - collector = ResultCollector(total_chunks=3) - collector.add(make_result(2)) - collector.add(make_result(0)) - collector.add(make_result(1)) - - ordered = collector.get_ordered_results() - assert [r.sequence for r in ordered] == [0, 1, 2] - - def test_avg_processing_time(self, make_result): - """Average processing time from sliding window.""" - collector = ResultCollector(total_chunks=2) - collector.add(make_result(0, processing_time=0.1)) - collector.add(make_result(1, processing_time=0.3)) - - assert abs(collector.avg_processing_time - 0.2) < 0.001 - - def test_not_complete_when_partial(self, make_result): - """is_complete is False until all chunks emitted.""" - collector = ResultCollector(total_chunks=3) - collector.add(make_result(0)) - collector.add(make_result(1)) - assert not collector.is_complete diff --git a/tests/chunker/test_exceptions.py b/tests/chunker/test_exceptions.py deleted file mode 100644 index 91ff59e..0000000 --- a/tests/chunker/test_exceptions.py +++ /dev/null @@ -1,69 +0,0 @@ -""" -Tests for exception hierarchy — catch patterns, attributes. - -Demonstrates: TDD (Interview Topic 8) — testing exception design. -""" - -import pytest - -from core.chunker.exceptions import ( - ChunkChecksumError, - ChunkError, - ChunkReadError, - PipelineError, - ProcessingError, - ProcessorFailureError, - ProcessorTimeoutError, - ReassemblyError, -) - - -class TestExceptionHierarchy: - """Verify the exception class hierarchy and catch patterns.""" - - def test_pipeline_error_is_base(self): - """All chunker exceptions inherit from PipelineError.""" - assert issubclass(ChunkError, PipelineError) - assert issubclass(ProcessingError, PipelineError) - assert issubclass(ReassemblyError, PipelineError) - - def test_chunk_error_subtypes(self): - """ChunkReadError and ChunkChecksumError are ChunkErrors.""" - assert issubclass(ChunkReadError, ChunkError) - assert issubclass(ChunkChecksumError, ChunkError) - - def test_processing_error_subtypes(self): - """ProcessorTimeoutError and ProcessorFailureError are ProcessingErrors.""" - assert issubclass(ProcessorTimeoutError, ProcessingError) - assert issubclass(ProcessorFailureError, ProcessingError) - - def test_catch_pipeline_error_catches_all(self): - """Catching PipelineError catches any subtype.""" - with pytest.raises(PipelineError): - raise ChunkReadError("test") - - with pytest.raises(PipelineError): - raise ReassemblyError("test") - - def test_checksum_error_attributes(self): - """ChunkChecksumError carries sequence, expected, actual.""" - err = ChunkChecksumError(sequence=5, expected="aaa", actual="bbb") - assert err.sequence == 5 - assert err.expected == "aaa" - assert err.actual == "bbb" - assert "5" in str(err) - - def test_timeout_error_attributes(self): - """ProcessorTimeoutError carries sequence and timeout.""" - err = ProcessorTimeoutError(sequence=3, timeout=30.0) - assert err.sequence == 3 - assert err.timeout == 30.0 - - def test_failure_error_attributes(self): - """ProcessorFailureError carries sequence, retries, original error.""" - original = RuntimeError("boom") - err = ProcessorFailureError(sequence=1, retries=3, original_error=original) - assert err.sequence == 1 - assert err.retries == 3 - assert err.original_error is original - assert "boom" in str(err) diff --git a/tests/chunker/test_pipeline.py b/tests/chunker/test_pipeline.py deleted file mode 100644 index e12e2d9..0000000 --- a/tests/chunker/test_pipeline.py +++ /dev/null @@ -1,144 +0,0 @@ -""" -Tests for Pipeline — end-to-end orchestration, stats, error handling. - -Demonstrates: TDD (Interview Topic 8) — integration testing with mocked FFmpeg probe. -""" - -from unittest.mock import MagicMock, patch - -import pytest - -from core.chunker import Pipeline -from core.chunker.exceptions import PipelineError - - -def mock_probe(duration): - """Create a mock ProbeResult with the given duration.""" - result = MagicMock() - result.duration = duration - return result - - -class TestPipeline: - @patch("core.chunker.chunker.probe_file") - def test_end_to_end(self, mock_pf, temp_file): - """Full pipeline processes a file successfully.""" - path = temp_file(b"x" * 4096) - mock_pf.return_value = mock_probe(40.0) - - result = Pipeline( - source=path, - chunk_duration=10.0, - num_workers=2, - processor_type="checksum", - ).run() - - assert result.total_chunks == 4 - assert result.processed == 4 - assert result.failed == 0 - assert result.elapsed_time > 0 - assert result.chunks_in_order is True - - @patch("core.chunker.chunker.probe_file") - def test_throughput_calculated(self, mock_pf, temp_file): - """Pipeline calculates throughput.""" - path = temp_file(b"x" * 10000) - mock_pf.return_value = mock_probe(30.0) - - result = Pipeline(source=path, chunk_duration=10.0, num_workers=2).run() - - assert result.throughput_mbps > 0 - - @patch("core.chunker.chunker.probe_file") - def test_worker_stats(self, mock_pf, temp_file): - """Pipeline reports per-worker stats.""" - path = temp_file(b"x" * 4000) - mock_pf.return_value = mock_probe(40.0) - - result = Pipeline( - source=path, chunk_duration=10.0, num_workers=2 - ).run() - - assert len(result.worker_stats) == 2 - for worker_id, stats in result.worker_stats.items(): - assert "processed" in stats - assert "errors" in stats - - def test_nonexistent_file(self): - """Non-existent file raises PipelineError.""" - with pytest.raises(PipelineError): - Pipeline(source="/nonexistent/file.mp4").run() - - @patch("core.chunker.chunker.probe_file") - def test_event_callback(self, mock_pf, temp_file): - """Pipeline emits events through callback.""" - path = temp_file(b"x" * 2048) - mock_pf.return_value = mock_probe(20.0) - events = [] - - def capture(event_type, data): - events.append(event_type) - - Pipeline( - source=path, - chunk_duration=10.0, - num_workers=1, - event_callback=capture, - ).run() - - assert "pipeline_start" in events - assert "pipeline_complete" in events - assert "chunk_queued" in events - - @patch("core.chunker.chunker.probe_file") - def test_simulated_decode_processor(self, mock_pf, temp_file): - """Pipeline works with simulated_decode processor.""" - path = temp_file(b"x" * 2048) - mock_pf.return_value = mock_probe(20.0) - - result = Pipeline( - source=path, - chunk_duration=10.0, - num_workers=2, - processor_type="simulated_decode", - ).run() - - assert result.total_chunks == 2 - assert result.failed == 0 - - @patch("core.chunker.chunker.probe_file") - def test_single_chunk_file(self, mock_pf, temp_file): - """Duration shorter than chunk_duration produces one chunk.""" - path = temp_file(b"x" * 100) - mock_pf.return_value = mock_probe(5.0) - - result = Pipeline(source=path, chunk_duration=10.0).run() - - assert result.total_chunks == 1 - assert result.processed == 1 - - @patch("core.chunker.chunker.probe_file") - def test_retries_tracked(self, mock_pf, temp_file): - """Pipeline result tracks total retries.""" - path = temp_file(b"x" * 2048) - mock_pf.return_value = mock_probe(20.0) - - result = Pipeline(source=path, chunk_duration=10.0).run() - - assert result.retries >= 0 # Might be 0 if no failures - - @patch("core.chunker.chunker.probe_file") - def test_output_dir_and_chunk_files(self, mock_pf, temp_file): - """Pipeline tracks output_dir and chunk_files when set.""" - path = temp_file(b"x" * 1024) - mock_pf.return_value = mock_probe(10.0) - - result = Pipeline( - source=path, - chunk_duration=10.0, - processor_type="checksum", - ).run() - - # No output_dir set, so chunk_files should be empty - assert result.output_dir is None - assert result.chunk_files == [] diff --git a/tests/chunker/test_processor.py b/tests/chunker/test_processor.py deleted file mode 100644 index 68980f2..0000000 --- a/tests/chunker/test_processor.py +++ /dev/null @@ -1,98 +0,0 @@ -""" -Tests for Processor implementations — ChecksumProcessor, SimulatedDecodeProcessor, CompositeProcessor. - -Demonstrates: TDD (Interview Topic 8) — ABC contract, parametrized tests. -""" - -import pytest - -from core.chunker.exceptions import ChunkChecksumError -from core.chunker.models import Chunk -from core.chunker.processor import ( - ChecksumProcessor, - CompositeProcessor, - Processor, - SimulatedDecodeProcessor, -) - - -class TestChecksumProcessor: - def test_valid_time_range(self, sample_chunk): - """Valid time range passes.""" - proc = ChecksumProcessor() - result = proc.process(sample_chunk) - assert result.success is True - assert result.checksum_valid is True - assert result.processing_time > 0 - - def test_invalid_time_range(self): - """Invalid time range raises ChunkChecksumError.""" - chunk = Chunk( - sequence=0, - start_time=10.0, - end_time=10.0, # zero duration - source_path="/fake.mp4", - duration=0.0, - ) - proc = ChecksumProcessor() - with pytest.raises(ChunkChecksumError) as exc_info: - proc.process(chunk) - assert exc_info.value.sequence == 0 - - def test_sequence_preserved(self, make_chunk): - """Result carries the chunk's sequence number.""" - chunk = make_chunk(42) - proc = ChecksumProcessor() - result = proc.process(chunk) - assert result.sequence == 42 - - -class TestSimulatedDecodeProcessor: - def test_processes_successfully(self, sample_chunk): - """Simulated decode always succeeds.""" - proc = SimulatedDecodeProcessor(ms_per_second=1.0) - result = proc.process(sample_chunk) - assert result.success is True - assert result.processing_time > 0 - - def test_time_proportional_to_duration(self): - """Longer chunks take longer.""" - short = Chunk(0, 0.0, 1.0, "/fake.mp4", 1.0) - long = Chunk(1, 0.0, 10.0, "/fake.mp4", 10.0) - - proc = SimulatedDecodeProcessor(ms_per_second=50.0) - r_short = proc.process(short) - r_long = proc.process(long) - - assert r_long.processing_time > r_short.processing_time - - -class TestCompositeProcessor: - def test_chains_processors(self, sample_chunk): - """Composite runs all processors in sequence.""" - proc = CompositeProcessor([ - ChecksumProcessor(), - SimulatedDecodeProcessor(ms_per_second=1.0), - ]) - result = proc.process(sample_chunk) - assert result.success is True - - def test_stops_on_failure(self): - """If first processor raises, composite propagates the error.""" - bad_chunk = Chunk(0, 10.0, 10.0, "/fake.mp4", 0.0) # invalid range - proc = CompositeProcessor([ - ChecksumProcessor(), - SimulatedDecodeProcessor(ms_per_second=1.0), - ]) - with pytest.raises(ChunkChecksumError): - proc.process(bad_chunk) - - def test_requires_at_least_one(self): - """Empty processor list raises ValueError.""" - with pytest.raises(ValueError, match="at least one"): - CompositeProcessor([]) - - def test_is_processor(self): - """CompositeProcessor is a Processor.""" - proc = CompositeProcessor([ChecksumProcessor()]) - assert isinstance(proc, Processor) diff --git a/tests/chunker/test_queue.py b/tests/chunker/test_queue.py deleted file mode 100644 index 7ebee6b..0000000 --- a/tests/chunker/test_queue.py +++ /dev/null @@ -1,115 +0,0 @@ -""" -Tests for ChunkQueue — backpressure, sentinel shutdown, timeout behavior. - -Demonstrates: TDD (Interview Topic 8) — concurrency testing. -""" - -import queue -import threading - -import pytest - -from core.chunker.queue import ChunkQueue - - -class TestChunkQueue: - def test_put_and_get(self, make_chunk): - """Basic put/get cycle.""" - q = ChunkQueue(maxsize=5) - chunk = make_chunk(0) - q.put(chunk) - result = q.get(timeout=1.0) - assert result.sequence == 0 - - def test_fifo_order(self, make_chunk): - """Items come out in FIFO order.""" - q = ChunkQueue(maxsize=5) - for i in range(3): - q.put(make_chunk(i)) - - for i in range(3): - assert q.get(timeout=1.0).sequence == i - - def test_close_returns_none(self, make_chunk): - """After close(), get() returns None (sentinel).""" - q = ChunkQueue(maxsize=5) - q.put(make_chunk(0)) - q.close() - - result = q.get(timeout=1.0) - assert result.sequence == 0 - - # Next get should hit sentinel - result = q.get(timeout=1.0) - assert result is None - - def test_close_propagates_to_multiple_consumers(self, make_chunk): - """Sentinel propagates: multiple consumers all get None.""" - q = ChunkQueue(maxsize=5) - q.close() - - # Multiple consumers should all see None - assert q.get(timeout=1.0) is None - assert q.get(timeout=1.0) is None - - def test_is_closed(self): - """is_closed reflects state.""" - q = ChunkQueue() - assert not q.is_closed - q.close() - assert q.is_closed - - def test_qsize(self, make_chunk): - """qsize tracks approximate queue depth.""" - q = ChunkQueue(maxsize=10) - assert q.qsize() == 0 - - q.put(make_chunk(0)) - q.put(make_chunk(1)) - assert q.qsize() == 2 - - q.get(timeout=1.0) - assert q.qsize() == 1 - - def test_backpressure_blocks(self, make_chunk): - """Put blocks when queue is full (backpressure).""" - q = ChunkQueue(maxsize=2) - q.put(make_chunk(0)) - q.put(make_chunk(1)) - - # Queue is full — put with short timeout should raise - with pytest.raises(queue.Full): - q.put(make_chunk(2), timeout=0.05) - - def test_get_timeout(self): - """Get on empty queue with timeout raises Empty.""" - q = ChunkQueue(maxsize=5) - - with pytest.raises(queue.Empty): - q.get(timeout=0.05) - - def test_concurrent_put_get(self, make_chunk): - """Producer/consumer threads work correctly.""" - q = ChunkQueue(maxsize=3) - results = [] - - def producer(): - for i in range(10): - q.put(make_chunk(i)) - q.close() - - def consumer(): - while True: - item = q.get(timeout=2.0) - if item is None: - break - results.append(item.sequence) - - t1 = threading.Thread(target=producer) - t2 = threading.Thread(target=consumer) - t1.start() - t2.start() - t1.join(timeout=5.0) - t2.join(timeout=5.0) - - assert sorted(results) == list(range(10)) diff --git a/tests/chunker/test_worker.py b/tests/chunker/test_worker.py deleted file mode 100644 index 12af394..0000000 --- a/tests/chunker/test_worker.py +++ /dev/null @@ -1,127 +0,0 @@ -""" -Tests for Worker — processing, retry with backoff, error handling. - -Demonstrates: TDD (Interview Topic 8) — mocking processors, testing retry logic. -""" - -from unittest.mock import MagicMock - -import pytest - -from core.chunker.models import Chunk, ChunkResult -from core.chunker.processor import Processor -from core.chunker.queue import ChunkQueue -from core.chunker.worker import Worker - - -class FailNTimesProcessor(Processor): - """Test processor that fails N times then succeeds.""" - - def __init__(self, fail_count: int): - self.fail_count = fail_count - self.call_count = 0 - - def process(self, chunk: Chunk) -> ChunkResult: - self.call_count += 1 - if self.call_count <= self.fail_count: - raise RuntimeError(f"Simulated failure #{self.call_count}") - return ChunkResult( - sequence=chunk.sequence, - success=True, - processing_time=0.001, - ) - - -class AlwaysFailProcessor(Processor): - """Test processor that always fails.""" - - def process(self, chunk: Chunk) -> ChunkResult: - raise RuntimeError("Always fails") - - -class TestWorker: - def test_processes_chunks(self, make_chunk): - """Worker processes all chunks from queue.""" - q = ChunkQueue(maxsize=5) - for i in range(3): - q.put(make_chunk(i)) - q.close() - - from core.chunker.processor import ChecksumProcessor - worker = Worker("w-0", q, ChecksumProcessor(), max_retries=0) - results = worker.run() - - assert len(results) == 3 - assert all(r.success for r in results) - - def test_retry_on_failure(self, make_chunk): - """Worker retries on processor failure.""" - q = ChunkQueue(maxsize=5) - q.put(make_chunk(0)) - q.close() - - proc = FailNTimesProcessor(fail_count=2) - worker = Worker("w-0", q, proc, max_retries=3) - results = worker.run() - - assert len(results) == 1 - assert results[0].success is True - assert results[0].retries == 2 - assert proc.call_count == 3 # 2 failures + 1 success - - def test_max_retries_exceeded(self, make_chunk): - """Worker gives up after max retries.""" - q = ChunkQueue(maxsize=5) - q.put(make_chunk(0)) - q.close() - - worker = Worker("w-0", q, AlwaysFailProcessor(), max_retries=2) - results = worker.run() - - assert len(results) == 1 - assert results[0].success is False - assert results[0].error is not None - assert worker.error_count == 1 - - def test_worker_id_on_results(self, make_chunk): - """Worker stamps its ID on results.""" - q = ChunkQueue(maxsize=5) - q.put(make_chunk(0)) - q.close() - - from core.chunker.processor import ChecksumProcessor - worker = Worker("worker-7", q, ChecksumProcessor()) - results = worker.run() - - assert results[0].worker_id == "worker-7" - - def test_event_callback(self, make_chunk): - """Worker emits events via callback.""" - q = ChunkQueue(maxsize=5) - q.put(make_chunk(0)) - q.close() - - events = [] - callback = MagicMock(side_effect=lambda t, d: events.append((t, d))) - - from core.chunker.processor import ChecksumProcessor - worker = Worker("w-0", q, ChecksumProcessor(), event_callback=callback) - worker.run() - - event_types = [e[0] for e in events] - assert "worker_status" in event_types - assert "chunk_processing" in event_types - assert "chunk_done" in event_types - - def test_processed_count(self, make_chunk): - """Worker tracks processed count.""" - q = ChunkQueue(maxsize=10) - for i in range(5): - q.put(make_chunk(i)) - q.close() - - from core.chunker.processor import ChecksumProcessor - worker = Worker("w-0", q, ChecksumProcessor()) - worker.run() - - assert worker.processed_count == 5 diff --git a/tests/detect/manual/run_extract_filter.py b/tests/detect/manual/run_extract_filter.py index 4b967eb..13745c1 100644 --- a/tests/detect/manual/run_extract_filter.py +++ b/tests/detect/manual/run_extract_filter.py @@ -24,9 +24,9 @@ logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(m sys.path.insert(0, ".") -from detect.profiles.soccer import SoccerBroadcastProfile -from detect.stages.frame_extractor import extract_frames -from detect.stages.scene_filter import scene_filter +from core.detect.profile import get_profile +from core.detect.stages.frame_extractor import extract_frames +from core.detect.stages.scene_filter import scene_filter logger = logging.getLogger(__name__) diff --git a/tests/detect/manual/run_graph.py b/tests/detect/manual/run_graph.py index 267445c..b4f5555 100644 --- a/tests/detect/manual/run_graph.py +++ b/tests/detect/manual/run_graph.py @@ -24,8 +24,8 @@ logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(m sys.path.insert(0, ".") -from detect.graph import get_pipeline -from detect.state import DetectState +from core.detect.graph import get_pipeline +from core.detect.state import DetectState logger = logging.getLogger(__name__) diff --git a/tests/detect/manual/run_region_analysis.py b/tests/detect/manual/run_region_analysis.py index f0fe663..ac6c123 100644 --- a/tests/detect/manual/run_region_analysis.py +++ b/tests/detect/manual/run_region_analysis.py @@ -39,13 +39,13 @@ sys.path.insert(0, ".") from langgraph.graph import END, StateGraph -from detect import emit -from detect.models import PipelineStats -from detect.profiles.soccer import SoccerBroadcastProfile -from detect.stages.frame_extractor import extract_frames -from detect.stages.scene_filter import scene_filter -from detect.stages.edge_detector import detect_edge_regions -from detect.state import DetectState +from core.detect import emit +from core.detect.models import PipelineStats +from core.detect.profile import get_profile +from core.detect.stages.frame_extractor import extract_frames +from core.detect.stages.scene_filter import scene_filter +from core.detect.stages.edge_detector import detect_edge_regions +from core.detect.state import DetectState logger = logging.getLogger(__name__) @@ -166,7 +166,7 @@ def main(): # --- Parameter sensitivity --- logger.info("=== Parameter sensitivity (local debug) ===") - from detect.stages.edge_detector import _load_cv_edges + from core.detect.stages.edge_detector import _load_cv_edges edges_mod = _load_cv_edges() filtered = result.get("filtered_frames", []) diff --git a/tests/detect/manual/seed_scenario.py b/tests/detect/manual/seed_scenario.py index 81bf80d..d97dbfe 100644 --- a/tests/detect/manual/seed_scenario.py +++ b/tests/detect/manual/seed_scenario.py @@ -58,7 +58,7 @@ def extract_frames_ffmpeg(video_path: str, fps: float, max_frames: int): import numpy as np from PIL import Image - from detect.models import Frame + from core.detect.models import Frame tmpdir = tempfile.mkdtemp(prefix="scenario_") pattern = os.path.join(tmpdir, "frame_%04d.jpg") @@ -111,7 +111,7 @@ def main(): logger.info("Extracted %d frames", len(frames)) # Create timeline + branch + checkpoint - from detect.checkpoint.storage import create_timeline, save_stage_output + from core.detect.checkpoint.storage import create_timeline, save_stage_output timeline_id, branch_id = create_timeline( source_video=video_path, diff --git a/tests/detect/manual/test_cloud_provider.py b/tests/detect/manual/test_cloud_provider.py index 1a7936a..39adc41 100644 --- a/tests/detect/manual/test_cloud_provider.py +++ b/tests/detect/manual/test_cloud_provider.py @@ -58,7 +58,7 @@ def make_brand_image(text: str, width: int = 300, height: int = 100) -> str: def main(): - from detect.providers import get_provider, has_api_key, PROVIDERS + from core.detect.providers import get_provider, has_api_key, PROVIDERS provider_name = os.environ.get("CLOUD_LLM_PROVIDER", "groq") logger.info("Provider: %s", provider_name) diff --git a/tests/detect/manual/test_frame_extractor_e2e.py b/tests/detect/manual/test_frame_extractor_e2e.py index 0b1e23b..88867b4 100644 --- a/tests/detect/manual/test_frame_extractor_e2e.py +++ b/tests/detect/manual/test_frame_extractor_e2e.py @@ -13,8 +13,8 @@ import sys sys.path.insert(0, ".") -from detect.profiles.soccer import SoccerBroadcastProfile -from detect.stages.frame_extractor import extract_frames +from core.detect.profile import get_profile +from core.detect.stages.frame_extractor import extract_frames logger = logging.getLogger(__name__) diff --git a/tests/detect/manual/test_ocr_e2e.py b/tests/detect/manual/test_ocr_e2e.py index d3c0e90..2e0d25a 100644 --- a/tests/detect/manual/test_ocr_e2e.py +++ b/tests/detect/manual/test_ocr_e2e.py @@ -86,9 +86,9 @@ def test_ocr_stage_remote(url: str): logger.info("--- OCR stage (remote mode) ---") sys.path.insert(0, ".") - from detect.models import BoundingBox, Frame - from detect.profiles.base import OCRConfig - from detect.stages.ocr_stage import run_ocr + from core.detect.models import BoundingBox, Frame + from core.detect.stages.models import OCRConfig + from core.detect.stages.ocr_stage import run_ocr # Create a frame with text baked in image = make_text_image("EMIRATES") diff --git a/tests/detect/manual/test_replay.py b/tests/detect/manual/test_replay.py index e319882..02a75f9 100644 --- a/tests/detect/manual/test_replay.py +++ b/tests/detect/manual/test_replay.py @@ -48,10 +48,10 @@ def main(): # Override Redis to localhost (ctrl/.env has k8s hostname) os.environ["REDIS_URL"] = f"redis://localhost:{args.port}/0" - from detect.graph import get_pipeline, NODES - from detect.checkpoint import list_checkpoints - from detect.checkpoint import replay_from - from detect.state import DetectState + from core.detect.graph import get_pipeline, NODES + from core.detect.checkpoint import list_checkpoints + from core.detect.checkpoint import replay_from + from core.detect.state import DetectState VIDEO = "media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4" diff --git a/tests/detect/test_aggregator.py b/tests/detect/test_aggregator.py index 7d268fa..fb3cf78 100644 --- a/tests/detect/test_aggregator.py +++ b/tests/detect/test_aggregator.py @@ -2,8 +2,8 @@ import pytest -from detect.models import BoundingBox, BrandDetection, PipelineStats -from detect.stages.aggregator import compile_report, _merge_contiguous +from core.detect.models import BoundingBox, BrandDetection, PipelineStats +from core.detect.stages.aggregator import compile_report, _merge_contiguous def _make_detection(brand: str, timestamp: float, duration: float = 0.5, @@ -43,7 +43,7 @@ def test_merge_empty(): def test_compile_report(monkeypatch): events = [] - monkeypatch.setattr("detect.emit.push_detect_event", + monkeypatch.setattr("core.detect.emit.push_detect_event", lambda job_id, etype, data: events.append((etype, data))) dets = [ diff --git a/tests/detect/test_brand_resolver.py b/tests/detect/test_brand_resolver.py index 72add92..8fbd745 100644 --- a/tests/detect/test_brand_resolver.py +++ b/tests/detect/test_brand_resolver.py @@ -3,9 +3,9 @@ import numpy as np import pytest -from detect.models import BoundingBox, Frame, TextCandidate -from detect.profiles.base import ResolverConfig -from detect.stages.brand_resolver import resolve_brands, _normalize, _match_session +from core.detect.models import BoundingBox, Frame, TextCandidate +from core.detect.stages.models import ResolverConfig +from core.detect.stages.brand_resolver import resolve_brands, _normalize, _match_session CONFIG = ResolverConfig(fuzzy_threshold=75) @@ -28,7 +28,7 @@ def test_session_match(): def test_resolve_with_session(monkeypatch): events = [] - monkeypatch.setattr("detect.emit.push_detect_event", + monkeypatch.setattr("core.detect.emit.push_detect_event", lambda job_id, etype, data: events.append((etype, data))) session = {"nike": "Nike", "emirates": "Emirates"} @@ -46,7 +46,7 @@ def test_resolve_with_session(monkeypatch): def test_resolve_unresolved_without_db(monkeypatch): events = [] - monkeypatch.setattr("detect.emit.push_detect_event", + monkeypatch.setattr("core.detect.emit.push_detect_event", lambda job_id, etype, data: events.append((etype, data))) candidates = [_make_candidate("random garbage text")] @@ -61,7 +61,7 @@ def test_resolve_unresolved_without_db(monkeypatch): def test_resolve_empty(monkeypatch): events = [] - monkeypatch.setattr("detect.emit.push_detect_event", + monkeypatch.setattr("core.detect.emit.push_detect_event", lambda job_id, etype, data: events.append((etype, data))) matched, unresolved = resolve_brands([], CONFIG, session_brands={}) @@ -73,7 +73,7 @@ def test_resolve_empty(monkeypatch): def test_resolve_builds_session_during_run(monkeypatch): """Session brands accumulate during a single run — second candidate benefits.""" events = [] - monkeypatch.setattr("detect.emit.push_detect_event", + monkeypatch.setattr("core.detect.emit.push_detect_event", lambda job_id, etype, data: events.append((etype, data))) session = {"nike": "Nike"} @@ -93,7 +93,7 @@ def test_resolve_builds_session_during_run(monkeypatch): def test_events_emitted(monkeypatch): events = [] - monkeypatch.setattr("detect.emit.push_detect_event", + monkeypatch.setattr("core.detect.emit.push_detect_event", lambda job_id, etype, data: events.append((etype, data))) session = {"nike": "Nike"} diff --git a/tests/detect/test_checkpoint.py b/tests/detect/test_checkpoint.py index c530cb3..f59070e 100644 --- a/tests/detect/test_checkpoint.py +++ b/tests/detect/test_checkpoint.py @@ -5,7 +5,7 @@ import json import numpy as np import pytest -from detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate +from core.detect.models import BoundingBox, BrandDetection, Frame, PipelineStats, TextCandidate from core.schema.serializers._common import safe_construct from core.schema.serializers.pipeline import ( serialize_frame_meta, @@ -163,34 +163,39 @@ def test_all_serialized_is_json_compatible(): assert roundtrip["frame_meta"]["sequence"] == frame.sequence -# --- OverrideProfile --- +# --- Config overrides (dict merging, replaces OverrideProfile) --- -def test_override_profile_region_analysis(): - """OverrideProfile must patch region_analysis_config with overrides.""" - from detect.checkpoint.replay import OverrideProfile - from detect.profiles.soccer import SoccerBroadcastProfile - from detect.profiles.base import RegionAnalysisConfig +def test_config_override_region_analysis(): + """Config overrides must patch stage config values.""" + from core.detect.profile import get_profile, get_stage_config + from core.detect.stages.models import RegionAnalysisConfig - base = SoccerBroadcastProfile() - original = base.region_analysis_config() + profile = get_profile("soccer_broadcast") + original = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges")) - overrides = {"region_analysis": {"edge_canny_low": 25, "edge_canny_high": 200}} - wrapped = OverrideProfile(base, overrides) - patched = wrapped.region_analysis_config() + overrides = {"detect_edges": {"edge_canny_low": 25, "edge_canny_high": 200}} + merged_configs = {**profile["configs"]} + merged_configs["detect_edges"] = {**merged_configs["detect_edges"], **overrides["detect_edges"]} + patched_profile = {**profile, "configs": merged_configs} + + patched = RegionAnalysisConfig(**get_stage_config(patched_profile, "detect_edges")) - assert isinstance(patched, RegionAnalysisConfig) assert patched.edge_canny_low == 25 assert patched.edge_canny_high == 200 - # Unmodified fields keep their defaults assert patched.edge_hough_threshold == original.edge_hough_threshold -def test_override_profile_passthrough(): - """OverrideProfile without region_analysis key passes through unchanged.""" - from detect.checkpoint.replay import OverrideProfile - from detect.profiles.soccer import SoccerBroadcastProfile +def test_config_override_passthrough(): + """Overrides for other stages don't affect unrelated stages.""" + from core.detect.profile import get_profile, get_stage_config + from core.detect.stages.models import RegionAnalysisConfig - base = SoccerBroadcastProfile() - wrapped = OverrideProfile(base, {"ocr": {"min_confidence": 0.1}}) - config = wrapped.region_analysis_config() - assert config.edge_canny_low == base.region_analysis_config().edge_canny_low + profile = get_profile("soccer_broadcast") + original = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges")) + + overrides = {"run_ocr": {"min_confidence": 0.1}} + merged_configs = {**profile["configs"], **overrides} + patched_profile = {**profile, "configs": merged_configs} + + patched = RegionAnalysisConfig(**get_stage_config(patched_profile, "detect_edges")) + assert patched.edge_canny_low == original.edge_canny_low diff --git a/tests/detect/test_config_endpoint.py b/tests/detect/test_config_endpoint.py index 291d92f..c21337b 100644 --- a/tests/detect/test_config_endpoint.py +++ b/tests/detect/test_config_endpoint.py @@ -1,6 +1,6 @@ """Tests for the config endpoint and stage palette.""" -from detect.stages import list_stages, get_palette +from core.detect.stages import list_stages, get_palette def test_stage_palette_has_config_fields(): diff --git a/tests/detect/test_edge_sensitivity.py b/tests/detect/test_edge_sensitivity.py index dac30d0..13634c5 100644 --- a/tests/detect/test_edge_sensitivity.py +++ b/tests/detect/test_edge_sensitivity.py @@ -15,7 +15,7 @@ import pytest # Load edges module directly _spec = importlib.util.spec_from_file_location( - "cv_edges", Path("gpu/models/cv/edges.py"), + "cv_edges", Path("core/gpu/models/cv/edges.py"), ) _edges_mod = importlib.util.module_from_spec(_spec) _spec.loader.exec_module(_edges_mod) diff --git a/tests/detect/test_frame_extractor.py b/tests/detect/test_frame_extractor.py index 71ef37d..f23a48e 100644 --- a/tests/detect/test_frame_extractor.py +++ b/tests/detect/test_frame_extractor.py @@ -5,8 +5,8 @@ from pathlib import Path import pytest -from detect.profiles.base import FrameExtractionConfig -from detect.stages.frame_extractor import extract_frames +from core.detect.stages.models import FrameExtractionConfig +from core.detect.stages.frame_extractor import extract_frames SAMPLE_DIR = Path("media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e") @@ -61,7 +61,7 @@ def test_extract_frames_with_events(monkeypatch): def mock_push(job_id, event_type, data): events.append((job_id, event_type, data)) - monkeypatch.setattr("detect.emit.push_detect_event", mock_push) + monkeypatch.setattr("core.detect.emit.push_detect_event", mock_push) video = _get_sample_video() config = FrameExtractionConfig(fps=1, max_frames=5) diff --git a/tests/detect/test_graph.py b/tests/detect/test_graph.py index 61ac5eb..e16e505 100644 --- a/tests/detect/test_graph.py +++ b/tests/detect/test_graph.py @@ -4,9 +4,9 @@ import os import pytest -from detect.graph import NODES, build_graph, get_pipeline -from detect.models import PipelineStats -from detect.state import DetectState +from core.detect.graph import NODES, build_graph, get_pipeline +from core.detect.models import PipelineStats +from core.detect.state import DetectState VIDEO = "media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4" @@ -42,7 +42,7 @@ def test_graph_has_all_nodes(): def test_graph_runs_end_to_end(monkeypatch): """Run the full graph with mocked event emission.""" events = [] - monkeypatch.setattr("detect.emit.push_detect_event", + monkeypatch.setattr("core.detect.emit.push_detect_event", lambda job_id, etype, data: events.append((etype, data))) pipeline = get_pipeline() @@ -75,7 +75,7 @@ def test_graph_runs_end_to_end(monkeypatch): def test_graph_node_transitions(monkeypatch): """Verify each node emits running → done transitions.""" events = [] - monkeypatch.setattr("detect.emit.push_detect_event", + monkeypatch.setattr("core.detect.emit.push_detect_event", lambda job_id, etype, data: events.append((etype, data))) pipeline = get_pipeline() diff --git a/tests/detect/test_ocr_stage.py b/tests/detect/test_ocr_stage.py index b65e340..4f10733 100644 --- a/tests/detect/test_ocr_stage.py +++ b/tests/detect/test_ocr_stage.py @@ -3,9 +3,9 @@ import numpy as np import pytest -from detect.models import BoundingBox, Frame -from detect.profiles.base import OCRConfig -from detect.stages.ocr_stage import _crop_region, _parse_ocr_raw, run_ocr +from core.detect.models import BoundingBox, Frame +from core.detect.stages.models import OCRConfig +from core.detect.stages.ocr_stage import _crop_region, _parse_ocr_raw, run_ocr def _has_paddleocr() -> bool: @@ -80,7 +80,7 @@ def test_parse_empty(): def test_run_ocr_remote(monkeypatch): events = [] - monkeypatch.setattr("detect.emit.push_detect_event", + monkeypatch.setattr("core.detect.emit.push_detect_event", lambda job_id, etype, data: events.append((etype, data))) class FakeResult: @@ -94,11 +94,11 @@ def test_run_ocr_remote(monkeypatch): def ocr(self, image, languages): return [FakeResult("NIKE", 0.92)] - monkeypatch.setattr("detect.stages.ocr_stage.InferenceClient", FakeClient, + monkeypatch.setattr("core.detect.stages.ocr_stage.InferenceClient", FakeClient, raising=False) # Patch the import path used in the function - import detect.stages.ocr_stage as mod - monkeypatch.setattr("detect.inference.InferenceClient", FakeClient) + import core.detect.stages.ocr_stage as mod + monkeypatch.setattr("core.detect.inference.InferenceClient", FakeClient) frame = _make_frame() box = _make_box() @@ -123,7 +123,7 @@ def test_run_ocr_remote(monkeypatch): ) def test_run_ocr_skips_empty_crop(monkeypatch): events = [] - monkeypatch.setattr("detect.emit.push_detect_event", + monkeypatch.setattr("core.detect.emit.push_detect_event", lambda job_id, etype, data: events.append((etype, data))) frame = _make_frame(w=10, h=10) diff --git a/tests/detect/test_preprocess.py b/tests/detect/test_preprocess.py index d2363b9..8defef5 100644 --- a/tests/detect/test_preprocess.py +++ b/tests/detect/test_preprocess.py @@ -26,7 +26,7 @@ def _make_image(w: int = 200, h: int = 60) -> np.ndarray: @requires_cv2 def test_binarize(): - from gpu.models.preprocess import binarize + from core.gpu.models.preprocess import binarize img = _make_image() result = binarize(img) @@ -40,7 +40,7 @@ def test_binarize(): @requires_cv2 def test_enhance_contrast(): - from gpu.models.preprocess import enhance_contrast + from core.gpu.models.preprocess import enhance_contrast img = _make_image() result = enhance_contrast(img) @@ -51,7 +51,7 @@ def test_enhance_contrast(): @requires_cv2 def test_deskew_no_rotation(): - from gpu.models.preprocess import deskew + from core.gpu.models.preprocess import deskew img = _make_image() result = deskew(img) @@ -63,7 +63,7 @@ def test_deskew_no_rotation(): @requires_cv2 def test_preprocess_pipeline(): - from gpu.models.preprocess import preprocess + from core.gpu.models.preprocess import preprocess img = _make_image() @@ -76,7 +76,7 @@ def test_preprocess_pipeline(): @requires_cv2 def test_preprocess_all_disabled(): - from gpu.models.preprocess import preprocess + from core.gpu.models.preprocess import preprocess img = _make_image() result = preprocess(img, do_binarize=False, do_deskew=False, do_contrast=False) diff --git a/tests/detect/test_profiles.py b/tests/detect/test_profiles.py index 8e95d5b..02e5729 100644 --- a/tests/detect/test_profiles.py +++ b/tests/detect/test_profiles.py @@ -1,55 +1,70 @@ -"""Tests for ContentTypeProfile implementations.""" +"""Tests for profile data and helper functions.""" -import pytest - -from detect.models import BrandDetection -from detect.profiles.base import ContentTypeProfile, CropContext -from detect.profiles.soccer import SoccerBroadcastProfile -from detect.profiles.stubs import AdvertisingProfile, NewsBroadcastProfile, TranscriptProfile +from core.detect.models import BrandDetection, CropContext +from core.detect.profile import get_profile, get_stage_config, build_vlm_prompt, aggregate_detections, pipeline_config_from_dict +from core.detect.stages.models import FrameExtractionConfig, DetectionConfig, ResolverConfig -def test_soccer_satisfies_protocol(): - profile: ContentTypeProfile = SoccerBroadcastProfile() - assert profile.name == "soccer_broadcast" +def test_soccer_profile_exists(): + profile = get_profile("soccer_broadcast") + assert profile["name"] == "soccer_broadcast" + + +def test_soccer_has_pipeline(): + profile = get_profile("soccer_broadcast") + assert "stages" in profile["pipeline"] + assert "edges" in profile["pipeline"] + + +def test_soccer_has_configs(): + profile = get_profile("soccer_broadcast") + configs = profile["configs"] + assert "extract_frames" in configs + assert "filter_scenes" in configs + assert "detect_edges" in configs def test_soccer_frame_extraction_config(): - cfg = SoccerBroadcastProfile().frame_extraction_config() + profile = get_profile("soccer_broadcast") + cfg = FrameExtractionConfig(**get_stage_config(profile, "extract_frames")) assert cfg.fps > 0 assert cfg.max_frames > 0 def test_soccer_detection_config(): - cfg = SoccerBroadcastProfile().detection_config() + profile = get_profile("soccer_broadcast") + cfg = DetectionConfig(**get_stage_config(profile, "detect_objects")) assert 0 < cfg.confidence_threshold < 1 assert isinstance(cfg.target_classes, list) def test_soccer_resolver_config(): - cfg = SoccerBroadcastProfile().resolver_config() + profile = get_profile("soccer_broadcast") + cfg = ResolverConfig(**get_stage_config(profile, "match_brands")) assert cfg.fuzzy_threshold > 0 -def test_soccer_vlm_prompt(): +def test_vlm_prompt(): ctx = CropContext(image=b"fake", surrounding_text="Emirates", position_hint="top-center") - prompt = SoccerBroadcastProfile().vlm_prompt(ctx) + template = get_profile("soccer_broadcast")["configs"]["escalate_vlm"]["vlm_prompt_template"] + prompt = build_vlm_prompt(ctx, template) assert "brand" in prompt.lower() assert "Emirates" in prompt -def test_soccer_aggregate_empty(): - report = SoccerBroadcastProfile().aggregate([]) +def test_aggregate_empty(): + report = aggregate_detections([], "soccer_broadcast") assert len(report.brands) == 0 assert len(report.timeline) == 0 -def test_soccer_aggregate_groups(): +def test_aggregate_groups(): detections = [ BrandDetection(brand="Nike", timestamp=1.0, duration=0.5, confidence=0.9, source="ocr"), BrandDetection(brand="Nike", timestamp=2.0, duration=0.5, confidence=0.8, source="ocr"), BrandDetection(brand="Adidas", timestamp=3.0, duration=0.5, confidence=0.7, source="logo_match"), ] - report = SoccerBroadcastProfile().aggregate(detections) + report = aggregate_detections(detections, "soccer_broadcast") assert "Nike" in report.brands assert "Adidas" in report.brands assert report.brands["Nike"].total_appearances == 2 @@ -57,15 +72,9 @@ def test_soccer_aggregate_groups(): assert report.timeline == sorted(report.timeline, key=lambda d: d.timestamp) -def test_soccer_auxiliary_returns_empty(): - assert SoccerBroadcastProfile().auxiliary_detections("test.mp4") == [] - - -@pytest.mark.parametrize("stub_cls", [NewsBroadcastProfile, AdvertisingProfile, TranscriptProfile]) -def test_stubs_raise(stub_cls): - stub = stub_cls() - assert isinstance(stub.name, str) - with pytest.raises(NotImplementedError): - stub.frame_extraction_config() - with pytest.raises(NotImplementedError): - stub.resolver_config() +def test_pipeline_config(): + profile = get_profile("soccer_broadcast") + config = pipeline_config_from_dict(profile["pipeline"]) + assert config.name == "soccer_broadcast" + assert len(config.stages) > 0 + assert len(config.edges) > 0 diff --git a/tests/detect/test_region_analyzer.py b/tests/detect/test_region_analyzer.py index 8260bf4..301e5d5 100644 --- a/tests/detect/test_region_analyzer.py +++ b/tests/detect/test_region_analyzer.py @@ -6,14 +6,14 @@ from pathlib import Path import numpy as np import pytest -from detect.models import BoundingBox, Frame -from detect.profiles.base import RegionAnalysisConfig -from detect.profiles.soccer import SoccerBroadcastProfile +from core.detect.models import BoundingBox, Frame +from core.detect.stages.models import RegionAnalysisConfig +from core.detect.profile import get_profile, get_stage_config # Load edges module directly — gpu/models/__init__.py has GPU-only imports _spec = importlib.util.spec_from_file_location( - "cv_edges", Path("gpu/models/cv/edges.py"), + "cv_edges", Path("core/gpu/models/cv/edges.py"), ) _edges_mod = importlib.util.module_from_spec(_spec) _spec.loader.exec_module(_edges_mod) @@ -40,8 +40,8 @@ def _make_frame_with_lines(seq: int = 0) -> Frame: # --- Config --- def test_soccer_profile_has_region_analysis_config(): - profile = SoccerBroadcastProfile() - config = profile.region_analysis_config() + profile = get_profile("soccer_broadcast") + config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges")) assert isinstance(config, RegionAnalysisConfig) assert config.enabled is True @@ -133,9 +133,9 @@ def test_detect_edges_debug_blank_frame(): def test_stage_disabled(monkeypatch): """When disabled, returns empty dict.""" - monkeypatch.setattr("detect.emit.push_detect_event", lambda *a, **kw: None) + monkeypatch.setattr("core.detect.emit.push_detect_event", lambda *a, **kw: None) - from detect.stages.edge_detector import detect_edge_regions + from core.detect.stages.edge_detector import detect_edge_regions config = RegionAnalysisConfig(enabled=False) result = detect_edge_regions([_make_frame()], config, job_id="test") @@ -144,9 +144,9 @@ def test_stage_disabled(monkeypatch): def test_stage_local_blank(monkeypatch): """Local mode on blank frames returns empty boxes.""" - monkeypatch.setattr("detect.emit.push_detect_event", lambda *a, **kw: None) + monkeypatch.setattr("core.detect.emit.push_detect_event", lambda *a, **kw: None) - from detect.stages.edge_detector import detect_edge_regions + from core.detect.stages.edge_detector import detect_edge_regions config = RegionAnalysisConfig() result = detect_edge_regions([_make_frame()], config, job_id="test") @@ -156,9 +156,9 @@ def test_stage_local_blank(monkeypatch): def test_stage_local_with_lines(monkeypatch): """Local mode on frame with lines should find regions.""" - monkeypatch.setattr("detect.emit.push_detect_event", lambda *a, **kw: None) + monkeypatch.setattr("core.detect.emit.push_detect_event", lambda *a, **kw: None) - from detect.stages.edge_detector import detect_edge_regions + from core.detect.stages.edge_detector import detect_edge_regions config = RegionAnalysisConfig() frame = _make_frame_with_lines() @@ -174,22 +174,22 @@ def test_stage_local_with_lines(monkeypatch): def test_detect_edges_in_nodes(): """detect_edges must be in the pipeline node list.""" - from detect.graph import NODES, NODE_FUNCTIONS + from core.detect.graph import NODES, NODE_FUNCTIONS assert "detect_edges" in NODES node_names = [name for name, _ in NODE_FUNCTIONS] assert "detect_edges" in node_names - # Must be after filter_scenes, before detect_objects + # Must be after field_segmentation, before detect_objects idx = NODES.index("detect_edges") - assert NODES[idx - 1] == "filter_scenes" + assert NODES[idx - 1] == "field_segmentation" assert NODES[idx + 1] == "detect_objects" # --- State --- def test_state_has_edge_regions_field(): - from detect.state import DetectState + from core.detect.state import DetectState hints = DetectState.__annotations__ assert "edge_regions_by_frame" in hints diff --git a/tests/detect/test_replay.py b/tests/detect/test_replay.py index 4b3cabb..c5965ee 100644 --- a/tests/detect/test_replay.py +++ b/tests/detect/test_replay.py @@ -1,87 +1,67 @@ -"""Tests for replay and OverrideProfile.""" +"""Tests for config overrides and replay.""" import pytest -from detect.profiles.soccer import SoccerBroadcastProfile -from detect.profiles.base import RegionAnalysisConfig -from detect.checkpoint.replay import OverrideProfile, replay_single_stage +from core.detect.profile import get_profile, get_stage_config +from core.detect.stages.models import RegionAnalysisConfig, OCRConfig, ResolverConfig +from core.detect.checkpoint.replay import replay_single_stage -def test_override_profile_patches_ocr(): - base = SoccerBroadcastProfile() - overrides = {"ocr": {"min_confidence": 0.3, "languages": ["en", "es", "pt"]}} - profile = OverrideProfile(base, overrides) +def _apply_overrides(profile, overrides): + """Apply config overrides to a profile dict (same logic as nodes._load_profile).""" + merged_configs = dict(profile.get("configs", {})) + for stage_name, stage_overrides in overrides.items(): + if stage_name in merged_configs: + merged_configs[stage_name] = {**merged_configs[stage_name], **stage_overrides} + else: + merged_configs[stage_name] = stage_overrides + return {**profile, "configs": merged_configs} - config = profile.ocr_config() + +def test_override_patches_ocr(): + profile = get_profile("soccer_broadcast") + overrides = {"run_ocr": {"min_confidence": 0.3, "languages": ["en", "es", "pt"]}} + patched = _apply_overrides(profile, overrides) + + config = OCRConfig(**get_stage_config(patched, "run_ocr")) assert config.min_confidence == 0.3 assert config.languages == ["en", "es", "pt"] -def test_override_profile_patches_resolver(): - base = SoccerBroadcastProfile() - overrides = {"resolver": {"fuzzy_threshold": 60}} - profile = OverrideProfile(base, overrides) +def test_override_patches_resolver(): + profile = get_profile("soccer_broadcast") + overrides = {"match_brands": {"fuzzy_threshold": 60}} + patched = _apply_overrides(profile, overrides) - config = profile.resolver_config() + config = ResolverConfig(**get_stage_config(patched, "match_brands")) assert config.fuzzy_threshold == 60 -def test_override_profile_patches_detection(): - base = SoccerBroadcastProfile() - overrides = {"detection": {"confidence_threshold": 0.5}} - profile = OverrideProfile(base, overrides) +def test_override_no_overrides(): + profile = get_profile("soccer_broadcast") + patched = _apply_overrides(profile, {}) - config = profile.detection_config() - - assert config.confidence_threshold == 0.5 - - -def test_override_profile_no_overrides(): - base = SoccerBroadcastProfile() - profile = OverrideProfile(base, {}) - - ocr = profile.ocr_config() - base_ocr = base.ocr_config() + ocr = OCRConfig(**get_stage_config(patched, "run_ocr")) + base_ocr = OCRConfig(**get_stage_config(profile, "run_ocr")) assert ocr.min_confidence == base_ocr.min_confidence assert ocr.languages == base_ocr.languages -def test_override_profile_delegates_non_config(): - base = SoccerBroadcastProfile() - profile = OverrideProfile(base, {"ocr": {"min_confidence": 0.1}}) +def test_override_patches_region_analysis(): + profile = get_profile("soccer_broadcast") + overrides = {"detect_edges": {"edge_canny_low": 25, "edge_canny_high": 200}} + patched = _apply_overrides(profile, overrides) - assert profile.name == "soccer_broadcast" - assert profile.resolver_config().fuzzy_threshold > 0 + config = RegionAnalysisConfig(**get_stage_config(patched, "detect_edges")) - -def test_override_profile_ignores_unknown_fields(): - base = SoccerBroadcastProfile() - overrides = {"ocr": {"nonexistent_field": 42}} - profile = OverrideProfile(base, overrides) - - config = profile.ocr_config() - - assert not hasattr(config, "nonexistent_field") - assert config.min_confidence == base.ocr_config().min_confidence - - -# --- OverrideProfile for region_analysis --- - -def test_override_profile_patches_region_analysis(): - base = SoccerBroadcastProfile() - overrides = {"region_analysis": {"edge_canny_low": 25, "edge_canny_high": 200}} - profile = OverrideProfile(base, overrides) - - config = profile.region_analysis_config() - - assert isinstance(config, RegionAnalysisConfig) assert config.edge_canny_low == 25 assert config.edge_canny_high == 200 - # Unchanged fields keep defaults - assert config.edge_hough_threshold == base.region_analysis_config().edge_hough_threshold + # Unchanged fields keep defaults from profile + base_config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges")) + assert config.edge_hough_threshold == base_config.edge_hough_threshold # --- replay_single_stage --- diff --git a/tests/detect/test_scene_filter.py b/tests/detect/test_scene_filter.py index 19db46e..2b41d1a 100644 --- a/tests/detect/test_scene_filter.py +++ b/tests/detect/test_scene_filter.py @@ -3,9 +3,9 @@ import numpy as np import pytest -from detect.models import Frame -from detect.profiles.base import SceneFilterConfig -from detect.stages.scene_filter import scene_filter +from core.detect.models import Frame +from core.detect.stages.models import SceneFilterConfig +from core.detect.stages.scene_filter import scene_filter def _make_frame(seq: int, color: tuple[int, int, int] = (128, 128, 128)) -> Frame: @@ -72,7 +72,7 @@ def test_hashes_populated(): def test_events_emitted(monkeypatch): events = [] - monkeypatch.setattr("detect.emit.push_detect_event", + monkeypatch.setattr("core.detect.emit.push_detect_event", lambda job_id, etype, data: events.append((etype, data))) frames = [_make_frame(i) for i in range(5)] diff --git a/tests/detect/test_sse_contract.py b/tests/detect/test_sse_contract.py index 6cf8be7..7dfc7eb 100644 --- a/tests/detect/test_sse_contract.py +++ b/tests/detect/test_sse_contract.py @@ -1,6 +1,6 @@ """Round-trip serialization tests for SSE contract models.""" -from detect.sse_contract import ( +from core.detect.sse import ( BoundingBoxEvent, BrandSummary, Detection, diff --git a/tests/detect/test_stage_registry.py b/tests/detect/test_stage_registry.py index 8a94bee..9ddbd82 100644 --- a/tests/detect/test_stage_registry.py +++ b/tests/detect/test_stage_registry.py @@ -1,7 +1,7 @@ """Tests for the stage registry.""" -from detect.stages import list_stages, get_stage, get_palette -from detect.stages.base import get_stage_class +from core.detect.stages import list_stages, get_stage, get_palette +from core.detect.stages.base import get_stage_class EXPECTED_STAGES = [ diff --git a/tests/detect/test_tracing.py b/tests/detect/test_tracing.py index 4f35c97..b0b94e4 100644 --- a/tests/detect/test_tracing.py +++ b/tests/detect/test_tracing.py @@ -2,7 +2,7 @@ import pytest -from detect.tracing import trace_node, SpanContext, flush +from core.detect.tracing import trace_node, SpanContext, flush def test_trace_node_noop(): diff --git a/tests/detect/test_vlm_cloud.py b/tests/detect/test_vlm_cloud.py index 1538a0b..6710271 100644 --- a/tests/detect/test_vlm_cloud.py +++ b/tests/detect/test_vlm_cloud.py @@ -3,8 +3,8 @@ import numpy as np import pytest -from detect.models import BoundingBox, Frame, PipelineStats, TextCandidate -from detect.stages.vlm_cloud import escalate_cloud, _parse_response +from core.detect.models import BoundingBox, Frame, PipelineStats, TextCandidate +from core.detect.stages.vlm_cloud import escalate_cloud, _parse_response def _make_candidate(text: str = "unknown", confidence: float = 0.4) -> TextCandidate: @@ -30,14 +30,14 @@ def test_parse_response_no_confidence(): def test_escalate_skips_without_api_key(monkeypatch): events = [] - monkeypatch.setattr("detect.emit.push_detect_event", + monkeypatch.setattr("core.detect.emit.push_detect_event", lambda job_id, etype, data: events.append((etype, data))) monkeypatch.delenv("GROQ_API_KEY", raising=False) monkeypatch.delenv("GEMINI_API_KEY", raising=False) monkeypatch.delenv("OPENAI_API_KEY", raising=False) monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq") # Reset cached provider - import detect.providers as prov + import core.detect.providers as prov monkeypatch.setattr(prov, "_cached", None) candidates = [_make_candidate()] @@ -54,7 +54,7 @@ def test_escalate_skips_without_api_key(monkeypatch): def test_escalate_empty_candidates(monkeypatch): events = [] - monkeypatch.setattr("detect.emit.push_detect_event", + monkeypatch.setattr("core.detect.emit.push_detect_event", lambda job_id, etype, data: events.append((etype, data))) stats = PipelineStats() @@ -66,18 +66,18 @@ def test_escalate_empty_candidates(monkeypatch): def test_escalate_with_mock_api(monkeypatch): events = [] - monkeypatch.setattr("detect.emit.push_detect_event", + monkeypatch.setattr("core.detect.emit.push_detect_event", lambda job_id, etype, data: events.append((etype, data))) monkeypatch.setenv("GROQ_API_KEY", "test-key") monkeypatch.setenv("CLOUD_LLM_PROVIDER", "groq") # Reset cached provider - import detect.providers as prov + import core.detect.providers as prov monkeypatch.setattr(prov, "_cached", None) def mock_call(image_b64, prompt): return {"brand": "Heineken", "confidence": 0.75, "reasoning": "green logo", "tokens": 300} - monkeypatch.setattr("detect.stages.vlm_cloud._call_cloud_api", mock_call) + monkeypatch.setattr("core.detect.stages.vlm_cloud._call_cloud_api", mock_call) candidates = [_make_candidate("unknown logo")] stats = PipelineStats() diff --git a/ui/common/types/generated.ts b/ui/common/types/generated.ts index 0dd674c..e6ff86f 100644 --- a/ui/common/types/generated.ts +++ b/ui/common/types/generated.ts @@ -57,6 +57,7 @@ export interface Job { source_asset_id: string; video_path: string; profile_name: string; + timeline_id: string | null; parent_id: string | null; run_type: RunType; config_overrides: Record; @@ -68,7 +69,6 @@ export interface Job { brands_found: number; cloud_llm_calls: number; estimated_cost_usd: number; - celery_task_id: string | null; priority: number; created_at: string | null; started_at: string | null; @@ -90,6 +90,7 @@ export interface Timeline { export interface Checkpoint { id: string; timeline_id: string; + job_id: string | null; parent_id: string | null; stage_outputs: Record; config_overrides: Record; @@ -111,6 +112,13 @@ export interface Brand { updated_at: string | null; } +export interface Profile { + id: string; + name: string; + pipeline: Record; + configs: Record; +} + export interface CreateJobRequest { source_asset_id: string; preset_id: string | null; diff --git a/ui/detection-app/.gitignore b/ui/detection-app/.gitignore new file mode 100644 index 0000000..f0c31e5 --- /dev/null +++ b/ui/detection-app/.gitignore @@ -0,0 +1,4 @@ +node_modules/ +dist/ +public/opencv.js +public/opencv_js.wasm diff --git a/ui/detection-app/package.json b/ui/detection-app/package.json index 9b87bd9..6c702b6 100644 --- a/ui/detection-app/package.json +++ b/ui/detection-app/package.json @@ -10,14 +10,15 @@ "typecheck": "vue-tsc --noEmit" }, "dependencies": { - "vue": "^3.5", + "@techstark/opencv-js": "4.12.0-release.1", + "mpr-ui-framework": "link:../framework", "pinia": "^2.2", - "mpr-ui-framework": "link:../framework" + "vue": "^3.5" }, "devDependencies": { + "@vitejs/plugin-vue": "^5", "typescript": "^5.6", "vite": "^6", - "@vitejs/plugin-vue": "^5", "vue-tsc": "^2" } } diff --git a/ui/detection-app/pnpm-lock.yaml b/ui/detection-app/pnpm-lock.yaml index 40cd245..0f4adac 100644 --- a/ui/detection-app/pnpm-lock.yaml +++ b/ui/detection-app/pnpm-lock.yaml @@ -8,6 +8,9 @@ importers: .: dependencies: + '@techstark/opencv-js': + specifier: 4.12.0-release.1 + version: 4.12.0-release.1 mpr-ui-framework: specifier: link:../framework version: link:../framework @@ -334,6 +337,9 @@ packages: cpu: [x64] os: [win32] + '@techstark/opencv-js@4.12.0-release.1': + resolution: {integrity: sha512-LtTaph9v/HqLPXEg3m1xs2h7QJh10pUpuDT0nj8g77lelWnTwwQrehtd+fXElLOdrkqc4Fea6Z/sJBvEJLYPfw==} + '@types/estree@1.0.8': resolution: {integrity: sha512-dWHzHa2WqEXI/O1E9OjrocMTKJl2mSrEolh1Iomrv6U+JuNwaHXsXx9bLu5gG7BUWFIN0skIQJQ/L1rIex4X6w==} @@ -735,6 +741,8 @@ snapshots: '@rollup/rollup-win32-x64-msvc@4.60.0': optional: true + '@techstark/opencv-js@4.12.0-release.1': {} + '@types/estree@1.0.8': {} '@vitejs/plugin-vue@5.2.4(vite@6.4.1)(vue@3.5.30(typescript@5.9.3))': diff --git a/ui/detection-app/src/App.vue b/ui/detection-app/src/App.vue index 30ff8dc..31ada30 100644 --- a/ui/detection-app/src/App.vue +++ b/ui/detection-app/src/App.vue @@ -1,5 +1,5 @@ + + + + diff --git a/ui/detection-app/src/components/StageConfigSliders.vue b/ui/detection-app/src/components/StageConfigSliders.vue deleted file mode 100644 index 19d50d2..0000000 --- a/ui/detection-app/src/components/StageConfigSliders.vue +++ /dev/null @@ -1,385 +0,0 @@ - - - - - diff --git a/ui/detection-app/src/composables/useCheckpointLoader.ts b/ui/detection-app/src/composables/useCheckpointLoader.ts index f14ca6e..126c90c 100644 --- a/ui/detection-app/src/composables/useCheckpointLoader.ts +++ b/ui/detection-app/src/composables/useCheckpointLoader.ts @@ -29,15 +29,37 @@ export function useCheckpointLoader( stripSelEndOverride.value ?? Math.max(0, checkpointFrames.value.length - 1), ) + // Cache job_id → timeline_id mappings + const timelineCache = new Map() + // Track current frame from SSE source.on<{ frame_ref: number; jpeg_b64: string }>('frame_update', (e) => { currentFrameImage.value = e.jpeg_b64 currentFrameRef.value = e.frame_ref }) + async function resolveTimelineId(job: string): Promise { + if (timelineCache.has(job)) return timelineCache.get(job)! + + try { + const resp = await fetch(`/api/detect/timeline/${job}`) + if (!resp.ok) return null + const data = await resp.json() + const tid = data.timeline_id + if (tid) timelineCache.set(job, tid) + return tid + } catch { + return null + } + } + async function loadCheckpoint(job: string, stage: string) { try { - const resp = await fetch(`/api/detect/checkpoints/${job}/${stage}`) + // Resolve timeline_id from job_id + const timelineId = await resolveTimelineId(job) + const lookupId = timelineId ?? job + + const resp = await fetch(`/api/detect/checkpoints/${lookupId}/${stage}`) if (!resp.ok) return const data = await resp.json() @@ -69,8 +91,6 @@ export function useCheckpointLoader( const { stages, checkpointStageFor } = useStageRegistry() // Auto-load checkpoint when entering editor mode. - // Also watches stages — the registry fetch is async, so on first load - // stages may be empty. When they arrive, re-evaluate. watch( () => [pipeline.layoutMode, pipeline.editorStage, jobId.value, stages.value.length] as const, ([mode, stage, job]) => { diff --git a/ui/detection-app/src/composables/useEditorState.ts b/ui/detection-app/src/composables/useEditorState.ts index 06ebe5e..125d3c6 100644 --- a/ui/detection-app/src/composables/useEditorState.ts +++ b/ui/detection-app/src/composables/useEditorState.ts @@ -1,7 +1,7 @@ import { ref } from 'vue' import type { Ref } from 'vue' import type { FrameOverlay, FrameBBox } from 'mpr-ui-framework' -import { matchTracks, renderTracksToImageData, imageDataToPngB64 } from '@/cv' +import type { StageResult } from '@/components/StageConfig.vue' export type RegionBox = { x: number @@ -15,83 +15,80 @@ export type RegionBox = { export function useEditorState(currentFrameRef: Ref) { const editorOverlays = ref([]) const editorBoxes = ref([]) + const activeStage = ref(null) const allFrameRegions = ref>({}) - const allFrameDebug = ref>({}) + const allFrameOverlays = ref>>({}) + const allFrameStats = ref>>({}) const frameDimensions = ref<{ w: number; h: number } | null>(null) + // Overlay definitions per stage — maps overlay keys to display labels + const STAGE_OVERLAYS: Record = { + detect_edges: [ + { key: 'edge_overlay_b64', label: 'Canny edges', defaultOpacity: 0.25 }, + { key: 'lines_overlay_b64', label: 'Hough lines', defaultOpacity: 0.25 }, + ], + field_segmentation: [ + { key: 'mask_overlay_b64', label: 'Field mask', defaultOpacity: 0.5, srcFormat: 'png' }, + ], + } + function updateDisplayForFrame(seq: number) { + const stage = activeStage.value ?? 'detect_edges' + + // Boxes — only for stages that produce regions const regions = allFrameRegions.value[seq] ?? [] editorBoxes.value = regions.map(r => ({ x: r.x, y: r.y, w: r.w, h: r.h, confidence: r.confidence, - label: r.label ?? 'edge_region', - stage: 'detect_edges', + label: r.label ?? 'region', + stage, })) - const debug = allFrameDebug.value[seq] - if (debug) { + // Overlays — driven by stage overlay definitions + const overlayData = allFrameOverlays.value[seq] + const overlayDefs = STAGE_OVERLAYS[stage] ?? [] + + if (overlayData && overlayDefs.length > 0) { const overlays: FrameOverlay[] = [] - if (debug.edge_overlay_b64) { - const existing = editorOverlays.value.find(o => o.label === 'Canny edges') - overlays.push({ src: debug.edge_overlay_b64, label: 'Canny edges', visible: existing?.visible ?? true, opacity: existing?.opacity ?? 0.25 }) + for (const def of overlayDefs) { + const src = overlayData[def.key] + if (!src) continue + const existing = editorOverlays.value.find(o => o.label === def.label) + overlays.push({ + src, + label: def.label, + visible: existing?.visible ?? true, + opacity: existing?.opacity ?? def.defaultOpacity, + srcFormat: def.srcFormat, + }) } - if (debug.lines_overlay_b64) { - const existing = editorOverlays.value.find(o => o.label === 'Hough lines') - overlays.push({ src: debug.lines_overlay_b64, label: 'Hough lines', visible: existing?.visible ?? true, opacity: existing?.opacity ?? 0.25 }) - } - const trackOverlay = editorOverlays.value.find(o => o.label === 'Motion tracks') - if (trackOverlay) overlays.push(trackOverlay) editorOverlays.value = overlays - } - - if (Object.keys(allFrameRegions.value).length >= 2 && frameDimensions.value) { - updateTrackOverlay(seq) + } else if (overlayDefs.length === 0) { + editorOverlays.value = [] } } - async function updateTrackOverlay(currentSeq: number) { - const dims = frameDimensions.value - if (!dims || Object.keys(allFrameRegions.value).length < 2) return - const tracks = matchTracks(allFrameRegions.value) - const imageData = renderTracksToImageData(tracks, dims.w, dims.h, currentSeq) - const b64 = await imageDataToPngB64(imageData) - const existing = editorOverlays.value.find(o => o.label === 'Motion tracks') - const trackOverlay: FrameOverlay = { - src: b64, - label: 'Motion tracks', - visible: existing?.visible ?? true, - opacity: existing?.opacity ?? 0.9, - srcFormat: 'png', - } - editorOverlays.value = [ - ...editorOverlays.value.filter(o => o.label !== 'Motion tracks'), - trackOverlay, - ] - } - - function onReplayResult(result: { - regions_by_frame?: Record - debug?: Record - frameWidth?: number - frameHeight?: number - }) { + function onReplayResult(result: StageResult) { if (result.frameWidth && result.frameHeight) { frameDimensions.value = { w: result.frameWidth, h: result.frameHeight } } if (result.regions_by_frame) { for (const [seqStr, regions] of Object.entries(result.regions_by_frame)) { - allFrameRegions.value[Number(seqStr)] = regions + allFrameRegions.value[Number(seqStr)] = regions as RegionBox[] } } - if (result.debug) { - for (const [seqStr, dbg] of Object.entries(result.debug)) { - allFrameDebug.value[Number(seqStr)] = { - edge_overlay_b64: dbg.edge_overlay_b64, - lines_overlay_b64: dbg.lines_overlay_b64, - } + if (result.overlays_by_frame) { + for (const [seqStr, overlayMap] of Object.entries(result.overlays_by_frame)) { + allFrameOverlays.value[Number(seqStr)] = overlayMap + } + } + + if (result.stats_by_frame) { + for (const [seqStr, stats] of Object.entries(result.stats_by_frame)) { + allFrameStats.value[Number(seqStr)] = stats } } @@ -99,9 +96,30 @@ export function useEditorState(currentFrameRef: Ref) { updateDisplayForFrame(currentSeq) } - function resetEditorState() { + function setActiveStage(stage: string) { + activeStage.value = stage + // Clear accumulated state when switching stages allFrameRegions.value = {} - allFrameDebug.value = {} + allFrameOverlays.value = {} + allFrameStats.value = {} + frameDimensions.value = null + editorBoxes.value = [] + + // Initialize overlay controls from stage definitions (visible before any results) + const defs = STAGE_OVERLAYS[stage] ?? [] + editorOverlays.value = defs.map(def => ({ + src: '', + label: def.label, + visible: true, + opacity: def.defaultOpacity, + })) + } + + function resetEditorState() { + activeStage.value = null + allFrameRegions.value = {} + allFrameOverlays.value = {} + allFrameStats.value = {} frameDimensions.value = null editorOverlays.value = [] editorBoxes.value = [] @@ -110,11 +128,14 @@ export function useEditorState(currentFrameRef: Ref) { return { editorOverlays, editorBoxes, + activeStage, allFrameRegions, - allFrameDebug, + allFrameOverlays, + allFrameStats, frameDimensions, updateDisplayForFrame, onReplayResult, + setActiveStage, resetEditorState, } } diff --git a/ui/detection-app/src/cv/edges.ts b/ui/detection-app/src/cv/edges.ts index 9831ea2..1bcdabd 100644 --- a/ui/detection-app/src/cv/edges.ts +++ b/ui/detection-app/src/cv/edges.ts @@ -1,12 +1,11 @@ /** - * Edge detection — TypeScript port of gpu/models/cv/edges.py + * Edge detection — OpenCV WASM version. * - * 1:1 with the Python version. Same algorithm, same parameters, - * same output format. Runs in the browser, no network. + * 1:1 with gpu/models/cv/edges.py. Same algorithm, same parameters, + * same output format. Uses cv.Canny() and cv.HoughLinesP() via WASM. */ -import { toGrayscale, canny } from './imageOps' -import { houghLinesP, type LineSegment } from './hough' +import { getCV } from './opencv' export interface EdgeRegion { x: number @@ -32,15 +31,243 @@ export interface EdgeDetectionResult { } export interface EdgeDetectionDebugResult extends EdgeDetectionResult { - edgeImageData: ImageData // Canny output for overlay - linesImageData: ImageData // Frame with Hough lines drawn + edgeImageData: ImageData + linesImageData: ImageData horizontalCount: number pairCount: number } type HLine = { xMin: number; xMax: number; yMid: number; length: number } -/** Set a pixel on ImageData with bounds check */ +const DEFAULT_PARAMS: EdgeDetectionParams = { + cannyLow: 50, + cannyHigh: 150, + houghThreshold: 80, + houghMinLength: 100, + houghMaxGap: 10, + pairMaxDistance: 200, + pairMinDistance: 15, +} + +/** + * Detect edges in an RGBA ImageData using OpenCV WASM. + */ +export async function detectEdges( + imageData: ImageData, + params: Partial = {}, +): Promise { + const cv = await getCV() + if (!cv) throw new Error('OpenCV WASM not available') + const p = { ...DEFAULT_PARAMS, ...params } + const { width, height } = imageData + + const src = cv.matFromImageData(imageData) + const gray = new cv.Mat() + const edges = new cv.Mat() + + try { + cv.cvtColor(src, gray, cv.COLOR_RGBA2GRAY) + cv.Canny(gray, edges, p.cannyLow, p.cannyHigh) + + const lines = new cv.Mat() + try { + cv.HoughLinesP(edges, lines, 1, Math.PI / 180, p.houghThreshold, p.houghMinLength, p.houghMaxGap) + const horizontals = filterHorizontal(lines) + if (horizontals.length < 2) return { regions: [] } + + const pairs = findLinePairs(horizontals, p.pairMinDistance, p.pairMaxDistance) + const regions = pairsToBoxes(pairs, width, height) + return { regions } + } finally { + lines.delete() + } + } finally { + src.delete() + gray.delete() + edges.delete() + } +} + +/** + * Detect edges with debug visualizations using OpenCV WASM. + */ +export async function detectEdgesDebug( + imageData: ImageData, + params: Partial = {}, +): Promise { + const cv = await getCV() + if (!cv) throw new Error('OpenCV WASM not available') + const p = { ...DEFAULT_PARAMS, ...params } + const { width, height, data } = imageData + + const src = cv.matFromImageData(imageData) + const gray = new cv.Mat() + const edges = new cv.Mat() + + try { + cv.cvtColor(src, gray, cv.COLOR_RGBA2GRAY) + cv.Canny(gray, edges, p.cannyLow, p.cannyHigh) + + // Edge overlay — white edges on black + const edgeRgba = new cv.Mat() + try { + cv.cvtColor(edges, edgeRgba, cv.COLOR_GRAY2RGBA) + const edgeImageData = new ImageData( + new Uint8ClampedArray(edgeRgba.data), + width, + height, + ) + + const lines = new cv.Mat() + try { + cv.HoughLinesP(edges, lines, 1, Math.PI / 180, p.houghThreshold, p.houghMinLength, p.houghMaxGap) + const horizontals = filterHorizontal(lines) + + // Lines overlay — darken original, draw lines + const linesImageData = new ImageData(new Uint8ClampedArray(data), width, height) + for (let i = 0; i < linesImageData.data.length; i += 4) { + linesImageData.data[i] = Math.round(linesImageData.data[i] * 0.3) + linesImageData.data[i + 1] = Math.round(linesImageData.data[i + 1] * 0.3) + linesImageData.data[i + 2] = Math.round(linesImageData.data[i + 2] * 0.3) + } + + // All Hough lines in red + for (let i = 0; i < lines.rows; i++) { + const x1 = lines.data32S[i * 4] + const y1 = lines.data32S[i * 4 + 1] + const x2 = lines.data32S[i * 4 + 2] + const y2 = lines.data32S[i * 4 + 3] + drawLineThick(linesImageData, x1, y1, x2, y2, 255, 50, 50, 2) + } + + // Horizontal lines in cyan + for (const h of horizontals) { + drawLineThick(linesImageData, Math.round(h.xMin), Math.round(h.yMid), Math.round(h.xMax), Math.round(h.yMid), 0, 255, 255, 3) + } + + const pairs = horizontals.length >= 2 + ? findLinePairs(horizontals, p.pairMinDistance, p.pairMaxDistance) + : [] + + // Paired lines in green + for (const [top, bottom] of pairs) { + drawLineThick(linesImageData, Math.round(top.xMin), Math.round(top.yMid), Math.round(top.xMax), Math.round(top.yMid), 0, 255, 0, 4) + drawLineThick(linesImageData, Math.round(bottom.xMin), Math.round(bottom.yMid), Math.round(bottom.xMax), Math.round(bottom.yMid), 0, 255, 0, 4) + } + + const regions = pairsToBoxes(pairs, width, height) + + return { + regions, + edgeImageData, + linesImageData, + horizontalCount: horizontals.length, + pairCount: pairs.length, + } + } finally { + lines.delete() + } + } finally { + edgeRgba.delete() + } + } finally { + src.delete() + gray.delete() + edges.delete() + } +} + +// --- Line analysis (same logic as Python version) --- + +function filterHorizontal(lines: any, maxAngleDeg: number = 10): HLine[] { + const maxSlope = Math.tan((maxAngleDeg * Math.PI) / 180) + const result: HLine[] = [] + + for (let i = 0; i < lines.rows; i++) { + const x1 = lines.data32S[i * 4] + const y1 = lines.data32S[i * 4 + 1] + const x2 = lines.data32S[i * 4 + 2] + const y2 = lines.data32S[i * 4 + 3] + const dx = x2 - x1 + if (dx === 0) continue + const slope = Math.abs((y2 - y1) / dx) + if (slope <= maxSlope) { + const yMid = (y1 + y2) / 2 + const xMin = Math.min(x1, x2) + const xMax = Math.max(x1, x2) + const length = Math.sqrt(dx * dx + (y2 - y1) ** 2) + result.push({ xMin, xMax, yMid, length }) + } + } + return result +} + +function findLinePairs( + horizontals: HLine[], + minDistance: number, + maxDistance: number, +): [HLine, HLine][] { + const sorted = [...horizontals].sort((a, b) => a.yMid - b.yMid) + const pairs: [HLine, HLine][] = [] + const used = new Set() + + for (let i = 0; i < sorted.length; i++) { + if (used.has(i)) continue + const top = sorted[i] + + for (let j = i + 1; j < sorted.length; j++) { + if (used.has(j)) continue + const bottom = sorted[j] + const yGap = bottom.yMid - top.yMid + + if (yGap < minDistance) continue + if (yGap > maxDistance) break + + const overlapStart = Math.max(top.xMin, bottom.xMin) + const overlapEnd = Math.min(top.xMax, bottom.xMax) + const overlap = overlapEnd - overlapStart + const shorterLength = Math.min(top.xMax - top.xMin, bottom.xMax - bottom.xMin) + + if (shorterLength > 0 && overlap / shorterLength >= 0.5) { + pairs.push([top, bottom]) + used.add(i) + used.add(j) + break + } + } + } + return pairs +} + +function pairsToBoxes(pairs: [HLine, HLine][], frameWidth: number, frameHeight: number): EdgeRegion[] { + const regions: EdgeRegion[] = [] + for (const [top, bottom] of pairs) { + const x = Math.max(0, Math.min(top.xMin, bottom.xMin)) + const y = Math.max(0, top.yMid) + const x2 = Math.min(frameWidth, Math.max(top.xMax, bottom.xMax)) + const y2 = Math.min(frameHeight, bottom.yMid) + const w = x2 - x + const h = y2 - y + + if (w < 20 || h < 5) continue + + const avgLineLength = (top.length + bottom.length) / 2 + const coverage = Math.min(1.0, avgLineLength / Math.max(w, 1)) + + regions.push({ + x: Math.round(x), + y: Math.round(y), + w: Math.round(w), + h: Math.round(h), + confidence: Math.round(coverage * 1000) / 1000, + label: 'edge_region', + }) + } + return regions +} + +// --- Drawing helpers --- + function setPixel(img: ImageData, x: number, y: number, r: number, g: number, b: number) { if (x >= 0 && x < img.width && y >= 0 && y < img.height) { const p = (y * img.width + x) * 4 @@ -51,7 +278,6 @@ function setPixel(img: ImageData, x: number, y: number, r: number, g: number, b: } } -/** Bresenham line drawing with thickness */ function drawLineThick( img: ImageData, x0: number, y0: number, x1: number, y1: number, @@ -77,202 +303,3 @@ function drawLineThick( if (e2 < dx) { err += dx; y0 += sy } } } - -const DEFAULT_PARAMS: EdgeDetectionParams = { - cannyLow: 50, - cannyHigh: 150, - houghThreshold: 80, - houghMinLength: 100, - houghMaxGap: 10, - pairMaxDistance: 200, - pairMinDistance: 15, -} - -/** Filter to near-horizontal lines (within 10 degrees) */ -function filterHorizontal(lines: LineSegment[], maxAngleDeg: number = 10): HLine[] { - const maxSlope = Math.tan((maxAngleDeg * Math.PI) / 180) - const result: HLine[] = [] - - for (const line of lines) { - const dx = line.x2 - line.x1 - if (dx === 0) continue - const slope = Math.abs((line.y2 - line.y1) / dx) - if (slope <= maxSlope) { - const yMid = (line.y1 + line.y2) / 2 - const xMin = Math.min(line.x1, line.x2) - const xMax = Math.max(line.x1, line.x2) - const length = Math.sqrt(dx * dx + (line.y2 - line.y1) ** 2) - result.push({ xMin, xMax, yMid, length }) - } - } - return result -} - -/** Find pairs of horizontal lines that could be top/bottom of a hoarding */ -function findLinePairs( - horizontals: HLine[], - minDistance: number, - maxDistance: number, -): [HLine, HLine][] { - const sorted = [...horizontals].sort((a, b) => a.yMid - b.yMid) - const pairs: [HLine, HLine][] = [] - const used = new Set() - - for (let i = 0; i < sorted.length; i++) { - if (used.has(i)) continue - const top = sorted[i] - - for (let j = i + 1; j < sorted.length; j++) { - if (used.has(j)) continue - const bottom = sorted[j] - const yGap = bottom.yMid - top.yMid - - if (yGap < minDistance) continue - if (yGap > maxDistance) break - - // Check horizontal overlap (50% of shorter line) - const overlapStart = Math.max(top.xMin, bottom.xMin) - const overlapEnd = Math.min(top.xMax, bottom.xMax) - const overlap = overlapEnd - overlapStart - const shorterLength = Math.min(top.xMax - top.xMin, bottom.xMax - bottom.xMin) - - if (shorterLength > 0 && overlap / shorterLength >= 0.5) { - pairs.push([top, bottom]) - used.add(i) - used.add(j) - break - } - } - } - return pairs -} - -/** Convert a line pair to a bounding box */ -function pairToBox( - top: HLine, - bottom: HLine, - frameWidth: number, - frameHeight: number, -): EdgeRegion | null { - const x = Math.max(0, Math.min(top.xMin, bottom.xMin)) - const y = Math.max(0, top.yMid) - const x2 = Math.min(frameWidth, Math.max(top.xMax, bottom.xMax)) - const y2 = Math.min(frameHeight, bottom.yMid) - const w = x2 - x - const h = y2 - y - - if (w < 20 || h < 5) return null - - const avgLineLength = (top.length + bottom.length) / 2 - const coverage = Math.min(1.0, avgLineLength / Math.max(w, 1)) - - return { - x: Math.round(x), - y: Math.round(y), - w: Math.round(w), - h: Math.round(h), - confidence: Math.round(coverage * 1000) / 1000, - label: 'edge_region', - } -} - -/** - * Detect edges in an RGBA ImageData. - * - * Equivalent to gpu/models/cv/edges.py detect_edges() - */ -export function detectEdges( - imageData: ImageData, - params: Partial = {}, -): EdgeDetectionResult { - const p = { ...DEFAULT_PARAMS, ...params } - const { width, height } = imageData - - const gray = toGrayscale(imageData.data, width, height) - const edges = canny(gray, width, height, p.cannyLow, p.cannyHigh) - - const rawLines = houghLinesP(edges, width, height, p.houghThreshold, p.houghMinLength, p.houghMaxGap) - const horizontals = filterHorizontal(rawLines) - - if (horizontals.length < 2) return { regions: [] } - - const pairs = findLinePairs(horizontals, p.pairMinDistance, p.pairMaxDistance) - const regions: EdgeRegion[] = [] - for (const [top, bottom] of pairs) { - const box = pairToBox(top, bottom, width, height) - if (box) regions.push(box) - } - - return { regions } -} - -/** - * Detect edges with debug visualizations. - * - * Equivalent to gpu/models/cv/edges.py detect_edges_debug() - */ -export function detectEdgesDebug( - imageData: ImageData, - params: Partial = {}, -): EdgeDetectionDebugResult { - const p = { ...DEFAULT_PARAMS, ...params } - const { width, height, data } = imageData - - const gray = toGrayscale(data, width, height) - const edges = canny(gray, width, height, p.cannyLow, p.cannyHigh) - - // Edge overlay — white edges on black - const edgeImageData = new ImageData(width, height) - for (let i = 0; i < edges.length; i++) { - const px = i * 4 - edgeImageData.data[px] = edges[i] - edgeImageData.data[px + 1] = edges[i] - edgeImageData.data[px + 2] = edges[i] - edgeImageData.data[px + 3] = 255 - } - - const rawLines = houghLinesP(edges, width, height, p.houghThreshold, p.houghMinLength, p.houghMaxGap) - const horizontals = filterHorizontal(rawLines) - - // Lines overlay — darken original frame so lines pop, then draw - const linesImageData = new ImageData(new Uint8ClampedArray(data), width, height) - for (let i = 0; i < linesImageData.data.length; i += 4) { - linesImageData.data[i] = Math.round(linesImageData.data[i] * 0.3) - linesImageData.data[i + 1] = Math.round(linesImageData.data[i + 1] * 0.3) - linesImageData.data[i + 2] = Math.round(linesImageData.data[i + 2] * 0.3) - } - - // Draw all Hough lines in red (3px thick) - for (const line of rawLines) { - drawLineThick(linesImageData, line.x1, line.y1, line.x2, line.y2, 255, 50, 50, 2) - } - - // Draw horizontal lines in cyan (3px thick) - for (const h of horizontals) { - drawLineThick(linesImageData, Math.round(h.xMin), Math.round(h.yMid), Math.round(h.xMax), Math.round(h.yMid), 0, 255, 255, 3) - } - - const pairs = horizontals.length >= 2 - ? findLinePairs(horizontals, p.pairMinDistance, p.pairMaxDistance) - : [] - - // Draw paired lines in bright green (4px thick) - for (const [top, bottom] of pairs) { - drawLineThick(linesImageData, Math.round(top.xMin), Math.round(top.yMid), Math.round(top.xMax), Math.round(top.yMid), 0, 255, 0, 4) - drawLineThick(linesImageData, Math.round(bottom.xMin), Math.round(bottom.yMid), Math.round(bottom.xMax), Math.round(bottom.yMid), 0, 255, 0, 4) - } - - const regions: EdgeRegion[] = [] - for (const [top, bottom] of pairs) { - const box = pairToBox(top, bottom, width, height) - if (box) regions.push(box) - } - - return { - regions, - edgeImageData, - linesImageData, - horizontalCount: horizontals.length, - pairCount: pairs.length, - } -} diff --git a/ui/detection-app/src/cv/edgesTs.ts b/ui/detection-app/src/cv/edgesTs.ts new file mode 100644 index 0000000..9831ea2 --- /dev/null +++ b/ui/detection-app/src/cv/edgesTs.ts @@ -0,0 +1,278 @@ +/** + * Edge detection — TypeScript port of gpu/models/cv/edges.py + * + * 1:1 with the Python version. Same algorithm, same parameters, + * same output format. Runs in the browser, no network. + */ + +import { toGrayscale, canny } from './imageOps' +import { houghLinesP, type LineSegment } from './hough' + +export interface EdgeRegion { + x: number + y: number + w: number + h: number + confidence: number + label: string +} + +export interface EdgeDetectionParams { + cannyLow: number + cannyHigh: number + houghThreshold: number + houghMinLength: number + houghMaxGap: number + pairMaxDistance: number + pairMinDistance: number +} + +export interface EdgeDetectionResult { + regions: EdgeRegion[] +} + +export interface EdgeDetectionDebugResult extends EdgeDetectionResult { + edgeImageData: ImageData // Canny output for overlay + linesImageData: ImageData // Frame with Hough lines drawn + horizontalCount: number + pairCount: number +} + +type HLine = { xMin: number; xMax: number; yMid: number; length: number } + +/** Set a pixel on ImageData with bounds check */ +function setPixel(img: ImageData, x: number, y: number, r: number, g: number, b: number) { + if (x >= 0 && x < img.width && y >= 0 && y < img.height) { + const p = (y * img.width + x) * 4 + img.data[p] = r + img.data[p + 1] = g + img.data[p + 2] = b + img.data[p + 3] = 255 + } +} + +/** Bresenham line drawing with thickness */ +function drawLineThick( + img: ImageData, + x0: number, y0: number, x1: number, y1: number, + r: number, g: number, b: number, + thickness: number = 1, +) { + const dx = Math.abs(x1 - x0) + const dy = Math.abs(y1 - y0) + const sx = x0 < x1 ? 1 : -1 + const sy = y0 < y1 ? 1 : -1 + let err = dx - dy + const half = Math.floor(thickness / 2) + + while (true) { + for (let oy = -half; oy <= half; oy++) { + for (let ox = -half; ox <= half; ox++) { + setPixel(img, x0 + ox, y0 + oy, r, g, b) + } + } + if (x0 === x1 && y0 === y1) break + const e2 = 2 * err + if (e2 > -dy) { err -= dy; x0 += sx } + if (e2 < dx) { err += dx; y0 += sy } + } +} + +const DEFAULT_PARAMS: EdgeDetectionParams = { + cannyLow: 50, + cannyHigh: 150, + houghThreshold: 80, + houghMinLength: 100, + houghMaxGap: 10, + pairMaxDistance: 200, + pairMinDistance: 15, +} + +/** Filter to near-horizontal lines (within 10 degrees) */ +function filterHorizontal(lines: LineSegment[], maxAngleDeg: number = 10): HLine[] { + const maxSlope = Math.tan((maxAngleDeg * Math.PI) / 180) + const result: HLine[] = [] + + for (const line of lines) { + const dx = line.x2 - line.x1 + if (dx === 0) continue + const slope = Math.abs((line.y2 - line.y1) / dx) + if (slope <= maxSlope) { + const yMid = (line.y1 + line.y2) / 2 + const xMin = Math.min(line.x1, line.x2) + const xMax = Math.max(line.x1, line.x2) + const length = Math.sqrt(dx * dx + (line.y2 - line.y1) ** 2) + result.push({ xMin, xMax, yMid, length }) + } + } + return result +} + +/** Find pairs of horizontal lines that could be top/bottom of a hoarding */ +function findLinePairs( + horizontals: HLine[], + minDistance: number, + maxDistance: number, +): [HLine, HLine][] { + const sorted = [...horizontals].sort((a, b) => a.yMid - b.yMid) + const pairs: [HLine, HLine][] = [] + const used = new Set() + + for (let i = 0; i < sorted.length; i++) { + if (used.has(i)) continue + const top = sorted[i] + + for (let j = i + 1; j < sorted.length; j++) { + if (used.has(j)) continue + const bottom = sorted[j] + const yGap = bottom.yMid - top.yMid + + if (yGap < minDistance) continue + if (yGap > maxDistance) break + + // Check horizontal overlap (50% of shorter line) + const overlapStart = Math.max(top.xMin, bottom.xMin) + const overlapEnd = Math.min(top.xMax, bottom.xMax) + const overlap = overlapEnd - overlapStart + const shorterLength = Math.min(top.xMax - top.xMin, bottom.xMax - bottom.xMin) + + if (shorterLength > 0 && overlap / shorterLength >= 0.5) { + pairs.push([top, bottom]) + used.add(i) + used.add(j) + break + } + } + } + return pairs +} + +/** Convert a line pair to a bounding box */ +function pairToBox( + top: HLine, + bottom: HLine, + frameWidth: number, + frameHeight: number, +): EdgeRegion | null { + const x = Math.max(0, Math.min(top.xMin, bottom.xMin)) + const y = Math.max(0, top.yMid) + const x2 = Math.min(frameWidth, Math.max(top.xMax, bottom.xMax)) + const y2 = Math.min(frameHeight, bottom.yMid) + const w = x2 - x + const h = y2 - y + + if (w < 20 || h < 5) return null + + const avgLineLength = (top.length + bottom.length) / 2 + const coverage = Math.min(1.0, avgLineLength / Math.max(w, 1)) + + return { + x: Math.round(x), + y: Math.round(y), + w: Math.round(w), + h: Math.round(h), + confidence: Math.round(coverage * 1000) / 1000, + label: 'edge_region', + } +} + +/** + * Detect edges in an RGBA ImageData. + * + * Equivalent to gpu/models/cv/edges.py detect_edges() + */ +export function detectEdges( + imageData: ImageData, + params: Partial = {}, +): EdgeDetectionResult { + const p = { ...DEFAULT_PARAMS, ...params } + const { width, height } = imageData + + const gray = toGrayscale(imageData.data, width, height) + const edges = canny(gray, width, height, p.cannyLow, p.cannyHigh) + + const rawLines = houghLinesP(edges, width, height, p.houghThreshold, p.houghMinLength, p.houghMaxGap) + const horizontals = filterHorizontal(rawLines) + + if (horizontals.length < 2) return { regions: [] } + + const pairs = findLinePairs(horizontals, p.pairMinDistance, p.pairMaxDistance) + const regions: EdgeRegion[] = [] + for (const [top, bottom] of pairs) { + const box = pairToBox(top, bottom, width, height) + if (box) regions.push(box) + } + + return { regions } +} + +/** + * Detect edges with debug visualizations. + * + * Equivalent to gpu/models/cv/edges.py detect_edges_debug() + */ +export function detectEdgesDebug( + imageData: ImageData, + params: Partial = {}, +): EdgeDetectionDebugResult { + const p = { ...DEFAULT_PARAMS, ...params } + const { width, height, data } = imageData + + const gray = toGrayscale(data, width, height) + const edges = canny(gray, width, height, p.cannyLow, p.cannyHigh) + + // Edge overlay — white edges on black + const edgeImageData = new ImageData(width, height) + for (let i = 0; i < edges.length; i++) { + const px = i * 4 + edgeImageData.data[px] = edges[i] + edgeImageData.data[px + 1] = edges[i] + edgeImageData.data[px + 2] = edges[i] + edgeImageData.data[px + 3] = 255 + } + + const rawLines = houghLinesP(edges, width, height, p.houghThreshold, p.houghMinLength, p.houghMaxGap) + const horizontals = filterHorizontal(rawLines) + + // Lines overlay — darken original frame so lines pop, then draw + const linesImageData = new ImageData(new Uint8ClampedArray(data), width, height) + for (let i = 0; i < linesImageData.data.length; i += 4) { + linesImageData.data[i] = Math.round(linesImageData.data[i] * 0.3) + linesImageData.data[i + 1] = Math.round(linesImageData.data[i + 1] * 0.3) + linesImageData.data[i + 2] = Math.round(linesImageData.data[i + 2] * 0.3) + } + + // Draw all Hough lines in red (3px thick) + for (const line of rawLines) { + drawLineThick(linesImageData, line.x1, line.y1, line.x2, line.y2, 255, 50, 50, 2) + } + + // Draw horizontal lines in cyan (3px thick) + for (const h of horizontals) { + drawLineThick(linesImageData, Math.round(h.xMin), Math.round(h.yMid), Math.round(h.xMax), Math.round(h.yMid), 0, 255, 255, 3) + } + + const pairs = horizontals.length >= 2 + ? findLinePairs(horizontals, p.pairMinDistance, p.pairMaxDistance) + : [] + + // Draw paired lines in bright green (4px thick) + for (const [top, bottom] of pairs) { + drawLineThick(linesImageData, Math.round(top.xMin), Math.round(top.yMid), Math.round(top.xMax), Math.round(top.yMid), 0, 255, 0, 4) + drawLineThick(linesImageData, Math.round(bottom.xMin), Math.round(bottom.yMid), Math.round(bottom.xMax), Math.round(bottom.yMid), 0, 255, 0, 4) + } + + const regions: EdgeRegion[] = [] + for (const [top, bottom] of pairs) { + const box = pairToBox(top, bottom, width, height) + if (box) regions.push(box) + } + + return { + regions, + edgeImageData, + linesImageData, + horizontalCount: horizontals.length, + pairCount: pairs.length, + } +} diff --git a/ui/detection-app/src/cv/index.ts b/ui/detection-app/src/cv/index.ts index 0f48e4e..b644049 100644 --- a/ui/detection-app/src/cv/index.ts +++ b/ui/detection-app/src/cv/index.ts @@ -1,52 +1,77 @@ /** * Browser-side CV — public API. * - * Runs edge detection directly on the main thread. - * Pure TypeScript, no WASM, no dependencies. - * ~10-50ms per 1080p frame — fast enough for slider feedback. + * Three execution backends for edge detection: + * - TS: pure TypeScript (always available, ~10-50ms per 1080p frame) + * - WASM: OpenCV.js via WebAssembly (async load, same perf, same API as GPU server) + * - Server: GPU box over HTTP * - * TODO: Move to Web Worker when processing larger batches. - * - * Usage: - * import { runEdgeDetection, runEdgeDetectionDebug } from '@/cv' - * const result = await runEdgeDetection(imageData, params) + * Field segmentation only has WASM + Server (no TS port). */ -import { detectEdges, detectEdgesDebug, type EdgeRegion, type EdgeDetectionParams } from './edges' +// WASM versions (edges + segmentation) +import { detectEdges, detectEdgesDebug } from './edges' +import { segmentField, segmentFieldDebug } from './segmentation' -export type { EdgeRegion, EdgeDetectionParams } from './edges' -export type { EdgeDetectionResult, EdgeDetectionDebugResult } from './edges' +// Pure TS versions (edges only — the original fallback) +import { detectEdges as detectEdgesTs, detectEdgesDebug as detectEdgesTsDebug } from './edgesTs' + +export type { EdgeRegion, EdgeDetectionParams } from './edgesTs' +export type { EdgeDetectionResult, EdgeDetectionDebugResult } from './edgesTs' +export type { SegmentationParams, SegmentationResult, SegmentationDebugResult } from './segmentation' export { matchTracks, renderTracksToImageData } from './tracks' export type { Track, TrackPoint } from './tracks' -/** Run edge detection. Returns bounding boxes. */ +// --- Edge detection: pure TS (always works, no async load) --- + +export function runEdgeDetectionTs( + imageData: ImageData, + params: Partial = {}, +) { + return detectEdgesTs(imageData, params) +} + +export function runEdgeDetectionTsDebug( + imageData: ImageData, + params: Partial = {}, +) { + return detectEdgesTsDebug(imageData, params) +} + +// --- Edge detection: WASM (needs async opencv load) --- + export async function runEdgeDetection( imageData: ImageData, - params: Partial = {}, -): Promise<{ regions: EdgeRegion[] }> { + params: Partial = {}, +) { return detectEdges(imageData, params) } -/** Run edge detection with debug overlays. Returns boxes + visualization ImageData. */ export async function runEdgeDetectionDebug( imageData: ImageData, - params: Partial = {}, -): Promise<{ - regions: EdgeRegion[] - edgeImageData: ImageData - linesImageData: ImageData - horizontalCount: number - pairCount: number -}> { + params: Partial = {}, +) { return detectEdgesDebug(imageData, params) } -/** - * Decode a base64 JPEG string to ImageData. - * - * Used to convert the checkpoint frame (base64) into ImageData - * that the CV functions can process. - */ +// --- Segmentation: WASM only --- + +export async function runSegmentation( + imageData: ImageData, + params: Partial = {}, +) { + return segmentField(imageData, params) +} + +export async function runSegmentationDebug( + imageData: ImageData, + params: Partial = {}, +) { + return segmentFieldDebug(imageData, params) +} + +// --- Utilities --- + export function b64ToImageData(b64: string): Promise { return new Promise((resolve, reject) => { const img = new Image() @@ -61,12 +86,6 @@ export function b64ToImageData(b64: string): Promise { }) } -/** - * Encode ImageData to base64 PNG string (preserves transparency). - * - * Used for overlays that need a transparent background (e.g. motion tracks). - * Pair with srcFormat: 'png' on the FrameOverlay. - */ export async function imageDataToPngB64(imageData: ImageData): Promise { const canvas = new OffscreenCanvas(imageData.width, imageData.height) const ctx = canvas.getContext('2d')! @@ -81,12 +100,6 @@ export async function imageDataToPngB64(imageData: ImageData): Promise { return btoa(binary) } -/** - * Encode ImageData to base64 JPEG string. - * - * Used to convert debug overlay ImageData back to base64 - * for the FrameRenderer overlays prop. - */ export async function imageDataToB64(imageData: ImageData): Promise { const canvas = new OffscreenCanvas(imageData.width, imageData.height) const ctx = canvas.getContext('2d')! diff --git a/ui/detection-app/src/cv/opencv.ts b/ui/detection-app/src/cv/opencv.ts new file mode 100644 index 0000000..60a44a3 --- /dev/null +++ b/ui/detection-app/src/cv/opencv.ts @@ -0,0 +1,84 @@ +/** + * OpenCV WASM loader — split build. + * + * public/opencv.js — 146KB JS loader (stripped of embedded WASM) + * public/opencv_js.wasm — 3.1MB WASM binary (custom build: imgproc + core only) + * + * The JS parses instantly. The WASM compiles asynchronously in the + * background — no main thread freeze. + */ + +let cvInstance: any = null +let cvPromise: Promise | null = null +let loadFailed = false +let loading = false + +export async function getCV(): Promise { + if (cvInstance) return cvInstance + if (loadFailed) return null + + if (!cvPromise) { + cvPromise = loadCV() + } + + return cvPromise +} + +export function isCVLoaded(): boolean { + return cvInstance != null +} + +export function isCVLoading(): boolean { + return loading +} + +export function isCVFailed(): boolean { + return loadFailed +} + +async function loadCV(): Promise { + loading = true + try { + const base = import.meta.env.BASE_URL ?? '/' + + // Module config must exist before script loads — the UMD wrapper reads it + ;(globalThis as any).Module = { + locateFile: (path: string) => `${base}${path}`, + } + + await new Promise((resolve, reject) => { + const script = document.createElement('script') + script.src = `${base}opencv.js` + script.async = true + + const timeout = setTimeout(() => { + reject(new Error('opencv.js load timed out')) + }, 10000) + + script.onload = () => { + clearTimeout(timeout) + resolve() + } + script.onerror = () => { + clearTimeout(timeout) + reject(new Error('Failed to load opencv.js')) + } + document.head.appendChild(script) + }) + + // OpenCV 4.12 UMD wrapper calls the factory and assigns the readyPromise to window.cv + const cvReady = (globalThis as any).cv + if (!cvReady) throw new Error('cv not on globalThis after script load') + + cvInstance = await cvReady + + loading = false + return cvInstance + } catch (e) { + loading = false + loadFailed = true + cvPromise = null + console.warn('[opencv] Load failed:', e) + return null + } +} diff --git a/ui/detection-app/src/cv/segmentation.ts b/ui/detection-app/src/cv/segmentation.ts new file mode 100644 index 0000000..17c2e2d --- /dev/null +++ b/ui/detection-app/src/cv/segmentation.ts @@ -0,0 +1,212 @@ +/** + * Field segmentation — OpenCV WASM version. + * + * 1:1 with gpu/models/cv/segmentation.py. HSV green mask + + * morphology + contour → pitch boundary detection. + */ + +import { getCV } from './opencv' + +export interface SegmentationParams { + hueLow: number + hueHigh: number + satLow: number + satHigh: number + valLow: number + valHigh: number + morphKernel: number + minAreaRatio: number +} + +export interface SegmentationResult { + boundary: [number, number][] + coverage: number + maskImageData: ImageData +} + +export interface SegmentationDebugResult extends SegmentationResult { + overlayImageData: ImageData +} + +const DEFAULT_PARAMS: SegmentationParams = { + hueLow: 30, + hueHigh: 85, + satLow: 30, + satHigh: 255, + valLow: 30, + valHigh: 255, + morphKernel: 15, + minAreaRatio: 0.05, +} + +/** + * Detect the pitch area using HSV green thresholding. + */ +export async function segmentField( + imageData: ImageData, + params: Partial = {}, +): Promise { + const cv = await getCV() + if (!cv) throw new Error('OpenCV WASM not available') + const p = { ...DEFAULT_PARAMS, ...params } + const { width, height } = imageData + + const src = cv.matFromImageData(imageData) + const rgb = new cv.Mat() + const hsv = new cv.Mat() + const mask = new cv.Mat() + + try { + // ImageData is RGBA, convert to RGB then HSV + cv.cvtColor(src, rgb, cv.COLOR_RGBA2RGB) + cv.cvtColor(rgb, hsv, cv.COLOR_RGB2HSV) + + const lower = new cv.Mat(1, 1, cv.CV_8UC3, new cv.Scalar(p.hueLow, p.satLow, p.valLow)) + const upper = new cv.Mat(1, 1, cv.CV_8UC3, new cv.Scalar(p.hueHigh, p.satHigh, p.valHigh)) + + try { + cv.inRange(hsv, lower, upper, mask) + } finally { + lower.delete() + upper.delete() + } + + // Morphology — close then open + const k = p.morphKernel + const kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE, new cv.Size(k, k)) + try { + cv.morphologyEx(mask, mask, cv.MORPH_CLOSE, kernel) + cv.morphologyEx(mask, mask, cv.MORPH_OPEN, kernel) + } finally { + kernel.delete() + } + + // Find contours + const contours = new cv.MatVector() + const hierarchy = new cv.Mat() + try { + cv.findContours(mask, contours, hierarchy, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE) + + const minArea = p.minAreaRatio * height * width + let boundary: [number, number][] = [] + let coverage = 0 + + // Find the largest contour above min area + let bestContour: any = null + let bestArea = 0 + + for (let i = 0; i < contours.size(); i++) { + const c = contours.get(i) + const area = cv.contourArea(c) + if (area >= minArea && area > bestArea) { + bestContour = c + bestArea = area + } + } + + if (bestContour) { + coverage = bestArea / (height * width) + + // Extract boundary points + for (let j = 0; j < bestContour.rows; j++) { + boundary.push([bestContour.data32S[j * 2], bestContour.data32S[j * 2 + 1]]) + } + + // Refine mask to just the largest contour + mask.setTo(new cv.Scalar(0)) + const contourVec = new cv.MatVector() + contourVec.push_back(bestContour) + try { + cv.drawContours(mask, contourVec, -1, new cv.Scalar(255), cv.FILLED) + } finally { + contourVec.delete() + } + } + + // Convert mask to RGBA ImageData + const maskRgba = new cv.Mat() + try { + cv.cvtColor(mask, maskRgba, cv.COLOR_GRAY2RGBA) + const maskImageData = new ImageData( + new Uint8ClampedArray(maskRgba.data), + width, + height, + ) + + return { boundary, coverage, maskImageData } + } finally { + maskRgba.delete() + } + } finally { + contours.delete() + hierarchy.delete() + } + } finally { + src.delete() + rgb.delete() + hsv.delete() + mask.delete() + } +} + +/** + * Same as segmentField but includes a blended overlay for the editor. + */ +export async function segmentFieldDebug( + imageData: ImageData, + params: Partial = {}, +): Promise { + const cv = await getCV() + if (!cv) throw new Error('OpenCV WASM not available') + const result = await segmentField(imageData, params) + const { width, height } = imageData + + // Build green overlay blended with original frame + const src = cv.matFromImageData(imageData) + const rgb = new cv.Mat() + const overlay = new cv.Mat(height, width, cv.CV_8UC3, new cv.Scalar(0, 0, 0)) + const maskGray = new cv.Mat() + + try { + cv.cvtColor(src, rgb, cv.COLOR_RGBA2RGB) + + // Recreate mask from result maskImageData + const maskSrc = cv.matFromImageData(result.maskImageData) + try { + cv.cvtColor(maskSrc, maskGray, cv.COLOR_RGBA2GRAY) + } finally { + maskSrc.delete() + } + + // Green overlay where mask > 0 + overlay.setTo(new cv.Scalar(0, 255, 0), maskGray) + + // Blend: 70% original + 30% overlay + const blended = new cv.Mat() + try { + cv.addWeighted(rgb, 0.7, overlay, 0.3, 0, blended) + + // Convert to RGBA ImageData + const blendedRgba = new cv.Mat() + try { + cv.cvtColor(blended, blendedRgba, cv.COLOR_RGB2RGBA) + const overlayImageData = new ImageData( + new Uint8ClampedArray(blendedRgba.data), + width, + height, + ) + + return { ...result, overlayImageData } + } finally { + blendedRgba.delete() + } + } finally { + blended.delete() + } + } finally { + src.delete() + rgb.delete() + overlay.delete() + maskGray.delete() + } +} diff --git a/ui/detection-app/src/cv/wasmBridge.ts b/ui/detection-app/src/cv/wasmBridge.ts new file mode 100644 index 0000000..5783ac9 --- /dev/null +++ b/ui/detection-app/src/cv/wasmBridge.ts @@ -0,0 +1,121 @@ +/** + * WASM Worker Bridge — runs OpenCV operations in a Web Worker. + * + * The worker loads opencv.js (~10MB) in its own thread. + * Main thread stays responsive during WASM compilation. + * + * Lazy-creates the worker on first call. Sends ImageData + + * params, gets back results via transferable buffers. + */ + +import type { EdgeDetectionParams, EdgeDetectionResult, EdgeDetectionDebugResult } from './edges' +import type { SegmentationParams, SegmentationDebugResult } from './segmentation' + +let worker: Worker | null = null +let messageId = 0 +const pending = new Map void; reject: (e: Error) => void }>() +let initPromise: Promise | null = null +let ready = false +let failed = false + +function getWorker(): Worker { + if (!worker) { + worker = new Worker(new URL('./worker.ts', import.meta.url), { type: 'module' }) + worker.onmessage = (event) => { + const { id, type, error: errMsg, ...data } = event.data + const handler = pending.get(id) + if (!handler) return + pending.delete(id) + + if (type === 'error') { + handler.reject(new Error(errMsg ?? 'Worker error')) + } else { + handler.resolve(data) + } + } + worker.onerror = (event) => { + // Reject all pending + for (const [, handler] of pending) { + handler.reject(new Error(event.message ?? 'Worker crashed')) + } + pending.clear() + } + } + return worker +} + +function postMessage(type: string, imageData: ImageData, params: Record): Promise { + return new Promise((resolve, reject) => { + const id = ++messageId + pending.set(id, { resolve, reject }) + getWorker().postMessage({ id, type, imageData, params }, [imageData.data.buffer]) + }) +} + +/** Initialize the worker + load WASM. Returns true if ready. */ +export async function initWasm(): Promise { + if (ready) return true + if (failed) return false + if (initPromise) return initPromise + + initPromise = (async () => { + try { + // Send a ping — the worker will load opencv.js on first message + const result = await postMessage('ping', new ImageData(1, 1), {}) + ready = true + return true + } catch { + failed = true + return false + } + })() + + return initPromise +} + +export function isWasmReady(): boolean { + return ready +} + +export function isWasmFailed(): boolean { + return failed +} + +// --- Edge detection --- + +export async function detectEdgesWasm( + imageData: ImageData, + params: Partial, +): Promise { + const data = await postMessage('detect_edges', imageData, params as Record) + return { regions: data.regions } +} + +export async function detectEdgesWasmDebug( + imageData: ImageData, + params: Partial, +): Promise { + const data = await postMessage('detect_edges_debug', imageData, params as Record) + return { + regions: data.regions, + edgeImageData: data.edgeImageData, + linesImageData: data.linesImageData, + horizontalCount: data.horizontalCount, + pairCount: data.pairCount, + } +} + +// --- Segmentation --- + +export async function segmentFieldWasmDebug( + imageData: ImageData, + params: Partial, +): Promise { + const data = await postMessage('segment_field_debug', imageData, params as Record) + return { + boundary: data.boundary, + coverage: data.coverage, + maskImageData: data.maskImageData, + overlayImageData: data.overlayImageData, + } +} diff --git a/ui/detection-app/src/cv/worker.ts b/ui/detection-app/src/cv/worker.ts index be614bc..15e58a9 100644 --- a/ui/detection-app/src/cv/worker.ts +++ b/ui/detection-app/src/cv/worker.ts @@ -1,25 +1,27 @@ /** - * CV Web Worker — runs edge detection off the main thread. + * CV Web Worker — runs CV operations off the main thread. * * Message protocol: - * Main → Worker: { type: 'detect_edges', imageData: ImageData, params: {...} } - * Main → Worker: { type: 'detect_edges_debug', imageData: ImageData, params: {...} } - * Worker → Main: { type: 'result', regions: [...] } - * Worker → Main: { type: 'debug_result', regions: [...], edgeImageData, linesImageData, horizontalCount, pairCount } - * Worker → Main: { type: 'error', message: string } + * Main → Worker: { type: 'detect_edges', imageData, params } + * Main → Worker: { type: 'detect_edges_debug', imageData, params } + * Main → Worker: { type: 'segment_field', imageData, params } + * Main → Worker: { type: 'segment_field_debug', imageData, params } + * Worker → Main: { type: 'result', ... } + * Worker → Main: { type: 'error', message } */ -import { detectEdges, detectEdgesDebug, type EdgeDetectionParams } from './edges' +import { detectEdges, detectEdgesDebug } from './edges' +import { segmentField, segmentFieldDebug } from './segmentation' -self.onmessage = (event: MessageEvent) => { +self.onmessage = async (event: MessageEvent) => { const { type, imageData, params } = event.data try { if (type === 'detect_edges') { - const result = detectEdges(imageData, params) + const result = await detectEdges(imageData, params) self.postMessage({ type: 'result', regions: result.regions }) } else if (type === 'detect_edges_debug') { - const result = detectEdgesDebug(imageData, params) + const result = await detectEdgesDebug(imageData, params) self.postMessage({ type: 'debug_result', regions: result.regions, @@ -28,10 +30,31 @@ self.onmessage = (event: MessageEvent) => { horizontalCount: result.horizontalCount, pairCount: result.pairCount, }, [ - // Transfer ownership of the backing buffers for zero-copy result.edgeImageData.data.buffer, result.linesImageData.data.buffer, ]) + } else if (type === 'segment_field') { + const result = await segmentField(imageData, params) + self.postMessage({ + type: 'result', + boundary: result.boundary, + coverage: result.coverage, + maskImageData: result.maskImageData, + }, [ + result.maskImageData.data.buffer, + ]) + } else if (type === 'segment_field_debug') { + const result = await segmentFieldDebug(imageData, params) + self.postMessage({ + type: 'debug_result', + boundary: result.boundary, + coverage: result.coverage, + maskImageData: result.maskImageData, + overlayImageData: result.overlayImageData, + }, [ + result.maskImageData.data.buffer, + result.overlayImageData.data.buffer, + ]) } else { self.postMessage({ type: 'error', message: `Unknown message type: ${type}` }) } diff --git a/ui/detection-app/src/types/sse-contract.ts b/ui/detection-app/src/types/sse-contract.ts index cf55d7c..cfed5ed 100644 --- a/ui/detection-app/src/types/sse-contract.ts +++ b/ui/detection-app/src/types/sse-contract.ts @@ -139,3 +139,19 @@ export interface RetryResponse { task_id: string; job_id: string; } + +export interface RunRequest { + video_path: string; + profile_name: string; + source_asset_id: string; + checkpoint: boolean; + skip_vlm: boolean; + skip_cloud: boolean; + log_level: string; +} + +export interface RunResponse { + status: string; + job_id: string; + video_path: string; +} diff --git a/ui/detection-app/src/vite-env.d.ts b/ui/detection-app/src/vite-env.d.ts new file mode 100644 index 0000000..11f02fe --- /dev/null +++ b/ui/detection-app/src/vite-env.d.ts @@ -0,0 +1 @@ +/// diff --git a/ui/framework/src/tokens.css b/ui/framework/src/tokens.css index d719c20..4858ec3 100644 --- a/ui/framework/src/tokens.css +++ b/ui/framework/src/tokens.css @@ -43,3 +43,17 @@ --panel-border: 1px solid var(--border); --panel-header-height: 36px; } + +/* Animated gradient outline for buttons in a waiting state. + Usage: add class="waiting" to any button/element. */ +@keyframes waiting-glow { + 0% { box-shadow: 0 0 3px 1px var(--status-processing); } + 33% { box-shadow: 0 0 3px 1px var(--status-live); } + 66% { box-shadow: 0 0 3px 1px var(--status-escalating); } + 100% { box-shadow: 0 0 3px 1px var(--status-processing); } +} + +.waiting { + animation: waiting-glow 2s linear infinite; + outline: 1px solid transparent; +}