extract parser.py and tool_runner.py from agent code

This commit is contained in:
2026-04-16 02:37:19 -03:00
parent 6856d09986
commit 69441fa180
5 changed files with 170 additions and 159 deletions

View File

@@ -12,6 +12,7 @@ from datetime import datetime, timezone
from typing import Any from typing import Any
from agents.shared.mcp_client import MCPMultiClient from agents.shared.mcp_client import MCPMultiClient
from agents.shared.tool_runner import build_tool_caller
async def run_fce( async def run_fce(
@@ -65,40 +66,7 @@ async def run_fce(
origin = flight_status.get("origin", "") origin = flight_status.get("origin", "")
destination = flight_status.get("destination", "") destination = flight_status.get("destination", "")
async def _call(server, tool, args, is_live=False, timeout=15.0): _call = build_tool_caller(mcp, emit=emit, errors=errors, lf=lf)
t = time.time()
ctx = lf.start_as_current_observation(
name=tool, as_type="tool", input=args,
metadata={"server": server, "is_live": is_live},
) if lf else None
if ctx:
ctx.__enter__()
try:
result = await asyncio.wait_for(
mcp.call_tool(server, tool, args), timeout=timeout,
)
lat = int((time.time() - t) * 1000)
if ctx:
lf.update_current_span(output=result, metadata={"latency_ms": lat})
ctx.__exit__(None, None, None)
await emit("tool_call_end", tool=tool, latency_ms=lat, is_live=is_live)
return result
except asyncio.TimeoutError:
lat = int((time.time() - t) * 1000)
if ctx:
lf.update_current_span(output={"error": "timeout"}, level="ERROR")
ctx.__exit__(None, None, None)
await emit("tool_call_error", tool=tool, error="timeout", latency_ms=lat)
errors.append(f"{tool}: timeout after {timeout}s")
return None
except Exception as e:
lat = int((time.time() - t) * 1000)
if ctx:
lf.update_current_span(output={"error": str(e)}, level="ERROR")
ctx.__exit__(None, None, None)
await emit("tool_call_error", tool=tool, error=str(e), latency_ms=lat)
errors.append(f"{tool}: {e}")
return None
# Fire all independent calls in parallel # Fire all independent calls in parallel
ops_data_task = asyncio.create_task( ops_data_task = asyncio.create_task(

View File

@@ -11,6 +11,7 @@ from datetime import datetime, timezone
from typing import Any from typing import Any
from agents.shared.mcp_client import MCPMultiClient from agents.shared.mcp_client import MCPMultiClient
from agents.shared.tool_runner import build_tool_caller
ALL_HUBS = ["ORD", "EWR", "IAH", "SFO", "DEN"] ALL_HUBS = ["ORD", "EWR", "IAH", "SFO", "DEN"]
@@ -43,40 +44,7 @@ async def run_handover(
await emit("node_enter", node="gather_all") await emit("node_enter", node="gather_all")
async def _call(server, tool, args, is_live=False, timeout=15.0): _call = build_tool_caller(mcp, emit=emit, errors=errors, lf=lf)
t = time.time()
ctx = lf.start_as_current_observation(
name=tool, as_type="tool", input=args,
metadata={"server": server, "is_live": is_live},
) if lf else None
if ctx:
ctx.__enter__()
try:
result = await asyncio.wait_for(
mcp.call_tool(server, tool, args), timeout=timeout,
)
lat = int((time.time() - t) * 1000)
if ctx:
lf.update_current_span(output=result, metadata={"latency_ms": lat})
ctx.__exit__(None, None, None)
await emit("tool_call_end", tool=tool, latency_ms=lat, is_live=is_live)
return result
except asyncio.TimeoutError:
lat = int((time.time() - t) * 1000)
if ctx:
lf.update_current_span(output={"error": "timeout"}, level="ERROR")
ctx.__exit__(None, None, None)
await emit("tool_call_error", tool=tool, error="timeout", latency_ms=lat)
errors.append(f"{tool}: timeout after {timeout}s")
return None
except Exception as e:
lat = int((time.time() - t) * 1000)
if ctx:
lf.update_current_span(output={"error": str(e)}, level="ERROR")
ctx.__exit__(None, None, None)
await emit("tool_call_error", tool=tool, error=str(e), latency_ms=lat)
errors.append(f"{tool}: {e}")
return None
# Per-hub calls # Per-hub calls
hub_tasks = {} hub_tasks = {}

View File

@@ -4,42 +4,17 @@ Composes the three domain-scoped MCP servers into namespaced configurations
that agents connect to as a single client. that agents connect to as a single client.
""" """
import json import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any from typing import Any
from fastmcp import Client from fastmcp import Client
from agents.shared.parser import (
def _env() -> dict: parse_prompt_result,
"""Forward LLM-related env vars and active scenario to MCP server subprocesses.""" parse_resource_result,
import os parse_tool_result,
from mcp_servers.data.scenarios.manager import scenario_manager )
env = {}
for key in (
"LLM_PROVIDER", "GROQ_API_KEY", "GROQ_MODEL",
"ANTHROPIC_API_KEY", "ANTHROPIC_MODEL",
"OPENAI_API_KEY", "OPENAI_BASE_URL", "OPENAI_MODEL",
"AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_DEFAULT_REGION",
"BEDROCK_MODEL_ID", "USE_BEDROCK",
"PATH",
):
val = os.getenv(key)
if val:
env[key] = val
env["ACTIVE_SCENARIO"] = scenario_manager.active_id
return env
def _server_config(module: str) -> dict:
"""Build server config with current env vars (called at connect time, not import time)."""
return {
"command": "uv",
"args": ["run", "python", "-m", module],
"env": _env(),
}
SERVER_MODULES = { SERVER_MODULES = {
"shared": "mcp_servers.shared", "shared": "mcp_servers.shared",
@@ -47,12 +22,33 @@ SERVER_MODULES = {
"passenger": "mcp_servers.passenger", "passenger": "mcp_servers.passenger",
} }
# Agent profiles — which servers each agent connects to
AGENT_PROFILES = { AGENT_PROFILES = {
"fce": ["shared", "ops", "passenger"], "fce": ["shared", "ops", "passenger"],
"handover": ["shared", "ops"], "handover": ["shared", "ops"],
} }
_FORWARDED_ENV_KEYS = (
"LLM_PROVIDER", "GROQ_API_KEY", "GROQ_MODEL",
"ANTHROPIC_API_KEY", "ANTHROPIC_MODEL",
"OPENAI_API_KEY", "OPENAI_BASE_URL", "OPENAI_MODEL",
"AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_DEFAULT_REGION",
"BEDROCK_MODEL_ID", "USE_BEDROCK",
"PATH",
)
def _subprocess_env() -> dict:
"""Env vars forwarded to MCP server subprocesses (LLM config + active scenario)."""
from mcp_servers.data.scenarios.manager import scenario_manager
env = {k: os.environ[k] for k in _FORWARDED_ENV_KEYS if k in os.environ}
env["ACTIVE_SCENARIO"] = scenario_manager.active_id
return env
def _server_config(module: str) -> dict:
return {"command": "uv", "args": ["run", "python", "-m", module], "env": _subprocess_env()}
class MCPMultiClient: class MCPMultiClient:
"""Manages connections to multiple MCP servers via fastmcp Client.""" """Manages connections to multiple MCP servers via fastmcp Client."""
@@ -61,7 +57,6 @@ class MCPMultiClient:
self._clients: dict[str, Client] = {} self._clients: dict[str, Client] = {}
async def connect(self, server_names: list[str]) -> None: async def connect(self, server_names: list[str]) -> None:
"""Connect to the specified MCP servers."""
for name in server_names: for name in server_names:
if name not in SERVER_MODULES: if name not in SERVER_MODULES:
raise ValueError(f"Unknown server: {name}. Available: {list(SERVER_MODULES.keys())}") raise ValueError(f"Unknown server: {name}. Available: {list(SERVER_MODULES.keys())}")
@@ -71,7 +66,6 @@ class MCPMultiClient:
self._clients[name] = client self._clients[name] = client
async def close(self) -> None: async def close(self) -> None:
"""Close all server connections."""
for client in self._clients.values(): for client in self._clients.values():
try: try:
await client.__aexit__(None, None, None) await client.__aexit__(None, None, None)
@@ -79,71 +73,23 @@ class MCPMultiClient:
pass pass
self._clients.clear() self._clients.clear()
async def call_tool(self, server: str, tool_name: str, arguments: dict) -> Any: def _client(self, server: str) -> Client:
"""Call a tool on a specific server. Returns parsed result."""
client = self._clients.get(server) client = self._clients.get(server)
if not client: if not client:
raise ValueError(f"Not connected to server: {server}") raise ValueError(f"Not connected to server: {server}")
return client
result = await client.call_tool(tool_name, arguments) async def call_tool(self, server: str, tool_name: str, arguments: dict) -> Any:
result = await self._client(server).call_tool(tool_name, arguments)
# Parse the result content return parse_tool_result(result)
if isinstance(result, list):
texts = [c.text for c in result if hasattr(c, "text")]
elif hasattr(result, "content"):
texts = [c.text for c in result.content if hasattr(c, "text")]
else:
return result
if len(texts) == 1:
try:
return json.loads(texts[0])
except (json.JSONDecodeError, TypeError):
return texts[0]
elif len(texts) > 1:
parsed = []
for t in texts:
try:
parsed.append(json.loads(t))
except (json.JSONDecodeError, TypeError):
parsed.append(t)
return parsed
return None
async def read_resource(self, server: str, uri: str) -> Any: async def read_resource(self, server: str, uri: str) -> Any:
"""Read a resource from a specific server.""" result = await self._client(server).read_resource(uri)
client = self._clients.get(server) return parse_resource_result(result)
if not client:
raise ValueError(f"Not connected to server: {server}")
result = await client.read_resource(uri)
if isinstance(result, str):
try:
return json.loads(result)
except (json.JSONDecodeError, TypeError):
return result
return result
async def get_prompt(self, server: str, prompt_name: str, arguments: dict) -> str: async def get_prompt(self, server: str, prompt_name: str, arguments: dict) -> str:
"""Get a rendered prompt from a specific server.""" result = await self._client(server).get_prompt(prompt_name, arguments)
client = self._clients.get(server) return parse_prompt_result(result)
if not client:
raise ValueError(f"Not connected to server: {server}")
result = await client.get_prompt(prompt_name, arguments)
if isinstance(result, str):
return result
# Handle structured prompt response
texts = []
if hasattr(result, "messages"):
for msg in result.messages:
if hasattr(msg.content, "text"):
texts.append(msg.content.text)
elif isinstance(msg.content, list):
for c in msg.content:
if hasattr(c, "text"):
texts.append(c.text)
return "\n".join(texts) if texts else str(result)
@asynccontextmanager @asynccontextmanager

56
agents/shared/parser.py Normal file
View File

@@ -0,0 +1,56 @@
"""Parsers for MCP response content (tools, resources, prompts)."""
import json
from typing import Any
def _extract_texts(result: Any) -> list[str]:
"""Pull text fields out of an MCP response, handling list/content/str shapes."""
if isinstance(result, list):
return [c.text for c in result if hasattr(c, "text")]
if hasattr(result, "content"):
return [c.text for c in result.content if hasattr(c, "text")]
return []
def _maybe_json(text: str) -> Any:
"""Parse JSON if possible; fall back to raw text."""
try:
return json.loads(text)
except (json.JSONDecodeError, TypeError):
return text
def parse_tool_result(result: Any) -> Any:
"""Parse `client.call_tool()` result into native Python structure."""
texts = _extract_texts(result)
if not texts:
return result
if len(texts) == 1:
return _maybe_json(texts[0])
return [_maybe_json(t) for t in texts]
def parse_resource_result(result: Any) -> Any:
"""Parse `client.read_resource()` result — often a JSON-encoded string."""
if isinstance(result, str):
return _maybe_json(result)
return result
def parse_prompt_result(result: Any) -> str:
"""Flatten a prompt response into a single string."""
if isinstance(result, str):
return result
if not hasattr(result, "messages"):
return str(result)
texts: list[str] = []
for msg in result.messages:
if hasattr(msg.content, "text"):
texts.append(msg.content.text)
elif isinstance(msg.content, list):
for c in msg.content:
if hasattr(c, "text"):
texts.append(c.text)
return "\n".join(texts) if texts else str(result)

View File

@@ -0,0 +1,73 @@
"""Shared tool-call helper: timeout, tracing, event emission, error collection.
Every agent invokes MCP tools through this so the three concerns
(observability, resilience, error state) stay in one place.
"""
import asyncio
import time
from typing import Any, Awaitable, Callable
EmitFn = Callable[..., Awaitable[None]]
def build_tool_caller(
mcp,
*,
emit: EmitFn,
errors: list[str],
lf: Any = None,
timeout: float = 15.0,
):
"""Return an async `_call(server, tool, args, is_live=False)` closure.
On each invocation:
- opens a Langfuse span (if `lf` is provided)
- runs `mcp.call_tool` with an `asyncio.wait_for` timeout
- emits `tool_call_end` / `tool_call_error` events
- appends to `errors` on failure, returns `None`
- closes the Langfuse span with output or error level
"""
async def _call(server: str, tool: str, args: dict, is_live: bool = False) -> Any:
t0 = time.time()
span_ctx = lf.start_as_current_observation(
name=tool, as_type="tool", input=args,
metadata={"server": server, "is_live": is_live},
) if lf else None
if span_ctx:
span_ctx.__enter__()
def _close_span(output: Any, level: str | None = None):
if not span_ctx:
return
kwargs: dict[str, Any] = {"output": output}
if level:
kwargs["level"] = level
else:
kwargs["metadata"] = {"latency_ms": int((time.time() - t0) * 1000)}
lf.update_current_span(**kwargs)
span_ctx.__exit__(None, None, None)
try:
result = await asyncio.wait_for(mcp.call_tool(server, tool, args), timeout=timeout)
except asyncio.TimeoutError:
latency = int((time.time() - t0) * 1000)
_close_span({"error": "timeout"}, level="ERROR")
await emit("tool_call_error", tool=tool, error="timeout", latency_ms=latency)
errors.append(f"{tool}: timeout after {timeout}s")
return None
except Exception as e:
latency = int((time.time() - t0) * 1000)
_close_span({"error": str(e)}, level="ERROR")
await emit("tool_call_error", tool=tool, error=str(e), latency_ms=latency)
errors.append(f"{tool}: {e}")
return None
latency = int((time.time() - t0) * 1000)
_close_span(result)
await emit("tool_call_end", tool=tool, latency_ms=latency, is_live=is_live)
return result
return _call