|
from typing import ClassVar, Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import LlavaNextConfig, \
|
|
LlavaNextPreTrainedModel
|
|
|
|
from custom_llava_next import LlavaNextWithCustomPacking as LlavaNextForConditionalGeneration
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ColGraniteVision(LlavaNextPreTrainedModel):
|
|
"""
|
|
ColGraniteVision model implementation.
|
|
"""
|
|
|
|
main_input_name: ClassVar[str] = "doc_input_ids"
|
|
|
|
def __init__(self, config: LlavaNextConfig):
|
|
super().__init__(config=config)
|
|
|
|
model = LlavaNextForConditionalGeneration(config=config)
|
|
if model.language_model._tied_weights_keys is not None:
|
|
self._tied_weights_keys = [f"model.language_model.{k}" for k in model.language_model._tied_weights_keys]
|
|
self.model = model
|
|
|
|
|
|
|
|
self.dim = 128
|
|
self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
|
|
|
|
self.post_init()
|
|
|
|
def forward(self, *args, **kwargs) -> torch.Tensor:
|
|
|
|
kwargs.pop("output_hidden_states", None)
|
|
if "pixel_values" in kwargs:
|
|
kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype=self.dtype)
|
|
|
|
outputs = self.model(*args, output_hidden_states=True, **kwargs)
|
|
last_hidden_states = outputs.hidden_states[-1]
|
|
|
|
attention_mask = kwargs["attention_mask"]
|
|
if "pixel_values" in kwargs:
|
|
input_ids = kwargs['input_ids']
|
|
image_mask = (input_ids == self.config.image_token_index)
|
|
|
|
N, M = image_mask.shape
|
|
|
|
idx = torch.arange(M, device=image_mask.device).expand(N, M)
|
|
|
|
masked_idx = torch.where(image_mask, idx, torch.tensor(-1, device=image_mask.device))
|
|
topk_values, _ = torch.topk(masked_idx, k=729, dim=1)
|
|
last_k_indices, _ = torch.sort(topk_values, dim=1)
|
|
last_k_indices_exp = last_k_indices.unsqueeze(-1).expand(-1, -1, last_hidden_states.size(-1))
|
|
last_hidden_states = torch.gather(last_hidden_states, 1, last_k_indices_exp)
|
|
attention_mask = torch.gather(attention_mask, 1, last_k_indices)
|
|
|
|
attention_mask = attention_mask.unsqueeze(-1)
|
|
|
|
proj = self.custom_text_proj(last_hidden_states)
|
|
|
|
|
|
proj = proj / (proj.norm(dim=-1, keepdim=True) + 1e-8)
|
|
|
|
|
|
proj = proj * attention_mask
|
|
|
|
return proj
|
|
|
|
def get_input_embeddings(self):
|
|
return self.model.language_model.get_input_embeddings()
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.model.language_model.set_input_embeddings(value)
|
|
|
|
def get_output_embeddings(self):
|
|
return self.model.language_model.get_output_embeddings()
|
|
|
|
def set_output_embeddings(self, new_embeddings):
|
|
self.model.language_model.set_output_embeddings(new_embeddings)
|
|
|
|
def set_decoder(self, decoder):
|
|
self.model.language_model.set_decoder(decoder)
|
|
|
|
def get_decoder(self):
|
|
return self.model.language_model.get_decoder()
|
|
|
|
def tie_weights(self):
|
|
return self.model.language_model.tie_weights()
|
|
|
|
def resize_token_embeddings(
|
|
self,
|
|
new_num_tokens: Optional[int] = None,
|
|
pad_to_multiple_of=None,
|
|
) -> nn.Embedding:
|
|
model_embeds = self.model.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
|
|
|
|
|
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
|
self.config.vocab_size = model_embeds.num_embeddings
|
|
self.model.vocab_size = model_embeds.num_embeddings
|
|
|
|
return model_embeds
|
|
|
|
@property
|
|
def patch_size(self) -> int:
|
|
return self.model.vision_tower.config.patch_size
|
|
|