better agent
This commit is contained in:
@@ -1,16 +1,28 @@
|
||||
"""
|
||||
Abstract base for agent providers.
|
||||
Agent data model — structured messages, connections, tools, threads.
|
||||
|
||||
Each provider takes a user message + session context and yields response
|
||||
text chunks for streaming into the UI.
|
||||
Replaces the old flat AgentProvider/SessionContext with:
|
||||
- Typed content blocks and messages (conversation model)
|
||||
- StreamEvent types (connection output)
|
||||
- AgentConnection protocol (replaces AgentProvider)
|
||||
- Tool protocol
|
||||
- Thread (session state, serializable)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
from typing import Iterator, Literal, Protocol, runtime_checkable
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared refs (used by runner mention-parsing and tools)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class FrameRef:
|
||||
id: str # "F0001"
|
||||
@@ -26,40 +38,298 @@ class TranscriptRef:
|
||||
text: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Content blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class SessionContext:
|
||||
class TextBlock:
|
||||
text: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageBlock:
|
||||
frame_id: str # "F0042"
|
||||
path: Path
|
||||
timestamp: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptBlock:
|
||||
transcript_id: str # "T0012"
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
|
||||
|
||||
ContentBlock = TextBlock | ImageBlock | TranscriptBlock
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class TokenUsage:
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
|
||||
|
||||
def _msg_id() -> str:
|
||||
return uuid.uuid4().hex[:12]
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserMessage:
|
||||
content: list[ContentBlock]
|
||||
id: str = field(default_factory=_msg_id)
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssistantMessage:
|
||||
content: list[ContentBlock]
|
||||
model: str = ""
|
||||
id: str = field(default_factory=_msg_id)
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
token_usage: TokenUsage | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolUse:
|
||||
tool_name: str
|
||||
input: dict
|
||||
id: str = field(default_factory=_msg_id)
|
||||
status: Literal["pending", "running", "done", "error"] = "pending"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
tool_use_id: str
|
||||
output: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
Message = UserMessage | AssistantMessage | ToolUse | ToolResult
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread (session state)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class Thread:
|
||||
messages: list[Message] = field(default_factory=list)
|
||||
id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
|
||||
created: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
token_usage: TokenUsage | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stream events (yielded by AgentConnection)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class TextDelta:
|
||||
text: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallStart:
|
||||
id: str
|
||||
name: str
|
||||
input: dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallEnd:
|
||||
id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Done:
|
||||
stop_reason: str # "end_turn", "tool_use", "max_tokens"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Error:
|
||||
message: str
|
||||
|
||||
|
||||
StreamEvent = TextDelta | ToolCallStart | ToolCallEnd | Done | Error
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool protocol
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@runtime_checkable
|
||||
class Tool(Protocol):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
def input_schema(self) -> dict: ...
|
||||
def run(self, input: dict, context: ToolContext) -> ToolResult: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolContext:
|
||||
"""Runtime context passed to tools — references to live app objects."""
|
||||
session_dir: Path
|
||||
frames: list[FrameRef] # all captured frames so far
|
||||
duration: float # current recording duration (seconds)
|
||||
mentioned_frames: list[FrameRef] = field(default_factory=list)
|
||||
transcript_segments: list[TranscriptRef] = field(default_factory=list)
|
||||
mentioned_transcripts: list[TranscriptRef] = field(default_factory=list)
|
||||
history: list[tuple[str, str]] = field(default_factory=list) # [(role, text), ...]
|
||||
frames_dir: Path
|
||||
transcript_dir: Path
|
||||
stream_mgr: object | None = None # StreamManager, optional
|
||||
tracker: object | None = None # RecordingTracker, optional
|
||||
|
||||
|
||||
class AgentProvider(ABC):
|
||||
@abstractmethod
|
||||
def stream(self, message: str, context: SessionContext) -> Iterator[str]:
|
||||
"""Yield response text chunks."""
|
||||
# ---------------------------------------------------------------------------
|
||||
# AgentConnection protocol (replaces AgentProvider)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@runtime_checkable
|
||||
class AgentConnection(Protocol):
|
||||
name: str
|
||||
|
||||
def available_models(self) -> list[str]: ...
|
||||
def get_model(self) -> str: ...
|
||||
def set_model(self, model: str) -> None: ...
|
||||
|
||||
def prompt(
|
||||
self,
|
||||
messages: list[Message],
|
||||
tools: list[Tool],
|
||||
) -> Iterator[StreamEvent]:
|
||||
"""Send messages to the model, yield stream events."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
...
|
||||
def cancel(self) -> None: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def available_models(self) -> list[str]:
|
||||
"""Return list of model IDs this provider supports."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model(self) -> str:
|
||||
...
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread serialization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@model.setter
|
||||
@abstractmethod
|
||||
def model(self, value: str):
|
||||
...
|
||||
class _Encoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if isinstance(o, Path):
|
||||
return str(o)
|
||||
if isinstance(o, datetime):
|
||||
return o.isoformat()
|
||||
return super().default(o)
|
||||
|
||||
|
||||
def _block_to_dict(b: ContentBlock) -> dict:
|
||||
if isinstance(b, TextBlock):
|
||||
return {"type": "text", "text": b.text}
|
||||
if isinstance(b, ImageBlock):
|
||||
return {"type": "image", "frame_id": b.frame_id, "path": str(b.path), "timestamp": b.timestamp}
|
||||
if isinstance(b, TranscriptBlock):
|
||||
return {"type": "transcript", "transcript_id": b.transcript_id, "start": b.start, "end": b.end, "text": b.text}
|
||||
raise TypeError(f"Unknown block type: {type(b)}")
|
||||
|
||||
|
||||
def _block_from_dict(d: dict) -> ContentBlock:
|
||||
t = d["type"]
|
||||
if t == "text":
|
||||
return TextBlock(text=d["text"])
|
||||
if t == "image":
|
||||
return ImageBlock(frame_id=d["frame_id"], path=Path(d["path"]), timestamp=d["timestamp"])
|
||||
if t == "transcript":
|
||||
return TranscriptBlock(transcript_id=d["transcript_id"], start=d["start"], end=d["end"], text=d["text"])
|
||||
raise ValueError(f"Unknown block type: {t}")
|
||||
|
||||
|
||||
def _msg_to_dict(m: Message) -> dict:
|
||||
if isinstance(m, UserMessage):
|
||||
return {
|
||||
"type": "user",
|
||||
"id": m.id,
|
||||
"content": [_block_to_dict(b) for b in m.content],
|
||||
"timestamp": m.timestamp.isoformat(),
|
||||
}
|
||||
if isinstance(m, AssistantMessage):
|
||||
d = {
|
||||
"type": "assistant",
|
||||
"id": m.id,
|
||||
"content": [_block_to_dict(b) for b in m.content],
|
||||
"timestamp": m.timestamp.isoformat(),
|
||||
"model": m.model,
|
||||
}
|
||||
if m.token_usage:
|
||||
d["token_usage"] = {"input_tokens": m.token_usage.input_tokens, "output_tokens": m.token_usage.output_tokens}
|
||||
return d
|
||||
if isinstance(m, ToolUse):
|
||||
return {"type": "tool_use", "id": m.id, "tool_name": m.tool_name, "input": m.input, "status": m.status}
|
||||
if isinstance(m, ToolResult):
|
||||
return {"type": "tool_result", "tool_use_id": m.tool_use_id, "output": m.output, "error": m.error}
|
||||
raise TypeError(f"Unknown message type: {type(m)}")
|
||||
|
||||
|
||||
def _msg_from_dict(d: dict) -> Message:
|
||||
t = d["type"]
|
||||
if t == "user":
|
||||
return UserMessage(
|
||||
id=d["id"],
|
||||
content=[_block_from_dict(b) for b in d["content"]],
|
||||
timestamp=datetime.fromisoformat(d["timestamp"]),
|
||||
)
|
||||
if t == "assistant":
|
||||
tu = None
|
||||
if "token_usage" in d and d["token_usage"]:
|
||||
tu = TokenUsage(**d["token_usage"])
|
||||
return AssistantMessage(
|
||||
id=d["id"],
|
||||
content=[_block_from_dict(b) for b in d["content"]],
|
||||
timestamp=datetime.fromisoformat(d["timestamp"]),
|
||||
model=d.get("model", ""),
|
||||
token_usage=tu,
|
||||
)
|
||||
if t == "tool_use":
|
||||
return ToolUse(id=d["id"], tool_name=d["tool_name"], input=d["input"], status=d.get("status", "done"))
|
||||
if t == "tool_result":
|
||||
return ToolResult(tool_use_id=d["tool_use_id"], output=d.get("output"), error=d.get("error"))
|
||||
raise ValueError(f"Unknown message type: {t}")
|
||||
|
||||
|
||||
def thread_to_dict(thread: Thread) -> dict:
|
||||
d = {
|
||||
"id": thread.id,
|
||||
"messages": [_msg_to_dict(m) for m in thread.messages],
|
||||
"created": thread.created.isoformat(),
|
||||
"updated": thread.updated.isoformat(),
|
||||
}
|
||||
if thread.token_usage:
|
||||
d["token_usage"] = {"input_tokens": thread.token_usage.input_tokens, "output_tokens": thread.token_usage.output_tokens}
|
||||
return d
|
||||
|
||||
|
||||
def thread_from_dict(d: dict) -> Thread:
|
||||
tu = None
|
||||
if "token_usage" in d and d["token_usage"]:
|
||||
tu = TokenUsage(**d["token_usage"])
|
||||
return Thread(
|
||||
id=d["id"],
|
||||
messages=[_msg_from_dict(m) for m in d["messages"]],
|
||||
created=datetime.fromisoformat(d["created"]),
|
||||
updated=datetime.fromisoformat(d["updated"]),
|
||||
token_usage=tu,
|
||||
)
|
||||
|
||||
|
||||
def save_thread(thread: Thread, session_dir: Path) -> None:
|
||||
agent_dir = session_dir / "agent"
|
||||
agent_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = agent_dir / "thread.json"
|
||||
thread.updated = datetime.now(timezone.utc)
|
||||
path.write_text(json.dumps(thread_to_dict(thread), indent=2, cls=_Encoder))
|
||||
|
||||
|
||||
def load_thread(session_dir: Path) -> Thread | None:
|
||||
path = session_dir / "agent" / "thread.json"
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
return thread_from_dict(json.loads(path.read_text()))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user