phase 6
This commit is contained in:
127
detect/stages/yolo_detector.py
Normal file
127
detect/stages/yolo_detector.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Stage 3 — YOLO Object Detection
|
||||
|
||||
Detects regions of interest (logos, text, banners) in frames.
|
||||
Two modes:
|
||||
- Remote: calls inference server over HTTP (GPU on another machine)
|
||||
- Local: imports ultralytics directly (GPU on same machine)
|
||||
|
||||
Emits frame_update events with bounding boxes for the UI.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from detect import emit
|
||||
from detect.models import BoundingBox, Frame
|
||||
from detect.profiles.base import DetectionConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _frame_to_b64(frame: Frame) -> str:
|
||||
"""Encode frame as base64 JPEG for SSE frame_update events."""
|
||||
img = Image.fromarray(frame.image)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=70)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
def _detect_remote(frame: Frame, config: DetectionConfig, inference_url: str) -> list[BoundingBox]:
|
||||
"""Call the inference server over HTTP."""
|
||||
from detect.inference import InferenceClient
|
||||
client = InferenceClient(base_url=inference_url)
|
||||
results = client.detect(
|
||||
image=frame.image,
|
||||
model=config.model_name,
|
||||
confidence=config.confidence_threshold,
|
||||
target_classes=config.target_classes,
|
||||
)
|
||||
boxes = []
|
||||
for r in results:
|
||||
box = BoundingBox(
|
||||
x=r.x, y=r.y, w=r.w, h=r.h,
|
||||
confidence=r.confidence, label=r.label,
|
||||
)
|
||||
boxes.append(box)
|
||||
return boxes
|
||||
|
||||
|
||||
def _detect_local(frame: Frame, config: DetectionConfig) -> list[BoundingBox]:
|
||||
"""Run YOLO in-process (requires ultralytics installed)."""
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(config.model_name)
|
||||
results = model(frame.image, conf=config.confidence_threshold, verbose=False)
|
||||
|
||||
boxes = []
|
||||
for r in results:
|
||||
for det in r.boxes:
|
||||
x1, y1, x2, y2 = det.xyxy[0].tolist()
|
||||
label = r.names[int(det.cls[0])]
|
||||
|
||||
if config.target_classes and label not in config.target_classes:
|
||||
continue
|
||||
|
||||
box = BoundingBox(
|
||||
x=int(x1), y=int(y1),
|
||||
w=int(x2 - x1), h=int(y2 - y1),
|
||||
confidence=float(det.conf[0]),
|
||||
label=label,
|
||||
)
|
||||
boxes.append(box)
|
||||
return boxes
|
||||
|
||||
|
||||
def detect_objects(
|
||||
frames: list[Frame],
|
||||
config: DetectionConfig,
|
||||
inference_url: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> dict[int, list[BoundingBox]]:
|
||||
"""
|
||||
Run object detection on all frames.
|
||||
|
||||
If inference_url is provided, calls the remote GPU server.
|
||||
Otherwise, imports ultralytics and runs locally.
|
||||
|
||||
Returns a dict mapping frame sequence → list of bounding boxes.
|
||||
"""
|
||||
mode = "remote" if inference_url else "local"
|
||||
emit.log(job_id, "YOLODetector", "INFO",
|
||||
f"Detecting objects in {len(frames)} frames "
|
||||
f"(model={config.model_name}, conf={config.confidence_threshold}, mode={mode})")
|
||||
|
||||
all_boxes: dict[int, list[BoundingBox]] = {}
|
||||
total_regions = 0
|
||||
|
||||
for frame in frames:
|
||||
if inference_url:
|
||||
boxes = _detect_remote(frame, config, inference_url)
|
||||
else:
|
||||
boxes = _detect_local(frame, config)
|
||||
|
||||
all_boxes[frame.sequence] = boxes
|
||||
total_regions += len(boxes)
|
||||
|
||||
if boxes and job_id:
|
||||
box_dicts = [{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
|
||||
"confidence": b.confidence, "label": b.label}
|
||||
for b in boxes]
|
||||
emit.frame_update(
|
||||
job_id,
|
||||
frame_ref=frame.sequence,
|
||||
timestamp=frame.timestamp,
|
||||
jpeg_b64=_frame_to_b64(frame),
|
||||
boxes=box_dicts,
|
||||
)
|
||||
|
||||
emit.log(job_id, "YOLODetector", "INFO",
|
||||
f"Detected {total_regions} regions across {len(frames)} frames")
|
||||
emit.stats(job_id, regions_detected=total_regions)
|
||||
|
||||
return all_boxes
|
||||
Reference in New Issue
Block a user