Files
mitus/cht/agent/runner.py
2026-04-03 00:25:14 -03:00

217 lines
6.9 KiB
Python

"""
Agent runner — resolves provider, parses @-mentions, dispatches messages.
Provider selection (in order):
1. GROQ_API_KEY → OpenAI-compat / Groq
2. OPENAI_API_KEY → OpenAI-compat / OpenAI
3. (default) → Claude Code SDK (uses CC subscription)
"""
import json
import logging
import os
import re
from pathlib import Path
from threading import Thread
from typing import Callable
from cht.agent.base import AgentProvider, FrameRef, TranscriptRef, SessionContext
log = logging.getLogger(__name__)
# Predefined actions — label → verb prefix (frame ref appended by UI)
ACTIONS: dict[str, str] = {
"Describe": "describe",
"Answer": "answer",
}
def check_claude_cli() -> str | None:
"""Returns None if OK, or an error string if CLI is missing/unauthenticated."""
import shutil, subprocess
if not shutil.which("claude"):
return "Claude Code CLI not found. Install from https://claude.ai/code"
try:
result = subprocess.run(
["claude", "--version"],
capture_output=True, timeout=5
)
if result.returncode != 0:
return "Claude Code CLI error. Run `claude` in a terminal to check."
except Exception as e:
return f"Claude Code CLI check failed: {e}"
return None
def _resolve_provider() -> AgentProvider:
if os.environ.get("GROQ_API_KEY") or os.environ.get("OPENAI_API_KEY"):
from cht.agent.openai_compat_provider import OpenAICompatProvider
return OpenAICompatProvider()
from cht.agent.claude_sdk_provider import ClaudeSDKProvider
return ClaudeSDKProvider()
def _expand_ref_nums(spec: str) -> list[int]:
"""Expand a ref spec like '2-6' or '2,4,6' or '2-4,6,8-10' into sorted ints."""
nums = set()
for part in spec.split(","):
part = part.strip()
if "-" in part:
a, b = part.split("-", 1)
try:
nums.update(range(int(a), int(b) + 1))
except ValueError:
pass
elif part:
try:
nums.add(int(part))
except ValueError:
pass
return sorted(nums)
def _parse_mentions(message: str, frames: list[FrameRef]) -> list[FrameRef]:
"""Extract @F references. Accepts @F1, @F2-6, @F2,4,6, @F2-4,6,8-10."""
mentioned = []
seen = set()
for match in re.finditer(r"@[Ff]([\d,\-]+)", message):
for num in _expand_ref_nums(match.group(1)):
fid = f"F{num:04d}"
if fid not in seen:
frame = next((f for f in frames if f.id == fid), None)
if frame:
mentioned.append(frame)
seen.add(fid)
return mentioned
def _resolve_frame_path(frames_dir: Path, raw_path: str) -> Path | None:
"""Resolve a frame path from index.json, handling mounted/remote sessions."""
p = Path(raw_path)
if p.exists():
return p
# Try relative to frames_dir (handles path prefix mismatch from remote)
local = frames_dir / p.name
if local.exists():
return local
return None
def _load_frames(frames_dir: Path) -> list[FrameRef]:
index_path = frames_dir / "index.json"
if not index_path.exists():
return []
try:
entries = json.loads(index_path.read_text())
frames = []
for e in entries:
resolved = _resolve_frame_path(frames_dir, e["path"])
if resolved:
frames.append(FrameRef(id=e["id"], path=resolved, timestamp=e["timestamp"]))
return frames
except Exception as e:
log.warning("Could not load frames index: %s", e)
return []
def _load_transcript(transcript_dir: Path) -> list[TranscriptRef]:
index_path = transcript_dir / "index.json"
if not index_path.exists():
return []
try:
entries = json.loads(index_path.read_text())
return [TranscriptRef(**e) for e in entries]
except Exception as e:
log.warning("Could not load transcript index: %s", e)
return []
def _parse_transcript_mentions(message: str, segments: list[TranscriptRef]) -> list[TranscriptRef]:
"""Extract @T references. Accepts @T1, @T2-6, @T2,4,6, @T1-3,5,7-10."""
mentioned = []
seen = set()
for match in re.finditer(r"@[Tt]([\d,\-]+)", message):
for num in _expand_ref_nums(match.group(1)):
tid = f"T{num:04d}"
if tid not in seen:
seg = next((s for s in segments if s.id == tid), None)
if seg:
mentioned.append(seg)
seen.add(tid)
return mentioned
class AgentRunner:
"""Runs agent queries in a background thread, streams chunks to a callback."""
def __init__(self):
self._provider: AgentProvider | None = None
def _get_provider(self) -> AgentProvider:
if self._provider is None:
self._provider = _resolve_provider()
log.info("Agent provider: %s", self._provider.name)
return self._provider
@property
def provider_name(self) -> str:
try:
return self._get_provider().name
except Exception:
return "unknown"
@property
def available_models(self) -> list[str]:
try:
return self._get_provider().available_models
except Exception:
return []
@property
def model(self) -> str:
try:
return self._get_provider().model
except Exception:
return ""
@model.setter
def model(self, value: str):
self._get_provider().model = value
def send(
self,
message: str,
stream_mgr,
tracker,
on_chunk: Callable[[str], None],
on_done: Callable[[str | None], None],
):
"""Dispatch message in a background thread.
on_chunk(text) — called for each streamed chunk
on_done(error_or_None) — called when complete
"""
def _run():
try:
provider = self._get_provider()
frames = _load_frames(stream_mgr.frames_dir)
mentioned_frames = _parse_mentions(message, frames)
transcript = _load_transcript(stream_mgr.transcript_dir)
mentioned_transcripts = _parse_transcript_mentions(message, transcript)
context = SessionContext(
session_dir=stream_mgr.session_dir,
frames=frames,
duration=tracker.duration if tracker else 0.0,
mentioned_frames=mentioned_frames,
transcript_segments=transcript,
mentioned_transcripts=mentioned_transcripts,
)
for chunk in provider.stream(message, context):
on_chunk(chunk)
on_done(None)
except Exception as e:
log.error("Agent error: %s", e)
on_done(str(e))
Thread(target=_run, daemon=True, name="agent_runner").start()