Adirazgold commited on
Commit
27836c1
·
verified ·
1 Parent(s): d957813

Upload 3 files

Browse files
modeling_colgranitevision/config.json ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "adapter_path": "",
3
+ "architectures": [
4
+ "ColGraniteVision"
5
+ ],
6
+ "auto_map": {
7
+ "AutoModelForVision2Seq": "modeling_colgranitevision.ColGraniteVision",
8
+ "AutoProcessor": "processing_colgranitevision.ColGraniteVisionProcessor"
9
+ },
10
+ "base_model": "ibm-granite/granite-vision-3.1-2b-preview",
11
+ "emb_dim_doc": 128,
12
+ "emb_dim_query": 128,
13
+ "image_grid_pinpoints": [
14
+ [
15
+ 384,
16
+ 768
17
+ ],
18
+ [
19
+ 384,
20
+ 1152
21
+ ],
22
+ [
23
+ 384,
24
+ 1536
25
+ ],
26
+ [
27
+ 384,
28
+ 1920
29
+ ],
30
+ [
31
+ 384,
32
+ 2304
33
+ ],
34
+ [
35
+ 384,
36
+ 2688
37
+ ],
38
+ [
39
+ 384,
40
+ 3072
41
+ ],
42
+ [
43
+ 384,
44
+ 3456
45
+ ],
46
+ [
47
+ 384,
48
+ 3840
49
+ ],
50
+ [
51
+ 768,
52
+ 384
53
+ ],
54
+ [
55
+ 768,
56
+ 768
57
+ ],
58
+ [
59
+ 768,
60
+ 1152
61
+ ],
62
+ [
63
+ 768,
64
+ 1536
65
+ ],
66
+ [
67
+ 768,
68
+ 1920
69
+ ],
70
+ [
71
+ 1152,
72
+ 384
73
+ ],
74
+ [
75
+ 1152,
76
+ 768
77
+ ],
78
+ [
79
+ 1152,
80
+ 1152
81
+ ],
82
+ [
83
+ 1536,
84
+ 384
85
+ ],
86
+ [
87
+ 1536,
88
+ 768
89
+ ],
90
+ [
91
+ 1920,
92
+ 384
93
+ ],
94
+ [
95
+ 1920,
96
+ 768
97
+ ],
98
+ [
99
+ 2304,
100
+ 384
101
+ ],
102
+ [
103
+ 2688,
104
+ 384
105
+ ],
106
+ [
107
+ 3072,
108
+ 384
109
+ ],
110
+ [
111
+ 3456,
112
+ 384
113
+ ],
114
+ [
115
+ 3840,
116
+ 384
117
+ ]
118
+ ],
119
+ "image_seq_length": 576,
120
+ "image_token_index": 49155,
121
+ "model_type": "llava_next",
122
+ "multimodal_projector_bias": true,
123
+ "projector_hidden_act": "gelu",
124
+ "text_config": {
125
+ "architectures": [
126
+ "GraniteForCausalLM"
127
+ ],
128
+ "attention_dropout": 0.1,
129
+ "attention_multiplier": 0.015625,
130
+ "bos_token_id": 0,
131
+ "embedding_multiplier": 12.0,
132
+ "eos_token_id": 0,
133
+ "hidden_size": 2048,
134
+ "intermediate_size": 8192,
135
+ "logits_scaling": 8.0,
136
+ "max_position_embeddings": 16384,
137
+ "model_type": "granite",
138
+ "num_hidden_layers": 40,
139
+ "num_key_value_heads": 8,
140
+ "pad_token_id": 0,
141
+ "residual_multiplier": 0.22,
142
+ "rms_norm_eps": 1e-05,
143
+ "rope_theta": 300000,
144
+ "tie_word_embeddings": true,
145
+ "torch_dtype": "bfloat16",
146
+ "vocab_size": 49156
147
+ },
148
+ "torch_dtype": "float32",
149
+ "transformers_version": "4.50.0.dev0",
150
+ "use_image_newline_parameter": true,
151
+ "vision_config": {
152
+ "hidden_size": 1152,
153
+ "image_size": 384,
154
+ "intermediate_size": 4304,
155
+ "model_type": "siglip_vision_model",
156
+ "num_attention_heads": 16,
157
+ "num_hidden_layers": 27,
158
+ "patch_size": 14
159
+ },
160
+ "vision_feature_layer": [
161
+ -24,
162
+ -20,
163
+ -12,
164
+ -1
165
+ ],
166
+ "vision_feature_select_strategy": "full"
167
+ }
modeling_colgranitevision/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ddf363fafb46eb2bda156c288b4023345220f243dfa40cdf8dea7e4e3e2f6ab0
3
+ size 113333568
modeling_colgranitevision/modeling_colgranitevision.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import ClassVar, Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers import LlavaNextConfig, \
6
+ LlavaNextPreTrainedModel
7
+
8
+ from custom_llava_next import LlavaNextWithCustomPacking as LlavaNextForConditionalGeneration
9
+
10
+
11
+ # from transformers.models.paligemma.modeling_paligemma import (
12
+ # PaliGemmaConfig,
13
+ # PaliGemmaForConditionalGeneration,
14
+ # PaliGemmaPreTrainedModel,
15
+ # )
16
+
17
+
18
+ class ColGraniteVision(LlavaNextPreTrainedModel):
19
+ """
20
+ ColGraniteVision model implementation.
21
+ """
22
+
23
+ main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
24
+
25
+ def __init__(self, config: LlavaNextConfig):
26
+ super().__init__(config=config)
27
+
28
+ model = LlavaNextForConditionalGeneration(config=config)
29
+ if model.language_model._tied_weights_keys is not None:
30
+ self._tied_weights_keys = [f"model.language_model.{k}" for k in model.language_model._tied_weights_keys]
31
+ self.model = model
32
+
33
+ # TODO: Wait for ColPali2 to create a ColPaliConfig to allow specifying the embedding dimension.
34
+ # We could do it now but it would break all the models trying to load the model from the checkpoint.
35
+ self.dim = 128
36
+ self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
37
+
38
+ self.post_init()
39
+
40
+ def forward(self, *args, **kwargs) -> torch.Tensor:
41
+ # Delete output_hidden_states from kwargs
42
+ kwargs.pop("output_hidden_states", None)
43
+ if "pixel_values" in kwargs:
44
+ kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype=self.dtype)
45
+
46
+ outputs = self.model(*args, output_hidden_states=True, **kwargs) # (batch_size, sequence_length, hidden_size)
47
+ last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
48
+
49
+ attention_mask = kwargs["attention_mask"]
50
+ if "pixel_values" in kwargs:
51
+ input_ids = kwargs['input_ids']
52
+ image_mask = (input_ids == self.config.image_token_index)
53
+ # inputs_embeds = last_hidden_states.masked_scatter(image_mask)
54
+ N, M = image_mask.shape
55
+ # Create an index matrix: each row is 0, 1, ..., M-1
56
+ idx = torch.arange(M, device=image_mask.device).expand(N, M)
57
+ # Replace False positions with -1 so they are ignored by topk (since all valid indices are >=0)
58
+ masked_idx = torch.where(image_mask, idx, torch.tensor(-1, device=image_mask.device))
59
+ topk_values, _ = torch.topk(masked_idx, k=729, dim=1)
60
+ last_k_indices, _ = torch.sort(topk_values, dim=1)
61
+ last_k_indices_exp = last_k_indices.unsqueeze(-1).expand(-1, -1, last_hidden_states.size(-1))
62
+ last_hidden_states = torch.gather(last_hidden_states, 1, last_k_indices_exp)
63
+ attention_mask = torch.gather(attention_mask, 1, last_k_indices)
64
+
65
+ attention_mask = attention_mask.unsqueeze(-1)
66
+
67
+ proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)
68
+
69
+ # L2 normalization
70
+ proj = proj / (proj.norm(dim=-1, keepdim=True) + 1e-8)
71
+
72
+ # proj = proj * kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, dim)
73
+ proj = proj * attention_mask # (batch_size, sequence_length, dim)
74
+
75
+ return proj
76
+
77
+ def get_input_embeddings(self):
78
+ return self.model.language_model.get_input_embeddings()
79
+
80
+ def set_input_embeddings(self, value):
81
+ self.model.language_model.set_input_embeddings(value)
82
+
83
+ def get_output_embeddings(self):
84
+ return self.model.language_model.get_output_embeddings()
85
+
86
+ def set_output_embeddings(self, new_embeddings):
87
+ self.model.language_model.set_output_embeddings(new_embeddings)
88
+
89
+ def set_decoder(self, decoder):
90
+ self.model.language_model.set_decoder(decoder)
91
+
92
+ def get_decoder(self):
93
+ return self.model.language_model.get_decoder()
94
+
95
+ def tie_weights(self):
96
+ return self.model.language_model.tie_weights()
97
+
98
+ def resize_token_embeddings(
99
+ self,
100
+ new_num_tokens: Optional[int] = None,
101
+ pad_to_multiple_of=None,
102
+ ) -> nn.Embedding:
103
+ model_embeds = self.model.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
104
+
105
+ # Update vocab size
106
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
107
+ self.config.vocab_size = model_embeds.num_embeddings
108
+ self.model.vocab_size = model_embeds.num_embeddings
109
+
110
+ return model_embeds
111
+
112
+ @property
113
+ def patch_size(self) -> int:
114
+ return self.model.vision_tower.config.patch_size