Update modeling_colgranitevision.py
Browse files- modeling_colgranitevision.py +87 -10
modeling_colgranitevision.py
CHANGED
@@ -1,18 +1,95 @@
|
|
1 |
from typing import ClassVar, Optional
|
2 |
|
|
|
3 |
import torch
|
4 |
from torch import nn
|
5 |
-
from transformers import
|
6 |
-
|
|
|
7 |
|
8 |
-
from .
|
9 |
-
from .colgranitevision_config import ColGraniteVisionConfig
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
class ColGraniteVision(LlavaNextPreTrainedModel):
|
@@ -26,7 +103,7 @@ class ColGraniteVision(LlavaNextPreTrainedModel):
|
|
26 |
def __init__(self, config: ColGraniteVisionConfig):
|
27 |
super().__init__(config=config)
|
28 |
|
29 |
-
model =
|
30 |
if model.language_model._tied_weights_keys is not None:
|
31 |
self._tied_weights_keys = [f"model.language_model.{k}" for k in model.language_model._tied_weights_keys]
|
32 |
self.model = model
|
|
|
1 |
from typing import ClassVar, Optional
|
2 |
|
3 |
+
import numpy as np
|
4 |
import torch
|
5 |
from torch import nn
|
6 |
+
from transformers import LlavaNextPreTrainedModel
|
7 |
+
from transformers.models.llava_next.modeling_llava_next import LlavaNextForConditionalGeneration
|
8 |
+
from transformers.models.llava_next.modeling_llava_next import unpad_image, get_anyres_image_grid_shape
|
9 |
|
10 |
+
from .colgranitevision_config import ColGraniteVisionConfig
|
|
|
11 |
|
12 |
+
|
13 |
+
class LlavaNextWithCustomPacking(LlavaNextForConditionalGeneration):
|
14 |
+
def pack_image_features(
|
15 |
+
self,
|
16 |
+
image_features,
|
17 |
+
image_sizes,
|
18 |
+
vision_feature_select_strategy,
|
19 |
+
image_newline=None,
|
20 |
+
base_image_feature_location="last",
|
21 |
+
):
|
22 |
+
"""
|
23 |
+
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
|
27 |
+
List of image feature tensor, each contains all the visual feature of all patches.
|
28 |
+
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
29 |
+
Actual image size of each images (H, W).
|
30 |
+
vision_feature_select_strategy (`str`)
|
31 |
+
The feature selection strategy used to select the vision feature from the vision backbone.
|
32 |
+
image_newline (`torch.Tensor` of shape `(embed_dim)`)
|
33 |
+
New line embedding vector.
|
34 |
+
Returns:
|
35 |
+
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
|
36 |
+
feature_lens (`List[int]`)
|
37 |
+
token length of each image in image_features
|
38 |
+
"""
|
39 |
+
|
40 |
+
new_image_features = []
|
41 |
+
feature_lens = []
|
42 |
+
for image_idx, image_feature in enumerate(image_features):
|
43 |
+
if image_feature.shape[0] > 1:
|
44 |
+
base_image_feature = image_feature[0]
|
45 |
+
image_feature = image_feature[1:]
|
46 |
+
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
47 |
+
|
48 |
+
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
49 |
+
image_sizes[image_idx],
|
50 |
+
self.config.image_grid_pinpoints,
|
51 |
+
self.config.vision_config.image_size,
|
52 |
+
)
|
53 |
+
|
54 |
+
if (
|
55 |
+
np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0
|
56 |
+
and vision_feature_select_strategy == "default"
|
57 |
+
):
|
58 |
+
logger.warning_once(
|
59 |
+
"Image feature shape does not line up with the provided patch size. "
|
60 |
+
"You may be using the `default` vision_feature_select_strategy with a"
|
61 |
+
" visual encoder that does not have CLS."
|
62 |
+
)
|
63 |
+
|
64 |
+
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
|
65 |
+
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
66 |
+
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
67 |
+
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
68 |
+
if image_newline is not None:
|
69 |
+
image_feature = torch.cat(
|
70 |
+
(
|
71 |
+
image_feature,
|
72 |
+
image_newline[:, None, None]
|
73 |
+
.expand(*image_feature.shape[:-1], 1)
|
74 |
+
.to(image_feature.device, image_feature.dtype),
|
75 |
+
),
|
76 |
+
dim=-1,
|
77 |
+
)
|
78 |
+
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
79 |
+
if base_image_feature_location == "last":
|
80 |
+
image_feature = torch.cat((image_feature, base_image_feature), dim=0)
|
81 |
+
else:
|
82 |
+
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
83 |
+
|
84 |
+
else:
|
85 |
+
image_feature = image_feature[0]
|
86 |
+
if image_newline is not None:
|
87 |
+
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
|
88 |
+
new_image_features.append(image_feature)
|
89 |
+
feature_lens.append(image_feature.size(0))
|
90 |
+
image_features = torch.cat(new_image_features, dim=0)
|
91 |
+
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
|
92 |
+
return image_features, feature_lens
|
93 |
|
94 |
|
95 |
class ColGraniteVision(LlavaNextPreTrainedModel):
|
|
|
103 |
def __init__(self, config: ColGraniteVisionConfig):
|
104 |
super().__init__(config=config)
|
105 |
|
106 |
+
model = LlavaNextWithCustomPacking(config=config)
|
107 |
if model.language_model._tied_weights_keys is not None:
|
108 |
self._tied_weights_keys = [f"model.language_model.{k}" for k in model.language_model._tied_weights_keys]
|
109 |
self.model = model
|