some changes
This commit is contained in:
@@ -36,3 +36,19 @@ class AgentProvider(ABC):
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def available_models(self) -> list[str]:
|
||||
"""Return list of model IDs this provider supports."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model(self) -> str:
|
||||
...
|
||||
|
||||
@model.setter
|
||||
@abstractmethod
|
||||
def model(self, value: str):
|
||||
...
|
||||
|
||||
@@ -51,16 +51,36 @@ def _build_prompt(message: str, context: SessionContext) -> str:
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
MODELS = [
|
||||
"claude-sonnet-4-6",
|
||||
"claude-opus-4-6",
|
||||
"claude-haiku-4-5",
|
||||
]
|
||||
|
||||
|
||||
class ClaudeSDKProvider(AgentProvider):
|
||||
"""Uses claude_agent_sdk — requires Claude Code CLI to be installed."""
|
||||
|
||||
def __init__(self, cwd: str | None = None, max_turns: int = 5):
|
||||
def __init__(self, cwd: str | None = None, max_turns: int = 5, model: str = MODELS[0]):
|
||||
self._cwd = cwd
|
||||
self._max_turns = max_turns
|
||||
self._model = model
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "claude-code-sdk"
|
||||
return f"claude-sdk/{self._model}"
|
||||
|
||||
@property
|
||||
def available_models(self) -> list[str]:
|
||||
return list(MODELS)
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self._model
|
||||
|
||||
@model.setter
|
||||
def model(self, value: str):
|
||||
self._model = value
|
||||
|
||||
def stream(self, message: str, context: SessionContext) -> Iterator[str]:
|
||||
prompt = _build_prompt(message, context)
|
||||
@@ -70,6 +90,7 @@ class ClaudeSDKProvider(AgentProvider):
|
||||
async for msg in query(
|
||||
prompt=prompt,
|
||||
options=ClaudeAgentOptions(
|
||||
model=self._model,
|
||||
cwd=self._cwd or str(context.session_dir),
|
||||
allowed_tools=["Read"],
|
||||
system_prompt=SYSTEM_PROMPT,
|
||||
|
||||
@@ -18,21 +18,35 @@ SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording
|
||||
You help the user understand what happened during their recording session.
|
||||
Be concise and specific. Focus on what's visible in the provided frames."""
|
||||
|
||||
# Default models per provider
|
||||
_PROVIDER_DEFAULTS = {
|
||||
"groq": ("https://api.groq.com/openai/v1", "meta-llama/llama-4-maverick-17b-128e-instruct"),
|
||||
"openai": ("https://api.openai.com/v1", "gpt-4o"),
|
||||
# Provider configs: (base_url, default_model, available_models)
|
||||
_PROVIDER_CONFIGS = {
|
||||
"groq": (
|
||||
"https://api.groq.com/openai/v1",
|
||||
"meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
[
|
||||
"meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
"meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
"qwen/qwen-2.5-vl-72b-instruct",
|
||||
],
|
||||
),
|
||||
"openai": (
|
||||
"https://api.openai.com/v1",
|
||||
"gpt-4o",
|
||||
["gpt-4o", "gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini"],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _detect_provider() -> tuple[str, str, str] | None:
|
||||
"""Returns (api_key, base_url, model) or None if no key found."""
|
||||
def _detect_provider() -> tuple[str, str, str, list[str]] | None:
|
||||
"""Returns (api_key, base_url, model, available_models) or None."""
|
||||
if key := os.environ.get("GROQ_API_KEY"):
|
||||
base_url, model = _PROVIDER_DEFAULTS["groq"]
|
||||
return key, base_url, os.environ.get("CHT_MODEL", model)
|
||||
base_url, default_model, models = _PROVIDER_CONFIGS["groq"]
|
||||
model = os.environ.get("CHT_MODEL", default_model)
|
||||
return key, base_url, model, models
|
||||
if key := os.environ.get("OPENAI_API_KEY"):
|
||||
base_url, model = _PROVIDER_DEFAULTS["openai"]
|
||||
return key, base_url, os.environ.get("CHT_MODEL", model)
|
||||
base_url, default_model, models = _PROVIDER_CONFIGS["openai"]
|
||||
model = os.environ.get("CHT_MODEL", default_model)
|
||||
return key, base_url, model, models
|
||||
return None
|
||||
|
||||
|
||||
@@ -54,7 +68,7 @@ class OpenAICompatProvider(AgentProvider):
|
||||
raise RuntimeError(
|
||||
"No API key found. Set GROQ_API_KEY or OPENAI_API_KEY."
|
||||
)
|
||||
self._api_key, self._base_url, self._model = detected
|
||||
self._api_key, self._base_url, self._model, self._models = detected
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -62,6 +76,18 @@ class OpenAICompatProvider(AgentProvider):
|
||||
return f"groq/{self._model}"
|
||||
return f"openai-compat/{self._model}"
|
||||
|
||||
@property
|
||||
def available_models(self) -> list[str]:
|
||||
return list(self._models)
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self._model
|
||||
|
||||
@model.setter
|
||||
def model(self, value: str):
|
||||
self._model = value
|
||||
|
||||
def stream(self, message: str, context: SessionContext) -> Iterator[str]:
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
@@ -19,12 +19,10 @@ from cht.agent.base import AgentProvider, FrameRef, SessionContext
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Predefined actions sent as messages with a fixed prompt
|
||||
# Predefined actions — label → verb prefix (frame ref appended by UI)
|
||||
ACTIONS: dict[str, str] = {
|
||||
"Summarize": "Summarize what happened in this recording so far. Look at the captured frames and describe the key content and any changes you notice.",
|
||||
"What changed": "Compare the captured frames in order and describe what changed between them. Focus on meaningful transitions.",
|
||||
"Key moments": "Identify the most important moments in the recording based on the frames. List them with timestamps.",
|
||||
"Describe now": "Look at the most recent frame and describe exactly what is currently on screen.",
|
||||
"Describe": "describe",
|
||||
"Answer": "answer",
|
||||
}
|
||||
|
||||
|
||||
@@ -106,6 +104,24 @@ class AgentRunner:
|
||||
except Exception:
|
||||
return "unknown"
|
||||
|
||||
@property
|
||||
def available_models(self) -> list[str]:
|
||||
try:
|
||||
return self._get_provider().available_models
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
try:
|
||||
return self._get_provider().model
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
@model.setter
|
||||
def model(self, value: str):
|
||||
self._get_provider().model = value
|
||||
|
||||
def send(
|
||||
self,
|
||||
message: str,
|
||||
|
||||
Reference in New Issue
Block a user