better agent
This commit is contained in:
@@ -1,16 +1,28 @@
|
|||||||
"""
|
"""
|
||||||
Abstract base for agent providers.
|
Agent data model — structured messages, connections, tools, threads.
|
||||||
|
|
||||||
Each provider takes a user message + session context and yields response
|
Replaces the old flat AgentProvider/SessionContext with:
|
||||||
text chunks for streaming into the UI.
|
- Typed content blocks and messages (conversation model)
|
||||||
|
- StreamEvent types (connection output)
|
||||||
|
- AgentConnection protocol (replaces AgentProvider)
|
||||||
|
- Tool protocol
|
||||||
|
- Thread (session state, serializable)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterator
|
from typing import Iterator, Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shared refs (used by runner mention-parsing and tools)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FrameRef:
|
class FrameRef:
|
||||||
id: str # "F0001"
|
id: str # "F0001"
|
||||||
@@ -26,40 +38,298 @@ class TranscriptRef:
|
|||||||
text: str
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Content blocks
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SessionContext:
|
class TextBlock:
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImageBlock:
|
||||||
|
frame_id: str # "F0042"
|
||||||
|
path: Path
|
||||||
|
timestamp: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TranscriptBlock:
|
||||||
|
transcript_id: str # "T0012"
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
ContentBlock = TextBlock | ImageBlock | TranscriptBlock
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Messages
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenUsage:
|
||||||
|
input_tokens: int = 0
|
||||||
|
output_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
def _msg_id() -> str:
|
||||||
|
return uuid.uuid4().hex[:12]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UserMessage:
|
||||||
|
content: list[ContentBlock]
|
||||||
|
id: str = field(default_factory=_msg_id)
|
||||||
|
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AssistantMessage:
|
||||||
|
content: list[ContentBlock]
|
||||||
|
model: str = ""
|
||||||
|
id: str = field(default_factory=_msg_id)
|
||||||
|
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
token_usage: TokenUsage | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolUse:
|
||||||
|
tool_name: str
|
||||||
|
input: dict
|
||||||
|
id: str = field(default_factory=_msg_id)
|
||||||
|
status: Literal["pending", "running", "done", "error"] = "pending"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolResult:
|
||||||
|
tool_use_id: str
|
||||||
|
output: str | None = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
Message = UserMessage | AssistantMessage | ToolUse | ToolResult
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Thread (session state)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Thread:
|
||||||
|
messages: list[Message] = field(default_factory=list)
|
||||||
|
id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
|
||||||
|
created: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
updated: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
token_usage: TokenUsage | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Stream events (yielded by AgentConnection)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextDelta:
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolCallStart:
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
input: dict
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolCallEnd:
|
||||||
|
id: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Done:
|
||||||
|
stop_reason: str # "end_turn", "tool_use", "max_tokens"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Error:
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
StreamEvent = TextDelta | ToolCallStart | ToolCallEnd | Done | Error
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tool protocol
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class Tool(Protocol):
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
def input_schema(self) -> dict: ...
|
||||||
|
def run(self, input: dict, context: ToolContext) -> ToolResult: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolContext:
|
||||||
|
"""Runtime context passed to tools — references to live app objects."""
|
||||||
session_dir: Path
|
session_dir: Path
|
||||||
frames: list[FrameRef] # all captured frames so far
|
frames_dir: Path
|
||||||
duration: float # current recording duration (seconds)
|
transcript_dir: Path
|
||||||
mentioned_frames: list[FrameRef] = field(default_factory=list)
|
stream_mgr: object | None = None # StreamManager, optional
|
||||||
transcript_segments: list[TranscriptRef] = field(default_factory=list)
|
tracker: object | None = None # RecordingTracker, optional
|
||||||
mentioned_transcripts: list[TranscriptRef] = field(default_factory=list)
|
|
||||||
history: list[tuple[str, str]] = field(default_factory=list) # [(role, text), ...]
|
|
||||||
|
|
||||||
|
|
||||||
class AgentProvider(ABC):
|
# ---------------------------------------------------------------------------
|
||||||
@abstractmethod
|
# AgentConnection protocol (replaces AgentProvider)
|
||||||
def stream(self, message: str, context: SessionContext) -> Iterator[str]:
|
# ---------------------------------------------------------------------------
|
||||||
"""Yield response text chunks."""
|
|
||||||
|
@runtime_checkable
|
||||||
|
class AgentConnection(Protocol):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
def available_models(self) -> list[str]: ...
|
||||||
|
def get_model(self) -> str: ...
|
||||||
|
def set_model(self, model: str) -> None: ...
|
||||||
|
|
||||||
|
def prompt(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
tools: list[Tool],
|
||||||
|
) -> Iterator[StreamEvent]:
|
||||||
|
"""Send messages to the model, yield stream events."""
|
||||||
...
|
...
|
||||||
|
|
||||||
@property
|
def cancel(self) -> None: ...
|
||||||
@abstractmethod
|
|
||||||
def name(self) -> str:
|
|
||||||
...
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def available_models(self) -> list[str]:
|
|
||||||
"""Return list of model IDs this provider supports."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@property
|
# ---------------------------------------------------------------------------
|
||||||
@abstractmethod
|
# Thread serialization
|
||||||
def model(self) -> str:
|
# ---------------------------------------------------------------------------
|
||||||
...
|
|
||||||
|
|
||||||
@model.setter
|
class _Encoder(json.JSONEncoder):
|
||||||
@abstractmethod
|
def default(self, o):
|
||||||
def model(self, value: str):
|
if isinstance(o, Path):
|
||||||
...
|
return str(o)
|
||||||
|
if isinstance(o, datetime):
|
||||||
|
return o.isoformat()
|
||||||
|
return super().default(o)
|
||||||
|
|
||||||
|
|
||||||
|
def _block_to_dict(b: ContentBlock) -> dict:
|
||||||
|
if isinstance(b, TextBlock):
|
||||||
|
return {"type": "text", "text": b.text}
|
||||||
|
if isinstance(b, ImageBlock):
|
||||||
|
return {"type": "image", "frame_id": b.frame_id, "path": str(b.path), "timestamp": b.timestamp}
|
||||||
|
if isinstance(b, TranscriptBlock):
|
||||||
|
return {"type": "transcript", "transcript_id": b.transcript_id, "start": b.start, "end": b.end, "text": b.text}
|
||||||
|
raise TypeError(f"Unknown block type: {type(b)}")
|
||||||
|
|
||||||
|
|
||||||
|
def _block_from_dict(d: dict) -> ContentBlock:
|
||||||
|
t = d["type"]
|
||||||
|
if t == "text":
|
||||||
|
return TextBlock(text=d["text"])
|
||||||
|
if t == "image":
|
||||||
|
return ImageBlock(frame_id=d["frame_id"], path=Path(d["path"]), timestamp=d["timestamp"])
|
||||||
|
if t == "transcript":
|
||||||
|
return TranscriptBlock(transcript_id=d["transcript_id"], start=d["start"], end=d["end"], text=d["text"])
|
||||||
|
raise ValueError(f"Unknown block type: {t}")
|
||||||
|
|
||||||
|
|
||||||
|
def _msg_to_dict(m: Message) -> dict:
|
||||||
|
if isinstance(m, UserMessage):
|
||||||
|
return {
|
||||||
|
"type": "user",
|
||||||
|
"id": m.id,
|
||||||
|
"content": [_block_to_dict(b) for b in m.content],
|
||||||
|
"timestamp": m.timestamp.isoformat(),
|
||||||
|
}
|
||||||
|
if isinstance(m, AssistantMessage):
|
||||||
|
d = {
|
||||||
|
"type": "assistant",
|
||||||
|
"id": m.id,
|
||||||
|
"content": [_block_to_dict(b) for b in m.content],
|
||||||
|
"timestamp": m.timestamp.isoformat(),
|
||||||
|
"model": m.model,
|
||||||
|
}
|
||||||
|
if m.token_usage:
|
||||||
|
d["token_usage"] = {"input_tokens": m.token_usage.input_tokens, "output_tokens": m.token_usage.output_tokens}
|
||||||
|
return d
|
||||||
|
if isinstance(m, ToolUse):
|
||||||
|
return {"type": "tool_use", "id": m.id, "tool_name": m.tool_name, "input": m.input, "status": m.status}
|
||||||
|
if isinstance(m, ToolResult):
|
||||||
|
return {"type": "tool_result", "tool_use_id": m.tool_use_id, "output": m.output, "error": m.error}
|
||||||
|
raise TypeError(f"Unknown message type: {type(m)}")
|
||||||
|
|
||||||
|
|
||||||
|
def _msg_from_dict(d: dict) -> Message:
|
||||||
|
t = d["type"]
|
||||||
|
if t == "user":
|
||||||
|
return UserMessage(
|
||||||
|
id=d["id"],
|
||||||
|
content=[_block_from_dict(b) for b in d["content"]],
|
||||||
|
timestamp=datetime.fromisoformat(d["timestamp"]),
|
||||||
|
)
|
||||||
|
if t == "assistant":
|
||||||
|
tu = None
|
||||||
|
if "token_usage" in d and d["token_usage"]:
|
||||||
|
tu = TokenUsage(**d["token_usage"])
|
||||||
|
return AssistantMessage(
|
||||||
|
id=d["id"],
|
||||||
|
content=[_block_from_dict(b) for b in d["content"]],
|
||||||
|
timestamp=datetime.fromisoformat(d["timestamp"]),
|
||||||
|
model=d.get("model", ""),
|
||||||
|
token_usage=tu,
|
||||||
|
)
|
||||||
|
if t == "tool_use":
|
||||||
|
return ToolUse(id=d["id"], tool_name=d["tool_name"], input=d["input"], status=d.get("status", "done"))
|
||||||
|
if t == "tool_result":
|
||||||
|
return ToolResult(tool_use_id=d["tool_use_id"], output=d.get("output"), error=d.get("error"))
|
||||||
|
raise ValueError(f"Unknown message type: {t}")
|
||||||
|
|
||||||
|
|
||||||
|
def thread_to_dict(thread: Thread) -> dict:
|
||||||
|
d = {
|
||||||
|
"id": thread.id,
|
||||||
|
"messages": [_msg_to_dict(m) for m in thread.messages],
|
||||||
|
"created": thread.created.isoformat(),
|
||||||
|
"updated": thread.updated.isoformat(),
|
||||||
|
}
|
||||||
|
if thread.token_usage:
|
||||||
|
d["token_usage"] = {"input_tokens": thread.token_usage.input_tokens, "output_tokens": thread.token_usage.output_tokens}
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def thread_from_dict(d: dict) -> Thread:
|
||||||
|
tu = None
|
||||||
|
if "token_usage" in d and d["token_usage"]:
|
||||||
|
tu = TokenUsage(**d["token_usage"])
|
||||||
|
return Thread(
|
||||||
|
id=d["id"],
|
||||||
|
messages=[_msg_from_dict(m) for m in d["messages"]],
|
||||||
|
created=datetime.fromisoformat(d["created"]),
|
||||||
|
updated=datetime.fromisoformat(d["updated"]),
|
||||||
|
token_usage=tu,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_thread(thread: Thread, session_dir: Path) -> None:
|
||||||
|
agent_dir = session_dir / "agent"
|
||||||
|
agent_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
path = agent_dir / "thread.json"
|
||||||
|
thread.updated = datetime.now(timezone.utc)
|
||||||
|
path.write_text(json.dumps(thread_to_dict(thread), indent=2, cls=_Encoder))
|
||||||
|
|
||||||
|
|
||||||
|
def load_thread(session_dir: Path) -> Thread | None:
|
||||||
|
path = session_dir / "agent" / "thread.json"
|
||||||
|
if not path.exists():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return thread_from_dict(json.loads(path.read_text()))
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|||||||
50
cht/agent/buffer.py
Normal file
50
cht/agent/buffer.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""Streaming text buffer — smooth reveal of bursty network chunks.
|
||||||
|
|
||||||
|
Accumulates text from the provider and reveals it at ~60fps with a
|
||||||
|
~200ms target reveal pace, so the UI sees a smooth stream regardless
|
||||||
|
of network burst patterns.
|
||||||
|
|
||||||
|
All methods must be called from the GTK main thread.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from gi.repository import GLib
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingTextBuffer:
|
||||||
|
TICK_MS = 16 # ~60fps
|
||||||
|
REVEAL_TARGET_MS = 200
|
||||||
|
|
||||||
|
def __init__(self, on_reveal: Callable[[str], None]):
|
||||||
|
self._pending = ""
|
||||||
|
self._on_reveal = on_reveal
|
||||||
|
self._timer_id: int | None = None
|
||||||
|
|
||||||
|
def push(self, text: str):
|
||||||
|
self._pending += text
|
||||||
|
if self._timer_id is None:
|
||||||
|
self._timer_id = GLib.timeout_add(self.TICK_MS, self._tick)
|
||||||
|
|
||||||
|
def _tick(self) -> bool:
|
||||||
|
if not self._pending:
|
||||||
|
self._timer_id = None
|
||||||
|
return False
|
||||||
|
n = max(1, len(self._pending) * self.TICK_MS // self.REVEAL_TARGET_MS)
|
||||||
|
reveal, self._pending = self._pending[:n], self._pending[n:]
|
||||||
|
self._on_reveal(reveal)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
if self._pending:
|
||||||
|
self._on_reveal(self._pending)
|
||||||
|
self._pending = ""
|
||||||
|
if self._timer_id is not None:
|
||||||
|
GLib.source_remove(self._timer_id)
|
||||||
|
self._timer_id = None
|
||||||
|
|
||||||
|
def cancel(self):
|
||||||
|
self._pending = ""
|
||||||
|
if self._timer_id is not None:
|
||||||
|
GLib.source_remove(self._timer_id)
|
||||||
|
self._timer_id = None
|
||||||
206
cht/agent/claude_sdk_connection.py
Normal file
206
cht/agent/claude_sdk_connection.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
"""
|
||||||
|
AgentConnection for Claude Code SDK (claude_agent_sdk).
|
||||||
|
|
||||||
|
Uses your Claude Code subscription — no direct API costs.
|
||||||
|
Truly streams via a queue bridge between the async SDK generator
|
||||||
|
and the synchronous Iterator[StreamEvent] interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import queue
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
from cht.agent.base import (
|
||||||
|
AgentConnection,
|
||||||
|
AssistantMessage,
|
||||||
|
ImageBlock,
|
||||||
|
Message,
|
||||||
|
StreamEvent,
|
||||||
|
TextBlock,
|
||||||
|
TextDelta,
|
||||||
|
Tool,
|
||||||
|
ToolCallEnd,
|
||||||
|
ToolCallStart,
|
||||||
|
ToolResult,
|
||||||
|
ToolUse,
|
||||||
|
TranscriptBlock,
|
||||||
|
UserMessage,
|
||||||
|
Done,
|
||||||
|
Error,
|
||||||
|
)
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording and analysis tool.
|
||||||
|
You help the user understand what happened during their recording session.
|
||||||
|
|
||||||
|
You have access to frame screenshots extracted from the recording. When frames are mentioned,
|
||||||
|
use the Read tool to view them. Frame timestamps are in seconds from the start of the recording.
|
||||||
|
|
||||||
|
You also have tools to search transcripts, get session info, and capture new frames.
|
||||||
|
|
||||||
|
Be concise and specific. Focus on what's visible in the frames."""
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
"claude-sonnet-4-6",
|
||||||
|
"claude-opus-4-6",
|
||||||
|
"claude-haiku-4-5",
|
||||||
|
]
|
||||||
|
|
||||||
|
_SENTINEL = object()
|
||||||
|
|
||||||
|
|
||||||
|
def _messages_to_prompt(messages: list[Message]) -> str:
|
||||||
|
"""Flatten structured messages into a text prompt for the SDK."""
|
||||||
|
lines = []
|
||||||
|
for msg in messages:
|
||||||
|
if isinstance(msg, UserMessage):
|
||||||
|
parts = []
|
||||||
|
for b in msg.content:
|
||||||
|
if isinstance(b, TextBlock):
|
||||||
|
parts.append(b.text)
|
||||||
|
elif isinstance(b, ImageBlock):
|
||||||
|
m, s = divmod(int(b.timestamp), 60)
|
||||||
|
parts.append(f"[Frame {b.frame_id} at {m:02d}:{s:02d} — {b.path}]")
|
||||||
|
elif isinstance(b, TranscriptBlock):
|
||||||
|
m1, s1 = divmod(int(b.start), 60)
|
||||||
|
m2, s2 = divmod(int(b.end), 60)
|
||||||
|
parts.append(f"[Transcript {b.transcript_id} {m1:02d}:{s1:02d}-{m2:02d}:{s2:02d}: {b.text}]")
|
||||||
|
lines.append(f"User: {' '.join(parts)}")
|
||||||
|
elif isinstance(msg, AssistantMessage):
|
||||||
|
text = " ".join(b.text for b in msg.content if isinstance(b, TextBlock))
|
||||||
|
lines.append(f"Assistant: {text}")
|
||||||
|
elif isinstance(msg, ToolUse):
|
||||||
|
lines.append(f"[Tool call: {msg.tool_name}({msg.input})]")
|
||||||
|
elif isinstance(msg, ToolResult):
|
||||||
|
out = msg.output or msg.error or ""
|
||||||
|
lines.append(f"[Tool result: {out}]")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_schemas(tools: list[Tool]) -> list[str]:
|
||||||
|
"""Extract tool names for the SDK's allowed_tools parameter."""
|
||||||
|
# The Claude SDK uses allowed_tools as a list of tool name strings.
|
||||||
|
# Our custom tools are executed by the runner, not by the SDK,
|
||||||
|
# so we only pass "Read" to the SDK (for frame viewing).
|
||||||
|
return ["Read"]
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeSDKConnection:
|
||||||
|
"""AgentConnection using claude_agent_sdk — requires Claude Code CLI."""
|
||||||
|
|
||||||
|
def __init__(self, cwd: str | None = None, max_turns: int = 5, model: str = MODELS[0]):
|
||||||
|
self._cwd = cwd
|
||||||
|
self._max_turns = max_turns
|
||||||
|
self._model = model
|
||||||
|
self._cancelled = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return f"claude-sdk/{self._model}"
|
||||||
|
|
||||||
|
def available_models(self) -> list[str]:
|
||||||
|
return list(MODELS)
|
||||||
|
|
||||||
|
def get_model(self) -> str:
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
def set_model(self, model: str) -> None:
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
def prompt(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
tools: list[Tool],
|
||||||
|
) -> Iterator[StreamEvent]:
|
||||||
|
from claude_agent_sdk import (
|
||||||
|
query,
|
||||||
|
ClaudeAgentOptions,
|
||||||
|
AssistantMessage as SDKAssistantMessage,
|
||||||
|
TextBlock as SDKTextBlock,
|
||||||
|
ResultMessage,
|
||||||
|
CLINotFoundError,
|
||||||
|
CLIConnectionError,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_text = _messages_to_prompt(messages)
|
||||||
|
self._cancelled = False
|
||||||
|
q: queue.Queue = queue.Queue()
|
||||||
|
|
||||||
|
# Determine cwd from the last UserMessage's image paths if available
|
||||||
|
cwd = self._cwd
|
||||||
|
if not cwd:
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if isinstance(msg, UserMessage):
|
||||||
|
for b in msg.content:
|
||||||
|
if isinstance(b, ImageBlock):
|
||||||
|
cwd = str(b.path.parent.parent) # session_dir
|
||||||
|
break
|
||||||
|
if cwd:
|
||||||
|
break
|
||||||
|
|
||||||
|
async def _run():
|
||||||
|
try:
|
||||||
|
got_assistant_text = False
|
||||||
|
async for msg in query(
|
||||||
|
prompt=prompt_text,
|
||||||
|
options=ClaudeAgentOptions(
|
||||||
|
model=self._model,
|
||||||
|
cwd=cwd or ".",
|
||||||
|
allowed_tools=_tool_schemas(tools),
|
||||||
|
system_prompt=SYSTEM_PROMPT,
|
||||||
|
max_turns=self._max_turns,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
if self._cancelled:
|
||||||
|
break
|
||||||
|
if isinstance(msg, SDKAssistantMessage):
|
||||||
|
for block in msg.content:
|
||||||
|
if isinstance(block, SDKTextBlock):
|
||||||
|
q.put(TextDelta(text=block.text))
|
||||||
|
got_assistant_text = True
|
||||||
|
elif isinstance(msg, ResultMessage):
|
||||||
|
# Only use ResultMessage.result if we got no text from AssistantMessages
|
||||||
|
if msg.result and not got_assistant_text:
|
||||||
|
q.put(TextDelta(text=msg.result))
|
||||||
|
q.put(Done(stop_reason="end_turn"))
|
||||||
|
except CLINotFoundError:
|
||||||
|
q.put(Error(
|
||||||
|
message="Claude Code CLI not found.\n"
|
||||||
|
"Install it: https://claude.ai/code\n"
|
||||||
|
"Then run `claude` once in a terminal to authenticate."
|
||||||
|
))
|
||||||
|
except CLIConnectionError as e:
|
||||||
|
if "auth" in str(e).lower() or "login" in str(e).lower() or "401" in str(e):
|
||||||
|
q.put(Error(
|
||||||
|
message="Claude Code not authenticated.\n"
|
||||||
|
"Run `claude` in a terminal and complete the login flow, then retry."
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
q.put(Error(message=str(e)))
|
||||||
|
except Exception as e:
|
||||||
|
q.put(Error(message=str(e)))
|
||||||
|
finally:
|
||||||
|
q.put(_SENTINEL)
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
|
||||||
|
def _thread():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(_run())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
t = threading.Thread(target=_thread, daemon=True, name="claude_sdk_stream")
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
item = q.get()
|
||||||
|
if item is _SENTINEL:
|
||||||
|
break
|
||||||
|
yield item
|
||||||
|
|
||||||
|
def cancel(self) -> None:
|
||||||
|
self._cancelled = True
|
||||||
@@ -1,132 +0,0 @@
|
|||||||
"""
|
|
||||||
Agent provider using the Claude Code SDK (claude_agent_sdk).
|
|
||||||
|
|
||||||
Uses your Claude Code subscription — no direct API costs.
|
|
||||||
Passes frame paths in the prompt; Claude reads them visually via the Read tool.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Iterator
|
|
||||||
|
|
||||||
import anyio
|
|
||||||
from claude_agent_sdk import query, ClaudeAgentOptions, AssistantMessage, TextBlock, ResultMessage
|
|
||||||
from claude_agent_sdk import CLINotFoundError, CLIConnectionError
|
|
||||||
|
|
||||||
from cht.agent.base import AgentProvider, SessionContext
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording and analysis tool.
|
|
||||||
You help the user understand what happened during their recording session.
|
|
||||||
|
|
||||||
You have access to frame screenshots extracted from the recording. When frames are mentioned,
|
|
||||||
use the Read tool to view them. Frame timestamps are in seconds from the start of the recording.
|
|
||||||
|
|
||||||
Be concise and specific. Focus on what's visible in the frames."""
|
|
||||||
|
|
||||||
|
|
||||||
def _build_prompt(message: str, context: SessionContext) -> str:
|
|
||||||
lines = []
|
|
||||||
# Session summary
|
|
||||||
m, s = divmod(int(context.duration), 60)
|
|
||||||
lines.append(f"Recording duration: {m:02d}:{s:02d}")
|
|
||||||
lines.append(f"Total frames captured: {len(context.frames)}")
|
|
||||||
|
|
||||||
if context.mentioned_frames:
|
|
||||||
lines.append("\nFrames:")
|
|
||||||
for f in context.mentioned_frames:
|
|
||||||
fm, fs = divmod(int(f.timestamp), 60)
|
|
||||||
lines.append(f" {f.id} at {fm:02d}:{fs:02d} — {f.path}")
|
|
||||||
if context.mentioned_transcripts:
|
|
||||||
lines.append("\nTranscript:")
|
|
||||||
for t in context.mentioned_transcripts:
|
|
||||||
tm1, ts1 = divmod(int(t.start), 60)
|
|
||||||
tm2, ts2 = divmod(int(t.end), 60)
|
|
||||||
lines.append(f" {t.id} [{tm1:02d}:{ts1:02d}-{tm2:02d}:{ts2:02d}] {t.text}")
|
|
||||||
|
|
||||||
if context.history:
|
|
||||||
lines.append("\nConversation history:")
|
|
||||||
for role, text in context.history:
|
|
||||||
prefix = "User" if role == "user" else "Assistant"
|
|
||||||
lines.append(f" {prefix}: {text}")
|
|
||||||
|
|
||||||
lines.append(f"\nUser message: {message}")
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
MODELS = [
|
|
||||||
"claude-sonnet-4-6",
|
|
||||||
"claude-opus-4-6",
|
|
||||||
"claude-haiku-4-5",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ClaudeSDKProvider(AgentProvider):
|
|
||||||
"""Uses claude_agent_sdk — requires Claude Code CLI to be installed."""
|
|
||||||
|
|
||||||
def __init__(self, cwd: str | None = None, max_turns: int = 5, model: str = MODELS[0]):
|
|
||||||
self._cwd = cwd
|
|
||||||
self._max_turns = max_turns
|
|
||||||
self._model = model
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return f"claude-sdk/{self._model}"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def available_models(self) -> list[str]:
|
|
||||||
return list(MODELS)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def model(self) -> str:
|
|
||||||
return self._model
|
|
||||||
|
|
||||||
@model.setter
|
|
||||||
def model(self, value: str):
|
|
||||||
self._model = value
|
|
||||||
|
|
||||||
def stream(self, message: str, context: SessionContext) -> Iterator[str]:
|
|
||||||
prompt = _build_prompt(message, context)
|
|
||||||
chunks = []
|
|
||||||
|
|
||||||
async def _run():
|
|
||||||
async for msg in query(
|
|
||||||
prompt=prompt,
|
|
||||||
options=ClaudeAgentOptions(
|
|
||||||
model=self._model,
|
|
||||||
cwd=self._cwd or str(context.session_dir),
|
|
||||||
allowed_tools=["Read"],
|
|
||||||
system_prompt=SYSTEM_PROMPT,
|
|
||||||
max_turns=self._max_turns,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
if isinstance(msg, AssistantMessage):
|
|
||||||
for block in msg.content:
|
|
||||||
if isinstance(block, TextBlock):
|
|
||||||
chunks.append(block.text)
|
|
||||||
elif isinstance(msg, ResultMessage):
|
|
||||||
if msg.result:
|
|
||||||
chunks.append(msg.result)
|
|
||||||
|
|
||||||
try:
|
|
||||||
import asyncio
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(_run())
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
except CLINotFoundError:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Claude Code CLI not found.\n"
|
|
||||||
"Install it: https://claude.ai/code\n"
|
|
||||||
"Then run `claude` once in a terminal to authenticate."
|
|
||||||
)
|
|
||||||
except CLIConnectionError as e:
|
|
||||||
if "auth" in str(e).lower() or "login" in str(e).lower() or "401" in str(e):
|
|
||||||
raise RuntimeError(
|
|
||||||
"Claude Code not authenticated.\n"
|
|
||||||
"Run `claude` in a terminal and complete the login flow, then retry."
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
yield from chunks
|
|
||||||
@@ -1,134 +0,0 @@
|
|||||||
"""
|
|
||||||
Agent provider for OpenAI-compatible APIs (Groq, OpenAI, etc.).
|
|
||||||
|
|
||||||
Sends frame images as base64. Requires GROQ_API_KEY or OPENAI_API_KEY env var.
|
|
||||||
Auto-detects provider from available env keys.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import Iterator
|
|
||||||
|
|
||||||
from cht.agent.base import AgentProvider, SessionContext, FrameRef
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording and analysis tool.
|
|
||||||
You help the user understand what happened during their recording session.
|
|
||||||
Be concise and specific. Focus on what's visible in the provided frames."""
|
|
||||||
|
|
||||||
# Provider configs: (base_url, default_model, available_models)
|
|
||||||
_PROVIDER_CONFIGS = {
|
|
||||||
"groq": (
|
|
||||||
"https://api.groq.com/openai/v1",
|
|
||||||
"meta-llama/llama-4-maverick-17b-128e-instruct",
|
|
||||||
[
|
|
||||||
"meta-llama/llama-4-maverick-17b-128e-instruct",
|
|
||||||
"meta-llama/llama-4-scout-17b-16e-instruct",
|
|
||||||
"qwen/qwen-2.5-vl-72b-instruct",
|
|
||||||
],
|
|
||||||
),
|
|
||||||
"openai": (
|
|
||||||
"https://api.openai.com/v1",
|
|
||||||
"gpt-4o",
|
|
||||||
["gpt-4o", "gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini"],
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _detect_provider() -> tuple[str, str, str, list[str]] | None:
|
|
||||||
"""Returns (api_key, base_url, model, available_models) or None."""
|
|
||||||
if key := os.environ.get("GROQ_API_KEY"):
|
|
||||||
base_url, default_model, models = _PROVIDER_CONFIGS["groq"]
|
|
||||||
model = os.environ.get("CHT_MODEL", default_model)
|
|
||||||
return key, base_url, model, models
|
|
||||||
if key := os.environ.get("OPENAI_API_KEY"):
|
|
||||||
base_url, default_model, models = _PROVIDER_CONFIGS["openai"]
|
|
||||||
model = os.environ.get("CHT_MODEL", default_model)
|
|
||||||
return key, base_url, model, models
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _frame_to_image_content(frame: FrameRef) -> dict:
|
|
||||||
with open(frame.path, "rb") as f:
|
|
||||||
data = base64.standard_b64encode(f.read()).decode()
|
|
||||||
return {
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {"url": f"data:image/jpeg;base64,{data}"},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAICompatProvider(AgentProvider):
|
|
||||||
"""Uses any OpenAI-compatible API. Auto-detects from env vars."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
detected = _detect_provider()
|
|
||||||
if not detected:
|
|
||||||
raise RuntimeError(
|
|
||||||
"No API key found. Set GROQ_API_KEY or OPENAI_API_KEY."
|
|
||||||
)
|
|
||||||
self._api_key, self._base_url, self._model, self._models = detected
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
if "groq" in self._base_url:
|
|
||||||
return f"groq/{self._model}"
|
|
||||||
return f"openai-compat/{self._model}"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def available_models(self) -> list[str]:
|
|
||||||
return list(self._models)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def model(self) -> str:
|
|
||||||
return self._model
|
|
||||||
|
|
||||||
@model.setter
|
|
||||||
def model(self, value: str):
|
|
||||||
self._model = value
|
|
||||||
|
|
||||||
def stream(self, message: str, context: SessionContext) -> Iterator[str]:
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
client = OpenAI(api_key=self._api_key, base_url=self._base_url)
|
|
||||||
|
|
||||||
# Build context header
|
|
||||||
m, s = divmod(int(context.duration), 60)
|
|
||||||
ctx_lines = [
|
|
||||||
f"Recording duration: {m:02d}:{s:02d}",
|
|
||||||
f"Total frames: {len(context.frames)}",
|
|
||||||
]
|
|
||||||
if context.mentioned_transcripts:
|
|
||||||
ctx_lines.append("\nTranscript:")
|
|
||||||
for t in context.mentioned_transcripts:
|
|
||||||
tm1, ts1 = divmod(int(t.start), 60)
|
|
||||||
tm2, ts2 = divmod(int(t.end), 60)
|
|
||||||
ctx_lines.append(f" {t.id} [{tm1:02d}:{ts1:02d}-{tm2:02d}:{ts2:02d}] {t.text}")
|
|
||||||
ctx_text = "\n".join(ctx_lines) + "\n"
|
|
||||||
|
|
||||||
frames_to_send = context.mentioned_frames
|
|
||||||
|
|
||||||
content: list[dict] = [{"type": "text", "text": ctx_text + message}]
|
|
||||||
for frame in frames_to_send:
|
|
||||||
fm, fs = divmod(int(frame.timestamp), 60)
|
|
||||||
content.append({"type": "text", "text": f"{frame.id} at {fm:02d}:{fs:02d}:"})
|
|
||||||
try:
|
|
||||||
content.append(_frame_to_image_content(frame))
|
|
||||||
except Exception as e:
|
|
||||||
log.warning("Could not encode frame %s: %s", frame.id, e)
|
|
||||||
|
|
||||||
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
|
||||||
for role, text in context.history:
|
|
||||||
messages.append({"role": role, "content": text})
|
|
||||||
messages.append({"role": "user", "content": content})
|
|
||||||
|
|
||||||
stream = client.chat.completions.create(
|
|
||||||
model=self._model,
|
|
||||||
messages=messages,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
for chunk in stream:
|
|
||||||
delta = chunk.choices[0].delta.content
|
|
||||||
if delta:
|
|
||||||
yield delta
|
|
||||||
249
cht/agent/openai_connection.py
Normal file
249
cht/agent/openai_connection.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
"""
|
||||||
|
AgentConnection for OpenAI-compatible APIs (Groq, OpenAI, etc.).
|
||||||
|
|
||||||
|
Sends frame images as base64. Supports tool calls via function calling.
|
||||||
|
Requires GROQ_API_KEY or OPENAI_API_KEY env var.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
from cht.agent.base import (
|
||||||
|
AssistantMessage,
|
||||||
|
ImageBlock,
|
||||||
|
Message,
|
||||||
|
StreamEvent,
|
||||||
|
TextBlock,
|
||||||
|
TextDelta,
|
||||||
|
Tool,
|
||||||
|
ToolCallEnd,
|
||||||
|
ToolCallStart,
|
||||||
|
ToolResult,
|
||||||
|
ToolUse,
|
||||||
|
TranscriptBlock,
|
||||||
|
UserMessage,
|
||||||
|
Done,
|
||||||
|
Error,
|
||||||
|
)
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = """You are an assistant integrated into CHT, a screen recording and analysis tool.
|
||||||
|
You help the user understand what happened during their recording session.
|
||||||
|
Be concise and specific. Focus on what's visible in the provided frames."""
|
||||||
|
|
||||||
|
_PROVIDER_CONFIGS = {
|
||||||
|
"groq": (
|
||||||
|
"https://api.groq.com/openai/v1",
|
||||||
|
"meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||||
|
[
|
||||||
|
"meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||||
|
"meta-llama/llama-4-scout-17b-16e-instruct",
|
||||||
|
"qwen/qwen-2.5-vl-72b-instruct",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
"openai": (
|
||||||
|
"https://api.openai.com/v1",
|
||||||
|
"gpt-4o",
|
||||||
|
["gpt-4o", "gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini"],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_provider() -> tuple[str, str, str, list[str]] | None:
|
||||||
|
if key := os.environ.get("GROQ_API_KEY"):
|
||||||
|
base_url, default_model, models = _PROVIDER_CONFIGS["groq"]
|
||||||
|
model = os.environ.get("CHT_MODEL", default_model)
|
||||||
|
return key, base_url, model, models
|
||||||
|
if key := os.environ.get("OPENAI_API_KEY"):
|
||||||
|
base_url, default_model, models = _PROVIDER_CONFIGS["openai"]
|
||||||
|
model = os.environ.get("CHT_MODEL", default_model)
|
||||||
|
return key, base_url, model, models
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _frame_to_base64(path) -> str | None:
|
||||||
|
try:
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
return base64.standard_b64encode(f.read()).decode()
|
||||||
|
except Exception as e:
|
||||||
|
log.warning("Could not encode frame %s: %s", path, e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _messages_to_openai(messages: list[Message]) -> list[dict]:
|
||||||
|
"""Convert structured messages to OpenAI chat format."""
|
||||||
|
result = [{"role": "system", "content": SYSTEM_PROMPT}]
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
if isinstance(msg, UserMessage):
|
||||||
|
content: list[dict] = []
|
||||||
|
for b in msg.content:
|
||||||
|
if isinstance(b, TextBlock):
|
||||||
|
content.append({"type": "text", "text": b.text})
|
||||||
|
elif isinstance(b, ImageBlock):
|
||||||
|
m, s = divmod(int(b.timestamp), 60)
|
||||||
|
content.append({"type": "text", "text": f"{b.frame_id} at {m:02d}:{s:02d}:"})
|
||||||
|
data = _frame_to_base64(b.path)
|
||||||
|
if data:
|
||||||
|
content.append({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:image/jpeg;base64,{data}"},
|
||||||
|
})
|
||||||
|
elif isinstance(b, TranscriptBlock):
|
||||||
|
m1, s1 = divmod(int(b.start), 60)
|
||||||
|
m2, s2 = divmod(int(b.end), 60)
|
||||||
|
content.append({
|
||||||
|
"type": "text",
|
||||||
|
"text": f"{b.transcript_id} [{m1:02d}:{s1:02d}-{m2:02d}:{s2:02d}] {b.text}",
|
||||||
|
})
|
||||||
|
result.append({"role": "user", "content": content})
|
||||||
|
|
||||||
|
elif isinstance(msg, AssistantMessage):
|
||||||
|
text = " ".join(b.text for b in msg.content if isinstance(b, TextBlock))
|
||||||
|
result.append({"role": "assistant", "content": text})
|
||||||
|
|
||||||
|
elif isinstance(msg, ToolUse):
|
||||||
|
result.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": msg.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": msg.tool_name, "arguments": json.dumps(msg.input)},
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
|
||||||
|
elif isinstance(msg, ToolResult):
|
||||||
|
result.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": msg.tool_use_id,
|
||||||
|
"content": msg.output or msg.error or "",
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _tools_to_openai(tools: list[Tool]) -> list[dict] | None:
|
||||||
|
if not tools:
|
||||||
|
return None
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": t.name,
|
||||||
|
"description": t.description,
|
||||||
|
"parameters": t.input_schema(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for t in tools
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIConnection:
|
||||||
|
"""AgentConnection using any OpenAI-compatible API."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
detected = _detect_provider()
|
||||||
|
if not detected:
|
||||||
|
raise RuntimeError("No API key found. Set GROQ_API_KEY or OPENAI_API_KEY.")
|
||||||
|
self._api_key, self._base_url, self._model, self._models = detected
|
||||||
|
self._cancelled = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
if "groq" in self._base_url:
|
||||||
|
return f"groq/{self._model}"
|
||||||
|
return f"openai-compat/{self._model}"
|
||||||
|
|
||||||
|
def available_models(self) -> list[str]:
|
||||||
|
return list(self._models)
|
||||||
|
|
||||||
|
def get_model(self) -> str:
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
def set_model(self, model: str) -> None:
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
def prompt(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
tools: list[Tool],
|
||||||
|
) -> Iterator[StreamEvent]:
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
client = OpenAI(api_key=self._api_key, base_url=self._base_url)
|
||||||
|
self._cancelled = False
|
||||||
|
|
||||||
|
oai_messages = _messages_to_openai(messages)
|
||||||
|
oai_tools = _tools_to_openai(tools)
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"model": self._model,
|
||||||
|
"messages": oai_messages,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
if oai_tools:
|
||||||
|
kwargs["tools"] = oai_tools
|
||||||
|
|
||||||
|
try:
|
||||||
|
stream = client.chat.completions.create(**kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
yield Error(message=str(e))
|
||||||
|
return
|
||||||
|
|
||||||
|
# Accumulate tool calls from streaming deltas
|
||||||
|
tool_calls: dict[int, dict] = {} # index → {id, name, arguments}
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
if self._cancelled:
|
||||||
|
break
|
||||||
|
choice = chunk.choices[0] if chunk.choices else None
|
||||||
|
if not choice:
|
||||||
|
continue
|
||||||
|
|
||||||
|
delta = choice.delta
|
||||||
|
|
||||||
|
# Text content
|
||||||
|
if delta.content:
|
||||||
|
yield TextDelta(text=delta.content)
|
||||||
|
|
||||||
|
# Tool calls (streamed incrementally)
|
||||||
|
if delta.tool_calls:
|
||||||
|
for tc_delta in delta.tool_calls:
|
||||||
|
idx = tc_delta.index
|
||||||
|
if idx not in tool_calls:
|
||||||
|
tool_calls[idx] = {"id": "", "name": "", "arguments": ""}
|
||||||
|
if tc_delta.id:
|
||||||
|
tool_calls[idx]["id"] = tc_delta.id
|
||||||
|
if tc_delta.function:
|
||||||
|
if tc_delta.function.name:
|
||||||
|
tool_calls[idx]["name"] = tc_delta.function.name
|
||||||
|
if tc_delta.function.arguments:
|
||||||
|
tool_calls[idx]["arguments"] += tc_delta.function.arguments
|
||||||
|
|
||||||
|
# Check finish reason
|
||||||
|
if choice.finish_reason:
|
||||||
|
if choice.finish_reason == "tool_calls":
|
||||||
|
# Emit accumulated tool calls
|
||||||
|
for idx in sorted(tool_calls.keys()):
|
||||||
|
tc = tool_calls[idx]
|
||||||
|
try:
|
||||||
|
inp = json.loads(tc["arguments"]) if tc["arguments"] else {}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
inp = {}
|
||||||
|
yield ToolCallStart(id=tc["id"], name=tc["name"], input=inp)
|
||||||
|
yield ToolCallEnd(id=tc["id"])
|
||||||
|
yield Done(stop_reason="tool_use")
|
||||||
|
elif choice.finish_reason == "stop":
|
||||||
|
yield Done(stop_reason="end_turn")
|
||||||
|
elif choice.finish_reason == "length":
|
||||||
|
yield Done(stop_reason="max_tokens")
|
||||||
|
else:
|
||||||
|
yield Done(stop_reason=choice.finish_reason)
|
||||||
|
|
||||||
|
def cancel(self) -> None:
|
||||||
|
self._cancelled = True
|
||||||
@@ -1,13 +1,13 @@
|
|||||||
"""
|
"""
|
||||||
Agent runner — resolves provider, parses @-mentions, dispatches messages.
|
Agent runner — resolves connection, parses @-mentions, dispatches messages,
|
||||||
|
executes tool loop.
|
||||||
|
|
||||||
Provider selection (in order):
|
Connection selection (in order):
|
||||||
1. GROQ_API_KEY → OpenAI-compat / Groq
|
1. GROQ_API_KEY → OpenAI-compat / Groq
|
||||||
2. OPENAI_API_KEY → OpenAI-compat / OpenAI
|
2. OPENAI_API_KEY → OpenAI-compat / OpenAI
|
||||||
3. (default) → Claude Code SDK (uses CC subscription)
|
3. (default) → Claude Code SDK (uses CC subscription)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@@ -15,7 +15,29 @@ from pathlib import Path
|
|||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from cht.agent.base import AgentProvider, FrameRef, TranscriptRef, SessionContext
|
from cht.agent.base import (
|
||||||
|
AssistantMessage,
|
||||||
|
FrameRef,
|
||||||
|
ImageBlock,
|
||||||
|
Message,
|
||||||
|
StreamEvent,
|
||||||
|
TextBlock,
|
||||||
|
TextDelta,
|
||||||
|
ToolCallEnd,
|
||||||
|
ToolCallStart,
|
||||||
|
ToolContext,
|
||||||
|
ToolResult,
|
||||||
|
ToolUse,
|
||||||
|
TranscriptBlock,
|
||||||
|
TranscriptRef,
|
||||||
|
UserMessage,
|
||||||
|
Done,
|
||||||
|
Error,
|
||||||
|
Thread as AgentThread,
|
||||||
|
save_thread,
|
||||||
|
load_thread,
|
||||||
|
)
|
||||||
|
from cht.agent.tools import load_frames, load_transcript, BUILTIN_TOOLS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -25,6 +47,8 @@ ACTIONS: dict[str, str] = {
|
|||||||
"Answer": "answer",
|
"Answer": "answer",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MAX_TOOL_TURNS = 10
|
||||||
|
|
||||||
|
|
||||||
def check_claude_cli() -> str | None:
|
def check_claude_cli() -> str | None:
|
||||||
"""Returns None if OK, or an error string if CLI is missing/unauthenticated."""
|
"""Returns None if OK, or an error string if CLI is missing/unauthenticated."""
|
||||||
@@ -43,16 +67,15 @@ def check_claude_cli() -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _resolve_provider() -> AgentProvider:
|
def _resolve_connection():
|
||||||
if os.environ.get("GROQ_API_KEY") or os.environ.get("OPENAI_API_KEY"):
|
if os.environ.get("GROQ_API_KEY") or os.environ.get("OPENAI_API_KEY"):
|
||||||
from cht.agent.openai_compat_provider import OpenAICompatProvider
|
from cht.agent.openai_connection import OpenAIConnection
|
||||||
return OpenAICompatProvider()
|
return OpenAIConnection()
|
||||||
from cht.agent.claude_sdk_provider import ClaudeSDKProvider
|
from cht.agent.claude_sdk_connection import ClaudeSDKConnection
|
||||||
return ClaudeSDKProvider()
|
return ClaudeSDKConnection()
|
||||||
|
|
||||||
|
|
||||||
def _expand_ref_nums(spec: str) -> list[int]:
|
def _expand_ref_nums(spec: str) -> list[int]:
|
||||||
"""Expand a ref spec like '2-6' or '2,4,6' or '2-4,6,8-10' into sorted ints."""
|
|
||||||
nums = set()
|
nums = set()
|
||||||
for part in spec.split(","):
|
for part in spec.split(","):
|
||||||
part = part.strip()
|
part = part.strip()
|
||||||
@@ -71,7 +94,6 @@ def _expand_ref_nums(spec: str) -> list[int]:
|
|||||||
|
|
||||||
|
|
||||||
def _parse_mentions(message: str, frames: list[FrameRef]) -> list[FrameRef]:
|
def _parse_mentions(message: str, frames: list[FrameRef]) -> list[FrameRef]:
|
||||||
"""Extract @F references. Accepts @F1, @F2-6, @F2,4,6, @F2-4,6,8-10."""
|
|
||||||
mentioned = []
|
mentioned = []
|
||||||
seen = set()
|
seen = set()
|
||||||
for match in re.finditer(r"@[Ff]([\d,\-]+)", message):
|
for match in re.finditer(r"@[Ff]([\d,\-]+)", message):
|
||||||
@@ -85,49 +107,7 @@ def _parse_mentions(message: str, frames: list[FrameRef]) -> list[FrameRef]:
|
|||||||
return mentioned
|
return mentioned
|
||||||
|
|
||||||
|
|
||||||
def _resolve_frame_path(frames_dir: Path, raw_path: str) -> Path | None:
|
|
||||||
"""Resolve a frame path from index.json, handling mounted/remote sessions."""
|
|
||||||
p = Path(raw_path)
|
|
||||||
if p.exists():
|
|
||||||
return p
|
|
||||||
# Try relative to frames_dir (handles path prefix mismatch from remote)
|
|
||||||
local = frames_dir / p.name
|
|
||||||
if local.exists():
|
|
||||||
return local
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _load_frames(frames_dir: Path) -> list[FrameRef]:
|
|
||||||
index_path = frames_dir / "index.json"
|
|
||||||
if not index_path.exists():
|
|
||||||
return []
|
|
||||||
try:
|
|
||||||
entries = json.loads(index_path.read_text())
|
|
||||||
frames = []
|
|
||||||
for e in entries:
|
|
||||||
resolved = _resolve_frame_path(frames_dir, e["path"])
|
|
||||||
if resolved:
|
|
||||||
frames.append(FrameRef(id=e["id"], path=resolved, timestamp=e["timestamp"]))
|
|
||||||
return frames
|
|
||||||
except Exception as e:
|
|
||||||
log.warning("Could not load frames index: %s", e)
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def _load_transcript(transcript_dir: Path) -> list[TranscriptRef]:
|
|
||||||
index_path = transcript_dir / "index.json"
|
|
||||||
if not index_path.exists():
|
|
||||||
return []
|
|
||||||
try:
|
|
||||||
entries = json.loads(index_path.read_text())
|
|
||||||
return [TranscriptRef(**e) for e in entries]
|
|
||||||
except Exception as e:
|
|
||||||
log.warning("Could not load transcript index: %s", e)
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_transcript_mentions(message: str, segments: list[TranscriptRef]) -> list[TranscriptRef]:
|
def _parse_transcript_mentions(message: str, segments: list[TranscriptRef]) -> list[TranscriptRef]:
|
||||||
"""Extract @T references. Accepts @T1, @T2-6, @T2,4,6, @T1-3,5,7-10."""
|
|
||||||
mentioned = []
|
mentioned = []
|
||||||
seen = set()
|
seen = set()
|
||||||
for match in re.finditer(r"@[Tt]([\d,\-]+)", message):
|
for match in re.finditer(r"@[Tt]([\d,\-]+)", message):
|
||||||
@@ -141,84 +121,191 @@ def _parse_transcript_mentions(message: str, segments: list[TranscriptRef]) -> l
|
|||||||
return mentioned
|
return mentioned
|
||||||
|
|
||||||
|
|
||||||
|
def _build_user_message(text: str, mentioned_frames: list[FrameRef],
|
||||||
|
mentioned_transcripts: list[TranscriptRef]) -> UserMessage:
|
||||||
|
"""Build a UserMessage with content blocks from text and @-mentions."""
|
||||||
|
content: list = [TextBlock(text=text)]
|
||||||
|
for f in mentioned_frames:
|
||||||
|
content.append(ImageBlock(frame_id=f.id, path=f.path, timestamp=f.timestamp))
|
||||||
|
for t in mentioned_transcripts:
|
||||||
|
content.append(TranscriptBlock(
|
||||||
|
transcript_id=t.id, start=t.start, end=t.end, text=t.text
|
||||||
|
))
|
||||||
|
return UserMessage(content=content)
|
||||||
|
|
||||||
|
|
||||||
class AgentRunner:
|
class AgentRunner:
|
||||||
"""Runs agent queries in a background thread, streams chunks to a callback."""
|
"""Runs agent queries in a background thread with tool execution loop."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._provider: AgentProvider | None = None
|
self._connection = None
|
||||||
self._history: list[tuple[str, str]] = [] # (role, text)
|
self._thread: AgentThread = AgentThread()
|
||||||
self.include_history = False # toggled by UI
|
self.include_history = False
|
||||||
|
|
||||||
def _get_provider(self) -> AgentProvider:
|
@property
|
||||||
if self._provider is None:
|
def thread(self) -> AgentThread:
|
||||||
self._provider = _resolve_provider()
|
return self._thread
|
||||||
log.info("Agent provider: %s", self._provider.name)
|
|
||||||
return self._provider
|
def _get_connection(self):
|
||||||
|
if self._connection is None:
|
||||||
|
self._connection = _resolve_connection()
|
||||||
|
log.info("Agent connection: %s", self._connection.name)
|
||||||
|
return self._connection
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_name(self) -> str:
|
def provider_name(self) -> str:
|
||||||
try:
|
try:
|
||||||
return self._get_provider().name
|
return self._get_connection().name
|
||||||
except Exception:
|
except Exception:
|
||||||
return "unknown"
|
return "unknown"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def available_models(self) -> list[str]:
|
def available_models(self) -> list[str]:
|
||||||
try:
|
try:
|
||||||
return self._get_provider().available_models
|
return self._get_connection().available_models()
|
||||||
except Exception:
|
except Exception:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model(self) -> str:
|
def model(self) -> str:
|
||||||
try:
|
try:
|
||||||
return self._get_provider().model
|
return self._get_connection().get_model()
|
||||||
except Exception:
|
except Exception:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
@model.setter
|
@model.setter
|
||||||
def model(self, value: str):
|
def model(self, value: str):
|
||||||
self._get_provider().model = value
|
self._get_connection().set_model(value)
|
||||||
|
|
||||||
def clear_history(self):
|
def clear_history(self):
|
||||||
self._history.clear()
|
self._thread = AgentThread()
|
||||||
|
|
||||||
|
def set_thread(self, thread: AgentThread):
|
||||||
|
self._thread = thread
|
||||||
|
|
||||||
|
def load_from_session(self, session_dir: Path):
|
||||||
|
loaded = load_thread(session_dir)
|
||||||
|
if loaded:
|
||||||
|
self._thread = loaded
|
||||||
|
log.info("Loaded thread %s with %d messages", loaded.id, len(loaded.messages))
|
||||||
|
else:
|
||||||
|
self._thread = AgentThread()
|
||||||
|
|
||||||
def send(
|
def send(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
stream_mgr,
|
stream_mgr,
|
||||||
tracker,
|
tracker,
|
||||||
on_chunk: Callable[[str], None],
|
on_event: Callable[[StreamEvent], None],
|
||||||
on_done: Callable[[str | None], None],
|
on_done: Callable[[str | None], None],
|
||||||
):
|
):
|
||||||
"""Dispatch message in a background thread.
|
"""Dispatch message in a background thread with tool execution loop.
|
||||||
|
|
||||||
on_chunk(text) — called for each streamed chunk
|
on_event(StreamEvent) — called for each stream event (from bg thread)
|
||||||
on_done(error_or_None) — called when complete
|
on_done(error_or_None) — called when complete (from bg thread)
|
||||||
"""
|
"""
|
||||||
def _run():
|
def _run():
|
||||||
try:
|
try:
|
||||||
provider = self._get_provider()
|
connection = self._get_connection()
|
||||||
frames = _load_frames(stream_mgr.frames_dir)
|
frames = load_frames(stream_mgr.frames_dir)
|
||||||
mentioned_frames = _parse_mentions(message, frames)
|
mentioned_frames = _parse_mentions(message, frames)
|
||||||
transcript = _load_transcript(stream_mgr.transcript_dir)
|
transcript = load_transcript(stream_mgr.transcript_dir)
|
||||||
mentioned_transcripts = _parse_transcript_mentions(message, transcript)
|
mentioned_transcripts = _parse_transcript_mentions(message, transcript)
|
||||||
context = SessionContext(
|
|
||||||
|
# Build and append user message
|
||||||
|
user_msg = _build_user_message(message, mentioned_frames, mentioned_transcripts)
|
||||||
|
self._thread.messages.append(user_msg)
|
||||||
|
|
||||||
|
# Build tool context
|
||||||
|
tool_ctx = ToolContext(
|
||||||
session_dir=stream_mgr.session_dir,
|
session_dir=stream_mgr.session_dir,
|
||||||
frames=frames,
|
frames_dir=stream_mgr.frames_dir,
|
||||||
duration=tracker.duration if tracker else 0.0,
|
transcript_dir=stream_mgr.transcript_dir,
|
||||||
mentioned_frames=mentioned_frames,
|
stream_mgr=stream_mgr,
|
||||||
transcript_segments=transcript,
|
tracker=tracker,
|
||||||
mentioned_transcripts=mentioned_transcripts,
|
|
||||||
history=list(self._history) if self.include_history else [],
|
|
||||||
)
|
)
|
||||||
self._history.append(("user", message))
|
|
||||||
response_chunks = []
|
# Tool registry
|
||||||
for chunk in provider.stream(message, context):
|
tools_by_name = {t.name: t for t in BUILTIN_TOOLS}
|
||||||
response_chunks.append(chunk)
|
|
||||||
on_chunk(chunk)
|
# Messages to send — full thread or just last message
|
||||||
self._history.append(("assistant", "".join(response_chunks)))
|
def _get_messages():
|
||||||
|
if self.include_history:
|
||||||
|
return list(self._thread.messages)
|
||||||
|
return [self._thread.messages[-1]]
|
||||||
|
|
||||||
|
# Tool execution loop
|
||||||
|
full_text_parts = []
|
||||||
|
for _turn in range(MAX_TOOL_TURNS):
|
||||||
|
msgs = _get_messages()
|
||||||
|
stop_reason = None
|
||||||
|
|
||||||
|
for event in connection.prompt(msgs, BUILTIN_TOOLS):
|
||||||
|
on_event(event)
|
||||||
|
|
||||||
|
if isinstance(event, TextDelta):
|
||||||
|
full_text_parts.append(event.text)
|
||||||
|
|
||||||
|
elif isinstance(event, ToolCallStart):
|
||||||
|
tool_use = ToolUse(
|
||||||
|
id=event.id,
|
||||||
|
tool_name=event.name,
|
||||||
|
input=event.input,
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
self._thread.messages.append(tool_use)
|
||||||
|
|
||||||
|
elif isinstance(event, ToolCallEnd):
|
||||||
|
# Execute the tool
|
||||||
|
tool = tools_by_name.get(event.id)
|
||||||
|
# Find the ToolUse by id
|
||||||
|
tool_use_msg = None
|
||||||
|
for m in reversed(self._thread.messages):
|
||||||
|
if isinstance(m, ToolUse) and m.id == event.id:
|
||||||
|
tool_use_msg = m
|
||||||
|
break
|
||||||
|
|
||||||
|
if tool_use_msg:
|
||||||
|
tool_impl = tools_by_name.get(tool_use_msg.tool_name)
|
||||||
|
if tool_impl:
|
||||||
|
result = tool_impl.run(tool_use_msg.input, tool_ctx)
|
||||||
|
result.tool_use_id = tool_use_msg.id
|
||||||
|
tool_use_msg.status = "error" if result.error else "done"
|
||||||
|
else:
|
||||||
|
result = ToolResult(
|
||||||
|
tool_use_id=tool_use_msg.id,
|
||||||
|
error=f"Unknown tool: {tool_use_msg.tool_name}",
|
||||||
|
)
|
||||||
|
tool_use_msg.status = "error"
|
||||||
|
self._thread.messages.append(result)
|
||||||
|
on_event(result)
|
||||||
|
|
||||||
|
elif isinstance(event, Done):
|
||||||
|
stop_reason = event.stop_reason
|
||||||
|
break
|
||||||
|
|
||||||
|
elif isinstance(event, Error):
|
||||||
|
on_done(event.message)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build assistant message from accumulated text
|
||||||
|
if full_text_parts:
|
||||||
|
asst_msg = AssistantMessage(
|
||||||
|
content=[TextBlock(text="".join(full_text_parts))],
|
||||||
|
model=connection.get_model(),
|
||||||
|
)
|
||||||
|
self._thread.messages.append(asst_msg)
|
||||||
|
|
||||||
|
if stop_reason != "tool_use":
|
||||||
|
break
|
||||||
|
|
||||||
|
# Reset text for next turn (tool loop continues)
|
||||||
|
full_text_parts = []
|
||||||
|
|
||||||
|
# Save thread
|
||||||
|
save_thread(self._thread, stream_mgr.session_dir)
|
||||||
on_done(None)
|
on_done(None)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error("Agent error: %s", e)
|
log.error("Agent error: %s", e)
|
||||||
on_done(str(e))
|
on_done(str(e))
|
||||||
|
|||||||
201
cht/agent/tools.py
Normal file
201
cht/agent/tools.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
"""Built-in agent tools — ReadFrame, SearchTranscript, GetSessionInfo, CaptureFrame.
|
||||||
|
|
||||||
|
Also contains shared data-loading functions (moved from runner.py).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from cht.agent.base import FrameRef, TranscriptRef, ToolContext, ToolResult
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Data loading (shared by tools and runner)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _resolve_frame_path(frames_dir: Path, raw_path: str) -> Path | None:
|
||||||
|
p = Path(raw_path)
|
||||||
|
if p.exists():
|
||||||
|
return p
|
||||||
|
local = frames_dir / p.name
|
||||||
|
if local.exists():
|
||||||
|
return local
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_frames(frames_dir: Path) -> list[FrameRef]:
|
||||||
|
index_path = frames_dir / "index.json"
|
||||||
|
if not index_path.exists():
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
entries = json.loads(index_path.read_text())
|
||||||
|
frames = []
|
||||||
|
for e in entries:
|
||||||
|
resolved = _resolve_frame_path(frames_dir, e["path"])
|
||||||
|
if resolved:
|
||||||
|
frames.append(FrameRef(id=e["id"], path=resolved, timestamp=e["timestamp"]))
|
||||||
|
return frames
|
||||||
|
except Exception as e:
|
||||||
|
log.warning("Could not load frames index: %s", e)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def load_transcript(transcript_dir: Path) -> list[TranscriptRef]:
|
||||||
|
index_path = transcript_dir / "index.json"
|
||||||
|
if not index_path.exists():
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
entries = json.loads(index_path.read_text())
|
||||||
|
return [TranscriptRef(**e) for e in entries]
|
||||||
|
except Exception as e:
|
||||||
|
log.warning("Could not load transcript index: %s", e)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tool implementations
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class ReadFrameTool:
|
||||||
|
name = "read_frame"
|
||||||
|
description = "Read frame screenshots by ID. Returns file paths for visual inspection."
|
||||||
|
|
||||||
|
def input_schema(self) -> dict:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"frame_ids": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "Frame IDs like ['F0001', 'F0003']",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["frame_ids"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def run(self, input: dict, context: ToolContext) -> ToolResult:
|
||||||
|
frame_ids = input.get("frame_ids", [])
|
||||||
|
frames = load_frames(context.frames_dir)
|
||||||
|
by_id = {f.id: f for f in frames}
|
||||||
|
lines = []
|
||||||
|
for fid in frame_ids:
|
||||||
|
f = by_id.get(fid)
|
||||||
|
if f:
|
||||||
|
m, s = divmod(int(f.timestamp), 60)
|
||||||
|
lines.append(f"{f.id} at {m:02d}:{s:02d} — {f.path}")
|
||||||
|
else:
|
||||||
|
lines.append(f"{fid}: not found")
|
||||||
|
return ToolResult(tool_use_id="", output="\n".join(lines))
|
||||||
|
|
||||||
|
|
||||||
|
class SearchTranscriptTool:
|
||||||
|
name = "search_transcript"
|
||||||
|
description = "Search transcript segments by text substring and/or time range."
|
||||||
|
|
||||||
|
def input_schema(self) -> dict:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string", "description": "Text substring to search for (optional)"},
|
||||||
|
"start": {"type": "number", "description": "Start time in seconds (optional)"},
|
||||||
|
"end": {"type": "number", "description": "End time in seconds (optional)"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def run(self, input: dict, context: ToolContext) -> ToolResult:
|
||||||
|
segments = load_transcript(context.transcript_dir)
|
||||||
|
query = input.get("query", "").lower()
|
||||||
|
start = input.get("start")
|
||||||
|
end = input.get("end")
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for seg in segments:
|
||||||
|
if start is not None and seg.end < start:
|
||||||
|
continue
|
||||||
|
if end is not None and seg.start > end:
|
||||||
|
continue
|
||||||
|
if query and query not in seg.text.lower():
|
||||||
|
continue
|
||||||
|
matches.append(seg)
|
||||||
|
|
||||||
|
if not matches:
|
||||||
|
return ToolResult(tool_use_id="", output="No matching transcript segments found.")
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
for seg in matches:
|
||||||
|
m1, s1 = divmod(int(seg.start), 60)
|
||||||
|
m2, s2 = divmod(int(seg.end), 60)
|
||||||
|
lines.append(f"{seg.id} [{m1:02d}:{s1:02d}-{m2:02d}:{s2:02d}] {seg.text}")
|
||||||
|
return ToolResult(tool_use_id="", output="\n".join(lines))
|
||||||
|
|
||||||
|
|
||||||
|
class GetSessionInfoTool:
|
||||||
|
name = "get_session_info"
|
||||||
|
description = "Get recording session information: duration, frame count, segment list."
|
||||||
|
|
||||||
|
def input_schema(self) -> dict:
|
||||||
|
return {"type": "object", "properties": {}}
|
||||||
|
|
||||||
|
def run(self, input: dict, context: ToolContext) -> ToolResult:
|
||||||
|
frames = load_frames(context.frames_dir)
|
||||||
|
segments = load_transcript(context.transcript_dir)
|
||||||
|
|
||||||
|
duration = 0.0
|
||||||
|
if context.tracker:
|
||||||
|
duration = getattr(context.tracker, "duration", 0.0)
|
||||||
|
|
||||||
|
m, s = divmod(int(duration), 60)
|
||||||
|
lines = [
|
||||||
|
f"Recording duration: {m:02d}:{s:02d}",
|
||||||
|
f"Frames captured: {len(frames)}",
|
||||||
|
f"Transcript segments: {len(segments)}",
|
||||||
|
]
|
||||||
|
|
||||||
|
# List recording segments from session dir
|
||||||
|
stream_dir = context.session_dir / "stream"
|
||||||
|
if stream_dir.exists():
|
||||||
|
recordings = sorted(stream_dir.glob("recording_*.mp4"))
|
||||||
|
lines.append(f"Recording files: {len(recordings)}")
|
||||||
|
for rec in recordings:
|
||||||
|
lines.append(f" {rec.name}")
|
||||||
|
|
||||||
|
return ToolResult(tool_use_id="", output="\n".join(lines))
|
||||||
|
|
||||||
|
|
||||||
|
class CaptureFrameTool:
|
||||||
|
name = "capture_frame"
|
||||||
|
description = "Capture a frame at the current recording position."
|
||||||
|
|
||||||
|
def input_schema(self) -> dict:
|
||||||
|
return {"type": "object", "properties": {}}
|
||||||
|
|
||||||
|
def run(self, input: dict, context: ToolContext) -> ToolResult:
|
||||||
|
mgr = context.stream_mgr
|
||||||
|
if mgr is None:
|
||||||
|
return ToolResult(tool_use_id="", error="No active stream manager")
|
||||||
|
if getattr(mgr, "readonly", False):
|
||||||
|
return ToolResult(tool_use_id="", error="Session is read-only, cannot capture")
|
||||||
|
|
||||||
|
import threading
|
||||||
|
result = {"done": False, "error": None}
|
||||||
|
event = threading.Event()
|
||||||
|
|
||||||
|
def _on_frames(frames):
|
||||||
|
result["done"] = True
|
||||||
|
event.set()
|
||||||
|
|
||||||
|
try:
|
||||||
|
mgr.capture_now(on_new_frames=_on_frames)
|
||||||
|
event.wait(timeout=10)
|
||||||
|
if not result["done"]:
|
||||||
|
return ToolResult(tool_use_id="", error="Capture timed out")
|
||||||
|
return ToolResult(tool_use_id="", output="Frame captured successfully.")
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult(tool_use_id="", error=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
# All built-in tools
|
||||||
|
BUILTIN_TOOLS = [ReadFrameTool(), SearchTranscriptTool(), GetSessionInfoTool(), CaptureFrameTool()]
|
||||||
@@ -31,3 +31,7 @@ WHISPER_MODEL = "medium" # "small" for speed, "medium" for accuracy
|
|||||||
WHISPER_DEVICE = "cuda" # "cuda" or "cpu"
|
WHISPER_DEVICE = "cuda" # "cuda" or "cpu"
|
||||||
TRANSCRIBE_MIN_CHUNK_S = 5 # minimum seconds of audio before transcribing
|
TRANSCRIBE_MIN_CHUNK_S = 5 # minimum seconds of audio before transcribing
|
||||||
TRANSCRIBE_LINES_PER_GROUP = 3 # whisper segments grouped per transcript ID (1-5)
|
TRANSCRIBE_LINES_PER_GROUP = 3 # whisper segments grouped per transcript ID (1-5)
|
||||||
|
|
||||||
|
# Agent settings
|
||||||
|
AGENT_PERMISSION_MODE = "bypassPermissions" # default|acceptEdits|plan|bypassPermissions|dontAsk
|
||||||
|
AGENT_MAX_TURNS = 5
|
||||||
|
|||||||
@@ -1,20 +1,35 @@
|
|||||||
"""Agent output panel — scrollable text view with markdown rendering."""
|
"""Agent output panel — single TextView, fully selectable/copy-pastable."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
import gi
|
import gi
|
||||||
gi.require_version("Gtk", "4.0")
|
gi.require_version("Gtk", "4.0")
|
||||||
from gi.repository import Gtk
|
from gi.repository import Gtk, GLib, Pango
|
||||||
|
|
||||||
from cht.ui import markdown
|
from cht.ui import markdown
|
||||||
|
from cht.agent.base import (
|
||||||
|
AssistantMessage,
|
||||||
|
ImageBlock,
|
||||||
|
TextBlock,
|
||||||
|
Thread,
|
||||||
|
ToolResult,
|
||||||
|
ToolUse,
|
||||||
|
TranscriptBlock,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AgentOutputPanel(Gtk.Frame):
|
class AgentOutputPanel(Gtk.Frame):
|
||||||
"""Scrollable text view that displays agent responses with markdown."""
|
"""Scrollable text view showing the full conversation, copy-pastable."""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
box = Gtk.Box(orientation=Gtk.Orientation.VERTICAL, spacing=0)
|
box = Gtk.Box(orientation=Gtk.Orientation.VERTICAL, spacing=0)
|
||||||
|
|
||||||
|
# Header
|
||||||
header = Gtk.Box(orientation=Gtk.Orientation.HORIZONTAL, spacing=4)
|
header = Gtk.Box(orientation=Gtk.Orientation.HORIZONTAL, spacing=4)
|
||||||
header.set_margin_start(8)
|
header.set_margin_start(8)
|
||||||
header.set_margin_end(8)
|
header.set_margin_end(8)
|
||||||
@@ -33,13 +48,16 @@ class AgentOutputPanel(Gtk.Frame):
|
|||||||
|
|
||||||
box.append(header)
|
box.append(header)
|
||||||
|
|
||||||
|
# Single text view for the whole conversation
|
||||||
self._view = Gtk.TextView()
|
self._view = Gtk.TextView()
|
||||||
self._view.set_editable(False)
|
self._view.set_editable(False)
|
||||||
self._view.set_wrap_mode(Gtk.WrapMode.WORD_CHAR)
|
self._view.set_wrap_mode(Gtk.WrapMode.WORD_CHAR)
|
||||||
self._view.set_cursor_visible(False)
|
self._view.set_cursor_visible(False)
|
||||||
self._view.set_left_margin(8)
|
self._view.set_left_margin(8)
|
||||||
self._view.set_right_margin(8)
|
self._view.set_right_margin(8)
|
||||||
markdown.setup_tags(self._view.get_buffer())
|
self._view.set_top_margin(4)
|
||||||
|
self._view.set_bottom_margin(4)
|
||||||
|
self._setup_tags()
|
||||||
|
|
||||||
scroll = Gtk.ScrolledWindow()
|
scroll = Gtk.ScrolledWindow()
|
||||||
scroll.set_vexpand(True)
|
scroll.set_vexpand(True)
|
||||||
@@ -48,52 +66,170 @@ class AgentOutputPanel(Gtk.Frame):
|
|||||||
|
|
||||||
self.set_child(box)
|
self.set_child(box)
|
||||||
|
|
||||||
# Streaming state
|
# Streaming state: track where the current assistant response starts
|
||||||
self._thinking_replaced = False
|
self._response_marks: dict[str, Gtk.TextMark] = {}
|
||||||
self._response_start_mark = None
|
self._response_accum: dict[str, list[str]] = {}
|
||||||
self._response_accum = []
|
|
||||||
|
def _setup_tags(self):
|
||||||
|
buf = self._view.get_buffer()
|
||||||
|
markdown.setup_tags(buf)
|
||||||
|
buf.create_tag("user-prefix", weight=700, foreground="#7aafff")
|
||||||
|
buf.create_tag("assistant-prefix", weight=700, foreground="#8bc78b")
|
||||||
|
buf.create_tag("tool-prefix", weight=700, foreground="#d4a053")
|
||||||
|
buf.create_tag("tool-output", foreground="#aaaaaa", left_margin=16)
|
||||||
|
buf.create_tag("error", foreground="#ff6b6b")
|
||||||
|
buf.create_tag("ref-chip", foreground="#7aafff", style=Pango.Style.ITALIC)
|
||||||
|
buf.create_tag("status", foreground="#888888")
|
||||||
|
|
||||||
|
# -- Public API --
|
||||||
|
|
||||||
def append(self, text: str) -> None:
|
def append(self, text: str) -> None:
|
||||||
"""Append plain text and auto-scroll."""
|
"""Append a status/info line."""
|
||||||
buf = self._view.get_buffer()
|
buf = self._view.get_buffer()
|
||||||
buf.insert(buf.get_end_iter(), text)
|
end = buf.get_end_iter()
|
||||||
self._view.scroll_to_iter(buf.get_end_iter(), 0, False, 0, 0)
|
mark = buf.create_mark(None, end, True)
|
||||||
|
buf.insert(end, text)
|
||||||
|
start = buf.get_iter_at_mark(mark)
|
||||||
|
buf.apply_tag_by_name("status", start, buf.get_end_iter())
|
||||||
|
buf.delete_mark(mark)
|
||||||
|
self._scroll_to_bottom()
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear all output."""
|
|
||||||
self._view.get_buffer().set_text("")
|
self._view.get_buffer().set_text("")
|
||||||
|
self._response_marks.clear()
|
||||||
|
self._response_accum.clear()
|
||||||
|
|
||||||
def begin_response(self) -> None:
|
def add_user_message(self, text: str, frames: list | None = None,
|
||||||
"""Reset streaming state for a new response (call before sending)."""
|
transcripts: list | None = None) -> None:
|
||||||
self._thinking_replaced = False
|
|
||||||
self._response_start_mark = None
|
|
||||||
self._response_accum = []
|
|
||||||
|
|
||||||
def replace_thinking(self, chunk: str) -> None:
|
|
||||||
"""Replace the '...' placeholder with streamed chunks."""
|
|
||||||
buf = self._view.get_buffer()
|
buf = self._view.get_buffer()
|
||||||
if not self._thinking_replaced:
|
end = buf.get_end_iter()
|
||||||
self._thinking_replaced = True
|
|
||||||
end = buf.get_end_iter()
|
|
||||||
start = end.copy()
|
|
||||||
start.backward_chars(2)
|
|
||||||
buf.delete(start, end)
|
|
||||||
self._response_start_mark = buf.create_mark(
|
|
||||||
None, buf.get_end_iter(), left_gravity=True
|
|
||||||
)
|
|
||||||
self._response_accum.append(chunk)
|
|
||||||
self.append(chunk)
|
|
||||||
|
|
||||||
def finish_response(self, err: str | None) -> None:
|
# Prefix
|
||||||
"""Finalize the response — render markdown or show error."""
|
mark = buf.create_mark(None, end, True)
|
||||||
if err:
|
buf.insert(end, "\n> ")
|
||||||
self.append(f"[Error: {err}]\n")
|
buf.apply_tag_by_name("user-prefix", buf.get_iter_at_mark(mark), buf.get_end_iter())
|
||||||
return
|
buf.delete_mark(mark)
|
||||||
if self._response_start_mark and self._response_accum:
|
|
||||||
buf = self._view.get_buffer()
|
# Text
|
||||||
start = buf.get_iter_at_mark(self._response_start_mark)
|
buf.insert(buf.get_end_iter(), text)
|
||||||
|
|
||||||
|
# Reference chips
|
||||||
|
refs = []
|
||||||
|
if frames:
|
||||||
|
for f in frames:
|
||||||
|
fid = f.frame_id if hasattr(f, "frame_id") else (f.id if hasattr(f, "id") else str(f))
|
||||||
|
refs.append(fid)
|
||||||
|
if transcripts:
|
||||||
|
for t in transcripts:
|
||||||
|
tid = t.transcript_id if hasattr(t, "transcript_id") else (t.id if hasattr(t, "id") else str(t))
|
||||||
|
refs.append(tid)
|
||||||
|
if refs:
|
||||||
end = buf.get_end_iter()
|
end = buf.get_end_iter()
|
||||||
buf.delete(start, end)
|
mark = buf.create_mark(None, end, True)
|
||||||
markdown.render(buf, start, "".join(self._response_accum))
|
buf.insert(end, " [" + ", ".join(refs) + "]")
|
||||||
buf.delete_mark(self._response_start_mark)
|
buf.apply_tag_by_name("ref-chip", buf.get_iter_at_mark(mark), buf.get_end_iter())
|
||||||
self.append("\n")
|
buf.delete_mark(mark)
|
||||||
|
|
||||||
|
buf.insert(buf.get_end_iter(), "\n")
|
||||||
|
self._scroll_to_bottom()
|
||||||
|
|
||||||
|
def begin_assistant_message(self, msg_id: str) -> None:
|
||||||
|
buf = self._view.get_buffer()
|
||||||
|
end = buf.get_end_iter()
|
||||||
|
|
||||||
|
# Mark where this response starts (for markdown re-render on finish)
|
||||||
|
self._response_marks[msg_id] = buf.create_mark(f"resp_{msg_id}", end, True)
|
||||||
|
self._response_accum[msg_id] = []
|
||||||
|
|
||||||
|
def append_to_assistant(self, msg_id: str, text: str) -> None:
|
||||||
|
if msg_id not in self._response_accum:
|
||||||
|
return
|
||||||
|
self._response_accum[msg_id].append(text)
|
||||||
|
buf = self._view.get_buffer()
|
||||||
|
buf.insert(buf.get_end_iter(), text)
|
||||||
|
self._scroll_to_bottom()
|
||||||
|
|
||||||
|
def finish_assistant(self, msg_id: str, full_text: str) -> None:
|
||||||
|
mark = self._response_marks.pop(msg_id, None)
|
||||||
|
self._response_accum.pop(msg_id, None)
|
||||||
|
if not mark:
|
||||||
|
return
|
||||||
|
buf = self._view.get_buffer()
|
||||||
|
start = buf.get_iter_at_mark(mark)
|
||||||
|
end = buf.get_end_iter()
|
||||||
|
buf.delete(start, end)
|
||||||
|
it = buf.get_iter_at_mark(mark)
|
||||||
|
markdown.render(buf, it, full_text)
|
||||||
|
buf.insert(buf.get_end_iter(), "\n")
|
||||||
|
buf.delete_mark(mark)
|
||||||
|
|
||||||
|
def add_tool_call(self, tool_use: ToolUse) -> None:
|
||||||
|
buf = self._view.get_buffer()
|
||||||
|
end = buf.get_end_iter()
|
||||||
|
|
||||||
|
mark = buf.create_mark(None, end, True)
|
||||||
|
buf.insert(end, f" ▶ {tool_use.tool_name}")
|
||||||
|
buf.apply_tag_by_name("tool-prefix", buf.get_iter_at_mark(mark), buf.get_end_iter())
|
||||||
|
buf.delete_mark(mark)
|
||||||
|
|
||||||
|
if tool_use.input:
|
||||||
|
inp = str(tool_use.input)
|
||||||
|
if len(inp) > 80:
|
||||||
|
inp = inp[:77] + "..."
|
||||||
|
end = buf.get_end_iter()
|
||||||
|
mark = buf.create_mark(None, end, True)
|
||||||
|
buf.insert(end, f" {inp}")
|
||||||
|
buf.apply_tag_by_name("tool-output", buf.get_iter_at_mark(mark), buf.get_end_iter())
|
||||||
|
buf.delete_mark(mark)
|
||||||
|
|
||||||
|
buf.insert(buf.get_end_iter(), "\n")
|
||||||
|
self._scroll_to_bottom()
|
||||||
|
|
||||||
|
def update_tool_result(self, tool_use_id: str, result: ToolResult) -> None:
|
||||||
|
buf = self._view.get_buffer()
|
||||||
|
text = result.error or result.output or ""
|
||||||
|
if not text:
|
||||||
|
return
|
||||||
|
end = buf.get_end_iter()
|
||||||
|
mark = buf.create_mark(None, end, True)
|
||||||
|
# Indent tool output
|
||||||
|
indented = "\n".join(f" {line}" for line in text.split("\n"))
|
||||||
|
tag = "error" if result.error else "tool-output"
|
||||||
|
buf.insert(end, indented + "\n")
|
||||||
|
buf.apply_tag_by_name(tag, buf.get_iter_at_mark(mark), buf.get_end_iter())
|
||||||
|
buf.delete_mark(mark)
|
||||||
|
self._scroll_to_bottom()
|
||||||
|
|
||||||
|
def load_thread(self, thread: Thread) -> None:
|
||||||
|
"""Replay a thread to rebuild the conversation view."""
|
||||||
|
self.clear()
|
||||||
|
for msg in thread.messages:
|
||||||
|
if isinstance(msg, UserMessage):
|
||||||
|
text_parts = []
|
||||||
|
frames = []
|
||||||
|
transcripts = []
|
||||||
|
for b in msg.content:
|
||||||
|
if isinstance(b, TextBlock):
|
||||||
|
text_parts.append(b.text)
|
||||||
|
elif isinstance(b, ImageBlock):
|
||||||
|
frames.append(b)
|
||||||
|
elif isinstance(b, TranscriptBlock):
|
||||||
|
transcripts.append(b)
|
||||||
|
self.add_user_message(" ".join(text_parts), frames, transcripts)
|
||||||
|
elif isinstance(msg, AssistantMessage):
|
||||||
|
text = " ".join(b.text for b in msg.content if isinstance(b, TextBlock))
|
||||||
|
buf = self._view.get_buffer()
|
||||||
|
it = buf.get_end_iter()
|
||||||
|
markdown.render(buf, it, text)
|
||||||
|
buf.insert(buf.get_end_iter(), "\n")
|
||||||
|
elif isinstance(msg, ToolUse):
|
||||||
|
self.add_tool_call(msg)
|
||||||
|
elif isinstance(msg, ToolResult):
|
||||||
|
self.update_tool_result(msg.tool_use_id, msg)
|
||||||
|
|
||||||
|
def _scroll_to_bottom(self):
|
||||||
|
def _do():
|
||||||
|
adj = self._view.get_parent().get_vadjustment()
|
||||||
|
adj.set_value(adj.get_upper() - adj.get_page_size())
|
||||||
|
return False
|
||||||
|
GLib.idle_add(_do)
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ from cht.ui.session_dialog import SessionDialog
|
|||||||
from cht.session import load_frame_index, load_segment_manifest, rebuild_manifest, global_time_to_segment
|
from cht.session import load_frame_index, load_segment_manifest, rebuild_manifest, global_time_to_segment
|
||||||
from cht.scrub.manager import ProxyManager
|
from cht.scrub.manager import ProxyManager
|
||||||
from cht.agent.runner import AgentRunner, check_claude_cli
|
from cht.agent.runner import AgentRunner, check_claude_cli
|
||||||
|
from cht.agent.base import TextDelta, ToolCallStart, ToolCallEnd, ToolResult, ToolUse, Done, Error
|
||||||
|
from cht.agent.buffer import StreamingTextBuffer
|
||||||
from cht.telemetry import Telemetry
|
from cht.telemetry import Telemetry
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@@ -245,6 +247,11 @@ class ChtWindow(Adw.ApplicationWindow):
|
|||||||
self._update_scrub_bar_manifest()
|
self._update_scrub_bar_manifest()
|
||||||
self._populate_model_dropdown()
|
self._populate_model_dropdown()
|
||||||
|
|
||||||
|
# Load persisted agent conversation
|
||||||
|
self._agent.load_from_session(mgr.session_dir)
|
||||||
|
if self._agent.thread.messages:
|
||||||
|
self._agent_output.load_thread(self._agent.thread)
|
||||||
|
|
||||||
def _reload_waveform(self, mgr):
|
def _reload_waveform(self, mgr):
|
||||||
"""Recompute waveform from existing segments in background."""
|
"""Recompute waveform from existing segments in background."""
|
||||||
segments = mgr.recording_segments
|
segments = mgr.recording_segments
|
||||||
@@ -289,6 +296,9 @@ class ChtWindow(Adw.ApplicationWindow):
|
|||||||
self._load_existing_frames()
|
self._load_existing_frames()
|
||||||
self._load_existing_transcript()
|
self._load_existing_transcript()
|
||||||
self._reload_waveform(mgr)
|
self._reload_waveform(mgr)
|
||||||
|
self._agent.load_from_session(mgr.session_dir)
|
||||||
|
if self._agent.thread.messages:
|
||||||
|
self._agent_output.load_thread(self._agent.thread)
|
||||||
|
|
||||||
self.set_title(f"{APP_NAME} — {mgr.session_id}")
|
self.set_title(f"{APP_NAME} — {mgr.session_id}")
|
||||||
log.info("Waiting for sender...")
|
log.info("Waiting for sender...")
|
||||||
@@ -453,6 +463,11 @@ class ChtWindow(Adw.ApplicationWindow):
|
|||||||
mgr = self._lifecycle.stream_mgr
|
mgr = self._lifecycle.stream_mgr
|
||||||
last_session_id = mgr.session_id if mgr and not mgr.readonly else None
|
last_session_id = mgr.session_id if mgr and not mgr.readonly else None
|
||||||
|
|
||||||
|
# Save agent thread before stopping
|
||||||
|
if mgr and self._agent.thread.messages:
|
||||||
|
from cht.agent.base import save_thread
|
||||||
|
save_thread(self._agent.thread, mgr.session_dir)
|
||||||
|
|
||||||
if self._telemetry:
|
if self._telemetry:
|
||||||
self._telemetry.close()
|
self._telemetry.close()
|
||||||
self._telemetry = None
|
self._telemetry = None
|
||||||
@@ -604,15 +619,53 @@ class ChtWindow(Adw.ApplicationWindow):
|
|||||||
self._agent_output.append("No active session.\n")
|
self._agent_output.append("No active session.\n")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._agent_output.append(f"\n> {text}\n…\n")
|
# Show user message in UI
|
||||||
|
from cht.agent.runner import _parse_mentions, _parse_transcript_mentions
|
||||||
|
from cht.agent.tools import load_frames, load_transcript
|
||||||
|
mgr = self._lifecycle.stream_mgr
|
||||||
|
frames = load_frames(mgr.frames_dir)
|
||||||
|
mentioned_frames = _parse_mentions(text, frames)
|
||||||
|
transcript = load_transcript(mgr.transcript_dir)
|
||||||
|
mentioned_transcripts = _parse_transcript_mentions(text, transcript)
|
||||||
|
self._agent_output.add_user_message(text, mentioned_frames, mentioned_transcripts)
|
||||||
|
|
||||||
|
# Prepare streaming
|
||||||
|
from cht.agent.base import _msg_id
|
||||||
|
msg_id = _msg_id()
|
||||||
|
self._agent_output.begin_assistant_message(msg_id)
|
||||||
|
|
||||||
|
full_text_parts = []
|
||||||
|
buffer = StreamingTextBuffer(
|
||||||
|
on_reveal=lambda chunk: self._agent_output.append_to_assistant(msg_id, chunk)
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_event(event):
|
||||||
|
if isinstance(event, TextDelta):
|
||||||
|
full_text_parts.append(event.text)
|
||||||
|
GLib.idle_add(buffer.push, event.text)
|
||||||
|
elif isinstance(event, ToolCallStart):
|
||||||
|
tu = ToolUse(id=event.id, tool_name=event.name, input=event.input, status="running")
|
||||||
|
GLib.idle_add(self._agent_output.add_tool_call, tu)
|
||||||
|
elif isinstance(event, ToolResult):
|
||||||
|
GLib.idle_add(self._agent_output.update_tool_result, event.tool_use_id, event)
|
||||||
|
elif isinstance(event, Error):
|
||||||
|
GLib.idle_add(self._agent_output.append, f"[Error: {event.message}]\n")
|
||||||
|
|
||||||
|
def on_done(err):
|
||||||
|
def _finish():
|
||||||
|
buffer.flush()
|
||||||
|
if err:
|
||||||
|
self._agent_output.append(f"[Error: {err}]\n")
|
||||||
|
else:
|
||||||
|
self._agent_output.finish_assistant(msg_id, "".join(full_text_parts))
|
||||||
|
GLib.idle_add(_finish)
|
||||||
|
|
||||||
self._agent_output.begin_response()
|
|
||||||
self._agent.send(
|
self._agent.send(
|
||||||
message=text,
|
message=text,
|
||||||
stream_mgr=self._lifecycle.stream_mgr,
|
stream_mgr=mgr,
|
||||||
tracker=self._lifecycle.tracker,
|
tracker=self._lifecycle.tracker,
|
||||||
on_chunk=lambda chunk: GLib.idle_add(self._agent_output.replace_thinking, chunk),
|
on_event=on_event,
|
||||||
on_done=lambda err: GLib.idle_add(self._agent_output.finish_response, err),
|
on_done=on_done,
|
||||||
)
|
)
|
||||||
|
|
||||||
# -- Settings callbacks --
|
# -- Settings callbacks --
|
||||||
|
|||||||
Reference in New Issue
Block a user