"""Database operations for detection pipeline — SQLModel.""" from __future__ import annotations from typing import Optional from uuid import UUID from sqlmodel import select from .connection import get_session from .models import ( DetectJob, StageCheckpoint, KnownBrand, SourceBrandSighting, ) # --------------------------------------------------------------------------- # DetectJob # --------------------------------------------------------------------------- def create_detect_job(**fields) -> DetectJob: job = DetectJob(**fields) with get_session() as session: session.add(job) session.commit() session.refresh(job) return job def get_detect_job(id: UUID) -> DetectJob | None: with get_session() as session: return session.get(DetectJob, id) def update_detect_job(job_id: UUID, **fields) -> None: with get_session() as session: job = session.get(DetectJob, job_id) if not job: return for k, v in fields.items(): setattr(job, k, v) session.commit() def list_detect_jobs( parent_job_id: Optional[UUID] = None, status: Optional[str] = None, ) -> list[DetectJob]: with get_session() as session: stmt = select(DetectJob) if parent_job_id: stmt = stmt.where(DetectJob.parent_job_id == parent_job_id) if status: stmt = stmt.where(DetectJob.status == status) return list(session.exec(stmt).all()) # --------------------------------------------------------------------------- # StageCheckpoint # --------------------------------------------------------------------------- def save_stage_checkpoint(**fields) -> StageCheckpoint: with get_session() as session: # Upsert: replace if same job_id + stage job_id = fields.get("job_id") stage = fields.get("stage") if job_id and stage: stmt = select(StageCheckpoint).where( StageCheckpoint.job_id == job_id, StageCheckpoint.stage == stage, ) existing = session.exec(stmt).first() if existing: for k, v in fields.items(): setattr(existing, k, v) session.commit() session.refresh(existing) return existing checkpoint = StageCheckpoint(**fields) session.add(checkpoint) session.commit() session.refresh(checkpoint) return checkpoint def get_stage_checkpoint(job_id: UUID, stage: str) -> StageCheckpoint | None: with get_session() as session: stmt = select(StageCheckpoint).where( StageCheckpoint.job_id == job_id, StageCheckpoint.stage == stage, ) return session.exec(stmt).first() def list_stage_checkpoints(job_id: UUID) -> list[str]: with get_session() as session: stmt = ( select(StageCheckpoint.stage) .where(StageCheckpoint.job_id == job_id) .order_by(StageCheckpoint.stage_index) ) return list(session.exec(stmt).all()) def delete_stage_checkpoints(job_id: UUID) -> None: with get_session() as session: stmt = select(StageCheckpoint).where(StageCheckpoint.job_id == job_id) for cp in session.exec(stmt).all(): session.delete(cp) session.commit() # --------------------------------------------------------------------------- # KnownBrand # --------------------------------------------------------------------------- def get_or_create_brand(canonical_name: str, aliases: Optional[list[str]] = None, source: str = "ocr") -> tuple[KnownBrand, bool]: normalized = canonical_name.strip() with get_session() as session: stmt = select(KnownBrand).where(KnownBrand.canonical_name.ilike(normalized)) brand = session.exec(stmt).first() if brand: return brand, False brand = KnownBrand( canonical_name=normalized, aliases=aliases or [], first_source=source, ) session.add(brand) session.commit() session.refresh(brand) return brand, True def find_brand_by_text(text: str) -> KnownBrand | None: normalized = text.strip().lower() with get_session() as session: stmt = select(KnownBrand).where(KnownBrand.canonical_name.ilike(normalized)) brand = session.exec(stmt).first() if brand: return brand # Alias search — check if normalized is in any brand's aliases all_brands = session.exec(select(KnownBrand)).all() for b in all_brands: if normalized in [a.lower() for a in (b.aliases or [])]: return b return None def list_all_brands() -> list[KnownBrand]: with get_session() as session: return list(session.exec(select(KnownBrand).order_by(KnownBrand.canonical_name)).all()) def update_brand(brand_id: UUID, **fields) -> None: with get_session() as session: brand = session.get(KnownBrand, brand_id) if not brand: return for k, v in fields.items(): setattr(brand, k, v) session.commit() # --------------------------------------------------------------------------- # SourceBrandSighting # --------------------------------------------------------------------------- def get_source_sightings(source_asset_id: UUID) -> list[SourceBrandSighting]: with get_session() as session: stmt = ( select(SourceBrandSighting) .where(SourceBrandSighting.source_asset_id == source_asset_id) .order_by(SourceBrandSighting.occurrences.desc()) ) return list(session.exec(stmt).all()) def record_sighting(source_asset_id: UUID, brand_id: UUID, brand_name: str, timestamp: float, confidence: float, source: str = "ocr") -> SourceBrandSighting: with get_session() as session: stmt = select(SourceBrandSighting).where( SourceBrandSighting.source_asset_id == source_asset_id, SourceBrandSighting.brand_id == brand_id, ) sighting = session.exec(stmt).first() if sighting: total_conf = sighting.avg_confidence * sighting.occurrences + confidence sighting.occurrences += 1 sighting.last_seen_timestamp = timestamp sighting.avg_confidence = total_conf / sighting.occurrences session.commit() session.refresh(sighting) return sighting sighting = SourceBrandSighting( source_asset_id=source_asset_id, brand_id=brand_id, brand_name=brand_name, first_seen_timestamp=timestamp, last_seen_timestamp=timestamp, occurrences=1, detection_source=source, avg_confidence=confidence, ) session.add(sighting) session.commit() session.refresh(sighting) return sighting