granite-vision-3.3-2b-embedding / example_simple.py
Adirazgold's picture
Upload 18 files
c8ad458 verified
raw
history blame
3.33 kB
from PIL import Image
import requests
from io import BytesIO
import torch
from transformers import AutoModel, AutoProcessor, AutoConfig, AutoModelForVision2Seq
# from granite_cola import ColGraniteVisionConfig, ColGraniteVision, ColGraniteVisionProcessor
# --- 1) Register your custom classes so AutoModel/AutoProcessor work out-of-the-box
# AutoConfig.register("colgranitevision", ColGraniteVisionConfig)
# AutoModel.register(ColGraniteVisionConfig, ColGraniteVision)
# AutoProcessor.register(ColGraniteVisionConfig, ColGraniteVisionProcessor)
# ─────────────────────────────────────────────
# 2) Load model & processor
# ─────────────────────────────────────────────
model_dir = "."
model = AutoModelForVision2Seq.from_pretrained(
model_dir,
trust_remote_code=True,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)
# self.model = PeftModel.from_pretrained(self.model, peft_path).eval()
processor = AutoProcessor.from_pretrained(
model_dir,
trust_remote_code=True,
use_fast=True
)
# Set patch_size explicitly if needed
if hasattr(processor, 'patch_size') and processor.patch_size is None:
processor.patch_size = 14 # Default patch size for vision transformers
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()
# ─────────────────────────────────────────────
# 3) Download sample image + build a prompt containing <image>
# ─────────────────────────────────────────────
image_url = "https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg"
resp = requests.get(image_url)
image = Image.open(BytesIO(resp.content)).convert("RGB")
# ─────────────────────────────────────────────
# 4) Process image and text
# ─────────────────────────────────────────────
# Process image
image_inputs = processor.process_images([image])
image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
# Process text
text = "A photo of a tiger"
text_inputs = processor.process_queries([text])
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
# ─────────────────────────────────────────────
# 5) Get embeddings and score
# ─────────────────────────────────────────────
with torch.no_grad():
# Get image embedding
image_embedding = model(**image_inputs)
# Get text embedding
text_embedding = model(**text_inputs)
# Calculate similarity score
score = torch.matmul(text_embedding, image_embedding.T).item()
print(f"Similarity score between text and image: {score:.4f}")