170 lines
5.2 KiB
Python
170 lines
5.2 KiB
Python
"""
|
|
Agent runner — resolves provider, parses @-mentions, dispatches messages.
|
|
|
|
Provider selection (in order):
|
|
1. GROQ_API_KEY → OpenAI-compat / Groq
|
|
2. OPENAI_API_KEY → OpenAI-compat / OpenAI
|
|
3. (default) → Claude Code SDK (uses CC subscription)
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
from pathlib import Path
|
|
from threading import Thread
|
|
from typing import Callable
|
|
|
|
from cht.agent.base import AgentProvider, FrameRef, SessionContext
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
# Predefined actions — label → verb prefix (frame ref appended by UI)
|
|
ACTIONS: dict[str, str] = {
|
|
"Describe": "describe",
|
|
"Answer": "answer",
|
|
}
|
|
|
|
|
|
def check_claude_cli() -> str | None:
|
|
"""Returns None if OK, or an error string if CLI is missing/unauthenticated."""
|
|
import shutil, subprocess
|
|
if not shutil.which("claude"):
|
|
return "Claude Code CLI not found. Install from https://claude.ai/code"
|
|
try:
|
|
result = subprocess.run(
|
|
["claude", "--version"],
|
|
capture_output=True, timeout=5
|
|
)
|
|
if result.returncode != 0:
|
|
return "Claude Code CLI error. Run `claude` in a terminal to check."
|
|
except Exception as e:
|
|
return f"Claude Code CLI check failed: {e}"
|
|
return None
|
|
|
|
|
|
def _resolve_provider() -> AgentProvider:
|
|
if os.environ.get("GROQ_API_KEY") or os.environ.get("OPENAI_API_KEY"):
|
|
from cht.agent.openai_compat_provider import OpenAICompatProvider
|
|
return OpenAICompatProvider()
|
|
from cht.agent.claude_sdk_provider import ClaudeSDKProvider
|
|
return ClaudeSDKProvider()
|
|
|
|
|
|
def _parse_mentions(message: str, frames: list[FrameRef]) -> list[FrameRef]:
|
|
"""Extract @-references from message. Accepts:
|
|
@F0001 @f1 @1 @001 — all match frame F0001
|
|
"""
|
|
mentioned = []
|
|
seen = set()
|
|
for match in re.finditer(r"@([Ff]?\d+)", message):
|
|
raw = match.group(1).lstrip("Ff")
|
|
num = int(raw)
|
|
fid = f"F{num:04d}"
|
|
if fid not in seen:
|
|
frame = next((f for f in frames if f.id == fid), None)
|
|
if frame:
|
|
mentioned.append(frame)
|
|
seen.add(fid)
|
|
return mentioned
|
|
|
|
|
|
def _resolve_frame_path(frames_dir: Path, raw_path: str) -> Path | None:
|
|
"""Resolve a frame path from index.json, handling mounted/remote sessions."""
|
|
p = Path(raw_path)
|
|
if p.exists():
|
|
return p
|
|
# Try relative to frames_dir (handles path prefix mismatch from remote)
|
|
local = frames_dir / p.name
|
|
if local.exists():
|
|
return local
|
|
return None
|
|
|
|
|
|
def _load_frames(frames_dir: Path) -> list[FrameRef]:
|
|
index_path = frames_dir / "index.json"
|
|
if not index_path.exists():
|
|
return []
|
|
try:
|
|
entries = json.loads(index_path.read_text())
|
|
frames = []
|
|
for e in entries:
|
|
resolved = _resolve_frame_path(frames_dir, e["path"])
|
|
if resolved:
|
|
frames.append(FrameRef(id=e["id"], path=resolved, timestamp=e["timestamp"]))
|
|
return frames
|
|
except Exception as e:
|
|
log.warning("Could not load frames index: %s", e)
|
|
return []
|
|
|
|
|
|
class AgentRunner:
|
|
"""Runs agent queries in a background thread, streams chunks to a callback."""
|
|
|
|
def __init__(self):
|
|
self._provider: AgentProvider | None = None
|
|
|
|
def _get_provider(self) -> AgentProvider:
|
|
if self._provider is None:
|
|
self._provider = _resolve_provider()
|
|
log.info("Agent provider: %s", self._provider.name)
|
|
return self._provider
|
|
|
|
@property
|
|
def provider_name(self) -> str:
|
|
try:
|
|
return self._get_provider().name
|
|
except Exception:
|
|
return "unknown"
|
|
|
|
@property
|
|
def available_models(self) -> list[str]:
|
|
try:
|
|
return self._get_provider().available_models
|
|
except Exception:
|
|
return []
|
|
|
|
@property
|
|
def model(self) -> str:
|
|
try:
|
|
return self._get_provider().model
|
|
except Exception:
|
|
return ""
|
|
|
|
@model.setter
|
|
def model(self, value: str):
|
|
self._get_provider().model = value
|
|
|
|
def send(
|
|
self,
|
|
message: str,
|
|
stream_mgr,
|
|
tracker,
|
|
on_chunk: Callable[[str], None],
|
|
on_done: Callable[[str | None], None],
|
|
):
|
|
"""Dispatch message in a background thread.
|
|
|
|
on_chunk(text) — called for each streamed chunk
|
|
on_done(error_or_None) — called when complete
|
|
"""
|
|
def _run():
|
|
try:
|
|
provider = self._get_provider()
|
|
frames = _load_frames(stream_mgr.frames_dir)
|
|
mentioned = _parse_mentions(message, frames)
|
|
context = SessionContext(
|
|
session_dir=stream_mgr.session_dir,
|
|
frames=frames,
|
|
duration=tracker.duration if tracker else 0.0,
|
|
mentioned_frames=mentioned,
|
|
)
|
|
for chunk in provider.stream(message, context):
|
|
on_chunk(chunk)
|
|
on_done(None)
|
|
except Exception as e:
|
|
log.error("Agent error: %s", e)
|
|
on_done(str(e))
|
|
|
|
Thread(target=_run, daemon=True, name="agent_runner").start()
|