phase 9
This commit is contained in:
100
tests/detect/manual/test_vlm_e2e.py
Normal file
100
tests/detect/manual/test_vlm_e2e.py
Normal file
@@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test local VLM (moondream2) via the inference server.
|
||||
|
||||
Creates test images with brand text/logos, sends them to the /vlm endpoint,
|
||||
verifies moondream2 can identify the brand.
|
||||
|
||||
Usage:
|
||||
python tests/detect/manual/test_vlm_e2e.py [--url URL]
|
||||
|
||||
Requires: inference server running with moondream2 loaded (gpu/server.py)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_brand_image(text: str, width: int = 300, height: int = 100) -> np.ndarray:
|
||||
img = Image.new("RGB", (width, height), "white")
|
||||
draw = ImageDraw.Draw(img)
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 42)
|
||||
except (OSError, IOError):
|
||||
font = ImageFont.load_default()
|
||||
draw.text((10, 20), text, fill="black", font=font)
|
||||
return np.array(img)
|
||||
|
||||
|
||||
def image_to_b64(image: np.ndarray) -> str:
|
||||
img = Image.fromarray(image)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, "JPEG")
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
def test_health(url: str):
|
||||
logger.info("--- Health check ---")
|
||||
resp = requests.get(f"{url}/health")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
logger.info("Status: %s, device: %s, models: %s", data["status"], data["device"], data.get("loaded_models", []))
|
||||
|
||||
|
||||
def test_vlm(url: str, text: str, prompt: str):
|
||||
logger.info("--- VLM: image='%s' ---", text)
|
||||
image = make_brand_image(text)
|
||||
b64 = image_to_b64(image)
|
||||
|
||||
resp = requests.post(f"{url}/vlm", json={"image": b64, "prompt": prompt})
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
logger.info(" brand: %s", data["brand"])
|
||||
logger.info(" confidence: %.2f", data["confidence"])
|
||||
logger.info(" reasoning: %s", data["reasoning"])
|
||||
|
||||
if text.lower() in data["brand"].lower():
|
||||
logger.info(" PASS — matched")
|
||||
else:
|
||||
logger.warning(" MISS — expected '%s' in response", text)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--url", default="http://mcrndeb:8000")
|
||||
args = parser.parse_args()
|
||||
|
||||
url = args.url.rstrip("/")
|
||||
logger.info("Inference server: %s", url)
|
||||
input("\nPress Enter to start...")
|
||||
|
||||
test_health(url)
|
||||
|
||||
prompt = (
|
||||
"Identify the brand or sponsor visible in this image from a soccer broadcast. "
|
||||
"Respond with: brand, confidence (0-1), reasoning."
|
||||
)
|
||||
|
||||
test_vlm(url, "NIKE", prompt)
|
||||
test_vlm(url, "EMIRATES", prompt)
|
||||
test_vlm(url, "Coca-Cola", prompt)
|
||||
test_vlm(url, "adidas", prompt)
|
||||
|
||||
logger.info("All VLM tests complete.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user