AI
This commit is contained in:
38
cht/agent/base.py
Normal file
38
cht/agent/base.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Abstract base for agent providers.
|
||||
|
||||
Each provider takes a user message + session context and yields response
|
||||
text chunks for streaming into the UI.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameRef:
|
||||
id: str # "F0001"
|
||||
path: Path # absolute path to JPEG
|
||||
timestamp: float # seconds into recording
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionContext:
|
||||
session_dir: Path
|
||||
frames: list[FrameRef] # all captured frames so far
|
||||
duration: float # current recording duration (seconds)
|
||||
mentioned_frames: list[FrameRef] = field(default_factory=list) # @-referenced in message
|
||||
|
||||
|
||||
class AgentProvider(ABC):
|
||||
@abstractmethod
|
||||
def stream(self, message: str, context: SessionContext) -> Iterator[str]:
|
||||
"""Yield response text chunks."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
...
|
||||
108
cht/agent/claude_sdk_provider.py
Normal file
108
cht/agent/claude_sdk_provider.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Agent provider using the Claude Code SDK (claude_agent_sdk).
|
||||
|
||||
Uses your Claude Code subscription — no direct API costs.
|
||||
Passes frame paths in the prompt; Claude reads them visually via the Read tool.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Iterator
|
||||
|
||||
import anyio
|
||||
from claude_agent_sdk import query, ClaudeAgentOptions, AssistantMessage, TextBlock, ResultMessage
|
||||
from claude_agent_sdk import CLINotFoundError, CLIConnectionError
|
||||
|
||||
from cht.agent.base import AgentProvider, SessionContext
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording and analysis tool.
|
||||
You help the user understand what happened during their recording session.
|
||||
|
||||
You have access to frame screenshots extracted from the recording. When frames are mentioned,
|
||||
use the Read tool to view them. Frame timestamps are in seconds from the start of the recording.
|
||||
|
||||
Be concise and specific. Focus on what's visible in the frames."""
|
||||
|
||||
|
||||
def _build_prompt(message: str, context: SessionContext) -> str:
|
||||
lines = []
|
||||
|
||||
# Session summary
|
||||
m, s = divmod(int(context.duration), 60)
|
||||
lines.append(f"Recording duration: {m:02d}:{s:02d}")
|
||||
lines.append(f"Total frames captured: {len(context.frames)}")
|
||||
|
||||
# All available frames (let Claude decide which to look at)
|
||||
if context.frames:
|
||||
lines.append("\nAvailable frames:")
|
||||
for f in context.frames:
|
||||
fm, fs = divmod(int(f.timestamp), 60)
|
||||
lines.append(f" {f.id} at {fm:02d}:{fs:02d} — {f.path}")
|
||||
|
||||
# Explicitly mentioned frames
|
||||
if context.mentioned_frames:
|
||||
lines.append("\nFrames referenced in this message:")
|
||||
for f in context.mentioned_frames:
|
||||
fm, fs = divmod(int(f.timestamp), 60)
|
||||
lines.append(f" {f.id} at {fm:02d}:{fs:02d} — {f.path}")
|
||||
|
||||
lines.append(f"\nUser message: {message}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class ClaudeSDKProvider(AgentProvider):
|
||||
"""Uses claude_agent_sdk — requires Claude Code CLI to be installed."""
|
||||
|
||||
def __init__(self, cwd: str | None = None, max_turns: int = 5):
|
||||
self._cwd = cwd
|
||||
self._max_turns = max_turns
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "claude-code-sdk"
|
||||
|
||||
def stream(self, message: str, context: SessionContext) -> Iterator[str]:
|
||||
prompt = _build_prompt(message, context)
|
||||
chunks = []
|
||||
|
||||
async def _run():
|
||||
async for msg in query(
|
||||
prompt=prompt,
|
||||
options=ClaudeAgentOptions(
|
||||
cwd=self._cwd or str(context.session_dir),
|
||||
allowed_tools=["Read"],
|
||||
system_prompt=SYSTEM_PROMPT,
|
||||
max_turns=self._max_turns,
|
||||
),
|
||||
):
|
||||
if isinstance(msg, AssistantMessage):
|
||||
for block in msg.content:
|
||||
if isinstance(block, TextBlock):
|
||||
chunks.append(block.text)
|
||||
elif isinstance(msg, ResultMessage):
|
||||
if msg.result:
|
||||
chunks.append(msg.result)
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(_run())
|
||||
finally:
|
||||
loop.close()
|
||||
except CLINotFoundError:
|
||||
raise RuntimeError(
|
||||
"Claude Code CLI not found.\n"
|
||||
"Install it: https://claude.ai/code\n"
|
||||
"Then run `claude` once in a terminal to authenticate."
|
||||
)
|
||||
except CLIConnectionError as e:
|
||||
if "auth" in str(e).lower() or "login" in str(e).lower() or "401" in str(e):
|
||||
raise RuntimeError(
|
||||
"Claude Code not authenticated.\n"
|
||||
"Run `claude` in a terminal and complete the login flow, then retry."
|
||||
)
|
||||
raise
|
||||
|
||||
yield from chunks
|
||||
100
cht/agent/openai_compat_provider.py
Normal file
100
cht/agent/openai_compat_provider.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
Agent provider for OpenAI-compatible APIs (Groq, OpenAI, etc.).
|
||||
|
||||
Sends frame images as base64. Requires GROQ_API_KEY or OPENAI_API_KEY env var.
|
||||
Auto-detects provider from available env keys.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from typing import Iterator
|
||||
|
||||
from cht.agent.base import AgentProvider, SessionContext, FrameRef
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording and analysis tool.
|
||||
You help the user understand what happened during their recording session.
|
||||
Be concise and specific. Focus on what's visible in the provided frames."""
|
||||
|
||||
# Default models per provider
|
||||
_PROVIDER_DEFAULTS = {
|
||||
"groq": ("https://api.groq.com/openai/v1", "meta-llama/llama-4-maverick-17b-128e-instruct"),
|
||||
"openai": ("https://api.openai.com/v1", "gpt-4o"),
|
||||
}
|
||||
|
||||
|
||||
def _detect_provider() -> tuple[str, str, str] | None:
|
||||
"""Returns (api_key, base_url, model) or None if no key found."""
|
||||
if key := os.environ.get("GROQ_API_KEY"):
|
||||
base_url, model = _PROVIDER_DEFAULTS["groq"]
|
||||
return key, base_url, os.environ.get("CHT_MODEL", model)
|
||||
if key := os.environ.get("OPENAI_API_KEY"):
|
||||
base_url, model = _PROVIDER_DEFAULTS["openai"]
|
||||
return key, base_url, os.environ.get("CHT_MODEL", model)
|
||||
return None
|
||||
|
||||
|
||||
def _frame_to_image_content(frame: FrameRef) -> dict:
|
||||
with open(frame.path, "rb") as f:
|
||||
data = base64.standard_b64encode(f.read()).decode()
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{data}"},
|
||||
}
|
||||
|
||||
|
||||
class OpenAICompatProvider(AgentProvider):
|
||||
"""Uses any OpenAI-compatible API. Auto-detects from env vars."""
|
||||
|
||||
def __init__(self):
|
||||
detected = _detect_provider()
|
||||
if not detected:
|
||||
raise RuntimeError(
|
||||
"No API key found. Set GROQ_API_KEY or OPENAI_API_KEY."
|
||||
)
|
||||
self._api_key, self._base_url, self._model = detected
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
if "groq" in self._base_url:
|
||||
return f"groq/{self._model}"
|
||||
return f"openai-compat/{self._model}"
|
||||
|
||||
def stream(self, message: str, context: SessionContext) -> Iterator[str]:
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(api_key=self._api_key, base_url=self._base_url)
|
||||
|
||||
# Build context header
|
||||
m, s = divmod(int(context.duration), 60)
|
||||
ctx_text = (
|
||||
f"Recording duration: {m:02d}:{s:02d}\n"
|
||||
f"Total frames: {len(context.frames)}\n"
|
||||
)
|
||||
|
||||
# Include mentioned frames as images, fall back to last 3 frames
|
||||
frames_to_send = context.mentioned_frames or context.frames[-3:]
|
||||
|
||||
content: list[dict] = [{"type": "text", "text": ctx_text + message}]
|
||||
for frame in frames_to_send:
|
||||
fm, fs = divmod(int(frame.timestamp), 60)
|
||||
content.append({"type": "text", "text": f"{frame.id} at {fm:02d}:{fs:02d}:"})
|
||||
try:
|
||||
content.append(_frame_to_image_content(frame))
|
||||
except Exception as e:
|
||||
log.warning("Could not encode frame %s: %s", frame.id, e)
|
||||
|
||||
stream = client.chat.completions.create(
|
||||
model=self._model,
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": content},
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
for chunk in stream:
|
||||
delta = chunk.choices[0].delta.content
|
||||
if delta:
|
||||
yield delta
|
||||
140
cht/agent/runner.py
Normal file
140
cht/agent/runner.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Agent runner — resolves provider, parses @-mentions, dispatches messages.
|
||||
|
||||
Provider selection (in order):
|
||||
1. GROQ_API_KEY → OpenAI-compat / Groq
|
||||
2. OPENAI_API_KEY → OpenAI-compat / OpenAI
|
||||
3. (default) → Claude Code SDK (uses CC subscription)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Callable
|
||||
|
||||
from cht.agent.base import AgentProvider, FrameRef, SessionContext
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Predefined actions sent as messages with a fixed prompt
|
||||
ACTIONS: dict[str, str] = {
|
||||
"Summarize": "Summarize what happened in this recording so far. Look at the captured frames and describe the key content and any changes you notice.",
|
||||
"What changed": "Compare the captured frames in order and describe what changed between them. Focus on meaningful transitions.",
|
||||
"Key moments": "Identify the most important moments in the recording based on the frames. List them with timestamps.",
|
||||
"Describe now": "Look at the most recent frame and describe exactly what is currently on screen.",
|
||||
}
|
||||
|
||||
|
||||
def check_claude_cli() -> str | None:
|
||||
"""Returns None if OK, or an error string if CLI is missing/unauthenticated."""
|
||||
import shutil, subprocess
|
||||
if not shutil.which("claude"):
|
||||
return "Claude Code CLI not found. Install from https://claude.ai/code"
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["claude", "--version"],
|
||||
capture_output=True, timeout=5
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return "Claude Code CLI error. Run `claude` in a terminal to check."
|
||||
except Exception as e:
|
||||
return f"Claude Code CLI check failed: {e}"
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_provider() -> AgentProvider:
|
||||
if os.environ.get("GROQ_API_KEY") or os.environ.get("OPENAI_API_KEY"):
|
||||
from cht.agent.openai_compat_provider import OpenAICompatProvider
|
||||
return OpenAICompatProvider()
|
||||
from cht.agent.claude_sdk_provider import ClaudeSDKProvider
|
||||
return ClaudeSDKProvider()
|
||||
|
||||
|
||||
def _parse_mentions(message: str, frames: list[FrameRef]) -> list[FrameRef]:
|
||||
"""Extract @-references from message. Accepts:
|
||||
@F0001 @f1 @1 @001 — all match frame F0001
|
||||
"""
|
||||
mentioned = []
|
||||
seen = set()
|
||||
for match in re.finditer(r"@([Ff]?\d+)", message):
|
||||
raw = match.group(1).lstrip("Ff")
|
||||
num = int(raw)
|
||||
fid = f"F{num:04d}"
|
||||
if fid not in seen:
|
||||
frame = next((f for f in frames if f.id == fid), None)
|
||||
if frame:
|
||||
mentioned.append(frame)
|
||||
seen.add(fid)
|
||||
return mentioned
|
||||
|
||||
|
||||
def _load_frames(frames_dir: Path) -> list[FrameRef]:
|
||||
index_path = frames_dir / "index.json"
|
||||
if not index_path.exists():
|
||||
return []
|
||||
try:
|
||||
entries = json.loads(index_path.read_text())
|
||||
return [
|
||||
FrameRef(id=e["id"], path=Path(e["path"]), timestamp=e["timestamp"])
|
||||
for e in entries
|
||||
if Path(e["path"]).exists()
|
||||
]
|
||||
except Exception as e:
|
||||
log.warning("Could not load frames index: %s", e)
|
||||
return []
|
||||
|
||||
|
||||
class AgentRunner:
|
||||
"""Runs agent queries in a background thread, streams chunks to a callback."""
|
||||
|
||||
def __init__(self):
|
||||
self._provider: AgentProvider | None = None
|
||||
|
||||
def _get_provider(self) -> AgentProvider:
|
||||
if self._provider is None:
|
||||
self._provider = _resolve_provider()
|
||||
log.info("Agent provider: %s", self._provider.name)
|
||||
return self._provider
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
try:
|
||||
return self._get_provider().name
|
||||
except Exception:
|
||||
return "unknown"
|
||||
|
||||
def send(
|
||||
self,
|
||||
message: str,
|
||||
stream_mgr,
|
||||
tracker,
|
||||
on_chunk: Callable[[str], None],
|
||||
on_done: Callable[[str | None], None],
|
||||
):
|
||||
"""Dispatch message in a background thread.
|
||||
|
||||
on_chunk(text) — called for each streamed chunk
|
||||
on_done(error_or_None) — called when complete
|
||||
"""
|
||||
def _run():
|
||||
try:
|
||||
provider = self._get_provider()
|
||||
frames = _load_frames(stream_mgr.frames_dir)
|
||||
mentioned = _parse_mentions(message, frames)
|
||||
context = SessionContext(
|
||||
session_dir=stream_mgr.session_dir,
|
||||
frames=frames,
|
||||
duration=tracker.duration if tracker else 0.0,
|
||||
mentioned_frames=mentioned,
|
||||
)
|
||||
for chunk in provider.stream(message, context):
|
||||
on_chunk(chunk)
|
||||
on_done(None)
|
||||
except Exception as e:
|
||||
log.error("Agent error: %s", e)
|
||||
on_done(str(e))
|
||||
|
||||
Thread(target=_run, daemon=True, name="agent_runner").start()
|
||||
@@ -15,6 +15,7 @@ from cht.ui.timeline import Timeline, TimelineControls
|
||||
from cht.ui.monitor import MonitorWidget
|
||||
from cht.stream.manager import StreamManager
|
||||
from cht.stream.tracker import RecordingTracker
|
||||
from cht.agent.runner import AgentRunner, ACTIONS, check_claude_cli
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -32,6 +33,7 @@ class ChtWindow(Adw.ApplicationWindow):
|
||||
|
||||
# Timeline is the central state machine
|
||||
self._timeline = Timeline()
|
||||
self._agent = AgentRunner()
|
||||
|
||||
# Main layout
|
||||
self._main_paned = Gtk.Paned(orientation=Gtk.Orientation.HORIZONTAL)
|
||||
@@ -61,8 +63,8 @@ class ChtWindow(Adw.ApplicationWindow):
|
||||
self.connect("close-request", self._on_close)
|
||||
log.info("Window initialized")
|
||||
|
||||
# Auto-connect on startup
|
||||
GLib.idle_add(self._start_stream)
|
||||
GLib.idle_add(self._check_agent_auth)
|
||||
|
||||
def _on_connect_clicked(self, button):
|
||||
if self._streaming:
|
||||
@@ -340,34 +342,90 @@ class ChtWindow(Adw.ApplicationWindow):
|
||||
return frame
|
||||
|
||||
def _build_agent_input(self):
|
||||
box = Gtk.Box(orientation=Gtk.Orientation.HORIZONTAL, spacing=4)
|
||||
box.set_margin_start(4)
|
||||
box.set_margin_end(4)
|
||||
box.set_margin_top(4)
|
||||
box.set_margin_bottom(4)
|
||||
outer = Gtk.Box(orientation=Gtk.Orientation.VERTICAL, spacing=4)
|
||||
outer.set_margin_start(4)
|
||||
outer.set_margin_end(4)
|
||||
outer.set_margin_top(4)
|
||||
outer.set_margin_bottom(4)
|
||||
|
||||
# Quick action buttons
|
||||
actions_box = Gtk.Box(orientation=Gtk.Orientation.HORIZONTAL, spacing=4)
|
||||
for label in ACTIONS:
|
||||
btn = Gtk.Button(label=label)
|
||||
btn.add_css_class("flat")
|
||||
btn.connect("clicked", lambda b, l=label: self._send_message(ACTIONS[l]))
|
||||
actions_box.append(btn)
|
||||
outer.append(actions_box)
|
||||
|
||||
# Text entry + send
|
||||
input_row = Gtk.Box(orientation=Gtk.Orientation.HORIZONTAL, spacing=4)
|
||||
self._input_entry = Gtk.Entry()
|
||||
self._input_entry.set_hexpand(True)
|
||||
self._input_entry.set_placeholder_text("Message agent... (use @ to reference frames/transcripts)")
|
||||
self._input_entry.set_placeholder_text("Message agent... (use @F0001 to reference a frame)")
|
||||
self._input_entry.connect("activate", lambda e: self._send_message())
|
||||
box.append(self._input_entry)
|
||||
input_row.append(self._input_entry)
|
||||
|
||||
send_btn = Gtk.Button(label="Send")
|
||||
send_btn.add_css_class("suggested-action")
|
||||
send_btn.connect("clicked", lambda b: self._send_message())
|
||||
box.append(send_btn)
|
||||
input_row.append(send_btn)
|
||||
outer.append(input_row)
|
||||
|
||||
frame = Gtk.Frame()
|
||||
frame.set_child(box)
|
||||
frame.set_child(outer)
|
||||
return frame
|
||||
|
||||
def _send_message(self):
|
||||
text = self._input_entry.get_text().strip()
|
||||
def _send_message(self, text: str | None = None):
|
||||
if text is None:
|
||||
text = self._input_entry.get_text().strip()
|
||||
self._input_entry.set_text("")
|
||||
if not text:
|
||||
return
|
||||
if not self._stream_mgr:
|
||||
self._append_agent_output("No active session.\n")
|
||||
return
|
||||
|
||||
self._append_agent_output(f"\n> {text}\n…\n")
|
||||
|
||||
self._agent.send(
|
||||
message=text,
|
||||
stream_mgr=self._stream_mgr,
|
||||
tracker=self._tracker,
|
||||
on_chunk=lambda chunk: GLib.idle_add(self._replace_thinking, chunk),
|
||||
on_done=lambda err: GLib.idle_add(
|
||||
self._append_agent_output,
|
||||
f"[Error: {err}]\n" if err else "\n"
|
||||
),
|
||||
)
|
||||
self._thinking_replaced = False
|
||||
|
||||
def _replace_thinking(self, chunk: str):
|
||||
"""Replace the '…' placeholder with the first chunk, then append normally."""
|
||||
if not self._thinking_replaced:
|
||||
self._thinking_replaced = True
|
||||
buf = self._agent_output_view.get_buffer()
|
||||
# Remove the trailing '…\n' (3 chars)
|
||||
end = buf.get_end_iter()
|
||||
start = end.copy()
|
||||
start.backward_chars(2)
|
||||
buf.delete(start, end)
|
||||
self._append_agent_output(chunk)
|
||||
|
||||
def _check_agent_auth(self):
|
||||
import os
|
||||
if os.environ.get("GROQ_API_KEY") or os.environ.get("OPENAI_API_KEY"):
|
||||
return # using external provider, no CLI check needed
|
||||
err = check_claude_cli()
|
||||
if err:
|
||||
self._append_agent_output(f"⚠ {err}\n")
|
||||
else:
|
||||
self._append_agent_output(f"Agent ready ({self._agent.provider_name})\n")
|
||||
|
||||
def _append_agent_output(self, text: str):
|
||||
buf = self._agent_output_view.get_buffer()
|
||||
buf.insert(buf.get_end_iter(), f"\n> {text}\n")
|
||||
self._input_entry.set_text("")
|
||||
buf.insert(buf.get_end_iter(), text)
|
||||
# Auto-scroll to bottom
|
||||
self._agent_output_view.scroll_to_iter(buf.get_end_iter(), 0, False, 0, 0)
|
||||
|
||||
# -- Frame thumbnails --
|
||||
|
||||
|
||||
Reference in New Issue
Block a user