180 lines
4.3 KiB
Python
180 lines
4.3 KiB
Python
"""
|
|
Inference server — thin HTTP routes over model wrappers.
|
|
|
|
Config lives in config.py, model logic in models/.
|
|
This file is just the FastAPI glue.
|
|
|
|
Usage:
|
|
cd gpu && python server.py
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import io
|
|
import logging
|
|
import os
|
|
from contextlib import asynccontextmanager
|
|
|
|
import numpy as np
|
|
from fastapi import FastAPI, HTTPException
|
|
from PIL import Image
|
|
from pydantic import BaseModel
|
|
|
|
from config import get_config, get_device, update_config
|
|
from models import registry
|
|
from models.yolo import detect as yolo_detect
|
|
from models.ocr import ocr as ocr_run
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _decode_image(b64: str) -> np.ndarray:
|
|
data = base64.b64decode(b64)
|
|
img = Image.open(io.BytesIO(data)).convert("RGB")
|
|
return np.array(img)
|
|
|
|
|
|
# --- Request/Response models ---
|
|
|
|
class DetectRequest(BaseModel):
|
|
image: str
|
|
model: str | None = None
|
|
confidence: float | None = None
|
|
target_classes: list[str] | None = None
|
|
|
|
|
|
class BBox(BaseModel):
|
|
x: int
|
|
y: int
|
|
w: int
|
|
h: int
|
|
confidence: float
|
|
label: str
|
|
|
|
|
|
class DetectResponse(BaseModel):
|
|
detections: list[BBox]
|
|
|
|
|
|
class OCRRequest(BaseModel):
|
|
image: str
|
|
languages: list[str] | None = None
|
|
|
|
|
|
class OCRTextResult(BaseModel):
|
|
text: str
|
|
confidence: float
|
|
bbox: list[int]
|
|
|
|
|
|
class OCRResponse(BaseModel):
|
|
results: list[OCRTextResult]
|
|
|
|
|
|
class ConfigUpdate(BaseModel):
|
|
device: str | None = None
|
|
yolo_model: str | None = None
|
|
yolo_confidence: float | None = None
|
|
vram_budget_mb: int | None = None
|
|
strategy: str | None = None
|
|
ocr_languages: list[str] | None = None
|
|
ocr_min_confidence: float | None = None
|
|
|
|
|
|
# --- App ---
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
logger.info("Inference server starting (device=%s)", get_device())
|
|
yield
|
|
logger.info("Shutting down")
|
|
registry.clear()
|
|
|
|
|
|
app = FastAPI(title="MPR Inference Server", lifespan=lifespan)
|
|
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
cfg = get_config()
|
|
return {
|
|
"status": "ok",
|
|
"device": get_device(),
|
|
"loaded_models": registry.loaded(),
|
|
"vram_budget_mb": cfg["vram_budget_mb"],
|
|
"strategy": cfg["strategy"],
|
|
}
|
|
|
|
|
|
@app.get("/config")
|
|
def read_config():
|
|
return {**get_config(), "device_resolved": get_device()}
|
|
|
|
|
|
@app.put("/config")
|
|
def write_config(update: ConfigUpdate):
|
|
changes = update.model_dump(exclude_none=True)
|
|
if not changes:
|
|
return get_config()
|
|
|
|
# Unload model if it changed
|
|
old_model = get_config().get("yolo_model")
|
|
if "yolo_model" in changes and changes["yolo_model"] != old_model:
|
|
registry.unload(old_model)
|
|
|
|
update_config(changes)
|
|
logger.info("Config updated: %s", changes)
|
|
return {**get_config(), "device_resolved": get_device()}
|
|
|
|
|
|
@app.post("/models/unload")
|
|
def unload_model(body: dict):
|
|
name = body.get("model", "")
|
|
unloaded = registry.unload(name)
|
|
return {"status": "unloaded" if unloaded else "not_loaded", "model": name}
|
|
|
|
|
|
@app.post("/detect", response_model=DetectResponse)
|
|
def detect(req: DetectRequest):
|
|
try:
|
|
image = _decode_image(req.image)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
|
|
|
try:
|
|
results = yolo_detect(
|
|
image,
|
|
model_name=req.model,
|
|
confidence=req.confidence,
|
|
target_classes=req.target_classes,
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Detection failed: {e}")
|
|
|
|
return DetectResponse(detections=[BBox(**r) for r in results])
|
|
|
|
|
|
@app.post("/ocr", response_model=OCRResponse)
|
|
def ocr(req: OCRRequest):
|
|
try:
|
|
image = _decode_image(req.image)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
|
|
|
try:
|
|
results = ocr_run(image, languages=req.languages)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"OCR failed: {e}")
|
|
|
|
return OCRResponse(results=[OCRTextResult(**r) for r in results])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
|
host = os.environ.get("HOST", "0.0.0.0")
|
|
port = int(os.environ.get("PORT", "8000"))
|
|
uvicorn.run(app, host=host, port=port)
|