328 lines
11 KiB
Python
328 lines
11 KiB
Python
"""
|
|
Agent runner — resolves connection, parses @-mentions, dispatches messages,
|
|
executes tool loop.
|
|
|
|
Connection selection (in order):
|
|
1. GROQ_API_KEY → OpenAI-compat / Groq
|
|
2. OPENAI_API_KEY → OpenAI-compat / OpenAI
|
|
3. (default) → Claude Code SDK (uses CC subscription)
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import re
|
|
from pathlib import Path
|
|
from threading import Thread
|
|
from typing import Callable
|
|
|
|
from cht.agent.base import (
|
|
AssistantMessage,
|
|
FrameRef,
|
|
ImageBlock,
|
|
Message,
|
|
StreamEvent,
|
|
TextBlock,
|
|
TextDelta,
|
|
ToolCallEnd,
|
|
ToolCallStart,
|
|
ToolContext,
|
|
ToolResult,
|
|
ToolUse,
|
|
TranscriptBlock,
|
|
TranscriptRef,
|
|
UserMessage,
|
|
Done,
|
|
Error,
|
|
Thread as AgentThread,
|
|
save_thread,
|
|
load_thread,
|
|
)
|
|
from cht.agent.tools import load_frames, load_transcript, BUILTIN_TOOLS
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
# Predefined actions — label → verb prefix (frame ref appended by UI)
|
|
ACTIONS: dict[str, str] = {
|
|
"Describe": "describe",
|
|
"Answer": "answer",
|
|
}
|
|
|
|
MAX_TOOL_TURNS = 10
|
|
|
|
|
|
def check_claude_cli() -> str | None:
|
|
"""Returns None if OK, or an error string if CLI is missing/unauthenticated."""
|
|
import shutil, subprocess
|
|
if not shutil.which("claude"):
|
|
return "Claude Code CLI not found. Install from https://claude.ai/code"
|
|
try:
|
|
result = subprocess.run(
|
|
["claude", "--version"],
|
|
capture_output=True, timeout=5
|
|
)
|
|
if result.returncode != 0:
|
|
return "Claude Code CLI error. Run `claude` in a terminal to check."
|
|
except Exception as e:
|
|
return f"Claude Code CLI check failed: {e}"
|
|
return None
|
|
|
|
|
|
def _resolve_connection():
|
|
if os.environ.get("GROQ_API_KEY") or os.environ.get("OPENAI_API_KEY"):
|
|
from cht.agent.openai_connection import OpenAIConnection
|
|
return OpenAIConnection()
|
|
from cht.agent.claude_sdk_connection import ClaudeSDKConnection
|
|
return ClaudeSDKConnection()
|
|
|
|
|
|
def _expand_ref_nums(spec: str) -> list[int]:
|
|
nums = set()
|
|
for part in spec.split(","):
|
|
part = part.strip()
|
|
if "-" in part:
|
|
a, b = part.split("-", 1)
|
|
try:
|
|
nums.update(range(int(a), int(b) + 1))
|
|
except ValueError:
|
|
pass
|
|
elif part:
|
|
try:
|
|
nums.add(int(part))
|
|
except ValueError:
|
|
pass
|
|
return sorted(nums)
|
|
|
|
|
|
def _parse_mentions(message: str, frames: list[FrameRef]) -> list[FrameRef]:
|
|
mentioned = []
|
|
seen = set()
|
|
for match in re.finditer(r"@[Ff]([\d,\-]+)", message):
|
|
for num in _expand_ref_nums(match.group(1)):
|
|
fid = f"F{num:04d}"
|
|
if fid not in seen:
|
|
frame = next((f for f in frames if f.id == fid), None)
|
|
if frame:
|
|
mentioned.append(frame)
|
|
seen.add(fid)
|
|
return mentioned
|
|
|
|
|
|
def _parse_transcript_mentions(message: str, segments: list[TranscriptRef]) -> list[TranscriptRef]:
|
|
mentioned = []
|
|
seen = set()
|
|
for match in re.finditer(r"@[Tt]([\d,\-]+)", message):
|
|
for num in _expand_ref_nums(match.group(1)):
|
|
tid = f"T{num:04d}"
|
|
if tid not in seen:
|
|
seg = next((s for s in segments if s.id == tid), None)
|
|
if seg:
|
|
mentioned.append(seg)
|
|
seen.add(tid)
|
|
return mentioned
|
|
|
|
|
|
def _build_user_message(text: str, mentioned_frames: list[FrameRef],
|
|
mentioned_transcripts: list[TranscriptRef]) -> UserMessage:
|
|
"""Build a UserMessage with content blocks from text and @-mentions."""
|
|
content: list = [TextBlock(text=text)]
|
|
for f in mentioned_frames:
|
|
content.append(ImageBlock(frame_id=f.id, path=f.path, timestamp=f.timestamp))
|
|
for t in mentioned_transcripts:
|
|
content.append(TranscriptBlock(
|
|
transcript_id=t.id, start=t.start, end=t.end, text=t.text
|
|
))
|
|
return UserMessage(content=content)
|
|
|
|
|
|
class AgentRunner:
|
|
"""Runs agent queries in a background thread with tool execution loop."""
|
|
|
|
def __init__(self):
|
|
self._connection = None
|
|
self._thread: AgentThread = AgentThread()
|
|
self.include_history = False
|
|
|
|
@property
|
|
def thread(self) -> AgentThread:
|
|
return self._thread
|
|
|
|
def _get_connection(self):
|
|
if self._connection is None:
|
|
self._connection = _resolve_connection()
|
|
log.info("Agent connection: %s", self._connection.name)
|
|
return self._connection
|
|
|
|
@property
|
|
def provider_name(self) -> str:
|
|
try:
|
|
return self._get_connection().name
|
|
except Exception:
|
|
return "unknown"
|
|
|
|
@property
|
|
def available_models(self) -> list[str]:
|
|
try:
|
|
return self._get_connection().available_models()
|
|
except Exception:
|
|
return []
|
|
|
|
@property
|
|
def model(self) -> str:
|
|
try:
|
|
return self._get_connection().get_model()
|
|
except Exception:
|
|
return ""
|
|
|
|
@model.setter
|
|
def model(self, value: str):
|
|
self._get_connection().set_model(value)
|
|
|
|
@property
|
|
def permission_mode(self) -> str:
|
|
conn = self._get_connection()
|
|
return getattr(conn, "_permission_mode", "default")
|
|
|
|
@permission_mode.setter
|
|
def permission_mode(self, value: str):
|
|
conn = self._get_connection()
|
|
if hasattr(conn, "_permission_mode"):
|
|
conn._permission_mode = value
|
|
import cht.config
|
|
cht.config.AGENT_PERMISSION_MODE = value
|
|
log.info("Permission mode set to %s", value)
|
|
|
|
def clear_history(self):
|
|
self._thread = AgentThread()
|
|
|
|
def set_thread(self, thread: AgentThread):
|
|
self._thread = thread
|
|
|
|
def load_from_session(self, session_dir: Path):
|
|
loaded = load_thread(session_dir)
|
|
if loaded:
|
|
self._thread = loaded
|
|
log.info("Loaded thread %s with %d messages", loaded.id, len(loaded.messages))
|
|
else:
|
|
self._thread = AgentThread()
|
|
|
|
def send(
|
|
self,
|
|
message: str,
|
|
stream_mgr,
|
|
tracker,
|
|
on_event: Callable[[StreamEvent], None],
|
|
on_done: Callable[[str | None], None],
|
|
):
|
|
"""Dispatch message in a background thread with tool execution loop.
|
|
|
|
on_event(StreamEvent) — called for each stream event (from bg thread)
|
|
on_done(error_or_None) — called when complete (from bg thread)
|
|
"""
|
|
def _run():
|
|
try:
|
|
connection = self._get_connection()
|
|
frames = load_frames(stream_mgr.frames_dir)
|
|
mentioned_frames = _parse_mentions(message, frames)
|
|
transcript = load_transcript(stream_mgr.transcript_dir)
|
|
mentioned_transcripts = _parse_transcript_mentions(message, transcript)
|
|
|
|
# Build and append user message
|
|
user_msg = _build_user_message(message, mentioned_frames, mentioned_transcripts)
|
|
self._thread.messages.append(user_msg)
|
|
|
|
# Build tool context
|
|
tool_ctx = ToolContext(
|
|
session_dir=stream_mgr.session_dir,
|
|
frames_dir=stream_mgr.frames_dir,
|
|
transcript_dir=stream_mgr.transcript_dir,
|
|
stream_mgr=stream_mgr,
|
|
tracker=tracker,
|
|
)
|
|
|
|
# Tool registry
|
|
tools_by_name = {t.name: t for t in BUILTIN_TOOLS}
|
|
|
|
# Messages to send — full thread or just last message
|
|
def _get_messages():
|
|
if self.include_history:
|
|
return list(self._thread.messages)
|
|
return [self._thread.messages[-1]]
|
|
|
|
# Tool execution loop
|
|
full_text_parts = []
|
|
for _turn in range(MAX_TOOL_TURNS):
|
|
msgs = _get_messages()
|
|
stop_reason = None
|
|
|
|
for event in connection.prompt(msgs, BUILTIN_TOOLS):
|
|
on_event(event)
|
|
|
|
if isinstance(event, TextDelta):
|
|
full_text_parts.append(event.text)
|
|
|
|
elif isinstance(event, ToolCallStart):
|
|
tool_use = ToolUse(
|
|
id=event.id,
|
|
tool_name=event.name,
|
|
input=event.input,
|
|
status="running",
|
|
)
|
|
self._thread.messages.append(tool_use)
|
|
|
|
elif isinstance(event, ToolCallEnd):
|
|
# Execute the tool
|
|
tool = tools_by_name.get(event.id)
|
|
# Find the ToolUse by id
|
|
tool_use_msg = None
|
|
for m in reversed(self._thread.messages):
|
|
if isinstance(m, ToolUse) and m.id == event.id:
|
|
tool_use_msg = m
|
|
break
|
|
|
|
if tool_use_msg:
|
|
tool_impl = tools_by_name.get(tool_use_msg.tool_name)
|
|
if tool_impl:
|
|
result = tool_impl.run(tool_use_msg.input, tool_ctx)
|
|
result.tool_use_id = tool_use_msg.id
|
|
tool_use_msg.status = "error" if result.error else "done"
|
|
else:
|
|
result = ToolResult(
|
|
tool_use_id=tool_use_msg.id,
|
|
error=f"Unknown tool: {tool_use_msg.tool_name}",
|
|
)
|
|
tool_use_msg.status = "error"
|
|
self._thread.messages.append(result)
|
|
on_event(result)
|
|
|
|
elif isinstance(event, Done):
|
|
stop_reason = event.stop_reason
|
|
break
|
|
|
|
elif isinstance(event, Error):
|
|
on_done(event.message)
|
|
return
|
|
|
|
# Build assistant message from accumulated text
|
|
if full_text_parts:
|
|
asst_msg = AssistantMessage(
|
|
content=[TextBlock(text="".join(full_text_parts))],
|
|
model=connection.get_model(),
|
|
)
|
|
self._thread.messages.append(asst_msg)
|
|
|
|
if stop_reason != "tool_use":
|
|
break
|
|
|
|
# Reset text for next turn (tool loop continues)
|
|
full_text_parts = []
|
|
|
|
# Save thread
|
|
save_thread(self._thread, stream_mgr.session_dir)
|
|
on_done(None)
|
|
|
|
except Exception as e:
|
|
log.error("Agent error: %s", e)
|
|
on_done(str(e))
|
|
|
|
Thread(target=_run, daemon=True, name="agent_runner").start()
|