phase 7
This commit is contained in:
176
tests/detect/manual/test_brand_table_e2e.py
Normal file
176
tests/detect/manual/test_brand_table_e2e.py
Normal file
@@ -0,0 +1,176 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Push OCR + brand detection events to test the BrandTablePanel live.
|
||||
|
||||
Simulates what the OCR and BrandResolver stages emit: detection events
|
||||
with brand names, confidence scores, sources, and frame refs. Watch
|
||||
the BrandTablePanel in the UI populate and sort in real time.
|
||||
|
||||
Usage:
|
||||
python tests/detect/manual/test_brand_table_e2e.py [--job JOB_ID] [--port PORT] [--delay SECS]
|
||||
|
||||
Opens: http://mpr.local.ar/detection/?job=<JOB_ID>
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import redis
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DETECTIONS = [
|
||||
# (brand, confidence, source, timestamp, frame_ref) — simulates a real match
|
||||
("Nike", 0.97, "ocr", 2.0, 4),
|
||||
("Nike", 0.95, "ocr", 3.5, 7),
|
||||
("Emirates", 0.92, "ocr", 5.0, 10),
|
||||
("Adidas", 0.89, "ocr", 7.5, 15),
|
||||
("Coca-Cola", 0.85, "ocr", 10.0, 20),
|
||||
("Nike", 0.94, "ocr", 12.5, 25),
|
||||
("Emirates", 0.88, "ocr", 15.0, 30),
|
||||
("Mastercard", 0.78, "local_vlm", 18.0, 36),
|
||||
("Heineken", 0.72, "cloud_llm", 22.5, 45),
|
||||
("Adidas", 0.91, "ocr", 25.0, 50),
|
||||
("Nike", 0.96, "ocr", 27.5, 55),
|
||||
("Emirates", 0.90, "ocr", 30.0, 60),
|
||||
("Unknown Brand", 0.65, "cloud_llm", 33.0, 66),
|
||||
("Coca-Cola", 0.87, "ocr", 35.5, 71),
|
||||
("Nike", 0.93, "ocr", 38.0, 76),
|
||||
]
|
||||
|
||||
|
||||
def ts():
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def push(r, key, event):
|
||||
event["ts"] = event.get("ts", ts())
|
||||
r.rpush(key, json.dumps(event))
|
||||
return event
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--job", default="brand-table-test")
|
||||
parser.add_argument("--port", type=int, default=6382)
|
||||
parser.add_argument("--delay", type=float, default=0.6)
|
||||
args = parser.parse_args()
|
||||
|
||||
r = redis.Redis(port=args.port, decode_responses=True)
|
||||
key = f"detect_events:{args.job}"
|
||||
|
||||
r.delete(key)
|
||||
|
||||
logger.info("Pushing %d detections to %s", len(DETECTIONS), key)
|
||||
logger.info("Open: http://mpr.local.ar/detection/?job=%s", args.job)
|
||||
input("\nPress Enter to start...")
|
||||
|
||||
# Progressive stats — mimics real pipeline stages so the funnel chart draws lines
|
||||
STATS_PROGRESSION = [
|
||||
{"event": "stats_update",
|
||||
"frames_extracted": 120, "frames_after_scene_filter": 0,
|
||||
"regions_detected": 0, "regions_resolved_by_ocr": 0,
|
||||
"regions_escalated_to_local_vlm": 0, "regions_escalated_to_cloud_llm": 0,
|
||||
"cloud_llm_calls": 0, "processing_time_seconds": 3.2, "estimated_cloud_cost_usd": 0},
|
||||
{"event": "stats_update",
|
||||
"frames_extracted": 120, "frames_after_scene_filter": 45,
|
||||
"regions_detected": 0, "regions_resolved_by_ocr": 0,
|
||||
"regions_escalated_to_local_vlm": 0, "regions_escalated_to_cloud_llm": 0,
|
||||
"cloud_llm_calls": 0, "processing_time_seconds": 5.1, "estimated_cloud_cost_usd": 0},
|
||||
{"event": "stats_update",
|
||||
"frames_extracted": 120, "frames_after_scene_filter": 45,
|
||||
"regions_detected": 32, "regions_resolved_by_ocr": 0,
|
||||
"regions_escalated_to_local_vlm": 0, "regions_escalated_to_cloud_llm": 0,
|
||||
"cloud_llm_calls": 0, "processing_time_seconds": 12.4, "estimated_cloud_cost_usd": 0},
|
||||
]
|
||||
|
||||
NODES = ["extract_frames", "filter_scenes", "detect_objects", "run_ocr",
|
||||
"match_brands", "escalate_vlm", "escalate_cloud", "compile_report"]
|
||||
|
||||
def push_graph(r, key, active_node, status, delay):
|
||||
nodes = []
|
||||
for n in NODES:
|
||||
if n == active_node:
|
||||
nodes.append({"id": n, "status": status})
|
||||
elif NODES.index(n) < NODES.index(active_node):
|
||||
nodes.append({"id": n, "status": "done"})
|
||||
else:
|
||||
nodes.append({"id": n, "status": "pending"})
|
||||
push(r, key, {"event": "graph_update", "nodes": nodes})
|
||||
time.sleep(delay)
|
||||
|
||||
# Simulate pipeline progression: extract → filter → detect
|
||||
push(r, key, {"event": "log", "level": "INFO", "stage": "BrandResolver",
|
||||
"msg": f"Starting brand matching — {len(DETECTIONS)} candidates"})
|
||||
time.sleep(args.delay)
|
||||
|
||||
push_graph(r, key, "extract_frames", "running", args.delay)
|
||||
push(r, key, STATS_PROGRESSION[0])
|
||||
time.sleep(args.delay)
|
||||
push_graph(r, key, "extract_frames", "done", args.delay)
|
||||
|
||||
push_graph(r, key, "filter_scenes", "running", args.delay)
|
||||
push(r, key, STATS_PROGRESSION[1])
|
||||
time.sleep(args.delay)
|
||||
push_graph(r, key, "filter_scenes", "done", args.delay)
|
||||
|
||||
push_graph(r, key, "detect_objects", "running", args.delay)
|
||||
push(r, key, STATS_PROGRESSION[2])
|
||||
time.sleep(args.delay)
|
||||
push_graph(r, key, "detect_objects", "done", args.delay)
|
||||
|
||||
push_graph(r, key, "run_ocr", "running", args.delay)
|
||||
|
||||
for i, (brand, conf, source, timestamp, frame_ref) in enumerate(DETECTIONS):
|
||||
push(r, key, {"event": "detection",
|
||||
"brand": brand,
|
||||
"confidence": conf,
|
||||
"source": source,
|
||||
"timestamp": timestamp,
|
||||
"duration": 0.5,
|
||||
"content_type": "soccer_broadcast",
|
||||
"frame_ref": frame_ref})
|
||||
|
||||
logger.info("[%d/%d] %s conf=%.2f source=%s t=%.1fs frame=%d",
|
||||
i + 1, len(DETECTIONS), brand, conf, source, timestamp, frame_ref)
|
||||
time.sleep(args.delay)
|
||||
|
||||
push_graph(r, key, "run_ocr", "done", args.delay)
|
||||
push_graph(r, key, "match_brands", "running", args.delay)
|
||||
|
||||
# Final stats after brand matching
|
||||
push_graph(r, key, "match_brands", "done", args.delay)
|
||||
push_graph(r, key, "escalate_vlm", "running", args.delay)
|
||||
push_graph(r, key, "escalate_vlm", "done", args.delay)
|
||||
push_graph(r, key, "escalate_cloud", "running", args.delay)
|
||||
push_graph(r, key, "escalate_cloud", "done", args.delay)
|
||||
push_graph(r, key, "compile_report", "running", args.delay)
|
||||
|
||||
push(r, key, {"event": "stats_update",
|
||||
"frames_extracted": 120,
|
||||
"frames_after_scene_filter": 45,
|
||||
"regions_detected": 32,
|
||||
"regions_resolved_by_ocr": 24,
|
||||
"regions_escalated_to_local_vlm": 6,
|
||||
"regions_escalated_to_cloud_llm": 2,
|
||||
"cloud_llm_calls": 2,
|
||||
"processing_time_seconds": 31.4,
|
||||
"estimated_cloud_cost_usd": 0.0038})
|
||||
time.sleep(args.delay)
|
||||
|
||||
push_graph(r, key, "compile_report", "done", args.delay)
|
||||
|
||||
push(r, key, {"event": "log", "level": "INFO", "stage": "BrandResolver",
|
||||
"msg": "Brand matching complete — "
|
||||
f"{len(DETECTIONS)} detections, "
|
||||
f"{len(set(d[0] for d in DETECTIONS))} unique brands"})
|
||||
|
||||
logger.info("Done. Watch the BrandTablePanel — try sorting by confidence and brand.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
135
tests/detect/manual/test_ocr_e2e.py
Normal file
135
tests/detect/manual/test_ocr_e2e.py
Normal file
@@ -0,0 +1,135 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test OCR stage end-to-end — sends real images to the inference server.
|
||||
|
||||
Creates test images with known text, sends them through the /ocr endpoint,
|
||||
verifies the text comes back. Tests both the inference server and the
|
||||
ocr_stage module's remote path.
|
||||
|
||||
Usage:
|
||||
python tests/detect/manual/test_ocr_e2e.py [--url URL]
|
||||
|
||||
Requires: inference server running (gpu/server.py)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_text_image(text: str, width: int = 300, height: int = 80) -> np.ndarray:
|
||||
"""Create a white image with black text for OCR testing."""
|
||||
img = Image.new("RGB", (width, height), "white")
|
||||
draw = ImageDraw.Draw(img)
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 36)
|
||||
except (OSError, IOError):
|
||||
font = ImageFont.load_default()
|
||||
draw.text((10, 15), text, fill="black", font=font)
|
||||
return np.array(img)
|
||||
|
||||
|
||||
def image_to_b64(image: np.ndarray) -> str:
|
||||
img = Image.fromarray(image)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, "JPEG")
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
def test_health(url: str):
|
||||
logger.info("--- Health check ---")
|
||||
resp = requests.get(f"{url}/health")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
logger.info("Status: %s, device: %s", data["status"], data["device"])
|
||||
return True
|
||||
|
||||
|
||||
def test_ocr_endpoint(url: str, text: str):
|
||||
logger.info("--- OCR endpoint: '%s' ---", text)
|
||||
image = make_text_image(text)
|
||||
b64 = image_to_b64(image)
|
||||
|
||||
resp = requests.post(f"{url}/ocr", json={"image": b64})
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
results = data.get("results", [])
|
||||
logger.info("Results: %d text regions", len(results))
|
||||
|
||||
found = False
|
||||
for r in results:
|
||||
logger.info(" text=%r confidence=%.3f bbox=%s", r["text"], r["confidence"], r["bbox"])
|
||||
if text.lower() in r["text"].lower():
|
||||
found = True
|
||||
|
||||
if found:
|
||||
logger.info("PASS — found '%s' in OCR output", text)
|
||||
else:
|
||||
logger.warning("MISS — '%s' not found (may be font/rendering issue, check results above)", text)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_ocr_stage_remote(url: str):
|
||||
"""Test the detect/stages/ocr_stage.py remote path."""
|
||||
logger.info("--- OCR stage (remote mode) ---")
|
||||
|
||||
sys.path.insert(0, ".")
|
||||
from detect.models import BoundingBox, Frame
|
||||
from detect.profiles.base import OCRConfig
|
||||
from detect.stages.ocr_stage import run_ocr
|
||||
|
||||
# Create a frame with text baked in
|
||||
image = make_text_image("EMIRATES")
|
||||
frame = Frame(sequence=0, chunk_id=0, timestamp=1.0, image=image)
|
||||
box = BoundingBox(x=0, y=0, w=image.shape[1], h=image.shape[0], confidence=0.9, label="text")
|
||||
config = OCRConfig(languages=["en"], min_confidence=0.3)
|
||||
|
||||
candidates = run_ocr(
|
||||
frames=[frame],
|
||||
boxes_by_frame={0: [box]},
|
||||
config=config,
|
||||
inference_url=url,
|
||||
)
|
||||
|
||||
logger.info("Candidates: %d", len(candidates))
|
||||
for c in candidates:
|
||||
logger.info(" text=%r confidence=%.3f", c.text, c.ocr_confidence)
|
||||
|
||||
if candidates:
|
||||
logger.info("PASS — ocr_stage remote path returned results")
|
||||
else:
|
||||
logger.warning("MISS — no candidates returned (check inference server logs)")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--url", default="http://mcrndeb:8000")
|
||||
args = parser.parse_args()
|
||||
|
||||
url = args.url.rstrip("/")
|
||||
logger.info("Inference server: %s", url)
|
||||
input("\nPress Enter to start...")
|
||||
|
||||
test_health(url)
|
||||
test_ocr_endpoint(url, "NIKE")
|
||||
test_ocr_endpoint(url, "Coca-Cola")
|
||||
test_ocr_endpoint(url, "EMIRATES")
|
||||
test_ocr_stage_remote(url)
|
||||
|
||||
logger.info("All OCR tests complete.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
92
tests/detect/test_brand_resolver.py
Normal file
92
tests/detect/test_brand_resolver.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Tests for BrandResolver stage."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, Frame, TextCandidate
|
||||
from detect.profiles.base import BrandDictionary, ResolverConfig
|
||||
from detect.stages.brand_resolver import resolve_brands, _exact_match, _fuzzy_match
|
||||
|
||||
|
||||
DICTIONARY = BrandDictionary(brands={
|
||||
"Nike": ["nike", "NIKE", "swoosh"],
|
||||
"Adidas": ["adidas", "ADIDAS"],
|
||||
"Coca-Cola": ["coca-cola", "coca cola", "coke", "COCA-COLA"],
|
||||
"Emirates": ["emirates", "fly emirates", "EMIRATES"],
|
||||
})
|
||||
|
||||
CONFIG = ResolverConfig(fuzzy_threshold=75)
|
||||
|
||||
|
||||
def _make_candidate(text: str, confidence: float = 0.9) -> TextCandidate:
|
||||
dummy_frame = Frame(sequence=0, chunk_id=0, timestamp=1.0,
|
||||
image=np.zeros((10, 10, 3), dtype=np.uint8))
|
||||
dummy_box = BoundingBox(x=0, y=0, w=10, h=10, confidence=0.8, label="text")
|
||||
return TextCandidate(frame=dummy_frame, bbox=dummy_box, text=text, ocr_confidence=confidence)
|
||||
|
||||
|
||||
def test_exact_match():
|
||||
assert _exact_match("Nike", DICTIONARY) == "Nike"
|
||||
assert _exact_match("nike", DICTIONARY) == "Nike"
|
||||
assert _exact_match("COCA-COLA", DICTIONARY) == "Coca-Cola"
|
||||
assert _exact_match("fly emirates", DICTIONARY) == "Emirates"
|
||||
assert _exact_match("unknown brand", DICTIONARY) is None
|
||||
|
||||
|
||||
def test_fuzzy_match():
|
||||
brand, score = _fuzzy_match("Nik3", DICTIONARY, threshold=75)
|
||||
assert brand == "Nike"
|
||||
assert score >= 75
|
||||
|
||||
brand, score = _fuzzy_match("adldas", DICTIONARY, threshold=75)
|
||||
assert brand == "Adidas"
|
||||
|
||||
brand, score = _fuzzy_match("xyzxyzxyz", DICTIONARY, threshold=75)
|
||||
assert brand is None
|
||||
|
||||
|
||||
def test_resolve_exact():
|
||||
candidates = [_make_candidate("Nike"), _make_candidate("EMIRATES")]
|
||||
matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG)
|
||||
assert len(matched) == 2
|
||||
assert len(unresolved) == 0
|
||||
assert matched[0].brand == "Nike"
|
||||
assert matched[1].brand == "Emirates"
|
||||
|
||||
|
||||
def test_resolve_fuzzy():
|
||||
candidates = [_make_candidate("coca coIa")] # OCR misread
|
||||
matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG)
|
||||
assert len(matched) == 1
|
||||
assert matched[0].brand == "Coca-Cola"
|
||||
|
||||
|
||||
def test_resolve_unresolved():
|
||||
candidates = [_make_candidate("random garbage text")]
|
||||
matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG)
|
||||
assert len(matched) == 0
|
||||
assert len(unresolved) == 1
|
||||
|
||||
|
||||
def test_resolve_mixed():
|
||||
candidates = [
|
||||
_make_candidate("Nike"),
|
||||
_make_candidate("unknown"),
|
||||
_make_candidate("adldas"),
|
||||
]
|
||||
matched, unresolved = resolve_brands(candidates, DICTIONARY, CONFIG)
|
||||
assert len(matched) == 2 # Nike exact + Adidas fuzzy
|
||||
assert len(unresolved) == 1
|
||||
|
||||
|
||||
def test_events_emitted(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
candidates = [_make_candidate("Nike")]
|
||||
resolve_brands(candidates, DICTIONARY, CONFIG, job_id="test-job")
|
||||
|
||||
event_types = [e[0] for e in events]
|
||||
assert "log" in event_types
|
||||
assert "detection" in event_types
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Tests for the LangGraph detection pipeline."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from detect.graph import NODES, build_graph, get_pipeline
|
||||
@@ -9,6 +11,22 @@ from detect.state import DetectState
|
||||
VIDEO = "media/out/chunks/95043d50-4df6-4ac8-bbd5-2ba873117c6e/chunk_0000.mp4"
|
||||
|
||||
|
||||
def _has_inference() -> bool:
|
||||
if os.environ.get("INFERENCE_URL"):
|
||||
return True
|
||||
try:
|
||||
import ultralytics
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
requires_inference = pytest.mark.skipif(
|
||||
not _has_inference(),
|
||||
reason="Needs INFERENCE_URL or ultralytics installed",
|
||||
)
|
||||
|
||||
|
||||
def test_graph_compiles():
|
||||
pipeline = get_pipeline()
|
||||
assert pipeline is not None
|
||||
@@ -20,6 +38,7 @@ def test_graph_has_all_nodes():
|
||||
assert node in graph.nodes
|
||||
|
||||
|
||||
@requires_inference
|
||||
def test_graph_runs_end_to_end(monkeypatch):
|
||||
"""Run the full graph with mocked event emission."""
|
||||
events = []
|
||||
@@ -52,6 +71,7 @@ def test_graph_runs_end_to_end(monkeypatch):
|
||||
assert len(complete_events) == 1
|
||||
|
||||
|
||||
@requires_inference
|
||||
def test_graph_node_transitions(monkeypatch):
|
||||
"""Verify each node emits running → done transitions."""
|
||||
events = []
|
||||
|
||||
141
tests/detect/test_ocr_stage.py
Normal file
141
tests/detect/test_ocr_stage.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Tests for OCR stage."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from detect.models import BoundingBox, Frame
|
||||
from detect.profiles.base import OCRConfig
|
||||
from detect.stages.ocr_stage import _crop_region, _parse_ocr_raw, run_ocr
|
||||
|
||||
|
||||
def _has_paddleocr() -> bool:
|
||||
try:
|
||||
import paddleocr
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def _make_frame(seq: int = 0, w: int = 100, h: int = 80) -> Frame:
|
||||
image = np.zeros((h, w, 3), dtype=np.uint8)
|
||||
return Frame(sequence=seq, chunk_id=0, timestamp=float(seq), image=image)
|
||||
|
||||
|
||||
def _make_box(x=10, y=10, w=30, h=20) -> BoundingBox:
|
||||
return BoundingBox(x=x, y=y, w=w, h=h, confidence=0.9, label="text")
|
||||
|
||||
|
||||
# --- _crop_region ---
|
||||
|
||||
def test_crop_basic():
|
||||
frame = _make_frame()
|
||||
box = _make_box(x=10, y=20, w=30, h=15)
|
||||
crop = _crop_region(frame, box)
|
||||
assert crop.shape == (15, 30, 3)
|
||||
|
||||
|
||||
def test_crop_clamps_to_frame():
|
||||
frame = _make_frame(w=50, h=40)
|
||||
box = _make_box(x=30, y=25, w=100, h=100)
|
||||
crop = _crop_region(frame, box)
|
||||
assert crop.shape[0] == 15 # 40 - 25
|
||||
assert crop.shape[1] == 20 # 50 - 30
|
||||
|
||||
|
||||
def test_crop_negative_origin():
|
||||
frame = _make_frame()
|
||||
box = _make_box(x=-5, y=-5, w=20, h=20)
|
||||
crop = _crop_region(frame, box)
|
||||
assert crop.shape[0] == 15 # min(80, -5+20) - 0
|
||||
assert crop.shape[1] == 15 # min(100, -5+20) - 0
|
||||
|
||||
|
||||
# --- _parse_ocr_raw ---
|
||||
|
||||
def test_parse_nested_list_layout():
|
||||
raw = [[
|
||||
[[[0, 0], [10, 0], [10, 10], [0, 10]], ["hello", 0.95]],
|
||||
[[[0, 0], [10, 0], [10, 10], [0, 10]], ["low", 0.2]],
|
||||
]]
|
||||
results = _parse_ocr_raw(raw, min_confidence=0.5)
|
||||
assert len(results) == 1
|
||||
assert results[0]["text"] == "hello"
|
||||
assert results[0]["confidence"] == 0.95
|
||||
|
||||
|
||||
def test_parse_dict_layout():
|
||||
raw = [{"rec_texts": ["brand", "noise"], "rec_scores": [0.9, 0.3]}]
|
||||
results = _parse_ocr_raw(raw, min_confidence=0.5)
|
||||
assert len(results) == 1
|
||||
assert results[0]["text"] == "brand"
|
||||
|
||||
|
||||
def test_parse_empty():
|
||||
assert _parse_ocr_raw(None, 0.5) == []
|
||||
assert _parse_ocr_raw([], 0.5) == []
|
||||
assert _parse_ocr_raw([[]], 0.5) == []
|
||||
|
||||
|
||||
# --- run_ocr (remote, mocked) ---
|
||||
|
||||
def test_run_ocr_remote(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
class FakeResult:
|
||||
def __init__(self, text, confidence):
|
||||
self.text = text
|
||||
self.confidence = confidence
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self, base_url):
|
||||
pass
|
||||
def ocr(self, image, languages):
|
||||
return [FakeResult("NIKE", 0.92)]
|
||||
|
||||
monkeypatch.setattr("detect.stages.ocr_stage.InferenceClient", FakeClient,
|
||||
raising=False)
|
||||
# Patch the import path used in the function
|
||||
import detect.stages.ocr_stage as mod
|
||||
monkeypatch.setattr("detect.inference.InferenceClient", FakeClient)
|
||||
|
||||
frame = _make_frame()
|
||||
box = _make_box()
|
||||
config = OCRConfig(languages=["en"], min_confidence=0.5)
|
||||
|
||||
candidates = run_ocr(
|
||||
frames=[frame],
|
||||
boxes_by_frame={0: [box]},
|
||||
config=config,
|
||||
inference_url="http://fake:8000",
|
||||
job_id="test",
|
||||
)
|
||||
|
||||
assert len(candidates) == 1
|
||||
assert candidates[0].text == "NIKE"
|
||||
assert candidates[0].ocr_confidence == 0.92
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _has_paddleocr(),
|
||||
reason="Needs paddleocr installed (GPU box)",
|
||||
)
|
||||
def test_run_ocr_skips_empty_crop(monkeypatch):
|
||||
events = []
|
||||
monkeypatch.setattr("detect.emit.push_detect_event",
|
||||
lambda job_id, etype, data: events.append((etype, data)))
|
||||
|
||||
frame = _make_frame(w=10, h=10)
|
||||
box = _make_box(x=100, y=100, w=5, h=5) # outside frame → empty crop
|
||||
config = OCRConfig(languages=["en"], min_confidence=0.5)
|
||||
|
||||
candidates = run_ocr(
|
||||
frames=[frame],
|
||||
boxes_by_frame={0: [box]},
|
||||
config=config,
|
||||
inference_url=None,
|
||||
job_id="test",
|
||||
)
|
||||
|
||||
assert len(candidates) == 0
|
||||
@@ -22,7 +22,7 @@ def test_soccer_frame_extraction_config():
|
||||
def test_soccer_detection_config():
|
||||
cfg = SoccerBroadcastProfile().detection_config()
|
||||
assert 0 < cfg.confidence_threshold < 1
|
||||
assert len(cfg.target_classes) > 0
|
||||
assert isinstance(cfg.target_classes, list)
|
||||
|
||||
|
||||
def test_soccer_brand_dictionary_non_empty():
|
||||
|
||||
Reference in New Issue
Block a user