extract parser.py and tool_runner.py from agent code
This commit is contained in:
@@ -12,6 +12,7 @@ from datetime import datetime, timezone
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from agents.shared.mcp_client import MCPMultiClient
|
from agents.shared.mcp_client import MCPMultiClient
|
||||||
|
from agents.shared.tool_runner import build_tool_caller
|
||||||
|
|
||||||
|
|
||||||
async def run_fce(
|
async def run_fce(
|
||||||
@@ -65,40 +66,7 @@ async def run_fce(
|
|||||||
origin = flight_status.get("origin", "")
|
origin = flight_status.get("origin", "")
|
||||||
destination = flight_status.get("destination", "")
|
destination = flight_status.get("destination", "")
|
||||||
|
|
||||||
async def _call(server, tool, args, is_live=False, timeout=15.0):
|
_call = build_tool_caller(mcp, emit=emit, errors=errors, lf=lf)
|
||||||
t = time.time()
|
|
||||||
ctx = lf.start_as_current_observation(
|
|
||||||
name=tool, as_type="tool", input=args,
|
|
||||||
metadata={"server": server, "is_live": is_live},
|
|
||||||
) if lf else None
|
|
||||||
if ctx:
|
|
||||||
ctx.__enter__()
|
|
||||||
try:
|
|
||||||
result = await asyncio.wait_for(
|
|
||||||
mcp.call_tool(server, tool, args), timeout=timeout,
|
|
||||||
)
|
|
||||||
lat = int((time.time() - t) * 1000)
|
|
||||||
if ctx:
|
|
||||||
lf.update_current_span(output=result, metadata={"latency_ms": lat})
|
|
||||||
ctx.__exit__(None, None, None)
|
|
||||||
await emit("tool_call_end", tool=tool, latency_ms=lat, is_live=is_live)
|
|
||||||
return result
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
lat = int((time.time() - t) * 1000)
|
|
||||||
if ctx:
|
|
||||||
lf.update_current_span(output={"error": "timeout"}, level="ERROR")
|
|
||||||
ctx.__exit__(None, None, None)
|
|
||||||
await emit("tool_call_error", tool=tool, error="timeout", latency_ms=lat)
|
|
||||||
errors.append(f"{tool}: timeout after {timeout}s")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
lat = int((time.time() - t) * 1000)
|
|
||||||
if ctx:
|
|
||||||
lf.update_current_span(output={"error": str(e)}, level="ERROR")
|
|
||||||
ctx.__exit__(None, None, None)
|
|
||||||
await emit("tool_call_error", tool=tool, error=str(e), latency_ms=lat)
|
|
||||||
errors.append(f"{tool}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Fire all independent calls in parallel
|
# Fire all independent calls in parallel
|
||||||
ops_data_task = asyncio.create_task(
|
ops_data_task = asyncio.create_task(
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from datetime import datetime, timezone
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from agents.shared.mcp_client import MCPMultiClient
|
from agents.shared.mcp_client import MCPMultiClient
|
||||||
|
from agents.shared.tool_runner import build_tool_caller
|
||||||
|
|
||||||
ALL_HUBS = ["ORD", "EWR", "IAH", "SFO", "DEN"]
|
ALL_HUBS = ["ORD", "EWR", "IAH", "SFO", "DEN"]
|
||||||
|
|
||||||
@@ -43,40 +44,7 @@ async def run_handover(
|
|||||||
|
|
||||||
await emit("node_enter", node="gather_all")
|
await emit("node_enter", node="gather_all")
|
||||||
|
|
||||||
async def _call(server, tool, args, is_live=False, timeout=15.0):
|
_call = build_tool_caller(mcp, emit=emit, errors=errors, lf=lf)
|
||||||
t = time.time()
|
|
||||||
ctx = lf.start_as_current_observation(
|
|
||||||
name=tool, as_type="tool", input=args,
|
|
||||||
metadata={"server": server, "is_live": is_live},
|
|
||||||
) if lf else None
|
|
||||||
if ctx:
|
|
||||||
ctx.__enter__()
|
|
||||||
try:
|
|
||||||
result = await asyncio.wait_for(
|
|
||||||
mcp.call_tool(server, tool, args), timeout=timeout,
|
|
||||||
)
|
|
||||||
lat = int((time.time() - t) * 1000)
|
|
||||||
if ctx:
|
|
||||||
lf.update_current_span(output=result, metadata={"latency_ms": lat})
|
|
||||||
ctx.__exit__(None, None, None)
|
|
||||||
await emit("tool_call_end", tool=tool, latency_ms=lat, is_live=is_live)
|
|
||||||
return result
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
lat = int((time.time() - t) * 1000)
|
|
||||||
if ctx:
|
|
||||||
lf.update_current_span(output={"error": "timeout"}, level="ERROR")
|
|
||||||
ctx.__exit__(None, None, None)
|
|
||||||
await emit("tool_call_error", tool=tool, error="timeout", latency_ms=lat)
|
|
||||||
errors.append(f"{tool}: timeout after {timeout}s")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
lat = int((time.time() - t) * 1000)
|
|
||||||
if ctx:
|
|
||||||
lf.update_current_span(output={"error": str(e)}, level="ERROR")
|
|
||||||
ctx.__exit__(None, None, None)
|
|
||||||
await emit("tool_call_error", tool=tool, error=str(e), latency_ms=lat)
|
|
||||||
errors.append(f"{tool}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Per-hub calls
|
# Per-hub calls
|
||||||
hub_tasks = {}
|
hub_tasks = {}
|
||||||
|
|||||||
@@ -4,42 +4,17 @@ Composes the three domain-scoped MCP servers into namespaced configurations
|
|||||||
that agents connect to as a single client.
|
that agents connect to as a single client.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastmcp import Client
|
from fastmcp import Client
|
||||||
|
|
||||||
|
from agents.shared.parser import (
|
||||||
def _env() -> dict:
|
parse_prompt_result,
|
||||||
"""Forward LLM-related env vars and active scenario to MCP server subprocesses."""
|
parse_resource_result,
|
||||||
import os
|
parse_tool_result,
|
||||||
from mcp_servers.data.scenarios.manager import scenario_manager
|
)
|
||||||
|
|
||||||
env = {}
|
|
||||||
for key in (
|
|
||||||
"LLM_PROVIDER", "GROQ_API_KEY", "GROQ_MODEL",
|
|
||||||
"ANTHROPIC_API_KEY", "ANTHROPIC_MODEL",
|
|
||||||
"OPENAI_API_KEY", "OPENAI_BASE_URL", "OPENAI_MODEL",
|
|
||||||
"AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_DEFAULT_REGION",
|
|
||||||
"BEDROCK_MODEL_ID", "USE_BEDROCK",
|
|
||||||
"PATH",
|
|
||||||
):
|
|
||||||
val = os.getenv(key)
|
|
||||||
if val:
|
|
||||||
env[key] = val
|
|
||||||
env["ACTIVE_SCENARIO"] = scenario_manager.active_id
|
|
||||||
return env
|
|
||||||
|
|
||||||
|
|
||||||
def _server_config(module: str) -> dict:
|
|
||||||
"""Build server config with current env vars (called at connect time, not import time)."""
|
|
||||||
return {
|
|
||||||
"command": "uv",
|
|
||||||
"args": ["run", "python", "-m", module],
|
|
||||||
"env": _env(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
SERVER_MODULES = {
|
SERVER_MODULES = {
|
||||||
"shared": "mcp_servers.shared",
|
"shared": "mcp_servers.shared",
|
||||||
@@ -47,12 +22,33 @@ SERVER_MODULES = {
|
|||||||
"passenger": "mcp_servers.passenger",
|
"passenger": "mcp_servers.passenger",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Agent profiles — which servers each agent connects to
|
|
||||||
AGENT_PROFILES = {
|
AGENT_PROFILES = {
|
||||||
"fce": ["shared", "ops", "passenger"],
|
"fce": ["shared", "ops", "passenger"],
|
||||||
"handover": ["shared", "ops"],
|
"handover": ["shared", "ops"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_FORWARDED_ENV_KEYS = (
|
||||||
|
"LLM_PROVIDER", "GROQ_API_KEY", "GROQ_MODEL",
|
||||||
|
"ANTHROPIC_API_KEY", "ANTHROPIC_MODEL",
|
||||||
|
"OPENAI_API_KEY", "OPENAI_BASE_URL", "OPENAI_MODEL",
|
||||||
|
"AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_DEFAULT_REGION",
|
||||||
|
"BEDROCK_MODEL_ID", "USE_BEDROCK",
|
||||||
|
"PATH",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _subprocess_env() -> dict:
|
||||||
|
"""Env vars forwarded to MCP server subprocesses (LLM config + active scenario)."""
|
||||||
|
from mcp_servers.data.scenarios.manager import scenario_manager
|
||||||
|
|
||||||
|
env = {k: os.environ[k] for k in _FORWARDED_ENV_KEYS if k in os.environ}
|
||||||
|
env["ACTIVE_SCENARIO"] = scenario_manager.active_id
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def _server_config(module: str) -> dict:
|
||||||
|
return {"command": "uv", "args": ["run", "python", "-m", module], "env": _subprocess_env()}
|
||||||
|
|
||||||
|
|
||||||
class MCPMultiClient:
|
class MCPMultiClient:
|
||||||
"""Manages connections to multiple MCP servers via fastmcp Client."""
|
"""Manages connections to multiple MCP servers via fastmcp Client."""
|
||||||
@@ -61,7 +57,6 @@ class MCPMultiClient:
|
|||||||
self._clients: dict[str, Client] = {}
|
self._clients: dict[str, Client] = {}
|
||||||
|
|
||||||
async def connect(self, server_names: list[str]) -> None:
|
async def connect(self, server_names: list[str]) -> None:
|
||||||
"""Connect to the specified MCP servers."""
|
|
||||||
for name in server_names:
|
for name in server_names:
|
||||||
if name not in SERVER_MODULES:
|
if name not in SERVER_MODULES:
|
||||||
raise ValueError(f"Unknown server: {name}. Available: {list(SERVER_MODULES.keys())}")
|
raise ValueError(f"Unknown server: {name}. Available: {list(SERVER_MODULES.keys())}")
|
||||||
@@ -71,7 +66,6 @@ class MCPMultiClient:
|
|||||||
self._clients[name] = client
|
self._clients[name] = client
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Close all server connections."""
|
|
||||||
for client in self._clients.values():
|
for client in self._clients.values():
|
||||||
try:
|
try:
|
||||||
await client.__aexit__(None, None, None)
|
await client.__aexit__(None, None, None)
|
||||||
@@ -79,71 +73,23 @@ class MCPMultiClient:
|
|||||||
pass
|
pass
|
||||||
self._clients.clear()
|
self._clients.clear()
|
||||||
|
|
||||||
async def call_tool(self, server: str, tool_name: str, arguments: dict) -> Any:
|
def _client(self, server: str) -> Client:
|
||||||
"""Call a tool on a specific server. Returns parsed result."""
|
|
||||||
client = self._clients.get(server)
|
client = self._clients.get(server)
|
||||||
if not client:
|
if not client:
|
||||||
raise ValueError(f"Not connected to server: {server}")
|
raise ValueError(f"Not connected to server: {server}")
|
||||||
|
return client
|
||||||
|
|
||||||
result = await client.call_tool(tool_name, arguments)
|
async def call_tool(self, server: str, tool_name: str, arguments: dict) -> Any:
|
||||||
|
result = await self._client(server).call_tool(tool_name, arguments)
|
||||||
# Parse the result content
|
return parse_tool_result(result)
|
||||||
if isinstance(result, list):
|
|
||||||
texts = [c.text for c in result if hasattr(c, "text")]
|
|
||||||
elif hasattr(result, "content"):
|
|
||||||
texts = [c.text for c in result.content if hasattr(c, "text")]
|
|
||||||
else:
|
|
||||||
return result
|
|
||||||
|
|
||||||
if len(texts) == 1:
|
|
||||||
try:
|
|
||||||
return json.loads(texts[0])
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
return texts[0]
|
|
||||||
elif len(texts) > 1:
|
|
||||||
parsed = []
|
|
||||||
for t in texts:
|
|
||||||
try:
|
|
||||||
parsed.append(json.loads(t))
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
parsed.append(t)
|
|
||||||
return parsed
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def read_resource(self, server: str, uri: str) -> Any:
|
async def read_resource(self, server: str, uri: str) -> Any:
|
||||||
"""Read a resource from a specific server."""
|
result = await self._client(server).read_resource(uri)
|
||||||
client = self._clients.get(server)
|
return parse_resource_result(result)
|
||||||
if not client:
|
|
||||||
raise ValueError(f"Not connected to server: {server}")
|
|
||||||
|
|
||||||
result = await client.read_resource(uri)
|
|
||||||
if isinstance(result, str):
|
|
||||||
try:
|
|
||||||
return json.loads(result)
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
return result
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def get_prompt(self, server: str, prompt_name: str, arguments: dict) -> str:
|
async def get_prompt(self, server: str, prompt_name: str, arguments: dict) -> str:
|
||||||
"""Get a rendered prompt from a specific server."""
|
result = await self._client(server).get_prompt(prompt_name, arguments)
|
||||||
client = self._clients.get(server)
|
return parse_prompt_result(result)
|
||||||
if not client:
|
|
||||||
raise ValueError(f"Not connected to server: {server}")
|
|
||||||
|
|
||||||
result = await client.get_prompt(prompt_name, arguments)
|
|
||||||
if isinstance(result, str):
|
|
||||||
return result
|
|
||||||
# Handle structured prompt response
|
|
||||||
texts = []
|
|
||||||
if hasattr(result, "messages"):
|
|
||||||
for msg in result.messages:
|
|
||||||
if hasattr(msg.content, "text"):
|
|
||||||
texts.append(msg.content.text)
|
|
||||||
elif isinstance(msg.content, list):
|
|
||||||
for c in msg.content:
|
|
||||||
if hasattr(c, "text"):
|
|
||||||
texts.append(c.text)
|
|
||||||
return "\n".join(texts) if texts else str(result)
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
|
|||||||
56
agents/shared/parser.py
Normal file
56
agents/shared/parser.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""Parsers for MCP response content (tools, resources, prompts)."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_texts(result: Any) -> list[str]:
|
||||||
|
"""Pull text fields out of an MCP response, handling list/content/str shapes."""
|
||||||
|
if isinstance(result, list):
|
||||||
|
return [c.text for c in result if hasattr(c, "text")]
|
||||||
|
if hasattr(result, "content"):
|
||||||
|
return [c.text for c in result.content if hasattr(c, "text")]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_json(text: str) -> Any:
|
||||||
|
"""Parse JSON if possible; fall back to raw text."""
|
||||||
|
try:
|
||||||
|
return json.loads(text)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def parse_tool_result(result: Any) -> Any:
|
||||||
|
"""Parse `client.call_tool()` result into native Python structure."""
|
||||||
|
texts = _extract_texts(result)
|
||||||
|
if not texts:
|
||||||
|
return result
|
||||||
|
if len(texts) == 1:
|
||||||
|
return _maybe_json(texts[0])
|
||||||
|
return [_maybe_json(t) for t in texts]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_resource_result(result: Any) -> Any:
|
||||||
|
"""Parse `client.read_resource()` result — often a JSON-encoded string."""
|
||||||
|
if isinstance(result, str):
|
||||||
|
return _maybe_json(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def parse_prompt_result(result: Any) -> str:
|
||||||
|
"""Flatten a prompt response into a single string."""
|
||||||
|
if isinstance(result, str):
|
||||||
|
return result
|
||||||
|
if not hasattr(result, "messages"):
|
||||||
|
return str(result)
|
||||||
|
|
||||||
|
texts: list[str] = []
|
||||||
|
for msg in result.messages:
|
||||||
|
if hasattr(msg.content, "text"):
|
||||||
|
texts.append(msg.content.text)
|
||||||
|
elif isinstance(msg.content, list):
|
||||||
|
for c in msg.content:
|
||||||
|
if hasattr(c, "text"):
|
||||||
|
texts.append(c.text)
|
||||||
|
return "\n".join(texts) if texts else str(result)
|
||||||
73
agents/shared/tool_runner.py
Normal file
73
agents/shared/tool_runner.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""Shared tool-call helper: timeout, tracing, event emission, error collection.
|
||||||
|
|
||||||
|
Every agent invokes MCP tools through this so the three concerns
|
||||||
|
(observability, resilience, error state) stay in one place.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import Any, Awaitable, Callable
|
||||||
|
|
||||||
|
|
||||||
|
EmitFn = Callable[..., Awaitable[None]]
|
||||||
|
|
||||||
|
|
||||||
|
def build_tool_caller(
|
||||||
|
mcp,
|
||||||
|
*,
|
||||||
|
emit: EmitFn,
|
||||||
|
errors: list[str],
|
||||||
|
lf: Any = None,
|
||||||
|
timeout: float = 15.0,
|
||||||
|
):
|
||||||
|
"""Return an async `_call(server, tool, args, is_live=False)` closure.
|
||||||
|
|
||||||
|
On each invocation:
|
||||||
|
- opens a Langfuse span (if `lf` is provided)
|
||||||
|
- runs `mcp.call_tool` with an `asyncio.wait_for` timeout
|
||||||
|
- emits `tool_call_end` / `tool_call_error` events
|
||||||
|
- appends to `errors` on failure, returns `None`
|
||||||
|
- closes the Langfuse span with output or error level
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _call(server: str, tool: str, args: dict, is_live: bool = False) -> Any:
|
||||||
|
t0 = time.time()
|
||||||
|
span_ctx = lf.start_as_current_observation(
|
||||||
|
name=tool, as_type="tool", input=args,
|
||||||
|
metadata={"server": server, "is_live": is_live},
|
||||||
|
) if lf else None
|
||||||
|
if span_ctx:
|
||||||
|
span_ctx.__enter__()
|
||||||
|
|
||||||
|
def _close_span(output: Any, level: str | None = None):
|
||||||
|
if not span_ctx:
|
||||||
|
return
|
||||||
|
kwargs: dict[str, Any] = {"output": output}
|
||||||
|
if level:
|
||||||
|
kwargs["level"] = level
|
||||||
|
else:
|
||||||
|
kwargs["metadata"] = {"latency_ms": int((time.time() - t0) * 1000)}
|
||||||
|
lf.update_current_span(**kwargs)
|
||||||
|
span_ctx.__exit__(None, None, None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(mcp.call_tool(server, tool, args), timeout=timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
latency = int((time.time() - t0) * 1000)
|
||||||
|
_close_span({"error": "timeout"}, level="ERROR")
|
||||||
|
await emit("tool_call_error", tool=tool, error="timeout", latency_ms=latency)
|
||||||
|
errors.append(f"{tool}: timeout after {timeout}s")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
latency = int((time.time() - t0) * 1000)
|
||||||
|
_close_span({"error": str(e)}, level="ERROR")
|
||||||
|
await emit("tool_call_error", tool=tool, error=str(e), latency_ms=latency)
|
||||||
|
errors.append(f"{tool}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
latency = int((time.time() - t0) * 1000)
|
||||||
|
_close_span(result)
|
||||||
|
await emit("tool_call_end", tool=tool, latency_ms=latency, is_live=is_live)
|
||||||
|
return result
|
||||||
|
|
||||||
|
return _call
|
||||||
Reference in New Issue
Block a user