|  | import base64 | 
					
						
						|  | import json | 
					
						
						|  | import os | 
					
						
						|  | import math | 
					
						
						|  | from io import BytesIO | 
					
						
						|  | from typing import Any, Dict, List, Literal, Optional, Union | 
					
						
						|  | from urllib.parse import urlparse | 
					
						
						|  |  | 
					
						
						|  | import requests | 
					
						
						|  | import torch | 
					
						
						|  | from PIL import Image | 
					
						
						|  | from torch import nn | 
					
						
						|  | from transformers import AutoProcessor, Qwen2VLForConditionalGeneration | 
					
						
						|  |  | 
					
						
						|  | class Transformer(nn.Module): | 
					
						
						|  | save_in_root: bool = True | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | model_name_or_path: str = 'llamaindex/vdr-2b-multi-v1', | 
					
						
						|  | processor_name_or_path: Optional[str] = None, | 
					
						
						|  | max_pixels: int = 768 * 28 * 28, | 
					
						
						|  | min_pixels: int = 1 * 28 * 28, | 
					
						
						|  | dimension: int = 2048, | 
					
						
						|  | max_seq_length: Optional[int] = None, | 
					
						
						|  | model_args: Optional[Dict[str, Any]] = None, | 
					
						
						|  | processor_args: Optional[Dict[str, Any]] = None, | 
					
						
						|  | tokenizer_args: Optional[Dict[str, Any]] = None, | 
					
						
						|  | config_args: Optional[Dict[str, Any]] = None, | 
					
						
						|  | cache_dir: Optional[str] = None, | 
					
						
						|  | backend: Literal['torch', 'onnx', 'openvino'] = 'torch', | 
					
						
						|  | **kwargs, | 
					
						
						|  | ) -> None: | 
					
						
						|  | super(Transformer, self).__init__() | 
					
						
						|  |  | 
					
						
						|  | if backend != 'torch': | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f'Backend \'{backend}\' is not supported, please use \'torch\' instead' | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.dimension = dimension | 
					
						
						|  | self.max_pixels = max_pixels | 
					
						
						|  | self.min_pixels = min_pixels | 
					
						
						|  | self.max_seq_length = max_seq_length | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model_kwargs = model_args or {} | 
					
						
						|  | model_kwargs.update(kwargs) | 
					
						
						|  |  | 
					
						
						|  | processor_kwargs = processor_args or {} | 
					
						
						|  | processor_kwargs.update({ | 
					
						
						|  | 'min_pixels': min_pixels, | 
					
						
						|  | 'max_pixels': max_pixels, | 
					
						
						|  | 'cache_dir': cache_dir | 
					
						
						|  | }) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.model = Qwen2VLForConditionalGeneration.from_pretrained( | 
					
						
						|  | model_name_or_path, | 
					
						
						|  | cache_dir=cache_dir, | 
					
						
						|  | **model_kwargs | 
					
						
						|  | ).eval() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.processor = AutoProcessor.from_pretrained( | 
					
						
						|  | processor_name_or_path or model_name_or_path, | 
					
						
						|  | **processor_kwargs | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.model.padding_side = "left" | 
					
						
						|  | self.processor.tokenizer.padding_side = "left" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.document_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is shown in this image?<|im_end|>\n<|endoftext|>" | 
					
						
						|  | self.query_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Query: %s<|im_end|>\n<|endoftext|>" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.max_seq_length is None: | 
					
						
						|  | if ( | 
					
						
						|  | hasattr(self.model, 'config') | 
					
						
						|  | and hasattr(self.model.config, 'max_position_embeddings') | 
					
						
						|  | and hasattr(self.processor.tokenizer, 'model_max_length') | 
					
						
						|  | ): | 
					
						
						|  | self.max_seq_length = min( | 
					
						
						|  | self.model.config.max_position_embeddings, | 
					
						
						|  | self.processor.tokenizer.model_max_length, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def _smart_resize(self, height: int, width: int) -> tuple[int, int]: | 
					
						
						|  | h_bar = max(28, self._round_by_factor(height, 28)) | 
					
						
						|  | w_bar = max(28, self._round_by_factor(width, 28)) | 
					
						
						|  | if h_bar * w_bar > self.max_pixels: | 
					
						
						|  | beta = math.sqrt((height * width) / self.max_pixels) | 
					
						
						|  | h_bar = self._floor_by_factor(height / beta, 28) | 
					
						
						|  | w_bar = self._floor_by_factor(width / beta, 28) | 
					
						
						|  | elif h_bar * w_bar < self.min_pixels: | 
					
						
						|  | beta = math.sqrt(self.min_pixels / (height * width)) | 
					
						
						|  | h_bar = self._ceil_by_factor(height * beta, 28) | 
					
						
						|  | w_bar = self._ceil_by_factor(width * beta, 28) | 
					
						
						|  | return w_bar, h_bar | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _round_by_factor(number: float, factor: int) -> int: | 
					
						
						|  | return round(number / factor) * factor | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _ceil_by_factor(number: float, factor: int) -> int: | 
					
						
						|  | return math.ceil(number / factor) * factor | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _floor_by_factor(number: float, factor: int) -> int: | 
					
						
						|  | return math.floor(number / factor) * factor | 
					
						
						|  |  | 
					
						
						|  | def _resize_image(self, image: Image.Image) -> Image.Image: | 
					
						
						|  | new_size = self._smart_resize(image.height, image.width) | 
					
						
						|  | return image.resize(new_size) | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _decode_data_image(data_image_str: str) -> Image.Image: | 
					
						
						|  | header, data = data_image_str.split(',', 1) | 
					
						
						|  | image_data = base64.b64decode(data) | 
					
						
						|  | return Image.open(BytesIO(image_data)) | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _is_valid_url(url: str) -> bool: | 
					
						
						|  | try: | 
					
						
						|  | result = urlparse(url) | 
					
						
						|  |  | 
					
						
						|  | return all([result.scheme in ('http', 'https'), result.netloc]) | 
					
						
						|  | except Exception: | 
					
						
						|  | return False | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _is_safe_path(path: str) -> bool: | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | abs_path = os.path.abspath(os.path.normpath(path)) | 
					
						
						|  |  | 
					
						
						|  | return os.path.isfile(abs_path) | 
					
						
						|  | except Exception: | 
					
						
						|  | return False | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _load_image_from_url(url: str) -> Image.Image: | 
					
						
						|  | try: | 
					
						
						|  | response = requests.get( | 
					
						
						|  | url, | 
					
						
						|  | stream=True, | 
					
						
						|  | timeout=10, | 
					
						
						|  | headers={'User-Agent': 'Mozilla/5.0'} | 
					
						
						|  | ) | 
					
						
						|  | response.raise_for_status() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | content_type = response.headers.get('content-type', '') | 
					
						
						|  | if not content_type.startswith('image/'): | 
					
						
						|  | raise ValueError(f"Invalid content type: {content_type}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | content = BytesIO() | 
					
						
						|  | size = 0 | 
					
						
						|  | max_size = 10 * 1024 * 1024 | 
					
						
						|  |  | 
					
						
						|  | for chunk in response.iter_content(chunk_size=8192): | 
					
						
						|  | size += len(chunk) | 
					
						
						|  | if size > max_size: | 
					
						
						|  | raise ValueError("File too large") | 
					
						
						|  | content.write(chunk) | 
					
						
						|  |  | 
					
						
						|  | content.seek(0) | 
					
						
						|  | return Image.open(content) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | raise ValueError(f"Failed to load image from URL: {str(e)}") | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _load_image_from_path(image_path: str) -> Image.Image: | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | abs_path = os.path.abspath(os.path.normpath(image_path)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | file_size = os.path.getsize(abs_path) | 
					
						
						|  | max_size = 10 * 1024 * 1024 | 
					
						
						|  | if file_size > max_size: | 
					
						
						|  | raise ValueError("File too large") | 
					
						
						|  |  | 
					
						
						|  | with Image.open(abs_path) as img: | 
					
						
						|  |  | 
					
						
						|  | return img.copy() | 
					
						
						|  | except Exception as e: | 
					
						
						|  | raise ValueError(f"Failed to load image from path: {str(e)}") | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _load_image_from_bytes(image_bytes: bytes) -> Image.Image: | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | if len(image_bytes) > 10 * 1024 * 1024: | 
					
						
						|  | raise ValueError("Image data too large") | 
					
						
						|  |  | 
					
						
						|  | return Image.open(BytesIO(image_bytes)) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | raise ValueError(f"Failed to load image from bytes: {str(e)}") | 
					
						
						|  |  | 
					
						
						|  | def _process_input(self, texts: List[Union[str, Image.Image, bytes]]) -> tuple[List[str], List[Image.Image]]: | 
					
						
						|  | processed_texts = [] | 
					
						
						|  | processed_images = [] | 
					
						
						|  | dummy_image = Image.new('RGB', (56, 56)) | 
					
						
						|  |  | 
					
						
						|  | for sample in texts: | 
					
						
						|  | if isinstance(sample, str): | 
					
						
						|  |  | 
					
						
						|  | if self._is_valid_url(sample): | 
					
						
						|  | try: | 
					
						
						|  | img = self._load_image_from_url(sample) | 
					
						
						|  | processed_texts.append(self.document_prompt) | 
					
						
						|  | processed_images.append(self._resize_image(img)) | 
					
						
						|  | except Exception as e: | 
					
						
						|  |  | 
					
						
						|  | processed_texts.append(self.query_prompt % sample) | 
					
						
						|  | processed_images.append(dummy_image) | 
					
						
						|  |  | 
					
						
						|  | elif self._is_safe_path(sample): | 
					
						
						|  | try: | 
					
						
						|  | img = self._load_image_from_path(sample) | 
					
						
						|  | processed_texts.append(self.document_prompt) | 
					
						
						|  | processed_images.append(self._resize_image(img)) | 
					
						
						|  | except Exception as e: | 
					
						
						|  |  | 
					
						
						|  | processed_texts.append(self.query_prompt % sample) | 
					
						
						|  | processed_images.append(dummy_image) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | processed_texts.append(self.query_prompt % sample) | 
					
						
						|  | processed_images.append(dummy_image) | 
					
						
						|  | elif isinstance(sample, Image.Image): | 
					
						
						|  | processed_texts.append(self.document_prompt) | 
					
						
						|  | processed_images.append(self._resize_image(sample)) | 
					
						
						|  | elif isinstance(sample, bytes): | 
					
						
						|  | try: | 
					
						
						|  | img = self._load_image_from_bytes(sample) | 
					
						
						|  | processed_texts.append(self.document_prompt) | 
					
						
						|  | processed_images.append(self._resize_image(img)) | 
					
						
						|  | except Exception as e: | 
					
						
						|  |  | 
					
						
						|  | processed_texts.append(self.document_prompt) | 
					
						
						|  | processed_images.append(dummy_image) | 
					
						
						|  |  | 
					
						
						|  | return processed_texts, processed_images | 
					
						
						|  |  | 
					
						
						|  | def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | 
					
						
						|  | cache_position = torch.arange(0, features['input_ids'].shape[1]) | 
					
						
						|  | inputs = self.model.prepare_inputs_for_generation( | 
					
						
						|  | **features, cache_position=cache_position, use_cache=False | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | device = next(self.model.parameters()).device | 
					
						
						|  | inputs = {k: v.to(device) for k, v in inputs.items() if isinstance(v, torch.Tensor)} | 
					
						
						|  |  | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | output = self.model( | 
					
						
						|  | **inputs, | 
					
						
						|  | return_dict=True, | 
					
						
						|  | output_hidden_states=True | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | embeddings = output.hidden_states[-1][:, -1] | 
					
						
						|  | features['sentence_embedding'] = torch.nn.functional.normalize( | 
					
						
						|  | embeddings[:, :self.dimension], p=2, dim=-1 | 
					
						
						|  | ) | 
					
						
						|  | return features | 
					
						
						|  |  | 
					
						
						|  | def tokenize(self, texts: List[Union[str, Image.Image, bytes]], padding: str = 'longest') -> Dict[str, torch.Tensor]: | 
					
						
						|  | processed_texts, processed_images = self._process_input(texts) | 
					
						
						|  |  | 
					
						
						|  | return self.processor( | 
					
						
						|  | text=processed_texts, | 
					
						
						|  | images=processed_images, | 
					
						
						|  | videos=None, | 
					
						
						|  | padding=padding, | 
					
						
						|  | return_tensors='pt' | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def save(self, output_path: str, safe_serialization: bool = True) -> None: | 
					
						
						|  | """Save the model, tokenizer and processor to the given path.""" | 
					
						
						|  | self.model.save_pretrained(output_path, safe_serialization=safe_serialization) | 
					
						
						|  | self.processor.save_pretrained(output_path) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | config = { | 
					
						
						|  | 'model_name_or_path': output_path, | 
					
						
						|  | 'max_pixels': self.max_pixels, | 
					
						
						|  | 'min_pixels': self.min_pixels, | 
					
						
						|  | 'dimension': self.dimension, | 
					
						
						|  | 'max_seq_length': self.max_seq_length, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | config_path = os.path.join(output_path, 'sentence_bert_config.json') | 
					
						
						|  | with open(config_path, 'w') as f: | 
					
						
						|  | json.dump(config, f) | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def load(input_path: str) -> 'Transformer': | 
					
						
						|  | """Load a saved model from the given path.""" | 
					
						
						|  |  | 
					
						
						|  | config_path = os.path.join(input_path, 'sentence_bert_config.json') | 
					
						
						|  | if os.path.exists(config_path): | 
					
						
						|  | with open(config_path) as f: | 
					
						
						|  | config = json.load(f) | 
					
						
						|  | else: | 
					
						
						|  | config = {'model_name_or_path': input_path} | 
					
						
						|  |  | 
					
						
						|  | return Transformer(**config) |