AI
This commit is contained in:
108
cht/agent/claude_sdk_provider.py
Normal file
108
cht/agent/claude_sdk_provider.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Agent provider using the Claude Code SDK (claude_agent_sdk).
|
||||
|
||||
Uses your Claude Code subscription — no direct API costs.
|
||||
Passes frame paths in the prompt; Claude reads them visually via the Read tool.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Iterator
|
||||
|
||||
import anyio
|
||||
from claude_agent_sdk import query, ClaudeAgentOptions, AssistantMessage, TextBlock, ResultMessage
|
||||
from claude_agent_sdk import CLINotFoundError, CLIConnectionError
|
||||
|
||||
from cht.agent.base import AgentProvider, SessionContext
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording and analysis tool.
|
||||
You help the user understand what happened during their recording session.
|
||||
|
||||
You have access to frame screenshots extracted from the recording. When frames are mentioned,
|
||||
use the Read tool to view them. Frame timestamps are in seconds from the start of the recording.
|
||||
|
||||
Be concise and specific. Focus on what's visible in the frames."""
|
||||
|
||||
|
||||
def _build_prompt(message: str, context: SessionContext) -> str:
|
||||
lines = []
|
||||
|
||||
# Session summary
|
||||
m, s = divmod(int(context.duration), 60)
|
||||
lines.append(f"Recording duration: {m:02d}:{s:02d}")
|
||||
lines.append(f"Total frames captured: {len(context.frames)}")
|
||||
|
||||
# All available frames (let Claude decide which to look at)
|
||||
if context.frames:
|
||||
lines.append("\nAvailable frames:")
|
||||
for f in context.frames:
|
||||
fm, fs = divmod(int(f.timestamp), 60)
|
||||
lines.append(f" {f.id} at {fm:02d}:{fs:02d} — {f.path}")
|
||||
|
||||
# Explicitly mentioned frames
|
||||
if context.mentioned_frames:
|
||||
lines.append("\nFrames referenced in this message:")
|
||||
for f in context.mentioned_frames:
|
||||
fm, fs = divmod(int(f.timestamp), 60)
|
||||
lines.append(f" {f.id} at {fm:02d}:{fs:02d} — {f.path}")
|
||||
|
||||
lines.append(f"\nUser message: {message}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class ClaudeSDKProvider(AgentProvider):
|
||||
"""Uses claude_agent_sdk — requires Claude Code CLI to be installed."""
|
||||
|
||||
def __init__(self, cwd: str | None = None, max_turns: int = 5):
|
||||
self._cwd = cwd
|
||||
self._max_turns = max_turns
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "claude-code-sdk"
|
||||
|
||||
def stream(self, message: str, context: SessionContext) -> Iterator[str]:
|
||||
prompt = _build_prompt(message, context)
|
||||
chunks = []
|
||||
|
||||
async def _run():
|
||||
async for msg in query(
|
||||
prompt=prompt,
|
||||
options=ClaudeAgentOptions(
|
||||
cwd=self._cwd or str(context.session_dir),
|
||||
allowed_tools=["Read"],
|
||||
system_prompt=SYSTEM_PROMPT,
|
||||
max_turns=self._max_turns,
|
||||
),
|
||||
):
|
||||
if isinstance(msg, AssistantMessage):
|
||||
for block in msg.content:
|
||||
if isinstance(block, TextBlock):
|
||||
chunks.append(block.text)
|
||||
elif isinstance(msg, ResultMessage):
|
||||
if msg.result:
|
||||
chunks.append(msg.result)
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(_run())
|
||||
finally:
|
||||
loop.close()
|
||||
except CLINotFoundError:
|
||||
raise RuntimeError(
|
||||
"Claude Code CLI not found.\n"
|
||||
"Install it: https://claude.ai/code\n"
|
||||
"Then run `claude` once in a terminal to authenticate."
|
||||
)
|
||||
except CLIConnectionError as e:
|
||||
if "auth" in str(e).lower() or "login" in str(e).lower() or "401" in str(e):
|
||||
raise RuntimeError(
|
||||
"Claude Code not authenticated.\n"
|
||||
"Run `claude` in a terminal and complete the login flow, then retry."
|
||||
)
|
||||
raise
|
||||
|
||||
yield from chunks
|
||||
Reference in New Issue
Block a user