better agent

This commit is contained in:
2026-04-09 14:46:29 -03:00
parent ade92069c0
commit 64ecdca71e
11 changed files with 1424 additions and 434 deletions

View File

@@ -1,13 +1,13 @@
"""
Agent runner — resolves provider, parses @-mentions, dispatches messages.
Agent runner — resolves connection, parses @-mentions, dispatches messages,
executes tool loop.
Provider selection (in order):
Connection selection (in order):
1. GROQ_API_KEY → OpenAI-compat / Groq
2. OPENAI_API_KEY → OpenAI-compat / OpenAI
3. (default) → Claude Code SDK (uses CC subscription)
"""
import json
import logging
import os
import re
@@ -15,7 +15,29 @@ from pathlib import Path
from threading import Thread
from typing import Callable
from cht.agent.base import AgentProvider, FrameRef, TranscriptRef, SessionContext
from cht.agent.base import (
AssistantMessage,
FrameRef,
ImageBlock,
Message,
StreamEvent,
TextBlock,
TextDelta,
ToolCallEnd,
ToolCallStart,
ToolContext,
ToolResult,
ToolUse,
TranscriptBlock,
TranscriptRef,
UserMessage,
Done,
Error,
Thread as AgentThread,
save_thread,
load_thread,
)
from cht.agent.tools import load_frames, load_transcript, BUILTIN_TOOLS
log = logging.getLogger(__name__)
@@ -25,6 +47,8 @@ ACTIONS: dict[str, str] = {
"Answer": "answer",
}
MAX_TOOL_TURNS = 10
def check_claude_cli() -> str | None:
"""Returns None if OK, or an error string if CLI is missing/unauthenticated."""
@@ -43,16 +67,15 @@ def check_claude_cli() -> str | None:
return None
def _resolve_provider() -> AgentProvider:
def _resolve_connection():
if os.environ.get("GROQ_API_KEY") or os.environ.get("OPENAI_API_KEY"):
from cht.agent.openai_compat_provider import OpenAICompatProvider
return OpenAICompatProvider()
from cht.agent.claude_sdk_provider import ClaudeSDKProvider
return ClaudeSDKProvider()
from cht.agent.openai_connection import OpenAIConnection
return OpenAIConnection()
from cht.agent.claude_sdk_connection import ClaudeSDKConnection
return ClaudeSDKConnection()
def _expand_ref_nums(spec: str) -> list[int]:
"""Expand a ref spec like '2-6' or '2,4,6' or '2-4,6,8-10' into sorted ints."""
nums = set()
for part in spec.split(","):
part = part.strip()
@@ -71,7 +94,6 @@ def _expand_ref_nums(spec: str) -> list[int]:
def _parse_mentions(message: str, frames: list[FrameRef]) -> list[FrameRef]:
"""Extract @F references. Accepts @F1, @F2-6, @F2,4,6, @F2-4,6,8-10."""
mentioned = []
seen = set()
for match in re.finditer(r"@[Ff]([\d,\-]+)", message):
@@ -85,49 +107,7 @@ def _parse_mentions(message: str, frames: list[FrameRef]) -> list[FrameRef]:
return mentioned
def _resolve_frame_path(frames_dir: Path, raw_path: str) -> Path | None:
"""Resolve a frame path from index.json, handling mounted/remote sessions."""
p = Path(raw_path)
if p.exists():
return p
# Try relative to frames_dir (handles path prefix mismatch from remote)
local = frames_dir / p.name
if local.exists():
return local
return None
def _load_frames(frames_dir: Path) -> list[FrameRef]:
index_path = frames_dir / "index.json"
if not index_path.exists():
return []
try:
entries = json.loads(index_path.read_text())
frames = []
for e in entries:
resolved = _resolve_frame_path(frames_dir, e["path"])
if resolved:
frames.append(FrameRef(id=e["id"], path=resolved, timestamp=e["timestamp"]))
return frames
except Exception as e:
log.warning("Could not load frames index: %s", e)
return []
def _load_transcript(transcript_dir: Path) -> list[TranscriptRef]:
index_path = transcript_dir / "index.json"
if not index_path.exists():
return []
try:
entries = json.loads(index_path.read_text())
return [TranscriptRef(**e) for e in entries]
except Exception as e:
log.warning("Could not load transcript index: %s", e)
return []
def _parse_transcript_mentions(message: str, segments: list[TranscriptRef]) -> list[TranscriptRef]:
"""Extract @T references. Accepts @T1, @T2-6, @T2,4,6, @T1-3,5,7-10."""
mentioned = []
seen = set()
for match in re.finditer(r"@[Tt]([\d,\-]+)", message):
@@ -141,84 +121,191 @@ def _parse_transcript_mentions(message: str, segments: list[TranscriptRef]) -> l
return mentioned
def _build_user_message(text: str, mentioned_frames: list[FrameRef],
mentioned_transcripts: list[TranscriptRef]) -> UserMessage:
"""Build a UserMessage with content blocks from text and @-mentions."""
content: list = [TextBlock(text=text)]
for f in mentioned_frames:
content.append(ImageBlock(frame_id=f.id, path=f.path, timestamp=f.timestamp))
for t in mentioned_transcripts:
content.append(TranscriptBlock(
transcript_id=t.id, start=t.start, end=t.end, text=t.text
))
return UserMessage(content=content)
class AgentRunner:
"""Runs agent queries in a background thread, streams chunks to a callback."""
"""Runs agent queries in a background thread with tool execution loop."""
def __init__(self):
self._provider: AgentProvider | None = None
self._history: list[tuple[str, str]] = [] # (role, text)
self.include_history = False # toggled by UI
self._connection = None
self._thread: AgentThread = AgentThread()
self.include_history = False
def _get_provider(self) -> AgentProvider:
if self._provider is None:
self._provider = _resolve_provider()
log.info("Agent provider: %s", self._provider.name)
return self._provider
@property
def thread(self) -> AgentThread:
return self._thread
def _get_connection(self):
if self._connection is None:
self._connection = _resolve_connection()
log.info("Agent connection: %s", self._connection.name)
return self._connection
@property
def provider_name(self) -> str:
try:
return self._get_provider().name
return self._get_connection().name
except Exception:
return "unknown"
@property
def available_models(self) -> list[str]:
try:
return self._get_provider().available_models
return self._get_connection().available_models()
except Exception:
return []
@property
def model(self) -> str:
try:
return self._get_provider().model
return self._get_connection().get_model()
except Exception:
return ""
@model.setter
def model(self, value: str):
self._get_provider().model = value
self._get_connection().set_model(value)
def clear_history(self):
self._history.clear()
self._thread = AgentThread()
def set_thread(self, thread: AgentThread):
self._thread = thread
def load_from_session(self, session_dir: Path):
loaded = load_thread(session_dir)
if loaded:
self._thread = loaded
log.info("Loaded thread %s with %d messages", loaded.id, len(loaded.messages))
else:
self._thread = AgentThread()
def send(
self,
message: str,
stream_mgr,
tracker,
on_chunk: Callable[[str], None],
on_event: Callable[[StreamEvent], None],
on_done: Callable[[str | None], None],
):
"""Dispatch message in a background thread.
"""Dispatch message in a background thread with tool execution loop.
on_chunk(text) — called for each streamed chunk
on_done(error_or_None) — called when complete
on_event(StreamEvent) — called for each stream event (from bg thread)
on_done(error_or_None) — called when complete (from bg thread)
"""
def _run():
try:
provider = self._get_provider()
frames = _load_frames(stream_mgr.frames_dir)
connection = self._get_connection()
frames = load_frames(stream_mgr.frames_dir)
mentioned_frames = _parse_mentions(message, frames)
transcript = _load_transcript(stream_mgr.transcript_dir)
transcript = load_transcript(stream_mgr.transcript_dir)
mentioned_transcripts = _parse_transcript_mentions(message, transcript)
context = SessionContext(
# Build and append user message
user_msg = _build_user_message(message, mentioned_frames, mentioned_transcripts)
self._thread.messages.append(user_msg)
# Build tool context
tool_ctx = ToolContext(
session_dir=stream_mgr.session_dir,
frames=frames,
duration=tracker.duration if tracker else 0.0,
mentioned_frames=mentioned_frames,
transcript_segments=transcript,
mentioned_transcripts=mentioned_transcripts,
history=list(self._history) if self.include_history else [],
frames_dir=stream_mgr.frames_dir,
transcript_dir=stream_mgr.transcript_dir,
stream_mgr=stream_mgr,
tracker=tracker,
)
self._history.append(("user", message))
response_chunks = []
for chunk in provider.stream(message, context):
response_chunks.append(chunk)
on_chunk(chunk)
self._history.append(("assistant", "".join(response_chunks)))
# Tool registry
tools_by_name = {t.name: t for t in BUILTIN_TOOLS}
# Messages to send — full thread or just last message
def _get_messages():
if self.include_history:
return list(self._thread.messages)
return [self._thread.messages[-1]]
# Tool execution loop
full_text_parts = []
for _turn in range(MAX_TOOL_TURNS):
msgs = _get_messages()
stop_reason = None
for event in connection.prompt(msgs, BUILTIN_TOOLS):
on_event(event)
if isinstance(event, TextDelta):
full_text_parts.append(event.text)
elif isinstance(event, ToolCallStart):
tool_use = ToolUse(
id=event.id,
tool_name=event.name,
input=event.input,
status="running",
)
self._thread.messages.append(tool_use)
elif isinstance(event, ToolCallEnd):
# Execute the tool
tool = tools_by_name.get(event.id)
# Find the ToolUse by id
tool_use_msg = None
for m in reversed(self._thread.messages):
if isinstance(m, ToolUse) and m.id == event.id:
tool_use_msg = m
break
if tool_use_msg:
tool_impl = tools_by_name.get(tool_use_msg.tool_name)
if tool_impl:
result = tool_impl.run(tool_use_msg.input, tool_ctx)
result.tool_use_id = tool_use_msg.id
tool_use_msg.status = "error" if result.error else "done"
else:
result = ToolResult(
tool_use_id=tool_use_msg.id,
error=f"Unknown tool: {tool_use_msg.tool_name}",
)
tool_use_msg.status = "error"
self._thread.messages.append(result)
on_event(result)
elif isinstance(event, Done):
stop_reason = event.stop_reason
break
elif isinstance(event, Error):
on_done(event.message)
return
# Build assistant message from accumulated text
if full_text_parts:
asst_msg = AssistantMessage(
content=[TextBlock(text="".join(full_text_parts))],
model=connection.get_model(),
)
self._thread.messages.append(asst_msg)
if stop_reason != "tool_use":
break
# Reset text for next turn (tool loop continues)
full_text_parts = []
# Save thread
save_thread(self._thread, stream_mgr.session_dir)
on_done(None)
except Exception as e:
log.error("Agent error: %s", e)
on_done(str(e))