diff --git a/admin/mpr/media_assets/models.py b/admin/mpr/media_assets/models.py index 5ef9005..e75c486 100644 --- a/admin/mpr/media_assets/models.py +++ b/admin/mpr/media_assets/models.py @@ -274,32 +274,3 @@ class SourceBrandSighting(models.Model): def __str__(self): return str(self.id) - -class SourceJob(models.Model): - """A group of chunks that belong together (same source video/session).""" - - job_id = models.CharField(max_length=255) - source_type = models.CharField(max_length=255) - chunk_count = models.IntegerField() - total_bytes = models.IntegerField(default=0) - - class Meta: - pass - - def __str__(self): - return str(self.id) - - -class ChunkInfo(models.Model): - """A single chunk (video segment) stored in blob storage.""" - - filename = models.CharField(max_length=500) - key = models.CharField(max_length=255) - size_bytes = models.IntegerField() - - class Meta: - pass - - def __str__(self): - return self.filename - diff --git a/core/api/main.py b/core/api/main.py index 4c9d7f8..d3c38b8 100644 --- a/core/api/main.py +++ b/core/api/main.py @@ -12,12 +12,7 @@ from uuid import UUID # Add project root to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) -# Initialize Django before importing models -os.environ.setdefault("DJANGO_SETTINGS_MODULE", "admin.mpr.settings") - -import django - -django.setup() +from contextlib import asynccontextmanager from fastapi import FastAPI, Header, HTTPException from fastapi.middleware.cors import CORSMiddleware @@ -32,12 +27,22 @@ from core.api.graphql import schema as graphql_schema CALLBACK_API_KEY = os.environ.get("CALLBACK_API_KEY", "") + +@asynccontextmanager +async def lifespan(app): + # Create/reset DB tables on startup + from core.db.connection import create_tables + create_tables() + yield + + app = FastAPI( title="MPR API", description="Media Processor — GraphQL API", version="0.1.0", docs_url="/docs", redoc_url="/redoc", + lifespan=lifespan, ) # CORS diff --git a/core/db/__init__.py b/core/db/__init__.py index 91689a9..899a17a 100644 --- a/core/db/__init__.py +++ b/core/db/__init__.py @@ -17,3 +17,4 @@ from .presets import ( get_preset, list_presets, ) +from .connection import get_session, create_tables diff --git a/core/db/assets.py b/core/db/assets.py index ae368a1..8f849df 100644 --- a/core/db/assets.py +++ b/core/db/assets.py @@ -1,48 +1,58 @@ -"""Database operations for MediaAsset.""" +"""Database operations for MediaAsset — SQLModel.""" + +from __future__ import annotations from typing import Optional from uuid import UUID +from sqlmodel import select -def list_assets(status: Optional[str] = None, search: Optional[str] = None): - from admin.mpr.media_assets.models import MediaAsset - - qs = MediaAsset.objects.all() - if status: - qs = qs.filter(status=status) - if search: - qs = qs.filter(filename__icontains=search) - return list(qs) +from .connection import get_session +from .models import MediaAsset -def get_asset(id: UUID): - from admin.mpr.media_assets.models import MediaAsset +def list_assets(status: Optional[str] = None, search: Optional[str] = None) -> list[MediaAsset]: + with get_session() as session: + stmt = select(MediaAsset) + if status: + stmt = stmt.where(MediaAsset.status == status) + if search: + stmt = stmt.where(MediaAsset.filename.ilike(f"%{search}%")) + return list(session.exec(stmt).all()) - return MediaAsset.objects.get(id=id) + +def get_asset(id: UUID) -> MediaAsset | None: + with get_session() as session: + return session.get(MediaAsset, id) def get_asset_filenames() -> set[str]: - from admin.mpr.media_assets.models import MediaAsset - - return set(MediaAsset.objects.values_list("filename", flat=True)) + with get_session() as session: + return set(session.exec(select(MediaAsset.filename)).all()) -def create_asset(*, filename: str, file_path: str, file_size: int): - from admin.mpr.media_assets.models import MediaAsset - - return MediaAsset.objects.create( - filename=filename, - file_path=file_path, - file_size=file_size, - ) +def create_asset(*, filename: str, file_path: str, file_size: int) -> MediaAsset: + asset = MediaAsset(filename=filename, file_path=file_path, file_size=file_size) + with get_session() as session: + session.add(asset) + session.commit() + session.refresh(asset) + return asset -def update_asset(asset, **fields): - for key, value in fields.items(): - setattr(asset, key, value) - asset.save(update_fields=list(fields.keys())) - return asset +def update_asset(id: UUID, **fields) -> None: + with get_session() as session: + asset = session.get(MediaAsset, id) + if not asset: + return + for k, v in fields.items(): + setattr(asset, k, v) + session.commit() -def delete_asset(asset): - asset.delete() +def delete_asset(id: UUID) -> None: + with get_session() as session: + asset = session.get(MediaAsset, id) + if asset: + session.delete(asset) + session.commit() diff --git a/core/db/connection.py b/core/db/connection.py new file mode 100644 index 0000000..ef89009 --- /dev/null +++ b/core/db/connection.py @@ -0,0 +1,33 @@ +""" +Database engine and session — SQLModel/SQLAlchemy, no Django. + +Reads DATABASE_URL from the environment. +""" + +from __future__ import annotations + +import os + +from sqlalchemy import create_engine +from sqlmodel import Session + +DATABASE_URL = os.environ.get("DATABASE_URL", "postgresql://mpr:mpr@localhost:5432/mpr") + +_engine = None + + +def get_engine(): + global _engine + if _engine is None: + _engine = create_engine(DATABASE_URL, pool_size=5, max_overflow=10) + return _engine + + +def get_session() -> Session: + return Session(get_engine()) + + +def create_tables(): + """Create all SQLModel tables.""" + from .models import SQLModel # noqa — registers all models + SQLModel.metadata.create_all(get_engine()) diff --git a/core/db/detect.py b/core/db/detect.py index 93872a1..3681b2d 100644 --- a/core/db/detect.py +++ b/core/db/detect.py @@ -1,175 +1,213 @@ -"""Database operations for DetectJob and StageCheckpoint.""" +"""Database operations for detection pipeline — SQLModel.""" + +from __future__ import annotations from typing import Optional from uuid import UUID +from sqlmodel import select + +from .connection import get_session +from .models import ( + DetectJob, StageCheckpoint, KnownBrand, SourceBrandSighting, +) + # --------------------------------------------------------------------------- # DetectJob # --------------------------------------------------------------------------- -def create_detect_job(**fields): - from admin.mpr.media_assets.models import DetectJob - return DetectJob.objects.create(**fields) +def create_detect_job(**fields) -> DetectJob: + job = DetectJob(**fields) + with get_session() as session: + session.add(job) + session.commit() + session.refresh(job) + return job -def get_detect_job(id: UUID): - from admin.mpr.media_assets.models import DetectJob - return DetectJob.objects.get(id=id) +def get_detect_job(id: UUID) -> DetectJob | None: + with get_session() as session: + return session.get(DetectJob, id) -def update_detect_job(job_id: UUID, **fields): - from admin.mpr.media_assets.models import DetectJob - DetectJob.objects.filter(id=job_id).update(**fields) +def update_detect_job(job_id: UUID, **fields) -> None: + with get_session() as session: + job = session.get(DetectJob, job_id) + if not job: + return + for k, v in fields.items(): + setattr(job, k, v) + session.commit() def list_detect_jobs( parent_job_id: Optional[UUID] = None, status: Optional[str] = None, -): - from admin.mpr.media_assets.models import DetectJob - - qs = DetectJob.objects.all() - if parent_job_id: - qs = qs.filter(parent_job_id=parent_job_id) - if status: - qs = qs.filter(status=status) - return list(qs) +) -> list[DetectJob]: + with get_session() as session: + stmt = select(DetectJob) + if parent_job_id: + stmt = stmt.where(DetectJob.parent_job_id == parent_job_id) + if status: + stmt = stmt.where(DetectJob.status == status) + return list(session.exec(stmt).all()) # --------------------------------------------------------------------------- # StageCheckpoint # --------------------------------------------------------------------------- -def save_stage_checkpoint(**fields): - from admin.mpr.media_assets.models import StageCheckpoint - return StageCheckpoint.objects.create(**fields) +def save_stage_checkpoint(**fields) -> StageCheckpoint: + with get_session() as session: + # Upsert: replace if same job_id + stage + job_id = fields.get("job_id") + stage = fields.get("stage") + if job_id and stage: + stmt = select(StageCheckpoint).where( + StageCheckpoint.job_id == job_id, + StageCheckpoint.stage == stage, + ) + existing = session.exec(stmt).first() + if existing: + for k, v in fields.items(): + setattr(existing, k, v) + session.commit() + session.refresh(existing) + return existing + + checkpoint = StageCheckpoint(**fields) + session.add(checkpoint) + session.commit() + session.refresh(checkpoint) + return checkpoint -def get_stage_checkpoint(job_id: UUID, stage: str): - from admin.mpr.media_assets.models import StageCheckpoint - return StageCheckpoint.objects.get(job_id=job_id, stage=stage) +def get_stage_checkpoint(job_id: UUID, stage: str) -> StageCheckpoint | None: + with get_session() as session: + stmt = select(StageCheckpoint).where( + StageCheckpoint.job_id == job_id, + StageCheckpoint.stage == stage, + ) + return session.exec(stmt).first() def list_stage_checkpoints(job_id: UUID) -> list[str]: - from admin.mpr.media_assets.models import StageCheckpoint - - stages = ( - StageCheckpoint.objects - .filter(job_id=job_id) - .order_by("stage_index") - .values_list("stage", flat=True) - ) - return list(stages) + with get_session() as session: + stmt = ( + select(StageCheckpoint.stage) + .where(StageCheckpoint.job_id == job_id) + .order_by(StageCheckpoint.stage_index) + ) + return list(session.exec(stmt).all()) -def delete_stage_checkpoints(job_id: UUID): - from admin.mpr.media_assets.models import StageCheckpoint - StageCheckpoint.objects.filter(job_id=job_id).delete() +def delete_stage_checkpoints(job_id: UUID) -> None: + with get_session() as session: + stmt = select(StageCheckpoint).where(StageCheckpoint.job_id == job_id) + for cp in session.exec(stmt).all(): + session.delete(cp) + session.commit() # --------------------------------------------------------------------------- # KnownBrand # --------------------------------------------------------------------------- -def get_or_create_brand(canonical_name: str, aliases: list[str] | None = None, - source: str = "ocr") -> tuple: - """Get existing brand or create new one. Returns (brand, created).""" - from admin.mpr.media_assets.models import KnownBrand - import uuid - +def get_or_create_brand(canonical_name: str, aliases: Optional[list[str]] = None, + source: str = "ocr") -> tuple[KnownBrand, bool]: normalized = canonical_name.strip() - brand = KnownBrand.objects.filter(canonical_name__iexact=normalized).first() - if brand: - return brand, False + with get_session() as session: + stmt = select(KnownBrand).where(KnownBrand.canonical_name.ilike(normalized)) + brand = session.exec(stmt).first() + if brand: + return brand, False - # Check aliases of existing brands - for existing in KnownBrand.objects.all(): - existing_aliases = [a.lower() for a in (existing.aliases or [])] - if normalized.lower() in existing_aliases: - return existing, False - - brand = KnownBrand.objects.create( - id=uuid.uuid4(), - canonical_name=normalized, - aliases=aliases or [], - first_source=source, - ) - return brand, True + brand = KnownBrand( + canonical_name=normalized, + aliases=aliases or [], + first_source=source, + ) + session.add(brand) + session.commit() + session.refresh(brand) + return brand, True -def find_brand_by_text(text: str) -> Optional[object]: - """Find a known brand by canonical name or alias (case-insensitive).""" - from admin.mpr.media_assets.models import KnownBrand - +def find_brand_by_text(text: str) -> KnownBrand | None: normalized = text.strip().lower() - - # Exact canonical match - brand = KnownBrand.objects.filter(canonical_name__iexact=normalized).first() - if brand: - return brand - - # Search aliases (jsonb contains) - for brand in KnownBrand.objects.all(): - brand_aliases = [a.lower() for a in (brand.aliases or [])] - if normalized in brand_aliases: + with get_session() as session: + stmt = select(KnownBrand).where(KnownBrand.canonical_name.ilike(normalized)) + brand = session.exec(stmt).first() + if brand: return brand - return None + # Alias search — check if normalized is in any brand's aliases + all_brands = session.exec(select(KnownBrand)).all() + for b in all_brands: + if normalized in [a.lower() for a in (b.aliases or [])]: + return b + return None -def list_all_brands() -> list: - from admin.mpr.media_assets.models import KnownBrand - return list(KnownBrand.objects.all().order_by("canonical_name")) +def list_all_brands() -> list[KnownBrand]: + with get_session() as session: + return list(session.exec(select(KnownBrand).order_by(KnownBrand.canonical_name)).all()) -def update_brand(brand_id: UUID, **fields): - from admin.mpr.media_assets.models import KnownBrand - KnownBrand.objects.filter(id=brand_id).update(**fields) +def update_brand(brand_id: UUID, **fields) -> None: + with get_session() as session: + brand = session.get(KnownBrand, brand_id) + if not brand: + return + for k, v in fields.items(): + setattr(brand, k, v) + session.commit() # --------------------------------------------------------------------------- # SourceBrandSighting # --------------------------------------------------------------------------- -def get_source_sightings(source_asset_id: UUID) -> list: - """Get all brand sightings for a specific source video.""" - from admin.mpr.media_assets.models import SourceBrandSighting - return list( - SourceBrandSighting.objects - .filter(source_asset_id=source_asset_id) - .order_by("-occurrences") - ) +def get_source_sightings(source_asset_id: UUID) -> list[SourceBrandSighting]: + with get_session() as session: + stmt = ( + select(SourceBrandSighting) + .where(SourceBrandSighting.source_asset_id == source_asset_id) + .order_by(SourceBrandSighting.occurrences.desc()) + ) + return list(session.exec(stmt).all()) def record_sighting(source_asset_id: UUID, brand_id: UUID, brand_name: str, - timestamp: float, confidence: float, source: str = "ocr"): - """Record or update a brand sighting for a source.""" - from admin.mpr.media_assets.models import SourceBrandSighting - import uuid + timestamp: float, confidence: float, source: str = "ocr") -> SourceBrandSighting: + with get_session() as session: + stmt = select(SourceBrandSighting).where( + SourceBrandSighting.source_asset_id == source_asset_id, + SourceBrandSighting.brand_id == brand_id, + ) + sighting = session.exec(stmt).first() - sighting = SourceBrandSighting.objects.filter( - source_asset_id=source_asset_id, - brand_id=brand_id, - ).first() + if sighting: + total_conf = sighting.avg_confidence * sighting.occurrences + confidence + sighting.occurrences += 1 + sighting.last_seen_timestamp = timestamp + sighting.avg_confidence = total_conf / sighting.occurrences + session.commit() + session.refresh(sighting) + return sighting - if sighting: - sighting.occurrences += 1 - sighting.last_seen_timestamp = timestamp - total_conf = sighting.avg_confidence * (sighting.occurrences - 1) + confidence - sighting.avg_confidence = total_conf / sighting.occurrences - sighting.save() + sighting = SourceBrandSighting( + source_asset_id=source_asset_id, + brand_id=brand_id, + brand_name=brand_name, + first_seen_timestamp=timestamp, + last_seen_timestamp=timestamp, + occurrences=1, + detection_source=source, + avg_confidence=confidence, + ) + session.add(sighting) + session.commit() + session.refresh(sighting) return sighting - - sighting = SourceBrandSighting.objects.create( - id=uuid.uuid4(), - source_asset_id=source_asset_id, - brand_id=brand_id, - brand_name=brand_name, - first_seen_timestamp=timestamp, - last_seen_timestamp=timestamp, - occurrences=1, - detection_source=source, - avg_confidence=confidence, - ) - return sighting diff --git a/core/db/jobs.py b/core/db/jobs.py index ac79f6b..cc73abd 100644 --- a/core/db/jobs.py +++ b/core/db/jobs.py @@ -1,40 +1,49 @@ -"""Database operations for TranscodeJob.""" +"""Database operations for TranscodeJob — SQLModel.""" + +from __future__ import annotations from typing import Optional from uuid import UUID +from sqlmodel import select -def list_jobs(status: Optional[str] = None, source_asset_id: Optional[UUID] = None): - from admin.mpr.media_assets.models import TranscodeJob - - qs = TranscodeJob.objects.all() - if status: - qs = qs.filter(status=status) - if source_asset_id: - qs = qs.filter(source_asset_id=source_asset_id) - return list(qs) +from .connection import get_session +from .models import TranscodeJob -def get_job(id: UUID): - from admin.mpr.media_assets.models import TranscodeJob - - return TranscodeJob.objects.get(id=id) +def list_jobs(status: Optional[str] = None, source_asset_id: Optional[UUID] = None) -> list[TranscodeJob]: + with get_session() as session: + stmt = select(TranscodeJob) + if status: + stmt = stmt.where(TranscodeJob.status == status) + if source_asset_id: + stmt = stmt.where(TranscodeJob.source_asset_id == source_asset_id) + return list(session.exec(stmt).all()) -def create_job(**fields): - from admin.mpr.media_assets.models import TranscodeJob - - return TranscodeJob.objects.create(**fields) +def get_job(id: UUID) -> TranscodeJob | None: + with get_session() as session: + return session.get(TranscodeJob, id) -def update_job(job, **fields): - for key, value in fields.items(): - setattr(job, key, value) - job.save(update_fields=list(fields.keys())) - return job +def create_job(**fields) -> TranscodeJob: + job = TranscodeJob(**fields) + with get_session() as session: + session.add(job) + session.commit() + session.refresh(job) + return job -def update_job_fields(job_id, **fields): - from admin.mpr.media_assets.models import TranscodeJob +def update_job(id: UUID, **fields) -> None: + with get_session() as session: + job = session.get(TranscodeJob, id) + if not job: + return + for k, v in fields.items(): + setattr(job, k, v) + session.commit() - TranscodeJob.objects.filter(id=job_id).update(**fields) + +def update_job_fields(job_id: UUID, **fields) -> None: + update_job(job_id, **fields) diff --git a/core/db/models.py b/core/db/models.py new file mode 100644 index 0000000..6dc6772 --- /dev/null +++ b/core/db/models.py @@ -0,0 +1,233 @@ +""" +SQLModel Table Models - GENERATED FILE + +Do not edit directly. Regenerate using modelgen. +""" + +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional +from uuid import UUID, uuid4 + +from sqlmodel import SQLModel, Field, Column +from sqlalchemy import JSON + +class AssetStatus(str, Enum): + PENDING = "pending" + READY = "ready" + ERROR = "error" + +class JobStatus(str, Enum): + PENDING = "pending" + PROCESSING = "processing" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + +class ChunkJobStatus(str, Enum): + PENDING = "pending" + CHUNKING = "chunking" + PROCESSING = "processing" + COLLECTING = "collecting" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + +class DetectJobStatus(str, Enum): + PENDING = "pending" + RUNNING = "running" + PAUSED = "paused" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + +class RunType(str, Enum): + INITIAL = "initial" + REPLAY = "replay" + RETRY = "retry" + +class BrandSource(str, Enum): + OCR = "ocr" + VLM = "local_vlm" + CLOUD = "cloud_llm" + MANUAL = "manual" + +class SourceType(str, Enum): + CHUNK_JOB = "chunk_job" + UPLOAD = "upload" + DEVICE = "device" + STREAM = "stream" + +class MediaAsset(SQLModel, table=True): + """A video/audio file registered in the system.""" + __tablename__ = "media_assets" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + filename: str + file_path: str + status: AssetStatus = "pending" + error_message: Optional[str] = None + file_size: Optional[int] = None + duration: Optional[float] = None + video_codec: Optional[str] = None + audio_codec: Optional[str] = None + width: Optional[int] = None + height: Optional[int] = None + framerate: Optional[float] = None + bitrate: Optional[int] = None + properties: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) + comments: str = "" + tags: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]')) + created_at: Optional[datetime] = Field(default_factory=datetime.utcnow) + updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow) + +class TranscodePreset(SQLModel, table=True): + """A reusable transcoding configuration (like Handbrake presets).""" + __tablename__ = "transcode_presets" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + name: str + description: str = "" + is_builtin: bool = False + container: str = "mp4" + video_codec: str = "libx264" + video_bitrate: Optional[str] = None + video_crf: Optional[int] = None + video_preset: Optional[str] = None + resolution: Optional[str] = None + framerate: Optional[float] = None + audio_codec: str = "aac" + audio_bitrate: Optional[str] = None + audio_channels: Optional[int] = None + audio_samplerate: Optional[int] = None + extra_args: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]')) + created_at: Optional[datetime] = Field(default_factory=datetime.utcnow) + updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow) + +class TranscodeJob(SQLModel, table=True): + """A transcoding or trimming job in the queue.""" + __tablename__ = "transcode_jobs" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + source_asset_id: UUID = Field(index=True) + preset_id: Optional[UUID] = None + preset_snapshot: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) + trim_start: Optional[float] = None + trim_end: Optional[float] = None + output_filename: str = "" + output_path: Optional[str] = None + output_asset_id: Optional[UUID] = None + status: JobStatus = "pending" + progress: float = 0.0 + current_frame: Optional[int] = None + current_time: Optional[float] = None + speed: Optional[str] = None + error_message: Optional[str] = None + celery_task_id: Optional[str] = None + execution_arn: Optional[str] = None + priority: int = 0 + created_at: Optional[datetime] = Field(default_factory=datetime.utcnow) + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + +class ChunkJob(SQLModel, table=True): + """A chunk pipeline job — splits a media file into chunks and processes them""" + __tablename__ = "chunk_jobs" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + source_asset_id: UUID = Field(index=True) + chunk_duration: float = 10.0 + num_workers: int = 4 + max_retries: int = 3 + processor_type: str = "ffmpeg" + status: ChunkJobStatus = "pending" + progress: float = 0.0 + total_chunks: int = 0 + processed_chunks: int = 0 + failed_chunks: int = 0 + retry_count: int = 0 + error_message: Optional[str] = None + throughput_mbps: Optional[float] = None + elapsed_seconds: Optional[float] = None + celery_task_id: Optional[str] = None + priority: int = 0 + created_at: Optional[datetime] = Field(default_factory=datetime.utcnow) + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + +class DetectJob(SQLModel, table=True): + """A detection pipeline job.""" + __tablename__ = "detect_jobs" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + source_asset_id: UUID = Field(index=True) + video_path: str + profile_name: str = "soccer_broadcast" + parent_job_id: Optional[UUID] = Field(default=None, index=True) + run_type: RunType = "initial" + replay_from_stage: Optional[str] = None + config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) + status: DetectJobStatus = "pending" + current_stage: Optional[str] = None + progress: float = 0.0 + error_message: Optional[str] = None + total_detections: int = 0 + brands_found: int = 0 + cloud_llm_calls: int = 0 + estimated_cost_usd: float = 0.0 + celery_task_id: Optional[str] = None + priority: int = 0 + created_at: Optional[datetime] = Field(default_factory=datetime.utcnow) + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + +class StageCheckpoint(SQLModel, table=True): + """A checkpoint saved after a pipeline stage completes.""" + __tablename__ = "stage_checkpoints" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + job_id: UUID = Field(index=True) + stage: str + stage_index: int + frames_prefix: str = "" + frames_manifest: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) + frames_meta: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]')) + filtered_frame_sequences: List[int] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]')) + boxes_by_frame: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) + text_candidates: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]')) + unresolved_candidates: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]')) + detections: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]')) + stats: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) + config_snapshot: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) + config_overrides: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON, nullable=False, server_default='{}')) + video_path: str = "" + profile_name: str = "" + created_at: Optional[datetime] = Field(default_factory=datetime.utcnow) + +class KnownBrand(SQLModel, table=True): + """A brand discovered or registered in the system.""" + __tablename__ = "known_brands" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + canonical_name: str = Field(index=True) + aliases: List[str] = Field(default_factory=list, sa_column=Column(JSON, nullable=False, server_default='[]')) + first_source: BrandSource = "ocr" + total_occurrences: int = 0 + confirmed: bool = False + created_at: Optional[datetime] = Field(default_factory=datetime.utcnow) + updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow) + +class SourceBrandSighting(SQLModel, table=True): + """A brand seen in a specific source (video/asset).""" + __tablename__ = "source_brand_sightings" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + source_asset_id: UUID = Field(index=True) + brand_id: UUID + brand_name: str + first_seen_timestamp: float = 0.0 + last_seen_timestamp: float = 0.0 + occurrences: int = 0 + detection_source: BrandSource = "ocr" + avg_confidence: float = 0.0 + created_at: Optional[datetime] = Field(default_factory=datetime.utcnow) diff --git a/core/db/presets.py b/core/db/presets.py index 4145c06..72dda4a 100644 --- a/core/db/presets.py +++ b/core/db/presets.py @@ -1,15 +1,20 @@ -"""Database operations for TranscodePreset.""" +"""Database operations for TranscodePreset — SQLModel.""" + +from __future__ import annotations from uuid import UUID +from sqlmodel import select -def list_presets(): - from admin.mpr.media_assets.models import TranscodePreset - - return list(TranscodePreset.objects.all()) +from .connection import get_session +from .models import TranscodePreset -def get_preset(id: UUID): - from admin.mpr.media_assets.models import TranscodePreset +def list_presets() -> list[TranscodePreset]: + with get_session() as session: + return list(session.exec(select(TranscodePreset)).all()) - return TranscodePreset.objects.get(id=id) + +def get_preset(id: UUID) -> TranscodePreset | None: + with get_session() as session: + return session.get(TranscodePreset, id) diff --git a/core/schema/modelgen.json b/core/schema/modelgen.json index 6705af6..37658ca 100644 --- a/core/schema/modelgen.json +++ b/core/schema/modelgen.json @@ -6,6 +6,11 @@ "output": "admin/mpr/media_assets/models.py", "include": ["dataclasses", "enums"] }, + { + "target": "sqlmodel", + "output": "core/db/models.py", + "include": ["dataclasses", "enums"] + }, { "target": "graphene", "output": "core/api/schema/graphql.py", diff --git a/core/schema/models/__init__.py b/core/schema/models/__init__.py index eae0e4f..8222f9b 100644 --- a/core/schema/models/__init__.py +++ b/core/schema/models/__init__.py @@ -37,10 +37,9 @@ from .ui_state import UI_STATE_VIEWS # noqa: F401 — UI store state types from .views import ChunkEvent, ChunkOutputFile, PipelineStats, WorkerEvent from .sources import ChunkInfo, SourceJob, SourceType -# Core domain models - generates Django, Pydantic, TypeScript +# Core domain models - generates Django, SQLModel, TypeScript DATACLASSES = [MediaAsset, TranscodePreset, TranscodeJob, ChunkJob, - DetectJob, StageCheckpoint, KnownBrand, SourceBrandSighting, - SourceJob, ChunkInfo] + DetectJob, StageCheckpoint, KnownBrand, SourceBrandSighting] # API request/response models - generates TypeScript only (no Django) # WorkerStatus from grpc.py is reused here @@ -51,6 +50,8 @@ API_MODELS = [ ScanResult, DeleteResult, WorkerStatus, + SourceJob, + ChunkInfo, ] # Status enums - included in generated code diff --git a/ctrl/Tiltfile b/ctrl/Tiltfile index 8feffa2..9562761 100644 --- a/ctrl/Tiltfile +++ b/ctrl/Tiltfile @@ -37,7 +37,8 @@ docker_build( k8s_resource('redis') k8s_resource('minio', port_forwards=['9000:9000', '9001:9001']) -k8s_resource('fastapi', resource_deps=['redis', 'minio']) +k8s_resource('postgres') +k8s_resource('fastapi', resource_deps=['redis', 'minio', 'postgres']) k8s_resource('detection-ui') k8s_resource('gateway', resource_deps=['fastapi', 'detection-ui'], port_forwards=['8080:8080']) @@ -45,6 +46,6 @@ k8s_resource('gateway', resource_deps=['fastapi', 'detection-ui'], # Group uncategorized resources (configmaps, namespace) under infra k8s_resource( objects=['mpr:namespace', 'mpr-config:configmap', 'minio-config:configmap', - 'envoy-gateway-config:configmap'], + 'postgres-config:configmap', 'envoy-gateway-config:configmap'], new_name='infra', ) diff --git a/ctrl/k8s/base/fastapi.yaml b/ctrl/k8s/base/fastapi.yaml index f7d6fe7..e2104fe 100644 --- a/ctrl/k8s/base/fastapi.yaml +++ b/ctrl/k8s/base/fastapi.yaml @@ -24,6 +24,8 @@ spec: name: mpr-config - configMapRef: name: minio-config + - configMapRef: + name: postgres-config readinessProbe: httpGet: path: /health diff --git a/ctrl/k8s/base/kustomization.yaml b/ctrl/k8s/base/kustomization.yaml index 982abb3..3972248 100644 --- a/ctrl/k8s/base/kustomization.yaml +++ b/ctrl/k8s/base/kustomization.yaml @@ -8,6 +8,7 @@ resources: - configmap.yaml - redis.yaml - minio.yaml + - postgres.yaml - fastapi.yaml - detection-ui.yaml - gateway.yaml diff --git a/ctrl/k8s/base/postgres.yaml b/ctrl/k8s/base/postgres.yaml new file mode 100644 index 0000000..1f67078 --- /dev/null +++ b/ctrl/k8s/base/postgres.yaml @@ -0,0 +1,63 @@ +apiVersion: v1 +kind: ConfigMap +metadata: + name: postgres-config + namespace: mpr +data: + POSTGRES_DB: mpr + POSTGRES_USER: mpr + POSTGRES_PASSWORD: mpr + DATABASE_URL: postgresql://mpr:mpr@postgres:5432/mpr +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: postgres + namespace: mpr +spec: + replicas: 1 + selector: + matchLabels: + app: postgres + template: + metadata: + labels: + app: postgres + spec: + containers: + - name: postgres + image: postgres:16-alpine + ports: + - containerPort: 5432 + envFrom: + - configMapRef: + name: postgres-config + readinessProbe: + exec: + command: ["pg_isready", "-U", "mpr"] + initialDelaySeconds: 5 + periodSeconds: 10 + resources: + requests: + memory: 128Mi + cpu: 100m + limits: + memory: 512Mi + volumeMounts: + - name: data + mountPath: /var/lib/postgresql/data + volumes: + - name: data + emptyDir: {} +--- +apiVersion: v1 +kind: Service +metadata: + name: postgres + namespace: mpr +spec: + selector: + app: postgres + ports: + - port: 5432 + targetPort: 5432 diff --git a/detect/checkpoint/frames.py b/detect/checkpoint/frames.py index 64d5084..4f3a698 100644 --- a/detect/checkpoint/frames.py +++ b/detect/checkpoint/frames.py @@ -13,7 +13,7 @@ from detect.models import Frame logger = logging.getLogger(__name__) -BUCKET = os.environ.get("S3_BUCKET_OUT", "out") +BUCKET = os.environ.get("S3_BUCKET", "mpr") CHECKPOINT_PREFIX = "checkpoints" diff --git a/detect/checkpoint/storage.py b/detect/checkpoint/storage.py index 155cede..d79c083 100644 --- a/detect/checkpoint/storage.py +++ b/detect/checkpoint/storage.py @@ -24,13 +24,14 @@ logger = logging.getLogger(__name__) def _has_db() -> bool: - """Check if the DB layer is available (Django + models generated by modelgen).""" + """Check if Postgres is reachable.""" try: - from core.db.detect import get_stage_checkpoint as _ - # Quick check that the model exists (modelgen may not have run yet) - from admin.mpr.media_assets.models import StageCheckpoint as _ + from core.db.connection import get_session + from sqlmodel import text + with get_session() as session: + session.exec(text("SELECT 1")) return True - except (ImportError, Exception): + except Exception: return False @@ -69,17 +70,13 @@ def save_checkpoint( def _save_to_db(job_id: str, stage: str, stage_index: int, data: dict) -> str: - """Save checkpoint structured data to Postgres via core/db.""" - import uuid + """Save checkpoint structured data to Postgres.""" from core.db.detect import save_stage_checkpoint - job_uuid = uuid.UUID(job_id) if isinstance(job_id, str) else job_id - checkpoint_id = uuid.uuid4() frames_prefix = f"{CHECKPOINT_PREFIX}/{job_id}/frames/" checkpoint = save_stage_checkpoint( - id=checkpoint_id, - job_id=job_uuid, + job_id=job_id, stage=stage, stage_index=stage_index, frames_prefix=frames_prefix, diff --git a/modelgen/generator/__init__.py b/modelgen/generator/__init__.py index be49ab4..7e1b55c 100644 --- a/modelgen/generator/__init__.py +++ b/modelgen/generator/__init__.py @@ -17,6 +17,7 @@ from .django import DjangoGenerator from .prisma import PrismaGenerator from .protobuf import ProtobufGenerator from .pydantic import PydanticGenerator +from .sqlmodel import SQLModelGenerator from .strawberry import StrawberryGenerator from .typescript import TypeScriptGenerator @@ -24,6 +25,7 @@ from .typescript import TypeScriptGenerator GENERATORS: Dict[str, Type[BaseGenerator]] = { "pydantic": PydanticGenerator, "django": DjangoGenerator, + "sqlmodel": SQLModelGenerator, "typescript": TypeScriptGenerator, "ts": TypeScriptGenerator, # Alias "protobuf": ProtobufGenerator, diff --git a/modelgen/generator/sqlmodel.py b/modelgen/generator/sqlmodel.py new file mode 100644 index 0000000..4498ed7 --- /dev/null +++ b/modelgen/generator/sqlmodel.py @@ -0,0 +1,186 @@ +""" +SQLModel Generator + +Generates SQLModel table classes from model definitions. +Extends the Pydantic generator — SQLModel classes *are* Pydantic models +with table=True and SQLAlchemy column config for JSON fields. +""" + +import dataclasses as dc +import re +from enum import Enum +from typing import Any, List, get_type_hints + +from ..helpers import get_origin_name, get_type_name, unwrap_optional +from .pydantic import PydanticGenerator + + +# --------------------------------------------------------------------------- +# Field resolvers — each returns a Field() string or None to fall through +# --------------------------------------------------------------------------- + +def _resolve_special(name, _base, _origin, _optional, _default): + """id, created_at, updated_at get fixed Field() definitions.""" + specials = { + "id": "Field(default_factory=uuid4, primary_key=True)", + "created_at": "Field(default_factory=datetime.utcnow)", + "updated_at": "Field(default_factory=datetime.utcnow)", + } + return specials.get(name) + + +def _resolve_json(name, _base, origin, _optional, _default): + """Dict and List fields → sa_column=Column(JSON).""" + mapping = { + "dict": ("dict", "{}"), + "list": ("list", "[]"), + } + entry = mapping.get(origin) + if not entry: + return None + factory, server_default = entry + return ( + f"Field(default_factory={factory}, " + f"sa_column=Column(JSON, nullable=False, server_default='{server_default}'))" + ) + + +def _resolve_indexed(name, _base, _origin, optional, _default): + """Known indexed fields.""" + indexed = {"source_asset_id", "parent_job_id", "job_id", "canonical_name"} + if name not in indexed: + return None + if optional: + return "Field(default=None, index=True)" + return "Field(index=True)" + + +def _resolve_optional(_name, _base, _origin, optional, _default): + """Optional fields default to None.""" + if optional: + return "None" + return None + + +def _resolve_default(_name, _base, _origin, _optional, default): + """Fields with explicit defaults. Enum before str (str enums are both).""" + if default is dc.MISSING or default is None: + return None + if isinstance(default, Enum): + return f'"{default.value}"' + if isinstance(default, bool): + return str(default) + if isinstance(default, (int, float)): + return str(default) + if isinstance(default, str): + return f'"{default}"' + return None + + +# Resolver chain — first non-None result wins +_FIELD_RESOLVERS = [ + _resolve_special, + _resolve_json, + _resolve_indexed, + _resolve_optional, + _resolve_default, +] + + +def _resolve_field(name, type_hint, default): + """Run the resolver chain for a field. Returns ' = ...' string.""" + base, is_optional = unwrap_optional(type_hint) + origin = get_origin_name(base) + + for resolver in _FIELD_RESOLVERS: + result = resolver(name, base, origin, is_optional, default) + if result is not None: + return f" = {result}" + return "" + + +def _to_snake_plural(name): + """CamelCase → snake_case_plural for table names.""" + s = re.sub(r"(?<=[a-z])(?=[A-Z])", "_", name).lower() + if s.endswith("y") and not s.endswith("ey"): + return s[:-1] + "ies" + if s.endswith("s"): + return s + "es" + return s + "s" + + +_HEADER = [ + '"""', + "SQLModel Table Models - GENERATED FILE", + "", + "Do not edit directly. Regenerate using modelgen.", + '"""', + "", + "from datetime import datetime", + "from enum import Enum", + "from typing import Any, Dict, List, Optional", + "from uuid import UUID, uuid4", + "", + "from sqlmodel import SQLModel, Field, Column", + "from sqlalchemy import JSON", + "", +] + + +class SQLModelGenerator(PydanticGenerator): + """Generates SQLModel table classes.""" + + def _generate_header(self) -> List[str]: + return list(_HEADER) + + def _generate_model_from_dataclass(self, cls: type) -> List[str]: + return _build_table( + cls.__name__, + cls.__doc__ or cls.__name__, + get_type_hints(cls), + {f.name: f for f in dc.fields(cls)}, + self._resolve_type, + ) + + def _generate_model_from_definition(self, model_def) -> List[str]: + hints = {f.name: f.type_hint for f in model_def.fields} + defaults = {f.name: f.default for f in model_def.fields} + + class FakeField: + def __init__(self, default): + self.default = default + + fields = {name: FakeField(defaults.get(name, dc.MISSING)) for name in hints} + return _build_table( + model_def.name, + model_def.docstring or model_def.name, + hints, + fields, + self._resolve_type, + ) + + +def _build_table(name, docstring, hints, fields, resolve_type_fn): + """Build a SQLModel table class from field data.""" + table_name = _to_snake_plural(name) + lines = [ + f"class {name}(SQLModel, table=True):", + f' """{docstring.strip().split(chr(10))[0]}"""', + f' __tablename__ = "{table_name}"', + "", + ] + + for field_name, type_hint in hints.items(): + if field_name.startswith("_"): + continue + + field = fields.get(field_name) + default_val = dc.MISSING + if field and field.default is not dc.MISSING: + default_val = field.default + + py_type = resolve_type_fn(type_hint, False) + field_extra = _resolve_field(field_name, type_hint, default_val) + lines.append(f" {field_name}: {py_type}{field_extra}") + + return lines diff --git a/requirements.txt b/requirements.txt index 4ba097d..bda5b00 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ -# Django +# Django (admin viewer only — no app code depends on this) Django>=4.2,<5.0 django-environ>=0.11.2 -psycopg2-binary>=2.9.9 # FastAPI fastapi>=0.109.0 @@ -32,6 +31,10 @@ langfuse>=2.0.0 # Cloud LLM providers (only needed for cloud escalation stage) anthropic>=0.40.0 +# Database (SQLModel/SQLAlchemy + psycopg2) +sqlmodel>=0.0.14 +psycopg2-binary>=2.9.9 + # Detection pipeline orchestration numpy>=1.24.0 Pillow>=10.0.0