Files
mitus/cht/agent/claude_sdk_connection.py

191 lines
6.3 KiB
Python

"""
AgentConnection for Claude Code SDK (claude_agent_sdk).
Uses your Claude Code subscription — no direct API costs.
Truly streams via a queue bridge between the async SDK generator
and the synchronous Iterator[StreamEvent] interface.
"""
import logging
import queue
from typing import Iterator
from cht.agent.base import (
AgentConnection,
AssistantMessage,
ImageBlock,
Message,
StreamEvent,
TextBlock,
TextDelta,
Tool,
ToolCallEnd,
ToolCallStart,
ToolResult,
ToolUse,
TranscriptBlock,
UserMessage,
Done,
Error,
)
log = logging.getLogger(__name__)
MODELS = [
"claude-sonnet-4-6",
"claude-opus-4-6",
"claude-haiku-4-5",
]
_SENTINEL = object()
def _messages_to_prompt(messages: list[Message]) -> str:
"""Flatten structured messages into a text prompt for the SDK."""
lines = []
for msg in messages:
if isinstance(msg, UserMessage):
parts = []
for b in msg.content:
if isinstance(b, TextBlock):
parts.append(b.text)
elif isinstance(b, ImageBlock):
m, s = divmod(int(b.timestamp), 60)
parts.append(f"[Frame {b.frame_id} at {m:02d}:{s:02d}{b.path}]")
elif isinstance(b, TranscriptBlock):
m1, s1 = divmod(int(b.start), 60)
m2, s2 = divmod(int(b.end), 60)
parts.append(f"[Transcript {b.transcript_id} {m1:02d}:{s1:02d}-{m2:02d}:{s2:02d}: {b.text}]")
lines.append(f"User: {' '.join(parts)}")
elif isinstance(msg, AssistantMessage):
text = " ".join(b.text for b in msg.content if isinstance(b, TextBlock))
lines.append(f"Assistant: {text}")
elif isinstance(msg, ToolUse):
lines.append(f"[Tool call: {msg.tool_name}({msg.input})]")
elif isinstance(msg, ToolResult):
out = msg.output or msg.error or ""
lines.append(f"[Tool result: {out}]")
return "\n".join(lines)
class ClaudeSDKConnection:
"""AgentConnection using claude_agent_sdk — requires Claude Code CLI."""
def __init__(self, cwd: str | None = None, max_turns: int | None = None,
model: str = MODELS[0], permission_mode: str | None = None):
from cht.config import AGENT_PERMISSION_MODE, AGENT_MAX_TURNS
self._cwd = cwd
self._max_turns = max_turns or AGENT_MAX_TURNS
self._model = model
self._permission_mode = permission_mode or AGENT_PERMISSION_MODE
self._cancelled = False
@property
def name(self) -> str:
return f"claude-sdk/{self._model}"
def available_models(self) -> list[str]:
return list(MODELS)
def get_model(self) -> str:
return self._model
def set_model(self, model: str) -> None:
self._model = model
def prompt(
self,
messages: list[Message],
tools: list[Tool],
) -> Iterator[StreamEvent]:
from claude_agent_sdk import (
query,
ClaudeAgentOptions,
AssistantMessage as SDKAssistantMessage,
TextBlock as SDKTextBlock,
ResultMessage,
CLINotFoundError,
CLIConnectionError,
)
prompt_text = _messages_to_prompt(messages)
self._cancelled = False
q: queue.Queue = queue.Queue()
# Determine cwd from the last UserMessage's image paths if available
cwd = self._cwd
if not cwd:
for msg in reversed(messages):
if isinstance(msg, UserMessage):
for b in msg.content:
if isinstance(b, ImageBlock):
cwd = str(b.path.parent.parent) # session_dir
break
if cwd:
break
async def _run():
try:
got_assistant_text = False
async for msg in query(
prompt=prompt_text,
options=ClaudeAgentOptions(
model=self._model,
cwd=cwd or ".",
max_turns=self._max_turns,
permission_mode=self._permission_mode,
),
):
if self._cancelled:
break
if isinstance(msg, SDKAssistantMessage):
for block in msg.content:
if isinstance(block, SDKTextBlock):
q.put(TextDelta(text=block.text))
got_assistant_text = True
elif isinstance(msg, ResultMessage):
# Only use ResultMessage.result if we got no text from AssistantMessages
if msg.result and not got_assistant_text:
q.put(TextDelta(text=msg.result))
q.put(Done(stop_reason="end_turn"))
except CLINotFoundError:
q.put(Error(
message="Claude Code CLI not found.\n"
"Install it: https://claude.ai/code\n"
"Then run `claude` once in a terminal to authenticate."
))
except CLIConnectionError as e:
if "auth" in str(e).lower() or "login" in str(e).lower() or "401" in str(e):
q.put(Error(
message="Claude Code not authenticated.\n"
"Run `claude` in a terminal and complete the login flow, then retry."
))
else:
q.put(Error(message=str(e)))
except Exception as e:
q.put(Error(message=str(e)))
finally:
q.put(_SENTINEL)
import asyncio
import threading
def _thread():
loop = asyncio.new_event_loop()
try:
loop.run_until_complete(_run())
finally:
loop.close()
t = threading.Thread(target=_thread, daemon=True, name="claude_sdk_stream")
t.start()
while True:
item = q.get()
if item is _SENTINEL:
break
yield item
def cancel(self) -> None:
self._cancelled = True