phase 7
This commit is contained in:
54
gpu/models/yolo.py
Normal file
54
gpu/models/yolo.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""YOLO object detection model wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from models import registry
|
||||
from config import get_config, get_device
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load(model_name: str):
|
||||
from ultralytics import YOLO
|
||||
device = get_device()
|
||||
model = YOLO(model_name)
|
||||
model.to(device)
|
||||
registry.put(model_name, model)
|
||||
return model
|
||||
|
||||
|
||||
def _get(model_name: str | None = None):
|
||||
name = model_name or get_config()["yolo_model"]
|
||||
model = registry.get(name)
|
||||
if model is None:
|
||||
model = _load(name)
|
||||
return model
|
||||
|
||||
|
||||
def detect(image, model_name: str | None = None, confidence: float | None = None, target_classes: list[str] | None = None) -> list[dict]:
|
||||
"""Run YOLO detection, return list of bbox dicts."""
|
||||
cfg = get_config()
|
||||
conf = confidence if confidence is not None else cfg["yolo_confidence"]
|
||||
model = _get(model_name)
|
||||
|
||||
results = model(image, conf=conf, verbose=False)
|
||||
|
||||
detections = []
|
||||
for r in results:
|
||||
for box in r.boxes:
|
||||
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||
label = r.names[int(box.cls[0])]
|
||||
|
||||
if target_classes and label not in target_classes:
|
||||
continue
|
||||
|
||||
detections.append({
|
||||
"x": int(x1), "y": int(y1),
|
||||
"w": int(x2 - x1), "h": int(y2 - y1),
|
||||
"confidence": float(box.conf[0]),
|
||||
"label": label,
|
||||
})
|
||||
|
||||
return detections
|
||||
Reference in New Issue
Block a user