Files
mitus/cht/agent/openai_connection.py
2026-04-09 14:46:29 -03:00

250 lines
8.1 KiB
Python

"""
AgentConnection for OpenAI-compatible APIs (Groq, OpenAI, etc.).
Sends frame images as base64. Supports tool calls via function calling.
Requires GROQ_API_KEY or OPENAI_API_KEY env var.
"""
import base64
import json
import logging
import os
from typing import Iterator
from cht.agent.base import (
AssistantMessage,
ImageBlock,
Message,
StreamEvent,
TextBlock,
TextDelta,
Tool,
ToolCallEnd,
ToolCallStart,
ToolResult,
ToolUse,
TranscriptBlock,
UserMessage,
Done,
Error,
)
log = logging.getLogger(__name__)
SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording and analysis tool.
You help the user understand what happened during their recording session.
Be concise and specific. Focus on what's visible in the provided frames."""
_PROVIDER_CONFIGS = {
"groq": (
"https://api.groq.com/openai/v1",
"meta-llama/llama-4-maverick-17b-128e-instruct",
[
"meta-llama/llama-4-maverick-17b-128e-instruct",
"meta-llama/llama-4-scout-17b-16e-instruct",
"qwen/qwen-2.5-vl-72b-instruct",
],
),
"openai": (
"https://api.openai.com/v1",
"gpt-4o",
["gpt-4o", "gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini"],
),
}
def _detect_provider() -> tuple[str, str, str, list[str]] | None:
if key := os.environ.get("GROQ_API_KEY"):
base_url, default_model, models = _PROVIDER_CONFIGS["groq"]
model = os.environ.get("CHT_MODEL", default_model)
return key, base_url, model, models
if key := os.environ.get("OPENAI_API_KEY"):
base_url, default_model, models = _PROVIDER_CONFIGS["openai"]
model = os.environ.get("CHT_MODEL", default_model)
return key, base_url, model, models
return None
def _frame_to_base64(path) -> str | None:
try:
with open(path, "rb") as f:
return base64.standard_b64encode(f.read()).decode()
except Exception as e:
log.warning("Could not encode frame %s: %s", path, e)
return None
def _messages_to_openai(messages: list[Message]) -> list[dict]:
"""Convert structured messages to OpenAI chat format."""
result = [{"role": "system", "content": SYSTEM_PROMPT}]
for msg in messages:
if isinstance(msg, UserMessage):
content: list[dict] = []
for b in msg.content:
if isinstance(b, TextBlock):
content.append({"type": "text", "text": b.text})
elif isinstance(b, ImageBlock):
m, s = divmod(int(b.timestamp), 60)
content.append({"type": "text", "text": f"{b.frame_id} at {m:02d}:{s:02d}:"})
data = _frame_to_base64(b.path)
if data:
content.append({
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{data}"},
})
elif isinstance(b, TranscriptBlock):
m1, s1 = divmod(int(b.start), 60)
m2, s2 = divmod(int(b.end), 60)
content.append({
"type": "text",
"text": f"{b.transcript_id} [{m1:02d}:{s1:02d}-{m2:02d}:{s2:02d}] {b.text}",
})
result.append({"role": "user", "content": content})
elif isinstance(msg, AssistantMessage):
text = " ".join(b.text for b in msg.content if isinstance(b, TextBlock))
result.append({"role": "assistant", "content": text})
elif isinstance(msg, ToolUse):
result.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": msg.id,
"type": "function",
"function": {"name": msg.tool_name, "arguments": json.dumps(msg.input)},
}],
})
elif isinstance(msg, ToolResult):
result.append({
"role": "tool",
"tool_call_id": msg.tool_use_id,
"content": msg.output or msg.error or "",
})
return result
def _tools_to_openai(tools: list[Tool]) -> list[dict] | None:
if not tools:
return None
return [
{
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.input_schema(),
},
}
for t in tools
]
class OpenAIConnection:
"""AgentConnection using any OpenAI-compatible API."""
def __init__(self):
detected = _detect_provider()
if not detected:
raise RuntimeError("No API key found. Set GROQ_API_KEY or OPENAI_API_KEY.")
self._api_key, self._base_url, self._model, self._models = detected
self._cancelled = False
@property
def name(self) -> str:
if "groq" in self._base_url:
return f"groq/{self._model}"
return f"openai-compat/{self._model}"
def available_models(self) -> list[str]:
return list(self._models)
def get_model(self) -> str:
return self._model
def set_model(self, model: str) -> None:
self._model = model
def prompt(
self,
messages: list[Message],
tools: list[Tool],
) -> Iterator[StreamEvent]:
from openai import OpenAI
client = OpenAI(api_key=self._api_key, base_url=self._base_url)
self._cancelled = False
oai_messages = _messages_to_openai(messages)
oai_tools = _tools_to_openai(tools)
kwargs = {
"model": self._model,
"messages": oai_messages,
"stream": True,
}
if oai_tools:
kwargs["tools"] = oai_tools
try:
stream = client.chat.completions.create(**kwargs)
except Exception as e:
yield Error(message=str(e))
return
# Accumulate tool calls from streaming deltas
tool_calls: dict[int, dict] = {} # index → {id, name, arguments}
for chunk in stream:
if self._cancelled:
break
choice = chunk.choices[0] if chunk.choices else None
if not choice:
continue
delta = choice.delta
# Text content
if delta.content:
yield TextDelta(text=delta.content)
# Tool calls (streamed incrementally)
if delta.tool_calls:
for tc_delta in delta.tool_calls:
idx = tc_delta.index
if idx not in tool_calls:
tool_calls[idx] = {"id": "", "name": "", "arguments": ""}
if tc_delta.id:
tool_calls[idx]["id"] = tc_delta.id
if tc_delta.function:
if tc_delta.function.name:
tool_calls[idx]["name"] = tc_delta.function.name
if tc_delta.function.arguments:
tool_calls[idx]["arguments"] += tc_delta.function.arguments
# Check finish reason
if choice.finish_reason:
if choice.finish_reason == "tool_calls":
# Emit accumulated tool calls
for idx in sorted(tool_calls.keys()):
tc = tool_calls[idx]
try:
inp = json.loads(tc["arguments"]) if tc["arguments"] else {}
except json.JSONDecodeError:
inp = {}
yield ToolCallStart(id=tc["id"], name=tc["name"], input=inp)
yield ToolCallEnd(id=tc["id"])
yield Done(stop_reason="tool_use")
elif choice.finish_reason == "stop":
yield Done(stop_reason="end_turn")
elif choice.finish_reason == "length":
yield Done(stop_reason="max_tokens")
else:
yield Done(stop_reason=choice.finish_reason)
def cancel(self) -> None:
self._cancelled = True