import math from typing import ClassVar, List, Optional, Tuple, Union import torch from PIL import Image, ImageOps from transformers import BatchFeature, LlavaNextProcessor def round_by_factor(number: float, factor: int) -> int: """Returns the closest integer to 'number' that is divisible by 'factor'.""" return round(number / factor) * factor def ceil_by_factor(number: float, factor: int) -> int: """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" return math.ceil(number / factor) * factor def floor_by_factor(number: float, factor: int) -> int: """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" return math.floor(number / factor) * factor class GraniteVisionEmbProcessor(LlavaNextProcessor): """ Processor for GraniteVisionEmb. """ visual_prompt_prefix: ClassVar[str] = "<|user|>\n\nDescribe the image.\n" system_message: ClassVar[ str] = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." query_prefix: ClassVar[str] = "Query: " query_start: ClassVar[str] = "<|user|>\n" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.factor = 14 self.min_size = 384 self.max_size = 384 * 2 self.suffix_len = 10 self.patch_size = 14 @property def query_augmentation_token(self) -> str: """ Return the query augmentation token. Query augmentation buffers are used as reasoning buffers during inference. """ return self.tokenizer.pad_token @staticmethod def smart_resize_helper( width: int, height: int, factor: int, min_size: int, max_size: int ) -> Tuple[int, int]: """ Returns the resized image dimensions such that: 1. The smaller dimension is set to 'min_size'. 2. The larger dimension is scaled proportionally to maintain aspect ratio. 3. If the larger dimension exceeds 'max_size', it is clipped to 'max_size', and the smaller dimension is adjusted accordingly to maintain aspect ratio. 4. Both dimensions are divisible by 'factor'. """ # Determine scale factor based on min_size if height < width: scale_factor = min_size / height else: scale_factor = min_size / width new_width = round(width * scale_factor) new_height = round(height * scale_factor) # If the longer dimension exceeds max_size, adjust accordingly if max(new_width, new_height) > max_size: clip_factor = max_size / max(new_width, new_height) new_width = round(new_width * clip_factor) new_height = round(new_height * clip_factor) # Ensure dimensions are divisible by factor # new_width = round_by_factor(new_width, factor) # new_height = round_by_factor(new_height, factor) return new_width, new_height @staticmethod def pad_image_center(image: Image.Image, target_width: int, target_height: int, fill_color=(0, 0, 0)) -> Image.Image: """ Pads the given image to be centered within the target dimensions. :param image: PIL Image to be padded. :param target_width: The desired width after padding. :param target_height: The desired height after padding. :param fill_color: Background color (default is black). :return: Padded image with centered content. """ # Get original image size img_width, img_height = image.size # Compute padding values pad_left = (target_width - img_width) // 2 pad_top = (target_height - img_height) // 2 pad_right = target_width - img_width - pad_left pad_bottom = target_height - img_height - pad_top # Apply padding padded_image = ImageOps.expand(image, (pad_left, pad_top, pad_right, pad_bottom), fill_color).convert("RGB") return padded_image def smart_resize(self, image: Image.Image) -> Image.Image: """ Resize and convert the image to the required format. """ image_size = image.size resized_height, resized_width = self.smart_resize_helper( width=image_size[0], height=image_size[1], factor=self.factor, min_size=self.min_size, max_size=self.max_size ) return image.convert("RGB").resize((resized_width, resized_height)) def smart_resize_and_pad(self, image: Image.Image) -> Image.Image: """ Resize and pad the image to the required format. """ return self.resize_and_pad_centered_to_long_side( image=image, factor=self.factor, min_size=self.min_size, max_size=self.max_size, fill_color=0 ) def resize_and_pad_centered_to_long_side( self, image: Image.Image, factor: int, min_size: int, max_size: int, fill_color=0 ) -> Image.Image: """ Resizes and pads an image such that: - The long side is set to `max_size`. - The short side is scaled proportionally but not below `min_size`. - The image is centered within the final padded area. :param image: PIL Image :param factor: Factor to make dimensions divisible by :param min_size: Minimum allowed size for the short side :param max_size: Target size for the long side :param fill_color: Background padding color (default black) :return: Resized and padded image """ # Get original size width, height = image.size if min_size == -1 or max_size == -1: return image.convert("RGB") # Step 1: scale long side to max_size, keep aspect ratio if width > height: scale_factor = max_size / width target_width = max_size max_scale_factor = max(min_size / height, scale_factor) target_height = round(height * max_scale_factor) else: scale_factor = max_size / height target_height = max_size max_scale_factor = max(min_size / width, scale_factor) target_width = round(width * max_scale_factor) # Resize the image resized_image = image.resize((target_width, target_height), Image.LANCZOS) final_image = resized_image.convert("RGB") return final_image def resize_and_pad_centered(self, image: Image.Image, factor: int, min_size: int, max_size: int, fill_color=0 ) -> Image.Image: """ Resizes and pads an image such that: - The short side is set to `min_size`. - The long side is scaled proportionally but clipped to `max_size`. - The image is centered within the final padded area. :param image: PIL Image :param factor: Factor to make dimensions divisible by :param min_size: Minimum size for the short side :param max_size: Maximum allowed size for the long side :param fill_color: Background padding color (default black) :return: Resized and padded image """ # Get original size width, height = image.size if min_size == -1 or max_size == -1: return image.convert("RGB") # Determine scale factor based on the short side (min_size) if width < height: scale_factor = min_size / width target_width = min_size max_scale_factor = min(max_size / height, scale_factor) target_height = round(height * max_scale_factor) else: scale_factor = min_size / height target_height = min_size max_scale_factor = min(max_size / width, scale_factor) target_width = round(width * max_scale_factor) # Ensure the longer side does not exceed max_size # if max(target_width, target_height) > max_size: # clip_factor = max_size / max(target_width, target_height) # target_width = round(target_width * clip_factor) # target_height = round(target_height * clip_factor) # Ensure dimensions are divisible by factor # target_width = round_by_factor(target_width, factor) # target_height = round_by_factor(target_height, factor) # Resize the image resized_image = image.resize((target_width, target_height), Image.LANCZOS) # Determine final padded dimensions (aligned to short side) if width < height: final_width, final_height = min_size, max_size else: final_width, final_height = max_size, min_size # Compute padding to center the image pad_left = (final_width - target_width) // 2 pad_top = (final_height - target_height) // 2 pad_right = final_width - target_width - pad_left pad_bottom = final_height - target_height - pad_top # Apply centered padding # final_image = ImageOps.expand(resized_image, (pad_left, pad_top, pad_right, pad_bottom), fill_color).convert("RGB") final_image = resized_image.convert("RGB") return final_image def format_data(self, question, image): return [ { "role": "system", "content": [{"type": "text", "text": self.system_message}], }, { "role": "user", "content": [ { "type": "image", "image": image, }, { "type": "text", "text": question, }, ], } ] def format_data_wo_role(self, question, image=None): return [ { "role": "user", "content": [ { "type": "image", "image": image, }, { "type": "text", "text": question, }, ], } ] def process_images( self, images: List[Image.Image], ) -> BatchFeature: """ Process images. """ # texts_doc = [self.apply_chat_template(self.format_data_wo_role(self.visual_prompt_prefix, img),tokenize=False ) for img in images] texts_doc = [self.visual_prompt_prefix for _ in images] images = [self.smart_resize_and_pad(image) for image in images] batch_doc = self( text=texts_doc, images=images, return_tensors="pt", padding="longest", ) return batch_doc def process_queries(self, queries, max_length=2048, suffix=None): if suffix is None: suffix = self.query_augmentation_token * self.suffix_len processed = [] for q in queries: q = self.query_start + self.query_prefix + q + ' ' + q q += suffix + "\n" processed.append(q) return self( text=processed, images=None, return_tensors="pt", padding="longest", truncation=True, max_length=max_length, ) def score( self, qs: List[torch.Tensor], ps: List[torch.Tensor], device: Optional[Union[str, torch.device]] = None, **kwargs, ) -> torch.Tensor: """ Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. """ return self.score_multi_vector(qs, ps, device=device, **kwargs) def get_n_patches( self, image_size: Tuple[int, int], patch_size: int, ) -> Tuple[int, int]: n_patches_x = self.image_processor.size["width"] // patch_size n_patches_y = self.image_processor.size["height"] // patch_size return n_patches_x, n_patches_y def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor: return batch_images.input_ids == self.image_token_id @staticmethod def score_single_vector( qs: List[torch.Tensor], ps: List[torch.Tensor], device: Optional[Union[str, torch.device]] = None, ) -> torch.Tensor: """ Compute the dot product score for the given single-vector query and passage embeddings. """ if len(qs) == 0: raise ValueError("No queries provided") if len(ps) == 0: raise ValueError("No passages provided") qs_stacked = torch.stack(qs).to(device) ps_stacked = torch.stack(ps).to(device) scores = torch.einsum("bd,cd->bc", qs_stacked, ps_stacked) assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" scores = scores.to(torch.float32) return scores @staticmethod def score_multi_vector( qs: Union[torch.Tensor, List[torch.Tensor]], ps: Union[torch.Tensor, List[torch.Tensor]], batch_size: int = 128, device: Optional[Union[str, torch.device]] = None, ) -> torch.Tensor: """ Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector query embeddings (`qs`) and passage embeddings (`ps`). For us, a passage is the image of a document page. Because the embedding tensors are multi-vector and can thus have different shapes, they should be fed as: (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim) (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually obtained by padding the list of tensors. Args: qs (`Union[torch.Tensor, List[torch.Tensor]`): Query embeddings. ps (`Union[torch.Tensor, List[torch.Tensor]`): Passage embeddings. batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores. device (`Union[str, torch.device]`, *optional*): Device to use for computation. If not provided, uses `get_torch_device("auto")`. Returns: `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score tensor is saved on the "cpu" device. """ if len(qs) == 0: raise ValueError("No queries provided") if len(ps) == 0: raise ValueError("No passages provided") scores_list: List[torch.Tensor] = [] for i in range(0, len(qs), batch_size): scores_batch = [] qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i: i + batch_size], batch_first=True, padding_value=0).to( device ) for j in range(0, len(ps), batch_size): ps_batch = torch.nn.utils.rnn.pad_sequence( ps[j: j + batch_size], batch_first=True, padding_value=0 ).to(device) scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2)) scores_batch = torch.cat(scores_batch, dim=1).cpu() scores_list.append(scores_batch) scores = torch.cat(scores_list, dim=0) assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" scores = scores.to(torch.float32) return scores