From 64ecdca71e1f1699f385ab561b1e23432255f323 Mon Sep 17 00:00:00 2001 From: buenosairesam Date: Thu, 9 Apr 2026 14:46:29 -0300 Subject: [PATCH] better agent --- cht/agent/base.py | 336 +++++++++++++++++++++++++--- cht/agent/buffer.py | 50 +++++ cht/agent/claude_sdk_connection.py | 206 +++++++++++++++++ cht/agent/claude_sdk_provider.py | 132 ----------- cht/agent/openai_compat_provider.py | 134 ----------- cht/agent/openai_connection.py | 249 +++++++++++++++++++++ cht/agent/runner.py | 261 ++++++++++++++------- cht/agent/tools.py | 201 +++++++++++++++++ cht/config.py | 4 + cht/ui/agent_output.py | 222 ++++++++++++++---- cht/window.py | 63 +++++- 11 files changed, 1424 insertions(+), 434 deletions(-) create mode 100644 cht/agent/buffer.py create mode 100644 cht/agent/claude_sdk_connection.py delete mode 100644 cht/agent/claude_sdk_provider.py delete mode 100644 cht/agent/openai_compat_provider.py create mode 100644 cht/agent/openai_connection.py create mode 100644 cht/agent/tools.py diff --git a/cht/agent/base.py b/cht/agent/base.py index 17be1f0..783bd04 100644 --- a/cht/agent/base.py +++ b/cht/agent/base.py @@ -1,16 +1,28 @@ """ -Abstract base for agent providers. +Agent data model — structured messages, connections, tools, threads. -Each provider takes a user message + session context and yields response -text chunks for streaming into the UI. +Replaces the old flat AgentProvider/SessionContext with: +- Typed content blocks and messages (conversation model) +- StreamEvent types (connection output) +- AgentConnection protocol (replaces AgentProvider) +- Tool protocol +- Thread (session state, serializable) """ -from abc import ABC, abstractmethod +from __future__ import annotations + +import json +import uuid from dataclasses import dataclass, field +from datetime import datetime, timezone from pathlib import Path -from typing import Iterator +from typing import Iterator, Literal, Protocol, runtime_checkable +# --------------------------------------------------------------------------- +# Shared refs (used by runner mention-parsing and tools) +# --------------------------------------------------------------------------- + @dataclass class FrameRef: id: str # "F0001" @@ -26,40 +38,298 @@ class TranscriptRef: text: str +# --------------------------------------------------------------------------- +# Content blocks +# --------------------------------------------------------------------------- + @dataclass -class SessionContext: +class TextBlock: + text: str + + +@dataclass +class ImageBlock: + frame_id: str # "F0042" + path: Path + timestamp: float + + +@dataclass +class TranscriptBlock: + transcript_id: str # "T0012" + start: float + end: float + text: str + + +ContentBlock = TextBlock | ImageBlock | TranscriptBlock + + +# --------------------------------------------------------------------------- +# Messages +# --------------------------------------------------------------------------- + +@dataclass +class TokenUsage: + input_tokens: int = 0 + output_tokens: int = 0 + + +def _msg_id() -> str: + return uuid.uuid4().hex[:12] + + +@dataclass +class UserMessage: + content: list[ContentBlock] + id: str = field(default_factory=_msg_id) + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + +@dataclass +class AssistantMessage: + content: list[ContentBlock] + model: str = "" + id: str = field(default_factory=_msg_id) + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + token_usage: TokenUsage | None = None + + +@dataclass +class ToolUse: + tool_name: str + input: dict + id: str = field(default_factory=_msg_id) + status: Literal["pending", "running", "done", "error"] = "pending" + + +@dataclass +class ToolResult: + tool_use_id: str + output: str | None = None + error: str | None = None + + +Message = UserMessage | AssistantMessage | ToolUse | ToolResult + + +# --------------------------------------------------------------------------- +# Thread (session state) +# --------------------------------------------------------------------------- + +@dataclass +class Thread: + messages: list[Message] = field(default_factory=list) + id: str = field(default_factory=lambda: uuid.uuid4().hex[:12]) + created: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + updated: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + token_usage: TokenUsage | None = None + + +# --------------------------------------------------------------------------- +# Stream events (yielded by AgentConnection) +# --------------------------------------------------------------------------- + +@dataclass +class TextDelta: + text: str + + +@dataclass +class ToolCallStart: + id: str + name: str + input: dict + + +@dataclass +class ToolCallEnd: + id: str + + +@dataclass +class Done: + stop_reason: str # "end_turn", "tool_use", "max_tokens" + + +@dataclass +class Error: + message: str + + +StreamEvent = TextDelta | ToolCallStart | ToolCallEnd | Done | Error + + +# --------------------------------------------------------------------------- +# Tool protocol +# --------------------------------------------------------------------------- + +@runtime_checkable +class Tool(Protocol): + name: str + description: str + + def input_schema(self) -> dict: ... + def run(self, input: dict, context: ToolContext) -> ToolResult: ... + + +@dataclass +class ToolContext: + """Runtime context passed to tools — references to live app objects.""" session_dir: Path - frames: list[FrameRef] # all captured frames so far - duration: float # current recording duration (seconds) - mentioned_frames: list[FrameRef] = field(default_factory=list) - transcript_segments: list[TranscriptRef] = field(default_factory=list) - mentioned_transcripts: list[TranscriptRef] = field(default_factory=list) - history: list[tuple[str, str]] = field(default_factory=list) # [(role, text), ...] + frames_dir: Path + transcript_dir: Path + stream_mgr: object | None = None # StreamManager, optional + tracker: object | None = None # RecordingTracker, optional -class AgentProvider(ABC): - @abstractmethod - def stream(self, message: str, context: SessionContext) -> Iterator[str]: - """Yield response text chunks.""" +# --------------------------------------------------------------------------- +# AgentConnection protocol (replaces AgentProvider) +# --------------------------------------------------------------------------- + +@runtime_checkable +class AgentConnection(Protocol): + name: str + + def available_models(self) -> list[str]: ... + def get_model(self) -> str: ... + def set_model(self, model: str) -> None: ... + + def prompt( + self, + messages: list[Message], + tools: list[Tool], + ) -> Iterator[StreamEvent]: + """Send messages to the model, yield stream events.""" ... - @property - @abstractmethod - def name(self) -> str: - ... + def cancel(self) -> None: ... - @property - @abstractmethod - def available_models(self) -> list[str]: - """Return list of model IDs this provider supports.""" - ... - @property - @abstractmethod - def model(self) -> str: - ... +# --------------------------------------------------------------------------- +# Thread serialization +# --------------------------------------------------------------------------- - @model.setter - @abstractmethod - def model(self, value: str): - ... +class _Encoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, Path): + return str(o) + if isinstance(o, datetime): + return o.isoformat() + return super().default(o) + + +def _block_to_dict(b: ContentBlock) -> dict: + if isinstance(b, TextBlock): + return {"type": "text", "text": b.text} + if isinstance(b, ImageBlock): + return {"type": "image", "frame_id": b.frame_id, "path": str(b.path), "timestamp": b.timestamp} + if isinstance(b, TranscriptBlock): + return {"type": "transcript", "transcript_id": b.transcript_id, "start": b.start, "end": b.end, "text": b.text} + raise TypeError(f"Unknown block type: {type(b)}") + + +def _block_from_dict(d: dict) -> ContentBlock: + t = d["type"] + if t == "text": + return TextBlock(text=d["text"]) + if t == "image": + return ImageBlock(frame_id=d["frame_id"], path=Path(d["path"]), timestamp=d["timestamp"]) + if t == "transcript": + return TranscriptBlock(transcript_id=d["transcript_id"], start=d["start"], end=d["end"], text=d["text"]) + raise ValueError(f"Unknown block type: {t}") + + +def _msg_to_dict(m: Message) -> dict: + if isinstance(m, UserMessage): + return { + "type": "user", + "id": m.id, + "content": [_block_to_dict(b) for b in m.content], + "timestamp": m.timestamp.isoformat(), + } + if isinstance(m, AssistantMessage): + d = { + "type": "assistant", + "id": m.id, + "content": [_block_to_dict(b) for b in m.content], + "timestamp": m.timestamp.isoformat(), + "model": m.model, + } + if m.token_usage: + d["token_usage"] = {"input_tokens": m.token_usage.input_tokens, "output_tokens": m.token_usage.output_tokens} + return d + if isinstance(m, ToolUse): + return {"type": "tool_use", "id": m.id, "tool_name": m.tool_name, "input": m.input, "status": m.status} + if isinstance(m, ToolResult): + return {"type": "tool_result", "tool_use_id": m.tool_use_id, "output": m.output, "error": m.error} + raise TypeError(f"Unknown message type: {type(m)}") + + +def _msg_from_dict(d: dict) -> Message: + t = d["type"] + if t == "user": + return UserMessage( + id=d["id"], + content=[_block_from_dict(b) for b in d["content"]], + timestamp=datetime.fromisoformat(d["timestamp"]), + ) + if t == "assistant": + tu = None + if "token_usage" in d and d["token_usage"]: + tu = TokenUsage(**d["token_usage"]) + return AssistantMessage( + id=d["id"], + content=[_block_from_dict(b) for b in d["content"]], + timestamp=datetime.fromisoformat(d["timestamp"]), + model=d.get("model", ""), + token_usage=tu, + ) + if t == "tool_use": + return ToolUse(id=d["id"], tool_name=d["tool_name"], input=d["input"], status=d.get("status", "done")) + if t == "tool_result": + return ToolResult(tool_use_id=d["tool_use_id"], output=d.get("output"), error=d.get("error")) + raise ValueError(f"Unknown message type: {t}") + + +def thread_to_dict(thread: Thread) -> dict: + d = { + "id": thread.id, + "messages": [_msg_to_dict(m) for m in thread.messages], + "created": thread.created.isoformat(), + "updated": thread.updated.isoformat(), + } + if thread.token_usage: + d["token_usage"] = {"input_tokens": thread.token_usage.input_tokens, "output_tokens": thread.token_usage.output_tokens} + return d + + +def thread_from_dict(d: dict) -> Thread: + tu = None + if "token_usage" in d and d["token_usage"]: + tu = TokenUsage(**d["token_usage"]) + return Thread( + id=d["id"], + messages=[_msg_from_dict(m) for m in d["messages"]], + created=datetime.fromisoformat(d["created"]), + updated=datetime.fromisoformat(d["updated"]), + token_usage=tu, + ) + + +def save_thread(thread: Thread, session_dir: Path) -> None: + agent_dir = session_dir / "agent" + agent_dir.mkdir(parents=True, exist_ok=True) + path = agent_dir / "thread.json" + thread.updated = datetime.now(timezone.utc) + path.write_text(json.dumps(thread_to_dict(thread), indent=2, cls=_Encoder)) + + +def load_thread(session_dir: Path) -> Thread | None: + path = session_dir / "agent" / "thread.json" + if not path.exists(): + return None + try: + return thread_from_dict(json.loads(path.read_text())) + except Exception: + return None diff --git a/cht/agent/buffer.py b/cht/agent/buffer.py new file mode 100644 index 0000000..95e9796 --- /dev/null +++ b/cht/agent/buffer.py @@ -0,0 +1,50 @@ +"""Streaming text buffer — smooth reveal of bursty network chunks. + +Accumulates text from the provider and reveals it at ~60fps with a +~200ms target reveal pace, so the UI sees a smooth stream regardless +of network burst patterns. + +All methods must be called from the GTK main thread. +""" + +from typing import Callable + +from gi.repository import GLib + + +class StreamingTextBuffer: + TICK_MS = 16 # ~60fps + REVEAL_TARGET_MS = 200 + + def __init__(self, on_reveal: Callable[[str], None]): + self._pending = "" + self._on_reveal = on_reveal + self._timer_id: int | None = None + + def push(self, text: str): + self._pending += text + if self._timer_id is None: + self._timer_id = GLib.timeout_add(self.TICK_MS, self._tick) + + def _tick(self) -> bool: + if not self._pending: + self._timer_id = None + return False + n = max(1, len(self._pending) * self.TICK_MS // self.REVEAL_TARGET_MS) + reveal, self._pending = self._pending[:n], self._pending[n:] + self._on_reveal(reveal) + return True + + def flush(self): + if self._pending: + self._on_reveal(self._pending) + self._pending = "" + if self._timer_id is not None: + GLib.source_remove(self._timer_id) + self._timer_id = None + + def cancel(self): + self._pending = "" + if self._timer_id is not None: + GLib.source_remove(self._timer_id) + self._timer_id = None diff --git a/cht/agent/claude_sdk_connection.py b/cht/agent/claude_sdk_connection.py new file mode 100644 index 0000000..dacaa81 --- /dev/null +++ b/cht/agent/claude_sdk_connection.py @@ -0,0 +1,206 @@ +""" +AgentConnection for Claude Code SDK (claude_agent_sdk). + +Uses your Claude Code subscription — no direct API costs. +Truly streams via a queue bridge between the async SDK generator +and the synchronous Iterator[StreamEvent] interface. +""" + +import logging +import queue +from typing import Iterator + +from cht.agent.base import ( + AgentConnection, + AssistantMessage, + ImageBlock, + Message, + StreamEvent, + TextBlock, + TextDelta, + Tool, + ToolCallEnd, + ToolCallStart, + ToolResult, + ToolUse, + TranscriptBlock, + UserMessage, + Done, + Error, +) + +log = logging.getLogger(__name__) + +SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording and analysis tool. +You help the user understand what happened during their recording session. + +You have access to frame screenshots extracted from the recording. When frames are mentioned, +use the Read tool to view them. Frame timestamps are in seconds from the start of the recording. + +You also have tools to search transcripts, get session info, and capture new frames. + +Be concise and specific. Focus on what's visible in the frames.""" + +MODELS = [ + "claude-sonnet-4-6", + "claude-opus-4-6", + "claude-haiku-4-5", +] + +_SENTINEL = object() + + +def _messages_to_prompt(messages: list[Message]) -> str: + """Flatten structured messages into a text prompt for the SDK.""" + lines = [] + for msg in messages: + if isinstance(msg, UserMessage): + parts = [] + for b in msg.content: + if isinstance(b, TextBlock): + parts.append(b.text) + elif isinstance(b, ImageBlock): + m, s = divmod(int(b.timestamp), 60) + parts.append(f"[Frame {b.frame_id} at {m:02d}:{s:02d} — {b.path}]") + elif isinstance(b, TranscriptBlock): + m1, s1 = divmod(int(b.start), 60) + m2, s2 = divmod(int(b.end), 60) + parts.append(f"[Transcript {b.transcript_id} {m1:02d}:{s1:02d}-{m2:02d}:{s2:02d}: {b.text}]") + lines.append(f"User: {' '.join(parts)}") + elif isinstance(msg, AssistantMessage): + text = " ".join(b.text for b in msg.content if isinstance(b, TextBlock)) + lines.append(f"Assistant: {text}") + elif isinstance(msg, ToolUse): + lines.append(f"[Tool call: {msg.tool_name}({msg.input})]") + elif isinstance(msg, ToolResult): + out = msg.output or msg.error or "" + lines.append(f"[Tool result: {out}]") + return "\n".join(lines) + + +def _tool_schemas(tools: list[Tool]) -> list[str]: + """Extract tool names for the SDK's allowed_tools parameter.""" + # The Claude SDK uses allowed_tools as a list of tool name strings. + # Our custom tools are executed by the runner, not by the SDK, + # so we only pass "Read" to the SDK (for frame viewing). + return ["Read"] + + +class ClaudeSDKConnection: + """AgentConnection using claude_agent_sdk — requires Claude Code CLI.""" + + def __init__(self, cwd: str | None = None, max_turns: int = 5, model: str = MODELS[0]): + self._cwd = cwd + self._max_turns = max_turns + self._model = model + self._cancelled = False + + @property + def name(self) -> str: + return f"claude-sdk/{self._model}" + + def available_models(self) -> list[str]: + return list(MODELS) + + def get_model(self) -> str: + return self._model + + def set_model(self, model: str) -> None: + self._model = model + + def prompt( + self, + messages: list[Message], + tools: list[Tool], + ) -> Iterator[StreamEvent]: + from claude_agent_sdk import ( + query, + ClaudeAgentOptions, + AssistantMessage as SDKAssistantMessage, + TextBlock as SDKTextBlock, + ResultMessage, + CLINotFoundError, + CLIConnectionError, + ) + + prompt_text = _messages_to_prompt(messages) + self._cancelled = False + q: queue.Queue = queue.Queue() + + # Determine cwd from the last UserMessage's image paths if available + cwd = self._cwd + if not cwd: + for msg in reversed(messages): + if isinstance(msg, UserMessage): + for b in msg.content: + if isinstance(b, ImageBlock): + cwd = str(b.path.parent.parent) # session_dir + break + if cwd: + break + + async def _run(): + try: + got_assistant_text = False + async for msg in query( + prompt=prompt_text, + options=ClaudeAgentOptions( + model=self._model, + cwd=cwd or ".", + allowed_tools=_tool_schemas(tools), + system_prompt=SYSTEM_PROMPT, + max_turns=self._max_turns, + ), + ): + if self._cancelled: + break + if isinstance(msg, SDKAssistantMessage): + for block in msg.content: + if isinstance(block, SDKTextBlock): + q.put(TextDelta(text=block.text)) + got_assistant_text = True + elif isinstance(msg, ResultMessage): + # Only use ResultMessage.result if we got no text from AssistantMessages + if msg.result and not got_assistant_text: + q.put(TextDelta(text=msg.result)) + q.put(Done(stop_reason="end_turn")) + except CLINotFoundError: + q.put(Error( + message="Claude Code CLI not found.\n" + "Install it: https://claude.ai/code\n" + "Then run `claude` once in a terminal to authenticate." + )) + except CLIConnectionError as e: + if "auth" in str(e).lower() or "login" in str(e).lower() or "401" in str(e): + q.put(Error( + message="Claude Code not authenticated.\n" + "Run `claude` in a terminal and complete the login flow, then retry." + )) + else: + q.put(Error(message=str(e))) + except Exception as e: + q.put(Error(message=str(e))) + finally: + q.put(_SENTINEL) + + import asyncio + import threading + + def _thread(): + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(_run()) + finally: + loop.close() + + t = threading.Thread(target=_thread, daemon=True, name="claude_sdk_stream") + t.start() + + while True: + item = q.get() + if item is _SENTINEL: + break + yield item + + def cancel(self) -> None: + self._cancelled = True diff --git a/cht/agent/claude_sdk_provider.py b/cht/agent/claude_sdk_provider.py deleted file mode 100644 index 1078860..0000000 --- a/cht/agent/claude_sdk_provider.py +++ /dev/null @@ -1,132 +0,0 @@ -""" -Agent provider using the Claude Code SDK (claude_agent_sdk). - -Uses your Claude Code subscription — no direct API costs. -Passes frame paths in the prompt; Claude reads them visually via the Read tool. -""" - -import logging -from typing import Iterator - -import anyio -from claude_agent_sdk import query, ClaudeAgentOptions, AssistantMessage, TextBlock, ResultMessage -from claude_agent_sdk import CLINotFoundError, CLIConnectionError - -from cht.agent.base import AgentProvider, SessionContext - -log = logging.getLogger(__name__) - -SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording and analysis tool. -You help the user understand what happened during their recording session. - -You have access to frame screenshots extracted from the recording. When frames are mentioned, -use the Read tool to view them. Frame timestamps are in seconds from the start of the recording. - -Be concise and specific. Focus on what's visible in the frames.""" - - -def _build_prompt(message: str, context: SessionContext) -> str: - lines = [] - # Session summary - m, s = divmod(int(context.duration), 60) - lines.append(f"Recording duration: {m:02d}:{s:02d}") - lines.append(f"Total frames captured: {len(context.frames)}") - - if context.mentioned_frames: - lines.append("\nFrames:") - for f in context.mentioned_frames: - fm, fs = divmod(int(f.timestamp), 60) - lines.append(f" {f.id} at {fm:02d}:{fs:02d} — {f.path}") - if context.mentioned_transcripts: - lines.append("\nTranscript:") - for t in context.mentioned_transcripts: - tm1, ts1 = divmod(int(t.start), 60) - tm2, ts2 = divmod(int(t.end), 60) - lines.append(f" {t.id} [{tm1:02d}:{ts1:02d}-{tm2:02d}:{ts2:02d}] {t.text}") - - if context.history: - lines.append("\nConversation history:") - for role, text in context.history: - prefix = "User" if role == "user" else "Assistant" - lines.append(f" {prefix}: {text}") - - lines.append(f"\nUser message: {message}") - return "\n".join(lines) - - -MODELS = [ - "claude-sonnet-4-6", - "claude-opus-4-6", - "claude-haiku-4-5", -] - - -class ClaudeSDKProvider(AgentProvider): - """Uses claude_agent_sdk — requires Claude Code CLI to be installed.""" - - def __init__(self, cwd: str | None = None, max_turns: int = 5, model: str = MODELS[0]): - self._cwd = cwd - self._max_turns = max_turns - self._model = model - - @property - def name(self) -> str: - return f"claude-sdk/{self._model}" - - @property - def available_models(self) -> list[str]: - return list(MODELS) - - @property - def model(self) -> str: - return self._model - - @model.setter - def model(self, value: str): - self._model = value - - def stream(self, message: str, context: SessionContext) -> Iterator[str]: - prompt = _build_prompt(message, context) - chunks = [] - - async def _run(): - async for msg in query( - prompt=prompt, - options=ClaudeAgentOptions( - model=self._model, - cwd=self._cwd or str(context.session_dir), - allowed_tools=["Read"], - system_prompt=SYSTEM_PROMPT, - max_turns=self._max_turns, - ), - ): - if isinstance(msg, AssistantMessage): - for block in msg.content: - if isinstance(block, TextBlock): - chunks.append(block.text) - elif isinstance(msg, ResultMessage): - if msg.result: - chunks.append(msg.result) - - try: - import asyncio - loop = asyncio.new_event_loop() - try: - loop.run_until_complete(_run()) - finally: - loop.close() - except CLINotFoundError: - raise RuntimeError( - "Claude Code CLI not found.\n" - "Install it: https://claude.ai/code\n" - "Then run `claude` once in a terminal to authenticate." - ) - except CLIConnectionError as e: - if "auth" in str(e).lower() or "login" in str(e).lower() or "401" in str(e): - raise RuntimeError( - "Claude Code not authenticated.\n" - "Run `claude` in a terminal and complete the login flow, then retry." - ) - raise - - yield from chunks diff --git a/cht/agent/openai_compat_provider.py b/cht/agent/openai_compat_provider.py deleted file mode 100644 index 89fb6ac..0000000 --- a/cht/agent/openai_compat_provider.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -Agent provider for OpenAI-compatible APIs (Groq, OpenAI, etc.). - -Sends frame images as base64. Requires GROQ_API_KEY or OPENAI_API_KEY env var. -Auto-detects provider from available env keys. -""" - -import base64 -import logging -import os -from typing import Iterator - -from cht.agent.base import AgentProvider, SessionContext, FrameRef - -log = logging.getLogger(__name__) - -SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording and analysis tool. -You help the user understand what happened during their recording session. -Be concise and specific. Focus on what's visible in the provided frames.""" - -# Provider configs: (base_url, default_model, available_models) -_PROVIDER_CONFIGS = { - "groq": ( - "https://api.groq.com/openai/v1", - "meta-llama/llama-4-maverick-17b-128e-instruct", - [ - "meta-llama/llama-4-maverick-17b-128e-instruct", - "meta-llama/llama-4-scout-17b-16e-instruct", - "qwen/qwen-2.5-vl-72b-instruct", - ], - ), - "openai": ( - "https://api.openai.com/v1", - "gpt-4o", - ["gpt-4o", "gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini"], - ), -} - - -def _detect_provider() -> tuple[str, str, str, list[str]] | None: - """Returns (api_key, base_url, model, available_models) or None.""" - if key := os.environ.get("GROQ_API_KEY"): - base_url, default_model, models = _PROVIDER_CONFIGS["groq"] - model = os.environ.get("CHT_MODEL", default_model) - return key, base_url, model, models - if key := os.environ.get("OPENAI_API_KEY"): - base_url, default_model, models = _PROVIDER_CONFIGS["openai"] - model = os.environ.get("CHT_MODEL", default_model) - return key, base_url, model, models - return None - - -def _frame_to_image_content(frame: FrameRef) -> dict: - with open(frame.path, "rb") as f: - data = base64.standard_b64encode(f.read()).decode() - return { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{data}"}, - } - - -class OpenAICompatProvider(AgentProvider): - """Uses any OpenAI-compatible API. Auto-detects from env vars.""" - - def __init__(self): - detected = _detect_provider() - if not detected: - raise RuntimeError( - "No API key found. Set GROQ_API_KEY or OPENAI_API_KEY." - ) - self._api_key, self._base_url, self._model, self._models = detected - - @property - def name(self) -> str: - if "groq" in self._base_url: - return f"groq/{self._model}" - return f"openai-compat/{self._model}" - - @property - def available_models(self) -> list[str]: - return list(self._models) - - @property - def model(self) -> str: - return self._model - - @model.setter - def model(self, value: str): - self._model = value - - def stream(self, message: str, context: SessionContext) -> Iterator[str]: - from openai import OpenAI - - client = OpenAI(api_key=self._api_key, base_url=self._base_url) - - # Build context header - m, s = divmod(int(context.duration), 60) - ctx_lines = [ - f"Recording duration: {m:02d}:{s:02d}", - f"Total frames: {len(context.frames)}", - ] - if context.mentioned_transcripts: - ctx_lines.append("\nTranscript:") - for t in context.mentioned_transcripts: - tm1, ts1 = divmod(int(t.start), 60) - tm2, ts2 = divmod(int(t.end), 60) - ctx_lines.append(f" {t.id} [{tm1:02d}:{ts1:02d}-{tm2:02d}:{ts2:02d}] {t.text}") - ctx_text = "\n".join(ctx_lines) + "\n" - - frames_to_send = context.mentioned_frames - - content: list[dict] = [{"type": "text", "text": ctx_text + message}] - for frame in frames_to_send: - fm, fs = divmod(int(frame.timestamp), 60) - content.append({"type": "text", "text": f"{frame.id} at {fm:02d}:{fs:02d}:"}) - try: - content.append(_frame_to_image_content(frame)) - except Exception as e: - log.warning("Could not encode frame %s: %s", frame.id, e) - - messages = [{"role": "system", "content": SYSTEM_PROMPT}] - for role, text in context.history: - messages.append({"role": role, "content": text}) - messages.append({"role": "user", "content": content}) - - stream = client.chat.completions.create( - model=self._model, - messages=messages, - stream=True, - ) - for chunk in stream: - delta = chunk.choices[0].delta.content - if delta: - yield delta diff --git a/cht/agent/openai_connection.py b/cht/agent/openai_connection.py new file mode 100644 index 0000000..3f3fa50 --- /dev/null +++ b/cht/agent/openai_connection.py @@ -0,0 +1,249 @@ +""" +AgentConnection for OpenAI-compatible APIs (Groq, OpenAI, etc.). + +Sends frame images as base64. Supports tool calls via function calling. +Requires GROQ_API_KEY or OPENAI_API_KEY env var. +""" + +import base64 +import json +import logging +import os +from typing import Iterator + +from cht.agent.base import ( + AssistantMessage, + ImageBlock, + Message, + StreamEvent, + TextBlock, + TextDelta, + Tool, + ToolCallEnd, + ToolCallStart, + ToolResult, + ToolUse, + TranscriptBlock, + UserMessage, + Done, + Error, +) + +log = logging.getLogger(__name__) + +SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording and analysis tool. +You help the user understand what happened during their recording session. +Be concise and specific. Focus on what's visible in the provided frames.""" + +_PROVIDER_CONFIGS = { + "groq": ( + "https://api.groq.com/openai/v1", + "meta-llama/llama-4-maverick-17b-128e-instruct", + [ + "meta-llama/llama-4-maverick-17b-128e-instruct", + "meta-llama/llama-4-scout-17b-16e-instruct", + "qwen/qwen-2.5-vl-72b-instruct", + ], + ), + "openai": ( + "https://api.openai.com/v1", + "gpt-4o", + ["gpt-4o", "gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini"], + ), +} + + +def _detect_provider() -> tuple[str, str, str, list[str]] | None: + if key := os.environ.get("GROQ_API_KEY"): + base_url, default_model, models = _PROVIDER_CONFIGS["groq"] + model = os.environ.get("CHT_MODEL", default_model) + return key, base_url, model, models + if key := os.environ.get("OPENAI_API_KEY"): + base_url, default_model, models = _PROVIDER_CONFIGS["openai"] + model = os.environ.get("CHT_MODEL", default_model) + return key, base_url, model, models + return None + + +def _frame_to_base64(path) -> str | None: + try: + with open(path, "rb") as f: + return base64.standard_b64encode(f.read()).decode() + except Exception as e: + log.warning("Could not encode frame %s: %s", path, e) + return None + + +def _messages_to_openai(messages: list[Message]) -> list[dict]: + """Convert structured messages to OpenAI chat format.""" + result = [{"role": "system", "content": SYSTEM_PROMPT}] + + for msg in messages: + if isinstance(msg, UserMessage): + content: list[dict] = [] + for b in msg.content: + if isinstance(b, TextBlock): + content.append({"type": "text", "text": b.text}) + elif isinstance(b, ImageBlock): + m, s = divmod(int(b.timestamp), 60) + content.append({"type": "text", "text": f"{b.frame_id} at {m:02d}:{s:02d}:"}) + data = _frame_to_base64(b.path) + if data: + content.append({ + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{data}"}, + }) + elif isinstance(b, TranscriptBlock): + m1, s1 = divmod(int(b.start), 60) + m2, s2 = divmod(int(b.end), 60) + content.append({ + "type": "text", + "text": f"{b.transcript_id} [{m1:02d}:{s1:02d}-{m2:02d}:{s2:02d}] {b.text}", + }) + result.append({"role": "user", "content": content}) + + elif isinstance(msg, AssistantMessage): + text = " ".join(b.text for b in msg.content if isinstance(b, TextBlock)) + result.append({"role": "assistant", "content": text}) + + elif isinstance(msg, ToolUse): + result.append({ + "role": "assistant", + "content": None, + "tool_calls": [{ + "id": msg.id, + "type": "function", + "function": {"name": msg.tool_name, "arguments": json.dumps(msg.input)}, + }], + }) + + elif isinstance(msg, ToolResult): + result.append({ + "role": "tool", + "tool_call_id": msg.tool_use_id, + "content": msg.output or msg.error or "", + }) + + return result + + +def _tools_to_openai(tools: list[Tool]) -> list[dict] | None: + if not tools: + return None + return [ + { + "type": "function", + "function": { + "name": t.name, + "description": t.description, + "parameters": t.input_schema(), + }, + } + for t in tools + ] + + +class OpenAIConnection: + """AgentConnection using any OpenAI-compatible API.""" + + def __init__(self): + detected = _detect_provider() + if not detected: + raise RuntimeError("No API key found. Set GROQ_API_KEY or OPENAI_API_KEY.") + self._api_key, self._base_url, self._model, self._models = detected + self._cancelled = False + + @property + def name(self) -> str: + if "groq" in self._base_url: + return f"groq/{self._model}" + return f"openai-compat/{self._model}" + + def available_models(self) -> list[str]: + return list(self._models) + + def get_model(self) -> str: + return self._model + + def set_model(self, model: str) -> None: + self._model = model + + def prompt( + self, + messages: list[Message], + tools: list[Tool], + ) -> Iterator[StreamEvent]: + from openai import OpenAI + + client = OpenAI(api_key=self._api_key, base_url=self._base_url) + self._cancelled = False + + oai_messages = _messages_to_openai(messages) + oai_tools = _tools_to_openai(tools) + + kwargs = { + "model": self._model, + "messages": oai_messages, + "stream": True, + } + if oai_tools: + kwargs["tools"] = oai_tools + + try: + stream = client.chat.completions.create(**kwargs) + except Exception as e: + yield Error(message=str(e)) + return + + # Accumulate tool calls from streaming deltas + tool_calls: dict[int, dict] = {} # index → {id, name, arguments} + + for chunk in stream: + if self._cancelled: + break + choice = chunk.choices[0] if chunk.choices else None + if not choice: + continue + + delta = choice.delta + + # Text content + if delta.content: + yield TextDelta(text=delta.content) + + # Tool calls (streamed incrementally) + if delta.tool_calls: + for tc_delta in delta.tool_calls: + idx = tc_delta.index + if idx not in tool_calls: + tool_calls[idx] = {"id": "", "name": "", "arguments": ""} + if tc_delta.id: + tool_calls[idx]["id"] = tc_delta.id + if tc_delta.function: + if tc_delta.function.name: + tool_calls[idx]["name"] = tc_delta.function.name + if tc_delta.function.arguments: + tool_calls[idx]["arguments"] += tc_delta.function.arguments + + # Check finish reason + if choice.finish_reason: + if choice.finish_reason == "tool_calls": + # Emit accumulated tool calls + for idx in sorted(tool_calls.keys()): + tc = tool_calls[idx] + try: + inp = json.loads(tc["arguments"]) if tc["arguments"] else {} + except json.JSONDecodeError: + inp = {} + yield ToolCallStart(id=tc["id"], name=tc["name"], input=inp) + yield ToolCallEnd(id=tc["id"]) + yield Done(stop_reason="tool_use") + elif choice.finish_reason == "stop": + yield Done(stop_reason="end_turn") + elif choice.finish_reason == "length": + yield Done(stop_reason="max_tokens") + else: + yield Done(stop_reason=choice.finish_reason) + + def cancel(self) -> None: + self._cancelled = True diff --git a/cht/agent/runner.py b/cht/agent/runner.py index e947be1..ece3785 100644 --- a/cht/agent/runner.py +++ b/cht/agent/runner.py @@ -1,13 +1,13 @@ """ -Agent runner — resolves provider, parses @-mentions, dispatches messages. +Agent runner — resolves connection, parses @-mentions, dispatches messages, +executes tool loop. -Provider selection (in order): +Connection selection (in order): 1. GROQ_API_KEY → OpenAI-compat / Groq 2. OPENAI_API_KEY → OpenAI-compat / OpenAI 3. (default) → Claude Code SDK (uses CC subscription) """ -import json import logging import os import re @@ -15,7 +15,29 @@ from pathlib import Path from threading import Thread from typing import Callable -from cht.agent.base import AgentProvider, FrameRef, TranscriptRef, SessionContext +from cht.agent.base import ( + AssistantMessage, + FrameRef, + ImageBlock, + Message, + StreamEvent, + TextBlock, + TextDelta, + ToolCallEnd, + ToolCallStart, + ToolContext, + ToolResult, + ToolUse, + TranscriptBlock, + TranscriptRef, + UserMessage, + Done, + Error, + Thread as AgentThread, + save_thread, + load_thread, +) +from cht.agent.tools import load_frames, load_transcript, BUILTIN_TOOLS log = logging.getLogger(__name__) @@ -25,6 +47,8 @@ ACTIONS: dict[str, str] = { "Answer": "answer", } +MAX_TOOL_TURNS = 10 + def check_claude_cli() -> str | None: """Returns None if OK, or an error string if CLI is missing/unauthenticated.""" @@ -43,16 +67,15 @@ def check_claude_cli() -> str | None: return None -def _resolve_provider() -> AgentProvider: +def _resolve_connection(): if os.environ.get("GROQ_API_KEY") or os.environ.get("OPENAI_API_KEY"): - from cht.agent.openai_compat_provider import OpenAICompatProvider - return OpenAICompatProvider() - from cht.agent.claude_sdk_provider import ClaudeSDKProvider - return ClaudeSDKProvider() + from cht.agent.openai_connection import OpenAIConnection + return OpenAIConnection() + from cht.agent.claude_sdk_connection import ClaudeSDKConnection + return ClaudeSDKConnection() def _expand_ref_nums(spec: str) -> list[int]: - """Expand a ref spec like '2-6' or '2,4,6' or '2-4,6,8-10' into sorted ints.""" nums = set() for part in spec.split(","): part = part.strip() @@ -71,7 +94,6 @@ def _expand_ref_nums(spec: str) -> list[int]: def _parse_mentions(message: str, frames: list[FrameRef]) -> list[FrameRef]: - """Extract @F references. Accepts @F1, @F2-6, @F2,4,6, @F2-4,6,8-10.""" mentioned = [] seen = set() for match in re.finditer(r"@[Ff]([\d,\-]+)", message): @@ -85,49 +107,7 @@ def _parse_mentions(message: str, frames: list[FrameRef]) -> list[FrameRef]: return mentioned -def _resolve_frame_path(frames_dir: Path, raw_path: str) -> Path | None: - """Resolve a frame path from index.json, handling mounted/remote sessions.""" - p = Path(raw_path) - if p.exists(): - return p - # Try relative to frames_dir (handles path prefix mismatch from remote) - local = frames_dir / p.name - if local.exists(): - return local - return None - - -def _load_frames(frames_dir: Path) -> list[FrameRef]: - index_path = frames_dir / "index.json" - if not index_path.exists(): - return [] - try: - entries = json.loads(index_path.read_text()) - frames = [] - for e in entries: - resolved = _resolve_frame_path(frames_dir, e["path"]) - if resolved: - frames.append(FrameRef(id=e["id"], path=resolved, timestamp=e["timestamp"])) - return frames - except Exception as e: - log.warning("Could not load frames index: %s", e) - return [] - - -def _load_transcript(transcript_dir: Path) -> list[TranscriptRef]: - index_path = transcript_dir / "index.json" - if not index_path.exists(): - return [] - try: - entries = json.loads(index_path.read_text()) - return [TranscriptRef(**e) for e in entries] - except Exception as e: - log.warning("Could not load transcript index: %s", e) - return [] - - def _parse_transcript_mentions(message: str, segments: list[TranscriptRef]) -> list[TranscriptRef]: - """Extract @T references. Accepts @T1, @T2-6, @T2,4,6, @T1-3,5,7-10.""" mentioned = [] seen = set() for match in re.finditer(r"@[Tt]([\d,\-]+)", message): @@ -141,84 +121,191 @@ def _parse_transcript_mentions(message: str, segments: list[TranscriptRef]) -> l return mentioned +def _build_user_message(text: str, mentioned_frames: list[FrameRef], + mentioned_transcripts: list[TranscriptRef]) -> UserMessage: + """Build a UserMessage with content blocks from text and @-mentions.""" + content: list = [TextBlock(text=text)] + for f in mentioned_frames: + content.append(ImageBlock(frame_id=f.id, path=f.path, timestamp=f.timestamp)) + for t in mentioned_transcripts: + content.append(TranscriptBlock( + transcript_id=t.id, start=t.start, end=t.end, text=t.text + )) + return UserMessage(content=content) + + class AgentRunner: - """Runs agent queries in a background thread, streams chunks to a callback.""" + """Runs agent queries in a background thread with tool execution loop.""" def __init__(self): - self._provider: AgentProvider | None = None - self._history: list[tuple[str, str]] = [] # (role, text) - self.include_history = False # toggled by UI + self._connection = None + self._thread: AgentThread = AgentThread() + self.include_history = False - def _get_provider(self) -> AgentProvider: - if self._provider is None: - self._provider = _resolve_provider() - log.info("Agent provider: %s", self._provider.name) - return self._provider + @property + def thread(self) -> AgentThread: + return self._thread + + def _get_connection(self): + if self._connection is None: + self._connection = _resolve_connection() + log.info("Agent connection: %s", self._connection.name) + return self._connection @property def provider_name(self) -> str: try: - return self._get_provider().name + return self._get_connection().name except Exception: return "unknown" @property def available_models(self) -> list[str]: try: - return self._get_provider().available_models + return self._get_connection().available_models() except Exception: return [] @property def model(self) -> str: try: - return self._get_provider().model + return self._get_connection().get_model() except Exception: return "" @model.setter def model(self, value: str): - self._get_provider().model = value + self._get_connection().set_model(value) def clear_history(self): - self._history.clear() + self._thread = AgentThread() + + def set_thread(self, thread: AgentThread): + self._thread = thread + + def load_from_session(self, session_dir: Path): + loaded = load_thread(session_dir) + if loaded: + self._thread = loaded + log.info("Loaded thread %s with %d messages", loaded.id, len(loaded.messages)) + else: + self._thread = AgentThread() def send( self, message: str, stream_mgr, tracker, - on_chunk: Callable[[str], None], + on_event: Callable[[StreamEvent], None], on_done: Callable[[str | None], None], ): - """Dispatch message in a background thread. + """Dispatch message in a background thread with tool execution loop. - on_chunk(text) — called for each streamed chunk - on_done(error_or_None) — called when complete + on_event(StreamEvent) — called for each stream event (from bg thread) + on_done(error_or_None) — called when complete (from bg thread) """ def _run(): try: - provider = self._get_provider() - frames = _load_frames(stream_mgr.frames_dir) + connection = self._get_connection() + frames = load_frames(stream_mgr.frames_dir) mentioned_frames = _parse_mentions(message, frames) - transcript = _load_transcript(stream_mgr.transcript_dir) + transcript = load_transcript(stream_mgr.transcript_dir) mentioned_transcripts = _parse_transcript_mentions(message, transcript) - context = SessionContext( + + # Build and append user message + user_msg = _build_user_message(message, mentioned_frames, mentioned_transcripts) + self._thread.messages.append(user_msg) + + # Build tool context + tool_ctx = ToolContext( session_dir=stream_mgr.session_dir, - frames=frames, - duration=tracker.duration if tracker else 0.0, - mentioned_frames=mentioned_frames, - transcript_segments=transcript, - mentioned_transcripts=mentioned_transcripts, - history=list(self._history) if self.include_history else [], + frames_dir=stream_mgr.frames_dir, + transcript_dir=stream_mgr.transcript_dir, + stream_mgr=stream_mgr, + tracker=tracker, ) - self._history.append(("user", message)) - response_chunks = [] - for chunk in provider.stream(message, context): - response_chunks.append(chunk) - on_chunk(chunk) - self._history.append(("assistant", "".join(response_chunks))) + + # Tool registry + tools_by_name = {t.name: t for t in BUILTIN_TOOLS} + + # Messages to send — full thread or just last message + def _get_messages(): + if self.include_history: + return list(self._thread.messages) + return [self._thread.messages[-1]] + + # Tool execution loop + full_text_parts = [] + for _turn in range(MAX_TOOL_TURNS): + msgs = _get_messages() + stop_reason = None + + for event in connection.prompt(msgs, BUILTIN_TOOLS): + on_event(event) + + if isinstance(event, TextDelta): + full_text_parts.append(event.text) + + elif isinstance(event, ToolCallStart): + tool_use = ToolUse( + id=event.id, + tool_name=event.name, + input=event.input, + status="running", + ) + self._thread.messages.append(tool_use) + + elif isinstance(event, ToolCallEnd): + # Execute the tool + tool = tools_by_name.get(event.id) + # Find the ToolUse by id + tool_use_msg = None + for m in reversed(self._thread.messages): + if isinstance(m, ToolUse) and m.id == event.id: + tool_use_msg = m + break + + if tool_use_msg: + tool_impl = tools_by_name.get(tool_use_msg.tool_name) + if tool_impl: + result = tool_impl.run(tool_use_msg.input, tool_ctx) + result.tool_use_id = tool_use_msg.id + tool_use_msg.status = "error" if result.error else "done" + else: + result = ToolResult( + tool_use_id=tool_use_msg.id, + error=f"Unknown tool: {tool_use_msg.tool_name}", + ) + tool_use_msg.status = "error" + self._thread.messages.append(result) + on_event(result) + + elif isinstance(event, Done): + stop_reason = event.stop_reason + break + + elif isinstance(event, Error): + on_done(event.message) + return + + # Build assistant message from accumulated text + if full_text_parts: + asst_msg = AssistantMessage( + content=[TextBlock(text="".join(full_text_parts))], + model=connection.get_model(), + ) + self._thread.messages.append(asst_msg) + + if stop_reason != "tool_use": + break + + # Reset text for next turn (tool loop continues) + full_text_parts = [] + + # Save thread + save_thread(self._thread, stream_mgr.session_dir) on_done(None) + except Exception as e: log.error("Agent error: %s", e) on_done(str(e)) diff --git a/cht/agent/tools.py b/cht/agent/tools.py new file mode 100644 index 0000000..649cb32 --- /dev/null +++ b/cht/agent/tools.py @@ -0,0 +1,201 @@ +"""Built-in agent tools — ReadFrame, SearchTranscript, GetSessionInfo, CaptureFrame. + +Also contains shared data-loading functions (moved from runner.py). +""" + +import json +import logging +from pathlib import Path + +from cht.agent.base import FrameRef, TranscriptRef, ToolContext, ToolResult + +log = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Data loading (shared by tools and runner) +# --------------------------------------------------------------------------- + +def _resolve_frame_path(frames_dir: Path, raw_path: str) -> Path | None: + p = Path(raw_path) + if p.exists(): + return p + local = frames_dir / p.name + if local.exists(): + return local + return None + + +def load_frames(frames_dir: Path) -> list[FrameRef]: + index_path = frames_dir / "index.json" + if not index_path.exists(): + return [] + try: + entries = json.loads(index_path.read_text()) + frames = [] + for e in entries: + resolved = _resolve_frame_path(frames_dir, e["path"]) + if resolved: + frames.append(FrameRef(id=e["id"], path=resolved, timestamp=e["timestamp"])) + return frames + except Exception as e: + log.warning("Could not load frames index: %s", e) + return [] + + +def load_transcript(transcript_dir: Path) -> list[TranscriptRef]: + index_path = transcript_dir / "index.json" + if not index_path.exists(): + return [] + try: + entries = json.loads(index_path.read_text()) + return [TranscriptRef(**e) for e in entries] + except Exception as e: + log.warning("Could not load transcript index: %s", e) + return [] + + +# --------------------------------------------------------------------------- +# Tool implementations +# --------------------------------------------------------------------------- + +class ReadFrameTool: + name = "read_frame" + description = "Read frame screenshots by ID. Returns file paths for visual inspection." + + def input_schema(self) -> dict: + return { + "type": "object", + "properties": { + "frame_ids": { + "type": "array", + "items": {"type": "string"}, + "description": "Frame IDs like ['F0001', 'F0003']", + } + }, + "required": ["frame_ids"], + } + + def run(self, input: dict, context: ToolContext) -> ToolResult: + frame_ids = input.get("frame_ids", []) + frames = load_frames(context.frames_dir) + by_id = {f.id: f for f in frames} + lines = [] + for fid in frame_ids: + f = by_id.get(fid) + if f: + m, s = divmod(int(f.timestamp), 60) + lines.append(f"{f.id} at {m:02d}:{s:02d} — {f.path}") + else: + lines.append(f"{fid}: not found") + return ToolResult(tool_use_id="", output="\n".join(lines)) + + +class SearchTranscriptTool: + name = "search_transcript" + description = "Search transcript segments by text substring and/or time range." + + def input_schema(self) -> dict: + return { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Text substring to search for (optional)"}, + "start": {"type": "number", "description": "Start time in seconds (optional)"}, + "end": {"type": "number", "description": "End time in seconds (optional)"}, + }, + } + + def run(self, input: dict, context: ToolContext) -> ToolResult: + segments = load_transcript(context.transcript_dir) + query = input.get("query", "").lower() + start = input.get("start") + end = input.get("end") + + matches = [] + for seg in segments: + if start is not None and seg.end < start: + continue + if end is not None and seg.start > end: + continue + if query and query not in seg.text.lower(): + continue + matches.append(seg) + + if not matches: + return ToolResult(tool_use_id="", output="No matching transcript segments found.") + + lines = [] + for seg in matches: + m1, s1 = divmod(int(seg.start), 60) + m2, s2 = divmod(int(seg.end), 60) + lines.append(f"{seg.id} [{m1:02d}:{s1:02d}-{m2:02d}:{s2:02d}] {seg.text}") + return ToolResult(tool_use_id="", output="\n".join(lines)) + + +class GetSessionInfoTool: + name = "get_session_info" + description = "Get recording session information: duration, frame count, segment list." + + def input_schema(self) -> dict: + return {"type": "object", "properties": {}} + + def run(self, input: dict, context: ToolContext) -> ToolResult: + frames = load_frames(context.frames_dir) + segments = load_transcript(context.transcript_dir) + + duration = 0.0 + if context.tracker: + duration = getattr(context.tracker, "duration", 0.0) + + m, s = divmod(int(duration), 60) + lines = [ + f"Recording duration: {m:02d}:{s:02d}", + f"Frames captured: {len(frames)}", + f"Transcript segments: {len(segments)}", + ] + + # List recording segments from session dir + stream_dir = context.session_dir / "stream" + if stream_dir.exists(): + recordings = sorted(stream_dir.glob("recording_*.mp4")) + lines.append(f"Recording files: {len(recordings)}") + for rec in recordings: + lines.append(f" {rec.name}") + + return ToolResult(tool_use_id="", output="\n".join(lines)) + + +class CaptureFrameTool: + name = "capture_frame" + description = "Capture a frame at the current recording position." + + def input_schema(self) -> dict: + return {"type": "object", "properties": {}} + + def run(self, input: dict, context: ToolContext) -> ToolResult: + mgr = context.stream_mgr + if mgr is None: + return ToolResult(tool_use_id="", error="No active stream manager") + if getattr(mgr, "readonly", False): + return ToolResult(tool_use_id="", error="Session is read-only, cannot capture") + + import threading + result = {"done": False, "error": None} + event = threading.Event() + + def _on_frames(frames): + result["done"] = True + event.set() + + try: + mgr.capture_now(on_new_frames=_on_frames) + event.wait(timeout=10) + if not result["done"]: + return ToolResult(tool_use_id="", error="Capture timed out") + return ToolResult(tool_use_id="", output="Frame captured successfully.") + except Exception as e: + return ToolResult(tool_use_id="", error=str(e)) + + +# All built-in tools +BUILTIN_TOOLS = [ReadFrameTool(), SearchTranscriptTool(), GetSessionInfoTool(), CaptureFrameTool()] diff --git a/cht/config.py b/cht/config.py index c4b7e09..867b2a6 100644 --- a/cht/config.py +++ b/cht/config.py @@ -31,3 +31,7 @@ WHISPER_MODEL = "medium" # "small" for speed, "medium" for accuracy WHISPER_DEVICE = "cuda" # "cuda" or "cpu" TRANSCRIBE_MIN_CHUNK_S = 5 # minimum seconds of audio before transcribing TRANSCRIBE_LINES_PER_GROUP = 3 # whisper segments grouped per transcript ID (1-5) + +# Agent settings +AGENT_PERMISSION_MODE = "bypassPermissions" # default|acceptEdits|plan|bypassPermissions|dontAsk +AGENT_MAX_TURNS = 5 diff --git a/cht/ui/agent_output.py b/cht/ui/agent_output.py index 315495c..0834633 100644 --- a/cht/ui/agent_output.py +++ b/cht/ui/agent_output.py @@ -1,20 +1,35 @@ -"""Agent output panel — scrollable text view with markdown rendering.""" +"""Agent output panel — single TextView, fully selectable/copy-pastable.""" + +import logging import gi gi.require_version("Gtk", "4.0") -from gi.repository import Gtk +from gi.repository import Gtk, GLib, Pango from cht.ui import markdown +from cht.agent.base import ( + AssistantMessage, + ImageBlock, + TextBlock, + Thread, + ToolResult, + ToolUse, + TranscriptBlock, + UserMessage, +) + +log = logging.getLogger(__name__) class AgentOutputPanel(Gtk.Frame): - """Scrollable text view that displays agent responses with markdown.""" + """Scrollable text view showing the full conversation, copy-pastable.""" def __init__(self, **kwargs): super().__init__(**kwargs) box = Gtk.Box(orientation=Gtk.Orientation.VERTICAL, spacing=0) + # Header header = Gtk.Box(orientation=Gtk.Orientation.HORIZONTAL, spacing=4) header.set_margin_start(8) header.set_margin_end(8) @@ -33,13 +48,16 @@ class AgentOutputPanel(Gtk.Frame): box.append(header) + # Single text view for the whole conversation self._view = Gtk.TextView() self._view.set_editable(False) self._view.set_wrap_mode(Gtk.WrapMode.WORD_CHAR) self._view.set_cursor_visible(False) self._view.set_left_margin(8) self._view.set_right_margin(8) - markdown.setup_tags(self._view.get_buffer()) + self._view.set_top_margin(4) + self._view.set_bottom_margin(4) + self._setup_tags() scroll = Gtk.ScrolledWindow() scroll.set_vexpand(True) @@ -48,52 +66,170 @@ class AgentOutputPanel(Gtk.Frame): self.set_child(box) - # Streaming state - self._thinking_replaced = False - self._response_start_mark = None - self._response_accum = [] + # Streaming state: track where the current assistant response starts + self._response_marks: dict[str, Gtk.TextMark] = {} + self._response_accum: dict[str, list[str]] = {} + + def _setup_tags(self): + buf = self._view.get_buffer() + markdown.setup_tags(buf) + buf.create_tag("user-prefix", weight=700, foreground="#7aafff") + buf.create_tag("assistant-prefix", weight=700, foreground="#8bc78b") + buf.create_tag("tool-prefix", weight=700, foreground="#d4a053") + buf.create_tag("tool-output", foreground="#aaaaaa", left_margin=16) + buf.create_tag("error", foreground="#ff6b6b") + buf.create_tag("ref-chip", foreground="#7aafff", style=Pango.Style.ITALIC) + buf.create_tag("status", foreground="#888888") + + # -- Public API -- def append(self, text: str) -> None: - """Append plain text and auto-scroll.""" + """Append a status/info line.""" buf = self._view.get_buffer() - buf.insert(buf.get_end_iter(), text) - self._view.scroll_to_iter(buf.get_end_iter(), 0, False, 0, 0) + end = buf.get_end_iter() + mark = buf.create_mark(None, end, True) + buf.insert(end, text) + start = buf.get_iter_at_mark(mark) + buf.apply_tag_by_name("status", start, buf.get_end_iter()) + buf.delete_mark(mark) + self._scroll_to_bottom() def clear(self) -> None: - """Clear all output.""" self._view.get_buffer().set_text("") + self._response_marks.clear() + self._response_accum.clear() - def begin_response(self) -> None: - """Reset streaming state for a new response (call before sending).""" - self._thinking_replaced = False - self._response_start_mark = None - self._response_accum = [] - - def replace_thinking(self, chunk: str) -> None: - """Replace the '...' placeholder with streamed chunks.""" + def add_user_message(self, text: str, frames: list | None = None, + transcripts: list | None = None) -> None: buf = self._view.get_buffer() - if not self._thinking_replaced: - self._thinking_replaced = True - end = buf.get_end_iter() - start = end.copy() - start.backward_chars(2) - buf.delete(start, end) - self._response_start_mark = buf.create_mark( - None, buf.get_end_iter(), left_gravity=True - ) - self._response_accum.append(chunk) - self.append(chunk) + end = buf.get_end_iter() - def finish_response(self, err: str | None) -> None: - """Finalize the response — render markdown or show error.""" - if err: - self.append(f"[Error: {err}]\n") - return - if self._response_start_mark and self._response_accum: - buf = self._view.get_buffer() - start = buf.get_iter_at_mark(self._response_start_mark) + # Prefix + mark = buf.create_mark(None, end, True) + buf.insert(end, "\n> ") + buf.apply_tag_by_name("user-prefix", buf.get_iter_at_mark(mark), buf.get_end_iter()) + buf.delete_mark(mark) + + # Text + buf.insert(buf.get_end_iter(), text) + + # Reference chips + refs = [] + if frames: + for f in frames: + fid = f.frame_id if hasattr(f, "frame_id") else (f.id if hasattr(f, "id") else str(f)) + refs.append(fid) + if transcripts: + for t in transcripts: + tid = t.transcript_id if hasattr(t, "transcript_id") else (t.id if hasattr(t, "id") else str(t)) + refs.append(tid) + if refs: end = buf.get_end_iter() - buf.delete(start, end) - markdown.render(buf, start, "".join(self._response_accum)) - buf.delete_mark(self._response_start_mark) - self.append("\n") + mark = buf.create_mark(None, end, True) + buf.insert(end, " [" + ", ".join(refs) + "]") + buf.apply_tag_by_name("ref-chip", buf.get_iter_at_mark(mark), buf.get_end_iter()) + buf.delete_mark(mark) + + buf.insert(buf.get_end_iter(), "\n") + self._scroll_to_bottom() + + def begin_assistant_message(self, msg_id: str) -> None: + buf = self._view.get_buffer() + end = buf.get_end_iter() + + # Mark where this response starts (for markdown re-render on finish) + self._response_marks[msg_id] = buf.create_mark(f"resp_{msg_id}", end, True) + self._response_accum[msg_id] = [] + + def append_to_assistant(self, msg_id: str, text: str) -> None: + if msg_id not in self._response_accum: + return + self._response_accum[msg_id].append(text) + buf = self._view.get_buffer() + buf.insert(buf.get_end_iter(), text) + self._scroll_to_bottom() + + def finish_assistant(self, msg_id: str, full_text: str) -> None: + mark = self._response_marks.pop(msg_id, None) + self._response_accum.pop(msg_id, None) + if not mark: + return + buf = self._view.get_buffer() + start = buf.get_iter_at_mark(mark) + end = buf.get_end_iter() + buf.delete(start, end) + it = buf.get_iter_at_mark(mark) + markdown.render(buf, it, full_text) + buf.insert(buf.get_end_iter(), "\n") + buf.delete_mark(mark) + + def add_tool_call(self, tool_use: ToolUse) -> None: + buf = self._view.get_buffer() + end = buf.get_end_iter() + + mark = buf.create_mark(None, end, True) + buf.insert(end, f" ▶ {tool_use.tool_name}") + buf.apply_tag_by_name("tool-prefix", buf.get_iter_at_mark(mark), buf.get_end_iter()) + buf.delete_mark(mark) + + if tool_use.input: + inp = str(tool_use.input) + if len(inp) > 80: + inp = inp[:77] + "..." + end = buf.get_end_iter() + mark = buf.create_mark(None, end, True) + buf.insert(end, f" {inp}") + buf.apply_tag_by_name("tool-output", buf.get_iter_at_mark(mark), buf.get_end_iter()) + buf.delete_mark(mark) + + buf.insert(buf.get_end_iter(), "\n") + self._scroll_to_bottom() + + def update_tool_result(self, tool_use_id: str, result: ToolResult) -> None: + buf = self._view.get_buffer() + text = result.error or result.output or "" + if not text: + return + end = buf.get_end_iter() + mark = buf.create_mark(None, end, True) + # Indent tool output + indented = "\n".join(f" {line}" for line in text.split("\n")) + tag = "error" if result.error else "tool-output" + buf.insert(end, indented + "\n") + buf.apply_tag_by_name(tag, buf.get_iter_at_mark(mark), buf.get_end_iter()) + buf.delete_mark(mark) + self._scroll_to_bottom() + + def load_thread(self, thread: Thread) -> None: + """Replay a thread to rebuild the conversation view.""" + self.clear() + for msg in thread.messages: + if isinstance(msg, UserMessage): + text_parts = [] + frames = [] + transcripts = [] + for b in msg.content: + if isinstance(b, TextBlock): + text_parts.append(b.text) + elif isinstance(b, ImageBlock): + frames.append(b) + elif isinstance(b, TranscriptBlock): + transcripts.append(b) + self.add_user_message(" ".join(text_parts), frames, transcripts) + elif isinstance(msg, AssistantMessage): + text = " ".join(b.text for b in msg.content if isinstance(b, TextBlock)) + buf = self._view.get_buffer() + it = buf.get_end_iter() + markdown.render(buf, it, text) + buf.insert(buf.get_end_iter(), "\n") + elif isinstance(msg, ToolUse): + self.add_tool_call(msg) + elif isinstance(msg, ToolResult): + self.update_tool_result(msg.tool_use_id, msg) + + def _scroll_to_bottom(self): + def _do(): + adj = self._view.get_parent().get_vadjustment() + adj.set_value(adj.get_upper() - adj.get_page_size()) + return False + GLib.idle_add(_do) diff --git a/cht/window.py b/cht/window.py index d3fb005..f6e8823 100644 --- a/cht/window.py +++ b/cht/window.py @@ -28,6 +28,8 @@ from cht.ui.session_dialog import SessionDialog from cht.session import load_frame_index, load_segment_manifest, rebuild_manifest, global_time_to_segment from cht.scrub.manager import ProxyManager from cht.agent.runner import AgentRunner, check_claude_cli +from cht.agent.base import TextDelta, ToolCallStart, ToolCallEnd, ToolResult, ToolUse, Done, Error +from cht.agent.buffer import StreamingTextBuffer from cht.telemetry import Telemetry log = logging.getLogger(__name__) @@ -245,6 +247,11 @@ class ChtWindow(Adw.ApplicationWindow): self._update_scrub_bar_manifest() self._populate_model_dropdown() + # Load persisted agent conversation + self._agent.load_from_session(mgr.session_dir) + if self._agent.thread.messages: + self._agent_output.load_thread(self._agent.thread) + def _reload_waveform(self, mgr): """Recompute waveform from existing segments in background.""" segments = mgr.recording_segments @@ -289,6 +296,9 @@ class ChtWindow(Adw.ApplicationWindow): self._load_existing_frames() self._load_existing_transcript() self._reload_waveform(mgr) + self._agent.load_from_session(mgr.session_dir) + if self._agent.thread.messages: + self._agent_output.load_thread(self._agent.thread) self.set_title(f"{APP_NAME} — {mgr.session_id}") log.info("Waiting for sender...") @@ -453,6 +463,11 @@ class ChtWindow(Adw.ApplicationWindow): mgr = self._lifecycle.stream_mgr last_session_id = mgr.session_id if mgr and not mgr.readonly else None + # Save agent thread before stopping + if mgr and self._agent.thread.messages: + from cht.agent.base import save_thread + save_thread(self._agent.thread, mgr.session_dir) + if self._telemetry: self._telemetry.close() self._telemetry = None @@ -604,15 +619,53 @@ class ChtWindow(Adw.ApplicationWindow): self._agent_output.append("No active session.\n") return - self._agent_output.append(f"\n> {text}\n…\n") + # Show user message in UI + from cht.agent.runner import _parse_mentions, _parse_transcript_mentions + from cht.agent.tools import load_frames, load_transcript + mgr = self._lifecycle.stream_mgr + frames = load_frames(mgr.frames_dir) + mentioned_frames = _parse_mentions(text, frames) + transcript = load_transcript(mgr.transcript_dir) + mentioned_transcripts = _parse_transcript_mentions(text, transcript) + self._agent_output.add_user_message(text, mentioned_frames, mentioned_transcripts) + + # Prepare streaming + from cht.agent.base import _msg_id + msg_id = _msg_id() + self._agent_output.begin_assistant_message(msg_id) + + full_text_parts = [] + buffer = StreamingTextBuffer( + on_reveal=lambda chunk: self._agent_output.append_to_assistant(msg_id, chunk) + ) + + def on_event(event): + if isinstance(event, TextDelta): + full_text_parts.append(event.text) + GLib.idle_add(buffer.push, event.text) + elif isinstance(event, ToolCallStart): + tu = ToolUse(id=event.id, tool_name=event.name, input=event.input, status="running") + GLib.idle_add(self._agent_output.add_tool_call, tu) + elif isinstance(event, ToolResult): + GLib.idle_add(self._agent_output.update_tool_result, event.tool_use_id, event) + elif isinstance(event, Error): + GLib.idle_add(self._agent_output.append, f"[Error: {event.message}]\n") + + def on_done(err): + def _finish(): + buffer.flush() + if err: + self._agent_output.append(f"[Error: {err}]\n") + else: + self._agent_output.finish_assistant(msg_id, "".join(full_text_parts)) + GLib.idle_add(_finish) - self._agent_output.begin_response() self._agent.send( message=text, - stream_mgr=self._lifecycle.stream_mgr, + stream_mgr=mgr, tracker=self._lifecycle.tracker, - on_chunk=lambda chunk: GLib.idle_add(self._agent_output.replace_thinking, chunk), - on_done=lambda err: GLib.idle_add(self._agent_output.finish_response, err), + on_event=on_event, + on_done=on_done, ) # -- Settings callbacks --