Compare commits

..

56 Commits

Author SHA1 Message Date
ddb4f17faa modelgen: recurse through nested generics in list resolvers
_get_list_inner (Pydantic) and _resolve_ts_list (TypeScript) collapsed
nested generics to "str" / "string[]" — so List[List[int]] became
List[str] and List[Dict[str,Any]] became List[str]. Fix recurses on
list-of-list and falls back to Dict[str,Any] / Record<string,unknown>
for list-of-dict.

Regenerated outputs:
- core/gpu/models/models.py: SegmentFieldResponse.boundary now
  List[List[int]], unblocking /segment_field which was 500ing on every
  request with pydantic ValidationError
- core/db/models.py + ui/common/types/generated.ts: Brand.airings now
  matches its source schema (List[Dict[str,Any]] / Record<string,unknown>[])
2026-04-29 08:48:21 -03:00
f66d3a273f migrate to uv + pyproject.toml
- root pyproject.toml replaces requirements.txt and requirements-worker.txt
  (worker = root + ffmpeg-python which root already had); test deps moved
  to [dependency-groups] dev
- core/gpu/pyproject.toml replaces core/gpu/requirements.txt; uses
  [tool.uv.sources] to pin torch/torchvision and paddlepaddle-gpu to their
  CUDA index URLs, replacing the manual reinstall dance from old comments
- Dockerfiles use uv sync --frozen against uv.lock for reproducible builds;
  PATH includes /app/.venv/bin so k8s manifests' bare uvicorn/celery
  commands resolve without wrapping in uv run
- core/gpu/run.sh local mode now does uv sync + uv run python server.py;
  errors out cleanly if uv is missing
2026-04-29 07:32:56 -03:00
d9e0794b83 fix vite resolution for soleprint-ui deps; align kind template
- vite: alias @vue-flow/core and uplot to detection-app's node_modules
  so they resolve when imported through the symlinked soleprint-ui
  framework (Vite follows the symlink and would otherwise look up
  from ui/framework/, where node_modules doesn't exist)
- kind-config.yaml.tpl: match the new port layout (gateway on 30080
  via 127.0.0.1, redis on 6382), avoiding port-80 conflict with Caddy;
  keeps the extraMount for /mnt/media so MinIO can seed from host
2026-04-29 06:11:40 -03:00
ed1f6b761b update detection-app lockfile for @vue-flow/core and uplot 2026-04-29 05:51:48 -03:00
020f3540d3 phase 5: edge transforms, soleprint-ui rename, infra fixes
- pipeline edge transforms: stages can declare accepted_transforms,
  edges carry a transform dict, runner injects per-stage and nodes
  apply (e.g. invert_mask before edge detection); editable from UI
  via PUT /config/edge-transform
- rename mpr-ui-framework -> soleprint-ui (now an external package
  synced via .spr from /home/mariano/wdir/spr); add @vue-flow/core
  and uplot to detection-app so linked package resolves them
- Tiltfile guards kubectl context, k8s commands pin --context kind-mpr
- kind-config: gateway on hostPort 30080 (Caddy fronts mpr.local.ar)
- modelgen: pyproject.toml, .spr marker, dict default_factory support
2026-04-29 05:31:08 -03:00
55e83e4203 compare view 2026-03-30 13:05:28 -03:00
aac27b8504 a 2026-03-30 09:53:10 -03:00
4220b0418e phase 4 2026-03-30 07:22:14 -03:00
d0707333fd phase 3 2026-03-28 10:05:59 -03:00
e46bbc419c phase 2 2026-03-28 09:40:07 -03:00
0bd3888155 phase 1 2026-03-28 08:51:25 -03:00
acc99e691d phase 9 2026-03-28 01:11:31 -03:00
8a90436f33 phase 8 2026-03-28 00:59:59 -03:00
f6ef95ebea phase 7 2026-03-28 00:24:18 -03:00
49da927da0 phase 5 2026-03-28 00:09:47 -03:00
886720c3ce phase 6 2026-03-27 23:27:03 -03:00
1c6af767eb phase 4 2026-03-27 22:57:45 -03:00
94c7b21ae5 phase 3 2026-03-27 22:54:58 -03:00
3d8e7291f3 phase 3 2026-03-27 22:48:31 -03:00
bf30acd4df phase 1 2026-03-27 22:15:50 -03:00
a3b51c458d timeline 2026-03-27 07:33:49 -03:00
51ce14a812 major refactor 2026-03-27 06:14:02 -03:00
bcf6f3dc71 use sqlalchemy pattern 2026-03-27 05:19:45 -03:00
291ac8dd40 refactor stage 1 2026-03-27 04:23:21 -03:00
df6bcb01e8 phase 1 2026-03-27 00:41:23 -03:00
65814b5b9e phase cv 0 2026-03-26 22:22:35 -03:00
beb0416280 add heavy loggin 2026-03-26 11:45:31 -03:00
a85722f96a move to postgresql 2026-03-26 10:59:33 -03:00
c9ba9e4f5f refactor storage minio for k8s 2026-03-26 09:20:23 -03:00
e27cb5bcc3 phase 12 2026-03-26 07:40:14 -03:00
731964ca10 phase 11 2026-03-26 05:23:37 -03:00
d58a90157a schema clean up and refactor 2026-03-26 05:14:33 -03:00
08c58a6a9d phase 10 2026-03-26 04:40:00 -03:00
08b67f2bb7 phase 9 2026-03-26 02:54:56 -03:00
dfa3c12514 phase 8 2026-03-26 01:30:26 -03:00
95246c5452 phase 7 2026-03-26 00:56:35 -03:00
3df9ed5ada phase 6 2026-03-23 16:55:13 -03:00
4fdbdfc6d3 phase 5 2026-03-23 15:52:03 -03:00
b57da622cb phase 4 2026-03-23 15:18:23 -03:00
5ed876d694 phase 3 2026-03-23 14:42:36 -03:00
71fd0510de phase 2 2026-03-23 11:13:30 -03:00
8186bb5fe6 phase 1 2026-03-23 09:58:40 -03:00
9c9c7dff09 merge chunker 2026-03-23 02:56:13 -03:00
b40bd68411 chunker ui redo 2026-03-15 16:03:53 -03:00
d5a3372d6b ui selector 2026-03-13 14:59:32 -03:00
5ceb8172ea docker build fix 2026-03-13 14:31:26 -03:00
ccc478fbaa chunker and ui 2026-03-13 14:29:38 -03:00
3eeedebb15 major refactor 2026-03-13 01:07:02 -03:00
eaaf2ad60c executor abstraction, graphene to strawberry 2026-03-12 23:27:34 -03:00
4e9d731cff Remove REST API, keep GraphQL as sole API
- Add missing GraphQL mutations: retryJob, updateAsset, deleteAsset
- Add UpdateAssetRequest and DeleteResult to schema source of truth
- Move Lambda callback endpoint to main.py (only REST endpoint)
- Remove REST routes, pydantic schemas, and deps
- Remove pydantic target from modelgen.json
- Update architecture diagrams and documentation
2026-02-12 20:07:51 -03:00
dbbaad5b94 Display architecture diagrams side-by-side for easier comparison 2026-02-12 19:56:21 -03:00
2ac31083e5 Update root docs index.html to reference new separate architecture diagrams 2026-02-12 19:55:00 -03:00
f481fa6cbe Remove old combined architecture diagram 2026-02-12 19:53:23 -03:00
cc1a1b9953 Split architecture diagram into separate local and AWS diagrams 2026-02-12 19:49:47 -03:00
da1ff62877 Merge aws-int: Add AWS integration with GraphQL, Step Functions, and Lambda
# Conflicts:
#	docs/architecture/index.html
2026-02-12 19:47:15 -03:00
9cead74fb3 updated docs 2026-02-12 19:46:12 -03:00
381 changed files with 39753 additions and 3401 deletions

View File

@@ -0,0 +1,11 @@
---
name: agent_sdk_future
description: Claude Agent SDK for general mpr tasks (not vision provider), uses OAuth not API keys
type: project
---
Claude Agent SDK (`claude-agent-sdk`) is for future general-purpose tasks in mpr, NOT for the cloud vision provider.
**Why:** The Agent SDK uses Claude Code CLI's OAuth (browser login, no API keys) and is designed for agentic tasks (file read/edit, bash, web search). The vision provider needs raw API calls with image payloads — use the `anthropic` SDK with `ANTHROPIC_API_KEY` for that.
**How to apply:** When adding Claude-powered automation to mpr (e.g., log analysis, config suggestions, code review on pipeline changes), use the Agent SDK. For the cloud LLM escalation stage (image crops → brand ID), keep using the `anthropic` SDK with API key auth.

30
.dockerignore Normal file
View File

@@ -0,0 +1,30 @@
# Python
.venv/
__pycache__/
*.pyc
*.egg-info/
.pytest_cache/
# Node
node_modules/
ui/*/node_modules/
ui/*/dist/
# Media (9.8GB — mounted via volume, never needed in image)
media/
# Git
.git/
# IDE / OS
.idea/
.vscode/
*.swp
.DS_Store
# Docker
ctrl/docker-compose.yml
# Docs
docs/
*.md

7
.gitignore vendored
View File

@@ -17,10 +17,8 @@ env/
*.pot *.pot
*.pyc *.pyc
db.sqlite3 db.sqlite3
media/in/* media/*
!media/in/.gitkeep !media/.gitkeep
media/out/*
!media/out/.gitkeep
# Node # Node
node_modules/ node_modules/
@@ -39,3 +37,4 @@ Thumbs.db
# Project specific # Project specific
def/ def/
ctrl/k8s/overlays/dev/local-config.yaml

View File

@@ -71,12 +71,12 @@ docker compose logs -f
docker compose logs -f celery docker compose logs -f celery
# Create admin user # Create admin user
docker compose exec django python manage.py createsuperuser docker compose exec django python admin/manage.py createsuperuser
``` ```
## Code Generation ## Code Generation
Models are defined as dataclasses in `schema/models/` and generated via `modelgen`: Models are defined as dataclasses in `core/schema/models/` and generated via `modelgen`:
- **Django ORM** models (`--include dataclasses,enums`) - **Django ORM** models (`--include dataclasses,enums`)
- **Pydantic** schemas (`--include dataclasses,enums`) - **Pydantic** schemas (`--include dataclasses,enums`)
- **TypeScript** types (`--include dataclasses,enums,api`) - **TypeScript** types (`--include dataclasses,enums,api`)
@@ -113,26 +113,29 @@ See [docs/media-storage.md](docs/media-storage.md) for full details.
``` ```
mpr/ mpr/
├── api/ # FastAPI application ├── admin/ # Django project
│ ├── routes/ # API endpoints │ ├── manage.py # Django management script
│ └── schemas/ # Pydantic models (generated) │ └── mpr/ # Django settings & app
├── core/ # Core utilities │ └── media_assets/# Django app
│ └── ffmpeg/ # FFmpeg wrappers ├── core/ # Core application logic
│ ├── api/ # FastAPI + GraphQL API
│ │ └── schema/ # GraphQL types (generated)
│ ├── ffmpeg/ # FFmpeg wrappers
│ ├── rpc/ # gRPC server & client
│ │ └── protos/ # Protobuf definitions (generated)
│ ├── schema/ # Source of truth
│ │ └── models/ # Dataclass definitions
│ ├── storage/ # S3/GCP/local storage backends
│ └── task/ # Celery job execution
│ ├── executor.py # Executor abstraction
│ └── tasks.py # Celery tasks
├── ctrl/ # Docker & deployment ├── ctrl/ # Docker & deployment
│ ├── docker-compose.yml │ ├── docker-compose.yml
│ └── nginx.conf │ └── nginx.conf
├── media/ ├── media/
│ ├── in/ # Source media files │ ├── in/ # Source media files
│ └── out/ # Transcoded output │ └── out/ # Transcoded output
├── rpc/ # gRPC server & client ├── modelgen/ # Code generation tool
│ └── protos/ # Protobuf definitions (generated)
├── mpr/ # Django project
│ └── media_assets/ # Django app
├── schema/ # Source of truth
│ └── models/ # Dataclass definitions
├── task/ # Celery job execution
│ ├── executor.py # Executor abstraction
│ └── tasks.py # Celery tasks
└── ui/ # Frontend └── ui/ # Frontend
└── timeline/ # React app └── timeline/ # React app
``` ```

View File

@@ -6,7 +6,9 @@ import sys
def main(): def main():
"""Run administrative tasks.""" """Run administrative tasks."""
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'mpr.settings') # Ensure project root is on sys.path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'admin.mpr.settings')
try: try:
from django.core.management import execute_from_command_line from django.core.management import execute_from_command_line
except ImportError as exc: except ImportError as exc:

View File

@@ -11,6 +11,6 @@ import os
from django.core.asgi import get_asgi_application from django.core.asgi import get_asgi_application
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'mpr.settings') os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'admin.mpr.settings')
application = get_asgi_application() application = get_asgi_application()

View File

@@ -2,9 +2,9 @@ import os
from celery import Celery from celery import Celery
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "mpr.settings") os.environ.setdefault("DJANGO_SETTINGS_MODULE", "admin.mpr.settings")
app = Celery("mpr") app = Celery("mpr")
app.config_from_object("django.conf:settings", namespace="CELERY") app.config_from_object("django.conf:settings", namespace="CELERY")
app.autodiscover_tasks() app.autodiscover_tasks()
app.autodiscover_tasks(["task"]) app.autodiscover_tasks(["core.jobs"])

View File

@@ -3,5 +3,6 @@ from django.apps import AppConfig
class MediaAssetsConfig(AppConfig): class MediaAssetsConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField" default_auto_field = "django.db.models.BigAutoField"
name = "mpr.media_assets" name = "admin.mpr.media_assets"
label = "media_assets"
verbose_name = "Media Assets" verbose_name = "Media Assets"

View File

@@ -4,10 +4,10 @@ from pathlib import Path
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from mpr.media_assets.models import TranscodePreset from admin.mpr.media_assets.models import TranscodePreset
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent.parent.parent)) sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent.parent.parent.parent))
from schema.models import BUILTIN_PRESETS from core.schema.models import BUILTIN_PRESETS
class Command(BaseCommand): class Command(BaseCommand):

View File

@@ -1,8 +1,7 @@
# Generated by Django 6.0.1 on 2026-02-01 15:13 # Generated by Django 4.2.29 on 2026-03-13 04:04
import django.db.models.deletion
import uuid
from django.db import migrations, models from django.db import migrations, models
import uuid
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -13,47 +12,21 @@ class Migration(migrations.Migration):
] ]
operations = [ operations = [
migrations.CreateModel(
name='TranscodePreset',
fields=[
('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)),
('name', models.CharField(max_length=100, unique=True)),
('description', models.TextField(blank=True, default='')),
('is_builtin', models.BooleanField(default=False)),
('container', models.CharField(default='mp4', max_length=20)),
('video_codec', models.CharField(default='libx264', max_length=50)),
('video_bitrate', models.CharField(blank=True, max_length=20, null=True)),
('video_crf', models.IntegerField(blank=True, null=True)),
('video_preset', models.CharField(blank=True, max_length=20, null=True)),
('resolution', models.CharField(blank=True, max_length=20, null=True)),
('framerate', models.FloatField(blank=True, null=True)),
('audio_codec', models.CharField(default='aac', max_length=50)),
('audio_bitrate', models.CharField(blank=True, max_length=20, null=True)),
('audio_channels', models.IntegerField(blank=True, null=True)),
('audio_samplerate', models.IntegerField(blank=True, null=True)),
('extra_args', models.JSONField(blank=True, default=list)),
('created_at', models.DateTimeField(auto_now_add=True)),
('updated_at', models.DateTimeField(auto_now=True)),
],
options={
'ordering': ['name'],
},
),
migrations.CreateModel( migrations.CreateModel(
name='MediaAsset', name='MediaAsset',
fields=[ fields=[
('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), ('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)),
('filename', models.CharField(max_length=500)), ('filename', models.CharField(max_length=500)),
('file_path', models.CharField(max_length=1000)), ('file_path', models.CharField(max_length=1000)),
('status', models.CharField(choices=[('pending', 'Pending Probe'), ('ready', 'Ready'), ('error', 'Error')], default='pending', max_length=20)), ('status', models.CharField(choices=[('pending', 'Pending'), ('ready', 'Ready'), ('error', 'Error')], default='pending', max_length=20)),
('error_message', models.TextField(blank=True, null=True)), ('error_message', models.TextField(blank=True, default='')),
('file_size', models.BigIntegerField(blank=True, null=True)), ('file_size', models.BigIntegerField(blank=True, null=True)),
('duration', models.FloatField(blank=True, null=True)), ('duration', models.FloatField(blank=True, default=None, null=True)),
('video_codec', models.CharField(blank=True, max_length=50, null=True)), ('video_codec', models.CharField(blank=True, max_length=255, null=True)),
('audio_codec', models.CharField(blank=True, max_length=50, null=True)), ('audio_codec', models.CharField(blank=True, max_length=255, null=True)),
('width', models.IntegerField(blank=True, null=True)), ('width', models.IntegerField(blank=True, default=None, null=True)),
('height', models.IntegerField(blank=True, null=True)), ('height', models.IntegerField(blank=True, default=None, null=True)),
('framerate', models.FloatField(blank=True, null=True)), ('framerate', models.FloatField(blank=True, default=None, null=True)),
('bitrate', models.BigIntegerField(blank=True, null=True)), ('bitrate', models.BigIntegerField(blank=True, null=True)),
('properties', models.JSONField(blank=True, default=dict)), ('properties', models.JSONField(blank=True, default=dict)),
('comments', models.TextField(blank=True, default='')), ('comments', models.TextField(blank=True, default='')),
@@ -63,36 +36,61 @@ class Migration(migrations.Migration):
], ],
options={ options={
'ordering': ['-created_at'], 'ordering': ['-created_at'],
'indexes': [models.Index(fields=['status'], name='media_asset_status_9ea2f2_idx'), models.Index(fields=['created_at'], name='media_asset_created_368039_idx')],
}, },
), ),
migrations.CreateModel( migrations.CreateModel(
name='TranscodeJob', name='TranscodeJob',
fields=[ fields=[
('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), ('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)),
('source_asset_id', models.UUIDField()),
('preset_id', models.UUIDField(blank=True, null=True)),
('preset_snapshot', models.JSONField(blank=True, default=dict)), ('preset_snapshot', models.JSONField(blank=True, default=dict)),
('trim_start', models.FloatField(blank=True, null=True)), ('trim_start', models.FloatField(blank=True, default=None, null=True)),
('trim_end', models.FloatField(blank=True, null=True)), ('trim_end', models.FloatField(blank=True, default=None, null=True)),
('output_filename', models.CharField(max_length=500)), ('output_filename', models.CharField(max_length=500)),
('output_path', models.CharField(blank=True, max_length=1000, null=True)), ('output_path', models.CharField(blank=True, max_length=1000, null=True)),
('output_asset_id', models.UUIDField(blank=True, null=True)),
('status', models.CharField(choices=[('pending', 'Pending'), ('processing', 'Processing'), ('completed', 'Completed'), ('failed', 'Failed'), ('cancelled', 'Cancelled')], default='pending', max_length=20)), ('status', models.CharField(choices=[('pending', 'Pending'), ('processing', 'Processing'), ('completed', 'Completed'), ('failed', 'Failed'), ('cancelled', 'Cancelled')], default='pending', max_length=20)),
('progress', models.FloatField(default=0.0)), ('progress', models.FloatField(default=0.0)),
('current_frame', models.IntegerField(blank=True, null=True)), ('current_frame', models.IntegerField(blank=True, default=None, null=True)),
('current_time', models.FloatField(blank=True, null=True)), ('current_time', models.FloatField(blank=True, default=None, null=True)),
('speed', models.CharField(blank=True, max_length=20, null=True)), ('speed', models.CharField(blank=True, max_length=255, null=True)),
('error_message', models.TextField(blank=True, null=True)), ('error_message', models.TextField(blank=True, default='')),
('celery_task_id', models.CharField(blank=True, max_length=100, null=True)), ('celery_task_id', models.CharField(blank=True, max_length=255, null=True)),
('execution_arn', models.CharField(blank=True, max_length=255, null=True)),
('priority', models.IntegerField(default=0)), ('priority', models.IntegerField(default=0)),
('created_at', models.DateTimeField(auto_now_add=True)), ('created_at', models.DateTimeField(auto_now_add=True)),
('started_at', models.DateTimeField(blank=True, null=True)), ('started_at', models.DateTimeField(blank=True, null=True)),
('completed_at', models.DateTimeField(blank=True, null=True)), ('completed_at', models.DateTimeField(blank=True, null=True)),
('output_asset', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='source_jobs', to='media_assets.mediaasset')),
('source_asset', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='transcode_jobs', to='media_assets.mediaasset')),
('preset', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='jobs', to='media_assets.transcodepreset')),
], ],
options={ options={
'ordering': ['priority', 'created_at'], 'ordering': ['-created_at'],
'indexes': [models.Index(fields=['status', 'priority'], name='media_asset_status_e6ac18_idx'), models.Index(fields=['created_at'], name='media_asset_created_ba3a46_idx'), models.Index(fields=['celery_task_id'], name='media_asset_celery__81a88e_idx')], },
),
migrations.CreateModel(
name='TranscodePreset',
fields=[
('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)),
('name', models.CharField(max_length=255)),
('description', models.TextField(blank=True, default='')),
('is_builtin', models.BooleanField(default=False)),
('container', models.CharField(max_length=255)),
('video_codec', models.CharField(max_length=255)),
('video_bitrate', models.CharField(blank=True, max_length=255, null=True)),
('video_crf', models.IntegerField(blank=True, default=None, null=True)),
('video_preset', models.CharField(blank=True, max_length=255, null=True)),
('resolution', models.CharField(blank=True, max_length=255, null=True)),
('framerate', models.FloatField(blank=True, default=None, null=True)),
('audio_codec', models.CharField(max_length=255)),
('audio_bitrate', models.CharField(blank=True, max_length=255, null=True)),
('audio_channels', models.IntegerField(blank=True, default=None, null=True)),
('audio_samplerate', models.IntegerField(blank=True, default=None, null=True)),
('extra_args', models.JSONField(blank=True, default=list)),
('created_at', models.DateTimeField(auto_now_add=True)),
('updated_at', models.DateTimeField(auto_now=True)),
],
options={
'ordering': ['-created_at'],
}, },
), ),
] ]

View File

@@ -0,0 +1,222 @@
"""
Django ORM Models - GENERATED FILE
Do not edit directly. Regenerate using modelgen.
"""
import uuid
from django.db import models
class AssetStatus(models.TextChoices):
PENDING = "pending", "Pending"
READY = "ready", "Ready"
ERROR = "error", "Error"
class JobStatus(models.TextChoices):
PENDING = "pending", "Pending"
RUNNING = "running", "Running"
PAUSED = "paused", "Paused"
COMPLETED = "completed", "Completed"
FAILED = "failed", "Failed"
CANCELLED = "cancelled", "Cancelled"
class RunType(models.TextChoices):
INITIAL = "initial", "Initial"
REPLAY = "replay", "Replay"
RETRY = "retry", "Retry"
class BrandSource(models.TextChoices):
OCR = "ocr", "Ocr"
VLM = "local_vlm", "Vlm"
CLOUD = "cloud_llm", "Cloud"
MANUAL = "manual", "Manual"
class SourceType(models.TextChoices):
CHUNK_JOB = "chunk_job", "Chunk Job"
UPLOAD = "upload", "Upload"
DEVICE = "device", "Device"
STREAM = "stream", "Stream"
class MediaAsset(models.Model):
"""A video/audio file registered in the system."""
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
filename = models.CharField(max_length=500)
file_path = models.CharField(max_length=1000)
status = models.CharField(max_length=20, choices=AssetStatus.choices, default=AssetStatus.PENDING)
error_message = models.TextField(blank=True, default='')
file_size = models.BigIntegerField(null=True, blank=True)
duration = models.FloatField(null=True, blank=True, default=None)
video_codec = models.CharField(max_length=255, null=True, blank=True)
audio_codec = models.CharField(max_length=255, null=True, blank=True)
width = models.IntegerField(null=True, blank=True, default=None)
height = models.IntegerField(null=True, blank=True, default=None)
framerate = models.FloatField(null=True, blank=True, default=None)
bitrate = models.BigIntegerField(null=True, blank=True)
properties = models.JSONField(default=dict, blank=True)
comments = models.TextField(blank=True, default='')
tags = models.JSONField(default=list, blank=True)
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
class Meta:
ordering = ["-created_at"]
def __str__(self):
return self.filename
class TranscodePreset(models.Model):
"""A reusable transcoding configuration (like Handbrake presets)."""
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
name = models.CharField(max_length=255)
description = models.TextField(blank=True, default='')
is_builtin = models.BooleanField(default=False)
container = models.CharField(max_length=255)
video_codec = models.CharField(max_length=255)
video_bitrate = models.CharField(max_length=255, null=True, blank=True)
video_crf = models.IntegerField(null=True, blank=True, default=None)
video_preset = models.CharField(max_length=255, null=True, blank=True)
resolution = models.CharField(max_length=255, null=True, blank=True)
framerate = models.FloatField(null=True, blank=True, default=None)
audio_codec = models.CharField(max_length=255)
audio_bitrate = models.CharField(max_length=255, null=True, blank=True)
audio_channels = models.IntegerField(null=True, blank=True, default=None)
audio_samplerate = models.IntegerField(null=True, blank=True, default=None)
extra_args = models.JSONField(default=list, blank=True)
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
class Meta:
ordering = ["-created_at"]
def __str__(self):
return self.name
class Job(models.Model):
"""A pipeline job."""
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
source_asset_id = models.UUIDField()
video_path = models.CharField(max_length=1000)
profile_name = models.CharField(max_length=255)
timeline_id = models.UUIDField(null=True, blank=True)
parent_id = models.UUIDField(null=True, blank=True)
run_type = models.CharField(max_length=20, choices=RunType.choices, default=RunType.INITIAL)
config_overrides = models.JSONField(default=dict, blank=True)
status = models.CharField(max_length=20, choices=JobStatus.choices, default=JobStatus.PENDING)
current_stage = models.CharField(max_length=255, null=True, blank=True)
progress = models.FloatField(default=0.0)
error_message = models.TextField(blank=True, default='')
total_detections = models.IntegerField(default=0)
brands_found = models.IntegerField(default=0)
cloud_llm_calls = models.IntegerField(default=0)
estimated_cost_usd = models.FloatField(default=0.0)
priority = models.IntegerField(default=0)
created_at = models.DateTimeField(auto_now_add=True)
started_at = models.DateTimeField(null=True, blank=True)
completed_at = models.DateTimeField(null=True, blank=True)
class Meta:
ordering = ["-created_at"]
def __str__(self):
return str(self.id)
class Timeline(models.Model):
"""A user-created selection of source material."""
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
name = models.CharField(max_length=255)
source_asset_id = models.UUIDField(null=True, blank=True)
chunk_paths = models.JSONField(default=list, blank=True)
profile_name = models.CharField(max_length=255)
status = models.CharField(max_length=255)
fps = models.FloatField(default=2.0)
frame_count = models.IntegerField(default=0)
source_ephemeral = models.BooleanField(default=False)
created_at = models.DateTimeField(auto_now_add=True)
class Meta:
ordering = ["-created_at"]
def __str__(self):
return self.name
class Checkpoint(models.Model):
"""A snapshot of pipeline state on a timeline."""
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
timeline_id = models.UUIDField()
job_id = models.UUIDField(null=True, blank=True)
parent_id = models.UUIDField(null=True, blank=True)
stage_name = models.CharField(max_length=255)
config_overrides = models.JSONField(default=dict, blank=True)
stats = models.JSONField(default=dict, blank=True)
is_scenario = models.BooleanField(default=False)
scenario_label = models.CharField(max_length=255)
created_at = models.DateTimeField(auto_now_add=True)
class Meta:
ordering = ["-created_at"]
def __str__(self):
return str(self.id)
class StageOutput(models.Model):
"""Output of a single stage within a job."""
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
job_id = models.UUIDField()
timeline_id = models.UUIDField()
stage_name = models.CharField(max_length=255)
checkpoint_id = models.UUIDField(null=True, blank=True)
output = models.JSONField(default=dict, blank=True)
created_at = models.DateTimeField(auto_now_add=True)
class Meta:
ordering = ["-created_at"]
def __str__(self):
return str(self.id)
class Brand(models.Model):
"""A brand discovered or registered in the system."""
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
canonical_name = models.CharField(max_length=255)
aliases = models.JSONField(default=list, blank=True)
source = models.CharField(max_length=20, choices=BrandSource.choices, default=BrandSource.OCR)
confirmed = models.BooleanField(default=False)
airings = models.JSONField(default=list, blank=True)
total_airings = models.IntegerField(default=0)
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
class Meta:
ordering = ["-created_at"]
def __str__(self):
return str(self.id)
class Profile(models.Model):
"""A content type profile."""
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
name = models.CharField(max_length=255)
pipeline = models.JSONField(default=dict, blank=True)
configs = models.JSONField(default=dict, blank=True)
class Meta:
pass
def __str__(self):
return self.name

View File

@@ -7,7 +7,7 @@ from pathlib import Path
import environ import environ
BASE_DIR = Path(__file__).resolve().parent.parent BASE_DIR = Path(__file__).resolve().parent.parent.parent
env = environ.Env( env = environ.Env(
DEBUG=(bool, False), DEBUG=(bool, False),
@@ -27,7 +27,7 @@ INSTALLED_APPS = [
"django.contrib.sessions", "django.contrib.sessions",
"django.contrib.messages", "django.contrib.messages",
"django.contrib.staticfiles", "django.contrib.staticfiles",
"mpr.media_assets", "admin.mpr.media_assets",
] ]
MIDDLEWARE = [ MIDDLEWARE = [
@@ -40,7 +40,7 @@ MIDDLEWARE = [
"django.middleware.clickjacking.XFrameOptionsMiddleware", "django.middleware.clickjacking.XFrameOptionsMiddleware",
] ]
ROOT_URLCONF = "mpr.urls" ROOT_URLCONF = "admin.mpr.urls"
TEMPLATES = [ TEMPLATES = [
{ {
@@ -57,7 +57,7 @@ TEMPLATES = [
}, },
] ]
WSGI_APPLICATION = "mpr.wsgi.application" WSGI_APPLICATION = "admin.mpr.wsgi.application"
# Database # Database
DATABASE_URL = env("DATABASE_URL", default="sqlite:///db.sqlite3") DATABASE_URL = env("DATABASE_URL", default="sqlite:///db.sqlite3")

View File

@@ -11,6 +11,6 @@ import os
from django.core.wsgi import get_wsgi_application from django.core.wsgi import get_wsgi_application
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'mpr.settings') os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'admin.mpr.settings')
application = get_wsgi_application() application = get_wsgi_application()

View File

@@ -1,54 +0,0 @@
"""
FastAPI dependencies.
Provides database sessions, settings, and common dependencies.
"""
import os
from functools import lru_cache
from typing import Generator
import django
from django.conf import settings as django_settings
# Initialize Django
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "mpr.settings")
django.setup()
from mpr.media_assets.models import MediaAsset, TranscodeJob, TranscodePreset
@lru_cache
def get_settings():
"""Get Django settings."""
return django_settings
def get_asset(asset_id: str) -> MediaAsset:
"""Get asset by ID or raise 404."""
from fastapi import HTTPException
try:
return MediaAsset.objects.get(id=asset_id)
except MediaAsset.DoesNotExist:
raise HTTPException(status_code=404, detail="Asset not found")
def get_preset(preset_id: str) -> TranscodePreset:
"""Get preset by ID or raise 404."""
from fastapi import HTTPException
try:
return TranscodePreset.objects.get(id=preset_id)
except TranscodePreset.DoesNotExist:
raise HTTPException(status_code=404, detail="Preset not found")
def get_job(job_id: str) -> TranscodeJob:
"""Get job by ID or raise 404."""
from fastapi import HTTPException
try:
return TranscodeJob.objects.get(id=job_id)
except TranscodeJob.DoesNotExist:
raise HTTPException(status_code=404, detail="Job not found")

View File

@@ -1,251 +0,0 @@
"""
GraphQL API using graphene, mounted on FastAPI/Starlette.
Provides the same data as the REST API but via GraphQL queries and mutations.
Uses Django ORM directly for data access.
Types are generated from schema/ via modelgen — see api/schema/graphql.py.
"""
import os
import graphene
from api.schema.graphql import (
CreateJobInput,
MediaAssetType,
ScanResultType,
SystemStatusType,
TranscodeJobType,
TranscodePresetType,
)
from core.storage import BUCKET_IN, list_objects
# Media extensions (same as assets route)
VIDEO_EXTS = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv", ".wmv", ".m4v"}
AUDIO_EXTS = {".mp3", ".wav", ".flac", ".aac", ".ogg", ".m4a"}
MEDIA_EXTS = VIDEO_EXTS | AUDIO_EXTS
# ---------------------------------------------------------------------------
# Queries
# ---------------------------------------------------------------------------
class Query(graphene.ObjectType):
assets = graphene.List(
MediaAssetType,
status=graphene.String(),
search=graphene.String(),
)
asset = graphene.Field(MediaAssetType, id=graphene.UUID(required=True))
jobs = graphene.List(
TranscodeJobType,
status=graphene.String(),
source_asset_id=graphene.UUID(),
)
job = graphene.Field(TranscodeJobType, id=graphene.UUID(required=True))
presets = graphene.List(TranscodePresetType)
system_status = graphene.Field(SystemStatusType)
def resolve_assets(self, info, status=None, search=None):
from mpr.media_assets.models import MediaAsset
qs = MediaAsset.objects.all()
if status:
qs = qs.filter(status=status)
if search:
qs = qs.filter(filename__icontains=search)
return qs
def resolve_asset(self, info, id):
from mpr.media_assets.models import MediaAsset
try:
return MediaAsset.objects.get(id=id)
except MediaAsset.DoesNotExist:
return None
def resolve_jobs(self, info, status=None, source_asset_id=None):
from mpr.media_assets.models import TranscodeJob
qs = TranscodeJob.objects.all()
if status:
qs = qs.filter(status=status)
if source_asset_id:
qs = qs.filter(source_asset_id=source_asset_id)
return qs
def resolve_job(self, info, id):
from mpr.media_assets.models import TranscodeJob
try:
return TranscodeJob.objects.get(id=id)
except TranscodeJob.DoesNotExist:
return None
def resolve_presets(self, info):
from mpr.media_assets.models import TranscodePreset
return TranscodePreset.objects.all()
def resolve_system_status(self, info):
return {"status": "ok", "version": "0.1.0"}
# ---------------------------------------------------------------------------
# Mutations
# ---------------------------------------------------------------------------
class ScanMediaFolder(graphene.Mutation):
class Arguments:
pass
Output = ScanResultType
def mutate(self, info):
from mpr.media_assets.models import MediaAsset
objects = list_objects(BUCKET_IN, extensions=MEDIA_EXTS)
existing = set(MediaAsset.objects.values_list("filename", flat=True))
registered = []
skipped = []
for obj in objects:
if obj["filename"] in existing:
skipped.append(obj["filename"])
continue
try:
MediaAsset.objects.create(
filename=obj["filename"],
file_path=obj["key"],
file_size=obj["size"],
)
registered.append(obj["filename"])
except Exception:
pass
return ScanResultType(
found=len(objects),
registered=len(registered),
skipped=len(skipped),
files=registered,
)
class CreateJob(graphene.Mutation):
class Arguments:
input = CreateJobInput(required=True)
Output = TranscodeJobType
def mutate(self, info, input):
from pathlib import Path
from mpr.media_assets.models import MediaAsset, TranscodeJob, TranscodePreset
try:
source = MediaAsset.objects.get(id=input.source_asset_id)
except MediaAsset.DoesNotExist:
raise Exception("Source asset not found")
preset = None
preset_snapshot = {}
if input.preset_id:
try:
preset = TranscodePreset.objects.get(id=input.preset_id)
preset_snapshot = {
"name": preset.name,
"container": preset.container,
"video_codec": preset.video_codec,
"audio_codec": preset.audio_codec,
}
except TranscodePreset.DoesNotExist:
raise Exception("Preset not found")
if not preset and not input.trim_start and not input.trim_end:
raise Exception("Must specify preset_id or trim_start/trim_end")
output_filename = input.output_filename
if not output_filename:
stem = Path(source.filename).stem
ext = preset_snapshot.get("container", "mp4") if preset else "mp4"
output_filename = f"{stem}_output.{ext}"
job = TranscodeJob.objects.create(
source_asset_id=source.id,
preset_id=preset.id if preset else None,
preset_snapshot=preset_snapshot,
trim_start=input.trim_start,
trim_end=input.trim_end,
output_filename=output_filename,
output_path=output_filename,
priority=input.priority or 0,
)
# Dispatch
executor_mode = os.environ.get("MPR_EXECUTOR", "local")
if executor_mode == "lambda":
from task.executor import get_executor
get_executor().run(
job_id=str(job.id),
source_path=source.file_path,
output_path=output_filename,
preset=preset_snapshot or None,
trim_start=input.trim_start,
trim_end=input.trim_end,
duration=source.duration,
)
else:
from task.tasks import run_transcode_job
result = run_transcode_job.delay(
job_id=str(job.id),
source_key=source.file_path,
output_key=output_filename,
preset=preset_snapshot or None,
trim_start=input.trim_start,
trim_end=input.trim_end,
duration=source.duration,
)
job.celery_task_id = result.id
job.save(update_fields=["celery_task_id"])
return job
class CancelJob(graphene.Mutation):
class Arguments:
id = graphene.UUID(required=True)
Output = TranscodeJobType
def mutate(self, info, id):
from mpr.media_assets.models import TranscodeJob
try:
job = TranscodeJob.objects.get(id=id)
except TranscodeJob.DoesNotExist:
raise Exception("Job not found")
if job.status not in ("pending", "processing"):
raise Exception(f"Cannot cancel job with status: {job.status}")
job.status = "cancelled"
job.save(update_fields=["status"])
return job
class Mutation(graphene.ObjectType):
scan_media_folder = ScanMediaFolder.Field()
create_job = CreateJob.Field()
cancel_job = CancelJob.Field()
# ---------------------------------------------------------------------------
# Schema
# ---------------------------------------------------------------------------
schema = graphene.Schema(query=Query, mutation=Mutation)

View File

@@ -1,61 +0,0 @@
"""
MPR FastAPI Application
Main entry point for the REST API.
"""
import os
import sys
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Initialize Django before importing models
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "mpr.settings")
import django
django.setup()
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from api.graphql import schema as graphql_schema
from api.routes import assets_router, jobs_router, presets_router, system_router
from starlette_graphene3 import GraphQLApp, make_graphiql_handler
app = FastAPI(
title="MPR API",
description="Media Processor REST API",
version="0.1.0",
docs_url="/docs",
redoc_url="/redoc",
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["http://mpr.local.ar", "http://localhost:5173"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Routes - all under /api prefix
app.include_router(system_router, prefix="/api")
app.include_router(assets_router, prefix="/api")
app.include_router(presets_router, prefix="/api")
app.include_router(jobs_router, prefix="/api")
# GraphQL
app.mount("/graphql", GraphQLApp(schema=graphql_schema, on_get=make_graphiql_handler()))
@app.get("/")
def root():
"""API root."""
return {
"name": "MPR API",
"version": "0.1.0",
"docs": "/docs",
}

View File

@@ -1,8 +0,0 @@
"""API Routes."""
from .assets import router as assets_router
from .jobs import router as jobs_router
from .presets import router as presets_router
from .system import router as system_router
__all__ = ["assets_router", "jobs_router", "presets_router", "system_router"]

View File

@@ -1,117 +0,0 @@
"""
Asset endpoints - media file registration and metadata.
"""
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query
from api.deps import get_asset
from api.schema import AssetCreate, AssetResponse, AssetUpdate
from core.storage import BUCKET_IN, list_objects
router = APIRouter(prefix="/assets", tags=["assets"])
# Supported media extensions
VIDEO_EXTS = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv", ".wmv", ".m4v"}
AUDIO_EXTS = {".mp3", ".wav", ".flac", ".aac", ".ogg", ".m4a"}
MEDIA_EXTS = VIDEO_EXTS | AUDIO_EXTS
@router.post("/", response_model=AssetResponse, status_code=201)
def create_asset(data: AssetCreate):
"""Register a media file as an asset."""
from mpr.media_assets.models import MediaAsset
asset = MediaAsset.objects.create(
filename=data.filename or data.file_path.split("/")[-1],
file_path=data.file_path,
file_size=data.file_size,
)
return asset
@router.get("/", response_model=list[AssetResponse])
def list_assets(
status: Optional[str] = Query(None, description="Filter by status"),
limit: int = Query(50, ge=1, le=100),
offset: int = Query(0, ge=0),
):
"""List assets with optional filtering."""
from mpr.media_assets.models import MediaAsset
qs = MediaAsset.objects.all()
if status:
qs = qs.filter(status=status)
return list(qs[offset : offset + limit])
@router.get("/{asset_id}", response_model=AssetResponse)
def get_asset_detail(asset_id: UUID, asset=Depends(get_asset)):
"""Get asset details."""
return asset
@router.patch("/{asset_id}", response_model=AssetResponse)
def update_asset(asset_id: UUID, data: AssetUpdate, asset=Depends(get_asset)):
"""Update asset metadata (comments, tags)."""
update_fields = []
if data.comments is not None:
asset.comments = data.comments
update_fields.append("comments")
if data.tags is not None:
asset.tags = data.tags
update_fields.append("tags")
if update_fields:
asset.save(update_fields=update_fields)
return asset
@router.delete("/{asset_id}", status_code=204)
def delete_asset(asset_id: UUID, asset=Depends(get_asset)):
"""Delete an asset."""
asset.delete()
@router.post("/scan", response_model=dict)
def scan_media_folder():
"""
Scan the S3 media-in bucket for new video/audio files and register them as assets.
"""
from mpr.media_assets.models import MediaAsset
# List objects from S3 bucket
objects = list_objects(BUCKET_IN, extensions=MEDIA_EXTS)
# Get existing filenames to avoid duplicates
existing_filenames = set(MediaAsset.objects.values_list("filename", flat=True))
registered_files = []
skipped_files = []
for obj in objects:
if obj["filename"] in existing_filenames:
skipped_files.append(obj["filename"])
continue
try:
MediaAsset.objects.create(
filename=obj["filename"],
file_path=obj["key"],
file_size=obj["size"],
)
registered_files.append(obj["filename"])
except Exception as e:
print(f"Error registering {obj['filename']}: {e}")
return {
"found": len(objects),
"registered": len(registered_files),
"skipped": len(skipped_files),
"files": registered_files,
}

View File

@@ -1,233 +0,0 @@
"""
Job endpoints - transcode/trim job management.
"""
import os
from pathlib import Path
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, Header, HTTPException, Query
from api.deps import get_asset, get_job, get_preset
from api.schema import JobCreate, JobResponse
router = APIRouter(prefix="/jobs", tags=["jobs"])
CALLBACK_API_KEY = os.environ.get("CALLBACK_API_KEY", "")
@router.post("/", response_model=JobResponse, status_code=201)
def create_job(data: JobCreate):
"""
Create a transcode or trim job.
- With preset_id: Full transcode using preset settings
- Without preset_id but with trim_start/end: Trim only (stream copy)
"""
from mpr.media_assets.models import MediaAsset, TranscodeJob, TranscodePreset
# Get source asset
try:
source = MediaAsset.objects.get(id=data.source_asset_id)
except MediaAsset.DoesNotExist:
raise HTTPException(status_code=404, detail="Source asset not found")
# Get preset if specified
preset = None
preset_snapshot = {}
if data.preset_id:
try:
preset = TranscodePreset.objects.get(id=data.preset_id)
preset_snapshot = {
"name": preset.name,
"container": preset.container,
"video_codec": preset.video_codec,
"video_bitrate": preset.video_bitrate,
"video_crf": preset.video_crf,
"video_preset": preset.video_preset,
"resolution": preset.resolution,
"framerate": preset.framerate,
"audio_codec": preset.audio_codec,
"audio_bitrate": preset.audio_bitrate,
"audio_channels": preset.audio_channels,
"audio_samplerate": preset.audio_samplerate,
"extra_args": preset.extra_args,
}
except TranscodePreset.DoesNotExist:
raise HTTPException(status_code=404, detail="Preset not found")
# Validate trim-only job
if not preset and not data.trim_start and not data.trim_end:
raise HTTPException(
status_code=400, detail="Must specify preset_id or trim_start/trim_end"
)
# Generate output filename - stored as S3 key in output bucket
output_filename = data.output_filename
if not output_filename:
stem = Path(source.filename).stem
ext = preset_snapshot.get("container", "mp4") if preset else "mp4"
output_filename = f"{stem}_output.{ext}"
# Create job
job = TranscodeJob.objects.create(
source_asset_id=source.id,
preset_id=preset.id if preset else None,
preset_snapshot=preset_snapshot,
trim_start=data.trim_start,
trim_end=data.trim_end,
output_filename=output_filename,
output_path=output_filename, # S3 key in output bucket
priority=data.priority or 0,
)
# Dispatch based on executor mode
executor_mode = os.environ.get("MPR_EXECUTOR", "local")
if executor_mode == "lambda":
_dispatch_lambda(job, source, preset_snapshot)
else:
_dispatch_celery(job, source, preset_snapshot)
return job
def _dispatch_celery(job, source, preset_snapshot):
"""Dispatch job to Celery worker."""
from task.tasks import run_transcode_job
result = run_transcode_job.delay(
job_id=str(job.id),
source_key=source.file_path,
output_key=job.output_filename,
preset=preset_snapshot or None,
trim_start=job.trim_start,
trim_end=job.trim_end,
duration=source.duration,
)
job.celery_task_id = result.id
job.save(update_fields=["celery_task_id"])
def _dispatch_lambda(job, source, preset_snapshot):
"""Dispatch job to AWS Step Functions."""
from task.executor import get_executor
executor = get_executor()
executor.run(
job_id=str(job.id),
source_path=source.file_path,
output_path=job.output_filename,
preset=preset_snapshot or None,
trim_start=job.trim_start,
trim_end=job.trim_end,
duration=source.duration,
)
@router.post("/{job_id}/callback")
def job_callback(
job_id: UUID,
payload: dict,
x_api_key: Optional[str] = Header(None),
):
"""
Callback endpoint for Lambda to report job completion.
Protected by API key.
"""
if CALLBACK_API_KEY and x_api_key != CALLBACK_API_KEY:
raise HTTPException(status_code=403, detail="Invalid API key")
from django.utils import timezone
from mpr.media_assets.models import TranscodeJob
try:
job = TranscodeJob.objects.get(id=job_id)
except TranscodeJob.DoesNotExist:
raise HTTPException(status_code=404, detail="Job not found")
status = payload.get("status", "failed")
job.status = status
job.progress = 100.0 if status == "completed" else job.progress
update_fields = ["status", "progress"]
if payload.get("error"):
job.error_message = payload["error"]
update_fields.append("error_message")
if status == "completed":
job.completed_at = timezone.now()
update_fields.append("completed_at")
elif status == "failed":
job.completed_at = timezone.now()
update_fields.append("completed_at")
job.save(update_fields=update_fields)
return {"ok": True}
@router.get("/", response_model=list[JobResponse])
def list_jobs(
status: Optional[str] = Query(None, description="Filter by status"),
source_asset_id: Optional[UUID] = Query(None),
limit: int = Query(50, ge=1, le=100),
offset: int = Query(0, ge=0),
):
"""List jobs with optional filtering."""
from mpr.media_assets.models import TranscodeJob
qs = TranscodeJob.objects.all()
if status:
qs = qs.filter(status=status)
if source_asset_id:
qs = qs.filter(source_asset_id=source_asset_id)
return list(qs[offset : offset + limit])
@router.get("/{job_id}", response_model=JobResponse)
def get_job_detail(job_id: UUID, job=Depends(get_job)):
"""Get job details including progress."""
return job
@router.get("/{job_id}/progress")
def get_job_progress(job_id: UUID, job=Depends(get_job)):
"""Get real-time job progress."""
return {
"job_id": str(job.id),
"status": job.status,
"progress": job.progress,
"current_frame": job.current_frame,
"current_time": job.current_time,
"speed": job.speed,
}
@router.post("/{job_id}/cancel", response_model=JobResponse)
def cancel_job(job_id: UUID, job=Depends(get_job)):
"""Cancel a pending or processing job."""
if job.status not in ("pending", "processing"):
raise HTTPException(
status_code=400, detail=f"Cannot cancel job with status: {job.status}"
)
job.status = "cancelled"
job.save(update_fields=["status"])
return job
@router.post("/{job_id}/retry", response_model=JobResponse)
def retry_job(job_id: UUID, job=Depends(get_job)):
"""Retry a failed job."""
if job.status != "failed":
raise HTTPException(status_code=400, detail="Only failed jobs can be retried")
job.status = "pending"
job.progress = 0
job.error_message = None
job.save(update_fields=["status", "progress", "error_message"])
return job

View File

@@ -1,100 +0,0 @@
"""
Preset endpoints - transcode configuration templates.
"""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from api.deps import get_preset
from api.schema import PresetCreate, PresetResponse, PresetUpdate
router = APIRouter(prefix="/presets", tags=["presets"])
@router.post("/", response_model=PresetResponse, status_code=201)
def create_preset(data: PresetCreate):
"""Create a custom preset."""
from mpr.media_assets.models import TranscodePreset
preset = TranscodePreset.objects.create(
name=data.name,
description=data.description or "",
container=data.container or "mp4",
video_codec=data.video_codec or "libx264",
video_bitrate=data.video_bitrate,
video_crf=data.video_crf,
video_preset=data.video_preset,
resolution=data.resolution,
framerate=data.framerate,
audio_codec=data.audio_codec or "aac",
audio_bitrate=data.audio_bitrate,
audio_channels=data.audio_channels,
audio_samplerate=data.audio_samplerate,
extra_args=data.extra_args or [],
is_builtin=False,
)
return preset
@router.get("/", response_model=list[PresetResponse])
def list_presets(include_builtin: bool = True):
"""List all presets."""
from mpr.media_assets.models import TranscodePreset
qs = TranscodePreset.objects.all()
if not include_builtin:
qs = qs.filter(is_builtin=False)
return list(qs)
@router.get("/{preset_id}", response_model=PresetResponse)
def get_preset_detail(preset_id: UUID, preset=Depends(get_preset)):
"""Get preset details."""
return preset
@router.patch("/{preset_id}", response_model=PresetResponse)
def update_preset(preset_id: UUID, data: PresetUpdate, preset=Depends(get_preset)):
"""Update a custom preset. Builtin presets cannot be modified."""
if preset.is_builtin:
raise HTTPException(status_code=403, detail="Cannot modify builtin preset")
update_fields = []
for field in [
"name",
"description",
"container",
"video_codec",
"video_bitrate",
"video_crf",
"video_preset",
"resolution",
"framerate",
"audio_codec",
"audio_bitrate",
"audio_channels",
"audio_samplerate",
"extra_args",
]:
value = getattr(data, field, None)
if value is not None:
setattr(preset, field, value)
update_fields.append(field)
if update_fields:
preset.save(update_fields=update_fields)
return preset
@router.delete("/{preset_id}", status_code=204)
def delete_preset(preset_id: UUID, preset=Depends(get_preset)):
"""Delete a custom preset. Builtin presets cannot be deleted."""
if preset.is_builtin:
raise HTTPException(status_code=403, detail="Cannot delete builtin preset")
preset.delete()

View File

@@ -1,51 +0,0 @@
"""
System endpoints - health checks and FFmpeg capabilities.
"""
from fastapi import APIRouter
from core.ffmpeg import get_decoders, get_encoders, get_formats
router = APIRouter(prefix="/system", tags=["system"])
@router.get("/health")
def health_check():
"""Health check endpoint."""
return {"status": "healthy"}
@router.get("/status")
def system_status():
"""System status for UI."""
return {"status": "ok", "version": "0.1.0"}
@router.get("/worker")
def worker_status():
"""Worker status from gRPC."""
try:
from rpc.client import get_client
client = get_client()
status = client.get_worker_status()
if status:
return status
return {"available": False, "error": "No response from worker"}
except Exception as e:
return {"available": False, "error": str(e)}
@router.get("/ffmpeg/codecs")
def ffmpeg_codecs():
"""Get available FFmpeg encoders and decoders."""
return {
"encoders": get_encoders(),
"decoders": get_decoders(),
}
@router.get("/ffmpeg/formats")
def ffmpeg_formats():
"""Get available FFmpeg muxers and demuxers."""
return get_formats()

View File

@@ -1,10 +0,0 @@
"""API Schemas - GENERATED FILE"""
from .base import BaseSchema
from .asset import AssetCreate, AssetUpdate, AssetResponse
from .asset import AssetStatus
from .preset import PresetCreate, PresetUpdate, PresetResponse
from .job import JobCreate, JobUpdate, JobResponse
from .job import JobStatus
__all__ = ["BaseSchema", "AssetCreate", "AssetUpdate", "AssetResponse", "AssetStatus", "PresetCreate", "PresetUpdate", "PresetResponse", "JobCreate", "JobUpdate", "JobResponse", "JobStatus"]

View File

@@ -1,70 +0,0 @@
"""MediaAsset Schemas - GENERATED FILE"""
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
from .base import BaseSchema
class AssetStatus(str, Enum):
PENDING = "pending"
READY = "ready"
ERROR = "error"
class AssetCreate(BaseSchema):
"""AssetCreate schema."""
filename: str
file_path: str
file_size: Optional[int] = None
duration: Optional[float] = None
video_codec: Optional[str] = None
audio_codec: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
framerate: Optional[float] = None
bitrate: Optional[int] = None
properties: Dict[str, Any]
comments: str = ""
tags: List[str] = Field(default_factory=list)
class AssetUpdate(BaseSchema):
"""AssetUpdate schema."""
filename: Optional[str] = None
file_path: Optional[str] = None
status: Optional[AssetStatus] = None
error_message: Optional[str] = None
file_size: Optional[int] = None
duration: Optional[float] = None
video_codec: Optional[str] = None
audio_codec: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
framerate: Optional[float] = None
bitrate: Optional[int] = None
properties: Optional[Dict[str, Any]] = None
comments: Optional[str] = None
tags: Optional[List[str]] = None
class AssetResponse(BaseSchema):
"""AssetResponse schema."""
id: UUID
filename: str
file_path: str
status: AssetStatus = "AssetStatus.PENDING"
error_message: Optional[str] = None
file_size: Optional[int] = None
duration: Optional[float] = None
video_codec: Optional[str] = None
audio_codec: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
framerate: Optional[float] = None
bitrate: Optional[int] = None
properties: Dict[str, Any]
comments: str = ""
tags: List[str] = Field(default_factory=list)
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None

View File

@@ -1,8 +0,0 @@
"""Pydantic Base Schema - GENERATED FILE"""
from pydantic import BaseModel, ConfigDict
class BaseSchema(BaseModel):
"""Base schema with ORM mode."""
model_config = ConfigDict(from_attributes=True)

View File

@@ -1,129 +0,0 @@
"""
Graphene Types - GENERATED FILE
Do not edit directly. Regenerate using modelgen.
"""
import graphene
class AssetStatus(graphene.Enum):
PENDING = "pending"
READY = "ready"
ERROR = "error"
class JobStatus(graphene.Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class MediaAssetType(graphene.ObjectType):
"""A video/audio file registered in the system."""
id = graphene.UUID()
filename = graphene.String()
file_path = graphene.String()
status = graphene.String()
error_message = graphene.String()
file_size = graphene.Int()
duration = graphene.Float()
video_codec = graphene.String()
audio_codec = graphene.String()
width = graphene.Int()
height = graphene.Int()
framerate = graphene.Float()
bitrate = graphene.Int()
properties = graphene.JSONString()
comments = graphene.String()
tags = graphene.List(graphene.String)
created_at = graphene.DateTime()
updated_at = graphene.DateTime()
class TranscodePresetType(graphene.ObjectType):
"""A reusable transcoding configuration (like Handbrake presets)."""
id = graphene.UUID()
name = graphene.String()
description = graphene.String()
is_builtin = graphene.Boolean()
container = graphene.String()
video_codec = graphene.String()
video_bitrate = graphene.String()
video_crf = graphene.Int()
video_preset = graphene.String()
resolution = graphene.String()
framerate = graphene.Float()
audio_codec = graphene.String()
audio_bitrate = graphene.String()
audio_channels = graphene.Int()
audio_samplerate = graphene.Int()
extra_args = graphene.List(graphene.String)
created_at = graphene.DateTime()
updated_at = graphene.DateTime()
class TranscodeJobType(graphene.ObjectType):
"""A transcoding or trimming job in the queue."""
id = graphene.UUID()
source_asset_id = graphene.UUID()
preset_id = graphene.UUID()
preset_snapshot = graphene.JSONString()
trim_start = graphene.Float()
trim_end = graphene.Float()
output_filename = graphene.String()
output_path = graphene.String()
output_asset_id = graphene.UUID()
status = graphene.String()
progress = graphene.Float()
current_frame = graphene.Int()
current_time = graphene.Float()
speed = graphene.String()
error_message = graphene.String()
celery_task_id = graphene.String()
execution_arn = graphene.String()
priority = graphene.Int()
created_at = graphene.DateTime()
started_at = graphene.DateTime()
completed_at = graphene.DateTime()
class CreateJobInput(graphene.InputObjectType):
"""Request body for creating a transcode/trim job."""
source_asset_id = graphene.UUID(required=True)
preset_id = graphene.UUID()
trim_start = graphene.Float()
trim_end = graphene.Float()
output_filename = graphene.String()
priority = graphene.Int(default_value=0)
class SystemStatusType(graphene.ObjectType):
"""System status response."""
status = graphene.String()
version = graphene.String()
class ScanResultType(graphene.ObjectType):
"""Result of scanning the media input bucket."""
found = graphene.Int()
registered = graphene.Int()
skipped = graphene.Int()
files = graphene.List(graphene.String)
class WorkerStatusType(graphene.ObjectType):
"""Worker health and capabilities."""
available = graphene.Boolean()
active_jobs = graphene.Int()
supported_codecs = graphene.List(graphene.String)
gpu_available = graphene.Boolean()

View File

@@ -1,83 +0,0 @@
"""TranscodeJob Schemas - GENERATED FILE"""
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
from .base import BaseSchema
class JobStatus(str, Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class JobCreate(BaseSchema):
"""JobCreate schema."""
source_asset_id: UUID
preset_id: Optional[UUID] = None
preset_snapshot: Dict[str, Any]
trim_start: Optional[float] = None
trim_end: Optional[float] = None
output_filename: str = ""
output_path: Optional[str] = None
output_asset_id: Optional[UUID] = None
progress: float = 0.0
current_frame: Optional[int] = None
current_time: Optional[float] = None
speed: Optional[str] = None
celery_task_id: Optional[str] = None
execution_arn: Optional[str] = None
priority: int = 0
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
class JobUpdate(BaseSchema):
"""JobUpdate schema."""
source_asset_id: Optional[UUID] = None
preset_id: Optional[UUID] = None
preset_snapshot: Optional[Dict[str, Any]] = None
trim_start: Optional[float] = None
trim_end: Optional[float] = None
output_filename: Optional[str] = None
output_path: Optional[str] = None
output_asset_id: Optional[UUID] = None
status: Optional[JobStatus] = None
progress: Optional[float] = None
current_frame: Optional[int] = None
current_time: Optional[float] = None
speed: Optional[str] = None
error_message: Optional[str] = None
celery_task_id: Optional[str] = None
execution_arn: Optional[str] = None
priority: Optional[int] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
class JobResponse(BaseSchema):
"""JobResponse schema."""
id: UUID
source_asset_id: UUID
preset_id: Optional[UUID] = None
preset_snapshot: Dict[str, Any]
trim_start: Optional[float] = None
trim_end: Optional[float] = None
output_filename: str = ""
output_path: Optional[str] = None
output_asset_id: Optional[UUID] = None
status: JobStatus = "JobStatus.PENDING"
progress: float = 0.0
current_frame: Optional[int] = None
current_time: Optional[float] = None
speed: Optional[str] = None
error_message: Optional[str] = None
celery_task_id: Optional[str] = None
execution_arn: Optional[str] = None
priority: int = 0
created_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None

View File

@@ -1,66 +0,0 @@
"""TranscodePreset Schemas - GENERATED FILE"""
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
from .base import BaseSchema
class PresetCreate(BaseSchema):
"""PresetCreate schema."""
name: str
description: str = ""
is_builtin: bool = False
container: str = "mp4"
video_codec: str = "libx264"
video_bitrate: Optional[str] = None
video_crf: Optional[int] = None
video_preset: Optional[str] = None
resolution: Optional[str] = None
framerate: Optional[float] = None
audio_codec: str = "aac"
audio_bitrate: Optional[str] = None
audio_channels: Optional[int] = None
audio_samplerate: Optional[int] = None
extra_args: List[str] = Field(default_factory=list)
class PresetUpdate(BaseSchema):
"""PresetUpdate schema."""
name: Optional[str] = None
description: Optional[str] = None
is_builtin: Optional[bool] = None
container: Optional[str] = None
video_codec: Optional[str] = None
video_bitrate: Optional[str] = None
video_crf: Optional[int] = None
video_preset: Optional[str] = None
resolution: Optional[str] = None
framerate: Optional[float] = None
audio_codec: Optional[str] = None
audio_bitrate: Optional[str] = None
audio_channels: Optional[int] = None
audio_samplerate: Optional[int] = None
extra_args: Optional[List[str]] = None
class PresetResponse(BaseSchema):
"""PresetResponse schema."""
id: UUID
name: str
description: str = ""
is_builtin: bool = False
container: str = "mp4"
video_codec: str = "libx264"
video_bitrate: Optional[str] = None
video_crf: Optional[int] = None
video_preset: Optional[str] = None
resolution: Optional[str] = None
framerate: Optional[float] = None
audio_codec: str = "aac"
audio_bitrate: Optional[str] = None
audio_channels: Optional[int] = None
audio_samplerate: Optional[int] = None
extra_args: List[str] = Field(default_factory=list)
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None

View File

@@ -0,0 +1,22 @@
"""
Detection API — aggregated router.
Combines all detect sub-routers into a single include for main.py.
"""
from fastapi import APIRouter
from .sources import router as sources_router
from .run import router as run_router
from .sse import router as sse_router
from .replay import router as replay_router
from .config import router as config_router
from .timeline import router as timeline_router
router = APIRouter()
router.include_router(sources_router)
router.include_router(run_router)
router.include_router(sse_router)
router.include_router(replay_router)
router.include_router(config_router)
router.include_router(timeline_router)

203
core/api/detect/config.py Normal file
View File

@@ -0,0 +1,203 @@
"""
Runtime config endpoint for the detection pipeline.
GET /detect/config — read current config
PUT /detect/config — update config (takes effect on next run)
GET /detect/config/stages — list stage palette with config fields
"""
from __future__ import annotations
import logging
from fastapi import APIRouter
from pydantic import BaseModel
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/detect", tags=["detect"])
# In-memory config — persists until server restart.
# Phase 12+ moves this to DB.
_runtime_config: dict = {}
class ConfigUpdate(BaseModel):
detection: dict | None = None
ocr: dict | None = None
resolver: dict | None = None
escalation: dict | None = None
preprocessing: dict | None = None
class StageOutputHintInfo(BaseModel):
key: str
type: str
label: str = ""
default_opacity: float = 0.5
src_format: str = "png"
class TransformOptionInfo(BaseModel):
key: str
type: str
default: object = False
label: str = ""
description: str = ""
class StageConfigInfo(BaseModel):
name: str
label: str
description: str
category: str
config_fields: list[dict]
output_hints: list[StageOutputHintInfo] = []
accepted_transforms: list[TransformOptionInfo] = []
reads: list[str]
writes: list[str]
@router.get("/config")
def read_config():
return _runtime_config
@router.put("/config")
def write_config(update: ConfigUpdate):
changes = update.model_dump(exclude_none=True)
for section, values in changes.items():
if section not in _runtime_config:
_runtime_config[section] = {}
_runtime_config[section].update(values)
logger.info("Config updated: %s", list(changes.keys()))
return _runtime_config
@router.get("/config/profiles")
def get_profiles():
"""List available detection profiles."""
from core.detect.profile import list_profiles as _list
return [{"name": name} for name in _list()]
@router.get("/config/profiles/{profile_name}/pipeline")
def get_pipeline_config(profile_name: str):
"""Return the pipeline composition for a profile."""
from core.detect.profile import get_profile
from fastapi import HTTPException
try:
profile = get_profile(profile_name)
except ValueError:
raise HTTPException(status_code=404, detail=f"Unknown profile: {profile_name}")
return profile["pipeline"]
class UpdateEdgeTransformRequest(BaseModel):
profile_name: str = "soccer_broadcast"
source_stage: str
target_stage: str
transform: dict
@router.put("/config/edge-transform")
def update_edge_transform(req: UpdateEdgeTransformRequest):
"""Update the transform on an edge in a profile's pipeline config."""
from uuid import UUID
from core.db.models import Profile
from core.db.connection import get_session
from sqlmodel import select
from fastapi import HTTPException
with get_session() as session:
stmt = select(Profile).where(Profile.name == req.profile_name)
profile = session.exec(stmt).first()
if not profile:
raise HTTPException(status_code=404, detail=f"Profile not found: {req.profile_name}")
pipeline = dict(profile.pipeline)
edges = pipeline.get("edges", [])
found = False
for edge in edges:
if edge.get("source") == req.source_stage and edge.get("target") == req.target_stage:
edge["transform"] = req.transform
found = True
break
if not found:
raise HTTPException(
status_code=404,
detail=f"Edge not found: {req.source_stage}{req.target_stage}",
)
pipeline["edges"] = edges
profile.pipeline = pipeline
session.commit()
return {"status": "updated", "edge": f"{req.source_stage}{req.target_stage}", "transform": req.transform}
@router.get("/config/stages", response_model=list[StageConfigInfo])
def list_stage_configs():
"""Return the stage palette with config field metadata for the editor."""
from core.detect.stages import list_stages
result = []
for stage in list_stages():
info = _stage_to_info(stage)
result.append(info)
return result
@router.get("/config/stages/{stage_name}", response_model=StageConfigInfo)
def get_stage_config(stage_name: str):
"""Return config field metadata for a single stage."""
from core.detect.stages import get_stage
try:
stage = get_stage(stage_name)
except KeyError:
from fastapi import HTTPException
raise HTTPException(status_code=404, detail=f"Unknown stage: {stage_name}")
return _stage_to_info(stage)
def _stage_to_info(stage) -> StageConfigInfo:
return StageConfigInfo(
name=stage.name,
label=stage.label,
description=stage.description,
category=stage.category,
config_fields=[
{
"name": f.name,
"type": f.type,
"default": f.default,
"description": f.description,
"min": f.min,
"max": f.max,
"options": f.options,
}
for f in stage.config_fields
],
output_hints=[
StageOutputHintInfo(
key=h.key, type=h.type, label=h.label,
default_opacity=h.default_opacity, src_format=h.src_format,
)
for h in getattr(stage, "output_hints", [])
],
accepted_transforms=[
TransformOptionInfo(
key=t.key, type=t.type, default=t.default,
label=t.label, description=t.description,
)
for t in getattr(stage, "accepted_transforms", [])
],
reads=stage.io.reads,
writes=stage.io.writes,
)

521
core/api/detect/replay.py Normal file
View File

@@ -0,0 +1,521 @@
"""
API endpoints for checkpoint inspection, replay, retry, and GPU proxy.
GET /detect/checkpoints/{timeline_id} — list available checkpoints
POST /detect/replay — replay from a stage with config overrides
POST /detect/retry — queue async retry with different provider
POST /detect/replay-stage — replay single stage (fast path)
POST /detect/gpu/detect_edges — proxy to GPU inference server
POST /detect/gpu/detect_edges/debug — proxy with debug overlays
"""
from __future__ import annotations
import logging
import os
from fastapi import APIRouter, HTTPException, Request, Response
from pydantic import BaseModel
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/detect", tags=["detect"])
# --- Request/Response models ---
class CheckpointInfo(BaseModel):
stage: str
is_scenario: bool = False
scenario_label: str = ""
class ScenarioInfo(BaseModel):
timeline_id: str
stage: str
scenario_label: str
profile_name: str
video_path: str
frame_count: int = 0
created_at: str = ""
class ReplayRequest(BaseModel):
job_id: str
start_stage: str
config_overrides: dict | None = None
class ReplayResponse(BaseModel):
status: str
job_id: str
replay_job_id: str
start_stage: str
detections: int = 0
brands_found: int = 0
class ReplaySingleStageRequest(BaseModel):
job_id: str
stage: str
frame_refs: list[int] | None = None
config_overrides: dict | None = None
debug: bool = False
class ReplaySingleStageBox(BaseModel):
x: int
y: int
w: int
h: int
confidence: float
label: str
class FrameDebugOverlays(BaseModel):
edge_overlay_b64: str = ""
lines_overlay_b64: str = ""
horizontal_count: int = 0
pair_count: int = 0
class ReplaySingleStageResponse(BaseModel):
status: str
stage: str
frame_count: int = 0
region_count: int = 0
regions_by_frame: dict[str, list[ReplaySingleStageBox]] = {}
debug: dict[str, FrameDebugOverlays] = {} # keyed by frame seq
# --- Endpoints ---
@router.get("/checkpoints/{timeline_id}")
def list_checkpoints_endpoint(timeline_id: str) -> list[CheckpointInfo]:
"""List available checkpoint stages for a timeline."""
from core.detect.checkpoint.storage import get_checkpoints_for_timeline
try:
checkpoints = get_checkpoints_for_timeline(timeline_id)
except Exception as e:
raise HTTPException(status_code=404, detail=f"No checkpoints for timeline {timeline_id}: {e}")
result = [
CheckpointInfo(
stage=c["stage_name"],
is_scenario=c.get("is_scenario", False),
scenario_label=c.get("scenario_label", ""),
)
for c in checkpoints
if c["stage_name"]
]
return result
class CheckpointFrameInfo(BaseModel):
seq: int
timestamp: float
jpeg_b64: str
class CheckpointData(BaseModel):
timeline_id: str
stage: str
profile_name: str
video_path: str
is_scenario: bool
scenario_label: str
frames: list[CheckpointFrameInfo]
stats: dict = {}
config_snapshot: dict = {}
stage_output_key: str = ""
@router.get("/checkpoints/{timeline_id}/{stage}", response_model=CheckpointData)
def get_checkpoint_data(timeline_id: str, stage: str):
"""Load checkpoint frames + metadata for the editor UI.
Reads from the timeline's frame cache (local filesystem).
"""
from uuid import UUID
from core.db.models import Timeline, Checkpoint
from core.db.connection import get_session
from core.db.checkpoint import list_checkpoints
from core.detect.checkpoint.frames import load_cached_frames_b64
with get_session() as session:
timeline = session.get(Timeline, UUID(timeline_id))
if not timeline:
raise HTTPException(status_code=404, detail=f"Timeline not found: {timeline_id}")
checkpoints = list_checkpoints(session, UUID(timeline_id))
if not checkpoints:
raise HTTPException(status_code=404, detail=f"No checkpoints for timeline {timeline_id}")
# Prefer a checkpoint for this stage; fall back to latest
checkpoint = next(
(c for c in reversed(checkpoints) if c.stage_name == stage),
checkpoints[-1],
)
# Read from timeline's frame cache
frames_b64 = load_cached_frames_b64(timeline_id)
frame_list = [
CheckpointFrameInfo(seq=f["seq"], timestamp=f["timestamp"], jpeg_b64=f["jpeg_b64"])
for f in frames_b64
]
return CheckpointData(
timeline_id=timeline_id,
stage=stage,
profile_name=timeline.profile_name,
video_path=timeline.chunk_paths[0] if timeline.chunk_paths else "",
is_scenario=checkpoint.is_scenario,
scenario_label=checkpoint.scenario_label,
frames=frame_list,
stats=checkpoint.stats or {},
config_snapshot=checkpoint.config_overrides or {},
stage_output_key=stage,
)
@router.get("/scenarios", response_model=list[ScenarioInfo])
def list_scenarios_endpoint():
"""List all available scenarios (bookmarked checkpoints)."""
from core.db.models import Timeline
from core.db.connection import get_session
from core.db.checkpoint import list_scenarios
with get_session() as session:
scenarios = list_scenarios(session)
result = []
for s in scenarios:
timeline = session.get(Timeline, s.timeline_id)
if not timeline:
continue
info = ScenarioInfo(
timeline_id=str(s.timeline_id),
stage=s.stage_name,
scenario_label=s.scenario_label,
profile_name=timeline.profile_name,
video_path=timeline.chunk_paths[0] if timeline.chunk_paths else "",
created_at=str(s.created_at) if s.created_at else "",
)
result.append(info)
return result
@router.post("/replay", response_model=ReplayResponse)
def replay(req: ReplayRequest):
"""Replay pipeline from a specific stage with optional config overrides."""
from core.detect.checkpoint.replay import replay_from
try:
result = replay_from(
job_id=req.job_id,
start_stage=req.start_stage,
config_overrides=req.config_overrides,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Replay failed: {e}")
detections = result.get("detections", [])
report = result.get("report")
brands_found = len(report.brands) if report else 0
response = ReplayResponse(
status="completed",
job_id=req.job_id,
replay_job_id=result.get("job_id", ""),
start_stage=req.start_stage,
detections=len(detections),
brands_found=brands_found,
)
return response
@router.post("/replay-stage", response_model=ReplaySingleStageResponse)
def replay_single_stage(req: ReplaySingleStageRequest):
"""Replay a single stage on specific frames — fast path for interactive tuning."""
from core.detect.checkpoint.replay import replay_single_stage as _replay
try:
result = _replay(
job_id=req.job_id,
stage=req.stage,
frame_refs=req.frame_refs,
config_overrides=req.config_overrides,
debug=req.debug,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Single-stage replay failed: {e}")
# Convert result to response format
regions_by_frame = result.get("edge_regions_by_frame", {})
total_regions = 0
serialized = {}
for seq, boxes in regions_by_frame.items():
box_list = []
for b in boxes:
box = ReplaySingleStageBox(
x=b.x, y=b.y, w=b.w, h=b.h,
confidence=b.confidence, label=b.label,
)
box_list.append(box)
serialized[str(seq)] = box_list
total_regions += len(box_list)
# Serialize debug overlays if present
debug_out = {}
raw_debug = result.get("debug", {})
for seq, d in raw_debug.items():
debug_out[str(seq)] = FrameDebugOverlays(
edge_overlay_b64=d.get("edge_overlay_b64", ""),
lines_overlay_b64=d.get("lines_overlay_b64", ""),
horizontal_count=d.get("horizontal_count", 0),
pair_count=d.get("pair_count", 0),
)
return ReplaySingleStageResponse(
status="completed",
stage=req.stage,
frame_count=len(regions_by_frame),
region_count=total_regions,
regions_by_frame=serialized,
debug=debug_out,
)
# --- GPU proxy — thin passthrough to inference server for interactive editor ---
def _gpu_url() -> str:
url = os.environ.get("INFERENCE_URL", "http://localhost:8000")
return url.rstrip("/")
# --- Overlay cache — save/load debug overlay images ---
class SaveOverlaysRequest(BaseModel):
timeline_id: str
job_id: str
stage: str
seq: int
overlays: dict[str, str] # {overlay_key: base64_png}
@router.post("/overlays")
def save_overlays_endpoint(req: SaveOverlaysRequest):
"""Save debug overlay images to blob storage cache."""
from core.detect.checkpoint.frames import save_overlays
save_overlays(req.timeline_id, req.job_id, req.stage, req.seq, req.overlays)
return {"status": "saved", "count": len(req.overlays)}
@router.get("/overlays/{timeline_id}/{job_id}/{stage}/{seq}")
def load_overlays_endpoint(timeline_id: str, job_id: str, stage: str, seq: int):
"""Load cached debug overlay images."""
from core.detect.checkpoint.frames import load_overlays
overlays = load_overlays(timeline_id, job_id, stage, seq)
return {"overlays": overlays or {}}
def _generate_debug_overlays(job_id: str, stage: str, frame) -> dict[str, str] | None:
"""Generate debug overlay images for a single frame."""
import os
inference_url = os.environ.get("INFERENCE_URL")
if stage == "detect_edges":
from core.detect.profile import get_profile, get_stage_config
from core.detect.stages.models import RegionAnalysisConfig
from core.db.connection import get_session
from core.db.job import get_job
from uuid import UUID
with get_session() as session:
job = get_job(session, UUID(job_id))
if not job:
return None
profile = get_profile(job.profile_name)
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
if inference_url:
from core.detect.inference import InferenceClient
client = InferenceClient(base_url=inference_url, job_id=job_id)
dr = client.detect_edges_debug(
image=frame.image,
edge_canny_low=config.edge_canny_low,
edge_canny_high=config.edge_canny_high,
edge_hough_threshold=config.edge_hough_threshold,
edge_hough_min_length=config.edge_hough_min_length,
edge_hough_max_gap=config.edge_hough_max_gap,
edge_pair_max_distance=config.edge_pair_max_distance,
edge_pair_min_distance=config.edge_pair_min_distance,
)
return {
"edge_overlay_b64": dr.edge_overlay_b64,
"lines_overlay_b64": dr.lines_overlay_b64,
}
else:
from core.detect.stages.edge_detector import _load_cv_edges
edges_mod = _load_cv_edges()
dr = edges_mod.detect_edges_debug(
frame.image,
canny_low=config.edge_canny_low,
canny_high=config.edge_canny_high,
hough_threshold=config.edge_hough_threshold,
hough_min_length=config.edge_hough_min_length,
hough_max_gap=config.edge_hough_max_gap,
pair_max_distance=config.edge_pair_max_distance,
pair_min_distance=config.edge_pair_min_distance,
)
return {
"edge_overlay_b64": dr["edge_overlay_b64"],
"lines_overlay_b64": dr["lines_overlay_b64"],
}
elif stage == "field_segmentation":
from core.detect.profile import get_profile, get_stage_config
from core.detect.stages.models import FieldSegmentationConfig
from core.db.connection import get_session
from core.db.job import get_job
from uuid import UUID
with get_session() as session:
job = get_job(session, UUID(job_id))
if not job:
return None
profile = get_profile(job.profile_name)
config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation"))
if inference_url:
import httpx, json, base64, io
from PIL import Image
import numpy as np
buf = io.BytesIO()
Image.fromarray(frame.image).save(buf, format="JPEG", quality=85)
img_b64 = base64.b64encode(buf.getvalue()).decode()
resp = httpx.post(
f"{inference_url.rstrip('/')}/segment_field/debug",
json={
"image_b64": img_b64,
"hue_low": config.hue_low,
"hue_high": config.hue_high,
"sat_low": config.sat_low,
"sat_high": config.sat_high,
"val_low": config.val_low,
"val_high": config.val_high,
"morph_kernel": config.morph_kernel,
"min_area_ratio": config.min_area_ratio,
},
timeout=30.0,
)
if resp.status_code == 200:
data = resp.json()
return {"mask_overlay_b64": data.get("mask_b64", "")}
return None
return None
@router.get("/overlays/{timeline_id}/{job_id}/{stage}")
def list_overlay_frames_endpoint(timeline_id: str, job_id: str, stage: str):
"""List frame sequences that have cached overlays."""
from core.detect.checkpoint.frames import list_overlay_frames
seqs = list_overlay_frames(timeline_id, job_id, stage)
return {"frames": seqs}
# --- GPU proxy — thin passthrough to inference server for interactive editor ---
@router.post("/gpu/detect_edges")
async def gpu_detect_edges(request: Request):
"""Proxy to GPU inference server — browser can't reach it directly."""
import httpx
body = await request.body()
try:
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{_gpu_url()}/detect_edges",
content=body,
headers={"Content-Type": "application/json"},
)
return Response(content=resp.content, status_code=resp.status_code,
media_type="application/json")
except Exception as e:
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
@router.post("/gpu/detect_edges/debug")
async def gpu_detect_edges_debug(request: Request):
"""Proxy to GPU inference server debug endpoint."""
import httpx
body = await request.body()
try:
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{_gpu_url()}/detect_edges/debug",
content=body,
headers={"Content-Type": "application/json"},
)
return Response(content=resp.content, status_code=resp.status_code,
media_type="application/json")
except Exception as e:
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
@router.post("/gpu/segment_field")
async def gpu_segment_field(request: Request):
"""Proxy to GPU inference server — field segmentation."""
import httpx
body = await request.body()
try:
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{_gpu_url()}/segment_field",
content=body,
headers={"Content-Type": "application/json"},
)
return Response(content=resp.content, status_code=resp.status_code,
media_type="application/json")
except Exception as e:
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
@router.post("/gpu/segment_field/debug")
async def gpu_segment_field_debug(request: Request):
"""Proxy to GPU inference server — field segmentation with debug overlay."""
import httpx
body = await request.body()
try:
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{_gpu_url()}/segment_field/debug",
content=body,
headers={"Content-Type": "application/json"},
)
return Response(content=resp.content, status_code=resp.status_code,
media_type="application/json")
except Exception as e:
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")

278
core/api/detect/run.py Normal file
View File

@@ -0,0 +1,278 @@
"""
Pipeline run endpoints.
POST /detect/run — launch pipeline on a timeline
POST /detect/stop/{job_id} — cancel a running pipeline
POST /detect/pause/{job_id} — pause after current stage
POST /detect/resume/{job_id} — resume a paused pipeline
POST /detect/step/{job_id} — run one stage then pause
POST /detect/clear/{job_id} — clear events from Redis
GET /detect/status/{job_id} — pipeline run status
"""
from __future__ import annotations
import logging
import os
import threading
import uuid
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/detect", tags=["detect"])
# In-process pipeline tracking
_running_jobs: dict[str, threading.Thread] = {}
_cancelled_jobs: set[str] = set()
class RunRequest(BaseModel):
timeline_id: str
profile_name: str = "soccer_broadcast"
checkpoint: bool = True
skip_vlm: bool = False
skip_cloud: bool = False
log_level: str = "INFO" # INFO | DEBUG
pause_after_stage: bool = False
config_overrides: dict | None = None
class RunResponse(BaseModel):
status: str
job_id: str
timeline_id: str
def _resolve_video_path(video_path: str) -> str:
"""Download a chunk from blob storage to a temp file."""
from core.storage.blob import get_store
store = get_store("out")
try:
return store.download_to_temp(video_path)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to download chunk: {e}")
@router.post("/run", response_model=RunResponse)
def run_pipeline(req: RunRequest):
"""Launch a detection pipeline run on a timeline."""
from core.detect import emit
from core.detect.graph import get_pipeline
from core.detect.state import DetectState
from core.detect.checkpoint.storage import get_timeline
from core.db.connection import get_session
from core.db.job import create_job, update_job_status
# Load timeline
try:
timeline = get_timeline(req.timeline_id)
except ValueError:
raise HTTPException(status_code=404, detail=f"Timeline not found: {req.timeline_id}")
chunk_paths = timeline["chunk_paths"]
if not chunk_paths:
raise HTTPException(status_code=400, detail="Timeline has no chunk paths")
# Resolve first chunk to local path for the pipeline
local_path = _resolve_video_path(chunk_paths[0])
# Create job in DB
source_asset_id_str = timeline.get("source_asset_id", "")
with get_session() as session:
from uuid import UUID as _UUID
source_asset_id = _UUID(source_asset_id_str) if source_asset_id_str else uuid.uuid4()
job = create_job(
session,
source_asset_id=source_asset_id,
video_path=chunk_paths[0],
timeline_id=_UUID(req.timeline_id),
profile_name=req.profile_name,
config_overrides=req.config_overrides,
)
job_id = str(job.id)
if req.skip_vlm:
os.environ["SKIP_VLM"] = "1"
elif "SKIP_VLM" in os.environ:
del os.environ["SKIP_VLM"]
if req.skip_cloud:
os.environ["SKIP_CLOUD"] = "1"
elif "SKIP_CLOUD" in os.environ:
del os.environ["SKIP_CLOUD"]
# Clear any stale events
from core.events import _get_redis
from core.detect.events import DETECT_EVENTS_PREFIX
r = _get_redis()
r.delete(f"{DETECT_EVENTS_PREFIX}:{job_id}")
emit.set_run_context(
run_id=job_id, parent_job_id=job_id, run_type="initial",
log_level=req.log_level,
)
pipeline = get_pipeline(checkpoint=req.checkpoint, profile_name=req.profile_name)
initial_state = DetectState(
video_path=local_path,
job_id=job_id,
profile_name=req.profile_name,
source_asset_id=source_asset_id_str or str(source_asset_id),
timeline_id=req.timeline_id,
config_overrides=req.config_overrides or {},
)
from core.detect.graph import (
PipelineCancelled, set_cancel_check, clear_cancel_check,
init_pause, clear_pause,
)
set_cancel_check(job_id, lambda: job_id in _cancelled_jobs)
init_pause(job_id, pause_after_stage=req.pause_after_stage)
def _update_job(status, stage=None, error=None):
from core.db.connection import get_session
from core.db.job import update_job_status
with get_session() as session:
update_job_status(session, _UUID(job_id), status,
current_stage=stage, error_message=error)
def _run():
try:
_update_job("running")
emit.log(job_id, "Pipeline", "INFO",
f"Starting pipeline: {chunk_paths[0]} (profile={req.profile_name})")
pipeline.invoke(initial_state)
_update_job("completed")
emit.log(job_id, "Pipeline", "INFO", "Pipeline completed successfully")
emit.job_complete(job_id, {"status": "completed"})
except PipelineCancelled:
_update_job("cancelled")
emit.log(job_id, "Pipeline", "INFO", "Pipeline cancelled")
emit.job_complete(job_id, {"status": "cancelled"})
except Exception as e:
logger.exception("Pipeline run %s failed: %s", job_id, e)
_update_job("failed", error=str(e))
from core.detect.graph import _node_states, NODES
if job_id in _node_states:
states = _node_states[job_id]
for node in reversed(NODES):
if states.get(node) in ("running", "done"):
states[node] = "error"
break
nodes = [{"id": n, "status": states[n]} for n in NODES]
emit.graph_update(job_id, nodes)
emit.log(job_id, "Pipeline", "ERROR", str(e))
emit.job_complete(job_id, {"status": "failed", "error": str(e)})
finally:
_running_jobs.pop(job_id, None)
_cancelled_jobs.discard(job_id)
clear_cancel_check(job_id)
clear_pause(job_id)
emit.clear_run_context()
from core.detect.checkpoint.runner_bridge import reset_checkpoint_state
reset_checkpoint_state(job_id)
thread = threading.Thread(target=_run, daemon=True, name=f"pipeline-{job_id}")
_running_jobs[job_id] = thread
thread.start()
return RunResponse(status="started", job_id=job_id, timeline_id=req.timeline_id)
@router.post("/stop/{job_id}")
def stop_pipeline(job_id: str):
"""Stop a running pipeline. Signals cancellation; the thread checks on next stage."""
from core.detect import emit
if job_id not in _running_jobs:
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
_cancelled_jobs.add(job_id)
emit.log(job_id, "Pipeline", "INFO", "Stop requested — cancelling after current stage")
return {"status": "stopping", "job_id": job_id}
@router.post("/pause/{job_id}")
def pause(job_id: str):
"""Pause a running pipeline after the current stage completes."""
from core.detect.graph import pause_pipeline
if job_id not in _running_jobs:
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
pause_pipeline(job_id)
return {"status": "pausing", "job_id": job_id}
@router.post("/resume/{job_id}")
def resume(job_id: str):
"""Resume a paused pipeline."""
from core.detect.graph import resume_pipeline
if job_id not in _running_jobs:
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
resume_pipeline(job_id)
return {"status": "running", "job_id": job_id}
@router.post("/step/{job_id}")
def step(job_id: str):
"""Run one stage then pause again."""
from core.detect.graph import step_pipeline
if job_id not in _running_jobs:
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
step_pipeline(job_id)
return {"status": "stepping", "job_id": job_id}
@router.post("/pause-after-stage/{job_id}")
def toggle_pause_after_stage(job_id: str, enabled: bool = True):
"""Toggle pause-after-each-stage mode."""
from core.detect.graph import set_pause_after_stage
if job_id not in _running_jobs:
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
set_pause_after_stage(job_id, enabled)
return {"status": "ok", "pause_after_stage": enabled, "job_id": job_id}
@router.get("/status/{job_id}")
def pipeline_status(job_id: str):
"""Get pipeline run status."""
from core.detect.graph import is_paused
running = job_id in _running_jobs
paused = is_paused(job_id)
cancelling = job_id in _cancelled_jobs
if cancelling:
status = "cancelling"
elif paused:
status = "paused"
elif running:
status = "running"
else:
status = "idle"
return {"status": status, "job_id": job_id}
@router.post("/clear/{job_id}")
def clear_pipeline(job_id: str):
"""Clear events for a job from Redis."""
from core.events import _get_redis
from core.detect.events import DETECT_EVENTS_PREFIX
r = _get_redis()
r.delete(f"{DETECT_EVENTS_PREFIX}:{job_id}")
return {"status": "cleared", "job_id": job_id}

108
core/api/detect/sources.py Normal file
View File

@@ -0,0 +1,108 @@
"""
Source browser for detection pipeline.
Lists available media sources from blob storage (MinIO).
GET /detect/sources — list chunk jobs
GET /detect/sources/{job_id}/chunks — list chunks for a job
GET /detect/sources/{job_id}/chunks/{name}/url — presigned preview URL
"""
from __future__ import annotations
import logging
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/detect", tags=["detect"])
class ChunkInfoResponse(BaseModel):
filename: str
key: str
size_bytes: int
class SourceInfoResponse(BaseModel):
job_id: str
source_type: str = "chunk_job"
chunk_count: int
total_bytes: int = 0
def _list_sources() -> list[SourceInfoResponse]:
"""List chunk jobs from blob storage."""
from core.storage.blob import get_store
store = get_store("out")
try:
objects = store.list(prefix="chunks/")
except Exception as e:
logger.warning("Failed to list blob sources: %s", e)
return []
jobs: dict[str, int] = {}
job_bytes: dict[str, int] = {}
for obj in objects:
rel_key = obj.key.removeprefix(store.prefix)
parts = rel_key.split("/")
if len(parts) >= 3 and parts[0] == "chunks":
job_id = parts[1]
jobs[job_id] = jobs.get(job_id, 0) + 1
job_bytes[job_id] = job_bytes.get(job_id, 0) + obj.size_bytes
sources = []
for job_id, count in sorted(jobs.items()):
source = SourceInfoResponse(
job_id=job_id,
source_type="chunk_job",
chunk_count=count,
total_bytes=job_bytes.get(job_id, 0),
)
sources.append(source)
return sources
@router.get("/sources", response_model=list[SourceInfoResponse])
def list_sources():
"""List available chunk jobs from blob storage."""
return _list_sources()
@router.get("/sources/{source_job_id}/chunks", response_model=list[ChunkInfoResponse])
def list_chunks(source_job_id: str):
"""List chunks for a specific source job."""
from core.storage.blob import get_store
store = get_store("out")
try:
objects = store.list(prefix=f"chunks/{source_job_id}/", extensions={".mp4"})
except Exception as e:
logger.warning("Failed to list chunks for %s: %s", source_job_id, e)
raise HTTPException(status_code=503, detail=f"Blob storage unavailable: {e}")
if not objects:
raise HTTPException(status_code=404, detail=f"Source not found: {source_job_id}")
chunks = []
for obj in objects:
info = ChunkInfoResponse(filename=obj.filename, key=obj.key, size_bytes=obj.size_bytes)
chunks.append(info)
return sorted(chunks, key=lambda c: c.filename)
@router.get("/sources/{source_job_id}/chunks/{filename}/url")
def get_chunk_url(source_job_id: str, filename: str):
"""Return a presigned URL for previewing a chunk in the browser."""
from core.storage.blob import get_store
store = get_store("out")
key = f"chunks/{source_job_id}/{filename}"
try:
url = store.get_url(key, expires=3600)
except Exception as e:
raise HTTPException(status_code=503, detail=f"Could not generate URL: {e}")
return {"url": url}

79
core/api/detect/sse.py Normal file
View File

@@ -0,0 +1,79 @@
"""
SSE endpoint for detection pipeline events.
Uses Redis as the event bus between pipeline workers and the SSE stream.
Mirrors chunker_sse.py but polls detect_events:{job_id}.
GET /detect/stream/{job_id} → text/event-stream
"""
import asyncio
import json
import logging
import time
from typing import AsyncGenerator
from fastapi import APIRouter
from starlette.responses import StreamingResponse
from core.events import poll_events
from core.detect.events import DETECT_EVENTS_PREFIX, TERMINAL_EVENTS
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/detect", tags=["detect"])
async def _event_generator(job_id: str) -> AsyncGenerator[str, None]:
cursor = 0
timeout = time.monotonic() + 3600 # 1 hour max
while time.monotonic() < timeout:
events, cursor = poll_events(job_id, cursor, prefix=DETECT_EVENTS_PREFIX)
if not events:
await asyncio.sleep(0.2)
continue
is_terminal = False
for data in events:
event_type = data.pop("event", "update")
payload = {**data, "job_id": job_id}
yield f"event: {event_type}\ndata: {json.dumps(payload)}\n\n"
if event_type in TERMINAL_EVENTS:
is_terminal = True
if is_terminal:
yield f"event: done\ndata: {json.dumps({'job_id': job_id})}\n\n"
# Don't return — keep connection alive so EventSource doesn't reconnect.
# Just idle until the client disconnects or timeout.
while time.monotonic() < timeout:
await asyncio.sleep(5)
return
await asyncio.sleep(0.05)
yield f"event: timeout\ndata: {json.dumps({'job_id': job_id})}\n\n"
@router.get("/stream/{job_id}")
async def stream_detect_job(job_id: str):
"""
SSE stream for a detection pipeline job.
The UI connects via native EventSource:
const es = new EventSource('/api/detect/stream/<job_id>');
es.addEventListener('graph_update', (e) => { ... });
es.addEventListener('detection', (e) => { ... });
"""
return StreamingResponse(
_event_generator(job_id),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)

226
core/api/detect/timeline.py Normal file
View File

@@ -0,0 +1,226 @@
"""
Timeline + Job management endpoints.
POST /detect/timeline — create timeline from chunk selection
GET /detect/timeline — list timelines
GET /detect/timeline/{id} — timeline detail
DELETE /detect/timeline/{id}/cache — clear frame cache
GET /detect/jobs — list jobs (optionally by timeline)
GET /detect/jobs/{id} — job detail + checkpoints + stage outputs
"""
from __future__ import annotations
import logging
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/detect", tags=["detect"])
# --- Request/Response models ---
class CreateTimelineRequest(BaseModel):
chunk_paths: list[str]
profile_name: str = "soccer_broadcast"
name: str = ""
source_asset_id: str = ""
fps: float = 2.0
class TimelineResponse(BaseModel):
id: str
name: str
chunk_paths: list[str]
profile_name: str
status: str
fps: float
frame_count: int
source_ephemeral: bool
created_at: str | None = None
class JobResponse(BaseModel):
id: str
timeline_id: str | None
source_asset_id: str
video_path: str
profile_name: str
run_type: str
status: str
current_stage: str | None
config_overrides: dict
error_message: str | None
created_at: str | None
started_at: str | None
completed_at: str | None
class JobDetailResponse(JobResponse):
checkpoints: list[dict]
stage_outputs: dict[str, dict]
# --- Timeline endpoints ---
@router.post("/timeline", response_model=TimelineResponse)
def create_timeline_endpoint(req: CreateTimelineRequest):
"""Create a timeline from a chunk selection."""
from uuid import UUID
from core.detect.checkpoint.storage import create_timeline
source_asset_id = UUID(req.source_asset_id) if req.source_asset_id else None
tid = create_timeline(
chunk_paths=req.chunk_paths,
profile_name=req.profile_name,
name=req.name,
source_asset_id=source_asset_id,
fps=req.fps,
)
from core.detect.checkpoint.storage import get_timeline
tl = get_timeline(tid)
return TimelineResponse(
id=tl["id"],
name=tl["name"],
chunk_paths=tl["chunk_paths"],
profile_name=tl["profile_name"],
status=tl["status"],
fps=tl["fps"],
frame_count=0,
source_ephemeral=False,
created_at=tl["created_at"],
)
@router.get("/timeline", response_model=list[TimelineResponse])
def list_timelines():
"""List all timelines."""
from sqlmodel import select
from core.db.models import Timeline
from core.db.connection import get_session
with get_session() as session:
stmt = select(Timeline).order_by(Timeline.created_at.desc())
timelines = session.exec(stmt).all()
return [
TimelineResponse(
id=str(t.id),
name=t.name,
chunk_paths=t.chunk_paths or [],
profile_name=t.profile_name,
status=t.status,
fps=t.fps,
frame_count=t.frame_count,
source_ephemeral=t.source_ephemeral,
created_at=str(t.created_at) if t.created_at else None,
)
for t in timelines
]
@router.get("/timeline/{timeline_id}", response_model=TimelineResponse)
def get_timeline_endpoint(timeline_id: str):
"""Get timeline detail."""
from core.detect.checkpoint.storage import get_timeline
try:
tl = get_timeline(timeline_id)
except ValueError:
raise HTTPException(status_code=404, detail=f"Timeline not found: {timeline_id}")
from core.detect.checkpoint.frames import cache_exists
from uuid import UUID
from core.db.models import Timeline
from core.db.connection import get_session
with get_session() as session:
timeline = session.get(Timeline, UUID(timeline_id))
return TimelineResponse(
id=tl["id"],
name=tl["name"],
chunk_paths=tl["chunk_paths"],
profile_name=tl["profile_name"],
status=tl["status"],
fps=tl["fps"],
frame_count=timeline.frame_count if timeline else 0,
source_ephemeral=timeline.source_ephemeral if timeline else False,
created_at=tl["created_at"],
)
@router.delete("/timeline/{timeline_id}/cache")
def clear_timeline_cache(timeline_id: str):
"""Clear the frame cache for a timeline."""
from core.detect.checkpoint.frames import clear_cache
from core.detect.checkpoint.storage import update_timeline_status
clear_cache(timeline_id)
update_timeline_status(timeline_id, "created")
return {"status": "cleared", "timeline_id": timeline_id}
# --- Job endpoints ---
def _job_to_response(job) -> JobResponse:
return JobResponse(
id=str(job.id),
timeline_id=str(job.timeline_id) if job.timeline_id else None,
source_asset_id=str(job.source_asset_id),
video_path=job.video_path,
profile_name=job.profile_name,
run_type=job.run_type,
status=job.status,
current_stage=job.current_stage,
config_overrides=job.config_overrides or {},
error_message=job.error_message,
created_at=str(job.created_at) if job.created_at else None,
started_at=str(job.started_at) if job.started_at else None,
completed_at=str(job.completed_at) if job.completed_at else None,
)
@router.get("/jobs", response_model=list[JobResponse])
def list_jobs_endpoint(timeline_id: str | None = Query(None)):
"""List jobs, optionally filtered by timeline."""
from uuid import UUID
from core.db.connection import get_session
from core.db.job import list_jobs
tid = UUID(timeline_id) if timeline_id else None
with get_session() as session:
jobs = list_jobs(session, timeline_id=tid)
return [_job_to_response(j) for j in jobs]
@router.get("/jobs/{job_id}", response_model=JobDetailResponse)
def get_job_endpoint(job_id: str):
"""Get job detail with checkpoints and stage outputs."""
from uuid import UUID
from core.db.connection import get_session
from core.db.job import get_job
from core.detect.checkpoint.storage import (
get_checkpoints_for_job,
load_stage_outputs_for_job,
)
with get_session() as session:
job = get_job(session, UUID(job_id))
if not job:
raise HTTPException(status_code=404, detail=f"Job not found: {job_id}")
checkpoints = get_checkpoints_for_job(job_id)
stage_outputs = load_stage_outputs_for_job(job_id)
base = _job_to_response(job)
return JobDetailResponse(
**base.model_dump(),
checkpoints=checkpoints,
stage_outputs=stage_outputs,
)

58
core/api/main.py Normal file
View File

@@ -0,0 +1,58 @@
"""
MPR FastAPI Application
"""
import os
import sys
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from core.api.detect import router as detect_router
@asynccontextmanager
async def lifespan(app):
from core.db.connection import create_tables
from core.db.seed import seed_profiles
create_tables()
seed_profiles()
yield
app = FastAPI(
title="MPR API",
version="0.1.0",
docs_url="/docs",
redoc_url="/redoc",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://mpr.local.ar", "http://k8s.mpr.local.ar", "http://localhost:5173"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Detection API (sources, run, SSE, replay, config)
app.include_router(detect_router)
@app.get("/health")
def health():
return {"status": "ok"}
@app.get("/")
def root():
return {
"name": "MPR API",
"version": "0.1.0",
}

24
core/db/__init__.py Normal file
View File

@@ -0,0 +1,24 @@
"""
Database layer.
tables.py — SQLModel table definitions (generated by modelgen, don't edit)
domain files — session-first query functions for non-trivial operations
Basic CRUD (create, get, update, delete) goes directly through the session:
session.add(Job(...))
session.get(Job, id)
session.get(Job, id); setattr(...); session.commit()
session.delete(obj); session.commit()
"""
from .connection import get_session, create_tables
from .models import MediaAsset, Job, Timeline, Checkpoint, Brand
from .assets import list_assets, get_asset_filenames
from .job import list_jobs
from .checkpoint import (
get_latest_checkpoint, get_root_checkpoint,
list_checkpoints, list_scenarios,
)
from .brand import get_or_create_brand, find_brand_by_text, list_brands, record_airing

23
core/db/assets.py Normal file
View File

@@ -0,0 +1,23 @@
"""MediaAsset queries."""
from __future__ import annotations
from typing import Optional
from uuid import UUID
from sqlmodel import Session, select
from .models import MediaAsset
def list_assets(session: Session, status: Optional[str] = None, search: Optional[str] = None) -> list[MediaAsset]:
stmt = select(MediaAsset)
if status:
stmt = stmt.where(MediaAsset.status == status)
if search:
stmt = stmt.where(MediaAsset.filename.ilike(f"%{search}%"))
return list(session.exec(stmt).all())
def get_asset_filenames(session: Session) -> set[str]:
return set(session.exec(select(MediaAsset.filename)).all())

61
core/db/brand.py Normal file
View File

@@ -0,0 +1,61 @@
"""Brand queries."""
from __future__ import annotations
from typing import Optional
from uuid import UUID
from sqlmodel import Session, select
from .models import Brand
def get_or_create_brand(session: Session, canonical_name: str,
aliases: Optional[list[str]] = None,
source: str = "ocr") -> tuple[Brand, bool]:
normalized = canonical_name.strip()
brand = session.exec(select(Brand).where(Brand.canonical_name.ilike(normalized))).first()
if brand:
return brand, False
brand = Brand(canonical_name=normalized, aliases=aliases or [], source=source)
session.add(brand)
session.flush()
return brand, True
def find_brand_by_text(session: Session, text: str) -> Brand | None:
normalized = text.strip().lower()
brand = session.exec(select(Brand).where(Brand.canonical_name.ilike(normalized))).first()
if brand:
return brand
for b in session.exec(select(Brand)).all():
if normalized in [a.lower() for a in (b.aliases or [])]:
return b
return None
def list_brands(session: Session) -> list[Brand]:
return list(session.exec(select(Brand).order_by(Brand.canonical_name)).all())
def record_airing(session: Session, brand_id: UUID, timeline_id: UUID,
frame_start: int, frame_end: int,
confidence: float, source: str = "ocr") -> Brand:
brand = session.get(Brand, brand_id)
if not brand:
raise ValueError(f"Brand not found: {brand_id}")
airing = {
"timeline_id": str(timeline_id),
"frame_start": frame_start,
"frame_end": frame_end,
"confidence": confidence,
"source": source,
}
airings = list(brand.airings or [])
airings.append(airing)
brand.airings = airings
brand.total_airings = len(airings)
return brand

43
core/db/checkpoint.py Normal file
View File

@@ -0,0 +1,43 @@
"""Checkpoint queries."""
from __future__ import annotations
from uuid import UUID
from sqlmodel import Session, select
from .models import Checkpoint
def get_latest_checkpoint(session: Session, timeline_id: UUID, parent_id: UUID | None = None) -> Checkpoint | None:
stmt = select(Checkpoint).where(Checkpoint.timeline_id == timeline_id)
if parent_id is not None:
stmt = stmt.where(Checkpoint.parent_id == parent_id)
stmt = stmt.order_by(Checkpoint.created_at.desc())
return session.exec(stmt).first()
def get_root_checkpoint(session: Session, timeline_id: UUID) -> Checkpoint | None:
stmt = select(Checkpoint).where(
Checkpoint.timeline_id == timeline_id,
Checkpoint.parent_id == None,
)
return session.exec(stmt).first()
def list_checkpoints(session: Session, timeline_id: UUID) -> list[Checkpoint]:
stmt = (
select(Checkpoint)
.where(Checkpoint.timeline_id == timeline_id)
.order_by(Checkpoint.created_at)
)
return list(session.exec(stmt).all())
def list_scenarios(session: Session) -> list[Checkpoint]:
stmt = (
select(Checkpoint)
.where(Checkpoint.is_scenario == True)
.order_by(Checkpoint.created_at.desc())
)
return list(session.exec(stmt).all())

34
core/db/connection.py Normal file
View File

@@ -0,0 +1,34 @@
"""
Database engine and session — SQLModel/SQLAlchemy, no Django.
Reads DATABASE_URL from the environment.
"""
from __future__ import annotations
import os
from sqlalchemy import create_engine
from sqlmodel import Session
DATABASE_URL = os.environ.get("DATABASE_URL", "postgresql://mpr:mpr@localhost:5432/mpr")
_engine = None
def get_engine():
global _engine
if _engine is None:
_engine = create_engine(DATABASE_URL, pool_size=5, max_overflow=10)
return _engine
def get_session() -> Session:
return Session(get_engine())
def create_tables():
"""Create all SQLModel tables."""
from sqlmodel import SQLModel
from . import models # noqa — registers all table classes
SQLModel.metadata.create_all(get_engine())

View File

@@ -0,0 +1,143 @@
{
"name": "soccer_broadcast",
"pipeline": {
"name": "soccer_broadcast",
"profile_name": "soccer_broadcast",
"stages": [
{
"name": "extract_frames",
"branch": "trunk"
},
{
"name": "filter_scenes",
"branch": "trunk"
},
{
"name": "field_segmentation",
"branch": "trunk"
},
{
"name": "detect_edges",
"branch": "hoarding"
},
{
"name": "detect_objects",
"branch": "objects"
},
{
"name": "preprocess"
},
{
"name": "run_ocr"
},
{
"name": "match_brands"
},
{
"name": "escalate_vlm"
},
{
"name": "escalate_cloud"
},
{
"name": "compile_report"
}
],
"edges": [
{
"source": "extract_frames",
"target": "filter_scenes"
},
{
"source": "filter_scenes",
"target": "field_segmentation"
},
{
"source": "field_segmentation",
"target": "detect_edges",
"transform": {"invert_mask": true}
},
{
"source": "field_segmentation",
"target": "detect_objects"
},
{
"source": "detect_edges",
"target": "preprocess"
},
{
"source": "detect_objects",
"target": "preprocess"
},
{
"source": "preprocess",
"target": "run_ocr"
},
{
"source": "run_ocr",
"target": "match_brands"
},
{
"source": "match_brands",
"target": "escalate_vlm"
},
{
"source": "escalate_vlm",
"target": "escalate_cloud"
},
{
"source": "escalate_cloud",
"target": "compile_report"
}
]
},
"configs": {
"extract_frames": {
"fps": 2.0,
"max_frames": 500
},
"filter_scenes": {
"hamming_threshold": 8,
"enabled": true
},
"field_segmentation": {
"enabled": true,
"hue_low": 30,
"hue_high": 85,
"sat_low": 30,
"sat_high": 255,
"val_low": 30,
"val_high": 255,
"morph_kernel": 15,
"min_area_ratio": 0.05
},
"detect_edges": {
"enabled": true,
"edge_canny_low": 50,
"edge_canny_high": 150,
"edge_hough_threshold": 80,
"edge_hough_min_length": 100,
"edge_hough_max_gap": 10,
"edge_pair_max_distance": 200,
"edge_pair_min_distance": 15
},
"detect_objects": {
"model_name": "yolov8n.pt",
"confidence_threshold": 0.3,
"target_classes": []
},
"run_ocr": {
"languages": [
"en",
"es"
],
"min_confidence": 0.5
},
"match_brands": {
"fuzzy_threshold": 75
},
"escalate_vlm": {
"vlm_prompt_template": "Identify the brand or sponsor visible in this cropped region from a soccer broadcast.{hint}{text} Respond with: brand, confidence (0-1), reasoning."
}
}
}

80
core/db/job.py Normal file
View File

@@ -0,0 +1,80 @@
"""Job queries."""
from __future__ import annotations
from datetime import datetime
from typing import Optional
from uuid import UUID
from sqlmodel import Session, select
from .models import Job
def create_job(
session: Session,
source_asset_id: UUID,
video_path: str,
timeline_id: UUID,
profile_name: str = "soccer_broadcast",
run_type: str = "initial",
parent_id: UUID | None = None,
config_overrides: dict | None = None,
) -> Job:
job = Job(
source_asset_id=source_asset_id,
video_path=video_path,
timeline_id=timeline_id,
profile_name=profile_name,
run_type=run_type,
parent_id=parent_id,
config_overrides=config_overrides or {},
status="pending",
)
session.add(job)
session.commit()
session.refresh(job)
return job
def update_job_status(
session: Session,
job_id: UUID,
status: str,
current_stage: str | None = None,
error_message: str | None = None,
):
job = session.get(Job, job_id)
if not job:
return
job.status = status
if current_stage is not None:
job.current_stage = current_stage
if error_message is not None:
job.error_message = error_message
if status == "running" and not job.started_at:
job.started_at = datetime.utcnow()
if status in ("completed", "failed", "cancelled"):
job.completed_at = datetime.utcnow()
session.commit()
def get_job(session: Session, job_id: UUID) -> Job | None:
return session.get(Job, job_id)
def list_jobs(
session: Session,
timeline_id: UUID | None = None,
parent_id: UUID | None = None,
status: str | None = None,
) -> list[Job]:
stmt = select(Job)
if timeline_id:
stmt = stmt.where(Job.timeline_id == timeline_id)
if parent_id:
stmt = stmt.where(Job.parent_id == parent_id)
if status:
stmt = stmt.where(Job.status == status)
stmt = stmt.order_by(Job.created_at.desc())
return list(session.exec(stmt).all())

179
core/db/models.py Normal file
View File

@@ -0,0 +1,179 @@
"""
SQLModel Table Models - GENERATED FILE
Do not edit directly. Regenerate using modelgen.
"""
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID, uuid4
from sqlmodel import SQLModel, Field, Column
from sqlalchemy import JSON
class AssetStatus(str, Enum):
PENDING = "pending"
READY = "ready"
ERROR = "error"
class JobStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
PAUSED = "paused"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class RunType(str, Enum):
INITIAL = "initial"
REPLAY = "replay"
RETRY = "retry"
class BrandSource(str, Enum):
OCR = "ocr"
VLM = "local_vlm"
CLOUD = "cloud_llm"
MANUAL = "manual"
class SourceType(str, Enum):
CHUNK_JOB = "chunk_job"
UPLOAD = "upload"
DEVICE = "device"
STREAM = "stream"
class MediaAsset(SQLModel, table=True):
"""A video/audio file registered in the system."""
__tablename__ = "media_asset"
id: UUID = Field(default_factory=uuid4, primary_key=True)
filename: str
file_path: str
status: AssetStatus = "pending"
error_message: Optional[str] = None
file_size: Optional[int] = None
duration: Optional[float] = None
video_codec: Optional[str] = None
audio_codec: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
framerate: Optional[float] = None
bitrate: Optional[int] = None
properties: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
comments: str = ""
tags: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
class TranscodePreset(SQLModel, table=True):
"""A reusable transcoding configuration (like Handbrake presets)."""
__tablename__ = "transcode_preset"
id: UUID = Field(default_factory=uuid4, primary_key=True)
name: str
description: str = ""
is_builtin: bool = False
container: str = "mp4"
video_codec: str = "libx264"
video_bitrate: Optional[str] = None
video_crf: Optional[int] = None
video_preset: Optional[str] = None
resolution: Optional[str] = None
framerate: Optional[float] = None
audio_codec: str = "aac"
audio_bitrate: Optional[str] = None
audio_channels: Optional[int] = None
audio_samplerate: Optional[int] = None
extra_args: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
class Job(SQLModel, table=True):
"""A pipeline job."""
__tablename__ = "job"
id: UUID = Field(default_factory=uuid4, primary_key=True)
source_asset_id: UUID = Field(index=True)
video_path: str
profile_name: str = "soccer_broadcast"
timeline_id: Optional[UUID] = None
parent_id: Optional[UUID] = None
run_type: RunType = "initial"
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
status: JobStatus = "pending"
current_stage: Optional[str] = None
progress: float = 0.0
error_message: Optional[str] = None
total_detections: int = 0
brands_found: int = 0
cloud_llm_calls: int = 0
estimated_cost_usd: float = 0.0
priority: int = 0
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
class Timeline(SQLModel, table=True):
"""A user-created selection of source material."""
__tablename__ = "timeline"
id: UUID = Field(default_factory=uuid4, primary_key=True)
name: str = ""
source_asset_id: Optional[UUID] = Field(default=None, index=True)
chunk_paths: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
profile_name: str = ""
status: str = "created"
fps: float = 2.0
frame_count: int = 0
source_ephemeral: bool = False
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
class Checkpoint(SQLModel, table=True):
"""A snapshot of pipeline state on a timeline."""
__tablename__ = "checkpoint"
id: UUID = Field(default_factory=uuid4, primary_key=True)
timeline_id: UUID
job_id: Optional[UUID] = Field(default=None, index=True)
parent_id: Optional[UUID] = None
stage_name: str = ""
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
stats: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
is_scenario: bool = False
scenario_label: str = ""
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
class StageOutput(SQLModel, table=True):
"""Output of a single stage within a job."""
__tablename__ = "stage_output"
id: UUID = Field(default_factory=uuid4, primary_key=True)
job_id: UUID = Field(index=True)
timeline_id: UUID
stage_name: str
checkpoint_id: Optional[UUID] = None
output: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
class Brand(SQLModel, table=True):
"""A brand discovered or registered in the system."""
__tablename__ = "brand"
id: UUID = Field(default_factory=uuid4, primary_key=True)
canonical_name: str = Field(index=True)
aliases: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
source: BrandSource = "ocr"
confirmed: bool = False
airings: List[Dict[str, Any]] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
total_airings: int = 0
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
class Profile(SQLModel, table=True):
"""A content type profile."""
__tablename__ = "profile"
id: UUID = Field(default_factory=uuid4, primary_key=True)
name: str
pipeline: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
configs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))

43
core/db/seed.py Normal file
View File

@@ -0,0 +1,43 @@
"""
Seed data — insert initial profile rows if they don't exist.
Called on startup after create_tables().
"""
import json
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
SEED_DIR = Path(__file__).parent / "fixtures"
def seed_profiles():
"""Insert seed profiles from JSON fixtures if not already present."""
from .connection import get_session
from .models import Profile
fixtures = list(SEED_DIR.glob("*.json"))
if not fixtures:
return
with get_session() as session:
for f in fixtures:
data = json.loads(f.read_text())
name = data["name"]
existing = session.query(Profile).filter(Profile.name == name).first()
if existing:
logger.debug("Profile %s already exists, skipping seed", name)
continue
profile = Profile(
name=name,
pipeline=data.get("pipeline", {}),
configs=data.get("configs", {}),
)
session.add(profile)
logger.info("Seeded profile: %s", name)
session.commit()

0
core/detect/__init__.py Normal file
View File

View File

@@ -0,0 +1,31 @@
"""
Checkpoint system — Timeline + Checkpoint tree + StageOutput.
detect/checkpoint/
frames.py — per-timeline frame cache (local filesystem)
storage.py — Timeline, Checkpoint, StageOutput persistence
replay.py — replay from checkpoint (TODO: rework in 5d)
runner_bridge.py — checkpoint hook for PipelineRunner
"""
from .storage import (
create_timeline,
get_timeline,
update_timeline_status,
save_checkpoint,
get_checkpoints_for_job,
get_checkpoints_for_timeline,
save_stage_output,
load_stage_output,
load_stage_outputs_for_job,
load_stage_outputs_for_timeline,
)
from .frames import (
cache_exists,
cache_frames,
load_cached_frames,
load_cached_frames_b64,
clear_cache,
frames_to_b64,
)
from .runner_bridge import checkpoint_after_stage, reset_checkpoint_state, get_latest_checkpoint

View File

@@ -0,0 +1,281 @@
"""
Frame cache — per-timeline frame storage in blob storage (S3/MinIO).
Frames are extracted from chunks once, cached as JPEGs at
cache/timelines/{timeline_id}/frames/{seq}.jpg in the app's
blob storage. Any job on the timeline reads from the cache.
Cache is clearable and rebuildable from chunks.
Uses the same storage backend as the rest of the app, so it
works across lambdas, GPU boxes, and local dev.
"""
from __future__ import annotations
import base64
import io
import logging
import os
import tempfile
import numpy as np
from PIL import Image
from core.detect.models import Frame
logger = logging.getLogger(__name__)
BUCKET = os.environ.get("S3_BUCKET", "mpr")
CACHE_PREFIX = "cache/timelines"
def _frame_key(timeline_id: str, seq: int) -> str:
return f"{CACHE_PREFIX}/{timeline_id}/frames/{seq}.jpg"
def _list_prefix(timeline_id: str) -> str:
return f"{CACHE_PREFIX}/{timeline_id}/frames/"
def cache_exists(timeline_id: str) -> bool:
"""Check if frame cache exists for a timeline."""
from core.storage.s3 import list_objects
objects = list_objects(BUCKET, _list_prefix(timeline_id))
return len(objects) > 0
def cache_frames(timeline_id: str, frames: list[Frame], quality: int = 85) -> int:
"""
Write frames to blob storage as JPEGs.
Returns number of frames cached.
"""
from core.storage.s3 import upload_file
for frame in frames:
key = _frame_key(timeline_id, frame.sequence)
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
img = Image.fromarray(frame.image)
img.save(tmp, format="JPEG", quality=quality)
tmp_path = tmp.name
try:
upload_file(tmp_path, BUCKET, key)
finally:
os.unlink(tmp_path)
logger.info("Cached %d frames for timeline %s", len(frames), timeline_id)
return len(frames)
def load_cached_frames(timeline_id: str) -> list[Frame]:
"""
Load all cached frames as Frame objects with numpy arrays.
Returns empty list if cache doesn't exist.
"""
from core.storage.s3 import list_objects, download_to_temp
objects = list_objects(BUCKET, _list_prefix(timeline_id))
if not objects:
return []
frames = []
for obj in objects:
key = obj["key"]
filename = key.rsplit("/", 1)[-1]
if not filename.endswith(".jpg"):
continue
seq = int(filename.replace(".jpg", ""))
tmp_path = download_to_temp(BUCKET, key)
try:
img = Image.open(tmp_path).convert("RGB")
image_array = np.array(img)
finally:
os.unlink(tmp_path)
frame = Frame(
sequence=seq,
chunk_id=0,
timestamp=0.0,
image=image_array,
perceptual_hash="",
)
frames.append(frame)
frames.sort(key=lambda f: f.sequence)
return frames
def load_cached_frames_b64(timeline_id: str) -> list[dict]:
"""
Load cached frames as base64 JPEGs for the UI.
Returns list of {seq, timestamp, jpeg_b64}.
"""
from core.storage.s3 import list_objects, download_to_temp
objects = list_objects(BUCKET, _list_prefix(timeline_id))
if not objects:
return []
result = []
for obj in objects:
key = obj["key"]
filename = key.rsplit("/", 1)[-1]
if not filename.endswith(".jpg"):
continue
seq = int(filename.replace(".jpg", ""))
tmp_path = download_to_temp(BUCKET, key)
try:
with open(tmp_path, "rb") as f:
jpeg_b64 = base64.b64encode(f.read()).decode()
finally:
os.unlink(tmp_path)
result.append({
"seq": seq,
"timestamp": 0.0,
"jpeg_b64": jpeg_b64,
})
result.sort(key=lambda f: f["seq"])
return result
# ---------------------------------------------------------------------------
# Debug overlay storage — per job/stage/frame
# ---------------------------------------------------------------------------
def _overlay_prefix(timeline_id: str, job_id: str, stage: str) -> str:
return f"{CACHE_PREFIX}/{timeline_id}/overlays/{job_id}/{stage}/"
def _overlay_key(timeline_id: str, job_id: str, stage: str, seq: int, name: str) -> str:
return f"{CACHE_PREFIX}/{timeline_id}/overlays/{job_id}/{stage}/{seq}_{name}.png"
def save_overlays(
timeline_id: str,
job_id: str,
stage: str,
seq: int,
overlays: dict[str, str],
):
"""
Save debug overlay images (base64 PNG) to blob storage.
overlays: {overlay_key: base64_png_string}
e.g. {"edge_overlay_b64": "iVBOR...", "lines_overlay_b64": "iVBOR..."}
"""
from core.storage.s3 import upload_file
import tempfile
for name, b64_data in overlays.items():
key = _overlay_key(timeline_id, job_id, stage, seq, name)
raw = base64.b64decode(b64_data)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
tmp.write(raw)
tmp_path = tmp.name
try:
upload_file(tmp_path, BUCKET, key)
finally:
os.unlink(tmp_path)
logger.info("Saved %d overlays for timeline %s job %s stage %s frame %d",
len(overlays), timeline_id, job_id, stage, seq)
def load_overlays(
timeline_id: str,
job_id: str,
stage: str,
seq: int,
) -> dict[str, str] | None:
"""
Load debug overlay images from blob storage as base64 strings.
Returns {overlay_key: base64_png_string} or None if no overlays cached.
"""
from core.storage.s3 import list_objects, download_to_temp
prefix = _overlay_prefix(timeline_id, job_id, stage)
seq_prefix = f"{seq}_"
objects = list_objects(BUCKET, prefix)
overlays = {}
for obj in objects:
filename = obj["key"].rsplit("/", 1)[-1]
if not filename.startswith(seq_prefix):
continue
name = filename[len(seq_prefix):].replace(".png", "")
tmp_path = download_to_temp(BUCKET, obj["key"])
try:
with open(tmp_path, "rb") as f:
overlays[name] = base64.b64encode(f.read()).decode()
finally:
os.unlink(tmp_path)
return overlays if overlays else None
def list_overlay_frames(
timeline_id: str,
job_id: str,
stage: str,
) -> list[int]:
"""List frame sequences that have cached overlays."""
from core.storage.s3 import list_objects
prefix = _overlay_prefix(timeline_id, job_id, stage)
objects = list_objects(BUCKET, prefix)
seqs = set()
for obj in objects:
filename = obj["key"].rsplit("/", 1)[-1]
seq_str = filename.split("_")[0]
try:
seqs.add(int(seq_str))
except ValueError:
continue
return sorted(seqs)
def clear_cache(timeline_id: str):
"""Delete the frame cache for a timeline."""
from core.storage.s3 import delete_objects
prefix = _list_prefix(timeline_id)
delete_objects(BUCKET, prefix)
logger.info("Cleared frame cache for timeline %s", timeline_id)
def frames_to_b64(frames: list[Frame], quality: int = 75) -> list[dict]:
"""
Convert in-memory Frame objects to base64 JPEG dicts.
For API responses when frames are already in memory.
"""
result = []
for frame in frames:
buf = io.BytesIO()
img = Image.fromarray(frame.image)
img.save(buf, format="JPEG", quality=quality)
jpeg_b64 = base64.b64encode(buf.getvalue()).decode()
result.append({
"seq": frame.sequence,
"timestamp": frame.timestamp,
"jpeg_b64": jpeg_b64,
})
result.sort(key=lambda f: f["seq"])
return result

View File

@@ -0,0 +1,307 @@
"""
Pipeline replay — re-run from any stage with different config.
Loads stage outputs from DB, frames from timeline cache,
reconstitutes state, and runs from a target stage onward.
Creates a new Job (run_type=REPLAY) for each replay invocation.
"""
from __future__ import annotations
import logging
import os
import uuid
from core.detect import emit
from core.detect.graph import NODES, get_pipeline
from core.detect.graph.runner import PipelineRunner
logger = logging.getLogger(__name__)
def _build_state_for_replay(
job_id: str,
up_to_stage: str,
) -> dict:
"""
Reconstitute pipeline state from a completed job's stage outputs,
up to (but not including) the target stage.
Loads frames from timeline cache + stage outputs from DB.
"""
from .storage import load_stage_outputs_for_job, get_checkpoints_for_job
from .frames import load_cached_frames
from core.db.connection import get_session
from core.db.job import get_job
# Load the job to get timeline_id and profile
with get_session() as session:
job = get_job(session, uuid.UUID(job_id))
if not job:
raise ValueError(f"Job not found: {job_id}")
timeline_id = str(job.timeline_id) if job.timeline_id else ""
if not timeline_id:
raise ValueError(f"Job {job_id} has no timeline")
# Load frames from timeline cache
frames = load_cached_frames(timeline_id)
if not frames:
raise ValueError(f"No cached frames for timeline {timeline_id}. Run the pipeline first.")
# Load all stage outputs for this job
all_outputs = load_stage_outputs_for_job(job_id)
# Build state with envelope + frames
state = {
"job_id": job_id,
"timeline_id": timeline_id,
"video_path": job.video_path,
"profile_name": job.profile_name,
"source_asset_id": str(job.source_asset_id),
"frames": frames,
"config_overrides": {},
}
# Apply stage outputs in pipeline order, up to the target stage
target_idx = NODES.index(up_to_stage)
for stage_name in NODES[:target_idx]:
output = all_outputs.get(stage_name)
if output:
# Stage outputs contain serialized data — merge into state
# The stage registry's deserialize_fn can reconstitute if needed
for key, value in output.items():
state[key] = value
# Filtered frames: reconstruct from sequence list if present
filtered_seqs = state.get("filtered_frame_sequences")
if filtered_seqs:
seq_set = set(filtered_seqs)
state["filtered_frames"] = [f for f in frames if f.sequence in seq_set]
elif "filtered_frames" not in state:
state["filtered_frames"] = frames
return state
def replay_from(
job_id: str,
start_stage: str,
config_overrides: dict | None = None,
checkpoint: bool = True,
) -> dict:
"""
Replay the pipeline from a specific stage.
Loads state from the original job's stage outputs up to start_stage,
applies config overrides, and runs from start_stage onward.
Creates a new Job (run_type=REPLAY).
Returns the final state dict.
"""
if start_stage not in NODES:
raise ValueError(f"Unknown stage: {start_stage!r}. Options: {NODES}")
start_idx = NODES.index(start_stage)
if start_idx == 0:
raise ValueError("Cannot replay from the first stage — just run the full pipeline")
logger.info("Replaying job %s from %s", job_id, start_stage)
state = _build_state_for_replay(job_id, start_stage)
# Apply config overrides
if config_overrides:
state["config_overrides"] = config_overrides
# Create replay job
from core.db.connection import get_session
from core.db.job import create_job, get_job
with get_session() as session:
original = get_job(session, uuid.UUID(job_id))
replay_job = create_job(
session,
source_asset_id=original.source_asset_id,
video_path=original.video_path,
timeline_id=original.timeline_id,
profile_name=original.profile_name,
run_type="replay",
parent_id=original.id,
config_overrides=config_overrides,
)
replay_job_id = str(replay_job.id)
# Update state with new job ID
state["job_id"] = replay_job_id
# Set run context for SSE events
emit.set_run_context(
run_id=replay_job_id,
parent_job_id=job_id,
run_type="replay",
)
# Run from start_stage onward
pipeline = get_pipeline(
checkpoint=checkpoint,
profile_name=state["profile_name"],
start_from=start_stage,
)
try:
result = pipeline.invoke(state)
finally:
emit.clear_run_context()
return result
def replay_single_stage(
job_id: str,
stage: str,
frame_refs: list[int] | None = None,
config_overrides: dict | None = None,
debug: bool = False,
) -> dict:
"""
Replay a single stage on specific frames (or all frames from checkpoint).
Fast path for interactive parameter tuning — runs only the target stage
function, not the full pipeline tail. Returns the stage output directly.
"""
if stage not in NODES:
raise ValueError(f"Unknown stage: {stage!r}. Options: {NODES}")
stage_idx = NODES.index(stage)
if stage_idx == 0:
raise ValueError("Cannot replay the first stage — just run the full pipeline")
logger.info("Single-stage replay: job %s, stage %s (debug=%s)", job_id, stage, debug)
state = _build_state_for_replay(job_id, stage)
# Build profile with overrides
from core.detect.profile import get_profile, get_stage_config
profile = get_profile(state.get("profile_name", "soccer_broadcast"))
if config_overrides:
merged_configs = dict(profile.get("configs", {}))
for sname, soverrides in config_overrides.items():
if sname in merged_configs:
merged_configs[sname] = {**merged_configs[sname], **soverrides}
else:
merged_configs[sname] = soverrides
profile = {**profile, "configs": merged_configs}
# Subset frames if requested
frames = state.get("filtered_frames", state.get("frames", []))
if frame_refs:
ref_set = set(frame_refs)
frames = [f for f in frames if f.sequence in ref_set]
# Run the specific stage
if stage == "detect_edges":
return _replay_detect_edges(state, profile, frames, job_id, debug)
elif stage == "field_segmentation":
return _replay_field_segmentation(state, profile, frames, job_id, debug)
else:
raise ValueError(
f"Single-stage replay not yet implemented for {stage!r}. "
f"Use replay_from() for full pipeline replay."
)
def _replay_detect_edges(
state: dict,
profile,
frames: list,
job_id: str,
debug: bool,
) -> dict:
"""Run edge detection on checkpoint frames, optionally with debug overlays."""
from core.detect.stages.edge_detector import detect_edge_regions
from core.detect.profile import get_stage_config
from core.detect.stages.models import RegionAnalysisConfig
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
inference_url = os.environ.get("INFERENCE_URL")
field_masks = state.get("field_masks", {})
result = detect_edge_regions(
frames=frames,
config=config,
inference_url=inference_url,
job_id=job_id,
field_masks=field_masks,
)
output = {"edge_regions_by_frame": result}
if debug and frames:
debug_data = {}
if inference_url:
from core.detect.inference import InferenceClient
client = InferenceClient(base_url=inference_url, job_id=job_id)
for frame in frames:
dr = client.detect_edges_debug(
image=frame.image,
edge_canny_low=config.edge_canny_low,
edge_canny_high=config.edge_canny_high,
edge_hough_threshold=config.edge_hough_threshold,
edge_hough_min_length=config.edge_hough_min_length,
edge_hough_max_gap=config.edge_hough_max_gap,
edge_pair_max_distance=config.edge_pair_max_distance,
edge_pair_min_distance=config.edge_pair_min_distance,
)
debug_data[frame.sequence] = {
"edge_overlay_b64": dr.edge_overlay_b64,
"lines_overlay_b64": dr.lines_overlay_b64,
"horizontal_count": dr.horizontal_count,
"pair_count": dr.pair_count,
}
else:
from core.detect.stages.edge_detector import _load_cv_edges
edges_mod = _load_cv_edges()
for frame in frames:
dr = edges_mod.detect_edges_debug(
frame.image,
canny_low=config.edge_canny_low,
canny_high=config.edge_canny_high,
hough_threshold=config.edge_hough_threshold,
hough_min_length=config.edge_hough_min_length,
hough_max_gap=config.edge_hough_max_gap,
pair_max_distance=config.edge_pair_max_distance,
pair_min_distance=config.edge_pair_min_distance,
)
debug_data[frame.sequence] = {
"edge_overlay_b64": dr["edge_overlay_b64"],
"lines_overlay_b64": dr["lines_overlay_b64"],
"horizontal_count": dr["horizontal_count"],
"pair_count": dr["pair_count"],
}
output["debug"] = debug_data
return output
def _replay_field_segmentation(
state: dict,
profile,
frames: list,
job_id: str,
debug: bool,
) -> dict:
"""Run field segmentation on checkpoint frames."""
from core.detect.stages.field_segmentation import run_field_segmentation
from core.detect.profile import get_stage_config
from core.detect.stages.models import FieldSegmentationConfig
config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation"))
inference_url = os.environ.get("INFERENCE_URL")
result = run_field_segmentation(
frames=frames,
config=config,
inference_url=inference_url,
job_id=job_id,
)
return result

View File

@@ -0,0 +1,99 @@
"""
Runner bridge — checkpoint hook called by PipelineRunner after each stage.
Saves a checkpoint + stage output after each stage completes.
Timeline and Job are independent: timeline_id and job_id come from
the pipeline state (set at job creation time).
"""
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
# Per-job state: tracks the latest checkpoint so we can chain parent → child
_latest_checkpoint: dict[str, str] = {}
def reset_checkpoint_state(job_id: str):
"""Clean up per-job checkpoint state. Called when pipeline finishes."""
_latest_checkpoint.pop(job_id, None)
def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict):
"""
Save a checkpoint + stage output after a stage completes.
Called by the runner. Handles:
- Stage output serialization (via stage registry)
- Checkpoint chain (parent → child)
- Stage output as separate row in StageOutput table
"""
if not job_id:
return
timeline_id = state.get("timeline_id", "")
if not timeline_id:
logger.warning("No timeline_id in state for job %s, skipping checkpoint", job_id)
return
from .storage import save_checkpoint, save_stage_output
from core.detect.stages.base import _REGISTRY, _LEGACY_REGISTRY
merged = {**state, **result}
# Serialize stage output using the stage's serialize_fn if available
# Check new-style registry first, then legacy (some stages are in both)
serialize_fn = None
stage_cls = _REGISTRY.get(stage_name)
if stage_cls:
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
if not serialize_fn:
legacy = _LEGACY_REGISTRY.get(stage_name)
if legacy:
serialize_fn = legacy.serialize_fn
if serialize_fn:
output_json = serialize_fn(merged, job_id)
else:
output_json = {}
# Convert stats dataclass to dict for JSONB storage
import dataclasses
raw_stats = state.get("stats", {})
if dataclasses.is_dataclass(raw_stats):
stats_dict = dataclasses.asdict(raw_stats)
elif isinstance(raw_stats, dict):
stats_dict = raw_stats
else:
stats_dict = {}
# Save checkpoint (lightweight tree node)
parent_id = _latest_checkpoint.get(job_id)
checkpoint_id = save_checkpoint(
timeline_id=timeline_id,
stage_name=stage_name,
parent_checkpoint_id=parent_id,
config_overrides=state.get("config_overrides"),
stats=stats_dict,
job_id=job_id,
)
_latest_checkpoint[job_id] = checkpoint_id
# Save stage output (separate row, upsert by job+stage)
if output_json:
save_stage_output(
job_id=job_id,
timeline_id=timeline_id,
stage_name=stage_name,
output=output_json,
checkpoint_id=checkpoint_id,
)
logger.info("Checkpoint %s + output for stage %s (job %s)", checkpoint_id, stage_name, job_id)
def get_latest_checkpoint(job_id: str) -> str | None:
"""Get the latest checkpoint_id for a running job."""
return _latest_checkpoint.get(job_id)

View File

@@ -0,0 +1,109 @@
"""
State serialization — DetectState ↔ JSON-compatible dict.
Delegates to each stage's serialize_fn/deserialize_fn via the registry.
This file has no model-specific knowledge — stages own their data format.
The only things serialized here are the "envelope" fields (job_id, video_path, etc.)
that don't belong to any stage.
Frames are ephemeral (in-memory during a run). Serialization stores
metadata only; frames are re-extracted from chunks when needed.
"""
from __future__ import annotations
from core.schema.serializers._common import serialize_dataclass
from core.schema.serializers.pipeline import (
deserialize_pipeline_stats,
deserialize_text_candidates,
)
# Envelope fields — not owned by any stage, always present
ENVELOPE_KEYS = ["job_id", "video_path", "profile_name", "timeline_id", "config_overrides"]
def serialize_state(state: dict) -> dict:
"""
Serialize DetectState to a JSON-compatible dict.
Calls each registered stage's serialize_fn for stage-owned data.
Envelope fields (job_id, etc.) are copied directly.
"""
from core.detect.stages.base import _REGISTRY
checkpoint = {}
# Envelope
for key in ENVELOPE_KEYS:
default = {} if key == "config_overrides" else ""
checkpoint[key] = state.get(key, default)
# Stats (shared across stages, not owned by one)
stats = state.get("stats")
if stats is not None:
checkpoint["stats"] = serialize_dataclass(stats)
else:
checkpoint["stats"] = {}
# Per-stage data
for name, stage_def in _REGISTRY.items():
if stage_def.serialize_fn is None:
continue
job_id = state.get("job_id", "")
stage_data = stage_def.serialize_fn(state, job_id)
checkpoint[f"stage_{name}"] = stage_data
return checkpoint
def deserialize_state(checkpoint: dict, frames: list) -> dict:
"""
Reconstitute DetectState from a checkpoint dict + frames.
Frames are provided by the caller (re-extracted from chunks).
Calls each stage's deserialize_fn to restore stage-owned data.
"""
from core.detect.stages.base import _REGISTRY
frame_map = {f.sequence: f for f in frames}
state = {}
# Envelope
for key in ENVELOPE_KEYS:
default = {} if key == "config_overrides" else ""
state[key] = checkpoint.get(key, default)
# Frames (provided externally, ephemeral)
state["frames"] = frames
# Stats
state["stats"] = deserialize_pipeline_stats(checkpoint.get("stats", {}))
# Per-stage data
for name, stage_def in _REGISTRY.items():
if stage_def.deserialize_fn is None:
continue
stage_key = f"stage_{name}"
if stage_key not in checkpoint:
continue
job_id = state.get("job_id", "")
stage_data = stage_def.deserialize_fn(checkpoint[stage_key], job_id)
for k, v in stage_data.items():
if k == "_filtered_sequences":
# Reconnect filtered frames from sequence list
seq_set = set(v)
state["filtered_frames"] = [f for f in frames if f.sequence in seq_set]
elif k.endswith("_raw"):
# Raw text candidates need frame reference reconnection
real_key = k.removeprefix("_").removesuffix("_raw")
state[real_key] = deserialize_text_candidates(v, frame_map)
else:
state[k] = v
return state

View File

@@ -0,0 +1,303 @@
"""
Checkpoint storage — Timeline, Checkpoint, StageOutput persistence.
Timeline: user-created source selection (chunk paths)
Checkpoint: lightweight tree node (parent_id, stage_name, config, stats)
StageOutput: per-stage result (flat table, one row per job+stage)
"""
from __future__ import annotations
import logging
from uuid import UUID
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Timeline
# ---------------------------------------------------------------------------
def create_timeline(
chunk_paths: list[str],
profile_name: str = "",
name: str = "",
source_asset_id: UUID | None = None,
fps: float = 2.0,
) -> str:
"""
Create a timeline from a chunk selection.
Called by the user (via API) before any pipeline runs.
Returns timeline_id.
"""
from core.db.models import Timeline
from core.db.connection import get_session
with get_session() as session:
timeline = Timeline(
name=name,
chunk_paths=chunk_paths,
profile_name=profile_name,
source_asset_id=source_asset_id,
fps=fps,
status="created",
)
session.add(timeline)
session.commit()
session.refresh(timeline)
tid = str(timeline.id)
logger.info("Timeline created: %s (%d chunks)", tid, len(chunk_paths))
return tid
def get_timeline(timeline_id: str) -> dict:
"""Load a timeline as a dict."""
from core.db.models import Timeline
from core.db.connection import get_session
with get_session() as session:
timeline = session.get(Timeline, UUID(timeline_id))
if not timeline:
raise ValueError(f"Timeline not found: {timeline_id}")
return {
"id": str(timeline.id),
"name": timeline.name,
"chunk_paths": timeline.chunk_paths,
"profile_name": timeline.profile_name,
"status": timeline.status,
"fps": timeline.fps,
"source_asset_id": str(timeline.source_asset_id) if timeline.source_asset_id else None,
"created_at": str(timeline.created_at) if timeline.created_at else None,
}
def update_timeline_status(timeline_id: str, status: str, frame_count: int | None = None):
"""Update timeline status and optionally frame count."""
from core.db.models import Timeline
from core.db.connection import get_session
with get_session() as session:
timeline = session.get(Timeline, UUID(timeline_id))
if timeline:
timeline.status = status
if frame_count is not None:
timeline.frame_count = frame_count
session.commit()
# ---------------------------------------------------------------------------
# Checkpoint
# ---------------------------------------------------------------------------
def save_checkpoint(
timeline_id: str,
stage_name: str,
parent_checkpoint_id: str | None = None,
config_overrides: dict | None = None,
stats: dict | None = None,
is_scenario: bool = False,
scenario_label: str = "",
job_id: str | None = None,
) -> str:
"""
Save a checkpoint (lightweight tree node).
No stage outputs — those go in StageOutput table separately.
Returns the new checkpoint ID.
"""
from core.db.models import Checkpoint
from core.db.connection import get_session
with get_session() as session:
checkpoint = Checkpoint(
timeline_id=UUID(timeline_id),
job_id=UUID(job_id) if job_id else None,
parent_id=UUID(parent_checkpoint_id) if parent_checkpoint_id else None,
stage_name=stage_name,
config_overrides=config_overrides or {},
stats=stats or {},
is_scenario=is_scenario,
scenario_label=scenario_label,
)
session.add(checkpoint)
session.commit()
session.refresh(checkpoint)
cid = str(checkpoint.id)
logger.info("Checkpoint saved: %s (timeline %s, stage %s, parent %s)",
cid, timeline_id, stage_name, parent_checkpoint_id)
return cid
def get_checkpoints_for_job(job_id: str) -> list[dict]:
"""List checkpoints for a job, ordered by creation time."""
from sqlmodel import select
from core.db.models import Checkpoint
from core.db.connection import get_session
with get_session() as session:
stmt = (
select(Checkpoint)
.where(Checkpoint.job_id == UUID(job_id))
.order_by(Checkpoint.created_at)
)
checkpoints = session.exec(stmt).all()
return [
{
"id": str(c.id),
"timeline_id": str(c.timeline_id),
"job_id": str(c.job_id) if c.job_id else None,
"parent_id": str(c.parent_id) if c.parent_id else None,
"stage_name": c.stage_name,
"config_overrides": c.config_overrides or {},
"stats": c.stats or {},
"is_scenario": c.is_scenario,
"scenario_label": c.scenario_label,
"created_at": str(c.created_at) if c.created_at else None,
}
for c in checkpoints
]
def get_checkpoints_for_timeline(timeline_id: str) -> list[dict]:
"""List all checkpoints on a timeline, ordered by creation time."""
from sqlmodel import select
from core.db.models import Checkpoint
from core.db.connection import get_session
with get_session() as session:
stmt = (
select(Checkpoint)
.where(Checkpoint.timeline_id == UUID(timeline_id))
.order_by(Checkpoint.created_at)
)
checkpoints = session.exec(stmt).all()
return [
{
"id": str(c.id),
"timeline_id": str(c.timeline_id),
"job_id": str(c.job_id) if c.job_id else None,
"parent_id": str(c.parent_id) if c.parent_id else None,
"stage_name": c.stage_name,
"config_overrides": c.config_overrides or {},
"stats": c.stats or {},
"is_scenario": c.is_scenario,
"scenario_label": c.scenario_label,
"created_at": str(c.created_at) if c.created_at else None,
}
for c in checkpoints
]
# ---------------------------------------------------------------------------
# StageOutput
# ---------------------------------------------------------------------------
def save_stage_output(
job_id: str,
timeline_id: str,
stage_name: str,
output: dict,
checkpoint_id: str | None = None,
) -> str:
"""
Save (upsert) a stage output. One row per (job_id, stage_name).
Returns the stage_output ID.
"""
from sqlmodel import select
from core.db.models import StageOutput
from core.db.connection import get_session
with get_session() as session:
# Upsert: check if exists
stmt = (
select(StageOutput)
.where(StageOutput.job_id == UUID(job_id))
.where(StageOutput.stage_name == stage_name)
)
existing = session.exec(stmt).first()
if existing:
existing.output = output
existing.checkpoint_id = UUID(checkpoint_id) if checkpoint_id else None
session.commit()
session.refresh(existing)
return str(existing.id)
stage_output = StageOutput(
job_id=UUID(job_id),
timeline_id=UUID(timeline_id),
stage_name=stage_name,
checkpoint_id=UUID(checkpoint_id) if checkpoint_id else None,
output=output,
)
session.add(stage_output)
session.commit()
session.refresh(stage_output)
return str(stage_output.id)
def load_stage_output(job_id: str, stage_name: str) -> dict | None:
"""Load a stage's output by job + stage name."""
from sqlmodel import select
from core.db.models import StageOutput
from core.db.connection import get_session
with get_session() as session:
stmt = (
select(StageOutput)
.where(StageOutput.job_id == UUID(job_id))
.where(StageOutput.stage_name == stage_name)
)
row = session.exec(stmt).first()
if not row:
return None
return row.output
def load_stage_outputs_for_job(job_id: str) -> dict[str, dict]:
"""Load all stage outputs for a job. Returns {stage_name: output}."""
from sqlmodel import select
from core.db.models import StageOutput
from core.db.connection import get_session
with get_session() as session:
stmt = (
select(StageOutput)
.where(StageOutput.job_id == UUID(job_id))
)
rows = session.exec(stmt).all()
return {row.stage_name: row.output for row in rows}
def load_stage_outputs_for_timeline(timeline_id: str, stage_name: str | None = None) -> list[dict]:
"""Load stage outputs for a timeline, optionally filtered by stage."""
from sqlmodel import select
from core.db.models import StageOutput
from core.db.connection import get_session
with get_session() as session:
stmt = select(StageOutput).where(StageOutput.timeline_id == UUID(timeline_id))
if stage_name:
stmt = stmt.where(StageOutput.stage_name == stage_name)
rows = session.exec(stmt).all()
return [
{
"id": str(r.id),
"job_id": str(r.job_id),
"stage_name": r.stage_name,
"checkpoint_id": str(r.checkpoint_id) if r.checkpoint_id else None,
"output": r.output,
"created_at": str(r.created_at) if r.created_at else None,
}
for r in rows
]

159
core/detect/emit.py Normal file
View File

@@ -0,0 +1,159 @@
"""
Event emission helpers for detection pipeline stages.
Single place that knows how to build event payloads.
Stages call these instead of constructing dicts or dataclasses directly.
Run context (run_id, parent_job_id) is set once at pipeline start via
set_run_context() and automatically injected into all events.
Log level is set per-run with optional per-stage overrides.
DEBUG events are only pushed when the run (or stage) log level allows it.
"""
from __future__ import annotations
import dataclasses
from datetime import datetime, timezone
from core.detect.events import push_detect_event
from core.detect.models import PipelineStats
# Log level ordering for comparison
_LEVEL_ORDER = {"DEBUG": 0, "INFO": 1, "WARN": 2, "ERROR": 3}
# Module-level run context — set once per pipeline invocation
_run_context: dict = {}
_run_log_level: str = "INFO"
_stage_log_levels: dict[str, str] = {} # stage_name → level override
def set_run_context(
run_id: str = "",
parent_job_id: str = "",
run_type: str = "initial",
log_level: str = "INFO",
):
"""Set the run context for all subsequent events in this pipeline invocation."""
global _run_context, _run_log_level
_run_context = {
"run_id": run_id,
"parent_job_id": parent_job_id,
"run_type": run_type,
}
_run_log_level = log_level.upper()
_stage_log_levels.clear()
def set_stage_log_level(stage: str, level: str):
"""Override log level for a specific stage."""
_stage_log_levels[stage] = level.upper()
def clear_stage_log_level(stage: str):
"""Remove per-stage log level override."""
_stage_log_levels.pop(stage, None)
def clear_run_context():
global _run_context, _run_log_level
_run_context = {}
_run_log_level = "INFO"
_stage_log_levels.clear()
def _should_emit(level: str, stage: str) -> bool:
"""Check if this log level should be emitted given run/stage settings."""
effective = _stage_log_levels.get(stage, _run_log_level)
return _LEVEL_ORDER.get(level.upper(), 1) >= _LEVEL_ORDER.get(effective, 1)
def _inject_context(payload: dict) -> dict:
"""Add run context fields to an event payload."""
if _run_context:
payload.update(_run_context)
return payload
def log(job_id: str | None, stage: str, level: str, msg: str) -> None:
if not job_id:
return
if not _should_emit(level, stage):
return
payload = {
"level": level,
"stage": stage,
"msg": msg,
"ts": datetime.now(timezone.utc).isoformat(),
}
_inject_context(payload)
push_detect_event(job_id, "log", payload)
def stats(job_id: str | None, **kwargs) -> None:
if not job_id:
return
s = PipelineStats(**kwargs)
payload = dataclasses.asdict(s)
_inject_context(payload)
push_detect_event(job_id, "stats_update", payload)
def frame_update(
job_id: str | None,
frame_ref: int,
timestamp: float,
jpeg_b64: str,
boxes: list[dict],
) -> None:
if not job_id:
return
payload = {
"frame_ref": frame_ref,
"timestamp": timestamp,
"jpeg_b64": jpeg_b64,
"boxes": boxes,
}
_inject_context(payload)
push_detect_event(job_id, "frame_update", payload)
def graph_update(job_id: str | None, nodes: list[dict]) -> None:
if not job_id:
return
payload = {"nodes": nodes}
_inject_context(payload)
push_detect_event(job_id, "graph_update", payload)
def detection(
job_id: str | None,
brand: str,
confidence: float,
source: str,
timestamp: float,
duration: float = 0.0,
content_type: str = "",
frame_ref: int | None = None,
) -> None:
if not job_id:
return
payload = {
"brand": brand,
"confidence": confidence,
"source": source,
"timestamp": timestamp,
"duration": duration,
"content_type": content_type,
"frame_ref": frame_ref,
}
_inject_context(payload)
push_detect_event(job_id, "detection", payload)
def job_complete(job_id: str | None, report: dict) -> None:
if not job_id:
return
payload = {"job_id": job_id, "report": report}
_inject_context(payload)
push_detect_event(job_id, "job_complete", payload)

42
core/detect/events.py Normal file
View File

@@ -0,0 +1,42 @@
"""
Detection pipeline event helpers.
Non-generated runtime code for pushing SSE events.
The event payload types are in sse_contract.py (generated by modelgen).
"""
from pydantic import BaseModel
from core.events import push_event
DETECT_EVENTS_PREFIX = "detect_events"
# SSE event type names
EVENT_GRAPH_UPDATE = "graph_update"
EVENT_STATS_UPDATE = "stats_update"
EVENT_FRAME_UPDATE = "frame_update"
EVENT_DETECTION = "detection"
EVENT_LOG = "log"
EVENT_JOB_COMPLETE = "job_complete"
ALL_EVENT_TYPES = [
EVENT_GRAPH_UPDATE,
EVENT_STATS_UPDATE,
EVENT_FRAME_UPDATE,
EVENT_DETECTION,
EVENT_LOG,
EVENT_JOB_COMPLETE,
]
TERMINAL_EVENTS = [EVENT_JOB_COMPLETE]
def push_detect_event(job_id: str, event_type: str, data: BaseModel | dict) -> None:
"""Push a detection event to Redis. Accepts Pydantic models or plain dicts."""
payload = data.model_dump(mode="json") if isinstance(data, BaseModel) else data
push_event(
job_id=job_id,
event_type=event_type,
data=payload,
prefix=DETECT_EVENTS_PREFIX,
)

View File

@@ -0,0 +1,45 @@
"""
Detection pipeline graph.
detect/graph/
nodes.py — node functions (one per stage)
events.py — graph_update SSE emission
runner.py — PipelineRunner (config-driven, checkpoint, cancel, pause)
"""
from .nodes import NODES, NODE_FUNCTIONS
from .runner import (
PipelineCancelled,
PipelineRunner,
build_graph,
clear_cancel_check,
clear_pause,
get_pipeline,
init_pause,
is_paused,
pause_pipeline,
resume_pipeline,
set_cancel_check,
set_pause_after_stage,
step_pipeline,
)
from .events import _node_states
__all__ = [
"NODES",
"NODE_FUNCTIONS",
"PipelineCancelled",
"PipelineRunner",
"build_graph",
"get_pipeline",
"set_cancel_check",
"clear_cancel_check",
"init_pause",
"clear_pause",
"pause_pipeline",
"resume_pipeline",
"step_pipeline",
"set_pause_after_stage",
"is_paused",
"_node_states",
]

View File

@@ -0,0 +1,27 @@
"""
Graph event emission — node state tracking + SSE graph_update events.
"""
from __future__ import annotations
from core.detect import emit
from core.detect.state import DetectState
# Track node states across pipeline runs
_node_states: dict[str, dict[str, str]] = {}
def emit_transition(state: DetectState, node: str, status: str, node_list: list[str]):
"""Update node status and emit graph_update SSE event."""
job_id = state.get("job_id")
if not job_id:
return
if job_id not in _node_states:
_node_states[job_id] = {n: "pending" for n in node_list}
_node_states[job_id][node] = status
nodes = [{"id": n, "status": _node_states[job_id][n]} for n in node_list]
emit.graph_update(job_id, nodes)

386
core/detect/graph/nodes.py Normal file
View File

@@ -0,0 +1,386 @@
"""
Pipeline node functions — one per stage.
Each node: reads state, gets config from profile dict, runs stage logic,
emits transitions, returns output dict.
"""
from __future__ import annotations
import os
from core.detect import emit
from core.detect.models import CropContext, PipelineStats
from core.detect.profile import get_profile, get_stage_config, build_vlm_prompt, aggregate_detections
from core.detect.stages.models import (
DetectionConfig,
FieldSegmentationConfig,
FrameExtractionConfig,
OCRConfig,
RegionAnalysisConfig,
ResolverConfig,
SceneFilterConfig,
)
from core.detect.state import DetectState
from core.detect.stages.frame_extractor import extract_frames
from core.detect.stages.scene_filter import scene_filter
from core.detect.stages.field_segmentation import run_field_segmentation
from core.detect.stages.edge_detector import detect_edge_regions
from core.detect.stages.yolo_detector import detect_objects
from core.detect.stages.preprocess import preprocess_regions
from core.detect.stages.ocr_stage import run_ocr
from core.detect.stages.brand_resolver import resolve_brands
from core.detect.stages.vlm_local import escalate_vlm
from core.detect.stages.vlm_cloud import escalate_cloud
from core.detect.stages.aggregator import compile_report
from core.detect.tracing import trace_node, flush as flush_traces
from .events import emit_transition
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
NODES = [
"extract_frames",
"filter_scenes",
"field_segmentation",
"detect_edges",
"detect_objects",
"preprocess",
"run_ocr",
"match_brands",
"escalate_vlm",
"escalate_cloud",
"compile_report",
]
def _load_profile(state: DetectState) -> dict:
"""Load profile dict, apply config overrides if present."""
name = state.get("profile_name", "soccer_broadcast")
profile = get_profile(name)
overrides = state.get("config_overrides")
if overrides:
# Merge overrides into a copy of the profile configs
merged_configs = dict(profile.get("configs", {}))
for stage_name, stage_overrides in overrides.items():
if stage_name in merged_configs:
merged_configs[stage_name] = {**merged_configs[stage_name], **stage_overrides}
else:
merged_configs[stage_name] = stage_overrides
profile = {**profile, "configs": merged_configs}
return profile
def _emit(state, node, status):
emit_transition(state, node, status, NODES)
# --- Node functions ---
def node_extract_frames(state: DetectState) -> dict:
job_id = state.get("job_id", "")
if job_id and not emit._run_context:
emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial")
source_asset_id = state.get("source_asset_id")
if source_asset_id and not state.get("session_brands"):
from core.detect.stages.brand_resolver import build_session_dict
session_brands = build_session_dict(source_asset_id)
state["session_brands"] = session_brands
_emit(state, "extract_frames", "running")
with trace_node(state, "extract_frames") as span:
profile = _load_profile(state)
config = FrameExtractionConfig(**get_stage_config(profile, "extract_frames"))
frames = extract_frames(state["video_path"], config, job_id=job_id)
span.set_output({"frames_extracted": len(frames)})
# Cache frames on the timeline for reuse across jobs and UI
timeline_id = state.get("timeline_id")
if timeline_id:
from core.detect.checkpoint.frames import cache_frames, cache_exists
if not cache_exists(timeline_id):
cache_frames(timeline_id, frames)
from core.detect.checkpoint.storage import update_timeline_status
update_timeline_status(timeline_id, "cached", frame_count=len(frames))
_emit(state, "extract_frames", "done")
return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))}
def node_filter_scenes(state: DetectState) -> dict:
_emit(state, "filter_scenes", "running")
with trace_node(state, "filter_scenes") as span:
profile = _load_profile(state)
config = SceneFilterConfig(**get_stage_config(profile, "filter_scenes"))
frames = state.get("frames", [])
kept = scene_filter(frames, config, job_id=state.get("job_id"))
span.set_output({"frames_in": len(frames), "frames_kept": len(kept)})
stats = state.get("stats", PipelineStats())
stats.frames_after_scene_filter = len(kept)
_emit(state, "filter_scenes", "done")
return {"filtered_frames": kept, "stats": stats}
def node_field_segmentation(state: DetectState) -> dict:
_emit(state, "field_segmentation", "running")
with trace_node(state, "field_segmentation") as span:
profile = _load_profile(state)
config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation"))
frames = state.get("filtered_frames", [])
job_id = state.get("job_id")
result = run_field_segmentation(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
span.set_output({
"frames": len(frames),
"avg_coverage": sum(result["field_coverage"].values()) / max(len(result["field_coverage"]), 1),
})
_emit(state, "field_segmentation", "done")
return {
"field_masks": result["field_masks"],
"field_mask_overlays": result.get("field_mask_overlays", {}),
"field_boundaries": result["field_boundaries"],
"field_coverage": result["field_coverage"],
}
def node_detect_edges(state: DetectState) -> dict:
_emit(state, "detect_edges", "running")
with trace_node(state, "detect_edges") as span:
profile = _load_profile(state)
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
frames = state.get("filtered_frames", [])
field_masks = state.get("field_masks", {})
job_id = state.get("job_id")
# Apply edge transforms from upstream connections
edge_transforms = state.get("_edge_transforms", {})
for source_stage, transform in edge_transforms.items():
if transform.get("invert_mask") and field_masks:
import numpy as np
field_masks = {
seq: np.bitwise_not(mask) if mask is not None else None
for seq, mask in field_masks.items()
}
regions = detect_edge_regions(
frames, config, inference_url=INFERENCE_URL, job_id=job_id,
field_masks=field_masks,
)
total = sum(len(r) for r in regions.values())
span.set_output({"frames": len(frames), "edge_regions": total})
stats = state.get("stats", PipelineStats())
stats.cv_regions_detected = total
_emit(state, "detect_edges", "done")
return {"edge_regions_by_frame": regions, "stats": stats}
def node_detect_objects(state: DetectState) -> dict:
_emit(state, "detect_objects", "running")
with trace_node(state, "detect_objects") as span:
profile = _load_profile(state)
config = DetectionConfig(**get_stage_config(profile, "detect_objects"))
frames = state.get("filtered_frames", [])
job_id = state.get("job_id")
all_boxes = detect_objects(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
total_regions = sum(len(boxes) for boxes in all_boxes.values())
span.set_output({"frames": len(frames), "regions_detected": total_regions})
stats = state.get("stats", PipelineStats())
stats.regions_detected = total_regions
_emit(state, "detect_objects", "done")
return {"boxes_by_frame": all_boxes, "stats": stats}
def node_preprocess(state: DetectState) -> dict:
_emit(state, "preprocess", "running")
with trace_node(state, "preprocess") as span:
profile = _load_profile(state)
prep_config = get_stage_config(profile, "preprocess")
frames = state.get("filtered_frames", [])
boxes = state.get("boxes_by_frame", {})
job_id = state.get("job_id")
do_contrast = prep_config.get("contrast", True)
do_deskew = prep_config.get("deskew", False)
do_binarize = prep_config.get("binarize", False)
preprocessed = preprocess_regions(
frames, boxes,
do_contrast=do_contrast,
do_deskew=do_deskew,
do_binarize=do_binarize,
inference_url=INFERENCE_URL,
job_id=job_id,
)
span.set_output({"regions_preprocessed": len(preprocessed)})
_emit(state, "preprocess", "done")
return {"preprocessed_crops": preprocessed}
def node_run_ocr(state: DetectState) -> dict:
_emit(state, "run_ocr", "running")
with trace_node(state, "run_ocr") as span:
profile = _load_profile(state)
config = OCRConfig(**get_stage_config(profile, "run_ocr"))
frames = state.get("filtered_frames", [])
boxes = state.get("boxes_by_frame", {})
job_id = state.get("job_id")
candidates = run_ocr(frames, boxes, config, inference_url=INFERENCE_URL, job_id=job_id)
span.set_output({"regions_in": sum(len(b) for b in boxes.values()), "text_candidates": len(candidates)})
stats = state.get("stats", PipelineStats())
stats.regions_resolved_by_ocr = len(candidates)
_emit(state, "run_ocr", "done")
return {"text_candidates": candidates, "stats": stats}
def node_match_brands(state: DetectState) -> dict:
_emit(state, "match_brands", "running")
with trace_node(state, "match_brands") as span:
profile = _load_profile(state)
config = ResolverConfig(**get_stage_config(profile, "match_brands"))
candidates = state.get("text_candidates", [])
session_brands = state.get("session_brands", {})
job_id = state.get("job_id")
source_asset_id = state.get("source_asset_id")
matched, unresolved = resolve_brands(
candidates, config,
session_brands=session_brands,
source_asset_id=source_asset_id,
content_type=profile["name"], job_id=job_id,
)
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
_emit(state, "match_brands", "done")
return {"detections": matched, "unresolved_candidates": unresolved}
def node_escalate_vlm(state: DetectState) -> dict:
_emit(state, "escalate_vlm", "running")
with trace_node(state, "escalate_vlm") as span:
profile = _load_profile(state)
vlm_config = get_stage_config(profile, "escalate_vlm")
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
candidates = state.get("unresolved_candidates", [])
job_id = state.get("job_id")
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
vlm_matched, still_unresolved = escalate_vlm(
candidates,
vlm_prompt_fn=vlm_prompt_fn,
inference_url=INFERENCE_URL,
content_type=profile["name"],
source_asset_id=state.get("source_asset_id"),
job_id=job_id,
)
stats = state.get("stats", PipelineStats())
stats.regions_escalated_to_local_vlm = len(candidates)
span.set_output({"candidates": len(candidates), "matched": len(vlm_matched),
"still_unresolved": len(still_unresolved)})
existing = state.get("detections", [])
vlm_skipped = os.environ.get("SKIP_VLM", "").strip() == "1"
_emit(state, "escalate_vlm", "skipped" if vlm_skipped else "done")
return {
"detections": existing + vlm_matched,
"unresolved_candidates": still_unresolved,
"stats": stats,
}
def node_escalate_cloud(state: DetectState) -> dict:
_emit(state, "escalate_cloud", "running")
with trace_node(state, "escalate_cloud") as span:
profile = _load_profile(state)
vlm_config = get_stage_config(profile, "escalate_vlm")
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
candidates = state.get("unresolved_candidates", [])
job_id = state.get("job_id")
stats = state.get("stats", PipelineStats())
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
cloud_matched = escalate_cloud(
candidates,
vlm_prompt_fn=vlm_prompt_fn,
stats=stats,
content_type=profile["name"],
source_asset_id=state.get("source_asset_id"),
job_id=job_id,
)
span.set_output({"candidates": len(candidates), "matched": len(cloud_matched),
"cloud_calls": stats.cloud_llm_calls,
"cost_usd": stats.estimated_cloud_cost_usd})
existing = state.get("detections", [])
cloud_skipped = os.environ.get("SKIP_CLOUD", "").strip() == "1"
_emit(state, "escalate_cloud", "skipped" if cloud_skipped else "done")
return {"detections": existing + cloud_matched, "stats": stats}
def node_compile_report(state: DetectState) -> dict:
_emit(state, "compile_report", "running")
with trace_node(state, "compile_report") as span:
profile = _load_profile(state)
detections = state.get("detections", [])
stats = state.get("stats", PipelineStats())
job_id = state.get("job_id")
report = compile_report(
detections=detections,
stats=stats,
video_source=state.get("video_path", ""),
content_type=profile["name"],
job_id=job_id,
)
span.set_output({"brands": len(report.brands), "detections": len(report.timeline)})
flush_traces()
_emit(state, "compile_report", "done")
return {"report": report}
NODE_FUNCTIONS = [
("extract_frames", node_extract_frames),
("filter_scenes", node_filter_scenes),
("field_segmentation", node_field_segmentation),
("detect_edges", node_detect_edges),
("detect_objects", node_detect_objects),
("preprocess", node_preprocess),
("run_ocr", node_run_ocr),
("match_brands", node_match_brands),
("escalate_vlm", node_escalate_vlm),
("escalate_cloud", node_escalate_cloud),
("compile_report", node_compile_report),
]

289
core/detect/graph/runner.py Normal file
View File

@@ -0,0 +1,289 @@
"""
Pipeline runner — executes stages sequentially with checkpointing,
cancellation, and pause/resume.
Reads PipelineConfig from the profile to determine what stages to run.
Flattens the graph into a linear sequence for now (serial execution).
Executor socket: all stages run via LocalExecutor (call function directly).
"""
from __future__ import annotations
import logging
import os
import threading
from core.detect.stages.models import PipelineConfig
from core.detect.state import DetectState
from .nodes import NODES, NODE_FUNCTIONS
logger = logging.getLogger(__name__)
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
class PipelineCancelled(Exception):
"""Raised when a pipeline run is cancelled."""
pass
class PipelinePaused(Exception):
"""Raised when a pipeline is paused (internally, for flow control)."""
pass
# ---------------------------------------------------------------------------
# Cancellation — checked before each node
# ---------------------------------------------------------------------------
_cancel_check: dict[str, callable] = {}
def set_cancel_check(job_id: str, fn):
_cancel_check[job_id] = fn
def clear_cancel_check(job_id: str):
_cancel_check.pop(job_id, None)
# ---------------------------------------------------------------------------
# Pause / Resume / Step — checked after each node completes
# ---------------------------------------------------------------------------
_pause_gate: dict[str, threading.Event] = {}
_pause_after_stage: dict[str, bool] = {}
def init_pause(job_id: str, pause_after_stage: bool = False):
"""Initialize pause state for a job. Called when pipeline starts."""
gate = threading.Event()
gate.set() # start unpaused
_pause_gate[job_id] = gate
_pause_after_stage[job_id] = pause_after_stage
def clear_pause(job_id: str):
"""Clean up pause state. Called when pipeline finishes."""
_pause_gate.pop(job_id, None)
_pause_after_stage.pop(job_id, None)
def pause_pipeline(job_id: str):
"""Pause a running pipeline. It will block after the current stage completes."""
gate = _pause_gate.get(job_id)
if gate:
gate.clear()
logger.info("Pipeline %s paused", job_id)
def resume_pipeline(job_id: str):
"""Resume a paused pipeline."""
gate = _pause_gate.get(job_id)
if gate:
gate.set()
logger.info("Pipeline %s resumed", job_id)
def step_pipeline(job_id: str):
"""Run one stage then pause again."""
_pause_after_stage[job_id] = True
gate = _pause_gate.get(job_id)
if gate:
gate.set()
logger.info("Pipeline %s stepping", job_id)
def set_pause_after_stage(job_id: str, enabled: bool):
"""Toggle pause-after-each-stage mode."""
_pause_after_stage[job_id] = enabled
if not enabled:
gate = _pause_gate.get(job_id)
if gate:
gate.set()
def is_paused(job_id: str) -> bool:
"""Check if a pipeline is currently paused."""
gate = _pause_gate.get(job_id)
return gate is not None and not gate.is_set()
def _wait_if_paused(job_id: str, node_name: str):
"""Block until resumed. Called after each node completes."""
gate = _pause_gate.get(job_id)
if gate is None:
return
if _pause_after_stage.get(job_id, False):
gate.clear()
from core.detect import emit
emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}")
while not gate.wait(timeout=0.5):
check = _cancel_check.get(job_id)
if check and check():
raise PipelineCancelled(f"Cancelled while paused before next stage")
# ---------------------------------------------------------------------------
# Pipeline Runner
# ---------------------------------------------------------------------------
# Node function lookup — maps stage name to callable
_NODE_FN_MAP: dict[str, callable] = {name: fn for name, fn in NODE_FUNCTIONS}
def _flatten_config(config: PipelineConfig, start_from: str | None = None) -> list[str]:
"""
Flatten a PipelineConfig into a linear stage sequence.
For now: topological sort via edges. Falls back to stage order if no edges.
Respects start_from for replay (skip stages before it).
"""
if not config.edges:
# No edges defined — use stage order as-is
names = [s.name for s in config.stages]
else:
# Topological sort from edges
graph: dict[str, list[str]] = {}
in_degree: dict[str, int] = {}
stage_names = {s.name for s in config.stages}
for name in stage_names:
graph[name] = []
in_degree[name] = 0
for edge in config.edges:
if edge.source in stage_names and edge.target in stage_names:
graph[edge.source].append(edge.target)
in_degree[edge.target] = in_degree.get(edge.target, 0) + 1
# Kahn's algorithm
queue = [n for n in stage_names if in_degree.get(n, 0) == 0]
# Stable sort: prefer order from config.stages
stage_order = {s.name: i for i, s in enumerate(config.stages)}
queue.sort(key=lambda n: stage_order.get(n, 999))
names = []
while queue:
node = queue.pop(0)
names.append(node)
for neighbor in graph.get(node, []):
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
queue.sort(key=lambda n: stage_order.get(n, 999))
if start_from:
try:
idx = names.index(start_from)
names = names[idx:]
except ValueError:
raise ValueError(f"Stage {start_from!r} not in pipeline config")
return names
class PipelineRunner:
"""
Executes a pipeline defined by PipelineConfig.
Runs stages sequentially (flattened). Each stage:
1. Check cancel
2. Run node function (via executor — local for now)
3. Merge result into state
4. Checkpoint (if enabled)
5. Check pause
Executor socket: currently calls node functions directly.
Future: dispatch to LocalExecutor / GrpcExecutor / LambdaExecutor
based on StageRef.execution_target.
"""
def __init__(
self,
config: PipelineConfig,
checkpoint: bool = False,
start_from: str | None = None,
):
self.config = config
self.do_checkpoint = checkpoint
self.stage_sequence = _flatten_config(config, start_from)
# Build edge transform lookup: {target_stage: {source_stage: transform_dict}}
self._edge_transforms: dict[str, dict[str, dict]] = {}
for edge in config.edges:
if edge.transform:
if edge.target not in self._edge_transforms:
self._edge_transforms[edge.target] = {}
self._edge_transforms[edge.target][edge.source] = edge.transform
def invoke(self, state: DetectState) -> DetectState:
"""Run the pipeline on the given state. Returns final state."""
for stage_name in self.stage_sequence:
job_id = state.get("job_id", "")
# 1. Cancel check
check = _cancel_check.get(job_id)
if check and check():
raise PipelineCancelled(f"Cancelled before {stage_name}")
# Inject edge transforms into state so the stage can read them.
# Compatible with LangGraph — just a state dict key.
transforms = self._edge_transforms.get(stage_name, {})
if transforms:
state["_edge_transforms"] = transforms
elif "_edge_transforms" in state:
del state["_edge_transforms"]
# 2. Run node function
node_fn = _NODE_FN_MAP.get(stage_name)
if node_fn is None:
logger.warning("No node function for stage %s, skipping", stage_name)
continue
result = node_fn(state)
# 3. Merge result into state
state.update(result)
# 4. Checkpoint
if self.do_checkpoint:
from core.detect.checkpoint import checkpoint_after_stage
checkpoint_after_stage(job_id, stage_name, state, result)
# 5. Pause check
_wait_if_paused(job_id, stage_name)
return state
# ---------------------------------------------------------------------------
# Public API — backwards compatible with old get_pipeline/build_graph
# ---------------------------------------------------------------------------
def get_pipeline(
checkpoint: bool | None = None,
profile_name: str = "soccer_broadcast",
start_from: str | None = None,
) -> PipelineRunner:
"""Return a PipelineRunner for the given profile."""
from core.detect.profile import get_profile, pipeline_config_from_dict
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
profile = get_profile(profile_name)
config = pipeline_config_from_dict(profile["pipeline"])
return PipelineRunner(
config=config,
checkpoint=do_checkpoint,
start_from=start_from,
)
def build_graph(checkpoint: bool | None = None, start_from: str | None = None):
"""Backwards-compatible wrapper. Returns a PipelineRunner."""
return get_pipeline(checkpoint=checkpoint, start_from=start_from)

View File

@@ -0,0 +1,4 @@
from .client import InferenceClient
from .types import DetectResult, OCRResult, VLMResult
__all__ = ["InferenceClient", "DetectResult", "OCRResult", "VLMResult"]

View File

@@ -0,0 +1,262 @@
"""
HTTP client for the inference server.
The pipeline stages call this instead of importing ML libraries directly.
The inference server runs on the GPU machine (or spot instance).
"""
from __future__ import annotations
import base64
import io
import logging
import os
import numpy as np
import requests
from PIL import Image
from .types import DetectResult, OCRResult, RegionDebugResult, RegionResult, ServerStatus, VLMResult
logger = logging.getLogger(__name__)
DEFAULT_URL = os.environ.get("INFERENCE_URL", "http://localhost:8000")
def _encode_image(image: np.ndarray) -> str:
"""Encode numpy array as base64 JPEG."""
img = Image.fromarray(image)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
return base64.b64encode(buf.getvalue()).decode()
class InferenceClient:
"""HTTP client for the GPU inference server."""
def __init__(self, base_url: str | None = None, timeout: float = 60.0,
job_id: str = "", log_level: str = "INFO"):
self.base_url = (base_url or DEFAULT_URL).rstrip("/")
self.timeout = timeout
self.job_id = job_id
self.log_level = log_level
self.session = requests.Session()
if job_id:
self.session.headers["X-Job-Id"] = job_id
self.session.headers["X-Log-Level"] = log_level
def health(self) -> ServerStatus:
"""Check server health and loaded models."""
resp = self.session.get(f"{self.base_url}/health", timeout=self.timeout)
resp.raise_for_status()
data = resp.json()
return ServerStatus(
loaded_models=data.get("loaded_models", []),
vram_used_mb=data.get("vram_used_mb", 0),
vram_budget_mb=data.get("vram_budget_mb", 0),
strategy=data.get("strategy", "sequential"),
)
def detect(
self,
image: np.ndarray,
model: str = "yolov8n",
confidence: float = 0.3,
target_classes: list[str] | None = None,
) -> list[DetectResult]:
"""Run object detection on an image."""
payload = {
"image": _encode_image(image),
"model": model,
"confidence": confidence,
}
if target_classes:
payload["target_classes"] = target_classes
resp = self.session.post(
f"{self.base_url}/detect",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
results = []
for d in resp.json().get("detections", []):
result = DetectResult(
x=d["x"], y=d["y"], w=d["w"], h=d["h"],
confidence=d["confidence"], label=d["label"],
)
results.append(result)
return results
def ocr(
self,
image: np.ndarray,
languages: list[str] | None = None,
) -> list[OCRResult]:
"""Run OCR on an image region."""
payload = {
"image": _encode_image(image),
}
if languages:
payload["languages"] = languages
resp = self.session.post(
f"{self.base_url}/ocr",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
results = []
for d in resp.json().get("results", []):
result = OCRResult(
text=d["text"],
confidence=d["confidence"],
bbox=tuple(d["bbox"]),
)
results.append(result)
return results
def vlm(
self,
image: np.ndarray,
prompt: str,
model: str = "moondream2",
) -> VLMResult:
"""Query a visual language model with an image crop + prompt."""
payload = {
"image": _encode_image(image),
"prompt": prompt,
"model": model,
}
resp = self.session.post(
f"{self.base_url}/vlm",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
data = resp.json()
return VLMResult(
brand=data.get("brand", ""),
confidence=data.get("confidence", 0.0),
reasoning=data.get("reasoning", ""),
)
def detect_edges(
self,
image: np.ndarray,
edge_canny_low: int = 50,
edge_canny_high: int = 150,
edge_hough_threshold: int = 80,
edge_hough_min_length: int = 100,
edge_hough_max_gap: int = 10,
edge_pair_max_distance: int = 200,
edge_pair_min_distance: int = 15,
) -> list[RegionResult]:
"""Run edge detection on an image."""
payload = {
"image": _encode_image(image),
"edge_canny_low": edge_canny_low,
"edge_canny_high": edge_canny_high,
"edge_hough_threshold": edge_hough_threshold,
"edge_hough_min_length": edge_hough_min_length,
"edge_hough_max_gap": edge_hough_max_gap,
"edge_pair_max_distance": edge_pair_max_distance,
"edge_pair_min_distance": edge_pair_min_distance,
}
resp = self.session.post(
f"{self.base_url}/detect_edges",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
results = []
for r in resp.json().get("regions", []):
result = RegionResult(
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
confidence=r["confidence"], label=r["label"],
)
results.append(result)
return results
def detect_edges_debug(
self,
image: np.ndarray,
edge_canny_low: int = 50,
edge_canny_high: int = 150,
edge_hough_threshold: int = 80,
edge_hough_min_length: int = 100,
edge_hough_max_gap: int = 10,
edge_pair_max_distance: int = 200,
edge_pair_min_distance: int = 15,
) -> RegionDebugResult:
"""Run edge detection with debug overlays."""
payload = {
"image": _encode_image(image),
"edge_canny_low": edge_canny_low,
"edge_canny_high": edge_canny_high,
"edge_hough_threshold": edge_hough_threshold,
"edge_hough_min_length": edge_hough_min_length,
"edge_hough_max_gap": edge_hough_max_gap,
"edge_pair_max_distance": edge_pair_max_distance,
"edge_pair_min_distance": edge_pair_min_distance,
}
resp = self.session.post(
f"{self.base_url}/detect_edges/debug",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
data = resp.json()
regions = []
for r in data.get("regions", []):
region = RegionResult(
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
confidence=r["confidence"], label=r["label"],
)
regions.append(region)
return RegionDebugResult(
regions=regions,
edge_overlay_b64=data.get("edge_overlay_b64", ""),
lines_overlay_b64=data.get("lines_overlay_b64", ""),
horizontal_count=data.get("horizontal_count", 0),
pair_count=data.get("pair_count", 0),
)
def post(self, path: str, payload: dict) -> dict | None:
"""Generic POST to the inference server. Returns JSON response or None on error."""
try:
resp = self.session.post(
f"{self.base_url}{path}",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
return resp.json()
except Exception as e:
logger.warning("Inference POST %s failed: %s", path, e)
return None
def load_model(self, model: str, quantization: str = "fp16") -> None:
"""Request the server to load a model into VRAM."""
self.session.post(
f"{self.base_url}/models/load",
json={"model": model, "quantization": quantization},
timeout=self.timeout,
).raise_for_status()
def unload_model(self, model: str) -> None:
"""Request the server to unload a model from VRAM."""
self.session.post(
f"{self.base_url}/models/unload",
json={"model": model},
timeout=self.timeout,
).raise_for_status()

View File

@@ -0,0 +1,76 @@
"""
Inference response types.
These are the shapes returned by the inference server.
Kept separate from core.detect.models to avoid coupling the
inference protocol to pipeline internals.
"""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass
class DetectResult:
"""Single object detection from YOLO or similar."""
x: int
y: int
w: int
h: int
confidence: float
label: str
@dataclass
class OCRResult:
"""Text extracted from a region."""
text: str
confidence: float
bbox: tuple[int, int, int, int] # x, y, w, h
@dataclass
class VLMResult:
"""Visual language model response for a crop."""
brand: str
confidence: float
reasoning: str
@dataclass
class RegionResult:
"""A candidate region from CV analysis."""
x: int
y: int
w: int
h: int
confidence: float
label: str
@dataclass
class RegionDebugResult:
"""CV region analysis with debug overlays."""
regions: list[RegionResult] = field(default_factory=list)
edge_overlay_b64: str = ""
lines_overlay_b64: str = ""
horizontal_count: int = 0
pair_count: int = 0
@dataclass
class ModelInfo:
"""Info about a loaded model."""
name: str
vram_mb: float
quantization: str # fp32, fp16, int8, int4
@dataclass
class ServerStatus:
"""Inference server health response."""
loaded_models: list[ModelInfo] = field(default_factory=list)
vram_used_mb: float = 0.0
vram_budget_mb: float = 0.0
strategy: str = "sequential" # sequential, concurrent, auto

95
core/detect/models.py Normal file
View File

@@ -0,0 +1,95 @@
"""
Detection pipeline runtime models.
These are the data structures that flow between pipeline stages.
They contain runtime types (np.ndarray) so they live here, not in
core/schema/models/ (which is for modelgen source of truth).
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Literal
import numpy as np
@dataclass
class Frame:
sequence: int
chunk_id: int
timestamp: float # position in video (seconds)
image: np.ndarray
perceptual_hash: str = ""
@dataclass
class BoundingBox:
x: int
y: int
w: int
h: int
confidence: float
label: str
@dataclass
class TextCandidate:
frame: Frame
bbox: BoundingBox
text: str
ocr_confidence: float
@dataclass
class BrandDetection:
brand: str
timestamp: float
duration: float
confidence: float
source: Literal["ocr", "local_vlm", "cloud_llm", "logo_match", "auxiliary"]
bbox: BoundingBox | None = None
frame_ref: int | None = None
content_type: str = ""
@dataclass
class BrandStats:
total_appearances: int = 0
total_screen_time: float = 0.0
avg_confidence: float = 0.0
first_seen: float = 0.0
last_seen: float = 0.0
@dataclass
class PipelineStats:
frames_extracted: int = 0
frames_after_scene_filter: int = 0
cv_regions_detected: int = 0
regions_detected: int = 0
regions_resolved_by_ocr: int = 0
regions_escalated_to_local_vlm: int = 0
regions_escalated_to_cloud_llm: int = 0
auxiliary_detections: int = 0
cloud_llm_calls: int = 0
processing_time_seconds: float = 0.0
estimated_cloud_cost_usd: float = 0.0
@dataclass
class DetectionReport:
video_source: str
content_type: str
duration_seconds: float
brands: dict[str, BrandStats] = field(default_factory=dict)
timeline: list[BrandDetection] = field(default_factory=list)
pipeline_stats: PipelineStats = field(default_factory=PipelineStats)
@dataclass
class CropContext:
"""Runtime type — holds image bytes for VLM prompts."""
image: bytes
surrounding_text: str = ""
position_hint: str = ""

107
core/detect/profile.py Normal file
View File

@@ -0,0 +1,107 @@
"""
Profile registry and helpers.
Loads profile data from Postgres.
A profile is a dict with keys: name, pipeline, configs.
"""
from __future__ import annotations
import logging
from typing import Any, Dict
from core.detect.stages.models import PipelineConfig, StageRef, Edge
from core.detect.models import (
BrandDetection,
BrandStats,
CropContext,
DetectionReport,
PipelineStats,
)
logger = logging.getLogger(__name__)
def get_profile(name: str) -> Dict[str, Any]:
"""Get a profile dict by name from the database."""
from core.db.connection import get_session
from core.db.models import Profile
with get_session() as session:
row = session.query(Profile).filter(Profile.name == name).first()
if row is None:
raise ValueError(f"Unknown profile: {name!r}")
return {
"name": row.name,
"pipeline": row.pipeline or {},
"configs": row.configs or {},
}
def list_profiles() -> list[str]:
"""List available profile names from the database."""
from core.db.connection import get_session
from core.db.models import Profile
with get_session() as session:
rows = session.query(Profile.name).all()
return [r[0] for r in rows]
def get_stage_config(profile: Dict[str, Any], stage_name: str) -> dict:
"""Get config values for a stage from a profile."""
return profile.get("configs", {}).get(stage_name, {})
def pipeline_config_from_dict(data: Dict[str, Any]) -> PipelineConfig:
"""Deserialize a PipelineConfig from a JSONB dict."""
stages = [StageRef(**s) for s in data.get("stages", [])]
edges = [Edge(**e) for e in data.get("edges", [])]
return PipelineConfig(
name=data.get("name", ""),
profile_name=data.get("profile_name", ""),
stages=stages,
edges=edges,
routing_rules=data.get("routing_rules", {}),
)
def build_vlm_prompt(crop_context: CropContext, template: str) -> str:
"""Build a VLM prompt from a template and crop context."""
hint = f" Position: {crop_context.position_hint}." if crop_context.position_hint else ""
text = f" Nearby text: '{crop_context.surrounding_text}'." if crop_context.surrounding_text else ""
return template.format(hint=hint, text=text)
def aggregate_detections(
detections: list[BrandDetection],
content_type: str,
) -> DetectionReport:
"""Group detections by brand into a report."""
brands: dict[str, BrandStats] = {}
for d in detections:
if d.brand not in brands:
brands[d.brand] = BrandStats()
s = brands[d.brand]
s.total_appearances += 1
s.total_screen_time += d.duration
s.avg_confidence = (
(s.avg_confidence * (s.total_appearances - 1) + d.confidence)
/ s.total_appearances
)
if s.first_seen == 0.0 or d.timestamp < s.first_seen:
s.first_seen = d.timestamp
if d.timestamp > s.last_seen:
s.last_seen = d.timestamp
return DetectionReport(
video_source="",
content_type=content_type,
duration_seconds=0.0,
brands=brands,
timeline=sorted(detections, key=lambda d: d.timestamp),
pipeline_stats=PipelineStats(),
)

View File

@@ -0,0 +1,58 @@
"""
Cloud LLM provider registry.
Select provider via CLOUD_LLM_PROVIDER env var.
Each provider reads its own env vars for auth/config.
CLOUD_LLM_PROVIDER=groq → GROQ_API_KEY, GROQ_MODEL, GROQ_BASE_URL
CLOUD_LLM_PROVIDER=gemini → GEMINI_API_KEY, GEMINI_MODEL
CLOUD_LLM_PROVIDER=openai → OPENAI_API_KEY, OPENAI_MODEL, OPENAI_BASE_URL
CLOUD_LLM_PROVIDER=claude → ANTHROPIC_API_KEY, CLAUDE_MODEL
"""
from __future__ import annotations
import os
from .base import CloudProvider, ProviderResponse
from .groq import GroqProvider
from .gemini import GeminiProvider
from .openai_compat import OpenAICompatProvider
from .claude import ClaudeProvider
PROVIDERS: dict[str, type] = {
"groq": GroqProvider,
"gemini": GeminiProvider,
"openai": OpenAICompatProvider,
"claude": ClaudeProvider,
}
_cached: CloudProvider | None = None
def get_provider() -> CloudProvider:
"""Get the configured cloud provider (cached after first call)."""
global _cached
if _cached is not None:
return _cached
name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
cls = PROVIDERS.get(name)
if cls is None:
raise ValueError(f"Unknown provider: {name!r}. Options: {list(PROVIDERS)}")
_cached = cls()
return _cached
def has_api_key() -> bool:
"""Check if the configured provider has an API key set."""
name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
key_map = {
"groq": "GROQ_API_KEY",
"gemini": "GEMINI_API_KEY",
"openai": "OPENAI_API_KEY",
"claude": "ANTHROPIC_API_KEY",
}
env_var = key_map.get(name, "")
return bool(os.environ.get(env_var, ""))

View File

@@ -0,0 +1,36 @@
"""Cloud LLM provider protocol and model metadata."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol
@dataclass
class ModelInfo:
"""Metadata for a cloud LLM model."""
id: str
vision: bool = True
cost_per_input_token: float = 0.0
cost_per_output_token: float = 0.0
max_output_tokens: int = 4096
notes: str = ""
@dataclass
class ProviderResponse:
answer: str
total_tokens: int = 0
class CloudProvider(Protocol):
"""
Interface for cloud LLM providers.
Each provider handles its own auth, payload format, and response parsing.
The pipeline only calls call() and reads the response.
"""
name: str
models: dict[str, ModelInfo]
def call(self, image_b64: str, prompt: str) -> ProviderResponse: ...

View File

@@ -0,0 +1,73 @@
"""Anthropic Claude provider — uses the official SDK."""
from __future__ import annotations
import logging
import os
from .base import ModelInfo, ProviderResponse
logger = logging.getLogger(__name__)
# Claude-specific env vars
# ANTHROPIC_API_KEY is read by the SDK automatically
CLAUDE_MODEL = os.environ.get("CLAUDE_MODEL", "claude-sonnet-4-20250514")
MODELS = {
"claude-sonnet-4-20250514": ModelInfo(
id="claude-sonnet-4-20250514",
vision=True,
cost_per_input_token=0.000003,
cost_per_output_token=0.000015,
notes="Best balance of quality/cost with vision",
),
"claude-haiku-4-5-20251001": ModelInfo(
id="claude-haiku-4-5-20251001",
vision=True,
cost_per_input_token=0.0000008,
cost_per_output_token=0.000004,
notes="Fastest, cheapest, good for simple brand ID",
),
"claude-opus-4-6": ModelInfo(
id="claude-opus-4-6",
vision=True,
cost_per_input_token=0.000015,
cost_per_output_token=0.000075,
notes="Highest quality, use for ambiguous cases",
),
}
class ClaudeProvider:
name = "claude"
models = MODELS
def __init__(self):
from anthropic import Anthropic
self.client = Anthropic()
self.model = CLAUDE_MODEL
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
message = self.client.messages.create(
model=self.model,
max_tokens=150,
messages=[{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": image_b64,
},
},
{"type": "text", "text": prompt},
],
}],
)
answer = message.content[0].text.strip()
total_tokens = message.usage.input_tokens + message.usage.output_tokens
return ProviderResponse(answer=answer, total_tokens=total_tokens)

View File

@@ -0,0 +1,75 @@
"""Google Gemini provider — native REST API, not OpenAI-compatible."""
from __future__ import annotations
import logging
import os
import requests
from .base import ModelInfo, ProviderResponse
logger = logging.getLogger(__name__)
# Gemini-specific env vars
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-2.0-flash")
MODELS = {
"gemini-2.0-flash": ModelInfo(
id="gemini-2.0-flash",
vision=True,
cost_per_input_token=0.0000001,
cost_per_output_token=0.0000004,
notes="Fast, cheap, good vision",
),
"gemini-2.0-pro": ModelInfo(
id="gemini-2.0-pro",
vision=True,
cost_per_input_token=0.00000125,
cost_per_output_token=0.000005,
notes="Higher quality, slower",
),
"gemini-1.5-flash": ModelInfo(
id="gemini-1.5-flash",
vision=True,
cost_per_input_token=0.000000075,
cost_per_output_token=0.0000003,
notes="Cheapest option",
),
}
class GeminiProvider:
name = "gemini"
models = MODELS
def __init__(self):
self.api_key = GEMINI_API_KEY
self.model = GEMINI_MODEL
self.endpoint = (
f"https://generativelanguage.googleapis.com/v1beta/models/"
f"{self.model}:generateContent"
)
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
payload = {
"contents": [{
"parts": [
{"text": prompt},
{"inline_data": {"mime_type": "image/jpeg", "data": image_b64}},
],
}],
"generationConfig": {"maxOutputTokens": 150},
}
url = f"{self.endpoint}?key={self.api_key}"
resp = requests.post(url, json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
answer = data["candidates"][0]["content"]["parts"][0]["text"].strip()
usage = data.get("usageMetadata", {})
total_tokens = usage.get("totalTokenCount", 0)
return ProviderResponse(answer=answer, total_tokens=total_tokens)

View File

@@ -0,0 +1,66 @@
"""Groq cloud provider — OpenAI-compatible API with vision."""
from __future__ import annotations
import logging
import os
import requests
from .base import ModelInfo, ProviderResponse
logger = logging.getLogger(__name__)
# Groq-specific env vars
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
GROQ_BASE_URL = os.environ.get("GROQ_BASE_URL", "https://api.groq.com/openai/v1")
GROQ_MODEL = os.environ.get("GROQ_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct")
MODELS = {
"meta-llama/llama-4-scout-17b-16e-instruct": ModelInfo(
id="meta-llama/llama-4-scout-17b-16e-instruct",
vision=True,
cost_per_input_token=0.0,
cost_per_output_token=0.0,
notes="Llama 4 Scout, only vision model on Groq free tier",
),
}
class GroqProvider:
name = "groq"
models = MODELS
def __init__(self):
self.api_key = GROQ_API_KEY
self.base_url = GROQ_BASE_URL
self.model = GROQ_MODEL
self.endpoint = f"{self.base_url.rstrip('/')}/chat/completions"
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
payload = {
"model": self.model,
"messages": [{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {
"url": f"data:image/jpeg;base64,{image_b64}",
}},
],
}],
"max_tokens": 150,
}
resp = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
answer = data["choices"][0]["message"]["content"].strip()
total_tokens = data.get("usage", {}).get("total_tokens", 0)
return ProviderResponse(answer=answer, total_tokens=total_tokens)

View File

@@ -0,0 +1,73 @@
"""Generic OpenAI-compatible provider (OpenAI, Together, etc.)."""
from __future__ import annotations
import logging
import os
import requests
from .base import ModelInfo, ProviderResponse
logger = logging.getLogger(__name__)
# OpenAI-compat specific env vars
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
OPENAI_MODEL = os.environ.get("OPENAI_MODEL", "gpt-4o-mini")
MODELS = {
"gpt-4o-mini": ModelInfo(
id="gpt-4o-mini",
vision=True,
cost_per_input_token=0.00000015,
cost_per_output_token=0.0000006,
notes="Cheap, fast, decent vision",
),
"gpt-4o": ModelInfo(
id="gpt-4o",
vision=True,
cost_per_input_token=0.0000025,
cost_per_output_token=0.00001,
notes="Best OpenAI vision model",
),
}
class OpenAICompatProvider:
name = "openai"
models = MODELS
def __init__(self):
self.api_key = OPENAI_API_KEY
self.base_url = OPENAI_BASE_URL
self.model = OPENAI_MODEL
self.endpoint = f"{self.base_url.rstrip('/')}/chat/completions"
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
payload = {
"model": self.model,
"messages": [{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {
"url": f"data:image/jpeg;base64,{image_b64}",
}},
],
}],
"max_tokens": 150,
}
resp = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
answer = data["choices"][0]["message"]["content"].strip()
total_tokens = data.get("usage", {}).get("total_tokens", 0)
return ProviderResponse(answer=answer, total_tokens=total_tokens)

163
core/detect/sse.py Normal file
View File

@@ -0,0 +1,163 @@
"""
Pydantic Models - GENERATED FILE
Do not edit directly. Regenerate using modelgen.
"""
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel, Field
class GraphNode(BaseModel):
"""A pipeline stage node."""
id: str
status: str = "idle"
items_in: int = 0
items_out: int = 0
class GraphEdge(BaseModel):
"""An edge between pipeline stages."""
source: str
target: str
throughput: int = 0
class BoundingBoxEvent(BaseModel):
"""Bounding box in SSE event payloads."""
x: int
y: int
w: int
h: int
confidence: float
label: str
resolved_brand: Optional[str] = None
source: Optional[str] = None
stage: Optional[str] = None
class BrandSummary(BaseModel):
"""Per-brand stats in the final report."""
brand: str
total_appearances: int = 0
total_screen_time: float = 0.0
avg_confidence: float = 0.0
first_seen: float = 0.0
last_seen: float = 0.0
class GraphUpdate(BaseModel):
"""Pipeline node state transition. SSE event: graph_update"""
nodes: List[GraphNode] = Field(default_factory=list)
edges: List[GraphEdge] = Field(default_factory=list)
active_path: List[str] = Field(default_factory=list)
class StatsUpdate(BaseModel):
"""Funnel statistics snapshot. SSE event: stats_update"""
frames_extracted: int = 0
frames_after_scene_filter: int = 0
cv_regions_detected: int = 0
regions_detected: int = 0
regions_resolved_by_ocr: int = 0
regions_escalated_to_local_vlm: int = 0
regions_escalated_to_cloud_llm: int = 0
cloud_llm_calls: int = 0
processing_time_seconds: float = 0.0
estimated_cloud_cost_usd: float = 0.0
run_id: Optional[str] = None
parent_job_id: Optional[str] = None
run_type: str = "initial"
class FrameUpdate(BaseModel):
"""Current frame being processed. SSE event: frame_update"""
frame_ref: int
timestamp: float
jpeg_b64: str
boxes: List[BoundingBoxEvent] = Field(default_factory=list)
class Detection(BaseModel):
"""A confirmed brand detection. SSE event: detection"""
brand: str
timestamp: float
duration: float
confidence: float
source: str
content_type: str
bbox: Optional[BoundingBoxEvent] = None
frame_ref: Optional[int] = None
class LogEvent(BaseModel):
"""Pipeline log line. SSE event: log"""
level: str
stage: str
msg: str
ts: str
trace_id: Optional[str] = None
class DetectionReportSummary(BaseModel):
"""Final detection report summary."""
video_source: str
content_type: str
duration_seconds: float
total_detections: int = 0
brands: List[BrandSummary] = Field(default_factory=list)
stats: Optional[StatsUpdate] = None
class JobComplete(BaseModel):
"""Final report when pipeline finishes. SSE event: job_complete"""
job_id: str
report: Optional[DetectionReportSummary] = None
class RunContext(BaseModel):
"""Run context injected into all SSE events for grouping."""
run_id: str
parent_job_id: str
run_type: str = "initial"
class CheckpointInfo(BaseModel):
"""Available checkpoint for a stage."""
stage: str
is_scenario: bool = False
scenario_label: str = ""
class ReplayRequest(BaseModel):
"""Request to replay pipeline from a specific stage."""
job_id: str
start_stage: str
config_overrides: Optional[Dict[str, Any]] = None
class ReplayResponse(BaseModel):
"""Result of a replay invocation."""
status: str
job_id: str
start_stage: str
detections: int = 0
brands_found: int = 0
class RetryRequest(BaseModel):
"""Request to queue async retry with different config."""
job_id: str
config_overrides: Optional[Dict[str, Any]] = None
start_stage: str = "escalate_vlm"
schedule_seconds: Optional[float] = None
class RetryResponse(BaseModel):
"""Result of queueing a retry task."""
status: str
task_id: str
job_id: str
class RunRequest(BaseModel):
"""Request body for launching a detection pipeline run."""
video_path: str
profile_name: str = "soccer_broadcast"
source_asset_id: str = ""
checkpoint: bool = True
skip_vlm: bool = False
skip_cloud: bool = False
log_level: str = "INFO"
class RunResponse(BaseModel):
"""Response after starting a pipeline run."""
status: str
job_id: str
video_path: str

View File

@@ -0,0 +1,22 @@
"""
Pipeline stages.
Each stage is a file with a Stage subclass. Auto-discovered via
__init_subclass__ — importing the file registers the stage.
"""
from .base import (
Stage,
get_stage,
get_stage_instance,
list_stages,
list_stage_classes,
get_palette,
)
# Import all stage files to trigger auto-registration
from . import edge_detector # noqa: F401
from . import field_segmentation # noqa: F401
# Import registry for backward compat (other stages still use old pattern)
from . import registry # noqa: F401

View File

@@ -0,0 +1,116 @@
"""
Stage 8 — Report compilation
Groups all detections by brand, merges contiguous appearances,
and builds the final DetectionReport.
"""
from __future__ import annotations
import logging
from core.detect import emit
from core.detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
logger = logging.getLogger(__name__)
def _merge_contiguous(detections: list[BrandDetection], gap_threshold: float = 2.0) -> list[BrandDetection]:
"""
Merge detections of the same brand that are close in time.
If two detections of the same brand are within gap_threshold seconds,
they're merged into one detection spanning the full range.
"""
if not detections:
return []
sorted_dets = sorted(detections, key=lambda d: (d.brand, d.timestamp))
merged: list[BrandDetection] = []
current = sorted_dets[0]
for det in sorted_dets[1:]:
if (det.brand == current.brand
and det.timestamp <= current.timestamp + current.duration + gap_threshold):
end = max(current.timestamp + current.duration,
det.timestamp + det.duration)
current = BrandDetection(
brand=current.brand,
timestamp=current.timestamp,
duration=end - current.timestamp,
confidence=max(current.confidence, det.confidence),
source=current.source,
bbox=current.bbox,
frame_ref=current.frame_ref,
content_type=current.content_type,
)
else:
merged.append(current)
current = det
merged.append(current)
return merged
def compile_report(
detections: list[BrandDetection],
stats: PipelineStats,
video_source: str = "",
content_type: str = "",
duration_seconds: float = 0.0,
job_id: str | None = None,
) -> DetectionReport:
"""
Build the final detection report from all accumulated detections.
Merges contiguous detections, computes per-brand stats,
and emits the job_complete event.
"""
merged = _merge_contiguous(detections)
brands: dict[str, BrandStats] = {}
for d in merged:
if d.brand not in brands:
brands[d.brand] = BrandStats()
s = brands[d.brand]
s.total_appearances += 1
s.total_screen_time += d.duration
s.avg_confidence = (
(s.avg_confidence * (s.total_appearances - 1) + d.confidence)
/ s.total_appearances
)
if s.first_seen == 0.0 or d.timestamp < s.first_seen:
s.first_seen = d.timestamp
if d.timestamp > s.last_seen:
s.last_seen = d.timestamp
report = DetectionReport(
video_source=video_source,
content_type=content_type,
duration_seconds=duration_seconds,
brands=brands,
timeline=sorted(merged, key=lambda d: d.timestamp),
pipeline_stats=stats,
)
emit.log(job_id, "Aggregator", "INFO",
f"Report: {len(brands)} brands, {len(merged)} detections "
f"(merged from {len(detections)} raw)")
emit.job_complete(job_id, {
"video_source": report.video_source,
"content_type": report.content_type,
"duration_seconds": report.duration_seconds,
"brands": {
k: {
"total_appearances": v.total_appearances,
"total_screen_time": v.total_screen_time,
"avg_confidence": round(v.avg_confidence, 3),
"first_seen": v.first_seen,
"last_seen": v.last_seen,
}
for k, v in brands.items()
},
})
return report

151
core/detect/stages/base.py Normal file
View File

@@ -0,0 +1,151 @@
"""
Stage base class — common interface for all pipeline stages.
Each stage is a file that subclasses Stage. Auto-discovered via
__init_subclass__. No manual registration needed.
A stage:
- Has a StageDefinition (generated from schema) with name, config, IO
- Implements run(frames, config) → output
- Owns its output serialization (opaque blob)
- Optionally has a TypeScript port for browser-side execution
The checkpoint layer stores stage output as blobs without knowing
the format. The stage that wrote it is the only one that can read it.
"""
from __future__ import annotations
from typing import Any
import numpy as np
from core.detect.stages.models import (
StageConfigField,
StageIO,
StageDefinition,
)
# Legacy runtime extension — adds callable fields for old-style stages.
# New stages use Stage subclass with serialize()/deserialize() methods instead.
class LegacyStageDefinition:
"""Wraps a StageDefinition with callable serialize/deserialize functions."""
def __init__(self, definition: StageDefinition, fn=None, serialize_fn=None, deserialize_fn=None):
self._definition = definition
self.fn = fn
self.serialize_fn = serialize_fn
self.deserialize_fn = deserialize_fn
def __getattr__(self, name):
return getattr(self._definition, name)
# ---------------------------------------------------------------------------
# Registry — auto-populated by __init_subclass__ (new stages)
# + register_stage() (legacy stages during migration)
# ---------------------------------------------------------------------------
_REGISTRY: dict[str, type['Stage']] = {}
_LEGACY_REGISTRY: dict[str, LegacyStageDefinition] = {}
def register_stage(
definition: StageDefinition,
fn=None,
serialize_fn=None,
deserialize_fn=None,
):
"""Legacy registration for stages not yet converted to Stage subclass."""
legacy = LegacyStageDefinition(definition, fn=fn, serialize_fn=serialize_fn, deserialize_fn=deserialize_fn)
_LEGACY_REGISTRY[definition.name] = legacy
class Stage:
"""
Base class for all pipeline stages.
Subclass this in detect/stages/<name>.py. Define `definition` as a
class attribute. Implement `run()`. Optionally override `serialize()`
and `deserialize()` for custom blob formats (default is JSON).
"""
definition: StageDefinition # set by each subclass
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if hasattr(cls, 'definition') and cls.definition is not None:
_REGISTRY[cls.definition.name] = cls
def run(self, frames: list, config: dict) -> Any:
raise NotImplementedError
def serialize(self, output: Any) -> bytes:
"""Serialize stage output to bytes for checkpoint storage."""
import json
return json.dumps(output, default=str).encode()
def deserialize(self, data: bytes) -> Any:
"""Deserialize stage output from checkpoint blob."""
import json
return json.loads(data)
# ---------------------------------------------------------------------------
# Discovery API
# ---------------------------------------------------------------------------
def _all_definitions():
"""Merge new Stage subclass registry + legacy registry.
Returns StageDefinition for new-style stages,
LegacyStageDefinition for legacy stages (has serialize_fn etc).
"""
merged = {}
for name, legacy in _LEGACY_REGISTRY.items():
merged[name] = legacy
for name, cls in _REGISTRY.items():
merged[name] = cls.definition
return merged
def get_stage(name: str) -> StageDefinition:
"""Get a stage definition by name (works for both new and legacy)."""
all_defs = _all_definitions()
if name not in all_defs:
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(all_defs)}")
return all_defs[name]
def get_stage_class(name: str) -> type[Stage] | None:
"""Get a Stage subclass by name. Returns None for legacy stages."""
return _REGISTRY.get(name)
def get_stage_instance(name: str) -> Stage:
"""Get an instantiated Stage by name. Only works for new-style stages."""
cls = _REGISTRY.get(name)
if cls is None:
raise KeyError(f"No Stage subclass for {name!r}. Legacy stages don't have instances.")
return cls()
def list_stages() -> list[StageDefinition]:
"""List all registered stage definitions (new + legacy)."""
return list(_all_definitions().values())
def list_stage_classes() -> list[type[Stage]]:
"""List all registered Stage subclasses (new-style only)."""
return list(_REGISTRY.values())
def get_palette() -> dict[str, list[StageDefinition]]:
"""Group stages by category for the editor palette."""
palette: dict[str, list[StageDefinition]] = {}
for defn in _all_definitions().values():
if defn.category not in palette:
palette[defn.category] = []
palette[defn.category].append(defn)
return palette

View File

@@ -0,0 +1,216 @@
"""
Stage 5 — Brand Resolver (discovery mode)
Discovery-first brand matching. No static dictionary — all brands live in the DB.
Flow:
1. Check session brands first (brands already seen in this run, in-memory)
2. Check global known brands (accumulated across all runs)
3. Unresolved candidates → escalate to VLM/cloud
4. Confirmed brands get added to DB for future runs
"""
from __future__ import annotations
import logging
from rapidfuzz import fuzz
from core.detect import emit
from core.detect.models import BrandDetection, TextCandidate
from core.detect.stages.models import ResolverConfig
logger = logging.getLogger(__name__)
def _normalize(text: str) -> str:
return text.strip().lower()
def _has_db() -> bool:
try:
from core.db import find_brand_by_text as _
return True
except (ImportError, Exception):
return False
def _match_session(text: str, session_brands: dict[str, str]) -> str | None:
return session_brands.get(_normalize(text))
def _match_known(text: str, threshold: int) -> tuple[str | None, str | None]:
"""Check against global known brands in DB. Returns (canonical_name, brand_id) or (None, None)."""
if not _has_db():
return None, None
from core.db import find_brand_by_text, list_brands
from core.db.connection import get_session
with get_session() as session:
brand = find_brand_by_text(session, text)
if brand:
return brand.canonical_name, str(brand.id)
all_brands = list_brands(session)
normalized = _normalize(text)
best_brand = None
best_score = 0
for known in all_brands:
names = [known.canonical_name] + (known.aliases or [])
for name in names:
score = fuzz.ratio(normalized, _normalize(name))
if score > best_score and score >= threshold:
best_score = score
best_brand = known
if best_brand:
return best_brand.canonical_name, str(best_brand.id)
return None, None
def _register_brand(canonical_name: str, source: str) -> str | None:
"""Register a newly discovered brand in the DB. Returns brand_id."""
if not _has_db():
return None
from core.db import get_or_create_brand
from core.db.connection import get_session
with get_session() as session:
brand, created = get_or_create_brand(session, canonical_name, source=source)
session.commit()
if created:
logger.info("New brand discovered: %s (source=%s)", canonical_name, source)
return str(brand.id)
def _record_airing(timeline_id: str | None, brand_id: str,
frame_seq: int, confidence: float, source: str):
"""Record a brand airing on a timeline."""
if not _has_db() or not timeline_id:
return
from core.db import record_airing
from core.db.connection import get_session
from uuid import UUID
with get_session() as session:
record_airing(
session,
brand_id=UUID(brand_id),
timeline_id=UUID(timeline_id),
frame_start=frame_seq,
frame_end=frame_seq,
confidence=confidence,
source=source,
)
session.commit()
def build_session_dict(source_asset_id: str | None = None) -> dict[str, str]:
"""
Load known brands from DB as a session lookup dict.
Returns {normalized_name: canonical_name, ...} including aliases.
"""
if not _has_db():
return {}
from core.db import list_brands
from core.db.connection import get_session
with get_session() as session:
all_brands = list_brands(session)
session_dict = {}
for brand in all_brands:
session_dict[_normalize(brand.canonical_name)] = brand.canonical_name
for alias in (brand.aliases or []):
session_dict[_normalize(alias)] = brand.canonical_name
return session_dict
def resolve_brands(
candidates: list[TextCandidate],
config: ResolverConfig,
session_brands: dict[str, str] | None = None,
source_asset_id: str | None = None,
content_type: str = "",
job_id: str | None = None,
) -> tuple[list[BrandDetection], list[TextCandidate]]:
"""
Match text candidates against known brands (session → global → unresolved).
session_brands: pre-loaded session dict (from build_session_dict)
job_id: timeline_id — used to record airings
"""
if session_brands is None:
session_brands = {}
emit.log(job_id, "BrandResolver", "INFO",
f"Resolving {len(candidates)} candidates "
f"(session={len(session_brands)} brands, fuzzy={config.fuzzy_threshold})")
matched: list[BrandDetection] = []
unresolved: list[TextCandidate] = []
session_hits = 0
known_hits = 0
for candidate in candidates:
text = candidate.text
brand_name = None
brand_id = None
match_source = "ocr"
# 1. Check session (cheapest — in-memory dict)
brand_name = _match_session(text, session_brands)
if brand_name:
session_hits += 1
else:
# 2. Check global known brands (DB query + fuzzy)
brand_name, brand_id = _match_known(text, config.fuzzy_threshold)
if brand_name:
known_hits += 1
session_brands[_normalize(brand_name)] = brand_name
if brand_name:
detection = BrandDetection(
brand=brand_name,
timestamp=candidate.frame.timestamp,
duration=0.5,
confidence=candidate.ocr_confidence,
source=match_source,
bbox=candidate.bbox,
frame_ref=candidate.frame.sequence,
content_type=content_type,
)
matched.append(detection)
if brand_id:
_record_airing(
job_id, brand_id,
candidate.frame.sequence, candidate.ocr_confidence, match_source,
)
emit.detection(
job_id,
brand=brand_name,
confidence=candidate.ocr_confidence,
source=match_source,
timestamp=candidate.frame.timestamp,
content_type=content_type,
frame_ref=candidate.frame.sequence,
)
else:
unresolved.append(candidate)
emit.log(job_id, "BrandResolver", "INFO",
f"Session: {session_hits}, Known: {known_hits}, "
f"Unresolved: {len(unresolved)} → escalating")
return matched, unresolved

View File

@@ -0,0 +1,292 @@
"""
Stage — Edge Detection
Canny + HoughLinesP to find horizontal line pairs that bound
advertising hoardings. Pure OpenCV, no ML models.
Two modes:
- Remote: calls GPU inference server over HTTP
- Local: imports cv2 directly (OpenCV on same machine)
"""
from __future__ import annotations
import base64
import io
import json
import logging
import os
import time
from typing import Any
from PIL import Image
from core.detect import emit
from core.detect.models import BoundingBox, Frame
from core.detect.stages.base import Stage
from core.detect.stages.models import StageDefinition, StageConfigField, StageIO, StageOutputHint
logger = logging.getLogger(__name__)
class EdgeDetectionStage(Stage):
definition = StageDefinition(
name="detect_edges",
label="Edge Detection",
description="Canny + HoughLinesP — find horizontal line pairs (hoarding boundaries)",
category="cv_analysis",
io=StageIO(
reads=["filtered_frames"],
writes=["edge_regions_by_frame"],
),
config_fields=[
StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255),
StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255),
StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500),
StageConfigField(name="edge_hough_min_length", type="int", default=100, description="Min line length (px)", min=10, max=2000),
StageConfigField(name="edge_hough_max_gap", type="int", default=10, description="Max line gap (px)", min=1, max=100),
StageConfigField(name="edge_pair_max_distance", type="int", default=200, description="Max distance between line pair (px)", min=10, max=500),
StageConfigField(name="edge_pair_min_distance", type="int", default=15, description="Min distance between line pair (px)", min=5, max=200),
],
output_hints=[
StageOutputHint(key="edge_regions_by_frame", type="boxes_by_frame", label="Edge regions"),
StageOutputHint(key="edge_overlay_b64", type="overlay", label="Canny edges", default_opacity=0.25),
StageOutputHint(key="lines_overlay_b64", type="overlay", label="Hough lines", default_opacity=0.25),
],
tracks_element="edge_region",
)
def run(self, frames: list[Frame], config: dict) -> dict[int, list[BoundingBox]]:
"""
Run edge detection on all frames.
Config keys: enabled, edge_canny_low, edge_canny_high, edge_hough_threshold,
edge_hough_min_length, edge_hough_max_gap, edge_pair_max_distance, edge_pair_min_distance,
debug (bool), inference_url (str|None), job_id (str|None).
Returns dict mapping frame sequence → list of BoundingBox.
"""
enabled = config.get("enabled", True)
job_id = config.get("job_id")
inference_url = config.get("inference_url") or os.environ.get("INFERENCE_URL")
if not enabled:
emit.log(job_id, "EdgeDetection", "INFO", "Edge detection disabled, skipping")
return {}
mode = "remote" if inference_url else "local"
emit.log(job_id, "EdgeDetection", "INFO",
f"Detecting edges in {len(frames)} frames (mode={mode})")
all_boxes: dict[int, list[BoundingBox]] = {}
total_regions = 0
for frame in frames:
t0 = time.monotonic()
if inference_url:
boxes = self._run_remote(frame, config, inference_url, job_id or "")
else:
boxes = self._run_local(frame, config)
ms = (time.monotonic() - t0) * 1000
all_boxes[frame.sequence] = boxes
total_regions += len(boxes)
emit.log(job_id, "EdgeDetection", "DEBUG",
f"Frame {frame.sequence}: {len(boxes)} regions in {ms:.0f}ms"
+ (f" [{', '.join(b.label for b in boxes)}]" if boxes else ""))
if boxes and job_id:
box_dicts = [
{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
"confidence": b.confidence, "label": b.label,
"stage": "detect_edges"}
for b in boxes
]
emit.frame_update(
job_id,
frame_ref=frame.sequence,
timestamp=frame.timestamp,
jpeg_b64=_frame_to_b64(frame),
boxes=box_dicts,
)
emit.log(job_id, "EdgeDetection", "INFO",
f"Found {total_regions} edge regions across {len(frames)} frames")
emit.stats(job_id, cv_regions_detected=total_regions)
return all_boxes
def serialize(self, output: Any) -> bytes:
"""Serialize edge regions to JSON blob."""
serialized = {}
for seq, boxes in output.items():
serialized[str(seq)] = [
{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
"confidence": b.confidence, "label": b.label}
for b in boxes
]
return json.dumps(serialized).encode()
def deserialize(self, data: bytes) -> dict[int, list[BoundingBox]]:
"""Deserialize edge regions from JSON blob."""
raw = json.loads(data)
result = {}
for seq_str, box_dicts in raw.items():
boxes = [
BoundingBox(x=b["x"], y=b["y"], w=b["w"], h=b["h"],
confidence=b["confidence"], label=b["label"])
for b in box_dicts
]
result[int(seq_str)] = boxes
return result
# --- Private helpers ---
def _run_remote(self, frame: Frame, config: dict,
inference_url: str, job_id: str) -> list[BoundingBox]:
from core.detect.inference import InferenceClient
from core.detect.emit import _run_log_level
client = InferenceClient(
base_url=inference_url, job_id=job_id, log_level=_run_log_level,
)
results = client.detect_edges(
image=frame.image,
edge_canny_low=config.get("edge_canny_low", 50),
edge_canny_high=config.get("edge_canny_high", 150),
edge_hough_threshold=config.get("edge_hough_threshold", 80),
edge_hough_min_length=config.get("edge_hough_min_length", 100),
edge_hough_max_gap=config.get("edge_hough_max_gap", 10),
edge_pair_max_distance=config.get("edge_pair_max_distance", 200),
edge_pair_min_distance=config.get("edge_pair_min_distance", 15),
)
boxes = []
for r in results:
box = BoundingBox(
x=r.x, y=r.y, w=r.w, h=r.h,
confidence=r.confidence, label=r.label,
)
boxes.append(box)
return boxes
def _run_local(self, frame: Frame, config: dict) -> list[BoundingBox]:
detect_edges_fn = _load_cv_edges().detect_edges
edge_results = detect_edges_fn(
frame.image,
canny_low=config.get("edge_canny_low", 50),
canny_high=config.get("edge_canny_high", 150),
hough_threshold=config.get("edge_hough_threshold", 80),
hough_min_length=config.get("edge_hough_min_length", 100),
hough_max_gap=config.get("edge_hough_max_gap", 10),
pair_max_distance=config.get("edge_pair_max_distance", 200),
pair_min_distance=config.get("edge_pair_min_distance", 15),
)
boxes = []
for r in edge_results:
box = BoundingBox(
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
confidence=r["confidence"], label=r["label"],
)
boxes.append(box)
return boxes
# --- Module-level helpers ---
def _frame_to_b64(frame: Frame) -> str:
img = Image.fromarray(frame.image)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=70)
return base64.b64encode(buf.getvalue()).decode()
_cv_edges_mod = None
def _load_cv_edges():
global _cv_edges_mod
if _cv_edges_mod is None:
import importlib.util
from pathlib import Path
spec = importlib.util.spec_from_file_location("cv_edges", Path("core/gpu/models/cv/edges.py"))
_cv_edges_mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(_cv_edges_mod)
return _cv_edges_mod
# --- Backward compat: standalone function for graph.py ---
def _filter_by_field_mask(boxes, mask, margin_px=50):
"""
Keep only boxes that are near the pitch boundary (hoarding zone).
The field mask has 255=pitch, 0=not pitch. Hoardings sit just
outside the pitch boundary. We dilate the mask to create a
"boundary zone" and keep boxes whose center falls in the zone
between the dilated mask edge and the original mask.
"""
import cv2
import numpy as np
if mask is None or not boxes:
return boxes
# Dilate the pitch mask — the expansion zone is where hoardings are
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (margin_px * 2, margin_px * 2))
dilated = cv2.dilate(mask, kernel)
# Boundary zone = dilated but NOT original pitch
boundary_zone = cv2.bitwise_and(dilated, cv2.bitwise_not(mask))
kept = []
for box in boxes:
cx = box.x + box.w // 2
cy = box.y + box.h // 2
# Clamp to image bounds
cy = min(cy, boundary_zone.shape[0] - 1)
cx = min(cx, boundary_zone.shape[1] - 1)
if boundary_zone[cy, cx] > 0:
kept.append(box)
return kept
def detect_edge_regions(frames, config, inference_url=None, job_id=None, field_masks=None):
"""Convenience wrapper — calls EdgeDetectionStage.run(), optionally filters by field mask."""
stage = EdgeDetectionStage()
cfg = {
"enabled": config.enabled,
"edge_canny_low": config.edge_canny_low,
"edge_canny_high": config.edge_canny_high,
"edge_hough_threshold": config.edge_hough_threshold,
"edge_hough_min_length": config.edge_hough_min_length,
"edge_hough_max_gap": config.edge_hough_max_gap,
"edge_pair_max_distance": config.edge_pair_max_distance,
"edge_pair_min_distance": config.edge_pair_min_distance,
"inference_url": inference_url,
"job_id": job_id,
}
all_boxes = stage.run(frames, cfg)
# Filter by field segmentation mask if available
if field_masks:
filtered_total = 0
original_total = sum(len(b) for b in all_boxes.values())
for seq, boxes in all_boxes.items():
mask = field_masks.get(seq)
if mask is not None:
all_boxes[seq] = _filter_by_field_mask(boxes, mask)
filtered_total += len(all_boxes[seq])
else:
filtered_total += len(boxes)
if original_total != filtered_total:
from core.detect import emit
emit.log(job_id, "EdgeDetection", "INFO",
f"Field mask filter: {original_total}{filtered_total} regions")
return all_boxes

View File

@@ -0,0 +1,151 @@
"""
Stage — Field Segmentation
Calls the GPU inference server to detect pitch boundaries via
HSV green mask + morphology. The CV code lives in core/gpu/models/cv/.
Outputs a mask and boundary that downstream stages use as spatial priors.
"""
from __future__ import annotations
import base64
import io
import logging
import numpy as np
from PIL import Image
from core.detect import emit
from core.detect.models import Frame
from core.detect.stages.base import Stage
from core.detect.stages.models import (
FieldSegmentationConfig,
StageConfigField,
StageDefinition,
StageIO,
StageOutputHint,
TransformOption,
)
logger = logging.getLogger(__name__)
class FieldSegmentationStage(Stage):
definition = StageDefinition(
name="field_segmentation",
label="Field Segmentation",
description="HSV green mask — detect pitch boundaries for spatial priors",
category="cv_analysis",
io=StageIO(
reads=["filtered_frames"],
writes=["field_mask"],
),
config_fields=[
StageConfigField(name="hue_low", type="int", default=30, description="HSV hue lower bound", min=0, max=180),
StageConfigField(name="hue_high", type="int", default=85, description="HSV hue upper bound", min=0, max=180),
StageConfigField(name="sat_low", type="int", default=30, description="HSV saturation lower bound", min=0, max=255),
StageConfigField(name="sat_high", type="int", default=255, description="HSV saturation upper bound", min=0, max=255),
StageConfigField(name="val_low", type="int", default=30, description="HSV value lower bound", min=0, max=255),
StageConfigField(name="val_high", type="int", default=255, description="HSV value upper bound", min=0, max=255),
StageConfigField(name="morph_kernel", type="int", default=15, description="Morphology kernel size", min=3, max=51),
StageConfigField(name="min_area_ratio", type="float", default=0.05, description="Min contour area as fraction of frame", min=0.01, max=0.5),
],
output_hints=[
StageOutputHint(key="mask_overlay_b64", type="overlay", label="Field mask", default_opacity=0.5, src_format="png"),
],
accepted_transforms=[
TransformOption(key="invert_mask", type="bool", default=False, label="Invert selection", description="Invert the mask so downstream stages look outside the detected area"),
],
)
def _frame_to_b64(frame: Frame) -> str:
"""Encode frame image as base64 JPEG."""
img = Image.fromarray(frame.image)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
return base64.b64encode(buf.getvalue()).decode()
def _decode_mask_b64(mask_b64: str) -> np.ndarray:
"""Decode a base64 PNG mask back to numpy array."""
data = base64.b64decode(mask_b64)
img = Image.open(io.BytesIO(data)).convert("L")
return np.array(img)
def run_field_segmentation(
frames: list[Frame],
config: FieldSegmentationConfig,
inference_url: str | None = None,
job_id: str | None = None,
) -> dict:
"""
Run field segmentation on all frames via the inference server.
Returns dict with:
field_masks: {seq: np.ndarray}
field_boundaries: {seq: [(x,y), ...]}
field_coverage: {seq: float}
"""
if not config.enabled:
emit.log(job_id, "FieldSegmentation", "INFO", "Disabled, skipping")
return {"field_masks": {}, "field_boundaries": {}, "field_coverage": {}}
import os
url = inference_url or os.environ.get("INFERENCE_URL")
if not url:
emit.log(job_id, "FieldSegmentation", "WARNING",
"No INFERENCE_URL, skipping field segmentation")
return {"field_masks": {}, "field_boundaries": {}, "field_coverage": {}}
emit.log(job_id, "FieldSegmentation", "INFO",
f"Segmenting {len(frames)} frames (hue={config.hue_low}-{config.hue_high})")
from core.detect.inference import InferenceClient
from core.detect.emit import _run_log_level
client = InferenceClient(base_url=url, job_id=job_id or "", log_level=_run_log_level)
field_masks = {}
field_mask_overlays = {}
field_boundaries = {}
field_coverage = {}
for frame in frames:
image_b64 = _frame_to_b64(frame)
resp = client.post("/segment_field", {
"image": image_b64,
"hue_low": config.hue_low,
"hue_high": config.hue_high,
"sat_low": config.sat_low,
"sat_high": config.sat_high,
"val_low": config.val_low,
"val_high": config.val_high,
"morph_kernel": config.morph_kernel,
"min_area_ratio": config.min_area_ratio,
})
if resp is None:
continue
mask_b64 = resp.get("mask_b64", "")
if mask_b64:
field_masks[frame.sequence] = _decode_mask_b64(mask_b64)
field_mask_overlays[frame.sequence] = mask_b64
field_boundaries[frame.sequence] = resp.get("boundary", [])
field_coverage[frame.sequence] = resp.get("coverage", 0.0)
avg_coverage = sum(field_coverage.values()) / max(len(field_coverage), 1)
emit.log(job_id, "FieldSegmentation", "INFO",
f"Done: {len(frames)} frames, avg coverage {avg_coverage:.1%}")
return {
"field_masks": field_masks,
"field_mask_overlays": field_mask_overlays,
"field_boundaries": field_boundaries,
"field_coverage": field_coverage,
}

View File

@@ -0,0 +1,93 @@
"""
Stage 1 — Frame Extraction
Extracts frames from a video at a configurable FPS using the core ffmpeg module.
Emits log + stats_update SSE events as it works.
"""
from __future__ import annotations
import tempfile
import time
from pathlib import Path
import ffmpeg
import numpy as np
from PIL import Image
from core.ffmpeg.probe import probe_file
from core.detect import emit
from core.detect.models import Frame
from core.detect.stages.models import FrameExtractionConfig
def _load_frames(tmpdir: Path, fps: float) -> list[Frame]:
"""Load extracted JPEG files into Frame objects."""
frame_files = sorted(tmpdir.glob("frame_*.jpg"))
frames = []
for i, fpath in enumerate(frame_files):
img = Image.open(fpath)
frame = Frame(
sequence=i,
chunk_id=0,
timestamp=i / fps,
image=np.array(img),
)
frames.append(frame)
return frames
def extract_frames(
video_path: str,
config: FrameExtractionConfig,
job_id: str | None = None,
) -> list[Frame]:
"""
Extract frames from video at the configured FPS.
Uses ffmpeg-python to build the extraction pipeline,
outputs JPEG files to a temp dir, then loads as numpy arrays.
"""
probe = probe_file(video_path)
duration = probe.duration or 0.0
emit.log(job_id, "FrameExtractor", "INFO",
f"Starting extraction: {Path(video_path).name} "
f"({duration:.1f}s, {probe.width}x{probe.height}, fps={config.fps})")
emit.log(job_id, "FrameExtractor", "DEBUG",
f"Probe: codec={probe.video_codec}, bitrate={probe.video_bitrate}, max_frames={config.max_frames}")
with tempfile.TemporaryDirectory() as tmpdir:
pattern = str(Path(tmpdir) / "frame_%06d.jpg")
stream = (
ffmpeg
.input(video_path)
.filter("fps", fps=config.fps)
.output(pattern, qscale=2, frames=config.max_frames)
.overwrite_output()
)
t0 = time.monotonic()
try:
stream.run(capture_stdout=True, capture_stderr=True, quiet=True)
except ffmpeg.Error as e:
stderr = e.stderr.decode() if e.stderr else "unknown error"
emit.log(job_id, "FrameExtractor", "ERROR", f"FFmpeg failed: {stderr[:200]}")
raise RuntimeError(f"FFmpeg failed: {stderr}") from e
ffmpeg_ms = (time.monotonic() - t0) * 1000
emit.log(job_id, "FrameExtractor", "DEBUG", f"FFmpeg decode: {ffmpeg_ms:.0f}ms")
t0 = time.monotonic()
frames = _load_frames(Path(tmpdir), config.fps)
load_ms = (time.monotonic() - t0) * 1000
if frames:
h, w = frames[0].image.shape[:2]
mem_mb = sum(f.image.nbytes for f in frames) / (1024 * 1024)
emit.log(job_id, "FrameExtractor", "DEBUG",
f"Loaded {len(frames)} frames ({w}x{h}) in {load_ms:.0f}ms, {mem_mb:.1f}MB in memory")
emit.log(job_id, "FrameExtractor", "INFO", f"Extracted {len(frames)} frames")
emit.stats(job_id, frames_extracted=len(frames))
return frames

View File

@@ -0,0 +1,125 @@
"""
Pydantic Models - GENERATED FILE
Do not edit directly. Regenerate using modelgen.
"""
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel, Field
class StageConfigField(BaseModel):
"""A single tunable config parameter for the editor UI."""
name: str
type: str
default: Any
description: str = ""
min: Optional[float] = None
max: Optional[float] = None
options: Optional[List[str]] = None
class StageIO(BaseModel):
"""Declares what a stage reads and writes."""
reads: List[str] = Field(default_factory=list)
writes: List[str] = Field(default_factory=list)
optional_reads: List[str] = Field(default_factory=list)
class StageOutputHint(BaseModel):
"""How to render a stage output in the compare/editor views."""
key: str
type: str
label: str = ""
default_opacity: float = 0.5
src_format: str = "png"
class TransformOption(BaseModel):
"""A transform the stage accepts on its incoming edges."""
key: str
type: str
default: Any = False
label: str = ""
description: str = ""
class StageDefinition(BaseModel):
"""Complete metadata for a pipeline stage."""
name: str
label: str
description: str
category: str = "detection"
io: StageIO
config_fields: List[StageConfigField] = Field(default_factory=list)
output_hints: List[StageOutputHint] = Field(default_factory=list)
accepted_transforms: List[TransformOption] = Field(default_factory=list)
tracks_element: Optional[str] = None
class FrameExtractionConfig(BaseModel):
"""FrameExtractionConfig(fps: float = 2.0, max_frames: int = 500)"""
fps: float = 2.0
max_frames: int = 500
class SceneFilterConfig(BaseModel):
"""SceneFilterConfig(hamming_threshold: int = 8, enabled: bool = True)"""
hamming_threshold: int = 8
enabled: bool = True
class DetectionConfig(BaseModel):
"""DetectionConfig(model_name: str = 'yolov8n.pt', confidence_threshold: float = 0.3, target_classes: List[str] = <factory>)"""
model_name: str = "yolov8n.pt"
confidence_threshold: float = 0.3
target_classes: List[str]
class OCRConfig(BaseModel):
"""OCRConfig(languages: List[str] = <factory>, min_confidence: float = 0.5)"""
languages: List[str]
min_confidence: float = 0.5
class ResolverConfig(BaseModel):
"""ResolverConfig(fuzzy_threshold: int = 75)"""
fuzzy_threshold: int = 75
class RegionAnalysisConfig(BaseModel):
"""RegionAnalysisConfig(enabled: bool = True, edge_canny_low: int = 50, edge_canny_high: int = 150, edge_hough_threshold: int = 80, edge_hough_min_length: int = 100, edge_hough_max_gap: int = 10, edge_pair_max_distance: int = 200, edge_pair_min_distance: int = 15)"""
enabled: bool = True
edge_canny_low: int = 50
edge_canny_high: int = 150
edge_hough_threshold: int = 80
edge_hough_min_length: int = 100
edge_hough_max_gap: int = 10
edge_pair_max_distance: int = 200
edge_pair_min_distance: int = 15
class FieldSegmentationConfig(BaseModel):
"""FieldSegmentationConfig(enabled: bool = True, hue_low: int = 30, hue_high: int = 85, sat_low: int = 30, sat_high: int = 255, val_low: int = 30, val_high: int = 255, morph_kernel: int = 15, min_area_ratio: float = 0.05)"""
enabled: bool = True
hue_low: int = 30
hue_high: int = 85
sat_low: int = 30
sat_high: int = 255
val_low: int = 30
val_high: int = 255
morph_kernel: int = 15
min_area_ratio: float = 0.05
class StageRef(BaseModel):
"""Reference to a stage in the pipeline graph."""
name: str
branch: str = "trunk"
execution_target: str = "local"
class Edge(BaseModel):
"""Connection between stages in the graph."""
source: str
target: str
condition: str = ""
transform: Dict[str, Any] = Field(default_factory=dict)
class PipelineConfig(BaseModel):
"""Pipeline graph topology + routing rules."""
name: str
profile_name: str
stages: List[StageRef] = Field(default_factory=list)
edges: List[Edge] = Field(default_factory=list)
routing_rules: Dict[str, Any] = Field(default_factory=dict)

View File

@@ -0,0 +1,139 @@
"""
Stage 4 — OCR
Reads text from detected regions (YOLO bounding box crops).
Two modes:
- remote: calls inference server over HTTP (separate GPU box, or localhost)
- local: runs PaddleOCR in-process (single-box setup with enough VRAM)
The mode is selected by whether inference_url is provided.
Model instances are cached at module level so they survive across pipeline runs.
"""
from __future__ import annotations
import logging
import time
from typing import TYPE_CHECKING
import numpy as np
from core.detect import emit
from core.detect.models import BoundingBox, Frame, TextCandidate
from core.detect.stages.models import OCRConfig
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
# Module-level cache — avoids reloading the model for every crop or pipeline run
_local_ocr_cache: dict[str, object] = {}
def _crop_region(frame: Frame, box: BoundingBox) -> np.ndarray:
h, w = frame.image.shape[:2]
x1 = max(0, box.x)
y1 = max(0, box.y)
x2 = min(w, box.x + box.w)
y2 = min(h, box.y + box.h)
return frame.image[y1:y2, x1:x2]
def _get_local_model(lang: str):
if lang not in _local_ocr_cache:
from paddleocr import PaddleOCR
logger.info("Loading PaddleOCR locally (lang=%s)", lang)
_local_ocr_cache[lang] = PaddleOCR(lang=lang)
return _local_ocr_cache[lang]
def _parse_ocr_raw(raw, min_confidence: float) -> list[dict]:
"""Parse PaddleOCR 3.x result — handles dict-based and nested-list layouts."""
results = []
for page in (raw or []):
if not page:
continue
if isinstance(page, dict):
for text, confidence in zip(page.get("rec_texts", []), page.get("rec_scores", [])):
if float(confidence) >= min_confidence:
results.append({"text": text, "confidence": float(confidence)})
continue
for line in page:
if not line:
continue
rec = line[1]
if isinstance(rec, (list, tuple)) and len(rec) >= 2:
text, confidence = rec[0], rec[1]
if float(confidence) >= min_confidence:
results.append({"text": text, "confidence": float(confidence)})
return results
def run_ocr(
frames: list[Frame],
boxes_by_frame: dict[int, list[BoundingBox]],
config: OCRConfig,
inference_url: str | None = None,
job_id: str | None = None,
) -> list[TextCandidate]:
"""
Run OCR on cropped regions from YOLO detections.
inference_url=None → local in-process PaddleOCR (single-box)
inference_url=str → remote inference server (split or localhost)
"""
total_regions = sum(len(boxes) for boxes in boxes_by_frame.values())
mode = "remote" if inference_url else "local"
emit.log(job_id, "OCRStage", "INFO",
f"Running OCR on {total_regions} regions (mode={mode})")
# Build these once per pipeline run, not per crop
if inference_url:
from core.detect.inference import InferenceClient
from core.detect.emit import _run_log_level
client = InferenceClient(base_url=inference_url, job_id=job_id or "", log_level=_run_log_level)
else:
model = _get_local_model(config.languages[0])
frame_map = {f.sequence: f for f in frames}
candidates: list[TextCandidate] = []
for seq, boxes in boxes_by_frame.items():
frame = frame_map.get(seq)
if not frame:
continue
for box in boxes:
crop = _crop_region(frame, box)
if crop.size == 0:
continue
t0 = time.monotonic()
if inference_url:
raw_results = client.ocr(image=crop, languages=config.languages)
texts = [{"text": r.text, "confidence": r.confidence} for r in raw_results]
else:
raw = model.ocr(crop)
texts = _parse_ocr_raw(raw, config.min_confidence)
ocr_ms = (time.monotonic() - t0) * 1000
h, w = crop.shape[:2]
text_preview = ", ".join(t["text"][:30] for t in texts) if texts else "(none)"
emit.log(job_id, "OCRStage", "DEBUG",
f"Frame {seq} box {box.x},{box.y} ({w}x{h}): {ocr_ms:.0f}ms → {text_preview}")
for t in texts:
candidates.append(TextCandidate(
frame=frame,
bbox=box,
text=t["text"],
ocr_confidence=t["confidence"],
))
emit.log(job_id, "OCRStage", "INFO",
f"Extracted text from {len(candidates)} regions")
emit.stats(job_id, regions_resolved_by_ocr=len(candidates))
return candidates

View File

@@ -0,0 +1,128 @@
"""
Stage 3.5 — Preprocessing
Runs between YOLO detection and OCR. Applies configurable image
preprocessing to each detected region crop: contrast enhancement,
deskewing, binarization.
Operates on the crops derived from boxes_by_frame, produces
preprocessed_crops keyed by (frame_sequence, box_index).
"""
from __future__ import annotations
import logging
import numpy as np
from core.detect import emit
from core.detect.models import BoundingBox, Frame
logger = logging.getLogger(__name__)
def _crop_region(frame: Frame, box: BoundingBox) -> np.ndarray:
h, w = frame.image.shape[:2]
x1 = max(0, box.x)
y1 = max(0, box.y)
x2 = min(w, box.x + box.w)
y2 = min(h, box.y + box.h)
return frame.image[y1:y2, x1:x2]
def preprocess_regions(
frames: list[Frame],
boxes_by_frame: dict[int, list[BoundingBox]],
do_contrast: bool = True,
do_deskew: bool = False,
do_binarize: bool = False,
inference_url: str | None = None,
job_id: str | None = None,
) -> dict[str, np.ndarray]:
"""
Preprocess cropped regions from YOLO detections.
Returns dict keyed by "{frame_seq}_{box_idx}" → preprocessed crop.
These are passed to the OCR stage instead of raw crops.
"""
total_regions = sum(len(boxes) for boxes in boxes_by_frame.values())
any_active = do_contrast or do_deskew or do_binarize
if not any_active:
emit.log(job_id, "Preprocess", "INFO",
f"Preprocessing disabled, passing {total_regions} regions through")
return {}
mode = "remote" if inference_url else "local"
emit.log(job_id, "Preprocess", "INFO",
f"Preprocessing {total_regions} regions (mode={mode}, "
f"contrast={do_contrast}, deskew={do_deskew}, binarize={do_binarize})")
frame_map = {f.sequence: f for f in frames}
preprocessed: dict[str, np.ndarray] = {}
processed_count = 0
for seq, boxes in boxes_by_frame.items():
frame = frame_map.get(seq)
if not frame:
continue
for idx, box in enumerate(boxes):
crop = _crop_region(frame, box)
if crop.size == 0:
continue
key = f"{seq}_{idx}"
if inference_url:
result = _preprocess_remote(crop, inference_url,
do_contrast, do_deskew, do_binarize)
else:
result = _preprocess_local(crop, do_contrast, do_deskew, do_binarize)
preprocessed[key] = result
processed_count += 1
emit.log(job_id, "Preprocess", "INFO",
f"Preprocessed {processed_count} regions")
return preprocessed
def _preprocess_remote(crop: np.ndarray, inference_url: str,
do_contrast: bool, do_deskew: bool, do_binarize: bool) -> np.ndarray:
"""Call GPU server /preprocess endpoint."""
import base64
import io
import requests
from PIL import Image
img = Image.fromarray(crop)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
image_b64 = base64.b64encode(buf.getvalue()).decode()
resp = requests.post(
f"{inference_url.rstrip('/')}/preprocess",
json={
"image": image_b64,
"contrast": do_contrast,
"deskew": do_deskew,
"binarize": do_binarize,
},
timeout=30,
)
resp.raise_for_status()
data = resp.json()
result_bytes = base64.b64decode(data["image"])
result_img = Image.open(io.BytesIO(result_bytes)).convert("RGB")
return np.array(result_img)
def _preprocess_local(crop: np.ndarray,
do_contrast: bool, do_deskew: bool, do_binarize: bool) -> np.ndarray:
"""Run preprocessing in-process (requires opencv-python-headless)."""
from core.gpu.models.preprocess import preprocess
return preprocess(crop, do_binarize=do_binarize, do_deskew=do_deskew, do_contrast=do_contrast)

View File

@@ -0,0 +1,31 @@
"""
Stage registry — registers all built-in stages.
Split by category:
preprocessing.py — extract_frames, filter_scenes
cv_analysis.py — detect_edges (+ future: detect_contours, detect_color, merge_regions)
detection.py — detect_objects, run_ocr
resolution.py — match_brands
escalation.py — escalate_vlm, escalate_cloud
output.py — compile_report
_serializers.py — shared serialization helpers
"""
from . import preprocessing
from . import cv_analysis
from . import detection
from . import resolution
from . import escalation
from . import output
def register_all():
preprocessing.register()
cv_analysis.register()
detection.register()
resolution.register()
escalation.register()
output.register()
register_all()

View File

@@ -0,0 +1,24 @@
"""
Re-export serializers from core/schema/serializers/.
Stage registry modules import from here for convenience.
All serialization logic lives in core/schema/serializers/.
"""
from core.schema.serializers._common import (
safe_construct,
serialize_dataclass,
serialize_dataclass_list,
)
from core.schema.serializers.pipeline import (
serialize_frame_meta,
serialize_frames_meta,
serialize_text_candidate,
serialize_text_candidates,
deserialize_text_candidate,
deserialize_text_candidates,
deserialize_bounding_box,
deserialize_brand_detection,
deserialize_pipeline_stats,
deserialize_detection_report,
)

View File

@@ -0,0 +1,83 @@
"""Registration for CV analysis stages: edge detection, field segmentation."""
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
from core.detect.stages.base import register_stage
from ._serializers import serialize_dataclass_list, deserialize_bounding_box
def _ser_regions(state: dict, job_id: str) -> dict:
regions = state.get("edge_regions_by_frame", {})
serialized = {
str(seq): serialize_dataclass_list(bl) for seq, bl in regions.items()
}
return {"edge_regions_by_frame": serialized}
def _deser_regions(data: dict, job_id: str) -> dict:
regions = {}
for seq_str, box_dicts in data.get("edge_regions_by_frame", {}).items():
regions[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts]
return {"edge_regions_by_frame": regions}
def _ser_field_seg(state: dict, job_id: str) -> dict:
"""Serialize field segmentation — boundaries + coverage + mask overlays."""
boundaries = state.get("field_boundaries", {})
coverage = state.get("field_coverage", {})
mask_overlays = state.get("field_mask_overlays", {})
return {
"field_boundaries": {str(k): v for k, v in boundaries.items()},
"field_coverage": {str(k): v for k, v in coverage.items()},
"mask_overlays_by_frame": {str(k): v for k, v in mask_overlays.items()},
}
def _deser_field_seg(data: dict, job_id: str) -> dict:
boundaries = {int(k): v for k, v in data.get("field_boundaries", {}).items()}
coverage = {int(k): v for k, v in data.get("field_coverage", {}).items()}
return {"field_boundaries": boundaries, "field_coverage": coverage}
def register():
edge_detection = StageDefinition(
name="detect_edges",
label="Edge Detection",
description="Canny + HoughLinesP — find horizontal line pairs (hoarding boundaries)",
category="cv_analysis",
io=StageIO(
reads=["filtered_frames"],
writes=["edge_regions_by_frame"],
),
config_fields=[
StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255),
StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255),
StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500),
StageConfigField(name="edge_hough_min_length", type="int", default=100, description="Min line length (px)", min=10, max=2000),
StageConfigField(name="edge_hough_max_gap", type="int", default=10, description="Max line gap (px)", min=1, max=100),
StageConfigField(name="edge_pair_max_distance", type="int", default=200, description="Max distance between line pair (px)", min=10, max=500),
StageConfigField(name="edge_pair_min_distance", type="int", default=15, description="Min distance between line pair (px)", min=5, max=200),
],
)
register_stage(edge_detection, serialize_fn=_ser_regions, deserialize_fn=_deser_regions)
field_seg = StageDefinition(
name="field_segmentation",
label="Field Segmentation",
description="HSV green mask — detect pitch boundaries",
category="cv_analysis",
io=StageIO(
reads=["filtered_frames"],
writes=["field_mask"],
),
config_fields=[
StageConfigField(name="hue_low", type="int", default=30, description="HSV hue lower bound", min=0, max=180),
StageConfigField(name="hue_high", type="int", default=85, description="HSV hue upper bound", min=0, max=180),
StageConfigField(name="sat_low", type="int", default=30, description="HSV saturation lower bound", min=0, max=255),
StageConfigField(name="sat_high", type="int", default=255, description="HSV saturation upper bound", min=0, max=255),
StageConfigField(name="val_low", type="int", default=30, description="HSV value lower bound", min=0, max=255),
StageConfigField(name="val_high", type="int", default=255, description="HSV value upper bound", min=0, max=255),
StageConfigField(name="morph_kernel", type="int", default=15, description="Morphology kernel size", min=3, max=51),
StageConfigField(name="min_area_ratio", type="float", default=0.05, description="Min contour area", min=0.01, max=0.5),
],
)
register_stage(field_seg, serialize_fn=_ser_field_seg, deserialize_fn=_deser_field_seg)

View File

@@ -0,0 +1,60 @@
"""Registration for detection stages: YOLO, OCR."""
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
from core.detect.stages.base import register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
deserialize_bounding_box,
)
def _ser_detect(state: dict, job_id: str) -> dict:
boxes = state.get("boxes_by_frame", {})
serialized = {str(seq): serialize_dataclass_list(bl) for seq, bl in boxes.items()}
return {"boxes_by_frame": serialized}
def _deser_detect(data: dict, job_id: str) -> dict:
boxes = {}
for seq_str, box_dicts in data.get("boxes_by_frame", {}).items():
boxes[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts]
return {"boxes_by_frame": boxes}
def _ser_ocr(state: dict, job_id: str) -> dict:
candidates = state.get("text_candidates", [])
return {"text_candidates": serialize_text_candidates(candidates)}
def _deser_ocr(data: dict, job_id: str) -> dict:
return {"_text_candidates_raw": data["text_candidates"]}
def register():
yolo = StageDefinition(
name="detect_objects",
label="Object Detection",
description="YOLO object detection on filtered frames",
category="detection",
io=StageIO(reads=["filtered_frames"], writes=["boxes_by_frame"]),
config_fields=[
StageConfigField(name="model_name", type="str", default="yolov8n.pt", description="YOLO model file"),
StageConfigField(name="confidence_threshold", type="float", default=0.3, description="Min detection confidence", min=0.0, max=1.0),
StageConfigField(name="target_classes", type="list[str]", default=[], description="YOLO classes to detect (empty = all)"),
],
)
register_stage(yolo, serialize_fn=_ser_detect, deserialize_fn=_deser_detect)
ocr = StageDefinition(
name="run_ocr",
label="OCR",
description="Extract text from detected regions",
category="detection",
io=StageIO(reads=["filtered_frames", "boxes_by_frame"], writes=["text_candidates"]),
config_fields=[
StageConfigField(name="languages", type="list[str]", default=["en"], description="OCR languages"),
StageConfigField(name="min_confidence", type="float", default=0.5, description="Min OCR confidence", min=0.0, max=1.0),
],
)
register_stage(ocr, serialize_fn=_ser_ocr, deserialize_fn=_deser_ocr)

View File

@@ -0,0 +1,60 @@
"""Registration for escalation stages: local VLM, cloud LLM."""
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
from core.detect.stages.base import register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
deserialize_brand_detection,
)
def _ser_escalation(state: dict, job_id: str) -> dict:
matched = state.get("detections", [])
unresolved = state.get("unresolved_candidates", [])
return {
"detections": serialize_dataclass_list(matched),
"unresolved_candidates": serialize_text_candidates(unresolved),
}
def _deser_escalation(data: dict, job_id: str) -> dict:
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
return {
"detections": detections,
"_unresolved_raw": data.get("unresolved_candidates", []),
}
def register():
vlm = StageDefinition(
name="escalate_vlm",
label="Local VLM",
description="Process unresolved crops with moondream2",
category="escalation",
io=StageIO(
reads=["unresolved_candidates"],
writes=["detections", "unresolved_candidates"],
optional_reads=["source_asset_id"],
),
config_fields=[
StageConfigField(name="min_confidence", type="float", default=0.5, description="Min VLM confidence", min=0.0, max=1.0),
],
)
register_stage(vlm, serialize_fn=_ser_escalation, deserialize_fn=_deser_escalation)
cloud = StageDefinition(
name="escalate_cloud",
label="Cloud LLM",
description="Escalate remaining crops to cloud provider",
category="escalation",
io=StageIO(
reads=["unresolved_candidates"],
writes=["detections"],
optional_reads=["source_asset_id"],
),
config_fields=[
StageConfigField(name="min_confidence", type="float", default=0.4, description="Min cloud confidence", min=0.0, max=1.0),
],
)
register_stage(cloud, serialize_fn=_ser_escalation, deserialize_fn=_deser_escalation)

View File

@@ -0,0 +1,30 @@
"""Registration for output stages: report compilation."""
from core.detect.stages.base import StageDefinition, StageIO, register_stage
from ._serializers import serialize_dataclass, deserialize_detection_report
def _ser_report(state: dict, job_id: str) -> dict:
report = state.get("report")
if report is None:
return {"report": None}
return {"report": serialize_dataclass(report)}
def _deser_report(data: dict, job_id: str) -> dict:
raw = data.get("report")
if raw is None:
return {"report": None}
return {"report": deserialize_detection_report(raw)}
def register():
report = StageDefinition(
name="compile_report",
label="Report",
description="Merge detections and compile final report",
category="output",
io=StageIO(reads=["detections"], writes=["report"]),
config_fields=[],
)
register_stage(report, serialize_fn=_ser_report, deserialize_fn=_deser_report)

View File

@@ -0,0 +1,82 @@
"""Registration for preprocessing stages: frame extraction, scene filter, image preprocessing."""
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
from core.detect.stages.base import register_stage
from ._serializers import serialize_frame_meta
def _ser_extract(state: dict, job_id: str) -> dict:
frames = state.get("frames", [])
meta = [serialize_frame_meta(f) for f in frames]
return {"frames_meta": meta, "frame_count": len(frames)}
def _deser_extract(data: dict, job_id: str) -> dict:
# Frames are ephemeral — re-extract from chunks on demand.
# Store metadata so we know what was extracted.
return {"_frames_meta": data.get("frames_meta", [])}
def _ser_filter(state: dict, job_id: str) -> dict:
filtered = state.get("filtered_frames", [])
seqs = [f.sequence for f in filtered]
return {"filtered_frame_sequences": seqs}
def _deser_filter(data: dict, job_id: str) -> dict:
return {"_filtered_sequences": data["filtered_frame_sequences"]}
def _ser_preprocess(state: dict, job_id: str) -> dict:
# Preprocessed crops are numpy arrays — regenerable from frames + boxes + config
crops = state.get("preprocessed_crops", {})
return {"crop_keys": list(crops.keys()), "count": len(crops)}
def _deser_preprocess(data: dict, job_id: str) -> dict:
# Crops are regenerable — no need to restore from checkpoint
return {"preprocessed_crops": {}}
def register():
extract = StageDefinition(
name="extract_frames",
label="Frame Extraction",
description="Extract frames from video at configurable FPS",
category="preprocessing",
io=StageIO(reads=["video_path"], writes=["frames"]),
config_fields=[
StageConfigField(name="fps", type="float", default=2.0, description="Frames per second", min=0.1, max=30.0),
StageConfigField(name="max_frames", type="int", default=500, description="Maximum frames to extract", min=1, max=10000),
],
)
register_stage(extract, serialize_fn=_ser_extract, deserialize_fn=_deser_extract)
scene_filter = StageDefinition(
name="filter_scenes",
label="Scene Filter",
description="Deduplicate similar frames using perceptual hashing",
category="preprocessing",
io=StageIO(reads=["frames"], writes=["filtered_frames"]),
config_fields=[
StageConfigField(name="hamming_threshold", type="int", default=8, description="Hamming distance threshold", min=0, max=64),
],
)
register_stage(scene_filter, serialize_fn=_ser_filter, deserialize_fn=_deser_filter)
preprocess = StageDefinition(
name="preprocess",
label="Preprocess",
description="Image preprocessing on detected regions before OCR",
category="preprocessing",
io=StageIO(
reads=["filtered_frames", "boxes_by_frame"],
writes=["preprocessed_crops"],
),
config_fields=[
StageConfigField(name="contrast", type="bool", default=True, description="CLAHE contrast enhancement"),
StageConfigField(name="deskew", type="bool", default=False, description="Correct slight rotation"),
StageConfigField(name="binarize", type="bool", default=False, description="Otsu binarization"),
],
)
register_stage(preprocess, serialize_fn=_ser_preprocess, deserialize_fn=_deser_preprocess)

View File

@@ -0,0 +1,44 @@
"""Registration for resolution stages: brand resolver."""
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
from core.detect.stages.base import register_stage
from ._serializers import (
serialize_dataclass_list,
serialize_text_candidates,
deserialize_brand_detection,
)
def _ser_brands(state: dict, job_id: str) -> dict:
matched = state.get("detections", [])
unresolved = state.get("unresolved_candidates", [])
return {
"detections": serialize_dataclass_list(matched),
"unresolved_candidates": serialize_text_candidates(unresolved),
}
def _deser_brands(data: dict, job_id: str) -> dict:
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
return {
"detections": detections,
"_unresolved_raw": data.get("unresolved_candidates", []),
}
def register():
resolver = StageDefinition(
name="match_brands",
label="Brand Resolver",
description="Match OCR text against known brands (session + global DB)",
category="resolution",
io=StageIO(
reads=["text_candidates"],
writes=["detections", "unresolved_candidates"],
optional_reads=["session_brands", "source_asset_id"],
),
config_fields=[
StageConfigField(name="fuzzy_threshold", type="int", default=75, description="Fuzzy match threshold", min=0, max=100),
],
)
register_stage(resolver, serialize_fn=_ser_brands, deserialize_fn=_deser_brands)

View File

@@ -0,0 +1,86 @@
"""
Stage 2 — Scene Filter
Removes near-duplicate frames using perceptual hashing (pHash).
Frames with a hamming distance below the threshold are considered
duplicates and dropped. This dramatically reduces work for downstream
CV stages without losing unique visual content.
"""
from __future__ import annotations
import time
import imagehash
from PIL import Image
from core.detect import emit
from core.detect.models import Frame
from core.detect.stages.models import SceneFilterConfig
def _compute_hashes(frames: list[Frame]) -> list[imagehash.ImageHash]:
"""Compute perceptual hashes for all frames."""
hashes = []
for f in frames:
img = Image.fromarray(f.image)
h = imagehash.phash(img)
f.perceptual_hash = str(h)
hashes.append(h)
return hashes
def _dedup(frames: list[Frame], hashes: list[imagehash.ImageHash], threshold: int) -> list[Frame]:
"""Greedy dedup: keep a frame if it's sufficiently different from all kept frames."""
kept = [frames[0]]
kept_hashes = [hashes[0]]
for i in range(1, len(frames)):
is_duplicate = any(hashes[i] - kh < threshold for kh in kept_hashes)
if not is_duplicate:
kept.append(frames[i])
kept_hashes.append(hashes[i])
return kept
def scene_filter(
frames: list[Frame],
config: SceneFilterConfig,
job_id: str | None = None,
) -> list[Frame]:
"""
Filter near-duplicate frames based on perceptual hash distance.
Keeps the first frame in each group of similar frames.
Returns a new list — does not mutate the input.
"""
if not config.enabled:
emit.log(job_id, "SceneFilter", "INFO", "Scene filter disabled, passing all frames through")
return frames
if not frames:
return []
emit.log(job_id, "SceneFilter", "INFO",
f"Filtering {len(frames)} frames (hamming_threshold={config.hamming_threshold})")
t0 = time.monotonic()
hashes = _compute_hashes(frames)
hash_ms = (time.monotonic() - t0) * 1000
emit.log(job_id, "SceneFilter", "DEBUG",
f"Computed {len(hashes)} perceptual hashes in {hash_ms:.0f}ms ({hash_ms/max(len(hashes),1):.1f}ms/frame)")
t0 = time.monotonic()
kept = _dedup(frames, hashes, config.hamming_threshold)
dedup_ms = (time.monotonic() - t0) * 1000
emit.log(job_id, "SceneFilter", "DEBUG", f"Dedup pass: {dedup_ms:.0f}ms")
dropped = len(frames) - len(kept)
pct = (dropped / len(frames) * 100) if frames else 0
emit.log(job_id, "SceneFilter", "INFO",
f"Kept {len(kept)} frames, dropped {dropped} ({pct:.0f}% reduction)")
emit.stats(job_id, frames_extracted=len(frames), frames_after_scene_filter=len(kept))
return kept

View File

@@ -0,0 +1,201 @@
"""
Stage 7 — Cloud LLM escalation
Last resort for crops the local VLM couldn't resolve.
Provider-agnostic — switch via CLOUD_LLM_PROVIDER env var.
Each provider has its own file under detect/providers/.
Tracks token usage and cost.
"""
from __future__ import annotations
import base64
import io
import logging
import os
import time
import numpy as np
from PIL import Image
from core.detect import emit
from core.detect.models import BrandDetection, PipelineStats, TextCandidate
from core.detect.models import CropContext
from core.detect.providers import get_provider, has_api_key
logger = logging.getLogger(__name__)
ESTIMATED_TOKENS_PER_CROP = 500
def _register_discovered_brand(brand: str, source_asset_id: str | None,
timestamp: float, confidence: float):
"""Register a cloud-confirmed brand in the DB."""
try:
from core.detect.stages.brand_resolver import _register_brand, _record_sighting
brand_id = _register_brand(brand, "cloud_llm")
if brand_id and source_asset_id:
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, "cloud_llm")
except Exception as e:
logger.debug("Failed to register brand %s: %s", brand, e)
def _encode_crop(crop: np.ndarray) -> str:
img = Image.fromarray(crop)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
return base64.b64encode(buf.getvalue()).decode()
def _crop_image(candidate: TextCandidate) -> np.ndarray:
frame = candidate.frame
box = candidate.bbox
h, w = frame.image.shape[:2]
x1 = max(0, box.x)
y1 = max(0, box.y)
x2 = min(w, box.x + box.w)
y2 = min(h, box.y + box.h)
return frame.image[y1:y2, x1:x2]
def _parse_response(answer: str, total_tokens: int) -> dict:
"""Parse LLM free-text response into structured output."""
parts = [p.strip() for p in answer.split(",", 2)]
brand = parts[0] if parts else ""
confidence = 0.5
reasoning = answer
if len(parts) >= 2:
try:
confidence = float(parts[1])
confidence = max(0.0, min(1.0, confidence))
except ValueError:
pass
if len(parts) >= 3:
reasoning = parts[2]
return {
"brand": brand,
"confidence": confidence,
"reasoning": reasoning,
"tokens": total_tokens or ESTIMATED_TOKENS_PER_CROP,
}
def _call_cloud_api(image_b64: str, prompt: str) -> dict:
"""Route to the configured provider and parse the response."""
provider = get_provider()
result = provider.call(image_b64, prompt)
return _parse_response(result.answer, result.total_tokens)
def escalate_cloud(
candidates: list[TextCandidate],
vlm_prompt_fn,
stats: PipelineStats,
min_confidence: float = 0.4,
content_type: str = "",
source_asset_id: str | None = None,
job_id: str | None = None,
) -> list[BrandDetection]:
"""
Send remaining unresolved crops to cloud LLM.
Provider is selected via CLOUD_LLM_PROVIDER env var (groq, gemini, openai).
Updates stats with call count and cost.
"""
if not candidates:
return []
if os.environ.get("SKIP_CLOUD", "").strip() == "1":
emit.log(job_id, "CloudLLM", "INFO",
f"SKIP_CLOUD=1, skipping {len(candidates)} crops")
return []
if not has_api_key():
emit.log(job_id, "CloudLLM", "WARNING",
f"No API key set for cloud provider, skipping {len(candidates)} crops")
return []
provider = get_provider()
emit.log(job_id, "CloudLLM", "INFO",
f"Escalating {len(candidates)} crops to {provider.name}")
matched: list[BrandDetection] = []
total_cost = 0.0
for i, candidate in enumerate(candidates):
crop = _crop_image(candidate)
if crop.size == 0:
continue
crop_context = CropContext(
image=b"",
surrounding_text=candidate.text,
position_hint=f"frame {candidate.frame.sequence}",
)
prompt = vlm_prompt_fn(crop_context)
image_b64 = _encode_crop(crop)
t0 = time.monotonic()
try:
result = _call_cloud_api(image_b64, prompt)
except Exception as e:
call_ms = (time.monotonic() - t0) * 1000
emit.log(job_id, "CloudLLM", "DEBUG",
f"[{i+1}/{len(candidates)}] FAILED '{candidate.text[:30]}': {e} ({call_ms:.0f}ms)")
continue
call_ms = (time.monotonic() - t0) * 1000
stats.cloud_llm_calls += 1
model_info = provider.models.get(provider.model)
cost_per_token = model_info.cost_per_input_token if model_info else 0.00001
call_cost = result["tokens"] * cost_per_token
total_cost += call_cost
brand = result["brand"]
confidence = result["confidence"]
emit.log(job_id, "CloudLLM", "DEBUG",
f"[{i+1}/{len(candidates)}] '{candidate.text[:30]}'"
f"{'' + brand if brand else ''} "
f"(conf={confidence:.2f}, {result['tokens']}tok, ${call_cost:.4f}, {call_ms:.0f}ms)")
if brand and confidence >= min_confidence:
detection = BrandDetection(
brand=brand,
timestamp=candidate.frame.timestamp,
duration=0.5,
confidence=confidence,
source="cloud_llm",
bbox=candidate.bbox,
frame_ref=candidate.frame.sequence,
content_type=content_type,
)
matched.append(detection)
emit.detection(
job_id,
brand=brand,
confidence=confidence,
source="cloud_llm",
timestamp=candidate.frame.timestamp,
content_type=content_type,
frame_ref=candidate.frame.sequence,
)
# Register newly discovered brand in DB
_register_discovered_brand(brand, source_asset_id,
candidate.frame.timestamp, confidence)
stats.estimated_cloud_cost_usd += total_cost
stats.regions_escalated_to_cloud_llm = len(candidates)
emit.log(job_id, "CloudLLM", "INFO",
f"Cloud resolved {len(matched)}/{len(candidates)}"
f"cost ${total_cost:.4f} ({stats.cloud_llm_calls} calls total)")
return matched

View File

@@ -0,0 +1,157 @@
"""
Stage 6 — Local VLM escalation (moondream2)
Processes unresolved text candidates by sending crop images + prompt
to the local VLM on the inference server. Produces BrandDetection
objects for crops the VLM can identify.
"""
from __future__ import annotations
import logging
import os
import time
import numpy as np
from core.detect import emit
from core.detect.models import BrandDetection, TextCandidate
from core.detect.models import CropContext
logger = logging.getLogger(__name__)
def _register_discovered_brand(brand: str, source_asset_id: str | None,
timestamp: float, confidence: float, source: str):
"""Register a VLM-confirmed brand in the DB."""
try:
from core.detect.stages.brand_resolver import _register_brand, _record_sighting
brand_id = _register_brand(brand, source)
if brand_id and source_asset_id:
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, source)
except Exception as e:
logger.debug("Failed to register brand %s: %s", brand, e)
def _crop_image(candidate: TextCandidate) -> np.ndarray:
frame = candidate.frame
box = candidate.bbox
h, w = frame.image.shape[:2]
x1 = max(0, box.x)
y1 = max(0, box.y)
x2 = min(w, box.x + box.w)
y2 = min(h, box.y + box.h)
return frame.image[y1:y2, x1:x2]
def escalate_vlm(
candidates: list[TextCandidate],
vlm_prompt_fn,
inference_url: str | None = None,
min_confidence: float = 0.5,
content_type: str = "",
source_asset_id: str | None = None,
job_id: str | None = None,
) -> tuple[list[BrandDetection], list[TextCandidate]]:
"""
Send unresolved crops to local VLM for brand identification.
Returns:
- matched: BrandDetections the VLM confirmed
- still_unresolved: candidates the VLM couldn't resolve (→ cloud escalation)
"""
if not candidates:
return [], []
if os.environ.get("SKIP_VLM", "").strip() == "1":
emit.log(job_id, "VLMLocal", "INFO",
f"SKIP_VLM=1, skipping {len(candidates)} crops")
return [], candidates
emit.log(job_id, "VLMLocal", "INFO",
f"Processing {len(candidates)} unresolved crops with moondream2")
matched: list[BrandDetection] = []
still_unresolved: list[TextCandidate] = []
if inference_url:
from core.detect.inference import InferenceClient
from core.detect.emit import _run_log_level
client = InferenceClient(base_url=inference_url, job_id=job_id or "", log_level=_run_log_level)
for i, candidate in enumerate(candidates):
crop = _crop_image(candidate)
if crop.size == 0:
still_unresolved.append(candidate)
continue
crop_context = CropContext(
image=b"", # not used for prompt generation
surrounding_text=candidate.text,
position_hint=f"frame {candidate.frame.sequence}",
)
prompt = vlm_prompt_fn(crop_context)
t0 = time.monotonic()
try:
if inference_url:
result = client.vlm(image=crop, prompt=prompt)
brand = result.brand
confidence = result.confidence
reasoning = result.reasoning
else:
brand, confidence, reasoning = _vlm_local(crop, prompt)
except Exception as e:
vlm_ms = (time.monotonic() - t0) * 1000
emit.log(job_id, "VLMLocal", "DEBUG",
f"[{i+1}/{len(candidates)}] FAILED '{candidate.text[:30]}': {e} ({vlm_ms:.0f}ms)")
still_unresolved.append(candidate)
continue
vlm_ms = (time.monotonic() - t0) * 1000
emit.log(job_id, "VLMLocal", "DEBUG",
f"[{i+1}/{len(candidates)}] '{candidate.text[:30]}'"
f"{'' + brand if brand else '✗ unresolved'} "
f"(conf={confidence:.2f}, {vlm_ms:.0f}ms)")
if brand and confidence >= min_confidence:
detection = BrandDetection(
brand=brand,
timestamp=candidate.frame.timestamp,
duration=0.5,
confidence=confidence,
source="local_vlm",
bbox=candidate.bbox,
frame_ref=candidate.frame.sequence,
content_type=content_type,
)
matched.append(detection)
emit.detection(
job_id,
brand=brand,
confidence=confidence,
source="local_vlm",
timestamp=candidate.frame.timestamp,
content_type=content_type,
frame_ref=candidate.frame.sequence,
)
# Register newly discovered brand in DB
_register_discovered_brand(brand, source_asset_id,
candidate.frame.timestamp, confidence, "local_vlm")
logger.debug("VLM matched: %s (%.2f) — %s", brand, confidence, reasoning)
else:
still_unresolved.append(candidate)
emit.log(job_id, "VLMLocal", "INFO",
f"VLM resolved {len(matched)}, unresolved {len(still_unresolved)} → cloud")
return matched, still_unresolved
def _vlm_local(crop: np.ndarray, prompt: str) -> tuple[str, float, str]:
"""Run moondream2 in-process (single-box mode)."""
from core.gpu.models.vlm import query
result = query(crop, prompt)
return result["brand"], result["confidence"], result["reasoning"]

Some files were not shown because too many files have changed in this diff Show More