Files
mitus/cht/agent/tools.py
2026-04-09 14:46:29 -03:00

202 lines
6.8 KiB
Python

"""Built-in agent tools — ReadFrame, SearchTranscript, GetSessionInfo, CaptureFrame.
Also contains shared data-loading functions (moved from runner.py).
"""
import json
import logging
from pathlib import Path
from cht.agent.base import FrameRef, TranscriptRef, ToolContext, ToolResult
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Data loading (shared by tools and runner)
# ---------------------------------------------------------------------------
def _resolve_frame_path(frames_dir: Path, raw_path: str) -> Path | None:
p = Path(raw_path)
if p.exists():
return p
local = frames_dir / p.name
if local.exists():
return local
return None
def load_frames(frames_dir: Path) -> list[FrameRef]:
index_path = frames_dir / "index.json"
if not index_path.exists():
return []
try:
entries = json.loads(index_path.read_text())
frames = []
for e in entries:
resolved = _resolve_frame_path(frames_dir, e["path"])
if resolved:
frames.append(FrameRef(id=e["id"], path=resolved, timestamp=e["timestamp"]))
return frames
except Exception as e:
log.warning("Could not load frames index: %s", e)
return []
def load_transcript(transcript_dir: Path) -> list[TranscriptRef]:
index_path = transcript_dir / "index.json"
if not index_path.exists():
return []
try:
entries = json.loads(index_path.read_text())
return [TranscriptRef(**e) for e in entries]
except Exception as e:
log.warning("Could not load transcript index: %s", e)
return []
# ---------------------------------------------------------------------------
# Tool implementations
# ---------------------------------------------------------------------------
class ReadFrameTool:
name = "read_frame"
description = "Read frame screenshots by ID. Returns file paths for visual inspection."
def input_schema(self) -> dict:
return {
"type": "object",
"properties": {
"frame_ids": {
"type": "array",
"items": {"type": "string"},
"description": "Frame IDs like ['F0001', 'F0003']",
}
},
"required": ["frame_ids"],
}
def run(self, input: dict, context: ToolContext) -> ToolResult:
frame_ids = input.get("frame_ids", [])
frames = load_frames(context.frames_dir)
by_id = {f.id: f for f in frames}
lines = []
for fid in frame_ids:
f = by_id.get(fid)
if f:
m, s = divmod(int(f.timestamp), 60)
lines.append(f"{f.id} at {m:02d}:{s:02d}{f.path}")
else:
lines.append(f"{fid}: not found")
return ToolResult(tool_use_id="", output="\n".join(lines))
class SearchTranscriptTool:
name = "search_transcript"
description = "Search transcript segments by text substring and/or time range."
def input_schema(self) -> dict:
return {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Text substring to search for (optional)"},
"start": {"type": "number", "description": "Start time in seconds (optional)"},
"end": {"type": "number", "description": "End time in seconds (optional)"},
},
}
def run(self, input: dict, context: ToolContext) -> ToolResult:
segments = load_transcript(context.transcript_dir)
query = input.get("query", "").lower()
start = input.get("start")
end = input.get("end")
matches = []
for seg in segments:
if start is not None and seg.end < start:
continue
if end is not None and seg.start > end:
continue
if query and query not in seg.text.lower():
continue
matches.append(seg)
if not matches:
return ToolResult(tool_use_id="", output="No matching transcript segments found.")
lines = []
for seg in matches:
m1, s1 = divmod(int(seg.start), 60)
m2, s2 = divmod(int(seg.end), 60)
lines.append(f"{seg.id} [{m1:02d}:{s1:02d}-{m2:02d}:{s2:02d}] {seg.text}")
return ToolResult(tool_use_id="", output="\n".join(lines))
class GetSessionInfoTool:
name = "get_session_info"
description = "Get recording session information: duration, frame count, segment list."
def input_schema(self) -> dict:
return {"type": "object", "properties": {}}
def run(self, input: dict, context: ToolContext) -> ToolResult:
frames = load_frames(context.frames_dir)
segments = load_transcript(context.transcript_dir)
duration = 0.0
if context.tracker:
duration = getattr(context.tracker, "duration", 0.0)
m, s = divmod(int(duration), 60)
lines = [
f"Recording duration: {m:02d}:{s:02d}",
f"Frames captured: {len(frames)}",
f"Transcript segments: {len(segments)}",
]
# List recording segments from session dir
stream_dir = context.session_dir / "stream"
if stream_dir.exists():
recordings = sorted(stream_dir.glob("recording_*.mp4"))
lines.append(f"Recording files: {len(recordings)}")
for rec in recordings:
lines.append(f" {rec.name}")
return ToolResult(tool_use_id="", output="\n".join(lines))
class CaptureFrameTool:
name = "capture_frame"
description = "Capture a frame at the current recording position."
def input_schema(self) -> dict:
return {"type": "object", "properties": {}}
def run(self, input: dict, context: ToolContext) -> ToolResult:
mgr = context.stream_mgr
if mgr is None:
return ToolResult(tool_use_id="", error="No active stream manager")
if getattr(mgr, "readonly", False):
return ToolResult(tool_use_id="", error="Session is read-only, cannot capture")
import threading
result = {"done": False, "error": None}
event = threading.Event()
def _on_frames(frames):
result["done"] = True
event.set()
try:
mgr.capture_now(on_new_frames=_on_frames)
event.wait(timeout=10)
if not result["done"]:
return ToolResult(tool_use_id="", error="Capture timed out")
return ToolResult(tool_use_id="", output="Frame captured successfully.")
except Exception as e:
return ToolResult(tool_use_id="", error=str(e))
# All built-in tools
BUILTIN_TOOLS = [ReadFrameTool(), SearchTranscriptTool(), GetSessionInfoTool(), CaptureFrameTool()]