202 lines
6.8 KiB
Python
202 lines
6.8 KiB
Python
"""Built-in agent tools — ReadFrame, SearchTranscript, GetSessionInfo, CaptureFrame.
|
|
|
|
Also contains shared data-loading functions (moved from runner.py).
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
from cht.agent.base import FrameRef, TranscriptRef, ToolContext, ToolResult
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Data loading (shared by tools and runner)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _resolve_frame_path(frames_dir: Path, raw_path: str) -> Path | None:
|
|
p = Path(raw_path)
|
|
if p.exists():
|
|
return p
|
|
local = frames_dir / p.name
|
|
if local.exists():
|
|
return local
|
|
return None
|
|
|
|
|
|
def load_frames(frames_dir: Path) -> list[FrameRef]:
|
|
index_path = frames_dir / "index.json"
|
|
if not index_path.exists():
|
|
return []
|
|
try:
|
|
entries = json.loads(index_path.read_text())
|
|
frames = []
|
|
for e in entries:
|
|
resolved = _resolve_frame_path(frames_dir, e["path"])
|
|
if resolved:
|
|
frames.append(FrameRef(id=e["id"], path=resolved, timestamp=e["timestamp"]))
|
|
return frames
|
|
except Exception as e:
|
|
log.warning("Could not load frames index: %s", e)
|
|
return []
|
|
|
|
|
|
def load_transcript(transcript_dir: Path) -> list[TranscriptRef]:
|
|
index_path = transcript_dir / "index.json"
|
|
if not index_path.exists():
|
|
return []
|
|
try:
|
|
entries = json.loads(index_path.read_text())
|
|
return [TranscriptRef(**e) for e in entries]
|
|
except Exception as e:
|
|
log.warning("Could not load transcript index: %s", e)
|
|
return []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tool implementations
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class ReadFrameTool:
|
|
name = "read_frame"
|
|
description = "Read frame screenshots by ID. Returns file paths for visual inspection."
|
|
|
|
def input_schema(self) -> dict:
|
|
return {
|
|
"type": "object",
|
|
"properties": {
|
|
"frame_ids": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
"description": "Frame IDs like ['F0001', 'F0003']",
|
|
}
|
|
},
|
|
"required": ["frame_ids"],
|
|
}
|
|
|
|
def run(self, input: dict, context: ToolContext) -> ToolResult:
|
|
frame_ids = input.get("frame_ids", [])
|
|
frames = load_frames(context.frames_dir)
|
|
by_id = {f.id: f for f in frames}
|
|
lines = []
|
|
for fid in frame_ids:
|
|
f = by_id.get(fid)
|
|
if f:
|
|
m, s = divmod(int(f.timestamp), 60)
|
|
lines.append(f"{f.id} at {m:02d}:{s:02d} — {f.path}")
|
|
else:
|
|
lines.append(f"{fid}: not found")
|
|
return ToolResult(tool_use_id="", output="\n".join(lines))
|
|
|
|
|
|
class SearchTranscriptTool:
|
|
name = "search_transcript"
|
|
description = "Search transcript segments by text substring and/or time range."
|
|
|
|
def input_schema(self) -> dict:
|
|
return {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {"type": "string", "description": "Text substring to search for (optional)"},
|
|
"start": {"type": "number", "description": "Start time in seconds (optional)"},
|
|
"end": {"type": "number", "description": "End time in seconds (optional)"},
|
|
},
|
|
}
|
|
|
|
def run(self, input: dict, context: ToolContext) -> ToolResult:
|
|
segments = load_transcript(context.transcript_dir)
|
|
query = input.get("query", "").lower()
|
|
start = input.get("start")
|
|
end = input.get("end")
|
|
|
|
matches = []
|
|
for seg in segments:
|
|
if start is not None and seg.end < start:
|
|
continue
|
|
if end is not None and seg.start > end:
|
|
continue
|
|
if query and query not in seg.text.lower():
|
|
continue
|
|
matches.append(seg)
|
|
|
|
if not matches:
|
|
return ToolResult(tool_use_id="", output="No matching transcript segments found.")
|
|
|
|
lines = []
|
|
for seg in matches:
|
|
m1, s1 = divmod(int(seg.start), 60)
|
|
m2, s2 = divmod(int(seg.end), 60)
|
|
lines.append(f"{seg.id} [{m1:02d}:{s1:02d}-{m2:02d}:{s2:02d}] {seg.text}")
|
|
return ToolResult(tool_use_id="", output="\n".join(lines))
|
|
|
|
|
|
class GetSessionInfoTool:
|
|
name = "get_session_info"
|
|
description = "Get recording session information: duration, frame count, segment list."
|
|
|
|
def input_schema(self) -> dict:
|
|
return {"type": "object", "properties": {}}
|
|
|
|
def run(self, input: dict, context: ToolContext) -> ToolResult:
|
|
frames = load_frames(context.frames_dir)
|
|
segments = load_transcript(context.transcript_dir)
|
|
|
|
duration = 0.0
|
|
if context.tracker:
|
|
duration = getattr(context.tracker, "duration", 0.0)
|
|
|
|
m, s = divmod(int(duration), 60)
|
|
lines = [
|
|
f"Recording duration: {m:02d}:{s:02d}",
|
|
f"Frames captured: {len(frames)}",
|
|
f"Transcript segments: {len(segments)}",
|
|
]
|
|
|
|
# List recording segments from session dir
|
|
stream_dir = context.session_dir / "stream"
|
|
if stream_dir.exists():
|
|
recordings = sorted(stream_dir.glob("recording_*.mp4"))
|
|
lines.append(f"Recording files: {len(recordings)}")
|
|
for rec in recordings:
|
|
lines.append(f" {rec.name}")
|
|
|
|
return ToolResult(tool_use_id="", output="\n".join(lines))
|
|
|
|
|
|
class CaptureFrameTool:
|
|
name = "capture_frame"
|
|
description = "Capture a frame at the current recording position."
|
|
|
|
def input_schema(self) -> dict:
|
|
return {"type": "object", "properties": {}}
|
|
|
|
def run(self, input: dict, context: ToolContext) -> ToolResult:
|
|
mgr = context.stream_mgr
|
|
if mgr is None:
|
|
return ToolResult(tool_use_id="", error="No active stream manager")
|
|
if getattr(mgr, "readonly", False):
|
|
return ToolResult(tool_use_id="", error="Session is read-only, cannot capture")
|
|
|
|
import threading
|
|
result = {"done": False, "error": None}
|
|
event = threading.Event()
|
|
|
|
def _on_frames(frames):
|
|
result["done"] = True
|
|
event.set()
|
|
|
|
try:
|
|
mgr.capture_now(on_new_frames=_on_frames)
|
|
event.wait(timeout=10)
|
|
if not result["done"]:
|
|
return ToolResult(tool_use_id="", error="Capture timed out")
|
|
return ToolResult(tool_use_id="", output="Frame captured successfully.")
|
|
except Exception as e:
|
|
return ToolResult(tool_use_id="", error=str(e))
|
|
|
|
|
|
# All built-in tools
|
|
BUILTIN_TOOLS = [ReadFrameTool(), SearchTranscriptTool(), GetSessionInfoTool(), CaptureFrameTool()]
|