This commit is contained in:
2026-03-30 07:22:14 -03:00
parent d0707333fd
commit 4220b0418e
182 changed files with 3668 additions and 5231 deletions

399
core/gpu/server.py Normal file
View File

@@ -0,0 +1,399 @@
"""
Inference server — thin HTTP routes over model wrappers.
Config lives in config.py, model logic in models/.
This file is just the FastAPI glue.
Usage:
cd gpu && python server.py
"""
from __future__ import annotations
import base64
import io
import logging
import os
import time
from contextlib import asynccontextmanager
import numpy as np
from fastapi import FastAPI, HTTPException, Request
from PIL import Image
from pydantic import BaseModel
from emit import log as emit_log
from config import get_config, get_device, update_config
from models import registry
from models.yolo import detect as yolo_detect
from models.ocr import ocr as ocr_run
from models.vlm import query as vlm_query
logger = logging.getLogger(__name__)
def _decode_image(b64: str) -> np.ndarray:
data = base64.b64decode(b64)
img = Image.open(io.BytesIO(data)).convert("RGB")
return np.array(img)
def _job_ctx(request: Request) -> tuple[str, str]:
"""Extract job_id and log_level from request headers."""
job_id = request.headers.get("x-job-id", "")
log_level = request.headers.get("x-log-level", "INFO")
return job_id, log_level
def _gpu_log(job_id: str, log_level: str, stage: str, level: str, msg: str):
"""Emit a log event if job context is present."""
if job_id:
emit_log(job_id, stage, level, msg, log_level=log_level)
# --- Request/Response models (generated from core/schema/models/inference.py) ---
from models.models import (
AnalyzeRegionsDebugResponse,
AnalyzeRegionsRequest,
AnalyzeRegionsResponse,
BBox,
ConfigUpdate,
DetectRequest,
DetectResponse,
OCRRequest,
OCRResponse,
OCRTextResult,
PreprocessRequest,
PreprocessResponse,
RegionBox,
SegmentFieldRequest,
SegmentFieldResponse,
SegmentFieldDebugResponse,
VLMRequest,
VLMResponse,
)
# --- App ---
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Inference server starting (device=%s)", get_device())
yield
logger.info("Shutting down")
registry.clear()
app = FastAPI(title="MPR Inference Server", lifespan=lifespan)
@app.get("/health")
def health():
cfg = get_config()
return {
"status": "ok",
"device": get_device(),
"loaded_models": registry.loaded(),
"vram_budget_mb": cfg["vram_budget_mb"],
"strategy": cfg["strategy"],
}
@app.get("/config")
def read_config():
return {**get_config(), "device_resolved": get_device()}
@app.put("/config")
def write_config(update: ConfigUpdate):
changes = update.model_dump(exclude_none=True)
if not changes:
return get_config()
# Unload model if it changed
old_model = get_config().get("yolo_model")
if "yolo_model" in changes and changes["yolo_model"] != old_model:
registry.unload(old_model)
update_config(changes)
logger.info("Config updated: %s", changes)
return {**get_config(), "device_resolved": get_device()}
@app.post("/models/unload")
def unload_model(body: dict):
name = body.get("model", "")
unloaded = registry.unload(name)
return {"status": "unloaded" if unloaded else "not_loaded", "model": name}
@app.post("/detect", response_model=DetectResponse)
def detect(req: DetectRequest, request: Request):
job_id, log_level = _job_ctx(request)
try:
t0 = time.monotonic()
image = _decode_image(req.image)
decode_ms = (time.monotonic() - t0) * 1000
h, w = image.shape[:2]
_gpu_log(job_id, log_level, "GPU:YOLO", "DEBUG",
f"Decoded {w}x{h} image in {decode_ms:.0f}ms")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
t0 = time.monotonic()
results = yolo_detect(
image,
model_name=req.model,
confidence=req.confidence,
target_classes=req.target_classes,
)
infer_ms = (time.monotonic() - t0) * 1000
_gpu_log(job_id, log_level, "GPU:YOLO", "DEBUG",
f"Inference: {len(results)} detections in {infer_ms:.0f}ms "
f"(model={req.model}, conf={req.confidence})")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Detection failed: {e}")
return DetectResponse(detections=[BBox(**r) for r in results])
@app.post("/ocr", response_model=OCRResponse)
def ocr(req: OCRRequest, request: Request):
job_id, log_level = _job_ctx(request)
try:
image = _decode_image(req.image)
h, w = image.shape[:2]
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
t0 = time.monotonic()
results = ocr_run(image, languages=req.languages)
infer_ms = (time.monotonic() - t0) * 1000
texts = [r["text"][:20] for r in results]
_gpu_log(job_id, log_level, "GPU:OCR", "DEBUG",
f"OCR {w}x{h}: {infer_ms:.0f}ms → {len(results)} results {texts}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"OCR failed: {e}")
return OCRResponse(results=[OCRTextResult(**r) for r in results])
@app.post("/preprocess", response_model=PreprocessResponse)
def preprocess_image(req: PreprocessRequest):
try:
image = _decode_image(req.image)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
from models.preprocess import preprocess
processed = preprocess(
image,
do_binarize=req.binarize,
do_deskew=req.deskew,
do_contrast=req.contrast,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Preprocessing failed: {e}")
from PIL import Image as PILImage
import io
img = PILImage.fromarray(processed)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=90)
result_b64 = base64.b64encode(buf.getvalue()).decode()
return PreprocessResponse(image=result_b64)
@app.post("/vlm", response_model=VLMResponse)
def vlm(req: VLMRequest, request: Request):
job_id, log_level = _job_ctx(request)
try:
image = _decode_image(req.image)
h, w = image.shape[:2]
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
t0 = time.monotonic()
result = vlm_query(image, req.prompt)
infer_ms = (time.monotonic() - t0) * 1000
_gpu_log(job_id, log_level, "GPU:VLM", "DEBUG",
f"VLM {w}x{h}: {infer_ms:.0f}ms → "
f"brand='{result.get('brand', '')}' conf={result.get('confidence', 0):.2f}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"VLM failed: {e}")
return VLMResponse(**result)
@app.post("/detect_edges", response_model=AnalyzeRegionsResponse)
def detect_edges_endpoint(req: AnalyzeRegionsRequest, request: Request):
job_id, log_level = _job_ctx(request)
try:
image = _decode_image(req.image)
h, w = image.shape[:2]
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
t0 = time.monotonic()
from models.cv.edges import detect_edges
edge_regions = detect_edges(
image,
canny_low=req.edge_canny_low,
canny_high=req.edge_canny_high,
hough_threshold=req.edge_hough_threshold,
hough_min_length=req.edge_hough_min_length,
hough_max_gap=req.edge_hough_max_gap,
pair_max_distance=req.edge_pair_max_distance,
pair_min_distance=req.edge_pair_min_distance,
)
infer_ms = (time.monotonic() - t0) * 1000
_gpu_log(job_id, log_level, "GPU:CV", "DEBUG",
f"Edge analysis {w}x{h}: {infer_ms:.0f}ms → {len(edge_regions)} regions")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Region analysis failed: {e}")
boxes = [RegionBox(**r) for r in edge_regions]
return AnalyzeRegionsResponse(regions=boxes)
@app.post("/detect_edges/debug", response_model=AnalyzeRegionsDebugResponse)
def detect_edges_debug_endpoint(req: AnalyzeRegionsRequest, request: Request):
job_id, log_level = _job_ctx(request)
try:
image = _decode_image(req.image)
h, w = image.shape[:2]
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
t0 = time.monotonic()
from models.cv.edges import detect_edges_debug
result = detect_edges_debug(
image,
canny_low=req.edge_canny_low,
canny_high=req.edge_canny_high,
hough_threshold=req.edge_hough_threshold,
hough_min_length=req.edge_hough_min_length,
hough_max_gap=req.edge_hough_max_gap,
pair_max_distance=req.edge_pair_max_distance,
pair_min_distance=req.edge_pair_min_distance,
)
infer_ms = (time.monotonic() - t0) * 1000
_gpu_log(job_id, log_level, "GPU:CV", "DEBUG",
f"Edge debug {w}x{h}: {infer_ms:.0f}ms → {len(result['regions'])} regions, "
f"{result['horizontal_count']} horizontals, {result['pair_count']} pairs")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Region debug analysis failed: {e}")
boxes = [RegionBox(**r) for r in result["regions"]]
response = AnalyzeRegionsDebugResponse(
regions=boxes,
edge_overlay_b64=result["edge_overlay_b64"],
lines_overlay_b64=result["lines_overlay_b64"],
horizontal_count=result["horizontal_count"],
pair_count=result["pair_count"],
)
return response
@app.post("/segment_field", response_model=SegmentFieldResponse)
def segment_field_endpoint(req: SegmentFieldRequest, request: Request):
job_id, log_level = _job_ctx(request)
try:
image = _decode_image(req.image)
h, w = image.shape[:2]
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
t0 = time.monotonic()
from models.cv.segmentation import segment_field
result = segment_field(
image,
hue_low=req.hue_low,
hue_high=req.hue_high,
sat_low=req.sat_low,
sat_high=req.sat_high,
val_low=req.val_low,
val_high=req.val_high,
morph_kernel=req.morph_kernel,
min_area_ratio=req.min_area_ratio,
)
infer_ms = (time.monotonic() - t0) * 1000
# Encode mask as base64 PNG for downstream use
import cv2
_, buf = cv2.imencode(".png", result["mask"])
mask_b64 = base64.b64encode(buf.tobytes()).decode()
_gpu_log(job_id, log_level, "GPU:CV", "DEBUG",
f"Field segmentation {w}x{h}: {infer_ms:.0f}ms, coverage={result['coverage']:.1%}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Field segmentation failed: {e}")
return SegmentFieldResponse(
boundary=result["boundary"],
coverage=result["coverage"],
mask_b64=mask_b64,
)
@app.post("/segment_field/debug", response_model=SegmentFieldDebugResponse)
def segment_field_debug_endpoint(req: SegmentFieldRequest, request: Request):
job_id, log_level = _job_ctx(request)
try:
image = _decode_image(req.image)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
try:
from models.cv.segmentation import segment_field_debug
result = segment_field_debug(
image,
hue_low=req.hue_low,
hue_high=req.hue_high,
sat_low=req.sat_low,
sat_high=req.sat_high,
val_low=req.val_low,
val_high=req.val_high,
morph_kernel=req.morph_kernel,
min_area_ratio=req.min_area_ratio,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Field segmentation debug failed: {e}")
return SegmentFieldDebugResponse(
boundary=result["boundary"],
coverage=result["coverage"],
mask_overlay_b64=result["mask_overlay_b64"],
)
if __name__ == "__main__":
import uvicorn
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s%(message)s")
host = os.environ.get("HOST", "0.0.0.0")
port = int(os.environ.get("PORT", "8000"))
uvicorn.run(app, host=host, port=port)