""" AgentConnection for OpenAI-compatible APIs (Groq, OpenAI, etc.). Sends frame images as base64. Supports tool calls via function calling. Requires GROQ_API_KEY or OPENAI_API_KEY env var. """ import base64 import json import logging import os from typing import Iterator from cht.agent.base import ( AssistantMessage, ImageBlock, Message, StreamEvent, TextBlock, TextDelta, Tool, ToolCallEnd, ToolCallStart, ToolResult, ToolUse, TranscriptBlock, UserMessage, Done, Error, ) log = logging.getLogger(__name__) _PROVIDER_CONFIGS = { "groq": ( "https://api.groq.com/openai/v1", "meta-llama/llama-4-maverick-17b-128e-instruct", [ "meta-llama/llama-4-maverick-17b-128e-instruct", "meta-llama/llama-4-scout-17b-16e-instruct", "qwen/qwen-2.5-vl-72b-instruct", ], ), "openai": ( "https://api.openai.com/v1", "gpt-4o", ["gpt-4o", "gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini"], ), } def _detect_provider() -> tuple[str, str, str, list[str]] | None: if key := os.environ.get("GROQ_API_KEY"): base_url, default_model, models = _PROVIDER_CONFIGS["groq"] model = os.environ.get("CHT_MODEL", default_model) return key, base_url, model, models if key := os.environ.get("OPENAI_API_KEY"): base_url, default_model, models = _PROVIDER_CONFIGS["openai"] model = os.environ.get("CHT_MODEL", default_model) return key, base_url, model, models return None def _frame_to_base64(path) -> str | None: try: with open(path, "rb") as f: return base64.standard_b64encode(f.read()).decode() except Exception as e: log.warning("Could not encode frame %s: %s", path, e) return None def _messages_to_openai(messages: list[Message]) -> list[dict]: """Convert structured messages to OpenAI chat format.""" result: list[dict] = [] for msg in messages: if isinstance(msg, UserMessage): content: list[dict] = [] for b in msg.content: if isinstance(b, TextBlock): content.append({"type": "text", "text": b.text}) elif isinstance(b, ImageBlock): m, s = divmod(int(b.timestamp), 60) content.append({"type": "text", "text": f"{b.frame_id} at {m:02d}:{s:02d}:"}) data = _frame_to_base64(b.path) if data: content.append({ "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{data}"}, }) elif isinstance(b, TranscriptBlock): m1, s1 = divmod(int(b.start), 60) m2, s2 = divmod(int(b.end), 60) content.append({ "type": "text", "text": f"{b.transcript_id} [{m1:02d}:{s1:02d}-{m2:02d}:{s2:02d}] {b.text}", }) result.append({"role": "user", "content": content}) elif isinstance(msg, AssistantMessage): text = " ".join(b.text for b in msg.content if isinstance(b, TextBlock)) result.append({"role": "assistant", "content": text}) elif isinstance(msg, ToolUse): result.append({ "role": "assistant", "content": None, "tool_calls": [{ "id": msg.id, "type": "function", "function": {"name": msg.tool_name, "arguments": json.dumps(msg.input)}, }], }) elif isinstance(msg, ToolResult): result.append({ "role": "tool", "tool_call_id": msg.tool_use_id, "content": msg.output or msg.error or "", }) return result def _tools_to_openai(tools: list[Tool]) -> list[dict] | None: if not tools: return None return [ { "type": "function", "function": { "name": t.name, "description": t.description, "parameters": t.input_schema(), }, } for t in tools ] class OpenAIConnection: """AgentConnection using any OpenAI-compatible API.""" def __init__(self): detected = _detect_provider() if not detected: raise RuntimeError("No API key found. Set GROQ_API_KEY or OPENAI_API_KEY.") self._api_key, self._base_url, self._model, self._models = detected self._cancelled = False @property def name(self) -> str: if "groq" in self._base_url: return f"groq/{self._model}" return f"openai-compat/{self._model}" def available_models(self) -> list[str]: return list(self._models) def get_model(self) -> str: return self._model def set_model(self, model: str) -> None: self._model = model def prompt( self, messages: list[Message], tools: list[Tool], ) -> Iterator[StreamEvent]: from openai import OpenAI client = OpenAI(api_key=self._api_key, base_url=self._base_url) self._cancelled = False oai_messages = _messages_to_openai(messages) oai_tools = _tools_to_openai(tools) kwargs = { "model": self._model, "messages": oai_messages, "stream": True, } if oai_tools: kwargs["tools"] = oai_tools try: stream = client.chat.completions.create(**kwargs) except Exception as e: yield Error(message=str(e)) return # Accumulate tool calls from streaming deltas tool_calls: dict[int, dict] = {} # index → {id, name, arguments} for chunk in stream: if self._cancelled: break choice = chunk.choices[0] if chunk.choices else None if not choice: continue delta = choice.delta # Text content if delta.content: yield TextDelta(text=delta.content) # Tool calls (streamed incrementally) if delta.tool_calls: for tc_delta in delta.tool_calls: idx = tc_delta.index if idx not in tool_calls: tool_calls[idx] = {"id": "", "name": "", "arguments": ""} if tc_delta.id: tool_calls[idx]["id"] = tc_delta.id if tc_delta.function: if tc_delta.function.name: tool_calls[idx]["name"] = tc_delta.function.name if tc_delta.function.arguments: tool_calls[idx]["arguments"] += tc_delta.function.arguments # Check finish reason if choice.finish_reason: if choice.finish_reason == "tool_calls": # Emit accumulated tool calls for idx in sorted(tool_calls.keys()): tc = tool_calls[idx] try: inp = json.loads(tc["arguments"]) if tc["arguments"] else {} except json.JSONDecodeError: inp = {} yield ToolCallStart(id=tc["id"], name=tc["name"], input=inp) yield ToolCallEnd(id=tc["id"]) yield Done(stop_reason="tool_use") elif choice.finish_reason == "stop": yield Done(stop_reason="end_turn") elif choice.finish_reason == "length": yield Done(stop_reason="max_tokens") else: yield Done(stop_reason=choice.finish_reason) def cancel(self) -> None: self._cancelled = True