some changes
This commit is contained in:
@@ -18,21 +18,35 @@ SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording
|
||||
You help the user understand what happened during their recording session.
|
||||
Be concise and specific. Focus on what's visible in the provided frames."""
|
||||
|
||||
# Default models per provider
|
||||
_PROVIDER_DEFAULTS = {
|
||||
"groq": ("https://api.groq.com/openai/v1", "meta-llama/llama-4-maverick-17b-128e-instruct"),
|
||||
"openai": ("https://api.openai.com/v1", "gpt-4o"),
|
||||
# Provider configs: (base_url, default_model, available_models)
|
||||
_PROVIDER_CONFIGS = {
|
||||
"groq": (
|
||||
"https://api.groq.com/openai/v1",
|
||||
"meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
[
|
||||
"meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
"meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
"qwen/qwen-2.5-vl-72b-instruct",
|
||||
],
|
||||
),
|
||||
"openai": (
|
||||
"https://api.openai.com/v1",
|
||||
"gpt-4o",
|
||||
["gpt-4o", "gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini"],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _detect_provider() -> tuple[str, str, str] | None:
|
||||
"""Returns (api_key, base_url, model) or None if no key found."""
|
||||
def _detect_provider() -> tuple[str, str, str, list[str]] | None:
|
||||
"""Returns (api_key, base_url, model, available_models) or None."""
|
||||
if key := os.environ.get("GROQ_API_KEY"):
|
||||
base_url, model = _PROVIDER_DEFAULTS["groq"]
|
||||
return key, base_url, os.environ.get("CHT_MODEL", model)
|
||||
base_url, default_model, models = _PROVIDER_CONFIGS["groq"]
|
||||
model = os.environ.get("CHT_MODEL", default_model)
|
||||
return key, base_url, model, models
|
||||
if key := os.environ.get("OPENAI_API_KEY"):
|
||||
base_url, model = _PROVIDER_DEFAULTS["openai"]
|
||||
return key, base_url, os.environ.get("CHT_MODEL", model)
|
||||
base_url, default_model, models = _PROVIDER_CONFIGS["openai"]
|
||||
model = os.environ.get("CHT_MODEL", default_model)
|
||||
return key, base_url, model, models
|
||||
return None
|
||||
|
||||
|
||||
@@ -54,7 +68,7 @@ class OpenAICompatProvider(AgentProvider):
|
||||
raise RuntimeError(
|
||||
"No API key found. Set GROQ_API_KEY or OPENAI_API_KEY."
|
||||
)
|
||||
self._api_key, self._base_url, self._model = detected
|
||||
self._api_key, self._base_url, self._model, self._models = detected
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -62,6 +76,18 @@ class OpenAICompatProvider(AgentProvider):
|
||||
return f"groq/{self._model}"
|
||||
return f"openai-compat/{self._model}"
|
||||
|
||||
@property
|
||||
def available_models(self) -> list[str]:
|
||||
return list(self._models)
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self._model
|
||||
|
||||
@model.setter
|
||||
def model(self, value: str):
|
||||
self._model = value
|
||||
|
||||
def stream(self, message: str, context: SessionContext) -> Iterator[str]:
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
Reference in New Issue
Block a user