AI
This commit is contained in:
140
cht/agent/runner.py
Normal file
140
cht/agent/runner.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Agent runner — resolves provider, parses @-mentions, dispatches messages.
|
||||
|
||||
Provider selection (in order):
|
||||
1. GROQ_API_KEY → OpenAI-compat / Groq
|
||||
2. OPENAI_API_KEY → OpenAI-compat / OpenAI
|
||||
3. (default) → Claude Code SDK (uses CC subscription)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Callable
|
||||
|
||||
from cht.agent.base import AgentProvider, FrameRef, SessionContext
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Predefined actions sent as messages with a fixed prompt
|
||||
ACTIONS: dict[str, str] = {
|
||||
"Summarize": "Summarize what happened in this recording so far. Look at the captured frames and describe the key content and any changes you notice.",
|
||||
"What changed": "Compare the captured frames in order and describe what changed between them. Focus on meaningful transitions.",
|
||||
"Key moments": "Identify the most important moments in the recording based on the frames. List them with timestamps.",
|
||||
"Describe now": "Look at the most recent frame and describe exactly what is currently on screen.",
|
||||
}
|
||||
|
||||
|
||||
def check_claude_cli() -> str | None:
|
||||
"""Returns None if OK, or an error string if CLI is missing/unauthenticated."""
|
||||
import shutil, subprocess
|
||||
if not shutil.which("claude"):
|
||||
return "Claude Code CLI not found. Install from https://claude.ai/code"
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["claude", "--version"],
|
||||
capture_output=True, timeout=5
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return "Claude Code CLI error. Run `claude` in a terminal to check."
|
||||
except Exception as e:
|
||||
return f"Claude Code CLI check failed: {e}"
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_provider() -> AgentProvider:
|
||||
if os.environ.get("GROQ_API_KEY") or os.environ.get("OPENAI_API_KEY"):
|
||||
from cht.agent.openai_compat_provider import OpenAICompatProvider
|
||||
return OpenAICompatProvider()
|
||||
from cht.agent.claude_sdk_provider import ClaudeSDKProvider
|
||||
return ClaudeSDKProvider()
|
||||
|
||||
|
||||
def _parse_mentions(message: str, frames: list[FrameRef]) -> list[FrameRef]:
|
||||
"""Extract @-references from message. Accepts:
|
||||
@F0001 @f1 @1 @001 — all match frame F0001
|
||||
"""
|
||||
mentioned = []
|
||||
seen = set()
|
||||
for match in re.finditer(r"@([Ff]?\d+)", message):
|
||||
raw = match.group(1).lstrip("Ff")
|
||||
num = int(raw)
|
||||
fid = f"F{num:04d}"
|
||||
if fid not in seen:
|
||||
frame = next((f for f in frames if f.id == fid), None)
|
||||
if frame:
|
||||
mentioned.append(frame)
|
||||
seen.add(fid)
|
||||
return mentioned
|
||||
|
||||
|
||||
def _load_frames(frames_dir: Path) -> list[FrameRef]:
|
||||
index_path = frames_dir / "index.json"
|
||||
if not index_path.exists():
|
||||
return []
|
||||
try:
|
||||
entries = json.loads(index_path.read_text())
|
||||
return [
|
||||
FrameRef(id=e["id"], path=Path(e["path"]), timestamp=e["timestamp"])
|
||||
for e in entries
|
||||
if Path(e["path"]).exists()
|
||||
]
|
||||
except Exception as e:
|
||||
log.warning("Could not load frames index: %s", e)
|
||||
return []
|
||||
|
||||
|
||||
class AgentRunner:
|
||||
"""Runs agent queries in a background thread, streams chunks to a callback."""
|
||||
|
||||
def __init__(self):
|
||||
self._provider: AgentProvider | None = None
|
||||
|
||||
def _get_provider(self) -> AgentProvider:
|
||||
if self._provider is None:
|
||||
self._provider = _resolve_provider()
|
||||
log.info("Agent provider: %s", self._provider.name)
|
||||
return self._provider
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
try:
|
||||
return self._get_provider().name
|
||||
except Exception:
|
||||
return "unknown"
|
||||
|
||||
def send(
|
||||
self,
|
||||
message: str,
|
||||
stream_mgr,
|
||||
tracker,
|
||||
on_chunk: Callable[[str], None],
|
||||
on_done: Callable[[str | None], None],
|
||||
):
|
||||
"""Dispatch message in a background thread.
|
||||
|
||||
on_chunk(text) — called for each streamed chunk
|
||||
on_done(error_or_None) — called when complete
|
||||
"""
|
||||
def _run():
|
||||
try:
|
||||
provider = self._get_provider()
|
||||
frames = _load_frames(stream_mgr.frames_dir)
|
||||
mentioned = _parse_mentions(message, frames)
|
||||
context = SessionContext(
|
||||
session_dir=stream_mgr.session_dir,
|
||||
frames=frames,
|
||||
duration=tracker.duration if tracker else 0.0,
|
||||
mentioned_frames=mentioned,
|
||||
)
|
||||
for chunk in provider.stream(message, context):
|
||||
on_chunk(chunk)
|
||||
on_done(None)
|
||||
except Exception as e:
|
||||
log.error("Agent error: %s", e)
|
||||
on_done(str(e))
|
||||
|
||||
Thread(target=_run, daemon=True, name="agent_runner").start()
|
||||
Reference in New Issue
Block a user