phase 6
This commit is contained in:
206
gpu/server.py
Normal file
206
gpu/server.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Inference server — thin HTTP wrapper around ML models.
|
||||
|
||||
Runs on the GPU machine. The detection pipeline calls this over HTTP,
|
||||
or imports the same logic locally if GPU is on the same machine.
|
||||
|
||||
Config is loaded from env on startup, then editable at runtime via
|
||||
GET/PUT /config. The UI config panel is just a visual editor for these
|
||||
same values.
|
||||
|
||||
Usage:
|
||||
cd gpu && uvicorn server:app --host 0.0.0.0 --port 8000
|
||||
# or
|
||||
cd gpu && python server.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Runtime config (loaded from env, mutable via API) ---
|
||||
_config = {
|
||||
"device": os.environ.get("DEVICE", "auto"),
|
||||
"yolo_model": os.environ.get("YOLO_MODEL", "yolov8n.pt"),
|
||||
"yolo_confidence": float(os.environ.get("YOLO_CONFIDENCE", "0.3")),
|
||||
"vram_budget_mb": int(os.environ.get("VRAM_BUDGET_MB", "10240")),
|
||||
"strategy": os.environ.get("STRATEGY", "sequential"),
|
||||
}
|
||||
|
||||
# --- Model registry ---
|
||||
_models: dict[str, object] = {}
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
def _get_device() -> str:
|
||||
device = _config["device"]
|
||||
if device != "auto":
|
||||
return device
|
||||
try:
|
||||
import torch
|
||||
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||
except ImportError:
|
||||
return "cpu"
|
||||
|
||||
|
||||
def _get_yolo(model_name: str | None = None):
|
||||
name = model_name or _config["yolo_model"]
|
||||
if name not in _models:
|
||||
from ultralytics import YOLO
|
||||
device = _get_device()
|
||||
logger.info("Loading %s on %s", name, device)
|
||||
model = YOLO(name)
|
||||
model.to(device)
|
||||
_models[name] = model
|
||||
return _models[name]
|
||||
|
||||
|
||||
def _decode_image(b64: str) -> np.ndarray:
|
||||
data = base64.b64decode(b64)
|
||||
img = Image.open(io.BytesIO(data)).convert("RGB")
|
||||
return np.array(img)
|
||||
|
||||
|
||||
# --- Request/Response models ---
|
||||
|
||||
class DetectRequest(BaseModel):
|
||||
image: str # base64 JPEG
|
||||
model: str | None = None # defaults to config yolo_model
|
||||
confidence: float | None = None # defaults to config yolo_confidence
|
||||
target_classes: list[str] | None = None
|
||||
|
||||
|
||||
class BBox(BaseModel):
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
|
||||
|
||||
class DetectResponse(BaseModel):
|
||||
detections: list[BBox]
|
||||
|
||||
|
||||
class ConfigUpdate(BaseModel):
|
||||
"""Partial config update — only provided fields are changed."""
|
||||
device: str | None = None
|
||||
yolo_model: str | None = None
|
||||
yolo_confidence: float | None = None
|
||||
vram_budget_mb: int | None = None
|
||||
strategy: str | None = None
|
||||
|
||||
|
||||
# --- App ---
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("Inference server starting (device=%s)", _get_device())
|
||||
yield
|
||||
logger.info("Inference server shutting down")
|
||||
_models.clear()
|
||||
|
||||
|
||||
app = FastAPI(title="MPR Inference Server", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {
|
||||
"status": "ok",
|
||||
"device": _get_device(),
|
||||
"loaded_models": list(_models.keys()),
|
||||
"vram_budget_mb": _config["vram_budget_mb"],
|
||||
"strategy": _config["strategy"],
|
||||
}
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
def get_config():
|
||||
"""Current runtime config. Same values the .env sets at startup."""
|
||||
return {**_config, "device_resolved": _get_device()}
|
||||
|
||||
|
||||
@app.put("/config")
|
||||
def update_config(update: ConfigUpdate):
|
||||
"""Update runtime config. Only provided fields are changed."""
|
||||
changes = update.model_dump(exclude_none=True)
|
||||
if not changes:
|
||||
return _config
|
||||
|
||||
# If model changed, unload the old one so it gets reloaded on next request
|
||||
if "yolo_model" in changes and changes["yolo_model"] != _config["yolo_model"]:
|
||||
old = _config["yolo_model"]
|
||||
if old in _models:
|
||||
del _models[old]
|
||||
logger.info("Unloaded %s (model changed)", old)
|
||||
|
||||
_config.update(changes)
|
||||
logger.info("Config updated: %s", changes)
|
||||
return {**_config, "device_resolved": _get_device()}
|
||||
|
||||
|
||||
@app.post("/models/unload")
|
||||
def unload_model(body: dict):
|
||||
"""Unload a model from memory to free VRAM."""
|
||||
name = body.get("model", "")
|
||||
if name in _models:
|
||||
del _models[name]
|
||||
logger.info("Unloaded %s", name)
|
||||
return {"status": "unloaded", "model": name}
|
||||
return {"status": "not_loaded", "model": name}
|
||||
|
||||
|
||||
@app.post("/detect", response_model=DetectResponse)
|
||||
def detect(req: DetectRequest):
|
||||
model_name = req.model or _config["yolo_model"]
|
||||
confidence = req.confidence if req.confidence is not None else _config["yolo_confidence"]
|
||||
|
||||
try:
|
||||
model = _get_yolo(model_name)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to load model: {e}")
|
||||
|
||||
image = _decode_image(req.image)
|
||||
results = model(image, conf=confidence, verbose=False)
|
||||
|
||||
detections = []
|
||||
for r in results:
|
||||
for box in r.boxes:
|
||||
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||
label = r.names[int(box.cls[0])]
|
||||
|
||||
if req.target_classes and label not in req.target_classes:
|
||||
continue
|
||||
|
||||
det = BBox(
|
||||
x=int(x1), y=int(y1),
|
||||
w=int(x2 - x1), h=int(y2 - y1),
|
||||
confidence=float(box.conf[0]),
|
||||
label=label,
|
||||
)
|
||||
detections.append(det)
|
||||
|
||||
return DetectResponse(detections=detections)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
||||
host = os.environ.get("HOST", "0.0.0.0")
|
||||
port = int(os.environ.get("PORT", "8000"))
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
Reference in New Issue
Block a user