phase 9
This commit is contained in:
100
gpu/models/vlm.py
Normal file
100
gpu/models/vlm.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""moondream2 visual language model wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from models import registry
|
||||
from config import get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MODEL_KEY = "vlm_moondream2"
|
||||
|
||||
|
||||
def _load():
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
device = get_config().get("device", "auto")
|
||||
if device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
logger.info("Loading moondream2 (device=%s)...", device)
|
||||
|
||||
model_id = "vikhyatk/moondream2"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
dtype = torch.float16 if "cuda" in device else torch.float32
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
device_map=device,
|
||||
)
|
||||
|
||||
wrapper = {"model": model, "tokenizer": tokenizer}
|
||||
registry.put(_MODEL_KEY, wrapper)
|
||||
logger.info("moondream2 loaded")
|
||||
return wrapper
|
||||
|
||||
|
||||
def _get():
|
||||
wrapper = registry.get(_MODEL_KEY)
|
||||
if wrapper is None:
|
||||
wrapper = _load()
|
||||
return wrapper
|
||||
|
||||
|
||||
def query(image, prompt: str) -> dict:
|
||||
"""
|
||||
Query moondream2 with an image crop and prompt.
|
||||
|
||||
Returns {"brand": str, "confidence": float, "reasoning": str}
|
||||
"""
|
||||
from PIL import Image as PILImage
|
||||
|
||||
wrapper = _get()
|
||||
model = wrapper["model"]
|
||||
tokenizer = wrapper["tokenizer"]
|
||||
|
||||
# Convert numpy array to PIL if needed
|
||||
if not isinstance(image, PILImage.Image):
|
||||
image = PILImage.fromarray(image)
|
||||
|
||||
enc_image = model.encode_image(image)
|
||||
answer = model.answer_question(enc_image, prompt, tokenizer)
|
||||
|
||||
# Parse response — moondream2 returns free text, extract brand + confidence
|
||||
result = _parse_vlm_response(answer)
|
||||
return result
|
||||
|
||||
|
||||
def _parse_vlm_response(answer: str) -> dict:
|
||||
"""
|
||||
Parse moondream2 free-text response into structured output.
|
||||
|
||||
Expected format from prompt: "brand, confidence (0-1), reasoning"
|
||||
Falls back gracefully if format doesn't match.
|
||||
"""
|
||||
answer = answer.strip()
|
||||
parts = [p.strip() for p in answer.split(",", 2)]
|
||||
|
||||
brand = parts[0] if parts else ""
|
||||
confidence = 0.5
|
||||
reasoning = answer
|
||||
|
||||
if len(parts) >= 2:
|
||||
try:
|
||||
confidence = float(parts[1])
|
||||
confidence = max(0.0, min(1.0, confidence))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if len(parts) >= 3:
|
||||
reasoning = parts[2]
|
||||
|
||||
return {
|
||||
"brand": brand,
|
||||
"confidence": confidence,
|
||||
"reasoning": reasoning,
|
||||
}
|
||||
@@ -19,3 +19,9 @@ ultralytics>=8.0.0
|
||||
# Install with:
|
||||
# uv pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
paddleocr>=3.0.0
|
||||
|
||||
# VLM (moondream2) — uses torch (already installed above)
|
||||
# Pinned <5: transformers 5.x broke moondream2's custom model code
|
||||
# (all_tied_weights_keys API change). Also needs accelerate for device_map.
|
||||
transformers>=4.40.0,<5
|
||||
accelerate>=0.27.0
|
||||
|
||||
@@ -25,6 +25,7 @@ from config import get_config, get_device, update_config
|
||||
from models import registry
|
||||
from models.yolo import detect as yolo_detect
|
||||
from models.ocr import ocr as ocr_run
|
||||
from models.vlm import query as vlm_query
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -72,6 +73,18 @@ class OCRResponse(BaseModel):
|
||||
results: list[OCRTextResult]
|
||||
|
||||
|
||||
class VLMRequest(BaseModel):
|
||||
image: str
|
||||
prompt: str
|
||||
model: str | None = None
|
||||
|
||||
|
||||
class VLMResponse(BaseModel):
|
||||
brand: str
|
||||
confidence: float
|
||||
reasoning: str
|
||||
|
||||
|
||||
class ConfigUpdate(BaseModel):
|
||||
device: str | None = None
|
||||
yolo_model: str | None = None
|
||||
@@ -170,6 +183,21 @@ def ocr(req: OCRRequest):
|
||||
return OCRResponse(results=[OCRTextResult(**r) for r in results])
|
||||
|
||||
|
||||
@app.post("/vlm", response_model=VLMResponse)
|
||||
def vlm(req: VLMRequest):
|
||||
try:
|
||||
image = _decode_image(req.image)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Bad image: {e}")
|
||||
|
||||
try:
|
||||
result = vlm_query(image, req.prompt)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"VLM failed: {e}")
|
||||
|
||||
return VLMResponse(**result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
|
||||
Reference in New Issue
Block a user