phase 7
This commit is contained in:
@@ -10,5 +10,9 @@ STRATEGY=sequential # sequential | concurrent | auto
|
||||
YOLO_MODEL=yolov8n.pt
|
||||
YOLO_CONFIDENCE=0.3
|
||||
|
||||
# OCR
|
||||
OCR_LANGUAGES=en,es
|
||||
OCR_MIN_CONFIDENCE=0.5
|
||||
|
||||
# Device
|
||||
DEVICE=auto # auto | cpu | cuda | cuda:0
|
||||
|
||||
0
gpu/__init__.py
Normal file
0
gpu/__init__.py
Normal file
39
gpu/config.py
Normal file
39
gpu/config.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
Runtime config — loaded from env, mutable via API.
|
||||
|
||||
The UI config panel is just a visual editor for these same values.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
_config = {
|
||||
"device": os.environ.get("DEVICE", "auto"),
|
||||
"yolo_model": os.environ.get("YOLO_MODEL", "yolov8n.pt"),
|
||||
"yolo_confidence": float(os.environ.get("YOLO_CONFIDENCE", "0.3")),
|
||||
"vram_budget_mb": int(os.environ.get("VRAM_BUDGET_MB", "10240")),
|
||||
"strategy": os.environ.get("STRATEGY", "sequential"),
|
||||
"ocr_languages": os.environ.get("OCR_LANGUAGES", "en").split(","),
|
||||
"ocr_min_confidence": float(os.environ.get("OCR_MIN_CONFIDENCE", "0.5")),
|
||||
}
|
||||
|
||||
|
||||
def get_config() -> dict:
|
||||
return _config
|
||||
|
||||
|
||||
def update_config(changes: dict) -> dict:
|
||||
_config.update(changes)
|
||||
return _config
|
||||
|
||||
|
||||
def get_device() -> str:
|
||||
device = _config["device"]
|
||||
if device != "auto":
|
||||
return device
|
||||
try:
|
||||
import torch
|
||||
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||
except ImportError:
|
||||
return "cpu"
|
||||
5
gpu/models/__init__.py
Normal file
5
gpu/models/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from . import registry
|
||||
from .yolo import detect
|
||||
from .ocr import ocr
|
||||
|
||||
__all__ = ["registry", "detect", "ocr"]
|
||||
105
gpu/models/ocr.py
Normal file
105
gpu/models/ocr.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""PaddleOCR 3.x text extraction wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from models import registry
|
||||
from config import get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load(languages: list[str]):
|
||||
from paddleocr import PaddleOCR
|
||||
key = f"ocr_{'_'.join(languages)}"
|
||||
model = PaddleOCR(lang=languages[0])
|
||||
registry.put(key, model)
|
||||
return model
|
||||
|
||||
|
||||
def _get(languages: list[str] | None = None):
|
||||
langs = languages or get_config()["ocr_languages"]
|
||||
key = f"ocr_{'_'.join(langs)}"
|
||||
model = registry.get(key)
|
||||
if model is None:
|
||||
model = _load(langs)
|
||||
return model
|
||||
|
||||
|
||||
def _parse_raw(raw) -> list[tuple[list, str, float]]:
|
||||
"""
|
||||
Parse PaddleOCR output into (points, text, confidence) tuples.
|
||||
|
||||
PaddleOCR 3.x changed the result format. Two known layouts:
|
||||
|
||||
Layout A — dict-based (new pipeline API):
|
||||
raw = [{'rec_texts': [...], 'rec_scores': [...], 'dt_polys': [...]}]
|
||||
|
||||
Layout B — nested list (2.x compat / some 3.x builds):
|
||||
raw = [[ [points, [text, score]], ... ]]
|
||||
raw = [[ [points, [text, score], [cls, cls_score]], ... ]] # with angle cls
|
||||
"""
|
||||
results = []
|
||||
|
||||
for page in raw:
|
||||
if not page:
|
||||
continue
|
||||
|
||||
# Layout A: dict with parallel lists
|
||||
if isinstance(page, dict):
|
||||
texts = page.get("rec_texts", [])
|
||||
scores = page.get("rec_scores", [])
|
||||
polys = page.get("dt_polys", [])
|
||||
for points, text, confidence in zip(polys, texts, scores):
|
||||
results.append((points, text, float(confidence)))
|
||||
continue
|
||||
|
||||
# Layout B: list of per-line entries
|
||||
for line in page:
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# line[0] is always the polygon points
|
||||
points = line[0]
|
||||
|
||||
# line[1] is [text, score] — ignore any extra elements (angle cls etc.)
|
||||
rec = line[1]
|
||||
if isinstance(rec, (list, tuple)) and len(rec) >= 2:
|
||||
text, confidence = rec[0], rec[1]
|
||||
else:
|
||||
logger.warning("Unexpected OCR line format: %s", line)
|
||||
continue
|
||||
|
||||
results.append((points, str(text), float(confidence)))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def ocr(image, languages: list[str] | None = None, min_confidence: float | None = None) -> list[dict]:
|
||||
"""Run OCR on an image, return list of text result dicts."""
|
||||
cfg = get_config()
|
||||
min_conf = min_confidence if min_confidence is not None else cfg["ocr_min_confidence"]
|
||||
model = _get(languages)
|
||||
|
||||
raw = model.ocr(image)
|
||||
logger.debug("OCR raw: %s", raw)
|
||||
|
||||
parsed = _parse_raw(raw)
|
||||
|
||||
results = []
|
||||
for points, text, confidence in parsed:
|
||||
if confidence < min_conf:
|
||||
continue
|
||||
|
||||
xs = [p[0] for p in points]
|
||||
ys = [p[1] for p in points]
|
||||
|
||||
results.append({
|
||||
"text": text,
|
||||
"confidence": confidence,
|
||||
"bbox": [int(min(xs)), int(min(ys)),
|
||||
int(max(xs) - min(xs)), int(max(ys) - min(ys))],
|
||||
})
|
||||
|
||||
return results
|
||||
37
gpu/models/registry.py
Normal file
37
gpu/models/registry.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
Model registry — manages loaded models and VRAM lifecycle.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_models: dict[str, object] = {}
|
||||
|
||||
|
||||
def get(name: str) -> object | None:
|
||||
return _models.get(name)
|
||||
|
||||
|
||||
def put(name: str, model: object) -> None:
|
||||
_models[name] = model
|
||||
logger.info("Loaded %s", name)
|
||||
|
||||
|
||||
def unload(name: str) -> bool:
|
||||
if name in _models:
|
||||
del _models[name]
|
||||
logger.info("Unloaded %s", name)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def loaded() -> list[str]:
|
||||
return list(_models.keys())
|
||||
|
||||
|
||||
def clear() -> None:
|
||||
_models.clear()
|
||||
logger.info("All models unloaded")
|
||||
54
gpu/models/yolo.py
Normal file
54
gpu/models/yolo.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""YOLO object detection model wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from models import registry
|
||||
from config import get_config, get_device
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load(model_name: str):
|
||||
from ultralytics import YOLO
|
||||
device = get_device()
|
||||
model = YOLO(model_name)
|
||||
model.to(device)
|
||||
registry.put(model_name, model)
|
||||
return model
|
||||
|
||||
|
||||
def _get(model_name: str | None = None):
|
||||
name = model_name or get_config()["yolo_model"]
|
||||
model = registry.get(name)
|
||||
if model is None:
|
||||
model = _load(name)
|
||||
return model
|
||||
|
||||
|
||||
def detect(image, model_name: str | None = None, confidence: float | None = None, target_classes: list[str] | None = None) -> list[dict]:
|
||||
"""Run YOLO detection, return list of bbox dicts."""
|
||||
cfg = get_config()
|
||||
conf = confidence if confidence is not None else cfg["yolo_confidence"]
|
||||
model = _get(model_name)
|
||||
|
||||
results = model(image, conf=conf, verbose=False)
|
||||
|
||||
detections = []
|
||||
for r in results:
|
||||
for box in r.boxes:
|
||||
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||
label = r.names[int(box.cls[0])]
|
||||
|
||||
if target_classes and label not in target_classes:
|
||||
continue
|
||||
|
||||
detections.append({
|
||||
"x": int(x1), "y": int(y1),
|
||||
"w": int(x2 - x1), "h": int(y2 - y1),
|
||||
"confidence": float(box.conf[0]),
|
||||
"label": label,
|
||||
})
|
||||
|
||||
return detections
|
||||
@@ -1,4 +1,21 @@
|
||||
fastapi>=0.109.0
|
||||
uvicorn[standard]>=0.27.0
|
||||
ultralytics>=8.0.0
|
||||
rapidfuzz>=3.0.0
|
||||
Pillow>=10.0.0
|
||||
|
||||
# --- GPU-specific installs (mcrn: RTX 3080, CUDA toolkit 12.8) ---
|
||||
#
|
||||
# torch: must be installed from the PyTorch index, NOT from PyPI.
|
||||
# cu126 is the closest build to CUDA 12.8 (no cu128 wheel yet; cu126 is forward-compatible).
|
||||
# Install with:
|
||||
# uv pip install --reinstall torch torchvision --index-url https://download.pytorch.org/whl/cu126
|
||||
#
|
||||
# ultralytics pulls torch as a dependency — reinstall torch after ultralytics to ensure
|
||||
# the correct CUDA build. Mixing the PyPI torch with CUDA 12.8 causes NCCL symbol errors.
|
||||
ultralytics>=8.0.0
|
||||
|
||||
# paddlepaddle-gpu: NOT available on PyPI. Install from PaddlePaddle's package index.
|
||||
# cu126 build works on CUDA 12.8.
|
||||
# Install with:
|
||||
# uv pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
paddleocr>=3.0.0
|
||||
|
||||
169
gpu/server.py
169
gpu/server.py
@@ -1,16 +1,10 @@
|
||||
"""
|
||||
Inference server — thin HTTP wrapper around ML models.
|
||||
Inference server — thin HTTP routes over model wrappers.
|
||||
|
||||
Runs on the GPU machine. The detection pipeline calls this over HTTP,
|
||||
or imports the same logic locally if GPU is on the same machine.
|
||||
|
||||
Config is loaded from env on startup, then editable at runtime via
|
||||
GET/PUT /config. The UI config panel is just a visual editor for these
|
||||
same values.
|
||||
Config lives in config.py, model logic in models/.
|
||||
This file is just the FastAPI glue.
|
||||
|
||||
Usage:
|
||||
cd gpu && uvicorn server:app --host 0.0.0.0 --port 8000
|
||||
# or
|
||||
cd gpu && python server.py
|
||||
"""
|
||||
|
||||
@@ -27,45 +21,13 @@ from fastapi import FastAPI, HTTPException
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from config import get_config, get_device, update_config
|
||||
from models import registry
|
||||
from models.yolo import detect as yolo_detect
|
||||
from models.ocr import ocr as ocr_run
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Runtime config (loaded from env, mutable via API) ---
|
||||
_config = {
|
||||
"device": os.environ.get("DEVICE", "auto"),
|
||||
"yolo_model": os.environ.get("YOLO_MODEL", "yolov8n.pt"),
|
||||
"yolo_confidence": float(os.environ.get("YOLO_CONFIDENCE", "0.3")),
|
||||
"vram_budget_mb": int(os.environ.get("VRAM_BUDGET_MB", "10240")),
|
||||
"strategy": os.environ.get("STRATEGY", "sequential"),
|
||||
}
|
||||
|
||||
# --- Model registry ---
|
||||
_models: dict[str, object] = {}
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
def _get_device() -> str:
|
||||
device = _config["device"]
|
||||
if device != "auto":
|
||||
return device
|
||||
try:
|
||||
import torch
|
||||
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||
except ImportError:
|
||||
return "cpu"
|
||||
|
||||
|
||||
def _get_yolo(model_name: str | None = None):
|
||||
name = model_name or _config["yolo_model"]
|
||||
if name not in _models:
|
||||
from ultralytics import YOLO
|
||||
device = _get_device()
|
||||
logger.info("Loading %s on %s", name, device)
|
||||
model = YOLO(name)
|
||||
model.to(device)
|
||||
_models[name] = model
|
||||
return _models[name]
|
||||
|
||||
|
||||
def _decode_image(b64: str) -> np.ndarray:
|
||||
data = base64.b64decode(b64)
|
||||
@@ -76,9 +38,9 @@ def _decode_image(b64: str) -> np.ndarray:
|
||||
# --- Request/Response models ---
|
||||
|
||||
class DetectRequest(BaseModel):
|
||||
image: str # base64 JPEG
|
||||
model: str | None = None # defaults to config yolo_model
|
||||
confidence: float | None = None # defaults to config yolo_confidence
|
||||
image: str
|
||||
model: str | None = None
|
||||
confidence: float | None = None
|
||||
target_classes: list[str] | None = None
|
||||
|
||||
|
||||
@@ -95,23 +57,39 @@ class DetectResponse(BaseModel):
|
||||
detections: list[BBox]
|
||||
|
||||
|
||||
class OCRRequest(BaseModel):
|
||||
image: str
|
||||
languages: list[str] | None = None
|
||||
|
||||
|
||||
class OCRTextResult(BaseModel):
|
||||
text: str
|
||||
confidence: float
|
||||
bbox: list[int]
|
||||
|
||||
|
||||
class OCRResponse(BaseModel):
|
||||
results: list[OCRTextResult]
|
||||
|
||||
|
||||
class ConfigUpdate(BaseModel):
|
||||
"""Partial config update — only provided fields are changed."""
|
||||
device: str | None = None
|
||||
yolo_model: str | None = None
|
||||
yolo_confidence: float | None = None
|
||||
vram_budget_mb: int | None = None
|
||||
strategy: str | None = None
|
||||
ocr_languages: list[str] | None = None
|
||||
ocr_min_confidence: float | None = None
|
||||
|
||||
|
||||
# --- App ---
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("Inference server starting (device=%s)", _get_device())
|
||||
logger.info("Inference server starting (device=%s)", get_device())
|
||||
yield
|
||||
logger.info("Inference server shutting down")
|
||||
_models.clear()
|
||||
logger.info("Shutting down")
|
||||
registry.clear()
|
||||
|
||||
|
||||
app = FastAPI(title="MPR Inference Server", lifespan=lifespan)
|
||||
@@ -119,82 +97,77 @@ app = FastAPI(title="MPR Inference Server", lifespan=lifespan)
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
cfg = get_config()
|
||||
return {
|
||||
"status": "ok",
|
||||
"device": _get_device(),
|
||||
"loaded_models": list(_models.keys()),
|
||||
"vram_budget_mb": _config["vram_budget_mb"],
|
||||
"strategy": _config["strategy"],
|
||||
"device": get_device(),
|
||||
"loaded_models": registry.loaded(),
|
||||
"vram_budget_mb": cfg["vram_budget_mb"],
|
||||
"strategy": cfg["strategy"],
|
||||
}
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
def get_config():
|
||||
"""Current runtime config. Same values the .env sets at startup."""
|
||||
return {**_config, "device_resolved": _get_device()}
|
||||
def read_config():
|
||||
return {**get_config(), "device_resolved": get_device()}
|
||||
|
||||
|
||||
@app.put("/config")
|
||||
def update_config(update: ConfigUpdate):
|
||||
"""Update runtime config. Only provided fields are changed."""
|
||||
def write_config(update: ConfigUpdate):
|
||||
changes = update.model_dump(exclude_none=True)
|
||||
if not changes:
|
||||
return _config
|
||||
return get_config()
|
||||
|
||||
# If model changed, unload the old one so it gets reloaded on next request
|
||||
if "yolo_model" in changes and changes["yolo_model"] != _config["yolo_model"]:
|
||||
old = _config["yolo_model"]
|
||||
if old in _models:
|
||||
del _models[old]
|
||||
logger.info("Unloaded %s (model changed)", old)
|
||||
# Unload model if it changed
|
||||
old_model = get_config().get("yolo_model")
|
||||
if "yolo_model" in changes and changes["yolo_model"] != old_model:
|
||||
registry.unload(old_model)
|
||||
|
||||
_config.update(changes)
|
||||
update_config(changes)
|
||||
logger.info("Config updated: %s", changes)
|
||||
return {**_config, "device_resolved": _get_device()}
|
||||
return {**get_config(), "device_resolved": get_device()}
|
||||
|
||||
|
||||
@app.post("/models/unload")
|
||||
def unload_model(body: dict):
|
||||
"""Unload a model from memory to free VRAM."""
|
||||
name = body.get("model", "")
|
||||
if name in _models:
|
||||
del _models[name]
|
||||
logger.info("Unloaded %s", name)
|
||||
return {"status": "unloaded", "model": name}
|
||||
return {"status": "not_loaded", "model": name}
|
||||
unloaded = registry.unload(name)
|
||||
return {"status": "unloaded" if unloaded else "not_loaded", "model": name}
|
||||
|
||||
|
||||
@app.post("/detect", response_model=DetectResponse)
|
||||
def detect(req: DetectRequest):
|
||||
model_name = req.model or _config["yolo_model"]
|
||||
confidence = req.confidence if req.confidence is not None else _config["yolo_confidence"]
|
||||
try:
|
||||
image = _decode_image(req.image)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||
|
||||
try:
|
||||
model = _get_yolo(model_name)
|
||||
results = yolo_detect(
|
||||
image,
|
||||
model_name=req.model,
|
||||
confidence=req.confidence,
|
||||
target_classes=req.target_classes,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to load model: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Detection failed: {e}")
|
||||
|
||||
image = _decode_image(req.image)
|
||||
results = model(image, conf=confidence, verbose=False)
|
||||
return DetectResponse(detections=[BBox(**r) for r in results])
|
||||
|
||||
detections = []
|
||||
for r in results:
|
||||
for box in r.boxes:
|
||||
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||
label = r.names[int(box.cls[0])]
|
||||
|
||||
if req.target_classes and label not in req.target_classes:
|
||||
continue
|
||||
@app.post("/ocr", response_model=OCRResponse)
|
||||
def ocr(req: OCRRequest):
|
||||
try:
|
||||
image = _decode_image(req.image)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||
|
||||
det = BBox(
|
||||
x=int(x1), y=int(y1),
|
||||
w=int(x2 - x1), h=int(y2 - y1),
|
||||
confidence=float(box.conf[0]),
|
||||
label=label,
|
||||
)
|
||||
detections.append(det)
|
||||
try:
|
||||
results = ocr_run(image, languages=req.languages)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"OCR failed: {e}")
|
||||
|
||||
return DetectResponse(detections=detections)
|
||||
return OCRResponse(results=[OCRTextResult(**r) for r in results])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user