Adirazgold commited on
Commit
d0eb3d9
·
verified ·
1 Parent(s): 7de027c

Update modeling_colgranitevision.py

Browse files
Files changed (1) hide show
  1. modeling_colgranitevision.py +87 -10
modeling_colgranitevision.py CHANGED
@@ -1,18 +1,95 @@
1
  from typing import ClassVar, Optional
2
 
 
3
  import torch
4
  from torch import nn
5
- from transformers import LlavaNextConfig, \
6
- LlavaNextPreTrainedModel
 
7
 
8
- from .custom_llava_next import LlavaNextWithCustomPacking as LlavaNextForConditionalGeneration
9
- from .colgranitevision_config import ColGraniteVisionConfig
10
 
11
- # from transformers.models.paligemma.modeling_paligemma import (
12
- # PaliGemmaConfig,
13
- # PaliGemmaForConditionalGeneration,
14
- # PaliGemmaPreTrainedModel,
15
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  class ColGraniteVision(LlavaNextPreTrainedModel):
@@ -26,7 +103,7 @@ class ColGraniteVision(LlavaNextPreTrainedModel):
26
  def __init__(self, config: ColGraniteVisionConfig):
27
  super().__init__(config=config)
28
 
29
- model = LlavaNextForConditionalGeneration(config=config)
30
  if model.language_model._tied_weights_keys is not None:
31
  self._tied_weights_keys = [f"model.language_model.{k}" for k in model.language_model._tied_weights_keys]
32
  self.model = model
 
1
  from typing import ClassVar, Optional
2
 
3
+ import numpy as np
4
  import torch
5
  from torch import nn
6
+ from transformers import LlavaNextPreTrainedModel
7
+ from transformers.models.llava_next.modeling_llava_next import LlavaNextForConditionalGeneration
8
+ from transformers.models.llava_next.modeling_llava_next import unpad_image, get_anyres_image_grid_shape
9
 
10
+ from .colgranitevision_config import ColGraniteVisionConfig
 
11
 
12
+
13
+ class LlavaNextWithCustomPacking(LlavaNextForConditionalGeneration):
14
+ def pack_image_features(
15
+ self,
16
+ image_features,
17
+ image_sizes,
18
+ vision_feature_select_strategy,
19
+ image_newline=None,
20
+ base_image_feature_location="last",
21
+ ):
22
+ """
23
+ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
24
+
25
+ Args:
26
+ image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
27
+ List of image feature tensor, each contains all the visual feature of all patches.
28
+ image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
29
+ Actual image size of each images (H, W).
30
+ vision_feature_select_strategy (`str`)
31
+ The feature selection strategy used to select the vision feature from the vision backbone.
32
+ image_newline (`torch.Tensor` of shape `(embed_dim)`)
33
+ New line embedding vector.
34
+ Returns:
35
+ image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
36
+ feature_lens (`List[int]`)
37
+ token length of each image in image_features
38
+ """
39
+
40
+ new_image_features = []
41
+ feature_lens = []
42
+ for image_idx, image_feature in enumerate(image_features):
43
+ if image_feature.shape[0] > 1:
44
+ base_image_feature = image_feature[0]
45
+ image_feature = image_feature[1:]
46
+ height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
47
+
48
+ num_patch_height, num_patch_width = get_anyres_image_grid_shape(
49
+ image_sizes[image_idx],
50
+ self.config.image_grid_pinpoints,
51
+ self.config.vision_config.image_size,
52
+ )
53
+
54
+ if (
55
+ np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0
56
+ and vision_feature_select_strategy == "default"
57
+ ):
58
+ logger.warning_once(
59
+ "Image feature shape does not line up with the provided patch size. "
60
+ "You may be using the `default` vision_feature_select_strategy with a"
61
+ " visual encoder that does not have CLS."
62
+ )
63
+
64
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
65
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
66
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
67
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
68
+ if image_newline is not None:
69
+ image_feature = torch.cat(
70
+ (
71
+ image_feature,
72
+ image_newline[:, None, None]
73
+ .expand(*image_feature.shape[:-1], 1)
74
+ .to(image_feature.device, image_feature.dtype),
75
+ ),
76
+ dim=-1,
77
+ )
78
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
79
+ if base_image_feature_location == "last":
80
+ image_feature = torch.cat((image_feature, base_image_feature), dim=0)
81
+ else:
82
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
83
+
84
+ else:
85
+ image_feature = image_feature[0]
86
+ if image_newline is not None:
87
+ image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
88
+ new_image_features.append(image_feature)
89
+ feature_lens.append(image_feature.size(0))
90
+ image_features = torch.cat(new_image_features, dim=0)
91
+ feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
92
+ return image_features, feature_lens
93
 
94
 
95
  class ColGraniteVision(LlavaNextPreTrainedModel):
 
103
  def __init__(self, config: ColGraniteVisionConfig):
104
  super().__init__(config=config)
105
 
106
+ model = LlavaNextWithCustomPacking(config=config)
107
  if model.language_model._tied_weights_keys is not None:
108
  self._tied_weights_keys = [f"model.language_model.{k}" for k in model.language_model._tied_weights_keys]
109
  self.model = model