Compare commits
48 Commits
5ceb8172ea
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 0728fc6be3 | |||
| 41dd488fe6 | |||
| ddb4f17faa | |||
| f66d3a273f | |||
| d9e0794b83 | |||
| ed1f6b761b | |||
| 020f3540d3 | |||
| 55e83e4203 | |||
| aac27b8504 | |||
| 4220b0418e | |||
| d0707333fd | |||
| e46bbc419c | |||
| 0bd3888155 | |||
| acc99e691d | |||
| 8a90436f33 | |||
| f6ef95ebea | |||
| 49da927da0 | |||
| 886720c3ce | |||
| 1c6af767eb | |||
| 94c7b21ae5 | |||
| 3d8e7291f3 | |||
| bf30acd4df | |||
| a3b51c458d | |||
| 51ce14a812 | |||
| bcf6f3dc71 | |||
| 291ac8dd40 | |||
| df6bcb01e8 | |||
| 65814b5b9e | |||
| beb0416280 | |||
| a85722f96a | |||
| c9ba9e4f5f | |||
| e27cb5bcc3 | |||
| 731964ca10 | |||
| d58a90157a | |||
| 08c58a6a9d | |||
| 08b67f2bb7 | |||
| dfa3c12514 | |||
| 95246c5452 | |||
| 3df9ed5ada | |||
| 4fdbdfc6d3 | |||
| b57da622cb | |||
| 5ed876d694 | |||
| 71fd0510de | |||
| 8186bb5fe6 | |||
| 9c9c7dff09 | |||
| b40bd68411 | |||
| d5a3372d6b | |||
| ccc478fbaa |
@@ -0,0 +1,11 @@
|
||||
---
|
||||
name: agent_sdk_future
|
||||
description: Claude Agent SDK for general mpr tasks (not vision provider), uses OAuth not API keys
|
||||
type: project
|
||||
---
|
||||
|
||||
Claude Agent SDK (`claude-agent-sdk`) is for future general-purpose tasks in mpr, NOT for the cloud vision provider.
|
||||
|
||||
**Why:** The Agent SDK uses Claude Code CLI's OAuth (browser login, no API keys) and is designed for agentic tasks (file read/edit, bash, web search). The vision provider needs raw API calls with image payloads — use the `anthropic` SDK with `ANTHROPIC_API_KEY` for that.
|
||||
|
||||
**How to apply:** When adding Claude-powered automation to mpr (e.g., log analysis, config suggestions, code review on pipeline changes), use the Agent SDK. For the cloud LLM escalation stage (image crops → brand ID), keep using the `anthropic` SDK with API key auth.
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -17,10 +17,8 @@ env/
|
||||
*.pot
|
||||
*.pyc
|
||||
db.sqlite3
|
||||
media/in/*
|
||||
!media/in/.gitkeep
|
||||
media/out/*
|
||||
!media/out/.gitkeep
|
||||
media/*
|
||||
!media/.gitkeep
|
||||
|
||||
# Node
|
||||
node_modules/
|
||||
@@ -39,3 +37,4 @@ Thumbs.db
|
||||
|
||||
# Project specific
|
||||
def/
|
||||
ctrl/k8s/overlays/dev/local-config.yaml
|
||||
|
||||
161
README.md
161
README.md
@@ -1,161 +1,10 @@
|
||||
# MPR - Media Processor
|
||||
# MPR
|
||||
|
||||
A web-based media transcoding tool with Django admin, FastAPI backend, and React timeline UI.
|
||||
Brand and logo detection pipeline for video — extracts frames, segments the field, runs YOLO + OCR, and escalates unresolved detections to local or cloud VLMs.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Browser (mpr.local.ar)
|
||||
│
|
||||
nginx:80
|
||||
│
|
||||
┌────┴────┐
|
||||
│ │
|
||||
/admin /api, /ui
|
||||
│ │
|
||||
Django FastAPI ◄── Timeline UI
|
||||
│ │
|
||||
│ ┌────┘
|
||||
│ │
|
||||
└───►│ (job operations)
|
||||
│
|
||||
gRPC Server
|
||||
│
|
||||
Celery Worker
|
||||
```
|
||||
|
||||
- **Django** (`/admin`): Admin interface for data management
|
||||
- **FastAPI** (`/api`): REST API and gRPC client
|
||||
- **Timeline UI** (`/ui`): React app for video editing
|
||||
- **gRPC Server**: Worker communication with progress streaming
|
||||
- **Celery**: Job execution via FFmpeg
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Docker & Docker Compose
|
||||
|
||||
## Quick Start
|
||||
## Docs
|
||||
|
||||
```bash
|
||||
# Add to /etc/hosts
|
||||
echo "127.0.0.1 mpr.local.ar" | sudo tee -a /etc/hosts
|
||||
|
||||
# Start all services
|
||||
cd ctrl
|
||||
cp .env.template .env
|
||||
docker compose up -d
|
||||
python -m http.server 8000 --directory docs
|
||||
# open http://localhost:8000
|
||||
```
|
||||
|
||||
## Access Points
|
||||
|
||||
| URL | Description |
|
||||
|-----|-------------|
|
||||
| http://mpr.local.ar/admin | Django Admin |
|
||||
| http://mpr.local.ar/api/docs | FastAPI Swagger |
|
||||
| http://mpr.local.ar/ui | Timeline UI |
|
||||
|
||||
## Commands
|
||||
|
||||
```bash
|
||||
cd ctrl
|
||||
|
||||
# Start/stop
|
||||
docker compose up -d
|
||||
docker compose down
|
||||
|
||||
# Rebuild after code changes
|
||||
docker compose up -d --build
|
||||
|
||||
# View logs
|
||||
docker compose logs -f
|
||||
docker compose logs -f celery
|
||||
|
||||
# Create admin user
|
||||
docker compose exec django python admin/manage.py createsuperuser
|
||||
```
|
||||
|
||||
## Code Generation
|
||||
|
||||
Models are defined as dataclasses in `core/schema/models/` and generated via `modelgen`:
|
||||
- **Django ORM** models (`--include dataclasses,enums`)
|
||||
- **Pydantic** schemas (`--include dataclasses,enums`)
|
||||
- **TypeScript** types (`--include dataclasses,enums,api`)
|
||||
- **Protobuf** definitions (`--include grpc`)
|
||||
|
||||
Each target only gets the model groups it needs via the `--include` flag.
|
||||
|
||||
```bash
|
||||
# Regenerate all targets
|
||||
bash ctrl/generate.sh
|
||||
```
|
||||
|
||||
## Media Storage
|
||||
|
||||
MPR separates media into **input** (`MEDIA_IN`) and **output** (`MEDIA_OUT`) paths, each independently configurable. File paths are stored relative for cloud portability.
|
||||
|
||||
### Local Development
|
||||
- Source files: `/app/media/in/video.mp4`
|
||||
- Output files: `/app/media/out/video_h264.mp4`
|
||||
- Served via: `http://mpr.local.ar/media/in/video.mp4` (nginx alias)
|
||||
|
||||
### AWS/Cloud Deployment
|
||||
Input and output can be different buckets/locations:
|
||||
```bash
|
||||
MEDIA_IN=s3://source-bucket/media/
|
||||
MEDIA_OUT=s3://output-bucket/transcoded/
|
||||
```
|
||||
|
||||
**Scan Endpoint**: `POST /api/assets/scan` recursively scans `MEDIA_IN` and registers new files with relative paths.
|
||||
|
||||
See [docs/media-storage.md](docs/media-storage.md) for full details.
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
mpr/
|
||||
├── admin/ # Django project
|
||||
│ ├── manage.py # Django management script
|
||||
│ └── mpr/ # Django settings & app
|
||||
│ └── media_assets/# Django app
|
||||
├── core/ # Core application logic
|
||||
│ ├── api/ # FastAPI + GraphQL API
|
||||
│ │ └── schema/ # GraphQL types (generated)
|
||||
│ ├── ffmpeg/ # FFmpeg wrappers
|
||||
│ ├── rpc/ # gRPC server & client
|
||||
│ │ └── protos/ # Protobuf definitions (generated)
|
||||
│ ├── schema/ # Source of truth
|
||||
│ │ └── models/ # Dataclass definitions
|
||||
│ ├── storage/ # S3/GCP/local storage backends
|
||||
│ └── task/ # Celery job execution
|
||||
│ ├── executor.py # Executor abstraction
|
||||
│ └── tasks.py # Celery tasks
|
||||
├── ctrl/ # Docker & deployment
|
||||
│ ├── docker-compose.yml
|
||||
│ └── nginx.conf
|
||||
├── media/
|
||||
│ ├── in/ # Source media files
|
||||
│ └── out/ # Transcoded output
|
||||
├── modelgen/ # Code generation tool
|
||||
└── ui/ # Frontend
|
||||
└── timeline/ # React app
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
See `ctrl/.env.template` for all configuration options.
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `DATABASE_URL` | sqlite | PostgreSQL connection string |
|
||||
| `REDIS_URL` | redis://localhost:6379 | Redis for Celery |
|
||||
| `GRPC_HOST` | grpc | gRPC server hostname |
|
||||
| `GRPC_PORT` | 50051 | gRPC server port |
|
||||
| `MPR_EXECUTOR` | local | Executor type (local/lambda) |
|
||||
| `MEDIA_IN` | /app/media/in | Source media files directory |
|
||||
| `MEDIA_OUT` | /app/media/out | Transcoded output directory |
|
||||
| `MEDIA_BASE_URL` | /media/ | Base URL for serving media (use S3 URL for cloud) |
|
||||
| `VITE_ALLOWED_HOSTS` | - | Comma-separated allowed hosts for Vite dev server |
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
|
||||
@@ -7,4 +7,4 @@ os.environ.setdefault("DJANGO_SETTINGS_MODULE", "admin.mpr.settings")
|
||||
app = Celery("mpr")
|
||||
app.config_from_object("django.conf:settings", namespace="CELERY")
|
||||
app.autodiscover_tasks()
|
||||
app.autodiscover_tasks(["core.task"])
|
||||
app.autodiscover_tasks(["core.jobs"])
|
||||
|
||||
@@ -14,11 +14,29 @@ class AssetStatus(models.TextChoices):
|
||||
|
||||
class JobStatus(models.TextChoices):
|
||||
PENDING = "pending", "Pending"
|
||||
PROCESSING = "processing", "Processing"
|
||||
RUNNING = "running", "Running"
|
||||
PAUSED = "paused", "Paused"
|
||||
COMPLETED = "completed", "Completed"
|
||||
FAILED = "failed", "Failed"
|
||||
CANCELLED = "cancelled", "Cancelled"
|
||||
|
||||
class RunType(models.TextChoices):
|
||||
INITIAL = "initial", "Initial"
|
||||
REPLAY = "replay", "Replay"
|
||||
RETRY = "retry", "Retry"
|
||||
|
||||
class BrandSource(models.TextChoices):
|
||||
OCR = "ocr", "Ocr"
|
||||
VLM = "local_vlm", "Vlm"
|
||||
CLOUD = "cloud_llm", "Cloud"
|
||||
MANUAL = "manual", "Manual"
|
||||
|
||||
class SourceType(models.TextChoices):
|
||||
CHUNK_JOB = "chunk_job", "Chunk Job"
|
||||
UPLOAD = "upload", "Upload"
|
||||
DEVICE = "device", "Device"
|
||||
STREAM = "stream", "Stream"
|
||||
|
||||
class MediaAsset(models.Model):
|
||||
"""A video/audio file registered in the system."""
|
||||
|
||||
@@ -77,26 +95,25 @@ class TranscodePreset(models.Model):
|
||||
return self.name
|
||||
|
||||
|
||||
class TranscodeJob(models.Model):
|
||||
"""A transcoding or trimming job in the queue."""
|
||||
class Job(models.Model):
|
||||
"""A pipeline job."""
|
||||
|
||||
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
|
||||
source_asset_id = models.UUIDField()
|
||||
preset_id = models.UUIDField(null=True, blank=True)
|
||||
preset_snapshot = models.JSONField(default=dict, blank=True)
|
||||
trim_start = models.FloatField(null=True, blank=True, default=None)
|
||||
trim_end = models.FloatField(null=True, blank=True, default=None)
|
||||
output_filename = models.CharField(max_length=500)
|
||||
output_path = models.CharField(max_length=1000, null=True, blank=True)
|
||||
output_asset_id = models.UUIDField(null=True, blank=True)
|
||||
video_path = models.CharField(max_length=1000)
|
||||
profile_name = models.CharField(max_length=255)
|
||||
timeline_id = models.UUIDField(null=True, blank=True)
|
||||
parent_id = models.UUIDField(null=True, blank=True)
|
||||
run_type = models.CharField(max_length=20, choices=RunType.choices, default=RunType.INITIAL)
|
||||
config_overrides = models.JSONField(default=dict, blank=True)
|
||||
status = models.CharField(max_length=20, choices=JobStatus.choices, default=JobStatus.PENDING)
|
||||
current_stage = models.CharField(max_length=255, null=True, blank=True)
|
||||
progress = models.FloatField(default=0.0)
|
||||
current_frame = models.IntegerField(null=True, blank=True, default=None)
|
||||
current_time = models.FloatField(null=True, blank=True, default=None)
|
||||
speed = models.CharField(max_length=255, null=True, blank=True)
|
||||
error_message = models.TextField(blank=True, default='')
|
||||
celery_task_id = models.CharField(max_length=255, null=True, blank=True)
|
||||
execution_arn = models.CharField(max_length=255, null=True, blank=True)
|
||||
total_detections = models.IntegerField(default=0)
|
||||
brands_found = models.IntegerField(default=0)
|
||||
cloud_llm_calls = models.IntegerField(default=0)
|
||||
estimated_cost_usd = models.FloatField(default=0.0)
|
||||
priority = models.IntegerField(default=0)
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
started_at = models.DateTimeField(null=True, blank=True)
|
||||
@@ -108,3 +125,98 @@ class TranscodeJob(models.Model):
|
||||
def __str__(self):
|
||||
return str(self.id)
|
||||
|
||||
|
||||
class Timeline(models.Model):
|
||||
"""A user-created selection of source material."""
|
||||
|
||||
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
|
||||
name = models.CharField(max_length=255)
|
||||
source_asset_id = models.UUIDField(null=True, blank=True)
|
||||
chunk_paths = models.JSONField(default=list, blank=True)
|
||||
profile_name = models.CharField(max_length=255)
|
||||
status = models.CharField(max_length=255)
|
||||
fps = models.FloatField(default=2.0)
|
||||
frame_count = models.IntegerField(default=0)
|
||||
source_ephemeral = models.BooleanField(default=False)
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
ordering = ["-created_at"]
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class Checkpoint(models.Model):
|
||||
"""A snapshot of pipeline state on a timeline."""
|
||||
|
||||
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
|
||||
timeline_id = models.UUIDField()
|
||||
job_id = models.UUIDField(null=True, blank=True)
|
||||
parent_id = models.UUIDField(null=True, blank=True)
|
||||
stage_name = models.CharField(max_length=255)
|
||||
config_overrides = models.JSONField(default=dict, blank=True)
|
||||
stats = models.JSONField(default=dict, blank=True)
|
||||
is_scenario = models.BooleanField(default=False)
|
||||
scenario_label = models.CharField(max_length=255)
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
ordering = ["-created_at"]
|
||||
|
||||
def __str__(self):
|
||||
return str(self.id)
|
||||
|
||||
|
||||
class StageOutput(models.Model):
|
||||
"""Output of a single stage within a job."""
|
||||
|
||||
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
|
||||
job_id = models.UUIDField()
|
||||
timeline_id = models.UUIDField()
|
||||
stage_name = models.CharField(max_length=255)
|
||||
checkpoint_id = models.UUIDField(null=True, blank=True)
|
||||
output = models.JSONField(default=dict, blank=True)
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
ordering = ["-created_at"]
|
||||
|
||||
def __str__(self):
|
||||
return str(self.id)
|
||||
|
||||
|
||||
class Brand(models.Model):
|
||||
"""A brand discovered or registered in the system."""
|
||||
|
||||
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
|
||||
canonical_name = models.CharField(max_length=255)
|
||||
aliases = models.JSONField(default=list, blank=True)
|
||||
source = models.CharField(max_length=20, choices=BrandSource.choices, default=BrandSource.OCR)
|
||||
confirmed = models.BooleanField(default=False)
|
||||
airings = models.JSONField(default=list, blank=True)
|
||||
total_airings = models.IntegerField(default=0)
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
updated_at = models.DateTimeField(auto_now=True)
|
||||
|
||||
class Meta:
|
||||
ordering = ["-created_at"]
|
||||
|
||||
def __str__(self):
|
||||
return str(self.id)
|
||||
|
||||
|
||||
class Profile(models.Model):
|
||||
"""A content type profile."""
|
||||
|
||||
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
|
||||
name = models.CharField(max_length=255)
|
||||
pipeline = models.JSONField(default=dict, blank=True)
|
||||
configs = models.JSONField(default=dict, blank=True)
|
||||
|
||||
class Meta:
|
||||
pass
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
22
core/api/detect/__init__.py
Normal file
22
core/api/detect/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Detection API — aggregated router.
|
||||
|
||||
Combines all detect sub-routers into a single include for main.py.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .sources import router as sources_router
|
||||
from .run import router as run_router
|
||||
from .sse import router as sse_router
|
||||
from .replay import router as replay_router
|
||||
from .config import router as config_router
|
||||
from .timeline import router as timeline_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(sources_router)
|
||||
router.include_router(run_router)
|
||||
router.include_router(sse_router)
|
||||
router.include_router(replay_router)
|
||||
router.include_router(config_router)
|
||||
router.include_router(timeline_router)
|
||||
203
core/api/detect/config.py
Normal file
203
core/api/detect/config.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
Runtime config endpoint for the detection pipeline.
|
||||
|
||||
GET /detect/config — read current config
|
||||
PUT /detect/config — update config (takes effect on next run)
|
||||
GET /detect/config/stages — list stage palette with config fields
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/detect", tags=["detect"])
|
||||
|
||||
# In-memory config — persists until server restart.
|
||||
# Phase 12+ moves this to DB.
|
||||
_runtime_config: dict = {}
|
||||
|
||||
|
||||
class ConfigUpdate(BaseModel):
|
||||
detection: dict | None = None
|
||||
ocr: dict | None = None
|
||||
resolver: dict | None = None
|
||||
escalation: dict | None = None
|
||||
preprocessing: dict | None = None
|
||||
|
||||
|
||||
class StageOutputHintInfo(BaseModel):
|
||||
key: str
|
||||
type: str
|
||||
label: str = ""
|
||||
default_opacity: float = 0.5
|
||||
src_format: str = "png"
|
||||
|
||||
|
||||
class TransformOptionInfo(BaseModel):
|
||||
key: str
|
||||
type: str
|
||||
default: object = False
|
||||
label: str = ""
|
||||
description: str = ""
|
||||
|
||||
|
||||
class StageConfigInfo(BaseModel):
|
||||
name: str
|
||||
label: str
|
||||
description: str
|
||||
category: str
|
||||
config_fields: list[dict]
|
||||
output_hints: list[StageOutputHintInfo] = []
|
||||
accepted_transforms: list[TransformOptionInfo] = []
|
||||
reads: list[str]
|
||||
writes: list[str]
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
def read_config():
|
||||
return _runtime_config
|
||||
|
||||
|
||||
@router.put("/config")
|
||||
def write_config(update: ConfigUpdate):
|
||||
changes = update.model_dump(exclude_none=True)
|
||||
for section, values in changes.items():
|
||||
if section not in _runtime_config:
|
||||
_runtime_config[section] = {}
|
||||
_runtime_config[section].update(values)
|
||||
|
||||
logger.info("Config updated: %s", list(changes.keys()))
|
||||
return _runtime_config
|
||||
|
||||
|
||||
@router.get("/config/profiles")
|
||||
def get_profiles():
|
||||
"""List available detection profiles."""
|
||||
from core.detect.profile import list_profiles as _list
|
||||
return [{"name": name} for name in _list()]
|
||||
|
||||
|
||||
@router.get("/config/profiles/{profile_name}/pipeline")
|
||||
def get_pipeline_config(profile_name: str):
|
||||
"""Return the pipeline composition for a profile."""
|
||||
from core.detect.profile import get_profile
|
||||
from fastapi import HTTPException
|
||||
|
||||
try:
|
||||
profile = get_profile(profile_name)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown profile: {profile_name}")
|
||||
|
||||
return profile["pipeline"]
|
||||
|
||||
|
||||
class UpdateEdgeTransformRequest(BaseModel):
|
||||
profile_name: str = "soccer_broadcast"
|
||||
source_stage: str
|
||||
target_stage: str
|
||||
transform: dict
|
||||
|
||||
|
||||
@router.put("/config/edge-transform")
|
||||
def update_edge_transform(req: UpdateEdgeTransformRequest):
|
||||
"""Update the transform on an edge in a profile's pipeline config."""
|
||||
from uuid import UUID
|
||||
from core.db.models import Profile
|
||||
from core.db.connection import get_session
|
||||
from sqlmodel import select
|
||||
from fastapi import HTTPException
|
||||
|
||||
with get_session() as session:
|
||||
stmt = select(Profile).where(Profile.name == req.profile_name)
|
||||
profile = session.exec(stmt).first()
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail=f"Profile not found: {req.profile_name}")
|
||||
|
||||
pipeline = dict(profile.pipeline)
|
||||
edges = pipeline.get("edges", [])
|
||||
|
||||
found = False
|
||||
for edge in edges:
|
||||
if edge.get("source") == req.source_stage and edge.get("target") == req.target_stage:
|
||||
edge["transform"] = req.transform
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Edge not found: {req.source_stage} → {req.target_stage}",
|
||||
)
|
||||
|
||||
pipeline["edges"] = edges
|
||||
profile.pipeline = pipeline
|
||||
session.commit()
|
||||
|
||||
return {"status": "updated", "edge": f"{req.source_stage} → {req.target_stage}", "transform": req.transform}
|
||||
|
||||
|
||||
@router.get("/config/stages", response_model=list[StageConfigInfo])
|
||||
def list_stage_configs():
|
||||
"""Return the stage palette with config field metadata for the editor."""
|
||||
from core.detect.stages import list_stages
|
||||
|
||||
result = []
|
||||
for stage in list_stages():
|
||||
info = _stage_to_info(stage)
|
||||
result.append(info)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/config/stages/{stage_name}", response_model=StageConfigInfo)
|
||||
def get_stage_config(stage_name: str):
|
||||
"""Return config field metadata for a single stage."""
|
||||
from core.detect.stages import get_stage
|
||||
|
||||
try:
|
||||
stage = get_stage(stage_name)
|
||||
except KeyError:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=404, detail=f"Unknown stage: {stage_name}")
|
||||
return _stage_to_info(stage)
|
||||
|
||||
|
||||
def _stage_to_info(stage) -> StageConfigInfo:
|
||||
return StageConfigInfo(
|
||||
name=stage.name,
|
||||
label=stage.label,
|
||||
description=stage.description,
|
||||
category=stage.category,
|
||||
config_fields=[
|
||||
{
|
||||
"name": f.name,
|
||||
"type": f.type,
|
||||
"default": f.default,
|
||||
"description": f.description,
|
||||
"min": f.min,
|
||||
"max": f.max,
|
||||
"options": f.options,
|
||||
}
|
||||
for f in stage.config_fields
|
||||
],
|
||||
output_hints=[
|
||||
StageOutputHintInfo(
|
||||
key=h.key, type=h.type, label=h.label,
|
||||
default_opacity=h.default_opacity, src_format=h.src_format,
|
||||
)
|
||||
for h in getattr(stage, "output_hints", [])
|
||||
],
|
||||
accepted_transforms=[
|
||||
TransformOptionInfo(
|
||||
key=t.key, type=t.type, default=t.default,
|
||||
label=t.label, description=t.description,
|
||||
)
|
||||
for t in getattr(stage, "accepted_transforms", [])
|
||||
],
|
||||
reads=stage.io.reads,
|
||||
writes=stage.io.writes,
|
||||
)
|
||||
521
core/api/detect/replay.py
Normal file
521
core/api/detect/replay.py
Normal file
@@ -0,0 +1,521 @@
|
||||
"""
|
||||
API endpoints for checkpoint inspection, replay, retry, and GPU proxy.
|
||||
|
||||
GET /detect/checkpoints/{timeline_id} — list available checkpoints
|
||||
POST /detect/replay — replay from a stage with config overrides
|
||||
POST /detect/retry — queue async retry with different provider
|
||||
POST /detect/replay-stage — replay single stage (fast path)
|
||||
POST /detect/gpu/detect_edges — proxy to GPU inference server
|
||||
POST /detect/gpu/detect_edges/debug — proxy with debug overlays
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request, Response
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/detect", tags=["detect"])
|
||||
|
||||
|
||||
# --- Request/Response models ---
|
||||
|
||||
class CheckpointInfo(BaseModel):
|
||||
stage: str
|
||||
is_scenario: bool = False
|
||||
scenario_label: str = ""
|
||||
|
||||
|
||||
class ScenarioInfo(BaseModel):
|
||||
timeline_id: str
|
||||
stage: str
|
||||
scenario_label: str
|
||||
profile_name: str
|
||||
video_path: str
|
||||
frame_count: int = 0
|
||||
created_at: str = ""
|
||||
|
||||
|
||||
class ReplayRequest(BaseModel):
|
||||
job_id: str
|
||||
start_stage: str
|
||||
config_overrides: dict | None = None
|
||||
|
||||
|
||||
class ReplayResponse(BaseModel):
|
||||
status: str
|
||||
job_id: str
|
||||
replay_job_id: str
|
||||
start_stage: str
|
||||
detections: int = 0
|
||||
brands_found: int = 0
|
||||
|
||||
|
||||
|
||||
|
||||
class ReplaySingleStageRequest(BaseModel):
|
||||
job_id: str
|
||||
stage: str
|
||||
frame_refs: list[int] | None = None
|
||||
config_overrides: dict | None = None
|
||||
debug: bool = False
|
||||
|
||||
|
||||
class ReplaySingleStageBox(BaseModel):
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
|
||||
|
||||
class FrameDebugOverlays(BaseModel):
|
||||
edge_overlay_b64: str = ""
|
||||
lines_overlay_b64: str = ""
|
||||
horizontal_count: int = 0
|
||||
pair_count: int = 0
|
||||
|
||||
|
||||
class ReplaySingleStageResponse(BaseModel):
|
||||
status: str
|
||||
stage: str
|
||||
frame_count: int = 0
|
||||
region_count: int = 0
|
||||
regions_by_frame: dict[str, list[ReplaySingleStageBox]] = {}
|
||||
debug: dict[str, FrameDebugOverlays] = {} # keyed by frame seq
|
||||
|
||||
|
||||
# --- Endpoints ---
|
||||
|
||||
@router.get("/checkpoints/{timeline_id}")
|
||||
def list_checkpoints_endpoint(timeline_id: str) -> list[CheckpointInfo]:
|
||||
"""List available checkpoint stages for a timeline."""
|
||||
from core.detect.checkpoint.storage import get_checkpoints_for_timeline
|
||||
|
||||
try:
|
||||
checkpoints = get_checkpoints_for_timeline(timeline_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404, detail=f"No checkpoints for timeline {timeline_id}: {e}")
|
||||
|
||||
result = [
|
||||
CheckpointInfo(
|
||||
stage=c["stage_name"],
|
||||
is_scenario=c.get("is_scenario", False),
|
||||
scenario_label=c.get("scenario_label", ""),
|
||||
)
|
||||
for c in checkpoints
|
||||
if c["stage_name"]
|
||||
]
|
||||
return result
|
||||
|
||||
|
||||
class CheckpointFrameInfo(BaseModel):
|
||||
seq: int
|
||||
timestamp: float
|
||||
jpeg_b64: str
|
||||
|
||||
|
||||
class CheckpointData(BaseModel):
|
||||
timeline_id: str
|
||||
stage: str
|
||||
profile_name: str
|
||||
video_path: str
|
||||
is_scenario: bool
|
||||
scenario_label: str
|
||||
frames: list[CheckpointFrameInfo]
|
||||
stats: dict = {}
|
||||
config_snapshot: dict = {}
|
||||
stage_output_key: str = ""
|
||||
|
||||
|
||||
@router.get("/checkpoints/{timeline_id}/{stage}", response_model=CheckpointData)
|
||||
def get_checkpoint_data(timeline_id: str, stage: str):
|
||||
"""Load checkpoint frames + metadata for the editor UI.
|
||||
|
||||
Reads from the timeline's frame cache (local filesystem).
|
||||
"""
|
||||
from uuid import UUID
|
||||
from core.db.models import Timeline, Checkpoint
|
||||
from core.db.connection import get_session
|
||||
from core.db.checkpoint import list_checkpoints
|
||||
from core.detect.checkpoint.frames import load_cached_frames_b64
|
||||
|
||||
with get_session() as session:
|
||||
timeline = session.get(Timeline, UUID(timeline_id))
|
||||
if not timeline:
|
||||
raise HTTPException(status_code=404, detail=f"Timeline not found: {timeline_id}")
|
||||
|
||||
checkpoints = list_checkpoints(session, UUID(timeline_id))
|
||||
if not checkpoints:
|
||||
raise HTTPException(status_code=404, detail=f"No checkpoints for timeline {timeline_id}")
|
||||
# Prefer a checkpoint for this stage; fall back to latest
|
||||
checkpoint = next(
|
||||
(c for c in reversed(checkpoints) if c.stage_name == stage),
|
||||
checkpoints[-1],
|
||||
)
|
||||
|
||||
# Read from timeline's frame cache
|
||||
frames_b64 = load_cached_frames_b64(timeline_id)
|
||||
frame_list = [
|
||||
CheckpointFrameInfo(seq=f["seq"], timestamp=f["timestamp"], jpeg_b64=f["jpeg_b64"])
|
||||
for f in frames_b64
|
||||
]
|
||||
|
||||
return CheckpointData(
|
||||
timeline_id=timeline_id,
|
||||
stage=stage,
|
||||
profile_name=timeline.profile_name,
|
||||
video_path=timeline.chunk_paths[0] if timeline.chunk_paths else "",
|
||||
is_scenario=checkpoint.is_scenario,
|
||||
scenario_label=checkpoint.scenario_label,
|
||||
frames=frame_list,
|
||||
stats=checkpoint.stats or {},
|
||||
config_snapshot=checkpoint.config_overrides or {},
|
||||
stage_output_key=stage,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/scenarios", response_model=list[ScenarioInfo])
|
||||
def list_scenarios_endpoint():
|
||||
"""List all available scenarios (bookmarked checkpoints)."""
|
||||
from core.db.models import Timeline
|
||||
from core.db.connection import get_session
|
||||
from core.db.checkpoint import list_scenarios
|
||||
|
||||
with get_session() as session:
|
||||
scenarios = list_scenarios(session)
|
||||
result = []
|
||||
for s in scenarios:
|
||||
timeline = session.get(Timeline, s.timeline_id)
|
||||
if not timeline:
|
||||
continue
|
||||
info = ScenarioInfo(
|
||||
timeline_id=str(s.timeline_id),
|
||||
stage=s.stage_name,
|
||||
scenario_label=s.scenario_label,
|
||||
profile_name=timeline.profile_name,
|
||||
video_path=timeline.chunk_paths[0] if timeline.chunk_paths else "",
|
||||
created_at=str(s.created_at) if s.created_at else "",
|
||||
)
|
||||
result.append(info)
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/replay", response_model=ReplayResponse)
|
||||
def replay(req: ReplayRequest):
|
||||
"""Replay pipeline from a specific stage with optional config overrides."""
|
||||
from core.detect.checkpoint.replay import replay_from
|
||||
|
||||
try:
|
||||
result = replay_from(
|
||||
job_id=req.job_id,
|
||||
start_stage=req.start_stage,
|
||||
config_overrides=req.config_overrides,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Replay failed: {e}")
|
||||
|
||||
detections = result.get("detections", [])
|
||||
report = result.get("report")
|
||||
brands_found = len(report.brands) if report else 0
|
||||
|
||||
response = ReplayResponse(
|
||||
status="completed",
|
||||
job_id=req.job_id,
|
||||
replay_job_id=result.get("job_id", ""),
|
||||
start_stage=req.start_stage,
|
||||
detections=len(detections),
|
||||
brands_found=brands_found,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@router.post("/replay-stage", response_model=ReplaySingleStageResponse)
|
||||
def replay_single_stage(req: ReplaySingleStageRequest):
|
||||
"""Replay a single stage on specific frames — fast path for interactive tuning."""
|
||||
from core.detect.checkpoint.replay import replay_single_stage as _replay
|
||||
|
||||
try:
|
||||
result = _replay(
|
||||
job_id=req.job_id,
|
||||
stage=req.stage,
|
||||
frame_refs=req.frame_refs,
|
||||
config_overrides=req.config_overrides,
|
||||
debug=req.debug,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Single-stage replay failed: {e}")
|
||||
|
||||
# Convert result to response format
|
||||
regions_by_frame = result.get("edge_regions_by_frame", {})
|
||||
total_regions = 0
|
||||
serialized = {}
|
||||
for seq, boxes in regions_by_frame.items():
|
||||
box_list = []
|
||||
for b in boxes:
|
||||
box = ReplaySingleStageBox(
|
||||
x=b.x, y=b.y, w=b.w, h=b.h,
|
||||
confidence=b.confidence, label=b.label,
|
||||
)
|
||||
box_list.append(box)
|
||||
serialized[str(seq)] = box_list
|
||||
total_regions += len(box_list)
|
||||
|
||||
# Serialize debug overlays if present
|
||||
debug_out = {}
|
||||
raw_debug = result.get("debug", {})
|
||||
for seq, d in raw_debug.items():
|
||||
debug_out[str(seq)] = FrameDebugOverlays(
|
||||
edge_overlay_b64=d.get("edge_overlay_b64", ""),
|
||||
lines_overlay_b64=d.get("lines_overlay_b64", ""),
|
||||
horizontal_count=d.get("horizontal_count", 0),
|
||||
pair_count=d.get("pair_count", 0),
|
||||
)
|
||||
|
||||
return ReplaySingleStageResponse(
|
||||
status="completed",
|
||||
stage=req.stage,
|
||||
frame_count=len(regions_by_frame),
|
||||
region_count=total_regions,
|
||||
regions_by_frame=serialized,
|
||||
debug=debug_out,
|
||||
)
|
||||
|
||||
|
||||
# --- GPU proxy — thin passthrough to inference server for interactive editor ---
|
||||
|
||||
|
||||
def _gpu_url() -> str:
|
||||
url = os.environ.get("INFERENCE_URL", "http://localhost:8000")
|
||||
return url.rstrip("/")
|
||||
|
||||
|
||||
# --- Overlay cache — save/load debug overlay images ---
|
||||
|
||||
|
||||
class SaveOverlaysRequest(BaseModel):
|
||||
timeline_id: str
|
||||
job_id: str
|
||||
stage: str
|
||||
seq: int
|
||||
overlays: dict[str, str] # {overlay_key: base64_png}
|
||||
|
||||
|
||||
@router.post("/overlays")
|
||||
def save_overlays_endpoint(req: SaveOverlaysRequest):
|
||||
"""Save debug overlay images to blob storage cache."""
|
||||
from core.detect.checkpoint.frames import save_overlays
|
||||
|
||||
save_overlays(req.timeline_id, req.job_id, req.stage, req.seq, req.overlays)
|
||||
return {"status": "saved", "count": len(req.overlays)}
|
||||
|
||||
|
||||
@router.get("/overlays/{timeline_id}/{job_id}/{stage}/{seq}")
|
||||
def load_overlays_endpoint(timeline_id: str, job_id: str, stage: str, seq: int):
|
||||
"""Load cached debug overlay images."""
|
||||
from core.detect.checkpoint.frames import load_overlays
|
||||
|
||||
overlays = load_overlays(timeline_id, job_id, stage, seq)
|
||||
return {"overlays": overlays or {}}
|
||||
|
||||
|
||||
def _generate_debug_overlays(job_id: str, stage: str, frame) -> dict[str, str] | None:
|
||||
"""Generate debug overlay images for a single frame."""
|
||||
import os
|
||||
|
||||
inference_url = os.environ.get("INFERENCE_URL")
|
||||
|
||||
if stage == "detect_edges":
|
||||
from core.detect.profile import get_profile, get_stage_config
|
||||
from core.detect.stages.models import RegionAnalysisConfig
|
||||
from core.db.connection import get_session
|
||||
from core.db.job import get_job
|
||||
from uuid import UUID
|
||||
|
||||
with get_session() as session:
|
||||
job = get_job(session, UUID(job_id))
|
||||
if not job:
|
||||
return None
|
||||
|
||||
profile = get_profile(job.profile_name)
|
||||
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
||||
|
||||
if inference_url:
|
||||
from core.detect.inference import InferenceClient
|
||||
client = InferenceClient(base_url=inference_url, job_id=job_id)
|
||||
dr = client.detect_edges_debug(
|
||||
image=frame.image,
|
||||
edge_canny_low=config.edge_canny_low,
|
||||
edge_canny_high=config.edge_canny_high,
|
||||
edge_hough_threshold=config.edge_hough_threshold,
|
||||
edge_hough_min_length=config.edge_hough_min_length,
|
||||
edge_hough_max_gap=config.edge_hough_max_gap,
|
||||
edge_pair_max_distance=config.edge_pair_max_distance,
|
||||
edge_pair_min_distance=config.edge_pair_min_distance,
|
||||
)
|
||||
return {
|
||||
"edge_overlay_b64": dr.edge_overlay_b64,
|
||||
"lines_overlay_b64": dr.lines_overlay_b64,
|
||||
}
|
||||
else:
|
||||
from core.detect.stages.edge_detector import _load_cv_edges
|
||||
edges_mod = _load_cv_edges()
|
||||
dr = edges_mod.detect_edges_debug(
|
||||
frame.image,
|
||||
canny_low=config.edge_canny_low,
|
||||
canny_high=config.edge_canny_high,
|
||||
hough_threshold=config.edge_hough_threshold,
|
||||
hough_min_length=config.edge_hough_min_length,
|
||||
hough_max_gap=config.edge_hough_max_gap,
|
||||
pair_max_distance=config.edge_pair_max_distance,
|
||||
pair_min_distance=config.edge_pair_min_distance,
|
||||
)
|
||||
return {
|
||||
"edge_overlay_b64": dr["edge_overlay_b64"],
|
||||
"lines_overlay_b64": dr["lines_overlay_b64"],
|
||||
}
|
||||
|
||||
elif stage == "field_segmentation":
|
||||
from core.detect.profile import get_profile, get_stage_config
|
||||
from core.detect.stages.models import FieldSegmentationConfig
|
||||
from core.db.connection import get_session
|
||||
from core.db.job import get_job
|
||||
from uuid import UUID
|
||||
|
||||
with get_session() as session:
|
||||
job = get_job(session, UUID(job_id))
|
||||
if not job:
|
||||
return None
|
||||
|
||||
profile = get_profile(job.profile_name)
|
||||
config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation"))
|
||||
|
||||
if inference_url:
|
||||
import httpx, json, base64, io
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
buf = io.BytesIO()
|
||||
Image.fromarray(frame.image).save(buf, format="JPEG", quality=85)
|
||||
img_b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
resp = httpx.post(
|
||||
f"{inference_url.rstrip('/')}/segment_field/debug",
|
||||
json={
|
||||
"image_b64": img_b64,
|
||||
"hue_low": config.hue_low,
|
||||
"hue_high": config.hue_high,
|
||||
"sat_low": config.sat_low,
|
||||
"sat_high": config.sat_high,
|
||||
"val_low": config.val_low,
|
||||
"val_high": config.val_high,
|
||||
"morph_kernel": config.morph_kernel,
|
||||
"min_area_ratio": config.min_area_ratio,
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
return {"mask_overlay_b64": data.get("mask_b64", "")}
|
||||
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/overlays/{timeline_id}/{job_id}/{stage}")
|
||||
def list_overlay_frames_endpoint(timeline_id: str, job_id: str, stage: str):
|
||||
"""List frame sequences that have cached overlays."""
|
||||
from core.detect.checkpoint.frames import list_overlay_frames
|
||||
|
||||
seqs = list_overlay_frames(timeline_id, job_id, stage)
|
||||
return {"frames": seqs}
|
||||
|
||||
|
||||
# --- GPU proxy — thin passthrough to inference server for interactive editor ---
|
||||
|
||||
|
||||
@router.post("/gpu/detect_edges")
|
||||
async def gpu_detect_edges(request: Request):
|
||||
"""Proxy to GPU inference server — browser can't reach it directly."""
|
||||
import httpx
|
||||
|
||||
body = await request.body()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
f"{_gpu_url()}/detect_edges",
|
||||
content=body,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
return Response(content=resp.content, status_code=resp.status_code,
|
||||
media_type="application/json")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
|
||||
|
||||
|
||||
@router.post("/gpu/detect_edges/debug")
|
||||
async def gpu_detect_edges_debug(request: Request):
|
||||
"""Proxy to GPU inference server debug endpoint."""
|
||||
import httpx
|
||||
|
||||
body = await request.body()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
f"{_gpu_url()}/detect_edges/debug",
|
||||
content=body,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
return Response(content=resp.content, status_code=resp.status_code,
|
||||
media_type="application/json")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
|
||||
|
||||
|
||||
@router.post("/gpu/segment_field")
|
||||
async def gpu_segment_field(request: Request):
|
||||
"""Proxy to GPU inference server — field segmentation."""
|
||||
import httpx
|
||||
|
||||
body = await request.body()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
f"{_gpu_url()}/segment_field",
|
||||
content=body,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
return Response(content=resp.content, status_code=resp.status_code,
|
||||
media_type="application/json")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
|
||||
|
||||
|
||||
@router.post("/gpu/segment_field/debug")
|
||||
async def gpu_segment_field_debug(request: Request):
|
||||
"""Proxy to GPU inference server — field segmentation with debug overlay."""
|
||||
import httpx
|
||||
|
||||
body = await request.body()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
f"{_gpu_url()}/segment_field/debug",
|
||||
content=body,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
return Response(content=resp.content, status_code=resp.status_code,
|
||||
media_type="application/json")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=f"GPU server unreachable: {e}")
|
||||
278
core/api/detect/run.py
Normal file
278
core/api/detect/run.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
Pipeline run endpoints.
|
||||
|
||||
POST /detect/run — launch pipeline on a timeline
|
||||
POST /detect/stop/{job_id} — cancel a running pipeline
|
||||
POST /detect/pause/{job_id} — pause after current stage
|
||||
POST /detect/resume/{job_id} — resume a paused pipeline
|
||||
POST /detect/step/{job_id} — run one stage then pause
|
||||
POST /detect/clear/{job_id} — clear events from Redis
|
||||
GET /detect/status/{job_id} — pipeline run status
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/detect", tags=["detect"])
|
||||
|
||||
# In-process pipeline tracking
|
||||
_running_jobs: dict[str, threading.Thread] = {}
|
||||
_cancelled_jobs: set[str] = set()
|
||||
|
||||
|
||||
class RunRequest(BaseModel):
|
||||
timeline_id: str
|
||||
profile_name: str = "soccer_broadcast"
|
||||
checkpoint: bool = True
|
||||
skip_vlm: bool = False
|
||||
skip_cloud: bool = False
|
||||
log_level: str = "INFO" # INFO | DEBUG
|
||||
pause_after_stage: bool = False
|
||||
config_overrides: dict | None = None
|
||||
|
||||
|
||||
class RunResponse(BaseModel):
|
||||
status: str
|
||||
job_id: str
|
||||
timeline_id: str
|
||||
|
||||
|
||||
def _resolve_video_path(video_path: str) -> str:
|
||||
"""Download a chunk from blob storage to a temp file."""
|
||||
from core.storage.blob import get_store
|
||||
|
||||
store = get_store("out")
|
||||
try:
|
||||
return store.download_to_temp(video_path)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Failed to download chunk: {e}")
|
||||
|
||||
|
||||
@router.post("/run", response_model=RunResponse)
|
||||
def run_pipeline(req: RunRequest):
|
||||
"""Launch a detection pipeline run on a timeline."""
|
||||
from core.detect import emit
|
||||
from core.detect.graph import get_pipeline
|
||||
from core.detect.state import DetectState
|
||||
from core.detect.checkpoint.storage import get_timeline
|
||||
from core.db.connection import get_session
|
||||
from core.db.job import create_job, update_job_status
|
||||
|
||||
# Load timeline
|
||||
try:
|
||||
timeline = get_timeline(req.timeline_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail=f"Timeline not found: {req.timeline_id}")
|
||||
|
||||
chunk_paths = timeline["chunk_paths"]
|
||||
if not chunk_paths:
|
||||
raise HTTPException(status_code=400, detail="Timeline has no chunk paths")
|
||||
|
||||
# Resolve first chunk to local path for the pipeline
|
||||
local_path = _resolve_video_path(chunk_paths[0])
|
||||
|
||||
# Create job in DB
|
||||
source_asset_id_str = timeline.get("source_asset_id", "")
|
||||
with get_session() as session:
|
||||
from uuid import UUID as _UUID
|
||||
source_asset_id = _UUID(source_asset_id_str) if source_asset_id_str else uuid.uuid4()
|
||||
job = create_job(
|
||||
session,
|
||||
source_asset_id=source_asset_id,
|
||||
video_path=chunk_paths[0],
|
||||
timeline_id=_UUID(req.timeline_id),
|
||||
profile_name=req.profile_name,
|
||||
config_overrides=req.config_overrides,
|
||||
)
|
||||
job_id = str(job.id)
|
||||
|
||||
if req.skip_vlm:
|
||||
os.environ["SKIP_VLM"] = "1"
|
||||
elif "SKIP_VLM" in os.environ:
|
||||
del os.environ["SKIP_VLM"]
|
||||
|
||||
if req.skip_cloud:
|
||||
os.environ["SKIP_CLOUD"] = "1"
|
||||
elif "SKIP_CLOUD" in os.environ:
|
||||
del os.environ["SKIP_CLOUD"]
|
||||
|
||||
# Clear any stale events
|
||||
from core.events import _get_redis
|
||||
from core.detect.events import DETECT_EVENTS_PREFIX
|
||||
r = _get_redis()
|
||||
r.delete(f"{DETECT_EVENTS_PREFIX}:{job_id}")
|
||||
|
||||
emit.set_run_context(
|
||||
run_id=job_id, parent_job_id=job_id, run_type="initial",
|
||||
log_level=req.log_level,
|
||||
)
|
||||
|
||||
pipeline = get_pipeline(checkpoint=req.checkpoint, profile_name=req.profile_name)
|
||||
|
||||
initial_state = DetectState(
|
||||
video_path=local_path,
|
||||
job_id=job_id,
|
||||
profile_name=req.profile_name,
|
||||
source_asset_id=source_asset_id_str or str(source_asset_id),
|
||||
timeline_id=req.timeline_id,
|
||||
config_overrides=req.config_overrides or {},
|
||||
)
|
||||
|
||||
from core.detect.graph import (
|
||||
PipelineCancelled, set_cancel_check, clear_cancel_check,
|
||||
init_pause, clear_pause,
|
||||
)
|
||||
|
||||
set_cancel_check(job_id, lambda: job_id in _cancelled_jobs)
|
||||
init_pause(job_id, pause_after_stage=req.pause_after_stage)
|
||||
|
||||
def _update_job(status, stage=None, error=None):
|
||||
from core.db.connection import get_session
|
||||
from core.db.job import update_job_status
|
||||
with get_session() as session:
|
||||
update_job_status(session, _UUID(job_id), status,
|
||||
current_stage=stage, error_message=error)
|
||||
|
||||
def _run():
|
||||
try:
|
||||
_update_job("running")
|
||||
emit.log(job_id, "Pipeline", "INFO",
|
||||
f"Starting pipeline: {chunk_paths[0]} (profile={req.profile_name})")
|
||||
pipeline.invoke(initial_state)
|
||||
_update_job("completed")
|
||||
emit.log(job_id, "Pipeline", "INFO", "Pipeline completed successfully")
|
||||
emit.job_complete(job_id, {"status": "completed"})
|
||||
except PipelineCancelled:
|
||||
_update_job("cancelled")
|
||||
emit.log(job_id, "Pipeline", "INFO", "Pipeline cancelled")
|
||||
emit.job_complete(job_id, {"status": "cancelled"})
|
||||
except Exception as e:
|
||||
logger.exception("Pipeline run %s failed: %s", job_id, e)
|
||||
_update_job("failed", error=str(e))
|
||||
from core.detect.graph import _node_states, NODES
|
||||
if job_id in _node_states:
|
||||
states = _node_states[job_id]
|
||||
for node in reversed(NODES):
|
||||
if states.get(node) in ("running", "done"):
|
||||
states[node] = "error"
|
||||
break
|
||||
nodes = [{"id": n, "status": states[n]} for n in NODES]
|
||||
emit.graph_update(job_id, nodes)
|
||||
emit.log(job_id, "Pipeline", "ERROR", str(e))
|
||||
emit.job_complete(job_id, {"status": "failed", "error": str(e)})
|
||||
finally:
|
||||
_running_jobs.pop(job_id, None)
|
||||
_cancelled_jobs.discard(job_id)
|
||||
clear_cancel_check(job_id)
|
||||
clear_pause(job_id)
|
||||
emit.clear_run_context()
|
||||
from core.detect.checkpoint.runner_bridge import reset_checkpoint_state
|
||||
reset_checkpoint_state(job_id)
|
||||
|
||||
thread = threading.Thread(target=_run, daemon=True, name=f"pipeline-{job_id}")
|
||||
_running_jobs[job_id] = thread
|
||||
thread.start()
|
||||
|
||||
return RunResponse(status="started", job_id=job_id, timeline_id=req.timeline_id)
|
||||
|
||||
|
||||
@router.post("/stop/{job_id}")
|
||||
def stop_pipeline(job_id: str):
|
||||
"""Stop a running pipeline. Signals cancellation; the thread checks on next stage."""
|
||||
from core.detect import emit
|
||||
|
||||
if job_id not in _running_jobs:
|
||||
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
||||
|
||||
_cancelled_jobs.add(job_id)
|
||||
emit.log(job_id, "Pipeline", "INFO", "Stop requested — cancelling after current stage")
|
||||
return {"status": "stopping", "job_id": job_id}
|
||||
|
||||
|
||||
@router.post("/pause/{job_id}")
|
||||
def pause(job_id: str):
|
||||
"""Pause a running pipeline after the current stage completes."""
|
||||
from core.detect.graph import pause_pipeline
|
||||
|
||||
if job_id not in _running_jobs:
|
||||
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
||||
|
||||
pause_pipeline(job_id)
|
||||
return {"status": "pausing", "job_id": job_id}
|
||||
|
||||
|
||||
@router.post("/resume/{job_id}")
|
||||
def resume(job_id: str):
|
||||
"""Resume a paused pipeline."""
|
||||
from core.detect.graph import resume_pipeline
|
||||
|
||||
if job_id not in _running_jobs:
|
||||
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
||||
|
||||
resume_pipeline(job_id)
|
||||
return {"status": "running", "job_id": job_id}
|
||||
|
||||
|
||||
@router.post("/step/{job_id}")
|
||||
def step(job_id: str):
|
||||
"""Run one stage then pause again."""
|
||||
from core.detect.graph import step_pipeline
|
||||
|
||||
if job_id not in _running_jobs:
|
||||
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
||||
|
||||
step_pipeline(job_id)
|
||||
return {"status": "stepping", "job_id": job_id}
|
||||
|
||||
|
||||
@router.post("/pause-after-stage/{job_id}")
|
||||
def toggle_pause_after_stage(job_id: str, enabled: bool = True):
|
||||
"""Toggle pause-after-each-stage mode."""
|
||||
from core.detect.graph import set_pause_after_stage
|
||||
|
||||
if job_id not in _running_jobs:
|
||||
raise HTTPException(status_code=404, detail=f"No running pipeline: {job_id}")
|
||||
|
||||
set_pause_after_stage(job_id, enabled)
|
||||
return {"status": "ok", "pause_after_stage": enabled, "job_id": job_id}
|
||||
|
||||
|
||||
@router.get("/status/{job_id}")
|
||||
def pipeline_status(job_id: str):
|
||||
"""Get pipeline run status."""
|
||||
from core.detect.graph import is_paused
|
||||
|
||||
running = job_id in _running_jobs
|
||||
paused = is_paused(job_id)
|
||||
cancelling = job_id in _cancelled_jobs
|
||||
|
||||
if cancelling:
|
||||
status = "cancelling"
|
||||
elif paused:
|
||||
status = "paused"
|
||||
elif running:
|
||||
status = "running"
|
||||
else:
|
||||
status = "idle"
|
||||
|
||||
return {"status": status, "job_id": job_id}
|
||||
|
||||
|
||||
@router.post("/clear/{job_id}")
|
||||
def clear_pipeline(job_id: str):
|
||||
"""Clear events for a job from Redis."""
|
||||
from core.events import _get_redis
|
||||
from core.detect.events import DETECT_EVENTS_PREFIX
|
||||
|
||||
r = _get_redis()
|
||||
r.delete(f"{DETECT_EVENTS_PREFIX}:{job_id}")
|
||||
return {"status": "cleared", "job_id": job_id}
|
||||
108
core/api/detect/sources.py
Normal file
108
core/api/detect/sources.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Source browser for detection pipeline.
|
||||
|
||||
Lists available media sources from blob storage (MinIO).
|
||||
|
||||
GET /detect/sources — list chunk jobs
|
||||
GET /detect/sources/{job_id}/chunks — list chunks for a job
|
||||
GET /detect/sources/{job_id}/chunks/{name}/url — presigned preview URL
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/detect", tags=["detect"])
|
||||
|
||||
|
||||
class ChunkInfoResponse(BaseModel):
|
||||
filename: str
|
||||
key: str
|
||||
size_bytes: int
|
||||
|
||||
|
||||
class SourceInfoResponse(BaseModel):
|
||||
job_id: str
|
||||
source_type: str = "chunk_job"
|
||||
chunk_count: int
|
||||
total_bytes: int = 0
|
||||
|
||||
|
||||
def _list_sources() -> list[SourceInfoResponse]:
|
||||
"""List chunk jobs from blob storage."""
|
||||
from core.storage.blob import get_store
|
||||
|
||||
store = get_store("out")
|
||||
try:
|
||||
objects = store.list(prefix="chunks/")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to list blob sources: %s", e)
|
||||
return []
|
||||
|
||||
jobs: dict[str, int] = {}
|
||||
job_bytes: dict[str, int] = {}
|
||||
for obj in objects:
|
||||
rel_key = obj.key.removeprefix(store.prefix)
|
||||
parts = rel_key.split("/")
|
||||
if len(parts) >= 3 and parts[0] == "chunks":
|
||||
job_id = parts[1]
|
||||
jobs[job_id] = jobs.get(job_id, 0) + 1
|
||||
job_bytes[job_id] = job_bytes.get(job_id, 0) + obj.size_bytes
|
||||
|
||||
sources = []
|
||||
for job_id, count in sorted(jobs.items()):
|
||||
source = SourceInfoResponse(
|
||||
job_id=job_id,
|
||||
source_type="chunk_job",
|
||||
chunk_count=count,
|
||||
total_bytes=job_bytes.get(job_id, 0),
|
||||
)
|
||||
sources.append(source)
|
||||
return sources
|
||||
|
||||
|
||||
@router.get("/sources", response_model=list[SourceInfoResponse])
|
||||
def list_sources():
|
||||
"""List available chunk jobs from blob storage."""
|
||||
return _list_sources()
|
||||
|
||||
|
||||
@router.get("/sources/{source_job_id}/chunks", response_model=list[ChunkInfoResponse])
|
||||
def list_chunks(source_job_id: str):
|
||||
"""List chunks for a specific source job."""
|
||||
from core.storage.blob import get_store
|
||||
|
||||
store = get_store("out")
|
||||
try:
|
||||
objects = store.list(prefix=f"chunks/{source_job_id}/", extensions={".mp4"})
|
||||
except Exception as e:
|
||||
logger.warning("Failed to list chunks for %s: %s", source_job_id, e)
|
||||
raise HTTPException(status_code=503, detail=f"Blob storage unavailable: {e}")
|
||||
|
||||
if not objects:
|
||||
raise HTTPException(status_code=404, detail=f"Source not found: {source_job_id}")
|
||||
|
||||
chunks = []
|
||||
for obj in objects:
|
||||
info = ChunkInfoResponse(filename=obj.filename, key=obj.key, size_bytes=obj.size_bytes)
|
||||
chunks.append(info)
|
||||
return sorted(chunks, key=lambda c: c.filename)
|
||||
|
||||
|
||||
@router.get("/sources/{source_job_id}/chunks/{filename}/url")
|
||||
def get_chunk_url(source_job_id: str, filename: str):
|
||||
"""Return a presigned URL for previewing a chunk in the browser."""
|
||||
from core.storage.blob import get_store
|
||||
|
||||
store = get_store("out")
|
||||
key = f"chunks/{source_job_id}/{filename}"
|
||||
try:
|
||||
url = store.get_url(key, expires=3600)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=503, detail=f"Could not generate URL: {e}")
|
||||
return {"url": url}
|
||||
79
core/api/detect/sse.py
Normal file
79
core/api/detect/sse.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
SSE endpoint for detection pipeline events.
|
||||
|
||||
Uses Redis as the event bus between pipeline workers and the SSE stream.
|
||||
Mirrors chunker_sse.py but polls detect_events:{job_id}.
|
||||
|
||||
GET /detect/stream/{job_id} → text/event-stream
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from core.events import poll_events
|
||||
from core.detect.events import DETECT_EVENTS_PREFIX, TERMINAL_EVENTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/detect", tags=["detect"])
|
||||
|
||||
|
||||
async def _event_generator(job_id: str) -> AsyncGenerator[str, None]:
|
||||
cursor = 0
|
||||
timeout = time.monotonic() + 3600 # 1 hour max
|
||||
|
||||
while time.monotonic() < timeout:
|
||||
events, cursor = poll_events(job_id, cursor, prefix=DETECT_EVENTS_PREFIX)
|
||||
|
||||
if not events:
|
||||
await asyncio.sleep(0.2)
|
||||
continue
|
||||
|
||||
is_terminal = False
|
||||
for data in events:
|
||||
event_type = data.pop("event", "update")
|
||||
payload = {**data, "job_id": job_id}
|
||||
|
||||
yield f"event: {event_type}\ndata: {json.dumps(payload)}\n\n"
|
||||
|
||||
if event_type in TERMINAL_EVENTS:
|
||||
is_terminal = True
|
||||
|
||||
if is_terminal:
|
||||
yield f"event: done\ndata: {json.dumps({'job_id': job_id})}\n\n"
|
||||
# Don't return — keep connection alive so EventSource doesn't reconnect.
|
||||
# Just idle until the client disconnects or timeout.
|
||||
while time.monotonic() < timeout:
|
||||
await asyncio.sleep(5)
|
||||
return
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
yield f"event: timeout\ndata: {json.dumps({'job_id': job_id})}\n\n"
|
||||
|
||||
|
||||
@router.get("/stream/{job_id}")
|
||||
async def stream_detect_job(job_id: str):
|
||||
"""
|
||||
SSE stream for a detection pipeline job.
|
||||
|
||||
The UI connects via native EventSource:
|
||||
const es = new EventSource('/api/detect/stream/<job_id>');
|
||||
es.addEventListener('graph_update', (e) => { ... });
|
||||
es.addEventListener('detection', (e) => { ... });
|
||||
"""
|
||||
return StreamingResponse(
|
||||
_event_generator(job_id),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
226
core/api/detect/timeline.py
Normal file
226
core/api/detect/timeline.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""
|
||||
Timeline + Job management endpoints.
|
||||
|
||||
POST /detect/timeline — create timeline from chunk selection
|
||||
GET /detect/timeline — list timelines
|
||||
GET /detect/timeline/{id} — timeline detail
|
||||
DELETE /detect/timeline/{id}/cache — clear frame cache
|
||||
|
||||
GET /detect/jobs — list jobs (optionally by timeline)
|
||||
GET /detect/jobs/{id} — job detail + checkpoints + stage outputs
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/detect", tags=["detect"])
|
||||
|
||||
|
||||
# --- Request/Response models ---
|
||||
|
||||
class CreateTimelineRequest(BaseModel):
|
||||
chunk_paths: list[str]
|
||||
profile_name: str = "soccer_broadcast"
|
||||
name: str = ""
|
||||
source_asset_id: str = ""
|
||||
fps: float = 2.0
|
||||
|
||||
|
||||
class TimelineResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
chunk_paths: list[str]
|
||||
profile_name: str
|
||||
status: str
|
||||
fps: float
|
||||
frame_count: int
|
||||
source_ephemeral: bool
|
||||
created_at: str | None = None
|
||||
|
||||
|
||||
class JobResponse(BaseModel):
|
||||
id: str
|
||||
timeline_id: str | None
|
||||
source_asset_id: str
|
||||
video_path: str
|
||||
profile_name: str
|
||||
run_type: str
|
||||
status: str
|
||||
current_stage: str | None
|
||||
config_overrides: dict
|
||||
error_message: str | None
|
||||
created_at: str | None
|
||||
started_at: str | None
|
||||
completed_at: str | None
|
||||
|
||||
|
||||
class JobDetailResponse(JobResponse):
|
||||
checkpoints: list[dict]
|
||||
stage_outputs: dict[str, dict]
|
||||
|
||||
|
||||
# --- Timeline endpoints ---
|
||||
|
||||
@router.post("/timeline", response_model=TimelineResponse)
|
||||
def create_timeline_endpoint(req: CreateTimelineRequest):
|
||||
"""Create a timeline from a chunk selection."""
|
||||
from uuid import UUID
|
||||
from core.detect.checkpoint.storage import create_timeline
|
||||
|
||||
source_asset_id = UUID(req.source_asset_id) if req.source_asset_id else None
|
||||
tid = create_timeline(
|
||||
chunk_paths=req.chunk_paths,
|
||||
profile_name=req.profile_name,
|
||||
name=req.name,
|
||||
source_asset_id=source_asset_id,
|
||||
fps=req.fps,
|
||||
)
|
||||
|
||||
from core.detect.checkpoint.storage import get_timeline
|
||||
tl = get_timeline(tid)
|
||||
return TimelineResponse(
|
||||
id=tl["id"],
|
||||
name=tl["name"],
|
||||
chunk_paths=tl["chunk_paths"],
|
||||
profile_name=tl["profile_name"],
|
||||
status=tl["status"],
|
||||
fps=tl["fps"],
|
||||
frame_count=0,
|
||||
source_ephemeral=False,
|
||||
created_at=tl["created_at"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/timeline", response_model=list[TimelineResponse])
|
||||
def list_timelines():
|
||||
"""List all timelines."""
|
||||
from sqlmodel import select
|
||||
from core.db.models import Timeline
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
stmt = select(Timeline).order_by(Timeline.created_at.desc())
|
||||
timelines = session.exec(stmt).all()
|
||||
|
||||
return [
|
||||
TimelineResponse(
|
||||
id=str(t.id),
|
||||
name=t.name,
|
||||
chunk_paths=t.chunk_paths or [],
|
||||
profile_name=t.profile_name,
|
||||
status=t.status,
|
||||
fps=t.fps,
|
||||
frame_count=t.frame_count,
|
||||
source_ephemeral=t.source_ephemeral,
|
||||
created_at=str(t.created_at) if t.created_at else None,
|
||||
)
|
||||
for t in timelines
|
||||
]
|
||||
|
||||
|
||||
@router.get("/timeline/{timeline_id}", response_model=TimelineResponse)
|
||||
def get_timeline_endpoint(timeline_id: str):
|
||||
"""Get timeline detail."""
|
||||
from core.detect.checkpoint.storage import get_timeline
|
||||
try:
|
||||
tl = get_timeline(timeline_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail=f"Timeline not found: {timeline_id}")
|
||||
|
||||
from core.detect.checkpoint.frames import cache_exists
|
||||
from uuid import UUID
|
||||
from core.db.models import Timeline
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
timeline = session.get(Timeline, UUID(timeline_id))
|
||||
|
||||
return TimelineResponse(
|
||||
id=tl["id"],
|
||||
name=tl["name"],
|
||||
chunk_paths=tl["chunk_paths"],
|
||||
profile_name=tl["profile_name"],
|
||||
status=tl["status"],
|
||||
fps=tl["fps"],
|
||||
frame_count=timeline.frame_count if timeline else 0,
|
||||
source_ephemeral=timeline.source_ephemeral if timeline else False,
|
||||
created_at=tl["created_at"],
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/timeline/{timeline_id}/cache")
|
||||
def clear_timeline_cache(timeline_id: str):
|
||||
"""Clear the frame cache for a timeline."""
|
||||
from core.detect.checkpoint.frames import clear_cache
|
||||
from core.detect.checkpoint.storage import update_timeline_status
|
||||
|
||||
clear_cache(timeline_id)
|
||||
update_timeline_status(timeline_id, "created")
|
||||
return {"status": "cleared", "timeline_id": timeline_id}
|
||||
|
||||
|
||||
# --- Job endpoints ---
|
||||
|
||||
def _job_to_response(job) -> JobResponse:
|
||||
return JobResponse(
|
||||
id=str(job.id),
|
||||
timeline_id=str(job.timeline_id) if job.timeline_id else None,
|
||||
source_asset_id=str(job.source_asset_id),
|
||||
video_path=job.video_path,
|
||||
profile_name=job.profile_name,
|
||||
run_type=job.run_type,
|
||||
status=job.status,
|
||||
current_stage=job.current_stage,
|
||||
config_overrides=job.config_overrides or {},
|
||||
error_message=job.error_message,
|
||||
created_at=str(job.created_at) if job.created_at else None,
|
||||
started_at=str(job.started_at) if job.started_at else None,
|
||||
completed_at=str(job.completed_at) if job.completed_at else None,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/jobs", response_model=list[JobResponse])
|
||||
def list_jobs_endpoint(timeline_id: str | None = Query(None)):
|
||||
"""List jobs, optionally filtered by timeline."""
|
||||
from uuid import UUID
|
||||
from core.db.connection import get_session
|
||||
from core.db.job import list_jobs
|
||||
|
||||
tid = UUID(timeline_id) if timeline_id else None
|
||||
with get_session() as session:
|
||||
jobs = list_jobs(session, timeline_id=tid)
|
||||
|
||||
return [_job_to_response(j) for j in jobs]
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}", response_model=JobDetailResponse)
|
||||
def get_job_endpoint(job_id: str):
|
||||
"""Get job detail with checkpoints and stage outputs."""
|
||||
from uuid import UUID
|
||||
from core.db.connection import get_session
|
||||
from core.db.job import get_job
|
||||
from core.detect.checkpoint.storage import (
|
||||
get_checkpoints_for_job,
|
||||
load_stage_outputs_for_job,
|
||||
)
|
||||
|
||||
with get_session() as session:
|
||||
job = get_job(session, UUID(job_id))
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail=f"Job not found: {job_id}")
|
||||
|
||||
checkpoints = get_checkpoints_for_job(job_id)
|
||||
stage_outputs = load_stage_outputs_for_job(job_id)
|
||||
|
||||
base = _job_to_response(job)
|
||||
return JobDetailResponse(
|
||||
**base.model_dump(),
|
||||
checkpoints=checkpoints,
|
||||
stage_outputs=stage_outputs,
|
||||
)
|
||||
@@ -1,273 +0,0 @@
|
||||
"""
|
||||
GraphQL API using strawberry, served via FastAPI.
|
||||
|
||||
Primary API for MPR — all client interactions go through GraphQL.
|
||||
Uses core.db for data access.
|
||||
Types are generated from schema/ via modelgen — see api/schema/graphql.py.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import strawberry
|
||||
from strawberry.schema.config import StrawberryConfig
|
||||
from strawberry.types import Info
|
||||
|
||||
from core.api.schema.graphql import (
|
||||
CreateJobInput,
|
||||
DeleteResultType,
|
||||
MediaAssetType,
|
||||
ScanResultType,
|
||||
SystemStatusType,
|
||||
TranscodeJobType,
|
||||
TranscodePresetType,
|
||||
UpdateAssetInput,
|
||||
)
|
||||
from core.storage import BUCKET_IN, list_objects
|
||||
|
||||
VIDEO_EXTS = {".mp4", ".mkv", ".avi", ".mov", ".webm", ".flv", ".wmv", ".m4v"}
|
||||
AUDIO_EXTS = {".mp3", ".wav", ".flac", ".aac", ".ogg", ".m4a"}
|
||||
MEDIA_EXTS = VIDEO_EXTS | AUDIO_EXTS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Queries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class Query:
|
||||
@strawberry.field
|
||||
def assets(
|
||||
self,
|
||||
info: Info,
|
||||
status: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
) -> List[MediaAssetType]:
|
||||
from core.db import list_assets
|
||||
|
||||
return list_assets(status=status, search=search)
|
||||
|
||||
@strawberry.field
|
||||
def asset(self, info: Info, id: UUID) -> Optional[MediaAssetType]:
|
||||
from core.db import get_asset
|
||||
|
||||
try:
|
||||
return get_asset(id)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@strawberry.field
|
||||
def jobs(
|
||||
self,
|
||||
info: Info,
|
||||
status: Optional[str] = None,
|
||||
source_asset_id: Optional[UUID] = None,
|
||||
) -> List[TranscodeJobType]:
|
||||
from core.db import list_jobs
|
||||
|
||||
return list_jobs(status=status, source_asset_id=source_asset_id)
|
||||
|
||||
@strawberry.field
|
||||
def job(self, info: Info, id: UUID) -> Optional[TranscodeJobType]:
|
||||
from core.db import get_job
|
||||
|
||||
try:
|
||||
return get_job(id)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@strawberry.field
|
||||
def presets(self, info: Info) -> List[TranscodePresetType]:
|
||||
from core.db import list_presets
|
||||
|
||||
return list_presets()
|
||||
|
||||
@strawberry.field
|
||||
def system_status(self, info: Info) -> SystemStatusType:
|
||||
return SystemStatusType(status="ok", version="0.1.0")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mutations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class Mutation:
|
||||
@strawberry.mutation
|
||||
def scan_media_folder(self, info: Info) -> ScanResultType:
|
||||
from core.db import create_asset, get_asset_filenames
|
||||
|
||||
objects = list_objects(BUCKET_IN, extensions=MEDIA_EXTS)
|
||||
existing = get_asset_filenames()
|
||||
|
||||
registered = []
|
||||
skipped = []
|
||||
|
||||
for obj in objects:
|
||||
if obj["filename"] in existing:
|
||||
skipped.append(obj["filename"])
|
||||
continue
|
||||
try:
|
||||
create_asset(
|
||||
filename=obj["filename"],
|
||||
file_path=obj["key"],
|
||||
file_size=obj["size"],
|
||||
)
|
||||
registered.append(obj["filename"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ScanResultType(
|
||||
found=len(objects),
|
||||
registered=len(registered),
|
||||
skipped=len(skipped),
|
||||
files=registered,
|
||||
)
|
||||
|
||||
@strawberry.mutation
|
||||
def create_job(self, info: Info, input: CreateJobInput) -> TranscodeJobType:
|
||||
from pathlib import Path
|
||||
|
||||
from core.db import create_job, get_asset, get_preset
|
||||
|
||||
try:
|
||||
source = get_asset(input.source_asset_id)
|
||||
except Exception:
|
||||
raise Exception("Source asset not found")
|
||||
|
||||
preset = None
|
||||
preset_snapshot = {}
|
||||
if input.preset_id:
|
||||
try:
|
||||
preset = get_preset(input.preset_id)
|
||||
preset_snapshot = {
|
||||
"name": preset.name,
|
||||
"container": preset.container,
|
||||
"video_codec": preset.video_codec,
|
||||
"audio_codec": preset.audio_codec,
|
||||
}
|
||||
except Exception:
|
||||
raise Exception("Preset not found")
|
||||
|
||||
if not preset and not input.trim_start and not input.trim_end:
|
||||
raise Exception("Must specify preset_id or trim_start/trim_end")
|
||||
|
||||
output_filename = input.output_filename
|
||||
if not output_filename:
|
||||
stem = Path(source.filename).stem
|
||||
ext = preset_snapshot.get("container", "mp4") if preset else "mp4"
|
||||
output_filename = f"{stem}_output.{ext}"
|
||||
|
||||
job = create_job(
|
||||
source_asset_id=source.id,
|
||||
preset_id=preset.id if preset else None,
|
||||
preset_snapshot=preset_snapshot,
|
||||
trim_start=input.trim_start,
|
||||
trim_end=input.trim_end,
|
||||
output_filename=output_filename,
|
||||
output_path=output_filename,
|
||||
priority=input.priority or 0,
|
||||
)
|
||||
|
||||
executor_mode = os.environ.get("MPR_EXECUTOR", "local")
|
||||
if executor_mode in ("lambda", "gcp"):
|
||||
from core.task.executor import get_executor
|
||||
|
||||
get_executor().run(
|
||||
job_id=str(job.id),
|
||||
source_path=source.file_path,
|
||||
output_path=output_filename,
|
||||
preset=preset_snapshot or None,
|
||||
trim_start=input.trim_start,
|
||||
trim_end=input.trim_end,
|
||||
duration=source.duration,
|
||||
)
|
||||
else:
|
||||
from core.task.tasks import run_transcode_job
|
||||
|
||||
result = run_transcode_job.delay(
|
||||
job_id=str(job.id),
|
||||
source_key=source.file_path,
|
||||
output_key=output_filename,
|
||||
preset=preset_snapshot or None,
|
||||
trim_start=input.trim_start,
|
||||
trim_end=input.trim_end,
|
||||
duration=source.duration,
|
||||
)
|
||||
job.celery_task_id = result.id
|
||||
job.save(update_fields=["celery_task_id"])
|
||||
|
||||
return job
|
||||
|
||||
@strawberry.mutation
|
||||
def cancel_job(self, info: Info, id: UUID) -> TranscodeJobType:
|
||||
from core.db import get_job, update_job
|
||||
|
||||
try:
|
||||
job = get_job(id)
|
||||
except Exception:
|
||||
raise Exception("Job not found")
|
||||
|
||||
if job.status not in ("pending", "processing"):
|
||||
raise Exception(f"Cannot cancel job with status: {job.status}")
|
||||
|
||||
return update_job(job, status="cancelled")
|
||||
|
||||
@strawberry.mutation
|
||||
def retry_job(self, info: Info, id: UUID) -> TranscodeJobType:
|
||||
from core.db import get_job, update_job
|
||||
|
||||
try:
|
||||
job = get_job(id)
|
||||
except Exception:
|
||||
raise Exception("Job not found")
|
||||
|
||||
if job.status != "failed":
|
||||
raise Exception("Only failed jobs can be retried")
|
||||
|
||||
return update_job(job, status="pending", progress=0, error_message=None)
|
||||
|
||||
@strawberry.mutation
|
||||
def update_asset(self, info: Info, id: UUID, input: UpdateAssetInput) -> MediaAssetType:
|
||||
from core.db import get_asset, update_asset
|
||||
|
||||
try:
|
||||
asset = get_asset(id)
|
||||
except Exception:
|
||||
raise Exception("Asset not found")
|
||||
|
||||
fields = {}
|
||||
if input.comments is not None:
|
||||
fields["comments"] = input.comments
|
||||
if input.tags is not None:
|
||||
fields["tags"] = input.tags
|
||||
|
||||
if fields:
|
||||
asset = update_asset(asset, **fields)
|
||||
|
||||
return asset
|
||||
|
||||
@strawberry.mutation
|
||||
def delete_asset(self, info: Info, id: UUID) -> DeleteResultType:
|
||||
from core.db import delete_asset, get_asset
|
||||
|
||||
try:
|
||||
asset = get_asset(id)
|
||||
delete_asset(asset)
|
||||
return DeleteResultType(ok=True)
|
||||
except Exception:
|
||||
raise Exception("Asset not found")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
schema = strawberry.Schema(
|
||||
query=Query,
|
||||
mutation=Mutation,
|
||||
config=StrawberryConfig(auto_camel_case=False),
|
||||
)
|
||||
@@ -1,98 +1,58 @@
|
||||
"""
|
||||
MPR FastAPI Application
|
||||
|
||||
Serves GraphQL API and Lambda callback endpoint.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
# Initialize Django before importing models
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "admin.mpr.settings")
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import django
|
||||
|
||||
django.setup()
|
||||
|
||||
from fastapi import FastAPI, Header, HTTPException
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from strawberry.fastapi import GraphQLRouter
|
||||
|
||||
from core.api.graphql import schema as graphql_schema
|
||||
from core.api.detect import router as detect_router
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app):
|
||||
from core.db.connection import create_tables
|
||||
from core.db.seed import seed_profiles
|
||||
create_tables()
|
||||
seed_profiles()
|
||||
yield
|
||||
|
||||
CALLBACK_API_KEY = os.environ.get("CALLBACK_API_KEY", "")
|
||||
|
||||
app = FastAPI(
|
||||
title="MPR API",
|
||||
description="Media Processor — GraphQL API",
|
||||
version="0.1.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://mpr.local.ar", "http://localhost:5173"],
|
||||
allow_origins=["http://mpr.local.ar", "http://k8s.mpr.local.ar", "http://localhost:5173"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# GraphQL
|
||||
graphql_router = GraphQLRouter(schema=graphql_schema, graphql_ide="graphiql")
|
||||
app.include_router(graphql_router, prefix="/graphql")
|
||||
# Detection API (sources, run, SSE, replay, config)
|
||||
app.include_router(detect_router)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def root():
|
||||
"""API root."""
|
||||
return {
|
||||
"name": "MPR API",
|
||||
"version": "0.1.0",
|
||||
"graphql": "/graphql",
|
||||
}
|
||||
|
||||
|
||||
@app.post("/api/jobs/{job_id}/callback")
|
||||
def job_callback(
|
||||
job_id: UUID,
|
||||
payload: dict,
|
||||
x_api_key: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
Callback endpoint for Lambda to report job completion.
|
||||
Protected by API key.
|
||||
"""
|
||||
if CALLBACK_API_KEY and x_api_key != CALLBACK_API_KEY:
|
||||
raise HTTPException(status_code=403, detail="Invalid API key")
|
||||
|
||||
from django.utils import timezone
|
||||
|
||||
from core.db import get_job, update_job
|
||||
|
||||
try:
|
||||
job = get_job(job_id)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
status = payload.get("status", "failed")
|
||||
fields = {
|
||||
"status": status,
|
||||
"progress": 100.0 if status == "completed" else job.progress,
|
||||
}
|
||||
|
||||
if payload.get("error"):
|
||||
fields["error_message"] = payload["error"]
|
||||
|
||||
if status in ("completed", "failed"):
|
||||
fields["completed_at"] = timezone.now()
|
||||
|
||||
update_job(job, **fields)
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
@@ -1,158 +0,0 @@
|
||||
"""
|
||||
Strawberry Types - GENERATED FILE
|
||||
|
||||
Do not edit directly. Regenerate using modelgen.
|
||||
"""
|
||||
|
||||
import strawberry
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
from strawberry.scalars import JSON
|
||||
|
||||
|
||||
@strawberry.enum
|
||||
class AssetStatus(Enum):
|
||||
PENDING = "pending"
|
||||
READY = "ready"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@strawberry.enum
|
||||
class JobStatus(Enum):
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class MediaAssetType:
|
||||
"""A video/audio file registered in the system."""
|
||||
|
||||
id: Optional[UUID] = None
|
||||
filename: Optional[str] = None
|
||||
file_path: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
file_size: Optional[int] = None
|
||||
duration: Optional[float] = None
|
||||
video_codec: Optional[str] = None
|
||||
audio_codec: Optional[str] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
framerate: Optional[float] = None
|
||||
bitrate: Optional[int] = None
|
||||
properties: Optional[JSON] = None
|
||||
comments: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class TranscodePresetType:
|
||||
"""A reusable transcoding configuration (like Handbrake presets)."""
|
||||
|
||||
id: Optional[UUID] = None
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
is_builtin: Optional[bool] = None
|
||||
container: Optional[str] = None
|
||||
video_codec: Optional[str] = None
|
||||
video_bitrate: Optional[str] = None
|
||||
video_crf: Optional[int] = None
|
||||
video_preset: Optional[str] = None
|
||||
resolution: Optional[str] = None
|
||||
framerate: Optional[float] = None
|
||||
audio_codec: Optional[str] = None
|
||||
audio_bitrate: Optional[str] = None
|
||||
audio_channels: Optional[int] = None
|
||||
audio_samplerate: Optional[int] = None
|
||||
extra_args: Optional[List[str]] = None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class TranscodeJobType:
|
||||
"""A transcoding or trimming job in the queue."""
|
||||
|
||||
id: Optional[UUID] = None
|
||||
source_asset_id: Optional[UUID] = None
|
||||
preset_id: Optional[UUID] = None
|
||||
preset_snapshot: Optional[JSON] = None
|
||||
trim_start: Optional[float] = None
|
||||
trim_end: Optional[float] = None
|
||||
output_filename: Optional[str] = None
|
||||
output_path: Optional[str] = None
|
||||
output_asset_id: Optional[UUID] = None
|
||||
status: Optional[str] = None
|
||||
progress: Optional[float] = None
|
||||
current_frame: Optional[int] = None
|
||||
current_time: Optional[float] = None
|
||||
speed: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
celery_task_id: Optional[str] = None
|
||||
execution_arn: Optional[str] = None
|
||||
priority: Optional[int] = None
|
||||
created_at: Optional[datetime] = None
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class CreateJobInput:
|
||||
"""Request body for creating a transcode/trim job."""
|
||||
|
||||
source_asset_id: UUID
|
||||
preset_id: Optional[UUID] = None
|
||||
trim_start: Optional[float] = None
|
||||
trim_end: Optional[float] = None
|
||||
output_filename: Optional[str] = None
|
||||
priority: int = 0
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class UpdateAssetInput:
|
||||
"""Request body for updating asset metadata."""
|
||||
|
||||
comments: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class SystemStatusType:
|
||||
"""System status response."""
|
||||
|
||||
status: Optional[str] = None
|
||||
version: Optional[str] = None
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class ScanResultType:
|
||||
"""Result of scanning the media input bucket."""
|
||||
|
||||
found: Optional[int] = None
|
||||
registered: Optional[int] = None
|
||||
skipped: Optional[int] = None
|
||||
files: Optional[List[str]] = None
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class DeleteResultType:
|
||||
"""Result of a delete operation."""
|
||||
|
||||
ok: Optional[bool] = None
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class WorkerStatusType:
|
||||
"""Worker health and capabilities."""
|
||||
|
||||
available: Optional[bool] = None
|
||||
active_jobs: Optional[int] = None
|
||||
supported_codecs: Optional[List[str]] = None
|
||||
gpu_available: Optional[bool] = None
|
||||
@@ -1,19 +1,24 @@
|
||||
from .assets import (
|
||||
create_asset,
|
||||
delete_asset,
|
||||
get_asset,
|
||||
get_asset_filenames,
|
||||
list_assets,
|
||||
update_asset,
|
||||
)
|
||||
from .jobs import (
|
||||
create_job,
|
||||
get_job,
|
||||
list_jobs,
|
||||
update_job,
|
||||
update_job_fields,
|
||||
)
|
||||
from .presets import (
|
||||
get_preset,
|
||||
list_presets,
|
||||
"""
|
||||
Database layer.
|
||||
|
||||
tables.py — SQLModel table definitions (generated by modelgen, don't edit)
|
||||
domain files — session-first query functions for non-trivial operations
|
||||
|
||||
Basic CRUD (create, get, update, delete) goes directly through the session:
|
||||
session.add(Job(...))
|
||||
session.get(Job, id)
|
||||
session.get(Job, id); setattr(...); session.commit()
|
||||
session.delete(obj); session.commit()
|
||||
"""
|
||||
|
||||
from .connection import get_session, create_tables
|
||||
|
||||
from .models import MediaAsset, Job, Timeline, Checkpoint, Brand
|
||||
|
||||
from .assets import list_assets, get_asset_filenames
|
||||
from .job import list_jobs
|
||||
from .checkpoint import (
|
||||
get_latest_checkpoint, get_root_checkpoint,
|
||||
list_checkpoints, list_scenarios,
|
||||
)
|
||||
from .brand import get_or_create_brand, find_brand_by_text, list_brands, record_airing
|
||||
|
||||
@@ -1,48 +1,23 @@
|
||||
"""Database operations for MediaAsset."""
|
||||
"""MediaAsset queries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Session, select
|
||||
|
||||
def list_assets(status: Optional[str] = None, search: Optional[str] = None):
|
||||
from admin.mpr.media_assets.models import MediaAsset
|
||||
from .models import MediaAsset
|
||||
|
||||
qs = MediaAsset.objects.all()
|
||||
|
||||
def list_assets(session: Session, status: Optional[str] = None, search: Optional[str] = None) -> list[MediaAsset]:
|
||||
stmt = select(MediaAsset)
|
||||
if status:
|
||||
qs = qs.filter(status=status)
|
||||
stmt = stmt.where(MediaAsset.status == status)
|
||||
if search:
|
||||
qs = qs.filter(filename__icontains=search)
|
||||
return list(qs)
|
||||
stmt = stmt.where(MediaAsset.filename.ilike(f"%{search}%"))
|
||||
return list(session.exec(stmt).all())
|
||||
|
||||
|
||||
def get_asset(id: UUID):
|
||||
from admin.mpr.media_assets.models import MediaAsset
|
||||
|
||||
return MediaAsset.objects.get(id=id)
|
||||
|
||||
|
||||
def get_asset_filenames() -> set[str]:
|
||||
from admin.mpr.media_assets.models import MediaAsset
|
||||
|
||||
return set(MediaAsset.objects.values_list("filename", flat=True))
|
||||
|
||||
|
||||
def create_asset(*, filename: str, file_path: str, file_size: int):
|
||||
from admin.mpr.media_assets.models import MediaAsset
|
||||
|
||||
return MediaAsset.objects.create(
|
||||
filename=filename,
|
||||
file_path=file_path,
|
||||
file_size=file_size,
|
||||
)
|
||||
|
||||
|
||||
def update_asset(asset, **fields):
|
||||
for key, value in fields.items():
|
||||
setattr(asset, key, value)
|
||||
asset.save(update_fields=list(fields.keys()))
|
||||
return asset
|
||||
|
||||
|
||||
def delete_asset(asset):
|
||||
asset.delete()
|
||||
def get_asset_filenames(session: Session) -> set[str]:
|
||||
return set(session.exec(select(MediaAsset.filename)).all())
|
||||
|
||||
61
core/db/brand.py
Normal file
61
core/db/brand.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Brand queries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from .models import Brand
|
||||
|
||||
|
||||
def get_or_create_brand(session: Session, canonical_name: str,
|
||||
aliases: Optional[list[str]] = None,
|
||||
source: str = "ocr") -> tuple[Brand, bool]:
|
||||
normalized = canonical_name.strip()
|
||||
brand = session.exec(select(Brand).where(Brand.canonical_name.ilike(normalized))).first()
|
||||
if brand:
|
||||
return brand, False
|
||||
|
||||
brand = Brand(canonical_name=normalized, aliases=aliases or [], source=source)
|
||||
session.add(brand)
|
||||
session.flush()
|
||||
return brand, True
|
||||
|
||||
|
||||
def find_brand_by_text(session: Session, text: str) -> Brand | None:
|
||||
normalized = text.strip().lower()
|
||||
brand = session.exec(select(Brand).where(Brand.canonical_name.ilike(normalized))).first()
|
||||
if brand:
|
||||
return brand
|
||||
|
||||
for b in session.exec(select(Brand)).all():
|
||||
if normalized in [a.lower() for a in (b.aliases or [])]:
|
||||
return b
|
||||
return None
|
||||
|
||||
|
||||
def list_brands(session: Session) -> list[Brand]:
|
||||
return list(session.exec(select(Brand).order_by(Brand.canonical_name)).all())
|
||||
|
||||
|
||||
def record_airing(session: Session, brand_id: UUID, timeline_id: UUID,
|
||||
frame_start: int, frame_end: int,
|
||||
confidence: float, source: str = "ocr") -> Brand:
|
||||
brand = session.get(Brand, brand_id)
|
||||
if not brand:
|
||||
raise ValueError(f"Brand not found: {brand_id}")
|
||||
|
||||
airing = {
|
||||
"timeline_id": str(timeline_id),
|
||||
"frame_start": frame_start,
|
||||
"frame_end": frame_end,
|
||||
"confidence": confidence,
|
||||
"source": source,
|
||||
}
|
||||
airings = list(brand.airings or [])
|
||||
airings.append(airing)
|
||||
brand.airings = airings
|
||||
brand.total_airings = len(airings)
|
||||
return brand
|
||||
43
core/db/checkpoint.py
Normal file
43
core/db/checkpoint.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Checkpoint queries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from .models import Checkpoint
|
||||
|
||||
|
||||
def get_latest_checkpoint(session: Session, timeline_id: UUID, parent_id: UUID | None = None) -> Checkpoint | None:
|
||||
stmt = select(Checkpoint).where(Checkpoint.timeline_id == timeline_id)
|
||||
if parent_id is not None:
|
||||
stmt = stmt.where(Checkpoint.parent_id == parent_id)
|
||||
stmt = stmt.order_by(Checkpoint.created_at.desc())
|
||||
return session.exec(stmt).first()
|
||||
|
||||
|
||||
def get_root_checkpoint(session: Session, timeline_id: UUID) -> Checkpoint | None:
|
||||
stmt = select(Checkpoint).where(
|
||||
Checkpoint.timeline_id == timeline_id,
|
||||
Checkpoint.parent_id == None,
|
||||
)
|
||||
return session.exec(stmt).first()
|
||||
|
||||
|
||||
def list_checkpoints(session: Session, timeline_id: UUID) -> list[Checkpoint]:
|
||||
stmt = (
|
||||
select(Checkpoint)
|
||||
.where(Checkpoint.timeline_id == timeline_id)
|
||||
.order_by(Checkpoint.created_at)
|
||||
)
|
||||
return list(session.exec(stmt).all())
|
||||
|
||||
|
||||
def list_scenarios(session: Session) -> list[Checkpoint]:
|
||||
stmt = (
|
||||
select(Checkpoint)
|
||||
.where(Checkpoint.is_scenario == True)
|
||||
.order_by(Checkpoint.created_at.desc())
|
||||
)
|
||||
return list(session.exec(stmt).all())
|
||||
34
core/db/connection.py
Normal file
34
core/db/connection.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
Database engine and session — SQLModel/SQLAlchemy, no Django.
|
||||
|
||||
Reads DATABASE_URL from the environment.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlmodel import Session
|
||||
|
||||
DATABASE_URL = os.environ.get("DATABASE_URL", "postgresql://mpr:mpr@localhost:5432/mpr")
|
||||
|
||||
_engine = None
|
||||
|
||||
|
||||
def get_engine():
|
||||
global _engine
|
||||
if _engine is None:
|
||||
_engine = create_engine(DATABASE_URL, pool_size=5, max_overflow=10)
|
||||
return _engine
|
||||
|
||||
|
||||
def get_session() -> Session:
|
||||
return Session(get_engine())
|
||||
|
||||
|
||||
def create_tables():
|
||||
"""Create all SQLModel tables."""
|
||||
from sqlmodel import SQLModel
|
||||
from . import models # noqa — registers all table classes
|
||||
SQLModel.metadata.create_all(get_engine())
|
||||
143
core/db/fixtures/soccer_broadcast.json
Normal file
143
core/db/fixtures/soccer_broadcast.json
Normal file
@@ -0,0 +1,143 @@
|
||||
{
|
||||
"name": "soccer_broadcast",
|
||||
"pipeline": {
|
||||
"name": "soccer_broadcast",
|
||||
"profile_name": "soccer_broadcast",
|
||||
"stages": [
|
||||
{
|
||||
"name": "extract_frames",
|
||||
"branch": "trunk"
|
||||
},
|
||||
{
|
||||
"name": "filter_scenes",
|
||||
"branch": "trunk"
|
||||
},
|
||||
{
|
||||
"name": "field_segmentation",
|
||||
"branch": "trunk"
|
||||
},
|
||||
{
|
||||
"name": "detect_edges",
|
||||
"branch": "hoarding"
|
||||
},
|
||||
{
|
||||
"name": "detect_objects",
|
||||
"branch": "objects"
|
||||
},
|
||||
{
|
||||
"name": "preprocess"
|
||||
},
|
||||
{
|
||||
"name": "run_ocr"
|
||||
},
|
||||
{
|
||||
"name": "match_brands"
|
||||
},
|
||||
{
|
||||
"name": "escalate_vlm"
|
||||
},
|
||||
{
|
||||
"name": "escalate_cloud"
|
||||
},
|
||||
{
|
||||
"name": "compile_report"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": "extract_frames",
|
||||
"target": "filter_scenes"
|
||||
},
|
||||
{
|
||||
"source": "filter_scenes",
|
||||
"target": "field_segmentation"
|
||||
},
|
||||
{
|
||||
"source": "field_segmentation",
|
||||
"target": "detect_edges",
|
||||
"transform": {"invert_mask": true}
|
||||
},
|
||||
{
|
||||
"source": "field_segmentation",
|
||||
"target": "detect_objects"
|
||||
},
|
||||
{
|
||||
"source": "detect_edges",
|
||||
"target": "preprocess"
|
||||
},
|
||||
{
|
||||
"source": "detect_objects",
|
||||
"target": "preprocess"
|
||||
},
|
||||
{
|
||||
"source": "preprocess",
|
||||
"target": "run_ocr"
|
||||
},
|
||||
{
|
||||
"source": "run_ocr",
|
||||
"target": "match_brands"
|
||||
},
|
||||
{
|
||||
"source": "match_brands",
|
||||
"target": "escalate_vlm"
|
||||
},
|
||||
{
|
||||
"source": "escalate_vlm",
|
||||
"target": "escalate_cloud"
|
||||
},
|
||||
{
|
||||
"source": "escalate_cloud",
|
||||
"target": "compile_report"
|
||||
}
|
||||
]
|
||||
},
|
||||
"configs": {
|
||||
"extract_frames": {
|
||||
"fps": 2.0,
|
||||
"max_frames": 500
|
||||
},
|
||||
"filter_scenes": {
|
||||
"hamming_threshold": 8,
|
||||
"enabled": true
|
||||
},
|
||||
"field_segmentation": {
|
||||
"enabled": true,
|
||||
"hue_low": 30,
|
||||
"hue_high": 85,
|
||||
"sat_low": 30,
|
||||
"sat_high": 255,
|
||||
"val_low": 30,
|
||||
"val_high": 255,
|
||||
"morph_kernel": 15,
|
||||
"min_area_ratio": 0.05
|
||||
},
|
||||
"detect_edges": {
|
||||
"enabled": true,
|
||||
"edge_canny_low": 50,
|
||||
"edge_canny_high": 150,
|
||||
"edge_hough_threshold": 80,
|
||||
"edge_hough_min_length": 100,
|
||||
"edge_hough_max_gap": 10,
|
||||
"edge_pair_max_distance": 200,
|
||||
"edge_pair_min_distance": 15
|
||||
},
|
||||
"detect_objects": {
|
||||
"model_name": "yolov8n.pt",
|
||||
"confidence_threshold": 0.3,
|
||||
"target_classes": []
|
||||
},
|
||||
"run_ocr": {
|
||||
"languages": [
|
||||
"en",
|
||||
"es"
|
||||
],
|
||||
"min_confidence": 0.5
|
||||
},
|
||||
"match_brands": {
|
||||
"fuzzy_threshold": 75
|
||||
},
|
||||
"escalate_vlm": {
|
||||
"vlm_prompt_template": "Identify the brand or sponsor visible in this cropped region from a soccer broadcast.{hint}{text} Respond with: brand, confidence (0-1), reasoning."
|
||||
}
|
||||
}
|
||||
}
|
||||
80
core/db/job.py
Normal file
80
core/db/job.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Job queries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from .models import Job
|
||||
|
||||
|
||||
def create_job(
|
||||
session: Session,
|
||||
source_asset_id: UUID,
|
||||
video_path: str,
|
||||
timeline_id: UUID,
|
||||
profile_name: str = "soccer_broadcast",
|
||||
run_type: str = "initial",
|
||||
parent_id: UUID | None = None,
|
||||
config_overrides: dict | None = None,
|
||||
) -> Job:
|
||||
job = Job(
|
||||
source_asset_id=source_asset_id,
|
||||
video_path=video_path,
|
||||
timeline_id=timeline_id,
|
||||
profile_name=profile_name,
|
||||
run_type=run_type,
|
||||
parent_id=parent_id,
|
||||
config_overrides=config_overrides or {},
|
||||
status="pending",
|
||||
)
|
||||
session.add(job)
|
||||
session.commit()
|
||||
session.refresh(job)
|
||||
return job
|
||||
|
||||
|
||||
def update_job_status(
|
||||
session: Session,
|
||||
job_id: UUID,
|
||||
status: str,
|
||||
current_stage: str | None = None,
|
||||
error_message: str | None = None,
|
||||
):
|
||||
job = session.get(Job, job_id)
|
||||
if not job:
|
||||
return
|
||||
job.status = status
|
||||
if current_stage is not None:
|
||||
job.current_stage = current_stage
|
||||
if error_message is not None:
|
||||
job.error_message = error_message
|
||||
if status == "running" and not job.started_at:
|
||||
job.started_at = datetime.utcnow()
|
||||
if status in ("completed", "failed", "cancelled"):
|
||||
job.completed_at = datetime.utcnow()
|
||||
session.commit()
|
||||
|
||||
|
||||
def get_job(session: Session, job_id: UUID) -> Job | None:
|
||||
return session.get(Job, job_id)
|
||||
|
||||
|
||||
def list_jobs(
|
||||
session: Session,
|
||||
timeline_id: UUID | None = None,
|
||||
parent_id: UUID | None = None,
|
||||
status: str | None = None,
|
||||
) -> list[Job]:
|
||||
stmt = select(Job)
|
||||
if timeline_id:
|
||||
stmt = stmt.where(Job.timeline_id == timeline_id)
|
||||
if parent_id:
|
||||
stmt = stmt.where(Job.parent_id == parent_id)
|
||||
if status:
|
||||
stmt = stmt.where(Job.status == status)
|
||||
stmt = stmt.order_by(Job.created_at.desc())
|
||||
return list(session.exec(stmt).all())
|
||||
@@ -1,40 +0,0 @@
|
||||
"""Database operations for TranscodeJob."""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
def list_jobs(status: Optional[str] = None, source_asset_id: Optional[UUID] = None):
|
||||
from admin.mpr.media_assets.models import TranscodeJob
|
||||
|
||||
qs = TranscodeJob.objects.all()
|
||||
if status:
|
||||
qs = qs.filter(status=status)
|
||||
if source_asset_id:
|
||||
qs = qs.filter(source_asset_id=source_asset_id)
|
||||
return list(qs)
|
||||
|
||||
|
||||
def get_job(id: UUID):
|
||||
from admin.mpr.media_assets.models import TranscodeJob
|
||||
|
||||
return TranscodeJob.objects.get(id=id)
|
||||
|
||||
|
||||
def create_job(**fields):
|
||||
from admin.mpr.media_assets.models import TranscodeJob
|
||||
|
||||
return TranscodeJob.objects.create(**fields)
|
||||
|
||||
|
||||
def update_job(job, **fields):
|
||||
for key, value in fields.items():
|
||||
setattr(job, key, value)
|
||||
job.save(update_fields=list(fields.keys()))
|
||||
return job
|
||||
|
||||
|
||||
def update_job_fields(job_id, **fields):
|
||||
from admin.mpr.media_assets.models import TranscodeJob
|
||||
|
||||
TranscodeJob.objects.filter(id=job_id).update(**fields)
|
||||
179
core/db/models.py
Normal file
179
core/db/models.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
SQLModel Table Models - GENERATED FILE
|
||||
|
||||
Do not edit directly. Regenerate using modelgen.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import SQLModel, Field, Column
|
||||
from sqlalchemy import JSON
|
||||
|
||||
class AssetStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
READY = "ready"
|
||||
ERROR = "error"
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
class RunType(str, Enum):
|
||||
INITIAL = "initial"
|
||||
REPLAY = "replay"
|
||||
RETRY = "retry"
|
||||
|
||||
class BrandSource(str, Enum):
|
||||
OCR = "ocr"
|
||||
VLM = "local_vlm"
|
||||
CLOUD = "cloud_llm"
|
||||
MANUAL = "manual"
|
||||
|
||||
class SourceType(str, Enum):
|
||||
CHUNK_JOB = "chunk_job"
|
||||
UPLOAD = "upload"
|
||||
DEVICE = "device"
|
||||
STREAM = "stream"
|
||||
|
||||
class MediaAsset(SQLModel, table=True):
|
||||
"""A video/audio file registered in the system."""
|
||||
__tablename__ = "media_asset"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
filename: str
|
||||
file_path: str
|
||||
status: AssetStatus = "pending"
|
||||
error_message: Optional[str] = None
|
||||
file_size: Optional[int] = None
|
||||
duration: Optional[float] = None
|
||||
video_codec: Optional[str] = None
|
||||
audio_codec: Optional[str] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
framerate: Optional[float] = None
|
||||
bitrate: Optional[int] = None
|
||||
properties: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
comments: str = ""
|
||||
tags: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
|
||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
|
||||
class TranscodePreset(SQLModel, table=True):
|
||||
"""A reusable transcoding configuration (like Handbrake presets)."""
|
||||
__tablename__ = "transcode_preset"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
name: str
|
||||
description: str = ""
|
||||
is_builtin: bool = False
|
||||
container: str = "mp4"
|
||||
video_codec: str = "libx264"
|
||||
video_bitrate: Optional[str] = None
|
||||
video_crf: Optional[int] = None
|
||||
video_preset: Optional[str] = None
|
||||
resolution: Optional[str] = None
|
||||
framerate: Optional[float] = None
|
||||
audio_codec: str = "aac"
|
||||
audio_bitrate: Optional[str] = None
|
||||
audio_channels: Optional[int] = None
|
||||
audio_samplerate: Optional[int] = None
|
||||
extra_args: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
|
||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
|
||||
class Job(SQLModel, table=True):
|
||||
"""A pipeline job."""
|
||||
__tablename__ = "job"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
source_asset_id: UUID = Field(index=True)
|
||||
video_path: str
|
||||
profile_name: str = "soccer_broadcast"
|
||||
timeline_id: Optional[UUID] = None
|
||||
parent_id: Optional[UUID] = None
|
||||
run_type: RunType = "initial"
|
||||
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
status: JobStatus = "pending"
|
||||
current_stage: Optional[str] = None
|
||||
progress: float = 0.0
|
||||
error_message: Optional[str] = None
|
||||
total_detections: int = 0
|
||||
brands_found: int = 0
|
||||
cloud_llm_calls: int = 0
|
||||
estimated_cost_usd: float = 0.0
|
||||
priority: int = 0
|
||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
class Timeline(SQLModel, table=True):
|
||||
"""A user-created selection of source material."""
|
||||
__tablename__ = "timeline"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
name: str = ""
|
||||
source_asset_id: Optional[UUID] = Field(default=None, index=True)
|
||||
chunk_paths: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
|
||||
profile_name: str = ""
|
||||
status: str = "created"
|
||||
fps: float = 2.0
|
||||
frame_count: int = 0
|
||||
source_ephemeral: bool = False
|
||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
|
||||
class Checkpoint(SQLModel, table=True):
|
||||
"""A snapshot of pipeline state on a timeline."""
|
||||
__tablename__ = "checkpoint"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
timeline_id: UUID
|
||||
job_id: Optional[UUID] = Field(default=None, index=True)
|
||||
parent_id: Optional[UUID] = None
|
||||
stage_name: str = ""
|
||||
config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
stats: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
is_scenario: bool = False
|
||||
scenario_label: str = ""
|
||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
|
||||
class StageOutput(SQLModel, table=True):
|
||||
"""Output of a single stage within a job."""
|
||||
__tablename__ = "stage_output"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
job_id: UUID = Field(index=True)
|
||||
timeline_id: UUID
|
||||
stage_name: str
|
||||
checkpoint_id: Optional[UUID] = None
|
||||
output: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
|
||||
class Brand(SQLModel, table=True):
|
||||
"""A brand discovered or registered in the system."""
|
||||
__tablename__ = "brand"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
canonical_name: str = Field(index=True)
|
||||
aliases: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
|
||||
source: BrandSource = "ocr"
|
||||
confirmed: bool = False
|
||||
airings: List[Dict[str, Any]] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]'))
|
||||
total_airings: int = 0
|
||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow)
|
||||
|
||||
class Profile(SQLModel, table=True):
|
||||
"""A content type profile."""
|
||||
__tablename__ = "profile"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
name: str
|
||||
pipeline: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
configs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}'))
|
||||
@@ -1,15 +0,0 @@
|
||||
"""Database operations for TranscodePreset."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
def list_presets():
|
||||
from admin.mpr.media_assets.models import TranscodePreset
|
||||
|
||||
return list(TranscodePreset.objects.all())
|
||||
|
||||
|
||||
def get_preset(id: UUID):
|
||||
from admin.mpr.media_assets.models import TranscodePreset
|
||||
|
||||
return TranscodePreset.objects.get(id=id)
|
||||
43
core/db/seed.py
Normal file
43
core/db/seed.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Seed data — insert initial profile rows if they don't exist.
|
||||
|
||||
Called on startup after create_tables().
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SEED_DIR = Path(__file__).parent / "fixtures"
|
||||
|
||||
|
||||
def seed_profiles():
|
||||
"""Insert seed profiles from JSON fixtures if not already present."""
|
||||
from .connection import get_session
|
||||
from .models import Profile
|
||||
|
||||
fixtures = list(SEED_DIR.glob("*.json"))
|
||||
if not fixtures:
|
||||
return
|
||||
|
||||
with get_session() as session:
|
||||
for f in fixtures:
|
||||
data = json.loads(f.read_text())
|
||||
name = data["name"]
|
||||
|
||||
existing = session.query(Profile).filter(Profile.name == name).first()
|
||||
if existing:
|
||||
logger.debug("Profile %s already exists, skipping seed", name)
|
||||
continue
|
||||
|
||||
profile = Profile(
|
||||
name=name,
|
||||
pipeline=data.get("pipeline", {}),
|
||||
configs=data.get("configs", {}),
|
||||
)
|
||||
session.add(profile)
|
||||
logger.info("Seeded profile: %s", name)
|
||||
|
||||
session.commit()
|
||||
0
core/detect/__init__.py
Normal file
0
core/detect/__init__.py
Normal file
31
core/detect/checkpoint/__init__.py
Normal file
31
core/detect/checkpoint/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Checkpoint system — Timeline + Checkpoint tree + StageOutput.
|
||||
|
||||
detect/checkpoint/
|
||||
frames.py — per-timeline frame cache (local filesystem)
|
||||
storage.py — Timeline, Checkpoint, StageOutput persistence
|
||||
replay.py — replay from checkpoint (TODO: rework in 5d)
|
||||
runner_bridge.py — checkpoint hook for PipelineRunner
|
||||
"""
|
||||
|
||||
from .storage import (
|
||||
create_timeline,
|
||||
get_timeline,
|
||||
update_timeline_status,
|
||||
save_checkpoint,
|
||||
get_checkpoints_for_job,
|
||||
get_checkpoints_for_timeline,
|
||||
save_stage_output,
|
||||
load_stage_output,
|
||||
load_stage_outputs_for_job,
|
||||
load_stage_outputs_for_timeline,
|
||||
)
|
||||
from .frames import (
|
||||
cache_exists,
|
||||
cache_frames,
|
||||
load_cached_frames,
|
||||
load_cached_frames_b64,
|
||||
clear_cache,
|
||||
frames_to_b64,
|
||||
)
|
||||
from .runner_bridge import checkpoint_after_stage, reset_checkpoint_state, get_latest_checkpoint
|
||||
281
core/detect/checkpoint/frames.py
Normal file
281
core/detect/checkpoint/frames.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
Frame cache — per-timeline frame storage in blob storage (S3/MinIO).
|
||||
|
||||
Frames are extracted from chunks once, cached as JPEGs at
|
||||
cache/timelines/{timeline_id}/frames/{seq}.jpg in the app's
|
||||
blob storage. Any job on the timeline reads from the cache.
|
||||
Cache is clearable and rebuildable from chunks.
|
||||
|
||||
Uses the same storage backend as the rest of the app, so it
|
||||
works across lambdas, GPU boxes, and local dev.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from core.detect.models import Frame
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BUCKET = os.environ.get("S3_BUCKET", "mpr")
|
||||
CACHE_PREFIX = "cache/timelines"
|
||||
|
||||
|
||||
def _frame_key(timeline_id: str, seq: int) -> str:
|
||||
return f"{CACHE_PREFIX}/{timeline_id}/frames/{seq}.jpg"
|
||||
|
||||
|
||||
def _list_prefix(timeline_id: str) -> str:
|
||||
return f"{CACHE_PREFIX}/{timeline_id}/frames/"
|
||||
|
||||
|
||||
def cache_exists(timeline_id: str) -> bool:
|
||||
"""Check if frame cache exists for a timeline."""
|
||||
from core.storage.s3 import list_objects
|
||||
|
||||
objects = list_objects(BUCKET, _list_prefix(timeline_id))
|
||||
return len(objects) > 0
|
||||
|
||||
|
||||
def cache_frames(timeline_id: str, frames: list[Frame], quality: int = 85) -> int:
|
||||
"""
|
||||
Write frames to blob storage as JPEGs.
|
||||
|
||||
Returns number of frames cached.
|
||||
"""
|
||||
from core.storage.s3 import upload_file
|
||||
|
||||
for frame in frames:
|
||||
key = _frame_key(timeline_id, frame.sequence)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
|
||||
img = Image.fromarray(frame.image)
|
||||
img.save(tmp, format="JPEG", quality=quality)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
upload_file(tmp_path, BUCKET, key)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
logger.info("Cached %d frames for timeline %s", len(frames), timeline_id)
|
||||
return len(frames)
|
||||
|
||||
|
||||
def load_cached_frames(timeline_id: str) -> list[Frame]:
|
||||
"""
|
||||
Load all cached frames as Frame objects with numpy arrays.
|
||||
|
||||
Returns empty list if cache doesn't exist.
|
||||
"""
|
||||
from core.storage.s3 import list_objects, download_to_temp
|
||||
|
||||
objects = list_objects(BUCKET, _list_prefix(timeline_id))
|
||||
if not objects:
|
||||
return []
|
||||
|
||||
frames = []
|
||||
for obj in objects:
|
||||
key = obj["key"]
|
||||
filename = key.rsplit("/", 1)[-1]
|
||||
if not filename.endswith(".jpg"):
|
||||
continue
|
||||
seq = int(filename.replace(".jpg", ""))
|
||||
|
||||
tmp_path = download_to_temp(BUCKET, key)
|
||||
try:
|
||||
img = Image.open(tmp_path).convert("RGB")
|
||||
image_array = np.array(img)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
frame = Frame(
|
||||
sequence=seq,
|
||||
chunk_id=0,
|
||||
timestamp=0.0,
|
||||
image=image_array,
|
||||
perceptual_hash="",
|
||||
)
|
||||
frames.append(frame)
|
||||
|
||||
frames.sort(key=lambda f: f.sequence)
|
||||
return frames
|
||||
|
||||
|
||||
def load_cached_frames_b64(timeline_id: str) -> list[dict]:
|
||||
"""
|
||||
Load cached frames as base64 JPEGs for the UI.
|
||||
|
||||
Returns list of {seq, timestamp, jpeg_b64}.
|
||||
"""
|
||||
from core.storage.s3 import list_objects, download_to_temp
|
||||
|
||||
objects = list_objects(BUCKET, _list_prefix(timeline_id))
|
||||
if not objects:
|
||||
return []
|
||||
|
||||
result = []
|
||||
for obj in objects:
|
||||
key = obj["key"]
|
||||
filename = key.rsplit("/", 1)[-1]
|
||||
if not filename.endswith(".jpg"):
|
||||
continue
|
||||
seq = int(filename.replace(".jpg", ""))
|
||||
|
||||
tmp_path = download_to_temp(BUCKET, key)
|
||||
try:
|
||||
with open(tmp_path, "rb") as f:
|
||||
jpeg_b64 = base64.b64encode(f.read()).decode()
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
result.append({
|
||||
"seq": seq,
|
||||
"timestamp": 0.0,
|
||||
"jpeg_b64": jpeg_b64,
|
||||
})
|
||||
|
||||
result.sort(key=lambda f: f["seq"])
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Debug overlay storage — per job/stage/frame
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _overlay_prefix(timeline_id: str, job_id: str, stage: str) -> str:
|
||||
return f"{CACHE_PREFIX}/{timeline_id}/overlays/{job_id}/{stage}/"
|
||||
|
||||
|
||||
def _overlay_key(timeline_id: str, job_id: str, stage: str, seq: int, name: str) -> str:
|
||||
return f"{CACHE_PREFIX}/{timeline_id}/overlays/{job_id}/{stage}/{seq}_{name}.png"
|
||||
|
||||
|
||||
def save_overlays(
|
||||
timeline_id: str,
|
||||
job_id: str,
|
||||
stage: str,
|
||||
seq: int,
|
||||
overlays: dict[str, str],
|
||||
):
|
||||
"""
|
||||
Save debug overlay images (base64 PNG) to blob storage.
|
||||
|
||||
overlays: {overlay_key: base64_png_string}
|
||||
e.g. {"edge_overlay_b64": "iVBOR...", "lines_overlay_b64": "iVBOR..."}
|
||||
"""
|
||||
from core.storage.s3 import upload_file
|
||||
import tempfile
|
||||
|
||||
for name, b64_data in overlays.items():
|
||||
key = _overlay_key(timeline_id, job_id, stage, seq, name)
|
||||
raw = base64.b64decode(b64_data)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
||||
tmp.write(raw)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
upload_file(tmp_path, BUCKET, key)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
logger.info("Saved %d overlays for timeline %s job %s stage %s frame %d",
|
||||
len(overlays), timeline_id, job_id, stage, seq)
|
||||
|
||||
|
||||
def load_overlays(
|
||||
timeline_id: str,
|
||||
job_id: str,
|
||||
stage: str,
|
||||
seq: int,
|
||||
) -> dict[str, str] | None:
|
||||
"""
|
||||
Load debug overlay images from blob storage as base64 strings.
|
||||
|
||||
Returns {overlay_key: base64_png_string} or None if no overlays cached.
|
||||
"""
|
||||
from core.storage.s3 import list_objects, download_to_temp
|
||||
|
||||
prefix = _overlay_prefix(timeline_id, job_id, stage)
|
||||
seq_prefix = f"{seq}_"
|
||||
objects = list_objects(BUCKET, prefix)
|
||||
|
||||
overlays = {}
|
||||
for obj in objects:
|
||||
filename = obj["key"].rsplit("/", 1)[-1]
|
||||
if not filename.startswith(seq_prefix):
|
||||
continue
|
||||
name = filename[len(seq_prefix):].replace(".png", "")
|
||||
|
||||
tmp_path = download_to_temp(BUCKET, obj["key"])
|
||||
try:
|
||||
with open(tmp_path, "rb") as f:
|
||||
overlays[name] = base64.b64encode(f.read()).decode()
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
return overlays if overlays else None
|
||||
|
||||
|
||||
def list_overlay_frames(
|
||||
timeline_id: str,
|
||||
job_id: str,
|
||||
stage: str,
|
||||
) -> list[int]:
|
||||
"""List frame sequences that have cached overlays."""
|
||||
from core.storage.s3 import list_objects
|
||||
|
||||
prefix = _overlay_prefix(timeline_id, job_id, stage)
|
||||
objects = list_objects(BUCKET, prefix)
|
||||
|
||||
seqs = set()
|
||||
for obj in objects:
|
||||
filename = obj["key"].rsplit("/", 1)[-1]
|
||||
seq_str = filename.split("_")[0]
|
||||
try:
|
||||
seqs.add(int(seq_str))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return sorted(seqs)
|
||||
|
||||
|
||||
def clear_cache(timeline_id: str):
|
||||
"""Delete the frame cache for a timeline."""
|
||||
from core.storage.s3 import delete_objects
|
||||
|
||||
prefix = _list_prefix(timeline_id)
|
||||
delete_objects(BUCKET, prefix)
|
||||
logger.info("Cleared frame cache for timeline %s", timeline_id)
|
||||
|
||||
|
||||
def frames_to_b64(frames: list[Frame], quality: int = 75) -> list[dict]:
|
||||
"""
|
||||
Convert in-memory Frame objects to base64 JPEG dicts.
|
||||
|
||||
For API responses when frames are already in memory.
|
||||
"""
|
||||
result = []
|
||||
for frame in frames:
|
||||
buf = io.BytesIO()
|
||||
img = Image.fromarray(frame.image)
|
||||
img.save(buf, format="JPEG", quality=quality)
|
||||
jpeg_b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
result.append({
|
||||
"seq": frame.sequence,
|
||||
"timestamp": frame.timestamp,
|
||||
"jpeg_b64": jpeg_b64,
|
||||
})
|
||||
|
||||
result.sort(key=lambda f: f["seq"])
|
||||
return result
|
||||
307
core/detect/checkpoint/replay.py
Normal file
307
core/detect/checkpoint/replay.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""
|
||||
Pipeline replay — re-run from any stage with different config.
|
||||
|
||||
Loads stage outputs from DB, frames from timeline cache,
|
||||
reconstitutes state, and runs from a target stage onward.
|
||||
|
||||
Creates a new Job (run_type=REPLAY) for each replay invocation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.graph import NODES, get_pipeline
|
||||
from core.detect.graph.runner import PipelineRunner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_state_for_replay(
|
||||
job_id: str,
|
||||
up_to_stage: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Reconstitute pipeline state from a completed job's stage outputs,
|
||||
up to (but not including) the target stage.
|
||||
|
||||
Loads frames from timeline cache + stage outputs from DB.
|
||||
"""
|
||||
from .storage import load_stage_outputs_for_job, get_checkpoints_for_job
|
||||
from .frames import load_cached_frames
|
||||
from core.db.connection import get_session
|
||||
from core.db.job import get_job
|
||||
|
||||
# Load the job to get timeline_id and profile
|
||||
with get_session() as session:
|
||||
job = get_job(session, uuid.UUID(job_id))
|
||||
if not job:
|
||||
raise ValueError(f"Job not found: {job_id}")
|
||||
|
||||
timeline_id = str(job.timeline_id) if job.timeline_id else ""
|
||||
if not timeline_id:
|
||||
raise ValueError(f"Job {job_id} has no timeline")
|
||||
|
||||
# Load frames from timeline cache
|
||||
frames = load_cached_frames(timeline_id)
|
||||
if not frames:
|
||||
raise ValueError(f"No cached frames for timeline {timeline_id}. Run the pipeline first.")
|
||||
|
||||
# Load all stage outputs for this job
|
||||
all_outputs = load_stage_outputs_for_job(job_id)
|
||||
|
||||
# Build state with envelope + frames
|
||||
state = {
|
||||
"job_id": job_id,
|
||||
"timeline_id": timeline_id,
|
||||
"video_path": job.video_path,
|
||||
"profile_name": job.profile_name,
|
||||
"source_asset_id": str(job.source_asset_id),
|
||||
"frames": frames,
|
||||
"config_overrides": {},
|
||||
}
|
||||
|
||||
# Apply stage outputs in pipeline order, up to the target stage
|
||||
target_idx = NODES.index(up_to_stage)
|
||||
for stage_name in NODES[:target_idx]:
|
||||
output = all_outputs.get(stage_name)
|
||||
if output:
|
||||
# Stage outputs contain serialized data — merge into state
|
||||
# The stage registry's deserialize_fn can reconstitute if needed
|
||||
for key, value in output.items():
|
||||
state[key] = value
|
||||
|
||||
# Filtered frames: reconstruct from sequence list if present
|
||||
filtered_seqs = state.get("filtered_frame_sequences")
|
||||
if filtered_seqs:
|
||||
seq_set = set(filtered_seqs)
|
||||
state["filtered_frames"] = [f for f in frames if f.sequence in seq_set]
|
||||
elif "filtered_frames" not in state:
|
||||
state["filtered_frames"] = frames
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def replay_from(
|
||||
job_id: str,
|
||||
start_stage: str,
|
||||
config_overrides: dict | None = None,
|
||||
checkpoint: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Replay the pipeline from a specific stage.
|
||||
|
||||
Loads state from the original job's stage outputs up to start_stage,
|
||||
applies config overrides, and runs from start_stage onward.
|
||||
|
||||
Creates a new Job (run_type=REPLAY).
|
||||
Returns the final state dict.
|
||||
"""
|
||||
if start_stage not in NODES:
|
||||
raise ValueError(f"Unknown stage: {start_stage!r}. Options: {NODES}")
|
||||
|
||||
start_idx = NODES.index(start_stage)
|
||||
if start_idx == 0:
|
||||
raise ValueError("Cannot replay from the first stage — just run the full pipeline")
|
||||
|
||||
logger.info("Replaying job %s from %s", job_id, start_stage)
|
||||
|
||||
state = _build_state_for_replay(job_id, start_stage)
|
||||
|
||||
# Apply config overrides
|
||||
if config_overrides:
|
||||
state["config_overrides"] = config_overrides
|
||||
|
||||
# Create replay job
|
||||
from core.db.connection import get_session
|
||||
from core.db.job import create_job, get_job
|
||||
with get_session() as session:
|
||||
original = get_job(session, uuid.UUID(job_id))
|
||||
replay_job = create_job(
|
||||
session,
|
||||
source_asset_id=original.source_asset_id,
|
||||
video_path=original.video_path,
|
||||
timeline_id=original.timeline_id,
|
||||
profile_name=original.profile_name,
|
||||
run_type="replay",
|
||||
parent_id=original.id,
|
||||
config_overrides=config_overrides,
|
||||
)
|
||||
replay_job_id = str(replay_job.id)
|
||||
|
||||
# Update state with new job ID
|
||||
state["job_id"] = replay_job_id
|
||||
|
||||
# Set run context for SSE events
|
||||
emit.set_run_context(
|
||||
run_id=replay_job_id,
|
||||
parent_job_id=job_id,
|
||||
run_type="replay",
|
||||
)
|
||||
|
||||
# Run from start_stage onward
|
||||
pipeline = get_pipeline(
|
||||
checkpoint=checkpoint,
|
||||
profile_name=state["profile_name"],
|
||||
start_from=start_stage,
|
||||
)
|
||||
|
||||
try:
|
||||
result = pipeline.invoke(state)
|
||||
finally:
|
||||
emit.clear_run_context()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def replay_single_stage(
|
||||
job_id: str,
|
||||
stage: str,
|
||||
frame_refs: list[int] | None = None,
|
||||
config_overrides: dict | None = None,
|
||||
debug: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Replay a single stage on specific frames (or all frames from checkpoint).
|
||||
|
||||
Fast path for interactive parameter tuning — runs only the target stage
|
||||
function, not the full pipeline tail. Returns the stage output directly.
|
||||
"""
|
||||
if stage not in NODES:
|
||||
raise ValueError(f"Unknown stage: {stage!r}. Options: {NODES}")
|
||||
|
||||
stage_idx = NODES.index(stage)
|
||||
if stage_idx == 0:
|
||||
raise ValueError("Cannot replay the first stage — just run the full pipeline")
|
||||
|
||||
logger.info("Single-stage replay: job %s, stage %s (debug=%s)", job_id, stage, debug)
|
||||
|
||||
state = _build_state_for_replay(job_id, stage)
|
||||
|
||||
# Build profile with overrides
|
||||
from core.detect.profile import get_profile, get_stage_config
|
||||
profile = get_profile(state.get("profile_name", "soccer_broadcast"))
|
||||
if config_overrides:
|
||||
merged_configs = dict(profile.get("configs", {}))
|
||||
for sname, soverrides in config_overrides.items():
|
||||
if sname in merged_configs:
|
||||
merged_configs[sname] = {**merged_configs[sname], **soverrides}
|
||||
else:
|
||||
merged_configs[sname] = soverrides
|
||||
profile = {**profile, "configs": merged_configs}
|
||||
|
||||
# Subset frames if requested
|
||||
frames = state.get("filtered_frames", state.get("frames", []))
|
||||
if frame_refs:
|
||||
ref_set = set(frame_refs)
|
||||
frames = [f for f in frames if f.sequence in ref_set]
|
||||
|
||||
# Run the specific stage
|
||||
if stage == "detect_edges":
|
||||
return _replay_detect_edges(state, profile, frames, job_id, debug)
|
||||
elif stage == "field_segmentation":
|
||||
return _replay_field_segmentation(state, profile, frames, job_id, debug)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Single-stage replay not yet implemented for {stage!r}. "
|
||||
f"Use replay_from() for full pipeline replay."
|
||||
)
|
||||
|
||||
|
||||
def _replay_detect_edges(
|
||||
state: dict,
|
||||
profile,
|
||||
frames: list,
|
||||
job_id: str,
|
||||
debug: bool,
|
||||
) -> dict:
|
||||
"""Run edge detection on checkpoint frames, optionally with debug overlays."""
|
||||
from core.detect.stages.edge_detector import detect_edge_regions
|
||||
from core.detect.profile import get_stage_config
|
||||
from core.detect.stages.models import RegionAnalysisConfig
|
||||
|
||||
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
||||
inference_url = os.environ.get("INFERENCE_URL")
|
||||
field_masks = state.get("field_masks", {})
|
||||
|
||||
result = detect_edge_regions(
|
||||
frames=frames,
|
||||
config=config,
|
||||
inference_url=inference_url,
|
||||
job_id=job_id,
|
||||
field_masks=field_masks,
|
||||
)
|
||||
output = {"edge_regions_by_frame": result}
|
||||
|
||||
if debug and frames:
|
||||
debug_data = {}
|
||||
if inference_url:
|
||||
from core.detect.inference import InferenceClient
|
||||
client = InferenceClient(base_url=inference_url, job_id=job_id)
|
||||
for frame in frames:
|
||||
dr = client.detect_edges_debug(
|
||||
image=frame.image,
|
||||
edge_canny_low=config.edge_canny_low,
|
||||
edge_canny_high=config.edge_canny_high,
|
||||
edge_hough_threshold=config.edge_hough_threshold,
|
||||
edge_hough_min_length=config.edge_hough_min_length,
|
||||
edge_hough_max_gap=config.edge_hough_max_gap,
|
||||
edge_pair_max_distance=config.edge_pair_max_distance,
|
||||
edge_pair_min_distance=config.edge_pair_min_distance,
|
||||
)
|
||||
debug_data[frame.sequence] = {
|
||||
"edge_overlay_b64": dr.edge_overlay_b64,
|
||||
"lines_overlay_b64": dr.lines_overlay_b64,
|
||||
"horizontal_count": dr.horizontal_count,
|
||||
"pair_count": dr.pair_count,
|
||||
}
|
||||
else:
|
||||
from core.detect.stages.edge_detector import _load_cv_edges
|
||||
edges_mod = _load_cv_edges()
|
||||
for frame in frames:
|
||||
dr = edges_mod.detect_edges_debug(
|
||||
frame.image,
|
||||
canny_low=config.edge_canny_low,
|
||||
canny_high=config.edge_canny_high,
|
||||
hough_threshold=config.edge_hough_threshold,
|
||||
hough_min_length=config.edge_hough_min_length,
|
||||
hough_max_gap=config.edge_hough_max_gap,
|
||||
pair_max_distance=config.edge_pair_max_distance,
|
||||
pair_min_distance=config.edge_pair_min_distance,
|
||||
)
|
||||
debug_data[frame.sequence] = {
|
||||
"edge_overlay_b64": dr["edge_overlay_b64"],
|
||||
"lines_overlay_b64": dr["lines_overlay_b64"],
|
||||
"horizontal_count": dr["horizontal_count"],
|
||||
"pair_count": dr["pair_count"],
|
||||
}
|
||||
output["debug"] = debug_data
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _replay_field_segmentation(
|
||||
state: dict,
|
||||
profile,
|
||||
frames: list,
|
||||
job_id: str,
|
||||
debug: bool,
|
||||
) -> dict:
|
||||
"""Run field segmentation on checkpoint frames."""
|
||||
from core.detect.stages.field_segmentation import run_field_segmentation
|
||||
from core.detect.profile import get_stage_config
|
||||
from core.detect.stages.models import FieldSegmentationConfig
|
||||
|
||||
config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation"))
|
||||
inference_url = os.environ.get("INFERENCE_URL")
|
||||
|
||||
result = run_field_segmentation(
|
||||
frames=frames,
|
||||
config=config,
|
||||
inference_url=inference_url,
|
||||
job_id=job_id,
|
||||
)
|
||||
return result
|
||||
99
core/detect/checkpoint/runner_bridge.py
Normal file
99
core/detect/checkpoint/runner_bridge.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Runner bridge — checkpoint hook called by PipelineRunner after each stage.
|
||||
|
||||
Saves a checkpoint + stage output after each stage completes.
|
||||
Timeline and Job are independent: timeline_id and job_id come from
|
||||
the pipeline state (set at job creation time).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Per-job state: tracks the latest checkpoint so we can chain parent → child
|
||||
_latest_checkpoint: dict[str, str] = {}
|
||||
|
||||
|
||||
def reset_checkpoint_state(job_id: str):
|
||||
"""Clean up per-job checkpoint state. Called when pipeline finishes."""
|
||||
_latest_checkpoint.pop(job_id, None)
|
||||
|
||||
|
||||
def checkpoint_after_stage(job_id: str, stage_name: str, state: dict, result: dict):
|
||||
"""
|
||||
Save a checkpoint + stage output after a stage completes.
|
||||
|
||||
Called by the runner. Handles:
|
||||
- Stage output serialization (via stage registry)
|
||||
- Checkpoint chain (parent → child)
|
||||
- Stage output as separate row in StageOutput table
|
||||
"""
|
||||
if not job_id:
|
||||
return
|
||||
|
||||
timeline_id = state.get("timeline_id", "")
|
||||
if not timeline_id:
|
||||
logger.warning("No timeline_id in state for job %s, skipping checkpoint", job_id)
|
||||
return
|
||||
|
||||
from .storage import save_checkpoint, save_stage_output
|
||||
from core.detect.stages.base import _REGISTRY, _LEGACY_REGISTRY
|
||||
|
||||
merged = {**state, **result}
|
||||
|
||||
# Serialize stage output using the stage's serialize_fn if available
|
||||
# Check new-style registry first, then legacy (some stages are in both)
|
||||
serialize_fn = None
|
||||
stage_cls = _REGISTRY.get(stage_name)
|
||||
if stage_cls:
|
||||
serialize_fn = getattr(getattr(stage_cls, "definition", None), "serialize_fn", None)
|
||||
if not serialize_fn:
|
||||
legacy = _LEGACY_REGISTRY.get(stage_name)
|
||||
if legacy:
|
||||
serialize_fn = legacy.serialize_fn
|
||||
|
||||
if serialize_fn:
|
||||
output_json = serialize_fn(merged, job_id)
|
||||
else:
|
||||
output_json = {}
|
||||
|
||||
# Convert stats dataclass to dict for JSONB storage
|
||||
import dataclasses
|
||||
raw_stats = state.get("stats", {})
|
||||
if dataclasses.is_dataclass(raw_stats):
|
||||
stats_dict = dataclasses.asdict(raw_stats)
|
||||
elif isinstance(raw_stats, dict):
|
||||
stats_dict = raw_stats
|
||||
else:
|
||||
stats_dict = {}
|
||||
|
||||
# Save checkpoint (lightweight tree node)
|
||||
parent_id = _latest_checkpoint.get(job_id)
|
||||
checkpoint_id = save_checkpoint(
|
||||
timeline_id=timeline_id,
|
||||
stage_name=stage_name,
|
||||
parent_checkpoint_id=parent_id,
|
||||
config_overrides=state.get("config_overrides"),
|
||||
stats=stats_dict,
|
||||
job_id=job_id,
|
||||
)
|
||||
_latest_checkpoint[job_id] = checkpoint_id
|
||||
|
||||
# Save stage output (separate row, upsert by job+stage)
|
||||
if output_json:
|
||||
save_stage_output(
|
||||
job_id=job_id,
|
||||
timeline_id=timeline_id,
|
||||
stage_name=stage_name,
|
||||
output=output_json,
|
||||
checkpoint_id=checkpoint_id,
|
||||
)
|
||||
|
||||
logger.info("Checkpoint %s + output for stage %s (job %s)", checkpoint_id, stage_name, job_id)
|
||||
|
||||
|
||||
def get_latest_checkpoint(job_id: str) -> str | None:
|
||||
"""Get the latest checkpoint_id for a running job."""
|
||||
return _latest_checkpoint.get(job_id)
|
||||
109
core/detect/checkpoint/serializer.py
Normal file
109
core/detect/checkpoint/serializer.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
State serialization — DetectState ↔ JSON-compatible dict.
|
||||
|
||||
Delegates to each stage's serialize_fn/deserialize_fn via the registry.
|
||||
This file has no model-specific knowledge — stages own their data format.
|
||||
|
||||
The only things serialized here are the "envelope" fields (job_id, video_path, etc.)
|
||||
that don't belong to any stage.
|
||||
|
||||
Frames are ephemeral (in-memory during a run). Serialization stores
|
||||
metadata only; frames are re-extracted from chunks when needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from core.schema.serializers._common import serialize_dataclass
|
||||
from core.schema.serializers.pipeline import (
|
||||
deserialize_pipeline_stats,
|
||||
deserialize_text_candidates,
|
||||
)
|
||||
|
||||
|
||||
# Envelope fields — not owned by any stage, always present
|
||||
ENVELOPE_KEYS = ["job_id", "video_path", "profile_name", "timeline_id", "config_overrides"]
|
||||
|
||||
|
||||
def serialize_state(state: dict) -> dict:
|
||||
"""
|
||||
Serialize DetectState to a JSON-compatible dict.
|
||||
|
||||
Calls each registered stage's serialize_fn for stage-owned data.
|
||||
Envelope fields (job_id, etc.) are copied directly.
|
||||
"""
|
||||
from core.detect.stages.base import _REGISTRY
|
||||
|
||||
checkpoint = {}
|
||||
|
||||
# Envelope
|
||||
for key in ENVELOPE_KEYS:
|
||||
default = {} if key == "config_overrides" else ""
|
||||
checkpoint[key] = state.get(key, default)
|
||||
|
||||
# Stats (shared across stages, not owned by one)
|
||||
stats = state.get("stats")
|
||||
if stats is not None:
|
||||
checkpoint["stats"] = serialize_dataclass(stats)
|
||||
else:
|
||||
checkpoint["stats"] = {}
|
||||
|
||||
# Per-stage data
|
||||
for name, stage_def in _REGISTRY.items():
|
||||
if stage_def.serialize_fn is None:
|
||||
continue
|
||||
job_id = state.get("job_id", "")
|
||||
stage_data = stage_def.serialize_fn(state, job_id)
|
||||
checkpoint[f"stage_{name}"] = stage_data
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def deserialize_state(checkpoint: dict, frames: list) -> dict:
|
||||
"""
|
||||
Reconstitute DetectState from a checkpoint dict + frames.
|
||||
|
||||
Frames are provided by the caller (re-extracted from chunks).
|
||||
Calls each stage's deserialize_fn to restore stage-owned data.
|
||||
"""
|
||||
from core.detect.stages.base import _REGISTRY
|
||||
|
||||
frame_map = {f.sequence: f for f in frames}
|
||||
|
||||
state = {}
|
||||
|
||||
# Envelope
|
||||
for key in ENVELOPE_KEYS:
|
||||
default = {} if key == "config_overrides" else ""
|
||||
state[key] = checkpoint.get(key, default)
|
||||
|
||||
# Frames (provided externally, ephemeral)
|
||||
state["frames"] = frames
|
||||
|
||||
# Stats
|
||||
state["stats"] = deserialize_pipeline_stats(checkpoint.get("stats", {}))
|
||||
|
||||
# Per-stage data
|
||||
for name, stage_def in _REGISTRY.items():
|
||||
if stage_def.deserialize_fn is None:
|
||||
continue
|
||||
|
||||
stage_key = f"stage_{name}"
|
||||
if stage_key not in checkpoint:
|
||||
continue
|
||||
|
||||
job_id = state.get("job_id", "")
|
||||
stage_data = stage_def.deserialize_fn(checkpoint[stage_key], job_id)
|
||||
|
||||
for k, v in stage_data.items():
|
||||
if k == "_filtered_sequences":
|
||||
# Reconnect filtered frames from sequence list
|
||||
seq_set = set(v)
|
||||
state["filtered_frames"] = [f for f in frames if f.sequence in seq_set]
|
||||
elif k.endswith("_raw"):
|
||||
# Raw text candidates need frame reference reconnection
|
||||
real_key = k.removeprefix("_").removesuffix("_raw")
|
||||
state[real_key] = deserialize_text_candidates(v, frame_map)
|
||||
else:
|
||||
state[k] = v
|
||||
|
||||
return state
|
||||
303
core/detect/checkpoint/storage.py
Normal file
303
core/detect/checkpoint/storage.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
Checkpoint storage — Timeline, Checkpoint, StageOutput persistence.
|
||||
|
||||
Timeline: user-created source selection (chunk paths)
|
||||
Checkpoint: lightweight tree node (parent_id, stage_name, config, stats)
|
||||
StageOutput: per-stage result (flat table, one row per job+stage)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Timeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_timeline(
|
||||
chunk_paths: list[str],
|
||||
profile_name: str = "",
|
||||
name: str = "",
|
||||
source_asset_id: UUID | None = None,
|
||||
fps: float = 2.0,
|
||||
) -> str:
|
||||
"""
|
||||
Create a timeline from a chunk selection.
|
||||
|
||||
Called by the user (via API) before any pipeline runs.
|
||||
Returns timeline_id.
|
||||
"""
|
||||
from core.db.models import Timeline
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
timeline = Timeline(
|
||||
name=name,
|
||||
chunk_paths=chunk_paths,
|
||||
profile_name=profile_name,
|
||||
source_asset_id=source_asset_id,
|
||||
fps=fps,
|
||||
status="created",
|
||||
)
|
||||
session.add(timeline)
|
||||
session.commit()
|
||||
session.refresh(timeline)
|
||||
tid = str(timeline.id)
|
||||
|
||||
logger.info("Timeline created: %s (%d chunks)", tid, len(chunk_paths))
|
||||
return tid
|
||||
|
||||
|
||||
def get_timeline(timeline_id: str) -> dict:
|
||||
"""Load a timeline as a dict."""
|
||||
from core.db.models import Timeline
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
timeline = session.get(Timeline, UUID(timeline_id))
|
||||
if not timeline:
|
||||
raise ValueError(f"Timeline not found: {timeline_id}")
|
||||
|
||||
return {
|
||||
"id": str(timeline.id),
|
||||
"name": timeline.name,
|
||||
"chunk_paths": timeline.chunk_paths,
|
||||
"profile_name": timeline.profile_name,
|
||||
"status": timeline.status,
|
||||
"fps": timeline.fps,
|
||||
"source_asset_id": str(timeline.source_asset_id) if timeline.source_asset_id else None,
|
||||
"created_at": str(timeline.created_at) if timeline.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def update_timeline_status(timeline_id: str, status: str, frame_count: int | None = None):
|
||||
"""Update timeline status and optionally frame count."""
|
||||
from core.db.models import Timeline
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
timeline = session.get(Timeline, UUID(timeline_id))
|
||||
if timeline:
|
||||
timeline.status = status
|
||||
if frame_count is not None:
|
||||
timeline.frame_count = frame_count
|
||||
session.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Checkpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def save_checkpoint(
|
||||
timeline_id: str,
|
||||
stage_name: str,
|
||||
parent_checkpoint_id: str | None = None,
|
||||
config_overrides: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
is_scenario: bool = False,
|
||||
scenario_label: str = "",
|
||||
job_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Save a checkpoint (lightweight tree node).
|
||||
|
||||
No stage outputs — those go in StageOutput table separately.
|
||||
Returns the new checkpoint ID.
|
||||
"""
|
||||
from core.db.models import Checkpoint
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
checkpoint = Checkpoint(
|
||||
timeline_id=UUID(timeline_id),
|
||||
job_id=UUID(job_id) if job_id else None,
|
||||
parent_id=UUID(parent_checkpoint_id) if parent_checkpoint_id else None,
|
||||
stage_name=stage_name,
|
||||
config_overrides=config_overrides or {},
|
||||
stats=stats or {},
|
||||
is_scenario=is_scenario,
|
||||
scenario_label=scenario_label,
|
||||
)
|
||||
session.add(checkpoint)
|
||||
session.commit()
|
||||
session.refresh(checkpoint)
|
||||
cid = str(checkpoint.id)
|
||||
|
||||
logger.info("Checkpoint saved: %s (timeline %s, stage %s, parent %s)",
|
||||
cid, timeline_id, stage_name, parent_checkpoint_id)
|
||||
return cid
|
||||
|
||||
|
||||
def get_checkpoints_for_job(job_id: str) -> list[dict]:
|
||||
"""List checkpoints for a job, ordered by creation time."""
|
||||
from sqlmodel import select
|
||||
from core.db.models import Checkpoint
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
stmt = (
|
||||
select(Checkpoint)
|
||||
.where(Checkpoint.job_id == UUID(job_id))
|
||||
.order_by(Checkpoint.created_at)
|
||||
)
|
||||
checkpoints = session.exec(stmt).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(c.id),
|
||||
"timeline_id": str(c.timeline_id),
|
||||
"job_id": str(c.job_id) if c.job_id else None,
|
||||
"parent_id": str(c.parent_id) if c.parent_id else None,
|
||||
"stage_name": c.stage_name,
|
||||
"config_overrides": c.config_overrides or {},
|
||||
"stats": c.stats or {},
|
||||
"is_scenario": c.is_scenario,
|
||||
"scenario_label": c.scenario_label,
|
||||
"created_at": str(c.created_at) if c.created_at else None,
|
||||
}
|
||||
for c in checkpoints
|
||||
]
|
||||
|
||||
|
||||
def get_checkpoints_for_timeline(timeline_id: str) -> list[dict]:
|
||||
"""List all checkpoints on a timeline, ordered by creation time."""
|
||||
from sqlmodel import select
|
||||
from core.db.models import Checkpoint
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
stmt = (
|
||||
select(Checkpoint)
|
||||
.where(Checkpoint.timeline_id == UUID(timeline_id))
|
||||
.order_by(Checkpoint.created_at)
|
||||
)
|
||||
checkpoints = session.exec(stmt).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(c.id),
|
||||
"timeline_id": str(c.timeline_id),
|
||||
"job_id": str(c.job_id) if c.job_id else None,
|
||||
"parent_id": str(c.parent_id) if c.parent_id else None,
|
||||
"stage_name": c.stage_name,
|
||||
"config_overrides": c.config_overrides or {},
|
||||
"stats": c.stats or {},
|
||||
"is_scenario": c.is_scenario,
|
||||
"scenario_label": c.scenario_label,
|
||||
"created_at": str(c.created_at) if c.created_at else None,
|
||||
}
|
||||
for c in checkpoints
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StageOutput
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def save_stage_output(
|
||||
job_id: str,
|
||||
timeline_id: str,
|
||||
stage_name: str,
|
||||
output: dict,
|
||||
checkpoint_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Save (upsert) a stage output. One row per (job_id, stage_name).
|
||||
|
||||
Returns the stage_output ID.
|
||||
"""
|
||||
from sqlmodel import select
|
||||
from core.db.models import StageOutput
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
# Upsert: check if exists
|
||||
stmt = (
|
||||
select(StageOutput)
|
||||
.where(StageOutput.job_id == UUID(job_id))
|
||||
.where(StageOutput.stage_name == stage_name)
|
||||
)
|
||||
existing = session.exec(stmt).first()
|
||||
|
||||
if existing:
|
||||
existing.output = output
|
||||
existing.checkpoint_id = UUID(checkpoint_id) if checkpoint_id else None
|
||||
session.commit()
|
||||
session.refresh(existing)
|
||||
return str(existing.id)
|
||||
|
||||
stage_output = StageOutput(
|
||||
job_id=UUID(job_id),
|
||||
timeline_id=UUID(timeline_id),
|
||||
stage_name=stage_name,
|
||||
checkpoint_id=UUID(checkpoint_id) if checkpoint_id else None,
|
||||
output=output,
|
||||
)
|
||||
session.add(stage_output)
|
||||
session.commit()
|
||||
session.refresh(stage_output)
|
||||
return str(stage_output.id)
|
||||
|
||||
|
||||
def load_stage_output(job_id: str, stage_name: str) -> dict | None:
|
||||
"""Load a stage's output by job + stage name."""
|
||||
from sqlmodel import select
|
||||
from core.db.models import StageOutput
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
stmt = (
|
||||
select(StageOutput)
|
||||
.where(StageOutput.job_id == UUID(job_id))
|
||||
.where(StageOutput.stage_name == stage_name)
|
||||
)
|
||||
row = session.exec(stmt).first()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
return row.output
|
||||
|
||||
|
||||
def load_stage_outputs_for_job(job_id: str) -> dict[str, dict]:
|
||||
"""Load all stage outputs for a job. Returns {stage_name: output}."""
|
||||
from sqlmodel import select
|
||||
from core.db.models import StageOutput
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
stmt = (
|
||||
select(StageOutput)
|
||||
.where(StageOutput.job_id == UUID(job_id))
|
||||
)
|
||||
rows = session.exec(stmt).all()
|
||||
|
||||
return {row.stage_name: row.output for row in rows}
|
||||
|
||||
|
||||
def load_stage_outputs_for_timeline(timeline_id: str, stage_name: str | None = None) -> list[dict]:
|
||||
"""Load stage outputs for a timeline, optionally filtered by stage."""
|
||||
from sqlmodel import select
|
||||
from core.db.models import StageOutput
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
stmt = select(StageOutput).where(StageOutput.timeline_id == UUID(timeline_id))
|
||||
if stage_name:
|
||||
stmt = stmt.where(StageOutput.stage_name == stage_name)
|
||||
rows = session.exec(stmt).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(r.id),
|
||||
"job_id": str(r.job_id),
|
||||
"stage_name": r.stage_name,
|
||||
"checkpoint_id": str(r.checkpoint_id) if r.checkpoint_id else None,
|
||||
"output": r.output,
|
||||
"created_at": str(r.created_at) if r.created_at else None,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
159
core/detect/emit.py
Normal file
159
core/detect/emit.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Event emission helpers for detection pipeline stages.
|
||||
|
||||
Single place that knows how to build event payloads.
|
||||
Stages call these instead of constructing dicts or dataclasses directly.
|
||||
|
||||
Run context (run_id, parent_job_id) is set once at pipeline start via
|
||||
set_run_context() and automatically injected into all events.
|
||||
|
||||
Log level is set per-run with optional per-stage overrides.
|
||||
DEBUG events are only pushed when the run (or stage) log level allows it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from core.detect.events import push_detect_event
|
||||
from core.detect.models import PipelineStats
|
||||
|
||||
# Log level ordering for comparison
|
||||
_LEVEL_ORDER = {"DEBUG": 0, "INFO": 1, "WARN": 2, "ERROR": 3}
|
||||
|
||||
# Module-level run context — set once per pipeline invocation
|
||||
_run_context: dict = {}
|
||||
_run_log_level: str = "INFO"
|
||||
_stage_log_levels: dict[str, str] = {} # stage_name → level override
|
||||
|
||||
|
||||
def set_run_context(
|
||||
run_id: str = "",
|
||||
parent_job_id: str = "",
|
||||
run_type: str = "initial",
|
||||
log_level: str = "INFO",
|
||||
):
|
||||
"""Set the run context for all subsequent events in this pipeline invocation."""
|
||||
global _run_context, _run_log_level
|
||||
_run_context = {
|
||||
"run_id": run_id,
|
||||
"parent_job_id": parent_job_id,
|
||||
"run_type": run_type,
|
||||
}
|
||||
_run_log_level = log_level.upper()
|
||||
_stage_log_levels.clear()
|
||||
|
||||
|
||||
def set_stage_log_level(stage: str, level: str):
|
||||
"""Override log level for a specific stage."""
|
||||
_stage_log_levels[stage] = level.upper()
|
||||
|
||||
|
||||
def clear_stage_log_level(stage: str):
|
||||
"""Remove per-stage log level override."""
|
||||
_stage_log_levels.pop(stage, None)
|
||||
|
||||
|
||||
def clear_run_context():
|
||||
global _run_context, _run_log_level
|
||||
_run_context = {}
|
||||
_run_log_level = "INFO"
|
||||
_stage_log_levels.clear()
|
||||
|
||||
|
||||
def _should_emit(level: str, stage: str) -> bool:
|
||||
"""Check if this log level should be emitted given run/stage settings."""
|
||||
effective = _stage_log_levels.get(stage, _run_log_level)
|
||||
return _LEVEL_ORDER.get(level.upper(), 1) >= _LEVEL_ORDER.get(effective, 1)
|
||||
|
||||
|
||||
def _inject_context(payload: dict) -> dict:
|
||||
"""Add run context fields to an event payload."""
|
||||
if _run_context:
|
||||
payload.update(_run_context)
|
||||
return payload
|
||||
|
||||
|
||||
def log(job_id: str | None, stage: str, level: str, msg: str) -> None:
|
||||
if not job_id:
|
||||
return
|
||||
if not _should_emit(level, stage):
|
||||
return
|
||||
payload = {
|
||||
"level": level,
|
||||
"stage": stage,
|
||||
"msg": msg,
|
||||
"ts": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
_inject_context(payload)
|
||||
push_detect_event(job_id, "log", payload)
|
||||
|
||||
|
||||
def stats(job_id: str | None, **kwargs) -> None:
|
||||
if not job_id:
|
||||
return
|
||||
s = PipelineStats(**kwargs)
|
||||
payload = dataclasses.asdict(s)
|
||||
_inject_context(payload)
|
||||
push_detect_event(job_id, "stats_update", payload)
|
||||
|
||||
|
||||
def frame_update(
|
||||
job_id: str | None,
|
||||
frame_ref: int,
|
||||
timestamp: float,
|
||||
jpeg_b64: str,
|
||||
boxes: list[dict],
|
||||
) -> None:
|
||||
if not job_id:
|
||||
return
|
||||
payload = {
|
||||
"frame_ref": frame_ref,
|
||||
"timestamp": timestamp,
|
||||
"jpeg_b64": jpeg_b64,
|
||||
"boxes": boxes,
|
||||
}
|
||||
_inject_context(payload)
|
||||
push_detect_event(job_id, "frame_update", payload)
|
||||
|
||||
|
||||
def graph_update(job_id: str | None, nodes: list[dict]) -> None:
|
||||
if not job_id:
|
||||
return
|
||||
payload = {"nodes": nodes}
|
||||
_inject_context(payload)
|
||||
push_detect_event(job_id, "graph_update", payload)
|
||||
|
||||
|
||||
def detection(
|
||||
job_id: str | None,
|
||||
brand: str,
|
||||
confidence: float,
|
||||
source: str,
|
||||
timestamp: float,
|
||||
duration: float = 0.0,
|
||||
content_type: str = "",
|
||||
frame_ref: int | None = None,
|
||||
) -> None:
|
||||
if not job_id:
|
||||
return
|
||||
payload = {
|
||||
"brand": brand,
|
||||
"confidence": confidence,
|
||||
"source": source,
|
||||
"timestamp": timestamp,
|
||||
"duration": duration,
|
||||
"content_type": content_type,
|
||||
"frame_ref": frame_ref,
|
||||
}
|
||||
_inject_context(payload)
|
||||
push_detect_event(job_id, "detection", payload)
|
||||
|
||||
|
||||
def job_complete(job_id: str | None, report: dict) -> None:
|
||||
if not job_id:
|
||||
return
|
||||
payload = {"job_id": job_id, "report": report}
|
||||
_inject_context(payload)
|
||||
push_detect_event(job_id, "job_complete", payload)
|
||||
42
core/detect/events.py
Normal file
42
core/detect/events.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
Detection pipeline event helpers.
|
||||
|
||||
Non-generated runtime code for pushing SSE events.
|
||||
The event payload types are in sse_contract.py (generated by modelgen).
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.events import push_event
|
||||
|
||||
DETECT_EVENTS_PREFIX = "detect_events"
|
||||
|
||||
# SSE event type names
|
||||
EVENT_GRAPH_UPDATE = "graph_update"
|
||||
EVENT_STATS_UPDATE = "stats_update"
|
||||
EVENT_FRAME_UPDATE = "frame_update"
|
||||
EVENT_DETECTION = "detection"
|
||||
EVENT_LOG = "log"
|
||||
EVENT_JOB_COMPLETE = "job_complete"
|
||||
|
||||
ALL_EVENT_TYPES = [
|
||||
EVENT_GRAPH_UPDATE,
|
||||
EVENT_STATS_UPDATE,
|
||||
EVENT_FRAME_UPDATE,
|
||||
EVENT_DETECTION,
|
||||
EVENT_LOG,
|
||||
EVENT_JOB_COMPLETE,
|
||||
]
|
||||
|
||||
TERMINAL_EVENTS = [EVENT_JOB_COMPLETE]
|
||||
|
||||
|
||||
def push_detect_event(job_id: str, event_type: str, data: BaseModel | dict) -> None:
|
||||
"""Push a detection event to Redis. Accepts Pydantic models or plain dicts."""
|
||||
payload = data.model_dump(mode="json") if isinstance(data, BaseModel) else data
|
||||
push_event(
|
||||
job_id=job_id,
|
||||
event_type=event_type,
|
||||
data=payload,
|
||||
prefix=DETECT_EVENTS_PREFIX,
|
||||
)
|
||||
45
core/detect/graph/__init__.py
Normal file
45
core/detect/graph/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Detection pipeline graph.
|
||||
|
||||
detect/graph/
|
||||
nodes.py — node functions (one per stage)
|
||||
events.py — graph_update SSE emission
|
||||
runner.py — PipelineRunner (config-driven, checkpoint, cancel, pause)
|
||||
"""
|
||||
|
||||
from .nodes import NODES, NODE_FUNCTIONS
|
||||
from .runner import (
|
||||
PipelineCancelled,
|
||||
PipelineRunner,
|
||||
build_graph,
|
||||
clear_cancel_check,
|
||||
clear_pause,
|
||||
get_pipeline,
|
||||
init_pause,
|
||||
is_paused,
|
||||
pause_pipeline,
|
||||
resume_pipeline,
|
||||
set_cancel_check,
|
||||
set_pause_after_stage,
|
||||
step_pipeline,
|
||||
)
|
||||
from .events import _node_states
|
||||
|
||||
__all__ = [
|
||||
"NODES",
|
||||
"NODE_FUNCTIONS",
|
||||
"PipelineCancelled",
|
||||
"PipelineRunner",
|
||||
"build_graph",
|
||||
"get_pipeline",
|
||||
"set_cancel_check",
|
||||
"clear_cancel_check",
|
||||
"init_pause",
|
||||
"clear_pause",
|
||||
"pause_pipeline",
|
||||
"resume_pipeline",
|
||||
"step_pipeline",
|
||||
"set_pause_after_stage",
|
||||
"is_paused",
|
||||
"_node_states",
|
||||
]
|
||||
27
core/detect/graph/events.py
Normal file
27
core/detect/graph/events.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
Graph event emission — node state tracking + SSE graph_update events.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.state import DetectState
|
||||
|
||||
|
||||
# Track node states across pipeline runs
|
||||
_node_states: dict[str, dict[str, str]] = {}
|
||||
|
||||
|
||||
def emit_transition(state: DetectState, node: str, status: str, node_list: list[str]):
|
||||
"""Update node status and emit graph_update SSE event."""
|
||||
job_id = state.get("job_id")
|
||||
if not job_id:
|
||||
return
|
||||
|
||||
if job_id not in _node_states:
|
||||
_node_states[job_id] = {n: "pending" for n in node_list}
|
||||
|
||||
_node_states[job_id][node] = status
|
||||
|
||||
nodes = [{"id": n, "status": _node_states[job_id][n]} for n in node_list]
|
||||
emit.graph_update(job_id, nodes)
|
||||
386
core/detect/graph/nodes.py
Normal file
386
core/detect/graph/nodes.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
Pipeline node functions — one per stage.
|
||||
|
||||
Each node: reads state, gets config from profile dict, runs stage logic,
|
||||
emits transitions, returns output dict.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import CropContext, PipelineStats
|
||||
from core.detect.profile import get_profile, get_stage_config, build_vlm_prompt, aggregate_detections
|
||||
from core.detect.stages.models import (
|
||||
DetectionConfig,
|
||||
FieldSegmentationConfig,
|
||||
FrameExtractionConfig,
|
||||
OCRConfig,
|
||||
RegionAnalysisConfig,
|
||||
ResolverConfig,
|
||||
SceneFilterConfig,
|
||||
)
|
||||
from core.detect.state import DetectState
|
||||
from core.detect.stages.frame_extractor import extract_frames
|
||||
from core.detect.stages.scene_filter import scene_filter
|
||||
from core.detect.stages.field_segmentation import run_field_segmentation
|
||||
from core.detect.stages.edge_detector import detect_edge_regions
|
||||
from core.detect.stages.yolo_detector import detect_objects
|
||||
from core.detect.stages.preprocess import preprocess_regions
|
||||
from core.detect.stages.ocr_stage import run_ocr
|
||||
from core.detect.stages.brand_resolver import resolve_brands
|
||||
from core.detect.stages.vlm_local import escalate_vlm
|
||||
from core.detect.stages.vlm_cloud import escalate_cloud
|
||||
from core.detect.stages.aggregator import compile_report
|
||||
from core.detect.tracing import trace_node, flush as flush_traces
|
||||
|
||||
from .events import emit_transition
|
||||
|
||||
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
|
||||
|
||||
NODES = [
|
||||
"extract_frames",
|
||||
"filter_scenes",
|
||||
"field_segmentation",
|
||||
"detect_edges",
|
||||
"detect_objects",
|
||||
"preprocess",
|
||||
"run_ocr",
|
||||
"match_brands",
|
||||
"escalate_vlm",
|
||||
"escalate_cloud",
|
||||
"compile_report",
|
||||
]
|
||||
|
||||
|
||||
def _load_profile(state: DetectState) -> dict:
|
||||
"""Load profile dict, apply config overrides if present."""
|
||||
name = state.get("profile_name", "soccer_broadcast")
|
||||
profile = get_profile(name)
|
||||
|
||||
overrides = state.get("config_overrides")
|
||||
if overrides:
|
||||
# Merge overrides into a copy of the profile configs
|
||||
merged_configs = dict(profile.get("configs", {}))
|
||||
for stage_name, stage_overrides in overrides.items():
|
||||
if stage_name in merged_configs:
|
||||
merged_configs[stage_name] = {**merged_configs[stage_name], **stage_overrides}
|
||||
else:
|
||||
merged_configs[stage_name] = stage_overrides
|
||||
profile = {**profile, "configs": merged_configs}
|
||||
|
||||
return profile
|
||||
|
||||
|
||||
def _emit(state, node, status):
|
||||
emit_transition(state, node, status, NODES)
|
||||
|
||||
|
||||
# --- Node functions ---
|
||||
|
||||
def node_extract_frames(state: DetectState) -> dict:
|
||||
job_id = state.get("job_id", "")
|
||||
if job_id and not emit._run_context:
|
||||
emit.set_run_context(run_id=job_id, parent_job_id=job_id, run_type="initial")
|
||||
|
||||
source_asset_id = state.get("source_asset_id")
|
||||
if source_asset_id and not state.get("session_brands"):
|
||||
from core.detect.stages.brand_resolver import build_session_dict
|
||||
session_brands = build_session_dict(source_asset_id)
|
||||
state["session_brands"] = session_brands
|
||||
|
||||
_emit(state, "extract_frames", "running")
|
||||
|
||||
with trace_node(state, "extract_frames") as span:
|
||||
profile = _load_profile(state)
|
||||
config = FrameExtractionConfig(**get_stage_config(profile, "extract_frames"))
|
||||
frames = extract_frames(state["video_path"], config, job_id=job_id)
|
||||
span.set_output({"frames_extracted": len(frames)})
|
||||
|
||||
# Cache frames on the timeline for reuse across jobs and UI
|
||||
timeline_id = state.get("timeline_id")
|
||||
if timeline_id:
|
||||
from core.detect.checkpoint.frames import cache_frames, cache_exists
|
||||
if not cache_exists(timeline_id):
|
||||
cache_frames(timeline_id, frames)
|
||||
from core.detect.checkpoint.storage import update_timeline_status
|
||||
update_timeline_status(timeline_id, "cached", frame_count=len(frames))
|
||||
|
||||
_emit(state, "extract_frames", "done")
|
||||
return {"frames": frames, "stats": PipelineStats(frames_extracted=len(frames))}
|
||||
|
||||
|
||||
def node_filter_scenes(state: DetectState) -> dict:
|
||||
_emit(state, "filter_scenes", "running")
|
||||
|
||||
with trace_node(state, "filter_scenes") as span:
|
||||
profile = _load_profile(state)
|
||||
config = SceneFilterConfig(**get_stage_config(profile, "filter_scenes"))
|
||||
frames = state.get("frames", [])
|
||||
kept = scene_filter(frames, config, job_id=state.get("job_id"))
|
||||
span.set_output({"frames_in": len(frames), "frames_kept": len(kept)})
|
||||
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.frames_after_scene_filter = len(kept)
|
||||
|
||||
_emit(state, "filter_scenes", "done")
|
||||
return {"filtered_frames": kept, "stats": stats}
|
||||
|
||||
|
||||
def node_field_segmentation(state: DetectState) -> dict:
|
||||
_emit(state, "field_segmentation", "running")
|
||||
|
||||
with trace_node(state, "field_segmentation") as span:
|
||||
profile = _load_profile(state)
|
||||
config = FieldSegmentationConfig(**get_stage_config(profile, "field_segmentation"))
|
||||
frames = state.get("filtered_frames", [])
|
||||
job_id = state.get("job_id")
|
||||
|
||||
result = run_field_segmentation(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
|
||||
span.set_output({
|
||||
"frames": len(frames),
|
||||
"avg_coverage": sum(result["field_coverage"].values()) / max(len(result["field_coverage"]), 1),
|
||||
})
|
||||
|
||||
_emit(state, "field_segmentation", "done")
|
||||
return {
|
||||
"field_masks": result["field_masks"],
|
||||
"field_mask_overlays": result.get("field_mask_overlays", {}),
|
||||
"field_boundaries": result["field_boundaries"],
|
||||
"field_coverage": result["field_coverage"],
|
||||
}
|
||||
|
||||
|
||||
def node_detect_edges(state: DetectState) -> dict:
|
||||
_emit(state, "detect_edges", "running")
|
||||
|
||||
with trace_node(state, "detect_edges") as span:
|
||||
profile = _load_profile(state)
|
||||
config = RegionAnalysisConfig(**get_stage_config(profile, "detect_edges"))
|
||||
frames = state.get("filtered_frames", [])
|
||||
field_masks = state.get("field_masks", {})
|
||||
job_id = state.get("job_id")
|
||||
|
||||
# Apply edge transforms from upstream connections
|
||||
edge_transforms = state.get("_edge_transforms", {})
|
||||
for source_stage, transform in edge_transforms.items():
|
||||
if transform.get("invert_mask") and field_masks:
|
||||
import numpy as np
|
||||
field_masks = {
|
||||
seq: np.bitwise_not(mask) if mask is not None else None
|
||||
for seq, mask in field_masks.items()
|
||||
}
|
||||
|
||||
regions = detect_edge_regions(
|
||||
frames, config, inference_url=INFERENCE_URL, job_id=job_id,
|
||||
field_masks=field_masks,
|
||||
)
|
||||
total = sum(len(r) for r in regions.values())
|
||||
span.set_output({"frames": len(frames), "edge_regions": total})
|
||||
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.cv_regions_detected = total
|
||||
|
||||
_emit(state, "detect_edges", "done")
|
||||
return {"edge_regions_by_frame": regions, "stats": stats}
|
||||
|
||||
|
||||
def node_detect_objects(state: DetectState) -> dict:
|
||||
_emit(state, "detect_objects", "running")
|
||||
|
||||
with trace_node(state, "detect_objects") as span:
|
||||
profile = _load_profile(state)
|
||||
config = DetectionConfig(**get_stage_config(profile, "detect_objects"))
|
||||
frames = state.get("filtered_frames", [])
|
||||
job_id = state.get("job_id")
|
||||
|
||||
all_boxes = detect_objects(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
|
||||
total_regions = sum(len(boxes) for boxes in all_boxes.values())
|
||||
span.set_output({"frames": len(frames), "regions_detected": total_regions})
|
||||
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.regions_detected = total_regions
|
||||
|
||||
_emit(state, "detect_objects", "done")
|
||||
return {"boxes_by_frame": all_boxes, "stats": stats}
|
||||
|
||||
|
||||
def node_preprocess(state: DetectState) -> dict:
|
||||
_emit(state, "preprocess", "running")
|
||||
|
||||
with trace_node(state, "preprocess") as span:
|
||||
profile = _load_profile(state)
|
||||
prep_config = get_stage_config(profile, "preprocess")
|
||||
frames = state.get("filtered_frames", [])
|
||||
boxes = state.get("boxes_by_frame", {})
|
||||
job_id = state.get("job_id")
|
||||
|
||||
do_contrast = prep_config.get("contrast", True)
|
||||
do_deskew = prep_config.get("deskew", False)
|
||||
do_binarize = prep_config.get("binarize", False)
|
||||
|
||||
preprocessed = preprocess_regions(
|
||||
frames, boxes,
|
||||
do_contrast=do_contrast,
|
||||
do_deskew=do_deskew,
|
||||
do_binarize=do_binarize,
|
||||
inference_url=INFERENCE_URL,
|
||||
job_id=job_id,
|
||||
)
|
||||
span.set_output({"regions_preprocessed": len(preprocessed)})
|
||||
|
||||
_emit(state, "preprocess", "done")
|
||||
return {"preprocessed_crops": preprocessed}
|
||||
|
||||
|
||||
def node_run_ocr(state: DetectState) -> dict:
|
||||
_emit(state, "run_ocr", "running")
|
||||
|
||||
with trace_node(state, "run_ocr") as span:
|
||||
profile = _load_profile(state)
|
||||
config = OCRConfig(**get_stage_config(profile, "run_ocr"))
|
||||
frames = state.get("filtered_frames", [])
|
||||
boxes = state.get("boxes_by_frame", {})
|
||||
job_id = state.get("job_id")
|
||||
|
||||
candidates = run_ocr(frames, boxes, config, inference_url=INFERENCE_URL, job_id=job_id)
|
||||
span.set_output({"regions_in": sum(len(b) for b in boxes.values()), "text_candidates": len(candidates)})
|
||||
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.regions_resolved_by_ocr = len(candidates)
|
||||
|
||||
_emit(state, "run_ocr", "done")
|
||||
return {"text_candidates": candidates, "stats": stats}
|
||||
|
||||
|
||||
def node_match_brands(state: DetectState) -> dict:
|
||||
_emit(state, "match_brands", "running")
|
||||
|
||||
with trace_node(state, "match_brands") as span:
|
||||
profile = _load_profile(state)
|
||||
config = ResolverConfig(**get_stage_config(profile, "match_brands"))
|
||||
candidates = state.get("text_candidates", [])
|
||||
session_brands = state.get("session_brands", {})
|
||||
job_id = state.get("job_id")
|
||||
source_asset_id = state.get("source_asset_id")
|
||||
|
||||
matched, unresolved = resolve_brands(
|
||||
candidates, config,
|
||||
session_brands=session_brands,
|
||||
source_asset_id=source_asset_id,
|
||||
content_type=profile["name"], job_id=job_id,
|
||||
)
|
||||
span.set_output({"matched": len(matched), "unresolved": len(unresolved)})
|
||||
|
||||
_emit(state, "match_brands", "done")
|
||||
return {"detections": matched, "unresolved_candidates": unresolved}
|
||||
|
||||
|
||||
def node_escalate_vlm(state: DetectState) -> dict:
|
||||
_emit(state, "escalate_vlm", "running")
|
||||
|
||||
with trace_node(state, "escalate_vlm") as span:
|
||||
profile = _load_profile(state)
|
||||
vlm_config = get_stage_config(profile, "escalate_vlm")
|
||||
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
|
||||
candidates = state.get("unresolved_candidates", [])
|
||||
job_id = state.get("job_id")
|
||||
|
||||
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
|
||||
|
||||
vlm_matched, still_unresolved = escalate_vlm(
|
||||
candidates,
|
||||
vlm_prompt_fn=vlm_prompt_fn,
|
||||
inference_url=INFERENCE_URL,
|
||||
content_type=profile["name"],
|
||||
source_asset_id=state.get("source_asset_id"),
|
||||
job_id=job_id,
|
||||
)
|
||||
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.regions_escalated_to_local_vlm = len(candidates)
|
||||
span.set_output({"candidates": len(candidates), "matched": len(vlm_matched),
|
||||
"still_unresolved": len(still_unresolved)})
|
||||
|
||||
existing = state.get("detections", [])
|
||||
|
||||
vlm_skipped = os.environ.get("SKIP_VLM", "").strip() == "1"
|
||||
_emit(state, "escalate_vlm", "skipped" if vlm_skipped else "done")
|
||||
return {
|
||||
"detections": existing + vlm_matched,
|
||||
"unresolved_candidates": still_unresolved,
|
||||
"stats": stats,
|
||||
}
|
||||
|
||||
|
||||
def node_escalate_cloud(state: DetectState) -> dict:
|
||||
_emit(state, "escalate_cloud", "running")
|
||||
|
||||
with trace_node(state, "escalate_cloud") as span:
|
||||
profile = _load_profile(state)
|
||||
vlm_config = get_stage_config(profile, "escalate_vlm")
|
||||
vlm_template = vlm_config.get("vlm_prompt_template", "Identify the brand in this image.")
|
||||
candidates = state.get("unresolved_candidates", [])
|
||||
job_id = state.get("job_id")
|
||||
stats = state.get("stats", PipelineStats())
|
||||
|
||||
vlm_prompt_fn = lambda ctx: build_vlm_prompt(ctx, vlm_template)
|
||||
|
||||
cloud_matched = escalate_cloud(
|
||||
candidates,
|
||||
vlm_prompt_fn=vlm_prompt_fn,
|
||||
stats=stats,
|
||||
content_type=profile["name"],
|
||||
source_asset_id=state.get("source_asset_id"),
|
||||
job_id=job_id,
|
||||
)
|
||||
|
||||
span.set_output({"candidates": len(candidates), "matched": len(cloud_matched),
|
||||
"cloud_calls": stats.cloud_llm_calls,
|
||||
"cost_usd": stats.estimated_cloud_cost_usd})
|
||||
|
||||
existing = state.get("detections", [])
|
||||
|
||||
cloud_skipped = os.environ.get("SKIP_CLOUD", "").strip() == "1"
|
||||
_emit(state, "escalate_cloud", "skipped" if cloud_skipped else "done")
|
||||
return {"detections": existing + cloud_matched, "stats": stats}
|
||||
|
||||
|
||||
def node_compile_report(state: DetectState) -> dict:
|
||||
_emit(state, "compile_report", "running")
|
||||
|
||||
with trace_node(state, "compile_report") as span:
|
||||
profile = _load_profile(state)
|
||||
detections = state.get("detections", [])
|
||||
stats = state.get("stats", PipelineStats())
|
||||
job_id = state.get("job_id")
|
||||
|
||||
report = compile_report(
|
||||
detections=detections,
|
||||
stats=stats,
|
||||
video_source=state.get("video_path", ""),
|
||||
content_type=profile["name"],
|
||||
job_id=job_id,
|
||||
)
|
||||
|
||||
span.set_output({"brands": len(report.brands), "detections": len(report.timeline)})
|
||||
|
||||
flush_traces()
|
||||
_emit(state, "compile_report", "done")
|
||||
return {"report": report}
|
||||
|
||||
|
||||
NODE_FUNCTIONS = [
|
||||
("extract_frames", node_extract_frames),
|
||||
("filter_scenes", node_filter_scenes),
|
||||
("field_segmentation", node_field_segmentation),
|
||||
("detect_edges", node_detect_edges),
|
||||
("detect_objects", node_detect_objects),
|
||||
("preprocess", node_preprocess),
|
||||
("run_ocr", node_run_ocr),
|
||||
("match_brands", node_match_brands),
|
||||
("escalate_vlm", node_escalate_vlm),
|
||||
("escalate_cloud", node_escalate_cloud),
|
||||
("compile_report", node_compile_report),
|
||||
]
|
||||
289
core/detect/graph/runner.py
Normal file
289
core/detect/graph/runner.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
Pipeline runner — executes stages sequentially with checkpointing,
|
||||
cancellation, and pause/resume.
|
||||
|
||||
Reads PipelineConfig from the profile to determine what stages to run.
|
||||
Flattens the graph into a linear sequence for now (serial execution).
|
||||
Executor socket: all stages run via LocalExecutor (call function directly).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
|
||||
from core.detect.stages.models import PipelineConfig
|
||||
from core.detect.state import DetectState
|
||||
from .nodes import NODES, NODE_FUNCTIONS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_CHECKPOINT_ENABLED = os.environ.get("MPR_CHECKPOINT", "").strip() == "1"
|
||||
|
||||
|
||||
class PipelineCancelled(Exception):
|
||||
"""Raised when a pipeline run is cancelled."""
|
||||
pass
|
||||
|
||||
|
||||
class PipelinePaused(Exception):
|
||||
"""Raised when a pipeline is paused (internally, for flow control)."""
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cancellation — checked before each node
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_cancel_check: dict[str, callable] = {}
|
||||
|
||||
|
||||
def set_cancel_check(job_id: str, fn):
|
||||
_cancel_check[job_id] = fn
|
||||
|
||||
|
||||
def clear_cancel_check(job_id: str):
|
||||
_cancel_check.pop(job_id, None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pause / Resume / Step — checked after each node completes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_pause_gate: dict[str, threading.Event] = {}
|
||||
_pause_after_stage: dict[str, bool] = {}
|
||||
|
||||
|
||||
def init_pause(job_id: str, pause_after_stage: bool = False):
|
||||
"""Initialize pause state for a job. Called when pipeline starts."""
|
||||
gate = threading.Event()
|
||||
gate.set() # start unpaused
|
||||
_pause_gate[job_id] = gate
|
||||
_pause_after_stage[job_id] = pause_after_stage
|
||||
|
||||
|
||||
def clear_pause(job_id: str):
|
||||
"""Clean up pause state. Called when pipeline finishes."""
|
||||
_pause_gate.pop(job_id, None)
|
||||
_pause_after_stage.pop(job_id, None)
|
||||
|
||||
|
||||
def pause_pipeline(job_id: str):
|
||||
"""Pause a running pipeline. It will block after the current stage completes."""
|
||||
gate = _pause_gate.get(job_id)
|
||||
if gate:
|
||||
gate.clear()
|
||||
logger.info("Pipeline %s paused", job_id)
|
||||
|
||||
|
||||
def resume_pipeline(job_id: str):
|
||||
"""Resume a paused pipeline."""
|
||||
gate = _pause_gate.get(job_id)
|
||||
if gate:
|
||||
gate.set()
|
||||
logger.info("Pipeline %s resumed", job_id)
|
||||
|
||||
|
||||
def step_pipeline(job_id: str):
|
||||
"""Run one stage then pause again."""
|
||||
_pause_after_stage[job_id] = True
|
||||
gate = _pause_gate.get(job_id)
|
||||
if gate:
|
||||
gate.set()
|
||||
logger.info("Pipeline %s stepping", job_id)
|
||||
|
||||
|
||||
def set_pause_after_stage(job_id: str, enabled: bool):
|
||||
"""Toggle pause-after-each-stage mode."""
|
||||
_pause_after_stage[job_id] = enabled
|
||||
if not enabled:
|
||||
gate = _pause_gate.get(job_id)
|
||||
if gate:
|
||||
gate.set()
|
||||
|
||||
|
||||
def is_paused(job_id: str) -> bool:
|
||||
"""Check if a pipeline is currently paused."""
|
||||
gate = _pause_gate.get(job_id)
|
||||
return gate is not None and not gate.is_set()
|
||||
|
||||
|
||||
def _wait_if_paused(job_id: str, node_name: str):
|
||||
"""Block until resumed. Called after each node completes."""
|
||||
gate = _pause_gate.get(job_id)
|
||||
if gate is None:
|
||||
return
|
||||
|
||||
if _pause_after_stage.get(job_id, False):
|
||||
gate.clear()
|
||||
from core.detect import emit
|
||||
emit.log(job_id, "Pipeline", "INFO", f"Paused after {node_name}")
|
||||
|
||||
while not gate.wait(timeout=0.5):
|
||||
check = _cancel_check.get(job_id)
|
||||
if check and check():
|
||||
raise PipelineCancelled(f"Cancelled while paused before next stage")
|
||||
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pipeline Runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Node function lookup — maps stage name to callable
|
||||
_NODE_FN_MAP: dict[str, callable] = {name: fn for name, fn in NODE_FUNCTIONS}
|
||||
|
||||
|
||||
def _flatten_config(config: PipelineConfig, start_from: str | None = None) -> list[str]:
|
||||
"""
|
||||
Flatten a PipelineConfig into a linear stage sequence.
|
||||
|
||||
For now: topological sort via edges. Falls back to stage order if no edges.
|
||||
Respects start_from for replay (skip stages before it).
|
||||
"""
|
||||
if not config.edges:
|
||||
# No edges defined — use stage order as-is
|
||||
names = [s.name for s in config.stages]
|
||||
else:
|
||||
# Topological sort from edges
|
||||
graph: dict[str, list[str]] = {}
|
||||
in_degree: dict[str, int] = {}
|
||||
stage_names = {s.name for s in config.stages}
|
||||
|
||||
for name in stage_names:
|
||||
graph[name] = []
|
||||
in_degree[name] = 0
|
||||
|
||||
for edge in config.edges:
|
||||
if edge.source in stage_names and edge.target in stage_names:
|
||||
graph[edge.source].append(edge.target)
|
||||
in_degree[edge.target] = in_degree.get(edge.target, 0) + 1
|
||||
|
||||
# Kahn's algorithm
|
||||
queue = [n for n in stage_names if in_degree.get(n, 0) == 0]
|
||||
# Stable sort: prefer order from config.stages
|
||||
stage_order = {s.name: i for i, s in enumerate(config.stages)}
|
||||
queue.sort(key=lambda n: stage_order.get(n, 999))
|
||||
|
||||
names = []
|
||||
while queue:
|
||||
node = queue.pop(0)
|
||||
names.append(node)
|
||||
for neighbor in graph.get(node, []):
|
||||
in_degree[neighbor] -= 1
|
||||
if in_degree[neighbor] == 0:
|
||||
queue.append(neighbor)
|
||||
queue.sort(key=lambda n: stage_order.get(n, 999))
|
||||
|
||||
if start_from:
|
||||
try:
|
||||
idx = names.index(start_from)
|
||||
names = names[idx:]
|
||||
except ValueError:
|
||||
raise ValueError(f"Stage {start_from!r} not in pipeline config")
|
||||
|
||||
return names
|
||||
|
||||
|
||||
class PipelineRunner:
|
||||
"""
|
||||
Executes a pipeline defined by PipelineConfig.
|
||||
|
||||
Runs stages sequentially (flattened). Each stage:
|
||||
1. Check cancel
|
||||
2. Run node function (via executor — local for now)
|
||||
3. Merge result into state
|
||||
4. Checkpoint (if enabled)
|
||||
5. Check pause
|
||||
|
||||
Executor socket: currently calls node functions directly.
|
||||
Future: dispatch to LocalExecutor / GrpcExecutor / LambdaExecutor
|
||||
based on StageRef.execution_target.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PipelineConfig,
|
||||
checkpoint: bool = False,
|
||||
start_from: str | None = None,
|
||||
):
|
||||
self.config = config
|
||||
self.do_checkpoint = checkpoint
|
||||
self.stage_sequence = _flatten_config(config, start_from)
|
||||
# Build edge transform lookup: {target_stage: {source_stage: transform_dict}}
|
||||
self._edge_transforms: dict[str, dict[str, dict]] = {}
|
||||
for edge in config.edges:
|
||||
if edge.transform:
|
||||
if edge.target not in self._edge_transforms:
|
||||
self._edge_transforms[edge.target] = {}
|
||||
self._edge_transforms[edge.target][edge.source] = edge.transform
|
||||
|
||||
def invoke(self, state: DetectState) -> DetectState:
|
||||
"""Run the pipeline on the given state. Returns final state."""
|
||||
for stage_name in self.stage_sequence:
|
||||
job_id = state.get("job_id", "")
|
||||
|
||||
# 1. Cancel check
|
||||
check = _cancel_check.get(job_id)
|
||||
if check and check():
|
||||
raise PipelineCancelled(f"Cancelled before {stage_name}")
|
||||
|
||||
# Inject edge transforms into state so the stage can read them.
|
||||
# Compatible with LangGraph — just a state dict key.
|
||||
transforms = self._edge_transforms.get(stage_name, {})
|
||||
if transforms:
|
||||
state["_edge_transforms"] = transforms
|
||||
elif "_edge_transforms" in state:
|
||||
del state["_edge_transforms"]
|
||||
|
||||
# 2. Run node function
|
||||
node_fn = _NODE_FN_MAP.get(stage_name)
|
||||
if node_fn is None:
|
||||
logger.warning("No node function for stage %s, skipping", stage_name)
|
||||
continue
|
||||
|
||||
result = node_fn(state)
|
||||
|
||||
# 3. Merge result into state
|
||||
state.update(result)
|
||||
|
||||
# 4. Checkpoint
|
||||
if self.do_checkpoint:
|
||||
from core.detect.checkpoint import checkpoint_after_stage
|
||||
checkpoint_after_stage(job_id, stage_name, state, result)
|
||||
|
||||
# 5. Pause check
|
||||
_wait_if_paused(job_id, stage_name)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API — backwards compatible with old get_pipeline/build_graph
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def get_pipeline(
|
||||
checkpoint: bool | None = None,
|
||||
profile_name: str = "soccer_broadcast",
|
||||
start_from: str | None = None,
|
||||
) -> PipelineRunner:
|
||||
"""Return a PipelineRunner for the given profile."""
|
||||
from core.detect.profile import get_profile, pipeline_config_from_dict
|
||||
|
||||
do_checkpoint = checkpoint if checkpoint is not None else _CHECKPOINT_ENABLED
|
||||
profile = get_profile(profile_name)
|
||||
config = pipeline_config_from_dict(profile["pipeline"])
|
||||
|
||||
return PipelineRunner(
|
||||
config=config,
|
||||
checkpoint=do_checkpoint,
|
||||
start_from=start_from,
|
||||
)
|
||||
|
||||
|
||||
def build_graph(checkpoint: bool | None = None, start_from: str | None = None):
|
||||
"""Backwards-compatible wrapper. Returns a PipelineRunner."""
|
||||
return get_pipeline(checkpoint=checkpoint, start_from=start_from)
|
||||
4
core/detect/inference/__init__.py
Normal file
4
core/detect/inference/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .client import InferenceClient
|
||||
from .types import DetectResult, OCRResult, VLMResult
|
||||
|
||||
__all__ = ["InferenceClient", "DetectResult", "OCRResult", "VLMResult"]
|
||||
262
core/detect/inference/client.py
Normal file
262
core/detect/inference/client.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""
|
||||
HTTP client for the inference server.
|
||||
|
||||
The pipeline stages call this instead of importing ML libraries directly.
|
||||
The inference server runs on the GPU machine (or spot instance).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from .types import DetectResult, OCRResult, RegionDebugResult, RegionResult, ServerStatus, VLMResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_URL = os.environ.get("INFERENCE_URL", "http://localhost:8000")
|
||||
|
||||
|
||||
def _encode_image(image: np.ndarray) -> str:
|
||||
"""Encode numpy array as base64 JPEG."""
|
||||
img = Image.fromarray(image)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=85)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
class InferenceClient:
|
||||
"""HTTP client for the GPU inference server."""
|
||||
|
||||
def __init__(self, base_url: str | None = None, timeout: float = 60.0,
|
||||
job_id: str = "", log_level: str = "INFO"):
|
||||
self.base_url = (base_url or DEFAULT_URL).rstrip("/")
|
||||
self.timeout = timeout
|
||||
self.job_id = job_id
|
||||
self.log_level = log_level
|
||||
self.session = requests.Session()
|
||||
if job_id:
|
||||
self.session.headers["X-Job-Id"] = job_id
|
||||
self.session.headers["X-Log-Level"] = log_level
|
||||
|
||||
def health(self) -> ServerStatus:
|
||||
"""Check server health and loaded models."""
|
||||
resp = self.session.get(f"{self.base_url}/health", timeout=self.timeout)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return ServerStatus(
|
||||
loaded_models=data.get("loaded_models", []),
|
||||
vram_used_mb=data.get("vram_used_mb", 0),
|
||||
vram_budget_mb=data.get("vram_budget_mb", 0),
|
||||
strategy=data.get("strategy", "sequential"),
|
||||
)
|
||||
|
||||
def detect(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
model: str = "yolov8n",
|
||||
confidence: float = 0.3,
|
||||
target_classes: list[str] | None = None,
|
||||
) -> list[DetectResult]:
|
||||
"""Run object detection on an image."""
|
||||
payload = {
|
||||
"image": _encode_image(image),
|
||||
"model": model,
|
||||
"confidence": confidence,
|
||||
}
|
||||
if target_classes:
|
||||
payload["target_classes"] = target_classes
|
||||
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}/detect",
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
results = []
|
||||
for d in resp.json().get("detections", []):
|
||||
result = DetectResult(
|
||||
x=d["x"], y=d["y"], w=d["w"], h=d["h"],
|
||||
confidence=d["confidence"], label=d["label"],
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def ocr(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
languages: list[str] | None = None,
|
||||
) -> list[OCRResult]:
|
||||
"""Run OCR on an image region."""
|
||||
payload = {
|
||||
"image": _encode_image(image),
|
||||
}
|
||||
if languages:
|
||||
payload["languages"] = languages
|
||||
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}/ocr",
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
results = []
|
||||
for d in resp.json().get("results", []):
|
||||
result = OCRResult(
|
||||
text=d["text"],
|
||||
confidence=d["confidence"],
|
||||
bbox=tuple(d["bbox"]),
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def vlm(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
prompt: str,
|
||||
model: str = "moondream2",
|
||||
) -> VLMResult:
|
||||
"""Query a visual language model with an image crop + prompt."""
|
||||
payload = {
|
||||
"image": _encode_image(image),
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}/vlm",
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
data = resp.json()
|
||||
return VLMResult(
|
||||
brand=data.get("brand", ""),
|
||||
confidence=data.get("confidence", 0.0),
|
||||
reasoning=data.get("reasoning", ""),
|
||||
)
|
||||
|
||||
def detect_edges(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
edge_canny_low: int = 50,
|
||||
edge_canny_high: int = 150,
|
||||
edge_hough_threshold: int = 80,
|
||||
edge_hough_min_length: int = 100,
|
||||
edge_hough_max_gap: int = 10,
|
||||
edge_pair_max_distance: int = 200,
|
||||
edge_pair_min_distance: int = 15,
|
||||
) -> list[RegionResult]:
|
||||
"""Run edge detection on an image."""
|
||||
payload = {
|
||||
"image": _encode_image(image),
|
||||
"edge_canny_low": edge_canny_low,
|
||||
"edge_canny_high": edge_canny_high,
|
||||
"edge_hough_threshold": edge_hough_threshold,
|
||||
"edge_hough_min_length": edge_hough_min_length,
|
||||
"edge_hough_max_gap": edge_hough_max_gap,
|
||||
"edge_pair_max_distance": edge_pair_max_distance,
|
||||
"edge_pair_min_distance": edge_pair_min_distance,
|
||||
}
|
||||
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}/detect_edges",
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
results = []
|
||||
for r in resp.json().get("regions", []):
|
||||
result = RegionResult(
|
||||
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
|
||||
confidence=r["confidence"], label=r["label"],
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def detect_edges_debug(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
edge_canny_low: int = 50,
|
||||
edge_canny_high: int = 150,
|
||||
edge_hough_threshold: int = 80,
|
||||
edge_hough_min_length: int = 100,
|
||||
edge_hough_max_gap: int = 10,
|
||||
edge_pair_max_distance: int = 200,
|
||||
edge_pair_min_distance: int = 15,
|
||||
) -> RegionDebugResult:
|
||||
"""Run edge detection with debug overlays."""
|
||||
payload = {
|
||||
"image": _encode_image(image),
|
||||
"edge_canny_low": edge_canny_low,
|
||||
"edge_canny_high": edge_canny_high,
|
||||
"edge_hough_threshold": edge_hough_threshold,
|
||||
"edge_hough_min_length": edge_hough_min_length,
|
||||
"edge_hough_max_gap": edge_hough_max_gap,
|
||||
"edge_pair_max_distance": edge_pair_max_distance,
|
||||
"edge_pair_min_distance": edge_pair_min_distance,
|
||||
}
|
||||
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}/detect_edges/debug",
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
data = resp.json()
|
||||
regions = []
|
||||
for r in data.get("regions", []):
|
||||
region = RegionResult(
|
||||
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
|
||||
confidence=r["confidence"], label=r["label"],
|
||||
)
|
||||
regions.append(region)
|
||||
|
||||
return RegionDebugResult(
|
||||
regions=regions,
|
||||
edge_overlay_b64=data.get("edge_overlay_b64", ""),
|
||||
lines_overlay_b64=data.get("lines_overlay_b64", ""),
|
||||
horizontal_count=data.get("horizontal_count", 0),
|
||||
pair_count=data.get("pair_count", 0),
|
||||
)
|
||||
|
||||
def post(self, path: str, payload: dict) -> dict | None:
|
||||
"""Generic POST to the inference server. Returns JSON response or None on error."""
|
||||
try:
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}{path}",
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except Exception as e:
|
||||
logger.warning("Inference POST %s failed: %s", path, e)
|
||||
return None
|
||||
|
||||
def load_model(self, model: str, quantization: str = "fp16") -> None:
|
||||
"""Request the server to load a model into VRAM."""
|
||||
self.session.post(
|
||||
f"{self.base_url}/models/load",
|
||||
json={"model": model, "quantization": quantization},
|
||||
timeout=self.timeout,
|
||||
).raise_for_status()
|
||||
|
||||
def unload_model(self, model: str) -> None:
|
||||
"""Request the server to unload a model from VRAM."""
|
||||
self.session.post(
|
||||
f"{self.base_url}/models/unload",
|
||||
json={"model": model},
|
||||
timeout=self.timeout,
|
||||
).raise_for_status()
|
||||
76
core/detect/inference/types.py
Normal file
76
core/detect/inference/types.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
Inference response types.
|
||||
|
||||
These are the shapes returned by the inference server.
|
||||
Kept separate from core.detect.models to avoid coupling the
|
||||
inference protocol to pipeline internals.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectResult:
|
||||
"""Single object detection from YOLO or similar."""
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
"""Text extracted from a region."""
|
||||
text: str
|
||||
confidence: float
|
||||
bbox: tuple[int, int, int, int] # x, y, w, h
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLMResult:
|
||||
"""Visual language model response for a crop."""
|
||||
brand: str
|
||||
confidence: float
|
||||
reasoning: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegionResult:
|
||||
"""A candidate region from CV analysis."""
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegionDebugResult:
|
||||
"""CV region analysis with debug overlays."""
|
||||
regions: list[RegionResult] = field(default_factory=list)
|
||||
edge_overlay_b64: str = ""
|
||||
lines_overlay_b64: str = ""
|
||||
horizontal_count: int = 0
|
||||
pair_count: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Info about a loaded model."""
|
||||
name: str
|
||||
vram_mb: float
|
||||
quantization: str # fp32, fp16, int8, int4
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerStatus:
|
||||
"""Inference server health response."""
|
||||
loaded_models: list[ModelInfo] = field(default_factory=list)
|
||||
vram_used_mb: float = 0.0
|
||||
vram_budget_mb: float = 0.0
|
||||
strategy: str = "sequential" # sequential, concurrent, auto
|
||||
95
core/detect/models.py
Normal file
95
core/detect/models.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
Detection pipeline runtime models.
|
||||
|
||||
These are the data structures that flow between pipeline stages.
|
||||
They contain runtime types (np.ndarray) so they live here, not in
|
||||
core/schema/models/ (which is for modelgen source of truth).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class Frame:
|
||||
sequence: int
|
||||
chunk_id: int
|
||||
timestamp: float # position in video (seconds)
|
||||
image: np.ndarray
|
||||
perceptual_hash: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BoundingBox:
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextCandidate:
|
||||
frame: Frame
|
||||
bbox: BoundingBox
|
||||
text: str
|
||||
ocr_confidence: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrandDetection:
|
||||
brand: str
|
||||
timestamp: float
|
||||
duration: float
|
||||
confidence: float
|
||||
source: Literal["ocr", "local_vlm", "cloud_llm", "logo_match", "auxiliary"]
|
||||
bbox: BoundingBox | None = None
|
||||
frame_ref: int | None = None
|
||||
content_type: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrandStats:
|
||||
total_appearances: int = 0
|
||||
total_screen_time: float = 0.0
|
||||
avg_confidence: float = 0.0
|
||||
first_seen: float = 0.0
|
||||
last_seen: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineStats:
|
||||
frames_extracted: int = 0
|
||||
frames_after_scene_filter: int = 0
|
||||
cv_regions_detected: int = 0
|
||||
regions_detected: int = 0
|
||||
regions_resolved_by_ocr: int = 0
|
||||
regions_escalated_to_local_vlm: int = 0
|
||||
regions_escalated_to_cloud_llm: int = 0
|
||||
auxiliary_detections: int = 0
|
||||
cloud_llm_calls: int = 0
|
||||
processing_time_seconds: float = 0.0
|
||||
estimated_cloud_cost_usd: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionReport:
|
||||
video_source: str
|
||||
content_type: str
|
||||
duration_seconds: float
|
||||
brands: dict[str, BrandStats] = field(default_factory=dict)
|
||||
timeline: list[BrandDetection] = field(default_factory=list)
|
||||
pipeline_stats: PipelineStats = field(default_factory=PipelineStats)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CropContext:
|
||||
"""Runtime type — holds image bytes for VLM prompts."""
|
||||
image: bytes
|
||||
surrounding_text: str = ""
|
||||
position_hint: str = ""
|
||||
107
core/detect/profile.py
Normal file
107
core/detect/profile.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Profile registry and helpers.
|
||||
|
||||
Loads profile data from Postgres.
|
||||
A profile is a dict with keys: name, pipeline, configs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from core.detect.stages.models import PipelineConfig, StageRef, Edge
|
||||
from core.detect.models import (
|
||||
BrandDetection,
|
||||
BrandStats,
|
||||
CropContext,
|
||||
DetectionReport,
|
||||
PipelineStats,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_profile(name: str) -> Dict[str, Any]:
|
||||
"""Get a profile dict by name from the database."""
|
||||
from core.db.connection import get_session
|
||||
from core.db.models import Profile
|
||||
|
||||
with get_session() as session:
|
||||
row = session.query(Profile).filter(Profile.name == name).first()
|
||||
|
||||
if row is None:
|
||||
raise ValueError(f"Unknown profile: {name!r}")
|
||||
|
||||
return {
|
||||
"name": row.name,
|
||||
"pipeline": row.pipeline or {},
|
||||
"configs": row.configs or {},
|
||||
}
|
||||
|
||||
|
||||
def list_profiles() -> list[str]:
|
||||
"""List available profile names from the database."""
|
||||
from core.db.connection import get_session
|
||||
from core.db.models import Profile
|
||||
|
||||
with get_session() as session:
|
||||
rows = session.query(Profile.name).all()
|
||||
|
||||
return [r[0] for r in rows]
|
||||
|
||||
|
||||
def get_stage_config(profile: Dict[str, Any], stage_name: str) -> dict:
|
||||
"""Get config values for a stage from a profile."""
|
||||
return profile.get("configs", {}).get(stage_name, {})
|
||||
|
||||
|
||||
def pipeline_config_from_dict(data: Dict[str, Any]) -> PipelineConfig:
|
||||
"""Deserialize a PipelineConfig from a JSONB dict."""
|
||||
stages = [StageRef(**s) for s in data.get("stages", [])]
|
||||
edges = [Edge(**e) for e in data.get("edges", [])]
|
||||
return PipelineConfig(
|
||||
name=data.get("name", ""),
|
||||
profile_name=data.get("profile_name", ""),
|
||||
stages=stages,
|
||||
edges=edges,
|
||||
routing_rules=data.get("routing_rules", {}),
|
||||
)
|
||||
|
||||
|
||||
def build_vlm_prompt(crop_context: CropContext, template: str) -> str:
|
||||
"""Build a VLM prompt from a template and crop context."""
|
||||
hint = f" Position: {crop_context.position_hint}." if crop_context.position_hint else ""
|
||||
text = f" Nearby text: '{crop_context.surrounding_text}'." if crop_context.surrounding_text else ""
|
||||
return template.format(hint=hint, text=text)
|
||||
|
||||
|
||||
def aggregate_detections(
|
||||
detections: list[BrandDetection],
|
||||
content_type: str,
|
||||
) -> DetectionReport:
|
||||
"""Group detections by brand into a report."""
|
||||
brands: dict[str, BrandStats] = {}
|
||||
for d in detections:
|
||||
if d.brand not in brands:
|
||||
brands[d.brand] = BrandStats()
|
||||
s = brands[d.brand]
|
||||
s.total_appearances += 1
|
||||
s.total_screen_time += d.duration
|
||||
s.avg_confidence = (
|
||||
(s.avg_confidence * (s.total_appearances - 1) + d.confidence)
|
||||
/ s.total_appearances
|
||||
)
|
||||
if s.first_seen == 0.0 or d.timestamp < s.first_seen:
|
||||
s.first_seen = d.timestamp
|
||||
if d.timestamp > s.last_seen:
|
||||
s.last_seen = d.timestamp
|
||||
|
||||
return DetectionReport(
|
||||
video_source="",
|
||||
content_type=content_type,
|
||||
duration_seconds=0.0,
|
||||
brands=brands,
|
||||
timeline=sorted(detections, key=lambda d: d.timestamp),
|
||||
pipeline_stats=PipelineStats(),
|
||||
)
|
||||
58
core/detect/providers/__init__.py
Normal file
58
core/detect/providers/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Cloud LLM provider registry.
|
||||
|
||||
Select provider via CLOUD_LLM_PROVIDER env var.
|
||||
Each provider reads its own env vars for auth/config.
|
||||
|
||||
CLOUD_LLM_PROVIDER=groq → GROQ_API_KEY, GROQ_MODEL, GROQ_BASE_URL
|
||||
CLOUD_LLM_PROVIDER=gemini → GEMINI_API_KEY, GEMINI_MODEL
|
||||
CLOUD_LLM_PROVIDER=openai → OPENAI_API_KEY, OPENAI_MODEL, OPENAI_BASE_URL
|
||||
CLOUD_LLM_PROVIDER=claude → ANTHROPIC_API_KEY, CLAUDE_MODEL
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from .base import CloudProvider, ProviderResponse
|
||||
from .groq import GroqProvider
|
||||
from .gemini import GeminiProvider
|
||||
from .openai_compat import OpenAICompatProvider
|
||||
from .claude import ClaudeProvider
|
||||
|
||||
PROVIDERS: dict[str, type] = {
|
||||
"groq": GroqProvider,
|
||||
"gemini": GeminiProvider,
|
||||
"openai": OpenAICompatProvider,
|
||||
"claude": ClaudeProvider,
|
||||
}
|
||||
|
||||
_cached: CloudProvider | None = None
|
||||
|
||||
|
||||
def get_provider() -> CloudProvider:
|
||||
"""Get the configured cloud provider (cached after first call)."""
|
||||
global _cached
|
||||
if _cached is not None:
|
||||
return _cached
|
||||
|
||||
name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
|
||||
cls = PROVIDERS.get(name)
|
||||
if cls is None:
|
||||
raise ValueError(f"Unknown provider: {name!r}. Options: {list(PROVIDERS)}")
|
||||
|
||||
_cached = cls()
|
||||
return _cached
|
||||
|
||||
|
||||
def has_api_key() -> bool:
|
||||
"""Check if the configured provider has an API key set."""
|
||||
name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
|
||||
key_map = {
|
||||
"groq": "GROQ_API_KEY",
|
||||
"gemini": "GEMINI_API_KEY",
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"claude": "ANTHROPIC_API_KEY",
|
||||
}
|
||||
env_var = key_map.get(name, "")
|
||||
return bool(os.environ.get(env_var, ""))
|
||||
36
core/detect/providers/base.py
Normal file
36
core/detect/providers/base.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Cloud LLM provider protocol and model metadata."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Metadata for a cloud LLM model."""
|
||||
id: str
|
||||
vision: bool = True
|
||||
cost_per_input_token: float = 0.0
|
||||
cost_per_output_token: float = 0.0
|
||||
max_output_tokens: int = 4096
|
||||
notes: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderResponse:
|
||||
answer: str
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class CloudProvider(Protocol):
|
||||
"""
|
||||
Interface for cloud LLM providers.
|
||||
|
||||
Each provider handles its own auth, payload format, and response parsing.
|
||||
The pipeline only calls call() and reads the response.
|
||||
"""
|
||||
name: str
|
||||
models: dict[str, ModelInfo]
|
||||
|
||||
def call(self, image_b64: str, prompt: str) -> ProviderResponse: ...
|
||||
73
core/detect/providers/claude.py
Normal file
73
core/detect/providers/claude.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Anthropic Claude provider — uses the official SDK."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from .base import ModelInfo, ProviderResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Claude-specific env vars
|
||||
# ANTHROPIC_API_KEY is read by the SDK automatically
|
||||
CLAUDE_MODEL = os.environ.get("CLAUDE_MODEL", "claude-sonnet-4-20250514")
|
||||
|
||||
MODELS = {
|
||||
"claude-sonnet-4-20250514": ModelInfo(
|
||||
id="claude-sonnet-4-20250514",
|
||||
vision=True,
|
||||
cost_per_input_token=0.000003,
|
||||
cost_per_output_token=0.000015,
|
||||
notes="Best balance of quality/cost with vision",
|
||||
),
|
||||
"claude-haiku-4-5-20251001": ModelInfo(
|
||||
id="claude-haiku-4-5-20251001",
|
||||
vision=True,
|
||||
cost_per_input_token=0.0000008,
|
||||
cost_per_output_token=0.000004,
|
||||
notes="Fastest, cheapest, good for simple brand ID",
|
||||
),
|
||||
"claude-opus-4-6": ModelInfo(
|
||||
id="claude-opus-4-6",
|
||||
vision=True,
|
||||
cost_per_input_token=0.000015,
|
||||
cost_per_output_token=0.000075,
|
||||
notes="Highest quality, use for ambiguous cases",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class ClaudeProvider:
|
||||
name = "claude"
|
||||
models = MODELS
|
||||
|
||||
def __init__(self):
|
||||
from anthropic import Anthropic
|
||||
self.client = Anthropic()
|
||||
self.model = CLAUDE_MODEL
|
||||
|
||||
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
|
||||
message = self.client.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=150,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": image_b64,
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}],
|
||||
)
|
||||
|
||||
answer = message.content[0].text.strip()
|
||||
total_tokens = message.usage.input_tokens + message.usage.output_tokens
|
||||
|
||||
return ProviderResponse(answer=answer, total_tokens=total_tokens)
|
||||
75
core/detect/providers/gemini.py
Normal file
75
core/detect/providers/gemini.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Google Gemini provider — native REST API, not OpenAI-compatible."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
from .base import ModelInfo, ProviderResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Gemini-specific env vars
|
||||
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
|
||||
GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-2.0-flash")
|
||||
|
||||
MODELS = {
|
||||
"gemini-2.0-flash": ModelInfo(
|
||||
id="gemini-2.0-flash",
|
||||
vision=True,
|
||||
cost_per_input_token=0.0000001,
|
||||
cost_per_output_token=0.0000004,
|
||||
notes="Fast, cheap, good vision",
|
||||
),
|
||||
"gemini-2.0-pro": ModelInfo(
|
||||
id="gemini-2.0-pro",
|
||||
vision=True,
|
||||
cost_per_input_token=0.00000125,
|
||||
cost_per_output_token=0.000005,
|
||||
notes="Higher quality, slower",
|
||||
),
|
||||
"gemini-1.5-flash": ModelInfo(
|
||||
id="gemini-1.5-flash",
|
||||
vision=True,
|
||||
cost_per_input_token=0.000000075,
|
||||
cost_per_output_token=0.0000003,
|
||||
notes="Cheapest option",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class GeminiProvider:
|
||||
name = "gemini"
|
||||
models = MODELS
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = GEMINI_API_KEY
|
||||
self.model = GEMINI_MODEL
|
||||
self.endpoint = (
|
||||
f"https://generativelanguage.googleapis.com/v1beta/models/"
|
||||
f"{self.model}:generateContent"
|
||||
)
|
||||
|
||||
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
|
||||
payload = {
|
||||
"contents": [{
|
||||
"parts": [
|
||||
{"text": prompt},
|
||||
{"inline_data": {"mime_type": "image/jpeg", "data": image_b64}},
|
||||
],
|
||||
}],
|
||||
"generationConfig": {"maxOutputTokens": 150},
|
||||
}
|
||||
|
||||
url = f"{self.endpoint}?key={self.api_key}"
|
||||
resp = requests.post(url, json=payload, timeout=30)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
answer = data["candidates"][0]["content"]["parts"][0]["text"].strip()
|
||||
usage = data.get("usageMetadata", {})
|
||||
total_tokens = usage.get("totalTokenCount", 0)
|
||||
|
||||
return ProviderResponse(answer=answer, total_tokens=total_tokens)
|
||||
66
core/detect/providers/groq.py
Normal file
66
core/detect/providers/groq.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Groq cloud provider — OpenAI-compatible API with vision."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
from .base import ModelInfo, ProviderResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Groq-specific env vars
|
||||
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
|
||||
GROQ_BASE_URL = os.environ.get("GROQ_BASE_URL", "https://api.groq.com/openai/v1")
|
||||
GROQ_MODEL = os.environ.get("GROQ_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct")
|
||||
|
||||
MODELS = {
|
||||
"meta-llama/llama-4-scout-17b-16e-instruct": ModelInfo(
|
||||
id="meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
vision=True,
|
||||
cost_per_input_token=0.0,
|
||||
cost_per_output_token=0.0,
|
||||
notes="Llama 4 Scout, only vision model on Groq free tier",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class GroqProvider:
|
||||
name = "groq"
|
||||
models = MODELS
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = GROQ_API_KEY
|
||||
self.base_url = GROQ_BASE_URL
|
||||
self.model = GROQ_MODEL
|
||||
self.endpoint = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_b64}",
|
||||
}},
|
||||
],
|
||||
}],
|
||||
"max_tokens": 150,
|
||||
}
|
||||
|
||||
resp = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=30)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
answer = data["choices"][0]["message"]["content"].strip()
|
||||
total_tokens = data.get("usage", {}).get("total_tokens", 0)
|
||||
|
||||
return ProviderResponse(answer=answer, total_tokens=total_tokens)
|
||||
73
core/detect/providers/openai_compat.py
Normal file
73
core/detect/providers/openai_compat.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Generic OpenAI-compatible provider (OpenAI, Together, etc.)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
from .base import ModelInfo, ProviderResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OpenAI-compat specific env vars
|
||||
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
||||
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
OPENAI_MODEL = os.environ.get("OPENAI_MODEL", "gpt-4o-mini")
|
||||
|
||||
MODELS = {
|
||||
"gpt-4o-mini": ModelInfo(
|
||||
id="gpt-4o-mini",
|
||||
vision=True,
|
||||
cost_per_input_token=0.00000015,
|
||||
cost_per_output_token=0.0000006,
|
||||
notes="Cheap, fast, decent vision",
|
||||
),
|
||||
"gpt-4o": ModelInfo(
|
||||
id="gpt-4o",
|
||||
vision=True,
|
||||
cost_per_input_token=0.0000025,
|
||||
cost_per_output_token=0.00001,
|
||||
notes="Best OpenAI vision model",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class OpenAICompatProvider:
|
||||
name = "openai"
|
||||
models = MODELS
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = OPENAI_API_KEY
|
||||
self.base_url = OPENAI_BASE_URL
|
||||
self.model = OPENAI_MODEL
|
||||
self.endpoint = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_b64}",
|
||||
}},
|
||||
],
|
||||
}],
|
||||
"max_tokens": 150,
|
||||
}
|
||||
|
||||
resp = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=30)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
answer = data["choices"][0]["message"]["content"].strip()
|
||||
total_tokens = data.get("usage", {}).get("total_tokens", 0)
|
||||
|
||||
return ProviderResponse(answer=answer, total_tokens=total_tokens)
|
||||
163
core/detect/sse.py
Normal file
163
core/detect/sse.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Pydantic Models - GENERATED FILE
|
||||
|
||||
Do not edit directly. Regenerate using modelgen.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class GraphNode(BaseModel):
|
||||
"""A pipeline stage node."""
|
||||
id: str
|
||||
status: str = "idle"
|
||||
items_in: int = 0
|
||||
items_out: int = 0
|
||||
|
||||
class GraphEdge(BaseModel):
|
||||
"""An edge between pipeline stages."""
|
||||
source: str
|
||||
target: str
|
||||
throughput: int = 0
|
||||
|
||||
class BoundingBoxEvent(BaseModel):
|
||||
"""Bounding box in SSE event payloads."""
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
resolved_brand: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
stage: Optional[str] = None
|
||||
|
||||
class BrandSummary(BaseModel):
|
||||
"""Per-brand stats in the final report."""
|
||||
brand: str
|
||||
total_appearances: int = 0
|
||||
total_screen_time: float = 0.0
|
||||
avg_confidence: float = 0.0
|
||||
first_seen: float = 0.0
|
||||
last_seen: float = 0.0
|
||||
|
||||
class GraphUpdate(BaseModel):
|
||||
"""Pipeline node state transition. SSE event: graph_update"""
|
||||
nodes: List[GraphNode] = Field(default_factory=list)
|
||||
edges: List[GraphEdge] = Field(default_factory=list)
|
||||
active_path: List[str] = Field(default_factory=list)
|
||||
|
||||
class StatsUpdate(BaseModel):
|
||||
"""Funnel statistics snapshot. SSE event: stats_update"""
|
||||
frames_extracted: int = 0
|
||||
frames_after_scene_filter: int = 0
|
||||
cv_regions_detected: int = 0
|
||||
regions_detected: int = 0
|
||||
regions_resolved_by_ocr: int = 0
|
||||
regions_escalated_to_local_vlm: int = 0
|
||||
regions_escalated_to_cloud_llm: int = 0
|
||||
cloud_llm_calls: int = 0
|
||||
processing_time_seconds: float = 0.0
|
||||
estimated_cloud_cost_usd: float = 0.0
|
||||
run_id: Optional[str] = None
|
||||
parent_job_id: Optional[str] = None
|
||||
run_type: str = "initial"
|
||||
|
||||
class FrameUpdate(BaseModel):
|
||||
"""Current frame being processed. SSE event: frame_update"""
|
||||
frame_ref: int
|
||||
timestamp: float
|
||||
jpeg_b64: str
|
||||
boxes: List[BoundingBoxEvent] = Field(default_factory=list)
|
||||
|
||||
class Detection(BaseModel):
|
||||
"""A confirmed brand detection. SSE event: detection"""
|
||||
brand: str
|
||||
timestamp: float
|
||||
duration: float
|
||||
confidence: float
|
||||
source: str
|
||||
content_type: str
|
||||
bbox: Optional[BoundingBoxEvent] = None
|
||||
frame_ref: Optional[int] = None
|
||||
|
||||
class LogEvent(BaseModel):
|
||||
"""Pipeline log line. SSE event: log"""
|
||||
level: str
|
||||
stage: str
|
||||
msg: str
|
||||
ts: str
|
||||
trace_id: Optional[str] = None
|
||||
|
||||
class DetectionReportSummary(BaseModel):
|
||||
"""Final detection report summary."""
|
||||
video_source: str
|
||||
content_type: str
|
||||
duration_seconds: float
|
||||
total_detections: int = 0
|
||||
brands: List[BrandSummary] = Field(default_factory=list)
|
||||
stats: Optional[StatsUpdate] = None
|
||||
|
||||
class JobComplete(BaseModel):
|
||||
"""Final report when pipeline finishes. SSE event: job_complete"""
|
||||
job_id: str
|
||||
report: Optional[DetectionReportSummary] = None
|
||||
|
||||
class RunContext(BaseModel):
|
||||
"""Run context injected into all SSE events for grouping."""
|
||||
run_id: str
|
||||
parent_job_id: str
|
||||
run_type: str = "initial"
|
||||
|
||||
class CheckpointInfo(BaseModel):
|
||||
"""Available checkpoint for a stage."""
|
||||
stage: str
|
||||
is_scenario: bool = False
|
||||
scenario_label: str = ""
|
||||
|
||||
class ReplayRequest(BaseModel):
|
||||
"""Request to replay pipeline from a specific stage."""
|
||||
job_id: str
|
||||
start_stage: str
|
||||
config_overrides: Optional[Dict[str, Any]] = None
|
||||
|
||||
class ReplayResponse(BaseModel):
|
||||
"""Result of a replay invocation."""
|
||||
status: str
|
||||
job_id: str
|
||||
start_stage: str
|
||||
detections: int = 0
|
||||
brands_found: int = 0
|
||||
|
||||
class RetryRequest(BaseModel):
|
||||
"""Request to queue async retry with different config."""
|
||||
job_id: str
|
||||
config_overrides: Optional[Dict[str, Any]] = None
|
||||
start_stage: str = "escalate_vlm"
|
||||
schedule_seconds: Optional[float] = None
|
||||
|
||||
class RetryResponse(BaseModel):
|
||||
"""Result of queueing a retry task."""
|
||||
status: str
|
||||
task_id: str
|
||||
job_id: str
|
||||
|
||||
class RunRequest(BaseModel):
|
||||
"""Request body for launching a detection pipeline run."""
|
||||
video_path: str
|
||||
profile_name: str = "soccer_broadcast"
|
||||
source_asset_id: str = ""
|
||||
checkpoint: bool = True
|
||||
skip_vlm: bool = False
|
||||
skip_cloud: bool = False
|
||||
log_level: str = "INFO"
|
||||
|
||||
class RunResponse(BaseModel):
|
||||
"""Response after starting a pipeline run."""
|
||||
status: str
|
||||
job_id: str
|
||||
video_path: str
|
||||
22
core/detect/stages/__init__.py
Normal file
22
core/detect/stages/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Pipeline stages.
|
||||
|
||||
Each stage is a file with a Stage subclass. Auto-discovered via
|
||||
__init_subclass__ — importing the file registers the stage.
|
||||
"""
|
||||
|
||||
from .base import (
|
||||
Stage,
|
||||
get_stage,
|
||||
get_stage_instance,
|
||||
list_stages,
|
||||
list_stage_classes,
|
||||
get_palette,
|
||||
)
|
||||
|
||||
# Import all stage files to trigger auto-registration
|
||||
from . import edge_detector # noqa: F401
|
||||
from . import field_segmentation # noqa: F401
|
||||
|
||||
# Import registry for backward compat (other stages still use old pattern)
|
||||
from . import registry # noqa: F401
|
||||
116
core/detect/stages/aggregator.py
Normal file
116
core/detect/stages/aggregator.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Stage 8 — Report compilation
|
||||
|
||||
Groups all detections by brand, merges contiguous appearances,
|
||||
and builds the final DetectionReport.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import BrandDetection, BrandStats, DetectionReport, PipelineStats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _merge_contiguous(detections: list[BrandDetection], gap_threshold: float = 2.0) -> list[BrandDetection]:
|
||||
"""
|
||||
Merge detections of the same brand that are close in time.
|
||||
|
||||
If two detections of the same brand are within gap_threshold seconds,
|
||||
they're merged into one detection spanning the full range.
|
||||
"""
|
||||
if not detections:
|
||||
return []
|
||||
|
||||
sorted_dets = sorted(detections, key=lambda d: (d.brand, d.timestamp))
|
||||
merged: list[BrandDetection] = []
|
||||
current = sorted_dets[0]
|
||||
|
||||
for det in sorted_dets[1:]:
|
||||
if (det.brand == current.brand
|
||||
and det.timestamp <= current.timestamp + current.duration + gap_threshold):
|
||||
end = max(current.timestamp + current.duration,
|
||||
det.timestamp + det.duration)
|
||||
current = BrandDetection(
|
||||
brand=current.brand,
|
||||
timestamp=current.timestamp,
|
||||
duration=end - current.timestamp,
|
||||
confidence=max(current.confidence, det.confidence),
|
||||
source=current.source,
|
||||
bbox=current.bbox,
|
||||
frame_ref=current.frame_ref,
|
||||
content_type=current.content_type,
|
||||
)
|
||||
else:
|
||||
merged.append(current)
|
||||
current = det
|
||||
|
||||
merged.append(current)
|
||||
return merged
|
||||
|
||||
|
||||
def compile_report(
|
||||
detections: list[BrandDetection],
|
||||
stats: PipelineStats,
|
||||
video_source: str = "",
|
||||
content_type: str = "",
|
||||
duration_seconds: float = 0.0,
|
||||
job_id: str | None = None,
|
||||
) -> DetectionReport:
|
||||
"""
|
||||
Build the final detection report from all accumulated detections.
|
||||
|
||||
Merges contiguous detections, computes per-brand stats,
|
||||
and emits the job_complete event.
|
||||
"""
|
||||
merged = _merge_contiguous(detections)
|
||||
|
||||
brands: dict[str, BrandStats] = {}
|
||||
for d in merged:
|
||||
if d.brand not in brands:
|
||||
brands[d.brand] = BrandStats()
|
||||
s = brands[d.brand]
|
||||
s.total_appearances += 1
|
||||
s.total_screen_time += d.duration
|
||||
s.avg_confidence = (
|
||||
(s.avg_confidence * (s.total_appearances - 1) + d.confidence)
|
||||
/ s.total_appearances
|
||||
)
|
||||
if s.first_seen == 0.0 or d.timestamp < s.first_seen:
|
||||
s.first_seen = d.timestamp
|
||||
if d.timestamp > s.last_seen:
|
||||
s.last_seen = d.timestamp
|
||||
|
||||
report = DetectionReport(
|
||||
video_source=video_source,
|
||||
content_type=content_type,
|
||||
duration_seconds=duration_seconds,
|
||||
brands=brands,
|
||||
timeline=sorted(merged, key=lambda d: d.timestamp),
|
||||
pipeline_stats=stats,
|
||||
)
|
||||
|
||||
emit.log(job_id, "Aggregator", "INFO",
|
||||
f"Report: {len(brands)} brands, {len(merged)} detections "
|
||||
f"(merged from {len(detections)} raw)")
|
||||
|
||||
emit.job_complete(job_id, {
|
||||
"video_source": report.video_source,
|
||||
"content_type": report.content_type,
|
||||
"duration_seconds": report.duration_seconds,
|
||||
"brands": {
|
||||
k: {
|
||||
"total_appearances": v.total_appearances,
|
||||
"total_screen_time": v.total_screen_time,
|
||||
"avg_confidence": round(v.avg_confidence, 3),
|
||||
"first_seen": v.first_seen,
|
||||
"last_seen": v.last_seen,
|
||||
}
|
||||
for k, v in brands.items()
|
||||
},
|
||||
})
|
||||
|
||||
return report
|
||||
151
core/detect/stages/base.py
Normal file
151
core/detect/stages/base.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
Stage base class — common interface for all pipeline stages.
|
||||
|
||||
Each stage is a file that subclasses Stage. Auto-discovered via
|
||||
__init_subclass__. No manual registration needed.
|
||||
|
||||
A stage:
|
||||
- Has a StageDefinition (generated from schema) with name, config, IO
|
||||
- Implements run(frames, config) → output
|
||||
- Owns its output serialization (opaque blob)
|
||||
- Optionally has a TypeScript port for browser-side execution
|
||||
|
||||
The checkpoint layer stores stage output as blobs without knowing
|
||||
the format. The stage that wrote it is the only one that can read it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from core.detect.stages.models import (
|
||||
StageConfigField,
|
||||
StageIO,
|
||||
StageDefinition,
|
||||
)
|
||||
|
||||
|
||||
# Legacy runtime extension — adds callable fields for old-style stages.
|
||||
# New stages use Stage subclass with serialize()/deserialize() methods instead.
|
||||
class LegacyStageDefinition:
|
||||
"""Wraps a StageDefinition with callable serialize/deserialize functions."""
|
||||
|
||||
def __init__(self, definition: StageDefinition, fn=None, serialize_fn=None, deserialize_fn=None):
|
||||
self._definition = definition
|
||||
self.fn = fn
|
||||
self.serialize_fn = serialize_fn
|
||||
self.deserialize_fn = deserialize_fn
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._definition, name)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry — auto-populated by __init_subclass__ (new stages)
|
||||
# + register_stage() (legacy stages during migration)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_REGISTRY: dict[str, type['Stage']] = {}
|
||||
_LEGACY_REGISTRY: dict[str, LegacyStageDefinition] = {}
|
||||
|
||||
|
||||
def register_stage(
|
||||
definition: StageDefinition,
|
||||
fn=None,
|
||||
serialize_fn=None,
|
||||
deserialize_fn=None,
|
||||
):
|
||||
"""Legacy registration for stages not yet converted to Stage subclass."""
|
||||
legacy = LegacyStageDefinition(definition, fn=fn, serialize_fn=serialize_fn, deserialize_fn=deserialize_fn)
|
||||
_LEGACY_REGISTRY[definition.name] = legacy
|
||||
|
||||
|
||||
class Stage:
|
||||
"""
|
||||
Base class for all pipeline stages.
|
||||
|
||||
Subclass this in detect/stages/<name>.py. Define `definition` as a
|
||||
class attribute. Implement `run()`. Optionally override `serialize()`
|
||||
and `deserialize()` for custom blob formats (default is JSON).
|
||||
"""
|
||||
|
||||
definition: StageDefinition # set by each subclass
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if hasattr(cls, 'definition') and cls.definition is not None:
|
||||
_REGISTRY[cls.definition.name] = cls
|
||||
|
||||
def run(self, frames: list, config: dict) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def serialize(self, output: Any) -> bytes:
|
||||
"""Serialize stage output to bytes for checkpoint storage."""
|
||||
import json
|
||||
return json.dumps(output, default=str).encode()
|
||||
|
||||
def deserialize(self, data: bytes) -> Any:
|
||||
"""Deserialize stage output from checkpoint blob."""
|
||||
import json
|
||||
return json.loads(data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discovery API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _all_definitions():
|
||||
"""Merge new Stage subclass registry + legacy registry.
|
||||
|
||||
Returns StageDefinition for new-style stages,
|
||||
LegacyStageDefinition for legacy stages (has serialize_fn etc).
|
||||
"""
|
||||
merged = {}
|
||||
for name, legacy in _LEGACY_REGISTRY.items():
|
||||
merged[name] = legacy
|
||||
for name, cls in _REGISTRY.items():
|
||||
merged[name] = cls.definition
|
||||
return merged
|
||||
|
||||
|
||||
def get_stage(name: str) -> StageDefinition:
|
||||
"""Get a stage definition by name (works for both new and legacy)."""
|
||||
all_defs = _all_definitions()
|
||||
if name not in all_defs:
|
||||
raise KeyError(f"Unknown stage: {name!r}. Registered: {list(all_defs)}")
|
||||
return all_defs[name]
|
||||
|
||||
|
||||
def get_stage_class(name: str) -> type[Stage] | None:
|
||||
"""Get a Stage subclass by name. Returns None for legacy stages."""
|
||||
return _REGISTRY.get(name)
|
||||
|
||||
|
||||
def get_stage_instance(name: str) -> Stage:
|
||||
"""Get an instantiated Stage by name. Only works for new-style stages."""
|
||||
cls = _REGISTRY.get(name)
|
||||
if cls is None:
|
||||
raise KeyError(f"No Stage subclass for {name!r}. Legacy stages don't have instances.")
|
||||
return cls()
|
||||
|
||||
|
||||
def list_stages() -> list[StageDefinition]:
|
||||
"""List all registered stage definitions (new + legacy)."""
|
||||
return list(_all_definitions().values())
|
||||
|
||||
|
||||
def list_stage_classes() -> list[type[Stage]]:
|
||||
"""List all registered Stage subclasses (new-style only)."""
|
||||
return list(_REGISTRY.values())
|
||||
|
||||
|
||||
def get_palette() -> dict[str, list[StageDefinition]]:
|
||||
"""Group stages by category for the editor palette."""
|
||||
palette: dict[str, list[StageDefinition]] = {}
|
||||
for defn in _all_definitions().values():
|
||||
if defn.category not in palette:
|
||||
palette[defn.category] = []
|
||||
palette[defn.category].append(defn)
|
||||
return palette
|
||||
216
core/detect/stages/brand_resolver.py
Normal file
216
core/detect/stages/brand_resolver.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
Stage 5 — Brand Resolver (discovery mode)
|
||||
|
||||
Discovery-first brand matching. No static dictionary — all brands live in the DB.
|
||||
|
||||
Flow:
|
||||
1. Check session brands first (brands already seen in this run, in-memory)
|
||||
2. Check global known brands (accumulated across all runs)
|
||||
3. Unresolved candidates → escalate to VLM/cloud
|
||||
4. Confirmed brands get added to DB for future runs
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from rapidfuzz import fuzz
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import BrandDetection, TextCandidate
|
||||
from core.detect.stages.models import ResolverConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize(text: str) -> str:
|
||||
return text.strip().lower()
|
||||
|
||||
|
||||
def _has_db() -> bool:
|
||||
try:
|
||||
from core.db import find_brand_by_text as _
|
||||
return True
|
||||
except (ImportError, Exception):
|
||||
return False
|
||||
|
||||
|
||||
def _match_session(text: str, session_brands: dict[str, str]) -> str | None:
|
||||
return session_brands.get(_normalize(text))
|
||||
|
||||
|
||||
def _match_known(text: str, threshold: int) -> tuple[str | None, str | None]:
|
||||
"""Check against global known brands in DB. Returns (canonical_name, brand_id) or (None, None)."""
|
||||
if not _has_db():
|
||||
return None, None
|
||||
|
||||
from core.db import find_brand_by_text, list_brands
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
brand = find_brand_by_text(session, text)
|
||||
if brand:
|
||||
return brand.canonical_name, str(brand.id)
|
||||
|
||||
all_brands = list_brands(session)
|
||||
|
||||
normalized = _normalize(text)
|
||||
best_brand = None
|
||||
best_score = 0
|
||||
|
||||
for known in all_brands:
|
||||
names = [known.canonical_name] + (known.aliases or [])
|
||||
for name in names:
|
||||
score = fuzz.ratio(normalized, _normalize(name))
|
||||
if score > best_score and score >= threshold:
|
||||
best_score = score
|
||||
best_brand = known
|
||||
|
||||
if best_brand:
|
||||
return best_brand.canonical_name, str(best_brand.id)
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def _register_brand(canonical_name: str, source: str) -> str | None:
|
||||
"""Register a newly discovered brand in the DB. Returns brand_id."""
|
||||
if not _has_db():
|
||||
return None
|
||||
|
||||
from core.db import get_or_create_brand
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
brand, created = get_or_create_brand(session, canonical_name, source=source)
|
||||
session.commit()
|
||||
if created:
|
||||
logger.info("New brand discovered: %s (source=%s)", canonical_name, source)
|
||||
return str(brand.id)
|
||||
|
||||
|
||||
def _record_airing(timeline_id: str | None, brand_id: str,
|
||||
frame_seq: int, confidence: float, source: str):
|
||||
"""Record a brand airing on a timeline."""
|
||||
if not _has_db() or not timeline_id:
|
||||
return
|
||||
|
||||
from core.db import record_airing
|
||||
from core.db.connection import get_session
|
||||
from uuid import UUID
|
||||
|
||||
with get_session() as session:
|
||||
record_airing(
|
||||
session,
|
||||
brand_id=UUID(brand_id),
|
||||
timeline_id=UUID(timeline_id),
|
||||
frame_start=frame_seq,
|
||||
frame_end=frame_seq,
|
||||
confidence=confidence,
|
||||
source=source,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
||||
def build_session_dict(source_asset_id: str | None = None) -> dict[str, str]:
|
||||
"""
|
||||
Load known brands from DB as a session lookup dict.
|
||||
|
||||
Returns {normalized_name: canonical_name, ...} including aliases.
|
||||
"""
|
||||
if not _has_db():
|
||||
return {}
|
||||
|
||||
from core.db import list_brands
|
||||
from core.db.connection import get_session
|
||||
|
||||
with get_session() as session:
|
||||
all_brands = list_brands(session)
|
||||
|
||||
session_dict = {}
|
||||
for brand in all_brands:
|
||||
session_dict[_normalize(brand.canonical_name)] = brand.canonical_name
|
||||
for alias in (brand.aliases or []):
|
||||
session_dict[_normalize(alias)] = brand.canonical_name
|
||||
|
||||
return session_dict
|
||||
|
||||
|
||||
def resolve_brands(
|
||||
candidates: list[TextCandidate],
|
||||
config: ResolverConfig,
|
||||
session_brands: dict[str, str] | None = None,
|
||||
source_asset_id: str | None = None,
|
||||
content_type: str = "",
|
||||
job_id: str | None = None,
|
||||
) -> tuple[list[BrandDetection], list[TextCandidate]]:
|
||||
"""
|
||||
Match text candidates against known brands (session → global → unresolved).
|
||||
|
||||
session_brands: pre-loaded session dict (from build_session_dict)
|
||||
job_id: timeline_id — used to record airings
|
||||
"""
|
||||
if session_brands is None:
|
||||
session_brands = {}
|
||||
|
||||
emit.log(job_id, "BrandResolver", "INFO",
|
||||
f"Resolving {len(candidates)} candidates "
|
||||
f"(session={len(session_brands)} brands, fuzzy={config.fuzzy_threshold})")
|
||||
|
||||
matched: list[BrandDetection] = []
|
||||
unresolved: list[TextCandidate] = []
|
||||
session_hits = 0
|
||||
known_hits = 0
|
||||
|
||||
for candidate in candidates:
|
||||
text = candidate.text
|
||||
brand_name = None
|
||||
brand_id = None
|
||||
match_source = "ocr"
|
||||
|
||||
# 1. Check session (cheapest — in-memory dict)
|
||||
brand_name = _match_session(text, session_brands)
|
||||
if brand_name:
|
||||
session_hits += 1
|
||||
else:
|
||||
# 2. Check global known brands (DB query + fuzzy)
|
||||
brand_name, brand_id = _match_known(text, config.fuzzy_threshold)
|
||||
if brand_name:
|
||||
known_hits += 1
|
||||
session_brands[_normalize(brand_name)] = brand_name
|
||||
|
||||
if brand_name:
|
||||
detection = BrandDetection(
|
||||
brand=brand_name,
|
||||
timestamp=candidate.frame.timestamp,
|
||||
duration=0.5,
|
||||
confidence=candidate.ocr_confidence,
|
||||
source=match_source,
|
||||
bbox=candidate.bbox,
|
||||
frame_ref=candidate.frame.sequence,
|
||||
content_type=content_type,
|
||||
)
|
||||
matched.append(detection)
|
||||
|
||||
if brand_id:
|
||||
_record_airing(
|
||||
job_id, brand_id,
|
||||
candidate.frame.sequence, candidate.ocr_confidence, match_source,
|
||||
)
|
||||
|
||||
emit.detection(
|
||||
job_id,
|
||||
brand=brand_name,
|
||||
confidence=candidate.ocr_confidence,
|
||||
source=match_source,
|
||||
timestamp=candidate.frame.timestamp,
|
||||
content_type=content_type,
|
||||
frame_ref=candidate.frame.sequence,
|
||||
)
|
||||
else:
|
||||
unresolved.append(candidate)
|
||||
|
||||
emit.log(job_id, "BrandResolver", "INFO",
|
||||
f"Session: {session_hits}, Known: {known_hits}, "
|
||||
f"Unresolved: {len(unresolved)} → escalating")
|
||||
|
||||
return matched, unresolved
|
||||
292
core/detect/stages/edge_detector.py
Normal file
292
core/detect/stages/edge_detector.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
Stage — Edge Detection
|
||||
|
||||
Canny + HoughLinesP to find horizontal line pairs that bound
|
||||
advertising hoardings. Pure OpenCV, no ML models.
|
||||
|
||||
Two modes:
|
||||
- Remote: calls GPU inference server over HTTP
|
||||
- Local: imports cv2 directly (OpenCV on same machine)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import BoundingBox, Frame
|
||||
from core.detect.stages.base import Stage
|
||||
from core.detect.stages.models import StageDefinition, StageConfigField, StageIO, StageOutputHint
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EdgeDetectionStage(Stage):
|
||||
|
||||
definition = StageDefinition(
|
||||
name="detect_edges",
|
||||
label="Edge Detection",
|
||||
description="Canny + HoughLinesP — find horizontal line pairs (hoarding boundaries)",
|
||||
category="cv_analysis",
|
||||
io=StageIO(
|
||||
reads=["filtered_frames"],
|
||||
writes=["edge_regions_by_frame"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255),
|
||||
StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255),
|
||||
StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500),
|
||||
StageConfigField(name="edge_hough_min_length", type="int", default=100, description="Min line length (px)", min=10, max=2000),
|
||||
StageConfigField(name="edge_hough_max_gap", type="int", default=10, description="Max line gap (px)", min=1, max=100),
|
||||
StageConfigField(name="edge_pair_max_distance", type="int", default=200, description="Max distance between line pair (px)", min=10, max=500),
|
||||
StageConfigField(name="edge_pair_min_distance", type="int", default=15, description="Min distance between line pair (px)", min=5, max=200),
|
||||
],
|
||||
output_hints=[
|
||||
StageOutputHint(key="edge_regions_by_frame", type="boxes_by_frame", label="Edge regions"),
|
||||
StageOutputHint(key="edge_overlay_b64", type="overlay", label="Canny edges", default_opacity=0.25),
|
||||
StageOutputHint(key="lines_overlay_b64", type="overlay", label="Hough lines", default_opacity=0.25),
|
||||
],
|
||||
tracks_element="edge_region",
|
||||
)
|
||||
|
||||
def run(self, frames: list[Frame], config: dict) -> dict[int, list[BoundingBox]]:
|
||||
"""
|
||||
Run edge detection on all frames.
|
||||
|
||||
Config keys: enabled, edge_canny_low, edge_canny_high, edge_hough_threshold,
|
||||
edge_hough_min_length, edge_hough_max_gap, edge_pair_max_distance, edge_pair_min_distance,
|
||||
debug (bool), inference_url (str|None), job_id (str|None).
|
||||
|
||||
Returns dict mapping frame sequence → list of BoundingBox.
|
||||
"""
|
||||
enabled = config.get("enabled", True)
|
||||
job_id = config.get("job_id")
|
||||
inference_url = config.get("inference_url") or os.environ.get("INFERENCE_URL")
|
||||
|
||||
if not enabled:
|
||||
emit.log(job_id, "EdgeDetection", "INFO", "Edge detection disabled, skipping")
|
||||
return {}
|
||||
|
||||
mode = "remote" if inference_url else "local"
|
||||
emit.log(job_id, "EdgeDetection", "INFO",
|
||||
f"Detecting edges in {len(frames)} frames (mode={mode})")
|
||||
|
||||
all_boxes: dict[int, list[BoundingBox]] = {}
|
||||
total_regions = 0
|
||||
|
||||
for frame in frames:
|
||||
t0 = time.monotonic()
|
||||
if inference_url:
|
||||
boxes = self._run_remote(frame, config, inference_url, job_id or "")
|
||||
else:
|
||||
boxes = self._run_local(frame, config)
|
||||
ms = (time.monotonic() - t0) * 1000
|
||||
|
||||
all_boxes[frame.sequence] = boxes
|
||||
total_regions += len(boxes)
|
||||
|
||||
emit.log(job_id, "EdgeDetection", "DEBUG",
|
||||
f"Frame {frame.sequence}: {len(boxes)} regions in {ms:.0f}ms"
|
||||
+ (f" [{', '.join(b.label for b in boxes)}]" if boxes else ""))
|
||||
|
||||
if boxes and job_id:
|
||||
box_dicts = [
|
||||
{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
|
||||
"confidence": b.confidence, "label": b.label,
|
||||
"stage": "detect_edges"}
|
||||
for b in boxes
|
||||
]
|
||||
emit.frame_update(
|
||||
job_id,
|
||||
frame_ref=frame.sequence,
|
||||
timestamp=frame.timestamp,
|
||||
jpeg_b64=_frame_to_b64(frame),
|
||||
boxes=box_dicts,
|
||||
)
|
||||
|
||||
emit.log(job_id, "EdgeDetection", "INFO",
|
||||
f"Found {total_regions} edge regions across {len(frames)} frames")
|
||||
emit.stats(job_id, cv_regions_detected=total_regions)
|
||||
|
||||
return all_boxes
|
||||
|
||||
def serialize(self, output: Any) -> bytes:
|
||||
"""Serialize edge regions to JSON blob."""
|
||||
serialized = {}
|
||||
for seq, boxes in output.items():
|
||||
serialized[str(seq)] = [
|
||||
{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
|
||||
"confidence": b.confidence, "label": b.label}
|
||||
for b in boxes
|
||||
]
|
||||
return json.dumps(serialized).encode()
|
||||
|
||||
def deserialize(self, data: bytes) -> dict[int, list[BoundingBox]]:
|
||||
"""Deserialize edge regions from JSON blob."""
|
||||
raw = json.loads(data)
|
||||
result = {}
|
||||
for seq_str, box_dicts in raw.items():
|
||||
boxes = [
|
||||
BoundingBox(x=b["x"], y=b["y"], w=b["w"], h=b["h"],
|
||||
confidence=b["confidence"], label=b["label"])
|
||||
for b in box_dicts
|
||||
]
|
||||
result[int(seq_str)] = boxes
|
||||
return result
|
||||
|
||||
# --- Private helpers ---
|
||||
|
||||
def _run_remote(self, frame: Frame, config: dict,
|
||||
inference_url: str, job_id: str) -> list[BoundingBox]:
|
||||
from core.detect.inference import InferenceClient
|
||||
from core.detect.emit import _run_log_level
|
||||
|
||||
client = InferenceClient(
|
||||
base_url=inference_url, job_id=job_id, log_level=_run_log_level,
|
||||
)
|
||||
results = client.detect_edges(
|
||||
image=frame.image,
|
||||
edge_canny_low=config.get("edge_canny_low", 50),
|
||||
edge_canny_high=config.get("edge_canny_high", 150),
|
||||
edge_hough_threshold=config.get("edge_hough_threshold", 80),
|
||||
edge_hough_min_length=config.get("edge_hough_min_length", 100),
|
||||
edge_hough_max_gap=config.get("edge_hough_max_gap", 10),
|
||||
edge_pair_max_distance=config.get("edge_pair_max_distance", 200),
|
||||
edge_pair_min_distance=config.get("edge_pair_min_distance", 15),
|
||||
)
|
||||
boxes = []
|
||||
for r in results:
|
||||
box = BoundingBox(
|
||||
x=r.x, y=r.y, w=r.w, h=r.h,
|
||||
confidence=r.confidence, label=r.label,
|
||||
)
|
||||
boxes.append(box)
|
||||
return boxes
|
||||
|
||||
def _run_local(self, frame: Frame, config: dict) -> list[BoundingBox]:
|
||||
detect_edges_fn = _load_cv_edges().detect_edges
|
||||
|
||||
edge_results = detect_edges_fn(
|
||||
frame.image,
|
||||
canny_low=config.get("edge_canny_low", 50),
|
||||
canny_high=config.get("edge_canny_high", 150),
|
||||
hough_threshold=config.get("edge_hough_threshold", 80),
|
||||
hough_min_length=config.get("edge_hough_min_length", 100),
|
||||
hough_max_gap=config.get("edge_hough_max_gap", 10),
|
||||
pair_max_distance=config.get("edge_pair_max_distance", 200),
|
||||
pair_min_distance=config.get("edge_pair_min_distance", 15),
|
||||
)
|
||||
|
||||
boxes = []
|
||||
for r in edge_results:
|
||||
box = BoundingBox(
|
||||
x=r["x"], y=r["y"], w=r["w"], h=r["h"],
|
||||
confidence=r["confidence"], label=r["label"],
|
||||
)
|
||||
boxes.append(box)
|
||||
return boxes
|
||||
|
||||
|
||||
# --- Module-level helpers ---
|
||||
|
||||
def _frame_to_b64(frame: Frame) -> str:
|
||||
img = Image.fromarray(frame.image)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=70)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
_cv_edges_mod = None
|
||||
|
||||
def _load_cv_edges():
|
||||
global _cv_edges_mod
|
||||
if _cv_edges_mod is None:
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
spec = importlib.util.spec_from_file_location("cv_edges", Path("core/gpu/models/cv/edges.py"))
|
||||
_cv_edges_mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(_cv_edges_mod)
|
||||
return _cv_edges_mod
|
||||
|
||||
|
||||
# --- Backward compat: standalone function for graph.py ---
|
||||
|
||||
def _filter_by_field_mask(boxes, mask, margin_px=50):
|
||||
"""
|
||||
Keep only boxes that are near the pitch boundary (hoarding zone).
|
||||
|
||||
The field mask has 255=pitch, 0=not pitch. Hoardings sit just
|
||||
outside the pitch boundary. We dilate the mask to create a
|
||||
"boundary zone" and keep boxes whose center falls in the zone
|
||||
between the dilated mask edge and the original mask.
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
if mask is None or not boxes:
|
||||
return boxes
|
||||
|
||||
# Dilate the pitch mask — the expansion zone is where hoardings are
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (margin_px * 2, margin_px * 2))
|
||||
dilated = cv2.dilate(mask, kernel)
|
||||
|
||||
# Boundary zone = dilated but NOT original pitch
|
||||
boundary_zone = cv2.bitwise_and(dilated, cv2.bitwise_not(mask))
|
||||
|
||||
kept = []
|
||||
for box in boxes:
|
||||
cx = box.x + box.w // 2
|
||||
cy = box.y + box.h // 2
|
||||
# Clamp to image bounds
|
||||
cy = min(cy, boundary_zone.shape[0] - 1)
|
||||
cx = min(cx, boundary_zone.shape[1] - 1)
|
||||
if boundary_zone[cy, cx] > 0:
|
||||
kept.append(box)
|
||||
|
||||
return kept
|
||||
|
||||
|
||||
def detect_edge_regions(frames, config, inference_url=None, job_id=None, field_masks=None):
|
||||
"""Convenience wrapper — calls EdgeDetectionStage.run(), optionally filters by field mask."""
|
||||
stage = EdgeDetectionStage()
|
||||
cfg = {
|
||||
"enabled": config.enabled,
|
||||
"edge_canny_low": config.edge_canny_low,
|
||||
"edge_canny_high": config.edge_canny_high,
|
||||
"edge_hough_threshold": config.edge_hough_threshold,
|
||||
"edge_hough_min_length": config.edge_hough_min_length,
|
||||
"edge_hough_max_gap": config.edge_hough_max_gap,
|
||||
"edge_pair_max_distance": config.edge_pair_max_distance,
|
||||
"edge_pair_min_distance": config.edge_pair_min_distance,
|
||||
"inference_url": inference_url,
|
||||
"job_id": job_id,
|
||||
}
|
||||
all_boxes = stage.run(frames, cfg)
|
||||
|
||||
# Filter by field segmentation mask if available
|
||||
if field_masks:
|
||||
filtered_total = 0
|
||||
original_total = sum(len(b) for b in all_boxes.values())
|
||||
for seq, boxes in all_boxes.items():
|
||||
mask = field_masks.get(seq)
|
||||
if mask is not None:
|
||||
all_boxes[seq] = _filter_by_field_mask(boxes, mask)
|
||||
filtered_total += len(all_boxes[seq])
|
||||
else:
|
||||
filtered_total += len(boxes)
|
||||
|
||||
if original_total != filtered_total:
|
||||
from core.detect import emit
|
||||
emit.log(job_id, "EdgeDetection", "INFO",
|
||||
f"Field mask filter: {original_total} → {filtered_total} regions")
|
||||
|
||||
return all_boxes
|
||||
151
core/detect/stages/field_segmentation.py
Normal file
151
core/detect/stages/field_segmentation.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
Stage — Field Segmentation
|
||||
|
||||
Calls the GPU inference server to detect pitch boundaries via
|
||||
HSV green mask + morphology. The CV code lives in core/gpu/models/cv/.
|
||||
|
||||
Outputs a mask and boundary that downstream stages use as spatial priors.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import Frame
|
||||
from core.detect.stages.base import Stage
|
||||
from core.detect.stages.models import (
|
||||
FieldSegmentationConfig,
|
||||
StageConfigField,
|
||||
StageDefinition,
|
||||
StageIO,
|
||||
StageOutputHint,
|
||||
TransformOption,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FieldSegmentationStage(Stage):
|
||||
|
||||
definition = StageDefinition(
|
||||
name="field_segmentation",
|
||||
label="Field Segmentation",
|
||||
description="HSV green mask — detect pitch boundaries for spatial priors",
|
||||
category="cv_analysis",
|
||||
io=StageIO(
|
||||
reads=["filtered_frames"],
|
||||
writes=["field_mask"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField(name="hue_low", type="int", default=30, description="HSV hue lower bound", min=0, max=180),
|
||||
StageConfigField(name="hue_high", type="int", default=85, description="HSV hue upper bound", min=0, max=180),
|
||||
StageConfigField(name="sat_low", type="int", default=30, description="HSV saturation lower bound", min=0, max=255),
|
||||
StageConfigField(name="sat_high", type="int", default=255, description="HSV saturation upper bound", min=0, max=255),
|
||||
StageConfigField(name="val_low", type="int", default=30, description="HSV value lower bound", min=0, max=255),
|
||||
StageConfigField(name="val_high", type="int", default=255, description="HSV value upper bound", min=0, max=255),
|
||||
StageConfigField(name="morph_kernel", type="int", default=15, description="Morphology kernel size", min=3, max=51),
|
||||
StageConfigField(name="min_area_ratio", type="float", default=0.05, description="Min contour area as fraction of frame", min=0.01, max=0.5),
|
||||
],
|
||||
output_hints=[
|
||||
StageOutputHint(key="mask_overlay_b64", type="overlay", label="Field mask", default_opacity=0.5, src_format="png"),
|
||||
],
|
||||
accepted_transforms=[
|
||||
TransformOption(key="invert_mask", type="bool", default=False, label="Invert selection", description="Invert the mask so downstream stages look outside the detected area"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _frame_to_b64(frame: Frame) -> str:
|
||||
"""Encode frame image as base64 JPEG."""
|
||||
img = Image.fromarray(frame.image)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=85)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
def _decode_mask_b64(mask_b64: str) -> np.ndarray:
|
||||
"""Decode a base64 PNG mask back to numpy array."""
|
||||
data = base64.b64decode(mask_b64)
|
||||
img = Image.open(io.BytesIO(data)).convert("L")
|
||||
return np.array(img)
|
||||
|
||||
|
||||
def run_field_segmentation(
|
||||
frames: list[Frame],
|
||||
config: FieldSegmentationConfig,
|
||||
inference_url: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Run field segmentation on all frames via the inference server.
|
||||
|
||||
Returns dict with:
|
||||
field_masks: {seq: np.ndarray}
|
||||
field_boundaries: {seq: [(x,y), ...]}
|
||||
field_coverage: {seq: float}
|
||||
"""
|
||||
if not config.enabled:
|
||||
emit.log(job_id, "FieldSegmentation", "INFO", "Disabled, skipping")
|
||||
return {"field_masks": {}, "field_boundaries": {}, "field_coverage": {}}
|
||||
|
||||
import os
|
||||
url = inference_url or os.environ.get("INFERENCE_URL")
|
||||
if not url:
|
||||
emit.log(job_id, "FieldSegmentation", "WARNING",
|
||||
"No INFERENCE_URL, skipping field segmentation")
|
||||
return {"field_masks": {}, "field_boundaries": {}, "field_coverage": {}}
|
||||
|
||||
emit.log(job_id, "FieldSegmentation", "INFO",
|
||||
f"Segmenting {len(frames)} frames (hue={config.hue_low}-{config.hue_high})")
|
||||
|
||||
from core.detect.inference import InferenceClient
|
||||
from core.detect.emit import _run_log_level
|
||||
client = InferenceClient(base_url=url, job_id=job_id or "", log_level=_run_log_level)
|
||||
|
||||
field_masks = {}
|
||||
field_mask_overlays = {}
|
||||
field_boundaries = {}
|
||||
field_coverage = {}
|
||||
|
||||
for frame in frames:
|
||||
image_b64 = _frame_to_b64(frame)
|
||||
|
||||
resp = client.post("/segment_field", {
|
||||
"image": image_b64,
|
||||
"hue_low": config.hue_low,
|
||||
"hue_high": config.hue_high,
|
||||
"sat_low": config.sat_low,
|
||||
"sat_high": config.sat_high,
|
||||
"val_low": config.val_low,
|
||||
"val_high": config.val_high,
|
||||
"morph_kernel": config.morph_kernel,
|
||||
"min_area_ratio": config.min_area_ratio,
|
||||
})
|
||||
|
||||
if resp is None:
|
||||
continue
|
||||
|
||||
mask_b64 = resp.get("mask_b64", "")
|
||||
if mask_b64:
|
||||
field_masks[frame.sequence] = _decode_mask_b64(mask_b64)
|
||||
field_mask_overlays[frame.sequence] = mask_b64
|
||||
|
||||
field_boundaries[frame.sequence] = resp.get("boundary", [])
|
||||
field_coverage[frame.sequence] = resp.get("coverage", 0.0)
|
||||
|
||||
avg_coverage = sum(field_coverage.values()) / max(len(field_coverage), 1)
|
||||
emit.log(job_id, "FieldSegmentation", "INFO",
|
||||
f"Done: {len(frames)} frames, avg coverage {avg_coverage:.1%}")
|
||||
|
||||
return {
|
||||
"field_masks": field_masks,
|
||||
"field_mask_overlays": field_mask_overlays,
|
||||
"field_boundaries": field_boundaries,
|
||||
"field_coverage": field_coverage,
|
||||
}
|
||||
93
core/detect/stages/frame_extractor.py
Normal file
93
core/detect/stages/frame_extractor.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
Stage 1 — Frame Extraction
|
||||
|
||||
Extracts frames from a video at a configurable FPS using the core ffmpeg module.
|
||||
Emits log + stats_update SSE events as it works.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from core.ffmpeg.probe import probe_file
|
||||
from core.detect import emit
|
||||
from core.detect.models import Frame
|
||||
from core.detect.stages.models import FrameExtractionConfig
|
||||
|
||||
|
||||
def _load_frames(tmpdir: Path, fps: float) -> list[Frame]:
|
||||
"""Load extracted JPEG files into Frame objects."""
|
||||
frame_files = sorted(tmpdir.glob("frame_*.jpg"))
|
||||
frames = []
|
||||
for i, fpath in enumerate(frame_files):
|
||||
img = Image.open(fpath)
|
||||
frame = Frame(
|
||||
sequence=i,
|
||||
chunk_id=0,
|
||||
timestamp=i / fps,
|
||||
image=np.array(img),
|
||||
)
|
||||
frames.append(frame)
|
||||
return frames
|
||||
|
||||
|
||||
def extract_frames(
|
||||
video_path: str,
|
||||
config: FrameExtractionConfig,
|
||||
job_id: str | None = None,
|
||||
) -> list[Frame]:
|
||||
"""
|
||||
Extract frames from video at the configured FPS.
|
||||
|
||||
Uses ffmpeg-python to build the extraction pipeline,
|
||||
outputs JPEG files to a temp dir, then loads as numpy arrays.
|
||||
"""
|
||||
probe = probe_file(video_path)
|
||||
duration = probe.duration or 0.0
|
||||
|
||||
emit.log(job_id, "FrameExtractor", "INFO",
|
||||
f"Starting extraction: {Path(video_path).name} "
|
||||
f"({duration:.1f}s, {probe.width}x{probe.height}, fps={config.fps})")
|
||||
emit.log(job_id, "FrameExtractor", "DEBUG",
|
||||
f"Probe: codec={probe.video_codec}, bitrate={probe.video_bitrate}, max_frames={config.max_frames}")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pattern = str(Path(tmpdir) / "frame_%06d.jpg")
|
||||
|
||||
stream = (
|
||||
ffmpeg
|
||||
.input(video_path)
|
||||
.filter("fps", fps=config.fps)
|
||||
.output(pattern, qscale=2, frames=config.max_frames)
|
||||
.overwrite_output()
|
||||
)
|
||||
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
stream.run(capture_stdout=True, capture_stderr=True, quiet=True)
|
||||
except ffmpeg.Error as e:
|
||||
stderr = e.stderr.decode() if e.stderr else "unknown error"
|
||||
emit.log(job_id, "FrameExtractor", "ERROR", f"FFmpeg failed: {stderr[:200]}")
|
||||
raise RuntimeError(f"FFmpeg failed: {stderr}") from e
|
||||
ffmpeg_ms = (time.monotonic() - t0) * 1000
|
||||
emit.log(job_id, "FrameExtractor", "DEBUG", f"FFmpeg decode: {ffmpeg_ms:.0f}ms")
|
||||
|
||||
t0 = time.monotonic()
|
||||
frames = _load_frames(Path(tmpdir), config.fps)
|
||||
load_ms = (time.monotonic() - t0) * 1000
|
||||
if frames:
|
||||
h, w = frames[0].image.shape[:2]
|
||||
mem_mb = sum(f.image.nbytes for f in frames) / (1024 * 1024)
|
||||
emit.log(job_id, "FrameExtractor", "DEBUG",
|
||||
f"Loaded {len(frames)} frames ({w}x{h}) in {load_ms:.0f}ms, {mem_mb:.1f}MB in memory")
|
||||
|
||||
emit.log(job_id, "FrameExtractor", "INFO", f"Extracted {len(frames)} frames")
|
||||
emit.stats(job_id, frames_extracted=len(frames))
|
||||
|
||||
return frames
|
||||
125
core/detect/stages/models.py
Normal file
125
core/detect/stages/models.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Pydantic Models - GENERATED FILE
|
||||
|
||||
Do not edit directly. Regenerate using modelgen.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class StageConfigField(BaseModel):
|
||||
"""A single tunable config parameter for the editor UI."""
|
||||
name: str
|
||||
type: str
|
||||
default: Any
|
||||
description: str = ""
|
||||
min: Optional[float] = None
|
||||
max: Optional[float] = None
|
||||
options: Optional[List[str]] = None
|
||||
|
||||
class StageIO(BaseModel):
|
||||
"""Declares what a stage reads and writes."""
|
||||
reads: List[str] = Field(default_factory=list)
|
||||
writes: List[str] = Field(default_factory=list)
|
||||
optional_reads: List[str] = Field(default_factory=list)
|
||||
|
||||
class StageOutputHint(BaseModel):
|
||||
"""How to render a stage output in the compare/editor views."""
|
||||
key: str
|
||||
type: str
|
||||
label: str = ""
|
||||
default_opacity: float = 0.5
|
||||
src_format: str = "png"
|
||||
|
||||
class TransformOption(BaseModel):
|
||||
"""A transform the stage accepts on its incoming edges."""
|
||||
key: str
|
||||
type: str
|
||||
default: Any = False
|
||||
label: str = ""
|
||||
description: str = ""
|
||||
|
||||
class StageDefinition(BaseModel):
|
||||
"""Complete metadata for a pipeline stage."""
|
||||
name: str
|
||||
label: str
|
||||
description: str
|
||||
category: str = "detection"
|
||||
io: StageIO
|
||||
config_fields: List[StageConfigField] = Field(default_factory=list)
|
||||
output_hints: List[StageOutputHint] = Field(default_factory=list)
|
||||
accepted_transforms: List[TransformOption] = Field(default_factory=list)
|
||||
tracks_element: Optional[str] = None
|
||||
|
||||
class FrameExtractionConfig(BaseModel):
|
||||
"""FrameExtractionConfig(fps: float = 2.0, max_frames: int = 500)"""
|
||||
fps: float = 2.0
|
||||
max_frames: int = 500
|
||||
|
||||
class SceneFilterConfig(BaseModel):
|
||||
"""SceneFilterConfig(hamming_threshold: int = 8, enabled: bool = True)"""
|
||||
hamming_threshold: int = 8
|
||||
enabled: bool = True
|
||||
|
||||
class DetectionConfig(BaseModel):
|
||||
"""DetectionConfig(model_name: str = 'yolov8n.pt', confidence_threshold: float = 0.3, target_classes: List[str] = <factory>)"""
|
||||
model_name: str = "yolov8n.pt"
|
||||
confidence_threshold: float = 0.3
|
||||
target_classes: List[str]
|
||||
|
||||
class OCRConfig(BaseModel):
|
||||
"""OCRConfig(languages: List[str] = <factory>, min_confidence: float = 0.5)"""
|
||||
languages: List[str]
|
||||
min_confidence: float = 0.5
|
||||
|
||||
class ResolverConfig(BaseModel):
|
||||
"""ResolverConfig(fuzzy_threshold: int = 75)"""
|
||||
fuzzy_threshold: int = 75
|
||||
|
||||
class RegionAnalysisConfig(BaseModel):
|
||||
"""RegionAnalysisConfig(enabled: bool = True, edge_canny_low: int = 50, edge_canny_high: int = 150, edge_hough_threshold: int = 80, edge_hough_min_length: int = 100, edge_hough_max_gap: int = 10, edge_pair_max_distance: int = 200, edge_pair_min_distance: int = 15)"""
|
||||
enabled: bool = True
|
||||
edge_canny_low: int = 50
|
||||
edge_canny_high: int = 150
|
||||
edge_hough_threshold: int = 80
|
||||
edge_hough_min_length: int = 100
|
||||
edge_hough_max_gap: int = 10
|
||||
edge_pair_max_distance: int = 200
|
||||
edge_pair_min_distance: int = 15
|
||||
|
||||
class FieldSegmentationConfig(BaseModel):
|
||||
"""FieldSegmentationConfig(enabled: bool = True, hue_low: int = 30, hue_high: int = 85, sat_low: int = 30, sat_high: int = 255, val_low: int = 30, val_high: int = 255, morph_kernel: int = 15, min_area_ratio: float = 0.05)"""
|
||||
enabled: bool = True
|
||||
hue_low: int = 30
|
||||
hue_high: int = 85
|
||||
sat_low: int = 30
|
||||
sat_high: int = 255
|
||||
val_low: int = 30
|
||||
val_high: int = 255
|
||||
morph_kernel: int = 15
|
||||
min_area_ratio: float = 0.05
|
||||
|
||||
class StageRef(BaseModel):
|
||||
"""Reference to a stage in the pipeline graph."""
|
||||
name: str
|
||||
branch: str = "trunk"
|
||||
execution_target: str = "local"
|
||||
|
||||
class Edge(BaseModel):
|
||||
"""Connection between stages in the graph."""
|
||||
source: str
|
||||
target: str
|
||||
condition: str = ""
|
||||
transform: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
class PipelineConfig(BaseModel):
|
||||
"""Pipeline graph topology + routing rules."""
|
||||
name: str
|
||||
profile_name: str
|
||||
stages: List[StageRef] = Field(default_factory=list)
|
||||
edges: List[Edge] = Field(default_factory=list)
|
||||
routing_rules: Dict[str, Any] = Field(default_factory=dict)
|
||||
139
core/detect/stages/ocr_stage.py
Normal file
139
core/detect/stages/ocr_stage.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
Stage 4 — OCR
|
||||
|
||||
Reads text from detected regions (YOLO bounding box crops).
|
||||
Two modes:
|
||||
- remote: calls inference server over HTTP (separate GPU box, or localhost)
|
||||
- local: runs PaddleOCR in-process (single-box setup with enough VRAM)
|
||||
|
||||
The mode is selected by whether inference_url is provided.
|
||||
Model instances are cached at module level so they survive across pipeline runs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import BoundingBox, Frame, TextCandidate
|
||||
from core.detect.stages.models import OCRConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Module-level cache — avoids reloading the model for every crop or pipeline run
|
||||
_local_ocr_cache: dict[str, object] = {}
|
||||
|
||||
|
||||
def _crop_region(frame: Frame, box: BoundingBox) -> np.ndarray:
|
||||
h, w = frame.image.shape[:2]
|
||||
x1 = max(0, box.x)
|
||||
y1 = max(0, box.y)
|
||||
x2 = min(w, box.x + box.w)
|
||||
y2 = min(h, box.y + box.h)
|
||||
return frame.image[y1:y2, x1:x2]
|
||||
|
||||
|
||||
def _get_local_model(lang: str):
|
||||
if lang not in _local_ocr_cache:
|
||||
from paddleocr import PaddleOCR
|
||||
logger.info("Loading PaddleOCR locally (lang=%s)", lang)
|
||||
_local_ocr_cache[lang] = PaddleOCR(lang=lang)
|
||||
return _local_ocr_cache[lang]
|
||||
|
||||
|
||||
def _parse_ocr_raw(raw, min_confidence: float) -> list[dict]:
|
||||
"""Parse PaddleOCR 3.x result — handles dict-based and nested-list layouts."""
|
||||
results = []
|
||||
for page in (raw or []):
|
||||
if not page:
|
||||
continue
|
||||
if isinstance(page, dict):
|
||||
for text, confidence in zip(page.get("rec_texts", []), page.get("rec_scores", [])):
|
||||
if float(confidence) >= min_confidence:
|
||||
results.append({"text": text, "confidence": float(confidence)})
|
||||
continue
|
||||
for line in page:
|
||||
if not line:
|
||||
continue
|
||||
rec = line[1]
|
||||
if isinstance(rec, (list, tuple)) and len(rec) >= 2:
|
||||
text, confidence = rec[0], rec[1]
|
||||
if float(confidence) >= min_confidence:
|
||||
results.append({"text": text, "confidence": float(confidence)})
|
||||
return results
|
||||
|
||||
|
||||
def run_ocr(
|
||||
frames: list[Frame],
|
||||
boxes_by_frame: dict[int, list[BoundingBox]],
|
||||
config: OCRConfig,
|
||||
inference_url: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> list[TextCandidate]:
|
||||
"""
|
||||
Run OCR on cropped regions from YOLO detections.
|
||||
|
||||
inference_url=None → local in-process PaddleOCR (single-box)
|
||||
inference_url=str → remote inference server (split or localhost)
|
||||
"""
|
||||
total_regions = sum(len(boxes) for boxes in boxes_by_frame.values())
|
||||
mode = "remote" if inference_url else "local"
|
||||
|
||||
emit.log(job_id, "OCRStage", "INFO",
|
||||
f"Running OCR on {total_regions} regions (mode={mode})")
|
||||
|
||||
# Build these once per pipeline run, not per crop
|
||||
if inference_url:
|
||||
from core.detect.inference import InferenceClient
|
||||
from core.detect.emit import _run_log_level
|
||||
client = InferenceClient(base_url=inference_url, job_id=job_id or "", log_level=_run_log_level)
|
||||
else:
|
||||
model = _get_local_model(config.languages[0])
|
||||
|
||||
frame_map = {f.sequence: f for f in frames}
|
||||
candidates: list[TextCandidate] = []
|
||||
|
||||
for seq, boxes in boxes_by_frame.items():
|
||||
frame = frame_map.get(seq)
|
||||
if not frame:
|
||||
continue
|
||||
|
||||
for box in boxes:
|
||||
crop = _crop_region(frame, box)
|
||||
if crop.size == 0:
|
||||
continue
|
||||
|
||||
t0 = time.monotonic()
|
||||
if inference_url:
|
||||
raw_results = client.ocr(image=crop, languages=config.languages)
|
||||
texts = [{"text": r.text, "confidence": r.confidence} for r in raw_results]
|
||||
else:
|
||||
raw = model.ocr(crop)
|
||||
texts = _parse_ocr_raw(raw, config.min_confidence)
|
||||
ocr_ms = (time.monotonic() - t0) * 1000
|
||||
|
||||
h, w = crop.shape[:2]
|
||||
text_preview = ", ".join(t["text"][:30] for t in texts) if texts else "(none)"
|
||||
emit.log(job_id, "OCRStage", "DEBUG",
|
||||
f"Frame {seq} box {box.x},{box.y} ({w}x{h}): {ocr_ms:.0f}ms → {text_preview}")
|
||||
|
||||
for t in texts:
|
||||
candidates.append(TextCandidate(
|
||||
frame=frame,
|
||||
bbox=box,
|
||||
text=t["text"],
|
||||
ocr_confidence=t["confidence"],
|
||||
))
|
||||
|
||||
emit.log(job_id, "OCRStage", "INFO",
|
||||
f"Extracted text from {len(candidates)} regions")
|
||||
emit.stats(job_id, regions_resolved_by_ocr=len(candidates))
|
||||
|
||||
return candidates
|
||||
128
core/detect/stages/preprocess.py
Normal file
128
core/detect/stages/preprocess.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
Stage 3.5 — Preprocessing
|
||||
|
||||
Runs between YOLO detection and OCR. Applies configurable image
|
||||
preprocessing to each detected region crop: contrast enhancement,
|
||||
deskewing, binarization.
|
||||
|
||||
Operates on the crops derived from boxes_by_frame, produces
|
||||
preprocessed_crops keyed by (frame_sequence, box_index).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import BoundingBox, Frame
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _crop_region(frame: Frame, box: BoundingBox) -> np.ndarray:
|
||||
h, w = frame.image.shape[:2]
|
||||
x1 = max(0, box.x)
|
||||
y1 = max(0, box.y)
|
||||
x2 = min(w, box.x + box.w)
|
||||
y2 = min(h, box.y + box.h)
|
||||
return frame.image[y1:y2, x1:x2]
|
||||
|
||||
|
||||
def preprocess_regions(
|
||||
frames: list[Frame],
|
||||
boxes_by_frame: dict[int, list[BoundingBox]],
|
||||
do_contrast: bool = True,
|
||||
do_deskew: bool = False,
|
||||
do_binarize: bool = False,
|
||||
inference_url: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
Preprocess cropped regions from YOLO detections.
|
||||
|
||||
Returns dict keyed by "{frame_seq}_{box_idx}" → preprocessed crop.
|
||||
These are passed to the OCR stage instead of raw crops.
|
||||
"""
|
||||
total_regions = sum(len(boxes) for boxes in boxes_by_frame.values())
|
||||
any_active = do_contrast or do_deskew or do_binarize
|
||||
|
||||
if not any_active:
|
||||
emit.log(job_id, "Preprocess", "INFO",
|
||||
f"Preprocessing disabled, passing {total_regions} regions through")
|
||||
return {}
|
||||
|
||||
mode = "remote" if inference_url else "local"
|
||||
emit.log(job_id, "Preprocess", "INFO",
|
||||
f"Preprocessing {total_regions} regions (mode={mode}, "
|
||||
f"contrast={do_contrast}, deskew={do_deskew}, binarize={do_binarize})")
|
||||
|
||||
frame_map = {f.sequence: f for f in frames}
|
||||
preprocessed: dict[str, np.ndarray] = {}
|
||||
processed_count = 0
|
||||
|
||||
for seq, boxes in boxes_by_frame.items():
|
||||
frame = frame_map.get(seq)
|
||||
if not frame:
|
||||
continue
|
||||
|
||||
for idx, box in enumerate(boxes):
|
||||
crop = _crop_region(frame, box)
|
||||
if crop.size == 0:
|
||||
continue
|
||||
|
||||
key = f"{seq}_{idx}"
|
||||
|
||||
if inference_url:
|
||||
result = _preprocess_remote(crop, inference_url,
|
||||
do_contrast, do_deskew, do_binarize)
|
||||
else:
|
||||
result = _preprocess_local(crop, do_contrast, do_deskew, do_binarize)
|
||||
|
||||
preprocessed[key] = result
|
||||
processed_count += 1
|
||||
|
||||
emit.log(job_id, "Preprocess", "INFO",
|
||||
f"Preprocessed {processed_count} regions")
|
||||
|
||||
return preprocessed
|
||||
|
||||
|
||||
def _preprocess_remote(crop: np.ndarray, inference_url: str,
|
||||
do_contrast: bool, do_deskew: bool, do_binarize: bool) -> np.ndarray:
|
||||
"""Call GPU server /preprocess endpoint."""
|
||||
import base64
|
||||
import io
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
img = Image.fromarray(crop)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=85)
|
||||
image_b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
resp = requests.post(
|
||||
f"{inference_url.rstrip('/')}/preprocess",
|
||||
json={
|
||||
"image": image_b64,
|
||||
"contrast": do_contrast,
|
||||
"deskew": do_deskew,
|
||||
"binarize": do_binarize,
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
result_bytes = base64.b64decode(data["image"])
|
||||
result_img = Image.open(io.BytesIO(result_bytes)).convert("RGB")
|
||||
return np.array(result_img)
|
||||
|
||||
|
||||
def _preprocess_local(crop: np.ndarray,
|
||||
do_contrast: bool, do_deskew: bool, do_binarize: bool) -> np.ndarray:
|
||||
"""Run preprocessing in-process (requires opencv-python-headless)."""
|
||||
from core.gpu.models.preprocess import preprocess
|
||||
return preprocess(crop, do_binarize=do_binarize, do_deskew=do_deskew, do_contrast=do_contrast)
|
||||
31
core/detect/stages/registry/__init__.py
Normal file
31
core/detect/stages/registry/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Stage registry — registers all built-in stages.
|
||||
|
||||
Split by category:
|
||||
preprocessing.py — extract_frames, filter_scenes
|
||||
cv_analysis.py — detect_edges (+ future: detect_contours, detect_color, merge_regions)
|
||||
detection.py — detect_objects, run_ocr
|
||||
resolution.py — match_brands
|
||||
escalation.py — escalate_vlm, escalate_cloud
|
||||
output.py — compile_report
|
||||
_serializers.py — shared serialization helpers
|
||||
"""
|
||||
|
||||
from . import preprocessing
|
||||
from . import cv_analysis
|
||||
from . import detection
|
||||
from . import resolution
|
||||
from . import escalation
|
||||
from . import output
|
||||
|
||||
|
||||
def register_all():
|
||||
preprocessing.register()
|
||||
cv_analysis.register()
|
||||
detection.register()
|
||||
resolution.register()
|
||||
escalation.register()
|
||||
output.register()
|
||||
|
||||
|
||||
register_all()
|
||||
24
core/detect/stages/registry/_serializers.py
Normal file
24
core/detect/stages/registry/_serializers.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Re-export serializers from core/schema/serializers/.
|
||||
|
||||
Stage registry modules import from here for convenience.
|
||||
All serialization logic lives in core/schema/serializers/.
|
||||
"""
|
||||
|
||||
from core.schema.serializers._common import (
|
||||
safe_construct,
|
||||
serialize_dataclass,
|
||||
serialize_dataclass_list,
|
||||
)
|
||||
from core.schema.serializers.pipeline import (
|
||||
serialize_frame_meta,
|
||||
serialize_frames_meta,
|
||||
serialize_text_candidate,
|
||||
serialize_text_candidates,
|
||||
deserialize_text_candidate,
|
||||
deserialize_text_candidates,
|
||||
deserialize_bounding_box,
|
||||
deserialize_brand_detection,
|
||||
deserialize_pipeline_stats,
|
||||
deserialize_detection_report,
|
||||
)
|
||||
83
core/detect/stages/registry/cv_analysis.py
Normal file
83
core/detect/stages/registry/cv_analysis.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Registration for CV analysis stages: edge detection, field segmentation."""
|
||||
|
||||
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||
from core.detect.stages.base import register_stage
|
||||
from ._serializers import serialize_dataclass_list, deserialize_bounding_box
|
||||
|
||||
|
||||
def _ser_regions(state: dict, job_id: str) -> dict:
|
||||
regions = state.get("edge_regions_by_frame", {})
|
||||
serialized = {
|
||||
str(seq): serialize_dataclass_list(bl) for seq, bl in regions.items()
|
||||
}
|
||||
return {"edge_regions_by_frame": serialized}
|
||||
|
||||
|
||||
def _deser_regions(data: dict, job_id: str) -> dict:
|
||||
regions = {}
|
||||
for seq_str, box_dicts in data.get("edge_regions_by_frame", {}).items():
|
||||
regions[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts]
|
||||
return {"edge_regions_by_frame": regions}
|
||||
|
||||
|
||||
def _ser_field_seg(state: dict, job_id: str) -> dict:
|
||||
"""Serialize field segmentation — boundaries + coverage + mask overlays."""
|
||||
boundaries = state.get("field_boundaries", {})
|
||||
coverage = state.get("field_coverage", {})
|
||||
mask_overlays = state.get("field_mask_overlays", {})
|
||||
return {
|
||||
"field_boundaries": {str(k): v for k, v in boundaries.items()},
|
||||
"field_coverage": {str(k): v for k, v in coverage.items()},
|
||||
"mask_overlays_by_frame": {str(k): v for k, v in mask_overlays.items()},
|
||||
}
|
||||
|
||||
|
||||
def _deser_field_seg(data: dict, job_id: str) -> dict:
|
||||
boundaries = {int(k): v for k, v in data.get("field_boundaries", {}).items()}
|
||||
coverage = {int(k): v for k, v in data.get("field_coverage", {}).items()}
|
||||
return {"field_boundaries": boundaries, "field_coverage": coverage}
|
||||
|
||||
|
||||
def register():
|
||||
edge_detection = StageDefinition(
|
||||
name="detect_edges",
|
||||
label="Edge Detection",
|
||||
description="Canny + HoughLinesP — find horizontal line pairs (hoarding boundaries)",
|
||||
category="cv_analysis",
|
||||
io=StageIO(
|
||||
reads=["filtered_frames"],
|
||||
writes=["edge_regions_by_frame"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField(name="edge_canny_low", type="int", default=50, description="Canny low threshold", min=0, max=255),
|
||||
StageConfigField(name="edge_canny_high", type="int", default=150, description="Canny high threshold", min=0, max=255),
|
||||
StageConfigField(name="edge_hough_threshold", type="int", default=80, description="Hough accumulator threshold", min=1, max=500),
|
||||
StageConfigField(name="edge_hough_min_length", type="int", default=100, description="Min line length (px)", min=10, max=2000),
|
||||
StageConfigField(name="edge_hough_max_gap", type="int", default=10, description="Max line gap (px)", min=1, max=100),
|
||||
StageConfigField(name="edge_pair_max_distance", type="int", default=200, description="Max distance between line pair (px)", min=10, max=500),
|
||||
StageConfigField(name="edge_pair_min_distance", type="int", default=15, description="Min distance between line pair (px)", min=5, max=200),
|
||||
],
|
||||
)
|
||||
register_stage(edge_detection, serialize_fn=_ser_regions, deserialize_fn=_deser_regions)
|
||||
|
||||
field_seg = StageDefinition(
|
||||
name="field_segmentation",
|
||||
label="Field Segmentation",
|
||||
description="HSV green mask — detect pitch boundaries",
|
||||
category="cv_analysis",
|
||||
io=StageIO(
|
||||
reads=["filtered_frames"],
|
||||
writes=["field_mask"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField(name="hue_low", type="int", default=30, description="HSV hue lower bound", min=0, max=180),
|
||||
StageConfigField(name="hue_high", type="int", default=85, description="HSV hue upper bound", min=0, max=180),
|
||||
StageConfigField(name="sat_low", type="int", default=30, description="HSV saturation lower bound", min=0, max=255),
|
||||
StageConfigField(name="sat_high", type="int", default=255, description="HSV saturation upper bound", min=0, max=255),
|
||||
StageConfigField(name="val_low", type="int", default=30, description="HSV value lower bound", min=0, max=255),
|
||||
StageConfigField(name="val_high", type="int", default=255, description="HSV value upper bound", min=0, max=255),
|
||||
StageConfigField(name="morph_kernel", type="int", default=15, description="Morphology kernel size", min=3, max=51),
|
||||
StageConfigField(name="min_area_ratio", type="float", default=0.05, description="Min contour area", min=0.01, max=0.5),
|
||||
],
|
||||
)
|
||||
register_stage(field_seg, serialize_fn=_ser_field_seg, deserialize_fn=_deser_field_seg)
|
||||
60
core/detect/stages/registry/detection.py
Normal file
60
core/detect/stages/registry/detection.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Registration for detection stages: YOLO, OCR."""
|
||||
|
||||
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||
from core.detect.stages.base import register_stage
|
||||
from ._serializers import (
|
||||
serialize_dataclass_list,
|
||||
serialize_text_candidates,
|
||||
deserialize_bounding_box,
|
||||
)
|
||||
|
||||
|
||||
def _ser_detect(state: dict, job_id: str) -> dict:
|
||||
boxes = state.get("boxes_by_frame", {})
|
||||
serialized = {str(seq): serialize_dataclass_list(bl) for seq, bl in boxes.items()}
|
||||
return {"boxes_by_frame": serialized}
|
||||
|
||||
|
||||
def _deser_detect(data: dict, job_id: str) -> dict:
|
||||
boxes = {}
|
||||
for seq_str, box_dicts in data.get("boxes_by_frame", {}).items():
|
||||
boxes[int(seq_str)] = [deserialize_bounding_box(b) for b in box_dicts]
|
||||
return {"boxes_by_frame": boxes}
|
||||
|
||||
|
||||
def _ser_ocr(state: dict, job_id: str) -> dict:
|
||||
candidates = state.get("text_candidates", [])
|
||||
return {"text_candidates": serialize_text_candidates(candidates)}
|
||||
|
||||
|
||||
def _deser_ocr(data: dict, job_id: str) -> dict:
|
||||
return {"_text_candidates_raw": data["text_candidates"]}
|
||||
|
||||
|
||||
def register():
|
||||
yolo = StageDefinition(
|
||||
name="detect_objects",
|
||||
label="Object Detection",
|
||||
description="YOLO object detection on filtered frames",
|
||||
category="detection",
|
||||
io=StageIO(reads=["filtered_frames"], writes=["boxes_by_frame"]),
|
||||
config_fields=[
|
||||
StageConfigField(name="model_name", type="str", default="yolov8n.pt", description="YOLO model file"),
|
||||
StageConfigField(name="confidence_threshold", type="float", default=0.3, description="Min detection confidence", min=0.0, max=1.0),
|
||||
StageConfigField(name="target_classes", type="list[str]", default=[], description="YOLO classes to detect (empty = all)"),
|
||||
],
|
||||
)
|
||||
register_stage(yolo, serialize_fn=_ser_detect, deserialize_fn=_deser_detect)
|
||||
|
||||
ocr = StageDefinition(
|
||||
name="run_ocr",
|
||||
label="OCR",
|
||||
description="Extract text from detected regions",
|
||||
category="detection",
|
||||
io=StageIO(reads=["filtered_frames", "boxes_by_frame"], writes=["text_candidates"]),
|
||||
config_fields=[
|
||||
StageConfigField(name="languages", type="list[str]", default=["en"], description="OCR languages"),
|
||||
StageConfigField(name="min_confidence", type="float", default=0.5, description="Min OCR confidence", min=0.0, max=1.0),
|
||||
],
|
||||
)
|
||||
register_stage(ocr, serialize_fn=_ser_ocr, deserialize_fn=_deser_ocr)
|
||||
60
core/detect/stages/registry/escalation.py
Normal file
60
core/detect/stages/registry/escalation.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Registration for escalation stages: local VLM, cloud LLM."""
|
||||
|
||||
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||
from core.detect.stages.base import register_stage
|
||||
from ._serializers import (
|
||||
serialize_dataclass_list,
|
||||
serialize_text_candidates,
|
||||
deserialize_brand_detection,
|
||||
)
|
||||
|
||||
|
||||
def _ser_escalation(state: dict, job_id: str) -> dict:
|
||||
matched = state.get("detections", [])
|
||||
unresolved = state.get("unresolved_candidates", [])
|
||||
return {
|
||||
"detections": serialize_dataclass_list(matched),
|
||||
"unresolved_candidates": serialize_text_candidates(unresolved),
|
||||
}
|
||||
|
||||
|
||||
def _deser_escalation(data: dict, job_id: str) -> dict:
|
||||
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
|
||||
return {
|
||||
"detections": detections,
|
||||
"_unresolved_raw": data.get("unresolved_candidates", []),
|
||||
}
|
||||
|
||||
|
||||
def register():
|
||||
vlm = StageDefinition(
|
||||
name="escalate_vlm",
|
||||
label="Local VLM",
|
||||
description="Process unresolved crops with moondream2",
|
||||
category="escalation",
|
||||
io=StageIO(
|
||||
reads=["unresolved_candidates"],
|
||||
writes=["detections", "unresolved_candidates"],
|
||||
optional_reads=["source_asset_id"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField(name="min_confidence", type="float", default=0.5, description="Min VLM confidence", min=0.0, max=1.0),
|
||||
],
|
||||
)
|
||||
register_stage(vlm, serialize_fn=_ser_escalation, deserialize_fn=_deser_escalation)
|
||||
|
||||
cloud = StageDefinition(
|
||||
name="escalate_cloud",
|
||||
label="Cloud LLM",
|
||||
description="Escalate remaining crops to cloud provider",
|
||||
category="escalation",
|
||||
io=StageIO(
|
||||
reads=["unresolved_candidates"],
|
||||
writes=["detections"],
|
||||
optional_reads=["source_asset_id"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField(name="min_confidence", type="float", default=0.4, description="Min cloud confidence", min=0.0, max=1.0),
|
||||
],
|
||||
)
|
||||
register_stage(cloud, serialize_fn=_ser_escalation, deserialize_fn=_deser_escalation)
|
||||
30
core/detect/stages/registry/output.py
Normal file
30
core/detect/stages/registry/output.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Registration for output stages: report compilation."""
|
||||
|
||||
from core.detect.stages.base import StageDefinition, StageIO, register_stage
|
||||
from ._serializers import serialize_dataclass, deserialize_detection_report
|
||||
|
||||
|
||||
def _ser_report(state: dict, job_id: str) -> dict:
|
||||
report = state.get("report")
|
||||
if report is None:
|
||||
return {"report": None}
|
||||
return {"report": serialize_dataclass(report)}
|
||||
|
||||
|
||||
def _deser_report(data: dict, job_id: str) -> dict:
|
||||
raw = data.get("report")
|
||||
if raw is None:
|
||||
return {"report": None}
|
||||
return {"report": deserialize_detection_report(raw)}
|
||||
|
||||
|
||||
def register():
|
||||
report = StageDefinition(
|
||||
name="compile_report",
|
||||
label="Report",
|
||||
description="Merge detections and compile final report",
|
||||
category="output",
|
||||
io=StageIO(reads=["detections"], writes=["report"]),
|
||||
config_fields=[],
|
||||
)
|
||||
register_stage(report, serialize_fn=_ser_report, deserialize_fn=_deser_report)
|
||||
82
core/detect/stages/registry/preprocessing.py
Normal file
82
core/detect/stages/registry/preprocessing.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Registration for preprocessing stages: frame extraction, scene filter, image preprocessing."""
|
||||
|
||||
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||
from core.detect.stages.base import register_stage
|
||||
from ._serializers import serialize_frame_meta
|
||||
|
||||
|
||||
def _ser_extract(state: dict, job_id: str) -> dict:
|
||||
frames = state.get("frames", [])
|
||||
meta = [serialize_frame_meta(f) for f in frames]
|
||||
return {"frames_meta": meta, "frame_count": len(frames)}
|
||||
|
||||
|
||||
def _deser_extract(data: dict, job_id: str) -> dict:
|
||||
# Frames are ephemeral — re-extract from chunks on demand.
|
||||
# Store metadata so we know what was extracted.
|
||||
return {"_frames_meta": data.get("frames_meta", [])}
|
||||
|
||||
|
||||
def _ser_filter(state: dict, job_id: str) -> dict:
|
||||
filtered = state.get("filtered_frames", [])
|
||||
seqs = [f.sequence for f in filtered]
|
||||
return {"filtered_frame_sequences": seqs}
|
||||
|
||||
|
||||
def _deser_filter(data: dict, job_id: str) -> dict:
|
||||
return {"_filtered_sequences": data["filtered_frame_sequences"]}
|
||||
|
||||
|
||||
def _ser_preprocess(state: dict, job_id: str) -> dict:
|
||||
# Preprocessed crops are numpy arrays — regenerable from frames + boxes + config
|
||||
crops = state.get("preprocessed_crops", {})
|
||||
return {"crop_keys": list(crops.keys()), "count": len(crops)}
|
||||
|
||||
|
||||
def _deser_preprocess(data: dict, job_id: str) -> dict:
|
||||
# Crops are regenerable — no need to restore from checkpoint
|
||||
return {"preprocessed_crops": {}}
|
||||
|
||||
|
||||
def register():
|
||||
extract = StageDefinition(
|
||||
name="extract_frames",
|
||||
label="Frame Extraction",
|
||||
description="Extract frames from video at configurable FPS",
|
||||
category="preprocessing",
|
||||
io=StageIO(reads=["video_path"], writes=["frames"]),
|
||||
config_fields=[
|
||||
StageConfigField(name="fps", type="float", default=2.0, description="Frames per second", min=0.1, max=30.0),
|
||||
StageConfigField(name="max_frames", type="int", default=500, description="Maximum frames to extract", min=1, max=10000),
|
||||
],
|
||||
)
|
||||
register_stage(extract, serialize_fn=_ser_extract, deserialize_fn=_deser_extract)
|
||||
|
||||
scene_filter = StageDefinition(
|
||||
name="filter_scenes",
|
||||
label="Scene Filter",
|
||||
description="Deduplicate similar frames using perceptual hashing",
|
||||
category="preprocessing",
|
||||
io=StageIO(reads=["frames"], writes=["filtered_frames"]),
|
||||
config_fields=[
|
||||
StageConfigField(name="hamming_threshold", type="int", default=8, description="Hamming distance threshold", min=0, max=64),
|
||||
],
|
||||
)
|
||||
register_stage(scene_filter, serialize_fn=_ser_filter, deserialize_fn=_deser_filter)
|
||||
|
||||
preprocess = StageDefinition(
|
||||
name="preprocess",
|
||||
label="Preprocess",
|
||||
description="Image preprocessing on detected regions before OCR",
|
||||
category="preprocessing",
|
||||
io=StageIO(
|
||||
reads=["filtered_frames", "boxes_by_frame"],
|
||||
writes=["preprocessed_crops"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField(name="contrast", type="bool", default=True, description="CLAHE contrast enhancement"),
|
||||
StageConfigField(name="deskew", type="bool", default=False, description="Correct slight rotation"),
|
||||
StageConfigField(name="binarize", type="bool", default=False, description="Otsu binarization"),
|
||||
],
|
||||
)
|
||||
register_stage(preprocess, serialize_fn=_ser_preprocess, deserialize_fn=_deser_preprocess)
|
||||
44
core/detect/stages/registry/resolution.py
Normal file
44
core/detect/stages/registry/resolution.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Registration for resolution stages: brand resolver."""
|
||||
|
||||
from core.detect.stages.models import StageDefinition, StageIO, StageConfigField
|
||||
from core.detect.stages.base import register_stage
|
||||
from ._serializers import (
|
||||
serialize_dataclass_list,
|
||||
serialize_text_candidates,
|
||||
deserialize_brand_detection,
|
||||
)
|
||||
|
||||
|
||||
def _ser_brands(state: dict, job_id: str) -> dict:
|
||||
matched = state.get("detections", [])
|
||||
unresolved = state.get("unresolved_candidates", [])
|
||||
return {
|
||||
"detections": serialize_dataclass_list(matched),
|
||||
"unresolved_candidates": serialize_text_candidates(unresolved),
|
||||
}
|
||||
|
||||
|
||||
def _deser_brands(data: dict, job_id: str) -> dict:
|
||||
detections = [deserialize_brand_detection(d) for d in data.get("detections", [])]
|
||||
return {
|
||||
"detections": detections,
|
||||
"_unresolved_raw": data.get("unresolved_candidates", []),
|
||||
}
|
||||
|
||||
|
||||
def register():
|
||||
resolver = StageDefinition(
|
||||
name="match_brands",
|
||||
label="Brand Resolver",
|
||||
description="Match OCR text against known brands (session + global DB)",
|
||||
category="resolution",
|
||||
io=StageIO(
|
||||
reads=["text_candidates"],
|
||||
writes=["detections", "unresolved_candidates"],
|
||||
optional_reads=["session_brands", "source_asset_id"],
|
||||
),
|
||||
config_fields=[
|
||||
StageConfigField(name="fuzzy_threshold", type="int", default=75, description="Fuzzy match threshold", min=0, max=100),
|
||||
],
|
||||
)
|
||||
register_stage(resolver, serialize_fn=_ser_brands, deserialize_fn=_deser_brands)
|
||||
86
core/detect/stages/scene_filter.py
Normal file
86
core/detect/stages/scene_filter.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Stage 2 — Scene Filter
|
||||
|
||||
Removes near-duplicate frames using perceptual hashing (pHash).
|
||||
Frames with a hamming distance below the threshold are considered
|
||||
duplicates and dropped. This dramatically reduces work for downstream
|
||||
CV stages without losing unique visual content.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import imagehash
|
||||
from PIL import Image
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import Frame
|
||||
from core.detect.stages.models import SceneFilterConfig
|
||||
|
||||
|
||||
def _compute_hashes(frames: list[Frame]) -> list[imagehash.ImageHash]:
|
||||
"""Compute perceptual hashes for all frames."""
|
||||
hashes = []
|
||||
for f in frames:
|
||||
img = Image.fromarray(f.image)
|
||||
h = imagehash.phash(img)
|
||||
f.perceptual_hash = str(h)
|
||||
hashes.append(h)
|
||||
return hashes
|
||||
|
||||
|
||||
def _dedup(frames: list[Frame], hashes: list[imagehash.ImageHash], threshold: int) -> list[Frame]:
|
||||
"""Greedy dedup: keep a frame if it's sufficiently different from all kept frames."""
|
||||
kept = [frames[0]]
|
||||
kept_hashes = [hashes[0]]
|
||||
|
||||
for i in range(1, len(frames)):
|
||||
is_duplicate = any(hashes[i] - kh < threshold for kh in kept_hashes)
|
||||
if not is_duplicate:
|
||||
kept.append(frames[i])
|
||||
kept_hashes.append(hashes[i])
|
||||
|
||||
return kept
|
||||
|
||||
|
||||
def scene_filter(
|
||||
frames: list[Frame],
|
||||
config: SceneFilterConfig,
|
||||
job_id: str | None = None,
|
||||
) -> list[Frame]:
|
||||
"""
|
||||
Filter near-duplicate frames based on perceptual hash distance.
|
||||
|
||||
Keeps the first frame in each group of similar frames.
|
||||
Returns a new list — does not mutate the input.
|
||||
"""
|
||||
if not config.enabled:
|
||||
emit.log(job_id, "SceneFilter", "INFO", "Scene filter disabled, passing all frames through")
|
||||
return frames
|
||||
|
||||
if not frames:
|
||||
return []
|
||||
|
||||
emit.log(job_id, "SceneFilter", "INFO",
|
||||
f"Filtering {len(frames)} frames (hamming_threshold={config.hamming_threshold})")
|
||||
|
||||
t0 = time.monotonic()
|
||||
hashes = _compute_hashes(frames)
|
||||
hash_ms = (time.monotonic() - t0) * 1000
|
||||
emit.log(job_id, "SceneFilter", "DEBUG",
|
||||
f"Computed {len(hashes)} perceptual hashes in {hash_ms:.0f}ms ({hash_ms/max(len(hashes),1):.1f}ms/frame)")
|
||||
|
||||
t0 = time.monotonic()
|
||||
kept = _dedup(frames, hashes, config.hamming_threshold)
|
||||
dedup_ms = (time.monotonic() - t0) * 1000
|
||||
emit.log(job_id, "SceneFilter", "DEBUG", f"Dedup pass: {dedup_ms:.0f}ms")
|
||||
|
||||
dropped = len(frames) - len(kept)
|
||||
pct = (dropped / len(frames) * 100) if frames else 0
|
||||
|
||||
emit.log(job_id, "SceneFilter", "INFO",
|
||||
f"Kept {len(kept)} frames, dropped {dropped} ({pct:.0f}% reduction)")
|
||||
emit.stats(job_id, frames_extracted=len(frames), frames_after_scene_filter=len(kept))
|
||||
|
||||
return kept
|
||||
201
core/detect/stages/vlm_cloud.py
Normal file
201
core/detect/stages/vlm_cloud.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Stage 7 — Cloud LLM escalation
|
||||
|
||||
Last resort for crops the local VLM couldn't resolve.
|
||||
Provider-agnostic — switch via CLOUD_LLM_PROVIDER env var.
|
||||
Each provider has its own file under detect/providers/.
|
||||
|
||||
Tracks token usage and cost.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import BrandDetection, PipelineStats, TextCandidate
|
||||
from core.detect.models import CropContext
|
||||
from core.detect.providers import get_provider, has_api_key
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ESTIMATED_TOKENS_PER_CROP = 500
|
||||
|
||||
|
||||
def _register_discovered_brand(brand: str, source_asset_id: str | None,
|
||||
timestamp: float, confidence: float):
|
||||
"""Register a cloud-confirmed brand in the DB."""
|
||||
try:
|
||||
from core.detect.stages.brand_resolver import _register_brand, _record_sighting
|
||||
brand_id = _register_brand(brand, "cloud_llm")
|
||||
if brand_id and source_asset_id:
|
||||
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, "cloud_llm")
|
||||
except Exception as e:
|
||||
logger.debug("Failed to register brand %s: %s", brand, e)
|
||||
|
||||
|
||||
def _encode_crop(crop: np.ndarray) -> str:
|
||||
img = Image.fromarray(crop)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=85)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
def _crop_image(candidate: TextCandidate) -> np.ndarray:
|
||||
frame = candidate.frame
|
||||
box = candidate.bbox
|
||||
h, w = frame.image.shape[:2]
|
||||
x1 = max(0, box.x)
|
||||
y1 = max(0, box.y)
|
||||
x2 = min(w, box.x + box.w)
|
||||
y2 = min(h, box.y + box.h)
|
||||
return frame.image[y1:y2, x1:x2]
|
||||
|
||||
|
||||
def _parse_response(answer: str, total_tokens: int) -> dict:
|
||||
"""Parse LLM free-text response into structured output."""
|
||||
parts = [p.strip() for p in answer.split(",", 2)]
|
||||
|
||||
brand = parts[0] if parts else ""
|
||||
confidence = 0.5
|
||||
reasoning = answer
|
||||
|
||||
if len(parts) >= 2:
|
||||
try:
|
||||
confidence = float(parts[1])
|
||||
confidence = max(0.0, min(1.0, confidence))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if len(parts) >= 3:
|
||||
reasoning = parts[2]
|
||||
|
||||
return {
|
||||
"brand": brand,
|
||||
"confidence": confidence,
|
||||
"reasoning": reasoning,
|
||||
"tokens": total_tokens or ESTIMATED_TOKENS_PER_CROP,
|
||||
}
|
||||
|
||||
|
||||
def _call_cloud_api(image_b64: str, prompt: str) -> dict:
|
||||
"""Route to the configured provider and parse the response."""
|
||||
provider = get_provider()
|
||||
result = provider.call(image_b64, prompt)
|
||||
return _parse_response(result.answer, result.total_tokens)
|
||||
|
||||
|
||||
def escalate_cloud(
|
||||
candidates: list[TextCandidate],
|
||||
vlm_prompt_fn,
|
||||
stats: PipelineStats,
|
||||
min_confidence: float = 0.4,
|
||||
content_type: str = "",
|
||||
source_asset_id: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> list[BrandDetection]:
|
||||
"""
|
||||
Send remaining unresolved crops to cloud LLM.
|
||||
|
||||
Provider is selected via CLOUD_LLM_PROVIDER env var (groq, gemini, openai).
|
||||
Updates stats with call count and cost.
|
||||
"""
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
if os.environ.get("SKIP_CLOUD", "").strip() == "1":
|
||||
emit.log(job_id, "CloudLLM", "INFO",
|
||||
f"SKIP_CLOUD=1, skipping {len(candidates)} crops")
|
||||
return []
|
||||
|
||||
if not has_api_key():
|
||||
emit.log(job_id, "CloudLLM", "WARNING",
|
||||
f"No API key set for cloud provider, skipping {len(candidates)} crops")
|
||||
return []
|
||||
|
||||
provider = get_provider()
|
||||
emit.log(job_id, "CloudLLM", "INFO",
|
||||
f"Escalating {len(candidates)} crops to {provider.name}")
|
||||
|
||||
matched: list[BrandDetection] = []
|
||||
total_cost = 0.0
|
||||
|
||||
for i, candidate in enumerate(candidates):
|
||||
crop = _crop_image(candidate)
|
||||
if crop.size == 0:
|
||||
continue
|
||||
|
||||
crop_context = CropContext(
|
||||
image=b"",
|
||||
surrounding_text=candidate.text,
|
||||
position_hint=f"frame {candidate.frame.sequence}",
|
||||
)
|
||||
prompt = vlm_prompt_fn(crop_context)
|
||||
image_b64 = _encode_crop(crop)
|
||||
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
result = _call_cloud_api(image_b64, prompt)
|
||||
except Exception as e:
|
||||
call_ms = (time.monotonic() - t0) * 1000
|
||||
emit.log(job_id, "CloudLLM", "DEBUG",
|
||||
f"[{i+1}/{len(candidates)}] FAILED '{candidate.text[:30]}': {e} ({call_ms:.0f}ms)")
|
||||
continue
|
||||
call_ms = (time.monotonic() - t0) * 1000
|
||||
|
||||
stats.cloud_llm_calls += 1
|
||||
model_info = provider.models.get(provider.model)
|
||||
cost_per_token = model_info.cost_per_input_token if model_info else 0.00001
|
||||
call_cost = result["tokens"] * cost_per_token
|
||||
total_cost += call_cost
|
||||
|
||||
brand = result["brand"]
|
||||
confidence = result["confidence"]
|
||||
|
||||
emit.log(job_id, "CloudLLM", "DEBUG",
|
||||
f"[{i+1}/{len(candidates)}] '{candidate.text[:30]}' → "
|
||||
f"{'✓ ' + brand if brand else '✗'} "
|
||||
f"(conf={confidence:.2f}, {result['tokens']}tok, ${call_cost:.4f}, {call_ms:.0f}ms)")
|
||||
|
||||
if brand and confidence >= min_confidence:
|
||||
detection = BrandDetection(
|
||||
brand=brand,
|
||||
timestamp=candidate.frame.timestamp,
|
||||
duration=0.5,
|
||||
confidence=confidence,
|
||||
source="cloud_llm",
|
||||
bbox=candidate.bbox,
|
||||
frame_ref=candidate.frame.sequence,
|
||||
content_type=content_type,
|
||||
)
|
||||
matched.append(detection)
|
||||
|
||||
emit.detection(
|
||||
job_id,
|
||||
brand=brand,
|
||||
confidence=confidence,
|
||||
source="cloud_llm",
|
||||
timestamp=candidate.frame.timestamp,
|
||||
content_type=content_type,
|
||||
frame_ref=candidate.frame.sequence,
|
||||
)
|
||||
|
||||
# Register newly discovered brand in DB
|
||||
_register_discovered_brand(brand, source_asset_id,
|
||||
candidate.frame.timestamp, confidence)
|
||||
|
||||
stats.estimated_cloud_cost_usd += total_cost
|
||||
stats.regions_escalated_to_cloud_llm = len(candidates)
|
||||
|
||||
emit.log(job_id, "CloudLLM", "INFO",
|
||||
f"Cloud resolved {len(matched)}/{len(candidates)} — "
|
||||
f"cost ${total_cost:.4f} ({stats.cloud_llm_calls} calls total)")
|
||||
|
||||
return matched
|
||||
157
core/detect/stages/vlm_local.py
Normal file
157
core/detect/stages/vlm_local.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Stage 6 — Local VLM escalation (moondream2)
|
||||
|
||||
Processes unresolved text candidates by sending crop images + prompt
|
||||
to the local VLM on the inference server. Produces BrandDetection
|
||||
objects for crops the VLM can identify.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import BrandDetection, TextCandidate
|
||||
from core.detect.models import CropContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _register_discovered_brand(brand: str, source_asset_id: str | None,
|
||||
timestamp: float, confidence: float, source: str):
|
||||
"""Register a VLM-confirmed brand in the DB."""
|
||||
try:
|
||||
from core.detect.stages.brand_resolver import _register_brand, _record_sighting
|
||||
brand_id = _register_brand(brand, source)
|
||||
if brand_id and source_asset_id:
|
||||
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, source)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to register brand %s: %s", brand, e)
|
||||
|
||||
|
||||
def _crop_image(candidate: TextCandidate) -> np.ndarray:
|
||||
frame = candidate.frame
|
||||
box = candidate.bbox
|
||||
h, w = frame.image.shape[:2]
|
||||
x1 = max(0, box.x)
|
||||
y1 = max(0, box.y)
|
||||
x2 = min(w, box.x + box.w)
|
||||
y2 = min(h, box.y + box.h)
|
||||
return frame.image[y1:y2, x1:x2]
|
||||
|
||||
|
||||
def escalate_vlm(
|
||||
candidates: list[TextCandidate],
|
||||
vlm_prompt_fn,
|
||||
inference_url: str | None = None,
|
||||
min_confidence: float = 0.5,
|
||||
content_type: str = "",
|
||||
source_asset_id: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> tuple[list[BrandDetection], list[TextCandidate]]:
|
||||
"""
|
||||
Send unresolved crops to local VLM for brand identification.
|
||||
|
||||
Returns:
|
||||
- matched: BrandDetections the VLM confirmed
|
||||
- still_unresolved: candidates the VLM couldn't resolve (→ cloud escalation)
|
||||
"""
|
||||
if not candidates:
|
||||
return [], []
|
||||
|
||||
if os.environ.get("SKIP_VLM", "").strip() == "1":
|
||||
emit.log(job_id, "VLMLocal", "INFO",
|
||||
f"SKIP_VLM=1, skipping {len(candidates)} crops")
|
||||
return [], candidates
|
||||
|
||||
emit.log(job_id, "VLMLocal", "INFO",
|
||||
f"Processing {len(candidates)} unresolved crops with moondream2")
|
||||
|
||||
matched: list[BrandDetection] = []
|
||||
still_unresolved: list[TextCandidate] = []
|
||||
|
||||
if inference_url:
|
||||
from core.detect.inference import InferenceClient
|
||||
from core.detect.emit import _run_log_level
|
||||
client = InferenceClient(base_url=inference_url, job_id=job_id or "", log_level=_run_log_level)
|
||||
|
||||
for i, candidate in enumerate(candidates):
|
||||
crop = _crop_image(candidate)
|
||||
if crop.size == 0:
|
||||
still_unresolved.append(candidate)
|
||||
continue
|
||||
|
||||
crop_context = CropContext(
|
||||
image=b"", # not used for prompt generation
|
||||
surrounding_text=candidate.text,
|
||||
position_hint=f"frame {candidate.frame.sequence}",
|
||||
)
|
||||
prompt = vlm_prompt_fn(crop_context)
|
||||
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
if inference_url:
|
||||
result = client.vlm(image=crop, prompt=prompt)
|
||||
brand = result.brand
|
||||
confidence = result.confidence
|
||||
reasoning = result.reasoning
|
||||
else:
|
||||
brand, confidence, reasoning = _vlm_local(crop, prompt)
|
||||
except Exception as e:
|
||||
vlm_ms = (time.monotonic() - t0) * 1000
|
||||
emit.log(job_id, "VLMLocal", "DEBUG",
|
||||
f"[{i+1}/{len(candidates)}] FAILED '{candidate.text[:30]}': {e} ({vlm_ms:.0f}ms)")
|
||||
still_unresolved.append(candidate)
|
||||
continue
|
||||
vlm_ms = (time.monotonic() - t0) * 1000
|
||||
emit.log(job_id, "VLMLocal", "DEBUG",
|
||||
f"[{i+1}/{len(candidates)}] '{candidate.text[:30]}' → "
|
||||
f"{'✓ ' + brand if brand else '✗ unresolved'} "
|
||||
f"(conf={confidence:.2f}, {vlm_ms:.0f}ms)")
|
||||
|
||||
if brand and confidence >= min_confidence:
|
||||
detection = BrandDetection(
|
||||
brand=brand,
|
||||
timestamp=candidate.frame.timestamp,
|
||||
duration=0.5,
|
||||
confidence=confidence,
|
||||
source="local_vlm",
|
||||
bbox=candidate.bbox,
|
||||
frame_ref=candidate.frame.sequence,
|
||||
content_type=content_type,
|
||||
)
|
||||
matched.append(detection)
|
||||
|
||||
emit.detection(
|
||||
job_id,
|
||||
brand=brand,
|
||||
confidence=confidence,
|
||||
source="local_vlm",
|
||||
timestamp=candidate.frame.timestamp,
|
||||
content_type=content_type,
|
||||
frame_ref=candidate.frame.sequence,
|
||||
)
|
||||
|
||||
# Register newly discovered brand in DB
|
||||
_register_discovered_brand(brand, source_asset_id,
|
||||
candidate.frame.timestamp, confidence, "local_vlm")
|
||||
|
||||
logger.debug("VLM matched: %s (%.2f) — %s", brand, confidence, reasoning)
|
||||
else:
|
||||
still_unresolved.append(candidate)
|
||||
|
||||
emit.log(job_id, "VLMLocal", "INFO",
|
||||
f"VLM resolved {len(matched)}, unresolved {len(still_unresolved)} → cloud")
|
||||
|
||||
return matched, still_unresolved
|
||||
|
||||
|
||||
def _vlm_local(crop: np.ndarray, prompt: str) -> tuple[str, float, str]:
|
||||
"""Run moondream2 in-process (single-box mode)."""
|
||||
from core.gpu.models.vlm import query
|
||||
result = query(crop, prompt)
|
||||
return result["brand"], result["confidence"], result["reasoning"]
|
||||
138
core/detect/stages/yolo_detector.py
Normal file
138
core/detect/stages/yolo_detector.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
Stage 3 — YOLO Object Detection
|
||||
|
||||
Detects regions of interest (logos, text, banners) in frames.
|
||||
Two modes:
|
||||
- Remote: calls inference server over HTTP (GPU on another machine)
|
||||
- Local: imports ultralytics directly (GPU on same machine)
|
||||
|
||||
Emits frame_update events with bounding boxes for the UI.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from core.detect import emit
|
||||
from core.detect.models import BoundingBox, Frame
|
||||
from core.detect.stages.models import DetectionConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _frame_to_b64(frame: Frame) -> str:
|
||||
"""Encode frame as base64 JPEG for SSE frame_update events."""
|
||||
img = Image.fromarray(frame.image)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=70)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
def _detect_remote(frame: Frame, config: DetectionConfig, inference_url: str,
|
||||
job_id: str = "", log_level: str = "INFO") -> list[BoundingBox]:
|
||||
"""Call the inference server over HTTP."""
|
||||
from core.detect.inference import InferenceClient
|
||||
client = InferenceClient(base_url=inference_url, job_id=job_id, log_level=log_level)
|
||||
results = client.detect(
|
||||
image=frame.image,
|
||||
model=config.model_name,
|
||||
confidence=config.confidence_threshold,
|
||||
target_classes=config.target_classes,
|
||||
)
|
||||
boxes = []
|
||||
for r in results:
|
||||
box = BoundingBox(
|
||||
x=r.x, y=r.y, w=r.w, h=r.h,
|
||||
confidence=r.confidence, label=r.label,
|
||||
)
|
||||
boxes.append(box)
|
||||
return boxes
|
||||
|
||||
|
||||
def _detect_local(frame: Frame, config: DetectionConfig) -> list[BoundingBox]:
|
||||
"""Run YOLO in-process (requires ultralytics installed)."""
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(config.model_name)
|
||||
results = model(frame.image, conf=config.confidence_threshold, verbose=False)
|
||||
|
||||
boxes = []
|
||||
for r in results:
|
||||
for det in r.boxes:
|
||||
x1, y1, x2, y2 = det.xyxy[0].tolist()
|
||||
label = r.names[int(det.cls[0])]
|
||||
|
||||
if config.target_classes and label not in config.target_classes:
|
||||
continue
|
||||
|
||||
box = BoundingBox(
|
||||
x=int(x1), y=int(y1),
|
||||
w=int(x2 - x1), h=int(y2 - y1),
|
||||
confidence=float(det.conf[0]),
|
||||
label=label,
|
||||
)
|
||||
boxes.append(box)
|
||||
return boxes
|
||||
|
||||
|
||||
def detect_objects(
|
||||
frames: list[Frame],
|
||||
config: DetectionConfig,
|
||||
inference_url: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> dict[int, list[BoundingBox]]:
|
||||
"""
|
||||
Run object detection on all frames.
|
||||
|
||||
If inference_url is provided, calls the remote GPU server.
|
||||
Otherwise, imports ultralytics and runs locally.
|
||||
|
||||
Returns a dict mapping frame sequence → list of bounding boxes.
|
||||
"""
|
||||
mode = "remote" if inference_url else "local"
|
||||
emit.log(job_id, "YOLODetector", "INFO",
|
||||
f"Detecting objects in {len(frames)} frames "
|
||||
f"(model={config.model_name}, conf={config.confidence_threshold}, mode={mode})")
|
||||
|
||||
all_boxes: dict[int, list[BoundingBox]] = {}
|
||||
total_regions = 0
|
||||
|
||||
for i, frame in enumerate(frames):
|
||||
t0 = time.monotonic()
|
||||
if inference_url:
|
||||
from core.detect.emit import _run_log_level
|
||||
boxes = _detect_remote(frame, config, inference_url,
|
||||
job_id=job_id or "", log_level=_run_log_level)
|
||||
else:
|
||||
boxes = _detect_local(frame, config)
|
||||
det_ms = (time.monotonic() - t0) * 1000
|
||||
|
||||
all_boxes[frame.sequence] = boxes
|
||||
total_regions += len(boxes)
|
||||
|
||||
emit.log(job_id, "YOLODetector", "DEBUG",
|
||||
f"Frame {frame.sequence}: {len(boxes)} regions in {det_ms:.0f}ms"
|
||||
f" [{', '.join(b.label for b in boxes)}]" if boxes else
|
||||
f"Frame {frame.sequence}: 0 regions in {det_ms:.0f}ms")
|
||||
|
||||
if boxes and job_id:
|
||||
box_dicts = [{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
|
||||
"confidence": b.confidence, "label": b.label}
|
||||
for b in boxes]
|
||||
emit.frame_update(
|
||||
job_id,
|
||||
frame_ref=frame.sequence,
|
||||
timestamp=frame.timestamp,
|
||||
jpeg_b64=_frame_to_b64(frame),
|
||||
boxes=box_dicts,
|
||||
)
|
||||
|
||||
emit.log(job_id, "YOLODetector", "INFO",
|
||||
f"Detected {total_regions} regions across {len(frames)} frames")
|
||||
emit.stats(job_id, regions_detected=total_regions)
|
||||
|
||||
return all_boxes
|
||||
44
core/detect/state.py
Normal file
44
core/detect/state.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
LangGraph state definition for the detection pipeline.
|
||||
|
||||
This TypedDict flows through all graph nodes. Each node reads what
|
||||
it needs and writes its outputs. LangGraph manages the state transitions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
from core.detect.models import BoundingBox, BrandDetection, DetectionReport, Frame, PipelineStats, TextCandidate
|
||||
|
||||
|
||||
class DetectState(TypedDict, total=False):
|
||||
# Input
|
||||
video_path: str
|
||||
job_id: str
|
||||
timeline_id: str
|
||||
profile_name: str
|
||||
source_asset_id: str # UUID of the source MediaAsset
|
||||
|
||||
# Stage outputs
|
||||
frames: list[Frame]
|
||||
filtered_frames: list[Frame]
|
||||
field_masks: dict # {seq: np.ndarray} — pitch mask per frame
|
||||
field_boundaries: dict # {seq: [(x,y), ...]} — pitch boundary per frame
|
||||
field_coverage: dict # {seq: float} — pitch coverage ratio per frame
|
||||
edge_regions_by_frame: dict[int, list[BoundingBox]]
|
||||
boxes_by_frame: dict[int, list[BoundingBox]]
|
||||
preprocessed_crops: dict # "{frame_seq}_{box_idx}" → np.ndarray
|
||||
text_candidates: list[TextCandidate]
|
||||
unresolved_candidates: list[TextCandidate]
|
||||
detections: list[BrandDetection]
|
||||
report: DetectionReport
|
||||
|
||||
# Session brands (accumulated during the run, persisted to DB)
|
||||
session_brands: dict # {normalized_name: canonical_name}
|
||||
|
||||
# Running stats (updated by each stage)
|
||||
stats: PipelineStats
|
||||
|
||||
# Config overrides for replay (merged into profile configs dict)
|
||||
config_overrides: dict
|
||||
131
core/detect/tracing.py
Normal file
131
core/detect/tracing.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
Langfuse tracing for the detection pipeline.
|
||||
|
||||
Provides span helpers that graph nodes use to record timing, frame counts,
|
||||
and stage-level metadata. The Langfuse client is optional — if not configured
|
||||
(no LANGFUSE_SECRET_KEY), tracing is a no-op.
|
||||
|
||||
Usage in graph nodes:
|
||||
from core.detect.tracing import trace_node
|
||||
|
||||
def node_extract_frames(state):
|
||||
with trace_node(state, "extract_frames") as span:
|
||||
...
|
||||
span.set_output({"frames": len(frames)})
|
||||
return {...}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_client = None
|
||||
_enabled: bool | None = None
|
||||
|
||||
|
||||
def _get_client():
|
||||
"""Lazy-init Langfuse client. Returns None if not configured."""
|
||||
global _client, _enabled
|
||||
if _enabled is False:
|
||||
return None
|
||||
if _client is not None:
|
||||
return _client
|
||||
|
||||
secret = os.environ.get("LANGFUSE_SECRET_KEY", "")
|
||||
if not secret:
|
||||
_enabled = False
|
||||
logger.info("Langfuse not configured (no LANGFUSE_SECRET_KEY), tracing disabled")
|
||||
return None
|
||||
|
||||
try:
|
||||
from langfuse import Langfuse
|
||||
_client = Langfuse()
|
||||
_enabled = True
|
||||
logger.info("Langfuse tracing enabled")
|
||||
return _client
|
||||
except Exception as e:
|
||||
_enabled = False
|
||||
logger.warning("Langfuse init failed: %s — tracing disabled", e)
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpanContext:
|
||||
"""Wraps a Langfuse span with convenience methods."""
|
||||
_span: object | None = None
|
||||
_start: float = field(default_factory=time.monotonic)
|
||||
metadata: dict = field(default_factory=dict)
|
||||
|
||||
def set_output(self, output: dict) -> None:
|
||||
self.metadata.update(output)
|
||||
|
||||
def set_error(self, error: str) -> None:
|
||||
self.metadata["error"] = error
|
||||
|
||||
def _finish(self, status: str = "ok") -> None:
|
||||
elapsed = time.monotonic() - self._start
|
||||
self.metadata["duration_seconds"] = round(elapsed, 3)
|
||||
self.metadata["status"] = status
|
||||
|
||||
if self._span is not None:
|
||||
try:
|
||||
self._span.update(
|
||||
output=self.metadata,
|
||||
level="ERROR" if status == "error" else "DEFAULT",
|
||||
)
|
||||
self._span.end()
|
||||
except Exception as e:
|
||||
logger.debug("Failed to end Langfuse span: %s", e)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace_node(state: dict, node_name: str):
|
||||
"""
|
||||
Context manager that creates a Langfuse span for a pipeline node.
|
||||
|
||||
Usage:
|
||||
with trace_node(state, "extract_frames") as span:
|
||||
frames = do_work()
|
||||
span.set_output({"frames": len(frames)})
|
||||
"""
|
||||
job_id = state.get("job_id", "unknown")
|
||||
profile = state.get("profile_name", "")
|
||||
client = _get_client()
|
||||
|
||||
span_obj = None
|
||||
if client is not None:
|
||||
try:
|
||||
trace = client.trace(
|
||||
name=f"detect:{job_id}",
|
||||
session_id=job_id,
|
||||
metadata={"profile": profile},
|
||||
)
|
||||
span_obj = trace.span(
|
||||
name=node_name,
|
||||
input={"job_id": job_id, "profile": profile},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to create Langfuse span: %s", e)
|
||||
|
||||
ctx = SpanContext(_span=span_obj)
|
||||
try:
|
||||
yield ctx
|
||||
ctx._finish("ok")
|
||||
except Exception:
|
||||
ctx._finish("error")
|
||||
raise
|
||||
|
||||
|
||||
def flush():
|
||||
"""Flush pending Langfuse events. Call at pipeline end."""
|
||||
if _client is not None:
|
||||
try:
|
||||
_client.flush()
|
||||
except Exception as e:
|
||||
logger.debug("Langfuse flush failed: %s", e)
|
||||
44
core/events.py
Normal file
44
core/events.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
Redis-based event bus for pipeline job progress.
|
||||
|
||||
Pipeline stages push events, SSE endpoints poll them.
|
||||
Only depends on redis — safe to import from any context.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import redis
|
||||
|
||||
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
|
||||
|
||||
|
||||
def _get_redis():
|
||||
return redis.from_url(REDIS_URL, decode_responses=True)
|
||||
|
||||
|
||||
def push_event(
|
||||
job_id: str, event_type: str, data: dict, prefix: str = "chunk_events"
|
||||
) -> None:
|
||||
"""Push an event to the Redis list for a job."""
|
||||
r = _get_redis()
|
||||
key = f"{prefix}:{job_id}"
|
||||
event = json.dumps({"event": event_type, **data})
|
||||
r.rpush(key, event)
|
||||
r.expire(key, 3600)
|
||||
|
||||
|
||||
def poll_events(
|
||||
job_id: str, cursor: int = 0, prefix: str = "chunk_events"
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Poll new events from Redis. Returns (events, new_cursor)."""
|
||||
r = _get_redis()
|
||||
key = f"{prefix}:{job_id}"
|
||||
raw_events = r.lrange(key, cursor, -1)
|
||||
parsed = []
|
||||
for raw in raw_events:
|
||||
try:
|
||||
parsed.append(json.loads(raw))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
return parsed, cursor + len(raw_events)
|
||||
@@ -1,13 +1,6 @@
|
||||
from .capabilities import get_decoders, get_encoders, get_formats
|
||||
from .probe import ProbeResult, probe_file
|
||||
from .transcode import TranscodeConfig, transcode
|
||||
|
||||
__all__ = [
|
||||
"probe_file",
|
||||
"ProbeResult",
|
||||
"transcode",
|
||||
"TranscodeConfig",
|
||||
"get_encoders",
|
||||
"get_decoders",
|
||||
"get_formats",
|
||||
]
|
||||
|
||||
@@ -1,145 +0,0 @@
|
||||
"""
|
||||
FFmpeg capabilities - Discover available codecs and formats using ffmpeg-python.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import ffmpeg
|
||||
|
||||
|
||||
@dataclass
|
||||
class Codec:
|
||||
"""An FFmpeg encoder or decoder."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
type: str # 'video' or 'audio'
|
||||
|
||||
|
||||
@dataclass
|
||||
class Format:
|
||||
"""An FFmpeg format (muxer/demuxer)."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
can_demux: bool
|
||||
can_mux: bool
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_ffmpeg_info() -> Dict[str, Any]:
|
||||
"""Get FFmpeg capabilities info."""
|
||||
# ffmpeg-python doesn't have a direct way to get codecs/formats
|
||||
# but we can use probe on a dummy or parse -codecs output
|
||||
# For now, return common codecs that are typically available
|
||||
return {
|
||||
"video_encoders": [
|
||||
{"name": "libx264", "description": "H.264 / AVC"},
|
||||
{"name": "libx265", "description": "H.265 / HEVC"},
|
||||
{"name": "mpeg4", "description": "MPEG-4 Part 2"},
|
||||
{"name": "libvpx", "description": "VP8"},
|
||||
{"name": "libvpx-vp9", "description": "VP9"},
|
||||
{"name": "h264_nvenc", "description": "NVIDIA NVENC H.264"},
|
||||
{"name": "hevc_nvenc", "description": "NVIDIA NVENC H.265"},
|
||||
{"name": "h264_vaapi", "description": "VAAPI H.264"},
|
||||
{"name": "prores_ks", "description": "Apple ProRes"},
|
||||
{"name": "dnxhd", "description": "Avid DNxHD/DNxHR"},
|
||||
{"name": "copy", "description": "Stream copy (no encoding)"},
|
||||
],
|
||||
"audio_encoders": [
|
||||
{"name": "aac", "description": "AAC"},
|
||||
{"name": "libmp3lame", "description": "MP3"},
|
||||
{"name": "libopus", "description": "Opus"},
|
||||
{"name": "libvorbis", "description": "Vorbis"},
|
||||
{"name": "pcm_s16le", "description": "PCM signed 16-bit little-endian"},
|
||||
{"name": "flac", "description": "FLAC"},
|
||||
{"name": "copy", "description": "Stream copy (no encoding)"},
|
||||
],
|
||||
"formats": [
|
||||
{"name": "mp4", "description": "MP4", "can_demux": True, "can_mux": True},
|
||||
{
|
||||
"name": "mov",
|
||||
"description": "QuickTime / MOV",
|
||||
"can_demux": True,
|
||||
"can_mux": True,
|
||||
},
|
||||
{
|
||||
"name": "mkv",
|
||||
"description": "Matroska",
|
||||
"can_demux": True,
|
||||
"can_mux": True,
|
||||
},
|
||||
{"name": "webm", "description": "WebM", "can_demux": True, "can_mux": True},
|
||||
{"name": "avi", "description": "AVI", "can_demux": True, "can_mux": True},
|
||||
{"name": "flv", "description": "FLV", "can_demux": True, "can_mux": True},
|
||||
{
|
||||
"name": "ts",
|
||||
"description": "MPEG-TS",
|
||||
"can_demux": True,
|
||||
"can_mux": True,
|
||||
},
|
||||
{
|
||||
"name": "mpegts",
|
||||
"description": "MPEG-TS",
|
||||
"can_demux": True,
|
||||
"can_mux": True,
|
||||
},
|
||||
{"name": "hls", "description": "HLS", "can_demux": True, "can_mux": True},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def get_encoders() -> List[Codec]:
|
||||
"""Get available encoders (video + audio)."""
|
||||
info = _get_ffmpeg_info()
|
||||
codecs = []
|
||||
|
||||
for c in info["video_encoders"]:
|
||||
codecs.append(Codec(name=c["name"], description=c["description"], type="video"))
|
||||
|
||||
for c in info["audio_encoders"]:
|
||||
codecs.append(Codec(name=c["name"], description=c["description"], type="audio"))
|
||||
|
||||
return codecs
|
||||
|
||||
|
||||
def get_decoders() -> List[Codec]:
|
||||
"""Get available decoders."""
|
||||
# Most encoders can also decode
|
||||
return get_encoders()
|
||||
|
||||
|
||||
def get_formats() -> List[Format]:
|
||||
"""Get available formats."""
|
||||
info = _get_ffmpeg_info()
|
||||
return [
|
||||
Format(
|
||||
name=f["name"],
|
||||
description=f["description"],
|
||||
can_demux=f["can_demux"],
|
||||
can_mux=f["can_mux"],
|
||||
)
|
||||
for f in info["formats"]
|
||||
]
|
||||
|
||||
|
||||
def get_video_encoders() -> List[Codec]:
|
||||
"""Get available video encoders."""
|
||||
return [c for c in get_encoders() if c.type == "video"]
|
||||
|
||||
|
||||
def get_audio_encoders() -> List[Codec]:
|
||||
"""Get available audio encoders."""
|
||||
return [c for c in get_encoders() if c.type == "audio"]
|
||||
|
||||
|
||||
def get_muxers() -> List[Format]:
|
||||
"""Get available output formats (muxers)."""
|
||||
return [f for f in get_formats() if f.can_mux]
|
||||
|
||||
|
||||
def get_demuxers() -> List[Format]:
|
||||
"""Get available input formats (demuxers)."""
|
||||
return [f for f in get_formats() if f.can_demux]
|
||||
@@ -1,225 +0,0 @@
|
||||
"""
|
||||
FFmpeg transcode module - Transcode media files using ffmpeg-python.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import ffmpeg
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscodeConfig:
|
||||
"""Configuration for a transcode operation."""
|
||||
|
||||
input_path: str
|
||||
output_path: str
|
||||
|
||||
# Video
|
||||
video_codec: str = "libx264"
|
||||
video_bitrate: Optional[str] = None
|
||||
video_crf: Optional[int] = None
|
||||
video_preset: Optional[str] = None
|
||||
resolution: Optional[str] = None
|
||||
framerate: Optional[float] = None
|
||||
|
||||
# Audio
|
||||
audio_codec: str = "aac"
|
||||
audio_bitrate: Optional[str] = None
|
||||
audio_channels: Optional[int] = None
|
||||
audio_samplerate: Optional[int] = None
|
||||
|
||||
# Trimming
|
||||
trim_start: Optional[float] = None
|
||||
trim_end: Optional[float] = None
|
||||
|
||||
# Container
|
||||
container: str = "mp4"
|
||||
|
||||
# Extra args (key-value pairs)
|
||||
extra_args: List[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def is_copy(self) -> bool:
|
||||
"""Check if this is a stream copy (no transcoding)."""
|
||||
return self.video_codec == "copy" and self.audio_codec == "copy"
|
||||
|
||||
|
||||
def build_stream(config: TranscodeConfig):
|
||||
"""
|
||||
Build an ffmpeg-python stream from config.
|
||||
|
||||
Returns the stream object ready to run.
|
||||
"""
|
||||
# Input options
|
||||
input_kwargs = {}
|
||||
if config.trim_start is not None:
|
||||
input_kwargs["ss"] = config.trim_start
|
||||
|
||||
stream = ffmpeg.input(config.input_path, **input_kwargs)
|
||||
|
||||
# Output options
|
||||
output_kwargs = {
|
||||
"vcodec": config.video_codec,
|
||||
"acodec": config.audio_codec,
|
||||
}
|
||||
|
||||
# Trimming duration
|
||||
if config.trim_end is not None:
|
||||
if config.trim_start is not None:
|
||||
output_kwargs["t"] = config.trim_end - config.trim_start
|
||||
else:
|
||||
output_kwargs["t"] = config.trim_end
|
||||
|
||||
# Video options (skip if copy)
|
||||
if config.video_codec != "copy":
|
||||
if config.video_crf is not None:
|
||||
output_kwargs["crf"] = config.video_crf
|
||||
elif config.video_bitrate:
|
||||
output_kwargs["video_bitrate"] = config.video_bitrate
|
||||
|
||||
if config.video_preset:
|
||||
output_kwargs["preset"] = config.video_preset
|
||||
|
||||
if config.resolution:
|
||||
output_kwargs["s"] = config.resolution
|
||||
|
||||
if config.framerate:
|
||||
output_kwargs["r"] = config.framerate
|
||||
|
||||
# Audio options (skip if copy)
|
||||
if config.audio_codec != "copy":
|
||||
if config.audio_bitrate:
|
||||
output_kwargs["audio_bitrate"] = config.audio_bitrate
|
||||
if config.audio_channels:
|
||||
output_kwargs["ac"] = config.audio_channels
|
||||
if config.audio_samplerate:
|
||||
output_kwargs["ar"] = config.audio_samplerate
|
||||
|
||||
# Parse extra args into kwargs
|
||||
extra_kwargs = parse_extra_args(config.extra_args)
|
||||
output_kwargs.update(extra_kwargs)
|
||||
|
||||
stream = ffmpeg.output(stream, config.output_path, **output_kwargs)
|
||||
stream = ffmpeg.overwrite_output(stream)
|
||||
|
||||
return stream
|
||||
|
||||
|
||||
def parse_extra_args(extra_args: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse extra args list into kwargs dict.
|
||||
|
||||
["-vtag", "xvid", "-pix_fmt", "yuv420p"] -> {"vtag": "xvid", "pix_fmt": "yuv420p"}
|
||||
"""
|
||||
kwargs = {}
|
||||
i = 0
|
||||
while i < len(extra_args):
|
||||
key = extra_args[i].lstrip("-")
|
||||
if i + 1 < len(extra_args) and not extra_args[i + 1].startswith("-"):
|
||||
kwargs[key] = extra_args[i + 1]
|
||||
i += 2
|
||||
else:
|
||||
# Flag without value
|
||||
kwargs[key] = None
|
||||
i += 1
|
||||
return kwargs
|
||||
|
||||
|
||||
def transcode(
|
||||
config: TranscodeConfig,
|
||||
duration: Optional[float] = None,
|
||||
progress_callback: Optional[Callable[[float, Dict[str, Any]], None]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Transcode a media file.
|
||||
|
||||
Args:
|
||||
config: Transcode configuration
|
||||
duration: Total duration in seconds (for progress calculation, optional)
|
||||
progress_callback: Called with (percent, details_dict) - requires duration
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
|
||||
Raises:
|
||||
ffmpeg.Error: If transcoding fails
|
||||
"""
|
||||
# Ensure output directory exists
|
||||
Path(config.output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
stream = build_stream(config)
|
||||
|
||||
if progress_callback and duration:
|
||||
# Run with progress tracking using run_async
|
||||
return _run_with_progress(stream, config, duration, progress_callback)
|
||||
else:
|
||||
# Run synchronously
|
||||
ffmpeg.run(stream, capture_stdout=True, capture_stderr=True)
|
||||
return True
|
||||
|
||||
|
||||
def _run_with_progress(
|
||||
stream,
|
||||
config: TranscodeConfig,
|
||||
duration: float,
|
||||
progress_callback: Callable[[float, Dict[str, Any]], None],
|
||||
) -> bool:
|
||||
"""Run FFmpeg with progress tracking using run_async and stderr parsing."""
|
||||
import re
|
||||
|
||||
# Calculate effective duration
|
||||
effective_duration = duration
|
||||
if config.trim_start and config.trim_end:
|
||||
effective_duration = config.trim_end - config.trim_start
|
||||
elif config.trim_end:
|
||||
effective_duration = config.trim_end
|
||||
elif config.trim_start:
|
||||
effective_duration = duration - config.trim_start
|
||||
|
||||
# Run async to get process handle
|
||||
process = ffmpeg.run_async(stream, pipe_stdout=True, pipe_stderr=True)
|
||||
|
||||
# Parse stderr for progress (time=HH:MM:SS.ms pattern)
|
||||
time_pattern = re.compile(r"time=(\d+):(\d+):(\d+)\.(\d+)")
|
||||
|
||||
while True:
|
||||
line = process.stderr.readline()
|
||||
if not line:
|
||||
break
|
||||
|
||||
line = line.decode("utf-8", errors="ignore")
|
||||
match = time_pattern.search(line)
|
||||
if match:
|
||||
hours = int(match.group(1))
|
||||
minutes = int(match.group(2))
|
||||
seconds = int(match.group(3))
|
||||
ms = int(match.group(4))
|
||||
|
||||
current_time = hours * 3600 + minutes * 60 + seconds + ms / 100
|
||||
percent = min(100.0, (current_time / effective_duration) * 100)
|
||||
|
||||
progress_callback(
|
||||
percent,
|
||||
{
|
||||
"time": current_time,
|
||||
"percent": percent,
|
||||
},
|
||||
)
|
||||
|
||||
# Wait for completion
|
||||
process.wait()
|
||||
|
||||
if process.returncode != 0:
|
||||
raise ffmpeg.Error(
|
||||
"ffmpeg", stdout=process.stdout.read(), stderr=process.stderr.read()
|
||||
)
|
||||
|
||||
# Final callback
|
||||
progress_callback(
|
||||
100.0, {"time": effective_duration, "percent": 100.0, "done": True}
|
||||
)
|
||||
|
||||
return True
|
||||
18
core/gpu/.env.template
Normal file
18
core/gpu/.env.template
Normal file
@@ -0,0 +1,18 @@
|
||||
# Inference server configuration
|
||||
HOST=0.0.0.0
|
||||
PORT=8000
|
||||
|
||||
# VRAM management
|
||||
VRAM_BUDGET_MB=10240
|
||||
STRATEGY=sequential # sequential | concurrent | auto
|
||||
|
||||
# Model defaults
|
||||
YOLO_MODEL=yolov8n.pt
|
||||
YOLO_CONFIDENCE=0.3
|
||||
|
||||
# OCR
|
||||
OCR_LANGUAGES=en,es
|
||||
OCR_MIN_CONFIDENCE=0.5
|
||||
|
||||
# Device
|
||||
DEVICE=auto # auto | cpu | cuda | cuda:0
|
||||
21
core/gpu/Dockerfile
Normal file
21
core/gpu/Dockerfile
Normal file
@@ -0,0 +1,21 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
RUN pip install --no-cache-dir uv
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
libgl1 libglib2.0-0 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
|
||||
# uv.lock is generated on first build (dev box can't reach paddlepaddle's index)
|
||||
COPY pyproject.toml ./
|
||||
RUN uv sync --no-install-project --no-dev
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["python", "server.py"]
|
||||
0
core/gpu/__init__.py
Normal file
0
core/gpu/__init__.py
Normal file
39
core/gpu/config.py
Normal file
39
core/gpu/config.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
Runtime config — loaded from env, mutable via API.
|
||||
|
||||
The UI config panel is just a visual editor for these same values.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
_config = {
|
||||
"device": os.environ.get("DEVICE", "auto"),
|
||||
"yolo_model": os.environ.get("YOLO_MODEL", "yolov8n.pt"),
|
||||
"yolo_confidence": float(os.environ.get("YOLO_CONFIDENCE", "0.3")),
|
||||
"vram_budget_mb": int(os.environ.get("VRAM_BUDGET_MB", "10240")),
|
||||
"strategy": os.environ.get("STRATEGY", "sequential"),
|
||||
"ocr_languages": os.environ.get("OCR_LANGUAGES", "en").split(","),
|
||||
"ocr_min_confidence": float(os.environ.get("OCR_MIN_CONFIDENCE", "0.5")),
|
||||
}
|
||||
|
||||
|
||||
def get_config() -> dict:
|
||||
return _config
|
||||
|
||||
|
||||
def update_config(changes: dict) -> dict:
|
||||
_config.update(changes)
|
||||
return _config
|
||||
|
||||
|
||||
def get_device() -> str:
|
||||
device = _config["device"]
|
||||
if device != "auto":
|
||||
return device
|
||||
try:
|
||||
import torch
|
||||
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||
except ImportError:
|
||||
return "cpu"
|
||||
52
core/gpu/emit.py
Normal file
52
core/gpu/emit.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Lightweight event emitter for the GPU inference server.
|
||||
|
||||
Pushes debug logs to the same Redis stream as the pipeline orchestrator,
|
||||
so GPU-side details (model load, VRAM, inference timing) appear in the
|
||||
same log panel.
|
||||
|
||||
Only active when the request includes X-Job-Id header.
|
||||
No dependency on the detect package.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import redis
|
||||
|
||||
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
|
||||
EVENTS_PREFIX = "detect_events"
|
||||
|
||||
_LEVEL_ORDER = {"DEBUG": 0, "INFO": 1, "WARN": 2, "ERROR": 3}
|
||||
|
||||
_redis_client = None
|
||||
|
||||
|
||||
def _get_redis():
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = redis.from_url(REDIS_URL, decode_responses=True)
|
||||
return _redis_client
|
||||
|
||||
|
||||
def log(job_id: str, stage: str, level: str, msg: str, log_level: str = "INFO"):
|
||||
"""Push a log event to Redis if the level meets the threshold."""
|
||||
if not job_id:
|
||||
return
|
||||
if _LEVEL_ORDER.get(level.upper(), 1) < _LEVEL_ORDER.get(log_level.upper(), 1):
|
||||
return
|
||||
|
||||
r = _get_redis()
|
||||
key = f"{EVENTS_PREFIX}:{job_id}"
|
||||
event = json.dumps({
|
||||
"event": "log",
|
||||
"level": level,
|
||||
"stage": stage,
|
||||
"msg": msg,
|
||||
"ts": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
r.rpush(key, event)
|
||||
r.expire(key, 3600)
|
||||
6
core/gpu/models/__init__.py
Normal file
6
core/gpu/models/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# GPU models — standalone container imports.
|
||||
# When running as a container (cd gpu && python server.py), bare imports work.
|
||||
# When imported from the main app (core.gpu.models.preprocess), only
|
||||
# individual modules should be imported directly, not this __init__.
|
||||
#
|
||||
# The server.py imports detect/ocr/vlm directly, not through this file.
|
||||
1
core/gpu/models/cv/__init__.py
Normal file
1
core/gpu/models/cv/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""CV operations — pure OpenCV, no ML models."""
|
||||
258
core/gpu/models/cv/edges.py
Normal file
258
core/gpu/models/cv/edges.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
Edge detection — Canny + HoughLinesP → parallel line pairs → bounding boxes.
|
||||
|
||||
Finds horizontal line pairs with consistent spacing, which correspond to
|
||||
the top and bottom edges of advertising hoardings.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def detect_edges(
|
||||
image: np.ndarray,
|
||||
canny_low: int = 50,
|
||||
canny_high: int = 150,
|
||||
hough_threshold: int = 80,
|
||||
hough_min_length: int = 100,
|
||||
hough_max_gap: int = 10,
|
||||
pair_max_distance: int = 200,
|
||||
pair_min_distance: int = 15,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Find horizontal line pairs that likely bound advertising hoardings.
|
||||
|
||||
Returns list of dicts with keys: x, y, w, h, confidence, label.
|
||||
Each box represents the region between a detected pair of parallel
|
||||
horizontal lines.
|
||||
"""
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
||||
edges = cv2.Canny(gray, canny_low, canny_high)
|
||||
|
||||
raw_lines = cv2.HoughLinesP(
|
||||
edges,
|
||||
rho=1,
|
||||
theta=np.pi / 180,
|
||||
threshold=hough_threshold,
|
||||
minLineLength=hough_min_length,
|
||||
maxLineGap=hough_max_gap,
|
||||
)
|
||||
|
||||
if raw_lines is None:
|
||||
return []
|
||||
|
||||
# Filter to near-horizontal lines (within 10 degrees)
|
||||
horizontals = _filter_horizontal(raw_lines, max_angle_deg=10)
|
||||
|
||||
if len(horizontals) < 2:
|
||||
return []
|
||||
|
||||
# Find pairs of parallel horizontals with consistent spacing
|
||||
pairs = _find_line_pairs(
|
||||
horizontals,
|
||||
min_distance=pair_min_distance,
|
||||
max_distance=pair_max_distance,
|
||||
)
|
||||
|
||||
# Convert pairs to bounding boxes
|
||||
h, w = image.shape[:2]
|
||||
results = []
|
||||
for top_line, bottom_line in pairs:
|
||||
box = _pair_to_bbox(top_line, bottom_line, frame_width=w, frame_height=h)
|
||||
if box is not None:
|
||||
results.append(box)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _filter_horizontal(lines: np.ndarray, max_angle_deg: float = 10) -> list[tuple]:
|
||||
"""Keep only lines within max_angle_deg of horizontal."""
|
||||
max_slope = np.tan(np.radians(max_angle_deg))
|
||||
result = []
|
||||
for line in lines:
|
||||
x1, y1, x2, y2 = line[0]
|
||||
dx = x2 - x1
|
||||
if dx == 0:
|
||||
continue
|
||||
slope = abs((y2 - y1) / dx)
|
||||
if slope <= max_slope:
|
||||
y_mid = (y1 + y2) / 2
|
||||
x_min = min(x1, x2)
|
||||
x_max = max(x1, x2)
|
||||
length = np.sqrt(dx**2 + (y2 - y1) ** 2)
|
||||
result.append((x_min, x_max, y_mid, length))
|
||||
return result
|
||||
|
||||
|
||||
def _find_line_pairs(
|
||||
horizontals: list[tuple],
|
||||
min_distance: int,
|
||||
max_distance: int,
|
||||
) -> list[tuple]:
|
||||
"""
|
||||
Find pairs of horizontal lines that could be top/bottom of a hoarding.
|
||||
|
||||
Lines must overlap horizontally and be spaced within [min_distance, max_distance].
|
||||
"""
|
||||
# Sort by y position
|
||||
sorted_lines = sorted(horizontals, key=lambda l: l[2])
|
||||
|
||||
pairs = []
|
||||
used = set()
|
||||
|
||||
for i, top in enumerate(sorted_lines):
|
||||
if i in used:
|
||||
continue
|
||||
for j, bottom in enumerate(sorted_lines[i + 1 :], start=i + 1):
|
||||
if j in used:
|
||||
continue
|
||||
|
||||
y_gap = bottom[2] - top[2]
|
||||
if y_gap < min_distance:
|
||||
continue
|
||||
if y_gap > max_distance:
|
||||
break # sorted by y, no point checking further
|
||||
|
||||
# Check horizontal overlap
|
||||
overlap_start = max(top[0], bottom[0])
|
||||
overlap_end = min(top[1], bottom[1])
|
||||
overlap = overlap_end - overlap_start
|
||||
|
||||
# Require at least 50% overlap relative to shorter line
|
||||
shorter_length = min(top[1] - top[0], bottom[1] - bottom[0])
|
||||
if shorter_length > 0 and overlap / shorter_length >= 0.5:
|
||||
pairs.append((top, bottom))
|
||||
used.add(i)
|
||||
used.add(j)
|
||||
break
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
def _pair_to_bbox(
|
||||
top: tuple,
|
||||
bottom: tuple,
|
||||
frame_width: int,
|
||||
frame_height: int,
|
||||
) -> dict | None:
|
||||
"""Convert a line pair to a bounding box dict."""
|
||||
x = int(max(0, min(top[0], bottom[0])))
|
||||
y = int(max(0, top[2]))
|
||||
x2 = int(min(frame_width, max(top[1], bottom[1])))
|
||||
y2 = int(min(frame_height, bottom[2]))
|
||||
w = x2 - x
|
||||
h = y2 - y
|
||||
|
||||
if w < 20 or h < 5:
|
||||
return None
|
||||
|
||||
# Confidence based on line lengths relative to box width
|
||||
avg_line_length = (top[3] + bottom[3]) / 2
|
||||
coverage = min(1.0, avg_line_length / max(w, 1))
|
||||
|
||||
return {
|
||||
"x": x,
|
||||
"y": y,
|
||||
"w": w,
|
||||
"h": h,
|
||||
"confidence": round(coverage, 3),
|
||||
"label": "edge_region",
|
||||
}
|
||||
|
||||
|
||||
def _np_to_b64_jpeg(image: np.ndarray, quality: int = 70) -> str:
|
||||
"""Encode a numpy image (BGR or grayscale) as base64 JPEG."""
|
||||
ok, buf = cv2.imencode(".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality])
|
||||
if not ok:
|
||||
return ""
|
||||
return base64.b64encode(buf.tobytes()).decode()
|
||||
|
||||
|
||||
def detect_edges_debug(
|
||||
image: np.ndarray,
|
||||
canny_low: int = 50,
|
||||
canny_high: int = 150,
|
||||
hough_threshold: int = 80,
|
||||
hough_min_length: int = 100,
|
||||
hough_max_gap: int = 10,
|
||||
pair_max_distance: int = 200,
|
||||
pair_min_distance: int = 15,
|
||||
) -> dict:
|
||||
"""
|
||||
Same as detect_edges but returns intermediate visualizations.
|
||||
|
||||
Returns dict with:
|
||||
regions: list[dict] — same boxes as detect_edges
|
||||
edge_overlay_b64: str — Canny edge image as base64 JPEG
|
||||
lines_overlay_b64: str — frame with Hough lines drawn
|
||||
horizontal_count: int — number of horizontal lines found
|
||||
pair_count: int — number of line pairs found
|
||||
"""
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
||||
edges = cv2.Canny(gray, canny_low, canny_high)
|
||||
|
||||
# Edge overlay — Canny output as-is (white edges on black)
|
||||
edge_overlay_b64 = _np_to_b64_jpeg(edges)
|
||||
|
||||
raw_lines = cv2.HoughLinesP(
|
||||
edges,
|
||||
rho=1,
|
||||
theta=np.pi / 180,
|
||||
threshold=hough_threshold,
|
||||
minLineLength=hough_min_length,
|
||||
maxLineGap=hough_max_gap,
|
||||
)
|
||||
|
||||
# Lines overlay — draw all Hough lines on a copy of the frame
|
||||
lines_vis = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
||||
if raw_lines is not None:
|
||||
for line in raw_lines:
|
||||
x1, y1, x2, y2 = line[0]
|
||||
cv2.line(lines_vis, (x1, y1), (x2, y2), (0, 0, 255), 1)
|
||||
|
||||
horizontals = []
|
||||
if raw_lines is not None:
|
||||
horizontals = _filter_horizontal(raw_lines, max_angle_deg=10)
|
||||
|
||||
# Draw horizontal lines in cyan, thicker
|
||||
for h_line in horizontals:
|
||||
x_min, x_max, y_mid, _ = h_line
|
||||
cv2.line(lines_vis, (int(x_min), int(y_mid)), (int(x_max), int(y_mid)), (255, 255, 0), 2)
|
||||
|
||||
pairs = []
|
||||
if len(horizontals) >= 2:
|
||||
pairs = _find_line_pairs(
|
||||
horizontals,
|
||||
min_distance=pair_min_distance,
|
||||
max_distance=pair_max_distance,
|
||||
)
|
||||
|
||||
# Draw paired lines in green
|
||||
for top_line, bottom_line in pairs:
|
||||
cv2.line(lines_vis, (int(top_line[0]), int(top_line[2])),
|
||||
(int(top_line[1]), int(top_line[2])), (0, 255, 0), 2)
|
||||
cv2.line(lines_vis, (int(bottom_line[0]), int(bottom_line[2])),
|
||||
(int(bottom_line[1]), int(bottom_line[2])), (0, 255, 0), 2)
|
||||
|
||||
lines_overlay_b64 = _np_to_b64_jpeg(lines_vis)
|
||||
|
||||
# Build region boxes (same logic as detect_edges)
|
||||
h, w = image.shape[:2]
|
||||
regions = []
|
||||
for top_line, bottom_line in pairs:
|
||||
box = _pair_to_bbox(top_line, bottom_line, frame_width=w, frame_height=h)
|
||||
if box is not None:
|
||||
regions.append(box)
|
||||
|
||||
return {
|
||||
"regions": regions,
|
||||
"edge_overlay_b64": edge_overlay_b64,
|
||||
"lines_overlay_b64": lines_overlay_b64,
|
||||
"horizontal_count": len(horizontals),
|
||||
"pair_count": len(pairs),
|
||||
}
|
||||
86
core/gpu/models/cv/segmentation.py
Normal file
86
core/gpu/models/cv/segmentation.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Field segmentation — HSV green mask → pitch boundary contour.
|
||||
|
||||
Pure OpenCV. Called by the inference server endpoint.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def segment_field(
|
||||
image: np.ndarray,
|
||||
hue_low: int = 30,
|
||||
hue_high: int = 85,
|
||||
sat_low: int = 30,
|
||||
sat_high: int = 255,
|
||||
val_low: int = 30,
|
||||
val_high: int = 255,
|
||||
morph_kernel: int = 15,
|
||||
min_area_ratio: float = 0.05,
|
||||
) -> dict:
|
||||
"""
|
||||
Detect the pitch area using HSV green thresholding.
|
||||
|
||||
Returns dict with:
|
||||
boundary: list of [x, y] points
|
||||
coverage: float (fraction of frame)
|
||||
"""
|
||||
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
|
||||
|
||||
lower = np.array([hue_low, sat_low, val_low])
|
||||
upper = np.array([hue_high, sat_high, val_high])
|
||||
mask = cv2.inRange(hsv, lower, upper)
|
||||
|
||||
k = morph_kernel
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
|
||||
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
||||
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
||||
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
h, w = image.shape[:2]
|
||||
min_area = min_area_ratio * h * w
|
||||
boundary = []
|
||||
coverage = 0.0
|
||||
|
||||
if contours:
|
||||
large = [c for c in contours if cv2.contourArea(c) >= min_area]
|
||||
if large:
|
||||
pitch_contour = max(large, key=cv2.contourArea)
|
||||
boundary = pitch_contour.reshape(-1, 2).tolist()
|
||||
coverage = cv2.contourArea(pitch_contour) / (h * w)
|
||||
|
||||
refined = np.zeros_like(mask)
|
||||
cv2.drawContours(refined, [pitch_contour], -1, 255, cv2.FILLED)
|
||||
mask = refined
|
||||
|
||||
return {
|
||||
"boundary": boundary,
|
||||
"coverage": coverage,
|
||||
"mask": mask,
|
||||
}
|
||||
|
||||
|
||||
def segment_field_debug(
|
||||
image: np.ndarray,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""Same as segment_field but includes a mask overlay for the editor."""
|
||||
result = segment_field(image, **kwargs)
|
||||
mask = result["mask"]
|
||||
|
||||
# RGBA overlay: solid green where mask, fully transparent elsewhere
|
||||
h, w = image.shape[:2]
|
||||
overlay = np.zeros((h, w, 4), dtype=np.uint8)
|
||||
overlay[mask > 0] = [0, 255, 0, 255]
|
||||
_, buf = cv2.imencode(".png", overlay)
|
||||
result["mask_overlay_b64"] = base64.b64encode(buf.tobytes()).decode()
|
||||
|
||||
# Don't send the raw mask over HTTP
|
||||
del result["mask"]
|
||||
return result
|
||||
136
core/gpu/models/models.py
Normal file
136
core/gpu/models/models.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Pydantic Models - GENERATED FILE
|
||||
|
||||
Do not edit directly. Regenerate using modelgen.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class DetectRequest(BaseModel):
|
||||
"""Request body for object detection."""
|
||||
image: str
|
||||
model: Optional[str] = None
|
||||
confidence: Optional[float] = None
|
||||
target_classes: Optional[List[str]] = None
|
||||
|
||||
class BBox(BaseModel):
|
||||
"""A detected bounding box."""
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
|
||||
class DetectResponse(BaseModel):
|
||||
"""Response from object detection."""
|
||||
detections: List[BBox] = Field(default_factory=list)
|
||||
|
||||
class OCRRequest(BaseModel):
|
||||
"""Request body for OCR."""
|
||||
image: str
|
||||
languages: Optional[List[str]] = None
|
||||
|
||||
class OCRTextResult(BaseModel):
|
||||
"""A single OCR text extraction result."""
|
||||
text: str
|
||||
confidence: float
|
||||
bbox: List[int] = Field(default_factory=list)
|
||||
|
||||
class OCRResponse(BaseModel):
|
||||
"""Response from OCR."""
|
||||
results: List[OCRTextResult] = Field(default_factory=list)
|
||||
|
||||
class PreprocessRequest(BaseModel):
|
||||
"""Request body for image preprocessing."""
|
||||
image: str
|
||||
binarize: bool = False
|
||||
deskew: bool = False
|
||||
contrast: bool = True
|
||||
|
||||
class PreprocessResponse(BaseModel):
|
||||
"""Response from preprocessing."""
|
||||
image: str
|
||||
|
||||
class VLMRequest(BaseModel):
|
||||
"""Request body for visual language model query."""
|
||||
image: str
|
||||
prompt: str
|
||||
model: Optional[str] = None
|
||||
|
||||
class VLMResponse(BaseModel):
|
||||
"""Response from VLM."""
|
||||
brand: str
|
||||
confidence: float
|
||||
reasoning: str
|
||||
|
||||
class AnalyzeRegionsRequest(BaseModel):
|
||||
"""Request body for CV region analysis."""
|
||||
image: str
|
||||
edge_canny_low: int = 50
|
||||
edge_canny_high: int = 150
|
||||
edge_hough_threshold: int = 80
|
||||
edge_hough_min_length: int = 100
|
||||
edge_hough_max_gap: int = 10
|
||||
edge_pair_max_distance: int = 200
|
||||
edge_pair_min_distance: int = 15
|
||||
|
||||
class RegionBox(BaseModel):
|
||||
"""A candidate region from CV analysis."""
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
|
||||
class AnalyzeRegionsResponse(BaseModel):
|
||||
"""Response from CV region analysis."""
|
||||
regions: List[RegionBox] = Field(default_factory=list)
|
||||
|
||||
class AnalyzeRegionsDebugResponse(BaseModel):
|
||||
"""Response from CV region analysis with debug overlays."""
|
||||
regions: List[RegionBox] = Field(default_factory=list)
|
||||
edge_overlay_b64: str = ""
|
||||
lines_overlay_b64: str = ""
|
||||
horizontal_count: int = 0
|
||||
pair_count: int = 0
|
||||
|
||||
class SegmentFieldRequest(BaseModel):
|
||||
"""Request body for field segmentation."""
|
||||
image: str
|
||||
hue_low: int = 30
|
||||
hue_high: int = 85
|
||||
sat_low: int = 30
|
||||
sat_high: int = 255
|
||||
val_low: int = 30
|
||||
val_high: int = 255
|
||||
morph_kernel: int = 15
|
||||
min_area_ratio: float = 0.05
|
||||
|
||||
class SegmentFieldResponse(BaseModel):
|
||||
"""Response from field segmentation."""
|
||||
boundary: List[List[int]] = Field(default_factory=list)
|
||||
coverage: float = 0.0
|
||||
mask_b64: str = ""
|
||||
|
||||
class SegmentFieldDebugResponse(BaseModel):
|
||||
"""Response from field segmentation with debug overlay."""
|
||||
boundary: List[List[int]] = Field(default_factory=list)
|
||||
coverage: float = 0.0
|
||||
mask_overlay_b64: str = ""
|
||||
|
||||
class ConfigUpdate(BaseModel):
|
||||
"""Request body for updating server configuration."""
|
||||
device: Optional[str] = None
|
||||
yolo_model: Optional[str] = None
|
||||
yolo_confidence: Optional[float] = None
|
||||
vram_budget_mb: Optional[int] = None
|
||||
strategy: Optional[str] = None
|
||||
ocr_languages: Optional[List[str]] = None
|
||||
ocr_min_confidence: Optional[float] = None
|
||||
105
core/gpu/models/ocr.py
Normal file
105
core/gpu/models/ocr.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""PaddleOCR 3.x text extraction wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from models import registry
|
||||
from config import get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load(languages: list[str]):
|
||||
from paddleocr import PaddleOCR
|
||||
key = f"ocr_{'_'.join(languages)}"
|
||||
model = PaddleOCR(lang=languages[0])
|
||||
registry.put(key, model)
|
||||
return model
|
||||
|
||||
|
||||
def _get(languages: list[str] | None = None):
|
||||
langs = languages or get_config()["ocr_languages"]
|
||||
key = f"ocr_{'_'.join(langs)}"
|
||||
model = registry.get(key)
|
||||
if model is None:
|
||||
model = _load(langs)
|
||||
return model
|
||||
|
||||
|
||||
def _parse_raw(raw) -> list[tuple[list, str, float]]:
|
||||
"""
|
||||
Parse PaddleOCR output into (points, text, confidence) tuples.
|
||||
|
||||
PaddleOCR 3.x changed the result format. Two known layouts:
|
||||
|
||||
Layout A — dict-based (new pipeline API):
|
||||
raw = [{'rec_texts': [...], 'rec_scores': [...], 'dt_polys': [...]}]
|
||||
|
||||
Layout B — nested list (2.x compat / some 3.x builds):
|
||||
raw = [[ [points, [text, score]], ... ]]
|
||||
raw = [[ [points, [text, score], [cls, cls_score]], ... ]] # with angle cls
|
||||
"""
|
||||
results = []
|
||||
|
||||
for page in raw:
|
||||
if not page:
|
||||
continue
|
||||
|
||||
# Layout A: dict with parallel lists
|
||||
if isinstance(page, dict):
|
||||
texts = page.get("rec_texts", [])
|
||||
scores = page.get("rec_scores", [])
|
||||
polys = page.get("dt_polys", [])
|
||||
for points, text, confidence in zip(polys, texts, scores):
|
||||
results.append((points, text, float(confidence)))
|
||||
continue
|
||||
|
||||
# Layout B: list of per-line entries
|
||||
for line in page:
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# line[0] is always the polygon points
|
||||
points = line[0]
|
||||
|
||||
# line[1] is [text, score] — ignore any extra elements (angle cls etc.)
|
||||
rec = line[1]
|
||||
if isinstance(rec, (list, tuple)) and len(rec) >= 2:
|
||||
text, confidence = rec[0], rec[1]
|
||||
else:
|
||||
logger.warning("Unexpected OCR line format: %s", line)
|
||||
continue
|
||||
|
||||
results.append((points, str(text), float(confidence)))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def ocr(image, languages: list[str] | None = None, min_confidence: float | None = None) -> list[dict]:
|
||||
"""Run OCR on an image, return list of text result dicts."""
|
||||
cfg = get_config()
|
||||
min_conf = min_confidence if min_confidence is not None else cfg["ocr_min_confidence"]
|
||||
model = _get(languages)
|
||||
|
||||
raw = model.ocr(image)
|
||||
logger.debug("OCR raw: %s", raw)
|
||||
|
||||
parsed = _parse_raw(raw)
|
||||
|
||||
results = []
|
||||
for points, text, confidence in parsed:
|
||||
if confidence < min_conf:
|
||||
continue
|
||||
|
||||
xs = [p[0] for p in points]
|
||||
ys = [p[1] for p in points]
|
||||
|
||||
results.append({
|
||||
"text": text,
|
||||
"confidence": confidence,
|
||||
"bbox": [int(min(xs)), int(min(ys)),
|
||||
int(max(xs) - min(xs)), int(max(ys) - min(ys))],
|
||||
})
|
||||
|
||||
return results
|
||||
117
core/gpu/models/preprocess.py
Normal file
117
core/gpu/models/preprocess.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Image preprocessing pipeline for crops before OCR.
|
||||
|
||||
Each step is independently toggleable via config.
|
||||
Operates on numpy arrays (BGR or RGB), returns processed array.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def binarize(image: np.ndarray, threshold: int = 128) -> np.ndarray:
|
||||
"""Convert to grayscale and apply Otsu binarization."""
|
||||
import cv2
|
||||
|
||||
if len(image.shape) == 3:
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
||||
else:
|
||||
gray = image
|
||||
|
||||
_, binary = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
|
||||
# Convert back to 3-channel for downstream compatibility
|
||||
result = cv2.cvtColor(binary, cv2.COLOR_GRAY2RGB)
|
||||
return result
|
||||
|
||||
|
||||
def deskew(image: np.ndarray) -> np.ndarray:
|
||||
"""Correct slight rotation using minimum area rectangle."""
|
||||
import cv2
|
||||
|
||||
if len(image.shape) == 3:
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
||||
else:
|
||||
gray = image
|
||||
|
||||
coords = np.column_stack(np.where(gray < 128))
|
||||
if len(coords) < 10:
|
||||
return image
|
||||
|
||||
rect = cv2.minAreaRect(coords)
|
||||
angle = rect[-1]
|
||||
|
||||
# Normalize angle
|
||||
if angle < -45:
|
||||
angle = -(90 + angle)
|
||||
else:
|
||||
angle = -angle
|
||||
|
||||
if abs(angle) < 0.5:
|
||||
return image
|
||||
|
||||
h, w = image.shape[:2]
|
||||
center = (w // 2, h // 2)
|
||||
rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
|
||||
result = cv2.warpAffine(
|
||||
image, rotation_matrix, (w, h),
|
||||
flags=cv2.INTER_LINEAR,
|
||||
borderMode=cv2.BORDER_REPLICATE,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def enhance_contrast(image: np.ndarray) -> np.ndarray:
|
||||
"""Apply CLAHE (adaptive histogram equalization) for contrast normalization."""
|
||||
import cv2
|
||||
|
||||
if len(image.shape) == 3:
|
||||
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
|
||||
l_channel = lab[:, :, 0]
|
||||
else:
|
||||
l_channel = image
|
||||
|
||||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(l_channel)
|
||||
|
||||
if len(image.shape) == 3:
|
||||
lab[:, :, 0] = enhanced
|
||||
result = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
|
||||
else:
|
||||
result = enhanced
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def preprocess(
|
||||
image: np.ndarray,
|
||||
do_binarize: bool = False,
|
||||
do_deskew: bool = False,
|
||||
do_contrast: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Run the preprocessing pipeline on a crop image.
|
||||
|
||||
Each step is independently toggleable. Order: contrast → deskew → binarize.
|
||||
Contrast first (works best on color), binarize last (destroys color info).
|
||||
"""
|
||||
result = image
|
||||
|
||||
if do_contrast:
|
||||
result = enhance_contrast(result)
|
||||
logger.debug("Preprocessing: contrast enhanced")
|
||||
|
||||
if do_deskew:
|
||||
result = deskew(result)
|
||||
logger.debug("Preprocessing: deskewed")
|
||||
|
||||
if do_binarize:
|
||||
result = binarize(result)
|
||||
logger.debug("Preprocessing: binarized")
|
||||
|
||||
return result
|
||||
37
core/gpu/models/registry.py
Normal file
37
core/gpu/models/registry.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
Model registry — manages loaded models and VRAM lifecycle.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_models: dict[str, object] = {}
|
||||
|
||||
|
||||
def get(name: str) -> object | None:
|
||||
return _models.get(name)
|
||||
|
||||
|
||||
def put(name: str, model: object) -> None:
|
||||
_models[name] = model
|
||||
logger.info("Loaded %s", name)
|
||||
|
||||
|
||||
def unload(name: str) -> bool:
|
||||
if name in _models:
|
||||
del _models[name]
|
||||
logger.info("Unloaded %s", name)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def loaded() -> list[str]:
|
||||
return list(_models.keys())
|
||||
|
||||
|
||||
def clear() -> None:
|
||||
_models.clear()
|
||||
logger.info("All models unloaded")
|
||||
100
core/gpu/models/vlm.py
Normal file
100
core/gpu/models/vlm.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""moondream2 visual language model wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from models import registry
|
||||
from config import get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MODEL_KEY = "vlm_moondream2"
|
||||
|
||||
|
||||
def _load():
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
device = get_config().get("device", "auto")
|
||||
if device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
logger.info("Loading moondream2 (device=%s)...", device)
|
||||
|
||||
model_id = "vikhyatk/moondream2"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
dtype = torch.float16 if "cuda" in device else torch.float32
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
device_map=device,
|
||||
)
|
||||
|
||||
wrapper = {"model": model, "tokenizer": tokenizer}
|
||||
registry.put(_MODEL_KEY, wrapper)
|
||||
logger.info("moondream2 loaded")
|
||||
return wrapper
|
||||
|
||||
|
||||
def _get():
|
||||
wrapper = registry.get(_MODEL_KEY)
|
||||
if wrapper is None:
|
||||
wrapper = _load()
|
||||
return wrapper
|
||||
|
||||
|
||||
def query(image, prompt: str) -> dict:
|
||||
"""
|
||||
Query moondream2 with an image crop and prompt.
|
||||
|
||||
Returns {"brand": str, "confidence": float, "reasoning": str}
|
||||
"""
|
||||
from PIL import Image as PILImage
|
||||
|
||||
wrapper = _get()
|
||||
model = wrapper["model"]
|
||||
tokenizer = wrapper["tokenizer"]
|
||||
|
||||
# Convert numpy array to PIL if needed
|
||||
if not isinstance(image, PILImage.Image):
|
||||
image = PILImage.fromarray(image)
|
||||
|
||||
enc_image = model.encode_image(image)
|
||||
answer = model.answer_question(enc_image, prompt, tokenizer)
|
||||
|
||||
# Parse response — moondream2 returns free text, extract brand + confidence
|
||||
result = _parse_vlm_response(answer)
|
||||
return result
|
||||
|
||||
|
||||
def _parse_vlm_response(answer: str) -> dict:
|
||||
"""
|
||||
Parse moondream2 free-text response into structured output.
|
||||
|
||||
Expected format from prompt: "brand, confidence (0-1), reasoning"
|
||||
Falls back gracefully if format doesn't match.
|
||||
"""
|
||||
answer = answer.strip()
|
||||
parts = [p.strip() for p in answer.split(",", 2)]
|
||||
|
||||
brand = parts[0] if parts else ""
|
||||
confidence = 0.5
|
||||
reasoning = answer
|
||||
|
||||
if len(parts) >= 2:
|
||||
try:
|
||||
confidence = float(parts[1])
|
||||
confidence = max(0.0, min(1.0, confidence))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if len(parts) >= 3:
|
||||
reasoning = parts[2]
|
||||
|
||||
return {
|
||||
"brand": brand,
|
||||
"confidence": confidence,
|
||||
"reasoning": reasoning,
|
||||
}
|
||||
54
core/gpu/models/yolo.py
Normal file
54
core/gpu/models/yolo.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""YOLO object detection model wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from models import registry
|
||||
from config import get_config, get_device
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load(model_name: str):
|
||||
from ultralytics import YOLO
|
||||
device = get_device()
|
||||
model = YOLO(model_name)
|
||||
model.to(device)
|
||||
registry.put(model_name, model)
|
||||
return model
|
||||
|
||||
|
||||
def _get(model_name: str | None = None):
|
||||
name = model_name or get_config()["yolo_model"]
|
||||
model = registry.get(name)
|
||||
if model is None:
|
||||
model = _load(name)
|
||||
return model
|
||||
|
||||
|
||||
def detect(image, model_name: str | None = None, confidence: float | None = None, target_classes: list[str] | None = None) -> list[dict]:
|
||||
"""Run YOLO detection, return list of bbox dicts."""
|
||||
cfg = get_config()
|
||||
conf = confidence if confidence is not None else cfg["yolo_confidence"]
|
||||
model = _get(model_name)
|
||||
|
||||
results = model(image, conf=conf, verbose=False)
|
||||
|
||||
detections = []
|
||||
for r in results:
|
||||
for box in r.boxes:
|
||||
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||
label = r.names[int(box.cls[0])]
|
||||
|
||||
if target_classes and label not in target_classes:
|
||||
continue
|
||||
|
||||
detections.append({
|
||||
"x": int(x1), "y": int(y1),
|
||||
"w": int(x2 - x1), "h": int(y2 - y1),
|
||||
"confidence": float(box.conf[0]),
|
||||
"label": label,
|
||||
})
|
||||
|
||||
return detections
|
||||
42
core/gpu/pyproject.toml
Normal file
42
core/gpu/pyproject.toml
Normal file
@@ -0,0 +1,42 @@
|
||||
[project]
|
||||
name = "mpr-gpu"
|
||||
version = "0.1.0"
|
||||
description = "MPR remote inference server (GPU)"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"fastapi>=0.109.0",
|
||||
"uvicorn[standard]>=0.27.0",
|
||||
"rapidfuzz>=3.0.0",
|
||||
"Pillow>=10.0.0",
|
||||
"redis>=5.0.0",
|
||||
"ultralytics>=8.0.0",
|
||||
"paddleocr>=3.0.0",
|
||||
"paddlepaddle-gpu==3.0.0",
|
||||
"transformers>=4.40.0,<5",
|
||||
"accelerate>=0.27.0",
|
||||
"torch",
|
||||
"torchvision",
|
||||
"opencv-python-headless>=4.8.0",
|
||||
]
|
||||
|
||||
# RTX 3080 / CUDA toolkit 12.8 — cu126 wheels are forward-compatible
|
||||
# (no cu128 wheels yet on either index). Mixing PyPI torch with CUDA 12.8
|
||||
# causes NCCL symbol errors, so the explicit index pins prevent uv from
|
||||
# pulling torch transitively from PyPI via ultralytics.
|
||||
[tool.uv.sources]
|
||||
torch = { index = "pytorch-cu126" }
|
||||
torchvision = { index = "pytorch-cu126" }
|
||||
paddlepaddle-gpu = { index = "paddle-cu126" }
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu126"
|
||||
url = "https://download.pytorch.org/whl/cu126"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "paddle-cu126"
|
||||
url = "https://www.paddlepaddle.org.cn/packages/stable/cu126/"
|
||||
explicit = true
|
||||
|
||||
[tool.uv]
|
||||
package = false
|
||||
59
core/gpu/run.sh
Executable file
59
core/gpu/run.sh
Executable file
@@ -0,0 +1,59 @@
|
||||
#!/bin/bash
|
||||
# Run the inference server
|
||||
#
|
||||
# Usage:
|
||||
# ./run.sh # Local: uv sync + run server.py (auto-installs/activates .venv)
|
||||
# ./run.sh docker # Docker (CPU)
|
||||
# ./run.sh docker-gpu # Docker with GPU
|
||||
# ./run.sh stop # Stop Docker container
|
||||
|
||||
set -e
|
||||
cd "$(dirname "${BASH_SOURCE[0]}")"
|
||||
|
||||
# Load env (create from template if missing)
|
||||
if [ ! -f .env ]; then
|
||||
if [ -f .env.template ]; then
|
||||
cp .env.template .env
|
||||
echo "Created .env from template — edit as needed"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -f .env ]; then
|
||||
set -a
|
||||
source .env
|
||||
set +a
|
||||
fi
|
||||
|
||||
case "${1:-local}" in
|
||||
local)
|
||||
if ! command -v uv >/dev/null 2>&1; then
|
||||
echo "uv not found. Install: curl -LsSf https://astral.sh/uv/install.sh | sh"
|
||||
exit 1
|
||||
fi
|
||||
uv sync
|
||||
uv run python server.py
|
||||
;;
|
||||
docker)
|
||||
docker build -t mpr-inference .
|
||||
ENV_FLAG=""; [ -f .env ] && ENV_FLAG="--env-file .env"
|
||||
docker run --rm -p "${PORT:-8000}:8000" \
|
||||
$ENV_FLAG \
|
||||
--name mpr-inference \
|
||||
mpr-inference
|
||||
;;
|
||||
docker-gpu)
|
||||
docker build -t mpr-inference .
|
||||
ENV_FLAG=""; [ -f .env ] && ENV_FLAG="--env-file .env"
|
||||
docker run --rm --gpus all -p "${PORT:-8000}:8000" \
|
||||
$ENV_FLAG \
|
||||
--name mpr-inference \
|
||||
mpr-inference
|
||||
;;
|
||||
stop)
|
||||
docker stop mpr-inference 2>/dev/null || true
|
||||
;;
|
||||
*)
|
||||
echo "Usage: ./run.sh [local|docker|docker-gpu|stop]"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
399
core/gpu/server.py
Normal file
399
core/gpu/server.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""
|
||||
Inference server — thin HTTP routes over model wrappers.
|
||||
|
||||
Config lives in config.py, model logic in models/.
|
||||
This file is just the FastAPI glue.
|
||||
|
||||
Usage:
|
||||
cd gpu && python server.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from emit import log as emit_log
|
||||
|
||||
from config import get_config, get_device, update_config
|
||||
from models import registry
|
||||
from models.yolo import detect as yolo_detect
|
||||
from models.ocr import ocr as ocr_run
|
||||
from models.vlm import query as vlm_query
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _decode_image(b64: str) -> np.ndarray:
|
||||
data = base64.b64decode(b64)
|
||||
img = Image.open(io.BytesIO(data)).convert("RGB")
|
||||
return np.array(img)
|
||||
|
||||
|
||||
def _job_ctx(request: Request) -> tuple[str, str]:
|
||||
"""Extract job_id and log_level from request headers."""
|
||||
job_id = request.headers.get("x-job-id", "")
|
||||
log_level = request.headers.get("x-log-level", "INFO")
|
||||
return job_id, log_level
|
||||
|
||||
|
||||
def _gpu_log(job_id: str, log_level: str, stage: str, level: str, msg: str):
|
||||
"""Emit a log event if job context is present."""
|
||||
if job_id:
|
||||
emit_log(job_id, stage, level, msg, log_level=log_level)
|
||||
|
||||
|
||||
# --- Request/Response models (generated from core/schema/models/inference.py) ---
|
||||
|
||||
from models.models import (
|
||||
AnalyzeRegionsDebugResponse,
|
||||
AnalyzeRegionsRequest,
|
||||
AnalyzeRegionsResponse,
|
||||
BBox,
|
||||
ConfigUpdate,
|
||||
DetectRequest,
|
||||
DetectResponse,
|
||||
OCRRequest,
|
||||
OCRResponse,
|
||||
OCRTextResult,
|
||||
PreprocessRequest,
|
||||
PreprocessResponse,
|
||||
RegionBox,
|
||||
SegmentFieldRequest,
|
||||
SegmentFieldResponse,
|
||||
SegmentFieldDebugResponse,
|
||||
VLMRequest,
|
||||
VLMResponse,
|
||||
)
|
||||
|
||||
|
||||
# --- App ---
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("Inference server starting (device=%s)", get_device())
|
||||
yield
|
||||
logger.info("Shutting down")
|
||||
registry.clear()
|
||||
|
||||
|
||||
app = FastAPI(title="MPR Inference Server", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
cfg = get_config()
|
||||
return {
|
||||
"status": "ok",
|
||||
"device": get_device(),
|
||||
"loaded_models": registry.loaded(),
|
||||
"vram_budget_mb": cfg["vram_budget_mb"],
|
||||
"strategy": cfg["strategy"],
|
||||
}
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
def read_config():
|
||||
return {**get_config(), "device_resolved": get_device()}
|
||||
|
||||
|
||||
@app.put("/config")
|
||||
def write_config(update: ConfigUpdate):
|
||||
changes = update.model_dump(exclude_none=True)
|
||||
if not changes:
|
||||
return get_config()
|
||||
|
||||
# Unload model if it changed
|
||||
old_model = get_config().get("yolo_model")
|
||||
if "yolo_model" in changes and changes["yolo_model"] != old_model:
|
||||
registry.unload(old_model)
|
||||
|
||||
update_config(changes)
|
||||
logger.info("Config updated: %s", changes)
|
||||
return {**get_config(), "device_resolved": get_device()}
|
||||
|
||||
|
||||
@app.post("/models/unload")
|
||||
def unload_model(body: dict):
|
||||
name = body.get("model", "")
|
||||
unloaded = registry.unload(name)
|
||||
return {"status": "unloaded" if unloaded else "not_loaded", "model": name}
|
||||
|
||||
|
||||
@app.post("/detect", response_model=DetectResponse)
|
||||
def detect(req: DetectRequest, request: Request):
|
||||
job_id, log_level = _job_ctx(request)
|
||||
|
||||
try:
|
||||
t0 = time.monotonic()
|
||||
image = _decode_image(req.image)
|
||||
decode_ms = (time.monotonic() - t0) * 1000
|
||||
h, w = image.shape[:2]
|
||||
_gpu_log(job_id, log_level, "GPU:YOLO", "DEBUG",
|
||||
f"Decoded {w}x{h} image in {decode_ms:.0f}ms")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||
|
||||
try:
|
||||
t0 = time.monotonic()
|
||||
results = yolo_detect(
|
||||
image,
|
||||
model_name=req.model,
|
||||
confidence=req.confidence,
|
||||
target_classes=req.target_classes,
|
||||
)
|
||||
infer_ms = (time.monotonic() - t0) * 1000
|
||||
_gpu_log(job_id, log_level, "GPU:YOLO", "DEBUG",
|
||||
f"Inference: {len(results)} detections in {infer_ms:.0f}ms "
|
||||
f"(model={req.model}, conf={req.confidence})")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Detection failed: {e}")
|
||||
|
||||
return DetectResponse(detections=[BBox(**r) for r in results])
|
||||
|
||||
|
||||
@app.post("/ocr", response_model=OCRResponse)
|
||||
def ocr(req: OCRRequest, request: Request):
|
||||
job_id, log_level = _job_ctx(request)
|
||||
|
||||
try:
|
||||
image = _decode_image(req.image)
|
||||
h, w = image.shape[:2]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||
|
||||
try:
|
||||
t0 = time.monotonic()
|
||||
results = ocr_run(image, languages=req.languages)
|
||||
infer_ms = (time.monotonic() - t0) * 1000
|
||||
texts = [r["text"][:20] for r in results]
|
||||
_gpu_log(job_id, log_level, "GPU:OCR", "DEBUG",
|
||||
f"OCR {w}x{h}: {infer_ms:.0f}ms → {len(results)} results {texts}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"OCR failed: {e}")
|
||||
|
||||
return OCRResponse(results=[OCRTextResult(**r) for r in results])
|
||||
|
||||
|
||||
@app.post("/preprocess", response_model=PreprocessResponse)
|
||||
def preprocess_image(req: PreprocessRequest):
|
||||
try:
|
||||
image = _decode_image(req.image)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||
|
||||
try:
|
||||
from models.preprocess import preprocess
|
||||
processed = preprocess(
|
||||
image,
|
||||
do_binarize=req.binarize,
|
||||
do_deskew=req.deskew,
|
||||
do_contrast=req.contrast,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Preprocessing failed: {e}")
|
||||
|
||||
from PIL import Image as PILImage
|
||||
import io
|
||||
img = PILImage.fromarray(processed)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=90)
|
||||
result_b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
return PreprocessResponse(image=result_b64)
|
||||
|
||||
|
||||
@app.post("/vlm", response_model=VLMResponse)
|
||||
def vlm(req: VLMRequest, request: Request):
|
||||
job_id, log_level = _job_ctx(request)
|
||||
|
||||
try:
|
||||
image = _decode_image(req.image)
|
||||
h, w = image.shape[:2]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||
|
||||
try:
|
||||
t0 = time.monotonic()
|
||||
result = vlm_query(image, req.prompt)
|
||||
infer_ms = (time.monotonic() - t0) * 1000
|
||||
_gpu_log(job_id, log_level, "GPU:VLM", "DEBUG",
|
||||
f"VLM {w}x{h}: {infer_ms:.0f}ms → "
|
||||
f"brand='{result.get('brand', '')}' conf={result.get('confidence', 0):.2f}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"VLM failed: {e}")
|
||||
|
||||
return VLMResponse(**result)
|
||||
|
||||
|
||||
@app.post("/detect_edges", response_model=AnalyzeRegionsResponse)
|
||||
def detect_edges_endpoint(req: AnalyzeRegionsRequest, request: Request):
|
||||
job_id, log_level = _job_ctx(request)
|
||||
|
||||
try:
|
||||
image = _decode_image(req.image)
|
||||
h, w = image.shape[:2]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||
|
||||
try:
|
||||
t0 = time.monotonic()
|
||||
from models.cv.edges import detect_edges
|
||||
|
||||
edge_regions = detect_edges(
|
||||
image,
|
||||
canny_low=req.edge_canny_low,
|
||||
canny_high=req.edge_canny_high,
|
||||
hough_threshold=req.edge_hough_threshold,
|
||||
hough_min_length=req.edge_hough_min_length,
|
||||
hough_max_gap=req.edge_hough_max_gap,
|
||||
pair_max_distance=req.edge_pair_max_distance,
|
||||
pair_min_distance=req.edge_pair_min_distance,
|
||||
)
|
||||
infer_ms = (time.monotonic() - t0) * 1000
|
||||
|
||||
_gpu_log(job_id, log_level, "GPU:CV", "DEBUG",
|
||||
f"Edge analysis {w}x{h}: {infer_ms:.0f}ms → {len(edge_regions)} regions")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Region analysis failed: {e}")
|
||||
|
||||
boxes = [RegionBox(**r) for r in edge_regions]
|
||||
return AnalyzeRegionsResponse(regions=boxes)
|
||||
|
||||
|
||||
@app.post("/detect_edges/debug", response_model=AnalyzeRegionsDebugResponse)
|
||||
def detect_edges_debug_endpoint(req: AnalyzeRegionsRequest, request: Request):
|
||||
job_id, log_level = _job_ctx(request)
|
||||
|
||||
try:
|
||||
image = _decode_image(req.image)
|
||||
h, w = image.shape[:2]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||
|
||||
try:
|
||||
t0 = time.monotonic()
|
||||
from models.cv.edges import detect_edges_debug
|
||||
|
||||
result = detect_edges_debug(
|
||||
image,
|
||||
canny_low=req.edge_canny_low,
|
||||
canny_high=req.edge_canny_high,
|
||||
hough_threshold=req.edge_hough_threshold,
|
||||
hough_min_length=req.edge_hough_min_length,
|
||||
hough_max_gap=req.edge_hough_max_gap,
|
||||
pair_max_distance=req.edge_pair_max_distance,
|
||||
pair_min_distance=req.edge_pair_min_distance,
|
||||
)
|
||||
infer_ms = (time.monotonic() - t0) * 1000
|
||||
|
||||
_gpu_log(job_id, log_level, "GPU:CV", "DEBUG",
|
||||
f"Edge debug {w}x{h}: {infer_ms:.0f}ms → {len(result['regions'])} regions, "
|
||||
f"{result['horizontal_count']} horizontals, {result['pair_count']} pairs")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Region debug analysis failed: {e}")
|
||||
|
||||
boxes = [RegionBox(**r) for r in result["regions"]]
|
||||
response = AnalyzeRegionsDebugResponse(
|
||||
regions=boxes,
|
||||
edge_overlay_b64=result["edge_overlay_b64"],
|
||||
lines_overlay_b64=result["lines_overlay_b64"],
|
||||
horizontal_count=result["horizontal_count"],
|
||||
pair_count=result["pair_count"],
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@app.post("/segment_field", response_model=SegmentFieldResponse)
|
||||
def segment_field_endpoint(req: SegmentFieldRequest, request: Request):
|
||||
job_id, log_level = _job_ctx(request)
|
||||
|
||||
try:
|
||||
image = _decode_image(req.image)
|
||||
h, w = image.shape[:2]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||
|
||||
try:
|
||||
t0 = time.monotonic()
|
||||
from models.cv.segmentation import segment_field
|
||||
|
||||
result = segment_field(
|
||||
image,
|
||||
hue_low=req.hue_low,
|
||||
hue_high=req.hue_high,
|
||||
sat_low=req.sat_low,
|
||||
sat_high=req.sat_high,
|
||||
val_low=req.val_low,
|
||||
val_high=req.val_high,
|
||||
morph_kernel=req.morph_kernel,
|
||||
min_area_ratio=req.min_area_ratio,
|
||||
)
|
||||
infer_ms = (time.monotonic() - t0) * 1000
|
||||
|
||||
# Encode mask as base64 PNG for downstream use
|
||||
import cv2
|
||||
_, buf = cv2.imencode(".png", result["mask"])
|
||||
mask_b64 = base64.b64encode(buf.tobytes()).decode()
|
||||
|
||||
_gpu_log(job_id, log_level, "GPU:CV", "DEBUG",
|
||||
f"Field segmentation {w}x{h}: {infer_ms:.0f}ms, coverage={result['coverage']:.1%}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Field segmentation failed: {e}")
|
||||
|
||||
return SegmentFieldResponse(
|
||||
boundary=result["boundary"],
|
||||
coverage=result["coverage"],
|
||||
mask_b64=mask_b64,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/segment_field/debug", response_model=SegmentFieldDebugResponse)
|
||||
def segment_field_debug_endpoint(req: SegmentFieldRequest, request: Request):
|
||||
job_id, log_level = _job_ctx(request)
|
||||
|
||||
try:
|
||||
image = _decode_image(req.image)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||
|
||||
try:
|
||||
from models.cv.segmentation import segment_field_debug
|
||||
|
||||
result = segment_field_debug(
|
||||
image,
|
||||
hue_low=req.hue_low,
|
||||
hue_high=req.hue_high,
|
||||
sat_low=req.sat_low,
|
||||
sat_high=req.sat_high,
|
||||
val_low=req.val_low,
|
||||
val_high=req.val_high,
|
||||
morph_kernel=req.morph_kernel,
|
||||
min_area_ratio=req.min_area_ratio,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Field segmentation debug failed: {e}")
|
||||
|
||||
return SegmentFieldDebugResponse(
|
||||
boundary=result["boundary"],
|
||||
coverage=result["coverage"],
|
||||
mask_overlay_b64=result["mask_overlay_b64"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
||||
host = os.environ.get("HOST", "0.0.0.0")
|
||||
port = int(os.environ.get("PORT", "8000"))
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
13
core/jobs/__init__.py
Normal file
13
core/jobs/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
MPR Jobs Module
|
||||
|
||||
Provides executor abstraction for job dispatch (local, Lambda, GCP).
|
||||
"""
|
||||
|
||||
from .executor import Executor, LocalExecutor, get_executor
|
||||
|
||||
__all__ = [
|
||||
"Executor",
|
||||
"LocalExecutor",
|
||||
"get_executor",
|
||||
]
|
||||
@@ -1,17 +1,16 @@
|
||||
"""
|
||||
Executor abstraction for job processing.
|
||||
|
||||
Supports different backends:
|
||||
- LocalExecutor: FFmpeg via Celery (default)
|
||||
- LambdaExecutor: AWS Lambda (future)
|
||||
Determines WHERE jobs run:
|
||||
- LocalExecutor: delegates to registered Handler (default)
|
||||
- LambdaExecutor: AWS Step Functions
|
||||
- GCPExecutor: Google Cloud Run Jobs
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from core.ffmpeg.transcode import TranscodeConfig, transcode
|
||||
|
||||
# Configuration from environment
|
||||
MPR_EXECUTOR = os.environ.get("MPR_EXECUTOR", "local")
|
||||
|
||||
@@ -22,26 +21,18 @@ class Executor(ABC):
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
job_type: str,
|
||||
job_id: str,
|
||||
source_path: str,
|
||||
output_path: str,
|
||||
preset: Optional[Dict[str, Any]] = None,
|
||||
trim_start: Optional[float] = None,
|
||||
trim_end: Optional[float] = None,
|
||||
duration: Optional[float] = None,
|
||||
payload: Dict[str, Any],
|
||||
progress_callback: Optional[Callable[[int, Dict[str, Any]], None]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Execute a transcode/trim job.
|
||||
Execute a job.
|
||||
|
||||
Args:
|
||||
job_type: Type of job ("transcode", "chunk", etc.)
|
||||
job_id: Unique job identifier
|
||||
source_path: Path to source file
|
||||
output_path: Path for output file
|
||||
preset: Transcode preset dict (optional, None = trim only)
|
||||
trim_start: Trim start time in seconds (optional)
|
||||
trim_end: Trim end time in seconds (optional)
|
||||
duration: Source duration in seconds (for progress calculation)
|
||||
payload: Job-type-specific configuration dict
|
||||
progress_callback: Called with (percent, details_dict)
|
||||
|
||||
Returns:
|
||||
@@ -51,61 +42,18 @@ class Executor(ABC):
|
||||
|
||||
|
||||
class LocalExecutor(Executor):
|
||||
"""Execute jobs locally using FFmpeg."""
|
||||
"""Execute jobs locally by calling the stage function directly."""
|
||||
|
||||
def run(
|
||||
self,
|
||||
job_type: str,
|
||||
job_id: str,
|
||||
source_path: str,
|
||||
output_path: str,
|
||||
preset: Optional[Dict[str, Any]] = None,
|
||||
trim_start: Optional[float] = None,
|
||||
trim_end: Optional[float] = None,
|
||||
duration: Optional[float] = None,
|
||||
payload: Dict[str, Any],
|
||||
progress_callback: Optional[Callable[[int, Dict[str, Any]], None]] = None,
|
||||
) -> bool:
|
||||
"""Execute job using local FFmpeg."""
|
||||
|
||||
# Build config from preset or use stream copy for trim-only
|
||||
if preset:
|
||||
config = TranscodeConfig(
|
||||
input_path=source_path,
|
||||
output_path=output_path,
|
||||
video_codec=preset.get("video_codec", "libx264"),
|
||||
video_bitrate=preset.get("video_bitrate"),
|
||||
video_crf=preset.get("video_crf"),
|
||||
video_preset=preset.get("video_preset"),
|
||||
resolution=preset.get("resolution"),
|
||||
framerate=preset.get("framerate"),
|
||||
audio_codec=preset.get("audio_codec", "aac"),
|
||||
audio_bitrate=preset.get("audio_bitrate"),
|
||||
audio_channels=preset.get("audio_channels"),
|
||||
audio_samplerate=preset.get("audio_samplerate"),
|
||||
container=preset.get("container", "mp4"),
|
||||
extra_args=preset.get("extra_args", []),
|
||||
trim_start=trim_start,
|
||||
trim_end=trim_end,
|
||||
)
|
||||
else:
|
||||
# Trim-only: stream copy
|
||||
config = TranscodeConfig(
|
||||
input_path=source_path,
|
||||
output_path=output_path,
|
||||
video_codec="copy",
|
||||
audio_codec="copy",
|
||||
trim_start=trim_start,
|
||||
trim_end=trim_end,
|
||||
)
|
||||
|
||||
# Wrapper to convert float percent to int
|
||||
def wrapped_callback(percent: float, details: Dict[str, Any]) -> None:
|
||||
if progress_callback:
|
||||
progress_callback(int(percent), details)
|
||||
|
||||
return transcode(
|
||||
config,
|
||||
duration=duration,
|
||||
progress_callback=wrapped_callback if progress_callback else None,
|
||||
"""Execute job locally. Socket for PipelineRunner integration."""
|
||||
raise NotImplementedError(
|
||||
"LocalExecutor.run() — will be wired to PipelineRunner in Phase 3"
|
||||
)
|
||||
|
||||
|
||||
@@ -123,26 +71,18 @@ class LambdaExecutor(Executor):
|
||||
|
||||
def run(
|
||||
self,
|
||||
job_type: str,
|
||||
job_id: str,
|
||||
source_path: str,
|
||||
output_path: str,
|
||||
preset: Optional[Dict[str, Any]] = None,
|
||||
trim_start: Optional[float] = None,
|
||||
trim_end: Optional[float] = None,
|
||||
duration: Optional[float] = None,
|
||||
payload: Dict[str, Any],
|
||||
progress_callback: Optional[Callable[[int, Dict[str, Any]], None]] = None,
|
||||
) -> bool:
|
||||
"""Start a Step Functions execution for this job."""
|
||||
import json
|
||||
|
||||
payload = {
|
||||
sfn_payload = {
|
||||
"job_type": job_type,
|
||||
"job_id": job_id,
|
||||
"source_key": source_path,
|
||||
"output_key": output_path,
|
||||
"preset": preset,
|
||||
"trim_start": trim_start,
|
||||
"trim_end": trim_end,
|
||||
"duration": duration,
|
||||
**payload,
|
||||
"callback_url": self.callback_url,
|
||||
"api_key": self.callback_api_key,
|
||||
}
|
||||
@@ -150,10 +90,9 @@ class LambdaExecutor(Executor):
|
||||
response = self.sfn.start_execution(
|
||||
stateMachineArn=self.state_machine_arn,
|
||||
name=f"mpr-{job_id}",
|
||||
input=json.dumps(payload),
|
||||
input=json.dumps(sfn_payload),
|
||||
)
|
||||
|
||||
# Store execution ARN on the job
|
||||
execution_arn = response["executionArn"]
|
||||
try:
|
||||
from core.db import update_job_fields
|
||||
@@ -179,13 +118,9 @@ class GCPExecutor(Executor):
|
||||
|
||||
def run(
|
||||
self,
|
||||
job_type: str,
|
||||
job_id: str,
|
||||
source_path: str,
|
||||
output_path: str,
|
||||
preset: Optional[Dict[str, Any]] = None,
|
||||
trim_start: Optional[float] = None,
|
||||
trim_end: Optional[float] = None,
|
||||
duration: Optional[float] = None,
|
||||
payload: Dict[str, Any],
|
||||
progress_callback: Optional[Callable[[int, Dict[str, Any]], None]] = None,
|
||||
) -> bool:
|
||||
"""Trigger a Cloud Run Job execution for this job."""
|
||||
@@ -193,14 +128,10 @@ class GCPExecutor(Executor):
|
||||
|
||||
from google.cloud import run_v2
|
||||
|
||||
payload = {
|
||||
gcp_payload = {
|
||||
"job_type": job_type,
|
||||
"job_id": job_id,
|
||||
"source_key": source_path,
|
||||
"output_key": output_path,
|
||||
"preset": preset,
|
||||
"trim_start": trim_start,
|
||||
"trim_end": trim_end,
|
||||
"duration": duration,
|
||||
**payload,
|
||||
"callback_url": self.callback_url,
|
||||
"api_key": self.callback_api_key,
|
||||
}
|
||||
@@ -216,7 +147,8 @@ class GCPExecutor(Executor):
|
||||
run_v2.RunJobRequest.Overrides.ContainerOverride(
|
||||
env=[
|
||||
run_v2.EnvVar(
|
||||
name="MPR_JOB_PAYLOAD", value=json.dumps(payload)
|
||||
name="MPR_JOB_PAYLOAD",
|
||||
value=json.dumps(gcp_payload),
|
||||
)
|
||||
]
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user