phase 7
This commit is contained in:
5
gpu/models/__init__.py
Normal file
5
gpu/models/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from . import registry
|
||||
from .yolo import detect
|
||||
from .ocr import ocr
|
||||
|
||||
__all__ = ["registry", "detect", "ocr"]
|
||||
105
gpu/models/ocr.py
Normal file
105
gpu/models/ocr.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""PaddleOCR 3.x text extraction wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from models import registry
|
||||
from config import get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load(languages: list[str]):
|
||||
from paddleocr import PaddleOCR
|
||||
key = f"ocr_{'_'.join(languages)}"
|
||||
model = PaddleOCR(lang=languages[0])
|
||||
registry.put(key, model)
|
||||
return model
|
||||
|
||||
|
||||
def _get(languages: list[str] | None = None):
|
||||
langs = languages or get_config()["ocr_languages"]
|
||||
key = f"ocr_{'_'.join(langs)}"
|
||||
model = registry.get(key)
|
||||
if model is None:
|
||||
model = _load(langs)
|
||||
return model
|
||||
|
||||
|
||||
def _parse_raw(raw) -> list[tuple[list, str, float]]:
|
||||
"""
|
||||
Parse PaddleOCR output into (points, text, confidence) tuples.
|
||||
|
||||
PaddleOCR 3.x changed the result format. Two known layouts:
|
||||
|
||||
Layout A — dict-based (new pipeline API):
|
||||
raw = [{'rec_texts': [...], 'rec_scores': [...], 'dt_polys': [...]}]
|
||||
|
||||
Layout B — nested list (2.x compat / some 3.x builds):
|
||||
raw = [[ [points, [text, score]], ... ]]
|
||||
raw = [[ [points, [text, score], [cls, cls_score]], ... ]] # with angle cls
|
||||
"""
|
||||
results = []
|
||||
|
||||
for page in raw:
|
||||
if not page:
|
||||
continue
|
||||
|
||||
# Layout A: dict with parallel lists
|
||||
if isinstance(page, dict):
|
||||
texts = page.get("rec_texts", [])
|
||||
scores = page.get("rec_scores", [])
|
||||
polys = page.get("dt_polys", [])
|
||||
for points, text, confidence in zip(polys, texts, scores):
|
||||
results.append((points, text, float(confidence)))
|
||||
continue
|
||||
|
||||
# Layout B: list of per-line entries
|
||||
for line in page:
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# line[0] is always the polygon points
|
||||
points = line[0]
|
||||
|
||||
# line[1] is [text, score] — ignore any extra elements (angle cls etc.)
|
||||
rec = line[1]
|
||||
if isinstance(rec, (list, tuple)) and len(rec) >= 2:
|
||||
text, confidence = rec[0], rec[1]
|
||||
else:
|
||||
logger.warning("Unexpected OCR line format: %s", line)
|
||||
continue
|
||||
|
||||
results.append((points, str(text), float(confidence)))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def ocr(image, languages: list[str] | None = None, min_confidence: float | None = None) -> list[dict]:
|
||||
"""Run OCR on an image, return list of text result dicts."""
|
||||
cfg = get_config()
|
||||
min_conf = min_confidence if min_confidence is not None else cfg["ocr_min_confidence"]
|
||||
model = _get(languages)
|
||||
|
||||
raw = model.ocr(image)
|
||||
logger.debug("OCR raw: %s", raw)
|
||||
|
||||
parsed = _parse_raw(raw)
|
||||
|
||||
results = []
|
||||
for points, text, confidence in parsed:
|
||||
if confidence < min_conf:
|
||||
continue
|
||||
|
||||
xs = [p[0] for p in points]
|
||||
ys = [p[1] for p in points]
|
||||
|
||||
results.append({
|
||||
"text": text,
|
||||
"confidence": confidence,
|
||||
"bbox": [int(min(xs)), int(min(ys)),
|
||||
int(max(xs) - min(xs)), int(max(ys) - min(ys))],
|
||||
})
|
||||
|
||||
return results
|
||||
37
gpu/models/registry.py
Normal file
37
gpu/models/registry.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
Model registry — manages loaded models and VRAM lifecycle.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_models: dict[str, object] = {}
|
||||
|
||||
|
||||
def get(name: str) -> object | None:
|
||||
return _models.get(name)
|
||||
|
||||
|
||||
def put(name: str, model: object) -> None:
|
||||
_models[name] = model
|
||||
logger.info("Loaded %s", name)
|
||||
|
||||
|
||||
def unload(name: str) -> bool:
|
||||
if name in _models:
|
||||
del _models[name]
|
||||
logger.info("Unloaded %s", name)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def loaded() -> list[str]:
|
||||
return list(_models.keys())
|
||||
|
||||
|
||||
def clear() -> None:
|
||||
_models.clear()
|
||||
logger.info("All models unloaded")
|
||||
54
gpu/models/yolo.py
Normal file
54
gpu/models/yolo.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""YOLO object detection model wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from models import registry
|
||||
from config import get_config, get_device
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load(model_name: str):
|
||||
from ultralytics import YOLO
|
||||
device = get_device()
|
||||
model = YOLO(model_name)
|
||||
model.to(device)
|
||||
registry.put(model_name, model)
|
||||
return model
|
||||
|
||||
|
||||
def _get(model_name: str | None = None):
|
||||
name = model_name or get_config()["yolo_model"]
|
||||
model = registry.get(name)
|
||||
if model is None:
|
||||
model = _load(name)
|
||||
return model
|
||||
|
||||
|
||||
def detect(image, model_name: str | None = None, confidence: float | None = None, target_classes: list[str] | None = None) -> list[dict]:
|
||||
"""Run YOLO detection, return list of bbox dicts."""
|
||||
cfg = get_config()
|
||||
conf = confidence if confidence is not None else cfg["yolo_confidence"]
|
||||
model = _get(model_name)
|
||||
|
||||
results = model(image, conf=conf, verbose=False)
|
||||
|
||||
detections = []
|
||||
for r in results:
|
||||
for box in r.boxes:
|
||||
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||
label = r.names[int(box.cls[0])]
|
||||
|
||||
if target_classes and label not in target_classes:
|
||||
continue
|
||||
|
||||
detections.append({
|
||||
"x": int(x1), "y": int(y1),
|
||||
"w": int(x2 - x1), "h": int(y2 - y1),
|
||||
"confidence": float(box.conf[0]),
|
||||
"label": label,
|
||||
})
|
||||
|
||||
return detections
|
||||
Reference in New Issue
Block a user