"""Shared tool-call helper: timeout, tracing, event emission, error collection. Every agent invokes MCP tools through this so the three concerns (observability, resilience, error state) stay in one place. """ import asyncio import time from typing import Any, Awaitable, Callable EmitFn = Callable[..., Awaitable[None]] def build_tool_caller( mcp, *, emit: EmitFn, errors: list[str], lf: Any = None, timeout: float = 15.0, ): """Return an async `_call(server, tool, args, is_live=False)` closure. On each invocation: - opens a Langfuse span (if `lf` is provided) - runs `mcp.call_tool` with an `asyncio.wait_for` timeout - emits `tool_call_end` / `tool_call_error` events - appends to `errors` on failure, returns `None` - closes the Langfuse span with output or error level """ async def _call(server: str, tool: str, args: dict, is_live: bool = False) -> Any: t0 = time.time() span_ctx = lf.start_as_current_observation( name=tool, as_type="tool", input=args, metadata={"server": server, "is_live": is_live}, ) if lf else None if span_ctx: span_ctx.__enter__() def _close_span(output: Any, level: str | None = None): if not span_ctx: return kwargs: dict[str, Any] = {"output": output} if level: kwargs["level"] = level else: kwargs["metadata"] = {"latency_ms": int((time.time() - t0) * 1000)} lf.update_current_span(**kwargs) span_ctx.__exit__(None, None, None) try: result = await asyncio.wait_for(mcp.call_tool(server, tool, args), timeout=timeout) except asyncio.TimeoutError: latency = int((time.time() - t0) * 1000) _close_span({"error": "timeout"}, level="ERROR") await emit("tool_call_error", tool=tool, error="timeout", latency_ms=latency) errors.append(f"{tool}: timeout after {timeout}s") return None except Exception as e: latency = int((time.time() - t0) * 1000) _close_span({"error": str(e)}, level="ERROR") await emit("tool_call_error", tool=tool, error=str(e), latency_ms=latency) errors.append(f"{tool}: {e}") return None latency = int((time.time() - t0) * 1000) _close_span(result) await emit("tool_call_end", tool=tool, latency_ms=latency, is_live=is_live) return result return _call