init commit
This commit is contained in:
261
api/main.py
Normal file
261
api/main.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""FastAPI application — triggers agents, exposes scenarios, streams events."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agents.efhas import run_efhas
|
||||
from agents.handover import run_handover
|
||||
from agents.shared.mcp_client import connect_servers
|
||||
from mcp_servers.data.scenarios.manager import scenario_manager
|
||||
|
||||
|
||||
# ── WebSocket event hub ──
|
||||
|
||||
class EventHub:
|
||||
"""Fans out agent events to connected WebSocket clients."""
|
||||
|
||||
def __init__(self):
|
||||
self._clients: set[WebSocket] = set()
|
||||
|
||||
async def connect(self, ws: WebSocket):
|
||||
await ws.accept()
|
||||
self._clients.add(ws)
|
||||
|
||||
def disconnect(self, ws: WebSocket):
|
||||
self._clients.discard(ws)
|
||||
|
||||
async def broadcast(self, event: dict):
|
||||
dead = set()
|
||||
for ws in self._clients:
|
||||
try:
|
||||
await ws.send_json(event)
|
||||
except Exception:
|
||||
dead.add(ws)
|
||||
self._clients -= dead
|
||||
|
||||
|
||||
event_hub = EventHub()
|
||||
|
||||
# ── In-memory run store ──
|
||||
|
||||
runs: dict[str, dict] = {}
|
||||
|
||||
|
||||
# ── App lifecycle ──
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
yield
|
||||
|
||||
app = FastAPI(title="United Ops MCP Demo", lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# ── Request/Response models ──
|
||||
|
||||
class EFHaSRequest(BaseModel):
|
||||
flight_id: str
|
||||
|
||||
class HandoverRequest(BaseModel):
|
||||
hubs: list[str] | None = None
|
||||
|
||||
class ScenarioUpdate(BaseModel):
|
||||
scenario_id: str
|
||||
|
||||
|
||||
# ── Agent routes ──
|
||||
|
||||
@app.post("/agents/efhas")
|
||||
async def trigger_efhas(req: EFHaSRequest):
|
||||
run_id = str(uuid.uuid4())[:8]
|
||||
runs[run_id] = {"status": "running", "agent": "efhas", "flight_id": req.flight_id}
|
||||
|
||||
async def _run():
|
||||
await event_hub.broadcast({
|
||||
"type": "agent_start", "run_id": run_id,
|
||||
"agent": "efhas", "flight_id": req.flight_id,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
|
||||
async def on_event(event):
|
||||
await event_hub.broadcast({"run_id": run_id, **event})
|
||||
|
||||
try:
|
||||
async with connect_servers(["shared", "ops", "passenger"]) as mcp:
|
||||
result = await run_efhas(req.flight_id, mcp, on_event=on_event)
|
||||
runs[run_id] = {"status": "completed", "agent": "efhas", "result": result}
|
||||
except Exception as e:
|
||||
runs[run_id] = {"status": "error", "agent": "efhas", "error": str(e)}
|
||||
|
||||
asyncio.create_task(_run())
|
||||
return {"run_id": run_id, "status": "running"}
|
||||
|
||||
|
||||
@app.post("/agents/handover")
|
||||
async def trigger_handover(req: HandoverRequest):
|
||||
run_id = str(uuid.uuid4())[:8]
|
||||
runs[run_id] = {"status": "running", "agent": "handover", "hubs": req.hubs}
|
||||
|
||||
async def _run():
|
||||
await event_hub.broadcast({
|
||||
"type": "agent_start", "run_id": run_id,
|
||||
"agent": "handover", "hubs": req.hubs,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
|
||||
async def on_event(event):
|
||||
await event_hub.broadcast({"run_id": run_id, **event})
|
||||
|
||||
try:
|
||||
async with connect_servers(["shared", "ops"]) as mcp:
|
||||
result = await run_handover(hubs=req.hubs, mcp=mcp, on_event=on_event)
|
||||
runs[run_id] = {"status": "completed", "agent": "handover", "result": result}
|
||||
except Exception as e:
|
||||
runs[run_id] = {"status": "error", "agent": "handover", "error": str(e)}
|
||||
|
||||
asyncio.create_task(_run())
|
||||
return {"run_id": run_id, "status": "running"}
|
||||
|
||||
|
||||
@app.get("/agents/runs/{run_id}")
|
||||
async def get_run(run_id: str):
|
||||
if run_id not in runs:
|
||||
return {"error": f"Run {run_id} not found"}
|
||||
return runs[run_id]
|
||||
|
||||
|
||||
@app.get("/agents/runs")
|
||||
async def list_runs():
|
||||
return [
|
||||
{"run_id": rid, "status": r["status"], "agent": r.get("agent")}
|
||||
for rid, r in sorted(runs.items(), reverse=True)
|
||||
]
|
||||
|
||||
|
||||
# ── Scenario routes ──
|
||||
|
||||
@app.get("/scenarios")
|
||||
async def list_scenarios():
|
||||
return scenario_manager.list_scenarios()
|
||||
|
||||
|
||||
@app.get("/scenarios/active")
|
||||
async def get_active_scenario():
|
||||
return scenario_manager.get_metadata()
|
||||
|
||||
|
||||
@app.put("/scenarios/active")
|
||||
async def set_active_scenario(req: ScenarioUpdate):
|
||||
try:
|
||||
return scenario_manager.set_active(req.scenario_id)
|
||||
except ValueError as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
# ── Scenario data routes ──
|
||||
|
||||
@app.get("/scenarios/data/flights")
|
||||
async def get_scenario_flights():
|
||||
return [f.model_dump(mode="json") for f in scenario_manager.flights]
|
||||
|
||||
|
||||
@app.get("/scenarios/data/crew")
|
||||
async def get_scenario_crew():
|
||||
crew = []
|
||||
for c in scenario_manager.crew:
|
||||
d = c.model_dump(mode="json")
|
||||
d["hours_until_limit"] = round(c.duty_hours_limit - c.duty_hours_elapsed, 2)
|
||||
d["at_risk"] = d["hours_until_limit"] <= 2.0
|
||||
crew.append(d)
|
||||
return crew
|
||||
|
||||
|
||||
@app.get("/scenarios/data/crew-notes")
|
||||
async def get_scenario_crew_notes():
|
||||
return scenario_manager.crew_notes
|
||||
|
||||
|
||||
@app.get("/scenarios/data/maintenance")
|
||||
async def get_scenario_maintenance():
|
||||
result = {}
|
||||
for tail, items in scenario_manager.maintenance.items():
|
||||
result[tail] = [i.model_dump(mode="json") for i in items]
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/scenarios/data/rebookings")
|
||||
async def get_scenario_rebookings():
|
||||
return [r.model_dump(mode="json") for r in scenario_manager.rebookings]
|
||||
|
||||
|
||||
class FlightPatch(BaseModel):
|
||||
delay_minutes: int | None = None
|
||||
status: str | None = None
|
||||
gate: str | None = None
|
||||
|
||||
|
||||
@app.patch("/scenarios/data/flights/{flight_id}")
|
||||
async def patch_flight(flight_id: str, patch: FlightPatch):
|
||||
for f in scenario_manager.flights:
|
||||
if f.flight_id == flight_id:
|
||||
if patch.delay_minutes is not None:
|
||||
f.delay_minutes = patch.delay_minutes
|
||||
if patch.status is not None:
|
||||
f.status = patch.status
|
||||
if patch.gate is not None:
|
||||
f.gate = patch.gate
|
||||
return f.model_dump(mode="json")
|
||||
return {"error": f"Flight {flight_id} not found"}
|
||||
|
||||
|
||||
class CrewPatch(BaseModel):
|
||||
duty_hours_elapsed: float | None = None
|
||||
|
||||
|
||||
@app.patch("/scenarios/data/crew/{crew_id}")
|
||||
async def patch_crew(crew_id: str, patch: CrewPatch):
|
||||
for c in scenario_manager.crew:
|
||||
if c.crew_id == crew_id:
|
||||
if patch.duty_hours_elapsed is not None:
|
||||
c.duty_hours_elapsed = patch.duty_hours_elapsed
|
||||
d = c.model_dump(mode="json")
|
||||
d["hours_until_limit"] = round(c.duty_hours_limit - c.duty_hours_elapsed, 2)
|
||||
d["at_risk"] = d["hours_until_limit"] <= 2.0
|
||||
return d
|
||||
return {"error": f"Crew {crew_id} not found"}
|
||||
|
||||
|
||||
class CrewNotesPatch(BaseModel):
|
||||
notes: list[str]
|
||||
|
||||
|
||||
@app.put("/scenarios/data/crew-notes/{flight_id}")
|
||||
async def put_crew_notes(flight_id: str, patch: CrewNotesPatch):
|
||||
scenario_manager.crew_notes[flight_id] = patch.notes
|
||||
return {"flight_id": flight_id, "notes": patch.notes}
|
||||
|
||||
|
||||
# ── WebSocket ──
|
||||
|
||||
@app.websocket("/ws/agent-events")
|
||||
async def agent_events_ws(ws: WebSocket):
|
||||
await event_hub.connect(ws)
|
||||
try:
|
||||
while True:
|
||||
# Keep connection alive — client can send pings
|
||||
await ws.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
event_hub.disconnect(ws)
|
||||
Reference in New Issue
Block a user