""" Inference server — thin HTTP routes over model wrappers. Config lives in config.py, model logic in models/. This file is just the FastAPI glue. Usage: cd gpu && python server.py """ from __future__ import annotations import base64 import io import logging import os import time from contextlib import asynccontextmanager import numpy as np from fastapi import FastAPI, HTTPException, Request from PIL import Image from pydantic import BaseModel from emit import log as emit_log from config import get_config, get_device, update_config from models import registry from models.yolo import detect as yolo_detect from models.ocr import ocr as ocr_run from models.vlm import query as vlm_query logger = logging.getLogger(__name__) def _decode_image(b64: str) -> np.ndarray: data = base64.b64decode(b64) img = Image.open(io.BytesIO(data)).convert("RGB") return np.array(img) def _job_ctx(request: Request) -> tuple[str, str]: """Extract job_id and log_level from request headers.""" job_id = request.headers.get("x-job-id", "") log_level = request.headers.get("x-log-level", "INFO") return job_id, log_level def _gpu_log(job_id: str, log_level: str, stage: str, level: str, msg: str): """Emit a log event if job context is present.""" if job_id: emit_log(job_id, stage, level, msg, log_level=log_level) # --- Request/Response models (generated from core/schema/models/inference.py) --- from models.inference_contract import ( AnalyzeRegionsDebugResponse, AnalyzeRegionsRequest, AnalyzeRegionsResponse, BBox, ConfigUpdate, DetectRequest, DetectResponse, OCRRequest, OCRResponse, OCRTextResult, PreprocessRequest, PreprocessResponse, RegionBox, VLMRequest, VLMResponse, ) # --- App --- @asynccontextmanager async def lifespan(app: FastAPI): logger.info("Inference server starting (device=%s)", get_device()) yield logger.info("Shutting down") registry.clear() app = FastAPI(title="MPR Inference Server", lifespan=lifespan) @app.get("/health") def health(): cfg = get_config() return { "status": "ok", "device": get_device(), "loaded_models": registry.loaded(), "vram_budget_mb": cfg["vram_budget_mb"], "strategy": cfg["strategy"], } @app.get("/config") def read_config(): return {**get_config(), "device_resolved": get_device()} @app.put("/config") def write_config(update: ConfigUpdate): changes = update.model_dump(exclude_none=True) if not changes: return get_config() # Unload model if it changed old_model = get_config().get("yolo_model") if "yolo_model" in changes and changes["yolo_model"] != old_model: registry.unload(old_model) update_config(changes) logger.info("Config updated: %s", changes) return {**get_config(), "device_resolved": get_device()} @app.post("/models/unload") def unload_model(body: dict): name = body.get("model", "") unloaded = registry.unload(name) return {"status": "unloaded" if unloaded else "not_loaded", "model": name} @app.post("/detect", response_model=DetectResponse) def detect(req: DetectRequest, request: Request): job_id, log_level = _job_ctx(request) try: t0 = time.monotonic() image = _decode_image(req.image) decode_ms = (time.monotonic() - t0) * 1000 h, w = image.shape[:2] _gpu_log(job_id, log_level, "GPU:YOLO", "DEBUG", f"Decoded {w}x{h} image in {decode_ms:.0f}ms") except Exception as e: raise HTTPException(status_code=400, detail=f"Bad image: {e}") try: t0 = time.monotonic() results = yolo_detect( image, model_name=req.model, confidence=req.confidence, target_classes=req.target_classes, ) infer_ms = (time.monotonic() - t0) * 1000 _gpu_log(job_id, log_level, "GPU:YOLO", "DEBUG", f"Inference: {len(results)} detections in {infer_ms:.0f}ms " f"(model={req.model}, conf={req.confidence})") except Exception as e: raise HTTPException(status_code=500, detail=f"Detection failed: {e}") return DetectResponse(detections=[BBox(**r) for r in results]) @app.post("/ocr", response_model=OCRResponse) def ocr(req: OCRRequest, request: Request): job_id, log_level = _job_ctx(request) try: image = _decode_image(req.image) h, w = image.shape[:2] except Exception as e: raise HTTPException(status_code=400, detail=f"Bad image: {e}") try: t0 = time.monotonic() results = ocr_run(image, languages=req.languages) infer_ms = (time.monotonic() - t0) * 1000 texts = [r["text"][:20] for r in results] _gpu_log(job_id, log_level, "GPU:OCR", "DEBUG", f"OCR {w}x{h}: {infer_ms:.0f}ms → {len(results)} results {texts}") except Exception as e: raise HTTPException(status_code=500, detail=f"OCR failed: {e}") return OCRResponse(results=[OCRTextResult(**r) for r in results]) @app.post("/preprocess", response_model=PreprocessResponse) def preprocess_image(req: PreprocessRequest): try: image = _decode_image(req.image) except Exception as e: raise HTTPException(status_code=400, detail=f"Bad image: {e}") try: from models.preprocess import preprocess processed = preprocess( image, do_binarize=req.binarize, do_deskew=req.deskew, do_contrast=req.contrast, ) except Exception as e: raise HTTPException(status_code=500, detail=f"Preprocessing failed: {e}") from PIL import Image as PILImage import io img = PILImage.fromarray(processed) buf = io.BytesIO() img.save(buf, format="JPEG", quality=90) result_b64 = base64.b64encode(buf.getvalue()).decode() return PreprocessResponse(image=result_b64) @app.post("/vlm", response_model=VLMResponse) def vlm(req: VLMRequest, request: Request): job_id, log_level = _job_ctx(request) try: image = _decode_image(req.image) h, w = image.shape[:2] except Exception as e: raise HTTPException(status_code=400, detail=f"Bad image: {e}") try: t0 = time.monotonic() result = vlm_query(image, req.prompt) infer_ms = (time.monotonic() - t0) * 1000 _gpu_log(job_id, log_level, "GPU:VLM", "DEBUG", f"VLM {w}x{h}: {infer_ms:.0f}ms → " f"brand='{result.get('brand', '')}' conf={result.get('confidence', 0):.2f}") except Exception as e: raise HTTPException(status_code=500, detail=f"VLM failed: {e}") return VLMResponse(**result) @app.post("/detect_edges", response_model=AnalyzeRegionsResponse) def detect_edges_endpoint(req: AnalyzeRegionsRequest, request: Request): job_id, log_level = _job_ctx(request) try: image = _decode_image(req.image) h, w = image.shape[:2] except Exception as e: raise HTTPException(status_code=400, detail=f"Bad image: {e}") try: t0 = time.monotonic() from models.cv.edges import detect_edges edge_regions = detect_edges( image, canny_low=req.edge_canny_low, canny_high=req.edge_canny_high, hough_threshold=req.edge_hough_threshold, hough_min_length=req.edge_hough_min_length, hough_max_gap=req.edge_hough_max_gap, pair_max_distance=req.edge_pair_max_distance, pair_min_distance=req.edge_pair_min_distance, ) infer_ms = (time.monotonic() - t0) * 1000 _gpu_log(job_id, log_level, "GPU:CV", "DEBUG", f"Edge analysis {w}x{h}: {infer_ms:.0f}ms → {len(edge_regions)} regions") except Exception as e: raise HTTPException(status_code=500, detail=f"Region analysis failed: {e}") boxes = [RegionBox(**r) for r in edge_regions] return AnalyzeRegionsResponse(regions=boxes) @app.post("/detect_edges/debug", response_model=AnalyzeRegionsDebugResponse) def detect_edges_debug_endpoint(req: AnalyzeRegionsRequest, request: Request): job_id, log_level = _job_ctx(request) try: image = _decode_image(req.image) h, w = image.shape[:2] except Exception as e: raise HTTPException(status_code=400, detail=f"Bad image: {e}") try: t0 = time.monotonic() from models.cv.edges import detect_edges_debug result = detect_edges_debug( image, canny_low=req.edge_canny_low, canny_high=req.edge_canny_high, hough_threshold=req.edge_hough_threshold, hough_min_length=req.edge_hough_min_length, hough_max_gap=req.edge_hough_max_gap, pair_max_distance=req.edge_pair_max_distance, pair_min_distance=req.edge_pair_min_distance, ) infer_ms = (time.monotonic() - t0) * 1000 _gpu_log(job_id, log_level, "GPU:CV", "DEBUG", f"Edge debug {w}x{h}: {infer_ms:.0f}ms → {len(result['regions'])} regions, " f"{result['horizontal_count']} horizontals, {result['pair_count']} pairs") except Exception as e: raise HTTPException(status_code=500, detail=f"Region debug analysis failed: {e}") boxes = [RegionBox(**r) for r in result["regions"]] response = AnalyzeRegionsDebugResponse( regions=boxes, edge_overlay_b64=result["edge_overlay_b64"], lines_overlay_b64=result["lines_overlay_b64"], horizontal_count=result["horizontal_count"], pair_count=result["pair_count"], ) return response if __name__ == "__main__": import uvicorn logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s") host = os.environ.get("HOST", "0.0.0.0") port = int(os.environ.get("PORT", "8000")) uvicorn.run(app, host=host, port=port)