176 lines
5.5 KiB
Python
176 lines
5.5 KiB
Python
"""Database operations for DetectJob and StageCheckpoint."""
|
|
|
|
from typing import Optional
|
|
from uuid import UUID
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DetectJob
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def create_detect_job(**fields):
|
|
from admin.mpr.media_assets.models import DetectJob
|
|
return DetectJob.objects.create(**fields)
|
|
|
|
|
|
def get_detect_job(id: UUID):
|
|
from admin.mpr.media_assets.models import DetectJob
|
|
return DetectJob.objects.get(id=id)
|
|
|
|
|
|
def update_detect_job(job_id: UUID, **fields):
|
|
from admin.mpr.media_assets.models import DetectJob
|
|
DetectJob.objects.filter(id=job_id).update(**fields)
|
|
|
|
|
|
def list_detect_jobs(
|
|
parent_job_id: Optional[UUID] = None,
|
|
status: Optional[str] = None,
|
|
):
|
|
from admin.mpr.media_assets.models import DetectJob
|
|
|
|
qs = DetectJob.objects.all()
|
|
if parent_job_id:
|
|
qs = qs.filter(parent_job_id=parent_job_id)
|
|
if status:
|
|
qs = qs.filter(status=status)
|
|
return list(qs)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# StageCheckpoint
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def save_stage_checkpoint(**fields):
|
|
from admin.mpr.media_assets.models import StageCheckpoint
|
|
return StageCheckpoint.objects.create(**fields)
|
|
|
|
|
|
def get_stage_checkpoint(job_id: UUID, stage: str):
|
|
from admin.mpr.media_assets.models import StageCheckpoint
|
|
return StageCheckpoint.objects.get(job_id=job_id, stage=stage)
|
|
|
|
|
|
def list_stage_checkpoints(job_id: UUID) -> list[str]:
|
|
from admin.mpr.media_assets.models import StageCheckpoint
|
|
|
|
stages = (
|
|
StageCheckpoint.objects
|
|
.filter(job_id=job_id)
|
|
.order_by("stage_index")
|
|
.values_list("stage", flat=True)
|
|
)
|
|
return list(stages)
|
|
|
|
|
|
def delete_stage_checkpoints(job_id: UUID):
|
|
from admin.mpr.media_assets.models import StageCheckpoint
|
|
StageCheckpoint.objects.filter(job_id=job_id).delete()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# KnownBrand
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def get_or_create_brand(canonical_name: str, aliases: list[str] | None = None,
|
|
source: str = "ocr") -> tuple:
|
|
"""Get existing brand or create new one. Returns (brand, created)."""
|
|
from admin.mpr.media_assets.models import KnownBrand
|
|
import uuid
|
|
|
|
normalized = canonical_name.strip()
|
|
brand = KnownBrand.objects.filter(canonical_name__iexact=normalized).first()
|
|
if brand:
|
|
return brand, False
|
|
|
|
# Check aliases of existing brands
|
|
for existing in KnownBrand.objects.all():
|
|
existing_aliases = [a.lower() for a in (existing.aliases or [])]
|
|
if normalized.lower() in existing_aliases:
|
|
return existing, False
|
|
|
|
brand = KnownBrand.objects.create(
|
|
id=uuid.uuid4(),
|
|
canonical_name=normalized,
|
|
aliases=aliases or [],
|
|
first_source=source,
|
|
)
|
|
return brand, True
|
|
|
|
|
|
def find_brand_by_text(text: str) -> Optional[object]:
|
|
"""Find a known brand by canonical name or alias (case-insensitive)."""
|
|
from admin.mpr.media_assets.models import KnownBrand
|
|
|
|
normalized = text.strip().lower()
|
|
|
|
# Exact canonical match
|
|
brand = KnownBrand.objects.filter(canonical_name__iexact=normalized).first()
|
|
if brand:
|
|
return brand
|
|
|
|
# Search aliases (jsonb contains)
|
|
for brand in KnownBrand.objects.all():
|
|
brand_aliases = [a.lower() for a in (brand.aliases or [])]
|
|
if normalized in brand_aliases:
|
|
return brand
|
|
|
|
return None
|
|
|
|
|
|
def list_all_brands() -> list:
|
|
from admin.mpr.media_assets.models import KnownBrand
|
|
return list(KnownBrand.objects.all().order_by("canonical_name"))
|
|
|
|
|
|
def update_brand(brand_id: UUID, **fields):
|
|
from admin.mpr.media_assets.models import KnownBrand
|
|
KnownBrand.objects.filter(id=brand_id).update(**fields)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SourceBrandSighting
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def get_source_sightings(source_asset_id: UUID) -> list:
|
|
"""Get all brand sightings for a specific source video."""
|
|
from admin.mpr.media_assets.models import SourceBrandSighting
|
|
return list(
|
|
SourceBrandSighting.objects
|
|
.filter(source_asset_id=source_asset_id)
|
|
.order_by("-occurrences")
|
|
)
|
|
|
|
|
|
def record_sighting(source_asset_id: UUID, brand_id: UUID, brand_name: str,
|
|
timestamp: float, confidence: float, source: str = "ocr"):
|
|
"""Record or update a brand sighting for a source."""
|
|
from admin.mpr.media_assets.models import SourceBrandSighting
|
|
import uuid
|
|
|
|
sighting = SourceBrandSighting.objects.filter(
|
|
source_asset_id=source_asset_id,
|
|
brand_id=brand_id,
|
|
).first()
|
|
|
|
if sighting:
|
|
sighting.occurrences += 1
|
|
sighting.last_seen_timestamp = timestamp
|
|
total_conf = sighting.avg_confidence * (sighting.occurrences - 1) + confidence
|
|
sighting.avg_confidence = total_conf / sighting.occurrences
|
|
sighting.save()
|
|
return sighting
|
|
|
|
sighting = SourceBrandSighting.objects.create(
|
|
id=uuid.uuid4(),
|
|
source_asset_id=source_asset_id,
|
|
brand_id=brand_id,
|
|
brand_name=brand_name,
|
|
first_seen_timestamp=timestamp,
|
|
last_seen_timestamp=timestamp,
|
|
occurrences=1,
|
|
detection_source=source,
|
|
avg_confidence=confidence,
|
|
)
|
|
return sighting
|