136 lines
3.9 KiB
Python
136 lines
3.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test OCR stage end-to-end — sends real images to the inference server.
|
|
|
|
Creates test images with known text, sends them through the /ocr endpoint,
|
|
verifies the text comes back. Tests both the inference server and the
|
|
ocr_stage module's remote path.
|
|
|
|
Usage:
|
|
python tests/detect/manual/test_ocr_e2e.py [--url URL]
|
|
|
|
Requires: inference server running (gpu/server.py)
|
|
"""
|
|
|
|
import argparse
|
|
import base64
|
|
import io
|
|
import json
|
|
import logging
|
|
import sys
|
|
|
|
import numpy as np
|
|
import requests
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def make_text_image(text: str, width: int = 300, height: int = 80) -> np.ndarray:
|
|
"""Create a white image with black text for OCR testing."""
|
|
img = Image.new("RGB", (width, height), "white")
|
|
draw = ImageDraw.Draw(img)
|
|
try:
|
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 36)
|
|
except (OSError, IOError):
|
|
font = ImageFont.load_default()
|
|
draw.text((10, 15), text, fill="black", font=font)
|
|
return np.array(img)
|
|
|
|
|
|
def image_to_b64(image: np.ndarray) -> str:
|
|
img = Image.fromarray(image)
|
|
buf = io.BytesIO()
|
|
img.save(buf, "JPEG")
|
|
return base64.b64encode(buf.getvalue()).decode()
|
|
|
|
|
|
def test_health(url: str):
|
|
logger.info("--- Health check ---")
|
|
resp = requests.get(f"{url}/health")
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
logger.info("Status: %s, device: %s", data["status"], data["device"])
|
|
return True
|
|
|
|
|
|
def test_ocr_endpoint(url: str, text: str):
|
|
logger.info("--- OCR endpoint: '%s' ---", text)
|
|
image = make_text_image(text)
|
|
b64 = image_to_b64(image)
|
|
|
|
resp = requests.post(f"{url}/ocr", json={"image": b64})
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
|
|
results = data.get("results", [])
|
|
logger.info("Results: %d text regions", len(results))
|
|
|
|
found = False
|
|
for r in results:
|
|
logger.info(" text=%r confidence=%.3f bbox=%s", r["text"], r["confidence"], r["bbox"])
|
|
if text.lower() in r["text"].lower():
|
|
found = True
|
|
|
|
if found:
|
|
logger.info("PASS — found '%s' in OCR output", text)
|
|
else:
|
|
logger.warning("MISS — '%s' not found (may be font/rendering issue, check results above)", text)
|
|
|
|
return results
|
|
|
|
|
|
def test_ocr_stage_remote(url: str):
|
|
"""Test the detect/stages/ocr_stage.py remote path."""
|
|
logger.info("--- OCR stage (remote mode) ---")
|
|
|
|
sys.path.insert(0, ".")
|
|
from detect.models import BoundingBox, Frame
|
|
from detect.profiles.base import OCRConfig
|
|
from detect.stages.ocr_stage import run_ocr
|
|
|
|
# Create a frame with text baked in
|
|
image = make_text_image("EMIRATES")
|
|
frame = Frame(sequence=0, chunk_id=0, timestamp=1.0, image=image)
|
|
box = BoundingBox(x=0, y=0, w=image.shape[1], h=image.shape[0], confidence=0.9, label="text")
|
|
config = OCRConfig(languages=["en"], min_confidence=0.3)
|
|
|
|
candidates = run_ocr(
|
|
frames=[frame],
|
|
boxes_by_frame={0: [box]},
|
|
config=config,
|
|
inference_url=url,
|
|
)
|
|
|
|
logger.info("Candidates: %d", len(candidates))
|
|
for c in candidates:
|
|
logger.info(" text=%r confidence=%.3f", c.text, c.ocr_confidence)
|
|
|
|
if candidates:
|
|
logger.info("PASS — ocr_stage remote path returned results")
|
|
else:
|
|
logger.warning("MISS — no candidates returned (check inference server logs)")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--url", default="http://mcrndeb:8000")
|
|
args = parser.parse_args()
|
|
|
|
url = args.url.rstrip("/")
|
|
logger.info("Inference server: %s", url)
|
|
input("\nPress Enter to start...")
|
|
|
|
test_health(url)
|
|
test_ocr_endpoint(url, "NIKE")
|
|
test_ocr_endpoint(url, "Coca-Cola")
|
|
test_ocr_endpoint(url, "EMIRATES")
|
|
test_ocr_stage_remote(url)
|
|
|
|
logger.info("All OCR tests complete.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|