|
from typing import ClassVar, Optional |
|
|
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
from transformers import LlavaNextPreTrainedModel |
|
from transformers.models.llava_next.modeling_llava_next import LlavaNextForConditionalGeneration |
|
from transformers.models.llava_next.modeling_llava_next import unpad_image, get_anyres_image_grid_shape |
|
|
|
from .granite_vision_embedding_config import GraniteVisionEmbConfig |
|
|
|
class LlavaNextWithCustomPacking(LlavaNextForConditionalGeneration): |
|
|
|
def pack_image_features( |
|
self, |
|
image_features, |
|
image_sizes, |
|
vision_feature_select_strategy, |
|
image_newline=None |
|
): |
|
""" |
|
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. |
|
|
|
Args: |
|
image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`) |
|
List of image feature tensor, each contains all the visual feature of all patches. |
|
image_sizes (`torch.Tensor` of shape `(num_images, 2)`) |
|
Actual image size of each images (H, W). |
|
vision_feature_select_strategy (`str`) |
|
The feature selection strategy used to select the vision feature from the vision backbone. |
|
image_newline (`torch.Tensor` of shape `(embed_dim)`) |
|
New line embedding vector. |
|
Returns: |
|
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`) |
|
feature_lens (`List[int]`) |
|
token length of each image in image_features |
|
""" |
|
|
|
base_image_feature_location = self.config.base_image_feature_location |
|
new_image_features = [] |
|
feature_lens = [] |
|
for image_idx, image_feature in enumerate(image_features): |
|
if image_feature.shape[0] > 1: |
|
base_image_feature = image_feature[0] |
|
image_feature = image_feature[1:] |
|
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size |
|
|
|
num_patch_height, num_patch_width = get_anyres_image_grid_shape( |
|
image_sizes[image_idx], |
|
self.config.image_grid_pinpoints, |
|
self.config.vision_config.image_size, |
|
) |
|
|
|
if ( |
|
np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0 |
|
and vision_feature_select_strategy == "default" |
|
): |
|
print( |
|
"Image feature shape does not line up with the provided patch size. " |
|
"You may be using the `default` vision_feature_select_strategy with a" |
|
" visual encoder that does not have CLS." |
|
) |
|
|
|
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) |
|
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() |
|
image_feature = image_feature.flatten(1, 2).flatten(2, 3) |
|
image_feature = unpad_image(image_feature, image_sizes[image_idx]) |
|
if image_newline is not None: |
|
image_feature = torch.cat( |
|
( |
|
image_feature, |
|
image_newline[:, None, None] |
|
.expand(*image_feature.shape[:-1], 1) |
|
.to(image_feature.device, image_feature.dtype), |
|
), |
|
dim=-1, |
|
) |
|
image_feature = image_feature.flatten(1, 2).transpose(0, 1) |
|
if base_image_feature_location == "last": |
|
image_feature = torch.cat((image_feature, base_image_feature), dim=0) |
|
else: |
|
image_feature = torch.cat((base_image_feature, image_feature), dim=0) |
|
|
|
else: |
|
image_feature = image_feature[0] |
|
if image_newline is not None: |
|
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) |
|
new_image_features.append(image_feature) |
|
feature_lens.append(image_feature.size(0)) |
|
image_features = torch.cat(new_image_features, dim=0) |
|
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) |
|
return image_features, feature_lens |
|
|
|
|
|
class GraniteVisionEmb(LlavaNextPreTrainedModel): |
|
""" |
|
GraniteVisionEmb model implementation. |
|
""" |
|
|
|
main_input_name: ClassVar[str] = "doc_input_ids" |
|
config_class = GraniteVisionEmbConfig |
|
|
|
def __init__(self, config: GraniteVisionEmbConfig): |
|
super().__init__(config=config) |
|
|
|
model = LlavaNextWithCustomPacking(config=config) |
|
if model.language_model._tied_weights_keys is not None: |
|
self._tied_weights_keys = [f"model.language_model.{k}" for k in model.language_model._tied_weights_keys] |
|
self.model = model |
|
|
|
self.dim = 128 |
|
self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim) |
|
|
|
self.post_init() |
|
|
|
def forward(self, *args, **kwargs) -> torch.Tensor: |
|
|
|
kwargs.pop("output_hidden_states", None) |
|
if "pixel_values" in kwargs: |
|
kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype=self.dtype) |
|
|
|
outputs = self.model(*args, output_hidden_states=True, **kwargs) |
|
last_hidden_states = outputs.hidden_states[-1] |
|
|
|
attention_mask = kwargs["attention_mask"] |
|
if "pixel_values" in kwargs: |
|
input_ids = kwargs['input_ids'] |
|
image_mask = (input_ids == self.config.image_token_index) |
|
|
|
N, M = image_mask.shape |
|
|
|
idx = torch.arange(M, device=image_mask.device).expand(N, M) |
|
|
|
masked_idx = torch.where(image_mask, idx, torch.tensor(-1, device=image_mask.device)) |
|
topk_values, _ = torch.topk(masked_idx, k=729, dim=1) |
|
last_k_indices, _ = torch.sort(topk_values, dim=1) |
|
last_k_indices_exp = last_k_indices.unsqueeze(-1).expand(-1, -1, last_hidden_states.size(-1)) |
|
last_hidden_states = torch.gather(last_hidden_states, 1, last_k_indices_exp) |
|
attention_mask = torch.gather(attention_mask, 1, last_k_indices) |
|
|
|
attention_mask = attention_mask.unsqueeze(-1) |
|
|
|
proj = self.custom_text_proj(last_hidden_states) |
|
|
|
|
|
proj = proj / (proj.norm(dim=-1, keepdim=True) + 1e-8) |
|
|
|
|
|
proj = proj * attention_mask |
|
|
|
return proj |
|
|
|
def get_input_embeddings(self): |
|
return self.model.language_model.get_input_embeddings() |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.language_model.set_input_embeddings(value) |
|
|
|
def get_output_embeddings(self): |
|
return self.model.language_model.get_output_embeddings() |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.model.language_model.set_output_embeddings(new_embeddings) |
|
|
|
def set_decoder(self, decoder): |
|
self.model.language_model.set_decoder(decoder) |
|
|
|
def get_decoder(self): |
|
return self.model.language_model.get_decoder() |
|
|
|
def tie_weights(self): |
|
return self.model.language_model.tie_weights() |
|
|
|
def resize_token_embeddings( |
|
self, |
|
new_num_tokens: Optional[int] = None, |
|
pad_to_multiple_of=None, |
|
) -> nn.Embedding: |
|
model_embeds = self.model.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) |
|
|
|
|
|
self.config.text_config.vocab_size = model_embeds.num_embeddings |
|
self.config.vocab_size = model_embeds.num_embeddings |
|
self.model.vocab_size = model_embeds.num_embeddings |
|
|
|
return model_embeds |
|
|
|
@property |
|
def patch_size(self) -> int: |
|
return self.model.vision_tower.config.patch_size |
|
|