major refactor

This commit is contained in:
2026-03-27 06:02:58 -03:00
parent bcf6f3dc71
commit 51ce14a812
18 changed files with 351 additions and 523 deletions

View File

@@ -4,14 +4,10 @@ Stage 5 — Brand Resolver (discovery mode)
Discovery-first brand matching. No static dictionary — all brands live in the DB.
Flow:
1. Check session sightings first (brands already seen in this source)
1. Check session brands first (brands already seen in this run, in-memory)
2. Check global known brands (accumulated across all runs)
3. Unresolved candidates → escalate to VLM/cloud
4. Confirmed brands get added to DB for future runs
The resolver is an enricher, not a gatekeeper. Every OCR text candidate
passes through — the question is whether we can resolve it cheaply (DB lookup)
or need to escalate (VLM/cloud).
"""
from __future__ import annotations
@@ -33,41 +29,30 @@ def _normalize(text: str) -> str:
def _has_db() -> bool:
try:
from core.db.detect import find_brand_by_text as _
from admin.mpr.media_assets.models import KnownBrand as _
from core.db import find_brand_by_text as _
return True
except (ImportError, Exception):
return False
def _match_session(text: str, session_brands: dict[str, str]) -> str | None:
"""
Check against session brands (already seen in this source).
session_brands: {normalized_name: canonical_name, ...}
Includes aliases.
"""
normalized = _normalize(text)
return session_brands.get(normalized)
return session_brands.get(_normalize(text))
def _match_known(text: str, threshold: int) -> tuple[str | None, str | None]:
"""
Check against global known brands in DB.
Returns (canonical_name, brand_id) or (None, None).
"""
"""Check against global known brands in DB. Returns (canonical_name, brand_id) or (None, None)."""
if not _has_db():
return None, None
from core.db.detect import find_brand_by_text
brand = find_brand_by_text(text)
if brand:
return brand.canonical_name, str(brand.id)
from core.db import find_brand_by_text, list_brands
from core.db.connection import get_session
# Fuzzy match against all known brands
from core.db.detect import list_all_brands
all_brands = list_all_brands()
with get_session() as session:
brand = find_brand_by_text(session, text)
if brand:
return brand.canonical_name, str(brand.id)
all_brands = list_brands(session)
normalized = _normalize(text)
best_brand = None
@@ -92,58 +77,62 @@ def _register_brand(canonical_name: str, source: str) -> str | None:
if not _has_db():
return None
from core.db.detect import get_or_create_brand
brand, created = get_or_create_brand(canonical_name, source=source)
from core.db import get_or_create_brand
from core.db.connection import get_session
with get_session() as session:
brand, created = get_or_create_brand(session, canonical_name, source=source)
session.commit()
if created:
logger.info("New brand discovered: %s (source=%s)", canonical_name, source)
return str(brand.id)
def _record_sighting(source_asset_id: str | None, brand_id: str,
brand_name: str, timestamp: float,
confidence: float, source: str):
"""Record a brand sighting for this source."""
if not _has_db() or not source_asset_id:
def _record_airing(timeline_id: str | None, brand_id: str,
frame_seq: int, confidence: float, source: str):
"""Record a brand airing on a timeline."""
if not _has_db() or not timeline_id:
return
from core.db.detect import record_sighting
import uuid
asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id
brand_uuid = uuid.UUID(brand_id) if isinstance(brand_id, str) else brand_id
record_sighting(asset_id, brand_uuid, brand_name, timestamp, confidence, source)
from core.db import record_airing
from core.db.connection import get_session
from uuid import UUID
with get_session() as session:
record_airing(
session,
brand_id=UUID(brand_id),
timeline_id=UUID(timeline_id),
frame_start=frame_seq,
frame_end=frame_seq,
confidence=confidence,
source=source,
)
session.commit()
def build_session_dict(source_asset_id: str | None) -> dict[str, str]:
def build_session_dict(source_asset_id: str | None = None) -> dict[str, str]:
"""
Load session brands from DB for this source.
Load known brands from DB as a session lookup dict.
Returns {normalized_name: canonical_name, ...} including aliases.
"""
if not _has_db() or not source_asset_id:
if not _has_db():
return {}
from core.db.detect import get_source_sightings
import uuid
from core.db import list_brands
from core.db.connection import get_session
asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id
sightings = get_source_sightings(asset_id)
with get_session() as session:
all_brands = list_brands(session)
session = {}
for s in sightings:
canonical = s.brand_name
session[_normalize(canonical)] = canonical
session_dict = {}
for brand in all_brands:
session_dict[_normalize(brand.canonical_name)] = brand.canonical_name
for alias in (brand.aliases or []):
session_dict[_normalize(alias)] = brand.canonical_name
# Also load aliases from KnownBrand for each sighted brand
if _has_db():
from core.db.detect import list_all_brands
all_brands = list_all_brands()
sighted_names = {s.brand_name for s in sightings}
for brand in all_brands:
if brand.canonical_name in sighted_names:
for alias in (brand.aliases or []):
session[_normalize(alias)] = brand.canonical_name
return session
return session_dict
def resolve_brands(
@@ -158,7 +147,7 @@ def resolve_brands(
Match text candidates against known brands (session → global → unresolved).
session_brands: pre-loaded session dict (from build_session_dict)
source_asset_id: for recording new sightings in DB
job_id: timeline_id — used to record airings
"""
if session_brands is None:
session_brands = {}
@@ -187,7 +176,6 @@ def resolve_brands(
brand_name, brand_id = _match_known(text, config.fuzzy_threshold)
if brand_name:
known_hits += 1
# Add to session for subsequent candidates in this run
session_brands[_normalize(brand_name)] = brand_name
if brand_name:
@@ -203,11 +191,10 @@ def resolve_brands(
)
matched.append(detection)
# Record sighting in DB
if brand_id:
_record_sighting(
source_asset_id, brand_id, brand_name,
candidate.frame.timestamp, candidate.ocr_confidence, match_source,
_record_airing(
job_id, brand_id,
candidate.frame.sequence, candidate.ocr_confidence, match_source,
)
emit.detection(

View File

@@ -10,7 +10,7 @@ from core.schema.serializers._common import (
serialize_dataclass,
serialize_dataclass_list,
)
from core.schema.serializers.detect_pipeline import (
from core.schema.serializers.pipeline import (
serialize_frame_meta,
serialize_frames_with_upload as serialize_frames,
deserialize_frames_with_download as deserialize_frames,