Files
mitus/cht/agent/runner.py
2026-04-09 14:58:15 -03:00

328 lines
11 KiB
Python

"""
Agent runner — resolves connection, parses @-mentions, dispatches messages,
executes tool loop.
Connection selection (in order):
1. GROQ_API_KEY → OpenAI-compat / Groq
2. OPENAI_API_KEY → OpenAI-compat / OpenAI
3. (default) → Claude Code SDK (uses CC subscription)
"""
import logging
import os
import re
from pathlib import Path
from threading import Thread
from typing import Callable
from cht.agent.base import (
AssistantMessage,
FrameRef,
ImageBlock,
Message,
StreamEvent,
TextBlock,
TextDelta,
ToolCallEnd,
ToolCallStart,
ToolContext,
ToolResult,
ToolUse,
TranscriptBlock,
TranscriptRef,
UserMessage,
Done,
Error,
Thread as AgentThread,
save_thread,
load_thread,
)
from cht.agent.tools import load_frames, load_transcript, BUILTIN_TOOLS
log = logging.getLogger(__name__)
# Predefined actions — label → verb prefix (frame ref appended by UI)
ACTIONS: dict[str, str] = {
"Describe": "describe",
"Answer": "answer",
}
MAX_TOOL_TURNS = 10
def check_claude_cli() -> str | None:
"""Returns None if OK, or an error string if CLI is missing/unauthenticated."""
import shutil, subprocess
if not shutil.which("claude"):
return "Claude Code CLI not found. Install from https://claude.ai/code"
try:
result = subprocess.run(
["claude", "--version"],
capture_output=True, timeout=5
)
if result.returncode != 0:
return "Claude Code CLI error. Run `claude` in a terminal to check."
except Exception as e:
return f"Claude Code CLI check failed: {e}"
return None
def _resolve_connection():
if os.environ.get("GROQ_API_KEY") or os.environ.get("OPENAI_API_KEY"):
from cht.agent.openai_connection import OpenAIConnection
return OpenAIConnection()
from cht.agent.claude_sdk_connection import ClaudeSDKConnection
return ClaudeSDKConnection()
def _expand_ref_nums(spec: str) -> list[int]:
nums = set()
for part in spec.split(","):
part = part.strip()
if "-" in part:
a, b = part.split("-", 1)
try:
nums.update(range(int(a), int(b) + 1))
except ValueError:
pass
elif part:
try:
nums.add(int(part))
except ValueError:
pass
return sorted(nums)
def _parse_mentions(message: str, frames: list[FrameRef]) -> list[FrameRef]:
mentioned = []
seen = set()
for match in re.finditer(r"@[Ff]([\d,\-]+)", message):
for num in _expand_ref_nums(match.group(1)):
fid = f"F{num:04d}"
if fid not in seen:
frame = next((f for f in frames if f.id == fid), None)
if frame:
mentioned.append(frame)
seen.add(fid)
return mentioned
def _parse_transcript_mentions(message: str, segments: list[TranscriptRef]) -> list[TranscriptRef]:
mentioned = []
seen = set()
for match in re.finditer(r"@[Tt]([\d,\-]+)", message):
for num in _expand_ref_nums(match.group(1)):
tid = f"T{num:04d}"
if tid not in seen:
seg = next((s for s in segments if s.id == tid), None)
if seg:
mentioned.append(seg)
seen.add(tid)
return mentioned
def _build_user_message(text: str, mentioned_frames: list[FrameRef],
mentioned_transcripts: list[TranscriptRef]) -> UserMessage:
"""Build a UserMessage with content blocks from text and @-mentions."""
content: list = [TextBlock(text=text)]
for f in mentioned_frames:
content.append(ImageBlock(frame_id=f.id, path=f.path, timestamp=f.timestamp))
for t in mentioned_transcripts:
content.append(TranscriptBlock(
transcript_id=t.id, start=t.start, end=t.end, text=t.text
))
return UserMessage(content=content)
class AgentRunner:
"""Runs agent queries in a background thread with tool execution loop."""
def __init__(self):
self._connection = None
self._thread: AgentThread = AgentThread()
self.include_history = False
@property
def thread(self) -> AgentThread:
return self._thread
def _get_connection(self):
if self._connection is None:
self._connection = _resolve_connection()
log.info("Agent connection: %s", self._connection.name)
return self._connection
@property
def provider_name(self) -> str:
try:
return self._get_connection().name
except Exception:
return "unknown"
@property
def available_models(self) -> list[str]:
try:
return self._get_connection().available_models()
except Exception:
return []
@property
def model(self) -> str:
try:
return self._get_connection().get_model()
except Exception:
return ""
@model.setter
def model(self, value: str):
self._get_connection().set_model(value)
@property
def permission_mode(self) -> str:
conn = self._get_connection()
return getattr(conn, "_permission_mode", "default")
@permission_mode.setter
def permission_mode(self, value: str):
conn = self._get_connection()
if hasattr(conn, "_permission_mode"):
conn._permission_mode = value
import cht.config
cht.config.AGENT_PERMISSION_MODE = value
log.info("Permission mode set to %s", value)
def clear_history(self):
self._thread = AgentThread()
def set_thread(self, thread: AgentThread):
self._thread = thread
def load_from_session(self, session_dir: Path):
loaded = load_thread(session_dir)
if loaded:
self._thread = loaded
log.info("Loaded thread %s with %d messages", loaded.id, len(loaded.messages))
else:
self._thread = AgentThread()
def send(
self,
message: str,
stream_mgr,
tracker,
on_event: Callable[[StreamEvent], None],
on_done: Callable[[str | None], None],
):
"""Dispatch message in a background thread with tool execution loop.
on_event(StreamEvent) — called for each stream event (from bg thread)
on_done(error_or_None) — called when complete (from bg thread)
"""
def _run():
try:
connection = self._get_connection()
frames = load_frames(stream_mgr.frames_dir)
mentioned_frames = _parse_mentions(message, frames)
transcript = load_transcript(stream_mgr.transcript_dir)
mentioned_transcripts = _parse_transcript_mentions(message, transcript)
# Build and append user message
user_msg = _build_user_message(message, mentioned_frames, mentioned_transcripts)
self._thread.messages.append(user_msg)
# Build tool context
tool_ctx = ToolContext(
session_dir=stream_mgr.session_dir,
frames_dir=stream_mgr.frames_dir,
transcript_dir=stream_mgr.transcript_dir,
stream_mgr=stream_mgr,
tracker=tracker,
)
# Tool registry
tools_by_name = {t.name: t for t in BUILTIN_TOOLS}
# Messages to send — full thread or just last message
def _get_messages():
if self.include_history:
return list(self._thread.messages)
return [self._thread.messages[-1]]
# Tool execution loop
full_text_parts = []
for _turn in range(MAX_TOOL_TURNS):
msgs = _get_messages()
stop_reason = None
for event in connection.prompt(msgs, BUILTIN_TOOLS):
on_event(event)
if isinstance(event, TextDelta):
full_text_parts.append(event.text)
elif isinstance(event, ToolCallStart):
tool_use = ToolUse(
id=event.id,
tool_name=event.name,
input=event.input,
status="running",
)
self._thread.messages.append(tool_use)
elif isinstance(event, ToolCallEnd):
# Execute the tool
tool = tools_by_name.get(event.id)
# Find the ToolUse by id
tool_use_msg = None
for m in reversed(self._thread.messages):
if isinstance(m, ToolUse) and m.id == event.id:
tool_use_msg = m
break
if tool_use_msg:
tool_impl = tools_by_name.get(tool_use_msg.tool_name)
if tool_impl:
result = tool_impl.run(tool_use_msg.input, tool_ctx)
result.tool_use_id = tool_use_msg.id
tool_use_msg.status = "error" if result.error else "done"
else:
result = ToolResult(
tool_use_id=tool_use_msg.id,
error=f"Unknown tool: {tool_use_msg.tool_name}",
)
tool_use_msg.status = "error"
self._thread.messages.append(result)
on_event(result)
elif isinstance(event, Done):
stop_reason = event.stop_reason
break
elif isinstance(event, Error):
on_done(event.message)
return
# Build assistant message from accumulated text
if full_text_parts:
asst_msg = AssistantMessage(
content=[TextBlock(text="".join(full_text_parts))],
model=connection.get_model(),
)
self._thread.messages.append(asst_msg)
if stop_reason != "tool_use":
break
# Reset text for next turn (tool loop continues)
full_text_parts = []
# Save thread
save_thread(self._thread, stream_mgr.session_dir)
on_done(None)
except Exception as e:
log.error("Agent error: %s", e)
on_done(str(e))
Thread(target=_run, daemon=True, name="agent_runner").start()