From 69441fa1802d46fc584dfc0f7ea8a62b59a565be Mon Sep 17 00:00:00 2001 From: buenosairesam Date: Thu, 16 Apr 2026 02:37:19 -0300 Subject: [PATCH] extract parser.py and tool_runner.py from agent code --- agents/fce.py | 36 +--------- agents/handover.py | 36 +--------- agents/shared/mcp_client.py | 128 ++++++++++------------------------- agents/shared/parser.py | 56 +++++++++++++++ agents/shared/tool_runner.py | 73 ++++++++++++++++++++ 5 files changed, 170 insertions(+), 159 deletions(-) create mode 100644 agents/shared/parser.py create mode 100644 agents/shared/tool_runner.py diff --git a/agents/fce.py b/agents/fce.py index 7d4525a..cac77c7 100644 --- a/agents/fce.py +++ b/agents/fce.py @@ -12,6 +12,7 @@ from datetime import datetime, timezone from typing import Any from agents.shared.mcp_client import MCPMultiClient +from agents.shared.tool_runner import build_tool_caller async def run_fce( @@ -65,40 +66,7 @@ async def run_fce( origin = flight_status.get("origin", "") destination = flight_status.get("destination", "") - async def _call(server, tool, args, is_live=False, timeout=15.0): - t = time.time() - ctx = lf.start_as_current_observation( - name=tool, as_type="tool", input=args, - metadata={"server": server, "is_live": is_live}, - ) if lf else None - if ctx: - ctx.__enter__() - try: - result = await asyncio.wait_for( - mcp.call_tool(server, tool, args), timeout=timeout, - ) - lat = int((time.time() - t) * 1000) - if ctx: - lf.update_current_span(output=result, metadata={"latency_ms": lat}) - ctx.__exit__(None, None, None) - await emit("tool_call_end", tool=tool, latency_ms=lat, is_live=is_live) - return result - except asyncio.TimeoutError: - lat = int((time.time() - t) * 1000) - if ctx: - lf.update_current_span(output={"error": "timeout"}, level="ERROR") - ctx.__exit__(None, None, None) - await emit("tool_call_error", tool=tool, error="timeout", latency_ms=lat) - errors.append(f"{tool}: timeout after {timeout}s") - return None - except Exception as e: - lat = int((time.time() - t) * 1000) - if ctx: - lf.update_current_span(output={"error": str(e)}, level="ERROR") - ctx.__exit__(None, None, None) - await emit("tool_call_error", tool=tool, error=str(e), latency_ms=lat) - errors.append(f"{tool}: {e}") - return None + _call = build_tool_caller(mcp, emit=emit, errors=errors, lf=lf) # Fire all independent calls in parallel ops_data_task = asyncio.create_task( diff --git a/agents/handover.py b/agents/handover.py index 0d0f229..02f2c85 100644 --- a/agents/handover.py +++ b/agents/handover.py @@ -11,6 +11,7 @@ from datetime import datetime, timezone from typing import Any from agents.shared.mcp_client import MCPMultiClient +from agents.shared.tool_runner import build_tool_caller ALL_HUBS = ["ORD", "EWR", "IAH", "SFO", "DEN"] @@ -43,40 +44,7 @@ async def run_handover( await emit("node_enter", node="gather_all") - async def _call(server, tool, args, is_live=False, timeout=15.0): - t = time.time() - ctx = lf.start_as_current_observation( - name=tool, as_type="tool", input=args, - metadata={"server": server, "is_live": is_live}, - ) if lf else None - if ctx: - ctx.__enter__() - try: - result = await asyncio.wait_for( - mcp.call_tool(server, tool, args), timeout=timeout, - ) - lat = int((time.time() - t) * 1000) - if ctx: - lf.update_current_span(output=result, metadata={"latency_ms": lat}) - ctx.__exit__(None, None, None) - await emit("tool_call_end", tool=tool, latency_ms=lat, is_live=is_live) - return result - except asyncio.TimeoutError: - lat = int((time.time() - t) * 1000) - if ctx: - lf.update_current_span(output={"error": "timeout"}, level="ERROR") - ctx.__exit__(None, None, None) - await emit("tool_call_error", tool=tool, error="timeout", latency_ms=lat) - errors.append(f"{tool}: timeout after {timeout}s") - return None - except Exception as e: - lat = int((time.time() - t) * 1000) - if ctx: - lf.update_current_span(output={"error": str(e)}, level="ERROR") - ctx.__exit__(None, None, None) - await emit("tool_call_error", tool=tool, error=str(e), latency_ms=lat) - errors.append(f"{tool}: {e}") - return None + _call = build_tool_caller(mcp, emit=emit, errors=errors, lf=lf) # Per-hub calls hub_tasks = {} diff --git a/agents/shared/mcp_client.py b/agents/shared/mcp_client.py index 58641df..1f82ca4 100644 --- a/agents/shared/mcp_client.py +++ b/agents/shared/mcp_client.py @@ -4,42 +4,17 @@ Composes the three domain-scoped MCP servers into namespaced configurations that agents connect to as a single client. """ -import json +import os from contextlib import asynccontextmanager from typing import Any from fastmcp import Client - -def _env() -> dict: - """Forward LLM-related env vars and active scenario to MCP server subprocesses.""" - import os - from mcp_servers.data.scenarios.manager import scenario_manager - - env = {} - for key in ( - "LLM_PROVIDER", "GROQ_API_KEY", "GROQ_MODEL", - "ANTHROPIC_API_KEY", "ANTHROPIC_MODEL", - "OPENAI_API_KEY", "OPENAI_BASE_URL", "OPENAI_MODEL", - "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_DEFAULT_REGION", - "BEDROCK_MODEL_ID", "USE_BEDROCK", - "PATH", - ): - val = os.getenv(key) - if val: - env[key] = val - env["ACTIVE_SCENARIO"] = scenario_manager.active_id - return env - - -def _server_config(module: str) -> dict: - """Build server config with current env vars (called at connect time, not import time).""" - return { - "command": "uv", - "args": ["run", "python", "-m", module], - "env": _env(), - } - +from agents.shared.parser import ( + parse_prompt_result, + parse_resource_result, + parse_tool_result, +) SERVER_MODULES = { "shared": "mcp_servers.shared", @@ -47,12 +22,33 @@ SERVER_MODULES = { "passenger": "mcp_servers.passenger", } -# Agent profiles — which servers each agent connects to AGENT_PROFILES = { "fce": ["shared", "ops", "passenger"], "handover": ["shared", "ops"], } +_FORWARDED_ENV_KEYS = ( + "LLM_PROVIDER", "GROQ_API_KEY", "GROQ_MODEL", + "ANTHROPIC_API_KEY", "ANTHROPIC_MODEL", + "OPENAI_API_KEY", "OPENAI_BASE_URL", "OPENAI_MODEL", + "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_DEFAULT_REGION", + "BEDROCK_MODEL_ID", "USE_BEDROCK", + "PATH", +) + + +def _subprocess_env() -> dict: + """Env vars forwarded to MCP server subprocesses (LLM config + active scenario).""" + from mcp_servers.data.scenarios.manager import scenario_manager + + env = {k: os.environ[k] for k in _FORWARDED_ENV_KEYS if k in os.environ} + env["ACTIVE_SCENARIO"] = scenario_manager.active_id + return env + + +def _server_config(module: str) -> dict: + return {"command": "uv", "args": ["run", "python", "-m", module], "env": _subprocess_env()} + class MCPMultiClient: """Manages connections to multiple MCP servers via fastmcp Client.""" @@ -61,7 +57,6 @@ class MCPMultiClient: self._clients: dict[str, Client] = {} async def connect(self, server_names: list[str]) -> None: - """Connect to the specified MCP servers.""" for name in server_names: if name not in SERVER_MODULES: raise ValueError(f"Unknown server: {name}. Available: {list(SERVER_MODULES.keys())}") @@ -71,7 +66,6 @@ class MCPMultiClient: self._clients[name] = client async def close(self) -> None: - """Close all server connections.""" for client in self._clients.values(): try: await client.__aexit__(None, None, None) @@ -79,71 +73,23 @@ class MCPMultiClient: pass self._clients.clear() - async def call_tool(self, server: str, tool_name: str, arguments: dict) -> Any: - """Call a tool on a specific server. Returns parsed result.""" + def _client(self, server: str) -> Client: client = self._clients.get(server) if not client: raise ValueError(f"Not connected to server: {server}") + return client - result = await client.call_tool(tool_name, arguments) - - # Parse the result content - if isinstance(result, list): - texts = [c.text for c in result if hasattr(c, "text")] - elif hasattr(result, "content"): - texts = [c.text for c in result.content if hasattr(c, "text")] - else: - return result - - if len(texts) == 1: - try: - return json.loads(texts[0]) - except (json.JSONDecodeError, TypeError): - return texts[0] - elif len(texts) > 1: - parsed = [] - for t in texts: - try: - parsed.append(json.loads(t)) - except (json.JSONDecodeError, TypeError): - parsed.append(t) - return parsed - return None + async def call_tool(self, server: str, tool_name: str, arguments: dict) -> Any: + result = await self._client(server).call_tool(tool_name, arguments) + return parse_tool_result(result) async def read_resource(self, server: str, uri: str) -> Any: - """Read a resource from a specific server.""" - client = self._clients.get(server) - if not client: - raise ValueError(f"Not connected to server: {server}") - - result = await client.read_resource(uri) - if isinstance(result, str): - try: - return json.loads(result) - except (json.JSONDecodeError, TypeError): - return result - return result + result = await self._client(server).read_resource(uri) + return parse_resource_result(result) async def get_prompt(self, server: str, prompt_name: str, arguments: dict) -> str: - """Get a rendered prompt from a specific server.""" - client = self._clients.get(server) - if not client: - raise ValueError(f"Not connected to server: {server}") - - result = await client.get_prompt(prompt_name, arguments) - if isinstance(result, str): - return result - # Handle structured prompt response - texts = [] - if hasattr(result, "messages"): - for msg in result.messages: - if hasattr(msg.content, "text"): - texts.append(msg.content.text) - elif isinstance(msg.content, list): - for c in msg.content: - if hasattr(c, "text"): - texts.append(c.text) - return "\n".join(texts) if texts else str(result) + result = await self._client(server).get_prompt(prompt_name, arguments) + return parse_prompt_result(result) @asynccontextmanager diff --git a/agents/shared/parser.py b/agents/shared/parser.py new file mode 100644 index 0000000..9446357 --- /dev/null +++ b/agents/shared/parser.py @@ -0,0 +1,56 @@ +"""Parsers for MCP response content (tools, resources, prompts).""" + +import json +from typing import Any + + +def _extract_texts(result: Any) -> list[str]: + """Pull text fields out of an MCP response, handling list/content/str shapes.""" + if isinstance(result, list): + return [c.text for c in result if hasattr(c, "text")] + if hasattr(result, "content"): + return [c.text for c in result.content if hasattr(c, "text")] + return [] + + +def _maybe_json(text: str) -> Any: + """Parse JSON if possible; fall back to raw text.""" + try: + return json.loads(text) + except (json.JSONDecodeError, TypeError): + return text + + +def parse_tool_result(result: Any) -> Any: + """Parse `client.call_tool()` result into native Python structure.""" + texts = _extract_texts(result) + if not texts: + return result + if len(texts) == 1: + return _maybe_json(texts[0]) + return [_maybe_json(t) for t in texts] + + +def parse_resource_result(result: Any) -> Any: + """Parse `client.read_resource()` result — often a JSON-encoded string.""" + if isinstance(result, str): + return _maybe_json(result) + return result + + +def parse_prompt_result(result: Any) -> str: + """Flatten a prompt response into a single string.""" + if isinstance(result, str): + return result + if not hasattr(result, "messages"): + return str(result) + + texts: list[str] = [] + for msg in result.messages: + if hasattr(msg.content, "text"): + texts.append(msg.content.text) + elif isinstance(msg.content, list): + for c in msg.content: + if hasattr(c, "text"): + texts.append(c.text) + return "\n".join(texts) if texts else str(result) diff --git a/agents/shared/tool_runner.py b/agents/shared/tool_runner.py new file mode 100644 index 0000000..1cea743 --- /dev/null +++ b/agents/shared/tool_runner.py @@ -0,0 +1,73 @@ +"""Shared tool-call helper: timeout, tracing, event emission, error collection. + +Every agent invokes MCP tools through this so the three concerns +(observability, resilience, error state) stay in one place. +""" + +import asyncio +import time +from typing import Any, Awaitable, Callable + + +EmitFn = Callable[..., Awaitable[None]] + + +def build_tool_caller( + mcp, + *, + emit: EmitFn, + errors: list[str], + lf: Any = None, + timeout: float = 15.0, +): + """Return an async `_call(server, tool, args, is_live=False)` closure. + + On each invocation: + - opens a Langfuse span (if `lf` is provided) + - runs `mcp.call_tool` with an `asyncio.wait_for` timeout + - emits `tool_call_end` / `tool_call_error` events + - appends to `errors` on failure, returns `None` + - closes the Langfuse span with output or error level + """ + + async def _call(server: str, tool: str, args: dict, is_live: bool = False) -> Any: + t0 = time.time() + span_ctx = lf.start_as_current_observation( + name=tool, as_type="tool", input=args, + metadata={"server": server, "is_live": is_live}, + ) if lf else None + if span_ctx: + span_ctx.__enter__() + + def _close_span(output: Any, level: str | None = None): + if not span_ctx: + return + kwargs: dict[str, Any] = {"output": output} + if level: + kwargs["level"] = level + else: + kwargs["metadata"] = {"latency_ms": int((time.time() - t0) * 1000)} + lf.update_current_span(**kwargs) + span_ctx.__exit__(None, None, None) + + try: + result = await asyncio.wait_for(mcp.call_tool(server, tool, args), timeout=timeout) + except asyncio.TimeoutError: + latency = int((time.time() - t0) * 1000) + _close_span({"error": "timeout"}, level="ERROR") + await emit("tool_call_error", tool=tool, error="timeout", latency_ms=latency) + errors.append(f"{tool}: timeout after {timeout}s") + return None + except Exception as e: + latency = int((time.time() - t0) * 1000) + _close_span({"error": str(e)}, level="ERROR") + await emit("tool_call_error", tool=tool, error=str(e), latency_ms=latency) + errors.append(f"{tool}: {e}") + return None + + latency = int((time.time() - t0) * 1000) + _close_span(result) + await emit("tool_call_end", tool=tool, latency_ms=latency, is_live=is_live) + return result + + return _call