74 lines
2.5 KiB
Python
74 lines
2.5 KiB
Python
"""Shared tool-call helper: timeout, tracing, event emission, error collection.
|
|
|
|
Every agent invokes MCP tools through this so the three concerns
|
|
(observability, resilience, error state) stay in one place.
|
|
"""
|
|
|
|
import asyncio
|
|
import time
|
|
from typing import Any, Awaitable, Callable
|
|
|
|
|
|
EmitFn = Callable[..., Awaitable[None]]
|
|
|
|
|
|
def build_tool_caller(
|
|
mcp,
|
|
*,
|
|
emit: EmitFn,
|
|
errors: list[str],
|
|
lf: Any = None,
|
|
timeout: float = 15.0,
|
|
):
|
|
"""Return an async `_call(server, tool, args, is_live=False)` closure.
|
|
|
|
On each invocation:
|
|
- opens a Langfuse span (if `lf` is provided)
|
|
- runs `mcp.call_tool` with an `asyncio.wait_for` timeout
|
|
- emits `tool_call_end` / `tool_call_error` events
|
|
- appends to `errors` on failure, returns `None`
|
|
- closes the Langfuse span with output or error level
|
|
"""
|
|
|
|
async def _call(server: str, tool: str, args: dict, is_live: bool = False) -> Any:
|
|
t0 = time.time()
|
|
span_ctx = lf.start_as_current_observation(
|
|
name=tool, as_type="tool", input=args,
|
|
metadata={"server": server, "is_live": is_live},
|
|
) if lf else None
|
|
if span_ctx:
|
|
span_ctx.__enter__()
|
|
|
|
def _close_span(output: Any, level: str | None = None):
|
|
if not span_ctx:
|
|
return
|
|
kwargs: dict[str, Any] = {"output": output}
|
|
if level:
|
|
kwargs["level"] = level
|
|
else:
|
|
kwargs["metadata"] = {"latency_ms": int((time.time() - t0) * 1000)}
|
|
lf.update_current_span(**kwargs)
|
|
span_ctx.__exit__(None, None, None)
|
|
|
|
try:
|
|
result = await asyncio.wait_for(mcp.call_tool(server, tool, args), timeout=timeout)
|
|
except asyncio.TimeoutError:
|
|
latency = int((time.time() - t0) * 1000)
|
|
_close_span({"error": "timeout"}, level="ERROR")
|
|
await emit("tool_call_error", tool=tool, error="timeout", latency_ms=latency)
|
|
errors.append(f"{tool}: timeout after {timeout}s")
|
|
return None
|
|
except Exception as e:
|
|
latency = int((time.time() - t0) * 1000)
|
|
_close_span({"error": str(e)}, level="ERROR")
|
|
await emit("tool_call_error", tool=tool, error=str(e), latency_ms=latency)
|
|
errors.append(f"{tool}: {e}")
|
|
return None
|
|
|
|
latency = int((time.time() - t0) * 1000)
|
|
_close_span(result)
|
|
await emit("tool_call_end", tool=tool, latency_ms=latency, is_live=is_live)
|
|
return result
|
|
|
|
return _call
|