101 lines
2.8 KiB
Python
101 lines
2.8 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test local VLM (moondream2) via the inference server.
|
|
|
|
Creates test images with brand text/logos, sends them to the /vlm endpoint,
|
|
verifies moondream2 can identify the brand.
|
|
|
|
Usage:
|
|
python tests/detect/manual/test_vlm_e2e.py [--url URL]
|
|
|
|
Requires: inference server running with moondream2 loaded (gpu/server.py)
|
|
"""
|
|
|
|
import argparse
|
|
import base64
|
|
import io
|
|
import logging
|
|
import sys
|
|
|
|
import numpy as np
|
|
import requests
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def make_brand_image(text: str, width: int = 300, height: int = 100) -> np.ndarray:
|
|
img = Image.new("RGB", (width, height), "white")
|
|
draw = ImageDraw.Draw(img)
|
|
try:
|
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 42)
|
|
except (OSError, IOError):
|
|
font = ImageFont.load_default()
|
|
draw.text((10, 20), text, fill="black", font=font)
|
|
return np.array(img)
|
|
|
|
|
|
def image_to_b64(image: np.ndarray) -> str:
|
|
img = Image.fromarray(image)
|
|
buf = io.BytesIO()
|
|
img.save(buf, "JPEG")
|
|
return base64.b64encode(buf.getvalue()).decode()
|
|
|
|
|
|
def test_health(url: str):
|
|
logger.info("--- Health check ---")
|
|
resp = requests.get(f"{url}/health")
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
logger.info("Status: %s, device: %s, models: %s", data["status"], data["device"], data.get("loaded_models", []))
|
|
|
|
|
|
def test_vlm(url: str, text: str, prompt: str):
|
|
logger.info("--- VLM: image='%s' ---", text)
|
|
image = make_brand_image(text)
|
|
b64 = image_to_b64(image)
|
|
|
|
resp = requests.post(f"{url}/vlm", json={"image": b64, "prompt": prompt})
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
|
|
logger.info(" brand: %s", data["brand"])
|
|
logger.info(" confidence: %.2f", data["confidence"])
|
|
logger.info(" reasoning: %s", data["reasoning"])
|
|
|
|
if text.lower() in data["brand"].lower():
|
|
logger.info(" PASS — matched")
|
|
else:
|
|
logger.warning(" MISS — expected '%s' in response", text)
|
|
|
|
return data
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--url", default="http://mcrndeb:8000")
|
|
args = parser.parse_args()
|
|
|
|
url = args.url.rstrip("/")
|
|
logger.info("Inference server: %s", url)
|
|
input("\nPress Enter to start...")
|
|
|
|
test_health(url)
|
|
|
|
prompt = (
|
|
"Identify the brand or sponsor visible in this image from a soccer broadcast. "
|
|
"Respond with: brand, confidence (0-1), reasoning."
|
|
)
|
|
|
|
test_vlm(url, "NIKE", prompt)
|
|
test_vlm(url, "EMIRATES", prompt)
|
|
test_vlm(url, "Coca-Cola", prompt)
|
|
test_vlm(url, "adidas", prompt)
|
|
|
|
logger.info("All VLM tests complete.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|