""" Agent data model — structured messages, connections, tools, threads. Replaces the old flat AgentProvider/SessionContext with: - Typed content blocks and messages (conversation model) - StreamEvent types (connection output) - AgentConnection protocol (replaces AgentProvider) - Tool protocol - Thread (session state, serializable) """ from __future__ import annotations import json import uuid from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path from typing import Iterator, Literal, Protocol, runtime_checkable # --------------------------------------------------------------------------- # Shared refs (used by runner mention-parsing and tools) # --------------------------------------------------------------------------- @dataclass class FrameRef: id: str # "F0001" path: Path # absolute path to JPEG timestamp: float # seconds into recording @dataclass class TranscriptRef: id: str # "T0001" start: float # seconds into recording end: float # seconds into recording text: str # --------------------------------------------------------------------------- # Content blocks # --------------------------------------------------------------------------- @dataclass class TextBlock: text: str @dataclass class ImageBlock: frame_id: str # "F0042" path: Path timestamp: float @dataclass class TranscriptBlock: transcript_id: str # "T0012" start: float end: float text: str ContentBlock = TextBlock | ImageBlock | TranscriptBlock # --------------------------------------------------------------------------- # Messages # --------------------------------------------------------------------------- @dataclass class TokenUsage: input_tokens: int = 0 output_tokens: int = 0 def _msg_id() -> str: return uuid.uuid4().hex[:12] @dataclass class UserMessage: content: list[ContentBlock] id: str = field(default_factory=_msg_id) timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) @dataclass class AssistantMessage: content: list[ContentBlock] model: str = "" id: str = field(default_factory=_msg_id) timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) token_usage: TokenUsage | None = None @dataclass class ToolUse: tool_name: str input: dict id: str = field(default_factory=_msg_id) status: Literal["pending", "running", "done", "error"] = "pending" @dataclass class ToolResult: tool_use_id: str output: str | None = None error: str | None = None Message = UserMessage | AssistantMessage | ToolUse | ToolResult # --------------------------------------------------------------------------- # Thread (session state) # --------------------------------------------------------------------------- @dataclass class Thread: messages: list[Message] = field(default_factory=list) id: str = field(default_factory=lambda: uuid.uuid4().hex[:12]) created: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) updated: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) token_usage: TokenUsage | None = None # --------------------------------------------------------------------------- # Stream events (yielded by AgentConnection) # --------------------------------------------------------------------------- @dataclass class TextDelta: text: str @dataclass class ToolCallStart: id: str name: str input: dict @dataclass class ToolCallEnd: id: str @dataclass class Done: stop_reason: str # "end_turn", "tool_use", "max_tokens" @dataclass class Error: message: str StreamEvent = TextDelta | ToolCallStart | ToolCallEnd | Done | Error # --------------------------------------------------------------------------- # Tool protocol # --------------------------------------------------------------------------- @runtime_checkable class Tool(Protocol): name: str description: str def input_schema(self) -> dict: ... def run(self, input: dict, context: ToolContext) -> ToolResult: ... @dataclass class ToolContext: """Runtime context passed to tools — references to live app objects.""" session_dir: Path frames_dir: Path transcript_dir: Path stream_mgr: object | None = None # StreamManager, optional tracker: object | None = None # RecordingTracker, optional # --------------------------------------------------------------------------- # AgentConnection protocol (replaces AgentProvider) # --------------------------------------------------------------------------- @runtime_checkable class AgentConnection(Protocol): name: str def available_models(self) -> list[str]: ... def get_model(self) -> str: ... def set_model(self, model: str) -> None: ... def prompt( self, messages: list[Message], tools: list[Tool], ) -> Iterator[StreamEvent]: """Send messages to the model, yield stream events.""" ... def cancel(self) -> None: ... # --------------------------------------------------------------------------- # Thread serialization # --------------------------------------------------------------------------- class _Encoder(json.JSONEncoder): def default(self, o): if isinstance(o, Path): return str(o) if isinstance(o, datetime): return o.isoformat() return super().default(o) def _block_to_dict(b: ContentBlock) -> dict: if isinstance(b, TextBlock): return {"type": "text", "text": b.text} if isinstance(b, ImageBlock): return {"type": "image", "frame_id": b.frame_id, "path": str(b.path), "timestamp": b.timestamp} if isinstance(b, TranscriptBlock): return {"type": "transcript", "transcript_id": b.transcript_id, "start": b.start, "end": b.end, "text": b.text} raise TypeError(f"Unknown block type: {type(b)}") def _block_from_dict(d: dict) -> ContentBlock: t = d["type"] if t == "text": return TextBlock(text=d["text"]) if t == "image": return ImageBlock(frame_id=d["frame_id"], path=Path(d["path"]), timestamp=d["timestamp"]) if t == "transcript": return TranscriptBlock(transcript_id=d["transcript_id"], start=d["start"], end=d["end"], text=d["text"]) raise ValueError(f"Unknown block type: {t}") def _msg_to_dict(m: Message) -> dict: if isinstance(m, UserMessage): return { "type": "user", "id": m.id, "content": [_block_to_dict(b) for b in m.content], "timestamp": m.timestamp.isoformat(), } if isinstance(m, AssistantMessage): d = { "type": "assistant", "id": m.id, "content": [_block_to_dict(b) for b in m.content], "timestamp": m.timestamp.isoformat(), "model": m.model, } if m.token_usage: d["token_usage"] = {"input_tokens": m.token_usage.input_tokens, "output_tokens": m.token_usage.output_tokens} return d if isinstance(m, ToolUse): return {"type": "tool_use", "id": m.id, "tool_name": m.tool_name, "input": m.input, "status": m.status} if isinstance(m, ToolResult): return {"type": "tool_result", "tool_use_id": m.tool_use_id, "output": m.output, "error": m.error} raise TypeError(f"Unknown message type: {type(m)}") def _msg_from_dict(d: dict) -> Message: t = d["type"] if t == "user": return UserMessage( id=d["id"], content=[_block_from_dict(b) for b in d["content"]], timestamp=datetime.fromisoformat(d["timestamp"]), ) if t == "assistant": tu = None if "token_usage" in d and d["token_usage"]: tu = TokenUsage(**d["token_usage"]) return AssistantMessage( id=d["id"], content=[_block_from_dict(b) for b in d["content"]], timestamp=datetime.fromisoformat(d["timestamp"]), model=d.get("model", ""), token_usage=tu, ) if t == "tool_use": return ToolUse(id=d["id"], tool_name=d["tool_name"], input=d["input"], status=d.get("status", "done")) if t == "tool_result": return ToolResult(tool_use_id=d["tool_use_id"], output=d.get("output"), error=d.get("error")) raise ValueError(f"Unknown message type: {t}") def thread_to_dict(thread: Thread) -> dict: d = { "id": thread.id, "messages": [_msg_to_dict(m) for m in thread.messages], "created": thread.created.isoformat(), "updated": thread.updated.isoformat(), } if thread.token_usage: d["token_usage"] = {"input_tokens": thread.token_usage.input_tokens, "output_tokens": thread.token_usage.output_tokens} return d def thread_from_dict(d: dict) -> Thread: tu = None if "token_usage" in d and d["token_usage"]: tu = TokenUsage(**d["token_usage"]) return Thread( id=d["id"], messages=[_msg_from_dict(m) for m in d["messages"]], created=datetime.fromisoformat(d["created"]), updated=datetime.fromisoformat(d["updated"]), token_usage=tu, ) def save_thread(thread: Thread, session_dir: Path) -> None: agent_dir = session_dir / "agent" agent_dir.mkdir(parents=True, exist_ok=True) path = agent_dir / "thread.json" thread.updated = datetime.now(timezone.utc) path.write_text(json.dumps(thread_to_dict(thread), indent=2, cls=_Encoder)) def load_thread(session_dir: Path) -> Thread | None: path = session_dir / "agent" / "thread.json" if not path.exists(): return None try: return thread_from_dict(json.loads(path.read_text())) except Exception: return None