207 lines
5.6 KiB
Python
207 lines
5.6 KiB
Python
"""
|
|
Inference server — thin HTTP wrapper around ML models.
|
|
|
|
Runs on the GPU machine. The detection pipeline calls this over HTTP,
|
|
or imports the same logic locally if GPU is on the same machine.
|
|
|
|
Config is loaded from env on startup, then editable at runtime via
|
|
GET/PUT /config. The UI config panel is just a visual editor for these
|
|
same values.
|
|
|
|
Usage:
|
|
cd gpu && uvicorn server:app --host 0.0.0.0 --port 8000
|
|
# or
|
|
cd gpu && python server.py
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import io
|
|
import logging
|
|
import os
|
|
from contextlib import asynccontextmanager
|
|
|
|
import numpy as np
|
|
from fastapi import FastAPI, HTTPException
|
|
from PIL import Image
|
|
from pydantic import BaseModel
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- Runtime config (loaded from env, mutable via API) ---
|
|
_config = {
|
|
"device": os.environ.get("DEVICE", "auto"),
|
|
"yolo_model": os.environ.get("YOLO_MODEL", "yolov8n.pt"),
|
|
"yolo_confidence": float(os.environ.get("YOLO_CONFIDENCE", "0.3")),
|
|
"vram_budget_mb": int(os.environ.get("VRAM_BUDGET_MB", "10240")),
|
|
"strategy": os.environ.get("STRATEGY", "sequential"),
|
|
}
|
|
|
|
# --- Model registry ---
|
|
_models: dict[str, object] = {}
|
|
|
|
|
|
# --- Helpers ---
|
|
|
|
def _get_device() -> str:
|
|
device = _config["device"]
|
|
if device != "auto":
|
|
return device
|
|
try:
|
|
import torch
|
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
|
except ImportError:
|
|
return "cpu"
|
|
|
|
|
|
def _get_yolo(model_name: str | None = None):
|
|
name = model_name or _config["yolo_model"]
|
|
if name not in _models:
|
|
from ultralytics import YOLO
|
|
device = _get_device()
|
|
logger.info("Loading %s on %s", name, device)
|
|
model = YOLO(name)
|
|
model.to(device)
|
|
_models[name] = model
|
|
return _models[name]
|
|
|
|
|
|
def _decode_image(b64: str) -> np.ndarray:
|
|
data = base64.b64decode(b64)
|
|
img = Image.open(io.BytesIO(data)).convert("RGB")
|
|
return np.array(img)
|
|
|
|
|
|
# --- Request/Response models ---
|
|
|
|
class DetectRequest(BaseModel):
|
|
image: str # base64 JPEG
|
|
model: str | None = None # defaults to config yolo_model
|
|
confidence: float | None = None # defaults to config yolo_confidence
|
|
target_classes: list[str] | None = None
|
|
|
|
|
|
class BBox(BaseModel):
|
|
x: int
|
|
y: int
|
|
w: int
|
|
h: int
|
|
confidence: float
|
|
label: str
|
|
|
|
|
|
class DetectResponse(BaseModel):
|
|
detections: list[BBox]
|
|
|
|
|
|
class ConfigUpdate(BaseModel):
|
|
"""Partial config update — only provided fields are changed."""
|
|
device: str | None = None
|
|
yolo_model: str | None = None
|
|
yolo_confidence: float | None = None
|
|
vram_budget_mb: int | None = None
|
|
strategy: str | None = None
|
|
|
|
|
|
# --- App ---
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
logger.info("Inference server starting (device=%s)", _get_device())
|
|
yield
|
|
logger.info("Inference server shutting down")
|
|
_models.clear()
|
|
|
|
|
|
app = FastAPI(title="MPR Inference Server", lifespan=lifespan)
|
|
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
return {
|
|
"status": "ok",
|
|
"device": _get_device(),
|
|
"loaded_models": list(_models.keys()),
|
|
"vram_budget_mb": _config["vram_budget_mb"],
|
|
"strategy": _config["strategy"],
|
|
}
|
|
|
|
|
|
@app.get("/config")
|
|
def get_config():
|
|
"""Current runtime config. Same values the .env sets at startup."""
|
|
return {**_config, "device_resolved": _get_device()}
|
|
|
|
|
|
@app.put("/config")
|
|
def update_config(update: ConfigUpdate):
|
|
"""Update runtime config. Only provided fields are changed."""
|
|
changes = update.model_dump(exclude_none=True)
|
|
if not changes:
|
|
return _config
|
|
|
|
# If model changed, unload the old one so it gets reloaded on next request
|
|
if "yolo_model" in changes and changes["yolo_model"] != _config["yolo_model"]:
|
|
old = _config["yolo_model"]
|
|
if old in _models:
|
|
del _models[old]
|
|
logger.info("Unloaded %s (model changed)", old)
|
|
|
|
_config.update(changes)
|
|
logger.info("Config updated: %s", changes)
|
|
return {**_config, "device_resolved": _get_device()}
|
|
|
|
|
|
@app.post("/models/unload")
|
|
def unload_model(body: dict):
|
|
"""Unload a model from memory to free VRAM."""
|
|
name = body.get("model", "")
|
|
if name in _models:
|
|
del _models[name]
|
|
logger.info("Unloaded %s", name)
|
|
return {"status": "unloaded", "model": name}
|
|
return {"status": "not_loaded", "model": name}
|
|
|
|
|
|
@app.post("/detect", response_model=DetectResponse)
|
|
def detect(req: DetectRequest):
|
|
model_name = req.model or _config["yolo_model"]
|
|
confidence = req.confidence if req.confidence is not None else _config["yolo_confidence"]
|
|
|
|
try:
|
|
model = _get_yolo(model_name)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Failed to load model: {e}")
|
|
|
|
image = _decode_image(req.image)
|
|
results = model(image, conf=confidence, verbose=False)
|
|
|
|
detections = []
|
|
for r in results:
|
|
for box in r.boxes:
|
|
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
|
label = r.names[int(box.cls[0])]
|
|
|
|
if req.target_classes and label not in req.target_classes:
|
|
continue
|
|
|
|
det = BBox(
|
|
x=int(x1), y=int(y1),
|
|
w=int(x2 - x1), h=int(y2 - y1),
|
|
confidence=float(box.conf[0]),
|
|
label=label,
|
|
)
|
|
detections.append(det)
|
|
|
|
return DetectResponse(detections=detections)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
|
host = os.environ.get("HOST", "0.0.0.0")
|
|
port = int(os.environ.get("PORT", "8000"))
|
|
uvicorn.run(app, host=host, port=port)
|