This commit is contained in:
2026-03-23 16:55:13 -03:00
parent 4fdbdfc6d3
commit 3df9ed5ada
17 changed files with 848 additions and 4 deletions

156
detect/inference/client.py Normal file
View File

@@ -0,0 +1,156 @@
"""
HTTP client for the inference server.
The pipeline stages call this instead of importing ML libraries directly.
The inference server runs on the GPU machine (or spot instance).
"""
from __future__ import annotations
import base64
import io
import logging
import os
import numpy as np
import requests
from PIL import Image
from .types import DetectResult, OCRResult, ServerStatus, VLMResult
logger = logging.getLogger(__name__)
DEFAULT_URL = os.environ.get("INFERENCE_URL", "http://localhost:8000")
def _encode_image(image: np.ndarray) -> str:
"""Encode numpy array as base64 JPEG."""
img = Image.fromarray(image)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
return base64.b64encode(buf.getvalue()).decode()
class InferenceClient:
"""HTTP client for the GPU inference server."""
def __init__(self, base_url: str | None = None, timeout: float = 60.0):
self.base_url = (base_url or DEFAULT_URL).rstrip("/")
self.timeout = timeout
self.session = requests.Session()
def health(self) -> ServerStatus:
"""Check server health and loaded models."""
resp = self.session.get(f"{self.base_url}/health", timeout=self.timeout)
resp.raise_for_status()
data = resp.json()
return ServerStatus(
loaded_models=data.get("loaded_models", []),
vram_used_mb=data.get("vram_used_mb", 0),
vram_budget_mb=data.get("vram_budget_mb", 0),
strategy=data.get("strategy", "sequential"),
)
def detect(
self,
image: np.ndarray,
model: str = "yolov8n",
confidence: float = 0.3,
target_classes: list[str] | None = None,
) -> list[DetectResult]:
"""Run object detection on an image."""
payload = {
"image": _encode_image(image),
"model": model,
"confidence": confidence,
}
if target_classes:
payload["target_classes"] = target_classes
resp = self.session.post(
f"{self.base_url}/detect",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
results = []
for d in resp.json().get("detections", []):
result = DetectResult(
x=d["x"], y=d["y"], w=d["w"], h=d["h"],
confidence=d["confidence"], label=d["label"],
)
results.append(result)
return results
def ocr(
self,
image: np.ndarray,
languages: list[str] | None = None,
) -> list[OCRResult]:
"""Run OCR on an image region."""
payload = {
"image": _encode_image(image),
}
if languages:
payload["languages"] = languages
resp = self.session.post(
f"{self.base_url}/ocr",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
results = []
for d in resp.json().get("results", []):
result = OCRResult(
text=d["text"],
confidence=d["confidence"],
bbox=tuple(d["bbox"]),
)
results.append(result)
return results
def vlm(
self,
image: np.ndarray,
prompt: str,
model: str = "moondream2",
) -> VLMResult:
"""Query a visual language model with an image crop + prompt."""
payload = {
"image": _encode_image(image),
"prompt": prompt,
"model": model,
}
resp = self.session.post(
f"{self.base_url}/vlm",
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
data = resp.json()
return VLMResult(
brand=data.get("brand", ""),
confidence=data.get("confidence", 0.0),
reasoning=data.get("reasoning", ""),
)
def load_model(self, model: str, quantization: str = "fp16") -> None:
"""Request the server to load a model into VRAM."""
self.session.post(
f"{self.base_url}/models/load",
json={"model": model, "quantization": quantization},
timeout=self.timeout,
).raise_for_status()
def unload_model(self, model: str) -> None:
"""Request the server to unload a model from VRAM."""
self.session.post(
f"{self.base_url}/models/unload",
json={"model": model},
timeout=self.timeout,
).raise_for_status()