Files
mediaproc/gpu/server.py
2026-03-26 00:56:35 -03:00

180 lines
4.3 KiB
Python

"""
Inference server — thin HTTP routes over model wrappers.
Config lives in config.py, model logic in models/.
This file is just the FastAPI glue.
Usage:
cd gpu && python server.py
"""
from __future__ import annotations
import base64
import io
import logging
import os
from contextlib import asynccontextmanager
import numpy as np
from fastapi import FastAPI, HTTPException
from PIL import Image
from pydantic import BaseModel
from config import get_config, get_device, update_config
from models import registry
from models.yolo import detect as yolo_detect
from models.ocr import ocr as ocr_run
logger = logging.getLogger(__name__)
def _decode_image(b64: str) -> np.ndarray:
data = base64.b64decode(b64)
img = Image.open(io.BytesIO(data)).convert("RGB")
return np.array(img)
# --- Request/Response models ---
class DetectRequest(BaseModel):
image: str
model: str | None = None
confidence: float | None = None
target_classes: list[str] | None = None
class BBox(BaseModel):
x: int
y: int
w: int
h: int
confidence: float
label: str
class DetectResponse(BaseModel):
detections: list[BBox]
class OCRRequest(BaseModel):
image: str
languages: list[str] | None = None
class OCRTextResult(BaseModel):
text: str
confidence: float
bbox: list[int]
class OCRResponse(BaseModel):
results: list[OCRTextResult]
class ConfigUpdate(BaseModel):
device: str | None = None
yolo_model: str | None = None
yolo_confidence: float | None = None
vram_budget_mb: int | None = None
strategy: str | None = None
ocr_languages: list[str] | None = None
ocr_min_confidence: float | None = None
# --- App ---
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Inference server starting (device=%s)", get_device())
yield
logger.info("Shutting down")
registry.clear()
app = FastAPI(title="MPR Inference Server", lifespan=lifespan)
@app.get("/health")
def health():
cfg = get_config()
return {
"status": "ok",
"device": get_device(),
"loaded_models": registry.loaded(),
"vram_budget_mb": cfg["vram_budget_mb"],
"strategy": cfg["strategy"],
}
@app.get("/config")
def read_config():
return {**get_config(), "device_resolved": get_device()}
@app.put("/config")
def write_config(update: ConfigUpdate):
changes = update.model_dump(exclude_none=True)
if not changes:
return get_config()
# Unload model if it changed
old_model = get_config().get("yolo_model")
if "yolo_model" in changes and changes["yolo_model"] != old_model:
registry.unload(old_model)
update_config(changes)
logger.info("Config updated: %s", changes)
return {**get_config(), "device_resolved": get_device()}
@app.post("/models/unload")
def unload_model(body: dict):
name = body.get("model", "")
unloaded = registry.unload(name)
return {"status": "unloaded" if unloaded else "not_loaded", "model": name}
@app.post("/detect", response_model=DetectResponse)
def detect(req: DetectRequest):
try:
image = _decode_image(req.image)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
results = yolo_detect(
image,
model_name=req.model,
confidence=req.confidence,
target_classes=req.target_classes,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Detection failed: {e}")
return DetectResponse(detections=[BBox(**r) for r in results])
@app.post("/ocr", response_model=OCRResponse)
def ocr(req: OCRRequest):
try:
image = _decode_image(req.image)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
results = ocr_run(image, languages=req.languages)
except Exception as e:
raise HTTPException(status_code=500, detail=f"OCR failed: {e}")
return OCRResponse(results=[OCRTextResult(**r) for r in results])
if __name__ == "__main__":
import uvicorn
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s%(message)s")
host = os.environ.get("HOST", "0.0.0.0")
port = int(os.environ.get("PORT", "8000"))
uvicorn.run(app, host=host, port=port)