some changes

This commit is contained in:
2026-04-02 21:08:17 -03:00
parent 76ff720906
commit 8c1138c746
8 changed files with 1245 additions and 26 deletions

View File

@@ -36,3 +36,19 @@ class AgentProvider(ABC):
@abstractmethod
def name(self) -> str:
...
@property
@abstractmethod
def available_models(self) -> list[str]:
"""Return list of model IDs this provider supports."""
...
@property
@abstractmethod
def model(self) -> str:
...
@model.setter
@abstractmethod
def model(self, value: str):
...

View File

@@ -51,16 +51,36 @@ def _build_prompt(message: str, context: SessionContext) -> str:
return "\n".join(lines)
MODELS = [
"claude-sonnet-4-6",
"claude-opus-4-6",
"claude-haiku-4-5",
]
class ClaudeSDKProvider(AgentProvider):
"""Uses claude_agent_sdk — requires Claude Code CLI to be installed."""
def __init__(self, cwd: str | None = None, max_turns: int = 5):
def __init__(self, cwd: str | None = None, max_turns: int = 5, model: str = MODELS[0]):
self._cwd = cwd
self._max_turns = max_turns
self._model = model
@property
def name(self) -> str:
return "claude-code-sdk"
return f"claude-sdk/{self._model}"
@property
def available_models(self) -> list[str]:
return list(MODELS)
@property
def model(self) -> str:
return self._model
@model.setter
def model(self, value: str):
self._model = value
def stream(self, message: str, context: SessionContext) -> Iterator[str]:
prompt = _build_prompt(message, context)
@@ -70,6 +90,7 @@ class ClaudeSDKProvider(AgentProvider):
async for msg in query(
prompt=prompt,
options=ClaudeAgentOptions(
model=self._model,
cwd=self._cwd or str(context.session_dir),
allowed_tools=["Read"],
system_prompt=SYSTEM_PROMPT,

View File

@@ -18,21 +18,35 @@ SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording
You help the user understand what happened during their recording session.
Be concise and specific. Focus on what's visible in the provided frames."""
# Default models per provider
_PROVIDER_DEFAULTS = {
"groq": ("https://api.groq.com/openai/v1", "meta-llama/llama-4-maverick-17b-128e-instruct"),
"openai": ("https://api.openai.com/v1", "gpt-4o"),
# Provider configs: (base_url, default_model, available_models)
_PROVIDER_CONFIGS = {
"groq": (
"https://api.groq.com/openai/v1",
"meta-llama/llama-4-maverick-17b-128e-instruct",
[
"meta-llama/llama-4-maverick-17b-128e-instruct",
"meta-llama/llama-4-scout-17b-16e-instruct",
"qwen/qwen-2.5-vl-72b-instruct",
],
),
"openai": (
"https://api.openai.com/v1",
"gpt-4o",
["gpt-4o", "gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini"],
),
}
def _detect_provider() -> tuple[str, str, str] | None:
"""Returns (api_key, base_url, model) or None if no key found."""
def _detect_provider() -> tuple[str, str, str, list[str]] | None:
"""Returns (api_key, base_url, model, available_models) or None."""
if key := os.environ.get("GROQ_API_KEY"):
base_url, model = _PROVIDER_DEFAULTS["groq"]
return key, base_url, os.environ.get("CHT_MODEL", model)
base_url, default_model, models = _PROVIDER_CONFIGS["groq"]
model = os.environ.get("CHT_MODEL", default_model)
return key, base_url, model, models
if key := os.environ.get("OPENAI_API_KEY"):
base_url, model = _PROVIDER_DEFAULTS["openai"]
return key, base_url, os.environ.get("CHT_MODEL", model)
base_url, default_model, models = _PROVIDER_CONFIGS["openai"]
model = os.environ.get("CHT_MODEL", default_model)
return key, base_url, model, models
return None
@@ -54,7 +68,7 @@ class OpenAICompatProvider(AgentProvider):
raise RuntimeError(
"No API key found. Set GROQ_API_KEY or OPENAI_API_KEY."
)
self._api_key, self._base_url, self._model = detected
self._api_key, self._base_url, self._model, self._models = detected
@property
def name(self) -> str:
@@ -62,6 +76,18 @@ class OpenAICompatProvider(AgentProvider):
return f"groq/{self._model}"
return f"openai-compat/{self._model}"
@property
def available_models(self) -> list[str]:
return list(self._models)
@property
def model(self) -> str:
return self._model
@model.setter
def model(self, value: str):
self._model = value
def stream(self, message: str, context: SessionContext) -> Iterator[str]:
from openai import OpenAI

View File

@@ -19,12 +19,10 @@ from cht.agent.base import AgentProvider, FrameRef, SessionContext
log = logging.getLogger(__name__)
# Predefined actions sent as messages with a fixed prompt
# Predefined actions — label → verb prefix (frame ref appended by UI)
ACTIONS: dict[str, str] = {
"Summarize": "Summarize what happened in this recording so far. Look at the captured frames and describe the key content and any changes you notice.",
"What changed": "Compare the captured frames in order and describe what changed between them. Focus on meaningful transitions.",
"Key moments": "Identify the most important moments in the recording based on the frames. List them with timestamps.",
"Describe now": "Look at the most recent frame and describe exactly what is currently on screen.",
"Describe": "describe",
"Answer": "answer",
}
@@ -106,6 +104,24 @@ class AgentRunner:
except Exception:
return "unknown"
@property
def available_models(self) -> list[str]:
try:
return self._get_provider().available_models
except Exception:
return []
@property
def model(self) -> str:
try:
return self._get_provider().model
except Exception:
return ""
@model.setter
def model(self, value: str):
self._get_provider().model = value
def send(
self,
message: str,