Files
nova/api/main.py

388 lines
11 KiB
Python

"""FastAPI application — triggers agents, exposes scenarios, streams events."""
import asyncio
import json
import logging
import uuid
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from agents.fce import run_fce
from agents.handover import run_handover
from agents.shared.mcp_client import connect_servers
from mcp_servers.data.scenarios.manager import scenario_manager
logger = logging.getLogger("nova")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
# ── WebSocket event hub ──
class EventHub:
"""Fans out agent events to connected WebSocket clients."""
def __init__(self):
self._clients: set[WebSocket] = set()
async def connect(self, ws: WebSocket):
await ws.accept()
self._clients.add(ws)
def disconnect(self, ws: WebSocket):
self._clients.discard(ws)
async def broadcast(self, event: dict):
dead = set()
for ws in self._clients:
try:
await ws.send_json(event)
except Exception:
dead.add(ws)
if dead:
logger.info("ws_cleanup removed=%d remaining=%d", len(dead), len(self._clients) - len(dead))
self._clients -= dead
event_hub = EventHub()
# ── In-memory run store ──
runs: dict[str, dict] = {}
RUN_TTL_SECONDS = 3600
RUN_CLEANUP_INTERVAL = 300
# ── App lifecycle ──
async def _cleanup_runs():
"""Prune completed/errored runs older than RUN_TTL_SECONDS."""
while True:
await asyncio.sleep(RUN_CLEANUP_INTERVAL)
now = datetime.now(timezone.utc).timestamp()
expired = [
rid for rid, r in runs.items()
if r["status"] in ("completed", "error")
and now - r.get("created_at", now) > RUN_TTL_SECONDS
]
for rid in expired:
del runs[rid]
@asynccontextmanager
async def lifespan(app: FastAPI):
task = asyncio.create_task(_cleanup_runs())
yield
task.cancel()
app = FastAPI(title="United Ops MCP Demo", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── Health check ──
@app.get("/health")
async def health():
return {
"status": "ok",
"scenario": scenario_manager.active_id,
"runs_in_memory": len(runs),
}
# ── Request/Response models ──
class FCERequest(BaseModel):
flight_id: str
class HandoverRequest(BaseModel):
hubs: list[str] | None = None
class ScenarioUpdate(BaseModel):
scenario_id: str
# ── Agent routes ──
@app.post("/agents/fce")
async def trigger_fce(req: FCERequest):
run_id = str(uuid.uuid4())[:8]
now = datetime.now(timezone.utc)
runs[run_id] = {
"status": "running", "agent": "fce",
"flight_id": req.flight_id, "created_at": now.timestamp(),
}
async def _run():
await event_hub.broadcast({
"type": "agent_start", "run_id": run_id,
"agent": "fce", "flight_id": req.flight_id,
"timestamp": now.isoformat(),
})
async def on_event(event):
await event_hub.broadcast({"run_id": run_id, **event})
try:
async with connect_servers(["shared", "ops", "passenger"]) as mcp:
result = await run_fce(req.flight_id, mcp, on_event=on_event)
runs[run_id] = {"status": "completed", "agent": "fce", "result": result}
logger.info("agent_complete agent=fce run_id=%s flight=%s", run_id, req.flight_id)
except Exception as e:
runs[run_id] = {"status": "error", "agent": "fce", "error": str(e)}
logger.error("agent_error agent=fce run_id=%s error=%s", run_id, e)
logger.info("agent_start agent=fce run_id=%s flight=%s", run_id, req.flight_id)
asyncio.create_task(_run())
return {"run_id": run_id, "status": "running"}
@app.post("/agents/handover")
async def trigger_handover(req: HandoverRequest):
run_id = str(uuid.uuid4())[:8]
now = datetime.now(timezone.utc)
runs[run_id] = {
"status": "running", "agent": "handover",
"hubs": req.hubs, "created_at": now.timestamp(),
}
async def _run():
await event_hub.broadcast({
"type": "agent_start", "run_id": run_id,
"agent": "handover", "hubs": req.hubs,
"timestamp": now.isoformat(),
})
async def on_event(event):
await event_hub.broadcast({"run_id": run_id, **event})
try:
async with connect_servers(["shared", "ops"]) as mcp:
result = await run_handover(hubs=req.hubs, mcp=mcp, on_event=on_event)
runs[run_id] = {"status": "completed", "agent": "handover", "result": result}
logger.info("agent_complete agent=handover run_id=%s hubs=%s", run_id, req.hubs)
except Exception as e:
runs[run_id] = {"status": "error", "agent": "handover", "error": str(e)}
logger.error("agent_error agent=handover run_id=%s error=%s", run_id, e)
logger.info("agent_start agent=handover run_id=%s hubs=%s", run_id, req.hubs)
asyncio.create_task(_run())
return {"run_id": run_id, "status": "running"}
@app.get("/agents/runs/{run_id}")
async def get_run(run_id: str):
if run_id not in runs:
raise HTTPException(404, detail=f"Run {run_id} not found")
return runs[run_id]
@app.get("/agents/runs")
async def list_runs():
return [
{"run_id": rid, "status": r["status"], "agent": r.get("agent")}
for rid, r in sorted(runs.items(), reverse=True)
]
# ── Scenario routes ──
@app.get("/scenarios")
async def list_scenarios():
return scenario_manager.list_scenarios()
@app.get("/scenarios/active")
async def get_active_scenario():
return scenario_manager.get_metadata()
@app.put("/scenarios/active")
async def set_active_scenario(req: ScenarioUpdate):
try:
result = scenario_manager.set_active(req.scenario_id)
logger.info("scenario_switch scenario=%s", req.scenario_id)
return result
except ValueError as e:
raise HTTPException(400, detail=str(e))
# ── LLM config routes ──
@app.get("/config/llm")
async def get_llm_config():
"""Current LLM provider configuration."""
from api.config import get_settings
s = get_settings()
return {
"provider": s.llm_provider,
"providers": {
"groq": {
"configured": bool(s.groq_api_key),
"model": s.groq_model,
},
"anthropic": {
"configured": bool(s.anthropic_api_key),
"model": s.anthropic_model,
},
"bedrock": {
"configured": bool(s.aws_access_key_id),
"model": s.bedrock_model_id,
},
"openai": {
"configured": bool(s.openai_api_key),
"model": s.openai_model,
"base_url": s.openai_base_url,
},
"template": {
"configured": True,
"model": "none (structured fallback)",
},
},
}
class LLMConfigUpdate(BaseModel):
provider: str
api_key: str | None = None
model: str | None = None
base_url: str | None = None
@app.put("/config/llm")
async def set_llm_config(req: LLMConfigUpdate):
"""Switch LLM provider at runtime. Sets env vars for MCP subprocesses."""
import os
os.environ["LLM_PROVIDER"] = req.provider
if req.provider == "groq" and req.api_key:
os.environ["GROQ_API_KEY"] = req.api_key
if req.model:
os.environ["GROQ_MODEL"] = req.model
elif req.provider == "anthropic" and req.api_key:
os.environ["ANTHROPIC_API_KEY"] = req.api_key
if req.model:
os.environ["ANTHROPIC_MODEL"] = req.model
elif req.provider == "openai" and req.api_key:
os.environ["OPENAI_API_KEY"] = req.api_key
if req.model:
os.environ["OPENAI_MODEL"] = req.model
if req.base_url:
os.environ["OPENAI_BASE_URL"] = req.base_url
elif req.provider == "bedrock":
if req.model:
os.environ["BEDROCK_MODEL_ID"] = req.model
logger.info("llm_config_change provider=%s", req.provider)
return await get_llm_config()
# ── Scenario data routes ──
@app.get("/scenarios/data/flights")
async def get_scenario_flights():
return [f.model_dump(mode="json") for f in scenario_manager.flights]
@app.get("/scenarios/data/crew")
async def get_scenario_crew():
crew = []
for c in scenario_manager.crew:
d = c.model_dump(mode="json")
d["hours_until_limit"] = round(c.duty_hours_limit - c.duty_hours_elapsed, 2)
d["at_risk"] = d["hours_until_limit"] <= 2.0
crew.append(d)
return crew
@app.get("/scenarios/data/crew-notes")
async def get_scenario_crew_notes():
return scenario_manager.crew_notes
@app.get("/scenarios/data/maintenance")
async def get_scenario_maintenance():
result = {}
for tail, items in scenario_manager.maintenance.items():
result[tail] = [i.model_dump(mode="json") for i in items]
return result
@app.get("/scenarios/data/rebookings")
async def get_scenario_rebookings():
return [r.model_dump(mode="json") for r in scenario_manager.rebookings]
class FlightPatch(BaseModel):
delay_minutes: int | None = None
status: str | None = None
gate: str | None = None
@app.patch("/scenarios/data/flights/{flight_id}")
async def patch_flight(flight_id: str, patch: FlightPatch):
for f in scenario_manager.flights:
if f.flight_id == flight_id:
if patch.delay_minutes is not None:
f.delay_minutes = patch.delay_minutes
if patch.status is not None:
f.status = patch.status
if patch.gate is not None:
f.gate = patch.gate
return f.model_dump(mode="json")
raise HTTPException(404, detail=f"Flight {flight_id} not found")
class CrewPatch(BaseModel):
duty_hours_elapsed: float | None = None
@app.patch("/scenarios/data/crew/{crew_id}")
async def patch_crew(crew_id: str, patch: CrewPatch):
for c in scenario_manager.crew:
if c.crew_id == crew_id:
if patch.duty_hours_elapsed is not None:
c.duty_hours_elapsed = patch.duty_hours_elapsed
d = c.model_dump(mode="json")
d["hours_until_limit"] = round(c.duty_hours_limit - c.duty_hours_elapsed, 2)
d["at_risk"] = d["hours_until_limit"] <= 2.0
return d
raise HTTPException(404, detail=f"Crew {crew_id} not found")
class CrewNotesPatch(BaseModel):
notes: list[str]
@app.put("/scenarios/data/crew-notes/{flight_id}")
async def put_crew_notes(flight_id: str, patch: CrewNotesPatch):
scenario_manager.crew_notes[flight_id] = patch.notes
return {"flight_id": flight_id, "notes": patch.notes}
# ── WebSocket ──
@app.websocket("/ws/agent-events")
async def agent_events_ws(ws: WebSocket):
await event_hub.connect(ws)
try:
while True:
# Keep connection alive — client can send pings
await ws.receive_text()
except WebSocketDisconnect:
event_hub.disconnect(ws)