108 lines
3.2 KiB
Python
108 lines
3.2 KiB
Python
"""
|
|
Profile registry and helpers.
|
|
|
|
Loads profile data from Postgres.
|
|
A profile is a dict with keys: name, pipeline, configs.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Any, Dict
|
|
|
|
from core.detect.stages.models import PipelineConfig, StageRef, Edge
|
|
from core.detect.models import (
|
|
BrandDetection,
|
|
BrandStats,
|
|
CropContext,
|
|
DetectionReport,
|
|
PipelineStats,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def get_profile(name: str) -> Dict[str, Any]:
|
|
"""Get a profile dict by name from the database."""
|
|
from core.db.connection import get_session
|
|
from core.db.models import Profile
|
|
|
|
with get_session() as session:
|
|
row = session.query(Profile).filter(Profile.name == name).first()
|
|
|
|
if row is None:
|
|
raise ValueError(f"Unknown profile: {name!r}")
|
|
|
|
return {
|
|
"name": row.name,
|
|
"pipeline": row.pipeline or {},
|
|
"configs": row.configs or {},
|
|
}
|
|
|
|
|
|
def list_profiles() -> list[str]:
|
|
"""List available profile names from the database."""
|
|
from core.db.connection import get_session
|
|
from core.db.models import Profile
|
|
|
|
with get_session() as session:
|
|
rows = session.query(Profile.name).all()
|
|
|
|
return [r[0] for r in rows]
|
|
|
|
|
|
def get_stage_config(profile: Dict[str, Any], stage_name: str) -> dict:
|
|
"""Get config values for a stage from a profile."""
|
|
return profile.get("configs", {}).get(stage_name, {})
|
|
|
|
|
|
def pipeline_config_from_dict(data: Dict[str, Any]) -> PipelineConfig:
|
|
"""Deserialize a PipelineConfig from a JSONB dict."""
|
|
stages = [StageRef(**s) for s in data.get("stages", [])]
|
|
edges = [Edge(**e) for e in data.get("edges", [])]
|
|
return PipelineConfig(
|
|
name=data.get("name", ""),
|
|
profile_name=data.get("profile_name", ""),
|
|
stages=stages,
|
|
edges=edges,
|
|
routing_rules=data.get("routing_rules", {}),
|
|
)
|
|
|
|
|
|
def build_vlm_prompt(crop_context: CropContext, template: str) -> str:
|
|
"""Build a VLM prompt from a template and crop context."""
|
|
hint = f" Position: {crop_context.position_hint}." if crop_context.position_hint else ""
|
|
text = f" Nearby text: '{crop_context.surrounding_text}'." if crop_context.surrounding_text else ""
|
|
return template.format(hint=hint, text=text)
|
|
|
|
|
|
def aggregate_detections(
|
|
detections: list[BrandDetection],
|
|
content_type: str,
|
|
) -> DetectionReport:
|
|
"""Group detections by brand into a report."""
|
|
brands: dict[str, BrandStats] = {}
|
|
for d in detections:
|
|
if d.brand not in brands:
|
|
brands[d.brand] = BrandStats()
|
|
s = brands[d.brand]
|
|
s.total_appearances += 1
|
|
s.total_screen_time += d.duration
|
|
s.avg_confidence = (
|
|
(s.avg_confidence * (s.total_appearances - 1) + d.confidence)
|
|
/ s.total_appearances
|
|
)
|
|
if s.first_seen == 0.0 or d.timestamp < s.first_seen:
|
|
s.first_seen = d.timestamp
|
|
if d.timestamp > s.last_seen:
|
|
s.last_seen = d.timestamp
|
|
|
|
return DetectionReport(
|
|
video_source="",
|
|
content_type=content_type,
|
|
duration_seconds=0.0,
|
|
brands=brands,
|
|
timeline=sorted(detections, key=lambda d: d.timestamp),
|
|
pipeline_stats=PipelineStats(),
|
|
)
|