phase 6
This commit is contained in:
@@ -35,6 +35,25 @@ def stats(job_id: str | None, **kwargs) -> None:
|
||||
push_detect_event(job_id, "stats_update", dataclasses.asdict(s))
|
||||
|
||||
|
||||
def frame_update(
|
||||
job_id: str | None,
|
||||
frame_ref: int,
|
||||
timestamp: float,
|
||||
jpeg_b64: str,
|
||||
boxes: list[dict],
|
||||
) -> None:
|
||||
"""Emit a frame_update event with the image and bounding boxes."""
|
||||
if not job_id:
|
||||
return
|
||||
payload = {
|
||||
"frame_ref": frame_ref,
|
||||
"timestamp": timestamp,
|
||||
"jpeg_b64": jpeg_b64,
|
||||
"boxes": boxes,
|
||||
}
|
||||
push_detect_event(job_id, "frame_update", payload)
|
||||
|
||||
|
||||
def graph_update(job_id: str | None, nodes: list[dict]) -> None:
|
||||
"""Emit a graph_update event with node states."""
|
||||
if not job_id:
|
||||
|
||||
@@ -7,6 +7,8 @@ Each node emits graph_update events so the UI can visualize transitions.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from langgraph.graph import END, StateGraph
|
||||
|
||||
from detect import emit
|
||||
@@ -15,6 +17,9 @@ from detect.profiles import SoccerBroadcastProfile
|
||||
from detect.state import DetectState
|
||||
from detect.stages.frame_extractor import extract_frames
|
||||
from detect.stages.scene_filter import scene_filter
|
||||
from detect.stages.yolo_detector import detect_objects
|
||||
|
||||
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
|
||||
|
||||
NODES = [
|
||||
"extract_frames",
|
||||
@@ -84,10 +89,19 @@ def node_filter_scenes(state: DetectState) -> dict:
|
||||
|
||||
def node_detect_objects(state: DetectState) -> dict:
|
||||
_emit_transition(state, "detect_objects", "running")
|
||||
|
||||
profile = _get_profile(state)
|
||||
config = profile.detection_config()
|
||||
frames = state.get("filtered_frames", [])
|
||||
job_id = state.get("job_id")
|
||||
emit.log(job_id, "YOLODetector", "INFO", "Stub: object detection not yet implemented")
|
||||
|
||||
all_boxes = detect_objects(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
|
||||
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.regions_detected = sum(len(boxes) for boxes in all_boxes.values())
|
||||
|
||||
_emit_transition(state, "detect_objects", "done")
|
||||
return {}
|
||||
return {"stats": stats}
|
||||
|
||||
|
||||
def node_run_ocr(state: DetectState) -> dict:
|
||||
|
||||
4
detect/inference/__init__.py
Normal file
4
detect/inference/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .client import InferenceClient
|
||||
from .types import DetectResult, OCRResult, VLMResult
|
||||
|
||||
__all__ = ["InferenceClient", "DetectResult", "OCRResult", "VLMResult"]
|
||||
156
detect/inference/client.py
Normal file
156
detect/inference/client.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
HTTP client for the inference server.
|
||||
|
||||
The pipeline stages call this instead of importing ML libraries directly.
|
||||
The inference server runs on the GPU machine (or spot instance).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from .types import DetectResult, OCRResult, ServerStatus, VLMResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_URL = os.environ.get("INFERENCE_URL", "http://localhost:8000")
|
||||
|
||||
|
||||
def _encode_image(image: np.ndarray) -> str:
|
||||
"""Encode numpy array as base64 JPEG."""
|
||||
img = Image.fromarray(image)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=85)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
class InferenceClient:
|
||||
"""HTTP client for the GPU inference server."""
|
||||
|
||||
def __init__(self, base_url: str | None = None, timeout: float = 60.0):
|
||||
self.base_url = (base_url or DEFAULT_URL).rstrip("/")
|
||||
self.timeout = timeout
|
||||
self.session = requests.Session()
|
||||
|
||||
def health(self) -> ServerStatus:
|
||||
"""Check server health and loaded models."""
|
||||
resp = self.session.get(f"{self.base_url}/health", timeout=self.timeout)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return ServerStatus(
|
||||
loaded_models=data.get("loaded_models", []),
|
||||
vram_used_mb=data.get("vram_used_mb", 0),
|
||||
vram_budget_mb=data.get("vram_budget_mb", 0),
|
||||
strategy=data.get("strategy", "sequential"),
|
||||
)
|
||||
|
||||
def detect(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
model: str = "yolov8n",
|
||||
confidence: float = 0.3,
|
||||
target_classes: list[str] | None = None,
|
||||
) -> list[DetectResult]:
|
||||
"""Run object detection on an image."""
|
||||
payload = {
|
||||
"image": _encode_image(image),
|
||||
"model": model,
|
||||
"confidence": confidence,
|
||||
}
|
||||
if target_classes:
|
||||
payload["target_classes"] = target_classes
|
||||
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}/detect",
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
results = []
|
||||
for d in resp.json().get("detections", []):
|
||||
result = DetectResult(
|
||||
x=d["x"], y=d["y"], w=d["w"], h=d["h"],
|
||||
confidence=d["confidence"], label=d["label"],
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def ocr(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
languages: list[str] | None = None,
|
||||
) -> list[OCRResult]:
|
||||
"""Run OCR on an image region."""
|
||||
payload = {
|
||||
"image": _encode_image(image),
|
||||
}
|
||||
if languages:
|
||||
payload["languages"] = languages
|
||||
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}/ocr",
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
results = []
|
||||
for d in resp.json().get("results", []):
|
||||
result = OCRResult(
|
||||
text=d["text"],
|
||||
confidence=d["confidence"],
|
||||
bbox=tuple(d["bbox"]),
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def vlm(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
prompt: str,
|
||||
model: str = "moondream2",
|
||||
) -> VLMResult:
|
||||
"""Query a visual language model with an image crop + prompt."""
|
||||
payload = {
|
||||
"image": _encode_image(image),
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}/vlm",
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
data = resp.json()
|
||||
return VLMResult(
|
||||
brand=data.get("brand", ""),
|
||||
confidence=data.get("confidence", 0.0),
|
||||
reasoning=data.get("reasoning", ""),
|
||||
)
|
||||
|
||||
def load_model(self, model: str, quantization: str = "fp16") -> None:
|
||||
"""Request the server to load a model into VRAM."""
|
||||
self.session.post(
|
||||
f"{self.base_url}/models/load",
|
||||
json={"model": model, "quantization": quantization},
|
||||
timeout=self.timeout,
|
||||
).raise_for_status()
|
||||
|
||||
def unload_model(self, model: str) -> None:
|
||||
"""Request the server to unload a model from VRAM."""
|
||||
self.session.post(
|
||||
f"{self.base_url}/models/unload",
|
||||
json={"model": model},
|
||||
timeout=self.timeout,
|
||||
).raise_for_status()
|
||||
55
detect/inference/types.py
Normal file
55
detect/inference/types.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Inference response types.
|
||||
|
||||
These are the shapes returned by the inference server.
|
||||
Kept separate from detect.models to avoid coupling the
|
||||
inference protocol to pipeline internals.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectResult:
|
||||
"""Single object detection from YOLO or similar."""
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
"""Text extracted from a region."""
|
||||
text: str
|
||||
confidence: float
|
||||
bbox: tuple[int, int, int, int] # x, y, w, h
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLMResult:
|
||||
"""Visual language model response for a crop."""
|
||||
brand: str
|
||||
confidence: float
|
||||
reasoning: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Info about a loaded model."""
|
||||
name: str
|
||||
vram_mb: float
|
||||
quantization: str # fp32, fp16, int8, int4
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerStatus:
|
||||
"""Inference server health response."""
|
||||
loaded_models: list[ModelInfo] = field(default_factory=list)
|
||||
vram_used_mb: float = 0.0
|
||||
vram_budget_mb: float = 0.0
|
||||
strategy: str = "sequential" # sequential, concurrent, auto
|
||||
@@ -28,7 +28,7 @@ class SoccerBroadcastProfile:
|
||||
return DetectionConfig(
|
||||
model_name="yolov8n.pt",
|
||||
confidence_threshold=0.3,
|
||||
target_classes=["logo", "text", "banner", "scoreboard"],
|
||||
target_classes=[], # empty = accept all COCO classes (until custom model)
|
||||
)
|
||||
|
||||
def ocr_config(self) -> OCRConfig:
|
||||
|
||||
127
detect/stages/yolo_detector.py
Normal file
127
detect/stages/yolo_detector.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Stage 3 — YOLO Object Detection
|
||||
|
||||
Detects regions of interest (logos, text, banners) in frames.
|
||||
Two modes:
|
||||
- Remote: calls inference server over HTTP (GPU on another machine)
|
||||
- Local: imports ultralytics directly (GPU on same machine)
|
||||
|
||||
Emits frame_update events with bounding boxes for the UI.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from detect import emit
|
||||
from detect.models import BoundingBox, Frame
|
||||
from detect.profiles.base import DetectionConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _frame_to_b64(frame: Frame) -> str:
|
||||
"""Encode frame as base64 JPEG for SSE frame_update events."""
|
||||
img = Image.fromarray(frame.image)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=70)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
def _detect_remote(frame: Frame, config: DetectionConfig, inference_url: str) -> list[BoundingBox]:
|
||||
"""Call the inference server over HTTP."""
|
||||
from detect.inference import InferenceClient
|
||||
client = InferenceClient(base_url=inference_url)
|
||||
results = client.detect(
|
||||
image=frame.image,
|
||||
model=config.model_name,
|
||||
confidence=config.confidence_threshold,
|
||||
target_classes=config.target_classes,
|
||||
)
|
||||
boxes = []
|
||||
for r in results:
|
||||
box = BoundingBox(
|
||||
x=r.x, y=r.y, w=r.w, h=r.h,
|
||||
confidence=r.confidence, label=r.label,
|
||||
)
|
||||
boxes.append(box)
|
||||
return boxes
|
||||
|
||||
|
||||
def _detect_local(frame: Frame, config: DetectionConfig) -> list[BoundingBox]:
|
||||
"""Run YOLO in-process (requires ultralytics installed)."""
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(config.model_name)
|
||||
results = model(frame.image, conf=config.confidence_threshold, verbose=False)
|
||||
|
||||
boxes = []
|
||||
for r in results:
|
||||
for det in r.boxes:
|
||||
x1, y1, x2, y2 = det.xyxy[0].tolist()
|
||||
label = r.names[int(det.cls[0])]
|
||||
|
||||
if config.target_classes and label not in config.target_classes:
|
||||
continue
|
||||
|
||||
box = BoundingBox(
|
||||
x=int(x1), y=int(y1),
|
||||
w=int(x2 - x1), h=int(y2 - y1),
|
||||
confidence=float(det.conf[0]),
|
||||
label=label,
|
||||
)
|
||||
boxes.append(box)
|
||||
return boxes
|
||||
|
||||
|
||||
def detect_objects(
|
||||
frames: list[Frame],
|
||||
config: DetectionConfig,
|
||||
inference_url: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> dict[int, list[BoundingBox]]:
|
||||
"""
|
||||
Run object detection on all frames.
|
||||
|
||||
If inference_url is provided, calls the remote GPU server.
|
||||
Otherwise, imports ultralytics and runs locally.
|
||||
|
||||
Returns a dict mapping frame sequence → list of bounding boxes.
|
||||
"""
|
||||
mode = "remote" if inference_url else "local"
|
||||
emit.log(job_id, "YOLODetector", "INFO",
|
||||
f"Detecting objects in {len(frames)} frames "
|
||||
f"(model={config.model_name}, conf={config.confidence_threshold}, mode={mode})")
|
||||
|
||||
all_boxes: dict[int, list[BoundingBox]] = {}
|
||||
total_regions = 0
|
||||
|
||||
for frame in frames:
|
||||
if inference_url:
|
||||
boxes = _detect_remote(frame, config, inference_url)
|
||||
else:
|
||||
boxes = _detect_local(frame, config)
|
||||
|
||||
all_boxes[frame.sequence] = boxes
|
||||
total_regions += len(boxes)
|
||||
|
||||
if boxes and job_id:
|
||||
box_dicts = [{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
|
||||
"confidence": b.confidence, "label": b.label}
|
||||
for b in boxes]
|
||||
emit.frame_update(
|
||||
job_id,
|
||||
frame_ref=frame.sequence,
|
||||
timestamp=frame.timestamp,
|
||||
jpeg_b64=_frame_to_b64(frame),
|
||||
boxes=box_dicts,
|
||||
)
|
||||
|
||||
emit.log(job_id, "YOLODetector", "INFO",
|
||||
f"Detected {total_regions} regions across {len(frames)} frames")
|
||||
emit.stats(job_id, regions_detected=total_regions)
|
||||
|
||||
return all_boxes
|
||||
Reference in New Issue
Block a user