240 lines
7.7 KiB
Python
240 lines
7.7 KiB
Python
"""Database operations for detection pipeline — SQLModel."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Optional
|
|
from uuid import UUID
|
|
|
|
from sqlmodel import select
|
|
|
|
from .connection import get_session
|
|
from .models import (
|
|
DetectJob, Timeline, Checkpoint,
|
|
KnownBrand, SourceBrandSighting,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DetectJob
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def create_detect_job(**fields) -> DetectJob:
|
|
job = DetectJob(**fields)
|
|
with get_session() as session:
|
|
session.add(job)
|
|
session.commit()
|
|
session.refresh(job)
|
|
return job
|
|
|
|
|
|
def get_detect_job(id: UUID) -> DetectJob | None:
|
|
with get_session() as session:
|
|
return session.get(DetectJob, id)
|
|
|
|
|
|
def update_detect_job(job_id: UUID, **fields) -> None:
|
|
with get_session() as session:
|
|
job = session.get(DetectJob, job_id)
|
|
if not job:
|
|
return
|
|
for k, v in fields.items():
|
|
setattr(job, k, v)
|
|
session.commit()
|
|
|
|
|
|
def list_detect_jobs(
|
|
parent_job_id: Optional[UUID] = None,
|
|
status: Optional[str] = None,
|
|
) -> list[DetectJob]:
|
|
with get_session() as session:
|
|
stmt = select(DetectJob)
|
|
if parent_job_id:
|
|
stmt = stmt.where(DetectJob.parent_job_id == parent_job_id)
|
|
if status:
|
|
stmt = stmt.where(DetectJob.status == status)
|
|
return list(session.exec(stmt).all())
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Timeline
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def create_timeline(**fields) -> Timeline:
|
|
timeline = Timeline(**fields)
|
|
with get_session() as session:
|
|
session.add(timeline)
|
|
session.commit()
|
|
session.refresh(timeline)
|
|
return timeline
|
|
|
|
|
|
def get_timeline(timeline_id: UUID) -> Timeline | None:
|
|
with get_session() as session:
|
|
return session.get(Timeline, timeline_id)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Checkpoint
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def save_checkpoint(**fields) -> Checkpoint:
|
|
checkpoint = Checkpoint(**fields)
|
|
with get_session() as session:
|
|
session.add(checkpoint)
|
|
session.commit()
|
|
session.refresh(checkpoint)
|
|
return checkpoint
|
|
|
|
|
|
def get_checkpoint(checkpoint_id: UUID) -> Checkpoint | None:
|
|
with get_session() as session:
|
|
return session.get(Checkpoint, checkpoint_id)
|
|
|
|
|
|
def get_latest_checkpoint(timeline_id: UUID, parent_id: UUID | None = None) -> Checkpoint | None:
|
|
"""Get the most recent checkpoint for a timeline, optionally from a specific parent."""
|
|
with get_session() as session:
|
|
stmt = (
|
|
select(Checkpoint)
|
|
.where(Checkpoint.timeline_id == timeline_id)
|
|
)
|
|
if parent_id is not None:
|
|
stmt = stmt.where(Checkpoint.parent_id == parent_id)
|
|
stmt = stmt.order_by(Checkpoint.created_at.desc())
|
|
return session.exec(stmt).first()
|
|
|
|
|
|
def list_checkpoints(timeline_id: UUID) -> list[Checkpoint]:
|
|
"""List all checkpoints for a timeline."""
|
|
with get_session() as session:
|
|
stmt = (
|
|
select(Checkpoint)
|
|
.where(Checkpoint.timeline_id == timeline_id)
|
|
.order_by(Checkpoint.created_at)
|
|
)
|
|
return list(session.exec(stmt).all())
|
|
|
|
|
|
def get_root_checkpoint(timeline_id: UUID) -> Checkpoint | None:
|
|
"""Get the root checkpoint (no parent) for a timeline."""
|
|
with get_session() as session:
|
|
stmt = select(Checkpoint).where(
|
|
Checkpoint.timeline_id == timeline_id,
|
|
Checkpoint.parent_id == None,
|
|
)
|
|
return session.exec(stmt).first()
|
|
|
|
|
|
def list_scenarios() -> list[Checkpoint]:
|
|
"""List all checkpoints marked as scenarios."""
|
|
with get_session() as session:
|
|
stmt = (
|
|
select(Checkpoint)
|
|
.where(Checkpoint.is_scenario == True)
|
|
.order_by(Checkpoint.created_at.desc())
|
|
)
|
|
return list(session.exec(stmt).all())
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# KnownBrand
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def get_or_create_brand(canonical_name: str, aliases: Optional[list[str]] = None,
|
|
source: str = "ocr") -> tuple[KnownBrand, bool]:
|
|
normalized = canonical_name.strip()
|
|
with get_session() as session:
|
|
stmt = select(KnownBrand).where(KnownBrand.canonical_name.ilike(normalized))
|
|
brand = session.exec(stmt).first()
|
|
if brand:
|
|
return brand, False
|
|
|
|
brand = KnownBrand(
|
|
canonical_name=normalized,
|
|
aliases=aliases or [],
|
|
first_source=source,
|
|
)
|
|
session.add(brand)
|
|
session.commit()
|
|
session.refresh(brand)
|
|
return brand, True
|
|
|
|
|
|
def find_brand_by_text(text: str) -> KnownBrand | None:
|
|
normalized = text.strip().lower()
|
|
with get_session() as session:
|
|
stmt = select(KnownBrand).where(KnownBrand.canonical_name.ilike(normalized))
|
|
brand = session.exec(stmt).first()
|
|
if brand:
|
|
return brand
|
|
|
|
# Alias search — check if normalized is in any brand's aliases
|
|
all_brands = session.exec(select(KnownBrand)).all()
|
|
for b in all_brands:
|
|
if normalized in [a.lower() for a in (b.aliases or [])]:
|
|
return b
|
|
return None
|
|
|
|
|
|
def list_all_brands() -> list[KnownBrand]:
|
|
with get_session() as session:
|
|
return list(session.exec(select(KnownBrand).order_by(KnownBrand.canonical_name)).all())
|
|
|
|
|
|
def update_brand(brand_id: UUID, **fields) -> None:
|
|
with get_session() as session:
|
|
brand = session.get(KnownBrand, brand_id)
|
|
if not brand:
|
|
return
|
|
for k, v in fields.items():
|
|
setattr(brand, k, v)
|
|
session.commit()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SourceBrandSighting
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def get_source_sightings(source_asset_id: UUID) -> list[SourceBrandSighting]:
|
|
with get_session() as session:
|
|
stmt = (
|
|
select(SourceBrandSighting)
|
|
.where(SourceBrandSighting.source_asset_id == source_asset_id)
|
|
.order_by(SourceBrandSighting.occurrences.desc())
|
|
)
|
|
return list(session.exec(stmt).all())
|
|
|
|
|
|
def record_sighting(source_asset_id: UUID, brand_id: UUID, brand_name: str,
|
|
timestamp: float, confidence: float, source: str = "ocr") -> SourceBrandSighting:
|
|
with get_session() as session:
|
|
stmt = select(SourceBrandSighting).where(
|
|
SourceBrandSighting.source_asset_id == source_asset_id,
|
|
SourceBrandSighting.brand_id == brand_id,
|
|
)
|
|
sighting = session.exec(stmt).first()
|
|
|
|
if sighting:
|
|
total_conf = sighting.avg_confidence * sighting.occurrences + confidence
|
|
sighting.occurrences += 1
|
|
sighting.last_seen_timestamp = timestamp
|
|
sighting.avg_confidence = total_conf / sighting.occurrences
|
|
session.commit()
|
|
session.refresh(sighting)
|
|
return sighting
|
|
|
|
sighting = SourceBrandSighting(
|
|
source_asset_id=source_asset_id,
|
|
brand_id=brand_id,
|
|
brand_name=brand_name,
|
|
first_seen_timestamp=timestamp,
|
|
last_seen_timestamp=timestamp,
|
|
occurrences=1,
|
|
detection_source=source,
|
|
avg_confidence=confidence,
|
|
)
|
|
session.add(sighting)
|
|
session.commit()
|
|
session.refresh(sighting)
|
|
return sighting
|