phase 10
This commit is contained in:
@@ -1,20 +1,13 @@
|
||||
"""Tests for BrandResolver stage."""
|
||||
"""Tests for BrandResolver stage (discovery mode)."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, Frame, TextCandidate
|
||||
from detect.profiles.base import BrandDictionary, ResolverConfig
|
||||
from detect.stages.brand_resolver import resolve_brands, _exact_match, _fuzzy_match
|
||||
from detect.profiles.base import ResolverConfig
|
||||
from detect.stages.brand_resolver import resolve_brands, _normalize, _match_session
|
||||
|
||||
|
||||
DICTIONARY = BrandDictionary(brands={
|
||||
"Nike": ["nike", "NIKE", "swoosh"],
|
||||
"Adidas": ["adidas", "ADIDAS"],
|
||||
"Coca-Cola": ["coca-cola", "coca cola", "coke", "COCA-COLA"],
|
||||
"Emirates": ["emirates", "fly emirates", "EMIRATES"],
|
||||
})
|
||||
|
||||
CONFIG = ResolverConfig(fuzzy_threshold=75)
|
||||
|
||||
|
||||
@@ -25,57 +18,76 @@ def _make_candidate(text: str, confidence: float = 0.9) -> TextCandidate:
|
||||
return TextCandidate(frame=dummy_frame, bbox=dummy_box, text=text, ocr_confidence=confidence)
|
||||
|
||||
|
||||
def test_exact_match():
|
||||
assert _exact_match("Nike", DICTIONARY) == "Nike"
|
||||
assert _exact_match("nike", DICTIONARY) == "Nike"
|
||||
assert _exact_match("COCA-COLA", DICTIONARY) == "Coca-Cola"
|
||||
assert _exact_match("fly emirates", DICTIONARY) == "Emirates"
|
||||
assert _exact_match("unknown brand", DICTIONARY) is None
|
||||
def test_session_match():
|
||||
session = {"nike": "Nike", "fly emirates": "Emirates"}
|
||||
assert _match_session("Nike", session) == "Nike"
|
||||
assert _match_session("nike", session) == "Nike"
|
||||
assert _match_session("FLY EMIRATES", session) == "Emirates"
|
||||
assert _match_session("unknown", session) is None
|
||||
|
||||
|
||||
def test_fuzzy_match():
|
||||
brand, score = _fuzzy_match("Nik3", DICTIONARY, threshold=75)
|
||||
assert brand == "Nike"
|
||||
assert score >= 75
|
||||
def test_resolve_with_session(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
brand, score = _fuzzy_match("adldas", DICTIONARY, threshold=75)
|
||||
assert brand == "Adidas"
|
||||
|
||||
brand, score = _fuzzy_match("xyzxyzxyz", DICTIONARY, threshold=75)
|
||||
assert brand is None
|
||||
|
||||
|
||||
def test_resolve_exact():
|
||||
session = {"nike": "Nike", "emirates": "Emirates"}
|
||||
candidates = [_make_candidate("Nike"), _make_candidate("EMIRATES")]
|
||||
matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG)
|
||||
|
||||
matched, unresolved = resolve_brands(
|
||||
candidates, CONFIG, session_brands=session,
|
||||
)
|
||||
|
||||
assert len(matched) == 2
|
||||
assert len(unresolved) == 0
|
||||
assert matched[0].brand == "Nike"
|
||||
assert matched[1].brand == "Emirates"
|
||||
|
||||
|
||||
def test_resolve_fuzzy():
|
||||
candidates = [_make_candidate("coca coIa")] # OCR misread
|
||||
matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG)
|
||||
assert len(matched) == 1
|
||||
assert matched[0].brand == "Coca-Cola"
|
||||
def test_resolve_unresolved_without_db(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
|
||||
def test_resolve_unresolved():
|
||||
candidates = [_make_candidate("random garbage text")]
|
||||
matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG)
|
||||
|
||||
matched, unresolved = resolve_brands(
|
||||
candidates, CONFIG, session_brands={},
|
||||
)
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(unresolved) == 1
|
||||
|
||||
|
||||
def test_resolve_mixed():
|
||||
def test_resolve_empty(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
matched, unresolved = resolve_brands([], CONFIG, session_brands={})
|
||||
|
||||
assert len(matched) == 0
|
||||
assert len(unresolved) == 0
|
||||
|
||||
|
||||
def test_resolve_builds_session_during_run(monkeypatch):
|
||||
"""Session brands accumulate during a single run — second candidate benefits."""
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
session = {"nike": "Nike"}
|
||||
candidates = [
|
||||
_make_candidate("Nike"),
|
||||
_make_candidate("unknown"),
|
||||
_make_candidate("adldas"),
|
||||
_make_candidate("Nike"), # hits session
|
||||
_make_candidate("unknown"), # misses everything
|
||||
]
|
||||
matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG)
|
||||
assert len(matched) == 2 # Nike exact + Adidas fuzzy
|
||||
|
||||
matched, unresolved = resolve_brands(
|
||||
candidates, CONFIG, session_brands=session,
|
||||
)
|
||||
|
||||
assert len(matched) == 1
|
||||
assert matched[0].brand == "Nike"
|
||||
assert len(unresolved) == 1
|
||||
|
||||
|
||||
@@ -84,8 +96,10 @@ def test_events_emitted(monkeypatch):
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
session = {"nike": "Nike"}
|
||||
candidates = [_make_candidate("Nike")]
|
||||
resolve_brands(candidates, DICTIONARY, CONFIG, job_id="test-job")
|
||||
|
||||
resolve_brands(candidates, CONFIG, session_brands=session, job_id="test-job")
|
||||
|
||||
event_types = [e[0] for e in events]
|
||||
assert "log" in event_types
|
||||
|
||||
Reference in New Issue
Block a user