better agent
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
"""
|
||||
Agent runner — resolves provider, parses @-mentions, dispatches messages.
|
||||
Agent runner — resolves connection, parses @-mentions, dispatches messages,
|
||||
executes tool loop.
|
||||
|
||||
Provider selection (in order):
|
||||
Connection selection (in order):
|
||||
1. GROQ_API_KEY → OpenAI-compat / Groq
|
||||
2. OPENAI_API_KEY → OpenAI-compat / OpenAI
|
||||
3. (default) → Claude Code SDK (uses CC subscription)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@@ -15,7 +15,29 @@ from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Callable
|
||||
|
||||
from cht.agent.base import AgentProvider, FrameRef, TranscriptRef, SessionContext
|
||||
from cht.agent.base import (
|
||||
AssistantMessage,
|
||||
FrameRef,
|
||||
ImageBlock,
|
||||
Message,
|
||||
StreamEvent,
|
||||
TextBlock,
|
||||
TextDelta,
|
||||
ToolCallEnd,
|
||||
ToolCallStart,
|
||||
ToolContext,
|
||||
ToolResult,
|
||||
ToolUse,
|
||||
TranscriptBlock,
|
||||
TranscriptRef,
|
||||
UserMessage,
|
||||
Done,
|
||||
Error,
|
||||
Thread as AgentThread,
|
||||
save_thread,
|
||||
load_thread,
|
||||
)
|
||||
from cht.agent.tools import load_frames, load_transcript, BUILTIN_TOOLS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,6 +47,8 @@ ACTIONS: dict[str, str] = {
|
||||
"Answer": "answer",
|
||||
}
|
||||
|
||||
MAX_TOOL_TURNS = 10
|
||||
|
||||
|
||||
def check_claude_cli() -> str | None:
|
||||
"""Returns None if OK, or an error string if CLI is missing/unauthenticated."""
|
||||
@@ -43,16 +67,15 @@ def check_claude_cli() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_provider() -> AgentProvider:
|
||||
def _resolve_connection():
|
||||
if os.environ.get("GROQ_API_KEY") or os.environ.get("OPENAI_API_KEY"):
|
||||
from cht.agent.openai_compat_provider import OpenAICompatProvider
|
||||
return OpenAICompatProvider()
|
||||
from cht.agent.claude_sdk_provider import ClaudeSDKProvider
|
||||
return ClaudeSDKProvider()
|
||||
from cht.agent.openai_connection import OpenAIConnection
|
||||
return OpenAIConnection()
|
||||
from cht.agent.claude_sdk_connection import ClaudeSDKConnection
|
||||
return ClaudeSDKConnection()
|
||||
|
||||
|
||||
def _expand_ref_nums(spec: str) -> list[int]:
|
||||
"""Expand a ref spec like '2-6' or '2,4,6' or '2-4,6,8-10' into sorted ints."""
|
||||
nums = set()
|
||||
for part in spec.split(","):
|
||||
part = part.strip()
|
||||
@@ -71,7 +94,6 @@ def _expand_ref_nums(spec: str) -> list[int]:
|
||||
|
||||
|
||||
def _parse_mentions(message: str, frames: list[FrameRef]) -> list[FrameRef]:
|
||||
"""Extract @F references. Accepts @F1, @F2-6, @F2,4,6, @F2-4,6,8-10."""
|
||||
mentioned = []
|
||||
seen = set()
|
||||
for match in re.finditer(r"@[Ff]([\d,\-]+)", message):
|
||||
@@ -85,49 +107,7 @@ def _parse_mentions(message: str, frames: list[FrameRef]) -> list[FrameRef]:
|
||||
return mentioned
|
||||
|
||||
|
||||
def _resolve_frame_path(frames_dir: Path, raw_path: str) -> Path | None:
|
||||
"""Resolve a frame path from index.json, handling mounted/remote sessions."""
|
||||
p = Path(raw_path)
|
||||
if p.exists():
|
||||
return p
|
||||
# Try relative to frames_dir (handles path prefix mismatch from remote)
|
||||
local = frames_dir / p.name
|
||||
if local.exists():
|
||||
return local
|
||||
return None
|
||||
|
||||
|
||||
def _load_frames(frames_dir: Path) -> list[FrameRef]:
|
||||
index_path = frames_dir / "index.json"
|
||||
if not index_path.exists():
|
||||
return []
|
||||
try:
|
||||
entries = json.loads(index_path.read_text())
|
||||
frames = []
|
||||
for e in entries:
|
||||
resolved = _resolve_frame_path(frames_dir, e["path"])
|
||||
if resolved:
|
||||
frames.append(FrameRef(id=e["id"], path=resolved, timestamp=e["timestamp"]))
|
||||
return frames
|
||||
except Exception as e:
|
||||
log.warning("Could not load frames index: %s", e)
|
||||
return []
|
||||
|
||||
|
||||
def _load_transcript(transcript_dir: Path) -> list[TranscriptRef]:
|
||||
index_path = transcript_dir / "index.json"
|
||||
if not index_path.exists():
|
||||
return []
|
||||
try:
|
||||
entries = json.loads(index_path.read_text())
|
||||
return [TranscriptRef(**e) for e in entries]
|
||||
except Exception as e:
|
||||
log.warning("Could not load transcript index: %s", e)
|
||||
return []
|
||||
|
||||
|
||||
def _parse_transcript_mentions(message: str, segments: list[TranscriptRef]) -> list[TranscriptRef]:
|
||||
"""Extract @T references. Accepts @T1, @T2-6, @T2,4,6, @T1-3,5,7-10."""
|
||||
mentioned = []
|
||||
seen = set()
|
||||
for match in re.finditer(r"@[Tt]([\d,\-]+)", message):
|
||||
@@ -141,84 +121,191 @@ def _parse_transcript_mentions(message: str, segments: list[TranscriptRef]) -> l
|
||||
return mentioned
|
||||
|
||||
|
||||
def _build_user_message(text: str, mentioned_frames: list[FrameRef],
|
||||
mentioned_transcripts: list[TranscriptRef]) -> UserMessage:
|
||||
"""Build a UserMessage with content blocks from text and @-mentions."""
|
||||
content: list = [TextBlock(text=text)]
|
||||
for f in mentioned_frames:
|
||||
content.append(ImageBlock(frame_id=f.id, path=f.path, timestamp=f.timestamp))
|
||||
for t in mentioned_transcripts:
|
||||
content.append(TranscriptBlock(
|
||||
transcript_id=t.id, start=t.start, end=t.end, text=t.text
|
||||
))
|
||||
return UserMessage(content=content)
|
||||
|
||||
|
||||
class AgentRunner:
|
||||
"""Runs agent queries in a background thread, streams chunks to a callback."""
|
||||
"""Runs agent queries in a background thread with tool execution loop."""
|
||||
|
||||
def __init__(self):
|
||||
self._provider: AgentProvider | None = None
|
||||
self._history: list[tuple[str, str]] = [] # (role, text)
|
||||
self.include_history = False # toggled by UI
|
||||
self._connection = None
|
||||
self._thread: AgentThread = AgentThread()
|
||||
self.include_history = False
|
||||
|
||||
def _get_provider(self) -> AgentProvider:
|
||||
if self._provider is None:
|
||||
self._provider = _resolve_provider()
|
||||
log.info("Agent provider: %s", self._provider.name)
|
||||
return self._provider
|
||||
@property
|
||||
def thread(self) -> AgentThread:
|
||||
return self._thread
|
||||
|
||||
def _get_connection(self):
|
||||
if self._connection is None:
|
||||
self._connection = _resolve_connection()
|
||||
log.info("Agent connection: %s", self._connection.name)
|
||||
return self._connection
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
try:
|
||||
return self._get_provider().name
|
||||
return self._get_connection().name
|
||||
except Exception:
|
||||
return "unknown"
|
||||
|
||||
@property
|
||||
def available_models(self) -> list[str]:
|
||||
try:
|
||||
return self._get_provider().available_models
|
||||
return self._get_connection().available_models()
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
try:
|
||||
return self._get_provider().model
|
||||
return self._get_connection().get_model()
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
@model.setter
|
||||
def model(self, value: str):
|
||||
self._get_provider().model = value
|
||||
self._get_connection().set_model(value)
|
||||
|
||||
def clear_history(self):
|
||||
self._history.clear()
|
||||
self._thread = AgentThread()
|
||||
|
||||
def set_thread(self, thread: AgentThread):
|
||||
self._thread = thread
|
||||
|
||||
def load_from_session(self, session_dir: Path):
|
||||
loaded = load_thread(session_dir)
|
||||
if loaded:
|
||||
self._thread = loaded
|
||||
log.info("Loaded thread %s with %d messages", loaded.id, len(loaded.messages))
|
||||
else:
|
||||
self._thread = AgentThread()
|
||||
|
||||
def send(
|
||||
self,
|
||||
message: str,
|
||||
stream_mgr,
|
||||
tracker,
|
||||
on_chunk: Callable[[str], None],
|
||||
on_event: Callable[[StreamEvent], None],
|
||||
on_done: Callable[[str | None], None],
|
||||
):
|
||||
"""Dispatch message in a background thread.
|
||||
"""Dispatch message in a background thread with tool execution loop.
|
||||
|
||||
on_chunk(text) — called for each streamed chunk
|
||||
on_done(error_or_None) — called when complete
|
||||
on_event(StreamEvent) — called for each stream event (from bg thread)
|
||||
on_done(error_or_None) — called when complete (from bg thread)
|
||||
"""
|
||||
def _run():
|
||||
try:
|
||||
provider = self._get_provider()
|
||||
frames = _load_frames(stream_mgr.frames_dir)
|
||||
connection = self._get_connection()
|
||||
frames = load_frames(stream_mgr.frames_dir)
|
||||
mentioned_frames = _parse_mentions(message, frames)
|
||||
transcript = _load_transcript(stream_mgr.transcript_dir)
|
||||
transcript = load_transcript(stream_mgr.transcript_dir)
|
||||
mentioned_transcripts = _parse_transcript_mentions(message, transcript)
|
||||
context = SessionContext(
|
||||
|
||||
# Build and append user message
|
||||
user_msg = _build_user_message(message, mentioned_frames, mentioned_transcripts)
|
||||
self._thread.messages.append(user_msg)
|
||||
|
||||
# Build tool context
|
||||
tool_ctx = ToolContext(
|
||||
session_dir=stream_mgr.session_dir,
|
||||
frames=frames,
|
||||
duration=tracker.duration if tracker else 0.0,
|
||||
mentioned_frames=mentioned_frames,
|
||||
transcript_segments=transcript,
|
||||
mentioned_transcripts=mentioned_transcripts,
|
||||
history=list(self._history) if self.include_history else [],
|
||||
frames_dir=stream_mgr.frames_dir,
|
||||
transcript_dir=stream_mgr.transcript_dir,
|
||||
stream_mgr=stream_mgr,
|
||||
tracker=tracker,
|
||||
)
|
||||
self._history.append(("user", message))
|
||||
response_chunks = []
|
||||
for chunk in provider.stream(message, context):
|
||||
response_chunks.append(chunk)
|
||||
on_chunk(chunk)
|
||||
self._history.append(("assistant", "".join(response_chunks)))
|
||||
|
||||
# Tool registry
|
||||
tools_by_name = {t.name: t for t in BUILTIN_TOOLS}
|
||||
|
||||
# Messages to send — full thread or just last message
|
||||
def _get_messages():
|
||||
if self.include_history:
|
||||
return list(self._thread.messages)
|
||||
return [self._thread.messages[-1]]
|
||||
|
||||
# Tool execution loop
|
||||
full_text_parts = []
|
||||
for _turn in range(MAX_TOOL_TURNS):
|
||||
msgs = _get_messages()
|
||||
stop_reason = None
|
||||
|
||||
for event in connection.prompt(msgs, BUILTIN_TOOLS):
|
||||
on_event(event)
|
||||
|
||||
if isinstance(event, TextDelta):
|
||||
full_text_parts.append(event.text)
|
||||
|
||||
elif isinstance(event, ToolCallStart):
|
||||
tool_use = ToolUse(
|
||||
id=event.id,
|
||||
tool_name=event.name,
|
||||
input=event.input,
|
||||
status="running",
|
||||
)
|
||||
self._thread.messages.append(tool_use)
|
||||
|
||||
elif isinstance(event, ToolCallEnd):
|
||||
# Execute the tool
|
||||
tool = tools_by_name.get(event.id)
|
||||
# Find the ToolUse by id
|
||||
tool_use_msg = None
|
||||
for m in reversed(self._thread.messages):
|
||||
if isinstance(m, ToolUse) and m.id == event.id:
|
||||
tool_use_msg = m
|
||||
break
|
||||
|
||||
if tool_use_msg:
|
||||
tool_impl = tools_by_name.get(tool_use_msg.tool_name)
|
||||
if tool_impl:
|
||||
result = tool_impl.run(tool_use_msg.input, tool_ctx)
|
||||
result.tool_use_id = tool_use_msg.id
|
||||
tool_use_msg.status = "error" if result.error else "done"
|
||||
else:
|
||||
result = ToolResult(
|
||||
tool_use_id=tool_use_msg.id,
|
||||
error=f"Unknown tool: {tool_use_msg.tool_name}",
|
||||
)
|
||||
tool_use_msg.status = "error"
|
||||
self._thread.messages.append(result)
|
||||
on_event(result)
|
||||
|
||||
elif isinstance(event, Done):
|
||||
stop_reason = event.stop_reason
|
||||
break
|
||||
|
||||
elif isinstance(event, Error):
|
||||
on_done(event.message)
|
||||
return
|
||||
|
||||
# Build assistant message from accumulated text
|
||||
if full_text_parts:
|
||||
asst_msg = AssistantMessage(
|
||||
content=[TextBlock(text="".join(full_text_parts))],
|
||||
model=connection.get_model(),
|
||||
)
|
||||
self._thread.messages.append(asst_msg)
|
||||
|
||||
if stop_reason != "tool_use":
|
||||
break
|
||||
|
||||
# Reset text for next turn (tool loop continues)
|
||||
full_text_parts = []
|
||||
|
||||
# Save thread
|
||||
save_thread(self._thread, stream_mgr.session_dir)
|
||||
on_done(None)
|
||||
|
||||
except Exception as e:
|
||||
log.error("Agent error: %s", e)
|
||||
on_done(str(e))
|
||||
|
||||
Reference in New Issue
Block a user