phase 7
This commit is contained in:
121
detect/stages/brand_resolver.py
Normal file
121
detect/stages/brand_resolver.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
Stage 5 — Brand Resolver
|
||||
|
||||
Matches OCR text against the profile's brand dictionary.
|
||||
Uses exact matching first, then fuzzy matching (rapidfuzz) as fallback.
|
||||
Emits detection events for confirmed brands.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from rapidfuzz import fuzz
|
||||
|
||||
from detect import emit
|
||||
from detect.models import BrandDetection, TextCandidate
|
||||
from detect.profiles.base import BrandDictionary, ResolverConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize(text: str) -> str:
|
||||
"""Normalize text for matching."""
|
||||
return text.strip().lower()
|
||||
|
||||
|
||||
def _exact_match(text: str, dictionary: BrandDictionary) -> str | None:
|
||||
"""Try exact match against all aliases."""
|
||||
normalized = _normalize(text)
|
||||
for canonical, aliases in dictionary.brands.items():
|
||||
if normalized == _normalize(canonical):
|
||||
return canonical
|
||||
for alias in aliases:
|
||||
if normalized == _normalize(alias):
|
||||
return canonical
|
||||
return None
|
||||
|
||||
|
||||
def _fuzzy_match(text: str, dictionary: BrandDictionary, threshold: int) -> tuple[str | None, int]:
|
||||
"""Try fuzzy match, return (brand, score) or (None, 0)."""
|
||||
normalized = _normalize(text)
|
||||
best_brand = None
|
||||
best_score = 0
|
||||
|
||||
for canonical, aliases in dictionary.brands.items():
|
||||
all_names = [canonical] + aliases
|
||||
for name in all_names:
|
||||
score = fuzz.ratio(normalized, _normalize(name))
|
||||
if score > best_score and score >= threshold:
|
||||
best_score = score
|
||||
best_brand = canonical
|
||||
|
||||
return best_brand, best_score
|
||||
|
||||
|
||||
def resolve_brands(
|
||||
candidates: list[TextCandidate],
|
||||
dictionary: BrandDictionary,
|
||||
config: ResolverConfig,
|
||||
content_type: str = "",
|
||||
job_id: str | None = None,
|
||||
) -> tuple[list[BrandDetection], list[TextCandidate]]:
|
||||
"""
|
||||
Match text candidates against the brand dictionary.
|
||||
|
||||
Returns:
|
||||
- matched: list of BrandDetection for confirmed brands
|
||||
- unresolved: list of TextCandidate that couldn't be matched
|
||||
"""
|
||||
emit.log(job_id, "BrandResolver", "INFO",
|
||||
f"Matching {len(candidates)} candidates against "
|
||||
f"{len(dictionary.brands)} brands (fuzzy_threshold={config.fuzzy_threshold})")
|
||||
|
||||
matched: list[BrandDetection] = []
|
||||
unresolved: list[TextCandidate] = []
|
||||
exact_count = 0
|
||||
fuzzy_count = 0
|
||||
|
||||
for candidate in candidates:
|
||||
# Try exact match first
|
||||
brand = _exact_match(candidate.text, dictionary)
|
||||
source = "ocr"
|
||||
|
||||
if brand:
|
||||
exact_count += 1
|
||||
else:
|
||||
# Try fuzzy match
|
||||
brand, score = _fuzzy_match(candidate.text, dictionary, config.fuzzy_threshold)
|
||||
if brand:
|
||||
fuzzy_count += 1
|
||||
|
||||
if brand:
|
||||
detection = BrandDetection(
|
||||
brand=brand,
|
||||
timestamp=candidate.frame.timestamp,
|
||||
duration=0.5,
|
||||
confidence=candidate.ocr_confidence,
|
||||
source=source,
|
||||
bbox=candidate.bbox,
|
||||
frame_ref=candidate.frame.sequence,
|
||||
content_type=content_type,
|
||||
)
|
||||
matched.append(detection)
|
||||
|
||||
emit.detection(
|
||||
job_id,
|
||||
brand=brand,
|
||||
confidence=candidate.ocr_confidence,
|
||||
source=source,
|
||||
timestamp=candidate.frame.timestamp,
|
||||
content_type=content_type,
|
||||
frame_ref=candidate.frame.sequence,
|
||||
)
|
||||
else:
|
||||
unresolved.append(candidate)
|
||||
|
||||
emit.log(job_id, "BrandResolver", "INFO",
|
||||
f"Exact: {exact_count}, Fuzzy: {fuzzy_count}, "
|
||||
f"Unresolved: {len(unresolved)} → escalating to VLM")
|
||||
|
||||
return matched, unresolved
|
||||
130
detect/stages/ocr_stage.py
Normal file
130
detect/stages/ocr_stage.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
Stage 4 — OCR
|
||||
|
||||
Reads text from detected regions (YOLO bounding box crops).
|
||||
Two modes:
|
||||
- remote: calls inference server over HTTP (separate GPU box, or localhost)
|
||||
- local: runs PaddleOCR in-process (single-box setup with enough VRAM)
|
||||
|
||||
The mode is selected by whether inference_url is provided.
|
||||
Model instances are cached at module level so they survive across pipeline runs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from detect import emit
|
||||
from detect.models import BoundingBox, Frame, TextCandidate
|
||||
from detect.profiles.base import OCRConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Module-level cache — avoids reloading the model for every crop or pipeline run
|
||||
_local_ocr_cache: dict[str, object] = {}
|
||||
|
||||
|
||||
def _crop_region(frame: Frame, box: BoundingBox) -> np.ndarray:
|
||||
h, w = frame.image.shape[:2]
|
||||
x1 = max(0, box.x)
|
||||
y1 = max(0, box.y)
|
||||
x2 = min(w, box.x + box.w)
|
||||
y2 = min(h, box.y + box.h)
|
||||
return frame.image[y1:y2, x1:x2]
|
||||
|
||||
|
||||
def _get_local_model(lang: str):
|
||||
if lang not in _local_ocr_cache:
|
||||
from paddleocr import PaddleOCR
|
||||
logger.info("Loading PaddleOCR locally (lang=%s)", lang)
|
||||
_local_ocr_cache[lang] = PaddleOCR(lang=lang)
|
||||
return _local_ocr_cache[lang]
|
||||
|
||||
|
||||
def _parse_ocr_raw(raw, min_confidence: float) -> list[dict]:
|
||||
"""Parse PaddleOCR 3.x result — handles dict-based and nested-list layouts."""
|
||||
results = []
|
||||
for page in (raw or []):
|
||||
if not page:
|
||||
continue
|
||||
if isinstance(page, dict):
|
||||
for text, confidence in zip(page.get("rec_texts", []), page.get("rec_scores", [])):
|
||||
if float(confidence) >= min_confidence:
|
||||
results.append({"text": text, "confidence": float(confidence)})
|
||||
continue
|
||||
for line in page:
|
||||
if not line:
|
||||
continue
|
||||
rec = line[1]
|
||||
if isinstance(rec, (list, tuple)) and len(rec) >= 2:
|
||||
text, confidence = rec[0], rec[1]
|
||||
if float(confidence) >= min_confidence:
|
||||
results.append({"text": text, "confidence": float(confidence)})
|
||||
return results
|
||||
|
||||
|
||||
def run_ocr(
|
||||
frames: list[Frame],
|
||||
boxes_by_frame: dict[int, list[BoundingBox]],
|
||||
config: OCRConfig,
|
||||
inference_url: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> list[TextCandidate]:
|
||||
"""
|
||||
Run OCR on cropped regions from YOLO detections.
|
||||
|
||||
inference_url=None → local in-process PaddleOCR (single-box)
|
||||
inference_url=str → remote inference server (split or localhost)
|
||||
"""
|
||||
total_regions = sum(len(boxes) for boxes in boxes_by_frame.values())
|
||||
mode = "remote" if inference_url else "local"
|
||||
|
||||
emit.log(job_id, "OCRStage", "INFO",
|
||||
f"Running OCR on {total_regions} regions (mode={mode})")
|
||||
|
||||
# Build these once per pipeline run, not per crop
|
||||
if inference_url:
|
||||
from detect.inference import InferenceClient
|
||||
client = InferenceClient(base_url=inference_url)
|
||||
else:
|
||||
model = _get_local_model(config.languages[0])
|
||||
|
||||
frame_map = {f.sequence: f for f in frames}
|
||||
candidates: list[TextCandidate] = []
|
||||
|
||||
for seq, boxes in boxes_by_frame.items():
|
||||
frame = frame_map.get(seq)
|
||||
if not frame:
|
||||
continue
|
||||
|
||||
for box in boxes:
|
||||
crop = _crop_region(frame, box)
|
||||
if crop.size == 0:
|
||||
continue
|
||||
|
||||
if inference_url:
|
||||
raw_results = client.ocr(image=crop, languages=config.languages)
|
||||
texts = [{"text": r.text, "confidence": r.confidence} for r in raw_results]
|
||||
else:
|
||||
raw = model.ocr(crop)
|
||||
texts = _parse_ocr_raw(raw, config.min_confidence)
|
||||
|
||||
for t in texts:
|
||||
candidates.append(TextCandidate(
|
||||
frame=frame,
|
||||
bbox=box,
|
||||
text=t["text"],
|
||||
ocr_confidence=t["confidence"],
|
||||
))
|
||||
|
||||
emit.log(job_id, "OCRStage", "INFO",
|
||||
f"Extracted text from {len(candidates)} regions")
|
||||
emit.stats(job_id, regions_resolved_by_ocr=len(candidates))
|
||||
|
||||
return candidates
|
||||
Reference in New Issue
Block a user