400 lines
12 KiB
Python
400 lines
12 KiB
Python
"""
|
|
Inference server — thin HTTP routes over model wrappers.
|
|
|
|
Config lives in config.py, model logic in models/.
|
|
This file is just the FastAPI glue.
|
|
|
|
Usage:
|
|
cd gpu && python server.py
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import io
|
|
import logging
|
|
import os
|
|
import time
|
|
from contextlib import asynccontextmanager
|
|
|
|
import numpy as np
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from PIL import Image
|
|
from pydantic import BaseModel
|
|
|
|
from emit import log as emit_log
|
|
|
|
from config import get_config, get_device, update_config
|
|
from models import registry
|
|
from models.yolo import detect as yolo_detect
|
|
from models.ocr import ocr as ocr_run
|
|
from models.vlm import query as vlm_query
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _decode_image(b64: str) -> np.ndarray:
|
|
data = base64.b64decode(b64)
|
|
img = Image.open(io.BytesIO(data)).convert("RGB")
|
|
return np.array(img)
|
|
|
|
|
|
def _job_ctx(request: Request) -> tuple[str, str]:
|
|
"""Extract job_id and log_level from request headers."""
|
|
job_id = request.headers.get("x-job-id", "")
|
|
log_level = request.headers.get("x-log-level", "INFO")
|
|
return job_id, log_level
|
|
|
|
|
|
def _gpu_log(job_id: str, log_level: str, stage: str, level: str, msg: str):
|
|
"""Emit a log event if job context is present."""
|
|
if job_id:
|
|
emit_log(job_id, stage, level, msg, log_level=log_level)
|
|
|
|
|
|
# --- Request/Response models (generated from core/schema/models/inference.py) ---
|
|
|
|
from models.models import (
|
|
AnalyzeRegionsDebugResponse,
|
|
AnalyzeRegionsRequest,
|
|
AnalyzeRegionsResponse,
|
|
BBox,
|
|
ConfigUpdate,
|
|
DetectRequest,
|
|
DetectResponse,
|
|
OCRRequest,
|
|
OCRResponse,
|
|
OCRTextResult,
|
|
PreprocessRequest,
|
|
PreprocessResponse,
|
|
RegionBox,
|
|
SegmentFieldRequest,
|
|
SegmentFieldResponse,
|
|
SegmentFieldDebugResponse,
|
|
VLMRequest,
|
|
VLMResponse,
|
|
)
|
|
|
|
|
|
# --- App ---
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
logger.info("Inference server starting (device=%s)", get_device())
|
|
yield
|
|
logger.info("Shutting down")
|
|
registry.clear()
|
|
|
|
|
|
app = FastAPI(title="MPR Inference Server", lifespan=lifespan)
|
|
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
cfg = get_config()
|
|
return {
|
|
"status": "ok",
|
|
"device": get_device(),
|
|
"loaded_models": registry.loaded(),
|
|
"vram_budget_mb": cfg["vram_budget_mb"],
|
|
"strategy": cfg["strategy"],
|
|
}
|
|
|
|
|
|
@app.get("/config")
|
|
def read_config():
|
|
return {**get_config(), "device_resolved": get_device()}
|
|
|
|
|
|
@app.put("/config")
|
|
def write_config(update: ConfigUpdate):
|
|
changes = update.model_dump(exclude_none=True)
|
|
if not changes:
|
|
return get_config()
|
|
|
|
# Unload model if it changed
|
|
old_model = get_config().get("yolo_model")
|
|
if "yolo_model" in changes and changes["yolo_model"] != old_model:
|
|
registry.unload(old_model)
|
|
|
|
update_config(changes)
|
|
logger.info("Config updated: %s", changes)
|
|
return {**get_config(), "device_resolved": get_device()}
|
|
|
|
|
|
@app.post("/models/unload")
|
|
def unload_model(body: dict):
|
|
name = body.get("model", "")
|
|
unloaded = registry.unload(name)
|
|
return {"status": "unloaded" if unloaded else "not_loaded", "model": name}
|
|
|
|
|
|
@app.post("/detect", response_model=DetectResponse)
|
|
def detect(req: DetectRequest, request: Request):
|
|
job_id, log_level = _job_ctx(request)
|
|
|
|
try:
|
|
t0 = time.monotonic()
|
|
image = _decode_image(req.image)
|
|
decode_ms = (time.monotonic() - t0) * 1000
|
|
h, w = image.shape[:2]
|
|
_gpu_log(job_id, log_level, "GPU:YOLO", "DEBUG",
|
|
f"Decoded {w}x{h} image in {decode_ms:.0f}ms")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
|
|
|
try:
|
|
t0 = time.monotonic()
|
|
results = yolo_detect(
|
|
image,
|
|
model_name=req.model,
|
|
confidence=req.confidence,
|
|
target_classes=req.target_classes,
|
|
)
|
|
infer_ms = (time.monotonic() - t0) * 1000
|
|
_gpu_log(job_id, log_level, "GPU:YOLO", "DEBUG",
|
|
f"Inference: {len(results)} detections in {infer_ms:.0f}ms "
|
|
f"(model={req.model}, conf={req.confidence})")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Detection failed: {e}")
|
|
|
|
return DetectResponse(detections=[BBox(**r) for r in results])
|
|
|
|
|
|
@app.post("/ocr", response_model=OCRResponse)
|
|
def ocr(req: OCRRequest, request: Request):
|
|
job_id, log_level = _job_ctx(request)
|
|
|
|
try:
|
|
image = _decode_image(req.image)
|
|
h, w = image.shape[:2]
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
|
|
|
try:
|
|
t0 = time.monotonic()
|
|
results = ocr_run(image, languages=req.languages)
|
|
infer_ms = (time.monotonic() - t0) * 1000
|
|
texts = [r["text"][:20] for r in results]
|
|
_gpu_log(job_id, log_level, "GPU:OCR", "DEBUG",
|
|
f"OCR {w}x{h}: {infer_ms:.0f}ms → {len(results)} results {texts}")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"OCR failed: {e}")
|
|
|
|
return OCRResponse(results=[OCRTextResult(**r) for r in results])
|
|
|
|
|
|
@app.post("/preprocess", response_model=PreprocessResponse)
|
|
def preprocess_image(req: PreprocessRequest):
|
|
try:
|
|
image = _decode_image(req.image)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
|
|
|
try:
|
|
from models.preprocess import preprocess
|
|
processed = preprocess(
|
|
image,
|
|
do_binarize=req.binarize,
|
|
do_deskew=req.deskew,
|
|
do_contrast=req.contrast,
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Preprocessing failed: {e}")
|
|
|
|
from PIL import Image as PILImage
|
|
import io
|
|
img = PILImage.fromarray(processed)
|
|
buf = io.BytesIO()
|
|
img.save(buf, format="JPEG", quality=90)
|
|
result_b64 = base64.b64encode(buf.getvalue()).decode()
|
|
|
|
return PreprocessResponse(image=result_b64)
|
|
|
|
|
|
@app.post("/vlm", response_model=VLMResponse)
|
|
def vlm(req: VLMRequest, request: Request):
|
|
job_id, log_level = _job_ctx(request)
|
|
|
|
try:
|
|
image = _decode_image(req.image)
|
|
h, w = image.shape[:2]
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
|
|
|
try:
|
|
t0 = time.monotonic()
|
|
result = vlm_query(image, req.prompt)
|
|
infer_ms = (time.monotonic() - t0) * 1000
|
|
_gpu_log(job_id, log_level, "GPU:VLM", "DEBUG",
|
|
f"VLM {w}x{h}: {infer_ms:.0f}ms → "
|
|
f"brand='{result.get('brand', '')}' conf={result.get('confidence', 0):.2f}")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"VLM failed: {e}")
|
|
|
|
return VLMResponse(**result)
|
|
|
|
|
|
@app.post("/detect_edges", response_model=AnalyzeRegionsResponse)
|
|
def detect_edges_endpoint(req: AnalyzeRegionsRequest, request: Request):
|
|
job_id, log_level = _job_ctx(request)
|
|
|
|
try:
|
|
image = _decode_image(req.image)
|
|
h, w = image.shape[:2]
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
|
|
|
try:
|
|
t0 = time.monotonic()
|
|
from models.cv.edges import detect_edges
|
|
|
|
edge_regions = detect_edges(
|
|
image,
|
|
canny_low=req.edge_canny_low,
|
|
canny_high=req.edge_canny_high,
|
|
hough_threshold=req.edge_hough_threshold,
|
|
hough_min_length=req.edge_hough_min_length,
|
|
hough_max_gap=req.edge_hough_max_gap,
|
|
pair_max_distance=req.edge_pair_max_distance,
|
|
pair_min_distance=req.edge_pair_min_distance,
|
|
)
|
|
infer_ms = (time.monotonic() - t0) * 1000
|
|
|
|
_gpu_log(job_id, log_level, "GPU:CV", "DEBUG",
|
|
f"Edge analysis {w}x{h}: {infer_ms:.0f}ms → {len(edge_regions)} regions")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Region analysis failed: {e}")
|
|
|
|
boxes = [RegionBox(**r) for r in edge_regions]
|
|
return AnalyzeRegionsResponse(regions=boxes)
|
|
|
|
|
|
@app.post("/detect_edges/debug", response_model=AnalyzeRegionsDebugResponse)
|
|
def detect_edges_debug_endpoint(req: AnalyzeRegionsRequest, request: Request):
|
|
job_id, log_level = _job_ctx(request)
|
|
|
|
try:
|
|
image = _decode_image(req.image)
|
|
h, w = image.shape[:2]
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
|
|
|
try:
|
|
t0 = time.monotonic()
|
|
from models.cv.edges import detect_edges_debug
|
|
|
|
result = detect_edges_debug(
|
|
image,
|
|
canny_low=req.edge_canny_low,
|
|
canny_high=req.edge_canny_high,
|
|
hough_threshold=req.edge_hough_threshold,
|
|
hough_min_length=req.edge_hough_min_length,
|
|
hough_max_gap=req.edge_hough_max_gap,
|
|
pair_max_distance=req.edge_pair_max_distance,
|
|
pair_min_distance=req.edge_pair_min_distance,
|
|
)
|
|
infer_ms = (time.monotonic() - t0) * 1000
|
|
|
|
_gpu_log(job_id, log_level, "GPU:CV", "DEBUG",
|
|
f"Edge debug {w}x{h}: {infer_ms:.0f}ms → {len(result['regions'])} regions, "
|
|
f"{result['horizontal_count']} horizontals, {result['pair_count']} pairs")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Region debug analysis failed: {e}")
|
|
|
|
boxes = [RegionBox(**r) for r in result["regions"]]
|
|
response = AnalyzeRegionsDebugResponse(
|
|
regions=boxes,
|
|
edge_overlay_b64=result["edge_overlay_b64"],
|
|
lines_overlay_b64=result["lines_overlay_b64"],
|
|
horizontal_count=result["horizontal_count"],
|
|
pair_count=result["pair_count"],
|
|
)
|
|
return response
|
|
|
|
|
|
@app.post("/segment_field", response_model=SegmentFieldResponse)
|
|
def segment_field_endpoint(req: SegmentFieldRequest, request: Request):
|
|
job_id, log_level = _job_ctx(request)
|
|
|
|
try:
|
|
image = _decode_image(req.image)
|
|
h, w = image.shape[:2]
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
|
|
|
try:
|
|
t0 = time.monotonic()
|
|
from models.cv.segmentation import segment_field
|
|
|
|
result = segment_field(
|
|
image,
|
|
hue_low=req.hue_low,
|
|
hue_high=req.hue_high,
|
|
sat_low=req.sat_low,
|
|
sat_high=req.sat_high,
|
|
val_low=req.val_low,
|
|
val_high=req.val_high,
|
|
morph_kernel=req.morph_kernel,
|
|
min_area_ratio=req.min_area_ratio,
|
|
)
|
|
infer_ms = (time.monotonic() - t0) * 1000
|
|
|
|
# Encode mask as base64 PNG for downstream use
|
|
import cv2
|
|
_, buf = cv2.imencode(".png", result["mask"])
|
|
mask_b64 = base64.b64encode(buf.tobytes()).decode()
|
|
|
|
_gpu_log(job_id, log_level, "GPU:CV", "DEBUG",
|
|
f"Field segmentation {w}x{h}: {infer_ms:.0f}ms, coverage={result['coverage']:.1%}")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Field segmentation failed: {e}")
|
|
|
|
return SegmentFieldResponse(
|
|
boundary=result["boundary"],
|
|
coverage=result["coverage"],
|
|
mask_b64=mask_b64,
|
|
)
|
|
|
|
|
|
@app.post("/segment_field/debug", response_model=SegmentFieldDebugResponse)
|
|
def segment_field_debug_endpoint(req: SegmentFieldRequest, request: Request):
|
|
job_id, log_level = _job_ctx(request)
|
|
|
|
try:
|
|
image = _decode_image(req.image)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
|
|
|
try:
|
|
from models.cv.segmentation import segment_field_debug
|
|
|
|
result = segment_field_debug(
|
|
image,
|
|
hue_low=req.hue_low,
|
|
hue_high=req.hue_high,
|
|
sat_low=req.sat_low,
|
|
sat_high=req.sat_high,
|
|
val_low=req.val_low,
|
|
val_high=req.val_high,
|
|
morph_kernel=req.morph_kernel,
|
|
min_area_ratio=req.min_area_ratio,
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Field segmentation debug failed: {e}")
|
|
|
|
return SegmentFieldDebugResponse(
|
|
boundary=result["boundary"],
|
|
coverage=result["coverage"],
|
|
mask_overlay_b64=result["mask_overlay_b64"],
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
|
host = os.environ.get("HOST", "0.0.0.0")
|
|
port = int(os.environ.get("PORT", "8000"))
|
|
uvicorn.run(app, host=host, port=port)
|