better agent

This commit is contained in:
2026-04-09 14:46:29 -03:00
parent ade92069c0
commit 64ecdca71e
11 changed files with 1424 additions and 434 deletions

View File

@@ -1,16 +1,28 @@
"""
Abstract base for agent providers.
Agent data model — structured messages, connections, tools, threads.
Each provider takes a user message + session context and yields response
text chunks for streaming into the UI.
Replaces the old flat AgentProvider/SessionContext with:
- Typed content blocks and messages (conversation model)
- StreamEvent types (connection output)
- AgentConnection protocol (replaces AgentProvider)
- Tool protocol
- Thread (session state, serializable)
"""
from abc import ABC, abstractmethod
from __future__ import annotations
import json
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Iterator
from typing import Iterator, Literal, Protocol, runtime_checkable
# ---------------------------------------------------------------------------
# Shared refs (used by runner mention-parsing and tools)
# ---------------------------------------------------------------------------
@dataclass
class FrameRef:
id: str # "F0001"
@@ -26,40 +38,298 @@ class TranscriptRef:
text: str
# ---------------------------------------------------------------------------
# Content blocks
# ---------------------------------------------------------------------------
@dataclass
class SessionContext:
class TextBlock:
text: str
@dataclass
class ImageBlock:
frame_id: str # "F0042"
path: Path
timestamp: float
@dataclass
class TranscriptBlock:
transcript_id: str # "T0012"
start: float
end: float
text: str
ContentBlock = TextBlock | ImageBlock | TranscriptBlock
# ---------------------------------------------------------------------------
# Messages
# ---------------------------------------------------------------------------
@dataclass
class TokenUsage:
input_tokens: int = 0
output_tokens: int = 0
def _msg_id() -> str:
return uuid.uuid4().hex[:12]
@dataclass
class UserMessage:
content: list[ContentBlock]
id: str = field(default_factory=_msg_id)
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
@dataclass
class AssistantMessage:
content: list[ContentBlock]
model: str = ""
id: str = field(default_factory=_msg_id)
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
token_usage: TokenUsage | None = None
@dataclass
class ToolUse:
tool_name: str
input: dict
id: str = field(default_factory=_msg_id)
status: Literal["pending", "running", "done", "error"] = "pending"
@dataclass
class ToolResult:
tool_use_id: str
output: str | None = None
error: str | None = None
Message = UserMessage | AssistantMessage | ToolUse | ToolResult
# ---------------------------------------------------------------------------
# Thread (session state)
# ---------------------------------------------------------------------------
@dataclass
class Thread:
messages: list[Message] = field(default_factory=list)
id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
created: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
updated: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
token_usage: TokenUsage | None = None
# ---------------------------------------------------------------------------
# Stream events (yielded by AgentConnection)
# ---------------------------------------------------------------------------
@dataclass
class TextDelta:
text: str
@dataclass
class ToolCallStart:
id: str
name: str
input: dict
@dataclass
class ToolCallEnd:
id: str
@dataclass
class Done:
stop_reason: str # "end_turn", "tool_use", "max_tokens"
@dataclass
class Error:
message: str
StreamEvent = TextDelta | ToolCallStart | ToolCallEnd | Done | Error
# ---------------------------------------------------------------------------
# Tool protocol
# ---------------------------------------------------------------------------
@runtime_checkable
class Tool(Protocol):
name: str
description: str
def input_schema(self) -> dict: ...
def run(self, input: dict, context: ToolContext) -> ToolResult: ...
@dataclass
class ToolContext:
"""Runtime context passed to tools — references to live app objects."""
session_dir: Path
frames: list[FrameRef] # all captured frames so far
duration: float # current recording duration (seconds)
mentioned_frames: list[FrameRef] = field(default_factory=list)
transcript_segments: list[TranscriptRef] = field(default_factory=list)
mentioned_transcripts: list[TranscriptRef] = field(default_factory=list)
history: list[tuple[str, str]] = field(default_factory=list) # [(role, text), ...]
frames_dir: Path
transcript_dir: Path
stream_mgr: object | None = None # StreamManager, optional
tracker: object | None = None # RecordingTracker, optional
class AgentProvider(ABC):
@abstractmethod
def stream(self, message: str, context: SessionContext) -> Iterator[str]:
"""Yield response text chunks."""
# ---------------------------------------------------------------------------
# AgentConnection protocol (replaces AgentProvider)
# ---------------------------------------------------------------------------
@runtime_checkable
class AgentConnection(Protocol):
name: str
def available_models(self) -> list[str]: ...
def get_model(self) -> str: ...
def set_model(self, model: str) -> None: ...
def prompt(
self,
messages: list[Message],
tools: list[Tool],
) -> Iterator[StreamEvent]:
"""Send messages to the model, yield stream events."""
...
@property
@abstractmethod
def name(self) -> str:
...
def cancel(self) -> None: ...
@property
@abstractmethod
def available_models(self) -> list[str]:
"""Return list of model IDs this provider supports."""
...
@property
@abstractmethod
def model(self) -> str:
...
# ---------------------------------------------------------------------------
# Thread serialization
# ---------------------------------------------------------------------------
@model.setter
@abstractmethod
def model(self, value: str):
...
class _Encoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, Path):
return str(o)
if isinstance(o, datetime):
return o.isoformat()
return super().default(o)
def _block_to_dict(b: ContentBlock) -> dict:
if isinstance(b, TextBlock):
return {"type": "text", "text": b.text}
if isinstance(b, ImageBlock):
return {"type": "image", "frame_id": b.frame_id, "path": str(b.path), "timestamp": b.timestamp}
if isinstance(b, TranscriptBlock):
return {"type": "transcript", "transcript_id": b.transcript_id, "start": b.start, "end": b.end, "text": b.text}
raise TypeError(f"Unknown block type: {type(b)}")
def _block_from_dict(d: dict) -> ContentBlock:
t = d["type"]
if t == "text":
return TextBlock(text=d["text"])
if t == "image":
return ImageBlock(frame_id=d["frame_id"], path=Path(d["path"]), timestamp=d["timestamp"])
if t == "transcript":
return TranscriptBlock(transcript_id=d["transcript_id"], start=d["start"], end=d["end"], text=d["text"])
raise ValueError(f"Unknown block type: {t}")
def _msg_to_dict(m: Message) -> dict:
if isinstance(m, UserMessage):
return {
"type": "user",
"id": m.id,
"content": [_block_to_dict(b) for b in m.content],
"timestamp": m.timestamp.isoformat(),
}
if isinstance(m, AssistantMessage):
d = {
"type": "assistant",
"id": m.id,
"content": [_block_to_dict(b) for b in m.content],
"timestamp": m.timestamp.isoformat(),
"model": m.model,
}
if m.token_usage:
d["token_usage"] = {"input_tokens": m.token_usage.input_tokens, "output_tokens": m.token_usage.output_tokens}
return d
if isinstance(m, ToolUse):
return {"type": "tool_use", "id": m.id, "tool_name": m.tool_name, "input": m.input, "status": m.status}
if isinstance(m, ToolResult):
return {"type": "tool_result", "tool_use_id": m.tool_use_id, "output": m.output, "error": m.error}
raise TypeError(f"Unknown message type: {type(m)}")
def _msg_from_dict(d: dict) -> Message:
t = d["type"]
if t == "user":
return UserMessage(
id=d["id"],
content=[_block_from_dict(b) for b in d["content"]],
timestamp=datetime.fromisoformat(d["timestamp"]),
)
if t == "assistant":
tu = None
if "token_usage" in d and d["token_usage"]:
tu = TokenUsage(**d["token_usage"])
return AssistantMessage(
id=d["id"],
content=[_block_from_dict(b) for b in d["content"]],
timestamp=datetime.fromisoformat(d["timestamp"]),
model=d.get("model", ""),
token_usage=tu,
)
if t == "tool_use":
return ToolUse(id=d["id"], tool_name=d["tool_name"], input=d["input"], status=d.get("status", "done"))
if t == "tool_result":
return ToolResult(tool_use_id=d["tool_use_id"], output=d.get("output"), error=d.get("error"))
raise ValueError(f"Unknown message type: {t}")
def thread_to_dict(thread: Thread) -> dict:
d = {
"id": thread.id,
"messages": [_msg_to_dict(m) for m in thread.messages],
"created": thread.created.isoformat(),
"updated": thread.updated.isoformat(),
}
if thread.token_usage:
d["token_usage"] = {"input_tokens": thread.token_usage.input_tokens, "output_tokens": thread.token_usage.output_tokens}
return d
def thread_from_dict(d: dict) -> Thread:
tu = None
if "token_usage" in d and d["token_usage"]:
tu = TokenUsage(**d["token_usage"])
return Thread(
id=d["id"],
messages=[_msg_from_dict(m) for m in d["messages"]],
created=datetime.fromisoformat(d["created"]),
updated=datetime.fromisoformat(d["updated"]),
token_usage=tu,
)
def save_thread(thread: Thread, session_dir: Path) -> None:
agent_dir = session_dir / "agent"
agent_dir.mkdir(parents=True, exist_ok=True)
path = agent_dir / "thread.json"
thread.updated = datetime.now(timezone.utc)
path.write_text(json.dumps(thread_to_dict(thread), indent=2, cls=_Encoder))
def load_thread(session_dir: Path) -> Thread | None:
path = session_dir / "agent" / "thread.json"
if not path.exists():
return None
try:
return thread_from_dict(json.loads(path.read_text()))
except Exception:
return None

50
cht/agent/buffer.py Normal file
View File

@@ -0,0 +1,50 @@
"""Streaming text buffer — smooth reveal of bursty network chunks.
Accumulates text from the provider and reveals it at ~60fps with a
~200ms target reveal pace, so the UI sees a smooth stream regardless
of network burst patterns.
All methods must be called from the GTK main thread.
"""
from typing import Callable
from gi.repository import GLib
class StreamingTextBuffer:
TICK_MS = 16 # ~60fps
REVEAL_TARGET_MS = 200
def __init__(self, on_reveal: Callable[[str], None]):
self._pending = ""
self._on_reveal = on_reveal
self._timer_id: int | None = None
def push(self, text: str):
self._pending += text
if self._timer_id is None:
self._timer_id = GLib.timeout_add(self.TICK_MS, self._tick)
def _tick(self) -> bool:
if not self._pending:
self._timer_id = None
return False
n = max(1, len(self._pending) * self.TICK_MS // self.REVEAL_TARGET_MS)
reveal, self._pending = self._pending[:n], self._pending[n:]
self._on_reveal(reveal)
return True
def flush(self):
if self._pending:
self._on_reveal(self._pending)
self._pending = ""
if self._timer_id is not None:
GLib.source_remove(self._timer_id)
self._timer_id = None
def cancel(self):
self._pending = ""
if self._timer_id is not None:
GLib.source_remove(self._timer_id)
self._timer_id = None

View File

@@ -0,0 +1,206 @@
"""
AgentConnection for Claude Code SDK (claude_agent_sdk).
Uses your Claude Code subscription — no direct API costs.
Truly streams via a queue bridge between the async SDK generator
and the synchronous Iterator[StreamEvent] interface.
"""
import logging
import queue
from typing import Iterator
from cht.agent.base import (
AgentConnection,
AssistantMessage,
ImageBlock,
Message,
StreamEvent,
TextBlock,
TextDelta,
Tool,
ToolCallEnd,
ToolCallStart,
ToolResult,
ToolUse,
TranscriptBlock,
UserMessage,
Done,
Error,
)
log = logging.getLogger(__name__)
SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording and analysis tool.
You help the user understand what happened during their recording session.
You have access to frame screenshots extracted from the recording. When frames are mentioned,
use the Read tool to view them. Frame timestamps are in seconds from the start of the recording.
You also have tools to search transcripts, get session info, and capture new frames.
Be concise and specific. Focus on what's visible in the frames."""
MODELS = [
"claude-sonnet-4-6",
"claude-opus-4-6",
"claude-haiku-4-5",
]
_SENTINEL = object()
def _messages_to_prompt(messages: list[Message]) -> str:
"""Flatten structured messages into a text prompt for the SDK."""
lines = []
for msg in messages:
if isinstance(msg, UserMessage):
parts = []
for b in msg.content:
if isinstance(b, TextBlock):
parts.append(b.text)
elif isinstance(b, ImageBlock):
m, s = divmod(int(b.timestamp), 60)
parts.append(f"[Frame {b.frame_id} at {m:02d}:{s:02d}{b.path}]")
elif isinstance(b, TranscriptBlock):
m1, s1 = divmod(int(b.start), 60)
m2, s2 = divmod(int(b.end), 60)
parts.append(f"[Transcript {b.transcript_id} {m1:02d}:{s1:02d}-{m2:02d}:{s2:02d}: {b.text}]")
lines.append(f"User: {' '.join(parts)}")
elif isinstance(msg, AssistantMessage):
text = " ".join(b.text for b in msg.content if isinstance(b, TextBlock))
lines.append(f"Assistant: {text}")
elif isinstance(msg, ToolUse):
lines.append(f"[Tool call: {msg.tool_name}({msg.input})]")
elif isinstance(msg, ToolResult):
out = msg.output or msg.error or ""
lines.append(f"[Tool result: {out}]")
return "\n".join(lines)
def _tool_schemas(tools: list[Tool]) -> list[str]:
"""Extract tool names for the SDK's allowed_tools parameter."""
# The Claude SDK uses allowed_tools as a list of tool name strings.
# Our custom tools are executed by the runner, not by the SDK,
# so we only pass "Read" to the SDK (for frame viewing).
return ["Read"]
class ClaudeSDKConnection:
"""AgentConnection using claude_agent_sdk — requires Claude Code CLI."""
def __init__(self, cwd: str | None = None, max_turns: int = 5, model: str = MODELS[0]):
self._cwd = cwd
self._max_turns = max_turns
self._model = model
self._cancelled = False
@property
def name(self) -> str:
return f"claude-sdk/{self._model}"
def available_models(self) -> list[str]:
return list(MODELS)
def get_model(self) -> str:
return self._model
def set_model(self, model: str) -> None:
self._model = model
def prompt(
self,
messages: list[Message],
tools: list[Tool],
) -> Iterator[StreamEvent]:
from claude_agent_sdk import (
query,
ClaudeAgentOptions,
AssistantMessage as SDKAssistantMessage,
TextBlock as SDKTextBlock,
ResultMessage,
CLINotFoundError,
CLIConnectionError,
)
prompt_text = _messages_to_prompt(messages)
self._cancelled = False
q: queue.Queue = queue.Queue()
# Determine cwd from the last UserMessage's image paths if available
cwd = self._cwd
if not cwd:
for msg in reversed(messages):
if isinstance(msg, UserMessage):
for b in msg.content:
if isinstance(b, ImageBlock):
cwd = str(b.path.parent.parent) # session_dir
break
if cwd:
break
async def _run():
try:
got_assistant_text = False
async for msg in query(
prompt=prompt_text,
options=ClaudeAgentOptions(
model=self._model,
cwd=cwd or ".",
allowed_tools=_tool_schemas(tools),
system_prompt=SYSTEM_PROMPT,
max_turns=self._max_turns,
),
):
if self._cancelled:
break
if isinstance(msg, SDKAssistantMessage):
for block in msg.content:
if isinstance(block, SDKTextBlock):
q.put(TextDelta(text=block.text))
got_assistant_text = True
elif isinstance(msg, ResultMessage):
# Only use ResultMessage.result if we got no text from AssistantMessages
if msg.result and not got_assistant_text:
q.put(TextDelta(text=msg.result))
q.put(Done(stop_reason="end_turn"))
except CLINotFoundError:
q.put(Error(
message="Claude Code CLI not found.\n"
"Install it: https://claude.ai/code\n"
"Then run `claude` once in a terminal to authenticate."
))
except CLIConnectionError as e:
if "auth" in str(e).lower() or "login" in str(e).lower() or "401" in str(e):
q.put(Error(
message="Claude Code not authenticated.\n"
"Run `claude` in a terminal and complete the login flow, then retry."
))
else:
q.put(Error(message=str(e)))
except Exception as e:
q.put(Error(message=str(e)))
finally:
q.put(_SENTINEL)
import asyncio
import threading
def _thread():
loop = asyncio.new_event_loop()
try:
loop.run_until_complete(_run())
finally:
loop.close()
t = threading.Thread(target=_thread, daemon=True, name="claude_sdk_stream")
t.start()
while True:
item = q.get()
if item is _SENTINEL:
break
yield item
def cancel(self) -> None:
self._cancelled = True

View File

@@ -1,132 +0,0 @@
"""
Agent provider using the Claude Code SDK (claude_agent_sdk).
Uses your Claude Code subscription — no direct API costs.
Passes frame paths in the prompt; Claude reads them visually via the Read tool.
"""
import logging
from typing import Iterator
import anyio
from claude_agent_sdk import query, ClaudeAgentOptions, AssistantMessage, TextBlock, ResultMessage
from claude_agent_sdk import CLINotFoundError, CLIConnectionError
from cht.agent.base import AgentProvider, SessionContext
log = logging.getLogger(__name__)
SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording and analysis tool.
You help the user understand what happened during their recording session.
You have access to frame screenshots extracted from the recording. When frames are mentioned,
use the Read tool to view them. Frame timestamps are in seconds from the start of the recording.
Be concise and specific. Focus on what's visible in the frames."""
def _build_prompt(message: str, context: SessionContext) -> str:
lines = []
# Session summary
m, s = divmod(int(context.duration), 60)
lines.append(f"Recording duration: {m:02d}:{s:02d}")
lines.append(f"Total frames captured: {len(context.frames)}")
if context.mentioned_frames:
lines.append("\nFrames:")
for f in context.mentioned_frames:
fm, fs = divmod(int(f.timestamp), 60)
lines.append(f" {f.id} at {fm:02d}:{fs:02d}{f.path}")
if context.mentioned_transcripts:
lines.append("\nTranscript:")
for t in context.mentioned_transcripts:
tm1, ts1 = divmod(int(t.start), 60)
tm2, ts2 = divmod(int(t.end), 60)
lines.append(f" {t.id} [{tm1:02d}:{ts1:02d}-{tm2:02d}:{ts2:02d}] {t.text}")
if context.history:
lines.append("\nConversation history:")
for role, text in context.history:
prefix = "User" if role == "user" else "Assistant"
lines.append(f" {prefix}: {text}")
lines.append(f"\nUser message: {message}")
return "\n".join(lines)
MODELS = [
"claude-sonnet-4-6",
"claude-opus-4-6",
"claude-haiku-4-5",
]
class ClaudeSDKProvider(AgentProvider):
"""Uses claude_agent_sdk — requires Claude Code CLI to be installed."""
def __init__(self, cwd: str | None = None, max_turns: int = 5, model: str = MODELS[0]):
self._cwd = cwd
self._max_turns = max_turns
self._model = model
@property
def name(self) -> str:
return f"claude-sdk/{self._model}"
@property
def available_models(self) -> list[str]:
return list(MODELS)
@property
def model(self) -> str:
return self._model
@model.setter
def model(self, value: str):
self._model = value
def stream(self, message: str, context: SessionContext) -> Iterator[str]:
prompt = _build_prompt(message, context)
chunks = []
async def _run():
async for msg in query(
prompt=prompt,
options=ClaudeAgentOptions(
model=self._model,
cwd=self._cwd or str(context.session_dir),
allowed_tools=["Read"],
system_prompt=SYSTEM_PROMPT,
max_turns=self._max_turns,
),
):
if isinstance(msg, AssistantMessage):
for block in msg.content:
if isinstance(block, TextBlock):
chunks.append(block.text)
elif isinstance(msg, ResultMessage):
if msg.result:
chunks.append(msg.result)
try:
import asyncio
loop = asyncio.new_event_loop()
try:
loop.run_until_complete(_run())
finally:
loop.close()
except CLINotFoundError:
raise RuntimeError(
"Claude Code CLI not found.\n"
"Install it: https://claude.ai/code\n"
"Then run `claude` once in a terminal to authenticate."
)
except CLIConnectionError as e:
if "auth" in str(e).lower() or "login" in str(e).lower() or "401" in str(e):
raise RuntimeError(
"Claude Code not authenticated.\n"
"Run `claude` in a terminal and complete the login flow, then retry."
)
raise
yield from chunks

View File

@@ -1,134 +0,0 @@
"""
Agent provider for OpenAI-compatible APIs (Groq, OpenAI, etc.).
Sends frame images as base64. Requires GROQ_API_KEY or OPENAI_API_KEY env var.
Auto-detects provider from available env keys.
"""
import base64
import logging
import os
from typing import Iterator
from cht.agent.base import AgentProvider, SessionContext, FrameRef
log = logging.getLogger(__name__)
SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording and analysis tool.
You help the user understand what happened during their recording session.
Be concise and specific. Focus on what's visible in the provided frames."""
# Provider configs: (base_url, default_model, available_models)
_PROVIDER_CONFIGS = {
"groq": (
"https://api.groq.com/openai/v1",
"meta-llama/llama-4-maverick-17b-128e-instruct",
[
"meta-llama/llama-4-maverick-17b-128e-instruct",
"meta-llama/llama-4-scout-17b-16e-instruct",
"qwen/qwen-2.5-vl-72b-instruct",
],
),
"openai": (
"https://api.openai.com/v1",
"gpt-4o",
["gpt-4o", "gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini"],
),
}
def _detect_provider() -> tuple[str, str, str, list[str]] | None:
"""Returns (api_key, base_url, model, available_models) or None."""
if key := os.environ.get("GROQ_API_KEY"):
base_url, default_model, models = _PROVIDER_CONFIGS["groq"]
model = os.environ.get("CHT_MODEL", default_model)
return key, base_url, model, models
if key := os.environ.get("OPENAI_API_KEY"):
base_url, default_model, models = _PROVIDER_CONFIGS["openai"]
model = os.environ.get("CHT_MODEL", default_model)
return key, base_url, model, models
return None
def _frame_to_image_content(frame: FrameRef) -> dict:
with open(frame.path, "rb") as f:
data = base64.standard_b64encode(f.read()).decode()
return {
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{data}"},
}
class OpenAICompatProvider(AgentProvider):
"""Uses any OpenAI-compatible API. Auto-detects from env vars."""
def __init__(self):
detected = _detect_provider()
if not detected:
raise RuntimeError(
"No API key found. Set GROQ_API_KEY or OPENAI_API_KEY."
)
self._api_key, self._base_url, self._model, self._models = detected
@property
def name(self) -> str:
if "groq" in self._base_url:
return f"groq/{self._model}"
return f"openai-compat/{self._model}"
@property
def available_models(self) -> list[str]:
return list(self._models)
@property
def model(self) -> str:
return self._model
@model.setter
def model(self, value: str):
self._model = value
def stream(self, message: str, context: SessionContext) -> Iterator[str]:
from openai import OpenAI
client = OpenAI(api_key=self._api_key, base_url=self._base_url)
# Build context header
m, s = divmod(int(context.duration), 60)
ctx_lines = [
f"Recording duration: {m:02d}:{s:02d}",
f"Total frames: {len(context.frames)}",
]
if context.mentioned_transcripts:
ctx_lines.append("\nTranscript:")
for t in context.mentioned_transcripts:
tm1, ts1 = divmod(int(t.start), 60)
tm2, ts2 = divmod(int(t.end), 60)
ctx_lines.append(f" {t.id} [{tm1:02d}:{ts1:02d}-{tm2:02d}:{ts2:02d}] {t.text}")
ctx_text = "\n".join(ctx_lines) + "\n"
frames_to_send = context.mentioned_frames
content: list[dict] = [{"type": "text", "text": ctx_text + message}]
for frame in frames_to_send:
fm, fs = divmod(int(frame.timestamp), 60)
content.append({"type": "text", "text": f"{frame.id} at {fm:02d}:{fs:02d}:"})
try:
content.append(_frame_to_image_content(frame))
except Exception as e:
log.warning("Could not encode frame %s: %s", frame.id, e)
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
for role, text in context.history:
messages.append({"role": role, "content": text})
messages.append({"role": "user", "content": content})
stream = client.chat.completions.create(
model=self._model,
messages=messages,
stream=True,
)
for chunk in stream:
delta = chunk.choices[0].delta.content
if delta:
yield delta

View File

@@ -0,0 +1,249 @@
"""
AgentConnection for OpenAI-compatible APIs (Groq, OpenAI, etc.).
Sends frame images as base64. Supports tool calls via function calling.
Requires GROQ_API_KEY or OPENAI_API_KEY env var.
"""
import base64
import json
import logging
import os
from typing import Iterator
from cht.agent.base import (
AssistantMessage,
ImageBlock,
Message,
StreamEvent,
TextBlock,
TextDelta,
Tool,
ToolCallEnd,
ToolCallStart,
ToolResult,
ToolUse,
TranscriptBlock,
UserMessage,
Done,
Error,
)
log = logging.getLogger(__name__)
SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording and analysis tool.
You help the user understand what happened during their recording session.
Be concise and specific. Focus on what's visible in the provided frames."""
_PROVIDER_CONFIGS = {
"groq": (
"https://api.groq.com/openai/v1",
"meta-llama/llama-4-maverick-17b-128e-instruct",
[
"meta-llama/llama-4-maverick-17b-128e-instruct",
"meta-llama/llama-4-scout-17b-16e-instruct",
"qwen/qwen-2.5-vl-72b-instruct",
],
),
"openai": (
"https://api.openai.com/v1",
"gpt-4o",
["gpt-4o", "gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini"],
),
}
def _detect_provider() -> tuple[str, str, str, list[str]] | None:
if key := os.environ.get("GROQ_API_KEY"):
base_url, default_model, models = _PROVIDER_CONFIGS["groq"]
model = os.environ.get("CHT_MODEL", default_model)
return key, base_url, model, models
if key := os.environ.get("OPENAI_API_KEY"):
base_url, default_model, models = _PROVIDER_CONFIGS["openai"]
model = os.environ.get("CHT_MODEL", default_model)
return key, base_url, model, models
return None
def _frame_to_base64(path) -> str | None:
try:
with open(path, "rb") as f:
return base64.standard_b64encode(f.read()).decode()
except Exception as e:
log.warning("Could not encode frame %s: %s", path, e)
return None
def _messages_to_openai(messages: list[Message]) -> list[dict]:
"""Convert structured messages to OpenAI chat format."""
result = [{"role": "system", "content": SYSTEM_PROMPT}]
for msg in messages:
if isinstance(msg, UserMessage):
content: list[dict] = []
for b in msg.content:
if isinstance(b, TextBlock):
content.append({"type": "text", "text": b.text})
elif isinstance(b, ImageBlock):
m, s = divmod(int(b.timestamp), 60)
content.append({"type": "text", "text": f"{b.frame_id} at {m:02d}:{s:02d}:"})
data = _frame_to_base64(b.path)
if data:
content.append({
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{data}"},
})
elif isinstance(b, TranscriptBlock):
m1, s1 = divmod(int(b.start), 60)
m2, s2 = divmod(int(b.end), 60)
content.append({
"type": "text",
"text": f"{b.transcript_id} [{m1:02d}:{s1:02d}-{m2:02d}:{s2:02d}] {b.text}",
})
result.append({"role": "user", "content": content})
elif isinstance(msg, AssistantMessage):
text = " ".join(b.text for b in msg.content if isinstance(b, TextBlock))
result.append({"role": "assistant", "content": text})
elif isinstance(msg, ToolUse):
result.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": msg.id,
"type": "function",
"function": {"name": msg.tool_name, "arguments": json.dumps(msg.input)},
}],
})
elif isinstance(msg, ToolResult):
result.append({
"role": "tool",
"tool_call_id": msg.tool_use_id,
"content": msg.output or msg.error or "",
})
return result
def _tools_to_openai(tools: list[Tool]) -> list[dict] | None:
if not tools:
return None
return [
{
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.input_schema(),
},
}
for t in tools
]
class OpenAIConnection:
"""AgentConnection using any OpenAI-compatible API."""
def __init__(self):
detected = _detect_provider()
if not detected:
raise RuntimeError("No API key found. Set GROQ_API_KEY or OPENAI_API_KEY.")
self._api_key, self._base_url, self._model, self._models = detected
self._cancelled = False
@property
def name(self) -> str:
if "groq" in self._base_url:
return f"groq/{self._model}"
return f"openai-compat/{self._model}"
def available_models(self) -> list[str]:
return list(self._models)
def get_model(self) -> str:
return self._model
def set_model(self, model: str) -> None:
self._model = model
def prompt(
self,
messages: list[Message],
tools: list[Tool],
) -> Iterator[StreamEvent]:
from openai import OpenAI
client = OpenAI(api_key=self._api_key, base_url=self._base_url)
self._cancelled = False
oai_messages = _messages_to_openai(messages)
oai_tools = _tools_to_openai(tools)
kwargs = {
"model": self._model,
"messages": oai_messages,
"stream": True,
}
if oai_tools:
kwargs["tools"] = oai_tools
try:
stream = client.chat.completions.create(**kwargs)
except Exception as e:
yield Error(message=str(e))
return
# Accumulate tool calls from streaming deltas
tool_calls: dict[int, dict] = {} # index → {id, name, arguments}
for chunk in stream:
if self._cancelled:
break
choice = chunk.choices[0] if chunk.choices else None
if not choice:
continue
delta = choice.delta
# Text content
if delta.content:
yield TextDelta(text=delta.content)
# Tool calls (streamed incrementally)
if delta.tool_calls:
for tc_delta in delta.tool_calls:
idx = tc_delta.index
if idx not in tool_calls:
tool_calls[idx] = {"id": "", "name": "", "arguments": ""}
if tc_delta.id:
tool_calls[idx]["id"] = tc_delta.id
if tc_delta.function:
if tc_delta.function.name:
tool_calls[idx]["name"] = tc_delta.function.name
if tc_delta.function.arguments:
tool_calls[idx]["arguments"] += tc_delta.function.arguments
# Check finish reason
if choice.finish_reason:
if choice.finish_reason == "tool_calls":
# Emit accumulated tool calls
for idx in sorted(tool_calls.keys()):
tc = tool_calls[idx]
try:
inp = json.loads(tc["arguments"]) if tc["arguments"] else {}
except json.JSONDecodeError:
inp = {}
yield ToolCallStart(id=tc["id"], name=tc["name"], input=inp)
yield ToolCallEnd(id=tc["id"])
yield Done(stop_reason="tool_use")
elif choice.finish_reason == "stop":
yield Done(stop_reason="end_turn")
elif choice.finish_reason == "length":
yield Done(stop_reason="max_tokens")
else:
yield Done(stop_reason=choice.finish_reason)
def cancel(self) -> None:
self._cancelled = True

View File

@@ -1,13 +1,13 @@
"""
Agent runner — resolves provider, parses @-mentions, dispatches messages.
Agent runner — resolves connection, parses @-mentions, dispatches messages,
executes tool loop.
Provider selection (in order):
Connection selection (in order):
1. GROQ_API_KEY → OpenAI-compat / Groq
2. OPENAI_API_KEY → OpenAI-compat / OpenAI
3. (default) → Claude Code SDK (uses CC subscription)
"""
import json
import logging
import os
import re
@@ -15,7 +15,29 @@ from pathlib import Path
from threading import Thread
from typing import Callable
from cht.agent.base import AgentProvider, FrameRef, TranscriptRef, SessionContext
from cht.agent.base import (
AssistantMessage,
FrameRef,
ImageBlock,
Message,
StreamEvent,
TextBlock,
TextDelta,
ToolCallEnd,
ToolCallStart,
ToolContext,
ToolResult,
ToolUse,
TranscriptBlock,
TranscriptRef,
UserMessage,
Done,
Error,
Thread as AgentThread,
save_thread,
load_thread,
)
from cht.agent.tools import load_frames, load_transcript, BUILTIN_TOOLS
log = logging.getLogger(__name__)
@@ -25,6 +47,8 @@ ACTIONS: dict[str, str] = {
"Answer": "answer",
}
MAX_TOOL_TURNS = 10
def check_claude_cli() -> str | None:
"""Returns None if OK, or an error string if CLI is missing/unauthenticated."""
@@ -43,16 +67,15 @@ def check_claude_cli() -> str | None:
return None
def _resolve_provider() -> AgentProvider:
def _resolve_connection():
if os.environ.get("GROQ_API_KEY") or os.environ.get("OPENAI_API_KEY"):
from cht.agent.openai_compat_provider import OpenAICompatProvider
return OpenAICompatProvider()
from cht.agent.claude_sdk_provider import ClaudeSDKProvider
return ClaudeSDKProvider()
from cht.agent.openai_connection import OpenAIConnection
return OpenAIConnection()
from cht.agent.claude_sdk_connection import ClaudeSDKConnection
return ClaudeSDKConnection()
def _expand_ref_nums(spec: str) -> list[int]:
"""Expand a ref spec like '2-6' or '2,4,6' or '2-4,6,8-10' into sorted ints."""
nums = set()
for part in spec.split(","):
part = part.strip()
@@ -71,7 +94,6 @@ def _expand_ref_nums(spec: str) -> list[int]:
def _parse_mentions(message: str, frames: list[FrameRef]) -> list[FrameRef]:
"""Extract @F references. Accepts @F1, @F2-6, @F2,4,6, @F2-4,6,8-10."""
mentioned = []
seen = set()
for match in re.finditer(r"@[Ff]([\d,\-]+)", message):
@@ -85,49 +107,7 @@ def _parse_mentions(message: str, frames: list[FrameRef]) -> list[FrameRef]:
return mentioned
def _resolve_frame_path(frames_dir: Path, raw_path: str) -> Path | None:
"""Resolve a frame path from index.json, handling mounted/remote sessions."""
p = Path(raw_path)
if p.exists():
return p
# Try relative to frames_dir (handles path prefix mismatch from remote)
local = frames_dir / p.name
if local.exists():
return local
return None
def _load_frames(frames_dir: Path) -> list[FrameRef]:
index_path = frames_dir / "index.json"
if not index_path.exists():
return []
try:
entries = json.loads(index_path.read_text())
frames = []
for e in entries:
resolved = _resolve_frame_path(frames_dir, e["path"])
if resolved:
frames.append(FrameRef(id=e["id"], path=resolved, timestamp=e["timestamp"]))
return frames
except Exception as e:
log.warning("Could not load frames index: %s", e)
return []
def _load_transcript(transcript_dir: Path) -> list[TranscriptRef]:
index_path = transcript_dir / "index.json"
if not index_path.exists():
return []
try:
entries = json.loads(index_path.read_text())
return [TranscriptRef(**e) for e in entries]
except Exception as e:
log.warning("Could not load transcript index: %s", e)
return []
def _parse_transcript_mentions(message: str, segments: list[TranscriptRef]) -> list[TranscriptRef]:
"""Extract @T references. Accepts @T1, @T2-6, @T2,4,6, @T1-3,5,7-10."""
mentioned = []
seen = set()
for match in re.finditer(r"@[Tt]([\d,\-]+)", message):
@@ -141,84 +121,191 @@ def _parse_transcript_mentions(message: str, segments: list[TranscriptRef]) -> l
return mentioned
def _build_user_message(text: str, mentioned_frames: list[FrameRef],
mentioned_transcripts: list[TranscriptRef]) -> UserMessage:
"""Build a UserMessage with content blocks from text and @-mentions."""
content: list = [TextBlock(text=text)]
for f in mentioned_frames:
content.append(ImageBlock(frame_id=f.id, path=f.path, timestamp=f.timestamp))
for t in mentioned_transcripts:
content.append(TranscriptBlock(
transcript_id=t.id, start=t.start, end=t.end, text=t.text
))
return UserMessage(content=content)
class AgentRunner:
"""Runs agent queries in a background thread, streams chunks to a callback."""
"""Runs agent queries in a background thread with tool execution loop."""
def __init__(self):
self._provider: AgentProvider | None = None
self._history: list[tuple[str, str]] = [] # (role, text)
self.include_history = False # toggled by UI
self._connection = None
self._thread: AgentThread = AgentThread()
self.include_history = False
def _get_provider(self) -> AgentProvider:
if self._provider is None:
self._provider = _resolve_provider()
log.info("Agent provider: %s", self._provider.name)
return self._provider
@property
def thread(self) -> AgentThread:
return self._thread
def _get_connection(self):
if self._connection is None:
self._connection = _resolve_connection()
log.info("Agent connection: %s", self._connection.name)
return self._connection
@property
def provider_name(self) -> str:
try:
return self._get_provider().name
return self._get_connection().name
except Exception:
return "unknown"
@property
def available_models(self) -> list[str]:
try:
return self._get_provider().available_models
return self._get_connection().available_models()
except Exception:
return []
@property
def model(self) -> str:
try:
return self._get_provider().model
return self._get_connection().get_model()
except Exception:
return ""
@model.setter
def model(self, value: str):
self._get_provider().model = value
self._get_connection().set_model(value)
def clear_history(self):
self._history.clear()
self._thread = AgentThread()
def set_thread(self, thread: AgentThread):
self._thread = thread
def load_from_session(self, session_dir: Path):
loaded = load_thread(session_dir)
if loaded:
self._thread = loaded
log.info("Loaded thread %s with %d messages", loaded.id, len(loaded.messages))
else:
self._thread = AgentThread()
def send(
self,
message: str,
stream_mgr,
tracker,
on_chunk: Callable[[str], None],
on_event: Callable[[StreamEvent], None],
on_done: Callable[[str | None], None],
):
"""Dispatch message in a background thread.
"""Dispatch message in a background thread with tool execution loop.
on_chunk(text) — called for each streamed chunk
on_done(error_or_None) — called when complete
on_event(StreamEvent) — called for each stream event (from bg thread)
on_done(error_or_None) — called when complete (from bg thread)
"""
def _run():
try:
provider = self._get_provider()
frames = _load_frames(stream_mgr.frames_dir)
connection = self._get_connection()
frames = load_frames(stream_mgr.frames_dir)
mentioned_frames = _parse_mentions(message, frames)
transcript = _load_transcript(stream_mgr.transcript_dir)
transcript = load_transcript(stream_mgr.transcript_dir)
mentioned_transcripts = _parse_transcript_mentions(message, transcript)
context = SessionContext(
# Build and append user message
user_msg = _build_user_message(message, mentioned_frames, mentioned_transcripts)
self._thread.messages.append(user_msg)
# Build tool context
tool_ctx = ToolContext(
session_dir=stream_mgr.session_dir,
frames=frames,
duration=tracker.duration if tracker else 0.0,
mentioned_frames=mentioned_frames,
transcript_segments=transcript,
mentioned_transcripts=mentioned_transcripts,
history=list(self._history) if self.include_history else [],
frames_dir=stream_mgr.frames_dir,
transcript_dir=stream_mgr.transcript_dir,
stream_mgr=stream_mgr,
tracker=tracker,
)
self._history.append(("user", message))
response_chunks = []
for chunk in provider.stream(message, context):
response_chunks.append(chunk)
on_chunk(chunk)
self._history.append(("assistant", "".join(response_chunks)))
# Tool registry
tools_by_name = {t.name: t for t in BUILTIN_TOOLS}
# Messages to send — full thread or just last message
def _get_messages():
if self.include_history:
return list(self._thread.messages)
return [self._thread.messages[-1]]
# Tool execution loop
full_text_parts = []
for _turn in range(MAX_TOOL_TURNS):
msgs = _get_messages()
stop_reason = None
for event in connection.prompt(msgs, BUILTIN_TOOLS):
on_event(event)
if isinstance(event, TextDelta):
full_text_parts.append(event.text)
elif isinstance(event, ToolCallStart):
tool_use = ToolUse(
id=event.id,
tool_name=event.name,
input=event.input,
status="running",
)
self._thread.messages.append(tool_use)
elif isinstance(event, ToolCallEnd):
# Execute the tool
tool = tools_by_name.get(event.id)
# Find the ToolUse by id
tool_use_msg = None
for m in reversed(self._thread.messages):
if isinstance(m, ToolUse) and m.id == event.id:
tool_use_msg = m
break
if tool_use_msg:
tool_impl = tools_by_name.get(tool_use_msg.tool_name)
if tool_impl:
result = tool_impl.run(tool_use_msg.input, tool_ctx)
result.tool_use_id = tool_use_msg.id
tool_use_msg.status = "error" if result.error else "done"
else:
result = ToolResult(
tool_use_id=tool_use_msg.id,
error=f"Unknown tool: {tool_use_msg.tool_name}",
)
tool_use_msg.status = "error"
self._thread.messages.append(result)
on_event(result)
elif isinstance(event, Done):
stop_reason = event.stop_reason
break
elif isinstance(event, Error):
on_done(event.message)
return
# Build assistant message from accumulated text
if full_text_parts:
asst_msg = AssistantMessage(
content=[TextBlock(text="".join(full_text_parts))],
model=connection.get_model(),
)
self._thread.messages.append(asst_msg)
if stop_reason != "tool_use":
break
# Reset text for next turn (tool loop continues)
full_text_parts = []
# Save thread
save_thread(self._thread, stream_mgr.session_dir)
on_done(None)
except Exception as e:
log.error("Agent error: %s", e)
on_done(str(e))

201
cht/agent/tools.py Normal file
View File

@@ -0,0 +1,201 @@
"""Built-in agent tools — ReadFrame, SearchTranscript, GetSessionInfo, CaptureFrame.
Also contains shared data-loading functions (moved from runner.py).
"""
import json
import logging
from pathlib import Path
from cht.agent.base import FrameRef, TranscriptRef, ToolContext, ToolResult
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Data loading (shared by tools and runner)
# ---------------------------------------------------------------------------
def _resolve_frame_path(frames_dir: Path, raw_path: str) -> Path | None:
p = Path(raw_path)
if p.exists():
return p
local = frames_dir / p.name
if local.exists():
return local
return None
def load_frames(frames_dir: Path) -> list[FrameRef]:
index_path = frames_dir / "index.json"
if not index_path.exists():
return []
try:
entries = json.loads(index_path.read_text())
frames = []
for e in entries:
resolved = _resolve_frame_path(frames_dir, e["path"])
if resolved:
frames.append(FrameRef(id=e["id"], path=resolved, timestamp=e["timestamp"]))
return frames
except Exception as e:
log.warning("Could not load frames index: %s", e)
return []
def load_transcript(transcript_dir: Path) -> list[TranscriptRef]:
index_path = transcript_dir / "index.json"
if not index_path.exists():
return []
try:
entries = json.loads(index_path.read_text())
return [TranscriptRef(**e) for e in entries]
except Exception as e:
log.warning("Could not load transcript index: %s", e)
return []
# ---------------------------------------------------------------------------
# Tool implementations
# ---------------------------------------------------------------------------
class ReadFrameTool:
name = "read_frame"
description = "Read frame screenshots by ID. Returns file paths for visual inspection."
def input_schema(self) -> dict:
return {
"type": "object",
"properties": {
"frame_ids": {
"type": "array",
"items": {"type": "string"},
"description": "Frame IDs like ['F0001', 'F0003']",
}
},
"required": ["frame_ids"],
}
def run(self, input: dict, context: ToolContext) -> ToolResult:
frame_ids = input.get("frame_ids", [])
frames = load_frames(context.frames_dir)
by_id = {f.id: f for f in frames}
lines = []
for fid in frame_ids:
f = by_id.get(fid)
if f:
m, s = divmod(int(f.timestamp), 60)
lines.append(f"{f.id} at {m:02d}:{s:02d}{f.path}")
else:
lines.append(f"{fid}: not found")
return ToolResult(tool_use_id="", output="\n".join(lines))
class SearchTranscriptTool:
name = "search_transcript"
description = "Search transcript segments by text substring and/or time range."
def input_schema(self) -> dict:
return {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Text substring to search for (optional)"},
"start": {"type": "number", "description": "Start time in seconds (optional)"},
"end": {"type": "number", "description": "End time in seconds (optional)"},
},
}
def run(self, input: dict, context: ToolContext) -> ToolResult:
segments = load_transcript(context.transcript_dir)
query = input.get("query", "").lower()
start = input.get("start")
end = input.get("end")
matches = []
for seg in segments:
if start is not None and seg.end < start:
continue
if end is not None and seg.start > end:
continue
if query and query not in seg.text.lower():
continue
matches.append(seg)
if not matches:
return ToolResult(tool_use_id="", output="No matching transcript segments found.")
lines = []
for seg in matches:
m1, s1 = divmod(int(seg.start), 60)
m2, s2 = divmod(int(seg.end), 60)
lines.append(f"{seg.id} [{m1:02d}:{s1:02d}-{m2:02d}:{s2:02d}] {seg.text}")
return ToolResult(tool_use_id="", output="\n".join(lines))
class GetSessionInfoTool:
name = "get_session_info"
description = "Get recording session information: duration, frame count, segment list."
def input_schema(self) -> dict:
return {"type": "object", "properties": {}}
def run(self, input: dict, context: ToolContext) -> ToolResult:
frames = load_frames(context.frames_dir)
segments = load_transcript(context.transcript_dir)
duration = 0.0
if context.tracker:
duration = getattr(context.tracker, "duration", 0.0)
m, s = divmod(int(duration), 60)
lines = [
f"Recording duration: {m:02d}:{s:02d}",
f"Frames captured: {len(frames)}",
f"Transcript segments: {len(segments)}",
]
# List recording segments from session dir
stream_dir = context.session_dir / "stream"
if stream_dir.exists():
recordings = sorted(stream_dir.glob("recording_*.mp4"))
lines.append(f"Recording files: {len(recordings)}")
for rec in recordings:
lines.append(f" {rec.name}")
return ToolResult(tool_use_id="", output="\n".join(lines))
class CaptureFrameTool:
name = "capture_frame"
description = "Capture a frame at the current recording position."
def input_schema(self) -> dict:
return {"type": "object", "properties": {}}
def run(self, input: dict, context: ToolContext) -> ToolResult:
mgr = context.stream_mgr
if mgr is None:
return ToolResult(tool_use_id="", error="No active stream manager")
if getattr(mgr, "readonly", False):
return ToolResult(tool_use_id="", error="Session is read-only, cannot capture")
import threading
result = {"done": False, "error": None}
event = threading.Event()
def _on_frames(frames):
result["done"] = True
event.set()
try:
mgr.capture_now(on_new_frames=_on_frames)
event.wait(timeout=10)
if not result["done"]:
return ToolResult(tool_use_id="", error="Capture timed out")
return ToolResult(tool_use_id="", output="Frame captured successfully.")
except Exception as e:
return ToolResult(tool_use_id="", error=str(e))
# All built-in tools
BUILTIN_TOOLS = [ReadFrameTool(), SearchTranscriptTool(), GetSessionInfoTool(), CaptureFrameTool()]