This commit is contained in:
2026-03-26 04:24:32 -03:00
parent 08b67f2bb7
commit 08c58a6a9d
43 changed files with 2627 additions and 252 deletions

View File

@@ -1,9 +1,17 @@
"""
Stage 5 — Brand Resolver
Stage 5 — Brand Resolver (discovery mode)
Matches OCR text against the profile's brand dictionary.
Uses exact matching first, then fuzzy matching (rapidfuzz) as fallback.
Emits detection events for confirmed brands.
Discovery-first brand matching. No static dictionary — all brands live in the DB.
Flow:
1. Check session sightings first (brands already seen in this source)
2. Check global known brands (accumulated across all runs)
3. Unresolved candidates → escalate to VLM/cloud
4. Confirmed brands get added to DB for future runs
The resolver is an enricher, not a gatekeeper. Every OCR text candidate
passes through — the question is whether we can resolve it cheaply (DB lookup)
or need to escalate (VLM/cloud).
"""
from __future__ import annotations
@@ -14,99 +22,199 @@ from rapidfuzz import fuzz
from detect import emit
from detect.models import BrandDetection, TextCandidate
from detect.profiles.base import BrandDictionary, ResolverConfig
from detect.profiles.base import ResolverConfig
logger = logging.getLogger(__name__)
def _normalize(text: str) -> str:
"""Normalize text for matching."""
return text.strip().lower()
def _exact_match(text: str, dictionary: BrandDictionary) -> str | None:
"""Try exact match against all aliases."""
def _has_db() -> bool:
try:
from core.db.detect import find_brand_by_text as _
from admin.mpr.media_assets.models import KnownBrand as _
return True
except (ImportError, Exception):
return False
def _match_session(text: str, session_brands: dict[str, str]) -> str | None:
"""
Check against session brands (already seen in this source).
session_brands: {normalized_name: canonical_name, ...}
Includes aliases.
"""
normalized = _normalize(text)
for canonical, aliases in dictionary.brands.items():
if normalized == _normalize(canonical):
return canonical
for alias in aliases:
if normalized == _normalize(alias):
return canonical
return None
return session_brands.get(normalized)
def _fuzzy_match(text: str, dictionary: BrandDictionary, threshold: int) -> tuple[str | None, int]:
"""Try fuzzy match, return (brand, score) or (None, 0)."""
def _match_known(text: str, threshold: int) -> tuple[str | None, str | None]:
"""
Check against global known brands in DB.
Returns (canonical_name, brand_id) or (None, None).
"""
if not _has_db():
return None, None
from core.db.detect import find_brand_by_text
brand = find_brand_by_text(text)
if brand:
return brand.canonical_name, str(brand.id)
# Fuzzy match against all known brands
from core.db.detect import list_all_brands
all_brands = list_all_brands()
normalized = _normalize(text)
best_brand = None
best_score = 0
for canonical, aliases in dictionary.brands.items():
all_names = [canonical] + aliases
for name in all_names:
for known in all_brands:
names = [known.canonical_name] + (known.aliases or [])
for name in names:
score = fuzz.ratio(normalized, _normalize(name))
if score > best_score and score >= threshold:
best_score = score
best_brand = canonical
best_brand = known
return best_brand, best_score
if best_brand:
return best_brand.canonical_name, str(best_brand.id)
return None, None
def _register_brand(canonical_name: str, source: str) -> str | None:
"""Register a newly discovered brand in the DB. Returns brand_id."""
if not _has_db():
return None
from core.db.detect import get_or_create_brand
brand, created = get_or_create_brand(canonical_name, source=source)
if created:
logger.info("New brand discovered: %s (source=%s)", canonical_name, source)
return str(brand.id)
def _record_sighting(source_asset_id: str | None, brand_id: str,
brand_name: str, timestamp: float,
confidence: float, source: str):
"""Record a brand sighting for this source."""
if not _has_db() or not source_asset_id:
return
from core.db.detect import record_sighting
import uuid
asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id
brand_uuid = uuid.UUID(brand_id) if isinstance(brand_id, str) else brand_id
record_sighting(asset_id, brand_uuid, brand_name, timestamp, confidence, source)
def build_session_dict(source_asset_id: str | None) -> dict[str, str]:
"""
Load session brands from DB for this source.
Returns {normalized_name: canonical_name, ...} including aliases.
"""
if not _has_db() or not source_asset_id:
return {}
from core.db.detect import get_source_sightings
import uuid
asset_id = uuid.UUID(source_asset_id) if isinstance(source_asset_id, str) else source_asset_id
sightings = get_source_sightings(asset_id)
session = {}
for s in sightings:
canonical = s.brand_name
session[_normalize(canonical)] = canonical
# Also load aliases from KnownBrand for each sighted brand
if _has_db():
from core.db.detect import list_all_brands
all_brands = list_all_brands()
sighted_names = {s.brand_name for s in sightings}
for brand in all_brands:
if brand.canonical_name in sighted_names:
for alias in (brand.aliases or []):
session[_normalize(alias)] = brand.canonical_name
return session
def resolve_brands(
candidates: list[TextCandidate],
dictionary: BrandDictionary,
config: ResolverConfig,
session_brands: dict[str, str] | None = None,
source_asset_id: str | None = None,
content_type: str = "",
job_id: str | None = None,
) -> tuple[list[BrandDetection], list[TextCandidate]]:
"""
Match text candidates against the brand dictionary.
Match text candidates against known brands (session → global → unresolved).
Returns:
- matched: list of BrandDetection for confirmed brands
- unresolved: list of TextCandidate that couldn't be matched
session_brands: pre-loaded session dict (from build_session_dict)
source_asset_id: for recording new sightings in DB
"""
if session_brands is None:
session_brands = {}
emit.log(job_id, "BrandResolver", "INFO",
f"Matching {len(candidates)} candidates against "
f"{len(dictionary.brands)} brands (fuzzy_threshold={config.fuzzy_threshold})")
f"Resolving {len(candidates)} candidates "
f"(session={len(session_brands)} brands, fuzzy={config.fuzzy_threshold})")
matched: list[BrandDetection] = []
unresolved: list[TextCandidate] = []
exact_count = 0
fuzzy_count = 0
session_hits = 0
known_hits = 0
for candidate in candidates:
# Try exact match first
brand = _exact_match(candidate.text, dictionary)
source = "ocr"
text = candidate.text
brand_name = None
brand_id = None
match_source = "ocr"
if brand:
exact_count += 1
# 1. Check session (cheapest — in-memory dict)
brand_name = _match_session(text, session_brands)
if brand_name:
session_hits += 1
else:
# Try fuzzy match
brand, score = _fuzzy_match(candidate.text, dictionary, config.fuzzy_threshold)
if brand:
fuzzy_count += 1
# 2. Check global known brands (DB query + fuzzy)
brand_name, brand_id = _match_known(text, config.fuzzy_threshold)
if brand_name:
known_hits += 1
# Add to session for subsequent candidates in this run
session_brands[_normalize(brand_name)] = brand_name
if brand:
if brand_name:
detection = BrandDetection(
brand=brand,
brand=brand_name,
timestamp=candidate.frame.timestamp,
duration=0.5,
confidence=candidate.ocr_confidence,
source=source,
source=match_source,
bbox=candidate.bbox,
frame_ref=candidate.frame.sequence,
content_type=content_type,
)
matched.append(detection)
# Record sighting in DB
if brand_id:
_record_sighting(
source_asset_id, brand_id, brand_name,
candidate.frame.timestamp, candidate.ocr_confidence, match_source,
)
emit.detection(
job_id,
brand=brand,
brand=brand_name,
confidence=candidate.ocr_confidence,
source=source,
source=match_source,
timestamp=candidate.frame.timestamp,
content_type=content_type,
frame_ref=candidate.frame.sequence,
@@ -115,7 +223,7 @@ def resolve_brands(
unresolved.append(candidate)
emit.log(job_id, "BrandResolver", "INFO",
f"Exact: {exact_count}, Fuzzy: {fuzzy_count}, "
f"Unresolved: {len(unresolved)} → escalating to VLM")
f"Session: {session_hits}, Known: {known_hits}, "
f"Unresolved: {len(unresolved)} → escalating")
return matched, unresolved

View File

@@ -27,6 +27,18 @@ logger = logging.getLogger(__name__)
ESTIMATED_TOKENS_PER_CROP = 500
def _register_discovered_brand(brand: str, source_asset_id: str | None,
timestamp: float, confidence: float):
"""Register a cloud-confirmed brand in the DB."""
try:
from detect.stages.brand_resolver import _register_brand, _record_sighting
brand_id = _register_brand(brand, "cloud_llm")
if brand_id and source_asset_id:
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, "cloud_llm")
except Exception as e:
logger.debug("Failed to register brand %s: %s", brand, e)
def _encode_crop(crop: np.ndarray) -> str:
img = Image.fromarray(crop)
buf = io.BytesIO()
@@ -84,6 +96,7 @@ def escalate_cloud(
stats: PipelineStats,
min_confidence: float = 0.4,
content_type: str = "",
source_asset_id: str | None = None,
job_id: str | None = None,
) -> list[BrandDetection]:
"""
@@ -158,6 +171,10 @@ def escalate_cloud(
frame_ref=candidate.frame.sequence,
)
# Register newly discovered brand in DB
_register_discovered_brand(brand, source_asset_id,
candidate.frame.timestamp, confidence)
stats.estimated_cloud_cost_usd += total_cost
stats.regions_escalated_to_cloud_llm = len(candidates)

View File

@@ -19,6 +19,18 @@ from detect.profiles.base import CropContext
logger = logging.getLogger(__name__)
def _register_discovered_brand(brand: str, source_asset_id: str | None,
timestamp: float, confidence: float, source: str):
"""Register a VLM-confirmed brand in the DB."""
try:
from detect.stages.brand_resolver import _register_brand, _record_sighting
brand_id = _register_brand(brand, source)
if brand_id and source_asset_id:
_record_sighting(source_asset_id, brand_id, brand, timestamp, confidence, source)
except Exception as e:
logger.debug("Failed to register brand %s: %s", brand, e)
def _crop_image(candidate: TextCandidate) -> np.ndarray:
frame = candidate.frame
box = candidate.bbox
@@ -36,6 +48,7 @@ def escalate_vlm(
inference_url: str | None = None,
min_confidence: float = 0.5,
content_type: str = "",
source_asset_id: str | None = None,
job_id: str | None = None,
) -> tuple[list[BrandDetection], list[TextCandidate]]:
"""
@@ -107,6 +120,10 @@ def escalate_vlm(
frame_ref=candidate.frame.sequence,
)
# Register newly discovered brand in DB
_register_discovered_brand(brand, source_asset_id,
candidate.frame.timestamp, confidence, "local_vlm")
logger.debug("VLM matched: %s (%.2f) — %s", brand, confidence, reasoning)
else:
still_unresolved.append(candidate)