76 lines
2.1 KiB
Python
76 lines
2.1 KiB
Python
"""Google Gemini provider — native REST API, not OpenAI-compatible."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
|
|
import requests
|
|
|
|
from .base import ModelInfo, ProviderResponse
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Gemini-specific env vars
|
|
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
|
|
GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-2.0-flash")
|
|
|
|
MODELS = {
|
|
"gemini-2.0-flash": ModelInfo(
|
|
id="gemini-2.0-flash",
|
|
vision=True,
|
|
cost_per_input_token=0.0000001,
|
|
cost_per_output_token=0.0000004,
|
|
notes="Fast, cheap, good vision",
|
|
),
|
|
"gemini-2.0-pro": ModelInfo(
|
|
id="gemini-2.0-pro",
|
|
vision=True,
|
|
cost_per_input_token=0.00000125,
|
|
cost_per_output_token=0.000005,
|
|
notes="Higher quality, slower",
|
|
),
|
|
"gemini-1.5-flash": ModelInfo(
|
|
id="gemini-1.5-flash",
|
|
vision=True,
|
|
cost_per_input_token=0.000000075,
|
|
cost_per_output_token=0.0000003,
|
|
notes="Cheapest option",
|
|
),
|
|
}
|
|
|
|
|
|
class GeminiProvider:
|
|
name = "gemini"
|
|
models = MODELS
|
|
|
|
def __init__(self):
|
|
self.api_key = GEMINI_API_KEY
|
|
self.model = GEMINI_MODEL
|
|
self.endpoint = (
|
|
f"https://generativelanguage.googleapis.com/v1beta/models/"
|
|
f"{self.model}:generateContent"
|
|
)
|
|
|
|
def call(self, image_b64: str, prompt: str) -> ProviderResponse:
|
|
payload = {
|
|
"contents": [{
|
|
"parts": [
|
|
{"text": prompt},
|
|
{"inline_data": {"mime_type": "image/jpeg", "data": image_b64}},
|
|
],
|
|
}],
|
|
"generationConfig": {"maxOutputTokens": 150},
|
|
}
|
|
|
|
url = f"{self.endpoint}?key={self.api_key}"
|
|
resp = requests.post(url, json=payload, timeout=30)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
|
|
answer = data["candidates"][0]["content"]["parts"][0]["text"].strip()
|
|
usage = data.get("usageMetadata", {})
|
|
total_tokens = usage.get("totalTokenCount", 0)
|
|
|
|
return ProviderResponse(answer=answer, total_tokens=total_tokens)
|