This commit is contained in:
2026-03-23 16:55:13 -03:00
parent 4fdbdfc6d3
commit 3df9ed5ada
17 changed files with 848 additions and 4 deletions

View File

@@ -35,6 +35,25 @@ def stats(job_id: str | None, **kwargs) -> None:
push_detect_event(job_id, "stats_update", dataclasses.asdict(s))
def frame_update(
job_id: str | None,
frame_ref: int,
timestamp: float,
jpeg_b64: str,
boxes: list[dict],
) -> None:
"""Emit a frame_update event with the image and bounding boxes."""
if not job_id:
return
payload = {
"frame_ref": frame_ref,
"timestamp": timestamp,
"jpeg_b64": jpeg_b64,
"boxes": boxes,
}
push_detect_event(job_id, "frame_update", payload)
def graph_update(job_id: str | None, nodes: list[dict]) -> None:
"""Emit a graph_update event with node states."""
if not job_id:

View File

@@ -7,6 +7,8 @@ Each node emits graph_update events so the UI can visualize transitions.
from __future__ import annotations
import os
from langgraph.graph import END, StateGraph
from detect import emit
@@ -15,6 +17,9 @@ from detect.profiles import SoccerBroadcastProfile
from detect.state import DetectState
from detect.stages.frame_extractor import extract_frames
from detect.stages.scene_filter import scene_filter
from detect.stages.yolo_detector import detect_objects
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
NODES = [
"extract_frames",
@@ -84,10 +89,19 @@ def node_filter_scenes(state: DetectState) -> dict:
def node_detect_objects(state: DetectState) -> dict:
_emit_transition(state, "detect_objects", "running")
profile = _get_profile(state)
config = profile.detection_config()
frames = state.get("filtered_frames", [])
job_id = state.get("job_id")
emit.log(job_id, "YOLODetector", "INFO", "Stub: object detection not yet implemented")
all_boxes = detect_objects(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
stats = state.get("stats", PipelineStats())
stats.regions_detected = sum(len(boxes) for boxes in all_boxes.values())
_emit_transition(state, "detect_objects", "done")
return {}
return {"stats": stats}
def node_run_ocr(state: DetectState) -> dict:

View File

@@ -0,0 +1,4 @@
from .client import InferenceClient
from .types import DetectResult, OCRResult, VLMResult
__all__ = ["InferenceClient", "DetectResult", "OCRResult", "VLMResult"]

156
detect/inference/client.py Normal file
View File

@@ -0,0 +1,156 @@
"""
HTTP client for the inference server.
The pipeline stages call this instead of importing ML libraries directly.
The inference server runs on the GPU machine (or spot instance).
"""
from __future__ import annotations
import base64
import io
import logging
import os
import numpy as np
import requests
from PIL import Image
from .types import DetectResult, OCRResult, ServerStatus, VLMResult
logger = logging.getLogger(__name__)
DEFAULT_URL = os.environ.get("INFERENCE_URL", "http://localhost:8000")
def _encode_image(image: np.ndarray) -> str:
"""Encode numpy array as base64 JPEG."""
img = Image.fromarray(image)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
return base64.b64encode(buf.getvalue()).decode()
class InferenceClient:
"""HTTP client for the GPU inference server."""
def __init__(self, base_url: str | None = None, timeout: float = 60.0):
self.base_url = (base_url or DEFAULT_URL).rstrip("/")
self.timeout = timeout
self.session = requests.Session()
def health(self) -> ServerStatus:
"""Check server health and loaded models."""
resp = self.session.get(f"{self.base_url}/health", timeout=self.timeout)
resp.raise_for_status()
data = resp.json()
return ServerStatus(
loaded_models=data.get("loaded_models", []),
vram_used_mb=data.get("vram_used_mb", 0),
vram_budget_mb=data.get("vram_budget_mb", 0),
strategy=data.get("strategy", "sequential"),
)
def detect(
self,
image: np.ndarray,
model: str = "yolov8n",
confidence: float = 0.3,
target_classes: list[str] | None = None,
) -> list[DetectResult]:
"""Run object detection on an image."""
payload = {
"image": _encode_image(image),
"model": model,
"confidence": confidence,
}
if target_classes:
payload["target_classes"] = target_classes
resp = self.session.post(
f"{self.base_url}/detect",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
results = []
for d in resp.json().get("detections", []):
result = DetectResult(
x=d["x"], y=d["y"], w=d["w"], h=d["h"],
confidence=d["confidence"], label=d["label"],
)
results.append(result)
return results
def ocr(
self,
image: np.ndarray,
languages: list[str] | None = None,
) -> list[OCRResult]:
"""Run OCR on an image region."""
payload = {
"image": _encode_image(image),
}
if languages:
payload["languages"] = languages
resp = self.session.post(
f"{self.base_url}/ocr",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
results = []
for d in resp.json().get("results", []):
result = OCRResult(
text=d["text"],
confidence=d["confidence"],
bbox=tuple(d["bbox"]),
)
results.append(result)
return results
def vlm(
self,
image: np.ndarray,
prompt: str,
model: str = "moondream2",
) -> VLMResult:
"""Query a visual language model with an image crop + prompt."""
payload = {
"image": _encode_image(image),
"prompt": prompt,
"model": model,
}
resp = self.session.post(
f"{self.base_url}/vlm",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
data = resp.json()
return VLMResult(
brand=data.get("brand", ""),
confidence=data.get("confidence", 0.0),
reasoning=data.get("reasoning", ""),
)
def load_model(self, model: str, quantization: str = "fp16") -> None:
"""Request the server to load a model into VRAM."""
self.session.post(
f"{self.base_url}/models/load",
json={"model": model, "quantization": quantization},
timeout=self.timeout,
).raise_for_status()
def unload_model(self, model: str) -> None:
"""Request the server to unload a model from VRAM."""
self.session.post(
f"{self.base_url}/models/unload",
json={"model": model},
timeout=self.timeout,
).raise_for_status()

55
detect/inference/types.py Normal file
View File

@@ -0,0 +1,55 @@
"""
Inference response types.
These are the shapes returned by the inference server.
Kept separate from detect.models to avoid coupling the
inference protocol to pipeline internals.
"""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass
class DetectResult:
"""Single object detection from YOLO or similar."""
x: int
y: int
w: int
h: int
confidence: float
label: str
@dataclass
class OCRResult:
"""Text extracted from a region."""
text: str
confidence: float
bbox: tuple[int, int, int, int] # x, y, w, h
@dataclass
class VLMResult:
"""Visual language model response for a crop."""
brand: str
confidence: float
reasoning: str
@dataclass
class ModelInfo:
"""Info about a loaded model."""
name: str
vram_mb: float
quantization: str # fp32, fp16, int8, int4
@dataclass
class ServerStatus:
"""Inference server health response."""
loaded_models: list[ModelInfo] = field(default_factory=list)
vram_used_mb: float = 0.0
vram_budget_mb: float = 0.0
strategy: str = "sequential" # sequential, concurrent, auto

View File

@@ -28,7 +28,7 @@ class SoccerBroadcastProfile:
return DetectionConfig(
model_name="yolov8n.pt",
confidence_threshold=0.3,
target_classes=["logo", "text", "banner", "scoreboard"],
target_classes=[], # empty = accept all COCO classes (until custom model)
)
def ocr_config(self) -> OCRConfig:

View File

@@ -0,0 +1,127 @@
"""
Stage 3 — YOLO Object Detection
Detects regions of interest (logos, text, banners) in frames.
Two modes:
- Remote: calls inference server over HTTP (GPU on another machine)
- Local: imports ultralytics directly (GPU on same machine)
Emits frame_update events with bounding boxes for the UI.
"""
from __future__ import annotations
import base64
import io
import logging
from PIL import Image
from detect import emit
from detect.models import BoundingBox, Frame
from detect.profiles.base import DetectionConfig
logger = logging.getLogger(__name__)
def _frame_to_b64(frame: Frame) -> str:
"""Encode frame as base64 JPEG for SSE frame_update events."""
img = Image.fromarray(frame.image)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=70)
return base64.b64encode(buf.getvalue()).decode()
def _detect_remote(frame: Frame, config: DetectionConfig, inference_url: str) -> list[BoundingBox]:
"""Call the inference server over HTTP."""
from detect.inference import InferenceClient
client = InferenceClient(base_url=inference_url)
results = client.detect(
image=frame.image,
model=config.model_name,
confidence=config.confidence_threshold,
target_classes=config.target_classes,
)
boxes = []
for r in results:
box = BoundingBox(
x=r.x, y=r.y, w=r.w, h=r.h,
confidence=r.confidence, label=r.label,
)
boxes.append(box)
return boxes
def _detect_local(frame: Frame, config: DetectionConfig) -> list[BoundingBox]:
"""Run YOLO in-process (requires ultralytics installed)."""
from ultralytics import YOLO
model = YOLO(config.model_name)
results = model(frame.image, conf=config.confidence_threshold, verbose=False)
boxes = []
for r in results:
for det in r.boxes:
x1, y1, x2, y2 = det.xyxy[0].tolist()
label = r.names[int(det.cls[0])]
if config.target_classes and label not in config.target_classes:
continue
box = BoundingBox(
x=int(x1), y=int(y1),
w=int(x2 - x1), h=int(y2 - y1),
confidence=float(det.conf[0]),
label=label,
)
boxes.append(box)
return boxes
def detect_objects(
frames: list[Frame],
config: DetectionConfig,
inference_url: str | None = None,
job_id: str | None = None,
) -> dict[int, list[BoundingBox]]:
"""
Run object detection on all frames.
If inference_url is provided, calls the remote GPU server.
Otherwise, imports ultralytics and runs locally.
Returns a dict mapping frame sequence → list of bounding boxes.
"""
mode = "remote" if inference_url else "local"
emit.log(job_id, "YOLODetector", "INFO",
f"Detecting objects in {len(frames)} frames "
f"(model={config.model_name}, conf={config.confidence_threshold}, mode={mode})")
all_boxes: dict[int, list[BoundingBox]] = {}
total_regions = 0
for frame in frames:
if inference_url:
boxes = _detect_remote(frame, config, inference_url)
else:
boxes = _detect_local(frame, config)
all_boxes[frame.sequence] = boxes
total_regions += len(boxes)
if boxes and job_id:
box_dicts = [{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
"confidence": b.confidence, "label": b.label}
for b in boxes]
emit.frame_update(
job_id,
frame_ref=frame.sequence,
timestamp=frame.timestamp,
jpeg_b64=_frame_to_b64(frame),
boxes=box_dicts,
)
emit.log(job_id, "YOLODetector", "INFO",
f"Detected {total_regions} regions across {len(frames)} frames")
emit.stats(job_id, regions_detected=total_regions)
return all_boxes