Files
mediaproc/tests/detect/manual/test_vlm_e2e.py
2026-03-26 02:54:56 -03:00

101 lines
2.8 KiB
Python

#!/usr/bin/env python3
"""
Test local VLM (moondream2) via the inference server.
Creates test images with brand text/logos, sends them to the /vlm endpoint,
verifies moondream2 can identify the brand.
Usage:
python tests/detect/manual/test_vlm_e2e.py [--url URL]
Requires: inference server running with moondream2 loaded (gpu/server.py)
"""
import argparse
import base64
import io
import logging
import sys
import numpy as np
import requests
from PIL import Image, ImageDraw, ImageFont
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s%(message)s")
logger = logging.getLogger(__name__)
def make_brand_image(text: str, width: int = 300, height: int = 100) -> np.ndarray:
img = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 42)
except (OSError, IOError):
font = ImageFont.load_default()
draw.text((10, 20), text, fill="black", font=font)
return np.array(img)
def image_to_b64(image: np.ndarray) -> str:
img = Image.fromarray(image)
buf = io.BytesIO()
img.save(buf, "JPEG")
return base64.b64encode(buf.getvalue()).decode()
def test_health(url: str):
logger.info("--- Health check ---")
resp = requests.get(f"{url}/health")
resp.raise_for_status()
data = resp.json()
logger.info("Status: %s, device: %s, models: %s", data["status"], data["device"], data.get("loaded_models", []))
def test_vlm(url: str, text: str, prompt: str):
logger.info("--- VLM: image='%s' ---", text)
image = make_brand_image(text)
b64 = image_to_b64(image)
resp = requests.post(f"{url}/vlm", json={"image": b64, "prompt": prompt})
resp.raise_for_status()
data = resp.json()
logger.info(" brand: %s", data["brand"])
logger.info(" confidence: %.2f", data["confidence"])
logger.info(" reasoning: %s", data["reasoning"])
if text.lower() in data["brand"].lower():
logger.info(" PASS — matched")
else:
logger.warning(" MISS — expected '%s' in response", text)
return data
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--url", default="http://mcrndeb:8000")
args = parser.parse_args()
url = args.url.rstrip("/")
logger.info("Inference server: %s", url)
input("\nPress Enter to start...")
test_health(url)
prompt = (
"Identify the brand or sponsor visible in this image from a soccer broadcast. "
"Respond with: brand, confidence (0-1), reasoning."
)
test_vlm(url, "NIKE", prompt)
test_vlm(url, "EMIRATES", prompt)
test_vlm(url, "Coca-Cola", prompt)
test_vlm(url, "adidas", prompt)
logger.info("All VLM tests complete.")
if __name__ == "__main__":
main()