"""moondream2 visual language model wrapper.""" from __future__ import annotations import logging from models import registry from config import get_config logger = logging.getLogger(__name__) _MODEL_KEY = "vlm_moondream2" def _load(): import torch from transformers import AutoModelForCausalLM, AutoTokenizer device = get_config().get("device", "auto") if device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" logger.info("Loading moondream2 (device=%s)...", device) model_id = "vikhyatk/moondream2" tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) dtype = torch.float16 if "cuda" in device else torch.float32 model = AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True, dtype=dtype, device_map=device, ) wrapper = {"model": model, "tokenizer": tokenizer} registry.put(_MODEL_KEY, wrapper) logger.info("moondream2 loaded") return wrapper def _get(): wrapper = registry.get(_MODEL_KEY) if wrapper is None: wrapper = _load() return wrapper def query(image, prompt: str) -> dict: """ Query moondream2 with an image crop and prompt. Returns {"brand": str, "confidence": float, "reasoning": str} """ from PIL import Image as PILImage wrapper = _get() model = wrapper["model"] tokenizer = wrapper["tokenizer"] # Convert numpy array to PIL if needed if not isinstance(image, PILImage.Image): image = PILImage.fromarray(image) enc_image = model.encode_image(image) answer = model.answer_question(enc_image, prompt, tokenizer) # Parse response — moondream2 returns free text, extract brand + confidence result = _parse_vlm_response(answer) return result def _parse_vlm_response(answer: str) -> dict: """ Parse moondream2 free-text response into structured output. Expected format from prompt: "brand, confidence (0-1), reasoning" Falls back gracefully if format doesn't match. """ answer = answer.strip() parts = [p.strip() for p in answer.split(",", 2)] brand = parts[0] if parts else "" confidence = 0.5 reasoning = answer if len(parts) >= 2: try: confidence = float(parts[1]) confidence = max(0.0, min(1.0, confidence)) except ValueError: pass if len(parts) >= 3: reasoning = parts[2] return { "brand": brand, "confidence": confidence, "reasoning": reasoning, }