better agent

This commit is contained in:
2026-04-09 14:46:29 -03:00
parent ade92069c0
commit 64ecdca71e
11 changed files with 1424 additions and 434 deletions

View File

@@ -1,16 +1,28 @@
"""
Abstract base for agent providers.
Agent data model — structured messages, connections, tools, threads.
Each provider takes a user message + session context and yields response
text chunks for streaming into the UI.
Replaces the old flat AgentProvider/SessionContext with:
- Typed content blocks and messages (conversation model)
- StreamEvent types (connection output)
- AgentConnection protocol (replaces AgentProvider)
- Tool protocol
- Thread (session state, serializable)
"""
from abc import ABC, abstractmethod
from __future__ import annotations
import json
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Iterator
from typing import Iterator, Literal, Protocol, runtime_checkable
# ---------------------------------------------------------------------------
# Shared refs (used by runner mention-parsing and tools)
# ---------------------------------------------------------------------------
@dataclass
class FrameRef:
id: str # "F0001"
@@ -26,40 +38,298 @@ class TranscriptRef:
text: str
# ---------------------------------------------------------------------------
# Content blocks
# ---------------------------------------------------------------------------
@dataclass
class SessionContext:
class TextBlock:
text: str
@dataclass
class ImageBlock:
frame_id: str # "F0042"
path: Path
timestamp: float
@dataclass
class TranscriptBlock:
transcript_id: str # "T0012"
start: float
end: float
text: str
ContentBlock = TextBlock | ImageBlock | TranscriptBlock
# ---------------------------------------------------------------------------
# Messages
# ---------------------------------------------------------------------------
@dataclass
class TokenUsage:
input_tokens: int = 0
output_tokens: int = 0
def _msg_id() -> str:
return uuid.uuid4().hex[:12]
@dataclass
class UserMessage:
content: list[ContentBlock]
id: str = field(default_factory=_msg_id)
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
@dataclass
class AssistantMessage:
content: list[ContentBlock]
model: str = ""
id: str = field(default_factory=_msg_id)
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
token_usage: TokenUsage | None = None
@dataclass
class ToolUse:
tool_name: str
input: dict
id: str = field(default_factory=_msg_id)
status: Literal["pending", "running", "done", "error"] = "pending"
@dataclass
class ToolResult:
tool_use_id: str
output: str | None = None
error: str | None = None
Message = UserMessage | AssistantMessage | ToolUse | ToolResult
# ---------------------------------------------------------------------------
# Thread (session state)
# ---------------------------------------------------------------------------
@dataclass
class Thread:
messages: list[Message] = field(default_factory=list)
id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
created: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
updated: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
token_usage: TokenUsage | None = None
# ---------------------------------------------------------------------------
# Stream events (yielded by AgentConnection)
# ---------------------------------------------------------------------------
@dataclass
class TextDelta:
text: str
@dataclass
class ToolCallStart:
id: str
name: str
input: dict
@dataclass
class ToolCallEnd:
id: str
@dataclass
class Done:
stop_reason: str # "end_turn", "tool_use", "max_tokens"
@dataclass
class Error:
message: str
StreamEvent = TextDelta | ToolCallStart | ToolCallEnd | Done | Error
# ---------------------------------------------------------------------------
# Tool protocol
# ---------------------------------------------------------------------------
@runtime_checkable
class Tool(Protocol):
name: str
description: str
def input_schema(self) -> dict: ...
def run(self, input: dict, context: ToolContext) -> ToolResult: ...
@dataclass
class ToolContext:
"""Runtime context passed to tools — references to live app objects."""
session_dir: Path
frames: list[FrameRef] # all captured frames so far
duration: float # current recording duration (seconds)
mentioned_frames: list[FrameRef] = field(default_factory=list)
transcript_segments: list[TranscriptRef] = field(default_factory=list)
mentioned_transcripts: list[TranscriptRef] = field(default_factory=list)
history: list[tuple[str, str]] = field(default_factory=list) # [(role, text), ...]
frames_dir: Path
transcript_dir: Path
stream_mgr: object | None = None # StreamManager, optional
tracker: object | None = None # RecordingTracker, optional
class AgentProvider(ABC):
@abstractmethod
def stream(self, message: str, context: SessionContext) -> Iterator[str]:
"""Yield response text chunks."""
# ---------------------------------------------------------------------------
# AgentConnection protocol (replaces AgentProvider)
# ---------------------------------------------------------------------------
@runtime_checkable
class AgentConnection(Protocol):
name: str
def available_models(self) -> list[str]: ...
def get_model(self) -> str: ...
def set_model(self, model: str) -> None: ...
def prompt(
self,
messages: list[Message],
tools: list[Tool],
) -> Iterator[StreamEvent]:
"""Send messages to the model, yield stream events."""
...
@property
@abstractmethod
def name(self) -> str:
...
def cancel(self) -> None: ...
@property
@abstractmethod
def available_models(self) -> list[str]:
"""Return list of model IDs this provider supports."""
...
@property
@abstractmethod
def model(self) -> str:
...
# ---------------------------------------------------------------------------
# Thread serialization
# ---------------------------------------------------------------------------
@model.setter
@abstractmethod
def model(self, value: str):
...
class _Encoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, Path):
return str(o)
if isinstance(o, datetime):
return o.isoformat()
return super().default(o)
def _block_to_dict(b: ContentBlock) -> dict:
if isinstance(b, TextBlock):
return {"type": "text", "text": b.text}
if isinstance(b, ImageBlock):
return {"type": "image", "frame_id": b.frame_id, "path": str(b.path), "timestamp": b.timestamp}
if isinstance(b, TranscriptBlock):
return {"type": "transcript", "transcript_id": b.transcript_id, "start": b.start, "end": b.end, "text": b.text}
raise TypeError(f"Unknown block type: {type(b)}")
def _block_from_dict(d: dict) -> ContentBlock:
t = d["type"]
if t == "text":
return TextBlock(text=d["text"])
if t == "image":
return ImageBlock(frame_id=d["frame_id"], path=Path(d["path"]), timestamp=d["timestamp"])
if t == "transcript":
return TranscriptBlock(transcript_id=d["transcript_id"], start=d["start"], end=d["end"], text=d["text"])
raise ValueError(f"Unknown block type: {t}")
def _msg_to_dict(m: Message) -> dict:
if isinstance(m, UserMessage):
return {
"type": "user",
"id": m.id,
"content": [_block_to_dict(b) for b in m.content],
"timestamp": m.timestamp.isoformat(),
}
if isinstance(m, AssistantMessage):
d = {
"type": "assistant",
"id": m.id,
"content": [_block_to_dict(b) for b in m.content],
"timestamp": m.timestamp.isoformat(),
"model": m.model,
}
if m.token_usage:
d["token_usage"] = {"input_tokens": m.token_usage.input_tokens, "output_tokens": m.token_usage.output_tokens}
return d
if isinstance(m, ToolUse):
return {"type": "tool_use", "id": m.id, "tool_name": m.tool_name, "input": m.input, "status": m.status}
if isinstance(m, ToolResult):
return {"type": "tool_result", "tool_use_id": m.tool_use_id, "output": m.output, "error": m.error}
raise TypeError(f"Unknown message type: {type(m)}")
def _msg_from_dict(d: dict) -> Message:
t = d["type"]
if t == "user":
return UserMessage(
id=d["id"],
content=[_block_from_dict(b) for b in d["content"]],
timestamp=datetime.fromisoformat(d["timestamp"]),
)
if t == "assistant":
tu = None
if "token_usage" in d and d["token_usage"]:
tu = TokenUsage(**d["token_usage"])
return AssistantMessage(
id=d["id"],
content=[_block_from_dict(b) for b in d["content"]],
timestamp=datetime.fromisoformat(d["timestamp"]),
model=d.get("model", ""),
token_usage=tu,
)
if t == "tool_use":
return ToolUse(id=d["id"], tool_name=d["tool_name"], input=d["input"], status=d.get("status", "done"))
if t == "tool_result":
return ToolResult(tool_use_id=d["tool_use_id"], output=d.get("output"), error=d.get("error"))
raise ValueError(f"Unknown message type: {t}")
def thread_to_dict(thread: Thread) -> dict:
d = {
"id": thread.id,
"messages": [_msg_to_dict(m) for m in thread.messages],
"created": thread.created.isoformat(),
"updated": thread.updated.isoformat(),
}
if thread.token_usage:
d["token_usage"] = {"input_tokens": thread.token_usage.input_tokens, "output_tokens": thread.token_usage.output_tokens}
return d
def thread_from_dict(d: dict) -> Thread:
tu = None
if "token_usage" in d and d["token_usage"]:
tu = TokenUsage(**d["token_usage"])
return Thread(
id=d["id"],
messages=[_msg_from_dict(m) for m in d["messages"]],
created=datetime.fromisoformat(d["created"]),
updated=datetime.fromisoformat(d["updated"]),
token_usage=tu,
)
def save_thread(thread: Thread, session_dir: Path) -> None:
agent_dir = session_dir / "agent"
agent_dir.mkdir(parents=True, exist_ok=True)
path = agent_dir / "thread.json"
thread.updated = datetime.now(timezone.utc)
path.write_text(json.dumps(thread_to_dict(thread), indent=2, cls=_Encoder))
def load_thread(session_dir: Path) -> Thread | None:
path = session_dir / "agent" / "thread.json"
if not path.exists():
return None
try:
return thread_from_dict(json.loads(path.read_text()))
except Exception:
return None