from typing import ClassVar, Optional import numpy as np import torch from torch import nn from transformers import LlavaNextPreTrainedModel from transformers.models.llava_next.modeling_llava_next import LlavaNextForConditionalGeneration from transformers.models.llava_next.modeling_llava_next import unpad_image, get_anyres_image_grid_shape from .granite_vision_embedding_config import GraniteVisionEmbConfig class LlavaNextWithCustomPacking(LlavaNextForConditionalGeneration): def pack_image_features( self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None ): """ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. Args: image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`) List of image feature tensor, each contains all the visual feature of all patches. image_sizes (`torch.Tensor` of shape `(num_images, 2)`) Actual image size of each images (H, W). vision_feature_select_strategy (`str`) The feature selection strategy used to select the vision feature from the vision backbone. image_newline (`torch.Tensor` of shape `(embed_dim)`) New line embedding vector. Returns: image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`) feature_lens (`List[int]`) token length of each image in image_features """ base_image_feature_location = self.config.base_image_feature_location new_image_features = [] feature_lens = [] for image_idx, image_feature in enumerate(image_features): if image_feature.shape[0] > 1: base_image_feature = image_feature[0] image_feature = image_feature[1:] height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size num_patch_height, num_patch_width = get_anyres_image_grid_shape( image_sizes[image_idx], self.config.image_grid_pinpoints, self.config.vision_config.image_size, ) if ( np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0 and vision_feature_select_strategy == "default" ): print( "Image feature shape does not line up with the provided patch size. " "You may be using the `default` vision_feature_select_strategy with a" " visual encoder that does not have CLS." ) image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = unpad_image(image_feature, image_sizes[image_idx]) if image_newline is not None: image_feature = torch.cat( ( image_feature, image_newline[:, None, None] .expand(*image_feature.shape[:-1], 1) .to(image_feature.device, image_feature.dtype), ), dim=-1, ) image_feature = image_feature.flatten(1, 2).transpose(0, 1) if base_image_feature_location == "last": image_feature = torch.cat((image_feature, base_image_feature), dim=0) else: image_feature = torch.cat((base_image_feature, image_feature), dim=0) else: image_feature = image_feature[0] if image_newline is not None: image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) new_image_features.append(image_feature) feature_lens.append(image_feature.size(0)) image_features = torch.cat(new_image_features, dim=0) feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) return image_features, feature_lens class GraniteVisionEmb(LlavaNextPreTrainedModel): """ GraniteVisionEmb model implementation. """ main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related config_class = GraniteVisionEmbConfig def __init__(self, config: GraniteVisionEmbConfig): super().__init__(config=config) model = LlavaNextWithCustomPacking(config=config) if model.language_model._tied_weights_keys is not None: self._tied_weights_keys = [f"model.language_model.{k}" for k in model.language_model._tied_weights_keys] self.model = model self.dim = 128 self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim) self.post_init() def forward(self, *args, **kwargs) -> torch.Tensor: # Delete output_hidden_states from kwargs kwargs.pop("output_hidden_states", None) if "pixel_values" in kwargs: kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype=self.dtype) outputs = self.model(*args, output_hidden_states=True, **kwargs) # (batch_size, sequence_length, hidden_size) last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size) attention_mask = kwargs["attention_mask"] if "pixel_values" in kwargs: input_ids = kwargs['input_ids'] image_mask = (input_ids == self.config.image_token_index) # inputs_embeds = last_hidden_states.masked_scatter(image_mask) N, M = image_mask.shape # Create an index matrix: each row is 0, 1, ..., M-1 idx = torch.arange(M, device=image_mask.device).expand(N, M) # Replace False positions with -1 so they are ignored by topk (since all valid indices are >=0) masked_idx = torch.where(image_mask, idx, torch.tensor(-1, device=image_mask.device)) topk_values, _ = torch.topk(masked_idx, k=729, dim=1) last_k_indices, _ = torch.sort(topk_values, dim=1) last_k_indices_exp = last_k_indices.unsqueeze(-1).expand(-1, -1, last_hidden_states.size(-1)) last_hidden_states = torch.gather(last_hidden_states, 1, last_k_indices_exp) attention_mask = torch.gather(attention_mask, 1, last_k_indices) attention_mask = attention_mask.unsqueeze(-1) proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim) # L2 normalization proj = proj / (proj.norm(dim=-1, keepdim=True) + 1e-8) # proj = proj * kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, dim) proj = proj * attention_mask # (batch_size, sequence_length, dim) return proj def get_input_embeddings(self): return self.model.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.model.language_model.set_input_embeddings(value) def get_output_embeddings(self): return self.model.language_model.get_output_embeddings() def set_output_embeddings(self, new_embeddings): self.model.language_model.set_output_embeddings(new_embeddings) def set_decoder(self, decoder): self.model.language_model.set_decoder(decoder) def get_decoder(self): return self.model.language_model.get_decoder() def tie_weights(self): return self.model.language_model.tie_weights() def resize_token_embeddings( self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None, ) -> nn.Embedding: model_embeds = self.model.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) # Update vocab size self.config.text_config.vocab_size = model_embeds.num_embeddings self.config.vocab_size = model_embeds.num_embeddings self.model.vocab_size = model_embeds.num_embeddings return model_embeds @property def patch_size(self) -> int: return self.model.vision_tower.config.patch_size