"""Database operations for detection pipeline — SQLModel.""" from __future__ import annotations from typing import Optional from uuid import UUID from sqlmodel import select from .connection import get_session from .models import ( DetectJob, Timeline, Checkpoint, KnownBrand, SourceBrandSighting, ) # --------------------------------------------------------------------------- # DetectJob # --------------------------------------------------------------------------- def create_detect_job(**fields) -> DetectJob: job = DetectJob(**fields) with get_session() as session: session.add(job) session.commit() session.refresh(job) return job def get_detect_job(id: UUID) -> DetectJob | None: with get_session() as session: return session.get(DetectJob, id) def update_detect_job(job_id: UUID, **fields) -> None: with get_session() as session: job = session.get(DetectJob, job_id) if not job: return for k, v in fields.items(): setattr(job, k, v) session.commit() def list_detect_jobs( parent_job_id: Optional[UUID] = None, status: Optional[str] = None, ) -> list[DetectJob]: with get_session() as session: stmt = select(DetectJob) if parent_job_id: stmt = stmt.where(DetectJob.parent_job_id == parent_job_id) if status: stmt = stmt.where(DetectJob.status == status) return list(session.exec(stmt).all()) # --------------------------------------------------------------------------- # Timeline # --------------------------------------------------------------------------- def create_timeline(**fields) -> Timeline: timeline = Timeline(**fields) with get_session() as session: session.add(timeline) session.commit() session.refresh(timeline) return timeline def get_timeline(timeline_id: UUID) -> Timeline | None: with get_session() as session: return session.get(Timeline, timeline_id) # --------------------------------------------------------------------------- # Checkpoint # --------------------------------------------------------------------------- def save_checkpoint(**fields) -> Checkpoint: checkpoint = Checkpoint(**fields) with get_session() as session: session.add(checkpoint) session.commit() session.refresh(checkpoint) return checkpoint def get_checkpoint(checkpoint_id: UUID) -> Checkpoint | None: with get_session() as session: return session.get(Checkpoint, checkpoint_id) def get_latest_checkpoint(timeline_id: UUID, parent_id: UUID | None = None) -> Checkpoint | None: """Get the most recent checkpoint for a timeline, optionally from a specific parent.""" with get_session() as session: stmt = ( select(Checkpoint) .where(Checkpoint.timeline_id == timeline_id) ) if parent_id is not None: stmt = stmt.where(Checkpoint.parent_id == parent_id) stmt = stmt.order_by(Checkpoint.created_at.desc()) return session.exec(stmt).first() def list_checkpoints(timeline_id: UUID) -> list[Checkpoint]: """List all checkpoints for a timeline.""" with get_session() as session: stmt = ( select(Checkpoint) .where(Checkpoint.timeline_id == timeline_id) .order_by(Checkpoint.created_at) ) return list(session.exec(stmt).all()) def get_root_checkpoint(timeline_id: UUID) -> Checkpoint | None: """Get the root checkpoint (no parent) for a timeline.""" with get_session() as session: stmt = select(Checkpoint).where( Checkpoint.timeline_id == timeline_id, Checkpoint.parent_id == None, ) return session.exec(stmt).first() def list_scenarios() -> list[Checkpoint]: """List all checkpoints marked as scenarios.""" with get_session() as session: stmt = ( select(Checkpoint) .where(Checkpoint.is_scenario == True) .order_by(Checkpoint.created_at.desc()) ) return list(session.exec(stmt).all()) # --------------------------------------------------------------------------- # KnownBrand # --------------------------------------------------------------------------- def get_or_create_brand(canonical_name: str, aliases: Optional[list[str]] = None, source: str = "ocr") -> tuple[KnownBrand, bool]: normalized = canonical_name.strip() with get_session() as session: stmt = select(KnownBrand).where(KnownBrand.canonical_name.ilike(normalized)) brand = session.exec(stmt).first() if brand: return brand, False brand = KnownBrand( canonical_name=normalized, aliases=aliases or [], first_source=source, ) session.add(brand) session.commit() session.refresh(brand) return brand, True def find_brand_by_text(text: str) -> KnownBrand | None: normalized = text.strip().lower() with get_session() as session: stmt = select(KnownBrand).where(KnownBrand.canonical_name.ilike(normalized)) brand = session.exec(stmt).first() if brand: return brand # Alias search — check if normalized is in any brand's aliases all_brands = session.exec(select(KnownBrand)).all() for b in all_brands: if normalized in [a.lower() for a in (b.aliases or [])]: return b return None def list_all_brands() -> list[KnownBrand]: with get_session() as session: return list(session.exec(select(KnownBrand).order_by(KnownBrand.canonical_name)).all()) def update_brand(brand_id: UUID, **fields) -> None: with get_session() as session: brand = session.get(KnownBrand, brand_id) if not brand: return for k, v in fields.items(): setattr(brand, k, v) session.commit() # --------------------------------------------------------------------------- # SourceBrandSighting # --------------------------------------------------------------------------- def get_source_sightings(source_asset_id: UUID) -> list[SourceBrandSighting]: with get_session() as session: stmt = ( select(SourceBrandSighting) .where(SourceBrandSighting.source_asset_id == source_asset_id) .order_by(SourceBrandSighting.occurrences.desc()) ) return list(session.exec(stmt).all()) def record_sighting(source_asset_id: UUID, brand_id: UUID, brand_name: str, timestamp: float, confidence: float, source: str = "ocr") -> SourceBrandSighting: with get_session() as session: stmt = select(SourceBrandSighting).where( SourceBrandSighting.source_asset_id == source_asset_id, SourceBrandSighting.brand_id == brand_id, ) sighting = session.exec(stmt).first() if sighting: total_conf = sighting.avg_confidence * sighting.occurrences + confidence sighting.occurrences += 1 sighting.last_seen_timestamp = timestamp sighting.avg_confidence = total_conf / sighting.occurrences session.commit() session.refresh(sighting) return sighting sighting = SourceBrandSighting( source_asset_id=source_asset_id, brand_id=brand_id, brand_name=brand_name, first_seen_timestamp=timestamp, last_seen_timestamp=timestamp, occurrences=1, detection_source=source, avg_confidence=confidence, ) session.add(sighting) session.commit() session.refresh(sighting) return sighting