67 lines
2.0 KiB
Python
67 lines
2.0 KiB
Python
"""Groq cloud provider — OpenAI-compatible API with vision."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
|
|
import requests
|
|
|
|
from .base import ModelInfo, ProviderResponse
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Groq-specific env vars
|
|
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
|
|
GROQ_BASE_URL = os.environ.get("GROQ_BASE_URL", "https://api.groq.com/openai/v1")
|
|
GROQ_MODEL = os.environ.get("GROQ_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct")
|
|
|
|
MODELS = {
|
|
"meta-llama/llama-4-scout-17b-16e-instruct": ModelInfo(
|
|
id="meta-llama/llama-4-scout-17b-16e-instruct",
|
|
vision=True,
|
|
cost_per_input_token=0.0,
|
|
cost_per_output_token=0.0,
|
|
notes="Llama 4 Scout, only vision model on Groq free tier",
|
|
),
|
|
}
|
|
|
|
|
|
class GroqProvider:
|
|
name = "groq"
|
|
models = MODELS
|
|
|
|
def __init__(self):
|
|
self.api_key = GROQ_API_KEY
|
|
self.base_url = GROQ_BASE_URL
|
|
self.model = GROQ_MODEL
|
|
self.endpoint = f"{self.base_url.rstrip('/')}/chat/completions"
|
|
self.headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": [{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": prompt},
|
|
{"type": "image_url", "image_url": {
|
|
"url": f"data:image/jpeg;base64,{image_b64}",
|
|
}},
|
|
],
|
|
}],
|
|
"max_tokens": 150,
|
|
}
|
|
|
|
resp = requests.post(self.endpoint, headers=self.headers, json=payload, timeout=30)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
|
|
answer = data["choices"][0]["message"]["content"].strip()
|
|
total_tokens = data.get("usage", {}).get("total_tokens", 0)
|
|
|
|
return ProviderResponse(answer=answer, total_tokens=total_tokens)
|