better agent
This commit is contained in:
206
cht/agent/claude_sdk_connection.py
Normal file
206
cht/agent/claude_sdk_connection.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
AgentConnection for Claude Code SDK (claude_agent_sdk).
|
||||
|
||||
Uses your Claude Code subscription — no direct API costs.
|
||||
Truly streams via a queue bridge between the async SDK generator
|
||||
and the synchronous Iterator[StreamEvent] interface.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import queue
|
||||
from typing import Iterator
|
||||
|
||||
from cht.agent.base import (
|
||||
AgentConnection,
|
||||
AssistantMessage,
|
||||
ImageBlock,
|
||||
Message,
|
||||
StreamEvent,
|
||||
TextBlock,
|
||||
TextDelta,
|
||||
Tool,
|
||||
ToolCallEnd,
|
||||
ToolCallStart,
|
||||
ToolResult,
|
||||
ToolUse,
|
||||
TranscriptBlock,
|
||||
UserMessage,
|
||||
Done,
|
||||
Error,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording and analysis tool.
|
||||
You help the user understand what happened during their recording session.
|
||||
|
||||
You have access to frame screenshots extracted from the recording. When frames are mentioned,
|
||||
use the Read tool to view them. Frame timestamps are in seconds from the start of the recording.
|
||||
|
||||
You also have tools to search transcripts, get session info, and capture new frames.
|
||||
|
||||
Be concise and specific. Focus on what's visible in the frames."""
|
||||
|
||||
MODELS = [
|
||||
"claude-sonnet-4-6",
|
||||
"claude-opus-4-6",
|
||||
"claude-haiku-4-5",
|
||||
]
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
def _messages_to_prompt(messages: list[Message]) -> str:
|
||||
"""Flatten structured messages into a text prompt for the SDK."""
|
||||
lines = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, UserMessage):
|
||||
parts = []
|
||||
for b in msg.content:
|
||||
if isinstance(b, TextBlock):
|
||||
parts.append(b.text)
|
||||
elif isinstance(b, ImageBlock):
|
||||
m, s = divmod(int(b.timestamp), 60)
|
||||
parts.append(f"[Frame {b.frame_id} at {m:02d}:{s:02d} — {b.path}]")
|
||||
elif isinstance(b, TranscriptBlock):
|
||||
m1, s1 = divmod(int(b.start), 60)
|
||||
m2, s2 = divmod(int(b.end), 60)
|
||||
parts.append(f"[Transcript {b.transcript_id} {m1:02d}:{s1:02d}-{m2:02d}:{s2:02d}: {b.text}]")
|
||||
lines.append(f"User: {' '.join(parts)}")
|
||||
elif isinstance(msg, AssistantMessage):
|
||||
text = " ".join(b.text for b in msg.content if isinstance(b, TextBlock))
|
||||
lines.append(f"Assistant: {text}")
|
||||
elif isinstance(msg, ToolUse):
|
||||
lines.append(f"[Tool call: {msg.tool_name}({msg.input})]")
|
||||
elif isinstance(msg, ToolResult):
|
||||
out = msg.output or msg.error or ""
|
||||
lines.append(f"[Tool result: {out}]")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _tool_schemas(tools: list[Tool]) -> list[str]:
|
||||
"""Extract tool names for the SDK's allowed_tools parameter."""
|
||||
# The Claude SDK uses allowed_tools as a list of tool name strings.
|
||||
# Our custom tools are executed by the runner, not by the SDK,
|
||||
# so we only pass "Read" to the SDK (for frame viewing).
|
||||
return ["Read"]
|
||||
|
||||
|
||||
class ClaudeSDKConnection:
|
||||
"""AgentConnection using claude_agent_sdk — requires Claude Code CLI."""
|
||||
|
||||
def __init__(self, cwd: str | None = None, max_turns: int = 5, model: str = MODELS[0]):
|
||||
self._cwd = cwd
|
||||
self._max_turns = max_turns
|
||||
self._model = model
|
||||
self._cancelled = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"claude-sdk/{self._model}"
|
||||
|
||||
def available_models(self) -> list[str]:
|
||||
return list(MODELS)
|
||||
|
||||
def get_model(self) -> str:
|
||||
return self._model
|
||||
|
||||
def set_model(self, model: str) -> None:
|
||||
self._model = model
|
||||
|
||||
def prompt(
|
||||
self,
|
||||
messages: list[Message],
|
||||
tools: list[Tool],
|
||||
) -> Iterator[StreamEvent]:
|
||||
from claude_agent_sdk import (
|
||||
query,
|
||||
ClaudeAgentOptions,
|
||||
AssistantMessage as SDKAssistantMessage,
|
||||
TextBlock as SDKTextBlock,
|
||||
ResultMessage,
|
||||
CLINotFoundError,
|
||||
CLIConnectionError,
|
||||
)
|
||||
|
||||
prompt_text = _messages_to_prompt(messages)
|
||||
self._cancelled = False
|
||||
q: queue.Queue = queue.Queue()
|
||||
|
||||
# Determine cwd from the last UserMessage's image paths if available
|
||||
cwd = self._cwd
|
||||
if not cwd:
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, UserMessage):
|
||||
for b in msg.content:
|
||||
if isinstance(b, ImageBlock):
|
||||
cwd = str(b.path.parent.parent) # session_dir
|
||||
break
|
||||
if cwd:
|
||||
break
|
||||
|
||||
async def _run():
|
||||
try:
|
||||
got_assistant_text = False
|
||||
async for msg in query(
|
||||
prompt=prompt_text,
|
||||
options=ClaudeAgentOptions(
|
||||
model=self._model,
|
||||
cwd=cwd or ".",
|
||||
allowed_tools=_tool_schemas(tools),
|
||||
system_prompt=SYSTEM_PROMPT,
|
||||
max_turns=self._max_turns,
|
||||
),
|
||||
):
|
||||
if self._cancelled:
|
||||
break
|
||||
if isinstance(msg, SDKAssistantMessage):
|
||||
for block in msg.content:
|
||||
if isinstance(block, SDKTextBlock):
|
||||
q.put(TextDelta(text=block.text))
|
||||
got_assistant_text = True
|
||||
elif isinstance(msg, ResultMessage):
|
||||
# Only use ResultMessage.result if we got no text from AssistantMessages
|
||||
if msg.result and not got_assistant_text:
|
||||
q.put(TextDelta(text=msg.result))
|
||||
q.put(Done(stop_reason="end_turn"))
|
||||
except CLINotFoundError:
|
||||
q.put(Error(
|
||||
message="Claude Code CLI not found.\n"
|
||||
"Install it: https://claude.ai/code\n"
|
||||
"Then run `claude` once in a terminal to authenticate."
|
||||
))
|
||||
except CLIConnectionError as e:
|
||||
if "auth" in str(e).lower() or "login" in str(e).lower() or "401" in str(e):
|
||||
q.put(Error(
|
||||
message="Claude Code not authenticated.\n"
|
||||
"Run `claude` in a terminal and complete the login flow, then retry."
|
||||
))
|
||||
else:
|
||||
q.put(Error(message=str(e)))
|
||||
except Exception as e:
|
||||
q.put(Error(message=str(e)))
|
||||
finally:
|
||||
q.put(_SENTINEL)
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
def _thread():
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(_run())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
t = threading.Thread(target=_thread, daemon=True, name="claude_sdk_stream")
|
||||
t.start()
|
||||
|
||||
while True:
|
||||
item = q.get()
|
||||
if item is _SENTINEL:
|
||||
break
|
||||
yield item
|
||||
|
||||
def cancel(self) -> None:
|
||||
self._cancelled = True
|
||||
Reference in New Issue
Block a user