from transformers.models.llava_next.modeling_llava_next import LlavaNextForConditionalGeneration import torch from transformers.models.llava_next.modeling_llava_next import unpad_image, get_anyres_image_grid_shape import numpy as np class LlavaNextWithCustomPacking(LlavaNextForConditionalGeneration): def pack_image_features( self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None, base_image_feature_location="last", ): """ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. Args: image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`) List of image feature tensor, each contains all the visual feature of all patches. image_sizes (`torch.Tensor` of shape `(num_images, 2)`) Actual image size of each images (H, W). vision_feature_select_strategy (`str`) The feature selection strategy used to select the vision feature from the vision backbone. image_newline (`torch.Tensor` of shape `(embed_dim)`) New line embedding vector. Returns: image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`) feature_lens (`List[int]`) token length of each image in image_features """ new_image_features = [] feature_lens = [] for image_idx, image_feature in enumerate(image_features): if image_feature.shape[0] > 1: base_image_feature = image_feature[0] image_feature = image_feature[1:] height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size num_patch_height, num_patch_width = get_anyres_image_grid_shape( image_sizes[image_idx], self.config.image_grid_pinpoints, self.config.vision_config.image_size, ) if ( np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0 and vision_feature_select_strategy == "default" ): logger.warning_once( "Image feature shape does not line up with the provided patch size. " "You may be using the `default` vision_feature_select_strategy with a" " visual encoder that does not have CLS." ) image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = unpad_image(image_feature, image_sizes[image_idx]) if image_newline is not None: image_feature = torch.cat( ( image_feature, image_newline[:, None, None] .expand(*image_feature.shape[:-1], 1) .to(image_feature.device, image_feature.dtype), ), dim=-1, ) image_feature = image_feature.flatten(1, 2).transpose(0, 1) if base_image_feature_location == "last": image_feature = torch.cat((image_feature, base_image_feature), dim=0) else: image_feature = torch.cat((base_image_feature, image_feature), dim=0) else: image_feature = image_feature[0] if image_newline is not None: image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) new_image_features.append(image_feature) feature_lens.append(image_feature.size(0)) image_features = torch.cat(new_image_features, dim=0) feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) return image_features, feature_lens def main(): import torch from transformers import AutoConfig # Load config and model model_id = "ibm-granite/granite-vision-3.1-2b-preview" config = AutoConfig.from_pretrained(model_id) model = LlavaNextWithCustomPacking.from_pretrained(model_id, config=config) # Dummy image features for 2 images (1 base + 3x3 patch grid flattened) B = 2 # batch size num_views = 3 num_patches = 729 embed_dim = model.config.text_config.hidden_size image_features = [ torch.randn(num_views, num_patches, embed_dim) for _ in range(B) ] image_sizes = torch.tensor([[384, 384], [384, 384]]) # H, W for each image # Call overridden pack_image_features packed_feats, lengths = model.pack_image_features( image_features=image_features, image_sizes=image_sizes, vision_feature_select_strategy="default", image_newline=model.image_newline, base_image_feature_location="last", ) print("Packed features shape:", packed_feats.shape) print("Feature lengths:", lengths) if __name__ == "__main__": main()