336 lines
9.6 KiB
Python
336 lines
9.6 KiB
Python
"""
|
|
Agent data model — structured messages, connections, tools, threads.
|
|
|
|
Replaces the old flat AgentProvider/SessionContext with:
|
|
- Typed content blocks and messages (conversation model)
|
|
- StreamEvent types (connection output)
|
|
- AgentConnection protocol (replaces AgentProvider)
|
|
- Tool protocol
|
|
- Thread (session state, serializable)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Iterator, Literal, Protocol, runtime_checkable
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Shared refs (used by runner mention-parsing and tools)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@dataclass
|
|
class FrameRef:
|
|
id: str # "F0001"
|
|
path: Path # absolute path to JPEG
|
|
timestamp: float # seconds into recording
|
|
|
|
|
|
@dataclass
|
|
class TranscriptRef:
|
|
id: str # "T0001"
|
|
start: float # seconds into recording
|
|
end: float # seconds into recording
|
|
text: str
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Content blocks
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@dataclass
|
|
class TextBlock:
|
|
text: str
|
|
|
|
|
|
@dataclass
|
|
class ImageBlock:
|
|
frame_id: str # "F0042"
|
|
path: Path
|
|
timestamp: float
|
|
|
|
|
|
@dataclass
|
|
class TranscriptBlock:
|
|
transcript_id: str # "T0012"
|
|
start: float
|
|
end: float
|
|
text: str
|
|
|
|
|
|
ContentBlock = TextBlock | ImageBlock | TranscriptBlock
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Messages
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@dataclass
|
|
class TokenUsage:
|
|
input_tokens: int = 0
|
|
output_tokens: int = 0
|
|
|
|
|
|
def _msg_id() -> str:
|
|
return uuid.uuid4().hex[:12]
|
|
|
|
|
|
@dataclass
|
|
class UserMessage:
|
|
content: list[ContentBlock]
|
|
id: str = field(default_factory=_msg_id)
|
|
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
|
|
|
|
@dataclass
|
|
class AssistantMessage:
|
|
content: list[ContentBlock]
|
|
model: str = ""
|
|
id: str = field(default_factory=_msg_id)
|
|
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
token_usage: TokenUsage | None = None
|
|
|
|
|
|
@dataclass
|
|
class ToolUse:
|
|
tool_name: str
|
|
input: dict
|
|
id: str = field(default_factory=_msg_id)
|
|
status: Literal["pending", "running", "done", "error"] = "pending"
|
|
|
|
|
|
@dataclass
|
|
class ToolResult:
|
|
tool_use_id: str
|
|
output: str | None = None
|
|
error: str | None = None
|
|
|
|
|
|
Message = UserMessage | AssistantMessage | ToolUse | ToolResult
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Thread (session state)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@dataclass
|
|
class Thread:
|
|
messages: list[Message] = field(default_factory=list)
|
|
id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
|
|
created: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
updated: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
token_usage: TokenUsage | None = None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Stream events (yielded by AgentConnection)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@dataclass
|
|
class TextDelta:
|
|
text: str
|
|
|
|
|
|
@dataclass
|
|
class ToolCallStart:
|
|
id: str
|
|
name: str
|
|
input: dict
|
|
|
|
|
|
@dataclass
|
|
class ToolCallEnd:
|
|
id: str
|
|
|
|
|
|
@dataclass
|
|
class Done:
|
|
stop_reason: str # "end_turn", "tool_use", "max_tokens"
|
|
|
|
|
|
@dataclass
|
|
class Error:
|
|
message: str
|
|
|
|
|
|
StreamEvent = TextDelta | ToolCallStart | ToolCallEnd | Done | Error
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tool protocol
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@runtime_checkable
|
|
class Tool(Protocol):
|
|
name: str
|
|
description: str
|
|
|
|
def input_schema(self) -> dict: ...
|
|
def run(self, input: dict, context: ToolContext) -> ToolResult: ...
|
|
|
|
|
|
@dataclass
|
|
class ToolContext:
|
|
"""Runtime context passed to tools — references to live app objects."""
|
|
session_dir: Path
|
|
frames_dir: Path
|
|
transcript_dir: Path
|
|
stream_mgr: object | None = None # StreamManager, optional
|
|
tracker: object | None = None # RecordingTracker, optional
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# AgentConnection protocol (replaces AgentProvider)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@runtime_checkable
|
|
class AgentConnection(Protocol):
|
|
name: str
|
|
|
|
def available_models(self) -> list[str]: ...
|
|
def get_model(self) -> str: ...
|
|
def set_model(self, model: str) -> None: ...
|
|
|
|
def prompt(
|
|
self,
|
|
messages: list[Message],
|
|
tools: list[Tool],
|
|
) -> Iterator[StreamEvent]:
|
|
"""Send messages to the model, yield stream events."""
|
|
...
|
|
|
|
def cancel(self) -> None: ...
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Thread serialization
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class _Encoder(json.JSONEncoder):
|
|
def default(self, o):
|
|
if isinstance(o, Path):
|
|
return str(o)
|
|
if isinstance(o, datetime):
|
|
return o.isoformat()
|
|
return super().default(o)
|
|
|
|
|
|
def _block_to_dict(b: ContentBlock) -> dict:
|
|
if isinstance(b, TextBlock):
|
|
return {"type": "text", "text": b.text}
|
|
if isinstance(b, ImageBlock):
|
|
return {"type": "image", "frame_id": b.frame_id, "path": str(b.path), "timestamp": b.timestamp}
|
|
if isinstance(b, TranscriptBlock):
|
|
return {"type": "transcript", "transcript_id": b.transcript_id, "start": b.start, "end": b.end, "text": b.text}
|
|
raise TypeError(f"Unknown block type: {type(b)}")
|
|
|
|
|
|
def _block_from_dict(d: dict) -> ContentBlock:
|
|
t = d["type"]
|
|
if t == "text":
|
|
return TextBlock(text=d["text"])
|
|
if t == "image":
|
|
return ImageBlock(frame_id=d["frame_id"], path=Path(d["path"]), timestamp=d["timestamp"])
|
|
if t == "transcript":
|
|
return TranscriptBlock(transcript_id=d["transcript_id"], start=d["start"], end=d["end"], text=d["text"])
|
|
raise ValueError(f"Unknown block type: {t}")
|
|
|
|
|
|
def _msg_to_dict(m: Message) -> dict:
|
|
if isinstance(m, UserMessage):
|
|
return {
|
|
"type": "user",
|
|
"id": m.id,
|
|
"content": [_block_to_dict(b) for b in m.content],
|
|
"timestamp": m.timestamp.isoformat(),
|
|
}
|
|
if isinstance(m, AssistantMessage):
|
|
d = {
|
|
"type": "assistant",
|
|
"id": m.id,
|
|
"content": [_block_to_dict(b) for b in m.content],
|
|
"timestamp": m.timestamp.isoformat(),
|
|
"model": m.model,
|
|
}
|
|
if m.token_usage:
|
|
d["token_usage"] = {"input_tokens": m.token_usage.input_tokens, "output_tokens": m.token_usage.output_tokens}
|
|
return d
|
|
if isinstance(m, ToolUse):
|
|
return {"type": "tool_use", "id": m.id, "tool_name": m.tool_name, "input": m.input, "status": m.status}
|
|
if isinstance(m, ToolResult):
|
|
return {"type": "tool_result", "tool_use_id": m.tool_use_id, "output": m.output, "error": m.error}
|
|
raise TypeError(f"Unknown message type: {type(m)}")
|
|
|
|
|
|
def _msg_from_dict(d: dict) -> Message:
|
|
t = d["type"]
|
|
if t == "user":
|
|
return UserMessage(
|
|
id=d["id"],
|
|
content=[_block_from_dict(b) for b in d["content"]],
|
|
timestamp=datetime.fromisoformat(d["timestamp"]),
|
|
)
|
|
if t == "assistant":
|
|
tu = None
|
|
if "token_usage" in d and d["token_usage"]:
|
|
tu = TokenUsage(**d["token_usage"])
|
|
return AssistantMessage(
|
|
id=d["id"],
|
|
content=[_block_from_dict(b) for b in d["content"]],
|
|
timestamp=datetime.fromisoformat(d["timestamp"]),
|
|
model=d.get("model", ""),
|
|
token_usage=tu,
|
|
)
|
|
if t == "tool_use":
|
|
return ToolUse(id=d["id"], tool_name=d["tool_name"], input=d["input"], status=d.get("status", "done"))
|
|
if t == "tool_result":
|
|
return ToolResult(tool_use_id=d["tool_use_id"], output=d.get("output"), error=d.get("error"))
|
|
raise ValueError(f"Unknown message type: {t}")
|
|
|
|
|
|
def thread_to_dict(thread: Thread) -> dict:
|
|
d = {
|
|
"id": thread.id,
|
|
"messages": [_msg_to_dict(m) for m in thread.messages],
|
|
"created": thread.created.isoformat(),
|
|
"updated": thread.updated.isoformat(),
|
|
}
|
|
if thread.token_usage:
|
|
d["token_usage"] = {"input_tokens": thread.token_usage.input_tokens, "output_tokens": thread.token_usage.output_tokens}
|
|
return d
|
|
|
|
|
|
def thread_from_dict(d: dict) -> Thread:
|
|
tu = None
|
|
if "token_usage" in d and d["token_usage"]:
|
|
tu = TokenUsage(**d["token_usage"])
|
|
return Thread(
|
|
id=d["id"],
|
|
messages=[_msg_from_dict(m) for m in d["messages"]],
|
|
created=datetime.fromisoformat(d["created"]),
|
|
updated=datetime.fromisoformat(d["updated"]),
|
|
token_usage=tu,
|
|
)
|
|
|
|
|
|
def save_thread(thread: Thread, session_dir: Path) -> None:
|
|
agent_dir = session_dir / "agent"
|
|
agent_dir.mkdir(parents=True, exist_ok=True)
|
|
path = agent_dir / "thread.json"
|
|
thread.updated = datetime.now(timezone.utc)
|
|
path.write_text(json.dumps(thread_to_dict(thread), indent=2, cls=_Encoder))
|
|
|
|
|
|
def load_thread(session_dir: Path) -> Thread | None:
|
|
path = session_dir / "agent" / "thread.json"
|
|
if not path.exists():
|
|
return None
|
|
try:
|
|
return thread_from_dict(json.loads(path.read_text()))
|
|
except Exception:
|
|
return None
|