Files
mediaproc/core/db/detect.py
2026-03-26 22:22:35 -03:00

225 lines
7.3 KiB
Python

"""Database operations for detection pipeline — SQLModel."""
from __future__ import annotations
from typing import Optional
from uuid import UUID
from sqlmodel import select
from .connection import get_session
from .models import (
DetectJob, StageCheckpoint, KnownBrand, SourceBrandSighting,
)
# ---------------------------------------------------------------------------
# DetectJob
# ---------------------------------------------------------------------------
def create_detect_job(**fields) -> DetectJob:
job = DetectJob(**fields)
with get_session() as session:
session.add(job)
session.commit()
session.refresh(job)
return job
def get_detect_job(id: UUID) -> DetectJob | None:
with get_session() as session:
return session.get(DetectJob, id)
def update_detect_job(job_id: UUID, **fields) -> None:
with get_session() as session:
job = session.get(DetectJob, job_id)
if not job:
return
for k, v in fields.items():
setattr(job, k, v)
session.commit()
def list_detect_jobs(
parent_job_id: Optional[UUID] = None,
status: Optional[str] = None,
) -> list[DetectJob]:
with get_session() as session:
stmt = select(DetectJob)
if parent_job_id:
stmt = stmt.where(DetectJob.parent_job_id == parent_job_id)
if status:
stmt = stmt.where(DetectJob.status == status)
return list(session.exec(stmt).all())
# ---------------------------------------------------------------------------
# StageCheckpoint
# ---------------------------------------------------------------------------
def save_stage_checkpoint(**fields) -> StageCheckpoint:
with get_session() as session:
# Upsert: replace if same job_id + stage
job_id = fields.get("job_id")
stage = fields.get("stage")
if job_id and stage:
stmt = select(StageCheckpoint).where(
StageCheckpoint.job_id == job_id,
StageCheckpoint.stage == stage,
)
existing = session.exec(stmt).first()
if existing:
for k, v in fields.items():
setattr(existing, k, v)
session.commit()
session.refresh(existing)
return existing
checkpoint = StageCheckpoint(**fields)
session.add(checkpoint)
session.commit()
session.refresh(checkpoint)
return checkpoint
def get_stage_checkpoint(job_id: UUID, stage: str) -> StageCheckpoint | None:
with get_session() as session:
stmt = select(StageCheckpoint).where(
StageCheckpoint.job_id == job_id,
StageCheckpoint.stage == stage,
)
return session.exec(stmt).first()
def list_stage_checkpoints(job_id: UUID) -> list[str]:
with get_session() as session:
stmt = (
select(StageCheckpoint.stage)
.where(StageCheckpoint.job_id == job_id)
.order_by(StageCheckpoint.stage_index)
)
return list(session.exec(stmt).all())
def list_scenarios() -> list[StageCheckpoint]:
"""List all checkpoints marked as scenarios."""
with get_session() as session:
stmt = (
select(StageCheckpoint)
.where(StageCheckpoint.is_scenario == True)
.order_by(StageCheckpoint.created_at.desc())
)
return list(session.exec(stmt).all())
def delete_stage_checkpoints(job_id: UUID) -> None:
with get_session() as session:
stmt = select(StageCheckpoint).where(StageCheckpoint.job_id == job_id)
for cp in session.exec(stmt).all():
session.delete(cp)
session.commit()
# ---------------------------------------------------------------------------
# KnownBrand
# ---------------------------------------------------------------------------
def get_or_create_brand(canonical_name: str, aliases: Optional[list[str]] = None,
source: str = "ocr") -> tuple[KnownBrand, bool]:
normalized = canonical_name.strip()
with get_session() as session:
stmt = select(KnownBrand).where(KnownBrand.canonical_name.ilike(normalized))
brand = session.exec(stmt).first()
if brand:
return brand, False
brand = KnownBrand(
canonical_name=normalized,
aliases=aliases or [],
first_source=source,
)
session.add(brand)
session.commit()
session.refresh(brand)
return brand, True
def find_brand_by_text(text: str) -> KnownBrand | None:
normalized = text.strip().lower()
with get_session() as session:
stmt = select(KnownBrand).where(KnownBrand.canonical_name.ilike(normalized))
brand = session.exec(stmt).first()
if brand:
return brand
# Alias search — check if normalized is in any brand's aliases
all_brands = session.exec(select(KnownBrand)).all()
for b in all_brands:
if normalized in [a.lower() for a in (b.aliases or [])]:
return b
return None
def list_all_brands() -> list[KnownBrand]:
with get_session() as session:
return list(session.exec(select(KnownBrand).order_by(KnownBrand.canonical_name)).all())
def update_brand(brand_id: UUID, **fields) -> None:
with get_session() as session:
brand = session.get(KnownBrand, brand_id)
if not brand:
return
for k, v in fields.items():
setattr(brand, k, v)
session.commit()
# ---------------------------------------------------------------------------
# SourceBrandSighting
# ---------------------------------------------------------------------------
def get_source_sightings(source_asset_id: UUID) -> list[SourceBrandSighting]:
with get_session() as session:
stmt = (
select(SourceBrandSighting)
.where(SourceBrandSighting.source_asset_id == source_asset_id)
.order_by(SourceBrandSighting.occurrences.desc())
)
return list(session.exec(stmt).all())
def record_sighting(source_asset_id: UUID, brand_id: UUID, brand_name: str,
timestamp: float, confidence: float, source: str = "ocr") -> SourceBrandSighting:
with get_session() as session:
stmt = select(SourceBrandSighting).where(
SourceBrandSighting.source_asset_id == source_asset_id,
SourceBrandSighting.brand_id == brand_id,
)
sighting = session.exec(stmt).first()
if sighting:
total_conf = sighting.avg_confidence * sighting.occurrences + confidence
sighting.occurrences += 1
sighting.last_seen_timestamp = timestamp
sighting.avg_confidence = total_conf / sighting.occurrences
session.commit()
session.refresh(sighting)
return sighting
sighting = SourceBrandSighting(
source_asset_id=source_asset_id,
brand_id=brand_id,
brand_name=brand_name,
first_seen_timestamp=timestamp,
last_seen_timestamp=timestamp,
occurrences=1,
detection_source=source,
avg_confidence=confidence,
)
session.add(sighting)
session.commit()
session.refresh(sighting)
return sighting