This commit is contained in:
2026-03-23 19:10:55 -03:00
parent 3df9ed5ada
commit 95246c5452
23 changed files with 1361 additions and 107 deletions

View File

@@ -1,16 +1,10 @@
"""
Inference server — thin HTTP wrapper around ML models.
Inference server — thin HTTP routes over model wrappers.
Runs on the GPU machine. The detection pipeline calls this over HTTP,
or imports the same logic locally if GPU is on the same machine.
Config is loaded from env on startup, then editable at runtime via
GET/PUT /config. The UI config panel is just a visual editor for these
same values.
Config lives in config.py, model logic in models/.
This file is just the FastAPI glue.
Usage:
cd gpu && uvicorn server:app --host 0.0.0.0 --port 8000
# or
cd gpu && python server.py
"""
@@ -27,45 +21,13 @@ from fastapi import FastAPI, HTTPException
from PIL import Image
from pydantic import BaseModel
from config import get_config, get_device, update_config
from models import registry
from models.yolo import detect as yolo_detect
from models.ocr import ocr as ocr_run
logger = logging.getLogger(__name__)
# --- Runtime config (loaded from env, mutable via API) ---
_config = {
"device": os.environ.get("DEVICE", "auto"),
"yolo_model": os.environ.get("YOLO_MODEL", "yolov8n.pt"),
"yolo_confidence": float(os.environ.get("YOLO_CONFIDENCE", "0.3")),
"vram_budget_mb": int(os.environ.get("VRAM_BUDGET_MB", "10240")),
"strategy": os.environ.get("STRATEGY", "sequential"),
}
# --- Model registry ---
_models: dict[str, object] = {}
# --- Helpers ---
def _get_device() -> str:
device = _config["device"]
if device != "auto":
return device
try:
import torch
return "cuda" if torch.cuda.is_available() else "cpu"
except ImportError:
return "cpu"
def _get_yolo(model_name: str | None = None):
name = model_name or _config["yolo_model"]
if name not in _models:
from ultralytics import YOLO
device = _get_device()
logger.info("Loading %s on %s", name, device)
model = YOLO(name)
model.to(device)
_models[name] = model
return _models[name]
def _decode_image(b64: str) -> np.ndarray:
data = base64.b64decode(b64)
@@ -76,9 +38,9 @@ def _decode_image(b64: str) -> np.ndarray:
# --- Request/Response models ---
class DetectRequest(BaseModel):
image: str # base64 JPEG
model: str | None = None # defaults to config yolo_model
confidence: float | None = None # defaults to config yolo_confidence
image: str
model: str | None = None
confidence: float | None = None
target_classes: list[str] | None = None
@@ -95,23 +57,39 @@ class DetectResponse(BaseModel):
detections: list[BBox]
class OCRRequest(BaseModel):
image: str
languages: list[str] | None = None
class OCRTextResult(BaseModel):
text: str
confidence: float
bbox: list[int]
class OCRResponse(BaseModel):
results: list[OCRTextResult]
class ConfigUpdate(BaseModel):
"""Partial config update — only provided fields are changed."""
device: str | None = None
yolo_model: str | None = None
yolo_confidence: float | None = None
vram_budget_mb: int | None = None
strategy: str | None = None
ocr_languages: list[str] | None = None
ocr_min_confidence: float | None = None
# --- App ---
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Inference server starting (device=%s)", _get_device())
logger.info("Inference server starting (device=%s)", get_device())
yield
logger.info("Inference server shutting down")
_models.clear()
logger.info("Shutting down")
registry.clear()
app = FastAPI(title="MPR Inference Server", lifespan=lifespan)
@@ -119,82 +97,77 @@ app = FastAPI(title="MPR Inference Server", lifespan=lifespan)
@app.get("/health")
def health():
cfg = get_config()
return {
"status": "ok",
"device": _get_device(),
"loaded_models": list(_models.keys()),
"vram_budget_mb": _config["vram_budget_mb"],
"strategy": _config["strategy"],
"device": get_device(),
"loaded_models": registry.loaded(),
"vram_budget_mb": cfg["vram_budget_mb"],
"strategy": cfg["strategy"],
}
@app.get("/config")
def get_config():
"""Current runtime config. Same values the .env sets at startup."""
return {**_config, "device_resolved": _get_device()}
def read_config():
return {**get_config(), "device_resolved": get_device()}
@app.put("/config")
def update_config(update: ConfigUpdate):
"""Update runtime config. Only provided fields are changed."""
def write_config(update: ConfigUpdate):
changes = update.model_dump(exclude_none=True)
if not changes:
return _config
return get_config()
# If model changed, unload the old one so it gets reloaded on next request
if "yolo_model" in changes and changes["yolo_model"] != _config["yolo_model"]:
old = _config["yolo_model"]
if old in _models:
del _models[old]
logger.info("Unloaded %s (model changed)", old)
# Unload model if it changed
old_model = get_config().get("yolo_model")
if "yolo_model" in changes and changes["yolo_model"] != old_model:
registry.unload(old_model)
_config.update(changes)
update_config(changes)
logger.info("Config updated: %s", changes)
return {**_config, "device_resolved": _get_device()}
return {**get_config(), "device_resolved": get_device()}
@app.post("/models/unload")
def unload_model(body: dict):
"""Unload a model from memory to free VRAM."""
name = body.get("model", "")
if name in _models:
del _models[name]
logger.info("Unloaded %s", name)
return {"status": "unloaded", "model": name}
return {"status": "not_loaded", "model": name}
unloaded = registry.unload(name)
return {"status": "unloaded" if unloaded else "not_loaded", "model": name}
@app.post("/detect", response_model=DetectResponse)
def detect(req: DetectRequest):
model_name = req.model or _config["yolo_model"]
confidence = req.confidence if req.confidence is not None else _config["yolo_confidence"]
try:
image = _decode_image(req.image)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
model = _get_yolo(model_name)
results = yolo_detect(
image,
model_name=req.model,
confidence=req.confidence,
target_classes=req.target_classes,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load model: {e}")
raise HTTPException(status_code=500, detail=f"Detection failed: {e}")
image = _decode_image(req.image)
results = model(image, conf=confidence, verbose=False)
return DetectResponse(detections=[BBox(**r) for r in results])
detections = []
for r in results:
for box in r.boxes:
x1, y1, x2, y2 = box.xyxy[0].tolist()
label = r.names[int(box.cls[0])]
if req.target_classes and label not in req.target_classes:
continue
@app.post("/ocr", response_model=OCRResponse)
def ocr(req: OCRRequest):
try:
image = _decode_image(req.image)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
det = BBox(
x=int(x1), y=int(y1),
w=int(x2 - x1), h=int(y2 - y1),
confidence=float(box.conf[0]),
label=label,
)
detections.append(det)
try:
results = ocr_run(image, languages=req.languages)
except Exception as e:
raise HTTPException(status_code=500, detail=f"OCR failed: {e}")
return DetectResponse(detections=detections)
return OCRResponse(results=[OCRTextResult(**r) for r in results])
if __name__ == "__main__":