This commit is contained in:
2026-03-26 02:54:56 -03:00
parent dfa3c12514
commit 08b67f2bb7
21 changed files with 1622 additions and 16 deletions

100
gpu/models/vlm.py Normal file
View File

@@ -0,0 +1,100 @@
"""moondream2 visual language model wrapper."""
from __future__ import annotations
import logging
from models import registry
from config import get_config
logger = logging.getLogger(__name__)
_MODEL_KEY = "vlm_moondream2"
def _load():
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = get_config().get("device", "auto")
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Loading moondream2 (device=%s)...", device)
model_id = "vikhyatk/moondream2"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
dtype = torch.float16 if "cuda" in device else torch.float32
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
dtype=dtype,
device_map=device,
)
wrapper = {"model": model, "tokenizer": tokenizer}
registry.put(_MODEL_KEY, wrapper)
logger.info("moondream2 loaded")
return wrapper
def _get():
wrapper = registry.get(_MODEL_KEY)
if wrapper is None:
wrapper = _load()
return wrapper
def query(image, prompt: str) -> dict:
"""
Query moondream2 with an image crop and prompt.
Returns {"brand": str, "confidence": float, "reasoning": str}
"""
from PIL import Image as PILImage
wrapper = _get()
model = wrapper["model"]
tokenizer = wrapper["tokenizer"]
# Convert numpy array to PIL if needed
if not isinstance(image, PILImage.Image):
image = PILImage.fromarray(image)
enc_image = model.encode_image(image)
answer = model.answer_question(enc_image, prompt, tokenizer)
# Parse response — moondream2 returns free text, extract brand + confidence
result = _parse_vlm_response(answer)
return result
def _parse_vlm_response(answer: str) -> dict:
"""
Parse moondream2 free-text response into structured output.
Expected format from prompt: "brand, confidence (0-1), reasoning"
Falls back gracefully if format doesn't match.
"""
answer = answer.strip()
parts = [p.strip() for p in answer.split(",", 2)]
brand = parts[0] if parts else ""
confidence = 0.5
reasoning = answer
if len(parts) >= 2:
try:
confidence = float(parts[1])
confidence = max(0.0, min(1.0, confidence))
except ValueError:
pass
if len(parts) >= 3:
reasoning = parts[2]
return {
"brand": brand,
"confidence": confidence,
"reasoning": reasoning,
}