""" Inference server — thin HTTP routes over model wrappers. Config lives in config.py, model logic in models/. This file is just the FastAPI glue. Usage: cd gpu && python server.py """ from __future__ import annotations import base64 import io import logging import os from contextlib import asynccontextmanager import numpy as np from fastapi import FastAPI, HTTPException from PIL import Image from pydantic import BaseModel from config import get_config, get_device, update_config from models import registry from models.yolo import detect as yolo_detect from models.ocr import ocr as ocr_run from models.vlm import query as vlm_query logger = logging.getLogger(__name__) def _decode_image(b64: str) -> np.ndarray: data = base64.b64decode(b64) img = Image.open(io.BytesIO(data)).convert("RGB") return np.array(img) # --- Request/Response models --- class DetectRequest(BaseModel): image: str model: str | None = None confidence: float | None = None target_classes: list[str] | None = None class BBox(BaseModel): x: int y: int w: int h: int confidence: float label: str class DetectResponse(BaseModel): detections: list[BBox] class OCRRequest(BaseModel): image: str languages: list[str] | None = None class OCRTextResult(BaseModel): text: str confidence: float bbox: list[int] class OCRResponse(BaseModel): results: list[OCRTextResult] class VLMRequest(BaseModel): image: str prompt: str model: str | None = None class VLMResponse(BaseModel): brand: str confidence: float reasoning: str class ConfigUpdate(BaseModel): device: str | None = None yolo_model: str | None = None yolo_confidence: float | None = None vram_budget_mb: int | None = None strategy: str | None = None ocr_languages: list[str] | None = None ocr_min_confidence: float | None = None # --- App --- @asynccontextmanager async def lifespan(app: FastAPI): logger.info("Inference server starting (device=%s)", get_device()) yield logger.info("Shutting down") registry.clear() app = FastAPI(title="MPR Inference Server", lifespan=lifespan) @app.get("/health") def health(): cfg = get_config() return { "status": "ok", "device": get_device(), "loaded_models": registry.loaded(), "vram_budget_mb": cfg["vram_budget_mb"], "strategy": cfg["strategy"], } @app.get("/config") def read_config(): return {**get_config(), "device_resolved": get_device()} @app.put("/config") def write_config(update: ConfigUpdate): changes = update.model_dump(exclude_none=True) if not changes: return get_config() # Unload model if it changed old_model = get_config().get("yolo_model") if "yolo_model" in changes and changes["yolo_model"] != old_model: registry.unload(old_model) update_config(changes) logger.info("Config updated: %s", changes) return {**get_config(), "device_resolved": get_device()} @app.post("/models/unload") def unload_model(body: dict): name = body.get("model", "") unloaded = registry.unload(name) return {"status": "unloaded" if unloaded else "not_loaded", "model": name} @app.post("/detect", response_model=DetectResponse) def detect(req: DetectRequest): try: image = _decode_image(req.image) except Exception as e: raise HTTPException(status_code=400, detail=f"Bad image: {e}") try: results = yolo_detect( image, model_name=req.model, confidence=req.confidence, target_classes=req.target_classes, ) except Exception as e: raise HTTPException(status_code=500, detail=f"Detection failed: {e}") return DetectResponse(detections=[BBox(**r) for r in results]) @app.post("/ocr", response_model=OCRResponse) def ocr(req: OCRRequest): try: image = _decode_image(req.image) except Exception as e: raise HTTPException(status_code=400, detail=f"Bad image: {e}") try: results = ocr_run(image, languages=req.languages) except Exception as e: raise HTTPException(status_code=500, detail=f"OCR failed: {e}") return OCRResponse(results=[OCRTextResult(**r) for r in results]) @app.post("/vlm", response_model=VLMResponse) def vlm(req: VLMRequest): try: image = _decode_image(req.image) except Exception as e: raise HTTPException(status_code=400, detail=f"Bad image: {e}") try: result = vlm_query(image, req.prompt) except Exception as e: raise HTTPException(status_code=500, detail=f"VLM failed: {e}") return VLMResponse(**result) if __name__ == "__main__": import uvicorn logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s") host = os.environ.get("HOST", "0.0.0.0") port = int(os.environ.get("PORT", "8000")) uvicorn.run(app, host=host, port=port)