phase 4
This commit is contained in:
399
core/gpu/server.py
Normal file
399
core/gpu/server.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""
|
||||
Inference server — thin HTTP routes over model wrappers.
|
||||
|
||||
Config lives in config.py, model logic in models/.
|
||||
This file is just the FastAPI glue.
|
||||
|
||||
Usage:
|
||||
cd gpu && python server.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from emit import log as emit_log
|
||||
|
||||
from config import get_config, get_device, update_config
|
||||
from models import registry
|
||||
from models.yolo import detect as yolo_detect
|
||||
from models.ocr import ocr as ocr_run
|
||||
from models.vlm import query as vlm_query
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _decode_image(b64: str) -> np.ndarray:
|
||||
data = base64.b64decode(b64)
|
||||
img = Image.open(io.BytesIO(data)).convert("RGB")
|
||||
return np.array(img)
|
||||
|
||||
|
||||
def _job_ctx(request: Request) -> tuple[str, str]:
|
||||
"""Extract job_id and log_level from request headers."""
|
||||
job_id = request.headers.get("x-job-id", "")
|
||||
log_level = request.headers.get("x-log-level", "INFO")
|
||||
return job_id, log_level
|
||||
|
||||
|
||||
def _gpu_log(job_id: str, log_level: str, stage: str, level: str, msg: str):
|
||||
"""Emit a log event if job context is present."""
|
||||
if job_id:
|
||||
emit_log(job_id, stage, level, msg, log_level=log_level)
|
||||
|
||||
|
||||
# --- Request/Response models (generated from core/schema/models/inference.py) ---
|
||||
|
||||
from models.models import (
|
||||
AnalyzeRegionsDebugResponse,
|
||||
AnalyzeRegionsRequest,
|
||||
AnalyzeRegionsResponse,
|
||||
BBox,
|
||||
ConfigUpdate,
|
||||
DetectRequest,
|
||||
DetectResponse,
|
||||
OCRRequest,
|
||||
OCRResponse,
|
||||
OCRTextResult,
|
||||
PreprocessRequest,
|
||||
PreprocessResponse,
|
||||
RegionBox,
|
||||
SegmentFieldRequest,
|
||||
SegmentFieldResponse,
|
||||
SegmentFieldDebugResponse,
|
||||
VLMRequest,
|
||||
VLMResponse,
|
||||
)
|
||||
|
||||
|
||||
# --- App ---
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("Inference server starting (device=%s)", get_device())
|
||||
yield
|
||||
logger.info("Shutting down")
|
||||
registry.clear()
|
||||
|
||||
|
||||
app = FastAPI(title="MPR Inference Server", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
cfg = get_config()
|
||||
return {
|
||||
"status": "ok",
|
||||
"device": get_device(),
|
||||
"loaded_models": registry.loaded(),
|
||||
"vram_budget_mb": cfg["vram_budget_mb"],
|
||||
"strategy": cfg["strategy"],
|
||||
}
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
def read_config():
|
||||
return {**get_config(), "device_resolved": get_device()}
|
||||
|
||||
|
||||
@app.put("/config")
|
||||
def write_config(update: ConfigUpdate):
|
||||
changes = update.model_dump(exclude_none=True)
|
||||
if not changes:
|
||||
return get_config()
|
||||
|
||||
# Unload model if it changed
|
||||
old_model = get_config().get("yolo_model")
|
||||
if "yolo_model" in changes and changes["yolo_model"] != old_model:
|
||||
registry.unload(old_model)
|
||||
|
||||
update_config(changes)
|
||||
logger.info("Config updated: %s", changes)
|
||||
return {**get_config(), "device_resolved": get_device()}
|
||||
|
||||
|
||||
@app.post("/models/unload")
|
||||
def unload_model(body: dict):
|
||||
name = body.get("model", "")
|
||||
unloaded = registry.unload(name)
|
||||
return {"status": "unloaded" if unloaded else "not_loaded", "model": name}
|
||||
|
||||
|
||||
@app.post("/detect", response_model=DetectResponse)
|
||||
def detect(req: DetectRequest, request: Request):
|
||||
job_id, log_level = _job_ctx(request)
|
||||
|
||||
try:
|
||||
t0 = time.monotonic()
|
||||
image = _decode_image(req.image)
|
||||
decode_ms = (time.monotonic() - t0) * 1000
|
||||
h, w = image.shape[:2]
|
||||
_gpu_log(job_id, log_level, "GPU:YOLO", "DEBUG",
|
||||
f"Decoded {w}x{h} image in {decode_ms:.0f}ms")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||
|
||||
try:
|
||||
t0 = time.monotonic()
|
||||
results = yolo_detect(
|
||||
image,
|
||||
model_name=req.model,
|
||||
confidence=req.confidence,
|
||||
target_classes=req.target_classes,
|
||||
)
|
||||
infer_ms = (time.monotonic() - t0) * 1000
|
||||
_gpu_log(job_id, log_level, "GPU:YOLO", "DEBUG",
|
||||
f"Inference: {len(results)} detections in {infer_ms:.0f}ms "
|
||||
f"(model={req.model}, conf={req.confidence})")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Detection failed: {e}")
|
||||
|
||||
return DetectResponse(detections=[BBox(**r) for r in results])
|
||||
|
||||
|
||||
@app.post("/ocr", response_model=OCRResponse)
|
||||
def ocr(req: OCRRequest, request: Request):
|
||||
job_id, log_level = _job_ctx(request)
|
||||
|
||||
try:
|
||||
image = _decode_image(req.image)
|
||||
h, w = image.shape[:2]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||
|
||||
try:
|
||||
t0 = time.monotonic()
|
||||
results = ocr_run(image, languages=req.languages)
|
||||
infer_ms = (time.monotonic() - t0) * 1000
|
||||
texts = [r["text"][:20] for r in results]
|
||||
_gpu_log(job_id, log_level, "GPU:OCR", "DEBUG",
|
||||
f"OCR {w}x{h}: {infer_ms:.0f}ms → {len(results)} results {texts}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"OCR failed: {e}")
|
||||
|
||||
return OCRResponse(results=[OCRTextResult(**r) for r in results])
|
||||
|
||||
|
||||
@app.post("/preprocess", response_model=PreprocessResponse)
|
||||
def preprocess_image(req: PreprocessRequest):
|
||||
try:
|
||||
image = _decode_image(req.image)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||
|
||||
try:
|
||||
from models.preprocess import preprocess
|
||||
processed = preprocess(
|
||||
image,
|
||||
do_binarize=req.binarize,
|
||||
do_deskew=req.deskew,
|
||||
do_contrast=req.contrast,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Preprocessing failed: {e}")
|
||||
|
||||
from PIL import Image as PILImage
|
||||
import io
|
||||
img = PILImage.fromarray(processed)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=90)
|
||||
result_b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
return PreprocessResponse(image=result_b64)
|
||||
|
||||
|
||||
@app.post("/vlm", response_model=VLMResponse)
|
||||
def vlm(req: VLMRequest, request: Request):
|
||||
job_id, log_level = _job_ctx(request)
|
||||
|
||||
try:
|
||||
image = _decode_image(req.image)
|
||||
h, w = image.shape[:2]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||
|
||||
try:
|
||||
t0 = time.monotonic()
|
||||
result = vlm_query(image, req.prompt)
|
||||
infer_ms = (time.monotonic() - t0) * 1000
|
||||
_gpu_log(job_id, log_level, "GPU:VLM", "DEBUG",
|
||||
f"VLM {w}x{h}: {infer_ms:.0f}ms → "
|
||||
f"brand='{result.get('brand', '')}' conf={result.get('confidence', 0):.2f}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"VLM failed: {e}")
|
||||
|
||||
return VLMResponse(**result)
|
||||
|
||||
|
||||
@app.post("/detect_edges", response_model=AnalyzeRegionsResponse)
|
||||
def detect_edges_endpoint(req: AnalyzeRegionsRequest, request: Request):
|
||||
job_id, log_level = _job_ctx(request)
|
||||
|
||||
try:
|
||||
image = _decode_image(req.image)
|
||||
h, w = image.shape[:2]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||
|
||||
try:
|
||||
t0 = time.monotonic()
|
||||
from models.cv.edges import detect_edges
|
||||
|
||||
edge_regions = detect_edges(
|
||||
image,
|
||||
canny_low=req.edge_canny_low,
|
||||
canny_high=req.edge_canny_high,
|
||||
hough_threshold=req.edge_hough_threshold,
|
||||
hough_min_length=req.edge_hough_min_length,
|
||||
hough_max_gap=req.edge_hough_max_gap,
|
||||
pair_max_distance=req.edge_pair_max_distance,
|
||||
pair_min_distance=req.edge_pair_min_distance,
|
||||
)
|
||||
infer_ms = (time.monotonic() - t0) * 1000
|
||||
|
||||
_gpu_log(job_id, log_level, "GPU:CV", "DEBUG",
|
||||
f"Edge analysis {w}x{h}: {infer_ms:.0f}ms → {len(edge_regions)} regions")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Region analysis failed: {e}")
|
||||
|
||||
boxes = [RegionBox(**r) for r in edge_regions]
|
||||
return AnalyzeRegionsResponse(regions=boxes)
|
||||
|
||||
|
||||
@app.post("/detect_edges/debug", response_model=AnalyzeRegionsDebugResponse)
|
||||
def detect_edges_debug_endpoint(req: AnalyzeRegionsRequest, request: Request):
|
||||
job_id, log_level = _job_ctx(request)
|
||||
|
||||
try:
|
||||
image = _decode_image(req.image)
|
||||
h, w = image.shape[:2]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||
|
||||
try:
|
||||
t0 = time.monotonic()
|
||||
from models.cv.edges import detect_edges_debug
|
||||
|
||||
result = detect_edges_debug(
|
||||
image,
|
||||
canny_low=req.edge_canny_low,
|
||||
canny_high=req.edge_canny_high,
|
||||
hough_threshold=req.edge_hough_threshold,
|
||||
hough_min_length=req.edge_hough_min_length,
|
||||
hough_max_gap=req.edge_hough_max_gap,
|
||||
pair_max_distance=req.edge_pair_max_distance,
|
||||
pair_min_distance=req.edge_pair_min_distance,
|
||||
)
|
||||
infer_ms = (time.monotonic() - t0) * 1000
|
||||
|
||||
_gpu_log(job_id, log_level, "GPU:CV", "DEBUG",
|
||||
f"Edge debug {w}x{h}: {infer_ms:.0f}ms → {len(result['regions'])} regions, "
|
||||
f"{result['horizontal_count']} horizontals, {result['pair_count']} pairs")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Region debug analysis failed: {e}")
|
||||
|
||||
boxes = [RegionBox(**r) for r in result["regions"]]
|
||||
response = AnalyzeRegionsDebugResponse(
|
||||
regions=boxes,
|
||||
edge_overlay_b64=result["edge_overlay_b64"],
|
||||
lines_overlay_b64=result["lines_overlay_b64"],
|
||||
horizontal_count=result["horizontal_count"],
|
||||
pair_count=result["pair_count"],
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@app.post("/segment_field", response_model=SegmentFieldResponse)
|
||||
def segment_field_endpoint(req: SegmentFieldRequest, request: Request):
|
||||
job_id, log_level = _job_ctx(request)
|
||||
|
||||
try:
|
||||
image = _decode_image(req.image)
|
||||
h, w = image.shape[:2]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||
|
||||
try:
|
||||
t0 = time.monotonic()
|
||||
from models.cv.segmentation import segment_field
|
||||
|
||||
result = segment_field(
|
||||
image,
|
||||
hue_low=req.hue_low,
|
||||
hue_high=req.hue_high,
|
||||
sat_low=req.sat_low,
|
||||
sat_high=req.sat_high,
|
||||
val_low=req.val_low,
|
||||
val_high=req.val_high,
|
||||
morph_kernel=req.morph_kernel,
|
||||
min_area_ratio=req.min_area_ratio,
|
||||
)
|
||||
infer_ms = (time.monotonic() - t0) * 1000
|
||||
|
||||
# Encode mask as base64 PNG for downstream use
|
||||
import cv2
|
||||
_, buf = cv2.imencode(".png", result["mask"])
|
||||
mask_b64 = base64.b64encode(buf.tobytes()).decode()
|
||||
|
||||
_gpu_log(job_id, log_level, "GPU:CV", "DEBUG",
|
||||
f"Field segmentation {w}x{h}: {infer_ms:.0f}ms, coverage={result['coverage']:.1%}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Field segmentation failed: {e}")
|
||||
|
||||
return SegmentFieldResponse(
|
||||
boundary=result["boundary"],
|
||||
coverage=result["coverage"],
|
||||
mask_b64=mask_b64,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/segment_field/debug", response_model=SegmentFieldDebugResponse)
|
||||
def segment_field_debug_endpoint(req: SegmentFieldRequest, request: Request):
|
||||
job_id, log_level = _job_ctx(request)
|
||||
|
||||
try:
|
||||
image = _decode_image(req.image)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||
|
||||
try:
|
||||
from models.cv.segmentation import segment_field_debug
|
||||
|
||||
result = segment_field_debug(
|
||||
image,
|
||||
hue_low=req.hue_low,
|
||||
hue_high=req.hue_high,
|
||||
sat_low=req.sat_low,
|
||||
sat_high=req.sat_high,
|
||||
val_low=req.val_low,
|
||||
val_high=req.val_high,
|
||||
morph_kernel=req.morph_kernel,
|
||||
min_area_ratio=req.min_area_ratio,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Field segmentation debug failed: {e}")
|
||||
|
||||
return SegmentFieldDebugResponse(
|
||||
boundary=result["boundary"],
|
||||
coverage=result["coverage"],
|
||||
mask_overlay_b64=result["mask_overlay_b64"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
||||
host = os.environ.get("HOST", "0.0.0.0")
|
||||
port = int(os.environ.get("PORT", "8000"))
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
Reference in New Issue
Block a user