better agent
This commit is contained in:
201
cht/agent/tools.py
Normal file
201
cht/agent/tools.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""Built-in agent tools — ReadFrame, SearchTranscript, GetSessionInfo, CaptureFrame.
|
||||
|
||||
Also contains shared data-loading functions (moved from runner.py).
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from cht.agent.base import FrameRef, TranscriptRef, ToolContext, ToolResult
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data loading (shared by tools and runner)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _resolve_frame_path(frames_dir: Path, raw_path: str) -> Path | None:
|
||||
p = Path(raw_path)
|
||||
if p.exists():
|
||||
return p
|
||||
local = frames_dir / p.name
|
||||
if local.exists():
|
||||
return local
|
||||
return None
|
||||
|
||||
|
||||
def load_frames(frames_dir: Path) -> list[FrameRef]:
|
||||
index_path = frames_dir / "index.json"
|
||||
if not index_path.exists():
|
||||
return []
|
||||
try:
|
||||
entries = json.loads(index_path.read_text())
|
||||
frames = []
|
||||
for e in entries:
|
||||
resolved = _resolve_frame_path(frames_dir, e["path"])
|
||||
if resolved:
|
||||
frames.append(FrameRef(id=e["id"], path=resolved, timestamp=e["timestamp"]))
|
||||
return frames
|
||||
except Exception as e:
|
||||
log.warning("Could not load frames index: %s", e)
|
||||
return []
|
||||
|
||||
|
||||
def load_transcript(transcript_dir: Path) -> list[TranscriptRef]:
|
||||
index_path = transcript_dir / "index.json"
|
||||
if not index_path.exists():
|
||||
return []
|
||||
try:
|
||||
entries = json.loads(index_path.read_text())
|
||||
return [TranscriptRef(**e) for e in entries]
|
||||
except Exception as e:
|
||||
log.warning("Could not load transcript index: %s", e)
|
||||
return []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool implementations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ReadFrameTool:
|
||||
name = "read_frame"
|
||||
description = "Read frame screenshots by ID. Returns file paths for visual inspection."
|
||||
|
||||
def input_schema(self) -> dict:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"frame_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Frame IDs like ['F0001', 'F0003']",
|
||||
}
|
||||
},
|
||||
"required": ["frame_ids"],
|
||||
}
|
||||
|
||||
def run(self, input: dict, context: ToolContext) -> ToolResult:
|
||||
frame_ids = input.get("frame_ids", [])
|
||||
frames = load_frames(context.frames_dir)
|
||||
by_id = {f.id: f for f in frames}
|
||||
lines = []
|
||||
for fid in frame_ids:
|
||||
f = by_id.get(fid)
|
||||
if f:
|
||||
m, s = divmod(int(f.timestamp), 60)
|
||||
lines.append(f"{f.id} at {m:02d}:{s:02d} — {f.path}")
|
||||
else:
|
||||
lines.append(f"{fid}: not found")
|
||||
return ToolResult(tool_use_id="", output="\n".join(lines))
|
||||
|
||||
|
||||
class SearchTranscriptTool:
|
||||
name = "search_transcript"
|
||||
description = "Search transcript segments by text substring and/or time range."
|
||||
|
||||
def input_schema(self) -> dict:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Text substring to search for (optional)"},
|
||||
"start": {"type": "number", "description": "Start time in seconds (optional)"},
|
||||
"end": {"type": "number", "description": "End time in seconds (optional)"},
|
||||
},
|
||||
}
|
||||
|
||||
def run(self, input: dict, context: ToolContext) -> ToolResult:
|
||||
segments = load_transcript(context.transcript_dir)
|
||||
query = input.get("query", "").lower()
|
||||
start = input.get("start")
|
||||
end = input.get("end")
|
||||
|
||||
matches = []
|
||||
for seg in segments:
|
||||
if start is not None and seg.end < start:
|
||||
continue
|
||||
if end is not None and seg.start > end:
|
||||
continue
|
||||
if query and query not in seg.text.lower():
|
||||
continue
|
||||
matches.append(seg)
|
||||
|
||||
if not matches:
|
||||
return ToolResult(tool_use_id="", output="No matching transcript segments found.")
|
||||
|
||||
lines = []
|
||||
for seg in matches:
|
||||
m1, s1 = divmod(int(seg.start), 60)
|
||||
m2, s2 = divmod(int(seg.end), 60)
|
||||
lines.append(f"{seg.id} [{m1:02d}:{s1:02d}-{m2:02d}:{s2:02d}] {seg.text}")
|
||||
return ToolResult(tool_use_id="", output="\n".join(lines))
|
||||
|
||||
|
||||
class GetSessionInfoTool:
|
||||
name = "get_session_info"
|
||||
description = "Get recording session information: duration, frame count, segment list."
|
||||
|
||||
def input_schema(self) -> dict:
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
def run(self, input: dict, context: ToolContext) -> ToolResult:
|
||||
frames = load_frames(context.frames_dir)
|
||||
segments = load_transcript(context.transcript_dir)
|
||||
|
||||
duration = 0.0
|
||||
if context.tracker:
|
||||
duration = getattr(context.tracker, "duration", 0.0)
|
||||
|
||||
m, s = divmod(int(duration), 60)
|
||||
lines = [
|
||||
f"Recording duration: {m:02d}:{s:02d}",
|
||||
f"Frames captured: {len(frames)}",
|
||||
f"Transcript segments: {len(segments)}",
|
||||
]
|
||||
|
||||
# List recording segments from session dir
|
||||
stream_dir = context.session_dir / "stream"
|
||||
if stream_dir.exists():
|
||||
recordings = sorted(stream_dir.glob("recording_*.mp4"))
|
||||
lines.append(f"Recording files: {len(recordings)}")
|
||||
for rec in recordings:
|
||||
lines.append(f" {rec.name}")
|
||||
|
||||
return ToolResult(tool_use_id="", output="\n".join(lines))
|
||||
|
||||
|
||||
class CaptureFrameTool:
|
||||
name = "capture_frame"
|
||||
description = "Capture a frame at the current recording position."
|
||||
|
||||
def input_schema(self) -> dict:
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
def run(self, input: dict, context: ToolContext) -> ToolResult:
|
||||
mgr = context.stream_mgr
|
||||
if mgr is None:
|
||||
return ToolResult(tool_use_id="", error="No active stream manager")
|
||||
if getattr(mgr, "readonly", False):
|
||||
return ToolResult(tool_use_id="", error="Session is read-only, cannot capture")
|
||||
|
||||
import threading
|
||||
result = {"done": False, "error": None}
|
||||
event = threading.Event()
|
||||
|
||||
def _on_frames(frames):
|
||||
result["done"] = True
|
||||
event.set()
|
||||
|
||||
try:
|
||||
mgr.capture_now(on_new_frames=_on_frames)
|
||||
event.wait(timeout=10)
|
||||
if not result["done"]:
|
||||
return ToolResult(tool_use_id="", error="Capture timed out")
|
||||
return ToolResult(tool_use_id="", output="Frame captured successfully.")
|
||||
except Exception as e:
|
||||
return ToolResult(tool_use_id="", error=str(e))
|
||||
|
||||
|
||||
# All built-in tools
|
||||
BUILTIN_TOOLS = [ReadFrameTool(), SearchTranscriptTool(), GetSessionInfoTool(), CaptureFrameTool()]
|
||||
Reference in New Issue
Block a user