phase 9
This commit is contained in:
100
gpu/models/vlm.py
Normal file
100
gpu/models/vlm.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""moondream2 visual language model wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from models import registry
|
||||
from config import get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MODEL_KEY = "vlm_moondream2"
|
||||
|
||||
|
||||
def _load():
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
device = get_config().get("device", "auto")
|
||||
if device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
logger.info("Loading moondream2 (device=%s)...", device)
|
||||
|
||||
model_id = "vikhyatk/moondream2"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
dtype = torch.float16 if "cuda" in device else torch.float32
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
device_map=device,
|
||||
)
|
||||
|
||||
wrapper = {"model": model, "tokenizer": tokenizer}
|
||||
registry.put(_MODEL_KEY, wrapper)
|
||||
logger.info("moondream2 loaded")
|
||||
return wrapper
|
||||
|
||||
|
||||
def _get():
|
||||
wrapper = registry.get(_MODEL_KEY)
|
||||
if wrapper is None:
|
||||
wrapper = _load()
|
||||
return wrapper
|
||||
|
||||
|
||||
def query(image, prompt: str) -> dict:
|
||||
"""
|
||||
Query moondream2 with an image crop and prompt.
|
||||
|
||||
Returns {"brand": str, "confidence": float, "reasoning": str}
|
||||
"""
|
||||
from PIL import Image as PILImage
|
||||
|
||||
wrapper = _get()
|
||||
model = wrapper["model"]
|
||||
tokenizer = wrapper["tokenizer"]
|
||||
|
||||
# Convert numpy array to PIL if needed
|
||||
if not isinstance(image, PILImage.Image):
|
||||
image = PILImage.fromarray(image)
|
||||
|
||||
enc_image = model.encode_image(image)
|
||||
answer = model.answer_question(enc_image, prompt, tokenizer)
|
||||
|
||||
# Parse response — moondream2 returns free text, extract brand + confidence
|
||||
result = _parse_vlm_response(answer)
|
||||
return result
|
||||
|
||||
|
||||
def _parse_vlm_response(answer: str) -> dict:
|
||||
"""
|
||||
Parse moondream2 free-text response into structured output.
|
||||
|
||||
Expected format from prompt: "brand, confidence (0-1), reasoning"
|
||||
Falls back gracefully if format doesn't match.
|
||||
"""
|
||||
answer = answer.strip()
|
||||
parts = [p.strip() for p in answer.split(",", 2)]
|
||||
|
||||
brand = parts[0] if parts else ""
|
||||
confidence = 0.5
|
||||
reasoning = answer
|
||||
|
||||
if len(parts) >= 2:
|
||||
try:
|
||||
confidence = float(parts[1])
|
||||
confidence = max(0.0, min(1.0, confidence))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if len(parts) >= 3:
|
||||
reasoning = parts[2]
|
||||
|
||||
return {
|
||||
"brand": brand,
|
||||
"confidence": confidence,
|
||||
"reasoning": reasoning,
|
||||
}
|
||||
Reference in New Issue
Block a user