151 lines
5.2 KiB
Python
151 lines
5.2 KiB
Python
"""
|
|
Agent provider using the Claude Code SDK (claude_agent_sdk).
|
|
|
|
Uses your Claude Code subscription — no direct API costs.
|
|
Passes frame paths in the prompt; Claude reads them visually via the Read tool.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Iterator
|
|
|
|
import anyio
|
|
from claude_agent_sdk import query, ClaudeAgentOptions, AssistantMessage, TextBlock, ResultMessage
|
|
from claude_agent_sdk import CLINotFoundError, CLIConnectionError
|
|
|
|
from cht.agent.base import AgentProvider, SessionContext
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording and analysis tool.
|
|
You help the user understand what happened during their recording session.
|
|
|
|
You have access to frame screenshots extracted from the recording. When frames are mentioned,
|
|
use the Read tool to view them. Frame timestamps are in seconds from the start of the recording.
|
|
|
|
Be concise and specific. Focus on what's visible in the frames."""
|
|
|
|
|
|
def _build_prompt(message: str, context: SessionContext) -> str:
|
|
lines = []
|
|
|
|
# Session summary
|
|
m, s = divmod(int(context.duration), 60)
|
|
lines.append(f"Recording duration: {m:02d}:{s:02d}")
|
|
lines.append(f"Total frames captured: {len(context.frames)}")
|
|
|
|
# All available frames (let Claude decide which to look at)
|
|
if context.frames:
|
|
lines.append("\nAvailable frames:")
|
|
for f in context.frames:
|
|
fm, fs = divmod(int(f.timestamp), 60)
|
|
lines.append(f" {f.id} at {fm:02d}:{fs:02d} — {f.path}")
|
|
|
|
# Explicitly mentioned frames
|
|
if context.mentioned_frames:
|
|
lines.append("\nFrames referenced in this message:")
|
|
for f in context.mentioned_frames:
|
|
fm, fs = divmod(int(f.timestamp), 60)
|
|
lines.append(f" {f.id} at {fm:02d}:{fs:02d} — {f.path}")
|
|
|
|
# Transcript
|
|
if context.transcript_segments:
|
|
lines.append(f"\nTranscript ({len(context.transcript_segments)} segments):")
|
|
for t in context.transcript_segments:
|
|
tm1, ts1 = divmod(int(t.start), 60)
|
|
tm2, ts2 = divmod(int(t.end), 60)
|
|
lines.append(f" {t.id} [{tm1:02d}:{ts1:02d}-{tm2:02d}:{ts2:02d}] {t.text}")
|
|
|
|
if context.mentioned_transcripts:
|
|
lines.append("\nTranscript segments referenced in this message:")
|
|
for t in context.mentioned_transcripts:
|
|
tm1, ts1 = divmod(int(t.start), 60)
|
|
tm2, ts2 = divmod(int(t.end), 60)
|
|
lines.append(f" {t.id} [{tm1:02d}:{ts1:02d}-{tm2:02d}:{ts2:02d}] {t.text}")
|
|
|
|
if context.history:
|
|
lines.append("\nConversation history:")
|
|
for role, text in context.history:
|
|
prefix = "User" if role == "user" else "Assistant"
|
|
lines.append(f" {prefix}: {text}")
|
|
|
|
lines.append(f"\nUser message: {message}")
|
|
return "\n".join(lines)
|
|
|
|
|
|
MODELS = [
|
|
"claude-sonnet-4-6",
|
|
"claude-opus-4-6",
|
|
"claude-haiku-4-5",
|
|
]
|
|
|
|
|
|
class ClaudeSDKProvider(AgentProvider):
|
|
"""Uses claude_agent_sdk — requires Claude Code CLI to be installed."""
|
|
|
|
def __init__(self, cwd: str | None = None, max_turns: int = 5, model: str = MODELS[0]):
|
|
self._cwd = cwd
|
|
self._max_turns = max_turns
|
|
self._model = model
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return f"claude-sdk/{self._model}"
|
|
|
|
@property
|
|
def available_models(self) -> list[str]:
|
|
return list(MODELS)
|
|
|
|
@property
|
|
def model(self) -> str:
|
|
return self._model
|
|
|
|
@model.setter
|
|
def model(self, value: str):
|
|
self._model = value
|
|
|
|
def stream(self, message: str, context: SessionContext) -> Iterator[str]:
|
|
prompt = _build_prompt(message, context)
|
|
chunks = []
|
|
|
|
async def _run():
|
|
async for msg in query(
|
|
prompt=prompt,
|
|
options=ClaudeAgentOptions(
|
|
model=self._model,
|
|
cwd=self._cwd or str(context.session_dir),
|
|
allowed_tools=["Read"],
|
|
system_prompt=SYSTEM_PROMPT,
|
|
max_turns=self._max_turns,
|
|
),
|
|
):
|
|
if isinstance(msg, AssistantMessage):
|
|
for block in msg.content:
|
|
if isinstance(block, TextBlock):
|
|
chunks.append(block.text)
|
|
elif isinstance(msg, ResultMessage):
|
|
if msg.result:
|
|
chunks.append(msg.result)
|
|
|
|
try:
|
|
import asyncio
|
|
loop = asyncio.new_event_loop()
|
|
try:
|
|
loop.run_until_complete(_run())
|
|
finally:
|
|
loop.close()
|
|
except CLINotFoundError:
|
|
raise RuntimeError(
|
|
"Claude Code CLI not found.\n"
|
|
"Install it: https://claude.ai/code\n"
|
|
"Then run `claude` once in a terminal to authenticate."
|
|
)
|
|
except CLIConnectionError as e:
|
|
if "auth" in str(e).lower() or "login" in str(e).lower() or "401" in str(e):
|
|
raise RuntimeError(
|
|
"Claude Code not authenticated.\n"
|
|
"Run `claude` in a terminal and complete the login flow, then retry."
|
|
)
|
|
raise
|
|
|
|
yield from chunks
|