Files
mitus/cht/agent/base.py
2026-04-09 14:46:29 -03:00

336 lines
9.6 KiB
Python

"""
Agent data model — structured messages, connections, tools, threads.
Replaces the old flat AgentProvider/SessionContext with:
- Typed content blocks and messages (conversation model)
- StreamEvent types (connection output)
- AgentConnection protocol (replaces AgentProvider)
- Tool protocol
- Thread (session state, serializable)
"""
from __future__ import annotations
import json
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Iterator, Literal, Protocol, runtime_checkable
# ---------------------------------------------------------------------------
# Shared refs (used by runner mention-parsing and tools)
# ---------------------------------------------------------------------------
@dataclass
class FrameRef:
id: str # "F0001"
path: Path # absolute path to JPEG
timestamp: float # seconds into recording
@dataclass
class TranscriptRef:
id: str # "T0001"
start: float # seconds into recording
end: float # seconds into recording
text: str
# ---------------------------------------------------------------------------
# Content blocks
# ---------------------------------------------------------------------------
@dataclass
class TextBlock:
text: str
@dataclass
class ImageBlock:
frame_id: str # "F0042"
path: Path
timestamp: float
@dataclass
class TranscriptBlock:
transcript_id: str # "T0012"
start: float
end: float
text: str
ContentBlock = TextBlock | ImageBlock | TranscriptBlock
# ---------------------------------------------------------------------------
# Messages
# ---------------------------------------------------------------------------
@dataclass
class TokenUsage:
input_tokens: int = 0
output_tokens: int = 0
def _msg_id() -> str:
return uuid.uuid4().hex[:12]
@dataclass
class UserMessage:
content: list[ContentBlock]
id: str = field(default_factory=_msg_id)
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
@dataclass
class AssistantMessage:
content: list[ContentBlock]
model: str = ""
id: str = field(default_factory=_msg_id)
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
token_usage: TokenUsage | None = None
@dataclass
class ToolUse:
tool_name: str
input: dict
id: str = field(default_factory=_msg_id)
status: Literal["pending", "running", "done", "error"] = "pending"
@dataclass
class ToolResult:
tool_use_id: str
output: str | None = None
error: str | None = None
Message = UserMessage | AssistantMessage | ToolUse | ToolResult
# ---------------------------------------------------------------------------
# Thread (session state)
# ---------------------------------------------------------------------------
@dataclass
class Thread:
messages: list[Message] = field(default_factory=list)
id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
created: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
updated: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
token_usage: TokenUsage | None = None
# ---------------------------------------------------------------------------
# Stream events (yielded by AgentConnection)
# ---------------------------------------------------------------------------
@dataclass
class TextDelta:
text: str
@dataclass
class ToolCallStart:
id: str
name: str
input: dict
@dataclass
class ToolCallEnd:
id: str
@dataclass
class Done:
stop_reason: str # "end_turn", "tool_use", "max_tokens"
@dataclass
class Error:
message: str
StreamEvent = TextDelta | ToolCallStart | ToolCallEnd | Done | Error
# ---------------------------------------------------------------------------
# Tool protocol
# ---------------------------------------------------------------------------
@runtime_checkable
class Tool(Protocol):
name: str
description: str
def input_schema(self) -> dict: ...
def run(self, input: dict, context: ToolContext) -> ToolResult: ...
@dataclass
class ToolContext:
"""Runtime context passed to tools — references to live app objects."""
session_dir: Path
frames_dir: Path
transcript_dir: Path
stream_mgr: object | None = None # StreamManager, optional
tracker: object | None = None # RecordingTracker, optional
# ---------------------------------------------------------------------------
# AgentConnection protocol (replaces AgentProvider)
# ---------------------------------------------------------------------------
@runtime_checkable
class AgentConnection(Protocol):
name: str
def available_models(self) -> list[str]: ...
def get_model(self) -> str: ...
def set_model(self, model: str) -> None: ...
def prompt(
self,
messages: list[Message],
tools: list[Tool],
) -> Iterator[StreamEvent]:
"""Send messages to the model, yield stream events."""
...
def cancel(self) -> None: ...
# ---------------------------------------------------------------------------
# Thread serialization
# ---------------------------------------------------------------------------
class _Encoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, Path):
return str(o)
if isinstance(o, datetime):
return o.isoformat()
return super().default(o)
def _block_to_dict(b: ContentBlock) -> dict:
if isinstance(b, TextBlock):
return {"type": "text", "text": b.text}
if isinstance(b, ImageBlock):
return {"type": "image", "frame_id": b.frame_id, "path": str(b.path), "timestamp": b.timestamp}
if isinstance(b, TranscriptBlock):
return {"type": "transcript", "transcript_id": b.transcript_id, "start": b.start, "end": b.end, "text": b.text}
raise TypeError(f"Unknown block type: {type(b)}")
def _block_from_dict(d: dict) -> ContentBlock:
t = d["type"]
if t == "text":
return TextBlock(text=d["text"])
if t == "image":
return ImageBlock(frame_id=d["frame_id"], path=Path(d["path"]), timestamp=d["timestamp"])
if t == "transcript":
return TranscriptBlock(transcript_id=d["transcript_id"], start=d["start"], end=d["end"], text=d["text"])
raise ValueError(f"Unknown block type: {t}")
def _msg_to_dict(m: Message) -> dict:
if isinstance(m, UserMessage):
return {
"type": "user",
"id": m.id,
"content": [_block_to_dict(b) for b in m.content],
"timestamp": m.timestamp.isoformat(),
}
if isinstance(m, AssistantMessage):
d = {
"type": "assistant",
"id": m.id,
"content": [_block_to_dict(b) for b in m.content],
"timestamp": m.timestamp.isoformat(),
"model": m.model,
}
if m.token_usage:
d["token_usage"] = {"input_tokens": m.token_usage.input_tokens, "output_tokens": m.token_usage.output_tokens}
return d
if isinstance(m, ToolUse):
return {"type": "tool_use", "id": m.id, "tool_name": m.tool_name, "input": m.input, "status": m.status}
if isinstance(m, ToolResult):
return {"type": "tool_result", "tool_use_id": m.tool_use_id, "output": m.output, "error": m.error}
raise TypeError(f"Unknown message type: {type(m)}")
def _msg_from_dict(d: dict) -> Message:
t = d["type"]
if t == "user":
return UserMessage(
id=d["id"],
content=[_block_from_dict(b) for b in d["content"]],
timestamp=datetime.fromisoformat(d["timestamp"]),
)
if t == "assistant":
tu = None
if "token_usage" in d and d["token_usage"]:
tu = TokenUsage(**d["token_usage"])
return AssistantMessage(
id=d["id"],
content=[_block_from_dict(b) for b in d["content"]],
timestamp=datetime.fromisoformat(d["timestamp"]),
model=d.get("model", ""),
token_usage=tu,
)
if t == "tool_use":
return ToolUse(id=d["id"], tool_name=d["tool_name"], input=d["input"], status=d.get("status", "done"))
if t == "tool_result":
return ToolResult(tool_use_id=d["tool_use_id"], output=d.get("output"), error=d.get("error"))
raise ValueError(f"Unknown message type: {t}")
def thread_to_dict(thread: Thread) -> dict:
d = {
"id": thread.id,
"messages": [_msg_to_dict(m) for m in thread.messages],
"created": thread.created.isoformat(),
"updated": thread.updated.isoformat(),
}
if thread.token_usage:
d["token_usage"] = {"input_tokens": thread.token_usage.input_tokens, "output_tokens": thread.token_usage.output_tokens}
return d
def thread_from_dict(d: dict) -> Thread:
tu = None
if "token_usage" in d and d["token_usage"]:
tu = TokenUsage(**d["token_usage"])
return Thread(
id=d["id"],
messages=[_msg_from_dict(m) for m in d["messages"]],
created=datetime.fromisoformat(d["created"]),
updated=datetime.fromisoformat(d["updated"]),
token_usage=tu,
)
def save_thread(thread: Thread, session_dir: Path) -> None:
agent_dir = session_dir / "agent"
agent_dir.mkdir(parents=True, exist_ok=True)
path = agent_dir / "thread.json"
thread.updated = datetime.now(timezone.utc)
path.write_text(json.dumps(thread_to_dict(thread), indent=2, cls=_Encoder))
def load_thread(session_dir: Path) -> Thread | None:
path = session_dir / "agent" / "thread.json"
if not path.exists():
return None
try:
return thread_from_dict(json.loads(path.read_text()))
except Exception:
return None