phase 4
This commit is contained in:
107
core/detect/profile.py
Normal file
107
core/detect/profile.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Profile registry and helpers.
|
||||
|
||||
Loads profile data from Postgres.
|
||||
A profile is a dict with keys: name, pipeline, configs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from core.detect.stages.models import PipelineConfig, StageRef, Edge
|
||||
from core.detect.models import (
|
||||
BrandDetection,
|
||||
BrandStats,
|
||||
CropContext,
|
||||
DetectionReport,
|
||||
PipelineStats,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_profile(name: str) -> Dict[str, Any]:
|
||||
"""Get a profile dict by name from the database."""
|
||||
from core.db.connection import get_session
|
||||
from core.db.models import Profile
|
||||
|
||||
with get_session() as session:
|
||||
row = session.query(Profile).filter(Profile.name == name).first()
|
||||
|
||||
if row is None:
|
||||
raise ValueError(f"Unknown profile: {name!r}")
|
||||
|
||||
return {
|
||||
"name": row.name,
|
||||
"pipeline": row.pipeline or {},
|
||||
"configs": row.configs or {},
|
||||
}
|
||||
|
||||
|
||||
def list_profiles() -> list[str]:
|
||||
"""List available profile names from the database."""
|
||||
from core.db.connection import get_session
|
||||
from core.db.models import Profile
|
||||
|
||||
with get_session() as session:
|
||||
rows = session.query(Profile.name).all()
|
||||
|
||||
return [r[0] for r in rows]
|
||||
|
||||
|
||||
def get_stage_config(profile: Dict[str, Any], stage_name: str) -> dict:
|
||||
"""Get config values for a stage from a profile."""
|
||||
return profile.get("configs", {}).get(stage_name, {})
|
||||
|
||||
|
||||
def pipeline_config_from_dict(data: Dict[str, Any]) -> PipelineConfig:
|
||||
"""Deserialize a PipelineConfig from a JSONB dict."""
|
||||
stages = [StageRef(**s) for s in data.get("stages", [])]
|
||||
edges = [Edge(**e) for e in data.get("edges", [])]
|
||||
return PipelineConfig(
|
||||
name=data.get("name", ""),
|
||||
profile_name=data.get("profile_name", ""),
|
||||
stages=stages,
|
||||
edges=edges,
|
||||
routing_rules=data.get("routing_rules", {}),
|
||||
)
|
||||
|
||||
|
||||
def build_vlm_prompt(crop_context: CropContext, template: str) -> str:
|
||||
"""Build a VLM prompt from a template and crop context."""
|
||||
hint = f" Position: {crop_context.position_hint}." if crop_context.position_hint else ""
|
||||
text = f" Nearby text: '{crop_context.surrounding_text}'." if crop_context.surrounding_text else ""
|
||||
return template.format(hint=hint, text=text)
|
||||
|
||||
|
||||
def aggregate_detections(
|
||||
detections: list[BrandDetection],
|
||||
content_type: str,
|
||||
) -> DetectionReport:
|
||||
"""Group detections by brand into a report."""
|
||||
brands: dict[str, BrandStats] = {}
|
||||
for d in detections:
|
||||
if d.brand not in brands:
|
||||
brands[d.brand] = BrandStats()
|
||||
s = brands[d.brand]
|
||||
s.total_appearances += 1
|
||||
s.total_screen_time += d.duration
|
||||
s.avg_confidence = (
|
||||
(s.avg_confidence * (s.total_appearances - 1) + d.confidence)
|
||||
/ s.total_appearances
|
||||
)
|
||||
if s.first_seen == 0.0 or d.timestamp < s.first_seen:
|
||||
s.first_seen = d.timestamp
|
||||
if d.timestamp > s.last_seen:
|
||||
s.last_seen = d.timestamp
|
||||
|
||||
return DetectionReport(
|
||||
video_source="",
|
||||
content_type=content_type,
|
||||
duration_seconds=0.0,
|
||||
brands=brands,
|
||||
timeline=sorted(detections, key=lambda d: d.timestamp),
|
||||
pipeline_stats=PipelineStats(),
|
||||
)
|
||||
Reference in New Issue
Block a user