101 lines
2.5 KiB
Python
101 lines
2.5 KiB
Python
"""moondream2 visual language model wrapper."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
from models import registry
|
|
from config import get_config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_MODEL_KEY = "vlm_moondream2"
|
|
|
|
|
|
def _load():
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
device = get_config().get("device", "auto")
|
|
if device == "auto":
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
logger.info("Loading moondream2 (device=%s)...", device)
|
|
|
|
model_id = "vikhyatk/moondream2"
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
|
dtype = torch.float16 if "cuda" in device else torch.float32
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_id,
|
|
trust_remote_code=True,
|
|
dtype=dtype,
|
|
device_map=device,
|
|
)
|
|
|
|
wrapper = {"model": model, "tokenizer": tokenizer}
|
|
registry.put(_MODEL_KEY, wrapper)
|
|
logger.info("moondream2 loaded")
|
|
return wrapper
|
|
|
|
|
|
def _get():
|
|
wrapper = registry.get(_MODEL_KEY)
|
|
if wrapper is None:
|
|
wrapper = _load()
|
|
return wrapper
|
|
|
|
|
|
def query(image, prompt: str) -> dict:
|
|
"""
|
|
Query moondream2 with an image crop and prompt.
|
|
|
|
Returns {"brand": str, "confidence": float, "reasoning": str}
|
|
"""
|
|
from PIL import Image as PILImage
|
|
|
|
wrapper = _get()
|
|
model = wrapper["model"]
|
|
tokenizer = wrapper["tokenizer"]
|
|
|
|
# Convert numpy array to PIL if needed
|
|
if not isinstance(image, PILImage.Image):
|
|
image = PILImage.fromarray(image)
|
|
|
|
enc_image = model.encode_image(image)
|
|
answer = model.answer_question(enc_image, prompt, tokenizer)
|
|
|
|
# Parse response — moondream2 returns free text, extract brand + confidence
|
|
result = _parse_vlm_response(answer)
|
|
return result
|
|
|
|
|
|
def _parse_vlm_response(answer: str) -> dict:
|
|
"""
|
|
Parse moondream2 free-text response into structured output.
|
|
|
|
Expected format from prompt: "brand, confidence (0-1), reasoning"
|
|
Falls back gracefully if format doesn't match.
|
|
"""
|
|
answer = answer.strip()
|
|
parts = [p.strip() for p in answer.split(",", 2)]
|
|
|
|
brand = parts[0] if parts else ""
|
|
confidence = 0.5
|
|
reasoning = answer
|
|
|
|
if len(parts) >= 2:
|
|
try:
|
|
confidence = float(parts[1])
|
|
confidence = max(0.0, min(1.0, confidence))
|
|
except ValueError:
|
|
pass
|
|
|
|
if len(parts) >= 3:
|
|
reasoning = parts[2]
|
|
|
|
return {
|
|
"brand": brand,
|
|
"confidence": confidence,
|
|
"reasoning": reasoning,
|
|
}
|