55 lines
1.4 KiB
Python
55 lines
1.4 KiB
Python
"""YOLO object detection model wrapper."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
from models import registry
|
|
from config import get_config, get_device
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _load(model_name: str):
|
|
from ultralytics import YOLO
|
|
device = get_device()
|
|
model = YOLO(model_name)
|
|
model.to(device)
|
|
registry.put(model_name, model)
|
|
return model
|
|
|
|
|
|
def _get(model_name: str | None = None):
|
|
name = model_name or get_config()["yolo_model"]
|
|
model = registry.get(name)
|
|
if model is None:
|
|
model = _load(name)
|
|
return model
|
|
|
|
|
|
def detect(image, model_name: str | None = None, confidence: float | None = None, target_classes: list[str] | None = None) -> list[dict]:
|
|
"""Run YOLO detection, return list of bbox dicts."""
|
|
cfg = get_config()
|
|
conf = confidence if confidence is not None else cfg["yolo_confidence"]
|
|
model = _get(model_name)
|
|
|
|
results = model(image, conf=conf, verbose=False)
|
|
|
|
detections = []
|
|
for r in results:
|
|
for box in r.boxes:
|
|
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
|
label = r.names[int(box.cls[0])]
|
|
|
|
if target_classes and label not in target_classes:
|
|
continue
|
|
|
|
detections.append({
|
|
"x": int(x1), "y": int(y1),
|
|
"w": int(x2 - x1), "h": int(y2 - y1),
|
|
"confidence": float(box.conf[0]),
|
|
"label": label,
|
|
})
|
|
|
|
return detections
|