init commit
This commit is contained in:
0
agents/__init__.py
Normal file
0
agents/__init__.py
Normal file
183
agents/efhas.py
Normal file
183
agents/efhas.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""FCE agent — "Behind Every Departure" (Flight Context Engine) passenger notification agent.
|
||||
|
||||
Connects to: shared + passenger MCP servers.
|
||||
When a flight is disrupted, gathers all operational context and generates
|
||||
an empathetic passenger notification.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from agents.shared.mcp_client import MCPMultiClient
|
||||
|
||||
|
||||
async def run_efhas(
|
||||
flight_id: str,
|
||||
mcp: MCPMultiClient,
|
||||
on_event: Any = None,
|
||||
) -> dict:
|
||||
"""Run the EFHaS agent for a single flight.
|
||||
|
||||
Args:
|
||||
flight_id: The flight to generate a notification for.
|
||||
mcp: Connected MCPMultiClient (shared + passenger).
|
||||
on_event: Optional async callback for real-time events.
|
||||
|
||||
Returns:
|
||||
Notification result dict.
|
||||
"""
|
||||
run_start = time.time()
|
||||
errors = []
|
||||
tool_calls = []
|
||||
|
||||
async def emit(event_type: str, **data):
|
||||
event = {"type": event_type, "timestamp": datetime.now(timezone.utc).isoformat(), **data}
|
||||
tool_calls.append(event)
|
||||
if on_event:
|
||||
await on_event(event)
|
||||
|
||||
# ── Node 1: Triage ──
|
||||
|
||||
await emit("node_enter", node="triage")
|
||||
|
||||
t0 = time.time()
|
||||
flight_status = await mcp.call_tool("shared", "get_flight_status", {"flight_id": flight_id})
|
||||
latency = int((time.time() - t0) * 1000)
|
||||
await emit("tool_call_end", tool="get_flight_status", latency_ms=latency, is_live=False)
|
||||
|
||||
if isinstance(flight_status, dict) and "error" in flight_status:
|
||||
return {"error": flight_status["error"], "flight_id": flight_id}
|
||||
|
||||
status = flight_status.get("status", "")
|
||||
delay_minutes = flight_status.get("delay_minutes", 0)
|
||||
should_notify = status in ("DELAYED", "CANCELLED", "DIVERTED") and delay_minutes >= 10
|
||||
|
||||
await emit("node_exit", node="triage", result={"should_notify": should_notify})
|
||||
|
||||
if not should_notify:
|
||||
return {
|
||||
"flight_id": flight_id,
|
||||
"type": "NO_NOTIFICATION",
|
||||
"reason": f"Status {status}, delay {delay_minutes}min — below threshold",
|
||||
"duration_ms": int((time.time() - run_start) * 1000),
|
||||
}
|
||||
|
||||
# ── Node 2: Gather Context (parallel tool calls) ──
|
||||
|
||||
await emit("node_enter", node="gather_context")
|
||||
|
||||
origin = flight_status.get("origin", "")
|
||||
destination = flight_status.get("destination", "")
|
||||
|
||||
async def _call(server, tool, args, is_live=False):
|
||||
t = time.time()
|
||||
try:
|
||||
result = await mcp.call_tool(server, tool, args)
|
||||
lat = int((time.time() - t) * 1000)
|
||||
await emit("tool_call_end", tool=tool, latency_ms=lat, is_live=is_live)
|
||||
return result
|
||||
except Exception as e:
|
||||
lat = int((time.time() - t) * 1000)
|
||||
await emit("tool_call_error", tool=tool, error=str(e), latency_ms=lat)
|
||||
errors.append(f"{tool}: {e}")
|
||||
return None
|
||||
|
||||
# Fire all independent calls in parallel
|
||||
ops_data_task = asyncio.create_task(
|
||||
_call("shared", "get_flight_details", {"flight_id": flight_id})
|
||||
)
|
||||
weather_task = asyncio.create_task(
|
||||
_call("shared", "get_route_weather", {"origin": origin, "destination": destination}, is_live=True)
|
||||
)
|
||||
airport_status_task = asyncio.create_task(
|
||||
_call("shared", "get_airport_status", {"airport_code": origin}, is_live=True)
|
||||
)
|
||||
airport_congestion_task = asyncio.create_task(
|
||||
_call("shared", "get_airport_congestion", {"airport_code": origin}, is_live=True)
|
||||
)
|
||||
crew_notes_task = asyncio.create_task(
|
||||
_call("ops", "get_crew_notes", {"flight_id": flight_id})
|
||||
)
|
||||
|
||||
ops_data, weather, airport_status, airport_congestion, crew_notes = await asyncio.gather(
|
||||
ops_data_task, weather_task, airport_status_task, airport_congestion_task, crew_notes_task
|
||||
)
|
||||
|
||||
await emit("node_exit", node="gather_context")
|
||||
|
||||
# ── Node 3: Synthesize ──
|
||||
|
||||
await emit("node_enter", node="synthesize")
|
||||
|
||||
# Build weather summary
|
||||
weather_summary = ""
|
||||
if weather and isinstance(weather, dict):
|
||||
events = weather.get("significant_events", [])
|
||||
if events:
|
||||
weather_summary = ", ".join(e.get("condition", "") for e in events)
|
||||
else:
|
||||
origin_wp = weather.get("waypoints", {}).get("origin", {})
|
||||
if isinstance(origin_wp, dict) and "weather" in origin_wp:
|
||||
weather_summary = origin_wp["weather"].get("condition", "")
|
||||
|
||||
# Build crew notes summary
|
||||
crew_summary = ""
|
||||
if crew_notes and isinstance(crew_notes, list):
|
||||
# Take first 2 notes for the notification
|
||||
relevant = [n for n in crew_notes if not n.startswith("CANCELLED")][:2]
|
||||
crew_summary = " ".join(relevant)
|
||||
|
||||
context = {
|
||||
"flight_id": flight_id,
|
||||
"origin": origin,
|
||||
"destination": destination,
|
||||
"status": status,
|
||||
"delay_minutes": delay_minutes,
|
||||
"delay_cause": flight_status.get("delay_cause")
|
||||
or (ops_data.get("delay_cause") if isinstance(ops_data, dict) else None),
|
||||
"gate": flight_status.get("gate", ""),
|
||||
"weather_summary": weather_summary,
|
||||
"crew_notes_summary": crew_summary,
|
||||
"get_airport_status": airport_status,
|
||||
"get_airport_congestion": airport_congestion,
|
||||
}
|
||||
|
||||
t0 = time.time()
|
||||
notification_text = await mcp.call_tool("passenger", "generate_notification", {"context": context})
|
||||
latency = int((time.time() - t0) * 1000)
|
||||
await emit("tool_call_end", tool="generate_notification", latency_ms=latency, is_live=False)
|
||||
|
||||
await emit("node_exit", node="synthesize")
|
||||
|
||||
# ── Node 4: Format Output ──
|
||||
|
||||
await emit("node_enter", node="format_output")
|
||||
|
||||
data_sources = ["flight_ops"]
|
||||
if weather:
|
||||
data_sources.append("weather_live")
|
||||
if airport_status:
|
||||
data_sources.append("faa_status_live")
|
||||
if crew_notes:
|
||||
data_sources.append("get_crew_notes")
|
||||
|
||||
notification = {
|
||||
"flight_id": flight_id,
|
||||
"type": "DELAY_NOTIFICATION" if status == "DELAYED" else f"{status}_NOTIFICATION",
|
||||
"status": status,
|
||||
"delay_minutes": delay_minutes,
|
||||
"notification_text": notification_text if isinstance(notification_text, str) else str(notification_text),
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"data_sources": data_sources,
|
||||
"human_approved": True, # auto-approve in demo
|
||||
"errors": errors,
|
||||
"duration_ms": int((time.time() - run_start) * 1000),
|
||||
}
|
||||
|
||||
await emit("node_exit", node="format_output")
|
||||
await emit("agent_end", output_summary=f"{status} notification for {flight_id}")
|
||||
|
||||
return notification
|
||||
293
agents/handover.py
Normal file
293
agents/handover.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""Shift Handover agent — compiles all active ops state into a prioritized brief.
|
||||
|
||||
Connects to: shared + ops MCP servers.
|
||||
Gathers data across all hubs, triages by severity and time sensitivity,
|
||||
generates a structured handover brief.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from agents.shared.mcp_client import MCPMultiClient
|
||||
|
||||
ALL_HUBS = ["ORD", "EWR", "IAH", "SFO", "DEN"]
|
||||
|
||||
# WMO weather severity mapping
|
||||
WMO_SEVERITY = {
|
||||
95: 9, 96: 9, 99: 10, # thunderstorms
|
||||
65: 6, 75: 7, 82: 7, 86: 8, # heavy precip
|
||||
45: 5, 48: 5, # fog
|
||||
}
|
||||
|
||||
|
||||
async def run_handover(
|
||||
hubs: list[str] | None = None,
|
||||
mcp: MCPMultiClient | None = None,
|
||||
on_event: Any = None,
|
||||
) -> dict:
|
||||
"""Run the Shift Handover agent.
|
||||
|
||||
Args:
|
||||
hubs: Hubs to cover (default: all 5).
|
||||
mcp: Connected MCPMultiClient (shared + ops).
|
||||
on_event: Optional async callback for real-time events.
|
||||
|
||||
Returns:
|
||||
Handover brief result dict.
|
||||
"""
|
||||
target_hubs = hubs or ALL_HUBS
|
||||
run_start = time.time()
|
||||
errors = []
|
||||
tool_calls = []
|
||||
|
||||
async def emit(event_type: str, **data):
|
||||
event = {"type": event_type, "timestamp": datetime.now(timezone.utc).isoformat(), **data}
|
||||
tool_calls.append(event)
|
||||
if on_event:
|
||||
await on_event(event)
|
||||
|
||||
# ── Node 1: Gather All (parallel tool calls across hubs) ──
|
||||
|
||||
await emit("node_enter", node="gather_all")
|
||||
|
||||
async def _call(server, tool, args, is_live=False):
|
||||
t = time.time()
|
||||
try:
|
||||
result = await mcp.call_tool(server, tool, args)
|
||||
lat = int((time.time() - t) * 1000)
|
||||
await emit("tool_call_end", tool=tool, latency_ms=lat, is_live=is_live)
|
||||
return result
|
||||
except Exception as e:
|
||||
lat = int((time.time() - t) * 1000)
|
||||
await emit("tool_call_error", tool=tool, error=str(e), latency_ms=lat)
|
||||
errors.append(f"{tool}: {e}")
|
||||
return None
|
||||
|
||||
# Per-hub calls
|
||||
hub_tasks = {}
|
||||
for hub in target_hubs:
|
||||
hub_tasks[hub] = {
|
||||
"irrops": asyncio.create_task(
|
||||
_call("shared", "get_irregular_ops", {"hub": hub})
|
||||
),
|
||||
"airport": asyncio.create_task(
|
||||
_call("shared", "get_airport_status", {"airport_code": hub}, is_live=True)
|
||||
),
|
||||
"rebookings": asyncio.create_task(
|
||||
_call("ops", "get_pending_rebookings", {"hub": hub})
|
||||
),
|
||||
}
|
||||
|
||||
# Global calls
|
||||
weather_task = asyncio.create_task(
|
||||
_call("shared", "get_hub_forecasts", {}, is_live=True)
|
||||
)
|
||||
|
||||
# Gather all hub results
|
||||
hub_data = {}
|
||||
for hub in target_hubs:
|
||||
hub_data[hub] = {
|
||||
"irrops": await hub_tasks[hub]["irrops"],
|
||||
"airport": await hub_tasks[hub]["airport"],
|
||||
"rebookings": await hub_tasks[hub]["rebookings"],
|
||||
}
|
||||
|
||||
weather_forecast = await weather_task
|
||||
|
||||
# Get crew duty status for all crew in disrupted flights
|
||||
all_crew_ids = []
|
||||
for hub in target_hubs:
|
||||
irrops = hub_data[hub].get("irrops")
|
||||
if isinstance(irrops, list):
|
||||
for irrop in irrops:
|
||||
if isinstance(irrop, dict):
|
||||
# Get full flight data to get crew IDs
|
||||
flight_data = await _call(
|
||||
"shared", "get_flight_details",
|
||||
{"flight_id": irrop.get("flight_id", "")}
|
||||
)
|
||||
if isinstance(flight_data, dict):
|
||||
all_crew_ids.extend(flight_data.get("crew_ids", []))
|
||||
|
||||
crew_status = None
|
||||
if all_crew_ids:
|
||||
crew_status = await _call("ops", "get_crew_duty_status", {"crew_ids": all_crew_ids})
|
||||
|
||||
# Get maintenance flags for disrupted aircraft
|
||||
maintenance_data = {}
|
||||
for hub in target_hubs:
|
||||
irrops = hub_data[hub].get("irrops")
|
||||
if isinstance(irrops, list):
|
||||
for irrop in irrops:
|
||||
if isinstance(irrop, dict):
|
||||
flight_data = await _call(
|
||||
"shared", "get_flight_details",
|
||||
{"flight_id": irrop.get("flight_id", "")}
|
||||
)
|
||||
if isinstance(flight_data, dict):
|
||||
tail = flight_data.get("aircraft_tail", "")
|
||||
if tail and tail not in maintenance_data:
|
||||
mel = await _call("shared", "get_maintenance_flags", {"aircraft_tail": tail})
|
||||
if mel:
|
||||
maintenance_data[tail] = mel
|
||||
|
||||
await emit("node_exit", node="gather_all")
|
||||
|
||||
# ── Node 2: Triage ──
|
||||
|
||||
await emit("node_enter", node="triage")
|
||||
|
||||
immediate = []
|
||||
monitor = []
|
||||
fyi = []
|
||||
|
||||
# Crew at risk
|
||||
if isinstance(crew_status, list):
|
||||
for crew in crew_status:
|
||||
if isinstance(crew, dict) and crew.get("at_risk"):
|
||||
hours_left = crew.get("hours_until_limit", 0)
|
||||
immediate.append(
|
||||
f"{crew.get('next_scheduled_flight', '?')} — "
|
||||
f"{crew.get('name', '?')} ({crew.get('role', '?')}) "
|
||||
f"duty limit in {hours_left:.1f}h. "
|
||||
f"Swap required if departure slips."
|
||||
)
|
||||
|
||||
# IROPs
|
||||
for hub in target_hubs:
|
||||
irrops = hub_data[hub].get("irrops")
|
||||
if isinstance(irrops, list) and irrops:
|
||||
total_pax = sum(
|
||||
i.get("affected_pax_count", 0) for i in irrops if isinstance(i, dict)
|
||||
)
|
||||
cancelled = [i for i in irrops if isinstance(i, dict) and i.get("irrop_type") == "CANCELLED"]
|
||||
|
||||
if cancelled:
|
||||
for c in cancelled:
|
||||
immediate.append(
|
||||
f"{c.get('flight_id', '?')} ({hub}→{c.get('destination', '?')}) — "
|
||||
f"CANCELLED ({c.get('delay_cause', '?')}). "
|
||||
f"{c.get('affected_pax_count', 0)} pax need rebooking."
|
||||
)
|
||||
elif total_pax > 100:
|
||||
monitor.append(
|
||||
f"{hub}: {len(irrops)} flights disrupted, {total_pax} pax affected."
|
||||
)
|
||||
|
||||
# Rebookings
|
||||
rebookings = hub_data[hub].get("rebookings")
|
||||
if isinstance(rebookings, list) and rebookings:
|
||||
high_priority = [r for r in rebookings if isinstance(r, dict) and r.get("urgency") == "HIGH"]
|
||||
if high_priority:
|
||||
immediate.append(
|
||||
f"{hub}: {len(rebookings)} pax awaiting rebooking "
|
||||
f"({len(high_priority)} HIGH priority)."
|
||||
)
|
||||
|
||||
# Weather risks
|
||||
if isinstance(weather_forecast, dict):
|
||||
for hub_code, forecast in weather_forecast.get("hubs", {}).items():
|
||||
if isinstance(forecast, dict) and forecast.get("risk_flag"):
|
||||
monitor.append(
|
||||
f"Weather risk at {hub_code}: convective activity or low visibility "
|
||||
f"forecast in next 4 hours."
|
||||
)
|
||||
|
||||
# MEL items
|
||||
for tail, items in maintenance_data.items():
|
||||
if isinstance(items, list):
|
||||
for item in items:
|
||||
if isinstance(item, dict) and item.get("restriction"):
|
||||
monitor.append(
|
||||
f"MEL {item.get('mel_id', '?')} on {tail}: "
|
||||
f"{item.get('system', '?')} — {item.get('restriction', '')}"
|
||||
)
|
||||
|
||||
# Airport status (live FAA)
|
||||
for hub in target_hubs:
|
||||
airport = hub_data[hub].get("airport")
|
||||
if isinstance(airport, dict) and airport.get("has_delays"):
|
||||
for delay in airport.get("delays", []):
|
||||
if isinstance(delay, dict):
|
||||
if delay.get("type") == "ground_stop":
|
||||
immediate.append(
|
||||
f"{hub} GROUND STOP: {delay.get('reason', 'unknown')}. "
|
||||
f"End time: {delay.get('end_time', 'TBD')}."
|
||||
)
|
||||
elif delay.get("type") == "ground_delay_program":
|
||||
monitor.append(
|
||||
f"{hub} GDP: {delay.get('reason', 'unknown')}. "
|
||||
f"Avg delay: {delay.get('average_delay', 'unknown')}."
|
||||
)
|
||||
|
||||
# Nominal hubs go to FYI
|
||||
for hub in target_hubs:
|
||||
irrops = hub_data[hub].get("irrops")
|
||||
airport = hub_data[hub].get("airport")
|
||||
has_issues = (isinstance(irrops, list) and len(irrops) > 0) or (
|
||||
isinstance(airport, dict) and airport.get("has_delays")
|
||||
)
|
||||
if not has_issues:
|
||||
fyi.append(f"{hub} fully nominal. No open items.")
|
||||
|
||||
await emit("node_exit", node="triage")
|
||||
|
||||
# ── Node 3: Synthesize ──
|
||||
|
||||
await emit("node_enter", node="synthesize")
|
||||
|
||||
shift_time = datetime.now(timezone.utc).strftime("%H:%M UTC")
|
||||
hub_label = ", ".join(target_hubs) if len(target_hubs) < 5 else "ALL HUBS"
|
||||
|
||||
t0 = time.time()
|
||||
brief_text = await mcp.call_tool("ops", "generate_narrative", {
|
||||
"context": {
|
||||
"hub": hub_label,
|
||||
"shift_time": shift_time,
|
||||
"immediate": immediate,
|
||||
"monitor": monitor,
|
||||
"fyi": fyi,
|
||||
}
|
||||
})
|
||||
latency = int((time.time() - t0) * 1000)
|
||||
await emit("tool_call_end", tool="generate_narrative", latency_ms=latency, is_live=False)
|
||||
|
||||
await emit("node_exit", node="synthesize")
|
||||
|
||||
# ── Node 4: Format Output ──
|
||||
|
||||
await emit("node_enter", node="format_output")
|
||||
|
||||
result = {
|
||||
"type": "HANDOVER_BRIEF",
|
||||
"hubs": target_hubs,
|
||||
"brief_text": brief_text if isinstance(brief_text, str) else str(brief_text),
|
||||
"summary": {
|
||||
"immediate_count": len(immediate),
|
||||
"monitor_count": len(monitor),
|
||||
"fyi_count": len(fyi),
|
||||
},
|
||||
"items": {
|
||||
"immediate": immediate,
|
||||
"monitor": monitor,
|
||||
"fyi": fyi,
|
||||
},
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"errors": errors,
|
||||
"duration_ms": int((time.time() - run_start) * 1000),
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
|
||||
await emit("node_exit", node="format_output")
|
||||
await emit("agent_end", output_summary=f"Handover brief for {hub_label}")
|
||||
|
||||
# Store the brief for the ops://handover/latest resource
|
||||
try:
|
||||
from mcp_servers.ops.server import store_handover_brief
|
||||
store_handover_brief(result)
|
||||
except Exception:
|
||||
pass # not critical
|
||||
|
||||
return result
|
||||
0
agents/shared/__init__.py
Normal file
0
agents/shared/__init__.py
Normal file
24
agents/shared/llm.py
Normal file
24
agents/shared/llm.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""LLM factory — Bedrock Converse API (production) or direct Anthropic SDK (local dev)."""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def get_agent_llm():
|
||||
"""Returns a LangChain chat model for agent orchestration."""
|
||||
if os.getenv("USE_BEDROCK", "").lower() == "true":
|
||||
from langchain_aws import ChatBedrockConverse
|
||||
|
||||
return ChatBedrockConverse(
|
||||
model=os.getenv("BEDROCK_MODEL_ID", "anthropic.claude-sonnet-4-20250514-v1:0"),
|
||||
region_name=os.getenv("AWS_DEFAULT_REGION", "us-east-1"),
|
||||
temperature=0.3,
|
||||
max_tokens=4096,
|
||||
)
|
||||
else:
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
return ChatAnthropic(
|
||||
model=os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-20250514"),
|
||||
temperature=0.3,
|
||||
max_tokens=4096,
|
||||
)
|
||||
137
agents/shared/mcp_client.py
Normal file
137
agents/shared/mcp_client.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Multi-server MCP client using fastmcp composition.
|
||||
|
||||
Composes the three domain-scoped MCP servers into namespaced configurations
|
||||
that agents connect to as a single client.
|
||||
"""
|
||||
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from fastmcp import Client
|
||||
|
||||
|
||||
# Server configurations for stdio transport
|
||||
SERVERS = {
|
||||
"shared": {
|
||||
"command": "uv",
|
||||
"args": ["run", "python", "-m", "mcp_servers.shared"],
|
||||
},
|
||||
"ops": {
|
||||
"command": "uv",
|
||||
"args": ["run", "python", "-m", "mcp_servers.ops"],
|
||||
},
|
||||
"passenger": {
|
||||
"command": "uv",
|
||||
"args": ["run", "python", "-m", "mcp_servers.passenger"],
|
||||
},
|
||||
}
|
||||
|
||||
# Agent profiles — which servers each agent connects to
|
||||
AGENT_PROFILES = {
|
||||
"efhas": ["shared", "ops", "passenger"],
|
||||
"handover": ["shared", "ops"],
|
||||
}
|
||||
|
||||
|
||||
class MCPMultiClient:
|
||||
"""Manages connections to multiple MCP servers via fastmcp Client."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._clients: dict[str, Client] = {}
|
||||
|
||||
async def connect(self, server_names: list[str]) -> None:
|
||||
"""Connect to the specified MCP servers."""
|
||||
for name in server_names:
|
||||
if name not in SERVERS:
|
||||
raise ValueError(f"Unknown server: {name}. Available: {list(SERVERS.keys())}")
|
||||
config = {"mcpServers": {"default": SERVERS[name]}}
|
||||
client = Client(config)
|
||||
await client.__aenter__()
|
||||
self._clients[name] = client
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close all server connections."""
|
||||
for client in self._clients.values():
|
||||
try:
|
||||
await client.__aexit__(None, None, None)
|
||||
except (Exception, BaseException):
|
||||
pass
|
||||
self._clients.clear()
|
||||
|
||||
async def call_tool(self, server: str, tool_name: str, arguments: dict) -> Any:
|
||||
"""Call a tool on a specific server. Returns parsed result."""
|
||||
client = self._clients.get(server)
|
||||
if not client:
|
||||
raise ValueError(f"Not connected to server: {server}")
|
||||
|
||||
result = await client.call_tool(tool_name, arguments)
|
||||
|
||||
# Parse the result content
|
||||
if isinstance(result, list):
|
||||
texts = [c.text for c in result if hasattr(c, "text")]
|
||||
elif hasattr(result, "content"):
|
||||
texts = [c.text for c in result.content if hasattr(c, "text")]
|
||||
else:
|
||||
return result
|
||||
|
||||
if len(texts) == 1:
|
||||
try:
|
||||
return json.loads(texts[0])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return texts[0]
|
||||
elif len(texts) > 1:
|
||||
parsed = []
|
||||
for t in texts:
|
||||
try:
|
||||
parsed.append(json.loads(t))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
parsed.append(t)
|
||||
return parsed
|
||||
return None
|
||||
|
||||
async def read_resource(self, server: str, uri: str) -> Any:
|
||||
"""Read a resource from a specific server."""
|
||||
client = self._clients.get(server)
|
||||
if not client:
|
||||
raise ValueError(f"Not connected to server: {server}")
|
||||
|
||||
result = await client.read_resource(uri)
|
||||
if isinstance(result, str):
|
||||
try:
|
||||
return json.loads(result)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return result
|
||||
return result
|
||||
|
||||
async def get_prompt(self, server: str, prompt_name: str, arguments: dict) -> str:
|
||||
"""Get a rendered prompt from a specific server."""
|
||||
client = self._clients.get(server)
|
||||
if not client:
|
||||
raise ValueError(f"Not connected to server: {server}")
|
||||
|
||||
result = await client.get_prompt(prompt_name, arguments)
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
# Handle structured prompt response
|
||||
texts = []
|
||||
if hasattr(result, "messages"):
|
||||
for msg in result.messages:
|
||||
if hasattr(msg.content, "text"):
|
||||
texts.append(msg.content.text)
|
||||
elif isinstance(msg.content, list):
|
||||
for c in msg.content:
|
||||
if hasattr(c, "text"):
|
||||
texts.append(c.text)
|
||||
return "\n".join(texts) if texts else str(result)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect_servers(server_names: list[str]):
|
||||
"""Context manager for multi-server MCP connections."""
|
||||
client = MCPMultiClient()
|
||||
try:
|
||||
await client.connect(server_names)
|
||||
yield client
|
||||
finally:
|
||||
await client.close()
|
||||
Reference in New Issue
Block a user