Files
mitus/cht/agent/claude_sdk_connection.py
2026-04-09 14:46:29 -03:00

207 lines
6.9 KiB
Python

"""
AgentConnection for Claude Code SDK (claude_agent_sdk).
Uses your Claude Code subscription — no direct API costs.
Truly streams via a queue bridge between the async SDK generator
and the synchronous Iterator[StreamEvent] interface.
"""
import logging
import queue
from typing import Iterator
from cht.agent.base import (
AgentConnection,
AssistantMessage,
ImageBlock,
Message,
StreamEvent,
TextBlock,
TextDelta,
Tool,
ToolCallEnd,
ToolCallStart,
ToolResult,
ToolUse,
TranscriptBlock,
UserMessage,
Done,
Error,
)
log = logging.getLogger(__name__)
SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording and analysis tool.
You help the user understand what happened during their recording session.
You have access to frame screenshots extracted from the recording. When frames are mentioned,
use the Read tool to view them. Frame timestamps are in seconds from the start of the recording.
You also have tools to search transcripts, get session info, and capture new frames.
Be concise and specific. Focus on what's visible in the frames."""
MODELS = [
"claude-sonnet-4-6",
"claude-opus-4-6",
"claude-haiku-4-5",
]
_SENTINEL = object()
def _messages_to_prompt(messages: list[Message]) -> str:
"""Flatten structured messages into a text prompt for the SDK."""
lines = []
for msg in messages:
if isinstance(msg, UserMessage):
parts = []
for b in msg.content:
if isinstance(b, TextBlock):
parts.append(b.text)
elif isinstance(b, ImageBlock):
m, s = divmod(int(b.timestamp), 60)
parts.append(f"[Frame {b.frame_id} at {m:02d}:{s:02d}{b.path}]")
elif isinstance(b, TranscriptBlock):
m1, s1 = divmod(int(b.start), 60)
m2, s2 = divmod(int(b.end), 60)
parts.append(f"[Transcript {b.transcript_id} {m1:02d}:{s1:02d}-{m2:02d}:{s2:02d}: {b.text}]")
lines.append(f"User: {' '.join(parts)}")
elif isinstance(msg, AssistantMessage):
text = " ".join(b.text for b in msg.content if isinstance(b, TextBlock))
lines.append(f"Assistant: {text}")
elif isinstance(msg, ToolUse):
lines.append(f"[Tool call: {msg.tool_name}({msg.input})]")
elif isinstance(msg, ToolResult):
out = msg.output or msg.error or ""
lines.append(f"[Tool result: {out}]")
return "\n".join(lines)
def _tool_schemas(tools: list[Tool]) -> list[str]:
"""Extract tool names for the SDK's allowed_tools parameter."""
# The Claude SDK uses allowed_tools as a list of tool name strings.
# Our custom tools are executed by the runner, not by the SDK,
# so we only pass "Read" to the SDK (for frame viewing).
return ["Read"]
class ClaudeSDKConnection:
"""AgentConnection using claude_agent_sdk — requires Claude Code CLI."""
def __init__(self, cwd: str | None = None, max_turns: int = 5, model: str = MODELS[0]):
self._cwd = cwd
self._max_turns = max_turns
self._model = model
self._cancelled = False
@property
def name(self) -> str:
return f"claude-sdk/{self._model}"
def available_models(self) -> list[str]:
return list(MODELS)
def get_model(self) -> str:
return self._model
def set_model(self, model: str) -> None:
self._model = model
def prompt(
self,
messages: list[Message],
tools: list[Tool],
) -> Iterator[StreamEvent]:
from claude_agent_sdk import (
query,
ClaudeAgentOptions,
AssistantMessage as SDKAssistantMessage,
TextBlock as SDKTextBlock,
ResultMessage,
CLINotFoundError,
CLIConnectionError,
)
prompt_text = _messages_to_prompt(messages)
self._cancelled = False
q: queue.Queue = queue.Queue()
# Determine cwd from the last UserMessage's image paths if available
cwd = self._cwd
if not cwd:
for msg in reversed(messages):
if isinstance(msg, UserMessage):
for b in msg.content:
if isinstance(b, ImageBlock):
cwd = str(b.path.parent.parent) # session_dir
break
if cwd:
break
async def _run():
try:
got_assistant_text = False
async for msg in query(
prompt=prompt_text,
options=ClaudeAgentOptions(
model=self._model,
cwd=cwd or ".",
allowed_tools=_tool_schemas(tools),
system_prompt=SYSTEM_PROMPT,
max_turns=self._max_turns,
),
):
if self._cancelled:
break
if isinstance(msg, SDKAssistantMessage):
for block in msg.content:
if isinstance(block, SDKTextBlock):
q.put(TextDelta(text=block.text))
got_assistant_text = True
elif isinstance(msg, ResultMessage):
# Only use ResultMessage.result if we got no text from AssistantMessages
if msg.result and not got_assistant_text:
q.put(TextDelta(text=msg.result))
q.put(Done(stop_reason="end_turn"))
except CLINotFoundError:
q.put(Error(
message="Claude Code CLI not found.\n"
"Install it: https://claude.ai/code\n"
"Then run `claude` once in a terminal to authenticate."
))
except CLIConnectionError as e:
if "auth" in str(e).lower() or "login" in str(e).lower() or "401" in str(e):
q.put(Error(
message="Claude Code not authenticated.\n"
"Run `claude` in a terminal and complete the login flow, then retry."
))
else:
q.put(Error(message=str(e)))
except Exception as e:
q.put(Error(message=str(e)))
finally:
q.put(_SENTINEL)
import asyncio
import threading
def _thread():
loop = asyncio.new_event_loop()
try:
loop.run_until_complete(_run())
finally:
loop.close()
t = threading.Thread(target=_thread, daemon=True, name="claude_sdk_stream")
t.start()
while True:
item = q.get()
if item is _SENTINEL:
break
yield item
def cancel(self) -> None:
self._cancelled = True