phase 4
This commit is contained in:
6
core/gpu/models/__init__.py
Normal file
6
core/gpu/models/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# GPU models — standalone container imports.
|
||||
# When running as a container (cd gpu && python server.py), bare imports work.
|
||||
# When imported from the main app (core.gpu.models.preprocess), only
|
||||
# individual modules should be imported directly, not this __init__.
|
||||
#
|
||||
# The server.py imports detect/ocr/vlm directly, not through this file.
|
||||
1
core/gpu/models/cv/__init__.py
Normal file
1
core/gpu/models/cv/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""CV operations — pure OpenCV, no ML models."""
|
||||
258
core/gpu/models/cv/edges.py
Normal file
258
core/gpu/models/cv/edges.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
Edge detection — Canny + HoughLinesP → parallel line pairs → bounding boxes.
|
||||
|
||||
Finds horizontal line pairs with consistent spacing, which correspond to
|
||||
the top and bottom edges of advertising hoardings.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def detect_edges(
|
||||
image: np.ndarray,
|
||||
canny_low: int = 50,
|
||||
canny_high: int = 150,
|
||||
hough_threshold: int = 80,
|
||||
hough_min_length: int = 100,
|
||||
hough_max_gap: int = 10,
|
||||
pair_max_distance: int = 200,
|
||||
pair_min_distance: int = 15,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Find horizontal line pairs that likely bound advertising hoardings.
|
||||
|
||||
Returns list of dicts with keys: x, y, w, h, confidence, label.
|
||||
Each box represents the region between a detected pair of parallel
|
||||
horizontal lines.
|
||||
"""
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
||||
edges = cv2.Canny(gray, canny_low, canny_high)
|
||||
|
||||
raw_lines = cv2.HoughLinesP(
|
||||
edges,
|
||||
rho=1,
|
||||
theta=np.pi / 180,
|
||||
threshold=hough_threshold,
|
||||
minLineLength=hough_min_length,
|
||||
maxLineGap=hough_max_gap,
|
||||
)
|
||||
|
||||
if raw_lines is None:
|
||||
return []
|
||||
|
||||
# Filter to near-horizontal lines (within 10 degrees)
|
||||
horizontals = _filter_horizontal(raw_lines, max_angle_deg=10)
|
||||
|
||||
if len(horizontals) < 2:
|
||||
return []
|
||||
|
||||
# Find pairs of parallel horizontals with consistent spacing
|
||||
pairs = _find_line_pairs(
|
||||
horizontals,
|
||||
min_distance=pair_min_distance,
|
||||
max_distance=pair_max_distance,
|
||||
)
|
||||
|
||||
# Convert pairs to bounding boxes
|
||||
h, w = image.shape[:2]
|
||||
results = []
|
||||
for top_line, bottom_line in pairs:
|
||||
box = _pair_to_bbox(top_line, bottom_line, frame_width=w, frame_height=h)
|
||||
if box is not None:
|
||||
results.append(box)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _filter_horizontal(lines: np.ndarray, max_angle_deg: float = 10) -> list[tuple]:
|
||||
"""Keep only lines within max_angle_deg of horizontal."""
|
||||
max_slope = np.tan(np.radians(max_angle_deg))
|
||||
result = []
|
||||
for line in lines:
|
||||
x1, y1, x2, y2 = line[0]
|
||||
dx = x2 - x1
|
||||
if dx == 0:
|
||||
continue
|
||||
slope = abs((y2 - y1) / dx)
|
||||
if slope <= max_slope:
|
||||
y_mid = (y1 + y2) / 2
|
||||
x_min = min(x1, x2)
|
||||
x_max = max(x1, x2)
|
||||
length = np.sqrt(dx**2 + (y2 - y1) ** 2)
|
||||
result.append((x_min, x_max, y_mid, length))
|
||||
return result
|
||||
|
||||
|
||||
def _find_line_pairs(
|
||||
horizontals: list[tuple],
|
||||
min_distance: int,
|
||||
max_distance: int,
|
||||
) -> list[tuple]:
|
||||
"""
|
||||
Find pairs of horizontal lines that could be top/bottom of a hoarding.
|
||||
|
||||
Lines must overlap horizontally and be spaced within [min_distance, max_distance].
|
||||
"""
|
||||
# Sort by y position
|
||||
sorted_lines = sorted(horizontals, key=lambda l: l[2])
|
||||
|
||||
pairs = []
|
||||
used = set()
|
||||
|
||||
for i, top in enumerate(sorted_lines):
|
||||
if i in used:
|
||||
continue
|
||||
for j, bottom in enumerate(sorted_lines[i + 1 :], start=i + 1):
|
||||
if j in used:
|
||||
continue
|
||||
|
||||
y_gap = bottom[2] - top[2]
|
||||
if y_gap < min_distance:
|
||||
continue
|
||||
if y_gap > max_distance:
|
||||
break # sorted by y, no point checking further
|
||||
|
||||
# Check horizontal overlap
|
||||
overlap_start = max(top[0], bottom[0])
|
||||
overlap_end = min(top[1], bottom[1])
|
||||
overlap = overlap_end - overlap_start
|
||||
|
||||
# Require at least 50% overlap relative to shorter line
|
||||
shorter_length = min(top[1] - top[0], bottom[1] - bottom[0])
|
||||
if shorter_length > 0 and overlap / shorter_length >= 0.5:
|
||||
pairs.append((top, bottom))
|
||||
used.add(i)
|
||||
used.add(j)
|
||||
break
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
def _pair_to_bbox(
|
||||
top: tuple,
|
||||
bottom: tuple,
|
||||
frame_width: int,
|
||||
frame_height: int,
|
||||
) -> dict | None:
|
||||
"""Convert a line pair to a bounding box dict."""
|
||||
x = int(max(0, min(top[0], bottom[0])))
|
||||
y = int(max(0, top[2]))
|
||||
x2 = int(min(frame_width, max(top[1], bottom[1])))
|
||||
y2 = int(min(frame_height, bottom[2]))
|
||||
w = x2 - x
|
||||
h = y2 - y
|
||||
|
||||
if w < 20 or h < 5:
|
||||
return None
|
||||
|
||||
# Confidence based on line lengths relative to box width
|
||||
avg_line_length = (top[3] + bottom[3]) / 2
|
||||
coverage = min(1.0, avg_line_length / max(w, 1))
|
||||
|
||||
return {
|
||||
"x": x,
|
||||
"y": y,
|
||||
"w": w,
|
||||
"h": h,
|
||||
"confidence": round(coverage, 3),
|
||||
"label": "edge_region",
|
||||
}
|
||||
|
||||
|
||||
def _np_to_b64_jpeg(image: np.ndarray, quality: int = 70) -> str:
|
||||
"""Encode a numpy image (BGR or grayscale) as base64 JPEG."""
|
||||
ok, buf = cv2.imencode(".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality])
|
||||
if not ok:
|
||||
return ""
|
||||
return base64.b64encode(buf.tobytes()).decode()
|
||||
|
||||
|
||||
def detect_edges_debug(
|
||||
image: np.ndarray,
|
||||
canny_low: int = 50,
|
||||
canny_high: int = 150,
|
||||
hough_threshold: int = 80,
|
||||
hough_min_length: int = 100,
|
||||
hough_max_gap: int = 10,
|
||||
pair_max_distance: int = 200,
|
||||
pair_min_distance: int = 15,
|
||||
) -> dict:
|
||||
"""
|
||||
Same as detect_edges but returns intermediate visualizations.
|
||||
|
||||
Returns dict with:
|
||||
regions: list[dict] — same boxes as detect_edges
|
||||
edge_overlay_b64: str — Canny edge image as base64 JPEG
|
||||
lines_overlay_b64: str — frame with Hough lines drawn
|
||||
horizontal_count: int — number of horizontal lines found
|
||||
pair_count: int — number of line pairs found
|
||||
"""
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
||||
edges = cv2.Canny(gray, canny_low, canny_high)
|
||||
|
||||
# Edge overlay — Canny output as-is (white edges on black)
|
||||
edge_overlay_b64 = _np_to_b64_jpeg(edges)
|
||||
|
||||
raw_lines = cv2.HoughLinesP(
|
||||
edges,
|
||||
rho=1,
|
||||
theta=np.pi / 180,
|
||||
threshold=hough_threshold,
|
||||
minLineLength=hough_min_length,
|
||||
maxLineGap=hough_max_gap,
|
||||
)
|
||||
|
||||
# Lines overlay — draw all Hough lines on a copy of the frame
|
||||
lines_vis = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
||||
if raw_lines is not None:
|
||||
for line in raw_lines:
|
||||
x1, y1, x2, y2 = line[0]
|
||||
cv2.line(lines_vis, (x1, y1), (x2, y2), (0, 0, 255), 1)
|
||||
|
||||
horizontals = []
|
||||
if raw_lines is not None:
|
||||
horizontals = _filter_horizontal(raw_lines, max_angle_deg=10)
|
||||
|
||||
# Draw horizontal lines in cyan, thicker
|
||||
for h_line in horizontals:
|
||||
x_min, x_max, y_mid, _ = h_line
|
||||
cv2.line(lines_vis, (int(x_min), int(y_mid)), (int(x_max), int(y_mid)), (255, 255, 0), 2)
|
||||
|
||||
pairs = []
|
||||
if len(horizontals) >= 2:
|
||||
pairs = _find_line_pairs(
|
||||
horizontals,
|
||||
min_distance=pair_min_distance,
|
||||
max_distance=pair_max_distance,
|
||||
)
|
||||
|
||||
# Draw paired lines in green
|
||||
for top_line, bottom_line in pairs:
|
||||
cv2.line(lines_vis, (int(top_line[0]), int(top_line[2])),
|
||||
(int(top_line[1]), int(top_line[2])), (0, 255, 0), 2)
|
||||
cv2.line(lines_vis, (int(bottom_line[0]), int(bottom_line[2])),
|
||||
(int(bottom_line[1]), int(bottom_line[2])), (0, 255, 0), 2)
|
||||
|
||||
lines_overlay_b64 = _np_to_b64_jpeg(lines_vis)
|
||||
|
||||
# Build region boxes (same logic as detect_edges)
|
||||
h, w = image.shape[:2]
|
||||
regions = []
|
||||
for top_line, bottom_line in pairs:
|
||||
box = _pair_to_bbox(top_line, bottom_line, frame_width=w, frame_height=h)
|
||||
if box is not None:
|
||||
regions.append(box)
|
||||
|
||||
return {
|
||||
"regions": regions,
|
||||
"edge_overlay_b64": edge_overlay_b64,
|
||||
"lines_overlay_b64": lines_overlay_b64,
|
||||
"horizontal_count": len(horizontals),
|
||||
"pair_count": len(pairs),
|
||||
}
|
||||
86
core/gpu/models/cv/segmentation.py
Normal file
86
core/gpu/models/cv/segmentation.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Field segmentation — HSV green mask → pitch boundary contour.
|
||||
|
||||
Pure OpenCV. Called by the inference server endpoint.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def segment_field(
|
||||
image: np.ndarray,
|
||||
hue_low: int = 30,
|
||||
hue_high: int = 85,
|
||||
sat_low: int = 30,
|
||||
sat_high: int = 255,
|
||||
val_low: int = 30,
|
||||
val_high: int = 255,
|
||||
morph_kernel: int = 15,
|
||||
min_area_ratio: float = 0.05,
|
||||
) -> dict:
|
||||
"""
|
||||
Detect the pitch area using HSV green thresholding.
|
||||
|
||||
Returns dict with:
|
||||
boundary: list of [x, y] points
|
||||
coverage: float (fraction of frame)
|
||||
"""
|
||||
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
|
||||
|
||||
lower = np.array([hue_low, sat_low, val_low])
|
||||
upper = np.array([hue_high, sat_high, val_high])
|
||||
mask = cv2.inRange(hsv, lower, upper)
|
||||
|
||||
k = morph_kernel
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
|
||||
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
||||
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
||||
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
h, w = image.shape[:2]
|
||||
min_area = min_area_ratio * h * w
|
||||
boundary = []
|
||||
coverage = 0.0
|
||||
|
||||
if contours:
|
||||
large = [c for c in contours if cv2.contourArea(c) >= min_area]
|
||||
if large:
|
||||
pitch_contour = max(large, key=cv2.contourArea)
|
||||
boundary = pitch_contour.reshape(-1, 2).tolist()
|
||||
coverage = cv2.contourArea(pitch_contour) / (h * w)
|
||||
|
||||
refined = np.zeros_like(mask)
|
||||
cv2.drawContours(refined, [pitch_contour], -1, 255, cv2.FILLED)
|
||||
mask = refined
|
||||
|
||||
return {
|
||||
"boundary": boundary,
|
||||
"coverage": coverage,
|
||||
"mask": mask,
|
||||
}
|
||||
|
||||
|
||||
def segment_field_debug(
|
||||
image: np.ndarray,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""Same as segment_field but includes a mask overlay for the editor."""
|
||||
result = segment_field(image, **kwargs)
|
||||
mask = result["mask"]
|
||||
|
||||
# RGBA overlay: solid green where mask, fully transparent elsewhere
|
||||
h, w = image.shape[:2]
|
||||
overlay = np.zeros((h, w, 4), dtype=np.uint8)
|
||||
overlay[mask > 0] = [0, 255, 0, 255]
|
||||
_, buf = cv2.imencode(".png", overlay)
|
||||
result["mask_overlay_b64"] = base64.b64encode(buf.tobytes()).decode()
|
||||
|
||||
# Don't send the raw mask over HTTP
|
||||
del result["mask"]
|
||||
return result
|
||||
136
core/gpu/models/models.py
Normal file
136
core/gpu/models/models.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Pydantic Models - GENERATED FILE
|
||||
|
||||
Do not edit directly. Regenerate using modelgen.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class DetectRequest(BaseModel):
|
||||
"""Request body for object detection."""
|
||||
image: str
|
||||
model: Optional[str] = None
|
||||
confidence: Optional[float] = None
|
||||
target_classes: Optional[List[str]] = None
|
||||
|
||||
class BBox(BaseModel):
|
||||
"""A detected bounding box."""
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
|
||||
class DetectResponse(BaseModel):
|
||||
"""Response from object detection."""
|
||||
detections: List[BBox] = Field(default_factory=list)
|
||||
|
||||
class OCRRequest(BaseModel):
|
||||
"""Request body for OCR."""
|
||||
image: str
|
||||
languages: Optional[List[str]] = None
|
||||
|
||||
class OCRTextResult(BaseModel):
|
||||
"""A single OCR text extraction result."""
|
||||
text: str
|
||||
confidence: float
|
||||
bbox: List[int] = Field(default_factory=list)
|
||||
|
||||
class OCRResponse(BaseModel):
|
||||
"""Response from OCR."""
|
||||
results: List[OCRTextResult] = Field(default_factory=list)
|
||||
|
||||
class PreprocessRequest(BaseModel):
|
||||
"""Request body for image preprocessing."""
|
||||
image: str
|
||||
binarize: bool = False
|
||||
deskew: bool = False
|
||||
contrast: bool = True
|
||||
|
||||
class PreprocessResponse(BaseModel):
|
||||
"""Response from preprocessing."""
|
||||
image: str
|
||||
|
||||
class VLMRequest(BaseModel):
|
||||
"""Request body for visual language model query."""
|
||||
image: str
|
||||
prompt: str
|
||||
model: Optional[str] = None
|
||||
|
||||
class VLMResponse(BaseModel):
|
||||
"""Response from VLM."""
|
||||
brand: str
|
||||
confidence: float
|
||||
reasoning: str
|
||||
|
||||
class AnalyzeRegionsRequest(BaseModel):
|
||||
"""Request body for CV region analysis."""
|
||||
image: str
|
||||
edge_canny_low: int = 50
|
||||
edge_canny_high: int = 150
|
||||
edge_hough_threshold: int = 80
|
||||
edge_hough_min_length: int = 100
|
||||
edge_hough_max_gap: int = 10
|
||||
edge_pair_max_distance: int = 200
|
||||
edge_pair_min_distance: int = 15
|
||||
|
||||
class RegionBox(BaseModel):
|
||||
"""A candidate region from CV analysis."""
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
label: str
|
||||
|
||||
class AnalyzeRegionsResponse(BaseModel):
|
||||
"""Response from CV region analysis."""
|
||||
regions: List[RegionBox] = Field(default_factory=list)
|
||||
|
||||
class AnalyzeRegionsDebugResponse(BaseModel):
|
||||
"""Response from CV region analysis with debug overlays."""
|
||||
regions: List[RegionBox] = Field(default_factory=list)
|
||||
edge_overlay_b64: str = ""
|
||||
lines_overlay_b64: str = ""
|
||||
horizontal_count: int = 0
|
||||
pair_count: int = 0
|
||||
|
||||
class SegmentFieldRequest(BaseModel):
|
||||
"""Request body for field segmentation."""
|
||||
image: str
|
||||
hue_low: int = 30
|
||||
hue_high: int = 85
|
||||
sat_low: int = 30
|
||||
sat_high: int = 255
|
||||
val_low: int = 30
|
||||
val_high: int = 255
|
||||
morph_kernel: int = 15
|
||||
min_area_ratio: float = 0.05
|
||||
|
||||
class SegmentFieldResponse(BaseModel):
|
||||
"""Response from field segmentation."""
|
||||
boundary: List[List[int]] = Field(default_factory=list)
|
||||
coverage: float = 0.0
|
||||
mask_b64: str = ""
|
||||
|
||||
class SegmentFieldDebugResponse(BaseModel):
|
||||
"""Response from field segmentation with debug overlay."""
|
||||
boundary: List[List[int]] = Field(default_factory=list)
|
||||
coverage: float = 0.0
|
||||
mask_overlay_b64: str = ""
|
||||
|
||||
class ConfigUpdate(BaseModel):
|
||||
"""Request body for updating server configuration."""
|
||||
device: Optional[str] = None
|
||||
yolo_model: Optional[str] = None
|
||||
yolo_confidence: Optional[float] = None
|
||||
vram_budget_mb: Optional[int] = None
|
||||
strategy: Optional[str] = None
|
||||
ocr_languages: Optional[List[str]] = None
|
||||
ocr_min_confidence: Optional[float] = None
|
||||
105
core/gpu/models/ocr.py
Normal file
105
core/gpu/models/ocr.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""PaddleOCR 3.x text extraction wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from models import registry
|
||||
from config import get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load(languages: list[str]):
|
||||
from paddleocr import PaddleOCR
|
||||
key = f"ocr_{'_'.join(languages)}"
|
||||
model = PaddleOCR(lang=languages[0])
|
||||
registry.put(key, model)
|
||||
return model
|
||||
|
||||
|
||||
def _get(languages: list[str] | None = None):
|
||||
langs = languages or get_config()["ocr_languages"]
|
||||
key = f"ocr_{'_'.join(langs)}"
|
||||
model = registry.get(key)
|
||||
if model is None:
|
||||
model = _load(langs)
|
||||
return model
|
||||
|
||||
|
||||
def _parse_raw(raw) -> list[tuple[list, str, float]]:
|
||||
"""
|
||||
Parse PaddleOCR output into (points, text, confidence) tuples.
|
||||
|
||||
PaddleOCR 3.x changed the result format. Two known layouts:
|
||||
|
||||
Layout A — dict-based (new pipeline API):
|
||||
raw = [{'rec_texts': [...], 'rec_scores': [...], 'dt_polys': [...]}]
|
||||
|
||||
Layout B — nested list (2.x compat / some 3.x builds):
|
||||
raw = [[ [points, [text, score]], ... ]]
|
||||
raw = [[ [points, [text, score], [cls, cls_score]], ... ]] # with angle cls
|
||||
"""
|
||||
results = []
|
||||
|
||||
for page in raw:
|
||||
if not page:
|
||||
continue
|
||||
|
||||
# Layout A: dict with parallel lists
|
||||
if isinstance(page, dict):
|
||||
texts = page.get("rec_texts", [])
|
||||
scores = page.get("rec_scores", [])
|
||||
polys = page.get("dt_polys", [])
|
||||
for points, text, confidence in zip(polys, texts, scores):
|
||||
results.append((points, text, float(confidence)))
|
||||
continue
|
||||
|
||||
# Layout B: list of per-line entries
|
||||
for line in page:
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# line[0] is always the polygon points
|
||||
points = line[0]
|
||||
|
||||
# line[1] is [text, score] — ignore any extra elements (angle cls etc.)
|
||||
rec = line[1]
|
||||
if isinstance(rec, (list, tuple)) and len(rec) >= 2:
|
||||
text, confidence = rec[0], rec[1]
|
||||
else:
|
||||
logger.warning("Unexpected OCR line format: %s", line)
|
||||
continue
|
||||
|
||||
results.append((points, str(text), float(confidence)))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def ocr(image, languages: list[str] | None = None, min_confidence: float | None = None) -> list[dict]:
|
||||
"""Run OCR on an image, return list of text result dicts."""
|
||||
cfg = get_config()
|
||||
min_conf = min_confidence if min_confidence is not None else cfg["ocr_min_confidence"]
|
||||
model = _get(languages)
|
||||
|
||||
raw = model.ocr(image)
|
||||
logger.debug("OCR raw: %s", raw)
|
||||
|
||||
parsed = _parse_raw(raw)
|
||||
|
||||
results = []
|
||||
for points, text, confidence in parsed:
|
||||
if confidence < min_conf:
|
||||
continue
|
||||
|
||||
xs = [p[0] for p in points]
|
||||
ys = [p[1] for p in points]
|
||||
|
||||
results.append({
|
||||
"text": text,
|
||||
"confidence": confidence,
|
||||
"bbox": [int(min(xs)), int(min(ys)),
|
||||
int(max(xs) - min(xs)), int(max(ys) - min(ys))],
|
||||
})
|
||||
|
||||
return results
|
||||
117
core/gpu/models/preprocess.py
Normal file
117
core/gpu/models/preprocess.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Image preprocessing pipeline for crops before OCR.
|
||||
|
||||
Each step is independently toggleable via config.
|
||||
Operates on numpy arrays (BGR or RGB), returns processed array.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def binarize(image: np.ndarray, threshold: int = 128) -> np.ndarray:
|
||||
"""Convert to grayscale and apply Otsu binarization."""
|
||||
import cv2
|
||||
|
||||
if len(image.shape) == 3:
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
||||
else:
|
||||
gray = image
|
||||
|
||||
_, binary = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
|
||||
# Convert back to 3-channel for downstream compatibility
|
||||
result = cv2.cvtColor(binary, cv2.COLOR_GRAY2RGB)
|
||||
return result
|
||||
|
||||
|
||||
def deskew(image: np.ndarray) -> np.ndarray:
|
||||
"""Correct slight rotation using minimum area rectangle."""
|
||||
import cv2
|
||||
|
||||
if len(image.shape) == 3:
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
||||
else:
|
||||
gray = image
|
||||
|
||||
coords = np.column_stack(np.where(gray < 128))
|
||||
if len(coords) < 10:
|
||||
return image
|
||||
|
||||
rect = cv2.minAreaRect(coords)
|
||||
angle = rect[-1]
|
||||
|
||||
# Normalize angle
|
||||
if angle < -45:
|
||||
angle = -(90 + angle)
|
||||
else:
|
||||
angle = -angle
|
||||
|
||||
if abs(angle) < 0.5:
|
||||
return image
|
||||
|
||||
h, w = image.shape[:2]
|
||||
center = (w // 2, h // 2)
|
||||
rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
|
||||
result = cv2.warpAffine(
|
||||
image, rotation_matrix, (w, h),
|
||||
flags=cv2.INTER_LINEAR,
|
||||
borderMode=cv2.BORDER_REPLICATE,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def enhance_contrast(image: np.ndarray) -> np.ndarray:
|
||||
"""Apply CLAHE (adaptive histogram equalization) for contrast normalization."""
|
||||
import cv2
|
||||
|
||||
if len(image.shape) == 3:
|
||||
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
|
||||
l_channel = lab[:, :, 0]
|
||||
else:
|
||||
l_channel = image
|
||||
|
||||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(l_channel)
|
||||
|
||||
if len(image.shape) == 3:
|
||||
lab[:, :, 0] = enhanced
|
||||
result = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
|
||||
else:
|
||||
result = enhanced
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def preprocess(
|
||||
image: np.ndarray,
|
||||
do_binarize: bool = False,
|
||||
do_deskew: bool = False,
|
||||
do_contrast: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Run the preprocessing pipeline on a crop image.
|
||||
|
||||
Each step is independently toggleable. Order: contrast → deskew → binarize.
|
||||
Contrast first (works best on color), binarize last (destroys color info).
|
||||
"""
|
||||
result = image
|
||||
|
||||
if do_contrast:
|
||||
result = enhance_contrast(result)
|
||||
logger.debug("Preprocessing: contrast enhanced")
|
||||
|
||||
if do_deskew:
|
||||
result = deskew(result)
|
||||
logger.debug("Preprocessing: deskewed")
|
||||
|
||||
if do_binarize:
|
||||
result = binarize(result)
|
||||
logger.debug("Preprocessing: binarized")
|
||||
|
||||
return result
|
||||
37
core/gpu/models/registry.py
Normal file
37
core/gpu/models/registry.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
Model registry — manages loaded models and VRAM lifecycle.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_models: dict[str, object] = {}
|
||||
|
||||
|
||||
def get(name: str) -> object | None:
|
||||
return _models.get(name)
|
||||
|
||||
|
||||
def put(name: str, model: object) -> None:
|
||||
_models[name] = model
|
||||
logger.info("Loaded %s", name)
|
||||
|
||||
|
||||
def unload(name: str) -> bool:
|
||||
if name in _models:
|
||||
del _models[name]
|
||||
logger.info("Unloaded %s", name)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def loaded() -> list[str]:
|
||||
return list(_models.keys())
|
||||
|
||||
|
||||
def clear() -> None:
|
||||
_models.clear()
|
||||
logger.info("All models unloaded")
|
||||
100
core/gpu/models/vlm.py
Normal file
100
core/gpu/models/vlm.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""moondream2 visual language model wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from models import registry
|
||||
from config import get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MODEL_KEY = "vlm_moondream2"
|
||||
|
||||
|
||||
def _load():
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
device = get_config().get("device", "auto")
|
||||
if device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
logger.info("Loading moondream2 (device=%s)...", device)
|
||||
|
||||
model_id = "vikhyatk/moondream2"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
dtype = torch.float16 if "cuda" in device else torch.float32
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
device_map=device,
|
||||
)
|
||||
|
||||
wrapper = {"model": model, "tokenizer": tokenizer}
|
||||
registry.put(_MODEL_KEY, wrapper)
|
||||
logger.info("moondream2 loaded")
|
||||
return wrapper
|
||||
|
||||
|
||||
def _get():
|
||||
wrapper = registry.get(_MODEL_KEY)
|
||||
if wrapper is None:
|
||||
wrapper = _load()
|
||||
return wrapper
|
||||
|
||||
|
||||
def query(image, prompt: str) -> dict:
|
||||
"""
|
||||
Query moondream2 with an image crop and prompt.
|
||||
|
||||
Returns {"brand": str, "confidence": float, "reasoning": str}
|
||||
"""
|
||||
from PIL import Image as PILImage
|
||||
|
||||
wrapper = _get()
|
||||
model = wrapper["model"]
|
||||
tokenizer = wrapper["tokenizer"]
|
||||
|
||||
# Convert numpy array to PIL if needed
|
||||
if not isinstance(image, PILImage.Image):
|
||||
image = PILImage.fromarray(image)
|
||||
|
||||
enc_image = model.encode_image(image)
|
||||
answer = model.answer_question(enc_image, prompt, tokenizer)
|
||||
|
||||
# Parse response — moondream2 returns free text, extract brand + confidence
|
||||
result = _parse_vlm_response(answer)
|
||||
return result
|
||||
|
||||
|
||||
def _parse_vlm_response(answer: str) -> dict:
|
||||
"""
|
||||
Parse moondream2 free-text response into structured output.
|
||||
|
||||
Expected format from prompt: "brand, confidence (0-1), reasoning"
|
||||
Falls back gracefully if format doesn't match.
|
||||
"""
|
||||
answer = answer.strip()
|
||||
parts = [p.strip() for p in answer.split(",", 2)]
|
||||
|
||||
brand = parts[0] if parts else ""
|
||||
confidence = 0.5
|
||||
reasoning = answer
|
||||
|
||||
if len(parts) >= 2:
|
||||
try:
|
||||
confidence = float(parts[1])
|
||||
confidence = max(0.0, min(1.0, confidence))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if len(parts) >= 3:
|
||||
reasoning = parts[2]
|
||||
|
||||
return {
|
||||
"brand": brand,
|
||||
"confidence": confidence,
|
||||
"reasoning": reasoning,
|
||||
}
|
||||
54
core/gpu/models/yolo.py
Normal file
54
core/gpu/models/yolo.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""YOLO object detection model wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from models import registry
|
||||
from config import get_config, get_device
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load(model_name: str):
|
||||
from ultralytics import YOLO
|
||||
device = get_device()
|
||||
model = YOLO(model_name)
|
||||
model.to(device)
|
||||
registry.put(model_name, model)
|
||||
return model
|
||||
|
||||
|
||||
def _get(model_name: str | None = None):
|
||||
name = model_name or get_config()["yolo_model"]
|
||||
model = registry.get(name)
|
||||
if model is None:
|
||||
model = _load(name)
|
||||
return model
|
||||
|
||||
|
||||
def detect(image, model_name: str | None = None, confidence: float | None = None, target_classes: list[str] | None = None) -> list[dict]:
|
||||
"""Run YOLO detection, return list of bbox dicts."""
|
||||
cfg = get_config()
|
||||
conf = confidence if confidence is not None else cfg["yolo_confidence"]
|
||||
model = _get(model_name)
|
||||
|
||||
results = model(image, conf=conf, verbose=False)
|
||||
|
||||
detections = []
|
||||
for r in results:
|
||||
for box in r.boxes:
|
||||
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||
label = r.names[int(box.cls[0])]
|
||||
|
||||
if target_classes and label not in target_classes:
|
||||
continue
|
||||
|
||||
detections.append({
|
||||
"x": int(x1), "y": int(y1),
|
||||
"w": int(x2 - x1), "h": int(y2 - y1),
|
||||
"confidence": float(box.conf[0]),
|
||||
"label": label,
|
||||
})
|
||||
|
||||
return detections
|
||||
Reference in New Issue
Block a user