ValueError: Number of images does not match number of special image tokens in the input text. Got 256 image tokens in the text but 256 tokens from image embeddings.

#91
by zml31415 - opened

Hello everyone
I get the error in the title when using inputs_embeds together with pixel_values. Small code example to reproduce:

import requests
from PIL import Image
from io import BytesIO
from transformers.models.gemma3 import modeling_gemma3
from transformers import AutoProcessor
import torch


model_name = "google/gemma-3-12b-it"
model = modeling_gemma3.Gemma3ForConditionalGeneration.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
device = model.device

processor = AutoProcessor.from_pretrained(model_name, use_fast=True)

img = Image.open(BytesIO(requests.get("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg").content)).convert("RGB")
prompt = ["Analyse and explain the image: <start_of_image>\n"]

inputs = processor(text=prompt, images=img, return_tensors="pt").to(device)

pixel_values = inputs.pixel_values  
input_ids = inputs.input_ids

inputs_embeds = model.get_input_embeddings()(input_ids)
outputs = model(
    inputs_embeds=inputs_embeds,
    #input_ids=input_ids,
    pixel_values=pixel_values,
    use_cache=False
)

The issue is that in the modeling_gemma3.py in line 898 and 899 there is this creation of a mask that identifies all the placeholders for the outputs of the vision tower, respectively multi_modal_projector:

if input_ids is None:
                special_image_mask = inputs_embeds == self.get_input_embeddings()(
                    torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)

which seems to mess something up with the special_image_mask since the alternative path (lines 901-903) for the input_ids works perfectly fine:

else:
                special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
                special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)

because of the if inputs_ids is None: case the later check (line 905) that throws the error

if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():

has a wired thing with the numbers. In my case inputs_embeds[special_image_mask].numel() is 983041 while image_features.numel() is 983040. Exactly 1 off. This seems wired to me since image_features.shape is torch.Size([1, 256, 3840]), therefore 256*3840 is the expected 983040. And inputs_embeds.shape is torch.Size([1, 269, 3840]), applying a correct mask and considering the 13 text tokens, the relevant tensor should also be torch.Size([1, 256, 3840]). Therefore again 983040 for the .numel(), but it is 1 more. i don't get why this is. My current workaround is to feed inputs_embeds as well inputs_ids, disable the check (line 867 and 868) if one feeds both effectively disabling the strange if input_ids is None mask calculation (lines 897 to 899). Then everything works fine. But i can not work with custom modifications in the modeling_gemma3.py code for ever :)
Is there something wrong that i do or is there something strange in the modeling_gemma3.py code?

Hi @zml31415 ,

Thanks for reaching out to us, the google/gemma-3-27b-it or google/gemma-3-12b-it are instruction tuned (IT) models they follows a specific kind of prompt and chat template to process your query/prompt, which means any IT Gemma model follows a role based instructions to process your request. Please find the following sample prompt message for your reference:

messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}]
},
{
"role": "user",
"content": [
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG"},
{"type": "text", "text": "What animal is on the candy?"}
]
}
]

inputs = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True,
return_dict=True, return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)

Please adjust your prompt based on the above given sample prompt message. Thanks for your interest in Gemma models.

Thanks.

Sign up or log in to comment