phase 6
This commit is contained in:
23
ctrl/sync.sh
Executable file
23
ctrl/sync.sh
Executable file
@@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
# Sync gpu/ folder to the GPU machine
|
||||
# Usage: ./ctrl/sync.sh [HOST] [DEST]
|
||||
#
|
||||
# Examples:
|
||||
# ./ctrl/sync.sh # defaults: mcrn:~/mpr/gpu
|
||||
# ./ctrl/sync.sh 192.168.1.3 # custom host
|
||||
# ./ctrl/sync.sh mcrn ~/inference # custom host + dest
|
||||
|
||||
set -e
|
||||
cd "$(dirname "$0")/.."
|
||||
|
||||
HOST="${1:-mcrndeb}"
|
||||
DEST="${2:-~/wdir/mpr/gpu}"
|
||||
|
||||
echo "Syncing gpu/ to ${HOST}:${DEST}..."
|
||||
rsync -avz --exclude='.git' --exclude='__pycache__' \
|
||||
--exclude='*.pyc' --exclude='.env' \
|
||||
--filter=':- .gitignore' \
|
||||
gpu/ "${HOST}:${DEST}/"
|
||||
|
||||
echo "Done. Run on ${HOST}:"
|
||||
echo " cd ${DEST} && cp .env.template .env && ./run.sh"
|
||||
@@ -35,6 +35,25 @@ def stats(job_id: str | None, **kwargs) -> None:
|
||||
push_detect_event(job_id, "stats_update", dataclasses.asdict(s))
|
||||
|
||||
|
||||
def frame_update(
|
||||
job_id: str | None,
|
||||
frame_ref: int,
|
||||
timestamp: float,
|
||||
jpeg_b64: str,
|
||||
boxes: list[dict],
|
||||
) -> None:
|
||||
"""Emit a frame_update event with the image and bounding boxes."""
|
||||
if not job_id:
|
||||
return
|
||||
payload = {
|
||||
"frame_ref": frame_ref,
|
||||
"timestamp": timestamp,
|
||||
"jpeg_b64": jpeg_b64,
|
||||
"boxes": boxes,
|
||||
}
|
||||
push_detect_event(job_id, "frame_update", payload)
|
||||
|
||||
|
||||
def graph_update(job_id: str | None, nodes: list[dict]) -> None:
|
||||
"""Emit a graph_update event with node states."""
|
||||
if not job_id:
|
||||
|
||||
@@ -7,6 +7,8 @@ Each node emits graph_update events so the UI can visualize transitions.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from langgraph.graph import END, StateGraph
|
||||
|
||||
from detect import emit
|
||||
@@ -15,6 +17,9 @@ from detect.profiles import SoccerBroadcastProfile
|
||||
from detect.state import DetectState
|
||||
from detect.stages.frame_extractor import extract_frames
|
||||
from detect.stages.scene_filter import scene_filter
|
||||
from detect.stages.yolo_detector import detect_objects
|
||||
|
||||
INFERENCE_URL = os.environ.get("INFERENCE_URL") # None = local mode
|
||||
|
||||
NODES = [
|
||||
"extract_frames",
|
||||
@@ -84,10 +89,19 @@ def node_filter_scenes(state: DetectState) -> dict:
|
||||
|
||||
def node_detect_objects(state: DetectState) -> dict:
|
||||
_emit_transition(state, "detect_objects", "running")
|
||||
|
||||
profile = _get_profile(state)
|
||||
config = profile.detection_config()
|
||||
frames = state.get("filtered_frames", [])
|
||||
job_id = state.get("job_id")
|
||||
emit.log(job_id, "YOLODetector", "INFO", "Stub: object detection not yet implemented")
|
||||
|
||||
all_boxes = detect_objects(frames, config, inference_url=INFERENCE_URL, job_id=job_id)
|
||||
|
||||
stats = state.get("stats", PipelineStats())
|
||||
stats.regions_detected = sum(len(boxes) for boxes in all_boxes.values())
|
||||
|
||||
_emit_transition(state, "detect_objects", "done")
|
||||
return {}
|
||||
return {"stats": stats}
|
||||
|
||||
|
||||
def node_run_ocr(state: DetectState) -> dict:
|
||||
|
||||
4
detect/inference/__init__.py
Normal file
4
detect/inference/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .client import InferenceClient
|
||||
from .types import DetectResult, OCRResult, VLMResult
|
||||
|
||||
__all__ = ["InferenceClient", "DetectResult", "OCRResult", "VLMResult"]
|
||||
156
detect/inference/client.py
Normal file
156
detect/inference/client.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
HTTP client for the inference server.
|
||||
|
||||
The pipeline stages call this instead of importing ML libraries directly.
|
||||
The inference server runs on the GPU machine (or spot instance).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from .types import DetectResult, OCRResult, ServerStatus, VLMResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_URL = os.environ.get("INFERENCE_URL", "http://localhost:8000")
|
||||
|
||||
|
||||
def _encode_image(image: np.ndarray) -> str:
|
||||
"""Encode numpy array as base64 JPEG."""
|
||||
img = Image.fromarray(image)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=85)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
class InferenceClient:
|
||||
"""HTTP client for the GPU inference server."""
|
||||
|
||||
def __init__(self, base_url: str | None = None, timeout: float = 60.0):
|
||||
self.base_url = (base_url or DEFAULT_URL).rstrip("/")
|
||||
self.timeout = timeout
|
||||
self.session = requests.Session()
|
||||
|
||||
def health(self) -> ServerStatus:
|
||||
"""Check server health and loaded models."""
|
||||
resp = self.session.get(f"{self.base_url}/health", timeout=self.timeout)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return ServerStatus(
|
||||
loaded_models=data.get("loaded_models", []),
|
||||
vram_used_mb=data.get("vram_used_mb", 0),
|
||||
vram_budget_mb=data.get("vram_budget_mb", 0),
|
||||
strategy=data.get("strategy", "sequential"),
|
||||
)
|
||||
|
||||
def detect(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
model: str = "yolov8n",
|
||||
confidence: float = 0.3,
|
||||
target_classes: list[str] | None = None,
|
||||
) -> list[DetectResult]:
|
||||
"""Run object detection on an image."""
|
||||
payload = {
|
||||
"image": _encode_image(image),
|
||||
"model": model,
|
||||
"confidence": confidence,
|
||||
}
|
||||
if target_classes:
|
||||
payload["target_classes"] = target_classes
|
||||
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}/detect",
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
results = []
|
||||
for d in resp.json().get("detections", []):
|
||||
result = DetectResult(
|
||||
x=d["x"], y=d["y"], w=d["w"], h=d["h"],
|
||||
confidence=d["confidence"], label=d["label"],
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def ocr(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
languages: list[str] | None = None,
|
||||
) -> list[OCRResult]:
|
||||
"""Run OCR on an image region."""
|
||||
payload = {
|
||||
"image": _encode_image(image),
|
||||
}
|
||||
if languages:
|
||||
payload["languages"] = languages
|
||||
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}/ocr",
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
results = []
|
||||
for d in resp.json().get("results", []):
|
||||
result = OCRResult(
|
||||
text=d["text"],
|
||||
confidence=d["confidence"],
|
||||
bbox=tuple(d["bbox"]),
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def vlm(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
prompt: str,
|
||||
model: str = "moondream2",
|
||||
) -> VLMResult:
|
||||
"""Query a visual language model with an image crop + prompt."""
|
||||
payload = {
|
||||
"image": _encode_image(image),
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}/vlm",
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
data = resp.json()
|
||||
return VLMResult(
|
||||
brand=data.get("brand", ""),
|
||||
confidence=data.get("confidence", 0.0),
|
||||
reasoning=data.get("reasoning", ""),
|
||||
)
|
||||
|
||||
def load_model(self, model: str, quantization: str = "fp16") -> None:
|
||||
"""Request the server to load a model into VRAM."""
|
||||
self.session.post(
|
||||
f"{self.base_url}/models/load",
|
||||
json={"model": model, "quantization": quantization},
|
||||
timeout=self.timeout,
|
||||
).raise_for_status()
|
||||
|
||||
def unload_model(self, model: str) -> None:
|
||||
"""Request the server to unload a model from VRAM."""
|
||||
self.session.post(
|
||||
f"{self.base_url}/models/unload",
|
||||
json={"model": model},
|
||||
timeout=self.timeout,
|
||||
).raise_for_status()
|
||||
55
detect/inference/types.py
Normal file
55
detect/inference/types.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Inference response types.
|
||||
|
||||
These are the shapes returned by the inference server.
|
||||
Kept separate from detect.models to avoid coupling the
|
||||
inference protocol to pipeline internals.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectResult:
|
||||
"""Single object detection from YOLO or similar."""
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
"""Text extracted from a region."""
|
||||
text: str
|
||||
confidence: float
|
||||
bbox: tuple[int, int, int, int] # x, y, w, h
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLMResult:
|
||||
"""Visual language model response for a crop."""
|
||||
brand: str
|
||||
confidence: float
|
||||
reasoning: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Info about a loaded model."""
|
||||
name: str
|
||||
vram_mb: float
|
||||
quantization: str # fp32, fp16, int8, int4
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerStatus:
|
||||
"""Inference server health response."""
|
||||
loaded_models: list[ModelInfo] = field(default_factory=list)
|
||||
vram_used_mb: float = 0.0
|
||||
vram_budget_mb: float = 0.0
|
||||
strategy: str = "sequential" # sequential, concurrent, auto
|
||||
@@ -28,7 +28,7 @@ class SoccerBroadcastProfile:
|
||||
return DetectionConfig(
|
||||
model_name="yolov8n.pt",
|
||||
confidence_threshold=0.3,
|
||||
target_classes=["logo", "text", "banner", "scoreboard"],
|
||||
target_classes=[], # empty = accept all COCO classes (until custom model)
|
||||
)
|
||||
|
||||
def ocr_config(self) -> OCRConfig:
|
||||
|
||||
127
detect/stages/yolo_detector.py
Normal file
127
detect/stages/yolo_detector.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Stage 3 — YOLO Object Detection
|
||||
|
||||
Detects regions of interest (logos, text, banners) in frames.
|
||||
Two modes:
|
||||
- Remote: calls inference server over HTTP (GPU on another machine)
|
||||
- Local: imports ultralytics directly (GPU on same machine)
|
||||
|
||||
Emits frame_update events with bounding boxes for the UI.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from detect import emit
|
||||
from detect.models import BoundingBox, Frame
|
||||
from detect.profiles.base import DetectionConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _frame_to_b64(frame: Frame) -> str:
|
||||
"""Encode frame as base64 JPEG for SSE frame_update events."""
|
||||
img = Image.fromarray(frame.image)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=70)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
def _detect_remote(frame: Frame, config: DetectionConfig, inference_url: str) -> list[BoundingBox]:
|
||||
"""Call the inference server over HTTP."""
|
||||
from detect.inference import InferenceClient
|
||||
client = InferenceClient(base_url=inference_url)
|
||||
results = client.detect(
|
||||
image=frame.image,
|
||||
model=config.model_name,
|
||||
confidence=config.confidence_threshold,
|
||||
target_classes=config.target_classes,
|
||||
)
|
||||
boxes = []
|
||||
for r in results:
|
||||
box = BoundingBox(
|
||||
x=r.x, y=r.y, w=r.w, h=r.h,
|
||||
confidence=r.confidence, label=r.label,
|
||||
)
|
||||
boxes.append(box)
|
||||
return boxes
|
||||
|
||||
|
||||
def _detect_local(frame: Frame, config: DetectionConfig) -> list[BoundingBox]:
|
||||
"""Run YOLO in-process (requires ultralytics installed)."""
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(config.model_name)
|
||||
results = model(frame.image, conf=config.confidence_threshold, verbose=False)
|
||||
|
||||
boxes = []
|
||||
for r in results:
|
||||
for det in r.boxes:
|
||||
x1, y1, x2, y2 = det.xyxy[0].tolist()
|
||||
label = r.names[int(det.cls[0])]
|
||||
|
||||
if config.target_classes and label not in config.target_classes:
|
||||
continue
|
||||
|
||||
box = BoundingBox(
|
||||
x=int(x1), y=int(y1),
|
||||
w=int(x2 - x1), h=int(y2 - y1),
|
||||
confidence=float(det.conf[0]),
|
||||
label=label,
|
||||
)
|
||||
boxes.append(box)
|
||||
return boxes
|
||||
|
||||
|
||||
def detect_objects(
|
||||
frames: list[Frame],
|
||||
config: DetectionConfig,
|
||||
inference_url: str | None = None,
|
||||
job_id: str | None = None,
|
||||
) -> dict[int, list[BoundingBox]]:
|
||||
"""
|
||||
Run object detection on all frames.
|
||||
|
||||
If inference_url is provided, calls the remote GPU server.
|
||||
Otherwise, imports ultralytics and runs locally.
|
||||
|
||||
Returns a dict mapping frame sequence → list of bounding boxes.
|
||||
"""
|
||||
mode = "remote" if inference_url else "local"
|
||||
emit.log(job_id, "YOLODetector", "INFO",
|
||||
f"Detecting objects in {len(frames)} frames "
|
||||
f"(model={config.model_name}, conf={config.confidence_threshold}, mode={mode})")
|
||||
|
||||
all_boxes: dict[int, list[BoundingBox]] = {}
|
||||
total_regions = 0
|
||||
|
||||
for frame in frames:
|
||||
if inference_url:
|
||||
boxes = _detect_remote(frame, config, inference_url)
|
||||
else:
|
||||
boxes = _detect_local(frame, config)
|
||||
|
||||
all_boxes[frame.sequence] = boxes
|
||||
total_regions += len(boxes)
|
||||
|
||||
if boxes and job_id:
|
||||
box_dicts = [{"x": b.x, "y": b.y, "w": b.w, "h": b.h,
|
||||
"confidence": b.confidence, "label": b.label}
|
||||
for b in boxes]
|
||||
emit.frame_update(
|
||||
job_id,
|
||||
frame_ref=frame.sequence,
|
||||
timestamp=frame.timestamp,
|
||||
jpeg_b64=_frame_to_b64(frame),
|
||||
boxes=box_dicts,
|
||||
)
|
||||
|
||||
emit.log(job_id, "YOLODetector", "INFO",
|
||||
f"Detected {total_regions} regions across {len(frames)} frames")
|
||||
emit.stats(job_id, regions_detected=total_regions)
|
||||
|
||||
return all_boxes
|
||||
14
gpu/.env.template
Normal file
14
gpu/.env.template
Normal file
@@ -0,0 +1,14 @@
|
||||
# Inference server configuration
|
||||
HOST=0.0.0.0
|
||||
PORT=8000
|
||||
|
||||
# VRAM management
|
||||
VRAM_BUDGET_MB=10240
|
||||
STRATEGY=sequential # sequential | concurrent | auto
|
||||
|
||||
# Model defaults
|
||||
YOLO_MODEL=yolov8n.pt
|
||||
YOLO_CONFIDENCE=0.3
|
||||
|
||||
# Device
|
||||
DEVICE=auto # auto | cpu | cuda | cuda:0
|
||||
18
gpu/Dockerfile
Normal file
18
gpu/Dockerfile
Normal file
@@ -0,0 +1,18 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
RUN pip install --no-cache-dir uv
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
libgl1 libglib2.0-0 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN uv pip install --system --no-cache -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["python", "server.py"]
|
||||
4
gpu/requirements.txt
Normal file
4
gpu/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
fastapi>=0.109.0
|
||||
uvicorn[standard]>=0.27.0
|
||||
ultralytics>=8.0.0
|
||||
Pillow>=10.0.0
|
||||
54
gpu/run.sh
Executable file
54
gpu/run.sh
Executable file
@@ -0,0 +1,54 @@
|
||||
#!/bin/bash
|
||||
# Run the inference server
|
||||
#
|
||||
# Usage:
|
||||
# ./run.sh # Local (pip install -r requirements.txt first)
|
||||
# ./run.sh docker # Docker (CPU)
|
||||
# ./run.sh docker-gpu # Docker with GPU
|
||||
# ./run.sh stop # Stop Docker container
|
||||
|
||||
set -e
|
||||
cd "$(dirname "${BASH_SOURCE[0]}")"
|
||||
|
||||
# Load env (create from template if missing)
|
||||
if [ ! -f .env ]; then
|
||||
if [ -f .env.template ]; then
|
||||
cp .env.template .env
|
||||
echo "Created .env from template — edit as needed"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -f .env ]; then
|
||||
set -a
|
||||
source .env
|
||||
set +a
|
||||
fi
|
||||
|
||||
case "${1:-local}" in
|
||||
local)
|
||||
python server.py
|
||||
;;
|
||||
docker)
|
||||
docker build -t mpr-inference .
|
||||
ENV_FLAG=""; [ -f .env ] && ENV_FLAG="--env-file .env"
|
||||
docker run --rm -p "${PORT:-8000}:8000" \
|
||||
$ENV_FLAG \
|
||||
--name mpr-inference \
|
||||
mpr-inference
|
||||
;;
|
||||
docker-gpu)
|
||||
docker build -t mpr-inference .
|
||||
ENV_FLAG=""; [ -f .env ] && ENV_FLAG="--env-file .env"
|
||||
docker run --rm --gpus all -p "${PORT:-8000}:8000" \
|
||||
$ENV_FLAG \
|
||||
--name mpr-inference \
|
||||
mpr-inference
|
||||
;;
|
||||
stop)
|
||||
docker stop mpr-inference 2>/dev/null || true
|
||||
;;
|
||||
*)
|
||||
echo "Usage: ./run.sh [local|docker|docker-gpu|stop]"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
206
gpu/server.py
Normal file
206
gpu/server.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Inference server — thin HTTP wrapper around ML models.
|
||||
|
||||
Runs on the GPU machine. The detection pipeline calls this over HTTP,
|
||||
or imports the same logic locally if GPU is on the same machine.
|
||||
|
||||
Config is loaded from env on startup, then editable at runtime via
|
||||
GET/PUT /config. The UI config panel is just a visual editor for these
|
||||
same values.
|
||||
|
||||
Usage:
|
||||
cd gpu && uvicorn server:app --host 0.0.0.0 --port 8000
|
||||
# or
|
||||
cd gpu && python server.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Runtime config (loaded from env, mutable via API) ---
|
||||
_config = {
|
||||
"device": os.environ.get("DEVICE", "auto"),
|
||||
"yolo_model": os.environ.get("YOLO_MODEL", "yolov8n.pt"),
|
||||
"yolo_confidence": float(os.environ.get("YOLO_CONFIDENCE", "0.3")),
|
||||
"vram_budget_mb": int(os.environ.get("VRAM_BUDGET_MB", "10240")),
|
||||
"strategy": os.environ.get("STRATEGY", "sequential"),
|
||||
}
|
||||
|
||||
# --- Model registry ---
|
||||
_models: dict[str, object] = {}
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
def _get_device() -> str:
|
||||
device = _config["device"]
|
||||
if device != "auto":
|
||||
return device
|
||||
try:
|
||||
import torch
|
||||
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||
except ImportError:
|
||||
return "cpu"
|
||||
|
||||
|
||||
def _get_yolo(model_name: str | None = None):
|
||||
name = model_name or _config["yolo_model"]
|
||||
if name not in _models:
|
||||
from ultralytics import YOLO
|
||||
device = _get_device()
|
||||
logger.info("Loading %s on %s", name, device)
|
||||
model = YOLO(name)
|
||||
model.to(device)
|
||||
_models[name] = model
|
||||
return _models[name]
|
||||
|
||||
|
||||
def _decode_image(b64: str) -> np.ndarray:
|
||||
data = base64.b64decode(b64)
|
||||
img = Image.open(io.BytesIO(data)).convert("RGB")
|
||||
return np.array(img)
|
||||
|
||||
|
||||
# --- Request/Response models ---
|
||||
|
||||
class DetectRequest(BaseModel):
|
||||
image: str # base64 JPEG
|
||||
model: str | None = None # defaults to config yolo_model
|
||||
confidence: float | None = None # defaults to config yolo_confidence
|
||||
target_classes: list[str] | None = None
|
||||
|
||||
|
||||
class BBox(BaseModel):
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
|
||||
|
||||
class DetectResponse(BaseModel):
|
||||
detections: list[BBox]
|
||||
|
||||
|
||||
class ConfigUpdate(BaseModel):
|
||||
"""Partial config update — only provided fields are changed."""
|
||||
device: str | None = None
|
||||
yolo_model: str | None = None
|
||||
yolo_confidence: float | None = None
|
||||
vram_budget_mb: int | None = None
|
||||
strategy: str | None = None
|
||||
|
||||
|
||||
# --- App ---
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("Inference server starting (device=%s)", _get_device())
|
||||
yield
|
||||
logger.info("Inference server shutting down")
|
||||
_models.clear()
|
||||
|
||||
|
||||
app = FastAPI(title="MPR Inference Server", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {
|
||||
"status": "ok",
|
||||
"device": _get_device(),
|
||||
"loaded_models": list(_models.keys()),
|
||||
"vram_budget_mb": _config["vram_budget_mb"],
|
||||
"strategy": _config["strategy"],
|
||||
}
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
def get_config():
|
||||
"""Current runtime config. Same values the .env sets at startup."""
|
||||
return {**_config, "device_resolved": _get_device()}
|
||||
|
||||
|
||||
@app.put("/config")
|
||||
def update_config(update: ConfigUpdate):
|
||||
"""Update runtime config. Only provided fields are changed."""
|
||||
changes = update.model_dump(exclude_none=True)
|
||||
if not changes:
|
||||
return _config
|
||||
|
||||
# If model changed, unload the old one so it gets reloaded on next request
|
||||
if "yolo_model" in changes and changes["yolo_model"] != _config["yolo_model"]:
|
||||
old = _config["yolo_model"]
|
||||
if old in _models:
|
||||
del _models[old]
|
||||
logger.info("Unloaded %s (model changed)", old)
|
||||
|
||||
_config.update(changes)
|
||||
logger.info("Config updated: %s", changes)
|
||||
return {**_config, "device_resolved": _get_device()}
|
||||
|
||||
|
||||
@app.post("/models/unload")
|
||||
def unload_model(body: dict):
|
||||
"""Unload a model from memory to free VRAM."""
|
||||
name = body.get("model", "")
|
||||
if name in _models:
|
||||
del _models[name]
|
||||
logger.info("Unloaded %s", name)
|
||||
return {"status": "unloaded", "model": name}
|
||||
return {"status": "not_loaded", "model": name}
|
||||
|
||||
|
||||
@app.post("/detect", response_model=DetectResponse)
|
||||
def detect(req: DetectRequest):
|
||||
model_name = req.model or _config["yolo_model"]
|
||||
confidence = req.confidence if req.confidence is not None else _config["yolo_confidence"]
|
||||
|
||||
try:
|
||||
model = _get_yolo(model_name)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to load model: {e}")
|
||||
|
||||
image = _decode_image(req.image)
|
||||
results = model(image, conf=confidence, verbose=False)
|
||||
|
||||
detections = []
|
||||
for r in results:
|
||||
for box in r.boxes:
|
||||
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||
label = r.names[int(box.cls[0])]
|
||||
|
||||
if req.target_classes and label not in req.target_classes:
|
||||
continue
|
||||
|
||||
det = BBox(
|
||||
x=int(x1), y=int(y1),
|
||||
w=int(x2 - x1), h=int(y2 - y1),
|
||||
confidence=float(box.conf[0]),
|
||||
label=label,
|
||||
)
|
||||
detections.append(det)
|
||||
|
||||
return DetectResponse(detections=detections)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)-7s %(name)s — %(message)s")
|
||||
host = os.environ.get("HOST", "0.0.0.0")
|
||||
port = int(os.environ.get("PORT", "8000"))
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
@@ -5,6 +5,7 @@ import 'mpr-ui-framework/src/tokens.css'
|
||||
import LogPanel from './panels/LogPanel.vue'
|
||||
import FunnelPanel from './panels/FunnelPanel.vue'
|
||||
import PipelineGraphPanel from './panels/PipelineGraphPanel.vue'
|
||||
import FramePanel from './panels/FramePanel.vue'
|
||||
import type { StatsUpdate } from './types/sse-contract'
|
||||
|
||||
const jobId = ref(new URLSearchParams(window.location.search).get('job') || 'test-job')
|
||||
@@ -41,7 +42,7 @@ source.connect()
|
||||
<span class="job-id">job: {{ jobId }}</span>
|
||||
</header>
|
||||
|
||||
<LayoutGrid :columns="2" :rows="2" gap="var(--space-2)">
|
||||
<LayoutGrid :columns="3" :rows="2" gap="var(--space-2)">
|
||||
<Panel title="Stats" :status="status">
|
||||
<div class="stats" v-if="stats">
|
||||
<div class="stat" v-for="s in [
|
||||
@@ -61,6 +62,8 @@ source.connect()
|
||||
|
||||
<FunnelPanel :source="source" :status="status" />
|
||||
|
||||
<FramePanel :source="source" :status="status" />
|
||||
|
||||
<PipelineGraphPanel :source="source" :status="status" />
|
||||
|
||||
<LogPanel :source="source" :status="status" />
|
||||
|
||||
31
ui/detection-app/src/panels/FramePanel.vue
Normal file
31
ui/detection-app/src/panels/FramePanel.vue
Normal file
@@ -0,0 +1,31 @@
|
||||
<script setup lang="ts">
|
||||
import { ref } from 'vue'
|
||||
import { Panel } from 'mpr-ui-framework'
|
||||
import FrameRenderer from 'mpr-ui-framework/src/renderers/FrameRenderer.vue'
|
||||
import type { FrameBBox } from 'mpr-ui-framework/src/renderers/FrameRenderer.vue'
|
||||
import type { DataSource } from 'mpr-ui-framework'
|
||||
|
||||
const props = defineProps<{
|
||||
source: DataSource
|
||||
status?: 'idle' | 'live' | 'processing' | 'error'
|
||||
}>()
|
||||
|
||||
const imageSrc = ref('')
|
||||
const boxes = ref<FrameBBox[]>([])
|
||||
|
||||
props.source.on<{
|
||||
frame_ref: number
|
||||
timestamp: number
|
||||
jpeg_b64: string
|
||||
boxes: FrameBBox[]
|
||||
}>('frame_update', (e) => {
|
||||
imageSrc.value = e.jpeg_b64
|
||||
boxes.value = e.boxes
|
||||
})
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<Panel title="Frame Viewer" :status="status">
|
||||
<FrameRenderer :image-src="imageSrc" :boxes="boxes" />
|
||||
</Panel>
|
||||
</template>
|
||||
@@ -12,3 +12,4 @@ export { default as LayoutGrid } from './components/LayoutGrid.vue'
|
||||
export { default as LogRenderer } from './renderers/LogRenderer.vue'
|
||||
export { default as TimeSeriesRenderer } from './renderers/TimeSeriesRenderer.vue'
|
||||
export { default as GraphRenderer } from './renderers/GraphRenderer.vue'
|
||||
export { default as FrameRenderer } from './renderers/FrameRenderer.vue'
|
||||
|
||||
115
ui/framework/src/renderers/FrameRenderer.vue
Normal file
115
ui/framework/src/renderers/FrameRenderer.vue
Normal file
@@ -0,0 +1,115 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, watch, onMounted, onUnmounted, nextTick } from 'vue'
|
||||
|
||||
export interface FrameBBox {
|
||||
x: number
|
||||
y: number
|
||||
w: number
|
||||
h: number
|
||||
confidence: number
|
||||
label: string
|
||||
}
|
||||
|
||||
const props = defineProps<{
|
||||
/** Base64 JPEG image */
|
||||
imageSrc: string
|
||||
/** Bounding boxes to overlay */
|
||||
boxes: FrameBBox[]
|
||||
}>()
|
||||
|
||||
const canvas = ref<HTMLCanvasElement | null>(null)
|
||||
const container = ref<HTMLElement | null>(null)
|
||||
|
||||
function draw() {
|
||||
const cvs = canvas.value
|
||||
const ctr = container.value
|
||||
if (!cvs || !ctr || !props.imageSrc) return
|
||||
|
||||
const ctx = cvs.getContext('2d')
|
||||
if (!ctx) return
|
||||
|
||||
const img = new window.Image()
|
||||
img.onload = () => {
|
||||
cvs.width = ctr.clientWidth
|
||||
cvs.height = ctr.clientHeight
|
||||
|
||||
const scale = Math.min(cvs.width / img.width, cvs.height / img.height)
|
||||
const dx = (cvs.width - img.width * scale) / 2
|
||||
const dy = (cvs.height - img.height * scale) / 2
|
||||
|
||||
ctx.clearRect(0, 0, cvs.width, cvs.height)
|
||||
ctx.drawImage(img, dx, dy, img.width * scale, img.height * scale)
|
||||
|
||||
for (const box of props.boxes) {
|
||||
const bx = dx + box.x * scale
|
||||
const by = dy + box.y * scale
|
||||
const bw = box.w * scale
|
||||
const bh = box.h * scale
|
||||
|
||||
// Box outline
|
||||
ctx.strokeStyle = confidenceColor(box.confidence)
|
||||
ctx.lineWidth = 2
|
||||
ctx.strokeRect(bx, by, bw, bh)
|
||||
|
||||
// Label background
|
||||
const label = `${box.label} ${(box.confidence * 100).toFixed(0)}%`
|
||||
ctx.font = '11px var(--font-mono)'
|
||||
const metrics = ctx.measureText(label)
|
||||
const labelH = 16
|
||||
ctx.fillStyle = confidenceColor(box.confidence)
|
||||
ctx.fillRect(bx, by - labelH, metrics.width + 8, labelH)
|
||||
|
||||
// Label text
|
||||
ctx.fillStyle = '#000'
|
||||
ctx.fillText(label, bx + 4, by - 4)
|
||||
}
|
||||
}
|
||||
img.src = `data:image/jpeg;base64,${props.imageSrc}`
|
||||
}
|
||||
|
||||
function confidenceColor(conf: number): string {
|
||||
if (conf >= 0.7) return 'var(--conf-high)'
|
||||
if (conf >= 0.4) return 'var(--conf-mid)'
|
||||
return 'var(--conf-low)'
|
||||
}
|
||||
|
||||
watch(() => [props.imageSrc, props.boxes], () => nextTick(draw), { deep: true })
|
||||
|
||||
onMounted(() => {
|
||||
nextTick(draw)
|
||||
const observer = new ResizeObserver(() => draw())
|
||||
if (container.value) observer.observe(container.value)
|
||||
onUnmounted(() => observer.disconnect())
|
||||
})
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div ref="container" class="frame-renderer">
|
||||
<canvas ref="canvas" />
|
||||
<div v-if="!imageSrc" class="frame-empty">No frame</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.frame-renderer {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
min-height: 200px;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.frame-renderer canvas {
|
||||
display: block;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
.frame-empty {
|
||||
position: absolute;
|
||||
inset: 0;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
color: var(--text-dim);
|
||||
}
|
||||
</style>
|
||||
Reference in New Issue
Block a user