Files
mediaproc/tests/detect/manual/test_cloud_provider.py
2026-03-26 02:54:56 -03:00

108 lines
3.6 KiB
Python

#!/usr/bin/env python3
"""
Test cloud LLM provider with a real API call.
Sends a test image to the configured cloud provider and verifies
the response. Set your provider env vars before running.
Usage:
# Groq (default)
CLOUD_LLM_PROVIDER=groq GROQ_API_KEY=gsk_... python tests/detect/manual/test_cloud_provider.py
# Gemini
CLOUD_LLM_PROVIDER=gemini GEMINI_API_KEY=AIza... python tests/detect/manual/test_cloud_provider.py
# Claude
CLOUD_LLM_PROVIDER=claude ANTHROPIC_API_KEY=sk-ant-... python tests/detect/manual/test_cloud_provider.py
# OpenAI-compatible
CLOUD_LLM_PROVIDER=openai OPENAI_API_KEY=sk-... python tests/detect/manual/test_cloud_provider.py
"""
import base64
import io
import logging
import os
import sys
from pathlib import Path
# Load .env from ctrl/ (same as docker-compose uses)
env_file = Path(__file__).resolve().parents[3] / "ctrl" / ".env"
if env_file.exists():
for line in env_file.read_text().splitlines():
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, _, value = line.partition("=")
os.environ.setdefault(key.strip(), value.strip())
import numpy as np
from PIL import Image, ImageDraw, ImageFont
sys.path.insert(0, ".")
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s%(message)s")
logger = logging.getLogger(__name__)
def make_brand_image(text: str, width: int = 300, height: int = 100) -> str:
img = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 42)
except (OSError, IOError):
font = ImageFont.load_default()
draw.text((10, 20), text, fill="black", font=font)
buf = io.BytesIO()
img.save(buf, "JPEG")
return base64.b64encode(buf.getvalue()).decode()
def main():
from detect.providers import get_provider, has_api_key, PROVIDERS
provider_name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
logger.info("Provider: %s", provider_name)
logger.info("Available providers: %s", list(PROVIDERS.keys()))
if not has_api_key():
logger.error("No API key set for provider '%s'", provider_name)
logger.error("Set the appropriate env var (see usage in docstring)")
sys.exit(1)
provider = get_provider()
logger.info("Model: %s", provider.model)
logger.info("Available models: %s", list(provider.models.keys()))
input("\nPress Enter to start...")
prompt = (
"Identify the brand or sponsor visible in this image from a soccer broadcast. "
"Respond with: brand, confidence (0-1), reasoning."
)
test_cases = ["NIKE", "EMIRATES", "Coca-Cola", "adidas"]
for text in test_cases:
logger.info("--- Testing: '%s' ---", text)
image_b64 = make_brand_image(text)
try:
result = provider.call(image_b64, prompt)
logger.info(" answer: %s", result.answer)
logger.info(" tokens: %d", result.total_tokens)
if text.lower() in result.answer.lower():
logger.info(" PASS — found '%s' in response", text)
else:
logger.warning(" MISS — '%s' not in response (may be correct, check answer)", text)
except Exception as e:
logger.error(" FAIL — %s: %s", type(e).__name__, e)
if hasattr(e, 'response') and e.response is not None:
logger.error(" Response: %s", e.response.text[:500])
logger.info("All provider tests complete.")
if __name__ == "__main__":
main()