"""FastAPI application — triggers agents, exposes scenarios, streams events.""" import asyncio import json import uuid from contextlib import asynccontextmanager from datetime import datetime, timezone from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from agents.efhas import run_efhas from agents.handover import run_handover from agents.shared.mcp_client import connect_servers from mcp_servers.data.scenarios.manager import scenario_manager # ── WebSocket event hub ── class EventHub: """Fans out agent events to connected WebSocket clients.""" def __init__(self): self._clients: set[WebSocket] = set() async def connect(self, ws: WebSocket): await ws.accept() self._clients.add(ws) def disconnect(self, ws: WebSocket): self._clients.discard(ws) async def broadcast(self, event: dict): dead = set() for ws in self._clients: try: await ws.send_json(event) except Exception: dead.add(ws) self._clients -= dead event_hub = EventHub() # ── In-memory run store ── runs: dict[str, dict] = {} # ── App lifecycle ── @asynccontextmanager async def lifespan(app: FastAPI): yield app = FastAPI(title="United Ops MCP Demo", lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ── Request/Response models ── class EFHaSRequest(BaseModel): flight_id: str class HandoverRequest(BaseModel): hubs: list[str] | None = None class ScenarioUpdate(BaseModel): scenario_id: str # ── Agent routes ── @app.post("/agents/efhas") async def trigger_efhas(req: EFHaSRequest): run_id = str(uuid.uuid4())[:8] runs[run_id] = {"status": "running", "agent": "efhas", "flight_id": req.flight_id} async def _run(): await event_hub.broadcast({ "type": "agent_start", "run_id": run_id, "agent": "efhas", "flight_id": req.flight_id, "timestamp": datetime.now(timezone.utc).isoformat(), }) async def on_event(event): await event_hub.broadcast({"run_id": run_id, **event}) try: async with connect_servers(["shared", "ops", "passenger"]) as mcp: result = await run_efhas(req.flight_id, mcp, on_event=on_event) runs[run_id] = {"status": "completed", "agent": "efhas", "result": result} except Exception as e: runs[run_id] = {"status": "error", "agent": "efhas", "error": str(e)} asyncio.create_task(_run()) return {"run_id": run_id, "status": "running"} @app.post("/agents/handover") async def trigger_handover(req: HandoverRequest): run_id = str(uuid.uuid4())[:8] runs[run_id] = {"status": "running", "agent": "handover", "hubs": req.hubs} async def _run(): await event_hub.broadcast({ "type": "agent_start", "run_id": run_id, "agent": "handover", "hubs": req.hubs, "timestamp": datetime.now(timezone.utc).isoformat(), }) async def on_event(event): await event_hub.broadcast({"run_id": run_id, **event}) try: async with connect_servers(["shared", "ops"]) as mcp: result = await run_handover(hubs=req.hubs, mcp=mcp, on_event=on_event) runs[run_id] = {"status": "completed", "agent": "handover", "result": result} except Exception as e: runs[run_id] = {"status": "error", "agent": "handover", "error": str(e)} asyncio.create_task(_run()) return {"run_id": run_id, "status": "running"} @app.get("/agents/runs/{run_id}") async def get_run(run_id: str): if run_id not in runs: return {"error": f"Run {run_id} not found"} return runs[run_id] @app.get("/agents/runs") async def list_runs(): return [ {"run_id": rid, "status": r["status"], "agent": r.get("agent")} for rid, r in sorted(runs.items(), reverse=True) ] # ── Scenario routes ── @app.get("/scenarios") async def list_scenarios(): return scenario_manager.list_scenarios() @app.get("/scenarios/active") async def get_active_scenario(): return scenario_manager.get_metadata() @app.put("/scenarios/active") async def set_active_scenario(req: ScenarioUpdate): try: return scenario_manager.set_active(req.scenario_id) except ValueError as e: return {"error": str(e)} # ── Scenario data routes ── @app.get("/scenarios/data/flights") async def get_scenario_flights(): return [f.model_dump(mode="json") for f in scenario_manager.flights] @app.get("/scenarios/data/crew") async def get_scenario_crew(): crew = [] for c in scenario_manager.crew: d = c.model_dump(mode="json") d["hours_until_limit"] = round(c.duty_hours_limit - c.duty_hours_elapsed, 2) d["at_risk"] = d["hours_until_limit"] <= 2.0 crew.append(d) return crew @app.get("/scenarios/data/crew-notes") async def get_scenario_crew_notes(): return scenario_manager.crew_notes @app.get("/scenarios/data/maintenance") async def get_scenario_maintenance(): result = {} for tail, items in scenario_manager.maintenance.items(): result[tail] = [i.model_dump(mode="json") for i in items] return result @app.get("/scenarios/data/rebookings") async def get_scenario_rebookings(): return [r.model_dump(mode="json") for r in scenario_manager.rebookings] class FlightPatch(BaseModel): delay_minutes: int | None = None status: str | None = None gate: str | None = None @app.patch("/scenarios/data/flights/{flight_id}") async def patch_flight(flight_id: str, patch: FlightPatch): for f in scenario_manager.flights: if f.flight_id == flight_id: if patch.delay_minutes is not None: f.delay_minutes = patch.delay_minutes if patch.status is not None: f.status = patch.status if patch.gate is not None: f.gate = patch.gate return f.model_dump(mode="json") return {"error": f"Flight {flight_id} not found"} class CrewPatch(BaseModel): duty_hours_elapsed: float | None = None @app.patch("/scenarios/data/crew/{crew_id}") async def patch_crew(crew_id: str, patch: CrewPatch): for c in scenario_manager.crew: if c.crew_id == crew_id: if patch.duty_hours_elapsed is not None: c.duty_hours_elapsed = patch.duty_hours_elapsed d = c.model_dump(mode="json") d["hours_until_limit"] = round(c.duty_hours_limit - c.duty_hours_elapsed, 2) d["at_risk"] = d["hours_until_limit"] <= 2.0 return d return {"error": f"Crew {crew_id} not found"} class CrewNotesPatch(BaseModel): notes: list[str] @app.put("/scenarios/data/crew-notes/{flight_id}") async def put_crew_notes(flight_id: str, patch: CrewNotesPatch): scenario_manager.crew_notes[flight_id] = patch.notes return {"flight_id": flight_id, "notes": patch.notes} # ── WebSocket ── @app.websocket("/ws/agent-events") async def agent_events_ws(ws: WebSocket): await event_hub.connect(ws) try: while True: # Keep connection alive — client can send pings await ws.receive_text() except WebSocketDisconnect: event_hub.disconnect(ws)