62 lines
1.8 KiB
Python
62 lines
1.8 KiB
Python
"""Brand queries."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Optional
|
|
from uuid import UUID
|
|
|
|
from sqlmodel import Session, select
|
|
|
|
from .tables import Brand
|
|
|
|
|
|
def get_or_create_brand(session: Session, canonical_name: str,
|
|
aliases: Optional[list[str]] = None,
|
|
source: str = "ocr") -> tuple[Brand, bool]:
|
|
normalized = canonical_name.strip()
|
|
brand = session.exec(select(Brand).where(Brand.canonical_name.ilike(normalized))).first()
|
|
if brand:
|
|
return brand, False
|
|
|
|
brand = Brand(canonical_name=normalized, aliases=aliases or [], source=source)
|
|
session.add(brand)
|
|
session.flush()
|
|
return brand, True
|
|
|
|
|
|
def find_brand_by_text(session: Session, text: str) -> Brand | None:
|
|
normalized = text.strip().lower()
|
|
brand = session.exec(select(Brand).where(Brand.canonical_name.ilike(normalized))).first()
|
|
if brand:
|
|
return brand
|
|
|
|
for b in session.exec(select(Brand)).all():
|
|
if normalized in [a.lower() for a in (b.aliases or [])]:
|
|
return b
|
|
return None
|
|
|
|
|
|
def list_brands(session: Session) -> list[Brand]:
|
|
return list(session.exec(select(Brand).order_by(Brand.canonical_name)).all())
|
|
|
|
|
|
def record_airing(session: Session, brand_id: UUID, timeline_id: UUID,
|
|
frame_start: int, frame_end: int,
|
|
confidence: float, source: str = "ocr") -> Brand:
|
|
brand = session.get(Brand, brand_id)
|
|
if not brand:
|
|
raise ValueError(f"Brand not found: {brand_id}")
|
|
|
|
airing = {
|
|
"timeline_id": str(timeline_id),
|
|
"frame_start": frame_start,
|
|
"frame_end": frame_end,
|
|
"confidence": confidence,
|
|
"source": source,
|
|
}
|
|
airings = list(brand.airings or [])
|
|
airings.append(airing)
|
|
brand.airings = airings
|
|
brand.total_airings = len(airings)
|
|
return brand
|