File size: 5,557 Bytes
c8ad458 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from transformers.models.llava_next.modeling_llava_next import LlavaNextForConditionalGeneration
import torch
from transformers.models.llava_next.modeling_llava_next import unpad_image, get_anyres_image_grid_shape
import numpy as np
class LlavaNextWithCustomPacking(LlavaNextForConditionalGeneration):
def pack_image_features(
self,
image_features,
image_sizes,
vision_feature_select_strategy,
image_newline=None,
base_image_feature_location="last",
):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
Args:
image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
List of image feature tensor, each contains all the visual feature of all patches.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_select_strategy (`str`)
The feature selection strategy used to select the vision feature from the vision backbone.
image_newline (`torch.Tensor` of shape `(embed_dim)`)
New line embedding vector.
Returns:
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
feature_lens (`List[int]`)
token length of each image in image_features
"""
new_image_features = []
feature_lens = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
if (
np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0
and vision_feature_select_strategy == "default"
):
logger.warning_once(
"Image feature shape does not line up with the provided patch size. "
"You may be using the `default` vision_feature_select_strategy with a"
" visual encoder that does not have CLS."
)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
if image_newline is not None:
image_feature = torch.cat(
(
image_feature,
image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device, image_feature.dtype),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
if base_image_feature_location == "last":
image_feature = torch.cat((image_feature, base_image_feature), dim=0)
else:
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if image_newline is not None:
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
new_image_features.append(image_feature)
feature_lens.append(image_feature.size(0))
image_features = torch.cat(new_image_features, dim=0)
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
return image_features, feature_lens
def main():
import torch
from transformers import AutoConfig
# Load config and model
model_id = "ibm-granite/granite-vision-3.1-2b-preview"
config = AutoConfig.from_pretrained(model_id)
model = LlavaNextWithCustomPacking.from_pretrained(model_id, config=config)
# Dummy image features for 2 images (1 base + 3x3 patch grid flattened)
B = 2 # batch size
num_views = 3
num_patches = 729
embed_dim = model.config.text_config.hidden_size
image_features = [
torch.randn(num_views, num_patches, embed_dim) for _ in range(B)
]
image_sizes = torch.tensor([[384, 384], [384, 384]]) # H, W for each image
# Call overridden pack_image_features
packed_feats, lengths = model.pack_image_features(
image_features=image_features,
image_sizes=image_sizes,
vision_feature_select_strategy="default",
image_newline=model.image_newline,
base_image_feature_location="last",
)
print("Packed features shape:", packed_feats.shape)
print("Feature lengths:", lengths)
if __name__ == "__main__":
main() |