Adirazgold commited on
Commit
09dc281
·
verified ·
1 Parent(s): 6b2247e

Delete processing_colgranitevision.py

Browse files
Files changed (1) hide show
  1. processing_colgranitevision.py +0 -396
processing_colgranitevision.py DELETED
@@ -1,396 +0,0 @@
1
- import math
2
- from typing import ClassVar, List, Optional, Tuple, Union
3
-
4
- import torch
5
- from PIL import Image, ImageOps
6
- from transformers import BatchFeature, LlavaNextProcessor
7
-
8
-
9
- def round_by_factor(number: float, factor: int) -> int:
10
- """Returns the closest integer to 'number' that is divisible by 'factor'."""
11
- return round(number / factor) * factor
12
-
13
-
14
- def ceil_by_factor(number: float, factor: int) -> int:
15
- """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
16
- return math.ceil(number / factor) * factor
17
-
18
-
19
- def floor_by_factor(number: float, factor: int) -> int:
20
- """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
21
- return math.floor(number / factor) * factor
22
-
23
-
24
- class ColGraniteVisionProcessor(LlavaNextProcessor):
25
- """
26
- Processor for ColPali.
27
- """
28
-
29
- visual_prompt_prefix: ClassVar[str] = "<|user|>\n<image>\nDescribe the image.\n"
30
- system_message: ClassVar[
31
- str] = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
32
- query_prefix: ClassVar[str] = "Query: "
33
- query_start: ClassVar[str] = "<|user|>\n"
34
-
35
- def __init__(self, *args, **kwargs):
36
- super().__init__(*args, **kwargs)
37
- self.factor = 14
38
- self.min_size = 384
39
- self.max_size = 384 * 2
40
- self.suffix_len = 10
41
- self.patch_size = 14
42
-
43
- @property
44
- def query_augmentation_token(self) -> str:
45
- """
46
- Return the query augmentation token.
47
- Query augmentation buffers are used as reasoning buffers during inference.
48
- """
49
- return self.tokenizer.pad_token
50
-
51
- @staticmethod
52
- def smart_resize_helper(
53
- width: int,
54
- height: int,
55
- factor: int,
56
- min_size: int,
57
- max_size: int
58
- ) -> Tuple[int, int]:
59
- """
60
- Returns the resized image dimensions such that:
61
- 1. The smaller dimension is set to 'min_size'.
62
- 2. The larger dimension is scaled proportionally to maintain aspect ratio.
63
- 3. If the larger dimension exceeds 'max_size', it is clipped to 'max_size',
64
- and the smaller dimension is adjusted accordingly to maintain aspect ratio.
65
- 4. Both dimensions are divisible by 'factor'.
66
- """
67
-
68
- # Determine scale factor based on min_size
69
- if height < width:
70
- scale_factor = min_size / height
71
- else:
72
- scale_factor = min_size / width
73
-
74
- new_width = round(width * scale_factor)
75
- new_height = round(height * scale_factor)
76
-
77
- # If the longer dimension exceeds max_size, adjust accordingly
78
- if max(new_width, new_height) > max_size:
79
- clip_factor = max_size / max(new_width, new_height)
80
- new_width = round(new_width * clip_factor)
81
- new_height = round(new_height * clip_factor)
82
-
83
- # Ensure dimensions are divisible by factor
84
- # new_width = round_by_factor(new_width, factor)
85
- # new_height = round_by_factor(new_height, factor)
86
-
87
- return new_width, new_height
88
-
89
- @staticmethod
90
- def pad_image_center(image: Image.Image,
91
- target_width: int,
92
- target_height: int,
93
- fill_color=(0, 0, 0)) -> Image.Image:
94
- """
95
- Pads the given image to be centered within the target dimensions.
96
-
97
- :param image: PIL Image to be padded.
98
- :param target_width: The desired width after padding.
99
- :param target_height: The desired height after padding.
100
- :param fill_color: Background color (default is black).
101
- :return: Padded image with centered content.
102
- """
103
-
104
- # Get original image size
105
- img_width, img_height = image.size
106
-
107
- # Compute padding values
108
- pad_left = (target_width - img_width) // 2
109
- pad_top = (target_height - img_height) // 2
110
- pad_right = target_width - img_width - pad_left
111
- pad_bottom = target_height - img_height - pad_top
112
-
113
- # Apply padding
114
- padded_image = ImageOps.expand(image, (pad_left, pad_top, pad_right, pad_bottom), fill_color).convert("RGB")
115
-
116
- return padded_image
117
-
118
- def smart_resize(self, image: Image.Image) -> Image.Image:
119
- """
120
- Resize and convert the image to the required format.
121
- """
122
- image_size = image.size
123
- resized_height, resized_width = self.smart_resize_helper(
124
- width=image_size[0],
125
- height=image_size[1],
126
- factor=self.factor,
127
- min_size=self.min_size,
128
- max_size=self.max_size
129
- )
130
- return image.convert("RGB").resize((resized_width, resized_height))
131
-
132
- def smart_resize_and_pad(self, image: Image.Image) -> Image.Image:
133
- """
134
- Resize and pad the image to the required format.
135
- """
136
- return self.resize_and_pad_centered(
137
- image=image,
138
- factor=self.factor,
139
- min_size=self.min_size,
140
- max_size=self.max_size,
141
- fill_color=0
142
- )
143
-
144
- def resize_and_pad_centered(self,
145
- image: Image.Image,
146
- factor: int,
147
- min_size: int,
148
- max_size: int,
149
- fill_color=0
150
- ) -> Image.Image:
151
- """
152
- Resizes and pads an image such that:
153
- - The short side is set to `min_size`.
154
- - The long side is scaled proportionally but clipped to `max_size`.
155
- - The image is centered within the final padded area.
156
-
157
- :param image: PIL Image
158
- :param factor: Factor to make dimensions divisible by
159
- :param min_size: Minimum size for the short side
160
- :param max_size: Maximum allowed size for the long side
161
- :param fill_color: Background padding color (default black)
162
- :return: Resized and padded image
163
- """
164
-
165
- # Get original size
166
- width, height = image.size
167
-
168
- if min_size == -1 or max_size == -1:
169
- return image.convert("RGB")
170
-
171
- # Determine scale factor based on the short side (min_size)
172
- if width < height:
173
- scale_factor = min_size / width
174
- target_width = min_size
175
- max_scale_factor = min(max_size / height, scale_factor)
176
- target_height = round(height * max_scale_factor)
177
- else:
178
- scale_factor = min_size / height
179
- target_height = min_size
180
- max_scale_factor = min(max_size / width, scale_factor)
181
- target_width = round(width * max_scale_factor)
182
-
183
- # Ensure the longer side does not exceed max_size
184
- # if max(target_width, target_height) > max_size:
185
- # clip_factor = max_size / max(target_width, target_height)
186
- # target_width = round(target_width * clip_factor)
187
- # target_height = round(target_height * clip_factor)
188
-
189
- # Ensure dimensions are divisible by factor
190
- # target_width = round_by_factor(target_width, factor)
191
- # target_height = round_by_factor(target_height, factor)
192
-
193
- # Resize the image
194
- resized_image = image.resize((target_width, target_height), Image.LANCZOS)
195
-
196
- # Determine final padded dimensions (aligned to short side)
197
- if width < height:
198
- final_width, final_height = min_size, max_size
199
- else:
200
- final_width, final_height = max_size, min_size
201
-
202
- # Compute padding to center the image
203
- pad_left = (final_width - target_width) // 2
204
- pad_top = (final_height - target_height) // 2
205
- pad_right = final_width - target_width - pad_left
206
- pad_bottom = final_height - target_height - pad_top
207
-
208
- # Apply centered padding
209
- # final_image = ImageOps.expand(resized_image, (pad_left, pad_top, pad_right, pad_bottom), fill_color).convert("RGB")
210
- final_image = resized_image.convert("RGB")
211
-
212
- return final_image
213
-
214
- def format_data(self, question, image):
215
- return [
216
- {
217
- "role": "system",
218
- "content": [{"type": "text", "text": self.system_message}],
219
- },
220
- {
221
- "role": "user",
222
- "content": [
223
- {
224
- "type": "image",
225
- "image": image,
226
- },
227
- {
228
- "type": "text",
229
- "text": question,
230
- },
231
- ],
232
- }
233
- ]
234
-
235
- def format_data_wo_role(self, question, image=None):
236
- return [
237
- {
238
- "role": "user",
239
- "content": [
240
- {
241
- "type": "image",
242
- "image": image,
243
- },
244
- {
245
- "type": "text",
246
- "text": question,
247
- },
248
- ],
249
- }
250
- ]
251
-
252
- def process_images(
253
- self,
254
- images: List[Image.Image],
255
- ) -> BatchFeature:
256
- """
257
- Process images for ColPali.
258
- """
259
- # texts_doc = [self.apply_chat_template(self.format_data_wo_role(self.visual_prompt_prefix, img),tokenize=False ) for img in images]
260
- texts_doc = [self.visual_prompt_prefix for _ in images]
261
- images = [self.smart_resize_and_pad(image) for image in images]
262
-
263
- batch_doc = self(
264
- text=texts_doc,
265
- images=images,
266
- return_tensors="pt",
267
- padding="longest",
268
- )
269
- return batch_doc
270
-
271
- def process_queries(self, queries, max_length=2048, suffix=None):
272
- if suffix is None:
273
- suffix = self.query_augmentation_token * self.suffix_len
274
-
275
- processed = []
276
- for q in queries:
277
- q = self.query_start + self.query_prefix + q
278
- # truncate before it eats actual query content
279
- if len(q) + len(suffix) > max_length:
280
- q = q[: max_length - len(suffix) - 1]
281
- q += suffix + "\n"
282
- processed.append(q)
283
-
284
- return self(
285
- text=processed,
286
- images=None,
287
- return_tensors="pt",
288
- padding="longest",
289
- truncation=True,
290
- max_length=max_length,
291
- )
292
-
293
- def score(
294
- self,
295
- qs: List[torch.Tensor],
296
- ps: List[torch.Tensor],
297
- device: Optional[Union[str, torch.device]] = None,
298
- **kwargs,
299
- ) -> torch.Tensor:
300
- """
301
- Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
302
- """
303
- return self.score_multi_vector(qs, ps, device=device, **kwargs)
304
-
305
- def get_n_patches(
306
- self,
307
- image_size: Tuple[int, int],
308
- patch_size: int,
309
- ) -> Tuple[int, int]:
310
- n_patches_x = self.image_processor.size["width"] // patch_size
311
- n_patches_y = self.image_processor.size["height"] // patch_size
312
-
313
- return n_patches_x, n_patches_y
314
-
315
- def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor:
316
- return batch_images.input_ids == self.image_token_id
317
-
318
- @staticmethod
319
- def score_single_vector(
320
- qs: List[torch.Tensor],
321
- ps: List[torch.Tensor],
322
- device: Optional[Union[str, torch.device]] = None,
323
- ) -> torch.Tensor:
324
- """
325
- Compute the dot product score for the given single-vector query and passage embeddings.
326
- """
327
-
328
- if len(qs) == 0:
329
- raise ValueError("No queries provided")
330
- if len(ps) == 0:
331
- raise ValueError("No passages provided")
332
-
333
- qs_stacked = torch.stack(qs).to(device)
334
- ps_stacked = torch.stack(ps).to(device)
335
-
336
- scores = torch.einsum("bd,cd->bc", qs_stacked, ps_stacked)
337
- assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
338
-
339
- scores = scores.to(torch.float32)
340
- return scores
341
-
342
- @staticmethod
343
- def score_multi_vector(
344
- qs: Union[torch.Tensor, List[torch.Tensor]],
345
- ps: Union[torch.Tensor, List[torch.Tensor]],
346
- batch_size: int = 128,
347
- device: Optional[Union[str, torch.device]] = None,
348
- ) -> torch.Tensor:
349
- """
350
- Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
351
- query embeddings (`qs`) and passage embeddings (`ps`). For ColPali, a passage is the
352
- image of a document page.
353
-
354
- Because the embedding tensors are multi-vector and can thus have different shapes, they
355
- should be fed as:
356
- (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim)
357
- (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually
358
- obtained by padding the list of tensors.
359
-
360
- Args:
361
- qs (`Union[torch.Tensor, List[torch.Tensor]`): Query embeddings.
362
- ps (`Union[torch.Tensor, List[torch.Tensor]`): Passage embeddings.
363
- batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores.
364
- device (`Union[str, torch.device]`, *optional*): Device to use for computation. If not
365
- provided, uses `get_torch_device("auto")`.
366
-
367
- Returns:
368
- `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score
369
- tensor is saved on the "cpu" device.
370
- """
371
-
372
- if len(qs) == 0:
373
- raise ValueError("No queries provided")
374
- if len(ps) == 0:
375
- raise ValueError("No passages provided")
376
-
377
- scores_list: List[torch.Tensor] = []
378
-
379
- for i in range(0, len(qs), batch_size):
380
- scores_batch = []
381
- qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i: i + batch_size], batch_first=True, padding_value=0).to(
382
- device
383
- )
384
- for j in range(0, len(ps), batch_size):
385
- ps_batch = torch.nn.utils.rnn.pad_sequence(
386
- ps[j: j + batch_size], batch_first=True, padding_value=0
387
- ).to(device)
388
- scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
389
- scores_batch = torch.cat(scores_batch, dim=1).cpu()
390
- scores_list.append(scores_batch)
391
-
392
- scores = torch.cat(scores_list, dim=0)
393
- assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
394
-
395
- scores = scores.to(torch.float32)
396
- return scores