File size: 5,557 Bytes
c8ad458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from transformers.models.llava_next.modeling_llava_next import LlavaNextForConditionalGeneration
import torch
from transformers.models.llava_next.modeling_llava_next import unpad_image, get_anyres_image_grid_shape
import numpy as np

class LlavaNextWithCustomPacking(LlavaNextForConditionalGeneration):
    def pack_image_features(

        self,

        image_features,

        image_sizes,

        vision_feature_select_strategy,

        image_newline=None,

        base_image_feature_location="last",

    ):
        """

        Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.



        Args:

            image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)

                List of image feature tensor, each contains all the visual feature of all patches.

            image_sizes (`torch.Tensor` of shape `(num_images, 2)`)

                Actual image size of each images (H, W).

            vision_feature_select_strategy (`str`)

                The feature selection strategy used to select the vision feature from the vision backbone.

            image_newline (`torch.Tensor` of shape `(embed_dim)`)

                New line embedding vector.

        Returns:

            image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)

            feature_lens (`List[int]`)

                token length of each image in image_features

        """


        new_image_features = []
        feature_lens = []
        for image_idx, image_feature in enumerate(image_features):
            if image_feature.shape[0] > 1:
                base_image_feature = image_feature[0]
                image_feature = image_feature[1:]
                height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size

                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    image_sizes[image_idx],
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )

                if (
                    np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0
                    and vision_feature_select_strategy == "default"
                ):
                    logger.warning_once(
                        "Image feature shape does not line up with the provided patch size. "
                        "You may be using the `default` vision_feature_select_strategy with a"
                        " visual encoder that does not have CLS."
                    )

                image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
                image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
                image_feature = image_feature.flatten(1, 2).flatten(2, 3)
                image_feature = unpad_image(image_feature, image_sizes[image_idx])
                if image_newline is not None:
                    image_feature = torch.cat(
                        (
                            image_feature,
                            image_newline[:, None, None]
                            .expand(*image_feature.shape[:-1], 1)
                            .to(image_feature.device, image_feature.dtype),
                        ),
                        dim=-1,
                    )
                image_feature = image_feature.flatten(1, 2).transpose(0, 1)
                if base_image_feature_location == "last":
                    image_feature = torch.cat((image_feature, base_image_feature), dim=0)
                else:
                    image_feature = torch.cat((base_image_feature, image_feature), dim=0)

            else:
                image_feature = image_feature[0]
                if image_newline is not None:
                    image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
            new_image_features.append(image_feature)
            feature_lens.append(image_feature.size(0))
        image_features = torch.cat(new_image_features, dim=0)
        feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
        return image_features, feature_lens



def main():
    import torch
    from transformers import AutoConfig

    # Load config and model
    model_id = "ibm-granite/granite-vision-3.1-2b-preview"
    config = AutoConfig.from_pretrained(model_id)
    model = LlavaNextWithCustomPacking.from_pretrained(model_id, config=config)

    # Dummy image features for 2 images (1 base + 3x3 patch grid flattened)
    B = 2  # batch size
    num_views = 3
    num_patches = 729
    embed_dim = model.config.text_config.hidden_size
    image_features = [
        torch.randn(num_views, num_patches, embed_dim) for _ in range(B)
    ]
    image_sizes = torch.tensor([[384, 384], [384, 384]])  # H, W for each image

    # Call overridden pack_image_features
    packed_feats, lengths = model.pack_image_features(
        image_features=image_features,
        image_sizes=image_sizes,
        vision_feature_select_strategy="default",
        image_newline=model.image_newline,
        base_image_feature_location="last",
    )

    print("Packed features shape:", packed_feats.shape)
    print("Feature lengths:", lengths)

if __name__ == "__main__":
    main()