262 lines
7.4 KiB
Python
262 lines
7.4 KiB
Python
"""FastAPI application — triggers agents, exposes scenarios, streams events."""
|
|
|
|
import asyncio
|
|
import json
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
from datetime import datetime, timezone
|
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
|
|
from agents.efhas import run_efhas
|
|
from agents.handover import run_handover
|
|
from agents.shared.mcp_client import connect_servers
|
|
from mcp_servers.data.scenarios.manager import scenario_manager
|
|
|
|
|
|
# ── WebSocket event hub ──
|
|
|
|
class EventHub:
|
|
"""Fans out agent events to connected WebSocket clients."""
|
|
|
|
def __init__(self):
|
|
self._clients: set[WebSocket] = set()
|
|
|
|
async def connect(self, ws: WebSocket):
|
|
await ws.accept()
|
|
self._clients.add(ws)
|
|
|
|
def disconnect(self, ws: WebSocket):
|
|
self._clients.discard(ws)
|
|
|
|
async def broadcast(self, event: dict):
|
|
dead = set()
|
|
for ws in self._clients:
|
|
try:
|
|
await ws.send_json(event)
|
|
except Exception:
|
|
dead.add(ws)
|
|
self._clients -= dead
|
|
|
|
|
|
event_hub = EventHub()
|
|
|
|
# ── In-memory run store ──
|
|
|
|
runs: dict[str, dict] = {}
|
|
|
|
|
|
# ── App lifecycle ──
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
yield
|
|
|
|
app = FastAPI(title="United Ops MCP Demo", lifespan=lifespan)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# ── Request/Response models ──
|
|
|
|
class EFHaSRequest(BaseModel):
|
|
flight_id: str
|
|
|
|
class HandoverRequest(BaseModel):
|
|
hubs: list[str] | None = None
|
|
|
|
class ScenarioUpdate(BaseModel):
|
|
scenario_id: str
|
|
|
|
|
|
# ── Agent routes ──
|
|
|
|
@app.post("/agents/efhas")
|
|
async def trigger_efhas(req: EFHaSRequest):
|
|
run_id = str(uuid.uuid4())[:8]
|
|
runs[run_id] = {"status": "running", "agent": "efhas", "flight_id": req.flight_id}
|
|
|
|
async def _run():
|
|
await event_hub.broadcast({
|
|
"type": "agent_start", "run_id": run_id,
|
|
"agent": "efhas", "flight_id": req.flight_id,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
})
|
|
|
|
async def on_event(event):
|
|
await event_hub.broadcast({"run_id": run_id, **event})
|
|
|
|
try:
|
|
async with connect_servers(["shared", "ops", "passenger"]) as mcp:
|
|
result = await run_efhas(req.flight_id, mcp, on_event=on_event)
|
|
runs[run_id] = {"status": "completed", "agent": "efhas", "result": result}
|
|
except Exception as e:
|
|
runs[run_id] = {"status": "error", "agent": "efhas", "error": str(e)}
|
|
|
|
asyncio.create_task(_run())
|
|
return {"run_id": run_id, "status": "running"}
|
|
|
|
|
|
@app.post("/agents/handover")
|
|
async def trigger_handover(req: HandoverRequest):
|
|
run_id = str(uuid.uuid4())[:8]
|
|
runs[run_id] = {"status": "running", "agent": "handover", "hubs": req.hubs}
|
|
|
|
async def _run():
|
|
await event_hub.broadcast({
|
|
"type": "agent_start", "run_id": run_id,
|
|
"agent": "handover", "hubs": req.hubs,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
})
|
|
|
|
async def on_event(event):
|
|
await event_hub.broadcast({"run_id": run_id, **event})
|
|
|
|
try:
|
|
async with connect_servers(["shared", "ops"]) as mcp:
|
|
result = await run_handover(hubs=req.hubs, mcp=mcp, on_event=on_event)
|
|
runs[run_id] = {"status": "completed", "agent": "handover", "result": result}
|
|
except Exception as e:
|
|
runs[run_id] = {"status": "error", "agent": "handover", "error": str(e)}
|
|
|
|
asyncio.create_task(_run())
|
|
return {"run_id": run_id, "status": "running"}
|
|
|
|
|
|
@app.get("/agents/runs/{run_id}")
|
|
async def get_run(run_id: str):
|
|
if run_id not in runs:
|
|
return {"error": f"Run {run_id} not found"}
|
|
return runs[run_id]
|
|
|
|
|
|
@app.get("/agents/runs")
|
|
async def list_runs():
|
|
return [
|
|
{"run_id": rid, "status": r["status"], "agent": r.get("agent")}
|
|
for rid, r in sorted(runs.items(), reverse=True)
|
|
]
|
|
|
|
|
|
# ── Scenario routes ──
|
|
|
|
@app.get("/scenarios")
|
|
async def list_scenarios():
|
|
return scenario_manager.list_scenarios()
|
|
|
|
|
|
@app.get("/scenarios/active")
|
|
async def get_active_scenario():
|
|
return scenario_manager.get_metadata()
|
|
|
|
|
|
@app.put("/scenarios/active")
|
|
async def set_active_scenario(req: ScenarioUpdate):
|
|
try:
|
|
return scenario_manager.set_active(req.scenario_id)
|
|
except ValueError as e:
|
|
return {"error": str(e)}
|
|
|
|
|
|
# ── Scenario data routes ──
|
|
|
|
@app.get("/scenarios/data/flights")
|
|
async def get_scenario_flights():
|
|
return [f.model_dump(mode="json") for f in scenario_manager.flights]
|
|
|
|
|
|
@app.get("/scenarios/data/crew")
|
|
async def get_scenario_crew():
|
|
crew = []
|
|
for c in scenario_manager.crew:
|
|
d = c.model_dump(mode="json")
|
|
d["hours_until_limit"] = round(c.duty_hours_limit - c.duty_hours_elapsed, 2)
|
|
d["at_risk"] = d["hours_until_limit"] <= 2.0
|
|
crew.append(d)
|
|
return crew
|
|
|
|
|
|
@app.get("/scenarios/data/crew-notes")
|
|
async def get_scenario_crew_notes():
|
|
return scenario_manager.crew_notes
|
|
|
|
|
|
@app.get("/scenarios/data/maintenance")
|
|
async def get_scenario_maintenance():
|
|
result = {}
|
|
for tail, items in scenario_manager.maintenance.items():
|
|
result[tail] = [i.model_dump(mode="json") for i in items]
|
|
return result
|
|
|
|
|
|
@app.get("/scenarios/data/rebookings")
|
|
async def get_scenario_rebookings():
|
|
return [r.model_dump(mode="json") for r in scenario_manager.rebookings]
|
|
|
|
|
|
class FlightPatch(BaseModel):
|
|
delay_minutes: int | None = None
|
|
status: str | None = None
|
|
gate: str | None = None
|
|
|
|
|
|
@app.patch("/scenarios/data/flights/{flight_id}")
|
|
async def patch_flight(flight_id: str, patch: FlightPatch):
|
|
for f in scenario_manager.flights:
|
|
if f.flight_id == flight_id:
|
|
if patch.delay_minutes is not None:
|
|
f.delay_minutes = patch.delay_minutes
|
|
if patch.status is not None:
|
|
f.status = patch.status
|
|
if patch.gate is not None:
|
|
f.gate = patch.gate
|
|
return f.model_dump(mode="json")
|
|
return {"error": f"Flight {flight_id} not found"}
|
|
|
|
|
|
class CrewPatch(BaseModel):
|
|
duty_hours_elapsed: float | None = None
|
|
|
|
|
|
@app.patch("/scenarios/data/crew/{crew_id}")
|
|
async def patch_crew(crew_id: str, patch: CrewPatch):
|
|
for c in scenario_manager.crew:
|
|
if c.crew_id == crew_id:
|
|
if patch.duty_hours_elapsed is not None:
|
|
c.duty_hours_elapsed = patch.duty_hours_elapsed
|
|
d = c.model_dump(mode="json")
|
|
d["hours_until_limit"] = round(c.duty_hours_limit - c.duty_hours_elapsed, 2)
|
|
d["at_risk"] = d["hours_until_limit"] <= 2.0
|
|
return d
|
|
return {"error": f"Crew {crew_id} not found"}
|
|
|
|
|
|
class CrewNotesPatch(BaseModel):
|
|
notes: list[str]
|
|
|
|
|
|
@app.put("/scenarios/data/crew-notes/{flight_id}")
|
|
async def put_crew_notes(flight_id: str, patch: CrewNotesPatch):
|
|
scenario_manager.crew_notes[flight_id] = patch.notes
|
|
return {"flight_id": flight_id, "notes": patch.notes}
|
|
|
|
|
|
# ── WebSocket ──
|
|
|
|
@app.websocket("/ws/agent-events")
|
|
async def agent_events_ws(ws: WebSocket):
|
|
await event_hub.connect(ws)
|
|
try:
|
|
while True:
|
|
# Keep connection alive — client can send pings
|
|
await ws.receive_text()
|
|
except WebSocketDisconnect:
|
|
event_hub.disconnect(ws)
|