108 lines
3.6 KiB
Python
108 lines
3.6 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test cloud LLM provider with a real API call.
|
|
|
|
Sends a test image to the configured cloud provider and verifies
|
|
the response. Set your provider env vars before running.
|
|
|
|
Usage:
|
|
# Groq (default)
|
|
CLOUD_LLM_PROVIDER=groq GROQ_API_KEY=gsk_... python tests/detect/manual/test_cloud_provider.py
|
|
|
|
# Gemini
|
|
CLOUD_LLM_PROVIDER=gemini GEMINI_API_KEY=AIza... python tests/detect/manual/test_cloud_provider.py
|
|
|
|
# Claude
|
|
CLOUD_LLM_PROVIDER=claude ANTHROPIC_API_KEY=sk-ant-... python tests/detect/manual/test_cloud_provider.py
|
|
|
|
# OpenAI-compatible
|
|
CLOUD_LLM_PROVIDER=openai OPENAI_API_KEY=sk-... python tests/detect/manual/test_cloud_provider.py
|
|
"""
|
|
|
|
import base64
|
|
import io
|
|
import logging
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Load .env from ctrl/ (same as docker-compose uses)
|
|
env_file = Path(__file__).resolve().parents[3] / "ctrl" / ".env"
|
|
if env_file.exists():
|
|
for line in env_file.read_text().splitlines():
|
|
line = line.strip()
|
|
if line and not line.startswith("#") and "=" in line:
|
|
key, _, value = line.partition("=")
|
|
os.environ.setdefault(key.strip(), value.strip())
|
|
|
|
import numpy as np
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
|
|
sys.path.insert(0, ".")
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def make_brand_image(text: str, width: int = 300, height: int = 100) -> str:
|
|
img = Image.new("RGB", (width, height), "white")
|
|
draw = ImageDraw.Draw(img)
|
|
try:
|
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 42)
|
|
except (OSError, IOError):
|
|
font = ImageFont.load_default()
|
|
draw.text((10, 20), text, fill="black", font=font)
|
|
buf = io.BytesIO()
|
|
img.save(buf, "JPEG")
|
|
return base64.b64encode(buf.getvalue()).decode()
|
|
|
|
|
|
def main():
|
|
from core.detect.providers import get_provider, has_api_key, PROVIDERS
|
|
|
|
provider_name = os.environ.get("CLOUD_LLM_PROVIDER", "groq")
|
|
logger.info("Provider: %s", provider_name)
|
|
logger.info("Available providers: %s", list(PROVIDERS.keys()))
|
|
|
|
if not has_api_key():
|
|
logger.error("No API key set for provider '%s'", provider_name)
|
|
logger.error("Set the appropriate env var (see usage in docstring)")
|
|
sys.exit(1)
|
|
|
|
provider = get_provider()
|
|
logger.info("Model: %s", provider.model)
|
|
logger.info("Available models: %s", list(provider.models.keys()))
|
|
input("\nPress Enter to start...")
|
|
|
|
prompt = (
|
|
"Identify the brand or sponsor visible in this image from a soccer broadcast. "
|
|
"Respond with: brand, confidence (0-1), reasoning."
|
|
)
|
|
|
|
test_cases = ["NIKE", "EMIRATES", "Coca-Cola", "adidas"]
|
|
|
|
for text in test_cases:
|
|
logger.info("--- Testing: '%s' ---", text)
|
|
image_b64 = make_brand_image(text)
|
|
|
|
try:
|
|
result = provider.call(image_b64, prompt)
|
|
logger.info(" answer: %s", result.answer)
|
|
logger.info(" tokens: %d", result.total_tokens)
|
|
|
|
if text.lower() in result.answer.lower():
|
|
logger.info(" PASS — found '%s' in response", text)
|
|
else:
|
|
logger.warning(" MISS — '%s' not in response (may be correct, check answer)", text)
|
|
|
|
except Exception as e:
|
|
logger.error(" FAIL — %s: %s", type(e).__name__, e)
|
|
if hasattr(e, 'response') and e.response is not None:
|
|
logger.error(" Response: %s", e.response.text[:500])
|
|
|
|
logger.info("All provider tests complete.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|