Files
mitus/cht/agent/claude_sdk_connection.py
2026-04-09 14:58:15 -03:00

204 lines
7.0 KiB
Python

"""
AgentConnection for Claude Code SDK (claude_agent_sdk).
Uses your Claude Code subscription — no direct API costs.
Truly streams via a queue bridge between the async SDK generator
and the synchronous Iterator[StreamEvent] interface.
"""
import logging
import queue
from typing import Iterator
from cht.agent.base import (
AgentConnection,
AssistantMessage,
ImageBlock,
Message,
StreamEvent,
TextBlock,
TextDelta,
Tool,
ToolCallEnd,
ToolCallStart,
ToolResult,
ToolUse,
TranscriptBlock,
UserMessage,
Done,
Error,
)
log = logging.getLogger(__name__)
SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording and analysis tool.
You help the user understand what happened during their recording session.
You have access to frame screenshots extracted from the recording. When frames are mentioned,
use the Read tool to view them. Frame timestamps are in seconds from the start of the recording.
You can use any available tools including WebFetch and WebSearch when the user asks you to
look something up. Use them freely — all tools are pre-authorized.
Your primary role is description and analysis, not code generation. Be concise and specific.
Focus on what's visible in the frames and what's in the transcript."""
MODELS = [
"claude-sonnet-4-6",
"claude-opus-4-6",
"claude-haiku-4-5",
]
_SENTINEL = object()
def _messages_to_prompt(messages: list[Message]) -> str:
"""Flatten structured messages into a text prompt for the SDK."""
lines = []
for msg in messages:
if isinstance(msg, UserMessage):
parts = []
for b in msg.content:
if isinstance(b, TextBlock):
parts.append(b.text)
elif isinstance(b, ImageBlock):
m, s = divmod(int(b.timestamp), 60)
parts.append(f"[Frame {b.frame_id} at {m:02d}:{s:02d}{b.path}]")
elif isinstance(b, TranscriptBlock):
m1, s1 = divmod(int(b.start), 60)
m2, s2 = divmod(int(b.end), 60)
parts.append(f"[Transcript {b.transcript_id} {m1:02d}:{s1:02d}-{m2:02d}:{s2:02d}: {b.text}]")
lines.append(f"User: {' '.join(parts)}")
elif isinstance(msg, AssistantMessage):
text = " ".join(b.text for b in msg.content if isinstance(b, TextBlock))
lines.append(f"Assistant: {text}")
elif isinstance(msg, ToolUse):
lines.append(f"[Tool call: {msg.tool_name}({msg.input})]")
elif isinstance(msg, ToolResult):
out = msg.output or msg.error or ""
lines.append(f"[Tool result: {out}]")
return "\n".join(lines)
class ClaudeSDKConnection:
"""AgentConnection using claude_agent_sdk — requires Claude Code CLI."""
def __init__(self, cwd: str | None = None, max_turns: int | None = None,
model: str = MODELS[0], permission_mode: str | None = None):
from cht.config import AGENT_PERMISSION_MODE, AGENT_MAX_TURNS
self._cwd = cwd
self._max_turns = max_turns or AGENT_MAX_TURNS
self._model = model
self._permission_mode = permission_mode or AGENT_PERMISSION_MODE
self._cancelled = False
@property
def name(self) -> str:
return f"claude-sdk/{self._model}"
def available_models(self) -> list[str]:
return list(MODELS)
def get_model(self) -> str:
return self._model
def set_model(self, model: str) -> None:
self._model = model
def prompt(
self,
messages: list[Message],
tools: list[Tool],
) -> Iterator[StreamEvent]:
from claude_agent_sdk import (
query,
ClaudeAgentOptions,
AssistantMessage as SDKAssistantMessage,
TextBlock as SDKTextBlock,
ResultMessage,
CLINotFoundError,
CLIConnectionError,
)
prompt_text = _messages_to_prompt(messages)
self._cancelled = False
q: queue.Queue = queue.Queue()
# Determine cwd from the last UserMessage's image paths if available
cwd = self._cwd
if not cwd:
for msg in reversed(messages):
if isinstance(msg, UserMessage):
for b in msg.content:
if isinstance(b, ImageBlock):
cwd = str(b.path.parent.parent) # session_dir
break
if cwd:
break
async def _run():
try:
got_assistant_text = False
async for msg in query(
prompt=prompt_text,
options=ClaudeAgentOptions(
model=self._model,
cwd=cwd or ".",
system_prompt=SYSTEM_PROMPT,
max_turns=self._max_turns,
permission_mode=self._permission_mode,
),
):
if self._cancelled:
break
if isinstance(msg, SDKAssistantMessage):
for block in msg.content:
if isinstance(block, SDKTextBlock):
q.put(TextDelta(text=block.text))
got_assistant_text = True
elif isinstance(msg, ResultMessage):
# Only use ResultMessage.result if we got no text from AssistantMessages
if msg.result and not got_assistant_text:
q.put(TextDelta(text=msg.result))
q.put(Done(stop_reason="end_turn"))
except CLINotFoundError:
q.put(Error(
message="Claude Code CLI not found.\n"
"Install it: https://claude.ai/code\n"
"Then run `claude` once in a terminal to authenticate."
))
except CLIConnectionError as e:
if "auth" in str(e).lower() or "login" in str(e).lower() or "401" in str(e):
q.put(Error(
message="Claude Code not authenticated.\n"
"Run `claude` in a terminal and complete the login flow, then retry."
))
else:
q.put(Error(message=str(e)))
except Exception as e:
q.put(Error(message=str(e)))
finally:
q.put(_SENTINEL)
import asyncio
import threading
def _thread():
loop = asyncio.new_event_loop()
try:
loop.run_until_complete(_run())
finally:
loop.close()
t = threading.Thread(target=_thread, daemon=True, name="claude_sdk_stream")
t.start()
while True:
item = q.get()
if item is _SENTINEL:
break
yield item
def cancel(self) -> None:
self._cancelled = True