extract parser.py and tool_runner.py from agent code

This commit is contained in:
2026-04-16 02:37:19 -03:00
parent 6856d09986
commit 69441fa180
5 changed files with 170 additions and 159 deletions

View File

@@ -12,6 +12,7 @@ from datetime import datetime, timezone
from typing import Any
from agents.shared.mcp_client import MCPMultiClient
from agents.shared.tool_runner import build_tool_caller
async def run_fce(
@@ -65,40 +66,7 @@ async def run_fce(
origin = flight_status.get("origin", "")
destination = flight_status.get("destination", "")
async def _call(server, tool, args, is_live=False, timeout=15.0):
t = time.time()
ctx = lf.start_as_current_observation(
name=tool, as_type="tool", input=args,
metadata={"server": server, "is_live": is_live},
) if lf else None
if ctx:
ctx.__enter__()
try:
result = await asyncio.wait_for(
mcp.call_tool(server, tool, args), timeout=timeout,
)
lat = int((time.time() - t) * 1000)
if ctx:
lf.update_current_span(output=result, metadata={"latency_ms": lat})
ctx.__exit__(None, None, None)
await emit("tool_call_end", tool=tool, latency_ms=lat, is_live=is_live)
return result
except asyncio.TimeoutError:
lat = int((time.time() - t) * 1000)
if ctx:
lf.update_current_span(output={"error": "timeout"}, level="ERROR")
ctx.__exit__(None, None, None)
await emit("tool_call_error", tool=tool, error="timeout", latency_ms=lat)
errors.append(f"{tool}: timeout after {timeout}s")
return None
except Exception as e:
lat = int((time.time() - t) * 1000)
if ctx:
lf.update_current_span(output={"error": str(e)}, level="ERROR")
ctx.__exit__(None, None, None)
await emit("tool_call_error", tool=tool, error=str(e), latency_ms=lat)
errors.append(f"{tool}: {e}")
return None
_call = build_tool_caller(mcp, emit=emit, errors=errors, lf=lf)
# Fire all independent calls in parallel
ops_data_task = asyncio.create_task(

View File

@@ -11,6 +11,7 @@ from datetime import datetime, timezone
from typing import Any
from agents.shared.mcp_client import MCPMultiClient
from agents.shared.tool_runner import build_tool_caller
ALL_HUBS = ["ORD", "EWR", "IAH", "SFO", "DEN"]
@@ -43,40 +44,7 @@ async def run_handover(
await emit("node_enter", node="gather_all")
async def _call(server, tool, args, is_live=False, timeout=15.0):
t = time.time()
ctx = lf.start_as_current_observation(
name=tool, as_type="tool", input=args,
metadata={"server": server, "is_live": is_live},
) if lf else None
if ctx:
ctx.__enter__()
try:
result = await asyncio.wait_for(
mcp.call_tool(server, tool, args), timeout=timeout,
)
lat = int((time.time() - t) * 1000)
if ctx:
lf.update_current_span(output=result, metadata={"latency_ms": lat})
ctx.__exit__(None, None, None)
await emit("tool_call_end", tool=tool, latency_ms=lat, is_live=is_live)
return result
except asyncio.TimeoutError:
lat = int((time.time() - t) * 1000)
if ctx:
lf.update_current_span(output={"error": "timeout"}, level="ERROR")
ctx.__exit__(None, None, None)
await emit("tool_call_error", tool=tool, error="timeout", latency_ms=lat)
errors.append(f"{tool}: timeout after {timeout}s")
return None
except Exception as e:
lat = int((time.time() - t) * 1000)
if ctx:
lf.update_current_span(output={"error": str(e)}, level="ERROR")
ctx.__exit__(None, None, None)
await emit("tool_call_error", tool=tool, error=str(e), latency_ms=lat)
errors.append(f"{tool}: {e}")
return None
_call = build_tool_caller(mcp, emit=emit, errors=errors, lf=lf)
# Per-hub calls
hub_tasks = {}

View File

@@ -4,42 +4,17 @@ Composes the three domain-scoped MCP servers into namespaced configurations
that agents connect to as a single client.
"""
import json
import os
from contextlib import asynccontextmanager
from typing import Any
from fastmcp import Client
def _env() -> dict:
"""Forward LLM-related env vars and active scenario to MCP server subprocesses."""
import os
from mcp_servers.data.scenarios.manager import scenario_manager
env = {}
for key in (
"LLM_PROVIDER", "GROQ_API_KEY", "GROQ_MODEL",
"ANTHROPIC_API_KEY", "ANTHROPIC_MODEL",
"OPENAI_API_KEY", "OPENAI_BASE_URL", "OPENAI_MODEL",
"AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_DEFAULT_REGION",
"BEDROCK_MODEL_ID", "USE_BEDROCK",
"PATH",
):
val = os.getenv(key)
if val:
env[key] = val
env["ACTIVE_SCENARIO"] = scenario_manager.active_id
return env
def _server_config(module: str) -> dict:
"""Build server config with current env vars (called at connect time, not import time)."""
return {
"command": "uv",
"args": ["run", "python", "-m", module],
"env": _env(),
}
from agents.shared.parser import (
parse_prompt_result,
parse_resource_result,
parse_tool_result,
)
SERVER_MODULES = {
"shared": "mcp_servers.shared",
@@ -47,12 +22,33 @@ SERVER_MODULES = {
"passenger": "mcp_servers.passenger",
}
# Agent profiles — which servers each agent connects to
AGENT_PROFILES = {
"fce": ["shared", "ops", "passenger"],
"handover": ["shared", "ops"],
}
_FORWARDED_ENV_KEYS = (
"LLM_PROVIDER", "GROQ_API_KEY", "GROQ_MODEL",
"ANTHROPIC_API_KEY", "ANTHROPIC_MODEL",
"OPENAI_API_KEY", "OPENAI_BASE_URL", "OPENAI_MODEL",
"AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_DEFAULT_REGION",
"BEDROCK_MODEL_ID", "USE_BEDROCK",
"PATH",
)
def _subprocess_env() -> dict:
"""Env vars forwarded to MCP server subprocesses (LLM config + active scenario)."""
from mcp_servers.data.scenarios.manager import scenario_manager
env = {k: os.environ[k] for k in _FORWARDED_ENV_KEYS if k in os.environ}
env["ACTIVE_SCENARIO"] = scenario_manager.active_id
return env
def _server_config(module: str) -> dict:
return {"command": "uv", "args": ["run", "python", "-m", module], "env": _subprocess_env()}
class MCPMultiClient:
"""Manages connections to multiple MCP servers via fastmcp Client."""
@@ -61,7 +57,6 @@ class MCPMultiClient:
self._clients: dict[str, Client] = {}
async def connect(self, server_names: list[str]) -> None:
"""Connect to the specified MCP servers."""
for name in server_names:
if name not in SERVER_MODULES:
raise ValueError(f"Unknown server: {name}. Available: {list(SERVER_MODULES.keys())}")
@@ -71,7 +66,6 @@ class MCPMultiClient:
self._clients[name] = client
async def close(self) -> None:
"""Close all server connections."""
for client in self._clients.values():
try:
await client.__aexit__(None, None, None)
@@ -79,71 +73,23 @@ class MCPMultiClient:
pass
self._clients.clear()
async def call_tool(self, server: str, tool_name: str, arguments: dict) -> Any:
"""Call a tool on a specific server. Returns parsed result."""
def _client(self, server: str) -> Client:
client = self._clients.get(server)
if not client:
raise ValueError(f"Not connected to server: {server}")
return client
result = await client.call_tool(tool_name, arguments)
# Parse the result content
if isinstance(result, list):
texts = [c.text for c in result if hasattr(c, "text")]
elif hasattr(result, "content"):
texts = [c.text for c in result.content if hasattr(c, "text")]
else:
return result
if len(texts) == 1:
try:
return json.loads(texts[0])
except (json.JSONDecodeError, TypeError):
return texts[0]
elif len(texts) > 1:
parsed = []
for t in texts:
try:
parsed.append(json.loads(t))
except (json.JSONDecodeError, TypeError):
parsed.append(t)
return parsed
return None
async def call_tool(self, server: str, tool_name: str, arguments: dict) -> Any:
result = await self._client(server).call_tool(tool_name, arguments)
return parse_tool_result(result)
async def read_resource(self, server: str, uri: str) -> Any:
"""Read a resource from a specific server."""
client = self._clients.get(server)
if not client:
raise ValueError(f"Not connected to server: {server}")
result = await client.read_resource(uri)
if isinstance(result, str):
try:
return json.loads(result)
except (json.JSONDecodeError, TypeError):
return result
return result
result = await self._client(server).read_resource(uri)
return parse_resource_result(result)
async def get_prompt(self, server: str, prompt_name: str, arguments: dict) -> str:
"""Get a rendered prompt from a specific server."""
client = self._clients.get(server)
if not client:
raise ValueError(f"Not connected to server: {server}")
result = await client.get_prompt(prompt_name, arguments)
if isinstance(result, str):
return result
# Handle structured prompt response
texts = []
if hasattr(result, "messages"):
for msg in result.messages:
if hasattr(msg.content, "text"):
texts.append(msg.content.text)
elif isinstance(msg.content, list):
for c in msg.content:
if hasattr(c, "text"):
texts.append(c.text)
return "\n".join(texts) if texts else str(result)
result = await self._client(server).get_prompt(prompt_name, arguments)
return parse_prompt_result(result)
@asynccontextmanager

56
agents/shared/parser.py Normal file
View File

@@ -0,0 +1,56 @@
"""Parsers for MCP response content (tools, resources, prompts)."""
import json
from typing import Any
def _extract_texts(result: Any) -> list[str]:
"""Pull text fields out of an MCP response, handling list/content/str shapes."""
if isinstance(result, list):
return [c.text for c in result if hasattr(c, "text")]
if hasattr(result, "content"):
return [c.text for c in result.content if hasattr(c, "text")]
return []
def _maybe_json(text: str) -> Any:
"""Parse JSON if possible; fall back to raw text."""
try:
return json.loads(text)
except (json.JSONDecodeError, TypeError):
return text
def parse_tool_result(result: Any) -> Any:
"""Parse `client.call_tool()` result into native Python structure."""
texts = _extract_texts(result)
if not texts:
return result
if len(texts) == 1:
return _maybe_json(texts[0])
return [_maybe_json(t) for t in texts]
def parse_resource_result(result: Any) -> Any:
"""Parse `client.read_resource()` result — often a JSON-encoded string."""
if isinstance(result, str):
return _maybe_json(result)
return result
def parse_prompt_result(result: Any) -> str:
"""Flatten a prompt response into a single string."""
if isinstance(result, str):
return result
if not hasattr(result, "messages"):
return str(result)
texts: list[str] = []
for msg in result.messages:
if hasattr(msg.content, "text"):
texts.append(msg.content.text)
elif isinstance(msg.content, list):
for c in msg.content:
if hasattr(c, "text"):
texts.append(c.text)
return "\n".join(texts) if texts else str(result)

View File

@@ -0,0 +1,73 @@
"""Shared tool-call helper: timeout, tracing, event emission, error collection.
Every agent invokes MCP tools through this so the three concerns
(observability, resilience, error state) stay in one place.
"""
import asyncio
import time
from typing import Any, Awaitable, Callable
EmitFn = Callable[..., Awaitable[None]]
def build_tool_caller(
mcp,
*,
emit: EmitFn,
errors: list[str],
lf: Any = None,
timeout: float = 15.0,
):
"""Return an async `_call(server, tool, args, is_live=False)` closure.
On each invocation:
- opens a Langfuse span (if `lf` is provided)
- runs `mcp.call_tool` with an `asyncio.wait_for` timeout
- emits `tool_call_end` / `tool_call_error` events
- appends to `errors` on failure, returns `None`
- closes the Langfuse span with output or error level
"""
async def _call(server: str, tool: str, args: dict, is_live: bool = False) -> Any:
t0 = time.time()
span_ctx = lf.start_as_current_observation(
name=tool, as_type="tool", input=args,
metadata={"server": server, "is_live": is_live},
) if lf else None
if span_ctx:
span_ctx.__enter__()
def _close_span(output: Any, level: str | None = None):
if not span_ctx:
return
kwargs: dict[str, Any] = {"output": output}
if level:
kwargs["level"] = level
else:
kwargs["metadata"] = {"latency_ms": int((time.time() - t0) * 1000)}
lf.update_current_span(**kwargs)
span_ctx.__exit__(None, None, None)
try:
result = await asyncio.wait_for(mcp.call_tool(server, tool, args), timeout=timeout)
except asyncio.TimeoutError:
latency = int((time.time() - t0) * 1000)
_close_span({"error": "timeout"}, level="ERROR")
await emit("tool_call_error", tool=tool, error="timeout", latency_ms=latency)
errors.append(f"{tool}: timeout after {timeout}s")
return None
except Exception as e:
latency = int((time.time() - t0) * 1000)
_close_span({"error": str(e)}, level="ERROR")
await emit("tool_call_error", tool=tool, error=str(e), latency_ms=latency)
errors.append(f"{tool}: {e}")
return None
latency = int((time.time() - t0) * 1000)
_close_span(result)
await emit("tool_call_end", tool=tool, latency_ms=latency, is_live=is_live)
return result
return _call