435 lines
13 KiB
Python
435 lines
13 KiB
Python
"""FastAPI application — triggers agents, exposes scenarios, streams events."""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
from datetime import datetime, timezone
|
|
|
|
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
|
|
from agents.fce import run_fce
|
|
from agents.handover import run_handover
|
|
from agents.shared.mcp_client import connect_servers
|
|
from mcp_servers.data.scenarios.manager import scenario_manager
|
|
|
|
logger = logging.getLogger("nova")
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
|
)
|
|
|
|
|
|
def _get_langfuse():
|
|
"""Lazy Langfuse client — returns None if not configured."""
|
|
try:
|
|
from api.config import get_settings
|
|
s = get_settings()
|
|
if not s.langfuse_public_key or not s.langfuse_secret_key:
|
|
return None
|
|
from langfuse import Langfuse
|
|
return Langfuse(
|
|
public_key=s.langfuse_public_key,
|
|
secret_key=s.langfuse_secret_key,
|
|
host=s.langfuse_host,
|
|
)
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
# ── WebSocket event hub ──
|
|
|
|
class EventHub:
|
|
"""Fans out agent events to connected WebSocket clients."""
|
|
|
|
def __init__(self):
|
|
self._clients: set[WebSocket] = set()
|
|
|
|
async def connect(self, ws: WebSocket):
|
|
await ws.accept()
|
|
self._clients.add(ws)
|
|
|
|
def disconnect(self, ws: WebSocket):
|
|
self._clients.discard(ws)
|
|
|
|
async def broadcast(self, event: dict):
|
|
dead = set()
|
|
for ws in self._clients:
|
|
try:
|
|
await ws.send_json(event)
|
|
except Exception:
|
|
dead.add(ws)
|
|
if dead:
|
|
logger.info("ws_cleanup removed=%d remaining=%d", len(dead), len(self._clients) - len(dead))
|
|
self._clients -= dead
|
|
|
|
|
|
event_hub = EventHub()
|
|
|
|
# ── In-memory run store ──
|
|
|
|
runs: dict[str, dict] = {}
|
|
|
|
RUN_TTL_SECONDS = 3600
|
|
RUN_CLEANUP_INTERVAL = 300
|
|
|
|
|
|
# ── App lifecycle ──
|
|
|
|
async def _cleanup_runs():
|
|
"""Prune completed/errored runs older than RUN_TTL_SECONDS."""
|
|
while True:
|
|
await asyncio.sleep(RUN_CLEANUP_INTERVAL)
|
|
now = datetime.now(timezone.utc).timestamp()
|
|
expired = [
|
|
rid for rid, r in runs.items()
|
|
if r["status"] in ("completed", "error")
|
|
and now - r.get("created_at", now) > RUN_TTL_SECONDS
|
|
]
|
|
for rid in expired:
|
|
del runs[rid]
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
task = asyncio.create_task(_cleanup_runs())
|
|
yield
|
|
task.cancel()
|
|
|
|
app = FastAPI(title="United Ops MCP Demo", lifespan=lifespan)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# ── Health check ──
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {
|
|
"status": "ok",
|
|
"scenario": scenario_manager.active_id,
|
|
"runs_in_memory": len(runs),
|
|
}
|
|
|
|
|
|
# ── Request/Response models ──
|
|
|
|
class FCERequest(BaseModel):
|
|
flight_id: str
|
|
|
|
class HandoverRequest(BaseModel):
|
|
hubs: list[str] | None = None
|
|
|
|
class ScenarioUpdate(BaseModel):
|
|
scenario_id: str
|
|
|
|
|
|
# ── Agent routes ──
|
|
|
|
@app.post("/agents/fce")
|
|
async def trigger_fce(req: FCERequest):
|
|
run_id = str(uuid.uuid4())[:8]
|
|
now = datetime.now(timezone.utc)
|
|
runs[run_id] = {
|
|
"status": "running", "agent": "fce",
|
|
"flight_id": req.flight_id, "created_at": now.timestamp(),
|
|
}
|
|
|
|
async def _run():
|
|
await event_hub.broadcast({
|
|
"type": "agent_start", "run_id": run_id,
|
|
"agent": "fce", "flight_id": req.flight_id,
|
|
"timestamp": now.isoformat(),
|
|
})
|
|
|
|
async def on_event(event):
|
|
await event_hub.broadcast({"run_id": run_id, **event})
|
|
|
|
langfuse = _get_langfuse()
|
|
|
|
try:
|
|
if langfuse:
|
|
with langfuse.start_as_current_observation(
|
|
name="fce", as_type="agent",
|
|
input={"flight_id": req.flight_id},
|
|
metadata={"run_id": run_id, "scenario": scenario_manager.active_id},
|
|
):
|
|
async with connect_servers(["shared", "ops", "passenger"]) as mcp:
|
|
result = await run_fce(req.flight_id, mcp, on_event=on_event, langfuse=langfuse)
|
|
langfuse.set_current_trace_io(output=result)
|
|
else:
|
|
async with connect_servers(["shared", "ops", "passenger"]) as mcp:
|
|
result = await run_fce(req.flight_id, mcp, on_event=on_event)
|
|
runs[run_id] = {"status": "completed", "agent": "fce", "result": result}
|
|
logger.info("agent_complete agent=fce run_id=%s flight=%s", run_id, req.flight_id)
|
|
except Exception as e:
|
|
runs[run_id] = {"status": "error", "agent": "fce", "error": str(e)}
|
|
logger.error("agent_error agent=fce run_id=%s error=%s", run_id, e)
|
|
finally:
|
|
if langfuse:
|
|
langfuse.flush()
|
|
|
|
logger.info("agent_start agent=fce run_id=%s flight=%s", run_id, req.flight_id)
|
|
asyncio.create_task(_run())
|
|
return {"run_id": run_id, "status": "running"}
|
|
|
|
|
|
@app.post("/agents/handover")
|
|
async def trigger_handover(req: HandoverRequest):
|
|
run_id = str(uuid.uuid4())[:8]
|
|
now = datetime.now(timezone.utc)
|
|
runs[run_id] = {
|
|
"status": "running", "agent": "handover",
|
|
"hubs": req.hubs, "created_at": now.timestamp(),
|
|
}
|
|
|
|
async def _run():
|
|
await event_hub.broadcast({
|
|
"type": "agent_start", "run_id": run_id,
|
|
"agent": "handover", "hubs": req.hubs,
|
|
"timestamp": now.isoformat(),
|
|
})
|
|
|
|
async def on_event(event):
|
|
await event_hub.broadcast({"run_id": run_id, **event})
|
|
|
|
langfuse = _get_langfuse()
|
|
|
|
try:
|
|
if langfuse:
|
|
with langfuse.start_as_current_observation(
|
|
name="handover", as_type="agent",
|
|
input={"hubs": req.hubs},
|
|
metadata={"run_id": run_id, "scenario": scenario_manager.active_id},
|
|
):
|
|
async with connect_servers(["shared", "ops"]) as mcp:
|
|
result = await run_handover(hubs=req.hubs, mcp=mcp, on_event=on_event, langfuse=langfuse)
|
|
langfuse.set_current_trace_io(output=result)
|
|
else:
|
|
async with connect_servers(["shared", "ops"]) as mcp:
|
|
result = await run_handover(hubs=req.hubs, mcp=mcp, on_event=on_event)
|
|
runs[run_id] = {"status": "completed", "agent": "handover", "result": result}
|
|
logger.info("agent_complete agent=handover run_id=%s hubs=%s", run_id, req.hubs)
|
|
except Exception as e:
|
|
runs[run_id] = {"status": "error", "agent": "handover", "error": str(e)}
|
|
logger.error("agent_error agent=handover run_id=%s error=%s", run_id, e)
|
|
finally:
|
|
if langfuse:
|
|
langfuse.flush()
|
|
|
|
logger.info("agent_start agent=handover run_id=%s hubs=%s", run_id, req.hubs)
|
|
asyncio.create_task(_run())
|
|
return {"run_id": run_id, "status": "running"}
|
|
|
|
|
|
@app.get("/agents/runs/{run_id}")
|
|
async def get_run(run_id: str):
|
|
if run_id not in runs:
|
|
raise HTTPException(404, detail=f"Run {run_id} not found")
|
|
return runs[run_id]
|
|
|
|
|
|
@app.get("/agents/runs")
|
|
async def list_runs():
|
|
return [
|
|
{"run_id": rid, "status": r["status"], "agent": r.get("agent")}
|
|
for rid, r in sorted(runs.items(), reverse=True)
|
|
]
|
|
|
|
|
|
# ── Scenario routes ──
|
|
|
|
@app.get("/scenarios")
|
|
async def list_scenarios():
|
|
return scenario_manager.list_scenarios()
|
|
|
|
|
|
@app.get("/scenarios/active")
|
|
async def get_active_scenario():
|
|
return scenario_manager.get_metadata()
|
|
|
|
|
|
@app.put("/scenarios/active")
|
|
async def set_active_scenario(req: ScenarioUpdate):
|
|
try:
|
|
result = scenario_manager.set_active(req.scenario_id)
|
|
logger.info("scenario_switch scenario=%s", req.scenario_id)
|
|
return result
|
|
except ValueError as e:
|
|
raise HTTPException(400, detail=str(e))
|
|
|
|
|
|
# ── LLM config routes ──
|
|
|
|
@app.get("/config/llm")
|
|
async def get_llm_config():
|
|
"""Current LLM provider configuration."""
|
|
from api.config import get_settings
|
|
s = get_settings()
|
|
return {
|
|
"provider": s.llm_provider,
|
|
"providers": {
|
|
"groq": {
|
|
"configured": bool(s.groq_api_key),
|
|
"model": s.groq_model,
|
|
},
|
|
"anthropic": {
|
|
"configured": bool(s.anthropic_api_key),
|
|
"model": s.anthropic_model,
|
|
},
|
|
"bedrock": {
|
|
"configured": bool(s.aws_access_key_id),
|
|
"model": s.bedrock_model_id,
|
|
},
|
|
"openai": {
|
|
"configured": bool(s.openai_api_key),
|
|
"model": s.openai_model,
|
|
"base_url": s.openai_base_url,
|
|
},
|
|
"template": {
|
|
"configured": True,
|
|
"model": "none (structured fallback)",
|
|
},
|
|
},
|
|
}
|
|
|
|
|
|
class LLMConfigUpdate(BaseModel):
|
|
provider: str
|
|
api_key: str | None = None
|
|
model: str | None = None
|
|
base_url: str | None = None
|
|
|
|
|
|
@app.put("/config/llm")
|
|
async def set_llm_config(req: LLMConfigUpdate):
|
|
"""Switch LLM provider at runtime. Sets env vars for MCP subprocesses."""
|
|
import os
|
|
|
|
os.environ["LLM_PROVIDER"] = req.provider
|
|
|
|
if req.provider == "groq" and req.api_key:
|
|
os.environ["GROQ_API_KEY"] = req.api_key
|
|
if req.model:
|
|
os.environ["GROQ_MODEL"] = req.model
|
|
elif req.provider == "anthropic" and req.api_key:
|
|
os.environ["ANTHROPIC_API_KEY"] = req.api_key
|
|
if req.model:
|
|
os.environ["ANTHROPIC_MODEL"] = req.model
|
|
elif req.provider == "openai" and req.api_key:
|
|
os.environ["OPENAI_API_KEY"] = req.api_key
|
|
if req.model:
|
|
os.environ["OPENAI_MODEL"] = req.model
|
|
if req.base_url:
|
|
os.environ["OPENAI_BASE_URL"] = req.base_url
|
|
elif req.provider == "bedrock":
|
|
if req.model:
|
|
os.environ["BEDROCK_MODEL_ID"] = req.model
|
|
|
|
logger.info("llm_config_change provider=%s", req.provider)
|
|
return await get_llm_config()
|
|
|
|
|
|
# ── Scenario data routes ──
|
|
|
|
@app.get("/scenarios/data/flights")
|
|
async def get_scenario_flights():
|
|
return [f.model_dump(mode="json") for f in scenario_manager.flights]
|
|
|
|
|
|
@app.get("/scenarios/data/crew")
|
|
async def get_scenario_crew():
|
|
crew = []
|
|
for c in scenario_manager.crew:
|
|
d = c.model_dump(mode="json")
|
|
d["hours_until_limit"] = round(c.duty_hours_limit - c.duty_hours_elapsed, 2)
|
|
d["at_risk"] = d["hours_until_limit"] <= 2.0
|
|
crew.append(d)
|
|
return crew
|
|
|
|
|
|
@app.get("/scenarios/data/crew-notes")
|
|
async def get_scenario_crew_notes():
|
|
return scenario_manager.crew_notes
|
|
|
|
|
|
@app.get("/scenarios/data/maintenance")
|
|
async def get_scenario_maintenance():
|
|
result = {}
|
|
for tail, items in scenario_manager.maintenance.items():
|
|
result[tail] = [i.model_dump(mode="json") for i in items]
|
|
return result
|
|
|
|
|
|
@app.get("/scenarios/data/rebookings")
|
|
async def get_scenario_rebookings():
|
|
return [r.model_dump(mode="json") for r in scenario_manager.rebookings]
|
|
|
|
|
|
class FlightPatch(BaseModel):
|
|
delay_minutes: int | None = None
|
|
status: str | None = None
|
|
gate: str | None = None
|
|
|
|
|
|
@app.patch("/scenarios/data/flights/{flight_id}")
|
|
async def patch_flight(flight_id: str, patch: FlightPatch):
|
|
for f in scenario_manager.flights:
|
|
if f.flight_id == flight_id:
|
|
if patch.delay_minutes is not None:
|
|
f.delay_minutes = patch.delay_minutes
|
|
if patch.status is not None:
|
|
f.status = patch.status
|
|
if patch.gate is not None:
|
|
f.gate = patch.gate
|
|
return f.model_dump(mode="json")
|
|
raise HTTPException(404, detail=f"Flight {flight_id} not found")
|
|
|
|
|
|
class CrewPatch(BaseModel):
|
|
duty_hours_elapsed: float | None = None
|
|
|
|
|
|
@app.patch("/scenarios/data/crew/{crew_id}")
|
|
async def patch_crew(crew_id: str, patch: CrewPatch):
|
|
for c in scenario_manager.crew:
|
|
if c.crew_id == crew_id:
|
|
if patch.duty_hours_elapsed is not None:
|
|
c.duty_hours_elapsed = patch.duty_hours_elapsed
|
|
d = c.model_dump(mode="json")
|
|
d["hours_until_limit"] = round(c.duty_hours_limit - c.duty_hours_elapsed, 2)
|
|
d["at_risk"] = d["hours_until_limit"] <= 2.0
|
|
return d
|
|
raise HTTPException(404, detail=f"Crew {crew_id} not found")
|
|
|
|
|
|
class CrewNotesPatch(BaseModel):
|
|
notes: list[str]
|
|
|
|
|
|
@app.put("/scenarios/data/crew-notes/{flight_id}")
|
|
async def put_crew_notes(flight_id: str, patch: CrewNotesPatch):
|
|
scenario_manager.crew_notes[flight_id] = patch.notes
|
|
return {"flight_id": flight_id, "notes": patch.notes}
|
|
|
|
|
|
# ── WebSocket ──
|
|
|
|
@app.websocket("/ws/agent-events")
|
|
async def agent_events_ws(ws: WebSocket):
|
|
await event_hub.connect(ws)
|
|
try:
|
|
while True:
|
|
# Keep connection alive — client can send pings
|
|
await ws.receive_text()
|
|
except WebSocketDisconnect:
|
|
event_hub.disconnect(ws)
|